From a6f401041ac82c1ba94a8fea21339acb33904ad0 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sun, 17 Jul 2022 10:10:38 +0100 Subject: [PATCH] SMP protocol v4: batching multiple server commands/responses in a transport block (#470) * batch server commands in one transport block * subscribe to multiple queues using batched commands * agent method to subscribe to multiple queues using batched commands * refactor * test for batched subscriptions * delete part of connections in batched test * add resubscribeConnections * remove comment * update SMP protocol doc --- protocol/simplex-messaging.md | 21 +- src/Simplex/Messaging/Agent.hs | 64 +++++- src/Simplex/Messaging/Agent/Client.hs | 78 +++++-- .../Agent/Store/SQLite/Migrations.hs | 4 +- src/Simplex/Messaging/Client.hs | 122 +++++++--- src/Simplex/Messaging/Encoding.hs | 1 + src/Simplex/Messaging/Notifications/Server.hs | 22 +- .../Messaging/Notifications/Transport.hs | 2 +- src/Simplex/Messaging/Protocol.hs | 46 +++- src/Simplex/Messaging/Server.hs | 208 +++++++++--------- src/Simplex/Messaging/Server/Env/STM.hs | 5 +- src/Simplex/Messaging/Transport.hs | 16 +- src/Simplex/Messaging/Util.hs | 1 + tests/AgentTests/FunctionalAPITests.hs | 76 ++++++- tests/AgentTests/NotificationTests.hs | 1 - tests/NtfClient.hs | 5 +- tests/NtfServerTests.hs | 12 +- tests/SMPAgentClient.hs | 6 + tests/SMPClient.hs | 8 +- tests/ServerTests.hs | 95 ++++---- 20 files changed, 545 insertions(+), 248 deletions(-) diff --git a/protocol/simplex-messaging.md b/protocol/simplex-messaging.md index 3c2cad2..1c53fe5 100644 --- a/protocol/simplex-messaging.md +++ b/protocol/simplex-messaging.md @@ -51,6 +51,8 @@ It's designed with the focus on communication security and integrity, under the It is designed as a low level protocol for other application protocols to solve the problem of secure and private message transmission, making [MITM attack][1] very difficult at any part of the message transmission system. +This document describes SMP protocol versions 3 and 4, the previous versions are discontinued. + ## Introduction The objective of Simplex Messaging Protocol (SMP) is to facilitate the secure and private unidirectional transfer of messages from senders to recipients via persistent simplex queues managed by the message broker (server). @@ -362,15 +364,16 @@ The clients can optionally instruct a dedicated push notification server to subs [`SEND` command](#send-message) includes the notification flag to instruct SMP server whether to send the notification - this flag is forwarded to the recepient inside encrypted envelope, together with the timestamp and the message body, so even if TLS is compromised this flag cannot be used for traffic correlation. -## SMP Transmission structure +## SMP Transmission andtransport block structure Each transport block (SMP transmission) has a fixed size of 16384 bytes for traffic uniformity. +From SMP version 4 each block can contain multiple transmissions, version 3 blocks have 1 transmission. Some parts of SMP transmission are padded to a fixed size; this padding is uniformly added as a word16 encoded in network byte order - see `paddedString` syntax. In places where some part of the transmission should be padded, the syntax for `paddedNotation` is used: -``` +```abnf paddedString = originalLength string pad originalLength = 2*2 OCTET pad = N*N"#" ; where N = paddedLength - originalLength - 2 @@ -380,9 +383,9 @@ paddedNotation = ; paddedLength - required length after padding, including 2 bytes for originalLength ``` -Each transmission between the client and the server must have this format/syntax: +Each transmission/block for SMP v3 between the client and the server must have this format/syntax: -``` +```abnf paddedTransmission = transmission = [signature] SP signed signed = sessionIdentifier SP [corrId] SP [queueId] SP smpCommand @@ -399,6 +402,16 @@ encoded = `base64` encoding should be used with padding, as defined in section 4 of [RFC 4648][9] +Transport block for SMP v4 has this syntax: + +```abnf +paddedTransportBlock = +transportBlock = transmissionCount transmissions +transmissionCount = 1*1 OCTET ; equal or greater than 1 +transmissions = transmissionLength transmission [transmissions] +transmissionLength = 2*2 OCTET ; word16 encoded in network byte order +``` + ## SMP commands Commands syntax below is provided using [ABNF][8] with [case-sensitive strings extension][8a]. diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index cd72694..30e745b 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -45,9 +45,11 @@ module Simplex.Messaging.Agent acceptContact, rejectContact, subscribeConnection, + subscribeConnections, getConnectionMessage, getNotificationMessage, resubscribeConnection, + resubscribeConnections, sendMessage, ackMessage, suspendConnection, @@ -79,6 +81,7 @@ import Data.Composition ((.:), (.:.)) import Data.Functor (($>)) import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L +import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (isJust) import qualified Data.Text as T @@ -105,10 +108,10 @@ import Simplex.Messaging.Parsers (parse) import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta) import qualified Simplex.Messaging.Protocol as SMP import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (bshow, eitherToMaybe, liftE, liftError, tryError, unlessM, whenM, ($>>=)) +import Simplex.Messaging.Util import Simplex.Messaging.Version import System.Random (randomR) -import UnliftIO.Async (async, race_) +import UnliftIO.Async (async, mapConcurrently, race_) import UnliftIO.Concurrent (forkFinally, forkIO, threadDelay) import qualified UnliftIO.Exception as E import UnliftIO.STM @@ -158,6 +161,10 @@ rejectContact c = withAgentEnv c .: rejectContact' c subscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () subscribeConnection c = withAgentEnv c . subscribeConnection' c +-- | Subscribe to receive connection messages from multiple connections, batching commands when possible +subscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +subscribeConnections c = withAgentEnv c . subscribeConnections' c + -- | Get connection message (GET command) getConnectionMessage :: AgentErrorMonad m => AgentClient -> ConnId -> m (Maybe SMPMsgMeta) getConnectionMessage c = withAgentEnv c . getConnectionMessage' c @@ -169,6 +176,9 @@ getNotificationMessage c = withAgentEnv c .: getNotificationMessage' c resubscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m () resubscribeConnection c = withAgentEnv c . resubscribeConnection' c +resubscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +resubscribeConnections c = withAgentEnv c . resubscribeConnections' c + -- | Send message to the connection (SEND command) sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId sendMessage c = withAgentEnv c .:. sendMessage' c @@ -393,12 +403,62 @@ subscribeConnection' c connId = ns <- asks ntfSupervisor atomically $ sendNtfSubCommand ns (connId, NSCCreate) +subscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +subscribeConnections' _ [] = pure M.empty +subscribeConnections' c connIds = do + conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (forM connIds . getConn) + let (errs, cs) = M.mapEither id conns + errs' = M.map (Left . storeError) errs + (sndQs, rcvQs) = M.mapEither rcvOrSndQueue cs + sndRs = M.map (sndSubResult . fst) sndQs + srvRcvQs :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) = M.foldlWithKey' addRcvQueue M.empty rcvQs + mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs + rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs) + let rs = M.unions $ errs' : sndRs : rcvRs + notifyResultError rs + pure rs + where + rcvOrSndQueue :: SomeConn -> Either (SndQueue, ConnData) (RcvQueue, ConnData) + rcvOrSndQueue = \case + SomeConn _ (DuplexConnection cData rq _) -> Right (rq, cData) + SomeConn _ (SndConnection cData sq) -> Left (sq, cData) + SomeConn _ (RcvConnection cData rq) -> Right (rq, cData) + SomeConn _ (ContactConnection cData rq) -> Right (rq, cData) + sndSubResult :: SndQueue -> Either AgentErrorType () + sndSubResult sq = case status (sq :: SndQueue) of + Confirmed -> Right () + Active -> Left $ CONN SIMPLEX + _ -> Left $ INTERNAL "unexpected queue status" + addRcvQueue :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) -> ConnId -> (RcvQueue, ConnData) -> Map SMPServer (Map ConnId (RcvQueue, ConnData)) + addRcvQueue m connId rq@(RcvQueue {server}, _) = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m + subscribe :: (SMPServer, Map ConnId (RcvQueue, ConnData)) -> m (Map ConnId (Either AgentErrorType ())) + subscribe (srv, qs) = subscribeQueues c srv (M.map fst qs) + sndQueue :: SomeConn -> Maybe (ConnData, SndQueue) + sndQueue = \case + SomeConn _ (DuplexConnection cData _ sq) -> Just (cData, sq) + SomeConn _ (SndConnection cData sq) -> Just (cData, sq) + _ -> Nothing + notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m () + notifyResultError rs = do + let actual = M.size rs + expected = length connIds + when (actual /= expected) . atomically $ + writeTBQueue (subQ c) ("", "", ERR . INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected) + resubscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m () resubscribeConnection' c connId = unlessM (atomically $ hasActiveSubscription c connId) (subscribeConnection' c connId) +resubscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ())) +resubscribeConnections' _ [] = pure M.empty +resubscribeConnections' c connIds = do + let r = M.fromList . zip connIds . repeat $ Right () + connIds' <- filterM (fmap not . atomically . hasActiveSubscription c) connIds + -- union is left-biased, so results returned by subscribeConnections' take precedence + (`M.union` r) <$> subscribeConnections' c connIds' + getConnectionMessage' :: AgentMonad m => AgentClient -> ConnId -> m (Maybe SMPMsgMeta) getConnectionMessage' c connId = do whenM (atomically $ hasActiveSubscription c connId) . throwError $ CMD PROHIBITED diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index fc3a6d4..e899e3c 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -21,6 +21,7 @@ module Simplex.Messaging.Agent.Client closeAgentClient, newRcvQueue, subscribeQueue, + subscribeQueues, getQueueMessage, decryptSMPMessage, addSubscription, @@ -64,6 +65,7 @@ module Simplex.Messaging.Agent.Client whenSuspending, withStore, withStore', + storeError, ) where @@ -74,16 +76,19 @@ import Control.Logger.Simple import Control.Monad.Except import Control.Monad.IO.Unlift import Control.Monad.Reader -import Data.Bifunctor (first) +import Data.Bifunctor (bimap, first, second) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Either (partitionEithers) import Data.List.NonEmpty (NonEmpty) +import qualified Data.List.NonEmpty as L import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (catMaybes) import Data.Set (Set) import Data.Text.Encoding +import Data.Tuple (swap) import Data.Word (Word16) import qualified Database.SQLite.Simple as DB import Simplex.Messaging.Agent.Env.SQLite @@ -103,7 +108,7 @@ import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), MsgId, N import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (bshow, catchAll_, ifM, liftEitherError, liftError, tryError, unlessM, whenM) +import Simplex.Messaging.Util import Simplex.Messaging.Version import System.Timeout (timeout) import UnliftIO (async, pooledForConcurrentlyN) @@ -476,14 +481,42 @@ subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m () subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED atomically $ addPendingSubscription c rq connId - withLogClient c server rcvId "SUB" $ \smp -> do - liftIO (runExceptT $ subscribeSMPQueue smp rcvPrivateKey rcvId) >>= \case - Left e -> do - atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $ - removePendingSubscription c server connId - throwError e - Right _ -> do - addSubscription c rq connId + withLogClient c server rcvId "SUB" $ \smp -> + liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId) + >>= either throwError pure + +processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ()) +processSubResult c rq@RcvQueue {server} connId r = do + case r of + Left e -> + atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $ + removePendingSubscription c server connId + _ -> addSubscription c rq connId + pure r + +-- | subscribe multiple queues - all passed queues should be on the same server +subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Map ConnId (Either AgentErrorType ())) +subscribeQueues c srv qs = do + (errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs) + forM_ qs_ $ atomically . uncurry (addPendingSubscription c) . swap + case L.nonEmpty qs_ of + Just qs' -> do + smp_ <- tryError (getSMPServerClient c srv) + M.fromList . (errs <>) <$> case smp_ of + Left e -> pure $ map (second . const $ Left e) qs_ + Right smp -> do + logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB" + let qs2 = L.map (queueCreds . snd) qs' + rs' :: [((ConnId, RcvQueue), Either ProtocolClientError ())] <- + liftIO $ zip qs_ . L.toList <$> subscribeSMPQueues smp qs2 + forM_ rs' $ \((connId, rq), r) -> liftIO $ processSubResult c rq connId r + pure $ map (bimap fst (first $ protocolClientError SMP)) rs' + _ -> pure $ M.fromList errs + where + checkQueue rq@(connId, RcvQueue {rcvId, server}) = do + prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c + pure $ if prohibited || srv /= server then Left (connId, Left $ CMD PROHIBITED) else Right rq + queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m () addSubscription c rq@RcvQueue {server} connId = atomically $ do @@ -762,15 +795,16 @@ withStore c action = do where handleInternal :: E.SomeException -> IO (Either StoreError a) handleInternal = pure . Left . SEInternal . bshow - storeError :: StoreError -> AgentErrorType - storeError = \case - SEConnNotFound -> CONN NOT_FOUND - SEConnDuplicate -> CONN DUPLICATE - SEBadConnType CRcv -> CONN SIMPLEX - SEBadConnType CSnd -> CONN SIMPLEX - SEInvitationNotFound -> CMD PROHIBITED - -- this error is never reported as store error, - -- it is used to wrap agent operations when "transaction-like" store access is needed - -- NOTE: network IO should NOT be used inside AgentStoreMonad - SEAgentError e -> e - e -> INTERNAL $ show e + +storeError :: StoreError -> AgentErrorType +storeError = \case + SEConnNotFound -> CONN NOT_FOUND + SEConnDuplicate -> CONN DUPLICATE + SEBadConnType CRcv -> CONN SIMPLEX + SEBadConnType CSnd -> CONN SIMPLEX + SEInvitationNotFound -> CMD PROHIBITED + -- this error is never reported as store error, + -- it is used to wrap agent operations when "transaction-like" store access is needed + -- NOTE: network IO should NOT be used inside AgentStoreMonad + SEAgentError e -> e + e -> INTERNAL $ show e diff --git a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs index 6b30617..da6b473 100644 --- a/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs +++ b/src/Simplex/Messaging/Agent/Store/SQLite/Migrations.hs @@ -16,8 +16,8 @@ module Simplex.Messaging.Agent.Store.SQLite.Migrations where import Control.Monad (forM_) -import Data.Function (on) import Data.List (intercalate, sortBy) +import Data.Ord (comparing) import Data.Text (Text) import Data.Time.Clock (getCurrentTime) import Database.SQLite.Simple (Connection, Only (..), Query (..)) @@ -44,7 +44,7 @@ schemaMigrations = -- | The list of migrations in ascending order by date app :: [Migration] -app = sortBy (compare `on` name) $ map migration schemaMigrations +app = sortBy (comparing name) $ map migration schemaMigrations where migration (name, query) = Migration {name = name, up = fromQuery query} diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index a04be51..ee7f2b6 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -6,6 +6,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} @@ -32,11 +33,14 @@ module Simplex.Messaging.Client -- * SMP protocol command functions createSMPQueue, subscribeSMPQueue, + subscribeSMPQueues, getSMPMessage, subscribeSMPQueueNotifications, secureSMPQueue, enableSMPQueueNotifications, disableSMPQueueNotifications, + enableSMPQueuesNtfs, + disableSMPQueuesNtfs, sendSMPMessage, ackSMPMessage, suspendSMPQueue, @@ -60,6 +64,10 @@ import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Except import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Either (rights) +import Data.Functor (($>)) +import Data.List.NonEmpty (NonEmpty) +import qualified Data.List.NonEmpty as L import Data.Maybe (fromMaybe) import Network.Socket (ServiceName) import Numeric.Natural @@ -87,13 +95,16 @@ data ProtocolClient msg = ProtocolClient tcpTimeout :: Int, clientCorrId :: TVar Natural, sentCommands :: TMap CorrId (Request msg), - sndQ :: TBQueue SentRawTransmission, - rcvQ :: TBQueue (SignedTransmission msg), + sndQ :: TBQueue (NonEmpty (SentRawTransmission)), + rcvQ :: TBQueue (NonEmpty (SignedTransmission msg)), msgQ :: Maybe (TBQueue (ServerTransmission msg)) } type SMPClient = ProtocolClient SMP.BrokerMsg +-- | Type for client command data +type ClientCommand msg = (Maybe C.APrivateSignKey, QueueId, ProtoCommand msg) + -- | Type synonym for transmission from some SPM server queue. type ServerTransmission msg = (ProtoServer msg, Version, SessionId, QueueId, msg) @@ -208,13 +219,15 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tc runExceptT $ sendProtocolCommand c Nothing "" protocolPing process :: ProtocolClient msg -> IO () - process c@ProtocolClient {rcvQ, sentCommands} = forever $ do - (_, _, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ + process c = forever $ atomically (readTBQueue $ rcvQ c) >>= mapM_ (processMsg c) + + processMsg :: ProtocolClient msg -> SignedTransmission msg -> IO () + processMsg c@ProtocolClient {sentCommands} (_, _, (corrId, qId, respOrErr)) = if B.null $ bs corrId - then sendMsg qId respOrErr + then sendMsg respOrErr else do atomically (TM.lookup corrId sentCommands) >>= \case - Nothing -> sendMsg qId respOrErr + Nothing -> sendMsg respOrErr Just Request {queueId, responseVar} -> atomically $ do TM.delete corrId sentCommands putTMVar responseVar $ @@ -226,8 +239,8 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tc _ -> Right r else Left . PCEUnexpectedResponse $ bshow respOrErr where - sendMsg :: QueueId -> Either ErrorType msg -> IO () - sendMsg qId = \case + sendMsg :: Either ErrorType msg -> IO () + sendMsg = \case Right msg -> atomically $ mapM_ (`writeTBQueue` serverTransmission c qId msg) msgQ -- TODO send everything else to errQ and log in agent _ -> return () @@ -285,12 +298,22 @@ subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT Pr subscribeSMPQueue c rpKey rId = sendSMPCommand c (Just rpKey) rId SUB >>= \case OK -> return () - cmd@MSG {} -> writeSMPMessage c rId cmd + cmd@MSG {} -> liftIO $ writeSMPMessage c rId cmd r -> throwE . PCEUnexpectedResponse $ bshow r -writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> ExceptT ProtocolClientError IO () -writeSMPMessage c rId msg = - liftIO . atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ c) +-- | Subscribe to multiple SMP queues batching commands if supported. +subscribeSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ())) +subscribeSMPQueues c qs = sendProtocolCommands c cs >>= mapM response . L.zip qs + where + cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient SUB)) qs + response ((_, rId), r) = case r of + Right OK -> pure $ Right () + Right cmd@MSG {} -> writeSMPMessage c rId cmd $> Right () + Right r' -> pure . Left . PCEUnexpectedResponse $ bshow r' + Left e -> pure $ Left e + +writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO () +writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ c) serverTransmission :: ProtocolClient msg -> RecipientId -> msg -> ServerTransmission msg serverTransmission ProtocolClient {protocolServer, thVersion, sessionId} entityId message = @@ -303,9 +326,7 @@ getSMPMessage :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT Protoc getSMPMessage c rpKey rId = sendSMPCommand c (Just rpKey) rId GET >>= \case OK -> pure Nothing - cmd@(MSG msg) -> do - writeSMPMessage c rId cmd - pure $ Just msg + cmd@(MSG msg) -> liftIO (writeSMPMessage c rId cmd) $> Just msg r -> throwE . PCEUnexpectedResponse $ bshow r -- | Subscribe to the SMP queue notifications. @@ -329,12 +350,32 @@ enableSMPQueueNotifications c rpKey rId notifierKey rcvNtfPublicDhKey = NID nId rcvNtfSrvPublicDhKey -> pure (nId, rcvNtfSrvPublicDhKey) r -> throwE . PCEUnexpectedResponse $ bshow r +-- | Enable notifications for the multiple queues for push notifications server. +enableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId, NtfPublicVerifyKey, RcvNtfPublicDhKey) -> IO (NonEmpty (Either ProtocolClientError (NotifierId, RcvNtfPublicDhKey))) +enableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs + where + cs = L.map (\(rpKey, rId, notifierKey, rcvNtfPublicDhKey) -> (Just rpKey, rId, Cmd SRecipient $ NKEY notifierKey rcvNtfPublicDhKey)) qs + response = \case + Right (NID nId rcvNtfSrvPublicDhKey) -> Right (nId, rcvNtfSrvPublicDhKey) + Right r -> Left . PCEUnexpectedResponse $ bshow r + Left e -> Left e + -- | Disable notifications for the queue for push notifications server. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#disable-notifications-command disableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO () disableSMPQueueNotifications = okSMPCommand NDEL +-- | Disable notifications for multiple queues for push notifications server. +disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ())) +disableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs + where + cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient NDEL)) qs + response = \case + Right OK -> Right () + Right r -> Left . PCEUnexpectedResponse $ bshow r + Left e -> Left e + -- | Send SMP message. -- -- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#send-message @@ -351,7 +392,7 @@ ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> MsgId -> ExceptT P ackSMPMessage c rpKey rId msgId = sendSMPCommand c (Just rpKey) rId (ACK msgId) >>= \case OK -> return () - cmd@MSG {} -> writeSMPMessage c rId cmd + cmd@MSG {} -> liftIO $ writeSMPMessage c rId cmd r -> throwE . PCEUnexpectedResponse $ bshow r -- | Irreversibly suspend SMP queue. @@ -377,37 +418,48 @@ okSMPCommand cmd c pKey qId = sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT ProtocolClientError IO BrokerMsg sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd) +-- | Send multiple commands with batching and collect responses +sendProtocolCommands :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Either ProtocolClientError msg)) +sendProtocolCommands c@ProtocolClient {sndQ, tcpTimeout} cs = do + ts <- mapM (runExceptT . mkTransmission c) cs + mapM_ (atomically . writeTBQueue sndQ . L.map fst) . L.nonEmpty . rights $ L.toList ts + forConcurrently ts $ \case + Right (_, r) -> withTimeout . atomically $ takeTMVar r + Left e -> pure $ Left e + where + withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a + -- | Send Protocol command sendProtocolCommand :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT ProtocolClientError IO msg -sendProtocolCommand ProtocolClient {sndQ, sentCommands, clientCorrId, sessionId, thVersion, tcpTimeout} pKey qId cmd = do - corrId <- lift_ getNextCorrId - t <- signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd) - ExceptT $ sendRecv corrId t +sendProtocolCommand c@ProtocolClient {sndQ, tcpTimeout} pKey qId cmd = do + (t, r) <- mkTransmission c (pKey, qId, cmd) + ExceptT $ sendRecv t r where - lift_ :: STM a -> ExceptT ProtocolClientError IO a - lift_ action = ExceptT $ Right <$> atomically action + -- two separate "atomically" needed to avoid blocking + sendRecv :: SentRawTransmission -> TMVar (Response msg) -> IO (Response msg) + sendRecv t r = atomically (writeTBQueue sndQ [t]) >> withTimeout (atomically $ takeTMVar r) + where + withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a +mkTransmission :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> ClientCommand msg -> ExceptT ProtocolClientError IO (SentRawTransmission, TMVar (Response msg)) +mkTransmission ProtocolClient {clientCorrId, sessionId, thVersion, sentCommands} (pKey, qId, cmd) = do + corrId <- liftIO $ atomically getNextCorrId + t <- signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd) + r <- liftIO . atomically $ mkRequest corrId + pure (t, r) + where getNextCorrId :: STM CorrId getNextCorrId = do i <- stateTVar clientCorrId $ \i -> (i, i + 1) pure . CorrId $ bshow i - signTransmission :: ByteString -> ExceptT ProtocolClientError IO SentRawTransmission signTransmission t = case pKey of - Nothing -> return (Nothing, t) + Nothing -> pure (Nothing, t) Just pk -> do sig <- liftError PCESignatureError $ C.sign pk t return (Just sig, t) - - -- two separate "atomically" needed to avoid blocking - sendRecv :: CorrId -> SentRawTransmission -> IO (Response msg) - sendRecv corrId t = atomically (send corrId t) >>= withTimeout . atomically . takeTMVar - where - withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a - - send :: CorrId -> SentRawTransmission -> STM (TMVar (Response msg)) - send corrId t = do + mkRequest :: CorrId -> STM (TMVar (Response msg)) + mkRequest corrId = do r <- newEmptyTMVar TM.insert corrId (Request qId r) sentCommands - writeTBQueue sndQ t - return r + pure r diff --git a/src/Simplex/Messaging/Encoding.hs b/src/Simplex/Messaging/Encoding.hs index 6f007d8..5d5dec3 100644 --- a/src/Simplex/Messaging/Encoding.hs +++ b/src/Simplex/Messaging/Encoding.hs @@ -13,6 +13,7 @@ module Simplex.Messaging.Encoding Large (..), smpEncodeList, smpListP, + lenEncode, ) where diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 254e870..b1ed82a 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -4,6 +4,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} @@ -248,22 +249,23 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m () receive th NtfServerClient {rcvQ, sndQ, activeAt} = forever $ do - t@(_, _, (corrId, entId, cmdOrError)) <- tGet th - atomically . writeTVar activeAt =<< liftIO getSystemTime - logDebug "received transmission" - case cmdOrError of - Left e -> write sndQ (corrId, entId, NRErr e) - Right cmd -> - verifyNtfTransmission t cmd >>= \case - VRVerified req -> write rcvQ req - VRFailed -> write sndQ (corrId, entId, NRErr AUTH) + ts <- tGet th + forM_ ts $ \t@(_, _, (corrId, entId, cmdOrError)) -> do + atomically . writeTVar activeAt =<< liftIO getSystemTime + logDebug "received transmission" + case cmdOrError of + Left e -> write sndQ (corrId, entId, NRErr e) + Right cmd -> + verifyNtfTransmission t cmd >>= \case + VRVerified req -> write rcvQ req + VRFailed -> write sndQ (corrId, entId, NRErr AUTH) where write q t = atomically $ writeTBQueue q t send :: (Transport c, MonadUnliftIO m) => THandle c -> NtfServerClient -> m () send h@THandle {thVersion = v} NtfServerClient {sndQ, sessionId, activeAt} = forever $ do t <- atomically $ readTBQueue sndQ - void . liftIO $ tPut h (Nothing, encodeTransmission v sessionId t) + void . liftIO $ tPut h [(Nothing, encodeTransmission v sessionId t)] atomically . writeTVar activeAt =<< liftIO getSystemTime -- instance Show a => Show (TVar a) where diff --git a/src/Simplex/Messaging/Notifications/Transport.hs b/src/Simplex/Messaging/Notifications/Transport.hs index d5a50ee..33abe56 100644 --- a/src/Simplex/Messaging/Notifications/Transport.hs +++ b/src/Simplex/Messaging/Notifications/Transport.hs @@ -69,4 +69,4 @@ ntfClientHandshake c keyHash ntfVRange = do Nothing -> throwError $ TEHandshake VERSION ntfTHandle :: Transport c => c -> THandle c -ntfTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0} +ntfTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0, batch = False} diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index 42713ad..c167d19 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -7,6 +7,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} @@ -130,6 +131,8 @@ import qualified Data.Attoparsec.ByteString.Char8 as A import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B import Data.Kind +import Data.List.NonEmpty (NonEmpty (..)) +import qualified Data.List.NonEmpty as L import Data.Maybe (isNothing) import Data.String import Data.Time.Clock.System (SystemTime (..)) @@ -1010,20 +1013,51 @@ instance Encoding CommandError where _ -> fail "bad command error type" -- | Send signed SMP transmission to TCP transport. -tPut :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ()) -tPut th (sig, t) = tPutBlock th $ smpEncode (C.signatureBytes sig) <> t +tPut :: Transport c => THandle c -> NonEmpty (SentRawTransmission) -> IO (NonEmpty (Either TransportError ())) +tPut th trs + | batch th = tPutBatch [] $ L.map tEncode trs + | otherwise = forM trs $ tPutBlock th . tEncode + where + tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO (NonEmpty (Either TransportError ())) + tPutBatch rs ts = do + let (n, s, ts_) = encodeBatch 0 "" ts + r <- if n == 0 then pure [Left TELargeMsg] else replicate n <$> tPutBlock th (lenEncode n `B.cons` s) + let rs' = rs <> r + case ts_ of + Just ts' -> tPutBatch rs' ts' + _ -> pure $ L.fromList rs' + encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString)) + encodeBatch n s ts@(t :| ts_) + | n == 255 = (n, s, Just ts) + | otherwise = + let s' = s <> smpEncode (Large t) + n' = n + 1 + in if B.length s' > blockSize th - 1 + then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts + else case L.nonEmpty ts_ of + Just ts' -> encodeBatch n' s' ts' + _ -> (n', s', Nothing) + tEncode (sig, tr) = smpEncode (C.signatureBytes sig) <> tr encodeTransmission :: ProtocolEncoding c => Version -> ByteString -> Transmission c -> ByteString encodeTransmission v sessionId (CorrId corrId, queueId, command) = smpEncode (sessionId, corrId, queueId) <> encodeProtocol v command -- | Receive and parse transmission from the TCP transport (ignoring any trailing padding). -tGetParse :: Transport c => THandle c -> IO (Either TransportError RawTransmission) -tGetParse th = (parse transmissionP TEBadBlock =<<) <$> tGetBlock th +tGetParse :: Transport c => THandle c -> IO (NonEmpty (Either TransportError RawTransmission)) +tGetParse th + | batch th = either ((:| []) . Left) id <$> runExceptT getBatch + | otherwise = (:| []) . (parse transmissionP TEBadBlock =<<) <$> tGetBlock th + where + getBatch :: ExceptT TransportError IO (NonEmpty (Either TransportError RawTransmission)) + getBatch = do + s <- ExceptT $ tGetBlock th + ts <- liftEither $ parse smpP TEBadBlock s + pure $ L.map (\(Large t) -> parse transmissionP TEBadBlock t) ts -- | Receive client and server transmissions (determined by `cmd` type). -tGet :: forall cmd c m. (ProtocolEncoding cmd, Transport c, MonadIO m) => THandle c -> m (SignedTransmission cmd) -tGet th@THandle {sessionId, thVersion = v} = liftIO (tGetParse th) >>= decodeParseValidate +tGet :: forall cmd c m. (ProtocolEncoding cmd, Transport c, MonadIO m) => THandle c -> m (NonEmpty (SignedTransmission cmd)) +tGet th@THandle {sessionId, thVersion = v} = liftIO (tGetParse th) >>= mapM decodeParseValidate where decodeParseValidate :: Either TransportError RawTransmission -> m (SignedTransmission cmd) decodeParseValidate = \case diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 9a02216..3474762 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -6,6 +6,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} @@ -43,9 +44,10 @@ import Crypto.Random import Data.Bifunctor (first) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Either (fromRight) +import Data.Either (fromRight, partitionEithers) import Data.Functor (($>)) import Data.List (intercalate) +import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M import Data.Maybe (isNothing) import Data.Set (Set) @@ -145,7 +147,7 @@ smpServer started = do endPreviousSubscriptions :: (QueueId, Client) -> m (Maybe s) endPreviousSubscriptions (qId, c) = do void . forkIO . atomically $ - writeTBQueue (sndQ c) (CorrId "", qId, END) + writeTBQueue (sndQ c) [(CorrId "", qId, END)] atomically $ TM.lookupDelete qId (clientSubs c) expireMessagesThread_ :: ServerConfig -> [m ()] @@ -243,26 +245,36 @@ cancelSub sub = Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread _ -> return () -receive :: (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> Client -> m () +receive :: forall c m. (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> Client -> m () receive th Client {rcvQ, sndQ, activeAt} = forever $ do - (sig, signed, (corrId, queueId, cmdOrError)) <- tGet th + ts <- L.toList <$> tGet th atomically . writeTVar activeAt =<< liftIO getSystemTime - case cmdOrError of - Left e -> write sndQ (corrId, queueId, ERR e) - Right cmd -> do - verified <- verifyTransmission sig signed queueId cmd - if verified - then write rcvQ (corrId, queueId, cmd) - else write sndQ (corrId, queueId, ERR AUTH) + as <- partitionEithers <$> mapM cmdAction ts + write sndQ $ fst as + write rcvQ $ snd as where - write q t = atomically $ writeTBQueue q t + cmdAction :: SignedTransmission Cmd -> m (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd)) + cmdAction (sig, signed, (corrId, queueId, cmdOrError)) = + case cmdOrError of + Left e -> pure $ Left (corrId, queueId, ERR e) + Right cmd -> verified <$> verifyTransmission sig signed queueId cmd + where + verified = \case + VRVerified qr -> Right (qr, (corrId, queueId, cmd)) + VRFailed -> Left (corrId, queueId, ERR AUTH) + write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty send :: (Transport c, MonadUnliftIO m) => THandle c -> Client -> m () send h@THandle {thVersion = v} Client {sndQ, sessionId, activeAt} = forever $ do - t <- atomically $ readTBQueue sndQ - -- TODO the line below can return Left, but we ignore it and do not disconnect the client - void . liftIO $ tPut h (Nothing, encodeTransmission v sessionId t) + ts <- atomically $ L.sortWith tOrder <$> readTBQueue sndQ + -- TODO the line below can return Lefts, but we ignore it and do not disconnect the client + void . liftIO . tPut h $ L.map ((Nothing,) . encodeTransmission v sessionId) ts atomically . writeTVar activeAt =<< liftIO getSystemTime + where + tOrder :: Transmission BrokerMsg -> Int + tOrder (_, _, cmd) = case cmd of + MSG {} -> 0 + _ -> 1 disconnectTransport :: (Transport c, MonadUnliftIO m) => THandle c -> client -> (client -> TVar SystemTime) -> ExpirationConfig -> m () disconnectTransport THandle {connection} c activeAt expCfg = do @@ -273,23 +285,28 @@ disconnectTransport THandle {connection} c activeAt expCfg = do ts <- readTVarIO $ activeAt c when (systemSeconds ts < old) $ closeConnection connection +data VerificationResult = VRVerified (Maybe QueueRec) | VRFailed + verifyTransmission :: - forall m. (MonadUnliftIO m, MonadReader Env m) => Maybe C.ASignature -> ByteString -> QueueId -> Cmd -> m Bool + forall m. (MonadUnliftIO m, MonadReader Env m) => Maybe C.ASignature -> ByteString -> QueueId -> Cmd -> m VerificationResult verifyTransmission sig_ signed queueId cmd = do case cmd of - Cmd SRecipient (NEW k _) -> pure $ verifyCmdSignature sig_ signed k + Cmd SRecipient (NEW k _) -> pure $ Nothing `verified` verifyCmdSignature sig_ signed k Cmd SRecipient _ -> verifyCmd SRecipient $ verifyCmdSignature sig_ signed . recipientKey Cmd SSender SEND {} -> verifyCmd SSender $ verifyMaybe . senderKey - Cmd SSender PING -> pure True + Cmd SSender PING -> pure $ VRVerified Nothing Cmd SNotifier NSUB -> verifyCmd SNotifier $ verifyMaybe . fmap notifierKey . notifier where - verifyCmd :: SParty p -> (QueueRec -> Bool) -> m Bool + verifyCmd :: SParty p -> (QueueRec -> Bool) -> m VerificationResult verifyCmd party f = do st <- asks queueStore - q <- atomically $ getQueue st party queueId - pure $ either (const $ maybe False (dummyVerifyCmd signed) sig_ `seq` False) f q + q_ <- atomically (getQueue st party queueId) + pure $ case q_ of + Right q -> Just q `verified` f q + _ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed verifyMaybe :: Maybe C.APublicVerifyKey -> Bool verifyMaybe = maybe (isNothing sig_) $ verifyCmdSignature sig_ signed + verified q cond = if cond then VRVerified q else VRFailed verifyCmdSignature :: Maybe C.ASignature -> ByteString -> C.APublicVerifyKey -> Bool verifyCmdSignature sig_ signed key = maybe False (verify key) sig_ @@ -320,16 +337,16 @@ client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscribedQ, ntfSubscribedQ, notifiers} = forever $ atomically (readTBQueue rcvQ) - >>= processCommand + >>= mapM processCommand >>= atomically . writeTBQueue sndQ where - processCommand :: Transmission Cmd -> m (Transmission BrokerMsg) - processCommand (corrId, queueId, cmd) = do + processCommand :: (Maybe QueueRec, Transmission Cmd) -> m (Transmission BrokerMsg) + processCommand (qr_, (corrId, queueId, cmd)) = do st <- asks queueStore case cmd of Cmd SSender command -> case command of - SEND flags msgBody -> sendMessage st flags msgBody + SEND flags msgBody -> withQueue $ \qr -> sendMessage qr flags msgBody PING -> pure (corrId, "", PONG) Cmd SNotifier NSUB -> subscribeNotifications Cmd SRecipient command -> @@ -339,9 +356,9 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv (asks $ allowNewQueues . config) (createQueue st rKey dhKey) (pure (corrId, queueId, ERR AUTH)) - SUB -> subscribeQueue st queueId - GET -> getMessage st - ACK msgId -> acknowledgeMsg st msgId + SUB -> withQueue (`subscribeQueue` queueId) + GET -> withQueue getMessage + ACK msgId -> withQueue (`acknowledgeMsg` msgId) KEY sKey -> secureQueue_ st sKey NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey NDEL -> deleteQueueNotifier_ st @@ -371,14 +388,15 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv addQueueRetry n qik qRec = do ids@(rId, _) <- getIds -- create QueueRec record with these ids and keys - atomically (addQueue st $ qRec ids) >>= \case + let qr = qRec ids + atomically (addQueue st qr) >>= \case Left DUPLICATE_ -> addQueueRetry (n - 1) qik qRec Left e -> pure $ ERR e Right _ -> do withLog (`logCreateById` rId) stats <- asks serverStats atomically $ modifyTVar (qCreated stats) (+ 1) - subscribeQueue st rId $> IDS (qik ids) + subscribeQueue qr rId $> IDS (qik ids) logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO () logCreateById s rId = @@ -426,8 +444,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv withLog (`logDeleteQueue` queueId) okResp <$> atomically (suspendQueue st queueId) - subscribeQueue :: QueueStore -> RecipientId -> m (Transmission BrokerMsg) - subscribeQueue st rId = + subscribeQueue :: QueueRec -> RecipientId -> m (Transmission BrokerMsg) + subscribeQueue qr rId = atomically (TM.lookup rId subscriptions) >>= \case Nothing -> atomically newSub >>= deliver @@ -449,10 +467,10 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv deliver sub = do q <- getStoreMsgQueue rId msg_ <- atomically $ tryPeekMsg q - deliverMessage st rId sub q msg_ + deliverMessage qr rId sub q msg_ - getMessage :: QueueStore -> m (Transmission BrokerMsg) - getMessage st = + getMessage :: QueueRec -> m (Transmission BrokerMsg) + getMessage qr = atomically (TM.lookup queueId subscriptions) >>= \case Nothing -> atomically newSub >>= getMessage_ @@ -471,7 +489,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv TM.insert queueId sub subscriptions pure s getMessage_ :: Sub -> m (Transmission BrokerMsg) - getMessage_ s = withRcvQueue st queueId $ \qr -> do + getMessage_ s = do q <- getStoreMsgQueue queueId atomically $ tryPeekMsg q >>= \case @@ -480,11 +498,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv in setDelivered s msg $> (corrId, queueId, MSG encMsg) _ -> pure (corrId, queueId, OK) - withRcvQueue :: QueueStore -> RecipientId -> (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg) - withRcvQueue st rId action = - atomically (getQueue st SRecipient rId) >>= \case - Left e -> pure (corrId, rId, ERR e) - Right qr -> action qr + withQueue :: (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg) + withQueue action = maybe (pure $ err AUTH) action qr_ subscribeNotifications :: m (Transmission BrokerMsg) subscribeNotifications = atomically $ do @@ -493,8 +508,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv TM.insert queueId () ntfSubscriptions pure ok - acknowledgeMsg :: QueueStore -> MsgId -> m (Transmission BrokerMsg) - acknowledgeMsg st msgId = do + acknowledgeMsg :: QueueRec -> MsgId -> m (Transmission BrokerMsg) + acknowledgeMsg qr msgId = do atomically (TM.lookup queueId subscriptions) >>= \case Nothing -> pure $ err NO_MSG Just sub -> @@ -509,7 +524,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv _ -> do (msgDeleted, msg_) <- atomically $ tryDelPeekMsg q msgId when msgDeleted updateStats - deliverMessage st queueId sub q msg_ + deliverMessage qr queueId sub q msg_ _ -> pure $ err NO_MSG where getDelivered :: TVar Sub -> STM (Maybe Sub) @@ -533,74 +548,69 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv where updatePeriod pSel = modifyTVar (pSel stats) (S.insert qId) - sendMessage :: QueueStore -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg) - sendMessage st msgFlags msgBody + sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg) + sendMessage qr msgFlags msgBody | B.length msgBody > maxMessageLength = pure $ err LARGE_MSG - | otherwise = do - qr <- atomically $ getQueue st SSender queueId - either (return . err) storeMessage qr + | otherwise = case status qr of + QueueOff -> return $ err AUTH + QueueActive -> + mapM mkMessage (C.maxLenBS msgBody) >>= \case + Left _ -> pure $ err LARGE_MSG + Right msg -> do + ms <- asks msgStore + ServerConfig {messageExpiration, msgQueueQuota} <- asks config + old <- liftIO $ mapM expireBeforeEpoch messageExpiration + ntfNonceDrg <- asks idsDrg + resp@(_, _, sent) <- atomically $ do + q <- getMsgQueue ms (recipientId qr) msgQueueQuota + mapM_ (deleteExpiredMsgs q) old + ifM (isFull q) (pure $ err QUOTA) $ do + when (notification msgFlags) $ trySendNotification msg ntfNonceDrg + writeMsg q msg + pure ok + when (sent == OK) $ do + stats <- asks serverStats + atomically $ modifyTVar (msgSent stats) (+ 1) + atomically $ updateActiveQueues stats $ recipientId qr + pure resp where - storeMessage :: QueueRec -> m (Transmission BrokerMsg) - storeMessage qr = case status qr of - QueueOff -> return $ err AUTH - QueueActive -> - mapM mkMessage (C.maxLenBS msgBody) >>= \case - Left _ -> pure $ err LARGE_MSG - Right msg -> do - ms <- asks msgStore - ServerConfig {messageExpiration, msgQueueQuota} <- asks config - old <- liftIO $ mapM expireBeforeEpoch messageExpiration - ntfNonceDrg <- asks idsDrg - resp@(_, _, sent) <- atomically $ do - q <- getMsgQueue ms (recipientId qr) msgQueueQuota - mapM_ (deleteExpiredMsgs q) old - ifM (isFull q) (pure $ err QUOTA) $ do - when (notification msgFlags) $ trySendNotification msg ntfNonceDrg - writeMsg q msg - pure ok - when (sent == OK) $ do - stats <- asks serverStats - atomically $ modifyTVar (msgSent stats) (+ 1) - atomically $ updateActiveQueues stats $ recipientId qr - pure resp - where - mkMessage :: C.MaxLenBS MaxMessageLen -> m Message - mkMessage body = do - msgId <- randomId =<< asks (msgIdBytes . config) - msgTs <- liftIO getSystemTime - pure $ Message msgId msgTs msgFlags body + mkMessage :: C.MaxLenBS MaxMessageLen -> m Message + mkMessage body = do + msgId <- randomId =<< asks (msgIdBytes . config) + msgTs <- liftIO getSystemTime + pure $ Message msgId msgTs msgFlags body - trySendNotification :: Message -> TVar ChaChaDRG -> STM () - trySendNotification msg ntfNonceDrg = - forM_ (notifier qr) $ \NtfCreds {notifierId, rcvNtfDhSecret} -> - mapM_ (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers + trySendNotification :: Message -> TVar ChaChaDRG -> STM () + trySendNotification msg ntfNonceDrg = + forM_ (notifier qr) $ \NtfCreds {notifierId, rcvNtfDhSecret} -> + mapM_ (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers - writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM () - writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} = - unlessM (isFullTBQueue sndQ) $ do - (nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg - writeTBQueue q (CorrId "", nId, NMSG nmsgNonce encNMsgMeta) + writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM () + writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} = + unlessM (isFullTBQueue sndQ) $ do + (nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg + writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)] - mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta) - mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do - cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg - let msgMeta = NMsgMeta {msgId, msgTs} - encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128 - pure . (cbNonce,) $ fromRight "" encNMsgMeta + mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta) + mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do + cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg + let msgMeta = NMsgMeta {msgId, msgTs} + encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128 + pure . (cbNonce,) $ fromRight "" encNMsgMeta - deliverMessage :: QueueStore -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg) - deliverMessage st rId sub q msg_ = withRcvQueue st rId $ \qr -> do + deliverMessage :: QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg) + deliverMessage qr rId sub q msg_ = do readTVarIO sub >>= \case s@Sub {subThread = NoSub} -> case msg_ of Just msg -> let encMsg = encryptMsg qr msg in atomically (setDelivered s msg) $> (corrId, rId, MSG encMsg) - _ -> forkSub qr $> ok + _ -> forkSub $> ok _ -> pure ok where - forkSub :: QueueRec -> m () - forkSub qr = do + forkSub :: m () + forkSub = do atomically . modifyTVar sub $ \s -> s {subThread = SubPending} t <- mkWeakThreadId =<< forkIO subscriber atomically . modifyTVar sub $ \case @@ -610,7 +620,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv subscriber = atomically $ do msg <- peekMsg q let encMsg = encryptMsg qr msg - writeTBQueue sndQ (CorrId "", rId, MSG encMsg) + writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)] s <- readTVar sub void $ setDelivered s msg writeTVar sub s {subThread = NoSub} diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index 2f6db1b..b041ac7 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -9,6 +9,7 @@ import Control.Concurrent (ThreadId) import Control.Monad.IO.Unlift import Crypto.Random import Data.ByteString.Char8 (ByteString) +import Data.List.NonEmpty (NonEmpty) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Time.Clock (getCurrentTime) @@ -103,8 +104,8 @@ data Server = Server data Client = Client { subscriptions :: TMap RecipientId (TVar Sub), ntfSubscriptions :: TMap NotifierId (), - rcvQ :: TBQueue (Transmission Cmd), - sndQ :: TBQueue (Transmission BrokerMsg), + rcvQ :: TBQueue (NonEmpty (Maybe QueueRec, Transmission Cmd)), + sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), thVersion :: Version, sessionId :: ByteString, connected :: TVar Bool, diff --git a/src/Simplex/Messaging/Transport.hs b/src/Simplex/Messaging/Transport.hs index be4faff..07bfadf 100644 --- a/src/Simplex/Messaging/Transport.hs +++ b/src/Simplex/Messaging/Transport.hs @@ -96,7 +96,7 @@ smpBlockSize :: Int smpBlockSize = 16384 supportedSMPServerVRange :: VersionRange -supportedSMPServerVRange = mkVersionRange 1 3 +supportedSMPServerVRange = mkVersionRange 1 4 simplexMQVersion :: String simplexMQVersion = "3.0.1" @@ -258,7 +258,10 @@ data THandle c = THandle sessionId :: SessionId, blockSize :: Int, -- | agreed server protocol version - thVersion :: Version + thVersion :: Version, + -- | send multiple transmissions in a single block + -- based on protocol and protocol version + batch :: Bool } -- | TLS-unique channel binding @@ -364,7 +367,7 @@ smpServerHandshake c kh smpVRange = do | keyHash /= kh -> throwE $ TEHandshake IDENTITY | smpVersion `isCompatible` smpVRange -> do - pure (th :: THandle c) {thVersion = smpVersion} + pure $ smpThHandle th smpVersion | otherwise -> throwE $ TEHandshake VERSION -- | Client SMP transport handshake. @@ -379,9 +382,12 @@ smpClientHandshake c keyHash smpVRange = do else case smpVersionRange `compatibleVersion` smpVRange of Just (Compatible smpVersion) -> do sendHandshake th $ ClientHandshake {smpVersion, keyHash} - pure (th :: THandle c) {thVersion = smpVersion} + pure $ smpThHandle th smpVersion Nothing -> throwE $ TEHandshake VERSION +smpThHandle :: forall c. THandle c -> Version -> THandle c +smpThHandle th v = (th :: THandle c) {thVersion = v, batch = v >= 4} + sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO () sendHandshake th = ExceptT . tPutBlock th . smpEncode @@ -389,4 +395,4 @@ getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportErr getHandshake th = ExceptT $ (parse smpP (TEHandshake PARSE) =<<) <$> tGetBlock th smpTHandle :: Transport c => c -> THandle c -smpTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0} +smpTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, batch = False} diff --git a/src/Simplex/Messaging/Util.hs b/src/Simplex/Messaging/Util.hs index c1e4146..b7ce5f9 100644 --- a/src/Simplex/Messaging/Util.hs +++ b/src/Simplex/Messaging/Util.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 7a312d7..6f95147 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -17,12 +17,16 @@ module AgentTests.FunctionalAPITests ) where -import Control.Concurrent (threadDelay) +import Control.Concurrent (killThread, threadDelay) +import Control.Monad import Control.Monad.Except (ExceptT, runExceptT) import Control.Monad.IO.Unlift +import Data.Int (Int64) +import qualified Data.Map as M +import qualified Data.Set as S import Data.Time.Clock.System (SystemTime (..), getSystemTime) import SMPAgentClient -import SMPClient (cfg, testPort, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn) +import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn) import Simplex.Messaging.Agent import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..)) import Simplex.Messaging.Agent.Protocol @@ -102,6 +106,9 @@ functionalAPITests t = do testSuspendingAgentCompleteSending t it "should suspend agent on timeout, even if pending messages not sent" $ testSuspendingAgentTimeout t + describe "Batching SMP commands" $ do + it "should subscribe to multiple subscriptions with batching" $ + testBatchedSubscriptions t testAgentClient :: IO () testAgentClient = do @@ -503,13 +510,64 @@ testSuspendingAgentTimeout t = do pure () +testBatchedSubscriptions :: ATransport -> IO () +testBatchedSubscriptions t = do + a <- getSMPAgentClient agentCfg initAgentServers2 + b <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers2 + Right conns <- runServers $ do + conns <- forM [1 .. 200 :: Int] . const $ makeConnection a b + forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId + forM_ (take 10 conns) $ \(aId, bId) -> do + deleteConnection a bId + deleteConnection b aId + liftIO $ threadDelay 1000000 + pure conns + ("", "", DOWN {}) <- get a + ("", "", DOWN {}) <- get a + ("", "", DOWN {}) <- get b + ("", "", DOWN {}) <- get b + Right () <- runServers $ do + ("", "", UP {}) <- get a + ("", "", UP {}) <- get a + ("", "", UP {}) <- get b + ("", "", UP {}) <- get b + liftIO $ threadDelay 1000000 + subscribe a $ map snd conns + subscribe b $ map fst conns + forM_ (drop 10 conns) $ \(aId, bId) -> exchangeGreetingsMsgId 6 a bId b aId + pure () + where + subscribe :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO () + subscribe c cs = do + r <- subscribeConnections c cs + liftIO $ do + let dc = S.fromList $ take 10 cs + all (== Right ()) (M.withoutKeys r dc) `shouldBe` True + all (== Left (CONN NOT_FOUND)) (M.restrictKeys r dc) `shouldBe` True + M.keys r `shouldMatchList` cs + runServers :: ExceptT AgentErrorType IO a -> IO (Either AgentErrorType a) + runServers a = do + withSmpServerStoreLogOn t testPort $ \t1 -> do + res <- withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile2} testPort2 $ \t2 -> do + res <- runExceptT a + killThread t2 + pure res + killThread t1 + pure res + exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () -exchangeGreetings alice bobId bob aliceId = do - 4 <- sendMessage alice bobId SMP.noMsgFlags "hello" - get alice ##> ("", bobId, SENT 4) +exchangeGreetings = exchangeGreetingsMsgId 4 + +exchangeGreetingsMsgId :: Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () +exchangeGreetingsMsgId msgId alice bobId bob aliceId = do + msgId1 <- sendMessage alice bobId SMP.noMsgFlags "hello" + liftIO $ msgId1 `shouldBe` msgId + get alice ##> ("", bobId, SENT msgId) get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False - ackMessage bob aliceId 4 - 5 <- sendMessage bob aliceId SMP.noMsgFlags "hello too" - get bob ##> ("", aliceId, SENT 5) + ackMessage bob aliceId msgId + msgId2 <- sendMessage bob aliceId SMP.noMsgFlags "hello too" + let msgId' = msgId + 1 + liftIO $ msgId2 `shouldBe` msgId' + get bob ##> ("", aliceId, SENT msgId') get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False - ackMessage alice bobId 5 + ackMessage alice bobId msgId' diff --git a/tests/AgentTests/NotificationTests.hs b/tests/AgentTests/NotificationTests.hs index 10e0ce7..62d8adb 100644 --- a/tests/AgentTests/NotificationTests.hs +++ b/tests/AgentTests/NotificationTests.hs @@ -22,7 +22,6 @@ import NtfClient import SMPAgentClient (agentCfg, initAgentServers, testDB, testDB2) import SMPClient (testPort, withSmpServer, withSmpServerStoreLogOn) import Simplex.Messaging.Agent -import Simplex.Messaging.Agent.Client (AgentClient) import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..)) import Simplex.Messaging.Agent.Protocol import qualified Simplex.Messaging.Crypto as C diff --git a/tests/NtfClient.hs b/tests/NtfClient.hs index 86ab111..97818aa 100644 --- a/tests/NtfClient.hs +++ b/tests/NtfClient.hs @@ -6,6 +6,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -140,10 +141,10 @@ ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h where tPut' h (sig, corrId, queueId, smp) = do let t' = smpEncode (sessionId (h :: THandle c), corrId, queueId, smp) - Right () <- tPut h (sig, t') + [Right ()] <- tPut h [(sig, t')] pure () tGet' h = do - (Nothing, _, (CorrId corrId, qId, Right cmd)) <- tGet h + [(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h pure (Nothing, corrId, qId, cmd) ntfTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation diff --git a/tests/NtfServerTests.hs b/tests/NtfServerTests.hs index f273862..42c4399 100644 --- a/tests/NtfServerTests.hs +++ b/tests/NtfServerTests.hs @@ -25,6 +25,8 @@ import ServerTests samplePubKey, sampleSig, signSendRecv, + tGet1, + tPut1, (#==), _SEND', pattern Resp, @@ -69,15 +71,15 @@ pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse) sendRecvNtf h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd) - Right () <- tPut h (sgn, t) - tGet h + Right () <- tPut1 h (sgn, t) + tGet1 h signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse) signSendRecvNtf h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd) Right sig <- runExceptT $ C.sign pk t - Right () <- tPut h (Just sig, t) - tGet h + Right () <- tPut1 h (Just sig, t) + tGet1 h (.->) :: J.Value -> J.Key -> Either String ByteString v .-> key = @@ -132,7 +134,7 @@ testNotificationSubscription (ATransport t) = notifierId `shouldBe` nId send' APNSRespOk -- receive message - Resp "" _ (MSG RcvMessage {msgId = mId1, msgBody = EncRcvMsgBody body}) <- tGet rh + Resp "" _ (MSG RcvMessage {msgId = mId1, msgBody = EncRcvMsgBody body}) <- tGet1 rh Right ClientRcvMsgBody {msgTs = mTs, msgBody} <- pure $ parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt rcvDhSecret (C.cbNonce mId1) body) mId1 `shouldBe` msgId mTs `shouldBe` msgTs diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 9a5bbee..ccadf15 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -158,6 +158,9 @@ smpAgentTest1_1_1 test' = testSMPServer :: SMPServer testSMPServer = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001" +testSMPServer2 :: SMPServer +testSMPServer2 = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5002" + initAgentServers :: InitialAgentServers initAgentServers = InitialAgentServers @@ -165,6 +168,9 @@ initAgentServers = ntf = ["ntf://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"] } +initAgentServers2 :: InitialAgentServers +initAgentServers2 = initAgentServers {smp = L.fromList [testSMPServer, testSMPServer2]} + agentCfg :: AgentConfig agentCfg = defaultAgentConfig diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index 8907e5a..17716cf 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} @@ -44,6 +45,9 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=" testStoreLogFile :: FilePath testStoreLogFile = "tests/tmp/smp-server-store.log" +testStoreLogFile2 :: FilePath +testStoreLogFile2 = "tests/tmp/smp-server-store.log.2" + testStoreMsgsFile :: FilePath testStoreMsgsFile = "tests/tmp/smp-server-messages.log" @@ -140,10 +144,10 @@ smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h where tPut' h (sig, corrId, queueId, smp) = do let t' = smpEncode (sessionId (h :: THandle c), corrId, queueId, smp) - Right () <- tPut h (sig, t') + [Right ()] <- tPut h [(sig, t')] pure () tGet' h = do - (Nothing, _, (CorrId corrId, qId, Right cmd)) <- tGet h + [(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h pure (Nothing, corrId, qId, cmd) smpTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation diff --git a/tests/ServerTests.hs b/tests/ServerTests.hs index 8c387ee..3ada5f4 100644 --- a/tests/ServerTests.hs +++ b/tests/ServerTests.hs @@ -3,6 +3,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -14,6 +15,7 @@ import Control.Concurrent (ThreadId, killThread, threadDelay) import Control.Concurrent.STM import Control.Exception (SomeException, try) import Control.Monad.Except (forM, forM_, runExceptT) +import Control.Monad.IO.Class import Data.Bifunctor (first) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) @@ -69,15 +71,25 @@ pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body} sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg) sendRecv h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd) - Right () <- tPut h (sgn, t) - tGet h + Right () <- tPut1 h (sgn, t) + tGet1 h signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg) signSendRecv h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd) Right sig <- runExceptT $ C.sign pk t - Right () <- tPut h (Just sig, t) - tGet h + Right () <- tPut1 h (Just sig, t) + tGet1 h + +tPut1 :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ()) +tPut1 h t = do + [r] <- tPut h [t] + pure r + +tGet1 :: (ProtocolEncoding cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission cmd) +tGet1 h = do + [r] <- tGet h + pure r (#==) :: (HasCallStack, Eq a, Show a) => (a, a) -> String -> Assertion (actual, expected) #== message = assertEqual message expected actual @@ -110,7 +122,7 @@ testCreateSecureV2 _ = (ok1, OK) #== "accepts unsigned SEND" (sId1, sId) #== "same queue ID in response 1" - Resp "" _ (Msg mId1 msg1) <- tGet h + Resp "" _ (Msg mId1 msg1) <- tGet1 h (dec mId1 msg1, Right "hello") #== "delivers message" Resp "cdab" _ ok4 <- signSendRecv h rKey ("cdab", rId, ACK mId1) @@ -140,7 +152,7 @@ testCreateSecureV2 _ = Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, _SEND "hello again") (ok3, OK) #== "accepts signed SEND" - Resp "" _ (Msg mId2 msg2) <- tGet h + Resp "" _ (Msg mId2 msg2) <- tGet1 h (dec mId2 msg2, Right "hello again") #== "delivers message 2" Resp "cdab" _ ok5 <- signSendRecv h rKey ("cdab", rId, ACK mId2) @@ -151,7 +163,7 @@ testCreateSecureV2 _ = let maxAllowedMessage = B.replicate maxMessageLength '-' Resp "bcda" _ OK <- signSendRecv h sKey ("bcda", sId, _SEND maxAllowedMessage) - Resp "" _ (Msg mId3 msg3) <- tGet h + Resp "" _ (Msg mId3 msg3) <- tGet1 h (dec mId3 msg3, Right maxAllowedMessage) #== "delivers message of max size" let biggerMessage = B.replicate (maxMessageLength + 1) '-' @@ -172,7 +184,7 @@ testCreateSecure (ATransport t) = (ok1, OK) #== "accepts unsigned SEND" (sId1, sId) #== "same queue ID in response 1" - Resp "" _ (Msg mId1 msg1) <- tGet h + Resp "" _ (Msg mId1 msg1) <- tGet1 h (dec mId1 msg1, Right "hello") #== "delivers message" Resp "cdab" _ ok4 <- signSendRecv h rKey ("cdab", rId, ACK mId1) @@ -202,7 +214,7 @@ testCreateSecure (ATransport t) = Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, _SEND "hello again") (ok3, OK) #== "accepts signed SEND" - Resp "" _ (Msg mId2 msg2) <- tGet h + Resp "" _ (Msg mId2 msg2) <- tGet1 h (dec mId2 msg2, Right "hello again") #== "delivers message 2" Resp "cdab" _ ok5 <- signSendRecv h rKey ("cdab", rId, ACK mId2) @@ -213,7 +225,7 @@ testCreateSecure (ATransport t) = let maxAllowedMessage = B.replicate maxMessageLength '-' Resp "bcda" _ OK <- signSendRecv h sKey ("bcda", sId, _SEND maxAllowedMessage) - Resp "" _ (Msg mId3 msg3) <- tGet h + Resp "" _ (Msg mId3 msg3) <- tGet1 h (dec mId3 msg3, Right maxAllowedMessage) #== "delivers message of max size" let biggerMessage = B.replicate (maxMessageLength + 1) '-' @@ -240,7 +252,7 @@ testCreateDelete (ATransport t) = Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, _SEND "hello 2") (ok7, OK) #== "accepts signed SEND 2 - this message is not delivered because the first is not ACKed" - Resp "" _ (Msg mId1 msg1) <- tGet rh + Resp "" _ (Msg mId1 msg1) <- tGet1 rh (dec mId1 msg1, Right "hello") #== "delivers message" Resp "abcd" _ err1 <- sendRecv rh (sampleSig, "abcd", rId, OFF) @@ -296,7 +308,7 @@ stressTest (ATransport t) = smpTest3 t $ \h1 h2 h3 -> do (rPub, rKey) <- C.generateSignatureKeyPair C.SEd25519 (dhPub, _ :: C.PrivateKeyX25519) <- C.generateKeyPair' - rIds <- forM [1 .. 50 :: Int] . const $ do + rIds <- forM ([1 .. 50] :: [Int]) . const $ do Resp "" "" (Ids rId _ _) <- signSendRecv h1 rKey ("", "", NEW rPub dhPub) pure rId let subscribeQueues h = forM_ rIds $ \rId -> do @@ -331,7 +343,7 @@ testDuplex (ATransport t) = Resp "bcda" _ OK <- sendRecv bob ("", "bcda", aSnd, _SEND $ "key " <> strEncode bsPub) -- "key ..." is ad-hoc, not a part of SMP protocol - Resp "" _ (Msg mId1 msg1) <- tGet alice + Resp "" _ (Msg mId1 msg1) <- tGet1 alice Resp "cdab" _ OK <- signSendRecv alice arKey ("cdab", aRcv, ACK mId1) Right ["key", bobKey] <- pure $ B.words <$> aDec mId1 msg1 (bobKey, strEncode bsPub) #== "key received from Bob" @@ -344,7 +356,7 @@ testDuplex (ATransport t) = Resp "bcda" _ OK <- signSendRecv bob bsKey ("bcda", aSnd, _SEND $ "reply_id " <> encode bSnd) -- "reply_id ..." is ad-hoc, not a part of SMP protocol - Resp "" _ (Msg mId2 msg2) <- tGet alice + Resp "" _ (Msg mId2 msg2) <- tGet1 alice Resp "cdab" _ OK <- signSendRecv alice arKey ("cdab", aRcv, ACK mId2) Right ["reply_id", bId] <- pure $ B.words <$> aDec mId2 msg2 (bId, encode bSnd) #== "reply queue ID received from Bob" @@ -353,7 +365,7 @@ testDuplex (ATransport t) = Resp "dabc" _ OK <- sendRecv alice ("", "dabc", bSnd, _SEND $ "key " <> strEncode asPub) -- "key ..." is ad-hoc, not a part of SMP protocol - Resp "" _ (Msg mId3 msg3) <- tGet bob + Resp "" _ (Msg mId3 msg3) <- tGet1 bob Resp "abcd" _ OK <- signSendRecv bob brKey ("abcd", bRcv, ACK mId3) Right ["key", aliceKey] <- pure $ B.words <$> bDec mId3 msg3 (aliceKey, strEncode asPub) #== "key received from Alice" @@ -361,13 +373,13 @@ testDuplex (ATransport t) = Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, _SEND "hi alice") - Resp "" _ (Msg mId4 msg4) <- tGet alice + Resp "" _ (Msg mId4 msg4) <- tGet1 alice Resp "dabc" _ OK <- signSendRecv alice arKey ("dabc", aRcv, ACK mId4) (aDec mId4 msg4, Right "hi alice") #== "message received from Bob" Resp "abcd" _ OK <- signSendRecv alice asKey ("abcd", bSnd, _SEND "how are you bob") - Resp "" _ (Msg mId5 msg5) <- tGet bob + Resp "" _ (Msg mId5 msg5) <- tGet1 bob Resp "bcda" _ OK <- signSendRecv bob brKey ("bcda", bRcv, ACK mId5) (bDec mId5 msg5, Right "how are you bob") #== "message received from alice" @@ -384,7 +396,7 @@ testSwitchSub (ATransport t) = Resp "cdab" _ ok2 <- sendRecv sh ("", "cdab", sId, _SEND "test2, no ACK") (ok2, OK) #== "sent test message 2" - Resp "" _ (Msg mId1 msg1) <- tGet rh1 + Resp "" _ (Msg mId1 msg1) <- tGet1 rh1 (dec mId1 msg1, Right "test1") #== "test message 1 delivered to the 1st TCP connection" Resp "abcd" _ (Msg mId2 msg2) <- signSendRecv rh1 rKey ("abcd", rId, ACK mId1) (dec mId2 msg2, Right "test2, no ACK") #== "test message 2 delivered, no ACK" @@ -393,12 +405,12 @@ testSwitchSub (ATransport t) = (dec mId2' msg2', Right "test2, no ACK") #== "same simplex queue via another TCP connection, tes2 delivered again (no ACK in 1st queue)" Resp "cdab" _ OK <- signSendRecv rh2 rKey ("cdab", rId, ACK mId2') - Resp "" _ end <- tGet rh1 + Resp "" _ end <- tGet1 rh1 (end, END) #== "unsubscribed the 1st TCP connection" Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, _SEND "test3") - Resp "" _ (Msg mId3 msg3) <- tGet rh2 + Resp "" _ (Msg mId3 msg3) <- tGet1 rh2 (dec mId3 msg3, Right "test3") #== "delivered to the 2nd TCP connection" Resp "abcd" _ err <- signSendRecv rh1 rKey ("abcd", rId, ACK mId3) @@ -441,7 +453,7 @@ testGetSubCommands t = Resp "1b" _ OK <- signSendRecv sh sKey ("1b", sId, _SEND "hello 3") Resp "1c" _ OK <- signSendRecv sh sKey ("1c", sId, _SEND "hello 4") -- both get the same if not ACK'd - Resp "" _ (Msg mId1 msg1) <- tGet rh1 + Resp "" _ (Msg mId1 msg1) <- tGet1 rh1 Resp "2" _ (Msg mId1' msg1') <- signSendRecv rh2 rKey ("2", rId, GET) (dec mId1 msg1, Right "hello 1") #== "received from queue via SUB" (dec mId1' msg1', Right "hello 1") #== "retrieved from queue with GET" @@ -503,14 +515,14 @@ testWithStoreLog at@(ATransport t) = writeTVar notifierId nId Resp "dabc" _ OK <- signSendRecv h1 nKey ("dabc", nId, NSUB) Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello") - Resp "" _ (Msg mId1 msg1) <- tGet h + Resp "" _ (Msg mId1 msg1) <- tGet1 h (decryptMsgV3 dhShared mId1 msg1, Right "hello") #== "delivered from queue 1" - Resp "" _ (NMSG _ _) <- tGet h1 + Resp "" _ (NMSG _ _) <- tGet1 h1 (sId2, rId2, rKey2, dhShared2) <- createAndSecureQueue h sPub2 atomically $ writeTVar senderId2 sId2 Resp "cdab" _ OK <- signSendRecv h sKey2 ("cdab", sId2, _SEND "hello too") - Resp "" _ (Msg mId2 msg2) <- tGet h + Resp "" _ (Msg mId2 msg2) <- tGet1 h (decryptMsgV3 dhShared2 mId2 msg2, Right "hello too") #== "delivered from queue 2" Resp "dabc" _ OK <- signSendRecv h rKey2 ("dabc", rId2, DEL) @@ -535,7 +547,7 @@ testWithStoreLog at@(ATransport t) = Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello") Resp "cdab" _ (Msg mId3 msg3) <- signSendRecv h rKey1 ("cdab", rId1, SUB) (decryptMsgV3 dh1 mId3 msg3, Right "hello") #== "delivered from restored queue" - Resp "" _ (NMSG _ _) <- tGet h1 + Resp "" _ (NMSG _ _) <- tGet1 h1 -- this queue is removed - not restored sId2 <- readTVarIO senderId2 Resp "cdab" _ (ERR AUTH) <- signSendRecv h sKey2 ("cdab", sId2, _SEND "hello too") @@ -576,7 +588,7 @@ testRestoreMessages at@(ATransport t) = writeTVar dhShared $ Just dh writeTVar senderId sId Resp "1" _ OK <- signSendRecv h sKey ("1", sId, _SEND "hello") - Resp "" _ (Msg mId1 msg1) <- tGet h1 + Resp "" _ (Msg mId1 msg1) <- tGet1 h1 Resp "1a" _ OK <- signSendRecv h1 rKey ("1a", rId, ACK mId1) (decryptMsgV3 dh mId1 msg1, Right "hello") #== "message delivered" -- messages below are delivered after server restart @@ -645,7 +657,7 @@ testRestoreMessagesV2 at@(ATransport t) = writeTVar dhShared $ Just dh writeTVar senderId sId Resp "1" _ OK <- signSendRecv h sKey ("1", sId, _SEND "hello") - Resp "" _ (Msg mId1 msg1) <- tGet h1 + Resp "" _ (Msg mId1 msg1) <- tGet1 h1 Resp "1a" _ OK <- signSendRecv h1 rKey ("1a", rId, ACK mId1) (decryptMsgV2 dh mId1 msg1, Right "hello") #== "message delivered" -- messages below are delivered after server restart @@ -710,14 +722,15 @@ testTiming :: ATransport -> Spec testTiming (ATransport t) = it "should have similar time for auth error, whether queue exists or not, for all key sizes" $ smpTest2 t $ \rh sh -> - mapM_ - (testSameTiming rh sh) - [ (32, 32, 200), - (32, 57, 100), - (57, 32, 200), - (57, 57, 100) - ] + mapM_ (testSameTiming rh sh) timingTests where + timingTests :: [(Int, Int, Int)] + timingTests = + [ (32, 32, 200), + (32, 57, 100), + (57, 32, 200), + (57, 57, 100) + ] timeRepeat n = fmap fst . timeItT . forM_ (replicate n ()) . const similarTime t1 t2 = abs (t2 / t1 - 1) < 0.25 `shouldBe` True testSameTiming :: Transport c => THandle c -> THandle c -> (Int, Int, Int) -> Expectation @@ -735,7 +748,7 @@ testTiming (ATransport t) = Resp "dabc" _ OK <- signSendRecv rh rKey ("dabc", rId, KEY sPub) Resp "bcda" _ OK <- signSendRecv sh sKey ("bcda", sId, _SEND "hello") - Resp "" _ (Msg mId msg) <- tGet rh + Resp "" _ (Msg mId msg) <- tGet1 rh (dec mId msg, Right "hello") #== "delivered from queue" runTimingTest sh badKey sId $ _SEND "hello" @@ -774,23 +787,23 @@ testMessageNotifications (ATransport t) = nId' `shouldNotBe` nId Resp "2" _ OK <- signSendRecv nh1 nKey ("2", nId, NSUB) Resp "3" _ OK <- signSendRecv sh sKey ("3", sId, _SEND' "hello") - Resp "" _ (Msg mId1 msg1) <- tGet rh + Resp "" _ (Msg mId1 msg1) <- tGet1 rh (dec mId1 msg1, Right "hello") #== "delivered from queue" Resp "3a" _ OK <- signSendRecv rh rKey ("3a", rId, ACK mId1) - Resp "" _ (NMSG _ _) <- tGet nh1 + Resp "" _ (NMSG _ _) <- tGet1 nh1 Resp "4" _ OK <- signSendRecv nh2 nKey ("4", nId, NSUB) - Resp "" _ END <- tGet nh1 + Resp "" _ END <- tGet1 nh1 Resp "5" _ OK <- signSendRecv sh sKey ("5", sId, _SEND' "hello again") - Resp "" _ (Msg mId2 msg2) <- tGet rh + Resp "" _ (Msg mId2 msg2) <- tGet1 rh Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2) (dec mId2 msg2, Right "hello again") #== "delivered from queue again" - Resp "" _ (NMSG _ _) <- tGet nh2 + Resp "" _ (NMSG _ _) <- tGet1 nh2 1000 `timeout` tGet @BrokerMsg nh1 >>= \case Nothing -> pure () Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection" Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL) Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there") - Resp "" _ (Msg mId3 msg3) <- tGet rh + Resp "" _ (Msg mId3 msg3) <- tGet1 rh (dec mId3 msg3, Right "hello there") #== "delivered from queue again" 1000 `timeout` tGet @BrokerMsg nh2 >>= \case Nothing -> pure ()