send host events when server hosts are connected and disconnected (#496)
This commit is contained in:
parent
68138c08d2
commit
02bba01c16
|
@ -457,7 +457,7 @@ subscribeConnections' c connIds = do
|
|||
addRcvQueue :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) -> ConnId -> (RcvQueue, ConnData) -> Map SMPServer (Map ConnId (RcvQueue, ConnData))
|
||||
addRcvQueue m connId rq@(RcvQueue {server}, _) = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m
|
||||
subscribe :: (SMPServer, Map ConnId (RcvQueue, ConnData)) -> m (Map ConnId (Either AgentErrorType ()))
|
||||
subscribe (srv, qs) = subscribeQueues c srv (M.map fst qs)
|
||||
subscribe (srv, qs) = snd <$> subscribeQueues c srv (M.map fst qs)
|
||||
sendNtfCreate :: NtfSupervisor -> [Map ConnId (Either AgentErrorType ())] -> m ()
|
||||
sendNtfCreate ns rcvRs =
|
||||
forM_ (concatMap M.assocs rcvRs) $ \case
|
||||
|
|
|
@ -106,10 +106,31 @@ import Simplex.Messaging.Notifications.Client
|
|||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Parsers (parse)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), MsgId, NotifierId, NtfPrivateSignKey, NtfPublicVerifyKey, NtfServer, ProtoServer, ProtocolServer (..), QueueId, QueueIdsKeys (..), RcvMessage (..), RcvNtfPublicDhKey, SMPMsgMeta (..), SndPublicVerifyKey)
|
||||
import Simplex.Messaging.Protocol
|
||||
( AProtocolType (..),
|
||||
BrokerMsg,
|
||||
ErrorType,
|
||||
MsgFlags (..),
|
||||
MsgId,
|
||||
NotifierId,
|
||||
NtfPrivateSignKey,
|
||||
NtfPublicVerifyKey,
|
||||
NtfServer,
|
||||
ProtoServer,
|
||||
Protocol (..),
|
||||
ProtocolServer (..),
|
||||
ProtocolTypeI (..),
|
||||
QueueId,
|
||||
QueueIdsKeys (..),
|
||||
RcvMessage (..),
|
||||
RcvNtfPublicDhKey,
|
||||
SMPMsgMeta (..),
|
||||
SndPublicVerifyKey,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport.Client (TransportHost)
|
||||
import Simplex.Messaging.Util
|
||||
import Simplex.Messaging.Version
|
||||
import System.Timeout (timeout)
|
||||
|
@ -232,46 +253,47 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
|||
u <- askUnliftIO
|
||||
liftEitherError (protocolClientError SMP) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u)
|
||||
|
||||
clientDisconnected :: UnliftIO m -> IO ()
|
||||
clientDisconnected u = do
|
||||
removeClientAndSubs >>= (`forM_` serverDown u)
|
||||
clientDisconnected :: UnliftIO m -> SMPClient -> IO ()
|
||||
clientDisconnected u client = do
|
||||
removeClientAndSubs >>= (`forM_` serverDown)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
|
||||
where
|
||||
updateSubs cVar = do
|
||||
cs <- readTVar cVar
|
||||
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs)
|
||||
addPendingSubs cVar cs
|
||||
pure cs
|
||||
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
|
||||
where
|
||||
updateSubs cVar = do
|
||||
cs <- readTVar cVar
|
||||
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs)
|
||||
addPendingSubs cVar cs
|
||||
pure cs
|
||||
|
||||
addPendingSubs cVar cs = do
|
||||
let ps = pendingSubscrSrvrs c
|
||||
TM.lookup srv ps >>= \case
|
||||
Just v -> TM.union cs v
|
||||
_ -> TM.insert srv cVar ps
|
||||
addPendingSubs cVar cs = do
|
||||
let ps = pendingSubscrSrvrs c
|
||||
TM.lookup srv ps >>= \case
|
||||
Just v -> TM.union cs v
|
||||
_ -> TM.insert srv cVar ps
|
||||
|
||||
serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO ()
|
||||
serverDown u cs = unless (M.null cs) $
|
||||
whenM (readTVarIO active) $ do
|
||||
let conns = M.keys cs
|
||||
unless (null conns) . notifySub "" $ DOWN srv conns
|
||||
atomically $ mapM_ (releaseGetLock c) cs
|
||||
unliftIO u reconnectServer
|
||||
serverDown :: Map ConnId RcvQueue -> IO ()
|
||||
serverDown cs = unless (M.null cs) $
|
||||
whenM (readTVarIO active) $ do
|
||||
let conns = M.keys cs
|
||||
notifySub "" $ hostEvent DISCONNECT client
|
||||
unless (null conns) . notifySub "" $ DOWN srv conns
|
||||
atomically $ mapM_ (releaseGetLock c) cs
|
||||
unliftIO u reconnectServer
|
||||
|
||||
reconnectServer :: m ()
|
||||
reconnectServer = do
|
||||
a <- async tryReconnectClient
|
||||
atomically $ modifyTVar' (reconnections c) (a :)
|
||||
reconnectServer :: m ()
|
||||
reconnectServer = do
|
||||
a <- async tryReconnectClient
|
||||
atomically $ modifyTVar' (reconnections c) (a :)
|
||||
|
||||
tryReconnectClient :: m ()
|
||||
tryReconnectClient = do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop ->
|
||||
reconnectClient `catchError` const loop
|
||||
tryReconnectClient :: m ()
|
||||
tryReconnectClient = do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop ->
|
||||
reconnectClient `catchError` const loop
|
||||
|
||||
reconnectClient :: m ()
|
||||
reconnectClient =
|
||||
|
@ -281,8 +303,11 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
|||
where
|
||||
resubscribe :: Map ConnId RcvQueue -> m ()
|
||||
resubscribe qs = do
|
||||
(errs, oks) <- M.mapEither id <$> subscribeQueues c srv qs
|
||||
liftIO . unless (M.null oks) . notifySub "" . UP srv $ M.keys oks
|
||||
(client_, (errs, oks)) <- second (M.mapEither id) <$> subscribeQueues c srv qs
|
||||
liftIO $ do
|
||||
mapM_ (notifySub "" . hostEvent CONNECT) client_
|
||||
unless (M.null oks) $ do
|
||||
notifySub "" . UP srv $ M.keys oks
|
||||
let (tempErrs, finalErrs) = M.partition temporaryAgentError errs
|
||||
liftIO . mapM_ (\(connId, e) -> notifySub connId $ ERR e) $ M.assocs finalErrs
|
||||
mapM_ throwError . listToMaybe $ M.elems tempErrs
|
||||
|
@ -303,9 +328,10 @@ getNtfServerClient c@AgentClient {active, ntfClients} srv = do
|
|||
cfg <- atomically . updateClientConfig c =<< asks (ntfCfg . config)
|
||||
liftEitherError (protocolClientError NTF) (getProtocolClient srv cfg Nothing clientDisconnected)
|
||||
|
||||
clientDisconnected :: IO ()
|
||||
clientDisconnected = do
|
||||
clientDisconnected :: NtfClient -> IO ()
|
||||
clientDisconnected client = do
|
||||
atomically $ TM.delete srv ntfClients
|
||||
atomically $ writeTBQueue (subQ c) ("", "", hostEvent DISCONNECT client)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
|
||||
|
@ -328,7 +354,7 @@ waitForProtocolClient c clientVar = do
|
|||
|
||||
newProtocolClient ::
|
||||
forall msg m.
|
||||
AgentMonad m =>
|
||||
(AgentMonad m, ProtocolTypeI (ProtoType msg)) =>
|
||||
AgentClient ->
|
||||
ProtoServer msg ->
|
||||
TMap (ProtoServer msg) (ClientVar msg) ->
|
||||
|
@ -344,6 +370,7 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon
|
|||
Right client -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
atomically $ putTMVar clientVar r
|
||||
atomically $ writeTBQueue (subQ c) ("", "", hostEvent CONNECT client)
|
||||
successAction client
|
||||
Left e -> do
|
||||
if temporaryAgentError e
|
||||
|
@ -361,6 +388,9 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon
|
|||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop
|
||||
|
||||
hostEvent :: forall msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient msg -> ACommand 'Agent
|
||||
hostEvent event client = event (AProtocolType $ protocolTypeI @(ProtoType msg)) $ transportHost client
|
||||
|
||||
updateClientConfig :: AgentClient -> ProtocolClientConfig -> STM ProtocolClientConfig
|
||||
updateClientConfig AgentClient {useNetworkConfig} cfg = do
|
||||
networkConfig <- readTVar useNetworkConfig
|
||||
|
@ -505,14 +535,14 @@ temporaryAgentError = \case
|
|||
_ -> False
|
||||
|
||||
-- | subscribe multiple queues - all passed queues should be on the same server
|
||||
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Map ConnId (Either AgentErrorType ()))
|
||||
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ()))
|
||||
subscribeQueues c srv qs = do
|
||||
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
|
||||
forM_ qs_ $ atomically . uncurry (addPendingSubscription c) . swap
|
||||
case L.nonEmpty qs_ of
|
||||
Just qs' -> do
|
||||
smp_ <- tryError (getSMPServerClient c srv)
|
||||
M.fromList . (errs <>) <$> case smp_ of
|
||||
(eitherToMaybe smp_,) . M.fromList . (errs <>) <$> case smp_ of
|
||||
Left e -> pure $ map (second . const $ Left e) qs_
|
||||
Right smp -> do
|
||||
logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB"
|
||||
|
@ -521,7 +551,7 @@ subscribeQueues c srv qs = do
|
|||
liftIO $ zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
|
||||
forM_ rs' $ \((connId, rq), r) -> liftIO $ processSubResult c rq connId r
|
||||
pure $ map (bimap fst (first $ protocolClientError SMP)) rs'
|
||||
_ -> pure $ M.fromList errs
|
||||
_ -> pure $ (Nothing, M.fromList errs)
|
||||
where
|
||||
checkQueue rq@(connId, RcvQueue {rcvId, server}) = do
|
||||
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
|
||||
|
|
|
@ -40,6 +40,7 @@ module Simplex.Messaging.Agent.Protocol
|
|||
-- * SMP agent protocol types
|
||||
ConnInfo,
|
||||
ACommand (..),
|
||||
ACmd (..),
|
||||
AParty (..),
|
||||
SAParty (..),
|
||||
MsgHash,
|
||||
|
@ -141,7 +142,8 @@ import Simplex.Messaging.Encoding
|
|||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Protocol
|
||||
( ErrorType,
|
||||
( AProtocolType,
|
||||
ErrorType,
|
||||
MsgBody,
|
||||
MsgFlags,
|
||||
MsgId,
|
||||
|
@ -228,6 +230,8 @@ data ACommand (p :: AParty) where
|
|||
CON :: ACommand Agent -- notification that connection is established
|
||||
SUB :: ACommand Client
|
||||
END :: ACommand Agent
|
||||
CONNECT :: AProtocolType -> TransportHost -> ACommand Agent
|
||||
DISCONNECT :: AProtocolType -> TransportHost -> ACommand Agent
|
||||
DOWN :: SMPServer -> [ConnId] -> ACommand Agent
|
||||
UP :: SMPServer -> [ConnId] -> ACommand Agent
|
||||
SEND :: MsgFlags -> MsgBody -> ACommand Client
|
||||
|
@ -929,8 +933,10 @@ commandP =
|
|||
<|> "INFO " *> infoCmd
|
||||
<|> "SUB" $> ACmd SClient SUB
|
||||
<|> "END" $> ACmd SAgent END
|
||||
<|> "DOWN " *> downsResp
|
||||
<|> "UP " *> upsResp
|
||||
<|> "CONNECT " *> connectResp
|
||||
<|> "DISCONNECT " *> disconnectResp
|
||||
<|> "DOWN " *> downResp
|
||||
<|> "UP " *> upResp
|
||||
<|> "SEND " *> sendCmd
|
||||
<|> "MID " *> msgIdResp
|
||||
<|> "SENT " *> sentResp
|
||||
|
@ -954,8 +960,10 @@ commandP =
|
|||
acptCmd = ACmd SClient .: ACPT <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString
|
||||
rjctCmd = ACmd SClient . RJCT <$> A.takeByteString
|
||||
infoCmd = ACmd SAgent . INFO <$> A.takeByteString
|
||||
downsResp = ACmd SAgent .: DOWN <$> strP_ <*> connections
|
||||
upsResp = ACmd SAgent .: UP <$> strP_ <*> connections
|
||||
connectResp = ACmd SAgent .: CONNECT <$> strP_ <*> strP
|
||||
disconnectResp = ACmd SAgent .: DISCONNECT <$> strP_ <*> strP
|
||||
downResp = ACmd SAgent .: DOWN <$> strP_ <*> connections
|
||||
upResp = ACmd SAgent .: UP <$> strP_ <*> connections
|
||||
sendCmd = ACmd SClient .: SEND <$> smpP <* A.space <*> A.takeByteString
|
||||
msgIdResp = ACmd SAgent . MID <$> A.decimal
|
||||
sentResp = ACmd SAgent . SENT <$> A.decimal
|
||||
|
@ -990,6 +998,8 @@ serializeCommand = \case
|
|||
INFO cInfo -> "INFO " <> serializeBinary cInfo
|
||||
SUB -> "SUB"
|
||||
END -> "END"
|
||||
CONNECT p h -> B.unwords ["CONNECT", strEncode p, strEncode h]
|
||||
DISCONNECT p h -> B.unwords ["DISCONNECT", strEncode p, strEncode h]
|
||||
DOWN srv conns -> B.unwords ["DOWN", strEncode srv, connections conns]
|
||||
UP srv conns -> B.unwords ["UP", strEncode srv, connections conns]
|
||||
SEND msgFlags msgBody -> "SEND " <> smpEncode msgFlags <> " " <> serializeBinary msgBody
|
||||
|
@ -1062,6 +1072,8 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
|||
ACPT {} -> Right cmd
|
||||
-- ERROR response does not always have connId
|
||||
ERR _ -> Right cmd
|
||||
CONNECT {} -> Right cmd
|
||||
DISCONNECT {} -> Right cmd
|
||||
DOWN {} -> Right cmd
|
||||
UP {} -> Right cmd
|
||||
-- other responses must have connId
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
|
||||
module Simplex.Messaging.Client
|
||||
( -- * Connect (disconnect) client to (from) SMP server
|
||||
ProtocolClient (thVersion, sessionId),
|
||||
ProtocolClient (thVersion, sessionId, transportHost),
|
||||
SMPClient,
|
||||
getProtocolClient,
|
||||
closeProtocolClient,
|
||||
|
@ -101,6 +101,7 @@ data ProtocolClient msg = ProtocolClient
|
|||
sessionId :: SessionId,
|
||||
thVersion :: Version,
|
||||
protocolServer :: ProtoServer msg,
|
||||
transportHost :: TransportHost,
|
||||
tcpTimeout :: Int,
|
||||
clientCorrId :: TVar Natural,
|
||||
sentCommands :: TMap CorrId (Request msg),
|
||||
|
@ -216,17 +217,17 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts
|
|||
--
|
||||
-- A single queue can be used for multiple 'SMPClient' instances,
|
||||
-- as 'SMPServerTransmission' includes server information.
|
||||
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do
|
||||
case chooseTransportHost networkConfig (host protocolServer) of
|
||||
Right useHost ->
|
||||
(atomically mkProtocolClient >>= runClient useTransport useHost)
|
||||
(atomically (mkProtocolClient useHost) >>= runClient useTransport useHost)
|
||||
`catch` \(e :: IOException) -> pure . Left $ PCEIOError e
|
||||
Left e -> pure $ Left e
|
||||
where
|
||||
NetworkConfig {tcpConnectTimeout, tcpTimeout, tcpKeepAlive, socksProxy, smpPingInterval} = networkConfig
|
||||
mkProtocolClient :: STM (ProtocolClient msg)
|
||||
mkProtocolClient = do
|
||||
mkProtocolClient :: TransportHost -> STM (ProtocolClient msg)
|
||||
mkProtocolClient transportHost = do
|
||||
connected <- newTVar False
|
||||
clientCorrId <- newTVar 0
|
||||
sentCommands <- TM.empty
|
||||
|
@ -239,6 +240,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
|
|||
thVersion = undefined,
|
||||
connected,
|
||||
protocolServer,
|
||||
transportHost,
|
||||
tcpTimeout,
|
||||
clientCorrId,
|
||||
sentCommands,
|
||||
|
@ -277,7 +279,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
|
|||
let c' = c {sessionId, thVersion} :: ProtocolClient msg
|
||||
-- TODO remove ping if 0 is passed (or Nothing?)
|
||||
raceAny_ [send c' th, process c', receive c' th, ping c']
|
||||
`finally` disconnected
|
||||
`finally` disconnected c'
|
||||
|
||||
send :: Transport c => ProtocolClient msg -> THandle c -> IO ()
|
||||
send ProtocolClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
|
||||
|
|
|
@ -162,8 +162,8 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
|||
connectClient :: ExceptT ProtocolClientError IO SMPClient
|
||||
connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected
|
||||
|
||||
clientDisconnected :: IO ()
|
||||
clientDisconnected = do
|
||||
clientDisconnected :: SMPClient -> IO ()
|
||||
clientDisconnected _ = do
|
||||
removeClientAndSubs >>= (`forM_` serverDown)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
|
|
|
@ -64,6 +64,8 @@ module Simplex.Messaging.Protocol
|
|||
PrivHeader (..),
|
||||
Protocol (..),
|
||||
ProtocolType (..),
|
||||
AProtocolType (..),
|
||||
ProtocolTypeI (..),
|
||||
ProtocolServer (..),
|
||||
ProtoServer,
|
||||
SMPServer,
|
||||
|
@ -135,7 +137,7 @@ import qualified Data.ByteString.Char8 as B
|
|||
import Data.Kind
|
||||
import Data.List.NonEmpty (NonEmpty (..))
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Maybe (isJust, isNothing)
|
||||
import Data.String
|
||||
import Data.Time.Clock.System (SystemTime (..))
|
||||
import Data.Type.Equality
|
||||
|
@ -578,7 +580,7 @@ instance StrEncoding ProtocolType where
|
|||
PSMP -> "smp"
|
||||
PNTF -> "ntf"
|
||||
strP =
|
||||
A.takeTill (== ':') >>= \case
|
||||
A.takeTill (\c -> c == ':' || c == ' ') >>= \case
|
||||
"smp" -> pure PSMP
|
||||
"ntf" -> pure PNTF
|
||||
_ -> fail "bad ProtocolType"
|
||||
|
@ -595,6 +597,11 @@ deriving instance Show (SProtocolType p)
|
|||
|
||||
data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p)
|
||||
|
||||
deriving instance Show AProtocolType
|
||||
|
||||
instance Eq AProtocolType where
|
||||
AProtocolType p == AProtocolType p' = isJust $ testEquality p p'
|
||||
|
||||
instance TestEquality SProtocolType where
|
||||
testEquality SPSMP SPSMP = Just Refl
|
||||
testEquality SPNTF SPNTF = Just Refl
|
||||
|
@ -618,6 +625,10 @@ instance StrEncoding AProtocolType where
|
|||
strEncode (AProtocolType p) = strEncode p
|
||||
strP = aProtocolType <$> strP
|
||||
|
||||
instance ToJSON AProtocolType where
|
||||
toEncoding = strToJEncoding
|
||||
toJSON = strToJSON
|
||||
|
||||
checkProtocolType :: forall t p p'. (ProtocolTypeI p, ProtocolTypeI p') => t p' -> Either String (t p)
|
||||
checkProtocolType p = case testEquality (protocolTypeI @p) (protocolTypeI @p') of
|
||||
Just Refl -> Right p
|
||||
|
|
|
@ -73,6 +73,10 @@ instance StrEncoding TransportHost where
|
|||
where
|
||||
ipNum = A.decimal <* A.char '.'
|
||||
|
||||
instance ToJSON TransportHost where
|
||||
toEncoding = strToJEncoding
|
||||
toJSON = strToJSON
|
||||
|
||||
newtype TransportHosts = TransportHosts {thList :: NonEmpty TransportHost}
|
||||
|
||||
instance StrEncoding TransportHosts where
|
||||
|
|
|
@ -78,9 +78,17 @@ agentTests (ATransport t) = do
|
|||
it "should deliver messages if one of connections has quota exceeded" $
|
||||
smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t
|
||||
|
||||
tGetAgent :: Transport c => c -> IO (ATransmissionOrError 'Agent)
|
||||
tGetAgent h = do
|
||||
t@(_, _, cmd) <- tGet SAgent h
|
||||
case cmd of
|
||||
Right CONNECT {} -> tGetAgent h
|
||||
Right DISCONNECT {} -> tGetAgent h
|
||||
_ -> pure t
|
||||
|
||||
-- | receive message to handle `h`
|
||||
(<#:) :: Transport c => c -> IO (ATransmissionOrError 'Agent)
|
||||
(<#:) = tGet SAgent
|
||||
(<#:) = tGetAgent
|
||||
|
||||
-- | send transmission `t` to handle `h` and get response
|
||||
(#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (ATransmissionOrError 'Agent)
|
||||
|
@ -114,7 +122,7 @@ h <#= p = (h <#:) >>= (`shouldSatisfy` p . correctTransmission)
|
|||
h #:# err = tryGet `shouldReturn` ()
|
||||
where
|
||||
tryGet =
|
||||
10000 `timeout` tGet SAgent h >>= \case
|
||||
10000 `timeout` tGetAgent h >>= \case
|
||||
Just _ -> error err
|
||||
_ -> return ()
|
||||
|
||||
|
|
|
@ -47,7 +47,12 @@ a ##> t = a >>= \t' -> liftIO (t' `shouldBe` t)
|
|||
a =##> p = a >>= \t -> liftIO (t `shouldSatisfy` p)
|
||||
|
||||
get :: MonadIO m => AgentClient -> m (ATransmission 'Agent)
|
||||
get c = atomically (readTBQueue $ subQ c)
|
||||
get c = do
|
||||
t@(_, _, cmd) <- atomically (readTBQueue $ subQ c)
|
||||
case cmd of
|
||||
CONNECT {} -> get c
|
||||
DISCONNECT {} -> get c
|
||||
_ -> pure t
|
||||
|
||||
pattern Msg :: MsgBody -> ACommand 'Agent
|
||||
pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
@ -27,6 +28,7 @@ import Simplex.Messaging.Agent.Protocol
|
|||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Server (runSMPAgentBlocking)
|
||||
import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultClientConfig, defaultNetworkConfig)
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client
|
||||
import Simplex.Messaging.Transport.KeepAlive
|
||||
|
@ -56,7 +58,14 @@ testDB3 :: String
|
|||
testDB3 = "tests/tmp/smp-agent3.test.protocol.db"
|
||||
|
||||
smpAgentTest :: forall c. Transport c => TProxy c -> ARawTransmission -> IO ARawTransmission
|
||||
smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> tGetRaw h
|
||||
smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> get h
|
||||
where
|
||||
get h = do
|
||||
t@(_, _, cmdStr) <- tGetRaw h
|
||||
case parseAll commandP cmdStr of
|
||||
Right (ACmd SAgent CONNECT {}) -> get h
|
||||
Right (ACmd SAgent DISCONNECT {}) -> get h
|
||||
_ -> pure t
|
||||
|
||||
runSmpAgentTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m, MonadFail m) => (c -> m a) -> m a
|
||||
runSmpAgentTest test = withSmpServer t . withSmpAgent t $ testSMPAgentClient test
|
||||
|
@ -177,7 +186,7 @@ agentCfg :: AgentConfig
|
|||
agentCfg =
|
||||
defaultAgentConfig
|
||||
{ tcpPort = agentTestPort,
|
||||
tbqSize = 1,
|
||||
tbqSize = 4,
|
||||
dbFile = testDB,
|
||||
smpCfg =
|
||||
defaultClientConfig
|
||||
|
|
Reference in New Issue