Web sockets with Spring and Spring Security


Web sockets with custom message security is not a topic that appears widely online, let alone example on how to make it happen, therefore I decided to present on web sockets and in particular allowing custom security on the socket level.

Any feedback or comments would be appreciated.

Spring Security

In this particular example I am wanting to perform authentication based on a custom message header that contains a token, the token is then in turn transformed to a User object which is then added to the SecurityContext.

What is supported out of the box:

Spring Security by default has security for web sockets, and can use the message header simpUser to authenticate, but when you want to authenticate on a different identifier in this example a token, then any custom header you attach with come under the message header nativeHeaders which will contains a list of custom headers, the default spring security implementation does not allow us to access the headers under nativeHeaders and therefore cannot perform our authentication. Spring also does not allow us to extend the default functionality and enhance it to support our requirements and we are required to write our own configuration to support this.

What is required:

To allow authentication to occur on custom headers we are required to override the Spring Security default behavior, and since it does not support to ability to override parts we would like to enhance we in fact to have create a new AbstractWebSocketMessageBrokerConfigurer.

The three major changes to note are:
1. This class in not marked as final so it can be extended.
2. the configureClientInboundChannel is not marked as final so it can be overridden
3. securityContextChannelInterceptor is now returning our injected TokenSecurityChannelInterceptor rather than SecurityContextChannelInterceptor

public class AbstractSecurityWebSocketMessageBrokerConfig extends AbstractWebSocketMessageBrokerConfigurer implements SmartInitializingSingleton {
    private final AbstractSecurityWebSocketMessageBrokerConfig.WebSocketMessageSecurityMetadataSourceRegistry inboundRegistry = new AbstractSecurityWebSocketMessageBrokerConfig.WebSocketMessageSecurityMetadataSourceRegistry();

    @Autowired
    private ApplicationContext context;

    @Autowired
    private TokenSecurityChannelInterceptor tokenSecurityChannelInterceptor;

    public AbstractSecurityWebSocketMessageBrokerConfig() {
    }

    public void registerStompEndpoints(StompEndpointRegistry registry) {
    }

    public void addArgumentResolvers(List argumentResolvers) {
        argumentResolvers.add(new AuthenticationPrincipalArgumentResolver());
    }

    public void configureClientInboundChannel(ChannelRegistration registration) {
        ChannelSecurityInterceptor inboundChannelSecurity = this.inboundChannelSecurity();

        registration.setInterceptors(new ChannelInterceptor[]{this.securityContextChannelInterceptor()});

        if (!this.sameOriginDisabled()) {
            registration.setInterceptors(new ChannelInterceptor[]{this.csrfChannelInterceptor()});
        }

        if (this.inboundRegistry.containsMapping()) {
            registration.setInterceptors(new ChannelInterceptor[]{inboundChannelSecurity});
        }

        this.customizeClientInboundChannel(registration);
    }

    protected boolean sameOriginDisabled() {
        return true;
    }

    protected void customizeClientInboundChannel(ChannelRegistration registration) {
    }

    @Bean
    public ChannelInterceptorAdapter securityContextChannelInterceptor() {
        return tokenSecurityChannelInterceptor;
    }

    @Bean
    public CsrfChannelInterceptor csrfChannelInterceptor() {
        return new CsrfChannelInterceptor();
    }

    @Bean
    public ChannelSecurityInterceptor inboundChannelSecurity() {
        ChannelSecurityInterceptor channelSecurityInterceptor = new ChannelSecurityInterceptor(this.inboundMessageSecurityMetadataSource());
        ArrayList voters = new ArrayList();
        voters.add(new MessageExpressionVoter());
        AffirmativeBased manager = new AffirmativeBased(voters);
        channelSecurityInterceptor.setAccessDecisionManager(manager);
        return channelSecurityInterceptor;
    }

    @Bean
    public MessageSecurityMetadataSource inboundMessageSecurityMetadataSource() {
        this.configureInbound(this.inboundRegistry);
        return this.inboundRegistry.createMetadataSource();
    }

    protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
    }

    public void afterSingletonsInstantiated() {
        if (!this.sameOriginDisabled()) {
            String beanName = "stompWebSocketHandlerMapping";
            SimpleUrlHandlerMapping mapping = this.context.getBean(beanName, SimpleUrlHandlerMapping.class);
            Map mappings = mapping.getHandlerMap();
            Iterator i$ = mappings.values().iterator();

            while (i$.hasNext()) {
                Object object = i$.next();
                if (object instanceof SockJsHttpRequestHandler) {
                    SockJsHttpRequestHandler handler = (SockJsHttpRequestHandler) object;
                    SockJsService handshakeInterceptors = handler.getSockJsService();
                    if (!(handshakeInterceptors instanceof TransportHandlingSockJsService)) {
                        throw new IllegalStateException("sockJsService must be instance of TransportHandlingSockJsService got " + handshakeInterceptors);
                    }

                    TransportHandlingSockJsService interceptorsToSet = (TransportHandlingSockJsService) handshakeInterceptors;
                    List handshakeInterceptors1 = interceptorsToSet.getHandshakeInterceptors();
                    ArrayList interceptorsToSet1 = new ArrayList(handshakeInterceptors1.size() + 1);
                    interceptorsToSet1.add(new CsrfTokenHandshakeInterceptor());
                    interceptorsToSet1.addAll(handshakeInterceptors1);
                    interceptorsToSet.setHandshakeInterceptors(interceptorsToSet1);
                } else {
                    if (!(object instanceof WebSocketHttpRequestHandler)) {
                        throw new IllegalStateException("Bean " + beanName + " is expected to contain mappings to either a SockJsHttpRequestHandler or a WebSocketHttpRequestHandler but got " + object);
                    }

                    WebSocketHttpRequestHandler handler1 = (WebSocketHttpRequestHandler) object;
                    List handshakeInterceptors2 = handler1.getHandshakeInterceptors();
                    ArrayList interceptorsToSet2 = new ArrayList(handshakeInterceptors2.size() + 1);
                    interceptorsToSet2.add(new CsrfTokenHandshakeInterceptor());
                    interceptorsToSet2.addAll(handshakeInterceptors2);
                    handler1.setHandshakeInterceptors(interceptorsToSet2);
                }
            }

        }
    }

    private class WebSocketMessageSecurityMetadataSourceRegistry extends MessageSecurityMetadataSourceRegistry {
        private WebSocketMessageSecurityMetadataSourceRegistry() {
        }

        public MessageSecurityMetadataSource createMetadataSource() {
            return super.createMetadataSource();
        }

        protected boolean containsMapping() {
            return super.containsMapping();
        }
    }
}

To allow authentication to occur on custom headers we are required to override the Spring Security default behavior, and since it does not support to ability to override parts we would like to enhance we in fact to have create a new SecurityContextChannelInterceptor .

The four major changes to note are:
1. This class in not marked as final so it can be extended.
2. the setup method calls our custom authentication method
3. we can now access the custom message headers that appear under nativeHeaders
4. the class in inject-able

@Component(value = Constants.SECURITY_TOKEN_SECURITY_CHANNEL_INTERCEPTOR)
public class TokenSecurityChannelInterceptor extends ChannelInterceptorAdapter implements ExecutorChannelInterceptor {

    private static final ThreadLocal<Stack> ORIGINAL_CONTEXT = new ThreadLocal();

    private final SecurityContext EMPTY_CONTEXT;
    private final Authentication anonymous;

    @Autowired
    private TokenAuthenticationService tokenAuthenticationService;

    public TokenSecurityChannelInterceptor() {
        this.EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
        this.anonymous = new AnonymousAuthenticationToken("key", "anonymous", AuthorityUtils.createAuthorityList(new String[]{"ROLE_ANONYMOUS"}));
    }

    public Message preSend(Message message, MessageChannel channel) {
        this.setup(message);
        return message;
    }

    public void afterSendCompletion(Message message, MessageChannel channel, boolean sent, Exception ex) {
        this.cleanup();
    }

    @Override
    public Message beforeHandle(Message message, MessageChannel channel, MessageHandler handler) {
        this.setup(message);
        return message;
    }

    @Override
    public void afterMessageHandled(Message message, MessageChannel channel, MessageHandler handler, Exception ex) {
        this.cleanup();
    }

    private void setup(Message message) {
        SecurityContext currentContext = SecurityContextHolder.getContext();
        Stack contextStack = (Stack) ORIGINAL_CONTEXT.get();
        if (contextStack == null) {
            contextStack = new Stack();
            ORIGINAL_CONTEXT.set(contextStack);
        }

        contextStack.push(currentContext);

        SecurityContext context = SecurityContextHolder.createEmptyContext();
        context.setAuthentication(getAuthentication(message.getHeaders()));
        SecurityContextHolder.setContext(context);
    }

    private Authentication getAuthentication(MessageHeaders messageHeaders) {
        Authentication authentication = this.anonymous;

        Map nativeHeaders = (Map) messageHeaders.get("nativeHeaders");

        if (nativeHeaders != null) {
            LinkedList token = nativeHeaders.get(Constants.HEADER_X_AUTH_TOKEN);

            if(token != null) {
                Authentication tokenAuthentication = tokenAuthenticationService.getAuthentication(token.getFirst().toString());

                if (tokenAuthentication != null) {
                    authentication = tokenAuthentication;
                }
            }
        }

        return authentication;
    }

    private void cleanup() {
        Stack contextStack = (Stack) ORIGINAL_CONTEXT.get();
        if (contextStack != null && !contextStack.isEmpty()) {
            SecurityContext originalContext = (SecurityContext) contextStack.pop();

            try {
                if (this.EMPTY_CONTEXT.equals(originalContext)) {
                    SecurityContextHolder.clearContext();
                    ORIGINAL_CONTEXT.remove();
                } else {
                    SecurityContextHolder.setContext(originalContext);
                }
            } catch (Throwable var4) {
                SecurityContextHolder.clearContext();
            }

        } else {
            SecurityContextHolder.clearContext();
            ORIGINAL_CONTEXT.remove();
        }
    }
}

