How to get @WebMvcTest work with OAuth?

ぃ、小莉子 提交于 2019-11-29 07:59:14

[Edit on May 2019]
The Solution below is specific to spring-security-oauth2 which is now deprecated.
I wrote a lib to achieve the same goal with Spring5, some of which is contributed to spring-security-test 5.2. They chose to integrate the JWT flow API only, so if you need to test a service (requires the use of an annotation) or use opaque tokens introspection, you might need to browse my repo a bit...

[Edit on July 2019]
I now publish my "spring-addons" libs for Spring 5 to maven-central, which greatly improve usability.
Source and READMEs still on github.

[solution for spring-security-oauth2]
The solution I iterated to is combining a dummy "Authorization" header in requests with a mocked token service intercepting it (after quite a few tries if you look at edits stack).

I provide with complete helpers source in a lib on Github and you can find sample OAuth2 controller test there.

To make it short: no Authorization header -> ResourceServerTokenServices is not triggered -> SecurityContext will be anonymous in OAuth stack (whatever you try to set it to with @WithMockUser or alike).

So two cases here:

  • you're writing integration tests, provide valid tokens and let real token service do it's job and provide authentication contained in this token
  • you're writing unit tests, my case, and mock token service so that it returns mocked authentication

A similar approach, I understood after pulling my hair for a few days and building this from ground up, has already been described here. I just went further in mocked Oauth2Authentication configuration and tooling for @WebMvcTests.

Sample usage

As this post is long, exposing a solution involving quite some code, lets get started with the result so that you can decide if it's worth reading ;)

@WebMvcTest(MyController.class) // Controller to unit-test
@Import(WebSecurityConfig.class) // your class extending WebSecurityConfigurerAdapter
public class MyControllerTest extends OAuth2ControllerTest {

    @Test
    public void testWithUnauthenticatedClient() throws Exception {
        api.post(payload, "/endpoint")
                .andExpect(...);
    }

    @Test
    @WithMockOAuth2Client
    public void testWithDefaultClient() throws Exception {
        api.get("/endpoint")
                .andExpect(...);
    }

    @Test
    @WithMockOAuth2User
    public void testWithDefaultClientOnBehalfDefaultUser() throws Exception {
            MockHttpServletRequestBuilder req = api.postRequestBuilder(null, "/uaa/refresh")
                .header("refresh_token", JWT_REFRESH_TOKEN);

        api.perform(req)
                .andExpect(status().isOk())
                .andExpect(...)
    }

    @Test
    @WithMockOAuth2User(
        client = @WithMockOAuth2Client(
                clientId = "custom-client",
                scope = {"custom-scope", "other-scope"},
                authorities = {"custom-authority", "ROLE_CUSTOM_CLIENT"}),
        user = @WithMockUser(
                username = "custom-username",
                authorities = {"custom-user-authority"}))
    public void testWithCustomClientOnBehalfCustomUser() throws Exception {
        api.get(MediaType.APPLICATION_ATOM_XML, "/endpoint")
                .andExpect(status().isOk())
                .andExpect(xpath(...));
    }
}

Funky, isn't it ?

P.S. api is an instance of MockMvcHelper, a wrapper of my own for MockMvc, provided at the end of this post.

@WithMockOAuth2Client to simulate client only authentication (no end-user involved)

@Retention(RetentionPolicy.RUNTIME)
@WithSecurityContext(factory = WithMockOAuth2Client.WithMockOAuth2ClientSecurityContextFactory.class)
public @interface WithMockOAuth2Client {

    String clientId() default "web-client";

    String[] scope() default {"openid"};

    String[] authorities() default {};

    boolean approved() default true;

    class WithMockOAuth2ClientSecurityContextFactory implements WithSecurityContextFactory<WithMockOAuth2Client> {

        public static OAuth2Request getOAuth2Request(final WithMockOAuth2Client annotation) {
            final Set<? extends GrantedAuthority> authorities = Stream.of(annotation.authorities())
                    .map(auth -> new SimpleGrantedAuthority(auth))
                    .collect(Collectors.toSet());

            final Set<String> scope = Stream.of(annotation.scope())
                    .collect(Collectors.toSet());

            return new OAuth2Request(
                    null,
                    annotation.clientId(),
                    authorities,
                    annotation.approved(),
                    scope,
                    null,
                    null,
                    null,
                    null);
        }

        @Override
        public SecurityContext createSecurityContext(final WithMockOAuth2Client annotation) {
            final SecurityContext ctx = SecurityContextHolder.createEmptyContext();
            ctx.setAuthentication(new OAuth2Authentication(getOAuth2Request(annotation), null));
            SecurityContextHolder.setContext(ctx);
            return ctx;
        }
    }

}

@WithMockOAuth2User to simulate client authenticating on behalf of an end-user

