send host events when server hosts are connected and disconnected (#496)

This commit is contained in:
Evgeny Poberezkin 2022-08-13 11:57:36 +01:00 committed by GitHub
parent 68138c08d2
commit 02bba01c16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 145 additions and 64 deletions

View File

@ -457,7 +457,7 @@ subscribeConnections' c connIds = do
addRcvQueue :: Map SMPServer (Map ConnId (RcvQueue, ConnData)) -> ConnId -> (RcvQueue, ConnData) -> Map SMPServer (Map ConnId (RcvQueue, ConnData))
addRcvQueue m connId rq@(RcvQueue {server}, _) = M.alter (Just . maybe (M.singleton connId rq) (M.insert connId rq)) server m
subscribe :: (SMPServer, Map ConnId (RcvQueue, ConnData)) -> m (Map ConnId (Either AgentErrorType ()))
subscribe (srv, qs) = subscribeQueues c srv (M.map fst qs)
subscribe (srv, qs) = snd <$> subscribeQueues c srv (M.map fst qs)
sendNtfCreate :: NtfSupervisor -> [Map ConnId (Either AgentErrorType ())] -> m ()
sendNtfCreate ns rcvRs =
forM_ (concatMap M.assocs rcvRs) $ \case

View File

@ -106,10 +106,31 @@ import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Types
import Simplex.Messaging.Parsers (parse)
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), MsgId, NotifierId, NtfPrivateSignKey, NtfPublicVerifyKey, NtfServer, ProtoServer, ProtocolServer (..), QueueId, QueueIdsKeys (..), RcvMessage (..), RcvNtfPublicDhKey, SMPMsgMeta (..), SndPublicVerifyKey)
import Simplex.Messaging.Protocol
( AProtocolType (..),
BrokerMsg,
ErrorType,
MsgFlags (..),
MsgId,
NotifierId,
NtfPrivateSignKey,
NtfPublicVerifyKey,
NtfServer,
ProtoServer,
Protocol (..),
ProtocolServer (..),
ProtocolTypeI (..),
QueueId,
QueueIdsKeys (..),
RcvMessage (..),
RcvNtfPublicDhKey,
SMPMsgMeta (..),
SndPublicVerifyKey,
)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Transport.Client (TransportHost)
import Simplex.Messaging.Util
import Simplex.Messaging.Version
import System.Timeout (timeout)
@ -232,46 +253,47 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
u <- askUnliftIO
liftEitherError (protocolClientError SMP) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u)
clientDisconnected :: UnliftIO m -> IO ()
clientDisconnected u = do
removeClientAndSubs >>= (`forM_` serverDown u)
clientDisconnected :: UnliftIO m -> SMPClient -> IO ()
clientDisconnected u client = do
removeClientAndSubs >>= (`forM_` serverDown)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
where
updateSubs cVar = do
cs <- readTVar cVar
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs)
addPendingSubs cVar cs
pure cs
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
removeClientAndSubs = atomically $ do
TM.delete srv smpClients
TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
where
updateSubs cVar = do
cs <- readTVar cVar
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs)
addPendingSubs cVar cs
pure cs
addPendingSubs cVar cs = do
let ps = pendingSubscrSrvrs c
TM.lookup srv ps >>= \case
Just v -> TM.union cs v
_ -> TM.insert srv cVar ps
addPendingSubs cVar cs = do
let ps = pendingSubscrSrvrs c
TM.lookup srv ps >>= \case
Just v -> TM.union cs v
_ -> TM.insert srv cVar ps
serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO ()
serverDown u cs = unless (M.null cs) $
whenM (readTVarIO active) $ do
let conns = M.keys cs
unless (null conns) . notifySub "" $ DOWN srv conns
atomically $ mapM_ (releaseGetLock c) cs
unliftIO u reconnectServer
serverDown :: Map ConnId RcvQueue -> IO ()
serverDown cs = unless (M.null cs) $
whenM (readTVarIO active) $ do
let conns = M.keys cs
notifySub "" $ hostEvent DISCONNECT client
unless (null conns) . notifySub "" $ DOWN srv conns
atomically $ mapM_ (releaseGetLock c) cs
unliftIO u reconnectServer
reconnectServer :: m ()
reconnectServer = do
a <- async tryReconnectClient
atomically $ modifyTVar' (reconnections c) (a :)
reconnectServer :: m ()
reconnectServer = do
a <- async tryReconnectClient
atomically $ modifyTVar' (reconnections c) (a :)
tryReconnectClient :: m ()
tryReconnectClient = do
ri <- asks $ reconnectInterval . config
withRetryInterval ri $ \loop ->
reconnectClient `catchError` const loop
tryReconnectClient :: m ()
tryReconnectClient = do
ri <- asks $ reconnectInterval . config
withRetryInterval ri $ \loop ->
reconnectClient `catchError` const loop
reconnectClient :: m ()
reconnectClient =
@ -281,8 +303,11 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
where
resubscribe :: Map ConnId RcvQueue -> m ()
resubscribe qs = do
(errs, oks) <- M.mapEither id <$> subscribeQueues c srv qs
liftIO . unless (M.null oks) . notifySub "" . UP srv $ M.keys oks
(client_, (errs, oks)) <- second (M.mapEither id) <$> subscribeQueues c srv qs
liftIO $ do
mapM_ (notifySub "" . hostEvent CONNECT) client_
unless (M.null oks) $ do
notifySub "" . UP srv $ M.keys oks
let (tempErrs, finalErrs) = M.partition temporaryAgentError errs
liftIO . mapM_ (\(connId, e) -> notifySub connId $ ERR e) $ M.assocs finalErrs
mapM_ throwError . listToMaybe $ M.elems tempErrs
@ -303,9 +328,10 @@ getNtfServerClient c@AgentClient {active, ntfClients} srv = do
cfg <- atomically . updateClientConfig c =<< asks (ntfCfg . config)
liftEitherError (protocolClientError NTF) (getProtocolClient srv cfg Nothing clientDisconnected)
clientDisconnected :: IO ()
clientDisconnected = do
clientDisconnected :: NtfClient -> IO ()
clientDisconnected client = do
atomically $ TM.delete srv ntfClients
atomically $ writeTBQueue (subQ c) ("", "", hostEvent DISCONNECT client)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
@ -328,7 +354,7 @@ waitForProtocolClient c clientVar = do
newProtocolClient ::
forall msg m.
AgentMonad m =>
(AgentMonad m, ProtocolTypeI (ProtoType msg)) =>
AgentClient ->
ProtoServer msg ->
TMap (ProtoServer msg) (ClientVar msg) ->
@ -344,6 +370,7 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon
Right client -> do
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
atomically $ putTMVar clientVar r
atomically $ writeTBQueue (subQ c) ("", "", hostEvent CONNECT client)
successAction client
Left e -> do
if temporaryAgentError e
@ -361,6 +388,9 @@ newProtocolClient c srv clients connectClient reconnectClient clientVar = tryCon
ri <- asks $ reconnectInterval . config
withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop
hostEvent :: forall msg. ProtocolTypeI (ProtoType msg) => (AProtocolType -> TransportHost -> ACommand 'Agent) -> ProtocolClient msg -> ACommand 'Agent
hostEvent event client = event (AProtocolType $ protocolTypeI @(ProtoType msg)) $ transportHost client
updateClientConfig :: AgentClient -> ProtocolClientConfig -> STM ProtocolClientConfig
updateClientConfig AgentClient {useNetworkConfig} cfg = do
networkConfig <- readTVar useNetworkConfig
@ -505,14 +535,14 @@ temporaryAgentError = \case
_ -> False
-- | subscribe multiple queues - all passed queues should be on the same server
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Map ConnId (Either AgentErrorType ()))
subscribeQueues :: AgentMonad m => AgentClient -> SMPServer -> Map ConnId RcvQueue -> m (Maybe SMPClient, Map ConnId (Either AgentErrorType ()))
subscribeQueues c srv qs = do
(errs, qs_) <- partitionEithers <$> mapM checkQueue (M.assocs qs)
forM_ qs_ $ atomically . uncurry (addPendingSubscription c) . swap
case L.nonEmpty qs_ of
Just qs' -> do
smp_ <- tryError (getSMPServerClient c srv)
M.fromList . (errs <>) <$> case smp_ of
(eitherToMaybe smp_,) . M.fromList . (errs <>) <$> case smp_ of
Left e -> pure $ map (second . const $ Left e) qs_
Right smp -> do
logServer "-->" c srv (bshow (length qs_) <> " queues") "SUB"
@ -521,7 +551,7 @@ subscribeQueues c srv qs = do
liftIO $ zip qs_ . L.toList <$> subscribeSMPQueues smp qs2
forM_ rs' $ \((connId, rq), r) -> liftIO $ processSubResult c rq connId r
pure $ map (bimap fst (first $ protocolClientError SMP)) rs'
_ -> pure $ M.fromList errs
_ -> pure $ (Nothing, M.fromList errs)
where
checkQueue rq@(connId, RcvQueue {rcvId, server}) = do
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c