This is the web socket configuration that sets the rules on what requires authentication and what doesn’t.

For this we can expect that any CONNECT, SUBSCRIBE or MESSAGE will require the custom token to be passed as part of the message header otherwise the response will be AccessDenied, we cannot pass headers to the UNSUBSCRIBE and DISCONNECT so we have set them to allow all to action.

@Configuration
@ComponentScan(basePackages = {"au.com.example.security.service", "au.com.example.security.spring.security"})
@PropertySource("classpath:properties/security.properties")
public class WebSocketSecurityConfig extends AbstractWebSocketSecurityConfig {

@Override
protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
    messages
            .simpTypeMatchers(
                    SimpMessageType.CONNECT,
                    SimpMessageType.MESSAGE,
                    SimpMessageType.SUBSCRIBE).authenticated()
            .simpTypeMatchers(
                    SimpMessageType.UNSUBSCRIBE,
                    SimpMessageType.DISCONNECT).permitAll()
            .anyMessage().denyAll();
  }
}

Client

The client is using Angular JS with SocksJS and STOMP.

'use strict';

angular.module('app.services').service('socketService', ['$rootScope', '$stomp', 'storageService', 'storageConstant', 'propertiesConstant',
function ($rootScope, $stomp, storageService, storageConstant, propertiesConstant) {
var connection;
var subscriptions = {};
    this.subscribe = function subscribe() {
        var authToken = storageService.getSessionItem(storageConstant.AUTH_TOKEN);
        var headers = (authToken) ? {"X-AUTH-TOKEN": authToken} : {};

        connect(headers);

        connection.then(function (frame) {
            if (!(subscriptions.articles)) {
                subscriptions.articles = $stomp.subscribe('/api/user/articles', function (payload, headers, res) {
                    $rootScope.$apply(function () {
                        $rootScope.articleCount = payload.length;
                    })
                }, headers);
            }
        });
    };

    this.unsubscribe = function unsubscribe() {
        if (subscriptions.articles) {
            subscriptions.articles.unsubscribe();
        }

        if (connection) {
            $stomp.disconnect(function () {
                delete $rootScope.articleCount;
            });
        }

        subscriptions = {};
    };

    function connect(headers) {
        if (!(connection)) {
            connection = $stomp.connect(propertiesConstant.WEBSOCKET_API_URL + '/stomp', headers);
        }
    }
}]);

Example

You can check out the project from the following location below:

Project URL: https://github.com/Rob-Leggett/angular_websockets_security
What you can learn:

  • Maven Modules
  • Gulp JS
  • ES6
  • SockJS / STOMP
  • Basic directives
  • Multiple views
  • Security
  • Filtering and sorting lists
  • Data Binding
  • Data retrieval via AJAX
  • Integration of client side MVC with server side MVC
  • Stateless API
  • RESTful entry points
  • Websocket entry points
  • Basic Authentication
  • Token Authentication
  • Data operations using JPA
  • In Memory Databases
  • Unit Tests
  • Integration Tests
  • Jasmine Tests
Advertisements

13 thoughts on “Web sockets with Spring and Spring Security

      • Hi, Robert,

        One more question. I want to introduce sending some messages to the user specific destination (messages to particular user). My user destination prefix is set to “/private”, and according to the documentation I should connect to the “/private/queue/messages” channel. I am also using you this configuration of websocket token security from your repo, but, unfortunately spent a lot of hours of debuging and experimenting, but this didn’t work for me. Maybe I missed something in websocket configuration? Can you advice if possible?

        Thank You

      • Hi Yuriy, Sorry for the late response. In my example the client sets a custom header refer to socketService.js to see this, and the websocket config on the api level will check for this header and authenticate against it, it expects all connect, message and subscribe socket actions to be authenticated which in my example is always via tokens.

        Unfortunately I cannot determine what you may be missing without checking out the project, but the key classes to keep in mind are: WebSocketConfig, WebSocketSecurityConfig, AbstractSecurityWebSocketMessageBrokerConfig and TokenSecurityChannelInterceptor.

        Hope that helps.

  1. I am also facing same issue as @Yuriy, suppose that I have to users ‘X’ and ‘Y’ connected to websocket endpoint and I want to send messages to specific user lets say ‘X’.

    • Hi

      I spent few days on this problem and figured out that the root cause is the following:

      – our REST API is stateless application (only tokens and no headers)
      – socksjs client javascript library and webbrowsers don’t support passing Websocket Http Headers (spring java client has such ability) for websocket connect message
      – when resolving private channel destination Spring websocket uses data from DefaultSimpUserRegistry which uses data resolved by other classes from Http Headers passed at the moment of connect.

  2. Robert i am trying to figure the same thing but handling binary websocket i.e. i do not use message broker nor i use stomp i..e my question to you si do you know what abstact class i will have to extend if i want to use a pure binary websocket.

  3. in this case i was referring to the class AbstractWebSocketMessageBrokerConfigurer i.e. what class will i have to extend if my web-socket is pure binary websocket. Many thanks

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s