@Retention(RetentionPolicy.RUNTIME)
@WithSecurityContext(factory = WithMockOAuth2User.WithMockOAuth2UserSecurityContextFactory.class)
public @interface WithMockOAuth2User {

    WithMockOAuth2Client client() default @WithMockOAuth2Client();

    WithMockUser user() default @WithMockUser();

    class WithMockOAuth2UserSecurityContextFactory implements WithSecurityContextFactory<WithMockOAuth2User> {

        /**
         * Sadly, #WithMockUserSecurityContextFactory is not public,
         * so re-implement mock user authentication creation
         *
         * @param user
         * @return an Authentication with provided user details
         */
        public static UsernamePasswordAuthenticationToken getUserAuthentication(final WithMockUser user) {
            final String principal = user.username().isEmpty() ? user.value() : user.username();

            final Stream<String> grants = user.authorities().length == 0 ?
                    Stream.of(user.roles()).map(r -> "ROLE_" + r) :
                    Stream.of(user.authorities());

            final Set<? extends GrantedAuthority> userAuthorities = grants
                    .map(auth -> new SimpleGrantedAuthority(auth))
                    .collect(Collectors.toSet());

            return new UsernamePasswordAuthenticationToken(
                    new User(principal, user.password(), userAuthorities),
                    principal + ":" + user.password(),
                    userAuthorities);
        }

        @Override
        public SecurityContext createSecurityContext(final WithMockOAuth2User annotation) {
            final SecurityContext ctx = SecurityContextHolder.createEmptyContext();
            ctx.setAuthentication(new OAuth2Authentication(
                    WithMockOAuth2Client.WithMockOAuth2ClientSecurityContextFactory.getOAuth2Request(annotation.client()),
                    getUserAuthentication(annotation.user())));
            SecurityContextHolder.setContext(ctx);
            return ctx;
        }
    }
}

OAuth2MockMvcHelper helps build test requests with expected Authorization header

public class OAuth2MockMvcHelper extends MockMvcHelper {
    public static final String VALID_TEST_TOKEN_VALUE = "test.fake.jwt";

    public OAuth2MockMvcHelper(
            final MockMvc mockMvc,
            final ObjectFactory<HttpMessageConverters> messageConverters,
            final MediaType defaultMediaType) {
        super(mockMvc, messageConverters, defaultMediaType);
    }

    /**
     * Adds OAuth2 support: adds an Authorisation header to all request builders
     * if there is an OAuth2Authentication in test security context.
     * 
     * /!\ Make sure your token services recognize this dummy "VALID_TEST_TOKEN_VALUE" token as valid during your tests /!\
     *
     * @param contentType should be not-null when issuing request with body (POST, PUT, PATCH), null otherwise
     * @param accept      should be not-null when issuing response with body (GET, POST, OPTION), null otherwise
     * @param method
     * @param urlTemplate
     * @param uriVars
     * @return a request builder with minimal info you can tweak further (add headers, cookies, etc.)
     */
    @Override
    public MockHttpServletRequestBuilder requestBuilder(
            Optional<MediaType> contentType,
            Optional<MediaType> accept,
            HttpMethod method,
            String urlTemplate,
            Object... uriVars) {
        final MockHttpServletRequestBuilder builder = super.requestBuilder(contentType, accept, method, urlTemplate, uriVars);
        if (SecurityContextHolder.getContext().getAuthentication() instanceof OAuth2Authentication) {
            builder.header("Authorization", "Bearer " + VALID_TEST_TOKEN_VALUE);
        }
        return builder;
    }
}

OAuth2ControllerTest a parent for controllers unit-tests

@RunWith(SpringRunner.class)
@Import(OAuth2MockMvcConfig.class)
public class OAuth2ControllerTest {

    @MockBean
    private ResourceServerTokenServices tokenService;

    @Autowired
    protected OAuth2MockMvcHelper api;

    @Autowired
    protected SerializationHelper conv;

    @Before
    public void setUpTokenService() {
        when(tokenService.loadAuthentication(api.VALID_TEST_TOKEN_VALUE))
                .thenAnswer(invocation -> SecurityContextHolder.getContext().getAuthentication());
    }
}
@TestConfiguration
class OAuth2MockMvcConfig {

    @Bean
    public SerializationHelper serializationHelper(ObjectFactory<HttpMessageConverters> messageConverters) {
        return new SerializationHelper(messageConverters);
    }

    @Bean
    public OAuth2MockMvcHelper mockMvcHelper(
            MockMvc mockMvc,
            ObjectFactory<HttpMessageConverters> messageConverters,
            @Value("${controllers.default-media-type:application/json;charset=UTF-8}") MediaType defaultMediaType) {
        return new OAuth2MockMvcHelper(mockMvc, messageConverters, defaultMediaType);
    }

}

Tooling referenced above but not directly related to OAuth2 testing

