// See the file "COPYING" in the main distribution directory for copyright. // Implementation of a WebSocket server and clients using the IXWebSocket client library. #include "zeek/cluster/websocket/WebSocket.h" #include #include #include "zeek/Reporter.h" #include "ixwebsocket/IXConnectionState.h" #include "ixwebsocket/IXSocketTLSOptions.h" #include "ixwebsocket/IXWebSocket.h" #include "ixwebsocket/IXWebSocketSendData.h" #include "ixwebsocket/IXWebSocketServer.h" namespace zeek::cluster::websocket::detail::ixwebsocket { /** * Implementation of WebSocketClient for the IXWebsocket library. */ class IxWebSocketClient : public WebSocketClient { public: IxWebSocketClient(std::shared_ptr cs, std::shared_ptr ws) : cs(std::move(cs)), ws(std::move(ws)) { if ( ! this->cs || ! this->ws ) throw std::invalid_argument("expected ws and cs to be set"); } bool IsTerminated() const override { if ( cs->isTerminated() ) return true; auto rs = ws->getReadyState(); return rs == ix::ReadyState::Closing || rs == ix::ReadyState::Closed; } void Close(uint16_t code, const std::string& reason) override { ws->close(code, reason); } SendInfo SendText(std::string_view sv) override { if ( cs->isTerminated() ) return {true}; // small lie auto send_info = ws->sendUtf8Text(ix::IXWebSocketSendData{sv.data(), sv.size()}); return SendInfo{send_info.success}; } const std::string& getId() override { return cs->getId(); } const std::string& getRemoteIp() override { return cs->getRemoteIp(); } int getRemotePort() override { return cs->getRemotePort(); } private: std::shared_ptr cs; std::shared_ptr ws; }; /** * Implementation of WebSocketServer using the IXWebsocket library. */ class IXWebSocketServer : public WebSocketServer { public: IXWebSocketServer(std::unique_ptr dispatcher, std::unique_ptr server) : WebSocketServer(std::move(dispatcher)), server(std::move(server)) {} private: void DoTerminate() override { // Stop the server. server->stop(); } std::unique_ptr server; }; std::unique_ptr StartServer(std::unique_ptr dispatcher, const ServerOptions& options) { auto server = std::make_unique(options.port, options.host, ix::SocketServer::kDefaultTcpBacklog, options.max_connections, ix::WebSocketServer::kDefaultHandShakeTimeoutSecs, ix::SocketServer::kDefaultAddressFamily, options.ping_interval_seconds); if ( ! options.per_message_deflate ) server->disablePerMessageDeflate(); const auto& tls_options = options.tls_options; if ( tls_options.TlsEnabled() ) { ix::SocketTLSOptions ix_tls_options{}; ix_tls_options.tls = true; ix_tls_options.certFile = tls_options.cert_file.value(); ix_tls_options.keyFile = tls_options.key_file.value(); if ( tls_options.enable_peer_verification ) { if ( ! tls_options.ca_file.empty() ) ix_tls_options.caFile = tls_options.ca_file; } else { // This is the IXWebSocket library's way of // disabling peer verification. ix_tls_options.caFile = "NONE"; } if ( ! tls_options.ciphers.empty() ) ix_tls_options.ciphers = tls_options.ciphers; server->setTLSOptions(ix_tls_options); } // Using the legacy IXWebsocketAPI API to acquire a shared_ptr to the ix::WebSocket instance. ix::WebSocketServer::OnConnectionCallback connection_callback = [dispatcher = dispatcher.get()](std::weak_ptr websocket, std::shared_ptr cs) -> void { // Hold a shared_ptr to the WebSocket object until we see the close. std::shared_ptr ws = websocket.lock(); // Client already gone or terminated? Weird... if ( ! ws || cs->isTerminated() ) return; auto id = cs->getId(); int remotePort = cs->getRemotePort(); std::string remoteIp = cs->getRemoteIp(); auto ixws = std::make_shared(std::move(cs), ws); // These callbacks run in per client threads. The actual processing happens // on the main thread via a single WebSocketDemux instance. ix::OnMessageCallback message_callback = [dispatcher, id, remotePort, remoteIp, ixws](const ix::WebSocketMessagePtr& msg) mutable { if ( msg->type == ix::WebSocketMessageType::Open ) { dispatcher->QueueForProcessing( WebSocketOpen{id, msg->openInfo.uri, msg->openInfo.protocol, std::move(ixws)}); } else if ( msg->type == ix::WebSocketMessageType::Message ) { dispatcher->QueueForProcessing(WebSocketMessage{id, msg->str}); } else if ( msg->type == ix::WebSocketMessageType::Close ) { dispatcher->QueueForProcessing(WebSocketClose{id}); } else if ( msg->type == ix::WebSocketMessageType::Error ) { dispatcher->QueueForProcessing(WebSocketClose{id}); } }; ws->setOnMessageCallback(message_callback); }; server->setOnConnectionCallback(connection_callback); const auto [success, reason] = server->listen(); if ( ! success ) { zeek::reporter->Error("WebSocket: Unable to listen on %s:%d: %s", options.host.c_str(), options.port, reason.c_str()); return nullptr; } server->start(); return std::make_unique(std::move(dispatcher), std::move(server)); } } // namespace zeek::cluster::websocket::detail::ixwebsocket using namespace zeek::cluster::websocket::detail; std::unique_ptr zeek::cluster::websocket::detail::StartServer( std::unique_ptr dispatcher, const ServerOptions& options) { // Just delegate to the above IXWebSocket specific implementation. return ixwebsocket::StartServer(std::move(dispatcher), options); }