SMP protocol v4: batching multiple server commands/responses in a transport block (#470)

* batch server commands in one transport block

* subscribe to multiple queues using batched commands

* agent method to subscribe to multiple queues using batched commands

* refactor

* test for batched subscriptions

* delete part of connections in batched test

* add resubscribeConnections

* remove comment

* update SMP protocol doc
This commit is contained in:
Evgeny Poberezkin 2022-07-17 10:10:38 +01:00 committed by GitHub
parent 1670c9c05e
commit a6f401041a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 545 additions and 248 deletions

View File

@ -51,6 +51,8 @@ It's designed with the focus on communication security and integrity, under the
It is designed as a low level protocol for other application protocols to solve the problem of secure and private message transmission, making [MITM attack][1] very difficult at any part of the message transmission system.
This document describes SMP protocol versions 3 and 4, the previous versions are discontinued.
## Introduction
The objective of Simplex Messaging Protocol (SMP) is to facilitate the secure and private unidirectional transfer of messages from senders to recipients via persistent simplex queues managed by the message broker (server).
@ -362,15 +364,16 @@ The clients can optionally instruct a dedicated push notification server to subs
[`SEND` command](#send-message) includes the notification flag to instruct SMP server whether to send the notification - this flag is forwarded to the recepient inside encrypted envelope, together with the timestamp and the message body, so even if TLS is compromised this flag cannot be used for traffic correlation.
## SMP Transmission structure
## SMP Transmission andtransport block structure
Each transport block (SMP transmission) has a fixed size of 16384 bytes for traffic uniformity.
From SMP version 4 each block can contain multiple transmissions, version 3 blocks have 1 transmission.
Some parts of SMP transmission are padded to a fixed size; this padding is uniformly added as a word16 encoded in network byte order - see `paddedString` syntax.
In places where some part of the transmission should be padded, the syntax for `paddedNotation` is used:
```
```abnf
paddedString = originalLength string pad
originalLength = 2*2 OCTET
pad = N*N"#" ; where N = paddedLength - originalLength - 2
@ -380,9 +383,9 @@ paddedNotation = <padded(string, paddedLength)>
; paddedLength - required length after padding, including 2 bytes for originalLength
```
Each transmission between the client and the server must have this format/syntax:
Each transmission/block for SMP v3 between the client and the server must have this format/syntax:
```
```abnf
paddedTransmission = <padded(transmission), 16384>
transmission = [signature] SP signed
signed = sessionIdentifier SP [corrId] SP [queueId] SP smpCommand
@ -399,6 +402,16 @@ encoded = <base64 encoded binary>
`base64` encoding should be used with padding, as defined in section 4 of [RFC 4648][9]
Transport block for SMP v4 has this syntax:
```abnf
paddedTransportBlock = <padded(transportBlock), 16384>
transportBlock = transmissionCount transmissions
transmissionCount = 1*1 OCTET ; equal or greater than 1
transmissions = transmissionLength transmission [transmissions]
transmissionLength = 2*2 OCTET ; word16 encoded in network byte order
```
## SMP commands
Commands syntax below is provided using [ABNF][8] with [case-sensitive strings extension][8a].

View File

@ -45,9 +45,11 @@ module Simplex.Messaging.Agent
acceptContact,
rejectContact,
subscribeConnection,
subscribeConnections,
getConnectionMessage,
getNotificationMessage,
resubscribeConnection,
resubscribeConnections,
sendMessage,
ackMessage,
suspendConnection,
@ -79,6 +81,7 @@ import Data.Composition ((.:), (.:.))
import Data.Functor (($>))
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (isJust)
import qualified Data.Text as T
@ -105,10 +108,10 @@ import Simplex.Messaging.Parsers (parse)
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta)
import qualified Simplex.Messaging.Protocol as SMP
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (bshow, eitherToMaybe, liftE, liftError, tryError, unlessM, whenM, ($>>=))
import Simplex.Messaging.Util
import Simplex.Messaging.Version
import System.Random (randomR)
import UnliftIO.Async (async, race_)
import UnliftIO.Async (async, mapConcurrently, race_)
import UnliftIO.Concurrent (forkFinally, forkIO, threadDelay)
import qualified UnliftIO.Exception as E
import UnliftIO.STM
@ -158,6 +161,10 @@ rejectContact c = withAgentEnv c .: rejectContact' c
subscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
subscribeConnection c = withAgentEnv c . subscribeConnection' c
-- | Subscribe to receive connection messages from multiple connections, batching commands when possible
subscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
subscribeConnections c = withAgentEnv c . subscribeConnections' c
-- | Get connection message (GET command)
getConnectionMessage :: AgentErrorMonad m => AgentClient -> ConnId -> m (Maybe SMPMsgMeta)
getConnectionMessage c = withAgentEnv c . getConnectionMessage' c
@ -169,6 +176,9 @@ getNotificationMessage c = withAgentEnv c .: getNotificationMessage' c
resubscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
resubscribeConnection c = withAgentEnv c . resubscribeConnection' c
resubscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
resubscribeConnections c = withAgentEnv c . resubscribeConnections' c
-- | Send message to the connection (SEND command)
sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId
sendMessage c = withAgentEnv c .:. sendMessage' c
@ -393,12 +403,62 @@ subscribeConnection' c connId =
ns <- asks ntfSupervisor
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
subscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
subscribeConnections' _ [] = pure M.empty
subscribeConnections' c connIds = do
conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (forM connIds . getConn)
let (errs, cs) = M.mapEither id conns
errs' = M.map (Left . storeError) errs
(sndQs, rcvQs) = M.mapEither rcvOrSndQueue cs
sndRs = M.map (sndSubResult . fst) sndQs
srvRcvQs :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) = M.foldlWithKey' addRcvQueue M.empty rcvQs
mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs
rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs)
let rs = M.unions $ errs' : sndRs : rcvRs
notifyResultError rs
pure rs
where
rcvOrSndQueue :: SomeConn -> Either (SndQueue, ConnData) (RcvQueue, ConnData)
rcvOrSndQueue = \case
SomeConn _ (DuplexConnection cData rq _) -> Right (rq, cData)
SomeConn _ (SndConnection cData sq) -> Left (sq, cData)
SomeConn _ (RcvConnection cData rq) -> Right (rq, cData)
SomeConn _ (ContactConnection cData rq) -> Right (rq, cData)
sndSubResult :: SndQueue -> Either AgentErrorType ()
sndSubResult sq = case status (sq :: SndQueue) of
Confirmed -> Right ()
Active -> Left $ CONN SIMPLEX
_ -> Left $ INTERNAL "unexpected queue status"
addRcvQueue :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) -> ConnId -> (RcvQueue, ConnData) -> Map SMPServer (Map ConnId (RcvQueue, ConnData))
addRcvQueue m connId rq@(RcvQueue {server}, _) = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m
subscribe :: (SMPServer, Map ConnId (RcvQueue, ConnData)) -> m (Map ConnId (Either AgentErrorType ()))
subscribe (srv, qs) = subscribeQueues c srv (M.map fst qs)
sndQueue :: SomeConn -> Maybe (ConnData, SndQueue)
sndQueue = \case
SomeConn _ (DuplexConnection cData _ sq) -> Just (cData, sq)
SomeConn _ (SndConnection cData sq) -> Just (cData, sq)
_ -> Nothing
notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m ()
notifyResultError rs = do
let actual = M.size rs
expected = length connIds
when (actual /= expected) . atomically $
writeTBQueue (subQ c) ("", "", ERR . INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected)
resubscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
resubscribeConnection' c connId =
unlessM
(atomically $ hasActiveSubscription c connId)
(subscribeConnection' c connId)
resubscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
resubscribeConnections' _ [] = pure M.empty
resubscribeConnections' c connIds = do
let r = M.fromList . zip connIds . repeat $ Right ()
connIds' <- filterM (fmap not . atomically . hasActiveSubscription c) connIds
-- union is left-biased, so results returned by subscribeConnections' take precedence
(`M.union` r) <$> subscribeConnections' c connIds'
getConnectionMessage' :: AgentMonad m => AgentClient -> ConnId -> m (Maybe SMPMsgMeta)
getConnectionMessage' c connId = do
whenM (atomically $ hasActiveSubscription c connId) . throwError $ CMD PROHIBITED

View File

@ -21,6 +21,7 @@ module Simplex.Messaging.Agent.Client
closeAgentClient,
newRcvQueue,
subscribeQueue,
subscribeQueues,
getQueueMessage,
decryptSMPMessage,
addSubscription,
@ -64,6 +65,7 @@ module Simplex.Messaging.Agent.Client
whenSuspending,
withStore,
withStore',
storeError,
)
where
@ -74,16 +76,19 @@ import Control.Logger.Simple
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Reader
import Data.Bifunctor (first)
import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (partitionEithers)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes)
import Data.Set (Set)
import Data.Text.Encoding
import Data.Tuple (swap)
import Data.Word (Word16)
import qualified Database.SQLite.Simple as DB
import Simplex.Messaging.Agent.Env.SQLite
@ -103,7 +108,7 @@ import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), MsgId, N
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (bshow, catchAll_, ifM, liftEitherError, liftError, tryError, unlessM, whenM)
import Simplex.Messaging.Util
import Simplex.Messaging.Version
import System.Timeout (timeout)
import UnliftIO (async, pooledForConcurrentlyN)
@ -476,14 +481,42 @@ subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
atomically $ addPendingSubscription c rq connId
withLogClient c server rcvId "SUB" $ \smp -> do
liftIO (runExceptT $ subscribeSMPQueue smp rcvPrivateKey rcvId) >>= \case
Left e -> do
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
removePendingSubscription c server connId
throwError e
Right _ -> do
addSubscription c rq connId
withLogClient c server rcvId "SUB" $ \smp ->
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId)
>>= either throwError pure
processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
processSubResult c rq@RcvQueue {server} connId r = do
case r of
Left e ->
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
removePendingSubscription c server connId
_ -> addSubscription c rq connId
pure r
-- | subscribe multiple queues - all passed queues should be on the same server
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Map ConnId (Either AgentErrorType ()))
subscribeQueues c srv qs = do
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
forM_ qs_ $ atomically . uncurry (addPendingSubscription c) . swap
case L.nonEmpty qs_ of
Just qs' -> do
smp_ <- tryError (getSMPServerClient c srv)
M.fromList . (errs <>) <$> case smp_ of
Left e -> pure $ map (second . const $ Left e) qs_
Right smp -> do
logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB"
let qs2 = L.map (queueCreds . snd) qs'
rs' :: [((ConnId, RcvQueue), Either ProtocolClientError ())] <-
liftIO $ zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
forM_ rs' $ \((connId, rq), r) -> liftIO $ processSubResult c rq connId r
pure $ map (bimap fst (first $ protocolClientError SMP)) rs'
_ -> pure $ M.fromList errs
where
checkQueue rq@(connId, RcvQueue {rcvId, server}) = do
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
pure $ if prohibited || srv /= server then Left (connId, Left $ CMD PROHIBITED) else Right rq
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m ()
addSubscription c rq@RcvQueue {server} connId = atomically $ do
@ -762,15 +795,16 @@ withStore c action = do
where
handleInternal :: E.SomeException -> IO (Either StoreError a)
handleInternal = pure . Left . SEInternal . bshow
storeError :: StoreError -> AgentErrorType
storeError = \case
SEConnNotFound -> CONN NOT_FOUND
SEConnDuplicate -> CONN DUPLICATE
SEBadConnType CRcv -> CONN SIMPLEX
SEBadConnType CSnd -> CONN SIMPLEX
SEInvitationNotFound -> CMD PROHIBITED
-- this error is never reported as store error,
-- it is used to wrap agent operations when "transaction-like" store access is needed
-- NOTE: network IO should NOT be used inside AgentStoreMonad
SEAgentError e -> e
e -> INTERNAL $ show e
storeError :: StoreError -> AgentErrorType
storeError = \case
SEConnNotFound -> CONN NOT_FOUND
SEConnDuplicate -> CONN DUPLICATE
SEBadConnType CRcv -> CONN SIMPLEX
SEBadConnType CSnd -> CONN SIMPLEX
SEInvitationNotFound -> CMD PROHIBITED
-- this error is never reported as store error,
-- it is used to wrap agent operations when "transaction-like" store access is needed
-- NOTE: network IO should NOT be used inside AgentStoreMonad
SEAgentError e -> e
e -> INTERNAL $ show e

View File

@ -16,8 +16,8 @@ module Simplex.Messaging.Agent.Store.SQLite.Migrations
where
import Control.Monad (forM_)
import Data.Function (on)
import Data.List (intercalate, sortBy)
import Data.Ord (comparing)
import Data.Text (Text)
import Data.Time.Clock (getCurrentTime)
import Database.SQLite.Simple (Connection, Only (..), Query (..))
@ -44,7 +44,7 @@ schemaMigrations =
-- | The list of migrations in ascending order by date
app :: [Migration]
app = sortBy (compare `on` name) $ map migration schemaMigrations
app = sortBy (comparing name) $ map migration schemaMigrations
where
migration (name, query) = Migration {name = name, up = fromQuery query}

View File

@ -6,6 +6,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
@ -32,11 +33,14 @@ module Simplex.Messaging.Client
-- * SMP protocol command functions
createSMPQueue,
subscribeSMPQueue,
subscribeSMPQueues,
getSMPMessage,
subscribeSMPQueueNotifications,
secureSMPQueue,
enableSMPQueueNotifications,
disableSMPQueueNotifications,
enableSMPQueuesNtfs,
disableSMPQueuesNtfs,
sendSMPMessage,
ackSMPMessage,
suspendSMPQueue,
@ -60,6 +64,10 @@ import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (rights)
import Data.Functor (($>))
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as L
import Data.Maybe (fromMaybe)
import Network.Socket (ServiceName)
import Numeric.Natural
@ -87,13 +95,16 @@ data ProtocolClient msg = ProtocolClient
tcpTimeout :: Int,
clientCorrId :: TVar Natural,
sentCommands :: TMap CorrId (Request msg),
sndQ :: TBQueue SentRawTransmission,
rcvQ :: TBQueue (SignedTransmission msg),
sndQ :: TBQueue (NonEmpty (SentRawTransmission)),
rcvQ :: TBQueue (NonEmpty (SignedTransmission msg)),
msgQ :: Maybe (TBQueue (ServerTransmission msg))
}
type SMPClient = ProtocolClient SMP.BrokerMsg
-- | Type for client command data
type ClientCommand msg = (Maybe C.APrivateSignKey, QueueId, ProtoCommand msg)
-- | Type synonym for transmission from some SPM server queue.
type ServerTransmission msg = (ProtoServer msg, Version, SessionId, QueueId, msg)
@ -208,13 +219,15 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tc
runExceptT $ sendProtocolCommand c Nothing "" protocolPing
process :: ProtocolClient msg -> IO ()
process c@ProtocolClient {rcvQ, sentCommands} = forever $ do
(_, _, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ
process c = forever $ atomically (readTBQueue $ rcvQ c) >>= mapM_ (processMsg c)
processMsg :: ProtocolClient msg -> SignedTransmission msg -> IO ()
processMsg c@ProtocolClient {sentCommands} (_, _, (corrId, qId, respOrErr)) =
if B.null $ bs corrId
then sendMsg qId respOrErr
then sendMsg respOrErr
else do
atomically (TM.lookup corrId sentCommands) >>= \case
Nothing -> sendMsg qId respOrErr
Nothing -> sendMsg respOrErr
Just Request {queueId, responseVar} -> atomically $ do
TM.delete corrId sentCommands
putTMVar responseVar $
@ -226,8 +239,8 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tc
_ -> Right r
else Left . PCEUnexpectedResponse $ bshow respOrErr
where
sendMsg :: QueueId -> Either ErrorType msg -> IO ()
sendMsg qId = \case
sendMsg :: Either ErrorType msg -> IO ()
sendMsg = \case
Right msg -> atomically $ mapM_ (`writeTBQueue` serverTransmission c qId msg) msgQ
-- TODO send everything else to errQ and log in agent
_ -> return ()
@ -285,12 +298,22 @@ subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT Pr
subscribeSMPQueue c rpKey rId =
sendSMPCommand c (Just rpKey) rId SUB >>= \case
OK -> return ()
cmd@MSG {} -> writeSMPMessage c rId cmd
cmd@MSG {} -> liftIO $ writeSMPMessage c rId cmd
r -> throwE . PCEUnexpectedResponse $ bshow r
writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> ExceptT ProtocolClientError IO ()
writeSMPMessage c rId msg =
liftIO . atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ c)
-- | Subscribe to multiple SMP queues batching commands if supported.
subscribeSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
subscribeSMPQueues c qs = sendProtocolCommands c cs >>= mapM response . L.zip qs
where
cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient SUB)) qs
response ((_, rId), r) = case r of
Right OK -> pure $ Right ()
Right cmd@MSG {} -> writeSMPMessage c rId cmd $> Right ()
Right r' -> pure . Left . PCEUnexpectedResponse $ bshow r'
Left e -> pure $ Left e
writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO ()
writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ c)
serverTransmission :: ProtocolClient msg -> RecipientId -> msg -> ServerTransmission msg
serverTransmission ProtocolClient {protocolServer, thVersion, sessionId} entityId message =
@ -303,9 +326,7 @@ getSMPMessage :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT Protoc
getSMPMessage c rpKey rId =
sendSMPCommand c (Just rpKey) rId GET >>= \case
OK -> pure Nothing
cmd@(MSG msg) -> do
writeSMPMessage c rId cmd
pure $ Just msg
cmd@(MSG msg) -> liftIO (writeSMPMessage c rId cmd) $> Just msg
r -> throwE . PCEUnexpectedResponse $ bshow r
-- | Subscribe to the SMP queue notifications.
@ -329,12 +350,32 @@ enableSMPQueueNotifications c rpKey rId notifierKey rcvNtfPublicDhKey =
NID nId rcvNtfSrvPublicDhKey -> pure (nId, rcvNtfSrvPublicDhKey)
r -> throwE . PCEUnexpectedResponse $ bshow r
-- | Enable notifications for the multiple queues for push notifications server.
enableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId, NtfPublicVerifyKey, RcvNtfPublicDhKey) -> IO (NonEmpty (Either ProtocolClientError (NotifierId, RcvNtfPublicDhKey)))
enableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs
where
cs = L.map (\(rpKey, rId, notifierKey, rcvNtfPublicDhKey) -> (Just rpKey, rId, Cmd SRecipient $ NKEY notifierKey rcvNtfPublicDhKey)) qs
response = \case
Right (NID nId rcvNtfSrvPublicDhKey) -> Right (nId, rcvNtfSrvPublicDhKey)
Right r -> Left . PCEUnexpectedResponse $ bshow r
Left e -> Left e
-- | Disable notifications for the queue for push notifications server.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#disable-notifications-command
disableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
disableSMPQueueNotifications = okSMPCommand NDEL
-- | Disable notifications for multiple queues for push notifications server.
disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
disableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs
where
cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient NDEL)) qs
response = \case
Right OK -> Right ()
Right r -> Left . PCEUnexpectedResponse $ bshow r
Left e -> Left e
-- | Send SMP message.
--
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#send-message
@ -351,7 +392,7 @@ ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> MsgId -> ExceptT P
ackSMPMessage c rpKey rId msgId =
sendSMPCommand c (Just rpKey) rId (ACK msgId) >>= \case
OK -> return ()
cmd@MSG {} -> writeSMPMessage c rId cmd
cmd@MSG {} -> liftIO $ writeSMPMessage c rId cmd
r -> throwE . PCEUnexpectedResponse $ bshow r
-- | Irreversibly suspend SMP queue.
@ -377,37 +418,48 @@ okSMPCommand cmd c pKey qId =
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT ProtocolClientError IO BrokerMsg
sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
-- | Send multiple commands with batching and collect responses
sendProtocolCommands :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Either ProtocolClientError msg))
sendProtocolCommands c@ProtocolClient {sndQ, tcpTimeout} cs = do
ts <- mapM (runExceptT . mkTransmission c) cs
mapM_ (atomically . writeTBQueue sndQ . L.map fst) . L.nonEmpty . rights $ L.toList ts
forConcurrently ts $ \case
Right (_, r) -> withTimeout . atomically $ takeTMVar r
Left e -> pure $ Left e
where
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
-- | Send Protocol command
sendProtocolCommand :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT ProtocolClientError IO msg
sendProtocolCommand ProtocolClient {sndQ, sentCommands, clientCorrId, sessionId, thVersion, tcpTimeout} pKey qId cmd = do
corrId <- lift_ getNextCorrId
t <- signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd)
ExceptT $ sendRecv corrId t
sendProtocolCommand c@ProtocolClient {sndQ, tcpTimeout} pKey qId cmd = do
(t, r) <- mkTransmission c (pKey, qId, cmd)
ExceptT $ sendRecv t r
where
lift_ :: STM a -> ExceptT ProtocolClientError IO a
lift_ action = ExceptT $ Right <$> atomically action
-- two separate "atomically" needed to avoid blocking
sendRecv :: SentRawTransmission -> TMVar (Response msg) -> IO (Response msg)
sendRecv t r = atomically (writeTBQueue sndQ [t]) >> withTimeout (atomically $ takeTMVar r)
where
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
mkTransmission :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> ClientCommand msg -> ExceptT ProtocolClientError IO (SentRawTransmission, TMVar (Response msg))
mkTransmission ProtocolClient {clientCorrId, sessionId, thVersion, sentCommands} (pKey, qId, cmd) = do
corrId <- liftIO $ atomically getNextCorrId
t <- signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd)
r <- liftIO . atomically $ mkRequest corrId
pure (t, r)
where
getNextCorrId :: STM CorrId
getNextCorrId = do
i <- stateTVar clientCorrId $ \i -> (i, i + 1)
pure . CorrId $ bshow i
signTransmission :: ByteString -> ExceptT ProtocolClientError IO SentRawTransmission
signTransmission t = case pKey of
Nothing -> return (Nothing, t)
Nothing -> pure (Nothing, t)
Just pk -> do
sig <- liftError PCESignatureError $ C.sign pk t
return (Just sig, t)
-- two separate "atomically" needed to avoid blocking
sendRecv :: CorrId -> SentRawTransmission -> IO (Response msg)
sendRecv corrId t = atomically (send corrId t) >>= withTimeout . atomically . takeTMVar
where
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
send :: CorrId -> SentRawTransmission -> STM (TMVar (Response msg))
send corrId t = do
mkRequest :: CorrId -> STM (TMVar (Response msg))
mkRequest corrId = do
r <- newEmptyTMVar
TM.insert corrId (Request qId r) sentCommands
writeTBQueue sndQ t
return r
pure r

View File

@ -13,6 +13,7 @@ module Simplex.Messaging.Encoding
Large (..),
smpEncodeList,
smpListP,
lenEncode,
)
where

View File

@ -4,6 +4,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
@ -248,22 +249,23 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte
receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m ()
receive th NtfServerClient {rcvQ, sndQ, activeAt} = forever $ do
t@(_, _, (corrId, entId, cmdOrError)) <- tGet th
atomically . writeTVar activeAt =<< liftIO getSystemTime
logDebug "received transmission"
case cmdOrError of
Left e -> write sndQ (corrId, entId, NRErr e)
Right cmd ->
verifyNtfTransmission t cmd >>= \case
VRVerified req -> write rcvQ req
VRFailed -> write sndQ (corrId, entId, NRErr AUTH)
ts <- tGet th
forM_ ts $ \t@(_, _, (corrId, entId, cmdOrError)) -> do
atomically . writeTVar activeAt =<< liftIO getSystemTime
logDebug "received transmission"
case cmdOrError of
Left e -> write sndQ (corrId, entId, NRErr e)
Right cmd ->
verifyNtfTransmission t cmd >>= \case
VRVerified req -> write rcvQ req
VRFailed -> write sndQ (corrId, entId, NRErr AUTH)
where
write q t = atomically $ writeTBQueue q t
send :: (Transport c, MonadUnliftIO m) => THandle c -> NtfServerClient -> m ()
send h@THandle {thVersion = v} NtfServerClient {sndQ, sessionId, activeAt} = forever $ do
t <- atomically $ readTBQueue sndQ
void . liftIO $ tPut h (Nothing, encodeTransmission v sessionId t)
void . liftIO $ tPut h [(Nothing, encodeTransmission v sessionId t)]
atomically . writeTVar activeAt =<< liftIO getSystemTime
-- instance Show a => Show (TVar a) where

View File

@ -69,4 +69,4 @@ ntfClientHandshake c keyHash ntfVRange = do
Nothing -> throwError $ TEHandshake VERSION
ntfTHandle :: Transport c => c -> THandle c
ntfTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0}
ntfTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0, batch = False}

View File

@ -7,6 +7,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
@ -130,6 +131,8 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Kind
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Maybe (isNothing)
import Data.String
import Data.Time.Clock.System (SystemTime (..))
@ -1010,20 +1013,51 @@ instance Encoding CommandError where
_ -> fail "bad command error type"
-- | Send signed SMP transmission to TCP transport.
tPut :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ())
tPut th (sig, t) = tPutBlock th $ smpEncode (C.signatureBytes sig) <> t
tPut :: Transport c => THandle c -> NonEmpty (SentRawTransmission) -> IO (NonEmpty (Either TransportError ()))
tPut th trs
| batch th = tPutBatch [] $ L.map tEncode trs
| otherwise = forM trs $ tPutBlock th . tEncode
where
tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO (NonEmpty (Either TransportError ()))
tPutBatch rs ts = do
let (n, s, ts_) = encodeBatch 0 "" ts
r <- if n == 0 then pure [Left TELargeMsg] else replicate n <$> tPutBlock th (lenEncode n `B.cons` s)
let rs' = rs <> r
case ts_ of
Just ts' -> tPutBatch rs' ts'
_ -> pure $ L.fromList rs'
encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString))
encodeBatch n s ts@(t :| ts_)
| n == 255 = (n, s, Just ts)
| otherwise =
let s' = s <> smpEncode (Large t)
n' = n + 1
in if B.length s' > blockSize th - 1
then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts
else case L.nonEmpty ts_ of
Just ts' -> encodeBatch n' s' ts'
_ -> (n', s', Nothing)
tEncode (sig, tr) = smpEncode (C.signatureBytes sig) <> tr
encodeTransmission :: ProtocolEncoding c => Version -> ByteString -> Transmission c -> ByteString
encodeTransmission v sessionId (CorrId corrId, queueId, command) =
smpEncode (sessionId, corrId, queueId) <> encodeProtocol v command
-- | Receive and parse transmission from the TCP transport (ignoring any trailing padding).
tGetParse :: Transport c => THandle c -> IO (Either TransportError RawTransmission)
tGetParse th = (parse transmissionP TEBadBlock =<<) <$> tGetBlock th
tGetParse :: Transport c => THandle c -> IO (NonEmpty (Either TransportError RawTransmission))
tGetParse th
| batch th = either ((:| []) . Left) id <$> runExceptT getBatch
| otherwise = (:| []) . (parse transmissionP TEBadBlock =<<) <$> tGetBlock th
where
getBatch :: ExceptT TransportError IO (NonEmpty (Either TransportError RawTransmission))
getBatch = do
s <- ExceptT $ tGetBlock th
ts <- liftEither $ parse smpP TEBadBlock s
pure $ L.map (\(Large t) -> parse transmissionP TEBadBlock t) ts
-- | Receive client and server transmissions (determined by `cmd` type).
tGet :: forall cmd c m. (ProtocolEncoding cmd, Transport c, MonadIO m) => THandle c -> m (SignedTransmission cmd)
tGet th@THandle {sessionId, thVersion = v} = liftIO (tGetParse th) >>= decodeParseValidate
tGet :: forall cmd c m. (ProtocolEncoding cmd, Transport c, MonadIO m) => THandle c -> m (NonEmpty (SignedTransmission cmd))
tGet th@THandle {sessionId, thVersion = v} = liftIO (tGetParse th) >>= mapM decodeParseValidate
where
decodeParseValidate :: Either TransportError RawTransmission -> m (SignedTransmission cmd)
decodeParseValidate = \case

View File

@ -6,6 +6,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
@ -43,9 +44,10 @@ import Crypto.Random
import Data.Bifunctor (first)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (fromRight)
import Data.Either (fromRight, partitionEithers)
import Data.Functor (($>))
import Data.List (intercalate)
import qualified Data.List.NonEmpty as L
import qualified Data.Map.Strict as M
import Data.Maybe (isNothing)
import Data.Set (Set)
@ -145,7 +147,7 @@ smpServer started = do
endPreviousSubscriptions :: (QueueId, Client) -> m (Maybe s)
endPreviousSubscriptions (qId, c) = do
void . forkIO . atomically $
writeTBQueue (sndQ c) (CorrId "", qId, END)
writeTBQueue (sndQ c) [(CorrId "", qId, END)]
atomically $ TM.lookupDelete qId (clientSubs c)
expireMessagesThread_ :: ServerConfig -> [m ()]
@ -243,26 +245,36 @@ cancelSub sub =
Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread
_ -> return ()
receive :: (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> Client -> m ()
receive :: forall c m. (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> Client -> m ()
receive th Client {rcvQ, sndQ, activeAt} = forever $ do
(sig, signed, (corrId, queueId, cmdOrError)) <- tGet th
ts <- L.toList <$> tGet th
atomically . writeTVar activeAt =<< liftIO getSystemTime
case cmdOrError of
Left e -> write sndQ (corrId, queueId, ERR e)
Right cmd -> do
verified <- verifyTransmission sig signed queueId cmd
if verified
then write rcvQ (corrId, queueId, cmd)
else write sndQ (corrId, queueId, ERR AUTH)
as <- partitionEithers <$> mapM cmdAction ts
write sndQ $ fst as
write rcvQ $ snd as
where
write q t = atomically $ writeTBQueue q t
cmdAction :: SignedTransmission Cmd -> m (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd))
cmdAction (sig, signed, (corrId, queueId, cmdOrError)) =
case cmdOrError of
Left e -> pure $ Left (corrId, queueId, ERR e)
Right cmd -> verified <$> verifyTransmission sig signed queueId cmd
where
verified = \case
VRVerified qr -> Right (qr, (corrId, queueId, cmd))
VRFailed -> Left (corrId, queueId, ERR AUTH)
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
send :: (Transport c, MonadUnliftIO m) => THandle c -> Client -> m ()
send h@THandle {thVersion = v} Client {sndQ, sessionId, activeAt} = forever $ do
t <- atomically $ readTBQueue sndQ
-- TODO the line below can return Left, but we ignore it and do not disconnect the client
void . liftIO $ tPut h (Nothing, encodeTransmission v sessionId t)
ts <- atomically $ L.sortWith tOrder <$> readTBQueue sndQ
-- TODO the line below can return Lefts, but we ignore it and do not disconnect the client
void . liftIO . tPut h $ L.map ((Nothing,) . encodeTransmission v sessionId) ts
atomically . writeTVar activeAt =<< liftIO getSystemTime
where
tOrder :: Transmission BrokerMsg -> Int
tOrder (_, _, cmd) = case cmd of
MSG {} -> 0
_ -> 1
disconnectTransport :: (Transport c, MonadUnliftIO m) => THandle c -> client -> (client -> TVar SystemTime) -> ExpirationConfig -> m ()
disconnectTransport THandle {connection} c activeAt expCfg = do
@ -273,23 +285,28 @@ disconnectTransport THandle {connection} c activeAt expCfg = do
ts <- readTVarIO $ activeAt c
when (systemSeconds ts < old) $ closeConnection connection
data VerificationResult = VRVerified (Maybe QueueRec) | VRFailed
verifyTransmission ::
forall m. (MonadUnliftIO m, MonadReader Env m) => Maybe C.ASignature -> ByteString -> QueueId -> Cmd -> m Bool
forall m. (MonadUnliftIO m, MonadReader Env m) => Maybe C.ASignature -> ByteString -> QueueId -> Cmd -> m VerificationResult
verifyTransmission sig_ signed queueId cmd = do
case cmd of
Cmd SRecipient (NEW k _) -> pure $ verifyCmdSignature sig_ signed k
Cmd SRecipient (NEW k _) -> pure $ Nothing `verified` verifyCmdSignature sig_ signed k
Cmd SRecipient _ -> verifyCmd SRecipient $ verifyCmdSignature sig_ signed . recipientKey
Cmd SSender SEND {} -> verifyCmd SSender $ verifyMaybe . senderKey
Cmd SSender PING -> pure True
Cmd SSender PING -> pure $ VRVerified Nothing
Cmd SNotifier NSUB -> verifyCmd SNotifier $ verifyMaybe . fmap notifierKey . notifier
where
verifyCmd :: SParty p -> (QueueRec -> Bool) -> m Bool
verifyCmd :: SParty p -> (QueueRec -> Bool) -> m VerificationResult
verifyCmd party f = do
st <- asks queueStore
q <- atomically $ getQueue st party queueId
pure $ either (const $ maybe False (dummyVerifyCmd signed) sig_ `seq` False) f q
q_ <- atomically (getQueue st party queueId)
pure $ case q_ of
Right q -> Just q `verified` f q
_ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
verifyMaybe :: Maybe C.APublicVerifyKey -> Bool
verifyMaybe = maybe (isNothing sig_) $ verifyCmdSignature sig_ signed
verified q cond = if cond then VRVerified q else VRFailed
verifyCmdSignature :: Maybe C.ASignature -> ByteString -> C.APublicVerifyKey -> Bool
verifyCmdSignature sig_ signed key = maybe False (verify key) sig_
@ -320,16 +337,16 @@ client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server ->
client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscribedQ, ntfSubscribedQ, notifiers} =
forever $
atomically (readTBQueue rcvQ)
>>= processCommand
>>= mapM processCommand
>>= atomically . writeTBQueue sndQ
where
processCommand :: Transmission Cmd -> m (Transmission BrokerMsg)
processCommand (corrId, queueId, cmd) = do
processCommand :: (Maybe QueueRec, Transmission Cmd) -> m (Transmission BrokerMsg)
processCommand (qr_, (corrId, queueId, cmd)) = do
st <- asks queueStore
case cmd of
Cmd SSender command ->
case command of
SEND flags msgBody -> sendMessage st flags msgBody
SEND flags msgBody -> withQueue $ \qr -> sendMessage qr flags msgBody
PING -> pure (corrId, "", PONG)
Cmd SNotifier NSUB -> subscribeNotifications
Cmd SRecipient command ->
@ -339,9 +356,9 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
(asks $ allowNewQueues . config)
(createQueue st rKey dhKey)
(pure (corrId, queueId, ERR AUTH))
SUB -> subscribeQueue st queueId
GET -> getMessage st
ACK msgId -> acknowledgeMsg st msgId
SUB -> withQueue (`subscribeQueue` queueId)
GET -> withQueue getMessage
ACK msgId -> withQueue (`acknowledgeMsg` msgId)
KEY sKey -> secureQueue_ st sKey
NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey
NDEL -> deleteQueueNotifier_ st
@ -371,14 +388,15 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
addQueueRetry n qik qRec = do
ids@(rId, _) <- getIds
-- create QueueRec record with these ids and keys
atomically (addQueue st $ qRec ids) >>= \case
let qr = qRec ids
atomically (addQueue st qr) >>= \case
Left DUPLICATE_ -> addQueueRetry (n - 1) qik qRec
Left e -> pure $ ERR e
Right _ -> do
withLog (`logCreateById` rId)
stats <- asks serverStats
atomically $ modifyTVar (qCreated stats) (+ 1)
subscribeQueue st rId $> IDS (qik ids)
subscribeQueue qr rId $> IDS (qik ids)
logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO ()
logCreateById s rId =
@ -426,8 +444,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
withLog (`logDeleteQueue` queueId)
okResp <$> atomically (suspendQueue st queueId)
subscribeQueue :: QueueStore -> RecipientId -> m (Transmission BrokerMsg)
subscribeQueue st rId =
subscribeQueue :: QueueRec -> RecipientId -> m (Transmission BrokerMsg)
subscribeQueue qr rId =
atomically (TM.lookup rId subscriptions) >>= \case
Nothing ->
atomically newSub >>= deliver
@ -449,10 +467,10 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
deliver sub = do
q <- getStoreMsgQueue rId
msg_ <- atomically $ tryPeekMsg q
deliverMessage st rId sub q msg_
deliverMessage qr rId sub q msg_
getMessage :: QueueStore -> m (Transmission BrokerMsg)
getMessage st =
getMessage :: QueueRec -> m (Transmission BrokerMsg)
getMessage qr =
atomically (TM.lookup queueId subscriptions) >>= \case
Nothing ->
atomically newSub >>= getMessage_
@ -471,7 +489,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
TM.insert queueId sub subscriptions
pure s
getMessage_ :: Sub -> m (Transmission BrokerMsg)
getMessage_ s = withRcvQueue st queueId $ \qr -> do
getMessage_ s = do
q <- getStoreMsgQueue queueId
atomically $
tryPeekMsg q >>= \case
@ -480,11 +498,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
in setDelivered s msg $> (corrId, queueId, MSG encMsg)
_ -> pure (corrId, queueId, OK)
withRcvQueue :: QueueStore -> RecipientId -> (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg)
withRcvQueue st rId action =
atomically (getQueue st SRecipient rId) >>= \case
Left e -> pure (corrId, rId, ERR e)
Right qr -> action qr
withQueue :: (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg)
withQueue action = maybe (pure $ err AUTH) action qr_
subscribeNotifications :: m (Transmission BrokerMsg)
subscribeNotifications = atomically $ do
@ -493,8 +508,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
TM.insert queueId () ntfSubscriptions
pure ok
acknowledgeMsg :: QueueStore -> MsgId -> m (Transmission BrokerMsg)
acknowledgeMsg st msgId = do
acknowledgeMsg :: QueueRec -> MsgId -> m (Transmission BrokerMsg)
acknowledgeMsg qr msgId = do
atomically (TM.lookup queueId subscriptions) >>= \case
Nothing -> pure $ err NO_MSG
Just sub ->
@ -509,7 +524,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
_ -> do
(msgDeleted, msg_) <- atomically $ tryDelPeekMsg q msgId
when msgDeleted updateStats
deliverMessage st queueId sub q msg_
deliverMessage qr queueId sub q msg_
_ -> pure $ err NO_MSG
where
getDelivered :: TVar Sub -> STM (Maybe Sub)
@ -533,74 +548,69 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
where
updatePeriod pSel = modifyTVar (pSel stats) (S.insert qId)
sendMessage :: QueueStore -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg)
sendMessage st msgFlags msgBody
sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg)
sendMessage qr msgFlags msgBody
| B.length msgBody > maxMessageLength = pure $ err LARGE_MSG
| otherwise = do
qr <- atomically $ getQueue st SSender queueId
either (return . err) storeMessage qr
| otherwise = case status qr of
QueueOff -> return $ err AUTH
QueueActive ->
mapM mkMessage (C.maxLenBS msgBody) >>= \case
Left _ -> pure $ err LARGE_MSG
Right msg -> do
ms <- asks msgStore
ServerConfig {messageExpiration, msgQueueQuota} <- asks config
old <- liftIO $ mapM expireBeforeEpoch messageExpiration
ntfNonceDrg <- asks idsDrg
resp@(_, _, sent) <- atomically $ do
q <- getMsgQueue ms (recipientId qr) msgQueueQuota
mapM_ (deleteExpiredMsgs q) old
ifM (isFull q) (pure $ err QUOTA) $ do
when (notification msgFlags) $ trySendNotification msg ntfNonceDrg
writeMsg q msg
pure ok
when (sent == OK) $ do
stats <- asks serverStats
atomically $ modifyTVar (msgSent stats) (+ 1)
atomically $ updateActiveQueues stats $ recipientId qr
pure resp
where
storeMessage :: QueueRec -> m (Transmission BrokerMsg)
storeMessage qr = case status qr of
QueueOff -> return $ err AUTH
QueueActive ->
mapM mkMessage (C.maxLenBS msgBody) >>= \case
Left _ -> pure $ err LARGE_MSG
Right msg -> do
ms <- asks msgStore
ServerConfig {messageExpiration, msgQueueQuota} <- asks config
old <- liftIO $ mapM expireBeforeEpoch messageExpiration
ntfNonceDrg <- asks idsDrg
resp@(_, _, sent) <- atomically $ do
q <- getMsgQueue ms (recipientId qr) msgQueueQuota
mapM_ (deleteExpiredMsgs q) old
ifM (isFull q) (pure $ err QUOTA) $ do
when (notification msgFlags) $ trySendNotification msg ntfNonceDrg
writeMsg q msg
pure ok
when (sent == OK) $ do
stats <- asks serverStats
atomically $ modifyTVar (msgSent stats) (+ 1)
atomically $ updateActiveQueues stats $ recipientId qr
pure resp
where
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
mkMessage body = do
msgId <- randomId =<< asks (msgIdBytes . config)
msgTs <- liftIO getSystemTime
pure $ Message msgId msgTs msgFlags body
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
mkMessage body = do
msgId <- randomId =<< asks (msgIdBytes . config)
msgTs <- liftIO getSystemTime
pure $ Message msgId msgTs msgFlags body
trySendNotification :: Message -> TVar ChaChaDRG -> STM ()
trySendNotification msg ntfNonceDrg =
forM_ (notifier qr) $ \NtfCreds {notifierId, rcvNtfDhSecret} ->
mapM_ (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers
trySendNotification :: Message -> TVar ChaChaDRG -> STM ()
trySendNotification msg ntfNonceDrg =
forM_ (notifier qr) $ \NtfCreds {notifierId, rcvNtfDhSecret} ->
mapM_ (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers
writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM ()
writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} =
unlessM (isFullTBQueue sndQ) $ do
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg
writeTBQueue q (CorrId "", nId, NMSG nmsgNonce encNMsgMeta)
writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM ()
writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} =
unlessM (isFullTBQueue sndQ) $ do
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg
writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)]
mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do
cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg
let msgMeta = NMsgMeta {msgId, msgTs}
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
pure . (cbNonce,) $ fromRight "" encNMsgMeta
mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do
cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg
let msgMeta = NMsgMeta {msgId, msgTs}
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
pure . (cbNonce,) $ fromRight "" encNMsgMeta
deliverMessage :: QueueStore -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg)
deliverMessage st rId sub q msg_ = withRcvQueue st rId $ \qr -> do
deliverMessage :: QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg)
deliverMessage qr rId sub q msg_ = do
readTVarIO sub >>= \case
s@Sub {subThread = NoSub} ->
case msg_ of
Just msg ->
let encMsg = encryptMsg qr msg
in atomically (setDelivered s msg) $> (corrId, rId, MSG encMsg)
_ -> forkSub qr $> ok
_ -> forkSub $> ok
_ -> pure ok
where
forkSub :: QueueRec -> m ()
forkSub qr = do
forkSub :: m ()
forkSub = do
atomically . modifyTVar sub $ \s -> s {subThread = SubPending}
t <- mkWeakThreadId =<< forkIO subscriber
atomically . modifyTVar sub $ \case
@ -610,7 +620,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
subscriber = atomically $ do
msg <- peekMsg q
let encMsg = encryptMsg qr msg
writeTBQueue sndQ (CorrId "", rId, MSG encMsg)
writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)]
s <- readTVar sub
void $ setDelivered s msg
writeTVar sub s {subThread = NoSub}

View File

@ -9,6 +9,7 @@ import Control.Concurrent (ThreadId)
import Control.Monad.IO.Unlift
import Crypto.Random
import Data.ByteString.Char8 (ByteString)
import Data.List.NonEmpty (NonEmpty)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Time.Clock (getCurrentTime)
@ -103,8 +104,8 @@ data Server = Server
data Client = Client
{ subscriptions :: TMap RecipientId (TVar Sub),
ntfSubscriptions :: TMap NotifierId (),
rcvQ :: TBQueue (Transmission Cmd),
sndQ :: TBQueue (Transmission BrokerMsg),
rcvQ :: TBQueue (NonEmpty (Maybe QueueRec, Transmission Cmd)),
sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)),
thVersion :: Version,
sessionId :: ByteString,
connected :: TVar Bool,

View File

@ -96,7 +96,7 @@ smpBlockSize :: Int
smpBlockSize = 16384
supportedSMPServerVRange :: VersionRange
supportedSMPServerVRange = mkVersionRange 1 3
supportedSMPServerVRange = mkVersionRange 1 4
simplexMQVersion :: String
simplexMQVersion = "3.0.1"
@ -258,7 +258,10 @@ data THandle c = THandle
sessionId :: SessionId,
blockSize :: Int,
-- | agreed server protocol version
thVersion :: Version
thVersion :: Version,
-- | send multiple transmissions in a single block
-- based on protocol and protocol version
batch :: Bool
}
-- | TLS-unique channel binding
@ -364,7 +367,7 @@ smpServerHandshake c kh smpVRange = do
| keyHash /= kh ->
throwE $ TEHandshake IDENTITY
| smpVersion `isCompatible` smpVRange -> do
pure (th :: THandle c) {thVersion = smpVersion}
pure $ smpThHandle th smpVersion
| otherwise -> throwE $ TEHandshake VERSION
-- | Client SMP transport handshake.
@ -379,9 +382,12 @@ smpClientHandshake c keyHash smpVRange = do
else case smpVersionRange `compatibleVersion` smpVRange of
Just (Compatible smpVersion) -> do
sendHandshake th $ ClientHandshake {smpVersion, keyHash}
pure (th :: THandle c) {thVersion = smpVersion}
pure $ smpThHandle th smpVersion
Nothing -> throwE $ TEHandshake VERSION
smpThHandle :: forall c. THandle c -> Version -> THandle c
smpThHandle th v = (th :: THandle c) {thVersion = v, batch = v >= 4}
sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO ()
sendHandshake th = ExceptT . tPutBlock th . smpEncode
@ -389,4 +395,4 @@ getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportErr
getHandshake th = ExceptT $ (parse smpP (TEHandshake PARSE) =<<) <$> tGetBlock th
smpTHandle :: Transport c => c -> THandle c
smpTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0}
smpTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, batch = False}

View File

@ -1,3 +1,4 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

View File

@ -17,12 +17,16 @@ module AgentTests.FunctionalAPITests
)
where
import Control.Concurrent (threadDelay)
import Control.Concurrent (killThread, threadDelay)
import Control.Monad
import Control.Monad.Except (ExceptT, runExceptT)
import Control.Monad.IO.Unlift
import Data.Int (Int64)
import qualified Data.Map as M
import qualified Data.Set as S
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
import SMPAgentClient
import SMPClient (cfg, testPort, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn)
import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn)
import Simplex.Messaging.Agent
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
import Simplex.Messaging.Agent.Protocol
@ -102,6 +106,9 @@ functionalAPITests t = do
testSuspendingAgentCompleteSending t
it "should suspend agent on timeout, even if pending messages not sent" $
testSuspendingAgentTimeout t
describe "Batching SMP commands" $ do
it "should subscribe to multiple subscriptions with batching" $
testBatchedSubscriptions t
testAgentClient :: IO ()
testAgentClient = do
@ -503,13 +510,64 @@ testSuspendingAgentTimeout t = do
pure ()
testBatchedSubscriptions :: ATransport -> IO ()
testBatchedSubscriptions t = do
a <- getSMPAgentClient agentCfg initAgentServers2
b <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers2
Right conns <- runServers $ do
conns <- forM [1 .. 200 :: Int] . const $ makeConnection a b
forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId
forM_ (take 10 conns) $ \(aId, bId) -> do
deleteConnection a bId
deleteConnection b aId
liftIO $ threadDelay 1000000
pure conns
("", "", DOWN {}) <- get a
("", "", DOWN {}) <- get a
("", "", DOWN {}) <- get b
("", "", DOWN {}) <- get b
Right () <- runServers $ do
("", "", UP {}) <- get a
("", "", UP {}) <- get a
("", "", UP {}) <- get b