/**
 * Wraps MockMvc to further ease interaction with tested API:
 * provides with:<ul>
 * <li>many request shortcuts for simple cases (see get, post, put, patch, delete methods)</li>
 * <li>perfom method along with request builder initialisation shortcuts (see getRequestBuilder, etc.) when more control is required (additional headers, ...)</li>
 * </ul>
 */
public class MockMvcHelper {

    private final MockMvc mockMvc;

    private final MediaType defaultMediaType;

    protected final SerializationHelper conv;

    public MockMvcHelper(MockMvc mockMvc, ObjectFactory<HttpMessageConverters> messageConverters, MediaType defaultMediaType) {
        this.mockMvc = mockMvc;
        this.conv = new SerializationHelper(messageConverters);
        this.defaultMediaType = defaultMediaType;
    }

    /**
     * Generic request builder which adds relevant "Accept" and "Content-Type" headers
     *
     * @param contentType should be not-null when issuing request with body (POST, PUT, PATCH), null otherwise
     * @param accept      should be not-null when issuing response with body (GET, POST, OPTION), null otherwise
     * @param method
     * @param urlTemplate
     * @param uriVars
     * @return a request builder with minimal info you can tweak further: add headers, cookies, etc.
     */
    public MockHttpServletRequestBuilder requestBuilder(
            Optional<MediaType> contentType,
            Optional<MediaType> accept,
            HttpMethod method,
            String urlTemplate,
            Object... uriVars) {
        final MockHttpServletRequestBuilder builder = request(method, urlTemplate, uriVars);
        contentType.ifPresent(builder::contentType);
        accept.ifPresent(builder::accept);
        return builder;
    }

    public ResultActions perform(MockHttpServletRequestBuilder request) throws Exception {
        return mockMvc.perform(request);
    }

    /* GET */
    public MockHttpServletRequestBuilder getRequestBuilder(MediaType accept, String urlTemplate, Object... uriVars) {
        return requestBuilder(Optional.empty(), Optional.of(accept), HttpMethod.GET, urlTemplate, uriVars);
    }

    public MockHttpServletRequestBuilder getRequestBuilder(String urlTemplate, Object... uriVars) {
        return getRequestBuilder(defaultMediaType, urlTemplate, uriVars);
    }

    public ResultActions get(MediaType accept, String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(getRequestBuilder(accept, urlTemplate, uriVars));
    }

    public ResultActions get(String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(getRequestBuilder(urlTemplate, uriVars));
    }

    /* POST */
    public <T> MockHttpServletRequestBuilder postRequestBuilder(final T payload, MediaType contentType, MediaType accept, String urlTemplate, Object... uriVars) throws Exception {
        return feed(
                requestBuilder(Optional.of(contentType), Optional.of(accept), HttpMethod.POST, urlTemplate, uriVars),
                payload,
                contentType);
    }

    public <T> MockHttpServletRequestBuilder postRequestBuilder(final T payload, String urlTemplate, Object... uriVars) throws Exception {
        return postRequestBuilder(payload, defaultMediaType, defaultMediaType, urlTemplate, uriVars);
    }

    public <T> ResultActions post(final T payload, MediaType contentType, MediaType accept, String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(postRequestBuilder(payload, contentType, accept, urlTemplate, uriVars));
    }

    public <T> ResultActions post(final T payload, String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(postRequestBuilder(payload, urlTemplate, uriVars));
    }


    /* PUT */
    public <T> MockHttpServletRequestBuilder putRequestBuilder(final T payload, MediaType contentType, String urlTemplate, Object... uriVars) throws Exception {
        return feed(
                requestBuilder(Optional.of(contentType), Optional.empty(), HttpMethod.PUT, urlTemplate, uriVars),
                payload,
                contentType);
    }

    public <T> MockHttpServletRequestBuilder putRequestBuilder(final T payload, String urlTemplate, Object... uriVars) throws Exception {
        return putRequestBuilder(payload, defaultMediaType, urlTemplate, uriVars);
    }

    public <T> ResultActions put(final T payload, MediaType contentType, String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(putRequestBuilder(payload, contentType, urlTemplate, uriVars));
    }

    public <T> ResultActions put(final T payload, String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(putRequestBuilder(payload, urlTemplate, uriVars));
    }


    /* PATCH */
    public <T> MockHttpServletRequestBuilder patchRequestBuilder(final T payload, MediaType contentType, String urlTemplate, Object... uriVars) throws Exception {
        return feed(
                requestBuilder(Optional.of(contentType), Optional.empty(), HttpMethod.PATCH, urlTemplate, uriVars),
                payload,
                contentType);
    }

    public <T> MockHttpServletRequestBuilder patchRequestBuilder(final T payload, String urlTemplate, Object... uriVars) throws Exception {
        return patchRequestBuilder(payload, defaultMediaType, urlTemplate, uriVars);
    }

