From 02bba01c1661e59b5bf6e443d759ca601ffa1cd8 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Sat, 13 Aug 2022 11:57:36 +0100 Subject: [PATCH] send host events when server hosts are connected and disconnected (#496) --- src/Simplex/Messaging/Agent.hs | 2 +- src/Simplex/Messaging/Agent/Client.hs | 116 ++++++++++++++-------- src/Simplex/Messaging/Agent/Protocol.hs | 22 +++- src/Simplex/Messaging/Client.hs | 14 +-- src/Simplex/Messaging/Client/Agent.hs | 4 +- src/Simplex/Messaging/Protocol.hs | 15 ++- src/Simplex/Messaging/Transport/Client.hs | 4 + tests/AgentTests.hs | 12 ++- tests/AgentTests/FunctionalAPITests.hs | 7 +- tests/SMPAgentClient.hs | 13 ++- 10 files changed, 145 insertions(+), 64 deletions(-) diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index d80503c..11eddb8 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -457,7 +457,7 @@ subscribeConnections' c connIds = do addRcvQueue :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) -> ConnId -> (RcvQueue, ConnData) -> Map SMPServer (Map ConnId (RcvQueue, ConnData)) addRcvQueue m connId rq@(RcvQueue {server}, _) = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m subscribe :: (SMPServer, Map ConnId (RcvQueue, ConnData)) -> m (Map ConnId (Either AgentErrorType ())) - subscribe (srv, qs) = subscribeQueues c srv (M.map fst qs) + subscribe (srv, qs) = snd <$> subscribeQueues c srv (M.map fst qs) sendNtfCreate :: NtfSupervisor -> [Map ConnId (Either AgentErrorType ())] -> m () sendNtfCreate ns rcvRs = forM_ (concatMap M.assocs rcvRs) $ \case diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 9f727c1..ff0d5cb 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -106,10 +106,31 @@ import Simplex.Messaging.Notifications.Client import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Parsers (parse) -import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), MsgId, NotifierId, NtfPrivateSignKey, NtfPublicVerifyKey, NtfServer, ProtoServer, ProtocolServer (..), QueueId, QueueIdsKeys (..), RcvMessage (..), RcvNtfPublicDhKey, SMPMsgMeta (..), SndPublicVerifyKey) +import Simplex.Messaging.Protocol + ( AProtocolType (..), + BrokerMsg, + ErrorType, + MsgFlags (..), + MsgId, + NotifierId, + NtfPrivateSignKey, + NtfPublicVerifyKey, + NtfServer, + ProtoServer, + Protocol (..), + ProtocolServer (..), + ProtocolTypeI (..), + QueueId, + QueueIdsKeys (..), + RcvMessage (..), + RcvNtfPublicDhKey, + SMPMsgMeta (..), + SndPublicVerifyKey, + ) import qualified Simplex.Messaging.Protocol as SMP import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM +import Simplex.Messaging.Transport.Client (TransportHost) import Simplex.Messaging.Util import Simplex.Messaging.Version import System.Timeout (timeout) @@ -232,46 +253,47 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do u <- askUnliftIO liftEitherError (protocolClientError SMP) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u) - clientDisconnected :: UnliftIO m -> IO () - clientDisconnected u = do - removeClientAndSubs >>= (`forM_` serverDown u) + clientDisconnected :: UnliftIO m -> SMPClient -> IO () + clientDisconnected u client = do + removeClientAndSubs >>= (`forM_` serverDown) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv - - removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue)) - removeClientAndSubs = atomically $ do - TM.delete srv smpClients - TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs where - updateSubs cVar = do - cs <- readTVar cVar - modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs) - addPendingSubs cVar cs - pure cs + removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue)) + removeClientAndSubs = atomically $ do + TM.delete srv smpClients + TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs + where + updateSubs cVar = do + cs <- readTVar cVar + modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs) + addPendingSubs cVar cs + pure cs - addPendingSubs cVar cs = do - let ps = pendingSubscrSrvrs c - TM.lookup srv ps >>= \case - Just v -> TM.union cs v - _ -> TM.insert srv cVar ps + addPendingSubs cVar cs = do + let ps = pendingSubscrSrvrs c + TM.lookup srv ps >>= \case + Just v -> TM.union cs v + _ -> TM.insert srv cVar ps - serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO () - serverDown u cs = unless (M.null cs) $ - whenM (readTVarIO active) $ do - let conns = M.keys cs - unless (null conns) . notifySub "" $ DOWN srv conns - atomically $ mapM_ (releaseGetLock c) cs - unliftIO u reconnectServer + serverDown :: Map ConnId RcvQueue -> IO () + serverDown cs = unless (M.null cs) $ + whenM (readTVarIO active) $ do + let conns = M.keys cs + notifySub "" $ hostEvent DISCONNECT client + unless (null conns) . notifySub "" $ DOWN srv conns + atomically $ mapM_ (releaseGetLock c) cs + unliftIO u reconnectServer - reconnectServer :: m () - reconnectServer = do - a <- async tryReconnectClient - atomically $ modifyTVar' (reconnections c) (a :) + reconnectServer :: m () + reconnectServer = do + a <- async tryReconnectClient + atomically $ modifyTVar' (reconnections c) (a :) - tryReconnectClient :: m () - tryReconnectClient = do - ri <- asks $ reconnectInterval . config - withRetryInterval ri $ \loop -> - reconnectClient `catchError` const loop + tryReconnectClient :: m () + tryReconnectClient = do + ri <- asks $ reconnectInterval . config + withRetryInterval ri $ \loop -> + reconnectClient `catchError` const loop reconnectClient :: m () reconnectClient = @@ -281,8 +303,11 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do where resubscribe :: Map ConnId RcvQueue -> m () resubscribe qs = do - (errs, oks) <- M.mapEither id <$> subscribeQueues c srv qs - liftIO . unless (M.null oks) . notifySub "" . UP srv $ M.keys oks + (client_, (errs, oks)) <- second (M.mapEither id) <$> subscribeQueues c srv qs + liftIO $ do + mapM_ (notifySub "" . hostEvent CONNECT) client_ + unless (M.null oks) $ do + notifySub "" . UP srv $ M.keys oks let (tempErrs, finalErrs) = M.partition temporaryAgentError errs liftIO . mapM_ (\(connId, e) -> notifySub connId $ ERR e) $ M.assocs finalErrs mapM_ throwError . listToMaybe $ M.elems tempErrs @@ -303,9 +328,10 @@ getNtfServerClient c@AgentClient {active, ntfClients} srv = do cfg <- atomically . updateClientConfig c =<< asks (ntfCfg . config) liftEitherError (protocolClientError NTF) (getProtocolClient srv cfg Nothing clientDisconnected) - clientDisconnected :: IO () - clientDisconnected = do + clientDisconnected :: NtfClient -> IO () + clientDisconnected client = do atomically $ TM.delete srv ntfClients + atomically $ writeTBQueue (subQ c) ("", "", hostEvent DISCONNECT client) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a)) @@ -328,7 +354,7 @@ waitForProtocolClient c clientVar = do newProtocolClient :: forall msg m. - AgentMonad m => + (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> ProtoServer msg -> TMap (ProtoServer msg) (ClientVar msg) -> @@ -344,6 +370,7 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon Right client -> do logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv atomically $ putTMVar clientVar r + atomically $ writeTBQueue (subQ c) ("", "", hostEvent CONNECT client) successAction client Left e -> do if temporaryAgentError e @@ -361,6 +388,9 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon ri <- asks $ reconnectInterval . config withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop +hostEvent :: forall msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient msg -> ACommand 'Agent +hostEvent event client = event (AProtocolType $ protocolTypeI @(ProtoType msg)) $ transportHost client + updateClientConfig :: AgentClient -> ProtocolClientConfig -> STM ProtocolClientConfig updateClientConfig AgentClient {useNetworkConfig} cfg = do networkConfig <- readTVar useNetworkConfig @@ -505,14 +535,14 @@ temporaryAgentError = \case _ -> False -- | subscribe multiple queues - all passed queues should be on the same server -subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Map ConnId (Either AgentErrorType ())) +subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ())) subscribeQueues c srv qs = do (errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs) forM_ qs_ $ atomically . uncurry (addPendingSubscription c) . swap case L.nonEmpty qs_ of Just qs' -> do smp_ <- tryError (getSMPServerClient c srv) - M.fromList . (errs <>) <$> case smp_ of + (eitherToMaybe smp_,) . M.fromList . (errs <>) <$> case smp_ of Left e -> pure $ map (second . const $ Left e) qs_ Right smp -> do logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB" @@ -521,7 +551,7 @@ subscribeQueues c srv qs = do liftIO $ zip qs_ . L.toList <$> subscribeSMPQueues smp qs2 forM_ rs' $ \((connId, rq), r) -> liftIO $ processSubResult c rq connId r pure $ map (bimap fst (first $ protocolClientError SMP)) rs' - _ -> pure $ M.fromList errs + _ -> pure $ (Nothing, M.fromList errs) where checkQueue rq@(connId, RcvQueue {rcvId, server}) = do prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index 52f4910..6a19282 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -40,6 +40,7 @@ module Simplex.Messaging.Agent.Protocol -- * SMP agent protocol types ConnInfo, ACommand (..), + ACmd (..), AParty (..), SAParty (..), MsgHash, @@ -141,7 +142,8 @@ import Simplex.Messaging.Encoding import Simplex.Messaging.Encoding.String import Simplex.Messaging.Parsers import Simplex.Messaging.Protocol - ( ErrorType, + ( AProtocolType, + ErrorType, MsgBody, MsgFlags, MsgId, @@ -228,6 +230,8 @@ data ACommand (p :: AParty) where CON :: ACommand Agent -- notification that connection is established SUB :: ACommand Client END :: ACommand Agent + CONNECT :: AProtocolType -> TransportHost -> ACommand Agent + DISCONNECT :: AProtocolType -> TransportHost -> ACommand Agent DOWN :: SMPServer -> [ConnId] -> ACommand Agent UP :: SMPServer -> [ConnId] -> ACommand Agent SEND :: MsgFlags -> MsgBody -> ACommand Client @@ -929,8 +933,10 @@ commandP = <|> "INFO " *> infoCmd <|> "SUB" $> ACmd SClient SUB <|> "END" $> ACmd SAgent END - <|> "DOWN " *> downsResp - <|> "UP " *> upsResp + <|> "CONNECT " *> connectResp + <|> "DISCONNECT " *> disconnectResp + <|> "DOWN " *> downResp + <|> "UP " *> upResp <|> "SEND " *> sendCmd <|> "MID " *> msgIdResp <|> "SENT " *> sentResp @@ -954,8 +960,10 @@ commandP = acptCmd = ACmd SClient .: ACPT <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString rjctCmd = ACmd SClient . RJCT <$> A.takeByteString infoCmd = ACmd SAgent . INFO <$> A.takeByteString - downsResp = ACmd SAgent .: DOWN <$> strP_ <*> connections - upsResp = ACmd SAgent .: UP <$> strP_ <*> connections + connectResp = ACmd SAgent .: CONNECT <$> strP_ <*> strP + disconnectResp = ACmd SAgent .: DISCONNECT <$> strP_ <*> strP + downResp = ACmd SAgent .: DOWN <$> strP_ <*> connections + upResp = ACmd SAgent .: UP <$> strP_ <*> connections sendCmd = ACmd SClient .: SEND <$> smpP <* A.space <*> A.takeByteString msgIdResp = ACmd SAgent . MID <$> A.decimal sentResp = ACmd SAgent . SENT <$> A.decimal @@ -990,6 +998,8 @@ serializeCommand = \case INFO cInfo -> "INFO " <> serializeBinary cInfo SUB -> "SUB" END -> "END" + CONNECT p h -> B.unwords ["CONNECT", strEncode p, strEncode h] + DISCONNECT p h -> B.unwords ["DISCONNECT", strEncode p, strEncode h] DOWN srv conns -> B.unwords ["DOWN", strEncode srv, connections conns] UP srv conns -> B.unwords ["UP", strEncode srv, connections conns] SEND msgFlags msgBody -> "SEND " <> smpEncode msgFlags <> " " <> serializeBinary msgBody @@ -1062,6 +1072,8 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody ACPT {} -> Right cmd -- ERROR response does not always have connId ERR _ -> Right cmd + CONNECT {} -> Right cmd + DISCONNECT {} -> Right cmd DOWN {} -> Right cmd UP {} -> Right cmd -- other responses must have connId diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index 874c371..e5582fe 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -26,7 +26,7 @@ -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md module Simplex.Messaging.Client ( -- * Connect (disconnect) client to (from) SMP server - ProtocolClient (thVersion, sessionId), + ProtocolClient (thVersion, sessionId, transportHost), SMPClient, getProtocolClient, closeProtocolClient, @@ -101,6 +101,7 @@ data ProtocolClient msg = ProtocolClient sessionId :: SessionId, thVersion :: Version, protocolServer :: ProtoServer msg, + transportHost :: TransportHost, tcpTimeout :: Int, clientCorrId :: TVar Natural, sentCommands :: TMap CorrId (Request msg), @@ -216,17 +217,17 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts -- -- A single queue can be used for multiple 'SMPClient' instances, -- as 'SMPServerTransmission' includes server information. -getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg)) +getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg)) getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do case chooseTransportHost networkConfig (host protocolServer) of Right useHost -> - (atomically mkProtocolClient >>= runClient useTransport useHost) + (atomically (mkProtocolClient useHost) >>= runClient useTransport useHost) `catch` \(e :: IOException) -> pure . Left $ PCEIOError e Left e -> pure $ Left e where NetworkConfig {tcpConnectTimeout, tcpTimeout, tcpKeepAlive, socksProxy, smpPingInterval} = networkConfig - mkProtocolClient :: STM (ProtocolClient msg) - mkProtocolClient = do + mkProtocolClient :: TransportHost -> STM (ProtocolClient msg) + mkProtocolClient transportHost = do connected <- newTVar False clientCorrId <- newTVar 0 sentCommands <- TM.empty @@ -239,6 +240,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, thVersion = undefined, connected, protocolServer, + transportHost, tcpTimeout, clientCorrId, sentCommands, @@ -277,7 +279,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, let c' = c {sessionId, thVersion} :: ProtocolClient msg -- TODO remove ping if 0 is passed (or Nothing?) raceAny_ [send c' th, process c', receive c' th, ping c'] - `finally` disconnected + `finally` disconnected c' send :: Transport c => ProtocolClient msg -> THandle c -> IO () send ProtocolClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 415dc81..6e6f61d 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -162,8 +162,8 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv = connectClient :: ExceptT ProtocolClientError IO SMPClient connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected - clientDisconnected :: IO () - clientDisconnected = do + clientDisconnected :: SMPClient -> IO () + clientDisconnected _ = do removeClientAndSubs >>= (`forM_` serverDown) logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv diff --git a/src/Simplex/Messaging/Protocol.hs b/src/Simplex/Messaging/Protocol.hs index c710eea..066df15 100644 --- a/src/Simplex/Messaging/Protocol.hs +++ b/src/Simplex/Messaging/Protocol.hs @@ -64,6 +64,8 @@ module Simplex.Messaging.Protocol PrivHeader (..), Protocol (..), ProtocolType (..), + AProtocolType (..), + ProtocolTypeI (..), ProtocolServer (..), ProtoServer, SMPServer, @@ -135,7 +137,7 @@ import qualified Data.ByteString.Char8 as B import Data.Kind import Data.List.NonEmpty (NonEmpty (..)) import qualified Data.List.NonEmpty as L -import Data.Maybe (isNothing) +import Data.Maybe (isJust, isNothing) import Data.String import Data.Time.Clock.System (SystemTime (..)) import Data.Type.Equality @@ -578,7 +580,7 @@ instance StrEncoding ProtocolType where PSMP -> "smp" PNTF -> "ntf" strP = - A.takeTill (== ':') >>= \case + A.takeTill (\c -> c == ':' || c == ' ') >>= \case "smp" -> pure PSMP "ntf" -> pure PNTF _ -> fail "bad ProtocolType" @@ -595,6 +597,11 @@ deriving instance Show (SProtocolType p) data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p) +deriving instance Show AProtocolType + +instance Eq AProtocolType where + AProtocolType p == AProtocolType p' = isJust $ testEquality p p' + instance TestEquality SProtocolType where testEquality SPSMP SPSMP = Just Refl testEquality SPNTF SPNTF = Just Refl @@ -618,6 +625,10 @@ instance StrEncoding AProtocolType where strEncode (AProtocolType p) = strEncode p strP = aProtocolType <$> strP +instance ToJSON AProtocolType where + toEncoding = strToJEncoding + toJSON = strToJSON + checkProtocolType :: forall t p p'. (ProtocolTypeI p, ProtocolTypeI p') => t p' -> Either String (t p) checkProtocolType p = case testEquality (protocolTypeI @p) (protocolTypeI @p') of Just Refl -> Right p diff --git a/src/Simplex/Messaging/Transport/Client.hs b/src/Simplex/Messaging/Transport/Client.hs index 485296f..400d0e0 100644 --- a/src/Simplex/Messaging/Transport/Client.hs +++ b/src/Simplex/Messaging/Transport/Client.hs @@ -73,6 +73,10 @@ instance StrEncoding TransportHost where where ipNum = A.decimal <* A.char '.' +instance ToJSON TransportHost where + toEncoding = strToJEncoding + toJSON = strToJSON + newtype TransportHosts = TransportHosts {thList :: NonEmpty TransportHost} instance StrEncoding TransportHosts where diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index 83256ff..83aa154 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -78,9 +78,17 @@ agentTests (ATransport t) = do it "should deliver messages if one of connections has quota exceeded" $ smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t +tGetAgent :: Transport c => c -> IO (ATransmissionOrError 'Agent) +tGetAgent h = do + t@(_, _, cmd) <- tGet SAgent h + case cmd of + Right CONNECT {} -> tGetAgent h + Right DISCONNECT {} -> tGetAgent h + _ -> pure t + -- | receive message to handle `h` (<#:) :: Transport c => c -> IO (ATransmissionOrError 'Agent) -(<#:) = tGet SAgent +(<#:) = tGetAgent -- | send transmission `t` to handle `h` and get response (#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (ATransmissionOrError 'Agent) @@ -114,7 +122,7 @@ h <#= p = (h <#:) >>= (`shouldSatisfy` p . correctTransmission) h #:# err = tryGet `shouldReturn` () where tryGet = - 10000 `timeout` tGet SAgent h >>= \case + 10000 `timeout` tGetAgent h >>= \case Just _ -> error err _ -> return () diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index e7de32c..f146a51 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -47,7 +47,12 @@ a ##> t = a >>= \t' -> liftIO (t' `shouldBe` t) a =##> p = a >>= \t -> liftIO (t `shouldSatisfy` p) get :: MonadIO m => AgentClient -> m (ATransmission 'Agent) -get c = atomically (readTBQueue $ subQ c) +get c = do + t@(_, _, cmd) <- atomically (readTBQueue $ subQ c) + case cmd of + CONNECT {} -> get c + DISCONNECT {} -> get c + _ -> pure t pattern Msg :: MsgBody -> ACommand 'Agent pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody diff --git a/tests/SMPAgentClient.hs b/tests/SMPAgentClient.hs index 04f7506..65ac5b9 100644 --- a/tests/SMPAgentClient.hs +++ b/tests/SMPAgentClient.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -27,6 +28,7 @@ import Simplex.Messaging.Agent.Protocol import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.Server (runSMPAgentBlocking) import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultClientConfig, defaultNetworkConfig) +import Simplex.Messaging.Parsers (parseAll) import Simplex.Messaging.Transport import Simplex.Messaging.Transport.Client import Simplex.Messaging.Transport.KeepAlive @@ -56,7 +58,14 @@ testDB3 :: String testDB3 = "tests/tmp/smp-agent3.test.protocol.db" smpAgentTest :: forall c. Transport c => TProxy c -> ARawTransmission -> IO ARawTransmission -smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> tGetRaw h +smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> get h + where + get h = do + t@(_, _, cmdStr) <- tGetRaw h + case parseAll commandP cmdStr of + Right (ACmd SAgent CONNECT {}) -> get h + Right (ACmd SAgent DISCONNECT {}) -> get h + _ -> pure t runSmpAgentTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m, MonadFail m) => (c -> m a) -> m a runSmpAgentTest test = withSmpServer t . withSmpAgent t $ testSMPAgentClient test @@ -177,7 +186,7 @@ agentCfg :: AgentConfig agentCfg = defaultAgentConfig { tcpPort = agentTestPort, - tbqSize = 1, + tbqSize = 4, dbFile = testDB, smpCfg = defaultClientConfig