View File

@ -40,6 +40,7 @@ module Simplex.Messaging.Agent.Protocol
-- * SMP agent protocol types
ConnInfo,
ACommand (..),
ACmd (..),
AParty (..),
SAParty (..),
MsgHash,
@ -141,7 +142,8 @@ import Simplex.Messaging.Encoding
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Parsers
import Simplex.Messaging.Protocol
( ErrorType,
( AProtocolType,
ErrorType,
MsgBody,
MsgFlags,
MsgId,
@ -228,6 +230,8 @@ data ACommand (p :: AParty) where
CON :: ACommand Agent -- notification that connection is established
SUB :: ACommand Client
END :: ACommand Agent
CONNECT :: AProtocolType -> TransportHost -> ACommand Agent
DISCONNECT :: AProtocolType -> TransportHost -> ACommand Agent
DOWN :: SMPServer -> [ConnId] -> ACommand Agent
UP :: SMPServer -> [ConnId] -> ACommand Agent
SEND :: MsgFlags -> MsgBody -> ACommand Client
@ -929,8 +933,10 @@ commandP =
<|> "INFO " *> infoCmd
<|> "SUB" $> ACmd SClient SUB
<|> "END" $> ACmd SAgent END
<|> "DOWN " *> downsResp
<|> "UP " *> upsResp
<|> "CONNECT " *> connectResp
<|> "DISCONNECT " *> disconnectResp
<|> "DOWN " *> downResp
<|> "UP " *> upResp
<|> "SEND " *> sendCmd
<|> "MID " *> msgIdResp
<|> "SENT " *> sentResp
@ -954,8 +960,10 @@ commandP =
acptCmd = ACmd SClient .: ACPT <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString
rjctCmd = ACmd SClient . RJCT <$> A.takeByteString
infoCmd = ACmd SAgent . INFO <$> A.takeByteString
downsResp = ACmd SAgent .: DOWN <$> strP_ <*> connections
upsResp = ACmd SAgent .: UP <$> strP_ <*> connections
connectResp = ACmd SAgent .: CONNECT <$> strP_ <*> strP
disconnectResp = ACmd SAgent .: DISCONNECT <$> strP_ <*> strP
downResp = ACmd SAgent .: DOWN <$> strP_ <*> connections
upResp = ACmd SAgent .: UP <$> strP_ <*> connections
sendCmd = ACmd SClient .: SEND <$> smpP <* A.space <*> A.takeByteString
msgIdResp = ACmd SAgent . MID <$> A.decimal
sentResp = ACmd SAgent . SENT <$> A.decimal
@ -990,6 +998,8 @@ serializeCommand = \case
INFO cInfo -> "INFO " <> serializeBinary cInfo
SUB -> "SUB"
END -> "END"
CONNECT p h -> B.unwords ["CONNECT", strEncode p, strEncode h]
DISCONNECT p h -> B.unwords ["DISCONNECT", strEncode p, strEncode h]
DOWN srv conns -> B.unwords ["DOWN", strEncode srv, connections conns]
UP srv conns -> B.unwords ["UP", strEncode srv, connections conns]
SEND msgFlags msgBody -> "SEND " <> smpEncode msgFlags <> " " <> serializeBinary msgBody
@ -1062,6 +1072,8 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
ACPT {} -> Right cmd
-- ERROR response does not always have connId
ERR _ -> Right cmd
CONNECT {} -> Right cmd
DISCONNECT {} -> Right cmd
DOWN {} -> Right cmd
UP {} -> Right cmd
-- other responses must have connId

View File

@ -26,7 +26,7 @@
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
module Simplex.Messaging.Client
( -- * Connect (disconnect) client to (from) SMP server
ProtocolClient (thVersion, sessionId),
ProtocolClient (thVersion, sessionId, transportHost),
SMPClient,
getProtocolClient,
closeProtocolClient,
@ -101,6 +101,7 @@ data ProtocolClient msg = ProtocolClient
sessionId :: SessionId,
thVersion :: Version,
protocolServer :: ProtoServer msg,
transportHost :: TransportHost,
tcpTimeout :: Int,
clientCorrId :: TVar Natural,
sentCommands :: TMap CorrId (Request msg),
@ -216,17 +217,17 @@ chooseTransportHost NetworkConfig {socksProxy, hostMode, requiredHostMode} hosts
--
-- A single queue can be used for multiple 'SMPClient' instances,
-- as 'SMPServerTransmission' includes server information.
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg))
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> (ProtocolClient msg -> IO ()) -> IO (Either ProtocolClientError (ProtocolClient msg))
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig, smpServerVRange} msgQ disconnected = do
case chooseTransportHost networkConfig (host protocolServer) of
Right useHost ->
(atomically mkProtocolClient >>= runClient useTransport useHost)
(atomically (mkProtocolClient useHost) >>= runClient useTransport useHost)
`catch` \(e :: IOException) -> pure . Left $ PCEIOError e
Left e -> pure $ Left e
where
NetworkConfig {tcpConnectTimeout, tcpTimeout, tcpKeepAlive, socksProxy, smpPingInterval} = networkConfig
mkProtocolClient :: STM (ProtocolClient msg)
mkProtocolClient = do
mkProtocolClient :: TransportHost -> STM (ProtocolClient msg)
mkProtocolClient transportHost = do
connected <- newTVar False
clientCorrId <- newTVar 0
sentCommands <- TM.empty
@ -239,6 +240,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
thVersion = undefined,
connected,
protocolServer,
transportHost,
tcpTimeout,
clientCorrId,
sentCommands,
@ -277,7 +279,7 @@ getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, networkConfig,
let c' = c {sessionId, thVersion} :: ProtocolClient msg
-- TODO remove ping if 0 is passed (or Nothing?)
raceAny_ [send c' th, process c', receive c' th, ping c']
`finally` disconnected
`finally` disconnected c'
send :: Transport c => ProtocolClient msg -> THandle c -> IO ()
send ProtocolClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h

