send host events when server hosts are connected and disconnected (#496)

This commit is contained in:
Evgeny Poberezkin 2022-08-13 11:57:36 +01:00 committed by GitHub
parent 68138c08d2
commit 02bba01c16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 145 additions and 64 deletions

View File

@ -457,7 +457,7 @@ subscribeConnections' c connIds = do
addRcvQueue :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) -> ConnId -> (RcvQueue, ConnData) -> Map SMPServer (Map ConnId (RcvQueue, ConnData)) addRcvQueue :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) -> ConnId -> (RcvQueue, ConnData) -> Map SMPServer (Map ConnId (RcvQueue, ConnData))
addRcvQueue m connId rq@(RcvQueue {server}, _) = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m addRcvQueue m connId rq@(RcvQueue {server}, _) = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m
subscribe :: (SMPServer, Map ConnId (RcvQueue, ConnData)) -> m (Map ConnId (Either AgentErrorType ())) subscribe :: (SMPServer, Map ConnId (RcvQueue, ConnData)) -> m (Map ConnId (Either AgentErrorType ()))
subscribe (srv, qs) = subscribeQueues c srv (M.map fst qs) subscribe (srv, qs) = snd <$> subscribeQueues c srv (M.map fst qs)
sendNtfCreate :: NtfSupervisor -> [Map ConnId (Either AgentErrorType ())] -> m () sendNtfCreate :: NtfSupervisor -> [Map ConnId (Either AgentErrorType ())] -> m ()
sendNtfCreate ns rcvRs = sendNtfCreate ns rcvRs =
forM_ (concatMap M.assocs rcvRs) $ \case forM_ (concatMap M.assocs rcvRs) $ \case

View File

