SMP protocol v4: batching multiple server commands/responses in a transport block (#470)
* batch server commands in one transport block * subscribe to multiple queues using batched commands * agent method to subscribe to multiple queues using batched commands * refactor * test for batched subscriptions * delete part of connections in batched test * add resubscribeConnections * remove comment * update SMP protocol doc
This commit is contained in:
parent
1670c9c05e
commit
a6f401041a
|
@ -51,6 +51,8 @@ It's designed with the focus on communication security and integrity, under the
|
|||
|
||||
It is designed as a low level protocol for other application protocols to solve the problem of secure and private message transmission, making [MITM attack][1] very difficult at any part of the message transmission system.
|
||||
|
||||
This document describes SMP protocol versions 3 and 4, the previous versions are discontinued.
|
||||
|
||||
## Introduction
|
||||
|
||||
The objective of Simplex Messaging Protocol (SMP) is to facilitate the secure and private unidirectional transfer of messages from senders to recipients via persistent simplex queues managed by the message broker (server).
|
||||
|
@ -362,15 +364,16 @@ The clients can optionally instruct a dedicated push notification server to subs
|
|||
|
||||
[`SEND` command](#send-message) includes the notification flag to instruct SMP server whether to send the notification - this flag is forwarded to the recepient inside encrypted envelope, together with the timestamp and the message body, so even if TLS is compromised this flag cannot be used for traffic correlation.
|
||||
|
||||
## SMP Transmission structure
|
||||
## SMP Transmission andtransport block structure
|
||||
|
||||
Each transport block (SMP transmission) has a fixed size of 16384 bytes for traffic uniformity.
|
||||
|
||||
From SMP version 4 each block can contain multiple transmissions, version 3 blocks have 1 transmission.
|
||||
Some parts of SMP transmission are padded to a fixed size; this padding is uniformly added as a word16 encoded in network byte order - see `paddedString` syntax.
|
||||
|
||||
In places where some part of the transmission should be padded, the syntax for `paddedNotation` is used:
|
||||
|
||||
```
|
||||
```abnf
|
||||
paddedString = originalLength string pad
|
||||
originalLength = 2*2 OCTET
|
||||
pad = N*N"#" ; where N = paddedLength - originalLength - 2
|
||||
|
@ -380,9 +383,9 @@ paddedNotation = <padded(string, paddedLength)>
|
|||
; paddedLength - required length after padding, including 2 bytes for originalLength
|
||||
```
|
||||
|
||||
Each transmission between the client and the server must have this format/syntax:
|
||||
Each transmission/block for SMP v3 between the client and the server must have this format/syntax:
|
||||
|
||||
```
|
||||
```abnf
|
||||
paddedTransmission = <padded(transmission), 16384>
|
||||
transmission = [signature] SP signed
|
||||
signed = sessionIdentifier SP [corrId] SP [queueId] SP smpCommand
|
||||
|
@ -399,6 +402,16 @@ encoded = <base64 encoded binary>
|
|||
|
||||
`base64` encoding should be used with padding, as defined in section 4 of [RFC 4648][9]
|
||||
|
||||
Transport block for SMP v4 has this syntax:
|
||||
|
||||
```abnf
|
||||
paddedTransportBlock = <padded(transportBlock), 16384>
|
||||
transportBlock = transmissionCount transmissions
|
||||
transmissionCount = 1*1 OCTET ; equal or greater than 1
|
||||
transmissions = transmissionLength transmission [transmissions]
|
||||
transmissionLength = 2*2 OCTET ; word16 encoded in network byte order
|
||||
```
|
||||
|
||||
## SMP commands
|
||||
|
||||
Commands syntax below is provided using [ABNF][8] with [case-sensitive strings extension][8a].
|
||||
|
|
|
@ -45,9 +45,11 @@ module Simplex.Messaging.Agent
|
|||
acceptContact,
|
||||
rejectContact,
|
||||
subscribeConnection,
|
||||
subscribeConnections,
|
||||
getConnectionMessage,
|
||||
getNotificationMessage,
|
||||
resubscribeConnection,
|
||||
resubscribeConnections,
|
||||
sendMessage,
|
||||
ackMessage,
|
||||
suspendConnection,
|
||||
|
@ -79,6 +81,7 @@ import Data.Composition ((.:), (.:.))
|
|||
import Data.Functor (($>))
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (isJust)
|
||||
import qualified Data.Text as T
|
||||
|
@ -105,10 +108,10 @@ import Simplex.Messaging.Parsers (parse)
|
|||
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody, MsgFlags, NtfServer, SMPMsgMeta)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (bshow, eitherToMaybe, liftE, liftError, tryError, unlessM, whenM, ($>>=))
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version
|
||||
import System.Random (randomR)
|
||||
import UnliftIO.Async (async, race_)
|
||||
import UnliftIO.Async (async, mapConcurrently, race_)
|
||||
import UnliftIO.Concurrent (forkFinally, forkIO, threadDelay)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
@ -158,6 +161,10 @@ rejectContact c = withAgentEnv c .: rejectContact' c
|
|||
subscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
|
||||
subscribeConnection c = withAgentEnv c . subscribeConnection' c
|
||||
|
||||
-- | Subscribe to receive connection messages from multiple connections, batching commands when possible
|
||||
subscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
|
||||
subscribeConnections c = withAgentEnv c . subscribeConnections' c
|
||||
|
||||
-- | Get connection message (GET command)
|
||||
getConnectionMessage :: AgentErrorMonad m => AgentClient -> ConnId -> m (Maybe SMPMsgMeta)
|
||||
getConnectionMessage c = withAgentEnv c . getConnectionMessage' c
|
||||
|
@ -169,6 +176,9 @@ getNotificationMessage c = withAgentEnv c .: getNotificationMessage' c
|
|||
resubscribeConnection :: AgentErrorMonad m => AgentClient -> ConnId -> m ()
|
||||
resubscribeConnection c = withAgentEnv c . resubscribeConnection' c
|
||||
|
||||
resubscribeConnections :: AgentErrorMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
|
||||
resubscribeConnections c = withAgentEnv c . resubscribeConnections' c
|
||||
|
||||
-- | Send message to the connection (SEND command)
|
||||
sendMessage :: AgentErrorMonad m => AgentClient -> ConnId -> MsgFlags -> MsgBody -> m AgentMsgId
|
||||
sendMessage c = withAgentEnv c .:. sendMessage' c
|
||||
|
@ -393,12 +403,62 @@ subscribeConnection' c connId =
|
|||
ns <- asks ntfSupervisor
|
||||
atomically $ sendNtfSubCommand ns (connId, NSCCreate)
|
||||
|
||||
subscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
|
||||
subscribeConnections' _ [] = pure M.empty
|
||||
subscribeConnections' c connIds = do
|
||||
conns :: Map ConnId (Either StoreError SomeConn) <- M.fromList . zip connIds <$> withStore' c (forM connIds . getConn)
|
||||
let (errs, cs) = M.mapEither id conns
|
||||
errs' = M.map (Left . storeError) errs
|
||||
(sndQs, rcvQs) = M.mapEither rcvOrSndQueue cs
|
||||
sndRs = M.map (sndSubResult . fst) sndQs
|
||||
srvRcvQs :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) = M.foldlWithKey' addRcvQueue M.empty rcvQs
|
||||
mapM_ (mapM_ (uncurry $ resumeMsgDelivery c) . sndQueue) cs
|
||||
rcvRs <- mapConcurrently subscribe (M.assocs srvRcvQs)
|
||||
let rs = M.unions $ errs' : sndRs : rcvRs
|
||||
notifyResultError rs
|
||||
pure rs
|
||||
where
|
||||
rcvOrSndQueue :: SomeConn -> Either (SndQueue, ConnData) (RcvQueue, ConnData)
|
||||
rcvOrSndQueue = \case
|
||||
SomeConn _ (DuplexConnection cData rq _) -> Right (rq, cData)
|
||||
SomeConn _ (SndConnection cData sq) -> Left (sq, cData)
|
||||
SomeConn _ (RcvConnection cData rq) -> Right (rq, cData)
|
||||
SomeConn _ (ContactConnection cData rq) -> Right (rq, cData)
|
||||
sndSubResult :: SndQueue -> Either AgentErrorType ()
|
||||
sndSubResult sq = case status (sq :: SndQueue) of
|
||||
Confirmed -> Right ()
|
||||
Active -> Left $ CONN SIMPLEX
|
||||
_ -> Left $ INTERNAL "unexpected queue status"
|
||||
addRcvQueue :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) -> ConnId -> (RcvQueue, ConnData) -> Map SMPServer (Map ConnId (RcvQueue, ConnData))
|
||||
addRcvQueue m connId rq@(RcvQueue {server}, _) = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m
|
||||
subscribe :: (SMPServer, Map ConnId (RcvQueue, ConnData)) -> m (Map ConnId (Either AgentErrorType ()))
|
||||
subscribe (srv, qs) = subscribeQueues c srv (M.map fst qs)
|
||||
sndQueue :: SomeConn -> Maybe (ConnData, SndQueue)
|
||||
sndQueue = \case
|
||||
SomeConn _ (DuplexConnection cData _ sq) -> Just (cData, sq)
|
||||
SomeConn _ (SndConnection cData sq) -> Just (cData, sq)
|
||||
_ -> Nothing
|
||||
notifyResultError :: Map ConnId (Either AgentErrorType ()) -> m ()
|
||||
notifyResultError rs = do
|
||||
let actual = M.size rs
|
||||
expected = length connIds
|
||||
when (actual /= expected) . atomically $
|
||||
writeTBQueue (subQ c) ("", "", ERR . INTERNAL $ "subscribeConnections result size: " <> show actual <> ", expected " <> show expected)
|
||||
|
||||
resubscribeConnection' :: AgentMonad m => AgentClient -> ConnId -> m ()
|
||||
resubscribeConnection' c connId =
|
||||
unlessM
|
||||
(atomically $ hasActiveSubscription c connId)
|
||||
(subscribeConnection' c connId)
|
||||
|
||||
resubscribeConnections' :: forall m. AgentMonad m => AgentClient -> [ConnId] -> m (Map ConnId (Either AgentErrorType ()))
|
||||
resubscribeConnections' _ [] = pure M.empty
|
||||
resubscribeConnections' c connIds = do
|
||||
let r = M.fromList . zip connIds . repeat $ Right ()
|
||||
connIds' <- filterM (fmap not . atomically . hasActiveSubscription c) connIds
|
||||
-- union is left-biased, so results returned by subscribeConnections' take precedence
|
||||
(`M.union` r) <$> subscribeConnections' c connIds'
|
||||
|
||||
getConnectionMessage' :: AgentMonad m => AgentClient -> ConnId -> m (Maybe SMPMsgMeta)
|
||||
getConnectionMessage' c connId = do
|
||||
whenM (atomically $ hasActiveSubscription c connId) . throwError $ CMD PROHIBITED
|
||||
|
|
|
@ -21,6 +21,7 @@ module Simplex.Messaging.Agent.Client
|
|||
closeAgentClient,
|
||||
newRcvQueue,
|
||||
subscribeQueue,
|
||||
subscribeQueues,
|
||||
getQueueMessage,
|
||||
decryptSMPMessage,
|
||||
addSubscription,
|
||||
|
@ -64,6 +65,7 @@ module Simplex.Messaging.Agent.Client
|
|||
whenSuspending,
|
||||
withStore,
|
||||
withStore',
|
||||
storeError,
|
||||
)
|
||||
where
|
||||
|
||||
|
@ -74,16 +76,19 @@ import Control.Logger.Simple
|
|||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Bifunctor (bimap, first, second)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (partitionEithers)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (catMaybes)
|
||||
import Data.Set (Set)
|
||||
import Data.Text.Encoding
|
||||
import Data.Tuple (swap)
|
||||
import Data.Word (Word16)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
|
@ -103,7 +108,7 @@ import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), MsgId, N
|
|||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (bshow, catchAll_, ifM, liftEitherError, liftError, tryError, unlessM, whenM)
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version
|
||||
import System.Timeout (timeout)
|
||||
import UnliftIO (async, pooledForConcurrentlyN)
|
||||
|
@ -476,14 +481,42 @@ subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
|
|||
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
|
||||
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
|
||||
atomically $ addPendingSubscription c rq connId
|
||||
withLogClient c server rcvId "SUB" $ \smp -> do
|
||||
liftIO (runExceptT $ subscribeSMPQueue smp rcvPrivateKey rcvId) >>= \case
|
||||
Left e -> do
|
||||
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
||||
removePendingSubscription c server connId
|
||||
throwError e
|
||||
Right _ -> do
|
||||
addSubscription c rq connId
|
||||
withLogClient c server rcvId "SUB" $ \smp ->
|
||||
liftIO (runExceptT (subscribeSMPQueue smp rcvPrivateKey rcvId) >>= processSubResult c rq connId)
|
||||
>>= either throwError pure
|
||||
|
||||
processSubResult :: AgentClient -> RcvQueue -> ConnId -> Either ProtocolClientError () -> IO (Either ProtocolClientError ())
|
||||
processSubResult c rq@RcvQueue {server} connId r = do
|
||||
case r of
|
||||
Left e ->
|
||||
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
||||
removePendingSubscription c server connId
|
||||
_ -> addSubscription c rq connId
|
||||
pure r
|
||||
|
||||
-- | subscribe multiple queues - all passed queues should be on the same server
|
||||
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Map ConnId (Either AgentErrorType ()))
|
||||
subscribeQueues c srv qs = do
|
||||
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
|
||||
forM_ qs_ $ atomically . uncurry (addPendingSubscription c) . swap
|
||||
case L.nonEmpty qs_ of
|
||||
Just qs' -> do
|
||||
smp_ <- tryError (getSMPServerClient c srv)
|
||||
M.fromList . (errs <>) <$> case smp_ of
|
||||
Left e -> pure $ map (second . const $ Left e) qs_
|
||||
Right smp -> do
|
||||
logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB"
|
||||
let qs2 = L.map (queueCreds . snd) qs'
|
||||
rs' :: [((ConnId, RcvQueue), Either ProtocolClientError ())] <-
|
||||
liftIO $ zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
|
||||
forM_ rs' $ \((connId, rq), r) -> liftIO $ processSubResult c rq connId r
|
||||
pure $ map (bimap fst (first $ protocolClientError SMP)) rs'
|
||||
_ -> pure $ M.fromList errs
|
||||
where
|
||||
checkQueue rq@(connId, RcvQueue {rcvId, server}) = do
|
||||
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
|
||||
pure $ if prohibited || srv /= server then Left (connId, Left $ CMD PROHIBITED) else Right rq
|
||||
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)
|
||||
|
||||
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
addSubscription c rq@RcvQueue {server} connId = atomically $ do
|
||||
|
@ -762,15 +795,16 @@ withStore c action = do
|
|||
where
|
||||
handleInternal :: E.SomeException -> IO (Either StoreError a)
|
||||
handleInternal = pure . Left . SEInternal . bshow
|
||||
storeError :: StoreError -> AgentErrorType
|
||||
storeError = \case
|
||||
SEConnNotFound -> CONN NOT_FOUND
|
||||
SEConnDuplicate -> CONN DUPLICATE
|
||||
SEBadConnType CRcv -> CONN SIMPLEX
|
||||
SEBadConnType CSnd -> CONN SIMPLEX
|
||||
SEInvitationNotFound -> CMD PROHIBITED
|
||||
-- this error is never reported as store error,
|
||||
-- it is used to wrap agent operations when "transaction-like" store access is needed
|
||||
-- NOTE: network IO should NOT be used inside AgentStoreMonad
|
||||
SEAgentError e -> e
|
||||
e -> INTERNAL $ show e
|
||||
|
||||
storeError :: StoreError -> AgentErrorType
|
||||
storeError = \case
|
||||
SEConnNotFound -> CONN NOT_FOUND
|
||||
SEConnDuplicate -> CONN DUPLICATE
|
||||
SEBadConnType CRcv -> CONN SIMPLEX
|
||||
SEBadConnType CSnd -> CONN SIMPLEX
|
||||
SEInvitationNotFound -> CMD PROHIBITED
|
||||
-- this error is never reported as store error,
|
||||
-- it is used to wrap agent operations when "transaction-like" store access is needed
|
||||
-- NOTE: network IO should NOT be used inside AgentStoreMonad
|
||||
SEAgentError e -> e
|
||||
e -> INTERNAL $ show e
|
||||
|
|
|
@ -16,8 +16,8 @@ module Simplex.Messaging.Agent.Store.SQLite.Migrations
|
|||
where
|
||||
|
||||
import Control.Monad (forM_)
|
||||
import Data.Function (on)
|
||||
import Data.List (intercalate, sortBy)
|
||||
import Data.Ord (comparing)
|
||||
import Data.Text (Text)
|
||||
import Data.Time.Clock (getCurrentTime)
|
||||
import Database.SQLite.Simple (Connection, Only (..), Query (..))
|
||||
|
@ -44,7 +44,7 @@ schemaMigrations =
|
|||
|
||||
-- | The list of migrations in ascending order by date
|
||||
app :: [Migration]
|
||||
app = sortBy (compare `on` name) $ map migration schemaMigrations
|
||||
app = sortBy (comparing name) $ map migration schemaMigrations
|
||||
where
|
||||
migration (name, query) = Migration {name = name, up = fromQuery query}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
@ -32,11 +33,14 @@ module Simplex.Messaging.Client
|
|||
-- * SMP protocol command functions
|
||||
createSMPQueue,
|
||||
subscribeSMPQueue,
|
||||
subscribeSMPQueues,
|
||||
getSMPMessage,
|
||||
subscribeSMPQueueNotifications,
|
||||
secureSMPQueue,
|
||||
enableSMPQueueNotifications,
|
||||
disableSMPQueueNotifications,
|
||||
enableSMPQueuesNtfs,
|
||||
disableSMPQueuesNtfs,
|
||||
sendSMPMessage,
|
||||
ackSMPMessage,
|
||||
suspendSMPQueue,
|
||||
|
@ -60,6 +64,10 @@ import Control.Monad.IO.Class (liftIO)
|
|||
import Control.Monad.Trans.Except
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (rights)
|
||||
import Data.Functor (($>))
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
|
@ -87,13 +95,16 @@ data ProtocolClient msg = ProtocolClient
|
|||
tcpTimeout :: Int,
|
||||
clientCorrId :: TVar Natural,
|
||||
sentCommands :: TMap CorrId (Request msg),
|
||||
sndQ :: TBQueue SentRawTransmission,
|
||||
rcvQ :: TBQueue (SignedTransmission msg),
|
||||
sndQ :: TBQueue (NonEmpty (SentRawTransmission)),
|
||||
rcvQ :: TBQueue (NonEmpty (SignedTransmission msg)),
|
||||
msgQ :: Maybe (TBQueue (ServerTransmission msg))
|
||||
}
|
||||
|
||||
type SMPClient = ProtocolClient SMP.BrokerMsg
|
||||
|
||||
-- | Type for client command data
|
||||
type ClientCommand msg = (Maybe C.APrivateSignKey, QueueId, ProtoCommand msg)
|
||||
|
||||
-- | Type synonym for transmission from some SPM server queue.
|
||||
type ServerTransmission msg = (ProtoServer msg, Version, SessionId, QueueId, msg)
|
||||
|
||||
|
@ -208,13 +219,15 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tc
|
|||
runExceptT $ sendProtocolCommand c Nothing "" protocolPing
|
||||
|
||||
process :: ProtocolClient msg -> IO ()
|
||||
process c@ProtocolClient {rcvQ, sentCommands} = forever $ do
|
||||
(_, _, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ
|
||||
process c = forever $ atomically (readTBQueue $ rcvQ c) >>= mapM_ (processMsg c)
|
||||
|
||||
processMsg :: ProtocolClient msg -> SignedTransmission msg -> IO ()
|
||||
processMsg c@ProtocolClient {sentCommands} (_, _, (corrId, qId, respOrErr)) =
|
||||
if B.null $ bs corrId
|
||||
then sendMsg qId respOrErr
|
||||
then sendMsg respOrErr
|
||||
else do
|
||||
atomically (TM.lookup corrId sentCommands) >>= \case
|
||||
Nothing -> sendMsg qId respOrErr
|
||||
Nothing -> sendMsg respOrErr
|
||||
Just Request {queueId, responseVar} -> atomically $ do
|
||||
TM.delete corrId sentCommands
|
||||
putTMVar responseVar $
|
||||
|
@ -226,8 +239,8 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tc
|
|||
_ -> Right r
|
||||
else Left . PCEUnexpectedResponse $ bshow respOrErr
|
||||
where
|
||||
sendMsg :: QueueId -> Either ErrorType msg -> IO ()
|
||||
sendMsg qId = \case
|
||||
sendMsg :: Either ErrorType msg -> IO ()
|
||||
sendMsg = \case
|
||||
Right msg -> atomically $ mapM_ (`writeTBQueue` serverTransmission c qId msg) msgQ
|
||||
-- TODO send everything else to errQ and log in agent
|
||||
_ -> return ()
|
||||
|
@ -285,12 +298,22 @@ subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT Pr
|
|||
subscribeSMPQueue c rpKey rId =
|
||||
sendSMPCommand c (Just rpKey) rId SUB >>= \case
|
||||
OK -> return ()
|
||||
cmd@MSG {} -> writeSMPMessage c rId cmd
|
||||
cmd@MSG {} -> liftIO $ writeSMPMessage c rId cmd
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> ExceptT ProtocolClientError IO ()
|
||||
writeSMPMessage c rId msg =
|
||||
liftIO . atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ c)
|
||||
-- | Subscribe to multiple SMP queues batching commands if supported.
|
||||
subscribeSMPQueues :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
|
||||
subscribeSMPQueues c qs = sendProtocolCommands c cs >>= mapM response . L.zip qs
|
||||
where
|
||||
cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient SUB)) qs
|
||||
response ((_, rId), r) = case r of
|
||||
Right OK -> pure $ Right ()
|
||||
Right cmd@MSG {} -> writeSMPMessage c rId cmd $> Right ()
|
||||
Right r' -> pure . Left . PCEUnexpectedResponse $ bshow r'
|
||||
Left e -> pure $ Left e
|
||||
|
||||
writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> IO ()
|
||||
writeSMPMessage c rId msg = atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ c)
|
||||
|
||||
serverTransmission :: ProtocolClient msg -> RecipientId -> msg -> ServerTransmission msg
|
||||
serverTransmission ProtocolClient {protocolServer, thVersion, sessionId} entityId message =
|
||||
|
@ -303,9 +326,7 @@ getSMPMessage :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT Protoc
|
|||
getSMPMessage c rpKey rId =
|
||||
sendSMPCommand c (Just rpKey) rId GET >>= \case
|
||||
OK -> pure Nothing
|
||||
cmd@(MSG msg) -> do
|
||||
writeSMPMessage c rId cmd
|
||||
pure $ Just msg
|
||||
cmd@(MSG msg) -> liftIO (writeSMPMessage c rId cmd) $> Just msg
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
-- | Subscribe to the SMP queue notifications.
|
||||
|
@ -329,12 +350,32 @@ enableSMPQueueNotifications c rpKey rId notifierKey rcvNtfPublicDhKey =
|
|||
NID nId rcvNtfSrvPublicDhKey -> pure (nId, rcvNtfSrvPublicDhKey)
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
-- | Enable notifications for the multiple queues for push notifications server.
|
||||
enableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId, NtfPublicVerifyKey, RcvNtfPublicDhKey) -> IO (NonEmpty (Either ProtocolClientError (NotifierId, RcvNtfPublicDhKey)))
|
||||
enableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs
|
||||
where
|
||||
cs = L.map (\(rpKey, rId, notifierKey, rcvNtfPublicDhKey) -> (Just rpKey, rId, Cmd SRecipient $ NKEY notifierKey rcvNtfPublicDhKey)) qs
|
||||
response = \case
|
||||
Right (NID nId rcvNtfSrvPublicDhKey) -> Right (nId, rcvNtfSrvPublicDhKey)
|
||||
Right r -> Left . PCEUnexpectedResponse $ bshow r
|
||||
Left e -> Left e
|
||||
|
||||
-- | Disable notifications for the queue for push notifications server.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#disable-notifications-command
|
||||
disableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
|
||||
disableSMPQueueNotifications = okSMPCommand NDEL
|
||||
|
||||
-- | Disable notifications for multiple queues for push notifications server.
|
||||
disableSMPQueuesNtfs :: SMPClient -> NonEmpty (RcvPrivateSignKey, RecipientId) -> IO (NonEmpty (Either ProtocolClientError ()))
|
||||
disableSMPQueuesNtfs c qs = L.map response <$> sendProtocolCommands c cs
|
||||
where
|
||||
cs = L.map (\(rpKey, rId) -> (Just rpKey, rId, Cmd SRecipient NDEL)) qs
|
||||
response = \case
|
||||
Right OK -> Right ()
|
||||
Right r -> Left . PCEUnexpectedResponse $ bshow r
|
||||
Left e -> Left e
|
||||
|
||||
-- | Send SMP message.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#send-message
|
||||
|
@ -351,7 +392,7 @@ ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> MsgId -> ExceptT P
|
|||
ackSMPMessage c rpKey rId msgId =
|
||||
sendSMPCommand c (Just rpKey) rId (ACK msgId) >>= \case
|
||||
OK -> return ()
|
||||
cmd@MSG {} -> writeSMPMessage c rId cmd
|
||||
cmd@MSG {} -> liftIO $ writeSMPMessage c rId cmd
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
-- | Irreversibly suspend SMP queue.
|
||||
|
@ -377,37 +418,48 @@ okSMPCommand cmd c pKey qId =
|
|||
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT ProtocolClientError IO BrokerMsg
|
||||
sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
|
||||
|
||||
-- | Send multiple commands with batching and collect responses
|
||||
sendProtocolCommands :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> NonEmpty (ClientCommand msg) -> IO (NonEmpty (Either ProtocolClientError msg))
|
||||
sendProtocolCommands c@ProtocolClient {sndQ, tcpTimeout} cs = do
|
||||
ts <- mapM (runExceptT . mkTransmission c) cs
|
||||
mapM_ (atomically . writeTBQueue sndQ . L.map fst) . L.nonEmpty . rights $ L.toList ts
|
||||
forConcurrently ts $ \case
|
||||
Right (_, r) -> withTimeout . atomically $ takeTMVar r
|
||||
Left e -> pure $ Left e
|
||||
where
|
||||
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
|
||||
|
||||
-- | Send Protocol command
|
||||
sendProtocolCommand :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT ProtocolClientError IO msg
|
||||
sendProtocolCommand ProtocolClient {sndQ, sentCommands, clientCorrId, sessionId, thVersion, tcpTimeout} pKey qId cmd = do
|
||||
corrId <- lift_ getNextCorrId
|
||||
t <- signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd)
|
||||
ExceptT $ sendRecv corrId t
|
||||
sendProtocolCommand c@ProtocolClient {sndQ, tcpTimeout} pKey qId cmd = do
|
||||
(t, r) <- mkTransmission c (pKey, qId, cmd)
|
||||
ExceptT $ sendRecv t r
|
||||
where
|
||||
lift_ :: STM a -> ExceptT ProtocolClientError IO a
|
||||
lift_ action = ExceptT $ Right <$> atomically action
|
||||
-- two separate "atomically" needed to avoid blocking
|
||||
sendRecv :: SentRawTransmission -> TMVar (Response msg) -> IO (Response msg)
|
||||
sendRecv t r = atomically (writeTBQueue sndQ [t]) >> withTimeout (atomically $ takeTMVar r)
|
||||
where
|
||||
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
|
||||
|
||||
mkTransmission :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> ClientCommand msg -> ExceptT ProtocolClientError IO (SentRawTransmission, TMVar (Response msg))
|
||||
mkTransmission ProtocolClient {clientCorrId, sessionId, thVersion, sentCommands} (pKey, qId, cmd) = do
|
||||
corrId <- liftIO $ atomically getNextCorrId
|
||||
t <- signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd)
|
||||
r <- liftIO . atomically $ mkRequest corrId
|
||||
pure (t, r)
|
||||
where
|
||||
getNextCorrId :: STM CorrId
|
||||
getNextCorrId = do
|
||||
i <- stateTVar clientCorrId $ \i -> (i, i + 1)
|
||||
pure . CorrId $ bshow i
|
||||
|
||||
signTransmission :: ByteString -> ExceptT ProtocolClientError IO SentRawTransmission
|
||||
signTransmission t = case pKey of
|
||||
Nothing -> return (Nothing, t)
|
||||
Nothing -> pure (Nothing, t)
|
||||
Just pk -> do
|
||||
sig <- liftError PCESignatureError $ C.sign pk t
|
||||
return (Just sig, t)
|
||||
|
||||
-- two separate "atomically" needed to avoid blocking
|
||||
sendRecv :: CorrId -> SentRawTransmission -> IO (Response msg)
|
||||
sendRecv corrId t = atomically (send corrId t) >>= withTimeout . atomically . takeTMVar
|
||||
where
|
||||
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
|
||||
|
||||
send :: CorrId -> SentRawTransmission -> STM (TMVar (Response msg))
|
||||
send corrId t = do
|
||||
mkRequest :: CorrId -> STM (TMVar (Response msg))
|
||||
mkRequest corrId = do
|
||||
r <- newEmptyTMVar
|
||||
TM.insert corrId (Request qId r) sentCommands
|
||||
writeTBQueue sndQ t
|
||||
return r
|
||||
pure r
|
||||
|
|
|
@ -13,6 +13,7 @@ module Simplex.Messaging.Encoding
|
|||
Large (..),
|
||||
smpEncodeList,
|
||||
smpListP,
|
||||
lenEncode,
|
||||
)
|
||||
where
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
@ -248,22 +249,23 @@ clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connecte
|
|||
|
||||
receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m ()
|
||||
receive th NtfServerClient {rcvQ, sndQ, activeAt} = forever $ do
|
||||
t@(_, _, (corrId, entId, cmdOrError)) <- tGet th
|
||||
atomically . writeTVar activeAt =<< liftIO getSystemTime
|
||||
logDebug "received transmission"
|
||||
case cmdOrError of
|
||||
Left e -> write sndQ (corrId, entId, NRErr e)
|
||||
Right cmd ->
|
||||
verifyNtfTransmission t cmd >>= \case
|
||||
VRVerified req -> write rcvQ req
|
||||
VRFailed -> write sndQ (corrId, entId, NRErr AUTH)
|
||||
ts <- tGet th
|
||||
forM_ ts $ \t@(_, _, (corrId, entId, cmdOrError)) -> do
|
||||
atomically . writeTVar activeAt =<< liftIO getSystemTime
|
||||
logDebug "received transmission"
|
||||
case cmdOrError of
|
||||
Left e -> write sndQ (corrId, entId, NRErr e)
|
||||
Right cmd ->
|
||||
verifyNtfTransmission t cmd >>= \case
|
||||
VRVerified req -> write rcvQ req
|
||||
VRFailed -> write sndQ (corrId, entId, NRErr AUTH)
|
||||
where
|
||||
write q t = atomically $ writeTBQueue q t
|
||||
|
||||
send :: (Transport c, MonadUnliftIO m) => THandle c -> NtfServerClient -> m ()
|
||||
send h@THandle {thVersion = v} NtfServerClient {sndQ, sessionId, activeAt} = forever $ do
|
||||
t <- atomically $ readTBQueue sndQ
|
||||
void . liftIO $ tPut h (Nothing, encodeTransmission v sessionId t)
|
||||
void . liftIO $ tPut h [(Nothing, encodeTransmission v sessionId t)]
|
||||
atomically . writeTVar activeAt =<< liftIO getSystemTime
|
||||
|
||||
-- instance Show a => Show (TVar a) where
|
||||
|
|
|
@ -69,4 +69,4 @@ ntfClientHandshake c keyHash ntfVRange = do
|
|||
Nothing -> throwError $ TEHandshake VERSION
|
||||
|
||||
ntfTHandle :: Transport c => c -> THandle c
|
||||
ntfTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0}
|
||||
ntfTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0, batch = False}
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
|
@ -130,6 +131,8 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
|
|||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Kind
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.String
|
||||
import Data.Time.Clock.System (SystemTime (..))
|
||||
|
@ -1010,20 +1013,51 @@ instance Encoding CommandError where
|
|||
_ -> fail "bad command error type"
|
||||
|
||||
-- | Send signed SMP transmission to TCP transport.
|
||||
tPut :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ())
|
||||
tPut th (sig, t) = tPutBlock th $ smpEncode (C.signatureBytes sig) <> t
|
||||
tPut :: Transport c => THandle c -> NonEmpty (SentRawTransmission) -> IO (NonEmpty (Either TransportError ()))
|
||||
tPut th trs
|
||||
| batch th = tPutBatch [] $ L.map tEncode trs
|
||||
| otherwise = forM trs $ tPutBlock th . tEncode
|
||||
where
|
||||
tPutBatch :: [Either TransportError ()] -> NonEmpty ByteString -> IO (NonEmpty (Either TransportError ()))
|
||||
tPutBatch rs ts = do
|
||||
let (n, s, ts_) = encodeBatch 0 "" ts
|
||||
r <- if n == 0 then pure [Left TELargeMsg] else replicate n <$> tPutBlock th (lenEncode n `B.cons` s)
|
||||
let rs' = rs <> r
|
||||
case ts_ of
|
||||
Just ts' -> tPutBatch rs' ts'
|
||||
_ -> pure $ L.fromList rs'
|
||||
encodeBatch :: Int -> ByteString -> NonEmpty ByteString -> (Int, ByteString, Maybe (NonEmpty ByteString))
|
||||
encodeBatch n s ts@(t :| ts_)
|
||||
| n == 255 = (n, s, Just ts)
|
||||
| otherwise =
|
||||
let s' = s <> smpEncode (Large t)
|
||||
n' = n + 1
|
||||
in if B.length s' > blockSize th - 1
|
||||
then (n,s,) $ if n == 0 then L.nonEmpty ts_ else Just ts
|
||||
else case L.nonEmpty ts_ of
|
||||
Just ts' -> encodeBatch n' s' ts'
|
||||
_ -> (n', s', Nothing)
|
||||
tEncode (sig, tr) = smpEncode (C.signatureBytes sig) <> tr
|
||||
|
||||
encodeTransmission :: ProtocolEncoding c => Version -> ByteString -> Transmission c -> ByteString
|
||||
encodeTransmission v sessionId (CorrId corrId, queueId, command) =
|
||||
smpEncode (sessionId, corrId, queueId) <> encodeProtocol v command
|
||||
|
||||
-- | Receive and parse transmission from the TCP transport (ignoring any trailing padding).
|
||||
tGetParse :: Transport c => THandle c -> IO (Either TransportError RawTransmission)
|
||||
tGetParse th = (parse transmissionP TEBadBlock =<<) <$> tGetBlock th
|
||||
tGetParse :: Transport c => THandle c -> IO (NonEmpty (Either TransportError RawTransmission))
|
||||
tGetParse th
|
||||
| batch th = either ((:| []) . Left) id <$> runExceptT getBatch
|
||||
| otherwise = (:| []) . (parse transmissionP TEBadBlock =<<) <$> tGetBlock th
|
||||
where
|
||||
getBatch :: ExceptT TransportError IO (NonEmpty (Either TransportError RawTransmission))
|
||||
getBatch = do
|
||||
s <- ExceptT $ tGetBlock th
|
||||
ts <- liftEither $ parse smpP TEBadBlock s
|
||||
pure $ L.map (\(Large t) -> parse transmissionP TEBadBlock t) ts
|
||||
|
||||
-- | Receive client and server transmissions (determined by `cmd` type).
|
||||
tGet :: forall cmd c m. (ProtocolEncoding cmd, Transport c, MonadIO m) => THandle c -> m (SignedTransmission cmd)
|
||||
tGet th@THandle {sessionId, thVersion = v} = liftIO (tGetParse th) >>= decodeParseValidate
|
||||
tGet :: forall cmd c m. (ProtocolEncoding cmd, Transport c, MonadIO m) => THandle c -> m (NonEmpty (SignedTransmission cmd))
|
||||
tGet th@THandle {sessionId, thVersion = v} = liftIO (tGetParse th) >>= mapM decodeParseValidate
|
||||
where
|
||||
decodeParseValidate :: Either TransportError RawTransmission -> m (SignedTransmission cmd)
|
||||
decodeParseValidate = \case
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
@ -43,9 +44,10 @@ import Crypto.Random
|
|||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (fromRight)
|
||||
import Data.Either (fromRight, partitionEithers)
|
||||
import Data.Functor (($>))
|
||||
import Data.List (intercalate)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Set (Set)
|
||||
|
@ -145,7 +147,7 @@ smpServer started = do
|
|||
endPreviousSubscriptions :: (QueueId, Client) -> m (Maybe s)
|
||||
endPreviousSubscriptions (qId, c) = do
|
||||
void . forkIO . atomically $
|
||||
writeTBQueue (sndQ c) (CorrId "", qId, END)
|
||||
writeTBQueue (sndQ c) [(CorrId "", qId, END)]
|
||||
atomically $ TM.lookupDelete qId (clientSubs c)
|
||||
|
||||
expireMessagesThread_ :: ServerConfig -> [m ()]
|
||||
|
@ -243,26 +245,36 @@ cancelSub sub =
|
|||
Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread
|
||||
_ -> return ()
|
||||
|
||||
receive :: (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> Client -> m ()
|
||||
receive :: forall c m. (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> Client -> m ()
|
||||
receive th Client {rcvQ, sndQ, activeAt} = forever $ do
|
||||
(sig, signed, (corrId, queueId, cmdOrError)) <- tGet th
|
||||
ts <- L.toList <$> tGet th
|
||||
atomically . writeTVar activeAt =<< liftIO getSystemTime
|
||||
case cmdOrError of
|
||||
Left e -> write sndQ (corrId, queueId, ERR e)
|
||||
Right cmd -> do
|
||||
verified <- verifyTransmission sig signed queueId cmd
|
||||
if verified
|
||||
then write rcvQ (corrId, queueId, cmd)
|
||||
else write sndQ (corrId, queueId, ERR AUTH)
|
||||
as <- partitionEithers <$> mapM cmdAction ts
|
||||
write sndQ $ fst as
|
||||
write rcvQ $ snd as
|
||||
where
|
||||
write q t = atomically $ writeTBQueue q t
|
||||
cmdAction :: SignedTransmission Cmd -> m (Either (Transmission BrokerMsg) (Maybe QueueRec, Transmission Cmd))
|
||||
cmdAction (sig, signed, (corrId, queueId, cmdOrError)) =
|
||||
case cmdOrError of
|
||||
Left e -> pure $ Left (corrId, queueId, ERR e)
|
||||
Right cmd -> verified <$> verifyTransmission sig signed queueId cmd
|
||||
where
|
||||
verified = \case
|
||||
VRVerified qr -> Right (qr, (corrId, queueId, cmd))
|
||||
VRFailed -> Left (corrId, queueId, ERR AUTH)
|
||||
write q = mapM_ (atomically . writeTBQueue q) . L.nonEmpty
|
||||
|
||||
send :: (Transport c, MonadUnliftIO m) => THandle c -> Client -> m ()
|
||||
send h@THandle {thVersion = v} Client {sndQ, sessionId, activeAt} = forever $ do
|
||||
t <- atomically $ readTBQueue sndQ
|
||||
-- TODO the line below can return Left, but we ignore it and do not disconnect the client
|
||||
void . liftIO $ tPut h (Nothing, encodeTransmission v sessionId t)
|
||||
ts <- atomically $ L.sortWith tOrder <$> readTBQueue sndQ
|
||||
-- TODO the line below can return Lefts, but we ignore it and do not disconnect the client
|
||||
void . liftIO . tPut h $ L.map ((Nothing,) . encodeTransmission v sessionId) ts
|
||||
atomically . writeTVar activeAt =<< liftIO getSystemTime
|
||||
where
|
||||
tOrder :: Transmission BrokerMsg -> Int
|
||||
tOrder (_, _, cmd) = case cmd of
|
||||
MSG {} -> 0
|
||||
_ -> 1
|
||||
|
||||
disconnectTransport :: (Transport c, MonadUnliftIO m) => THandle c -> client -> (client -> TVar SystemTime) -> ExpirationConfig -> m ()
|
||||
disconnectTransport THandle {connection} c activeAt expCfg = do
|
||||
|
@ -273,23 +285,28 @@ disconnectTransport THandle {connection} c activeAt expCfg = do
|
|||
ts <- readTVarIO $ activeAt c
|
||||
when (systemSeconds ts < old) $ closeConnection connection
|
||||
|
||||
data VerificationResult = VRVerified (Maybe QueueRec) | VRFailed
|
||||
|
||||
verifyTransmission ::
|
||||
forall m. (MonadUnliftIO m, MonadReader Env m) => Maybe C.ASignature -> ByteString -> QueueId -> Cmd -> m Bool
|
||||
forall m. (MonadUnliftIO m, MonadReader Env m) => Maybe C.ASignature -> ByteString -> QueueId -> Cmd -> m VerificationResult
|
||||
verifyTransmission sig_ signed queueId cmd = do
|
||||
case cmd of
|
||||
Cmd SRecipient (NEW k _) -> pure $ verifyCmdSignature sig_ signed k
|
||||
Cmd SRecipient (NEW k _) -> pure $ Nothing `verified` verifyCmdSignature sig_ signed k
|
||||
Cmd SRecipient _ -> verifyCmd SRecipient $ verifyCmdSignature sig_ signed . recipientKey
|
||||
Cmd SSender SEND {} -> verifyCmd SSender $ verifyMaybe . senderKey
|
||||
Cmd SSender PING -> pure True
|
||||
Cmd SSender PING -> pure $ VRVerified Nothing
|
||||
Cmd SNotifier NSUB -> verifyCmd SNotifier $ verifyMaybe . fmap notifierKey . notifier
|
||||
where
|
||||
verifyCmd :: SParty p -> (QueueRec -> Bool) -> m Bool
|
||||
verifyCmd :: SParty p -> (QueueRec -> Bool) -> m VerificationResult
|
||||
verifyCmd party f = do
|
||||
st <- asks queueStore
|
||||
q <- atomically $ getQueue st party queueId
|
||||
pure $ either (const $ maybe False (dummyVerifyCmd signed) sig_ `seq` False) f q
|
||||
q_ <- atomically (getQueue st party queueId)
|
||||
pure $ case q_ of
|
||||
Right q -> Just q `verified` f q
|
||||
_ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
|
||||
verifyMaybe :: Maybe C.APublicVerifyKey -> Bool
|
||||
verifyMaybe = maybe (isNothing sig_) $ verifyCmdSignature sig_ signed
|
||||
verified q cond = if cond then VRVerified q else VRFailed
|
||||
|
||||
verifyCmdSignature :: Maybe C.ASignature -> ByteString -> C.APublicVerifyKey -> Bool
|
||||
verifyCmdSignature sig_ signed key = maybe False (verify key) sig_
|
||||
|
@ -320,16 +337,16 @@ client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server ->
|
|||
client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscribedQ, ntfSubscribedQ, notifiers} =
|
||||
forever $
|
||||
atomically (readTBQueue rcvQ)
|
||||
>>= processCommand
|
||||
>>= mapM processCommand
|
||||
>>= atomically . writeTBQueue sndQ
|
||||
where
|
||||
processCommand :: Transmission Cmd -> m (Transmission BrokerMsg)
|
||||
processCommand (corrId, queueId, cmd) = do
|
||||
processCommand :: (Maybe QueueRec, Transmission Cmd) -> m (Transmission BrokerMsg)
|
||||
processCommand (qr_, (corrId, queueId, cmd)) = do
|
||||
st <- asks queueStore
|
||||
case cmd of
|
||||
Cmd SSender command ->
|
||||
case command of
|
||||
SEND flags msgBody -> sendMessage st flags msgBody
|
||||
SEND flags msgBody -> withQueue $ \qr -> sendMessage qr flags msgBody
|
||||
PING -> pure (corrId, "", PONG)
|
||||
Cmd SNotifier NSUB -> subscribeNotifications
|
||||
Cmd SRecipient command ->
|
||||
|
@ -339,9 +356,9 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
|||
(asks $ allowNewQueues . config)
|
||||
(createQueue st rKey dhKey)
|
||||
(pure (corrId, queueId, ERR AUTH))
|
||||
SUB -> subscribeQueue st queueId
|
||||
GET -> getMessage st
|
||||
ACK msgId -> acknowledgeMsg st msgId
|
||||
SUB -> withQueue (`subscribeQueue` queueId)
|
||||
GET -> withQueue getMessage
|
||||
ACK msgId -> withQueue (`acknowledgeMsg` msgId)
|
||||
KEY sKey -> secureQueue_ st sKey
|
||||
NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey
|
||||
NDEL -> deleteQueueNotifier_ st
|
||||
|
@ -371,14 +388,15 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
|||
addQueueRetry n qik qRec = do
|
||||
ids@(rId, _) <- getIds
|
||||
-- create QueueRec record with these ids and keys
|
||||
atomically (addQueue st $ qRec ids) >>= \case
|
||||
let qr = qRec ids
|
||||
atomically (addQueue st qr) >>= \case
|
||||
Left DUPLICATE_ -> addQueueRetry (n - 1) qik qRec
|
||||
Left e -> pure $ ERR e
|
||||
Right _ -> do
|
||||
withLog (`logCreateById` rId)
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (qCreated stats) (+ 1)
|
||||
subscribeQueue st rId $> IDS (qik ids)
|
||||
subscribeQueue qr rId $> IDS (qik ids)
|
||||
|
||||
logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO ()
|
||||
logCreateById s rId =
|
||||
|
@ -426,8 +444,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
|||
withLog (`logDeleteQueue` queueId)
|
||||
okResp <$> atomically (suspendQueue st queueId)
|
||||
|
||||
subscribeQueue :: QueueStore -> RecipientId -> m (Transmission BrokerMsg)
|
||||
subscribeQueue st rId =
|
||||
subscribeQueue :: QueueRec -> RecipientId -> m (Transmission BrokerMsg)
|
||||
subscribeQueue qr rId =
|
||||
atomically (TM.lookup rId subscriptions) >>= \case
|
||||
Nothing ->
|
||||
atomically newSub >>= deliver
|
||||
|
@ -449,10 +467,10 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
|||
deliver sub = do
|
||||
q <- getStoreMsgQueue rId
|
||||
msg_ <- atomically $ tryPeekMsg q
|
||||
deliverMessage st rId sub q msg_
|
||||
deliverMessage qr rId sub q msg_
|
||||
|
||||
getMessage :: QueueStore -> m (Transmission BrokerMsg)
|
||||
getMessage st =
|
||||
getMessage :: QueueRec -> m (Transmission BrokerMsg)
|
||||
getMessage qr =
|
||||
atomically (TM.lookup queueId subscriptions) >>= \case
|
||||
Nothing ->
|
||||
atomically newSub >>= getMessage_
|
||||
|
@ -471,7 +489,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
|||
TM.insert queueId sub subscriptions
|
||||
pure s
|
||||
getMessage_ :: Sub -> m (Transmission BrokerMsg)
|
||||
getMessage_ s = withRcvQueue st queueId $ \qr -> do
|
||||
getMessage_ s = do
|
||||
q <- getStoreMsgQueue queueId
|
||||
atomically $
|
||||
tryPeekMsg q >>= \case
|
||||
|
@ -480,11 +498,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
|||
in setDelivered s msg $> (corrId, queueId, MSG encMsg)
|
||||
_ -> pure (corrId, queueId, OK)
|
||||
|
||||
withRcvQueue :: QueueStore -> RecipientId -> (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg)
|
||||
withRcvQueue st rId action =
|
||||
atomically (getQueue st SRecipient rId) >>= \case
|
||||
Left e -> pure (corrId, rId, ERR e)
|
||||
Right qr -> action qr
|
||||
withQueue :: (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg)
|
||||
withQueue action = maybe (pure $ err AUTH) action qr_
|
||||
|
||||
subscribeNotifications :: m (Transmission BrokerMsg)
|
||||
subscribeNotifications = atomically $ do
|
||||
|
@ -493,8 +508,8 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
|||
TM.insert queueId () ntfSubscriptions
|
||||
pure ok
|
||||
|
||||
acknowledgeMsg :: QueueStore -> MsgId -> m (Transmission BrokerMsg)
|
||||
acknowledgeMsg st msgId = do
|
||||
acknowledgeMsg :: QueueRec -> MsgId -> m (Transmission BrokerMsg)
|
||||
acknowledgeMsg qr msgId = do
|
||||
atomically (TM.lookup queueId subscriptions) >>= \case
|
||||
Nothing -> pure $ err NO_MSG
|
||||
Just sub ->
|
||||
|
@ -509,7 +524,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
|||
_ -> do
|
||||
(msgDeleted, msg_) <- atomically $ tryDelPeekMsg q msgId
|
||||
when msgDeleted updateStats
|
||||
deliverMessage st queueId sub q msg_
|
||||
deliverMessage qr queueId sub q msg_
|
||||
_ -> pure $ err NO_MSG
|
||||
where
|
||||
getDelivered :: TVar Sub -> STM (Maybe Sub)
|
||||
|
@ -533,74 +548,69 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
|||
where
|
||||
updatePeriod pSel = modifyTVar (pSel stats) (S.insert qId)
|
||||
|
||||
sendMessage :: QueueStore -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg)
|
||||
sendMessage st msgFlags msgBody
|
||||
sendMessage :: QueueRec -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg)
|
||||
sendMessage qr msgFlags msgBody
|
||||
| B.length msgBody > maxMessageLength = pure $ err LARGE_MSG
|
||||
| otherwise = do
|
||||
qr <- atomically $ getQueue st SSender queueId
|
||||
either (return . err) storeMessage qr
|
||||
| otherwise = case status qr of
|
||||
QueueOff -> return $ err AUTH
|
||||
QueueActive ->
|
||||
mapM mkMessage (C.maxLenBS msgBody) >>= \case
|
||||
Left _ -> pure $ err LARGE_MSG
|
||||
Right msg -> do
|
||||
ms <- asks msgStore
|
||||
ServerConfig {messageExpiration, msgQueueQuota} <- asks config
|
||||
old <- liftIO $ mapM expireBeforeEpoch messageExpiration
|
||||
ntfNonceDrg <- asks idsDrg
|
||||
resp@(_, _, sent) <- atomically $ do
|
||||
q <- getMsgQueue ms (recipientId qr) msgQueueQuota
|
||||
mapM_ (deleteExpiredMsgs q) old
|
||||
ifM (isFull q) (pure $ err QUOTA) $ do
|
||||
when (notification msgFlags) $ trySendNotification msg ntfNonceDrg
|
||||
writeMsg q msg
|
||||
pure ok
|
||||
when (sent == OK) $ do
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (msgSent stats) (+ 1)
|
||||
atomically $ updateActiveQueues stats $ recipientId qr
|
||||
pure resp
|
||||
where
|
||||
storeMessage :: QueueRec -> m (Transmission BrokerMsg)
|
||||
storeMessage qr = case status qr of
|
||||
QueueOff -> return $ err AUTH
|
||||
QueueActive ->
|
||||
mapM mkMessage (C.maxLenBS msgBody) >>= \case
|
||||
Left _ -> pure $ err LARGE_MSG
|
||||
Right msg -> do
|
||||
ms <- asks msgStore
|
||||
ServerConfig {messageExpiration, msgQueueQuota} <- asks config
|
||||
old <- liftIO $ mapM expireBeforeEpoch messageExpiration
|
||||
ntfNonceDrg <- asks idsDrg
|
||||
resp@(_, _, sent) <- atomically $ do
|
||||
q <- getMsgQueue ms (recipientId qr) msgQueueQuota
|
||||
mapM_ (deleteExpiredMsgs q) old
|
||||
ifM (isFull q) (pure $ err QUOTA) $ do
|
||||
when (notification msgFlags) $ trySendNotification msg ntfNonceDrg
|
||||
writeMsg q msg
|
||||
pure ok
|
||||
when (sent == OK) $ do
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (msgSent stats) (+ 1)
|
||||
atomically $ updateActiveQueues stats $ recipientId qr
|
||||
pure resp
|
||||
where
|
||||
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
|
||||
mkMessage body = do
|
||||
msgId <- randomId =<< asks (msgIdBytes . config)
|
||||
msgTs <- liftIO getSystemTime
|
||||
pure $ Message msgId msgTs msgFlags body
|
||||
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
|
||||
mkMessage body = do
|
||||
msgId <- randomId =<< asks (msgIdBytes . config)
|
||||
msgTs <- liftIO getSystemTime
|
||||
pure $ Message msgId msgTs msgFlags body
|
||||
|
||||
trySendNotification :: Message -> TVar ChaChaDRG -> STM ()
|
||||
trySendNotification msg ntfNonceDrg =
|
||||
forM_ (notifier qr) $ \NtfCreds {notifierId, rcvNtfDhSecret} ->
|
||||
mapM_ (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers
|
||||
trySendNotification :: Message -> TVar ChaChaDRG -> STM ()
|
||||
trySendNotification msg ntfNonceDrg =
|
||||
forM_ (notifier qr) $ \NtfCreds {notifierId, rcvNtfDhSecret} ->
|
||||
mapM_ (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers
|
||||
|
||||
writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM ()
|
||||
writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} =
|
||||
unlessM (isFullTBQueue sndQ) $ do
|
||||
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg
|
||||
writeTBQueue q (CorrId "", nId, NMSG nmsgNonce encNMsgMeta)
|
||||
writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM ()
|
||||
writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} =
|
||||
unlessM (isFullTBQueue sndQ) $ do
|
||||
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg
|
||||
writeTBQueue q [(CorrId "", nId, NMSG nmsgNonce encNMsgMeta)]
|
||||
|
||||
mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
|
||||
mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do
|
||||
cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg
|
||||
let msgMeta = NMsgMeta {msgId, msgTs}
|
||||
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
|
||||
pure . (cbNonce,) $ fromRight "" encNMsgMeta
|
||||
mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
|
||||
mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do
|
||||
cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg
|
||||
let msgMeta = NMsgMeta {msgId, msgTs}
|
||||
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
|
||||
pure . (cbNonce,) $ fromRight "" encNMsgMeta
|
||||
|
||||
deliverMessage :: QueueStore -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg)
|
||||
deliverMessage st rId sub q msg_ = withRcvQueue st rId $ \qr -> do
|
||||
deliverMessage :: QueueRec -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg)
|
||||
deliverMessage qr rId sub q msg_ = do
|
||||
readTVarIO sub >>= \case
|
||||
s@Sub {subThread = NoSub} ->
|
||||
case msg_ of
|
||||
Just msg ->
|
||||
let encMsg = encryptMsg qr msg
|
||||
in atomically (setDelivered s msg) $> (corrId, rId, MSG encMsg)
|
||||
_ -> forkSub qr $> ok
|
||||
_ -> forkSub $> ok
|
||||
_ -> pure ok
|
||||
where
|
||||
forkSub :: QueueRec -> m ()
|
||||
forkSub qr = do
|
||||
forkSub :: m ()
|
||||
forkSub = do
|
||||
atomically . modifyTVar sub $ \s -> s {subThread = SubPending}
|
||||
t <- mkWeakThreadId =<< forkIO subscriber
|
||||
atomically . modifyTVar sub $ \case
|
||||
|
@ -610,7 +620,7 @@ client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Serv
|
|||
subscriber = atomically $ do
|
||||
msg <- peekMsg q
|
||||
let encMsg = encryptMsg qr msg
|
||||
writeTBQueue sndQ (CorrId "", rId, MSG encMsg)
|
||||
writeTBQueue sndQ [(CorrId "", rId, MSG encMsg)]
|
||||
s <- readTVar sub
|
||||
void $ setDelivered s msg
|
||||
writeTVar sub s {subThread = NoSub}
|
||||
|
|
|
@ -9,6 +9,7 @@ import Control.Concurrent (ThreadId)
|
|||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Time.Clock (getCurrentTime)
|
||||
|
@ -103,8 +104,8 @@ data Server = Server
|
|||
data Client = Client
|
||||
{ subscriptions :: TMap RecipientId (TVar Sub),
|
||||
ntfSubscriptions :: TMap NotifierId (),
|
||||
rcvQ :: TBQueue (Transmission Cmd),
|
||||
sndQ :: TBQueue (Transmission BrokerMsg),
|
||||
rcvQ :: TBQueue (NonEmpty (Maybe QueueRec, Transmission Cmd)),
|
||||
sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)),
|
||||
thVersion :: Version,
|
||||
sessionId :: ByteString,
|
||||
connected :: TVar Bool,
|
||||
|
|
|
@ -96,7 +96,7 @@ smpBlockSize :: Int
|
|||
smpBlockSize = 16384
|
||||
|
||||
supportedSMPServerVRange :: VersionRange
|
||||
supportedSMPServerVRange = mkVersionRange 1 3
|
||||
supportedSMPServerVRange = mkVersionRange 1 4
|
||||
|
||||
simplexMQVersion :: String
|
||||
simplexMQVersion = "3.0.1"
|
||||
|
@ -258,7 +258,10 @@ data THandle c = THandle
|
|||
sessionId :: SessionId,
|
||||
blockSize :: Int,
|
||||
-- | agreed server protocol version
|
||||
thVersion :: Version
|
||||
thVersion :: Version,
|
||||
-- | send multiple transmissions in a single block
|
||||
-- based on protocol and protocol version
|
||||
batch :: Bool
|
||||
}
|
||||
|
||||
-- | TLS-unique channel binding
|
||||
|
@ -364,7 +367,7 @@ smpServerHandshake c kh smpVRange = do
|
|||
| keyHash /= kh ->
|
||||
throwE $ TEHandshake IDENTITY
|
||||
| smpVersion `isCompatible` smpVRange -> do
|
||||
pure (th :: THandle c) {thVersion = smpVersion}
|
||||
pure $ smpThHandle th smpVersion
|
||||
| otherwise -> throwE $ TEHandshake VERSION
|
||||
|
||||
-- | Client SMP transport handshake.
|
||||
|
@ -379,9 +382,12 @@ smpClientHandshake c keyHash smpVRange = do
|
|||
else case smpVersionRange `compatibleVersion` smpVRange of
|
||||
Just (Compatible smpVersion) -> do
|
||||
sendHandshake th $ ClientHandshake {smpVersion, keyHash}
|
||||
pure (th :: THandle c) {thVersion = smpVersion}
|
||||
pure $ smpThHandle th smpVersion
|
||||
Nothing -> throwE $ TEHandshake VERSION
|
||||
|
||||
smpThHandle :: forall c. THandle c -> Version -> THandle c
|
||||
smpThHandle th v = (th :: THandle c) {thVersion = v, batch = v >= 4}
|
||||
|
||||
sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO ()
|
||||
sendHandshake th = ExceptT . tPutBlock th . smpEncode
|
||||
|
||||
|
@ -389,4 +395,4 @@ getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportErr
|
|||
getHandshake th = ExceptT $ (parse smpP (TEHandshake PARSE) =<<) <$> tGetBlock th
|
||||
|
||||
smpTHandle :: Transport c => c -> THandle c
|
||||
smpTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0}
|
||||
smpTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0, batch = False}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
|
|
|
@ -17,12 +17,16 @@ module AgentTests.FunctionalAPITests
|
|||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent (threadDelay)
|
||||
import Control.Concurrent (killThread, threadDelay)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except (ExceptT, runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.Int (Int64)
|
||||
import qualified Data.Map as M
|
||||
import qualified Data.Set as S
|
||||
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
||||
import SMPAgentClient
|
||||
import SMPClient (cfg, testPort, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn)
|
||||
import SMPClient (cfg, testPort, testPort2, testStoreLogFile2, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
|
@ -102,6 +106,9 @@ functionalAPITests t = do
|
|||
testSuspendingAgentCompleteSending t
|
||||
it "should suspend agent on timeout, even if pending messages not sent" $
|
||||
testSuspendingAgentTimeout t
|
||||
describe "Batching SMP commands" $ do
|
||||
it "should subscribe to multiple subscriptions with batching" $
|
||||
testBatchedSubscriptions t
|
||||
|
||||
testAgentClient :: IO ()
|
||||
testAgentClient = do
|
||||
|
@ -503,13 +510,64 @@ testSuspendingAgentTimeout t = do
|
|||
|
||||
pure ()
|
||||
|
||||
testBatchedSubscriptions :: ATransport -> IO ()
|
||||
testBatchedSubscriptions t = do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers2
|
||||
b <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers2
|
||||
Right conns <- runServers $ do
|
||||
conns <- forM [1 .. 200 :: Int] . const $ makeConnection a b
|
||||
forM_ conns $ \(aId, bId) -> exchangeGreetings a bId b aId
|
||||
forM_ (take 10 conns) $ \(aId, bId) -> do
|
||||
deleteConnection a bId
|
||||
deleteConnection b aId
|
||||
liftIO $ threadDelay 1000000
|
||||
pure conns
|
||||
("", "", DOWN {}) <- get a
|
||||
("", "", DOWN {}) <- get a
|
||||
("", "", DOWN {}) <- get b
|
||||
("", "", DOWN {}) <- get b
|
||||
Right () <- runServers $ do
|
||||
("", "", UP {}) <- get a
|
||||
("", "", UP {}) <- get a
|
||||
("", "", UP {}) <- get b
|
||||
("", "", UP {}) <- get b
|
||||
liftIO $ threadDelay 1000000
|
||||
subscribe a $ map snd conns
|
||||
subscribe b $ map fst conns
|
||||
forM_ (drop 10 conns) $ \(aId, bId) -> exchangeGreetingsMsgId 6 a bId b aId
|
||||
pure ()
|
||||
where
|
||||
subscribe :: AgentClient -> [ConnId] -> ExceptT AgentErrorType IO ()
|
||||
subscribe c cs = do
|
||||
r <- subscribeConnections c cs
|
||||
liftIO $ do
|
||||
let dc = S.fromList $ take 10 cs
|
||||
all (== Right ()) (M.withoutKeys r dc) `shouldBe` True
|
||||
all (== Left (CONN NOT_FOUND)) (M.restrictKeys r dc) `shouldBe` True
|
||||
M.keys r `shouldMatchList` cs
|
||||
runServers :: ExceptT AgentErrorType IO a -> IO (Either AgentErrorType a)
|
||||
runServers a = do
|
||||
withSmpServerStoreLogOn t testPort $ \t1 -> do
|
||||
res <- withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile2} testPort2 $ \t2 -> do
|
||||
res <- runExceptT a
|
||||
killThread t2
|
||||
pure res
|
||||
killThread t1
|
||||
pure res
|
||||
|
||||
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
|
||||
exchangeGreetings alice bobId bob aliceId = do
|
||||
4 <- sendMessage alice bobId SMP.noMsgFlags "hello"
|
||||
get alice ##> ("", bobId, SENT 4)
|
||||
exchangeGreetings = exchangeGreetingsMsgId 4
|
||||
|
||||
exchangeGreetingsMsgId :: Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
|
||||
exchangeGreetingsMsgId msgId alice bobId bob aliceId = do
|
||||
msgId1 <- sendMessage alice bobId SMP.noMsgFlags "hello"
|
||||
liftIO $ msgId1 `shouldBe` msgId
|
||||
get alice ##> ("", bobId, SENT msgId)
|
||||
get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False
|
||||
ackMessage bob aliceId 4
|
||||
5 <- sendMessage bob aliceId SMP.noMsgFlags "hello too"
|
||||
get bob ##> ("", aliceId, SENT 5)
|
||||
ackMessage bob aliceId msgId
|
||||
msgId2 <- sendMessage bob aliceId SMP.noMsgFlags "hello too"
|
||||
let msgId' = msgId + 1
|
||||
liftIO $ msgId2 `shouldBe` msgId'
|
||||
get bob ##> ("", aliceId, SENT msgId')
|
||||
get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId 5
|
||||
ackMessage alice bobId msgId'
|
||||
|
|
|
@ -22,7 +22,6 @@ import NtfClient
|
|||
import SMPAgentClient (agentCfg, initAgentServers, testDB, testDB2)
|
||||
import SMPClient (testPort, withSmpServer, withSmpServerStoreLogOn)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Client (AgentClient)
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
|
@ -140,10 +141,10 @@ ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h
|
|||
where
|
||||
tPut' h (sig, corrId, queueId, smp) = do
|
||||
let t' = smpEncode (sessionId (h :: THandle c), corrId, queueId, smp)
|
||||
Right () <- tPut h (sig, t')
|
||||
[Right ()] <- tPut h [(sig, t')]
|
||||
pure ()
|
||||
tGet' h = do
|
||||
(Nothing, _, (CorrId corrId, qId, Right cmd)) <- tGet h
|
||||
[(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h
|
||||
pure (Nothing, corrId, qId, cmd)
|
||||
|
||||
ntfTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
|
||||
|
|
|
@ -25,6 +25,8 @@ import ServerTests
|
|||
samplePubKey,
|
||||
sampleSig,
|
||||
signSendRecv,
|
||||
tGet1,
|
||||
tPut1,
|
||||
(#==),
|
||||
_SEND',
|
||||
pattern Resp,
|
||||
|
@ -69,15 +71,15 @@ pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command
|
|||
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse)
|
||||
sendRecvNtf h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut h (sgn, t)
|
||||
tGet h
|
||||
Right () <- tPut1 h (sgn, t)
|
||||
tGet1 h
|
||||
|
||||
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse)
|
||||
signSendRecvNtf h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right sig <- runExceptT $ C.sign pk t
|
||||
Right () <- tPut h (Just sig, t)
|
||||
tGet h
|
||||
Right () <- tPut1 h (Just sig, t)
|
||||
tGet1 h
|
||||
|
||||
(.->) :: J.Value -> J.Key -> Either String ByteString
|
||||
v .-> key =
|
||||
|
@ -132,7 +134,7 @@ testNotificationSubscription (ATransport t) =
|
|||
notifierId `shouldBe` nId
|
||||
send' APNSRespOk
|
||||
-- receive message
|
||||
Resp "" _ (MSG RcvMessage {msgId = mId1, msgBody = EncRcvMsgBody body}) <- tGet rh
|
||||
Resp "" _ (MSG RcvMessage {msgId = mId1, msgBody = EncRcvMsgBody body}) <- tGet1 rh
|
||||
Right ClientRcvMsgBody {msgTs = mTs, msgBody} <- pure $ parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt rcvDhSecret (C.cbNonce mId1) body)
|
||||
mId1 `shouldBe` msgId
|
||||
mTs `shouldBe` msgTs
|
||||
|
|
|
@ -158,6 +158,9 @@ smpAgentTest1_1_1 test' =
|
|||
testSMPServer :: SMPServer
|
||||
testSMPServer = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001"
|
||||
|
||||
testSMPServer2 :: SMPServer
|
||||
testSMPServer2 = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5002"
|
||||
|
||||
initAgentServers :: InitialAgentServers
|
||||
initAgentServers =
|
||||
InitialAgentServers
|
||||
|
@ -165,6 +168,9 @@ initAgentServers =
|
|||
ntf = ["ntf://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"]
|
||||
}
|
||||
|
||||
initAgentServers2 :: InitialAgentServers
|
||||
initAgentServers2 = initAgentServers {smp = L.fromList [testSMPServer, testSMPServer2]}
|
||||
|
||||
agentCfg :: AgentConfig
|
||||
agentCfg =
|
||||
defaultAgentConfig
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
@ -44,6 +45,9 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
|
|||
testStoreLogFile :: FilePath
|
||||
testStoreLogFile = "tests/tmp/smp-server-store.log"
|
||||
|
||||
testStoreLogFile2 :: FilePath
|
||||
testStoreLogFile2 = "tests/tmp/smp-server-store.log.2"
|
||||
|
||||
testStoreMsgsFile :: FilePath
|
||||
testStoreMsgsFile = "tests/tmp/smp-server-messages.log"
|
||||
|
||||
|
@ -140,10 +144,10 @@ smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h
|
|||
where
|
||||
tPut' h (sig, corrId, queueId, smp) = do
|
||||
let t' = smpEncode (sessionId (h :: THandle c), corrId, queueId, smp)
|
||||
Right () <- tPut h (sig, t')
|
||||
[Right ()] <- tPut h [(sig, t')]
|
||||
pure ()
|
||||
tGet' h = do
|
||||
(Nothing, _, (CorrId corrId, qId, Right cmd)) <- tGet h
|
||||
[(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h
|
||||
pure (Nothing, corrId, qId, cmd)
|
||||
|
||||
smpTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
@ -14,6 +15,7 @@ import Control.Concurrent (ThreadId, killThread, threadDelay)
|
|||
import Control.Concurrent.STM
|
||||
import Control.Exception (SomeException, try)
|
||||
import Control.Monad.Except (forM, forM_, runExceptT)
|
||||
import Control.Monad.IO.Class
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
|
@ -69,15 +71,25 @@ pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body}
|
|||
sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg)
|
||||
sendRecv h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut h (sgn, t)
|
||||
tGet h
|
||||
Right () <- tPut1 h (sgn, t)
|
||||
tGet1 h
|
||||
|
||||
signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg)
|
||||
signSendRecv h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right sig <- runExceptT $ C.sign pk t
|
||||
Right () <- tPut h (Just sig, t)
|
||||
tGet h
|
||||
Right () <- tPut1 h (Just sig, t)
|
||||
tGet1 h
|
||||
|
||||
tPut1 :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ())
|
||||
tPut1 h t = do
|
||||
[r] <- tPut h [t]
|
||||
pure r
|
||||
|
||||
tGet1 :: (ProtocolEncoding cmd, Transport c, MonadIO m, MonadFail m) => THandle c -> m (SignedTransmission cmd)
|
||||
tGet1 h = do
|
||||
[r] <- tGet h
|
||||
pure r
|
||||
|
||||
(#==) :: (HasCallStack, Eq a, Show a) => (a, a) -> String -> Assertion
|
||||
(actual, expected) #== message = assertEqual message expected actual
|
||||
|
@ -110,7 +122,7 @@ testCreateSecureV2 _ =
|
|||
(ok1, OK) #== "accepts unsigned SEND"
|
||||
(sId1, sId) #== "same queue ID in response 1"
|
||||
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet h
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 h
|
||||
(dec mId1 msg1, Right "hello") #== "delivers message"
|
||||
|
||||
Resp "cdab" _ ok4 <- signSendRecv h rKey ("cdab", rId, ACK mId1)
|
||||
|
@ -140,7 +152,7 @@ testCreateSecureV2 _ =
|
|||
Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, _SEND "hello again")
|
||||
(ok3, OK) #== "accepts signed SEND"
|
||||
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet h
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet1 h
|
||||
(dec mId2 msg2, Right "hello again") #== "delivers message 2"
|
||||
|
||||
Resp "cdab" _ ok5 <- signSendRecv h rKey ("cdab", rId, ACK mId2)
|
||||
|
@ -151,7 +163,7 @@ testCreateSecureV2 _ =
|
|||
|
||||
let maxAllowedMessage = B.replicate maxMessageLength '-'
|
||||
Resp "bcda" _ OK <- signSendRecv h sKey ("bcda", sId, _SEND maxAllowedMessage)
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet h
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet1 h
|
||||
(dec mId3 msg3, Right maxAllowedMessage) #== "delivers message of max size"
|
||||
|
||||
let biggerMessage = B.replicate (maxMessageLength + 1) '-'
|
||||
|
@ -172,7 +184,7 @@ testCreateSecure (ATransport t) =
|
|||
(ok1, OK) #== "accepts unsigned SEND"
|
||||
(sId1, sId) #== "same queue ID in response 1"
|
||||
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet h
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 h
|
||||
(dec mId1 msg1, Right "hello") #== "delivers message"
|
||||
|
||||
Resp "cdab" _ ok4 <- signSendRecv h rKey ("cdab", rId, ACK mId1)
|
||||
|
@ -202,7 +214,7 @@ testCreateSecure (ATransport t) =
|
|||
Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, _SEND "hello again")
|
||||
(ok3, OK) #== "accepts signed SEND"
|
||||
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet h
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet1 h
|
||||
(dec mId2 msg2, Right "hello again") #== "delivers message 2"
|
||||
|
||||
Resp "cdab" _ ok5 <- signSendRecv h rKey ("cdab", rId, ACK mId2)
|
||||
|
@ -213,7 +225,7 @@ testCreateSecure (ATransport t) =
|
|||
|
||||
let maxAllowedMessage = B.replicate maxMessageLength '-'
|
||||
Resp "bcda" _ OK <- signSendRecv h sKey ("bcda", sId, _SEND maxAllowedMessage)
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet h
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet1 h
|
||||
(dec mId3 msg3, Right maxAllowedMessage) #== "delivers message of max size"
|
||||
|
||||
let biggerMessage = B.replicate (maxMessageLength + 1) '-'
|
||||
|
@ -240,7 +252,7 @@ testCreateDelete (ATransport t) =
|
|||
Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, _SEND "hello 2")
|
||||
(ok7, OK) #== "accepts signed SEND 2 - this message is not delivered because the first is not ACKed"
|
||||
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet rh
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 rh
|
||||
(dec mId1 msg1, Right "hello") #== "delivers message"
|
||||
|
||||
Resp "abcd" _ err1 <- sendRecv rh (sampleSig, "abcd", rId, OFF)
|
||||
|
@ -296,7 +308,7 @@ stressTest (ATransport t) =
|
|||
smpTest3 t $ \h1 h2 h3 -> do
|
||||
(rPub, rKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
(dhPub, _ :: C.PrivateKeyX25519) <- C.generateKeyPair'
|
||||
rIds <- forM [1 .. 50 :: Int] . const $ do
|
||||
rIds <- forM ([1 .. 50] :: [Int]) . const $ do
|
||||
Resp "" "" (Ids rId _ _) <- signSendRecv h1 rKey ("", "", NEW rPub dhPub)
|
||||
pure rId
|
||||
let subscribeQueues h = forM_ rIds $ \rId -> do
|
||||
|
@ -331,7 +343,7 @@ testDuplex (ATransport t) =
|
|||
Resp "bcda" _ OK <- sendRecv bob ("", "bcda", aSnd, _SEND $ "key " <> strEncode bsPub)
|
||||
-- "key ..." is ad-hoc, not a part of SMP protocol
|
||||
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet alice
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 alice
|
||||
Resp "cdab" _ OK <- signSendRecv alice arKey ("cdab", aRcv, ACK mId1)
|
||||
Right ["key", bobKey] <- pure $ B.words <$> aDec mId1 msg1
|
||||
(bobKey, strEncode bsPub) #== "key received from Bob"
|
||||
|
@ -344,7 +356,7 @@ testDuplex (ATransport t) =
|
|||
Resp "bcda" _ OK <- signSendRecv bob bsKey ("bcda", aSnd, _SEND $ "reply_id " <> encode bSnd)
|
||||
-- "reply_id ..." is ad-hoc, not a part of SMP protocol
|
||||
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet alice
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet1 alice
|
||||
Resp "cdab" _ OK <- signSendRecv alice arKey ("cdab", aRcv, ACK mId2)
|
||||
Right ["reply_id", bId] <- pure $ B.words <$> aDec mId2 msg2
|
||||
(bId, encode bSnd) #== "reply queue ID received from Bob"
|
||||
|
@ -353,7 +365,7 @@ testDuplex (ATransport t) =
|
|||
Resp "dabc" _ OK <- sendRecv alice ("", "dabc", bSnd, _SEND $ "key " <> strEncode asPub)
|
||||
-- "key ..." is ad-hoc, not a part of SMP protocol
|
||||
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet bob
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet1 bob
|
||||
Resp "abcd" _ OK <- signSendRecv bob brKey ("abcd", bRcv, ACK mId3)
|
||||
Right ["key", aliceKey] <- pure $ B.words <$> bDec mId3 msg3
|
||||
(aliceKey, strEncode asPub) #== "key received from Alice"
|
||||
|
@ -361,13 +373,13 @@ testDuplex (ATransport t) =
|
|||
|
||||
Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, _SEND "hi alice")
|
||||
|
||||
Resp "" _ (Msg mId4 msg4) <- tGet alice
|
||||
Resp "" _ (Msg mId4 msg4) <- tGet1 alice
|
||||
Resp "dabc" _ OK <- signSendRecv alice arKey ("dabc", aRcv, ACK mId4)
|
||||
(aDec mId4 msg4, Right "hi alice") #== "message received from Bob"
|
||||
|
||||
Resp "abcd" _ OK <- signSendRecv alice asKey ("abcd", bSnd, _SEND "how are you bob")
|
||||
|
||||
Resp "" _ (Msg mId5 msg5) <- tGet bob
|
||||
Resp "" _ (Msg mId5 msg5) <- tGet1 bob
|
||||
Resp "bcda" _ OK <- signSendRecv bob brKey ("bcda", bRcv, ACK mId5)
|
||||
(bDec mId5 msg5, Right "how are you bob") #== "message received from alice"
|
||||
|
||||
|
@ -384,7 +396,7 @@ testSwitchSub (ATransport t) =
|
|||
Resp "cdab" _ ok2 <- sendRecv sh ("", "cdab", sId, _SEND "test2, no ACK")
|
||||
(ok2, OK) #== "sent test message 2"
|
||||
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet rh1
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 rh1
|
||||
(dec mId1 msg1, Right "test1") #== "test message 1 delivered to the 1st TCP connection"
|
||||
Resp "abcd" _ (Msg mId2 msg2) <- signSendRecv rh1 rKey ("abcd", rId, ACK mId1)
|
||||
(dec mId2 msg2, Right "test2, no ACK") #== "test message 2 delivered, no ACK"
|
||||
|
@ -393,12 +405,12 @@ testSwitchSub (ATransport t) =
|
|||
(dec mId2' msg2', Right "test2, no ACK") #== "same simplex queue via another TCP connection, tes2 delivered again (no ACK in 1st queue)"
|
||||
Resp "cdab" _ OK <- signSendRecv rh2 rKey ("cdab", rId, ACK mId2')
|
||||
|
||||
Resp "" _ end <- tGet rh1
|
||||
Resp "" _ end <- tGet1 rh1
|
||||
(end, END) #== "unsubscribed the 1st TCP connection"
|
||||
|
||||
Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, _SEND "test3")
|
||||
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet rh2
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet1 rh2
|
||||
(dec mId3 msg3, Right "test3") #== "delivered to the 2nd TCP connection"
|
||||
|
||||
Resp "abcd" _ err <- signSendRecv rh1 rKey ("abcd", rId, ACK mId3)
|
||||
|
@ -441,7 +453,7 @@ testGetSubCommands t =
|
|||
Resp "1b" _ OK <- signSendRecv sh sKey ("1b", sId, _SEND "hello 3")
|
||||
Resp "1c" _ OK <- signSendRecv sh sKey ("1c", sId, _SEND "hello 4")
|
||||
-- both get the same if not ACK'd
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet rh1
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 rh1
|
||||
Resp "2" _ (Msg mId1' msg1') <- signSendRecv rh2 rKey ("2", rId, GET)
|
||||
(dec mId1 msg1, Right "hello 1") #== "received from queue via SUB"
|
||||
(dec mId1' msg1', Right "hello 1") #== "retrieved from queue with GET"
|
||||
|
@ -503,14 +515,14 @@ testWithStoreLog at@(ATransport t) =
|
|||
writeTVar notifierId nId
|
||||
Resp "dabc" _ OK <- signSendRecv h1 nKey ("dabc", nId, NSUB)
|
||||
Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello")
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet h
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 h
|
||||
(decryptMsgV3 dhShared mId1 msg1, Right "hello") #== "delivered from queue 1"
|
||||
Resp "" _ (NMSG _ _) <- tGet h1
|
||||
Resp "" _ (NMSG _ _) <- tGet1 h1
|
||||
|
||||
(sId2, rId2, rKey2, dhShared2) <- createAndSecureQueue h sPub2
|
||||
atomically $ writeTVar senderId2 sId2
|
||||
Resp "cdab" _ OK <- signSendRecv h sKey2 ("cdab", sId2, _SEND "hello too")
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet h
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet1 h
|
||||
(decryptMsgV3 dhShared2 mId2 msg2, Right "hello too") #== "delivered from queue 2"
|
||||
|
||||
Resp "dabc" _ OK <- signSendRecv h rKey2 ("dabc", rId2, DEL)
|
||||
|
@ -535,7 +547,7 @@ testWithStoreLog at@(ATransport t) =
|
|||
Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello")
|
||||
Resp "cdab" _ (Msg mId3 msg3) <- signSendRecv h rKey1 ("cdab", rId1, SUB)
|
||||
(decryptMsgV3 dh1 mId3 msg3, Right "hello") #== "delivered from restored queue"
|
||||
Resp "" _ (NMSG _ _) <- tGet h1
|
||||
Resp "" _ (NMSG _ _) <- tGet1 h1
|
||||
-- this queue is removed - not restored
|
||||
sId2 <- readTVarIO senderId2
|
||||
Resp "cdab" _ (ERR AUTH) <- signSendRecv h sKey2 ("cdab", sId2, _SEND "hello too")
|
||||
|
@ -576,7 +588,7 @@ testRestoreMessages at@(ATransport t) =
|
|||
writeTVar dhShared $ Just dh
|
||||
writeTVar senderId sId
|
||||
Resp "1" _ OK <- signSendRecv h sKey ("1", sId, _SEND "hello")
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet h1
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 h1
|
||||
Resp "1a" _ OK <- signSendRecv h1 rKey ("1a", rId, ACK mId1)
|
||||
(decryptMsgV3 dh mId1 msg1, Right "hello") #== "message delivered"
|
||||
-- messages below are delivered after server restart
|
||||
|
@ -645,7 +657,7 @@ testRestoreMessagesV2 at@(ATransport t) =
|
|||
writeTVar dhShared $ Just dh
|
||||
writeTVar senderId sId
|
||||
Resp "1" _ OK <- signSendRecv h sKey ("1", sId, _SEND "hello")
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet h1
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 h1
|
||||
Resp "1a" _ OK <- signSendRecv h1 rKey ("1a", rId, ACK mId1)
|
||||
(decryptMsgV2 dh mId1 msg1, Right "hello") #== "message delivered"
|
||||
-- messages below are delivered after server restart
|
||||
|
@ -710,14 +722,15 @@ testTiming :: ATransport -> Spec
|
|||
testTiming (ATransport t) =
|
||||
it "should have similar time for auth error, whether queue exists or not, for all key sizes" $
|
||||
smpTest2 t $ \rh sh ->
|
||||
mapM_
|
||||
(testSameTiming rh sh)
|
||||
[ (32, 32, 200),
|
||||
(32, 57, 100),
|
||||
(57, 32, 200),
|
||||
(57, 57, 100)
|
||||
]
|
||||
mapM_ (testSameTiming rh sh) timingTests
|
||||
where
|
||||
timingTests :: [(Int, Int, Int)]
|
||||
timingTests =
|
||||
[ (32, 32, 200),
|
||||
(32, 57, 100),
|
||||
(57, 32, 200),
|
||||
(57, 57, 100)
|
||||
]
|
||||
timeRepeat n = fmap fst . timeItT . forM_ (replicate n ()) . const
|
||||
similarTime t1 t2 = abs (t2 / t1 - 1) < 0.25 `shouldBe` True
|
||||
testSameTiming :: Transport c => THandle c -> THandle c -> (Int, Int, Int) -> Expectation
|
||||
|
@ -735,7 +748,7 @@ testTiming (ATransport t) =
|
|||
Resp "dabc" _ OK <- signSendRecv rh rKey ("dabc", rId, KEY sPub)
|
||||
|
||||
Resp "bcda" _ OK <- signSendRecv sh sKey ("bcda", sId, _SEND "hello")
|
||||
Resp "" _ (Msg mId msg) <- tGet rh
|
||||
Resp "" _ (Msg mId msg) <- tGet1 rh
|
||||
(dec mId msg, Right "hello") #== "delivered from queue"
|
||||
|
||||
runTimingTest sh badKey sId $ _SEND "hello"
|
||||
|
@ -774,23 +787,23 @@ testMessageNotifications (ATransport t) =
|
|||
nId' `shouldNotBe` nId
|
||||
Resp "2" _ OK <- signSendRecv nh1 nKey ("2", nId, NSUB)
|
||||
Resp "3" _ OK <- signSendRecv sh sKey ("3", sId, _SEND' "hello")
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet rh
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet1 rh
|
||||
(dec mId1 msg1, Right "hello") #== "delivered from queue"
|
||||
Resp "3a" _ OK <- signSendRecv rh rKey ("3a", rId, ACK mId1)
|
||||
Resp "" _ (NMSG _ _) <- tGet nh1
|
||||
Resp "" _ (NMSG _ _) <- tGet1 nh1
|
||||
Resp "4" _ OK <- signSendRecv nh2 nKey ("4", nId, NSUB)
|
||||
Resp "" _ END <- tGet nh1
|
||||
Resp "" _ END <- tGet1 nh1
|
||||
Resp "5" _ OK <- signSendRecv sh sKey ("5", sId, _SEND' "hello again")
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet rh
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet1 rh
|
||||
Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2)
|
||||
(dec mId2 msg2, Right "hello again") #== "delivered from queue again"
|
||||
Resp "" _ (NMSG _ _) <- tGet nh2
|
||||
Resp "" _ (NMSG _ _) <- tGet1 nh2
|
||||
1000 `timeout` tGet @BrokerMsg nh1 >>= \case
|
||||
Nothing -> pure ()
|
||||
Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection"
|
||||
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL)
|
||||
Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there")
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet rh
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet1 rh
|
||||
(dec mId3 msg3, Right "hello there") #== "delivered from queue again"
|
||||
1000 `timeout` tGet @BrokerMsg nh2 >>= \case
|
||||
Nothing -> pure ()
|
||||
|
|
Reference in New Issue