primary tokens (#356)

* primary tokens

* support repeat token registration (TODO fix tests)

* fix notifications tests

* fix/test repeat/new registrations of the same token

* re-register token when subsequent ntf command fails with AUTH error (e.g. when server is re-started)

* cancel periodic notifications when token is deleted on the server

* debug failing test on CI

* fix notification test in CI

* debug CI test

* add delay in notificaitons test after server restart
This commit is contained in:
Evgeny Poberezkin 2022-04-21 17:04:26 +01:00 committed by GitHub
parent 4dc7d9bc77
commit e6fbaf5e50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 291 additions and 84 deletions

View File

@ -52,6 +52,7 @@ module Simplex.Messaging.Agent
registerNtfToken,
verifyNtfToken,
enableNtfCron,
checkNtfToken,
deleteNtfToken,
logConnection,
)
@ -89,10 +90,10 @@ import Simplex.Messaging.Encoding
import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol (DeviceToken, NtfRegCode (NtfRegCode), NtfTknStatus (..))
import Simplex.Messaging.Parsers (parse)
import Simplex.Messaging.Protocol (BrokerMsg, MsgBody)
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType (AUTH), MsgBody)
import qualified Simplex.Messaging.Protocol as SMP
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (bshow, liftError, tryError, unlessM)
import Simplex.Messaging.Util (bshow, liftError, tryError, unlessM, ($>>=))
import Simplex.Messaging.Version
import System.Random (randomR)
import UnliftIO.Async (async, race_)
@ -172,6 +173,9 @@ verifyNtfToken c = withAgentEnv c .:. verifyNtfToken' c
enableNtfCron :: AgentErrorMonad m => AgentClient -> DeviceToken -> Word16 -> m ()
enableNtfCron c = withAgentEnv c .: enableNtfCron' c
checkNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m NtfTknStatus
checkNtfToken c = withAgentEnv c . checkNtfToken' c
deleteNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> m ()
deleteNtfToken c = withAgentEnv c . deleteNtfToken' c
@ -552,7 +556,7 @@ registerNtfToken' c deviceToken =
registerToken tkn
_ -> throwError $ CMD PROHIBITED
where
t tkn = withToken tkn Nothing
t tkn = withToken c tkn Nothing
registerToken :: NtfToken -> m ()
registerToken tkn@NtfToken {ntfPubKey, ntfDhKeys = (pubDhKey, privDhKey)} = do
(tknId, srvPubDhKey) <- agentNtfRegisterToken c tkn ntfPubKey pubDhKey
@ -565,7 +569,7 @@ verifyNtfToken' c deviceToken code nonce =
withStore (`getDeviceNtfToken` deviceToken) >>= \case
(Just tkn@NtfToken {ntfTokenId = Just tknId, ntfDhSecret = Just dhSecret}, _) -> do
code' <- liftEither . bimap cryptoError NtfRegCode $ C.cbDecrypt dhSecret nonce code
withToken tkn (Just (NTConfirmed, NTAVerify code')) (NTActive, Just NTACheck) $
withToken c tkn (Just (NTConfirmed, NTAVerify code')) (NTActive, Just NTACheck) $
agentNtfVerifyToken c tknId tkn code'
_ -> throwError $ CMD PROHIBITED
@ -574,7 +578,7 @@ enableNtfCron' c deviceToken interval = do
when (interval < 20) . throwError $ CMD PROHIBITED
withStore (`getDeviceNtfToken` deviceToken) >>= \case
(Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive}, _) ->
withToken tkn (Just (NTActive, NTACron interval)) (cronSuccess interval) $
withToken c tkn (Just (NTActive, NTACron interval)) (cronSuccess interval) $
agentNtfEnableCron c tknId tkn interval
_ -> throwError $ CMD PROHIBITED
@ -583,6 +587,12 @@ cronSuccess interval
| interval == 0 = (NTActive, Just NTACheck)
| otherwise = (NTActive, Just $ NTACron interval)
checkNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m NtfTknStatus
checkNtfToken' c deviceToken =
withStore (`getDeviceNtfToken` deviceToken) >>= \case
(Just tkn@NtfToken {ntfTokenId = Just tknId}, _) -> agentNtfCheckToken c tknId tkn
_ -> throwError $ CMD PROHIBITED
deleteNtfToken' :: AgentMonad m => AgentClient -> DeviceToken -> m ()
deleteNtfToken' c deviceToken =
withStore (`getDeviceNtfToken` deviceToken) >>= \case
@ -593,15 +603,23 @@ deleteToken_ :: AgentMonad m => AgentClient -> NtfToken -> m ()
deleteToken_ c tkn@NtfToken {ntfTokenId, ntfTknStatus} = do
forM_ ntfTokenId $ \tknId -> do
withStore $ \st -> updateNtfToken st tkn ntfTknStatus (Just NTADelete)
agentNtfDeleteToken c tknId tkn
agentNtfDeleteToken c tknId tkn `catchError` \case
NTF AUTH -> pure ()
e -> throwError e
withStore $ \st -> removeNtfToken st tkn
withToken :: AgentMonad m => NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> m a -> m a
withToken tkn from_ (toStatus, toAction_) f = do
withToken :: AgentMonad m => AgentClient -> NtfToken -> Maybe (NtfTknStatus, NtfTknAction) -> (NtfTknStatus, Maybe NtfTknAction) -> m a -> m a
withToken c tkn@NtfToken {deviceToken} from_ (toStatus, toAction_) f = do
forM_ from_ $ \(status, action) -> withStore $ \st -> updateNtfToken st tkn status (Just action)
res <- f
withStore $ \st -> updateNtfToken st tkn toStatus toAction_
pure res
tryError f >>= \case
Right res -> do
withStore $ \st -> updateNtfToken st tkn toStatus toAction_
pure res
Left e@(NTF AUTH) -> do
withStore $ \st -> removeNtfToken st tkn
registerNtfToken' c deviceToken
throwError e
Left e -> throwError e
setNtfServers' :: AgentMonad m => AgentClient -> [NtfServer] -> m ()
setNtfServers' c servers = do
@ -678,7 +696,7 @@ processSMPTransmission c@AgentClient {smpClients, subQ} (srv, sessId, rId, cmd)
_ -> prohibited >> ack
_ -> prohibited >> ack
SMP.END ->
atomically (TM.lookup srv smpClients >>= fmap join . mapM tryReadTMVar >>= processEND)
atomically (TM.lookup srv smpClients $>>= tryReadTMVar >>= processEND)
>>= logServer "<--" c srv rId
where
processEND = \case

View File

@ -1,3 +1,4 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
@ -8,6 +9,7 @@
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Simplex.Messaging.Agent.Client
( AgentClient (..),
@ -66,7 +68,7 @@ import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Encoding
import Simplex.Messaging.Notifications.Client
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, ProtocolServer (..), QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
import qualified Simplex.Messaging.Protocol as SMP
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
@ -132,10 +134,15 @@ type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorTy
class ProtocolServerClient msg where
getProtocolServerClient :: AgentMonad m => AgentClient -> ProtocolServer -> m (ProtocolClient msg)
protocolError :: ErrorType -> AgentErrorType
instance ProtocolServerClient BrokerMsg where getProtocolServerClient = getSMPServerClient
instance ProtocolServerClient BrokerMsg where
getProtocolServerClient = getSMPServerClient
protocolError = SMP
instance ProtocolServerClient NtfResponse where getProtocolServerClient = getNtfServerClient
instance ProtocolServerClient NtfResponse where
getProtocolServerClient = getNtfServerClient
protocolError = NTF
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
@ -148,7 +155,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
connectClient = do
cfg <- asks $ smpCfg . config
u <- askUnliftIO
liftEitherError protocolClientError (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u)
liftEitherError (protocolClientError SMP) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u)
clientDisconnected :: UnliftIO m -> IO ()
clientDisconnected u = do
@ -207,7 +214,7 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
e@PCEResponseTimeout -> throwError e
e@PCENetworkError -> throwError e
e -> do
liftIO $ notifySub (ERR $ protocolClientError e) connId
liftIO $ notifySub (ERR $ protocolClientError SMP e) connId
atomically $ removePendingSubscription c srv connId
notifySub :: ACommand 'Agent -> ConnId -> IO ()
@ -223,7 +230,7 @@ getNtfServerClient c@AgentClient {ntfClients} srv =
connectClient :: m NtfClient
connectClient = do
cfg <- asks $ ntfCfg . config
liftEitherError protocolClientError (getProtocolClient srv cfg Nothing clientDisconnected)
liftEitherError (protocolClientError NTF) (getProtocolClient srv cfg Nothing clientDisconnected)
clientDisconnected :: IO ()
clientDisconnected = do
@ -322,18 +329,18 @@ withLogClient_ c srv qId cmdStr action = do
logServer "<--" c srv qId "OK"
return res
withClient :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withClient c srv action = withClient_ c srv $ liftClient . action
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withClient c srv action = withClient_ c srv $ liftClient (protocolError @msg) . action
withLogClient :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ liftClient . action
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtocolServer -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ liftClient (protocolError @msg) . action
liftClient :: AgentMonad m => ExceptT ProtocolClientError IO a -> m a
liftClient = liftError protocolClientError
liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> ExceptT ProtocolClientError IO a -> m a
liftClient = liftError . protocolClientError
protocolClientError :: ProtocolClientError -> AgentErrorType
protocolClientError = \case
PCEProtocolError e -> SMP e
protocolClientError :: (ErrorType -> AgentErrorType) -> ProtocolClientError -> AgentErrorType
protocolClientError protocolError_ = \case
PCEProtocolError e -> protocolError_ e
PCEResponseError e -> BROKER $ RESPONSE e
PCEUnexpectedResponse -> BROKER UNEXPECTED
PCEResponseTimeout -> BROKER TIMEOUT
@ -428,14 +435,14 @@ sendConfirmation c sq@SndQueue {server, sndId, sndPublicKey = Just sndPublicKey,
withLogClient_ c server sndId "SEND <CONF>" $ \smp -> do
let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation
msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg
liftClient $ sendSMPMessage smp Nothing sndId msg
liftClient SMP $ sendSMPMessage smp Nothing sndId msg
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
sendInvitation c (Compatible SMPQueueInfo {smpServer, senderId, dhPublicKey}) connReq connInfo =
withLogClient_ c smpServer senderId "SEND <INV>" $ \smp -> do
msg <- mkInvitation
liftClient $ sendSMPMessage smp Nothing senderId msg
liftClient SMP $ sendSMPMessage smp Nothing senderId msg
where
mkInvitation :: m ByteString
-- this is only encrypted with per-queue E2E, not with double ratchet
@ -469,7 +476,7 @@ sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} agentMsg =
withLogClient_ c server sndId "SEND <MSG>" $ \smp -> do
let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
liftClient $ sendSMPMessage smp (Just sndPrivateKey) sndId msg
liftClient SMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg
agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519)
agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey =

View File

@ -686,6 +686,8 @@ data AgentErrorType
CONN {connErr :: ConnectionErrorType}
| -- | SMP protocol errors forwarded to agent clients
SMP {smpErr :: ErrorType}
| -- | NTF protocol errors forwarded to agent clients
NTF {ntfErr :: ErrorType}
| -- | SMP server errors
BROKER {brokerErr :: BrokerErrorType}
| -- | errors of other agents
@ -774,6 +776,7 @@ instance StrEncoding AgentErrorType where
"CMD " *> (CMD <$> parseRead1)
<|> "CONN " *> (CONN <$> parseRead1)
<|> "SMP " *> (SMP <$> strP)
<|> "NTF " *> (NTF <$> strP)
<|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> strP)
<|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP)
<|> "BROKER " *> (BROKER <$> parseRead1)
@ -783,6 +786,7 @@ instance StrEncoding AgentErrorType where
CMD e -> "CMD " <> bshow e
CONN e -> "CONN " <> bshow e
SMP e -> "SMP " <> strEncode e
NTF e -> "NTF " <> strEncode e
BROKER (RESPONSE e) -> "BROKER RESPONSE " <> strEncode e
BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e
BROKER e -> "BROKER " <> bshow e