@ -106,10 +106,31 @@ import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Types import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Parsers (parse) import Simplex.Messaging.Parsers (parse)
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), MsgId, NotifierId, NtfPrivateSignKey, NtfPublicVerifyKey, NtfServer, ProtoServer, ProtocolServer (..), QueueId, QueueIdsKeys (..), RcvMessage (..), RcvNtfPublicDhKey, SMPMsgMeta (..), SndPublicVerifyKey) import Simplex.Messaging.Protocol
( AProtocolType (..),
BrokerMsg,
ErrorType,
MsgFlags (..),
MsgId,
NotifierId,
NtfPrivateSignKey,
NtfPublicVerifyKey,
NtfServer,
ProtoServer,
Protocol (..),
ProtocolServer (..),
ProtocolTypeI (..),
QueueId,
QueueIdsKeys (..),
RcvMessage (..),
RcvNtfPublicDhKey,
SMPMsgMeta (..),
SndPublicVerifyKey,
)
import qualified Simplex.Messaging.Protocol as SMP import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap) import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util import Simplex.Messaging.Util
import Simplex.Messaging.Version import Simplex.Messaging.Version
import System.Timeout (timeout) import System.Timeout (timeout)
@ -232,46 +253,47 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
u <- askUnliftIO u <- askUnliftIO
liftEitherError (protocolClientError SMP) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u) liftEitherError (protocolClientError SMP) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u)
clientDisconnected :: UnliftIO m -> IO () clientDisconnected :: UnliftIO m -> SMPClient -> IO ()
clientDisconnected u = do clientDisconnected u client = do
removeClientAndSubs >>= (`forM_` serverDown u) removeClientAndSubs >>= (`forM_` serverDown)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
where where
updateSubs cVar = do removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
cs <- readTVar cVar removeClientAndSubs = atomically $ do
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs) TM.delete srv smpClients
addPendingSubs cVar cs TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
pure cs where
updateSubs cVar = do
cs <- readTVar cVar
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs)
addPendingSubs cVar cs
pure cs
addPendingSubs cVar cs = do addPendingSubs cVar cs = do
let ps = pendingSubscrSrvrs c let ps = pendingSubscrSrvrs c
TM.lookup srv ps >>= \case TM.lookup srv ps >>= \case
Just v -> TM.union cs v Just v -> TM.union cs v
_ -> TM.insert srv cVar ps _ -> TM.insert srv cVar ps
serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO () serverDown :: Map ConnId RcvQueue -> IO ()
serverDown u cs = unless (M.null cs) $ serverDown cs = unless (M.null cs) $
whenM (readTVarIO active) $ do whenM (readTVarIO active) $ do
let conns = M.keys cs let conns = M.keys cs
unless (null conns) . notifySub "" $ DOWN srv conns notifySub "" $ hostEvent DISCONNECT client
atomically $ mapM_ (releaseGetLock c) cs unless (null conns) . notifySub "" $ DOWN srv conns
unliftIO u reconnectServer atomically $ mapM_ (releaseGetLock c) cs
unliftIO u reconnectServer
reconnectServer :: m () reconnectServer :: m ()
reconnectServer = do reconnectServer = do
a <- async tryReconnectClient a <- async tryReconnectClient
atomically $ modifyTVar' (reconnections c) (a :) atomically $ modifyTVar' (reconnections c) (a :)
tryReconnectClient :: m () tryReconnectClient :: m ()
tryReconnectClient = do tryReconnectClient = do
ri <- asks $ reconnectInterval . config ri <- asks $ reconnectInterval . config
withRetryInterval ri $ \loop -> withRetryInterval ri $ \loop ->
reconnectClient `catchError` const loop reconnectClient `catchError` const loop
reconnectClient :: m () reconnectClient :: m ()
reconnectClient = reconnectClient =
@ -281,8 +303,11 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
where where
resubscribe :: Map ConnId RcvQueue -> m () resubscribe :: Map ConnId RcvQueue -> m ()
resubscribe qs = do resubscribe qs = do
(errs, oks) <- M.mapEither id <$> subscribeQueues c srv qs (client_, (errs, oks)) <- second (M.mapEither id) <$> subscribeQueues c srv qs
liftIO . unless (M.null oks) . notifySub "" . UP srv $ M.keys oks liftIO $ do
mapM_ (notifySub "" . hostEvent CONNECT) client_
unless (M.null oks) $ do
notifySub "" . UP srv $ M.keys oks
let (tempErrs, finalErrs) = M.partition temporaryAgentError errs let (tempErrs, finalErrs) = M.partition temporaryAgentError errs
liftIO . mapM_ (\(connId, e) -> notifySub connId $ ERR e) $ M.assocs finalErrs liftIO . mapM_ (\(connId, e) -> notifySub connId $ ERR e) $ M.assocs finalErrs
mapM_ throwError . listToMaybe $ M.elems tempErrs mapM_ throwError . listToMaybe $ M.elems tempErrs
@ -303,9 +328,10 @@ getNtfServerClient c@AgentClient {active, ntfClients} srv = do
cfg <- atomically . updateClientConfig c =<< asks (ntfCfg . config) cfg <- atomically . updateClientConfig c =<< asks (ntfCfg . config)
liftEitherError (protocolClientError NTF) (getProtocolClient srv cfg Nothing clientDisconnected) liftEitherError (protocolClientError NTF) (getProtocolClient srv cfg Nothing clientDisconnected)
clientDisconnected :: IO () clientDisconnected :: NtfClient -> IO ()
clientDisconnected = do clientDisconnected client = do
atomically $ TM.delete srv ntfClients atomically $ TM.delete srv ntfClients
atomically $ writeTBQueue (subQ c) ("", "", hostEvent DISCONNECT client)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a)) getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
@ -328,7 +354,7 @@ waitForProtocolClient c clientVar = do
newProtocolClient :: newProtocolClient ::
forall msg m. forall msg m.
AgentMonad m => (AgentMonad m, ProtocolTypeI (ProtoType msg)) =>
AgentClient -> AgentClient ->
ProtoServer msg -> ProtoServer msg ->
TMap (ProtoServer msg) (ClientVar msg) -> TMap (ProtoServer msg) (ClientVar msg) ->
@ -344,6 +370,7 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon
Right client -> do Right client -> do
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
atomically $ putTMVar clientVar r atomically $ putTMVar clientVar r
atomically $ writeTBQueue (subQ c) ("", "", hostEvent CONNECT client)
successAction client successAction client
Left e -> do Left e -> do
if temporaryAgentError e if temporaryAgentError e
@ -361,6 +388,9 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon
ri <- asks $ reconnectInterval . config ri <- asks $ reconnectInterval . config
withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop
hostEvent :: forall msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient msg -> ACommand 'Agent
hostEvent event client = event (AProtocolType $ protocolTypeI @(ProtoType msg)) $ transportHost client
updateClientConfig :: AgentClient -> ProtocolClientConfig -> STM ProtocolClientConfig updateClientConfig :: AgentClient -> ProtocolClientConfig -> STM ProtocolClientConfig
updateClientConfig AgentClient {useNetworkConfig} cfg = do updateClientConfig AgentClient {useNetworkConfig} cfg = do
networkConfig <- readTVar useNetworkConfig networkConfig <- readTVar useNetworkConfig
@ -505,14 +535,14 @@ temporaryAgentError = \case
_ -> False _ -> False
-- | subscribe multiple queues - all passed queues should be on the same server -- | subscribe multiple queues - all passed queues should be on the same server
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Map ConnId (Either AgentErrorType ())) subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ()))
subscribeQueues c srv qs = do subscribeQueues c srv qs = do
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs) (errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
forM_ qs_ $ atomically . uncurry (addPendingSubscription c) . swap forM_ qs_ $ atomically . uncurry (addPendingSubscription c) . swap
case L.nonEmpty qs_ of case L.nonEmpty qs_ of
Just qs' -> do Just qs' -> do
smp_ <- tryError (getSMPServerClient c srv) smp_ <- tryError (getSMPServerClient c srv)
M.fromList . (errs <>) <$> case smp_ of (eitherToMaybe smp_,) . M.fromList . (errs <>) <$> case smp_ of
Left e -> pure $ map (second . const $ Left e) qs_ Left e -> pure $ map (second . const $ Left e) qs_
Right smp -> do Right smp -> do
logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB" logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB"
@ -521,7 +551,7 @@ subscribeQueues c srv qs = do
liftIO $ zip qs_ . L.toList <$> subscribeSMPQueues smp qs2 liftIO $ zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
forM_ rs' $ \((connId, rq), r) -> liftIO $ processSubResult c rq connId r forM_ rs' $ \((connId, rq), r) -> liftIO $ processSubResult c rq connId r
pure $ map (bimap fst (first $ protocolClientError SMP)) rs' pure $ map (bimap fst (first $ protocolClientError SMP)) rs'
_ -> pure $ M.fromList errs _ -> pure $ (Nothing, M.fromList errs)
where where
checkQueue rq@(connId, RcvQueue {rcvId, server}) = do checkQueue rq@(connId, RcvQueue {rcvId, server}) = do
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c

View File

@ -40,6 +40,7 @@ module Simplex.Messaging.Agent.Protocol
-- * SMP agent protocol types -- * SMP agent protocol types
ConnInfo, ConnInfo,
ACommand (..), ACommand (..),
ACmd (..),
AParty (..), AParty (..),
SAParty (..), SAParty (..),
MsgHash, MsgHash,
@ -141,7 +142,8 @@ import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers import Simplex.Messaging.Parsers
import Simplex.Messaging.Protocol import Simplex.Messaging.Protocol
( ErrorType, ( AProtocolType,
ErrorType,
MsgBody, MsgBody,
MsgFlags, MsgFlags,
MsgId, MsgId,
@ -228,6 +230,8 @@ data ACommand (p :: AParty) where
CON :: ACommand Agent -- notification that connection is established CON :: ACommand Agent -- notification that connection is established
SUB :: ACommand Client SUB :: ACommand Client
END :: ACommand Agent END :: ACommand Agent
CONNECT :: AProtocolType -> TransportHost -> ACommand Agent
DISCONNECT :: AProtocolType -> TransportHost -> ACommand Agent
DOWN :: SMPServer -> [ConnId] -> ACommand Agent DOWN :: SMPServer -> [ConnId] -> ACommand Agent
UP :: SMPServer -> [ConnId] -> ACommand Agent UP :: SMPServer -> [ConnId] -> ACommand Agent
SEND :: MsgFlags -> MsgBody -> ACommand Client SEND :: MsgFlags -> MsgBody -> ACommand Client
@ -929,8 +933,10 @@ commandP =
<|> "INFO " *> infoCmd <|> "INFO " *> infoCmd
<|> "SUB" $> ACmd SClient SUB <|> "SUB" $> ACmd SClient SUB
<|> "END" $> ACmd SAgent END <|> "END" $> ACmd SAgent END
<|> "DOWN " *> downsResp <|> "CONNECT " *> connectResp
<|> "UP " *> upsResp <|> "DISCONNECT " *> disconnectResp
<|> "DOWN " *> downResp
<|> "UP " *> upResp
<|> "SEND " *> sendCmd <|> "SEND " *> sendCmd
<|> "MID " *> msgIdResp <|> "MID " *> msgIdResp
<|> "SENT " *> sentResp <|> "SENT " *> sentResp
@ -954,8 +960,10 @@ commandP =
acptCmd = ACmd SClient .: ACPT <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString acptCmd = ACmd SClient .: ACPT <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString
rjctCmd = ACmd SClient . RJCT <$> A.takeByteString rjctCmd = ACmd SClient . RJCT <$> A.takeByteString
infoCmd = ACmd SAgent . INFO <$> A.takeByteString infoCmd = ACmd SAgent . INFO <$> A.takeByteString
downsResp = ACmd SAgent .: DOWN <$> strP_ <*> connections connectResp = ACmd SAgent .: CONNECT <$> strP_ <*> strP
upsResp = ACmd SAgent .: UP <$> strP_ <*> connections disconnectResp = ACmd SAgent .: DISCONNECT <$> strP_ <*> strP
downResp = ACmd SAgent .: DOWN <$> strP_ <*> connections
upResp = ACmd SAgent .: UP <$> strP_ <*> connections
sendCmd = ACmd SClient .: SEND <$> smpP <* A.space <*> A.takeByteString sendCmd = ACmd SClient .: SEND <$> smpP <* A.space <*> A.takeByteString
msgIdResp = ACmd SAgent . MID <$> A.decimal msgIdResp = ACmd SAgent . MID <$> A.decimal
sentResp = ACmd SAgent . SENT <$> A.decimal sentResp = ACmd SAgent . SENT <$> A.decimal
@ -990,6 +998,8 @@ serializeCommand = \case
INFO cInfo -> "INFO " <> serializeBinary cInfo INFO cInfo -> "INFO " <> serializeBinary cInfo
SUB -> "SUB" SUB -> "SUB"
END -> "END" END -> "END"
CONNECT p h -> B.unwords ["CONNECT", strEncode p, strEncode h]
DISCONNECT p h -> B.unwords ["DISCONNECT", strEncode p, strEncode h]
DOWN srv conns -> B.unwords ["DOWN", strEncode srv, connections conns] DOWN srv conns -> B.unwords ["DOWN", strEncode srv, connections conns]
UP srv conns -> B.unwords ["UP", strEncode srv, connections conns] UP srv conns -> B.unwords ["UP", strEncode srv, connections conns]
SEND msgFlags msgBody -> "SEND " <> smpEncode msgFlags <> " " <> serializeBinary msgBody SEND msgFlags msgBody -> "SEND " <> smpEncode msgFlags <> " " <> serializeBinary msgBody
@ -1062,6 +1072,8 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
ACPT {} -> Right cmd ACPT {} -> Right cmd
-- ERROR response does not always have connId -- ERROR response does not always have connId
ERR _ -> Right cmd ERR _ -> Right cmd
CONNECT {} -> Right cmd
DISCONNECT {} -> Right cmd
DOWN {} -> Right cmd DOWN {} -> Right cmd
UP {} -> Right cmd UP {} -> Right cmd
-- other responses must have connId -- other responses must have connId

View File

@ -26,7 +26,7 @@
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md -- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
module Simplex.Messaging.Client module Simplex.Messaging.Client
( -- * Connect (disconnect) client to (from) SMP server ( -- * Connect (disconnect) client to (from) SMP server
ProtocolClient (thVersion, sessionId), ProtocolClient (thVersion, sessionId, transportHost),
SMPClient, SMPClient,
getProtocolClient, getProtocolClient,
closeProtocolClient, closeProtocolClient,
@ -101,6 +101,7 @@ data ProtocolClient msg = ProtocolClient
sessionId :: SessionId, sessionId :: SessionId,
thVersion :: Version, thVersion :: Version,
protocolServer :: ProtoServer msg, protocolServer :: ProtoServer msg,
transportHost :: TransportHost,
tcpTimeout :: Int, tcpTimeout :: Int,
clientCorrId :: TVar Natural, clientCorrId :: TVar Natural,
sentCommands :: TMap CorrId (Request msg), sentCommands :: TMap CorrId (Request msg),
@ -216,17 +217,17 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts
-- --
-- A single queue can be used for multiple 'SMPClient' instances, -- A single queue can be used for multiple 'SMPClient' instances,
-- as 'SMPServerTransmission' includes server information. -- as 'SMPServerTransmission' includes server information.
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg)) getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do
case chooseTransportHost networkConfig (host protocolServer) of case chooseTransportHost networkConfig (host protocolServer) of
Right useHost -> Right useHost ->
(atomically mkProtocolClient >>= runClient useTransport useHost) (atomically (mkProtocolClient useHost) >>= runClient useTransport useHost)
`catch` \(e :: IOException) -> pure . Left $ PCEIOError e `catch` \(e :: IOException) -> pure . Left $ PCEIOError e
Left e -> pure $ Left e Left e -> pure $ Left e
where where
NetworkConfig {tcpConnectTimeout, tcpTimeout, tcpKeepAlive, socksProxy, smpPingInterval} = networkConfig NetworkConfig {tcpConnectTimeout, tcpTimeout, tcpKeepAlive, socksProxy, smpPingInterval} = networkConfig
mkProtocolClient :: STM (ProtocolClient msg) mkProtocolClient :: TransportHost -> STM (ProtocolClient msg)
mkProtocolClient = do mkProtocolClient transportHost = do
connected <- newTVar False connected <- newTVar False
clientCorrId <- newTVar 0 clientCorrId <- newTVar 0
sentCommands <- TM.empty sentCommands <- TM.empty
@ -239,6 +240,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
thVersion = undefined, thVersion = undefined,
connected, connected,
protocolServer, protocolServer,
transportHost,
tcpTimeout, tcpTimeout,
clientCorrId, clientCorrId,
sentCommands, sentCommands,
@ -277,7 +279,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
let c' = c {sessionId, thVersion} :: ProtocolClient msg let c' = c {sessionId, thVersion} :: ProtocolClient msg
-- TODO remove ping if 0 is passed (or Nothing?) -- TODO remove ping if 0 is passed (or Nothing?)
raceAny_ [send c' th, process c', receive c' th, ping c'] raceAny_ [send c' th, process c', receive c' th, ping c']
`finally` disconnected `finally` disconnected c'
send :: Transport c => ProtocolClient msg -> THandle c -> IO () send :: Transport c => ProtocolClient msg -> THandle c -> IO ()
send ProtocolClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h send ProtocolClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h

View File

@ -162,8 +162,8 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
connectClient :: ExceptT ProtocolClientError IO SMPClient connectClient :: ExceptT ProtocolClientError IO SMPClient
connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected
clientDisconnected :: IO () clientDisconnected :: SMPClient -> IO ()
clientDisconnected = do clientDisconnected _ = do
removeClientAndSubs >>= (`forM_` serverDown) removeClientAndSubs >>= (`forM_` serverDown)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv

View File

@ -64,6 +64,8 @@ module Simplex.Messaging.Protocol
PrivHeader (..), PrivHeader (..),
Protocol (..), Protocol (..),
ProtocolType (..), ProtocolType (..),
AProtocolType (..),
ProtocolTypeI (..),
ProtocolServer (..), ProtocolServer (..),
ProtoServer, ProtoServer,
SMPServer, SMPServer,
@ -135,7 +137,7 @@ import qualified Data.ByteString.Char8 as B
import Data.Kind import Data.Kind
import Data.List.NonEmpty (NonEmpty (..)) import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L import qualified Data.List.NonEmpty as L
import Data.Maybe (isNothing) import Data.Maybe (isJust, isNothing)
import Data.String import Data.String
import Data.Time.Clock.System (SystemTime (..)) import Data.Time.Clock.System (SystemTime (..))
import Data.Type.Equality import Data.Type.Equality
@ -578,7 +580,7 @@ instance StrEncoding ProtocolType where
PSMP -> "smp" PSMP -> "smp"
PNTF -> "ntf" PNTF -> "ntf"
strP = strP =
A.takeTill (== ':') >>= \case A.takeTill (\c -> c == ':' || c == ' ') >>= \case
"smp" -> pure PSMP "smp" -> pure PSMP
"ntf" -> pure PNTF "ntf" -> pure PNTF
_ -> fail "bad ProtocolType" _ -> fail "bad ProtocolType"
@ -595,6 +597,11 @@ deriving instance Show (SProtocolType p)
data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p) data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p)
deriving instance Show AProtocolType
instance Eq AProtocolType where
AProtocolType p == AProtocolType p' = isJust $ testEquality p p'
instance TestEquality SProtocolType where instance TestEquality SProtocolType where
testEquality SPSMP SPSMP = Just Refl testEquality SPSMP SPSMP = Just Refl
testEquality SPNTF SPNTF = Just Refl testEquality SPNTF SPNTF = Just Refl
@ -618,6 +625,10 @@ instance StrEncoding AProtocolType where
strEncode (AProtocolType p) = strEncode p strEncode (AProtocolType p) = strEncode p
strP = aProtocolType <$> strP strP = aProtocolType <$> strP
instance ToJSON AProtocolType where
toEncoding = strToJEncoding
toJSON = strToJSON
checkProtocolType :: forall t p p'. (ProtocolTypeI p, ProtocolTypeI p') => t p' -> Either String (t p) checkProtocolType :: forall t p p'. (ProtocolTypeI p, ProtocolTypeI p') => t p' -> Either String (t p)
checkProtocolType p = case testEquality (protocolTypeI @p) (protocolTypeI @p') of checkProtocolType p = case testEquality (protocolTypeI @p) (protocolTypeI @p') of
Just Refl -> Right p Just Refl -> Right p

View File

@ -73,6 +73,10 @@ instance StrEncoding TransportHost where
where where
ipNum = A.decimal <* A.char '.' ipNum = A.decimal <* A.char '.'
instance ToJSON TransportHost where
toEncoding = strToJEncoding
toJSON = strToJSON
newtype TransportHosts = TransportHosts {thList :: NonEmpty TransportHost} newtype TransportHosts = TransportHosts {thList :: NonEmpty TransportHost}
instance StrEncoding TransportHosts where instance StrEncoding TransportHosts where

View File

@ -78,9 +78,17 @@ agentTests (ATransport t) = do
it "should deliver messages if one of connections has quota exceeded" $ it "should deliver messages if one of connections has quota exceeded" $
smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t
tGetAgent :: Transport c => c -> IO (ATransmissionOrError 'Agent)
tGetAgent h = do
t@(_, _, cmd) <- tGet SAgent h
case cmd of
Right CONNECT {} -> tGetAgent h
Right DISCONNECT {} -> tGetAgent h
_ -> pure t
-- | receive message to handle `h` -- | receive message to handle `h`
(<#:) :: Transport c => c -> IO (ATransmissionOrError 'Agent) (<#:) :: Transport c => c -> IO (ATransmissionOrError 'Agent)
(<#:) = tGet SAgent (<#:) = tGetAgent
-- | send transmission `t` to handle `h` and get response -- | send transmission `t` to handle `h` and get response
(#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (ATransmissionOrError 'Agent) (#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (ATransmissionOrError 'Agent)
@ -114,7 +122,7 @@ h <#= p = (h <#:) >>= (`shouldSatisfy` p . correctTransmission)
h #:# err = tryGet `shouldReturn` () h #:# err = tryGet `shouldReturn` ()
where where
tryGet = tryGet =
10000 `timeout` tGet SAgent h >>= \case 10000 `timeout` tGetAgent h >>= \case
Just _ -> error err Just _ -> error err
_ -> return () _ -> return ()

View File

@ -47,7 +47,12 @@ a ##> t = a >>= \t' -> liftIO (t' `shouldBe` t)
a =##> p = a >>= \t -> liftIO (t `shouldSatisfy` p) a =##> p = a >>= \t -> liftIO (t `shouldSatisfy` p)
get :: MonadIO m => AgentClient -> m (ATransmission 'Agent) get :: MonadIO m => AgentClient -> m (ATransmission 'Agent)
get c = atomically (readTBQueue $ subQ c) get c = do
t@(_, _, cmd) <- atomically (readTBQueue $ subQ c)
case cmd of
CONNECT {} -> get c
DISCONNECT {} -> get c
_ -> pure t
pattern Msg :: MsgBody -> ACommand 'Agent pattern Msg :: MsgBody -> ACommand 'Agent
pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody

View File

@ -1,4 +1,5 @@
{-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
@ -27,6 +28,7 @@ import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Agent.Server (runSMPAgentBlocking) import Simplex.Messaging.Agent.Server (runSMPAgentBlocking)
import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultClientConfig, defaultNetworkConfig) import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultClientConfig, defaultNetworkConfig)
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Transport import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Client import Simplex.Messaging.Transport.Client
import Simplex.Messaging.Transport.KeepAlive import Simplex.Messaging.Transport.KeepAlive
@ -56,7 +58,14 @@ testDB3 :: String
testDB3 = "tests/tmp/smp-agent3.test.protocol.db" testDB3 = "tests/tmp/smp-agent3.test.protocol.db"
smpAgentTest :: forall c. Transport c => TProxy c -> ARawTransmission -> IO ARawTransmission smpAgentTest :: forall c. Transport c => TProxy c -> ARawTransmission -> IO ARawTransmission
smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> tGetRaw h smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> get h
where
get h = do
t@(_, _, cmdStr) <- tGetRaw h
case parseAll commandP cmdStr of
Right (ACmd SAgent CONNECT {}) -> get h
Right (ACmd SAgent DISCONNECT {}) -> get h
_ -> pure t
runSmpAgentTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m, MonadFail m) => (c -> m a) -> m a runSmpAgentTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m, MonadFail m) => (c -> m a) -> m a
runSmpAgentTest test = withSmpServer t . withSmpAgent t $ testSMPAgentClient test runSmpAgentTest test = withSmpServer t . withSmpAgent t $ testSMPAgentClient test
@ -177,7 +186,7 @@ agentCfg :: AgentConfig
agentCfg = agentCfg =
defaultAgentConfig defaultAgentConfig
{ tcpPort = agentTestPort, { tcpPort = agentTestPort,
tbqSize = 1, tbqSize = 4,
dbFile = testDB, dbFile = testDB,
smpCfg = smpCfg =
defaultClientConfig defaultClientConfig