View File

@ -162,8 +162,8 @@ getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
connectClient :: ExceptT ProtocolClientError IO SMPClient
connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected
clientDisconnected :: IO ()
clientDisconnected = do
clientDisconnected :: SMPClient -> IO ()
clientDisconnected _ = do
removeClientAndSubs >>= (`forM_` serverDown)
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv

View File

@ -64,6 +64,8 @@ module Simplex.Messaging.Protocol
PrivHeader (..),
Protocol (..),
ProtocolType (..),
AProtocolType (..),
ProtocolTypeI (..),
ProtocolServer (..),
ProtoServer,
SMPServer,
@ -135,7 +137,7 @@ import qualified Data.ByteString.Char8 as B
import Data.Kind
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Data.Maybe (isNothing)
import Data.Maybe (isJust, isNothing)
import Data.String
import Data.Time.Clock.System (SystemTime (..))
import Data.Type.Equality
@ -578,7 +580,7 @@ instance StrEncoding ProtocolType where
PSMP -> "smp"
PNTF -> "ntf"
strP =
A.takeTill (== ':') >>= \case
A.takeTill (\c -> c == ':' || c == ' ') >>= \case
"smp" -> pure PSMP
"ntf" -> pure PNTF
_ -> fail "bad ProtocolType"
@ -595,6 +597,11 @@ deriving instance Show (SProtocolType p)
data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p)
deriving instance Show AProtocolType
instance Eq AProtocolType where
AProtocolType p == AProtocolType p' = isJust $ testEquality p p'
instance TestEquality SProtocolType where
testEquality SPSMP SPSMP = Just Refl
testEquality SPNTF SPNTF = Just Refl
@ -618,6 +625,10 @@ instance StrEncoding AProtocolType where
strEncode (AProtocolType p) = strEncode p
strP = aProtocolType <$> strP
instance ToJSON AProtocolType where
toEncoding = strToJEncoding
toJSON = strToJSON
checkProtocolType :: forall t p p'. (ProtocolTypeI p, ProtocolTypeI p') => t p' -> Either String (t p)
checkProtocolType p = case testEquality (protocolTypeI @p) (protocolTypeI @p') of
Just Refl -> Right p

View File

@ -73,6 +73,10 @@ instance StrEncoding TransportHost where
where
ipNum = A.decimal <* A.char '.'
instance ToJSON TransportHost where
toEncoding = strToJEncoding
toJSON = strToJSON
newtype TransportHosts = TransportHosts {thList :: NonEmpty TransportHost}
instance StrEncoding TransportHosts where

View File

@ -78,9 +78,17 @@ agentTests (ATransport t) = do
it "should deliver messages if one of connections has quota exceeded" $
smpAgentTest2_2_1 $ testMsgDeliveryQuotaExceeded t
tGetAgent :: Transport c => c -> IO (ATransmissionOrError 'Agent)
tGetAgent h = do
t@(_, _, cmd) <- tGet SAgent h
case cmd of
Right CONNECT {} -> tGetAgent h
Right DISCONNECT {} -> tGetAgent h
_ -> pure t
-- | receive message to handle `h`
(<#:) :: Transport c => c -> IO (ATransmissionOrError 'Agent)
(<#:) = tGet SAgent
(<#:) = tGetAgent
-- | send transmission `t` to handle `h` and get response
(#:) :: Transport c => c -> (ByteString, ByteString, ByteString) -> IO (ATransmissionOrError 'Agent)
@ -114,7 +122,7 @@ h <#= p = (h <#:) >>= (`shouldSatisfy` p . correctTransmission)
h #:# err = tryGet `shouldReturn` ()
where
tryGet =
10000 `timeout` tGet SAgent h >>= \case
10000 `timeout` tGetAgent h >>= \case
Just _ -> error err
_ -> return ()

View File

@ -47,7 +47,12 @@ a ##> t = a >>= \t' -> liftIO (t' `shouldBe` t)
a =##> p = a >>= \t -> liftIO (t `shouldSatisfy` p)
get :: MonadIO m => AgentClient -> m (ATransmission 'Agent)
get c = atomically (readTBQueue $ subQ c)
get c = do
t@(_, _, cmd) <- atomically (readTBQueue $ subQ c)
case cmd of
CONNECT {} -> get c
DISCONNECT {} -> get c
_ -> pure t
pattern Msg :: MsgBody -> ACommand 'Agent
pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody

View File

@ -1,4 +1,5 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
@ -27,6 +28,7 @@ import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Agent.RetryInterval
import Simplex.Messaging.Agent.Server (runSMPAgentBlocking)
import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultClientConfig, defaultNetworkConfig)
import Simplex.Messaging.Parsers (parseAll)
import Simplex.Messaging.Transport
import Simplex.Messaging.Transport.Client
import Simplex.Messaging.Transport.KeepAlive
@ -56,7 +58,14 @@ testDB3 :: String
testDB3 = "tests/tmp/smp-agent3.test.protocol.db"
smpAgentTest :: forall c. Transport c => TProxy c -> ARawTransmission -> IO ARawTransmission
smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> tGetRaw h
smpAgentTest _ cmd = runSmpAgentTest $ \(h :: c) -> tPutRaw h cmd >> get h
where
get h = do
t@(_, _, cmdStr) <- tGetRaw h
case parseAll commandP cmdStr of
Right (ACmd SAgent CONNECT {}) -> get h
Right (ACmd SAgent DISCONNECT {}) -> get h
_ -> pure t
runSmpAgentTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m, MonadFail m) => (c -> m a) -> m a
runSmpAgentTest test = withSmpServer t . withSmpAgent t $ testSMPAgentClient test
@ -177,7 +186,7 @@ agentCfg :: AgentConfig
agentCfg =
defaultAgentConfig
{ tcpPort = agentTestPort,
tbqSize = 1,
tbqSize = 4,
dbFile = testDB,
smpCfg =
defaultClientConfig