    public <T> ResultActions patch(final T payload, MediaType contentType, String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(patchRequestBuilder(payload, contentType, urlTemplate, uriVars));
    }

    public <T> ResultActions patch(final T payload, String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(patchRequestBuilder(payload, urlTemplate, uriVars));
    }


    /* DELETE */
    public MockHttpServletRequestBuilder deleteRequestBuilder(String urlTemplate, Object... uriVars) {
        return requestBuilder(Optional.empty(), Optional.empty(), HttpMethod.DELETE, urlTemplate, uriVars);
    }

    public ResultActions delete(String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(deleteRequestBuilder(urlTemplate, uriVars));
    }


    /* HEAD */
    public MockHttpServletRequestBuilder headRequestBuilder(String urlTemplate, Object... uriVars) {
        return requestBuilder(Optional.empty(), Optional.empty(), HttpMethod.HEAD, urlTemplate, uriVars);
    }

    public ResultActions head(String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(headRequestBuilder(urlTemplate, uriVars));
    }


    /* OPTION */
    public MockHttpServletRequestBuilder optionRequestBuilder(MediaType accept, String urlTemplate, Object... uriVars) {
        return requestBuilder(Optional.empty(), Optional.of(accept), HttpMethod.OPTIONS, urlTemplate, uriVars);
    }

    public MockHttpServletRequestBuilder optionRequestBuilder(String urlTemplate, Object... uriVars) {
        return requestBuilder(Optional.empty(), Optional.of(defaultMediaType), HttpMethod.OPTIONS, urlTemplate, uriVars);
    }

    public ResultActions option(MediaType accept, String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(optionRequestBuilder(accept, urlTemplate, uriVars));
    }

    public ResultActions option(String urlTemplate, Object... uriVars) throws Exception {
        return mockMvc.perform(optionRequestBuilder(urlTemplate, uriVars));
    }

    /**
     * Adds serialized payload to request content
     *
     * @param request
     * @param payload
     * @param mediaType
     * @param <T>
     * @return the request with provided payload as content
     * @throws Exception if things go wrong (no registered serializer for payload type and asked MediaType, serialization failure, ...)
     */
    public <T> MockHttpServletRequestBuilder feed(
            MockHttpServletRequestBuilder request,
            final T payload,
            final MediaType mediaType) throws Exception {
        if (payload == null) {
            return request;
        }

        final SerializationHelper.ByteArrayHttpOutputMessage msg = conv.outputMessage(payload, mediaType);
        return request
                .headers(msg.headers)
                .content(msg.out.toByteArray());
    }
}
/**
 * Serialize objects to given media type using registered message converters
 */
public class SerializationHelper {

    private final ObjectFactory<HttpMessageConverters> messageConverters;

    public SerializationHelper(ObjectFactory<HttpMessageConverters> messageConverters) {
        this.messageConverters = messageConverters;
    }

    public <T> ByteArrayHttpOutputMessage outputMessage(final T payload, final MediaType mediaType) throws Exception {
        if (payload == null) {
            return null;
        }

        List<HttpMessageConverter<?>> relevantConverters = messageConverters.getObject().getConverters().stream()
                .filter(converter -> converter.canWrite(payload.getClass(), mediaType))
                .collect(Collectors.toList());

        final ByteArrayHttpOutputMessage converted = new ByteArrayHttpOutputMessage();
        boolean isConverted = false;
        for (HttpMessageConverter<?> converter : relevantConverters) {
            try {
                ((HttpMessageConverter<T>) converter).write(payload, mediaType, converted);
                isConverted = true; //won't be reached if a conversion error occurs
                break; //stop iterating over converters after first successful conversion
            } catch (IOException e) {
                //swallow exception so that next converter is tried
            }
        }

        if (!isConverted) {
            throw new Exception("Could not convert " + payload.getClass() + " to " + mediaType.toString());
        }

        return converted;
    }

    /**
     * Provides a String representation of provided payload
     *
     * @param payload
     * @param mediaType
     * @param <T>
     * @return
     * @throws Exception if things go wrong (no registered serializer for payload type and asked MediaType, serialization failure, ...)
     */
    public <T> String asString(T payload, MediaType mediaType) throws Exception {
        return payload == null ?
                null :
                outputMessage(payload, mediaType).out.toString();
    }

    public <T> String asJsonString(T payload) throws Exception {
        return asString(payload, MediaType.APPLICATION_JSON_UTF8);
    }

    public static final class ByteArrayHttpOutputMessage implements HttpOutputMessage {
        public final ByteArrayOutputStream out = new ByteArrayOutputStream();
        public final HttpHeaders headers = new HttpHeaders();

        @Override
        public OutputStream getBody() {
            return out;
        }

        @Override
        public HttpHeaders getHeaders() {
            return headers;
        }
    }
}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!