View File

@ -28,7 +28,7 @@ import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, SMPServer)
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (tryE, whenM)
import Simplex.Messaging.Util (tryE, whenM, ($>>=))
import System.Timeout (timeout)
import UnliftIO (async, forConcurrently_)
import UnliftIO.Exception (Exception)
@ -295,7 +295,7 @@ removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMP
removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s)
getSubKey :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM (Maybe C.APrivateSignKey)
getSubKey subs srv s = fmap join . mapM (TM.lookup s) =<< TM.lookup srv subs
getSubKey subs srv s = TM.lookup srv subs $>>= TM.lookup s
hasSub :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM Bool
hasSub subs srv s = maybe (pure False) (TM.member s) =<< TM.lookup srv subs

View File

@ -264,6 +264,7 @@ data NtfResponse
| NRTkn NtfTknStatus
| NRSub NtfSubStatus
| NRPong
deriving (Show)
instance ProtocolEncoding NtfResponse where
type Tag NtfResponse = NtfResponseTag
@ -361,7 +362,7 @@ data NtfSubStatus
NSEnd
| -- | SMP AUTH error
NSSMPAuth
deriving (Eq)
deriving (Eq, Show)
instance Encoding NtfSubStatus where
smpEncode = \case

View File

@ -117,7 +117,7 @@ ntfPush s@NtfPushServer {pushQ} = liftIO . forever . runExceptT $ do
(_, PNVerification _) -> do
-- TODO check token status
deliverNotification pp tkn ntf
atomically $ writeTVar tknStatus NTConfirmed
atomically $ modifyTVar tknStatus $ \status' -> if status' == NTActive then NTActive else NTConfirmed
(NTActive, PNCheckMessages) -> do
deliverNotification pp tkn ntf
_ -> do
@ -166,26 +166,26 @@ verifyNtfTransmission ::
verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
st <- asks store
case cmd of
NtfCmd SToken c@(TNEW n@(NewNtfTkn _ k _)) -> do
r_ <- atomically $ getNtfToken st entId
NtfCmd SToken c@(TNEW tkn@(NewNtfTkn _ k _)) -> do
r_ <- atomically $ getNtfTokenRegistration st tkn
pure $
if verifyCmdSignature sig_ signed k
then case r_ of
Just r@(NtfTkn NtfTknData {tknVerifyKey})
| k == tknVerifyKey -> tknCmd r c
Just t@NtfTknData {tknVerifyKey}
| k == tknVerifyKey -> verifiedTknCmd t c
| otherwise -> VRFailed
_ -> VRVerified (NtfReqNew corrId (ANE SToken n))
_ -> VRVerified (NtfReqNew corrId (ANE SToken tkn))
else VRFailed
NtfCmd SToken c -> do
r_ <- atomically $ getNtfToken st entId
pure $ case r_ of
Just r@(NtfTkn NtfTknData {tknVerifyKey})
| verifyCmdSignature sig_ signed tknVerifyKey -> tknCmd r c
t_ <- atomically $ getNtfToken st entId
pure $ case t_ of
Just t@NtfTknData {tknVerifyKey}
| verifyCmdSignature sig_ signed tknVerifyKey -> verifiedTknCmd t c
| otherwise -> VRFailed
_ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
_ -> pure VRFailed
where
tknCmd r c = VRVerified (NtfReqCmd SToken r (corrId, entId, c))
verifiedTknCmd t c = VRVerified (NtfReqCmd SToken (NtfTkn t) (corrId, entId, c))
client :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerClient -> NtfSubscriber -> NtfPushServer -> m ()
client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {pushQ, intervalNotifiers} =
@ -204,11 +204,11 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {push
tknId <- getId
regCode <- getRegCode
atomically $ do
tkn <- mkNtfTknData newTkn ks dhSecret regCode
tkn <- mkNtfTknData tknId newTkn ks dhSecret regCode
addNtfToken st tknId tkn
writeTBQueue pushQ (tkn, PNVerification regCode)
pure (corrId, "", NRId tknId srvDhPubKey)
NtfReqCmd SToken (NtfTkn tkn@NtfTknData {tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey)}) (corrId, tknId, cmd) -> do
NtfReqCmd SToken (NtfTkn tkn@NtfTknData {ntfTknId, tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey)}) (corrId, tknId, cmd) -> do
status <- readTVarIO tknStatus
(corrId,tknId,) <$> case cmd of
TNEW (NewNtfTkn _ _ dhPubKey) -> do
@ -218,12 +218,15 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {push
if tknDhSecret == dhSecret
then do
atomically $ writeTBQueue pushQ (tkn, PNVerification tknRegCode)
pure $ NRId tknId srvDhPubKey
pure $ NRId ntfTknId srvDhPubKey
else pure $ NRErr AUTH
TVFY code -- this allows repeated verification for cases when client connection dropped before server response
| (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do
logDebug "TVFY - token verified"
st <- asks store
atomically $ writeTVar tknStatus NTActive
tIds <- atomically $ removeInactiveTokenRegistrations st tkn
forM_ tIds cancelInvervalNotifications
pure NROk
| otherwise -> do
logDebug "TVFY - incorrect code or token status"
@ -233,12 +236,12 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {push
logDebug "TDEL"
st <- asks store
atomically $ deleteNtfToken st tknId
cancelInvervalNotifications tknId
pure NROk
TCRN 0 ->
TCRN 0 -> do
logDebug "TCRN 0"
>> atomically (TM.lookupDelete tknId intervalNotifiers)
>>= mapM_ (uninterruptibleCancel . action)
>> pure NROk
cancelInvervalNotifications tknId
pure NROk
TCRN int
| int < 20 -> pure $ NRErr QUOTA
| otherwise -> do
@ -274,6 +277,10 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {subQ = _} NtfPushServer {push
getRandomBytes n = do
gVar <- asks idsDrg
atomically (C.pseudoRandomBytes n gVar)
cancelInvervalNotifications :: NtfTokenId -> m ()
cancelInvervalNotifications tknId =
atomically (TM.lookupDelete tknId intervalNotifiers)
>>= mapM_ (uninterruptibleCancel . action)
-- NReqCreate corrId tokenId smpQueue -> pure (corrId, "", NROk)
-- do

View File

@ -2,30 +2,38 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module Simplex.Messaging.Notifications.Server.Store where
import Control.Concurrent.STM
import Control.Monad
import Data.ByteString.Char8 (ByteString)
import qualified Data.Map.Strict as M
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util ((<$$>))
import Simplex.Messaging.Util (whenM, ($>>=))
data NtfStore = NtfStore
{ tokens :: TMap NtfTokenId NtfTknData,
tokenIds :: TMap DeviceToken NtfTokenId
tokenRegistrations :: TMap DeviceToken (TMap ByteString NtfTokenId)
}
newNtfStore :: STM NtfStore
newNtfStore = do
tokens <- TM.empty
tokenIds <- TM.empty
pure NtfStore {tokens, tokenIds}
tokenRegistrations <- TM.empty
pure NtfStore {tokens, tokenRegistrations}
data NtfTknData = NtfTknData
{ token :: DeviceToken,
{ ntfTknId :: NtfTokenId,
token :: DeviceToken,
tknStatus :: TVar NtfTknStatus,
tknVerifyKey :: C.APublicVerifyKey,
tknDhKeys :: C.KeyPair 'C.X25519,
@ -33,10 +41,10 @@ data NtfTknData = NtfTknData
tknRegCode :: NtfRegCode
}
mkNtfTknData :: NewNtfEntity 'Token -> C.KeyPair 'C.X25519 -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData
mkNtfTknData (NewNtfTkn token tknVerifyKey _) tknDhKeys tknDhSecret tknRegCode = do
mkNtfTknData :: NtfTokenId -> NewNtfEntity 'Token -> C.KeyPair 'C.X25519 -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData
mkNtfTknData ntfTknId (NewNtfTkn token tknVerifyKey _) tknDhKeys tknDhSecret tknRegCode = do
tknStatus <- newTVar NTRegistered
pure NtfTknData {token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode}
pure NtfTknData {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode}
-- data NtfSubscriptionsStore = NtfSubscriptionsStore
@ -58,19 +66,57 @@ data NtfEntityRec (e :: NtfEntity) where
NtfTkn :: NtfTknData -> NtfEntityRec 'Token
NtfSub :: NtfSubData -> NtfEntityRec 'Subscription
data ANtfEntityRec = forall e. NtfEntityI e => NER (SNtfEntity e) (NtfEntityRec e)
getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe (NtfEntityRec 'Token))
getNtfToken st tknId = NtfTkn <$$> TM.lookup tknId (tokens st)
getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe NtfTknData)
getNtfToken st tknId = TM.lookup tknId (tokens st)
addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM ()
addNtfToken st tknId tkn@NtfTknData {token} = do
addNtfToken st tknId tkn@NtfTknData {token, tknVerifyKey} = do
TM.insert tknId tkn $ tokens st
TM.insert token tknId $ tokenIds st
TM.lookup token regs >>= \case
Just tIds -> TM.insert regKey tknId tIds
_ -> do
tIds <- TM.singleton regKey tknId
TM.insert token tIds regs
where
regs = tokenRegistrations st
regKey = C.toPubKey C.pubKeyBytes tknVerifyKey
getNtfTokenRegistration :: NtfStore -> NewNtfEntity 'Token -> STM (Maybe NtfTknData)
getNtfTokenRegistration st (NewNtfTkn token tknVerifyKey _) =
TM.lookup token (tokenRegistrations st)
$>>= TM.lookup regKey
$>>= (`TM.lookup` tokens st)
where
regKey = C.toPubKey C.pubKeyBytes tknVerifyKey
removeInactiveTokenRegistrations :: NtfStore -> NtfTknData -> STM [NtfTokenId]
removeInactiveTokenRegistrations st NtfTknData {ntfTknId = tId, token} =
TM.lookup token (tokenRegistrations st)
>>= maybe (pure []) removeRegs
where
removeRegs :: TMap ByteString NtfTokenId -> STM [NtfTokenId]
removeRegs tknRegs = do
tIds <- filter ((/= tId) . snd) . M.assocs <$> readTVar tknRegs
forM_ tIds $ \(regKey, tId') -> do
TM.delete regKey tknRegs
TM.delete tId' $ tokens st
pure $ map snd tIds
deleteNtfToken :: NtfStore -> NtfTokenId -> STM ()
deleteNtfToken st tknId = do
TM.lookupDelete tknId (tokens st) >>= mapM_ (\NtfTknData {token} -> TM.delete token $ tokenIds st)
TM.lookupDelete tknId (tokens st)
>>= mapM_
( \NtfTknData {token, tknVerifyKey} ->
TM.lookup token regs
>>= mapM_
( \tIds -> do
TM.delete (regKey tknVerifyKey) tIds
whenM (TM.null tIds) $ TM.delete token regs
)
)
where
regs = tokenRegistrations st
regKey = C.toPubKey C.pubKeyBytes
-- getNtfRec :: NtfStore -> SNtfEntity e -> NtfEntityId -> STM (Maybe (NtfEntityRec e))
-- getNtfRec st ent entId = case ent of

View File

@ -98,7 +98,7 @@ smpServer started = do
m ()
serverThread s subQ subs clientSubs unsub = forever $ do
atomically updateSubscribers
>>= fmap join . mapM endPreviousSubscriptions
$>>= endPreviousSubscriptions
>>= mapM_ unsub
where
updateSubscribers :: STM (Maybe (QueueId, Client))
@ -110,8 +110,7 @@ smpServer started = do
else do
yes <- readTVar $ connected c'
pure $ if yes then Just (qId, c') else Nothing
TM.lookupInsert qId clnt (subs s)
>>= fmap join . mapM clientToBeNotified
TM.lookupInsert qId clnt (subs s) $>>= clientToBeNotified
endPreviousSubscriptions :: (QueueId, Client) -> m (Maybe s)
endPreviousSubscriptions (qId, c) = do
void . forkIO . atomically $

View File

@ -18,7 +18,7 @@ import Simplex.Messaging.Protocol
import Simplex.Messaging.Server.QueueStore
import Simplex.Messaging.TMap (TMap)
import qualified Simplex.Messaging.TMap as TM
import Simplex.Messaging.Util (ifM)
import Simplex.Messaging.Util (ifM, ($>>=))
import UnliftIO.STM
data QueueStore = QueueStore
@ -51,9 +51,8 @@ instance MonadQueueStore QueueStore STM where
where
getVar = case party of
SRecipient -> TM.lookup qId queues
SSender -> TM.lookup qId senders >>= get
SNotifier -> TM.lookup qId notifiers >>= get
get = fmap join . mapM (`TM.lookup` queues)
SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues)
SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues)
secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec)
secureQueue QueueStore {queues} rId sKey =
@ -91,4 +90,4 @@ toResult :: Maybe a -> Either ErrorType a
toResult = maybe (Left AUTH) Right
withQueue :: RecipientId -> TMap RecipientId (TVar QueueRec) -> (TVar QueueRec -> STM (Maybe a)) -> STM (Either ErrorType a)
withQueue rId queues f = toResult <$> (TM.lookup rId queues >>= fmap join . mapM f)
withQueue rId queues f = toResult <$> TM.lookup rId queues $>>= f

View File

@ -2,6 +2,7 @@ module Simplex.Messaging.TMap
( TMap,
empty,
singleton,
Simplex.Messaging.TMap.null,
Simplex.Messaging.TMap.lookup,
member,
insert,
@ -30,6 +31,10 @@ singleton :: k -> a -> STM (TMap k a)
singleton k v = newTVar $ M.singleton k v
{-# INLINE singleton #-}
null :: TMap k a -> STM Bool
null m = M.null <$> readTVar m
{-# INLINE null #-}
lookup :: Ord k => k -> TMap k a -> STM (Maybe a)
lookup k m = M.lookup k <$> readTVar m
{-# INLINE lookup #-}

View File

@ -65,3 +65,6 @@ whenM b a = ifM b a $ pure ()
unlessM :: Monad m => m Bool -> m () -> m ()
unlessM b = ifM b $ pure ()
{-# INLINE unlessM #-}
($>>=) :: (Monad m, Monad f, Traversable f) => m (f a) -> (a -> m (f b)) -> m (f b)
f $>>= g = f >>= fmap join . mapM g

View File

@ -3,6 +3,9 @@
module AgentTests.NotificationTests where
-- import Control.Logger.Simple (LogConfig (..), LogLevel (..), setLogLevel, withGlobalLogging)
import Control.Concurrent (threadDelay)
import Control.Monad.Except
import qualified Data.Aeson as J
import qualified Data.Aeson.Types as JT
@ -11,21 +14,36 @@ import qualified Data.ByteString.Base64.URL as U
import Data.ByteString.Char8 (ByteString)
import Data.Text.Encoding (encodeUtf8)
import NtfClient
import SMPAgentClient (agentCfg, initAgentServers)
import SMPAgentClient (agentCfg, initAgentServers, testDB, testDB2)
import Simplex.Messaging.Agent
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
import Simplex.Messaging.Agent.Protocol
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Notifications.Protocol
import Simplex.Messaging.Notifications.Server.Push.APNS
import Simplex.Messaging.Protocol (ErrorType (AUTH))
import Simplex.Messaging.Transport (ATransport)
import Simplex.Messaging.Util (tryE)
import System.Directory (removeFile)
import Test.Hspec
import UnliftIO.STM
notificationTests :: ATransport -> Spec
notificationTests t = do
describe "Managing notification tokens" $
it "should register and verify notification token" $
withAPNSMockServer $ \apns -> withNtfServer t $ testNotificationToken apns
notificationTests t =
after_ (removeFile testDB) $
describe "Managing notification tokens" $ do
it "should register and verify notification token" $
withAPNSMockServer $ \apns ->
withNtfServer t $ testNotificationToken apns
it "should allow repeated registration with the same credentials" $ \_ ->
withAPNSMockServer $ \apns ->
withNtfServer t $ testNtfTokenRepeatRegistration apns
it "should allow the second registration with different credentials and delete the first after verification" $ \_ ->
withAPNSMockServer $ \apns ->
withNtfServer t $ testNtfTokenSecondRegistration apns
it "should re-register token when notification server is restarted" $ \_ ->
withAPNSMockServer $ \apns ->
testNtfTokenServerRestart t apns
testNotificationToken :: APNSMockServer -> IO ()
testNotificationToken APNSMockServer {apnsQ} = do
@ -40,10 +58,110 @@ testNotificationToken APNSMockServer {apnsQ} = do
liftIO $ sendApnsResponse APNSRespOk
verifyNtfToken a tkn verification nonce
enableNtfCron a tkn 30
NTActive <- checkNtfToken a tkn
deleteNtfToken a tkn
-- agent deleted this token
Left (CMD PROHIBITED) <- tryE $ checkNtfToken a tkn
pure ()
pure ()
where
(.->) :: J.Value -> J.Key -> ExceptT AgentErrorType IO ByteString
v .-> key = do
J.Object o <- pure v
liftEither . bimap INTERNAL (U.decodeLenient . encodeUtf8) $ JT.parseEither (J..: key) o
(.->) :: J.Value -> J.Key -> ExceptT AgentErrorType IO ByteString
v .-> key = do
J.Object o <- pure v
liftEither . bimap INTERNAL (U.decodeLenient . encodeUtf8) $ JT.parseEither (J..: key) o
-- logCfg :: LogConfig
-- logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
testNtfTokenRepeatRegistration :: APNSMockServer -> IO ()
testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do
-- setLogLevel LogError -- LogDebug
-- withGlobalLogging logCfg $ do
a <- getSMPAgentClient agentCfg initAgentServers
Right () <- runExceptT $ do
let tkn = DeviceToken PPApns "abcd"
registerNtfToken a tkn
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
atomically $ readTBQueue apnsQ
verification <- ntfData .-> "verification"
nonce <- C.cbNonce <$> ntfData .-> "nonce"
liftIO $ sendApnsResponse APNSRespOk
registerNtfToken a tkn
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <-
atomically $ readTBQueue apnsQ
_ <- ntfData' .-> "verification"
_ <- C.cbNonce <$> ntfData' .-> "nonce"
liftIO $ sendApnsResponse' APNSRespOk
-- can still use the first verification code, it is the same after decryption
verifyNtfToken a tkn verification nonce
enableNtfCron a tkn 30
NTActive <- checkNtfToken a tkn
pure ()
pure ()
testNtfTokenSecondRegistration :: APNSMockServer -> IO ()
testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do
-- setLogLevel LogError -- LogDebug
-- withGlobalLogging logCfg $ do
a <- getSMPAgentClient agentCfg initAgentServers
a' <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
Right () <- runExceptT $ do
let tkn = DeviceToken PPApns "abcd"
registerNtfToken a tkn
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
atomically $ readTBQueue apnsQ
verification <- ntfData .-> "verification"
nonce <- C.cbNonce <$> ntfData .-> "nonce"
liftIO $ sendApnsResponse APNSRespOk
verifyNtfToken a tkn verification nonce
registerNtfToken a' tkn
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <-
atomically $ readTBQueue apnsQ
verification' <- ntfData' .-> "verification"
nonce' <- C.cbNonce <$> ntfData' .-> "nonce"
liftIO $ sendApnsResponse' APNSRespOk
-- at this point the first token is still active
NTActive <- checkNtfToken a tkn
-- and the second is not yet verified
NTConfirmed <- checkNtfToken a' tkn
-- now the second token registration is verified
verifyNtfToken a' tkn verification' nonce'
-- the first registration is removed
Left (NTF AUTH) <- tryE $ checkNtfToken a tkn
-- and the second is active
NTActive <- checkNtfToken a' tkn
enableNtfCron a' tkn 30
pure ()
pure ()
testNtfTokenServerRestart :: ATransport -> APNSMockServer -> IO ()
testNtfTokenServerRestart t APNSMockServer {apnsQ} = do
a <- getSMPAgentClient agentCfg initAgentServers
let tkn = DeviceToken PPApns "abcd"
Right ntfData <- withNtfServer t . runExceptT $ do
registerNtfToken a tkn
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
atomically $ readTBQueue apnsQ
liftIO $ sendApnsResponse APNSRespOk
pure ntfData
-- the new agent is created as otherwise when running the tests in CI the old agent was keeping the connection to the server
threadDelay 1000000
disconnectAgentClient a
a' <- getSMPAgentClient agentCfg initAgentServers
-- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token,
-- so that repeat verification happens without restarting the clients, when notification arrives
Right () <- withNtfServer t . runExceptT $ do
verification <- ntfData .-> "verification"
nonce <- C.cbNonce <$> ntfData .-> "nonce"
Left (NTF AUTH) <- tryE $ verifyNtfToken a' tkn verification nonce
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <-
atomically $ readTBQueue apnsQ
verification' <- ntfData' .-> "verification"
nonce' <- C.cbNonce <$> ntfData' .-> "nonce"
liftIO $ sendApnsResponse' APNSRespOk
verifyNtfToken a' tkn verification' nonce'
NTActive <- checkNtfToken a' tkn
enableNtfCron a' tkn 30
pure ()