Spring OAuth2のPrincipalをカスタマイズする方法

Spring OAuth2のPrincipalをカスタマイズする方法を備忘録としてまとめます。

Spring OAuth2を使用することで、AADなどと連携することができます。 その際にidTokenに含まれている情報をPrincipalから取得できますが、場合によってはログイン後に カスタムな値をPrincipalへ含めたいことがあります。

今回はログイン後にPrincipalを更新する方法とそれに伴うテストについて書こうと思います。

実装

Principalを更新するタイミングとしてはログインが完了したタイミングが最適だと考え、 OAuth2が提供しているAuthenticationSuccessHandler を実装したLoginHandlerをクラスを作成します。 また作成した LoginHandlerクラスを作成しConfigへ登録します。

@Component class SuccessHandler: AuthenticationSuccessHandler { override fun onAuthenticationSuccess(request: HttpServletRequest, response: HttpServletResponse, authentication: Authentication?) { TODO("Not yet implemented") } }
@Configuration class SecurityConfig { @Autowired lateinit var successHandler: SuccessHandler @Bean fun filterChain(http: HttpSecurity): SecurityFilterChain { http.oauth2Login { it.successHandler(successHandler) } return http.build() } }

このファイルの中で、SecurityContextを更新することでPrincipalに独自の値を入れることができます。

たとえば gid という独自のリソースを管理するidが idTokenに含まれていると仮定し、 この値を紐づけたユーザーを作成してユーザーidをPrincipalへ格納したいケースを考えてみます。

すると下記のようなコードが書けます。

@Component class SuccessHandler: AuthenticationSuccessHandler { @Autowired lateinit var userRepository: UserRepository override fun onAuthenticationSuccess(request: HttpServletRequest, response: HttpServletResponse, authentication: Authentication?) { val tokenAuthentication = authentication as OAuth2AuthenticationToken val principal = authentication.principal val gid = principal.getAttribute<String>("gid") ?: throw Exception("gid doesn't exist") val name = principal.getAttribute<String>("name") ?: throw Exception("name doesn't exist") val user = userRepository.findOrSaveByGid(gid) SecurityContextHolder.getContext().authentication = OAuth2AuthenticationToken( CustomOauth2User(authorities = principal.authorities, userId = user.userId, name = name), tokenAuthentication.authorities, tokenAuthentication.authorizedClientRegistrationId ) } }

principalの値はimmutableであるため、値を更新するためにはAuthenticationを作り直さなければいけません。

oauth2でログインが完了したときに作成されるauthenticationは OAuth2AuthenticationToken であるため、 引数で渡されるauthenticationをキャストします。

val tokenAuthentication = authentication as OAuth2AuthenticationToken

またOIDCの場合、idTokenに name などの属性が含まれていますが、こちらは getAttribute で取得することができます。

val principal = authentication.principal val gid = principal.getAttribute<String>("gid") ?: throw Exception("gid doesn't exist") val name = principal.getAttribute<String>("name") ?: throw Exception("name doesn't exist")

そして最後にprincipalとauthenticationを作り直し SecurityContextHolder.getContext().authentication へ格納することで principalを変更することができます。

SecurityContextHolder.getContext().authentication = OAuth2AuthenticationToken( CustomOauth2User(authorities = principal.authorities, userId = user.userId, name = name), tokenAuthentication.authorities, tokenAuthentication.authorizedClientRegistrationId )

このとき独自に作成したPrincipalの CustomOauth2User を渡していますが、OAuth2AuthenticationTokenの引数には OAuth2User インタフェースを実装したクラスを渡せるため、 CustomOauth2User の実体は下記のようになっています。

data class CustomOauth2User(val authorities: Collection<GrantedAuthority>, val userId: Long, private val name: String): OAuth2User { private val attributes: Map<String, Any> init { attributes = mapOf("userId" to userId, "name" to name) } override fun getName(): String { return name } override fun getAttributes(): Map<String, Any> { return attributes } override fun getAuthorities(): Collection<GrantedAuthority> { return authorities } }

以上の設定を行うことで、認可の通ったエンドポイントで下記のようにPrincipalへアクセスすることができます。

@RestController("/api/books") class BookController(private val bookService: BookService) { @GetMapping fun getBooks(@AuthenticationPrincipal user: CustomOauth2User): List<String> { return bookService.getBooks(user.userId) } }

テスト

インテグレーションテストでSuccess Handlerのテストが難しそうなため、SuccessHandlerのユニットテストを書きました。 またモックにはMockkを使用しています。

@ExtendWith(MockKExtension::class) class SuccessHandlerTest { @MockK lateinit var userRepository: UserRepository @MockK lateinit var authentication: OAuth2AuthenticationToken lateinit var successHandler: SuccessHandler lateinit var request: MockHttpServletRequest lateinit var response: MockHttpServletResponse @BeforeEach fun setup() { successHandler = SuccessHandler(userRepository) request = MockHttpServletRequest() response = MockHttpServletResponse() every { authentication.principal.getAttribute<String>("gid") } returns "" every { authentication.principal.getAttribute<String>("name") } returns "" every { authentication.principal.authorities } returns emptyList() every { authentication.authorities } returns emptyList() every { authentication.authorizedClientRegistrationId } returns "dummy" every { userRepository.findOrSaveByGid(any()) } returns User(userId = 0) } @Test fun `ユーザーを取得または保存する`() { every { authentication.principal.getAttribute<String>("gid") } returns "12" every { userRepository.findOrSaveByGid("12") } returns User(userId = 12L) successHandler.onAuthenticationSuccess(request = request, response = response, authentication = authentication) verify { userRepository.findOrSaveByGid("12") } } @Test fun `Principalを更新する`() { every { authentication.principal.getAttribute<String>("name") } returns "sample name" every { userRepository.findOrSaveByGid(any()) } returns User(userId = 12L) every { authentication.principal.authorities } returns emptyList() successHandler.onAuthenticationSuccess(request = request, response = response, authentication = authentication) val authentication = SecurityContextHolder.getContext().authentication assertEquals(CustomOauth2User(userId = 12L, name="sample name", authorities = emptyList()), authentication.principal) } @Nested @DisplayName("エラーケース") inner class ErrorCases { @Test fun `名前がない場合エラーを投げる`() { every { authentication.principal.getAttribute<String>("name") } returns null val exceptions = assertThrows<Exception> { successHandler.onAuthenticationSuccess(request = request, response = response, authentication = authentication) } assertEquals("name doesn't exist", exceptions.message) } @Test fun `gidがない場合エラーを投げる`() { every { authentication.principal.getAttribute<String>("gid") } returns null val exceptions = assertThrows<Exception> { successHandler.onAuthenticationSuccess(request = request, response = response, authentication = authentication) } assertEquals("gid doesn't exist", exceptions.message) } } }

またコントローラーでは下記のようにCustomOauth2User をテストで渡すことができます。

mockMVC.perform( get("/api/users") .with(oauth2Login() .oauth2User( CustomOauth2User(userId = 1L, name = "sample name", authorities = emptyList()) ) ) ).andExpect(status().isOk)

サンプルコードはこちらから。

©Tsurutan. All Rights Reserved.