Merge pull request #468 from simplex-chat/master
Merge master to stable
This commit is contained in:
commit
e75846aa38
|
@ -19,9 +19,9 @@ jobs:
|
|||
- name: Setup Stack
|
||||
uses: haskell/actions/setup@v1
|
||||
with:
|
||||
ghc-version: '8.10.7'
|
||||
ghc-version: "8.10.7"
|
||||
enable-stack: true
|
||||
stack-version: 'latest'
|
||||
stack-version: "latest"
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v2
|
||||
|
@ -36,6 +36,7 @@ jobs:
|
|||
stack build --test --force-dirty
|
||||
install_root=$(stack path --local-install-root)
|
||||
mv ${install_root}/bin/smp-server smp-server-ubuntu-20_04-x86-64
|
||||
mv ${install_root}/bin/ntf-server ntf-server-ubuntu-20_04-x86-64
|
||||
|
||||
- name: Build changelog
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
|
@ -73,6 +74,7 @@ jobs:
|
|||
files: |
|
||||
LICENSE
|
||||
smp-server-ubuntu-20_04-x86-64
|
||||
ntf-server-ubuntu-20_04-x86-64
|
||||
fail_on_unmatched_files: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
|
75
CHANGELOG.md
75
CHANGELOG.md
|
@ -1,3 +1,78 @@
|
|||
# 3.0.0
|
||||
|
||||
SMP server:
|
||||
|
||||
- restore undeliverd messages when the server is restarted.
|
||||
- SMP protocol v3 to support push notification:
|
||||
- updated SEND and MSG to add message flags (for notification flag that contros whether the notification is sent and for any future extensions) and to move message meta-data sent to the recipient into the encrypted envelope.
|
||||
- update NKEY and NID to add e2e encryption keys (for the notification meta-data encryption between SMP server and the client), and update NMSG to include this meta-data.
|
||||
- update ACK command to include message ID (to avoid acknowledging unprocessed message).
|
||||
- add NDEL commands to remove notification subscription credentials from SMP queue.
|
||||
- add GET command to receive messages without subscription - to be used in iOS notification service extension to receive messages without terminating app subscriptions.
|
||||
|
||||
SMP agent:
|
||||
|
||||
- new protocol for duplex connection handshake reducing traffic and connection time.
|
||||
- support for SMP notifications server and managing device token.
|
||||
- remove redundant FQDN validation from TLS handshake to prepare for access via Tor.
|
||||
- support for fully stopping agent and for termporary suspending agent operations.
|
||||
- improve management of duplicate message delivery.
|
||||
|
||||
SMP notifications server v1.0:
|
||||
|
||||
- SMP notifications protocol with version negotiation during handshake.
|
||||
- device token registration and verification (via background notification).
|
||||
- SMP notification subscriptions and push notifications via APNS.
|
||||
- restoring notification subscriptions when the server is restarted.
|
||||
|
||||
# 2.3.0
|
||||
|
||||
SMP server:
|
||||
|
||||
- Save and restore undelivered messages, to avoid losing them. To save messages the server has to be stopped with SIGINT signal, if it is stopped with SIGTERM undelivered messages would not be saved.
|
||||
|
||||
# 2.2.0
|
||||
|
||||
SMP server:
|
||||
|
||||
- Fix sockets/threads/memory leak
|
||||
|
||||
SMP agent:
|
||||
|
||||
- Support stopping and resuming agent with `disconnectAgentClient` / `resumeAgentClient`
|
||||
|
||||
# 2.1.1
|
||||
|
||||
SMP server:
|
||||
|
||||
- gracefully close sockets on client disconnection
|
||||
- CLI warning when deleting server configuration
|
||||
|
||||
# 2.1.0
|
||||
|
||||
SMP server:
|
||||
|
||||
- configuration to expire inactive clients in ini file, increased TTL and check interval for client expiration
|
||||
|
||||
# 2.0.0
|
||||
|
||||
Push notifications server (beta):
|
||||
|
||||
- supports APNS
|
||||
- manage device tokens verification via notification delivery
|
||||
- sending periodic background notification to check messages (not more frequent than every 20 min)
|
||||
|
||||
SMP server:
|
||||
|
||||
- disconnect inactive clients after some period
|
||||
- remove undelivered messages after 30 days
|
||||
- log aggregate usage daily stats: only the number of queues created/secured/deleted/used and messages sent/delivered is logged, as one line per day, so we can plan server capacity and diagnose any problems.
|
||||
|
||||
SMP agent:
|
||||
|
||||
- manage device tokens and notification server connection
|
||||
- DOWN/UP events to the agent user about server disconnections/reconnections are now sent once per server
|
||||
|
||||
# 1.1.0
|
||||
|
||||
SMP server:
|
||||
|
|
13
README.md
13
README.md
|
@ -35,6 +35,8 @@ SMP server uses in-memory persistence with an optional append-only log of create
|
|||
|
||||
To enable store log, initialize server using `smp-server -l` command, or modify `smp-server.ini` created during initialization (uncomment `enable: on` option in the store log section). Use `smp-server --help` for other usage tips.
|
||||
|
||||
Starting from version 2.3.0, when store log is enabled, the server would also enable saving undelivered messages on exit and restoring them on start. This can be disabled via a separate setting `restore_messages` in `smp-server.ini` file. Saving messages would only work if the server is stopped with SIGINT signal (keyboard interrupt), if it is stopped with SIGTERM signal the messages would not be saved.
|
||||
|
||||
> **Please note:** On initialization SMP server creates a chain of two certificates: a self-signed CA certificate ("offline") and a server certificate used for TLS handshake ("online"). **You should store CA certificate private key securely and delete it from the server. If server TLS credential is compromised this key can be used to sign a new one, keeping the same server identity and established connections.** CA private key location by default is `/etc/opt/simplex/ca.key`.
|
||||
|
||||
SMP server implements [SMP protocol](https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md).
|
||||
|
@ -61,6 +63,7 @@ Now `openssl version` should be saying "OpenSSL". You can now run `smp-server in
|
|||
### SMP client library
|
||||
|
||||
[SMP client](https://github.com/simplex-chat/simplexmq/blob/master/src/Simplex/Messaging/Client.hs) is a Haskell library to connect to SMP servers that allows to:
|
||||
|
||||
- execute commands with a functional API.
|
||||
- receive messages and other notifications via STM queue.
|
||||
- automatically send keep-alive commands.
|
||||
|
@ -118,11 +121,11 @@ Deployment on Linode is performed via StackScripts, which serve as recipes for L
|
|||
- Create a Linode account or login with an already existing one.
|
||||
- Open [SMP server StackScript](https://cloud.linode.com/stackscripts/748014) and click "Deploy New Linode".
|
||||
- You can optionally configure the following parameters:
|
||||
- SMP Server store log flag for queue persistence on server restart, recommended.
|
||||
- [Linode API token](https://www.linode.com/docs/guides/getting-started-with-the-linode-api#get-an-access-token) to attach server address etc. as tags to Linode and to add A record to your 2nd level domain (e.g. `example.com` [domain should be created](https://cloud.linode.com/domains/create) in your account prior to deployment). The API token access scopes:
|
||||
- read/write for "linodes"
|
||||
- read/write for "domains"
|
||||
- Domain name to use instead of Linode IP address, e.g. `smp1.example.com`.
|
||||
- SMP Server store log flag for queue persistence on server restart, recommended.
|
||||
- [Linode API token](https://www.linode.com/docs/guides/getting-started-with-the-linode-api#get-an-access-token) to attach server address etc. as tags to Linode and to add A record to your 2nd level domain (e.g. `example.com` [domain should be created](https://cloud.linode.com/domains/create) in your account prior to deployment). The API token access scopes:
|
||||
- read/write for "linodes"
|
||||
- read/write for "domains"
|
||||
- Domain name to use instead of Linode IP address, e.g. `smp1.example.com`.
|
||||
- Choose the region and plan, Shared CPU Nanode with 1Gb is sufficient.
|
||||
- Provide ssh key to be able to connect to your Linode via ssh. If you haven't provided a Linode API token this step is required to login to your Linode and get the server's fingerprint either from the welcome message or from the file `/etc/opt/simplex/fingerprint` after server starts. See [Linode's guide on ssh](https://www.linode.com/docs/guides/use-public-key-authentication-with-ssh/) .
|
||||
- Deploy your Linode. After it starts wait for SMP server to start and for tags to appear (if a Linode API token was provided). It may take up to 5 minutes depending on the connection speed on the Linode. Connecting Linode IP address to provided domain name may take some additional time.
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Main where
|
||||
|
||||
import Control.Logger.Simple
|
||||
import Simplex.Messaging.Client.Agent (defaultSMPClientAgentConfig)
|
||||
import Simplex.Messaging.Notifications.Server (runNtfServer)
|
||||
import Simplex.Messaging.Notifications.Server.Env (NtfServerConfig (..))
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS (defaultAPNSPushClientConfig)
|
||||
import Simplex.Messaging.Server.CLI (ServerCLIConfig (..), protocolServerCLI)
|
||||
import System.FilePath (combine)
|
||||
|
||||
cfgPath :: FilePath
|
||||
cfgPath = "/etc/opt/simplex-notifications"
|
||||
|
||||
logPath :: FilePath
|
||||
logPath = "/var/opt/simplex-notifications"
|
||||
|
||||
logCfg :: LogConfig
|
||||
logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
setLogLevel LogDebug -- change to LogError in production
|
||||
withGlobalLogging logCfg $ protocolServerCLI ntfServerCLIConfig runNtfServer
|
||||
|
||||
ntfServerCLIConfig :: ServerCLIConfig NtfServerConfig
|
||||
ntfServerCLIConfig =
|
||||
let caCrtFile = combine cfgPath "ca.crt"
|
||||
serverKeyFile = combine cfgPath "server.key"
|
||||
serverCrtFile = combine cfgPath "server.crt"
|
||||
in ServerCLIConfig
|
||||
{ cfgDir = cfgPath,
|
||||
logDir = logPath,
|
||||
iniFile = combine cfgPath "ntf-server.ini",
|
||||
storeLogFile = combine logPath "ntf-server-store.log",
|
||||
caKeyFile = combine cfgPath "ca.key",
|
||||
caCrtFile,
|
||||
serverKeyFile,
|
||||
serverCrtFile,
|
||||
fingerprintFile = combine cfgPath "fingerprint",
|
||||
defaultServerPort = "443",
|
||||
executableName = "ntf-server",
|
||||
serverVersion = "SMP notifications server v1.0.0",
|
||||
mkIniFile = \enableStoreLog defaultServerPort ->
|
||||
"[STORE_LOG]\n\
|
||||
\# The server uses STM memory for persistence,\n\
|
||||
\# that will be lost on restart (e.g., as with redis).\n\
|
||||
\# This option enables saving memory to append only log,\n\
|
||||
\# and restoring it when the server is started.\n\
|
||||
\# Log is compacted on start (deleted objects are removed).\n\
|
||||
\# The messages are not logged.\n"
|
||||
<> ("enable: " <> (if enableStoreLog then "on" else "off") <> "\n\n")
|
||||
<> "[TRANSPORT]\n\
|
||||
\port: "
|
||||
<> defaultServerPort
|
||||
<> "\n\
|
||||
\websockets: off\n",
|
||||
mkServerConfig = \storeLogFile transports _ ->
|
||||
NtfServerConfig
|
||||
{ transports,
|
||||
subIdBytes = 24,
|
||||
regCodeBytes = 32,
|
||||
clientQSize = 16,
|
||||
subQSize = 64,
|
||||
pushQSize = 128,
|
||||
smpAgentCfg = defaultSMPClientAgentConfig,
|
||||
apnsConfig = defaultAPNSPushClientConfig,
|
||||
inactiveClientExpiration = Nothing,
|
||||
storeLogFile,
|
||||
resubscribeDelay = 50000, -- 50ms
|
||||
caCertificateFile = caCrtFile,
|
||||
privateKeyFile = serverKeyFile,
|
||||
certificateFile = serverCrtFile
|
||||
}
|
||||
}
|
|
@ -11,7 +11,14 @@ import Simplex.Messaging.Agent.Server (runSMPAgent)
|
|||
import Simplex.Messaging.Transport (TLS, Transport (..))
|
||||
|
||||
cfg :: AgentConfig
|
||||
cfg = defaultAgentConfig {initialSMPServers = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"]}
|
||||
cfg = defaultAgentConfig
|
||||
|
||||
servers :: InitialAgentServers
|
||||
servers =
|
||||
InitialAgentServers
|
||||
{ smp = L.fromList ["smp://bU0K-bRg24xWW__lS0umO1Zdw_SXqpJNtm1_RrPLViE=@localhost:5223"],
|
||||
ntf = []
|
||||
}
|
||||
|
||||
logCfg :: LogConfig
|
||||
logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
|
||||
|
@ -20,4 +27,4 @@ main :: IO ()
|
|||
main = do
|
||||
putStrLn $ "SMP agent listening on port " ++ tcpPort (cfg :: AgentConfig)
|
||||
setLogLevel LogInfo -- LogError
|
||||
withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg
|
||||
withGlobalLogging logCfg $ runSMPAgent (transport @TLS) cfg servers
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
@ -9,314 +6,104 @@
|
|||
|
||||
module Main where
|
||||
|
||||
import Control.Monad.Except
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (fromRight)
|
||||
import Data.Ini (Ini, lookupValue, readIniFile)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import qualified Data.Text as T
|
||||
import Data.X509.Validation (Fingerprint (..))
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import Options.Applicative
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Control.Logger.Simple
|
||||
import Data.Functor (($>))
|
||||
import Data.Ini (lookupValue)
|
||||
import Simplex.Messaging.Server (runSMPServer)
|
||||
import Simplex.Messaging.Server.Env.STM
|
||||
import Simplex.Messaging.Server.StoreLog (StoreLog, openReadStoreLog, storeLogFilePath)
|
||||
import Simplex.Messaging.Transport (ATransport (..), TLS, Transport (..), simplexMQVersion)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint)
|
||||
import Simplex.Messaging.Transport.WebSockets (WS)
|
||||
import System.Directory (createDirectoryIfMissing, doesDirectoryExist, doesFileExist, removeDirectoryRecursive)
|
||||
import System.Exit (exitFailure)
|
||||
import Simplex.Messaging.Server.CLI (ServerCLIConfig (..), protocolServerCLI, readStrictIni)
|
||||
import Simplex.Messaging.Server.Env.STM (ServerConfig (..), defaultInactiveClientExpiration, defaultMessageExpiration)
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Transport (simplexMQVersion, supportedSMPServerVRange)
|
||||
import System.FilePath (combine)
|
||||
import System.IO (BufferMode (..), IOMode (..), hGetLine, hSetBuffering, stderr, stdout, withFile)
|
||||
import System.Process (readCreateProcess, shell)
|
||||
import Text.Read (readMaybe)
|
||||
|
||||
cfgDir :: FilePath
|
||||
cfgDir = "/etc/opt/simplex"
|
||||
cfgPath :: FilePath
|
||||
cfgPath = "/etc/opt/simplex"
|
||||
|
||||
logDir :: FilePath
|
||||
logDir = "/var/opt/simplex"
|
||||
logPath :: FilePath
|
||||
logPath = "/var/opt/simplex"
|
||||
|
||||
iniFile :: FilePath
|
||||
iniFile = combine cfgDir "smp-server.ini"
|
||||
|
||||
storeLogFile :: FilePath
|
||||
storeLogFile = combine logDir "smp-server-store.log"
|
||||
|
||||
caKeyFile :: FilePath
|
||||
caKeyFile = combine cfgDir "ca.key"
|
||||
|
||||
caCrtFile :: FilePath
|
||||
caCrtFile = combine cfgDir "ca.crt"
|
||||
|
||||
serverKeyFile :: FilePath
|
||||
serverKeyFile = combine cfgDir "server.key"
|
||||
|
||||
serverCrtFile :: FilePath
|
||||
serverCrtFile = combine cfgDir "server.crt"
|
||||
|
||||
fingerprintFile :: FilePath
|
||||
fingerprintFile = combine cfgDir "fingerprint"
|
||||
logCfg :: LogConfig
|
||||
logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
getCliCommand >>= \case
|
||||
Init opts ->
|
||||
doesFileExist iniFile >>= \case
|
||||
True -> exitError $ "Error: server is already initialized (" <> iniFile <> " exists).\nRun `smp-server start`."
|
||||
_ -> initializeServer opts
|
||||
Start ->
|
||||
doesFileExist iniFile >>= \case
|
||||
True -> readIniFile iniFile >>= either exitError (runServer . mkIniOptions)
|
||||
_ -> exitError $ "Error: server is not initialized (" <> iniFile <> " does not exist).\nRun `smp-server init`."
|
||||
Delete -> cleanup >> putStrLn "Deleted configuration and log files"
|
||||
setLogLevel LogInfo
|
||||
withGlobalLogging logCfg . protocolServerCLI smpServerCLIConfig $ \cfg@ServerConfig {inactiveClientExpiration} -> do
|
||||
putStrLn $ case inactiveClientExpiration of
|
||||
Just ExpirationConfig {ttl, checkInterval} -> "expiring clients inactive for " <> show ttl <> " seconds every " <> show checkInterval <> " seconds"
|
||||
_ -> "not expiring inactive clients"
|
||||
runSMPServer cfg
|
||||
|
||||
exitError :: String -> IO ()
|
||||
exitError msg = putStrLn msg >> exitFailure
|
||||
|
||||
data CliCommand
|
||||
= Init InitOptions
|
||||
| Start
|
||||
| Delete
|
||||
|
||||
data InitOptions = InitOptions
|
||||
{ enableStoreLog :: Bool,
|
||||
signAlgorithm :: SignAlgorithm,
|
||||
ip :: HostName,
|
||||
fqdn :: Maybe HostName
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data SignAlgorithm = ED448 | ED25519
|
||||
deriving (Read, Show)
|
||||
|
||||
getCliCommand :: IO CliCommand
|
||||
getCliCommand =
|
||||
customExecParser
|
||||
(prefs showHelpOnEmpty)
|
||||
( info
|
||||
(helper <*> versionOption <*> cliCommandP)
|
||||
(header version <> fullDesc)
|
||||
)
|
||||
where
|
||||
versionOption = infoOption version (long "version" <> short 'v' <> help "Show version")
|
||||
|
||||
cliCommandP :: Parser CliCommand
|
||||
cliCommandP =
|
||||
hsubparser
|
||||
( command "init" (info initP (progDesc $ "Initialize server - creates " <> cfgDir <> " and " <> logDir <> " directories and configuration files"))
|
||||
<> command "start" (info (pure Start) (progDesc $ "Start server (configuration: " <> iniFile <> ")"))
|
||||
<> command "delete" (info (pure Delete) (progDesc "Delete configuration and log files"))
|
||||
)
|
||||
where
|
||||
initP :: Parser CliCommand
|
||||
initP =
|
||||
Init
|
||||
<$> ( InitOptions
|
||||
<$> switch
|
||||
( long "store-log"
|
||||
<> short 'l'
|
||||
<> help "Enable store log for SMP queues persistence"
|
||||
)
|
||||
<*> option
|
||||
(maybeReader readMaybe)
|
||||
( long "sign-algorithm"
|
||||
<> short 'a'
|
||||
<> help "Signature algorithm used for TLS certificates: ED25519, ED448"
|
||||
<> value ED448
|
||||
<> showDefault
|
||||
<> metavar "ALG"
|
||||
)
|
||||
<*> strOption
|
||||
( long "ip"
|
||||
<> help
|
||||
"Server IP address, used as Common Name for TLS online certificate if FQDN is not supplied"
|
||||
<> value "127.0.0.1"
|
||||
<> showDefault
|
||||
<> metavar "IP"
|
||||
)
|
||||
<*> (optional . strOption)
|
||||
( long "fqdn"
|
||||
<> short 'n'
|
||||
<> help "Server FQDN used as Common Name for TLS online certificate"
|
||||
<> showDefault
|
||||
<> metavar "FQDN"
|
||||
)
|
||||
)
|
||||
|
||||
initializeServer :: InitOptions -> IO ()
|
||||
initializeServer InitOptions {enableStoreLog, signAlgorithm, ip, fqdn} = do
|
||||
cleanup
|
||||
createDirectoryIfMissing True cfgDir
|
||||
createDirectoryIfMissing True logDir
|
||||
createX509
|
||||
fp <- saveFingerprint
|
||||
createIni
|
||||
putStrLn $ "Server initialized, you can modify configuration in " <> iniFile <> ".\nRun `smp-server start` to start server."
|
||||
printServiceInfo fp
|
||||
warnCAPrivateKeyFile
|
||||
where
|
||||
createX509 = do
|
||||
createOpensslCaConf
|
||||
createOpensslServerConf
|
||||
-- CA certificate (identity/offline)
|
||||
run $ "openssl genpkey -algorithm " <> show signAlgorithm <> " -out " <> caKeyFile
|
||||
run $ "openssl req -new -x509 -days 999999 -config " <> opensslCaConfFile <> " -extensions v3 -key " <> caKeyFile <> " -out " <> caCrtFile
|
||||
-- server certificate (online)
|
||||
run $ "openssl genpkey -algorithm " <> show signAlgorithm <> " -out " <> serverKeyFile
|
||||
run $ "openssl req -new -config " <> opensslServerConfFile <> " -reqexts v3 -key " <> serverKeyFile <> " -out " <> serverCsrFile
|
||||
run $ "openssl x509 -req -days 999999 -extfile " <> opensslServerConfFile <> " -extensions v3 -in " <> serverCsrFile <> " -CA " <> caCrtFile <> " -CAkey " <> caKeyFile <> " -CAcreateserial -out " <> serverCrtFile
|
||||
where
|
||||
run cmd = void $ readCreateProcess (shell cmd) ""
|
||||
opensslCaConfFile = combine cfgDir "openssl_ca.conf"
|
||||
opensslServerConfFile = combine cfgDir "openssl_server.conf"
|
||||
serverCsrFile = combine cfgDir "server.csr"
|
||||
createOpensslCaConf =
|
||||
writeFile
|
||||
opensslCaConfFile
|
||||
"[req]\n\
|
||||
\distinguished_name = req_distinguished_name\n\
|
||||
\prompt = no\n\n\
|
||||
\[req_distinguished_name]\n\
|
||||
\CN = SMP server CA\n\
|
||||
\O = SimpleX\n\n\
|
||||
\[v3]\n\
|
||||
\subjectKeyIdentifier = hash\n\
|
||||
\authorityKeyIdentifier = keyid:always\n\
|
||||
\basicConstraints = critical,CA:true\n"
|
||||
-- TODO revise https://www.rfc-editor.org/rfc/rfc5280#section-4.2.1.3, https://www.rfc-editor.org/rfc/rfc3279#section-2.3.5
|
||||
-- IP and FQDN can't both be used as server address interchangeably even if IP is added
|
||||
-- as Subject Alternative Name, unless the following validation hook is disabled:
|
||||
-- https://hackage.haskell.org/package/x509-validation-1.6.10/docs/src/Data-X509-Validation.html#validateCertificateName
|
||||
createOpensslServerConf =
|
||||
writeFile
|
||||
opensslServerConfFile
|
||||
( "[req]\n\
|
||||
\distinguished_name = req_distinguished_name\n\
|
||||
\prompt = no\n\n\
|
||||
\[req_distinguished_name]\n"
|
||||
<> ("CN = " <> cn <> "\n\n")
|
||||
<> "[v3]\n\
|
||||
\basicConstraints = CA:FALSE\n\
|
||||
\keyUsage = digitalSignature, nonRepudiation, keyAgreement\n\
|
||||
\extendedKeyUsage = serverAuth\n"
|
||||
)
|
||||
where
|
||||
cn = fromMaybe ip fqdn
|
||||
|
||||
saveFingerprint = do
|
||||
Fingerprint fp <- loadFingerprint caCrtFile
|
||||
withFile fingerprintFile WriteMode (`B.hPutStrLn` strEncode fp)
|
||||
pure fp
|
||||
|
||||
createIni = do
|
||||
writeFile iniFile $
|
||||
"[STORE_LOG]\n\
|
||||
\# The server uses STM memory to store SMP queues and messages,\n\
|
||||
\# that will be lost on restart (e.g., as with redis).\n\
|
||||
\# This option enables saving SMP queues to append only log,\n\
|
||||
\# and restoring them when the server is started.\n\
|
||||
\# Log is compacted on start (deleted queues are removed).\n\
|
||||
\# The messages in the queues are not logged.\n"
|
||||
<> ("enable: " <> (if enableStoreLog then "on" else "off # on") <> "\n\n")
|
||||
<> "[TRANSPORT]\n\
|
||||
\port: 5223\n\
|
||||
\websockets: off\n"
|
||||
|
||||
warnCAPrivateKeyFile =
|
||||
putStrLn $
|
||||
"----------\n\
|
||||
\You should store CA private key securely and delete it from the server.\n\
|
||||
\If server TLS credential is compromised this key can be used to sign a new one, \
|
||||
\keeping the same server identity and established connections.\n\
|
||||
\CA private key location:\n"
|
||||
<> caKeyFile
|
||||
<> "\n----------"
|
||||
|
||||
data IniOptions = IniOptions
|
||||
{ enableStoreLog :: Bool,
|
||||
port :: ServiceName,
|
||||
enableWebsockets :: Bool
|
||||
}
|
||||
|
||||
-- TODO ? properly parse ini as a whole
|
||||
mkIniOptions :: Ini -> IniOptions
|
||||
mkIniOptions ini =
|
||||
IniOptions
|
||||
{ enableStoreLog = (== "on") $ strict "STORE_LOG" "enable",
|
||||
port = T.unpack $ strict "TRANSPORT" "port",
|
||||
enableWebsockets = (== "on") $ strict "TRANSPORT" "websockets"
|
||||
}
|
||||
where
|
||||
strict :: String -> String -> T.Text
|
||||
strict section key =
|
||||
fromRight (error ("no key " <> key <> " in section " <> section)) $
|
||||
lookupValue (T.pack section) (T.pack key) ini
|
||||
|
||||
runServer :: IniOptions -> IO ()
|
||||
runServer IniOptions {enableStoreLog, port, enableWebsockets} = do
|
||||
hSetBuffering stdout LineBuffering
|
||||
hSetBuffering stderr LineBuffering
|
||||
fp <- checkSavedFingerprint
|
||||
printServiceInfo fp
|
||||
storeLog <- openStoreLog
|
||||
let cfg = mkServerConfig storeLog
|
||||
printServerConfig cfg
|
||||
runSMPServer cfg
|
||||
where
|
||||
checkSavedFingerprint = do
|
||||
savedFingerprint <- loadSavedFingerprint
|
||||
Fingerprint fp <- loadFingerprint caCrtFile
|
||||
when (B.pack savedFingerprint /= strEncode fp) $
|
||||
exitError "Stored fingerprint is invalid."
|
||||
pure fp
|
||||
|
||||
mkServerConfig storeLog =
|
||||
ServerConfig
|
||||
{ transports = (port, transport @TLS) : [("80", transport @WS) | enableWebsockets],
|
||||
tbqSize = 16,
|
||||
serverTbqSize = 64,
|
||||
msgQueueQuota = 128,
|
||||
queueIdBytes = 24,
|
||||
msgIdBytes = 24, -- must be at least 24 bytes, it is used as 192-bit nonce for XSalsa20
|
||||
caCertificateFile = caCrtFile,
|
||||
privateKeyFile = serverKeyFile,
|
||||
certificateFile = serverCrtFile,
|
||||
storeLog,
|
||||
allowNewQueues = True,
|
||||
messageTTL = Just $ 7 * 86400, -- 7 days
|
||||
expireMessagesInterval = Just 21600_000000 -- microseconds, 6 hours
|
||||
smpServerCLIConfig :: ServerCLIConfig ServerConfig
|
||||
smpServerCLIConfig =
|
||||
let caCrtFile = combine cfgPath "ca.crt"
|
||||
serverKeyFile = combine cfgPath "server.key"
|
||||
serverCrtFile = combine cfgPath "server.crt"
|
||||
in ServerCLIConfig
|
||||
{ cfgDir = cfgPath,
|
||||
logDir = logPath,
|
||||
iniFile = combine cfgPath "smp-server.ini",
|
||||
storeLogFile = combine logPath "smp-server-store.log",
|
||||
caKeyFile = combine cfgPath "ca.key",
|
||||
caCrtFile,
|
||||
serverKeyFile,
|
||||
serverCrtFile,
|
||||
fingerprintFile = combine cfgPath "fingerprint",
|
||||
defaultServerPort = "5223",
|
||||
executableName = "smp-server",
|
||||
serverVersion = "SMP server v" <> simplexMQVersion,
|
||||
mkIniFile = \enableStoreLog defaultServerPort ->
|
||||
"[STORE_LOG]\n\
|
||||
\# The server uses STM memory for persistence,\n\
|
||||
\# that will be lost on restart (e.g., as with redis).\n\
|
||||
\# This option enables saving memory to append only log,\n\
|
||||
\# and restoring it when the server is started.\n\
|
||||
\# Log is compacted on start (deleted objects are removed).\n"
|
||||
<> ("enable: " <> (if enableStoreLog then "on" else "off") <> "\n")
|
||||
<> "# The messages are optionally saved and restored when the server restarts,\n\
|
||||
\# they are deleted after restarting.\n"
|
||||
<> ("restore_messages: " <> (if enableStoreLog then "on" else "off") <> "\n\n")
|
||||
<> "[TRANSPORT]\n"
|
||||
<> ("port: " <> defaultServerPort <> "\n")
|
||||
<> "websockets: off\n\n"
|
||||
<> "[INACTIVE_CLIENTS]\n\
|
||||
\# TTL and interval to check inactive clients\n\
|
||||
\disconnect: off\n"
|
||||
<> ("# ttl: " <> show (ttl defaultInactiveClientExpiration) <> "\n")
|
||||
<> ("# check_interval: " <> show (checkInterval defaultInactiveClientExpiration) <> "\n"),
|
||||
mkServerConfig = \storeLogFile transports ini ->
|
||||
ServerConfig
|
||||
{ transports,
|
||||
tbqSize = 16,
|
||||
serverTbqSize = 64,
|
||||
msgQueueQuota = 128,
|
||||
queueIdBytes = 24,
|
||||
msgIdBytes = 24, -- must be at least 24 bytes, it is used as 192-bit nonce for XSalsa20
|
||||
caCertificateFile = caCrtFile,
|
||||
privateKeyFile = serverKeyFile,
|
||||
certificateFile = serverCrtFile,
|
||||
storeLogFile,
|
||||
storeMsgsFile =
|
||||
let messagesPath = combine logPath "smp-server-messages.log"
|
||||
in case lookupValue "STORE_LOG" "restore_messages" ini of
|
||||
Right "on" -> Just messagesPath
|
||||
Right _ -> Nothing
|
||||
-- if the setting is not set, it is enabled when store log is enabled
|
||||
_ -> storeLogFile $> messagesPath,
|
||||
allowNewQueues = True,
|
||||
messageExpiration = Just defaultMessageExpiration,
|
||||
inactiveClientExpiration =
|
||||
if lookupValue "INACTIVE_CLIENTS" "disconnect" ini == Right "on"
|
||||
then
|
||||
Just
|
||||
ExpirationConfig
|
||||
{ ttl = readStrictIni "INACTIVE_CLIENTS" "ttl" ini,
|
||||
checkInterval = readStrictIni "INACTIVE_CLIENTS" "check_interval" ini
|
||||
}
|
||||
else Nothing,
|
||||
logStatsInterval = Just 86400, -- seconds
|
||||
logStatsStartTime = 0, -- seconds from 00:00 UTC
|
||||
serverStatsFile = Just $ combine logPath "smp-server-stats.log",
|
||||
smpServerVRange = supportedSMPServerVRange
|
||||
}
|
||||
}
|
||||
|
||||
openStoreLog :: IO (Maybe (StoreLog 'ReadMode))
|
||||
openStoreLog =
|
||||
if enableStoreLog
|
||||
then Just <$> openReadStoreLog storeLogFile
|
||||
else pure Nothing
|
||||
|
||||
printServerConfig ServerConfig {storeLog, transports} = do
|
||||
putStrLn $ case storeLog of
|
||||
Just s -> "Store log: " <> storeLogFilePath s
|
||||
Nothing -> "Store log disabled."
|
||||
forM_ transports $ \(p, ATransport t) ->
|
||||
putStrLn $ "Listening on port " <> p <> " (" <> transportName t <> ")..."
|
||||
|
||||
cleanup :: IO ()
|
||||
cleanup = do
|
||||
deleteDirIfExists cfgDir
|
||||
deleteDirIfExists logDir
|
||||
where
|
||||
deleteDirIfExists path = doesDirectoryExist path >>= (`when` removeDirectoryRecursive path)
|
||||
|
||||
printServiceInfo :: ByteString -> IO ()
|
||||
printServiceInfo fpStr = do
|
||||
putStrLn version
|
||||
B.putStrLn $ "Fingerprint: " <> strEncode fpStr
|
||||
|
||||
version :: String
|
||||
version = "SMP server v" <> simplexMQVersion
|
||||
|
||||
loadSavedFingerprint :: IO String
|
||||
loadSavedFingerprint = withFile fingerprintFile ReadMode hGetLine
|
||||
|
|
24
package.yaml
24
package.yaml
|
@ -1,5 +1,5 @@
|
|||
name: simplexmq
|
||||
version: 1.1.0
|
||||
version: 3.0.0
|
||||
synopsis: SimpleXMQ message broker
|
||||
description: |
|
||||
This package includes <./docs/Simplex-Messaging-Server.html server>,
|
||||
|
@ -28,9 +28,10 @@ dependencies:
|
|||
- asn1-types == 0.3.*
|
||||
- async == 2.2.*
|
||||
- attoparsec == 0.14.*
|
||||
- base >= 4.7 && < 5
|
||||
- base >= 4.14 && < 5
|
||||
- base64-bytestring >= 1.0 && < 1.3
|
||||
- bytestring == 0.10.*
|
||||
- case-insensitive == 1.2.*
|
||||
- composition == 1.0.*
|
||||
- constraints >= 0.12 && < 0.14
|
||||
- containers == 0.6.*
|
||||
|
@ -41,13 +42,17 @@ dependencies:
|
|||
- directory == 1.3.*
|
||||
- filepath == 1.4.*
|
||||
- http-types == 0.12.*
|
||||
- http2 == 3.0.*
|
||||
- generic-random >= 1.3 && < 1.5
|
||||
- ini == 0.4.*
|
||||
- iso8601-time == 0.1.*
|
||||
- memory == 0.15.*
|
||||
- mtl == 2.2.*
|
||||
- network == 3.1.2.*
|
||||
- network >= 3.1.2.7 && < 3.2
|
||||
- network-transport == 0.5.*
|
||||
- optparse-applicative >= 0.15 && < 0.17
|
||||
- QuickCheck == 2.14.*
|
||||
- process == 1.6.*
|
||||
- random >= 1.1 && < 1.3
|
||||
- simple-logger == 0.1.*
|
||||
- sqlite-simple == 0.4.*
|
||||
|
@ -55,6 +60,8 @@ dependencies:
|
|||
- template-haskell == 2.16.*
|
||||
- text == 1.2.*
|
||||
- time == 1.9.*
|
||||
- time-compat == 1.9.*
|
||||
- time-manager == 0.0.*
|
||||
- tls >= 1.5.7 && < 1.6
|
||||
- transformers == 0.5.*
|
||||
- unliftio == 0.2.*
|
||||
|
@ -83,9 +90,14 @@ executables:
|
|||
source-dirs: apps/smp-server
|
||||
main: Main.hs
|
||||
dependencies:
|
||||
- ini == 0.4.*
|
||||
- optparse-applicative >= 0.15 && < 0.17
|
||||
- process == 1.6.*
|
||||
- simplexmq
|
||||
ghc-options:
|
||||
- -threaded
|
||||
|
||||
ntf-server:
|
||||
source-dirs: apps/ntf-server
|
||||
main: Main.hs
|
||||
dependencies:
|
||||
- simplexmq
|
||||
ghc-options:
|
||||
- -threaded
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
sequenceDiagram
|
||||
participant A as Alice
|
||||
participant AA as Alice's<br>agent
|
||||
participant AS as Alice's<br>server
|
||||
participant BS as Bob's<br>server
|
||||
participant BA as Bob's<br>agent
|
||||
participant B as Bob
|
||||
|
||||
note over AA, BA: status (receive/send): NONE/NONE
|
||||
|
||||
note over A, AA: 1. request connection<br>from agent
|
||||
A ->> AA: NEW: create<br>duplex connection
|
||||
|
||||
note over AA, AS: 2. create Alice's SMP queue
|
||||
AA ->> AS: NEW: create SMP queue
|
||||
AS ->> AA: IDS: SMP queue IDs
|
||||
note over AA: status: NEW/NONE
|
||||
|
||||
AA ->> A: INV: invitation<br>to connect
|
||||
|
||||
note over A, B: 3. out-of-band invitation
|
||||
A ->> B: OOB: invitation to connect
|
||||
|
||||
note over BA, B: 4. accept connection
|
||||
B ->> BA: JOIN:<br>via invitation info
|
||||
note over BA: status: NONE/NEW
|
||||
|
||||
note over BA, BS: 5. create Bob's SMP queue
|
||||
BA ->> BS: NEW: create SMP queue
|
||||
BS ->> BA: IDS: SMP queue IDs
|
||||
note over BA: status: NEW/NEW
|
||||
|
||||
note over BA, AA: 6. establish Alice's SMP queue
|
||||
BA ->> AS: SEND: Bob's info and sender server key (SMP confirmation with reply queues)
|
||||
note over BA: status: NEW/CONFIRMED
|
||||
|
||||
AS ->> AA: MSG: Bob's info and<br>sender server key
|
||||
note over AA: status: CONFIRMED/NONE
|
||||
AA ->> AS: ACK: confirm message
|
||||
AA ->> A: CONF: connection request ID<br>and Bob's info
|
||||
A ->> AA: LET: accept connection request,<br>send Alice's info
|
||||
AA ->> AS: KEY: secure queue
|
||||
note over AA: status: SECURED/NONE
|
||||
|
||||
AA ->> BS: SEND: Alice's info and sender's server key (SMP confirmation without reply queues)
|
||||
note over AA: status: SECURED/CONFIRMED
|
||||
|
||||
BS ->> BA: MSG: Alice's info and<br>sender's server key
|
||||
note over BA: status: CONFIRMED/CONFIRMED
|
||||
BA ->> B: INFO: Alice's info
|
||||
BA ->> BS: ACK: confirm message
|
||||
BA ->> BS: KEY: secure queue
|
||||
note over BA: status: SECURED/CONFIRMED
|
||||
|
||||
BA ->> AS: SEND: HELLO: only needs to be sent once in v2
|
||||
|
||||
note over BA: status: SECURED/ACTIVE
|
||||
note over BA, B: 7a. notify Bob<br>about connection success
|
||||
BA ->> B: CON: connected
|
||||
|
||||
AS ->> AA: MSG: HELLO: Alice's agent<br>knows Bob can send
|
||||
note over AA: status: SECURED/ACTIVE
|
||||
AA ->> AS: ACK: confirm message
|
||||
note over A, AA: 7a. notify Alice<br>about connection success
|
||||
AA ->> A: CON: connected
|
||||
|
||||
AA ->> BS: SEND: HELLO: only needs to be sent once in v2
|
||||
note over AA: status: ACTIVE/ACTIVE
|
||||
BS ->> BA: MSG: HELLO: Bob's agent<br>knows Alice can send
|
||||
note over BA: status: ACTIVE/ACTIVE
|
||||
BA ->> BS: ACK: confirm message
|
|
@ -33,8 +33,8 @@ sequenceDiagram
|
|||
AS ->> AA: MSG: Bob's info and<br>sender server key
|
||||
note over AA: status: CONFIRMED/NONE
|
||||
AA ->> AS: ACK: confirm message
|
||||
AA ->> A: REQ: connection request ID<br>and Bob's info
|
||||
A ->> AA: ACPT: accept connection request,<br>send Alice's info
|
||||
AA ->> A: CONF: connection request ID<br>and Bob's info
|
||||
A ->> AA: LET: accept connection request,<br>send Alice's info
|
||||
AA ->> AS: KEY: secure queue
|
||||
note over AA: status: SECURED/NONE
|
||||
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
sequenceDiagram
|
||||
participant M as mobile app
|
||||
participant C as chat core
|
||||
participant A as agent
|
||||
participant P as push server
|
||||
participant APN as APN
|
||||
|
||||
note over M, APN: get device token
|
||||
M ->> APN: registerForRemoteNotifications()
|
||||
APN ->> M: device token
|
||||
|
||||
note over M, P: register device token with push server
|
||||
M ->> C: /_ntf register <token>
|
||||
C ->> A: registerNtfToken(<token>)
|
||||
A ->> P: TNEW
|
||||
P ->> A: ID (tokenId)
|
||||
A ->> C: registered
|
||||
C ->> M: registered
|
||||
|
||||
note over M, APN: verify device token
|
||||
P ->> APN: E2E encrypted code<br>in background<br>notification
|
||||
APN ->> M: deliver background notification with e2ee verification token
|
||||
M ->> C: /_ntf verify <e2ee code>
|
||||
C ->> A: verifyNtfToken(<e2ee code>)
|
||||
A ->> P: TVFY code
|
||||
P ->> A: OK / ERR
|
||||
A ->> C: verified
|
||||
C ->> M: verified
|
||||
|
||||
note over M, APN: now token ID can be used
|
|
@ -0,0 +1,40 @@
|
|||
sequenceDiagram
|
||||
participant M as mobile app
|
||||
participant C as chat core
|
||||
participant A as agent
|
||||
participant S as SMP server
|
||||
participant N as NTF server
|
||||
participant APN as APN
|
||||
|
||||
note over M, APN: register subscription
|
||||
|
||||
alt register existing
|
||||
M -->> A: on /_ntf register, for subscribed queues
|
||||
else create new connection
|
||||
A -->> S: NEW / JOIN
|
||||
note over A, S: ...<br>Connection handshake<br>...
|
||||
S -->> A: CON
|
||||
end
|
||||
A ->> S: NKEY nKey
|
||||
S ->> A: NID nId
|
||||
A ->> N: SNEW tknId dhKey (smpServer, nId, nKey)
|
||||
N ->> A: ID subId dhKey
|
||||
N ->> S: NSUB nId
|
||||
S ->> N: OK [/ NMSG]
|
||||
|
||||
note over M, APN: notify about message
|
||||
|
||||
S ->> N: NMSG
|
||||
N ->> APN: APNSMutableContent<br>ntfQueue, nonce
|
||||
APN ->> M: UNMutableNotificationContent
|
||||
note over M, S: ...<br>Client awaken, message is received<br>...
|
||||
S ->> M: message
|
||||
note over M: mutate notification
|
||||
|
||||
note over M, APN: change APN token
|
||||
|
||||
APN ->> M: new device token
|
||||
M -->> C: /_ntf_sub update tkn
|
||||
C -->> A: updateNtfToken()
|
||||
A -->> N: TUPD tknId newDeviceToken
|
||||
note over M, N: ...<br>Verify token<br>...
|
|
@ -24,6 +24,7 @@
|
|||
- [Subscribe to queue](#subscribe-to-queue)
|
||||
- [Secure queue command](#secure-queue-command)
|
||||
- [Enable notifications command](#enable-notifications-command)
|
||||
- [Disable notifications command](#disable-notifications-command)
|
||||
- [Acknowledge message delivery](#acknowledge-message-delivery)
|
||||
- [Suspend queue](#suspend-queue)
|
||||
- [Delete queue](#delete-queue)
|
||||
|
@ -146,61 +147,61 @@ To create and start using a simplex queue Alice and Bob follow these steps:
|
|||
|
||||
1. Alice creates a simplex queue on the server:
|
||||
|
||||
1. Decides which SMP server to use (can be the same or different server that Alice uses for other queues) and opens secure encrypted transport connection to the chosen SMP server (see [Appendix A](#appendix-a)).
|
||||
1. Decides which SMP server to use (can be the same or different server that Alice uses for other queues) and opens secure encrypted transport connection to the chosen SMP server (see [Appendix A](#appendix-a)).
|
||||
|
||||
2. Generates a new random public/private key pair (encryption key - `EK`) that she did not use before for Bob to encrypt the messages.
|
||||
2. Generates a new random public/private key pair (encryption key - `EK`) that she did not use before for Bob to encrypt the messages.
|
||||
|
||||
3. Generates another new random public/private key pair (recipient key - `RK`) that she did not use before for her to sign commands and to decrypt the transmissions received from the server.
|
||||
3. Generates another new random public/private key pair (recipient key - `RK`) that she did not use before for her to sign commands and to decrypt the transmissions received from the server.
|
||||
|
||||
4. Generates one more random key pair (recipient DH key - `RDHK`) to negotiate symmetric key that will be used by the server to encrypt message bodies delivered to Alice (to avoid shared cipher-text inside transport connection).
|
||||
4. Generates one more random key pair (recipient DH key - `RDHK`) to negotiate symmetric key that will be used by the server to encrypt message bodies delivered to Alice (to avoid shared cipher-text inside transport connection).
|
||||
|
||||
5. Sends `"NEW"` command to the server to create a simplex queue (see `create` in [Create queue command](#create-queue-command)). This command contains previously generated unique "public" keys `RK` and `RDHK`. `RK` will be used to verify the following commands related to the same queue signed by its private counterpart, for example to subscribe to the messages received to this queue or to update the queue, e.g. by setting the key required to send the messages (initially Alice creates the queue that accepts unsigned messages, so anybody could send the message via this queue if they knew the queue sender's ID and server address).
|
||||
5. Sends `"NEW"` command to the server to create a simplex queue (see `create` in [Create queue command](#create-queue-command)). This command contains previously generated unique "public" keys `RK` and `RDHK`. `RK` will be used to verify the following commands related to the same queue signed by its private counterpart, for example to subscribe to the messages received to this queue or to update the queue, e.g. by setting the key required to send the messages (initially Alice creates the queue that accepts unsigned messages, so anybody could send the message via this queue if they knew the queue sender's ID and server address).
|
||||
|
||||
6. The server sends `"IDS"` response with queue IDs (`queueIds`):
|
||||
6. The server sends `"IDS"` response with queue IDs (`queueIds`):
|
||||
|
||||
- Recipient ID `RID` for Alice to manage the queue and to receive the messages.
|
||||
- Recipient ID `RID` for Alice to manage the queue and to receive the messages.
|
||||
|
||||
- Sender ID `SID` for Bob to send messages to the queue.
|
||||
- Sender ID `SID` for Bob to send messages to the queue.
|
||||
|
||||
- Server public DH key (`SDHK`) to negotiate a shared secret for message body encryption, that Alice uses to derive a shared secret with the server `SS`.
|
||||
- Server public DH key (`SDHK`) to negotiate a shared secret for message body encryption, that Alice uses to derive a shared secret with the server `SS`.
|
||||
|
||||
2. Alice sends an out-of-band message to Bob via the alternative channel that both Alice and Bob trust (see [protocol abstract](#simplex-messaging-protocol-abstract)). The message must include:
|
||||
|
||||
- Unique "public" key (`EK`) that Bob must use for E2E key agreement.
|
||||
- Unique "public" key (`EK`) that Bob must use for E2E key agreement.
|
||||
|
||||
- SMP server hostname and information to open secure encrypted transport connection (see [Appendix A](#appendix-a)).
|
||||
- SMP server hostname and information to open secure encrypted transport connection (see [Appendix A](#appendix-a)).
|
||||
|
||||
- Sender queue ID `SID` for Bob to use.
|
||||
- Sender queue ID `SID` for Bob to use.
|
||||
|
||||
3. Bob, having received the out-of-band message from Alice, connects to the queue:
|
||||
|
||||
1. Generates a new random public/private key pair (sender key - `SK`) that he did not use before for him to sign messages sent to Alice's server.
|
||||
1. Generates a new random public/private key pair (sender key - `SK`) that he did not use before for him to sign messages sent to Alice's server.
|
||||
|
||||
2. Prepares the confirmation message for Alice to secure the queue. This message includes:
|
||||
2. Prepares the confirmation message for Alice to secure the queue. This message includes:
|
||||
|
||||
- Previously generated "public" key `SK` that will be used by Alice's server to authenticate Bob's messages, once the queue is secured.
|
||||
- Previously generated "public" key `SK` that will be used by Alice's server to authenticate Bob's messages, once the queue is secured.
|
||||
|
||||
- Optionally, any additional information (application specific, e.g. Bob's profile name and details).
|
||||
- Optionally, any additional information (application specific, e.g. Bob's profile name and details).
|
||||
|
||||
3. Encrypts the confirmation body with the "public" key `EK` (that Alice provided via the out-of-band message).
|
||||
3. Encrypts the confirmation body with the "public" key `EK` (that Alice provided via the out-of-band message).
|
||||
|
||||
4. Sends the encrypted message to the server with queue ID `SID` (see `send` in [Send message](#send-message)). This initial message to the queue must not be signed - signed messages will be rejected until Alice secures the queue (below).
|
||||
4. Sends the encrypted message to the server with queue ID `SID` (see `send` in [Send message](#send-message)). This initial message to the queue must not be signed - signed messages will be rejected until Alice secures the queue (below).
|
||||
|
||||
4. Alice receives Bob's message from the server using recipient queue ID `RID` (possibly, via the same transport connection she already has opened - see `message` in [Deliver queue message](#deliver-queue-message)):
|
||||
|
||||
1. She decrypts received message body using the secret `SS`.
|
||||
1. She decrypts received message body using the secret `SS`.
|
||||
|
||||
2. She decrypts received message with [key agreed with sender using] "private" key `EK`.
|
||||
2. She decrypts received message with [key agreed with sender using] "private" key `EK`.
|
||||
|
||||
3. Even though anybody could have sent the message to the queue with ID `SID` before it is secured (e.g. if communication is compromised), Alice would ignore all messages until the decryption succeeds (i.e. the result contains the expected message format). Optionally, in the client application, she also may identify Bob using the information provided, but it is out of scope of SMP protocol.
|
||||
3. Even though anybody could have sent the message to the queue with ID `SID` before it is secured (e.g. if communication is compromised), Alice would ignore all messages until the decryption succeeds (i.e. the result contains the expected message format). Optionally, in the client application, she also may identify Bob using the information provided, but it is out of scope of SMP protocol.
|
||||
|
||||
5. Alice secures the queue `RID` with `"KEY"` command so only Bob can send messages to it (see [Secure queue command](#secure-queue-command)):
|
||||
|
||||
1. She sends the `KEY` command with `RID` signed with "private" key `RK` to update the queue to only accept requests signed by "private" key `SK` provided by Bob. This command contains unique "public" key `SK` previously generated by Bob.
|
||||
1. She sends the `KEY` command with `RID` signed with "private" key `RK` to update the queue to only accept requests signed by "private" key `SK` provided by Bob. This command contains unique "public" key `SK` previously generated by Bob.
|
||||
|
||||
2. From this moment the server will accept only signed commands to `SID`, so only Bob will be able to send messages to the queue `SID` (corresponding to `RID` that Alice has).
|
||||
2. From this moment the server will accept only signed commands to `SID`, so only Bob will be able to send messages to the queue `SID` (corresponding to `RID` that Alice has).
|
||||
|
||||
3. Once queue is secured, Alice deletes `SID` and `SK` - even if Alice's client is compromised in the future, the attacker would not be able to send messages pretending to be Bob.
|
||||
3. Once queue is secured, Alice deletes `SID` and `SK` - even if Alice's client is compromised in the future, the attacker would not be able to send messages pretending to be Bob.
|
||||
|
||||
6. The simplex queue `RID` is now ready to be used.
|
||||
|
||||
|
@ -214,21 +215,21 @@ Bob now can securely send messages to Alice:
|
|||
|
||||
1. Bob sends the message:
|
||||
|
||||
1. He encrypts the message to Alice with "public" key `EK` (provided by Alice, only known to Bob, used only for one simplex queue).
|
||||
1. He encrypts the message to Alice with "public" key `EK` (provided by Alice, only known to Bob, used only for one simplex queue).
|
||||
|
||||
2. He signs `"SEND"` command to the server queue `SID` using the "private" key `SK` (that only he knows, used only for this queue).
|
||||
2. He signs `"SEND"` command to the server queue `SID` using the "private" key `SK` (that only he knows, used only for this queue).
|
||||
|
||||
3. He sends the command to the server (see `send` in [Send message](#send-message)), that the server will authenticate using the "public" key `SK` (that Alice earlier received from Bob and provided to the server via `"KEY"` command).
|
||||
3. He sends the command to the server (see `send` in [Send message](#send-message)), that the server will authenticate using the "public" key `SK` (that Alice earlier received from Bob and provided to the server via `"KEY"` command).
|
||||
|
||||
2. Alice receives the message(s):
|
||||
|
||||
1. She signs `"SUB"` command to the server to subscribe to the queue `RID` with the "private" key `RK` (see `subscribe` in [Subscribe to queue](#subscribe-to-queue)).
|
||||
1. She signs `"SUB"` command to the server to subscribe to the queue `RID` with the "private" key `RK` (see `subscribe` in [Subscribe to queue](#subscribe-to-queue)).
|
||||
|
||||
2. The server, having authenticated Alice's command with the "public" key `RK` that she provided, delivers Bob's message(s) (see `message` in [Deliver queue message](#deliver-queue-message)).
|
||||
2. The server, having authenticated Alice's command with the "public" key `RK` that she provided, delivers Bob's message(s) (see `message` in [Deliver queue message](#deliver-queue-message)).
|
||||
|
||||
3. She decrypts Bob's message(s) with the "private" key `EK` (that only she has).
|
||||
3. She decrypts Bob's message(s) with the "private" key `EK` (that only she has).
|
||||
|
||||
4. She acknowledges the message reception to the server with `"ACK"` so that the server can delete the message and deliver the next messages.
|
||||
4. She acknowledges the message reception to the server with `"ACK"` so that the server can delete the message and deliver the next messages.
|
||||
|
||||
This flow is show on sequence diagram below.
|
||||
|
||||
|
@ -355,9 +356,12 @@ To protect the privacy of the recipients, there are several commands in SMP prot
|
|||
The clients can optionally instruct a dedicated push notification server to subscribe to notifications and deliver push notifications to the device, which can then retrieve the messages in the background and send local notifications to the user - this is out of scope of SMP protocol. The commands that SMP protocol provides to allow it:
|
||||
|
||||
- `enableNotifications` (`"NKEY"`) with `notifierId` (`"NID"`) response - see [Enable notifications command](#enable-notifications-command).
|
||||
- `disableNotifications` (`"NDEL"`) - see [Disable notifications command](#disable-notifications-command).
|
||||
- `subscribeNotifications` (`"NSUB"`) - see [Subscribe to queue notifications](#subscribe-to-queue-notifications).
|
||||
- `messageNotification` (`"NMSG"`) - see [Deliver message notification](#deliver-message-notification).
|
||||
|
||||
[`SEND` command](#send-message) includes the notification flag to instruct SMP server whether to send the notification - this flag is forwarded to the recepient inside encrypted envelope, together with the timestamp and the message body, so even if TLS is compromised this flag cannot be used for traffic correlation.
|
||||
|
||||
## SMP Transmission structure
|
||||
|
||||
Each transport block (SMP transmission) has a fixed size of 16384 bytes for traffic uniformity.
|
||||
|
@ -401,7 +405,7 @@ Commands syntax below is provided using [ABNF][8] with [case-sensitive strings e
|
|||
|
||||
```abnf
|
||||
smpCommand = ping / recipientCmd / send / subscribeNotifications / serverMsg
|
||||
recipientCmd = create / subscribe / secure / enableNotifications /
|
||||
recipientCmd = create / subscribe / secure / enableNotifications / disableNotifications /
|
||||
acknowledge / suspend / delete
|
||||
serverMsg = queueIds / message / notifierId / messageNotification /
|
||||
unsubscribed / ok / error
|
||||
|
@ -502,22 +506,42 @@ Once the queue is secured only signed messages can be sent to it.
|
|||
This command is sent by the recipient to the server to add notifier's key to the queue, to allow push notifications server to receive notifications when the message arrives, via a separate queue ID, without receiving message content.
|
||||
|
||||
```abnf
|
||||
enableNotifications = %s"NKEY " notifierKey
|
||||
enableNotifications = %s"NKEY " notifierKey recipientNotificationDhPublicKey
|
||||
notifierKey = length x509encoded
|
||||
; the notifier's Ed25519 or Ed448 public key public key to verify NSUB command for this queue
|
||||
|
||||
recipientNotificationDhPublicKey = length x509encoded
|
||||
; the recipient's Curve25519 key for DH exchange to derive the secret
|
||||
; that the server will use to encrypt notification metadata (encryptedNMsgMeta in NMSG)
|
||||
; using [NaCl crypto_box][16] encryption scheme (curve25519xsalsa20poly1305).
|
||||
```
|
||||
|
||||
The server will respond with `notifierId` response if notifications were enabled and the notifier's key was successfully added to the queue:
|
||||
|
||||
```abnf
|
||||
notifierId = %s"NID " notifierId
|
||||
notifierId = %s"NID " notifierId srvNotificationDhPublicKey
|
||||
notifierId = length *OCTET ; 16-24 bytes
|
||||
srvNotificationDhPublicKey = length x509encoded
|
||||
; the server's Curve25519 key for DH exchange to derive the secret
|
||||
; that the server will use to encrypt notification metadata to the recipient (encryptedNMsgMeta in NMSG)
|
||||
```
|
||||
|
||||
This response is sent with the recipient's queue ID (the third part of the transmission).
|
||||
|
||||
To receive the message notifications, `subscribeNotifications` command ("NSUB") must be sent signed with the notifier's key.
|
||||
|
||||
#### Disable notifications command
|
||||
|
||||
This command is sent by the recipient to the server to remove notifier's credentials from the queue:
|
||||
|
||||
```abnf
|
||||
disableNotifications = %s"NDEL"
|
||||
```
|
||||
|
||||
The server must respond `ok` to this command if it was successful.
|
||||
|
||||
Once notifier's credentials are removed server will no longer send "NMSG" for this queue to notifier.
|
||||
|
||||
#### Acknowledge message delivery
|
||||
|
||||
The recipient should send the acknowledgement of message delivery once the message was stored in the client, to notify the server that the message should be deleted:
|
||||
|
@ -565,7 +589,9 @@ Currently SMP defines only one command that can be used by senders - `send` mess
|
|||
This command is sent to the server by the sender both to confirm the queue after the sender received out-of-band message from the recipient and to send messages after the queue is secured:
|
||||
|
||||
```abnf
|
||||
send = %s"SEND " smpEncMessage
|
||||
send = %s"SEND " msgFlags SP smpEncMessage
|
||||
msgFlags = notificationFlag reserved
|
||||
notificationFlag = %s"T" / %s"F"
|
||||
smpEncMessage = smpPubHeader sentMsgBody ; message up to 16088 bytes
|
||||
smpPubHeader = smpClientVersion ("1" senderPublicDhKey / "0")
|
||||
smpClientVersion = word16
|
||||
|
@ -707,9 +733,11 @@ See its syntax in [Create queue command](#create-queue-command)
|
|||
The server must deliver messages to all subscribed simplex queues on the currently open transport connection. The syntax for the message delivery is:
|
||||
|
||||
```abnf
|
||||
message = %s"MSG " msgId SP timestamp SP encryptedMsgBody
|
||||
message = %s"MSG " msgId encryptedRcvMsgBody
|
||||
encryptedMsgBody = <encrypt paddedSentMsgBody> ; server-encrypted padded sent msgBody
|
||||
paddedSentMsgBody = <padded(sentMsgBody, maxMessageLength + 2)> ; maxMessageLength = 16088
|
||||
encryptedRcvMsgBody = <encrypt rcvMsgBody> ; server-encrypted meta-data and padded sent msgBody
|
||||
rcvMsgBody = timestamp msgFlags SP paddedSentMsgBody
|
||||
msgId = length 24*24OCTET
|
||||
timestamp = 8*8OCTET
|
||||
```
|
||||
|
@ -735,10 +763,19 @@ See its syntax in [Enable notifications command](#enable-notifications-command)
|
|||
The server must deliver message notifications to all simplex queues that were subscribed with `subscribeNotifications` command ("NSUB") on the currently open transport connection. The syntax for the message notification delivery is:
|
||||
|
||||
```abnf
|
||||
messageNotification = %s"NMSG"
|
||||
messageNotification = %s"NMSG " nmsgNonce encryptedNMsgMeta
|
||||
|
||||
encryptedNMsgMeta = <encrypted message metadata passed in notification>
|
||||
; metadata E2E encrypted between server and recipient containing server's message ID and timestamp (allows extension),
|
||||
; to be passed to the recipient by the notifier for them to decrypt
|
||||
; with key negotiated in NKEY and NID commands using nmsgNonce
|
||||
|
||||
nmsgNonce = <nonce used in NaCl crypto_box encryption scheme>
|
||||
; nonce used by the server for encryption of message metadata, to be passed to the recipient by the notifier
|
||||
; for them to use in decryption of E2E encrypted metadata
|
||||
```
|
||||
|
||||
Message notification does not contain any message data or meta-data, it only notifies that the message is available.
|
||||
Message notification does not contain any message data or non E2E encrypted metadata.
|
||||
|
||||
#### Subscription END notification
|
||||
|
||||
|
@ -770,7 +807,7 @@ The syntax for error responses:
|
|||
```abnf
|
||||
error = %s"ERR " errorType
|
||||
errorType = %s"BLOCK" / %s"SESSION" / %s"CMD " cmdError / %s"AUTH" / %s"LARGE_MSG" /%s"INTERNAL"
|
||||
cmdError = %s"SYNTAX" / %s"PROHIBITED" / %s"NO_AUTH" / %s"HAS_AUTH" / %s"NO_QUEUE"
|
||||
cmdError = %s"SYNTAX" / %s"PROHIBITED" / %s"NO_AUTH" / %s"HAS_AUTH" / %s"NO_ENTITY"
|
||||
```
|
||||
|
||||
Server implementations must aim to respond within the same time for each command in all cases when `"ERR AUTH"` response is required to prevent timing attacks (e.g., the server should perform signature verification even when the queue does not exist on the server or the signature of different size is sent, using any RSA key with the same size as the signature size).
|
||||
|
@ -792,6 +829,7 @@ ok = %s"OK"
|
|||
Both the recipient and the sender can use TCP or some other, possibly higher level, transport protocol to communicate with the server. The default TCP port for SMP server is 5223.
|
||||
|
||||
For scenarios when meta-data privacy is critical, it is recommended that clients:
|
||||
|
||||
- communicating over Tor network,
|
||||
- establish a separate connection for each SMP queue,
|
||||
- send noise traffic (using PING command).
|
||||
|
@ -799,12 +837,14 @@ For scenarios when meta-data privacy is critical, it is recommended that clients
|
|||
In addition to that, the servers can be deployed as Tor onion services.
|
||||
|
||||
The transport protocol should provide the following:
|
||||
|
||||
- server authentication (by matching server certificate hash with `serverIdentity`),
|
||||
- forward secrecy (by encrypting the traffic using ephemeral keys agreed during transport handshake),
|
||||
- integrity (preventing data modification by the attacker without detection),
|
||||
- unique channel binding (`sessionIdentifier`) to include in the signed part of SMP transmissions.
|
||||
|
||||
By default, the client and server communicate using [TLS 1.3 protocol][13] restricted to:
|
||||
|
||||
- TLS_CHACHA20_POLY1305_SHA256 cipher suite (for better performance on mobile devices),
|
||||
- ed25519 and ed448 EdDSA algorithms for signatures,
|
||||
- x25519 and x448 ECDHE groups for key exchange.
|
||||
|
|
|
@ -13,6 +13,6 @@ Change controller: Evgeny Poberezkin <ep@simplex.chat>
|
|||
|
||||
References:
|
||||
The syntax for connection requests in the latest version of SimpleX Agent Protocol:
|
||||
https://github.com/simplex-chat/simplexmq/blob/v5/protocol/agent-protocol.md#connection-request
|
||||
https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md#connection-request
|
||||
SimpleX Messaging Protocol:
|
||||
https://github.com/simplex-chat/simplexmq/blob/v5/protocol/simplex-messaging.md
|
||||
https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
|
||||
|
|
|
@ -13,4 +13,4 @@ Change controller: Evgeny Poberezkin <ep@simplex.chat>
|
|||
|
||||
References:
|
||||
The syntax for message queue URIs in the latest version of SimpleX Messaging Protocol:
|
||||
https://github.com/simplex-chat/simplexmq/blob/v5/protocol/simplex-messaging.md#smp-queue-uri
|
||||
https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#smp-queue-uri
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
# DB access and processing messages for iOS notification service extension
|
||||
|
||||
## Problem
|
||||
|
||||
The only way to receive/process notificaitons is via a separate NSE process that requires a concurrent DB and network access.
|
||||
|
||||
SQLite concurenncy does not work, so we need to sync database access.
|
||||
|
||||
The problem is complex, as we do not directly control db access from the app, it can be triggered by the message arriving and it may fail to complete in case the app is suspended in the background.
|
||||
|
||||
So we need to prevent db access from starting when we know the app is about to be suspended.
|
||||
|
||||
The last problem is how to receive and process messages in NSE - should it use recently added GET command or should it subscribe to the connections that receive messages and process messages normally.
|
||||
|
||||
To summarize, 2 problems need to be solved:
|
||||
|
||||
1. sync db access between 2 (or more, if we add share extension) processes
|
||||
|
||||
2. prevent access from starting when the process is due to suspend, only complete operations.
|
||||
|
||||
3. Receiveing and processing agent messages in NSE
|
||||
|
||||
## Proposed solution
|
||||
|
||||
For problem 1, we can use Posix semaphores from our core code, in the same bracket that manages database connection - it would wait for semaphore to be free and unlock it once the db operation is complete.
|
||||
|
||||
For problem 2, we need to communicate from the app when it goes to the background to prevent database access starting and completing before the suspension. This would set some `aboutToSuspend` STM flag (or the opposite) that would prevent operations from progressing (using STM retry, that would block until the flag has the value allowing operation to progress).
|
||||
|
||||
Several possibilities can be considered:
|
||||
|
||||
- use this flag in the bracket that provides DB connections. While simple, it may lock some operations in the middle and may also lead to the situation when network operation succeeds but database access was blocked, and the database is not updated.
|
||||
- use this flag to stop network operations that would require database updates - like sending messages, subscriptions and ACK - all these operations would require database access once they succeed.
|
||||
- use two flags, for both cases above, but set them at different times since going to background - block new network operations as soon as the app goes to the background and block database access once the app is about to be suspended.
|
||||
|
||||
The last option seems more robust. To do it, there will be an api allowing the app to communicat its phase:
|
||||
|
||||
- app going to the background would trigger blocking new network operations and start a new background task - `background` phase.
|
||||
- background task receiving completion warning, or, maybe, some time after it is started - probably 20 seconds - or whatever happens earlier - would trigger call blocking db access - `suspended` phase.
|
||||
- app becoming active would trigger unblocking both flags - `active` phase.
|
||||
|
||||
`/_app phase <phase>` where `phase` can be one of the above values.
|
||||
|
||||
NSE would also use the same phases:
|
||||
|
||||
- sending `active` when it is started (the process starts as active, but it is possible that the new notification arrives to the same process, after the previous one sent background/suspension)
|
||||
- sending `background + suspended` (or `suspended` should set both flags) once it is finished processing the notification, provided no new notification arrived and started processing - this should be tracked in NSE.
|
||||
|
||||
For problem 3, NSE can do one of the following:
|
||||
|
||||
- use SUB and process messages normally - the downside is that the app will have to resubscribe and it has to be tracked.
|
||||
- use GET and process messages by pushing them through the processing function - the downside it that a rewiring of message processing is needed.
|
||||
- use GET but deliver messages through the same queue as when they arrive normally (in which case getMessage agent function should not return the message, but will return a flag showing whether the message was received, or, possibly, or, possibly will return a message but the message would also be sent to the queue?).
|
||||
- process messages in agent manually and in chat via the queue.
|
||||
|
||||
One of the downside of GET is that it requires calling GET again after ACK. We could have two variants of ACK (or additional ACK parameter) - one that never delivers a new message, and another one that does. In this case, if get needs to process the next message (when the current one has no notification flag), it can call ACK that delivers the next message. But, it is probably a premature optimization, and having general support of batched commands would add more value.
|
||||
|
||||
Additional problem is concurrency in NSE - if the new notification from the same queue arrives before the current one finishes processing in the same process one of the following can happen:
|
||||
|
||||
- the 2nd notification naively call GET and receives the same message.
|
||||
- the 2nd notification waits until the first finished processing, in which case it can run out of time.
|
||||
|
||||
The problem is that the app won't know it's the same queue, as nId is encrypted, so the agent should handle this scenario when the new call to getMessage is made before the previous one finished processing, and differentiate between calls made for additional messages (possibly, getMessage should include previous message ID, if it is available) and the first call.
|
||||
|
||||
EDIT: GETs have to be sent from UI to chat and from chat to agent as function calls, but the agent will have to queue get calls to make sure they return different messages. GET call would return message flags (incl. notification flag), so that the UI can send the next GET if needed without waiting.
|
||||
|
||||
Considered alternative: include notification content in the message and have NSE only perform decryption, without any network IO. In this case notification content would be in SEND and in NMSG, e2e encrypted.
|
||||
|
||||
While promising, as it solves network coordination issues and makes GET unnecessary, it creates mutliple other problems, so it was rejected:
|
||||
|
||||
- message content is exposed to centralized ntf and apns servers, creating additional attack vector.
|
||||
- it adds complexity in security critical parts of the stack - double ratchet encryption, as it requires either storing message keys and using different IVs for notifications, or initializing completely separate ratchet for notifications content.
|
||||
- it reduces the size of the message.
|
||||
- it makes user experience worse, as:
|
||||
- it would not accelerate handshake for new contacts and for file delivery - this approach only works for content messages.
|
||||
- it would open the app without the new messages - the users would have to wait until the messages are received. It is also bad for "security optics" - the users might think that the message content was exposed to notifications.
|
|
@ -0,0 +1,61 @@
|
|||
sequenceDiagram
|
||||
participant M as iOS message<br>notification
|
||||
participant S as iOS system
|
||||
participant N as iOS NSE
|
||||
participant U as iOS UI
|
||||
participant C as Core chat
|
||||
participant A as Core agent
|
||||
|
||||
M ->> N: notification
|
||||
S ->> N: get app pref
|
||||
note over N: ignore,<br>app is active
|
||||
|
||||
note over M, A: app going to background
|
||||
S ->> U: phase: background<br>(possibly, "will" method)
|
||||
U ->> S: set app pref "pausing"
|
||||
U ->> C: /_app phase paused, result CRCmdOk
|
||||
C ->> A: pauseAgent<br>(no new network IO)
|
||||
M ->> N: notification
|
||||
S ->> N: get app pref
|
||||
note over N: wait/poll for<br>"paused"/"suspending"/"suspended"<br>event/pref
|
||||
A ->> C: event "IO paused"<br>(after in-flight op completed)<br>PHASE PAUSED
|
||||
C ->> U: event "IO paused" (CRAppPaused)
|
||||
U ->> S: set shared pref "paused"
|
||||
|
||||
note over M, A: process notification
|
||||
M ->> N: notification
|
||||
S ->> N: get app pref<br>continue if<br>"paused"/"suspending"/"suspended"
|
||||
N ->> S: set NSE pref "active"
|
||||
N ->> C: /_get message
|
||||
C ->> A: getMessage
|
||||
A ->> C: msg flags
|
||||
C ->> N: msg flags
|
||||
note over N: get messages<br>until notification flag set
|
||||
A ->> C: MSG/CONF/INFO
|
||||
C ->> N: some event
|
||||
N ->> S: set NSE pref "completed"
|
||||
N ->> S: show notification
|
||||
|
||||
note over M, A: app about to be suspended<br>(or 15-20 sec after background)
|
||||
S ->> U: background task notice
|
||||
U ->> S: set app pref "suspending"
|
||||
U ->> C: /_app phase suspended, response ok
|
||||
C ->> A: suspendAgent<br>(no new DB)
|
||||
A ->> C: event "DB paused"<br>(after in-flight op completed)<br>PHASE SUSPENDED
|
||||
C ->> U: event "DB paused" (CRAppSuspended)
|
||||
U ->> S: set app pref "suspended"
|
||||
|
||||
note over M, A: app about to be activated
|
||||
S ->> U: phase: active<br>(or inactive?)<br><br>(possibly, "will" method)
|
||||
S ->> U: get NSE pref
|
||||
U ->> S: set app pref "activating"
|
||||
alt nse active?
|
||||
U ->> C: /_app phase inactive
|
||||
note over U: poll/wait till NSE pref is "completed"
|
||||
end
|
||||
|
||||
U ->> C: /_app phase active (response result)
|
||||
C ->> A: activateAgent<br>(allow IO/DB)
|
||||
A ->> C: result ()
|
||||
C ->> U: CRCmdOk
|
||||
U ->> S: set app pref "active"
|
|
@ -0,0 +1,91 @@
|
|||
# Notification server
|
||||
|
||||
## Background and motivation
|
||||
|
||||
SimpleX Chat clients should receive message notifications when not being online and/or subscribed to SMP servers.
|
||||
|
||||
To avoid revealing identities of clients directly to SMP servers via any kind of push notification tokens, a new party called SimpleX Notification Server is introduced to act as a service for subscribing to SMP server queue notifications on behalf of clients and sending push notifications to them.
|
||||
|
||||
## Proposal
|
||||
|
||||
TCP service using the same TLS transport as SMP server, with the fixed size blocks (256 bytes?) and the following set of commands:
|
||||
|
||||
### Protocol
|
||||
|
||||
#### Create subscription
|
||||
|
||||
Command:
|
||||
|
||||
`%s"CREATE " ntfSmpQueueURI ntfPrivateKey token subPublicKey`
|
||||
|
||||
Response:
|
||||
|
||||
`s%"OK"`
|
||||
|
||||
#### Check subscription status
|
||||
|
||||
Command:
|
||||
|
||||
`%s"CHECK " ntfSmpQueueURI`
|
||||
|
||||
Response:
|
||||
|
||||
```abnf
|
||||
statusResp = %s"STAT " status
|
||||
status = %s"ERR AUTH" / "ERR SMP AUTH" / %s"ERR SMP TIMEOUT" / %s"ACTIVE" / %s"PENDING"
|
||||
```
|
||||
|
||||
#### Update subscription device token
|
||||
|
||||
Command:
|
||||
|
||||
`%s"TOKEN " ntfSmpQueueURI token`
|
||||
|
||||
Response:
|
||||
|
||||
`s%"OK" / %s"ERR"`
|
||||
|
||||
#### Delete subscription (e.g. when deleting the queue or moving to another notification server)
|
||||
|
||||
Command:
|
||||
|
||||
`%s"DELETE " SP ntfSmpQueueURI`
|
||||
|
||||
Response:
|
||||
|
||||
`s%"OK" / %s"ERR"`
|
||||
|
||||
### Agent schema changes
|
||||
|
||||
See [migration](../src/Simplex/Messaging/Agent/Store/SQLite/Migrations/M20220322_notifications.hs)
|
||||
|
||||
### Agent code
|
||||
|
||||
```haskell
|
||||
data NtfOptions = NtfOptions
|
||||
{ ntfServer :: Server, -- same type as for SMP servers, probably will be renamed
|
||||
ntfToken :: ByteString,
|
||||
ntfInitialCheckDelay :: Int, -- initial check delay after subscription is created, seconds
|
||||
ntfPeriodicCheckInterval :: Int -- subscription check interval, seconds
|
||||
}
|
||||
|
||||
data AgentConfig = AgentConfig {
|
||||
-- ...
|
||||
initialNtfOpts :: Maybe NtfOptions
|
||||
-- ...
|
||||
}
|
||||
|
||||
data AgentClient = AgentClient {
|
||||
-- ...
|
||||
ntfOpts :: TVar (Maybe NtfOptions)
|
||||
-- ...
|
||||
}
|
||||
```
|
||||
|
||||
A configuration parameter `initialNtfOpts :: Maybe NtfOptions` - if it is set or changes the agent would automatically manage subscriptions as SMP queues are subscribed/created/deleted and as the token or server changes.
|
||||
|
||||
There will be a method to update notifications configuration in case token or server changes.
|
||||
|
||||
All subscriptions will be managed in a separate subscription management loop, that would always take the earliest un-updated subscription that requires some action (ntf_sub_action column) and perform this action - the table of subscription would serve both as the table of existing subscriptions and required actions.
|
||||
|
||||
E.g. if the queue is subscribed and there is no notification subscription, it will be created in the table with "create" action, and the loop would create it and schedule "check" action on it.
|
|
@ -0,0 +1,34 @@
|
|||
# SMP protocol changes to support push notifications on iOS
|
||||
|
||||
## Problem
|
||||
|
||||
There are already commands/responses to allow subscriptions to message notifications - NKEY/NID, NSUB/NMSG. These commands will be used by SMP agent (NKEY/NID) and by notification server (NSUB/NMSG) to have message notifications delivered to notification server, so it can forward them to APNS server using device token.
|
||||
|
||||
There are two remaining problems that these commands do not solve.
|
||||
|
||||
1. Receiving the message when notification arrives.
|
||||
|
||||
iOS requires creating a bundled notification service extension (NSE) that runs in isolated container and, if we were to use the existing commands, would have SMP subscription to the same SMP servers as the main app, triggering resubscriptions every time the message reception switches between the app and NSE. That would cause a substantial increase in traffic and battery consumption to the users.
|
||||
|
||||
2. Showing notifications for service messages.
|
||||
|
||||
Users do not expect to see notifications for every single SMP messages - e.g., we currently do not show notifications when messages are edited and deleted, and users do not expect them. NSE requires that for every received push notification there should be some notification shown to the users. So only we would have to show a notification for message deletes and updates, we would have to show it for all service messages - e.g. user accepted file invitation, or file transmission has started, contact profile updates and so on.
|
||||
|
||||
We considered differentiating whether notifications are sent per queue, from the recipient side, so we do not send notifications for file queues. But it seems insufficient, particularly if we add such features as message receipts, status updates, etc.
|
||||
|
||||
## Proposal
|
||||
|
||||
1. To retrieve messages when push notifications arrive, we will add 2 SMP commands:
|
||||
|
||||
- GET: retrieve one message from the connection. Resonse could be either MSG (the same as when MSG is delivered, but with the correlation id) or GMSG (to simplify processing) – TBC during implementation. If message is not available, the response could be ERR NO_MSG
|
||||
- ACK or GACK: acknowledge that the message was processed by the client and can be removed - TBC which one is better. The response is OK or ERR NO_MSG if there was nothing to acknowledge (same as with ACK now)
|
||||
|
||||
This would allow receiving a single message from the queue without subscription, this way avoiding that the main app is unsubscribed from the queue.
|
||||
|
||||
2. The only way to avoid showing unnecessary notifications (status updates, service messages, etc.) is to avoid sending them. That requires instructing SMP server whether notification should be sent both per queue, from the recipient side, and per message - from the sender side. So the notification would only be sent if the queue has them enabled (via NKEY command) and the sender includes an additional flag in SEND command. The same flag should be included into MSG, so when the message is retrieved with GET command, the client knows, on the agent or chat level (or both), whether this message should have notification shown to the user, and if not - retrieve the next one(s) as well.
|
||||
|
||||
This is a substantial change to SMP protocol, that would require client and server upgrade for notifications to be supported.
|
||||
|
||||
We should consider whether to increase the SMP protocol version number to 2, so that the new clients can connect to the old clients but without notifications, or we could keep the old commands in the protocol and instead of adding flags to the existing commands, create new commands.
|
||||
|
||||
We can also consider making commands extensible so that the new flags can be added (and ignored by parsers if not supported) to at least some existing commands.
|
|
@ -0,0 +1,181 @@
|
|||
#!/bin/bash
|
||||
|
||||
# <UDF name="api_token" label="Linode API token - enable Linode to create tags with server address, fingerprint and version. Note: minimal permissions token should have are read/write access to `linodes` (to create tags) and `domains` (to add A record for the third level domain if FQDN is provided)." default="" />
|
||||
# TODO review
|
||||
# <UDF name="fqdn" label="FQDN (Fully Qualified Domain Name) - provide third level domain name (e.g. smp.example.com). If provided use `smp://fingerprint@FQDN` as server address in the client. If FQDN is not provided use `smp://fingerprint@IP` instead." default="" />
|
||||
# <UDF name="apns_key_id" label="APNS key ID." default="" />
|
||||
|
||||
# Log all stdout output to stackscript.log
|
||||
exec &> >(tee -i /var/log/stackscript.log)
|
||||
|
||||
# Uncomment next line to enable debugging features
|
||||
# set -xeo pipefail
|
||||
|
||||
cd $HOME
|
||||
|
||||
# https://superuser.com/questions/1638779/automatic-yess-to-linux-update-upgrade
|
||||
# https://superuser.com/questions/1412054/non-interactive-apt-upgrade
|
||||
sudo DEBIAN_FRONTEND=noninteractive \
|
||||
apt-get \
|
||||
-o Dpkg::Options::=--force-confold \
|
||||
-o Dpkg::Options::=--force-confdef \
|
||||
-y --allow-downgrades --allow-remove-essential --allow-change-held-packages \
|
||||
update
|
||||
|
||||
sudo DEBIAN_FRONTEND=noninteractive \
|
||||
apt-get \
|
||||
-o Dpkg::Options::=--force-confold \
|
||||
-o Dpkg::Options::=--force-confdef \
|
||||
-y --allow-downgrades --allow-remove-essential --allow-change-held-packages \
|
||||
dist-upgrade
|
||||
|
||||
# TODO install unattended-upgrades
|
||||
sudo DEBIAN_FRONTEND=noninteractive \
|
||||
apt-get \
|
||||
-o Dpkg::Options::=--force-confold \
|
||||
-o Dpkg::Options::=--force-confdef \
|
||||
-y --allow-downgrades --allow-remove-essential --allow-change-held-packages \
|
||||
install jq
|
||||
|
||||
# Add firewall
|
||||
echo "y" | ufw enable
|
||||
|
||||
# Open ports
|
||||
ufw allow ssh
|
||||
ufw allow https
|
||||
ufw allow 5223
|
||||
|
||||
# Increase file descriptors limit
|
||||
echo 'fs.file-max = 1000000' >> /etc/sysctl.conf
|
||||
echo 'fs.inode-max = 1000000' >> /etc/sysctl.conf
|
||||
echo 'root soft nofile unlimited' >> /etc/security/limits.conf
|
||||
echo 'root hard nofile unlimited' >> /etc/security/limits.conf
|
||||
|
||||
# Download latest release
|
||||
bin_dir="/opt/simplex-notifications/bin"
|
||||
binary="$bin_dir/ntf-server"
|
||||
mkdir -p $bin_dir
|
||||
curl -L -o $binary https://github.com/simplex-chat/simplexmq/releases/latest/download/ntf-server-ubuntu-20_04-x86-64
|
||||
chmod +x $binary
|
||||
|
||||
# / Add to PATH
|
||||
cat > /etc/profile.d/simplex.sh << EOF
|
||||
#!/bin/bash
|
||||
|
||||
export PATH="$PATH:$bin_dir"
|
||||
|
||||
EOF
|
||||
# Add to PATH /
|
||||
|
||||
# Source and test PATH
|
||||
source /etc/profile.d/simplex.sh
|
||||
ntf-server --version
|
||||
|
||||
# Initialize server
|
||||
init_opts=()
|
||||
|
||||
ip_address=$(curl ifconfig.me)
|
||||
init_opts+=(--ip $ip_address)
|
||||
|
||||
[[ -n "$FQDN" ]] && init_opts+=(-n $FQDN)
|
||||
|
||||
ntf-server init "${init_opts[@]}"
|
||||
|
||||
# Server fingerprint
|
||||
fingerprint=$(cat /etc/opt/simplex-notifications/fingerprint)
|
||||
|
||||
# Determine server address to specify in welcome script and Linode tag
|
||||
# ! If FQDN was provided and used as part of server initialization, server's certificate will not pass validation at client
|
||||
# ! if client tries to connect by server's IP address, so we have to specify FQDN as server address in Linode tag and
|
||||
# ! in welcome script regardless of creation of A record in Linode
|
||||
# ! https://hackage.haskell.org/package/x509-validation-1.6.10/docs/src/Data-X509-Validation.html#validateCertificateName
|
||||
if [[ -n "$FQDN" ]]; then
|
||||
server_address=$FQDN
|
||||
else
|
||||
server_address=$ip_address
|
||||
fi
|
||||
|
||||
# Set up welcome script
|
||||
on_login_script="/opt/simplex-notifications/on_login.sh"
|
||||
|
||||
# TODO fix address
|
||||
# / Welcome script
|
||||
cat > $on_login_script << EOF
|
||||
#!/bin/bash
|
||||
|
||||
fingerprint=\$1
|
||||
server_address=\$2
|
||||
|
||||
cat << EOF2
|
||||
********************************************************************************
|
||||
|
||||
SimpleX notifications server address: smp://\$fingerprint@\$server_address
|
||||
Check server status with: systemctl status ntf-server
|
||||
|
||||
To keep this server secure, the UFW firewall is enabled.
|
||||
All ports are BLOCKED except 22 (SSH), 443 (HTTPS), 5223 (notifications server).
|
||||
|
||||
********************************************************************************
|
||||
To stop seeing this message delete line - bash /opt/simplex-notifications/on_login.sh - from /root/.bashrc
|
||||
EOF2
|
||||
|
||||
EOF
|
||||
# Welcome script /
|
||||
|
||||
chmod +x $on_login_script
|
||||
echo "bash $on_login_script $fingerprint $server_address" >> /root/.bashrc
|
||||
|
||||
# Create A record and update Linode's tags
|
||||
if [[ -n "$API_TOKEN" ]]; then
|
||||
if [[ -n "$FQDN" ]]; then
|
||||
domain_address=$(echo $FQDN | rev | cut -d "." -f 1,2 | rev)
|
||||
domain_id=$(curl -H "Authorization: Bearer $API_TOKEN" https://api.linode.com/v4/domains \
|
||||
| jq --arg da "$domain_address" '.data[] | select( .domain == $da ) | .id')
|
||||
if [[ -n $domain_id ]]; then
|
||||
curl \
|
||||
-s -H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $API_TOKEN" \
|
||||
-X POST -d "{\"type\":\"A\",\"name\":\"$FQDN\",\"target\":\"$ip_address\"}" \
|
||||
https://api.linode.com/v4/domains/${domain_id}/records
|
||||
fi
|
||||
fi
|
||||
|
||||
version=$(ntf-server --version | cut -d ' ' -f 3-)
|
||||
|
||||
curl \
|
||||
-s -H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $API_TOKEN" \
|
||||
-X PUT -d "{\"tags\":[\"$server_address\",\"$fingerprint\",\"$version\"]}" \
|
||||
https://api.linode.com/v4/linode/instances/$LINODE_ID
|
||||
fi
|
||||
|
||||
# / Create systemd service
|
||||
cat > /etc/systemd/system/ntf-server.service << EOF
|
||||
[Unit]
|
||||
Description=SimpleX notifications server
|
||||
|
||||
[Service]
|
||||
Environment="APNS_KEY_FILE=/etc/opt/simplex-notifications/AuthKey.p8"
|
||||
Environment="APNS_KEY_ID=$APNS_KEY_ID"
|
||||
Type=simple
|
||||
ExecStart=/bin/sh -c "exec $binary start >> /var/opt/simplex-notifications/ntf-server.log 2>&1"
|
||||
KillSignal=SIGINT
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
LimitNOFILE=1000000
|
||||
LimitNOFILESoft=1000000
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
|
||||
EOF
|
||||
# Create systemd service /
|
||||
|
||||
# Start systemd service
|
||||
chmod 644 /etc/systemd/system/ntf-server.service
|
||||
sudo systemctl enable ntf-server
|
||||
# ! APNS key file and certificate have to be copied manually
|
||||
# sudo systemctl start ntf-server
|
||||
|
||||
# Reboot Linode to apply upgrades
|
||||
# sudo reboot
|
|
@ -157,6 +157,8 @@ Description=SMP server
|
|||
[Service]
|
||||
Type=simple
|
||||
ExecStart=/bin/sh -c "exec $binary start >> /var/opt/simplex/smp-server.log 2>&1"
|
||||
KillSignal=SIGINT
|
||||
TimeoutStopSec=infinity
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
LimitNOFILE=1000000
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
#!/bin/bash
|
||||
|
||||
# systemd has to be configured to use SIGINT to save and restore undelivered messages after restart.
|
||||
# Add this to [Service] section:
|
||||
# KillSignal=SIGINT
|
||||
curl -L -o /opt/simplex/bin/smp-server-new https://github.com/simplex-chat/simplexmq/releases/latest/download/smp-server-ubuntu-20_04-x86-64
|
||||
systemctl stop smp-server
|
||||
cp /var/opt/simplex/smp-server-store.log /var/opt/simplex/smp-server-store.log.bak
|
||||
mv /opt/simplex/bin/smp-server /opt/simplex/bin/smp-server-old
|
||||
mv /opt/simplex/bin/smp-server-new /opt/simplex/bin/smp-server
|
||||
chmod +x /opt/simplex/bin/smp-server
|
||||
systemctl start smp-server
|
128
simplexmq.cabal
128
simplexmq.cabal
|
@ -5,7 +5,7 @@ cabal-version: 1.12
|
|||
-- see: https://github.com/sol/hpack
|
||||
|
||||
name: simplexmq
|
||||
version: 1.1.0
|
||||
version: 3.0.0
|
||||
synopsis: SimpleXMQ message broker
|
||||
description: This package includes <./docs/Simplex-Messaging-Server.html server>,
|
||||
<./docs/Simplex-Messaging-Client.html client> and
|
||||
|
@ -37,6 +37,7 @@ library
|
|||
Simplex.Messaging.Agent
|
||||
Simplex.Messaging.Agent.Client
|
||||
Simplex.Messaging.Agent.Env.SQLite
|
||||
Simplex.Messaging.Agent.NtfSubSupervisor
|
||||
Simplex.Messaging.Agent.Protocol
|
||||
Simplex.Messaging.Agent.QueryString
|
||||
Simplex.Messaging.Agent.RetryInterval
|
||||
|
@ -46,23 +47,42 @@ library
|
|||
Simplex.Messaging.Agent.Store.SQLite.Migrations
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220101_initial
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220301_snd_queue_keys
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220322_notifications
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220608_v2
|
||||
Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode
|
||||
Simplex.Messaging.Client
|
||||
Simplex.Messaging.Client.Agent
|
||||
Simplex.Messaging.Crypto
|
||||
Simplex.Messaging.Crypto.Ratchet
|
||||
Simplex.Messaging.Encoding
|
||||
Simplex.Messaging.Encoding.String
|
||||
Simplex.Messaging.Notifications.Client
|
||||
Simplex.Messaging.Notifications.Protocol
|
||||
Simplex.Messaging.Notifications.Server
|
||||
Simplex.Messaging.Notifications.Server.Env
|
||||
Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
Simplex.Messaging.Notifications.Server.Store
|
||||
Simplex.Messaging.Notifications.Server.StoreLog
|
||||
Simplex.Messaging.Notifications.Transport
|
||||
Simplex.Messaging.Notifications.Types
|
||||
Simplex.Messaging.Parsers
|
||||
Simplex.Messaging.Protocol
|
||||
Simplex.Messaging.Server
|
||||
Simplex.Messaging.Server.CLI
|
||||
Simplex.Messaging.Server.Env.STM
|
||||
Simplex.Messaging.Server.Expiration
|
||||
Simplex.Messaging.Server.MsgStore
|
||||
Simplex.Messaging.Server.MsgStore.STM
|
||||
Simplex.Messaging.Server.QueueStore
|
||||
Simplex.Messaging.Server.QueueStore.STM
|
||||
Simplex.Messaging.Server.Stats
|
||||
Simplex.Messaging.Server.StoreLog
|
||||
Simplex.Messaging.TMap
|
||||
Simplex.Messaging.Transport
|
||||
Simplex.Messaging.Transport.Client
|
||||
Simplex.Messaging.Transport.HTTP2
|
||||
Simplex.Messaging.Transport.HTTP2.Client
|
||||
Simplex.Messaging.Transport.HTTP2.Server
|
||||
Simplex.Messaging.Transport.KeepAlive
|
||||
Simplex.Messaging.Transport.Server
|
||||
Simplex.Messaging.Transport.WebSockets
|
||||
|
@ -81,9 +101,10 @@ library
|
|||
, asn1-types ==0.3.*
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.7 && <5
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, bytestring ==0.10.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
, containers ==0.6.*
|
||||
|
@ -95,11 +116,15 @@ library
|
|||
, filepath ==1.4.*
|
||||
, generic-random >=1.3 && <1.5
|
||||
, http-types ==0.12.*
|
||||
, http2 ==3.0.*
|
||||
, ini ==0.4.*
|
||||
, iso8601-time ==0.1.*
|
||||
, memory ==0.15.*
|
||||
, mtl ==2.2.*
|
||||
, network ==3.1.2.*
|
||||
, network >=3.1.2.7 && <3.2
|
||||
, network-transport ==0.5.*
|
||||
, optparse-applicative >=0.15 && <0.17
|
||||
, process ==1.6.*
|
||||
, random >=1.1 && <1.3
|
||||
, simple-logger ==0.1.*
|
||||
, sqlite-simple ==0.4.*
|
||||
|
@ -107,6 +132,69 @@ library
|
|||
, template-haskell ==2.16.*
|
||||
, text ==1.2.*
|
||||
, time ==1.9.*
|
||||
, time-compat ==1.9.*
|
||||
, time-manager ==0.0.*
|
||||
, tls >=1.5.7 && <1.6
|
||||
, transformers ==0.5.*
|
||||
, unliftio ==0.2.*
|
||||
, unliftio-core ==0.2.*
|
||||
, websockets ==0.12.*
|
||||
, x509 ==1.7.*
|
||||
, x509-store ==1.6.*
|
||||
, x509-validation ==1.6.*
|
||||
if flag(swift)
|
||||
cpp-options: -DswiftJSON
|
||||
default-language: Haskell2010
|
||||
|
||||
executable ntf-server
|
||||
main-is: Main.hs
|
||||
other-modules:
|
||||
Paths_simplexmq
|
||||
hs-source-dirs:
|
||||
apps/ntf-server
|
||||
ghc-options: -Wall -Wcompat -Werror=incomplete-patterns -Wredundant-constraints -Wincomplete-record-updates -Wincomplete-uni-patterns -Wunused-type-patterns -threaded
|
||||
build-depends:
|
||||
QuickCheck ==2.14.*
|
||||
, aeson ==2.0.*
|
||||
, ansi-terminal >=0.10 && <0.12
|
||||
, asn1-encoding ==0.9.*
|
||||
, asn1-types ==0.3.*
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, bytestring ==0.10.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
, containers ==0.6.*
|
||||
, cryptonite >=0.27 && <0.30
|
||||
, cryptostore ==0.2.*
|
||||
, data-default ==0.7.*
|
||||
, direct-sqlite ==2.3.*
|
||||
, directory ==1.3.*
|
||||
, filepath ==1.4.*
|
||||
, generic-random >=1.3 && <1.5
|
||||
, http-types ==0.12.*
|
||||
, http2 ==3.0.*
|
||||
, ini ==0.4.*
|
||||
, iso8601-time ==0.1.*
|
||||
, memory ==0.15.*
|
||||
, mtl ==2.2.*
|
||||
, network >=3.1.2.7 && <3.2
|
||||
, network-transport ==0.5.*
|
||||
, optparse-applicative >=0.15 && <0.17
|
||||
, process ==1.6.*
|
||||
, random >=1.1 && <1.3
|
||||
, simple-logger ==0.1.*
|
||||
, simplexmq
|
||||
, sqlite-simple ==0.4.*
|
||||
, stm ==2.5.*
|
||||
, template-haskell ==2.16.*
|
||||
, text ==1.2.*
|
||||
, time ==1.9.*
|
||||
, time-compat ==1.9.*
|
||||
, time-manager ==0.0.*
|
||||
, tls >=1.5.7 && <1.6
|
||||
, transformers ==0.5.*
|
||||
, unliftio ==0.2.*
|
||||
|
@ -134,9 +222,10 @@ executable smp-agent
|
|||
, asn1-types ==0.3.*
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.7 && <5
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, bytestring ==0.10.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
, containers ==0.6.*
|
||||
|
@ -148,11 +237,15 @@ executable smp-agent
|
|||
, filepath ==1.4.*
|
||||
, generic-random >=1.3 && <1.5
|
||||
, http-types ==0.12.*
|
||||
, http2 ==3.0.*
|
||||
, ini ==0.4.*
|
||||
, iso8601-time ==0.1.*
|
||||
, memory ==0.15.*
|
||||
, mtl ==2.2.*
|
||||
, network ==3.1.2.*
|
||||
, network >=3.1.2.7 && <3.2
|
||||
, network-transport ==0.5.*
|
||||
, optparse-applicative >=0.15 && <0.17
|
||||
, process ==1.6.*
|
||||
, random >=1.1 && <1.3
|
||||
, simple-logger ==0.1.*
|
||||
, simplexmq
|
||||
|
@ -161,6 +254,8 @@ executable smp-agent
|
|||
, template-haskell ==2.16.*
|
||||
, text ==1.2.*
|
||||
, time ==1.9.*
|
||||
, time-compat ==1.9.*
|
||||
, time-manager ==0.0.*
|
||||
, tls >=1.5.7 && <1.6
|
||||
, transformers ==0.5.*
|
||||
, unliftio ==0.2.*
|
||||
|
@ -188,9 +283,10 @@ executable smp-server
|
|||
, asn1-types ==0.3.*
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.7 && <5
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, bytestring ==0.10.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
, containers ==0.6.*
|
||||
|
@ -202,11 +298,12 @@ executable smp-server
|
|||
, filepath ==1.4.*
|
||||
, generic-random >=1.3 && <1.5
|
||||
, http-types ==0.12.*
|
||||
, http2 ==3.0.*
|
||||
, ini ==0.4.*
|
||||
, iso8601-time ==0.1.*
|
||||
, memory ==0.15.*
|
||||
, mtl ==2.2.*
|
||||
, network ==3.1.2.*
|
||||
, network >=3.1.2.7 && <3.2
|
||||
, network-transport ==0.5.*
|
||||
, optparse-applicative >=0.15 && <0.17
|
||||
, process ==1.6.*
|
||||
|
@ -218,6 +315,8 @@ executable smp-server
|
|||
, template-haskell ==2.16.*
|
||||
, text ==1.2.*
|
||||
, time ==1.9.*
|
||||
, time-compat ==1.9.*
|
||||
, time-manager ==0.0.*
|
||||
, tls >=1.5.7 && <1.6
|
||||
, transformers ==0.5.*
|
||||
, unliftio ==0.2.*
|
||||
|
@ -238,10 +337,14 @@ test-suite smp-server-test
|
|||
AgentTests.ConnectionRequestTests
|
||||
AgentTests.DoubleRatchetTests
|
||||
AgentTests.FunctionalAPITests
|
||||
AgentTests.NotificationTests
|
||||
AgentTests.SchemaDump
|
||||
AgentTests.SQLiteTests
|
||||
CoreTests.EncodingTests
|
||||
CoreTests.ProtocolErrorTests
|
||||
CoreTests.VersionRangeTests
|
||||
NtfClient
|
||||
NtfServerTests
|
||||
ServerTests
|
||||
SMPAgentClient
|
||||
SMPClient
|
||||
|
@ -258,9 +361,10 @@ test-suite smp-server-test
|
|||
, asn1-types ==0.3.*
|
||||
, async ==2.2.*
|
||||
, attoparsec ==0.14.*
|
||||
, base >=4.7 && <5
|
||||
, base >=4.14 && <5
|
||||
, base64-bytestring >=1.0 && <1.3
|
||||
, bytestring ==0.10.*
|
||||
, case-insensitive ==1.2.*
|
||||
, composition ==1.0.*
|
||||
, constraints >=0.12 && <0.14
|
||||
, containers ==0.6.*
|
||||
|
@ -274,11 +378,15 @@ test-suite smp-server-test
|
|||
, hspec ==2.7.*
|
||||
, hspec-core ==2.7.*
|
||||
, http-types ==0.12.*
|
||||
, http2 ==3.0.*
|
||||
, ini ==0.4.*
|
||||
, iso8601-time ==0.1.*
|
||||
, memory ==0.15.*
|
||||
, mtl ==2.2.*
|
||||
, network ==3.1.2.*
|
||||
, network >=3.1.2.7 && <3.2
|
||||
, network-transport ==0.5.*
|
||||
, optparse-applicative >=0.15 && <0.17
|
||||
, process ==1.6.*
|
||||
, random >=1.1 && <1.3
|
||||
, simple-logger ==0.1.*
|
||||
, simplexmq
|
||||
|
@ -287,6 +395,8 @@ test-suite smp-server-test
|
|||
, template-haskell ==2.16.*
|
||||
, text ==1.2.*
|
||||
, time ==1.9.*
|
||||
, time-compat ==1.9.*
|
||||
, time-manager ==0.0.*
|
||||
, timeit ==2.0.*
|
||||
, tls >=1.5.7 && <1.6
|
||||
, transformers ==0.5.*
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,30 +1,46 @@
|
|||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Client
|
||||
( AgentClient (..),
|
||||
newAgentClient,
|
||||
AgentMonad,
|
||||
withAgentLock,
|
||||
closeAgentClient,
|
||||
newRcvQueue,
|
||||
subscribeQueue,
|
||||
getQueueMessage,
|
||||
decryptSMPMessage,
|
||||
addSubscription,
|
||||
getSubscriptions,
|
||||
sendConfirmation,
|
||||
sendInvitation,
|
||||
RetryInterval (..),
|
||||
secureQueue,
|
||||
enableQueueNotifications,
|
||||
disableQueueNotifications,
|
||||
sendAgentMessage,
|
||||
agentNtfRegisterToken,
|
||||
agentNtfVerifyToken,
|
||||
agentNtfCheckToken,
|
||||
agentNtfReplaceToken,
|
||||
agentNtfDeleteToken,
|
||||
agentNtfEnableCron,
|
||||
agentNtfCreateSubscription,
|
||||
agentNtfCheckSubscription,
|
||||
agentNtfDeleteSubscription,
|
||||
agentCbEncrypt,
|
||||
agentCbDecrypt,
|
||||
cryptoError,
|
||||
|
@ -33,12 +49,27 @@ module Simplex.Messaging.Agent.Client
|
|||
deleteQueue,
|
||||
logServer,
|
||||
removeSubscription,
|
||||
hasActiveSubscription,
|
||||
agentDbPath,
|
||||
AgentOperation (..),
|
||||
AgentOpState (..),
|
||||
AgentState (..),
|
||||
agentOperations,
|
||||
agentOperationBracket,
|
||||
beginAgentOperation,
|
||||
endAgentOperation,
|
||||
suspendSendingAndDatabase,
|
||||
suspendOperation,
|
||||
notifySuspended,
|
||||
whenSuspending,
|
||||
withStore,
|
||||
withStore',
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent (forkIO)
|
||||
import Control.Concurrent.Async (Async, uninterruptibleCancel)
|
||||
import Control.Concurrent.STM (stateTVar)
|
||||
import Control.Concurrent.STM (retry, stateTVar)
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
|
@ -50,141 +81,147 @@ import qualified Data.ByteString.Char8 as B
|
|||
import Data.List.NonEmpty (NonEmpty)
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Maybe (catMaybes)
|
||||
import Data.Set (Set)
|
||||
import Data.Text.Encoding
|
||||
import Data.Word (Word16)
|
||||
import qualified Database.SQLite.Simple as DB
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite (SQLiteStore (..), withTransaction)
|
||||
import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Protocol (QueueId, QueueIdsKeys (..), SndPublicVerifyKey)
|
||||
import Simplex.Messaging.Notifications.Client
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Parsers (parse)
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, ErrorType, MsgFlags (..), MsgId, NotifierId, NtfPrivateSignKey, NtfPublicVerifyKey, NtfServer, ProtoServer, ProtocolServer (..), QueueId, QueueIdsKeys (..), RcvMessage (..), RcvNtfPublicDhKey, SMPMsgMeta (..), SndPublicVerifyKey)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (bshow, liftEitherError, liftError, tryError, whenM)
|
||||
import Simplex.Messaging.Util (bshow, catchAll_, ifM, liftEitherError, liftError, tryError, unlessM, whenM)
|
||||
import Simplex.Messaging.Version
|
||||
import System.Timeout (timeout)
|
||||
import UnliftIO (async, forConcurrently_)
|
||||
import UnliftIO.Exception (Exception, IOException)
|
||||
import UnliftIO (async, pooledForConcurrentlyN)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
type ClientVar msg = TMVar (Either AgentErrorType (ProtocolClient msg))
|
||||
|
||||
type SMPClientVar = TMVar (Either AgentErrorType SMPClient)
|
||||
|
||||
type NtfClientVar = TMVar (Either AgentErrorType NtfClient)
|
||||
|
||||
data AgentClient = AgentClient
|
||||
{ rcvQ :: TBQueue (ATransmission 'Client),
|
||||
{ active :: TVar Bool,
|
||||
rcvQ :: TBQueue (ATransmission 'Client),
|
||||
subQ :: TBQueue (ATransmission 'Agent),
|
||||
msgQ :: TBQueue SMPServerTransmission,
|
||||
msgQ :: TBQueue (ServerTransmission BrokerMsg),
|
||||
smpServers :: TVar (NonEmpty SMPServer),
|
||||
smpClients :: TMap SMPServer SMPClientVar,
|
||||
ntfServers :: TVar [NtfServer],
|
||||
ntfClients :: TMap NtfServer NtfClientVar,
|
||||
subscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
|
||||
pendingSubscrSrvrs :: TMap SMPServer (TMap ConnId RcvQueue),
|
||||
subscrConns :: TMap ConnId SMPServer,
|
||||
connMsgsQueued :: TMap ConnId Bool,
|
||||
smpQueueMsgQueues :: TMap (ConnId, SMPServer, SMP.SenderId) (TQueue InternalId),
|
||||
smpQueueMsgDeliveries :: TMap (ConnId, SMPServer, SMP.SenderId) (Async ()),
|
||||
ntfNetworkOp :: TVar AgentOpState,
|
||||
rcvNetworkOp :: TVar AgentOpState,
|
||||
msgDeliveryOp :: TVar AgentOpState,
|
||||
sndNetworkOp :: TVar AgentOpState,
|
||||
databaseOp :: TVar AgentOpState,
|
||||
agentState :: TVar AgentState,
|
||||
getMsgLocks :: TMap (SMPServer, SMP.RecipientId) (TMVar ()),
|
||||
reconnections :: TVar [Async ()],
|
||||
asyncClients :: TVar [Async ()],
|
||||
clientId :: Int,
|
||||
agentEnv :: Env,
|
||||
smpSubscriber :: Async (),
|
||||
lock :: TMVar ()
|
||||
}
|
||||
|
||||
newAgentClient :: Env -> STM AgentClient
|
||||
newAgentClient agentEnv = do
|
||||
data AgentOperation = AONtfNetwork | AORcvNetwork | AOMsgDelivery | AOSndNetwork | AODatabase
|
||||
deriving (Eq, Show)
|
||||
|
||||
agentOpSel :: AgentOperation -> (AgentClient -> TVar AgentOpState)
|
||||
agentOpSel = \case
|
||||
AONtfNetwork -> ntfNetworkOp
|
||||
AORcvNetwork -> rcvNetworkOp
|
||||
AOMsgDelivery -> msgDeliveryOp
|
||||
AOSndNetwork -> sndNetworkOp
|
||||
AODatabase -> databaseOp
|
||||
|
||||
agentOperations :: [AgentClient -> TVar AgentOpState]
|
||||
agentOperations = [ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp]
|
||||
|
||||
data AgentOpState = AgentOpState {opSuspended :: Bool, opsInProgress :: Int}
|
||||
|
||||
data AgentState = ASActive | ASSuspending | ASSuspended
|
||||
deriving (Eq, Show)
|
||||
|
||||
newAgentClient :: InitialAgentServers -> Env -> STM AgentClient
|
||||
newAgentClient InitialAgentServers {smp, ntf} agentEnv = do
|
||||
let qSize = tbqSize $ config agentEnv
|
||||
active <- newTVar True
|
||||
rcvQ <- newTBQueue qSize
|
||||
subQ <- newTBQueue qSize
|
||||
msgQ <- newTBQueue qSize
|
||||
smpServers <- newTVar $ initialSMPServers (config agentEnv)
|
||||
smpServers <- newTVar smp
|
||||
smpClients <- TM.empty
|
||||
ntfServers <- newTVar ntf
|
||||
ntfClients <- TM.empty
|
||||
subscrSrvrs <- TM.empty
|
||||
pendingSubscrSrvrs <- TM.empty
|
||||
subscrConns <- TM.empty
|
||||
connMsgsQueued <- TM.empty
|
||||
smpQueueMsgQueues <- TM.empty
|
||||
smpQueueMsgDeliveries <- TM.empty
|
||||
ntfNetworkOp <- newTVar $ AgentOpState False 0
|
||||
rcvNetworkOp <- newTVar $ AgentOpState False 0
|
||||
msgDeliveryOp <- newTVar $ AgentOpState False 0
|
||||
sndNetworkOp <- newTVar $ AgentOpState False 0
|
||||
databaseOp <- newTVar $ AgentOpState False 0
|
||||
agentState <- newTVar ASActive
|
||||
getMsgLocks <- TM.empty
|
||||
reconnections <- newTVar []
|
||||
asyncClients <- newTVar []
|
||||
clientId <- stateTVar (clientCounter agentEnv) $ \i -> (i + 1, i + 1)
|
||||
clientId <- stateTVar (clientCounter agentEnv) $ \i -> let i' = i + 1 in (i', i')
|
||||
lock <- newTMVar ()
|
||||
return AgentClient {rcvQ, subQ, msgQ, smpServers, smpClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, reconnections, asyncClients, clientId, agentEnv, smpSubscriber = undefined, lock}
|
||||
return AgentClient {active, rcvQ, subQ, msgQ, smpServers, smpClients, ntfServers, ntfClients, subscrSrvrs, pendingSubscrSrvrs, subscrConns, connMsgsQueued, smpQueueMsgQueues, smpQueueMsgDeliveries, ntfNetworkOp, rcvNetworkOp, msgDeliveryOp, sndNetworkOp, databaseOp, agentState, getMsgLocks, reconnections, asyncClients, clientId, agentEnv, lock}
|
||||
|
||||
-- | Agent monad with MonadReader Env and MonadError AgentErrorType
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
agentDbPath :: AgentClient -> FilePath
|
||||
agentDbPath AgentClient {agentEnv = Env {store = SQLiteStore {dbFilePath}}} = dbFilePath
|
||||
|
||||
newtype InternalException e = InternalException {unInternalException :: e}
|
||||
deriving (Eq, Show)
|
||||
class ProtocolServerClient msg where
|
||||
getProtocolServerClient :: AgentMonad m => AgentClient -> ProtoServer msg -> m (ProtocolClient msg)
|
||||
clientProtocolError :: ErrorType -> AgentErrorType
|
||||
|
||||
instance Exception e => Exception (InternalException e)
|
||||
instance ProtocolServerClient BrokerMsg where
|
||||
getProtocolServerClient = getSMPServerClient
|
||||
clientProtocolError = SMP
|
||||
|
||||
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
|
||||
withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b
|
||||
withRunInIO exceptToIO =
|
||||
withExceptT unInternalException . ExceptT . E.try $
|
||||
withRunInIO $ \run ->
|
||||
exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT)
|
||||
instance ProtocolServerClient NtfResponse where
|
||||
getProtocolServerClient = getNtfServerClient
|
||||
clientProtocolError = NTF
|
||||
|
||||
getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPServer -> m SMPClient
|
||||
getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
||||
atomically getClientVar >>= either newSMPClient waitForSMPClient
|
||||
getSMPServerClient c@AgentClient {active, smpClients, msgQ} srv = do
|
||||
unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped"
|
||||
atomically (getClientVar srv smpClients)
|
||||
>>= either
|
||||
(newProtocolClient c srv smpClients connectClient reconnectClient)
|
||||
(waitForProtocolClient smpCfg)
|
||||
where
|
||||
getClientVar :: STM (Either SMPClientVar SMPClientVar)
|
||||
getClientVar = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv smpClients
|
||||
|
||||
newClientVar :: STM SMPClientVar
|
||||
newClientVar = do
|
||||
smpVar <- newEmptyTMVar
|
||||
TM.insert srv smpVar smpClients
|
||||
pure smpVar
|
||||
|
||||
waitForSMPClient :: TMVar (Either AgentErrorType SMPClient) -> m SMPClient
|
||||
waitForSMPClient smpVar = do
|
||||
SMPClientConfig {tcpTimeout} <- asks $ smpCfg . config
|
||||
smpClient_ <- liftIO $ tcpTimeout `timeout` atomically (readTMVar smpVar)
|
||||
liftEither $ case smpClient_ of
|
||||
Just (Right smpClient) -> Right smpClient
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left $ BROKER TIMEOUT
|
||||
|
||||
newSMPClient :: TMVar (Either AgentErrorType SMPClient) -> m SMPClient
|
||||
newSMPClient smpVar = tryConnectClient pure tryConnectAsync
|
||||
where
|
||||
tryConnectClient :: (SMPClient -> m a) -> m () -> m a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryError connectClient >>= \r -> case r of
|
||||
Right smp -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
atomically $ putTMVar smpVar r
|
||||
successAction smp
|
||||
Left e -> do
|
||||
if e == BROKER NETWORK || e == BROKER TIMEOUT
|
||||
then retryAction
|
||||
else atomically $ do
|
||||
putTMVar smpVar (Left e)
|
||||
TM.delete srv smpClients
|
||||
throwError e
|
||||
tryConnectAsync :: m ()
|
||||
tryConnectAsync = do
|
||||
a <- async connectAsync
|
||||
atomically $ modifyTVar' (asyncClients c) (a :)
|
||||
connectAsync :: m ()
|
||||
connectAsync = do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop
|
||||
|
||||
connectClient :: m SMPClient
|
||||
connectClient = do
|
||||
cfg <- asks $ smpCfg . config
|
||||
u <- askUnliftIO
|
||||
liftEitherError smpClientError (getSMPClient srv cfg msgQ $ clientDisconnected u)
|
||||
`E.catch` internalError
|
||||
where
|
||||
internalError :: IOException -> m SMPClient
|
||||
internalError = throwError . INTERNAL . show
|
||||
liftEitherError (protocolClientError SMP) (getProtocolClient srv cfg (Just msgQ) $ clientDisconnected u)
|
||||
|
||||
clientDisconnected :: UnliftIO m -> IO ()
|
||||
clientDisconnected u = do
|
||||
|
@ -194,13 +231,14 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
|||
removeClientAndSubs :: IO (Maybe (Map ConnId RcvQueue))
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
cVar_ <- TM.lookupDelete srv $ subscrSrvrs c
|
||||
forM cVar_ $ \cVar -> do
|
||||
cs <- readTVar cVar
|
||||
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs)
|
||||
addPendingSubs cVar cs
|
||||
pure cs
|
||||
TM.lookupDelete srv (subscrSrvrs c) >>= mapM updateSubs
|
||||
where
|
||||
updateSubs cVar = do
|
||||
cs <- readTVar cVar
|
||||
modifyTVar' (subscrConns c) (`M.withoutKeys` M.keysSet cs)
|
||||
addPendingSubs cVar cs
|
||||
pure cs
|
||||
|
||||
addPendingSubs cVar cs = do
|
||||
let ps = pendingSubscrSrvrs c
|
||||
TM.lookup srv ps >>= \case
|
||||
|
@ -208,9 +246,12 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
|||
_ -> TM.insert srv cVar ps
|
||||
|
||||
serverDown :: UnliftIO m -> Map ConnId RcvQueue -> IO ()
|
||||
serverDown u cs = unless (M.null cs) $ do
|
||||
mapM_ (notifySub DOWN) $ M.keysSet cs
|
||||
unliftIO u reconnectServer
|
||||
serverDown u cs = unless (M.null cs) $
|
||||
whenM (readTVarIO active) $ do
|
||||
let conns = M.keys cs
|
||||
unless (null conns) . notifySub "" $ DOWN srv conns
|
||||
atomically $ mapM_ (releaseGetLock c) cs
|
||||
unliftIO u reconnectServer
|
||||
|
||||
reconnectServer :: m ()
|
||||
reconnectServer = do
|
||||
|
@ -224,47 +265,138 @@ getSMPServerClient c@AgentClient {smpClients, msgQ} srv =
|
|||
reconnectClient `catchError` const loop
|
||||
|
||||
reconnectClient :: m ()
|
||||
reconnectClient =
|
||||
withAgentLock c . withSMP c srv $ \smp -> do
|
||||
reconnectClient = do
|
||||
n <- asks $ resubscriptionConcurrency . config
|
||||
withAgentLock c . withClient c srv $ \smp -> do
|
||||
cs <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSubscrSrvrs c)
|
||||
forConcurrently_ (maybe [] M.toList cs) $ \sub@(connId, _) ->
|
||||
whenM (atomically $ isNothing <$> TM.lookup connId (subscrConns c)) $
|
||||
subscribe_ smp sub `catchError` handleError connId
|
||||
-- TODO if any of the subscriptions fails here (e.g. because of timeout), it terminates the whole process for all subscriptions
|
||||
-- instead it should only report successful subscriptions and schedule the next call to reconnectClient to subscribe for the remaining subscriptions
|
||||
-- this way, for each DOWN event there can be several UP events
|
||||
conns <- pooledForConcurrentlyN n (maybe [] M.toList cs) $ \sub@(connId, _) ->
|
||||
ifM
|
||||
(atomically $ hasActiveSubscription c connId)
|
||||
(pure $ Just connId)
|
||||
(subscribe_ smp sub `catchError` handleError connId)
|
||||
liftIO . unless (null conns) . notifySub "" . UP srv $ catMaybes conns
|
||||
where
|
||||
subscribe_ :: SMPClient -> (ConnId, RcvQueue) -> ExceptT SMPClientError IO ()
|
||||
subscribe_ :: SMPClient -> (ConnId, RcvQueue) -> ExceptT ProtocolClientError IO (Maybe ConnId)
|
||||
subscribe_ smp (connId, rq@RcvQueue {rcvPrivateKey, rcvId}) = do
|
||||
subscribeSMPQueue smp rcvPrivateKey rcvId
|
||||
addSubscription c rq connId
|
||||
liftIO $ notifySub UP connId
|
||||
pure $ Just connId
|
||||
|
||||
handleError :: ConnId -> SMPClientError -> ExceptT SMPClientError IO ()
|
||||
handleError :: ConnId -> ProtocolClientError -> ExceptT ProtocolClientError IO (Maybe ConnId)
|
||||
handleError connId = \case
|
||||
e@SMPResponseTimeout -> throwError e
|
||||
e@SMPNetworkError -> throwError e
|
||||
e@PCEResponseTimeout -> throwError e
|
||||
e@PCENetworkError -> throwError e
|
||||
e -> do
|
||||
liftIO $ notifySub (ERR $ smpClientError e) connId
|
||||
liftIO . notifySub connId . ERR $ protocolClientError SMP e
|
||||
atomically $ removePendingSubscription c srv connId
|
||||
pure Nothing
|
||||
|
||||
notifySub :: ACommand 'Agent -> ConnId -> IO ()
|
||||
notifySub cmd connId = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
|
||||
notifySub :: ConnId -> ACommand 'Agent -> IO ()
|
||||
notifySub connId cmd = atomically $ writeTBQueue (subQ c) ("", connId, cmd)
|
||||
|
||||
closeAgentClient :: MonadUnliftIO m => AgentClient -> m ()
|
||||
getNtfServerClient :: forall m. AgentMonad m => AgentClient -> NtfServer -> m NtfClient
|
||||
getNtfServerClient c@AgentClient {active, ntfClients} srv = do
|
||||
unlessM (readTVarIO active) . throwError $ INTERNAL "agent is stopped"
|
||||
atomically (getClientVar srv ntfClients)
|
||||
>>= either
|
||||
(newProtocolClient c srv ntfClients connectClient $ pure ())
|
||||
(waitForProtocolClient ntfCfg)
|
||||
where
|
||||
connectClient :: m NtfClient
|
||||
connectClient = do
|
||||
cfg <- asks $ ntfCfg . config
|
||||
liftEitherError (protocolClientError NTF) (getProtocolClient srv cfg Nothing clientDisconnected)
|
||||
|
||||
clientDisconnected :: IO ()
|
||||
clientDisconnected = do
|
||||
atomically $ TM.delete srv ntfClients
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
getClientVar :: forall a s. ProtocolServer s -> TMap (ProtocolServer s) (TMVar a) -> STM (Either (TMVar a) (TMVar a))
|
||||
getClientVar srv clients = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv clients
|
||||
where
|
||||
newClientVar :: STM (TMVar a)
|
||||
newClientVar = do
|
||||
var <- newEmptyTMVar
|
||||
TM.insert srv var clients
|
||||
pure var
|
||||
|
||||
waitForProtocolClient :: AgentMonad m => (AgentConfig -> ProtocolClientConfig) -> ClientVar msg -> m (ProtocolClient msg)
|
||||
waitForProtocolClient clientConfig clientVar = do
|
||||
ProtocolClientConfig {tcpTimeout} <- asks $ clientConfig . config
|
||||
client_ <- liftIO $ tcpTimeout `timeout` atomically (readTMVar clientVar)
|
||||
liftEither $ case client_ of
|
||||
Just (Right smpClient) -> Right smpClient
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left $ BROKER TIMEOUT
|
||||
|
||||
newProtocolClient ::
|
||||
forall msg m.
|
||||
AgentMonad m =>
|
||||
AgentClient ->
|
||||
ProtoServer msg ->
|
||||
TMap (ProtoServer msg) (ClientVar msg) ->
|
||||
m (ProtocolClient msg) ->
|
||||
m () ->
|
||||
ClientVar msg ->
|
||||
m (ProtocolClient msg)
|
||||
newProtocolClient c srv clients connectClient reconnectClient clientVar = tryConnectClient pure tryConnectAsync
|
||||
where
|
||||
tryConnectClient :: (ProtocolClient msg -> m a) -> m () -> m a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryError connectClient >>= \r -> case r of
|
||||
Right client -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
atomically $ putTMVar clientVar r
|
||||
successAction client
|
||||
Left e -> do
|
||||
if e == BROKER NETWORK || e == BROKER TIMEOUT
|
||||
then retryAction
|
||||
else atomically $ do
|
||||
putTMVar clientVar (Left e)
|
||||
TM.delete srv clients
|
||||
throwError e
|
||||
tryConnectAsync :: m ()
|
||||
tryConnectAsync = do
|
||||
a <- async connectAsync
|
||||
atomically $ modifyTVar' (asyncClients c) (a :)
|
||||
connectAsync :: m ()
|
||||
connectAsync = do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop -> void $ tryConnectClient (const reconnectClient) loop
|
||||
|
||||
closeAgentClient :: MonadIO m => AgentClient -> m ()
|
||||
closeAgentClient c = liftIO $ do
|
||||
closeSMPServerClients c
|
||||
atomically $ writeTVar (active c) False
|
||||
closeProtocolServerClients (clientTimeout smpCfg) $ smpClients c
|
||||
closeProtocolServerClients (clientTimeout ntfCfg) $ ntfClients c
|
||||
cancelActions $ reconnections c
|
||||
cancelActions $ asyncClients c
|
||||
cancelActions $ smpQueueMsgDeliveries c
|
||||
|
||||
closeSMPServerClients :: AgentClient -> IO ()
|
||||
closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient)
|
||||
clear subscrSrvrs
|
||||
clear pendingSubscrSrvrs
|
||||
clear subscrConns
|
||||
clear connMsgsQueued
|
||||
clear smpQueueMsgQueues
|
||||
clear getMsgLocks
|
||||
where
|
||||
closeClient smpVar =
|
||||
atomically (readTMVar smpVar) >>= \case
|
||||
Right smp -> closeSMPClient smp `E.catch` \(_ :: E.SomeException) -> pure ()
|
||||
clientTimeout sel = tcpTimeout . sel . config $ agentEnv c
|
||||
clear :: (AgentClient -> TMap k a) -> IO ()
|
||||
clear sel = atomically $ writeTVar (sel c) M.empty
|
||||
|
||||
closeProtocolServerClients :: Int -> TMap (ProtoServer msg) (ClientVar msg) -> IO ()
|
||||
closeProtocolServerClients tcpTimeout cs = readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty)
|
||||
where
|
||||
closeClient cVar =
|
||||
tcpTimeout `timeout` atomically (readTMVar cVar) >>= \case
|
||||
Just (Right client) -> closeProtocolClient client `catchAll_` pure ()
|
||||
_ -> pure ()
|
||||
|
||||
cancelActions :: Foldable f => TVar (f (Async ())) -> IO ()
|
||||
cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel
|
||||
cancelActions :: (Foldable f, Monoid (f (Async ()))) => TVar (f (Async ())) -> IO ()
|
||||
cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel >> atomically (writeTVar as mempty)
|
||||
|
||||
withAgentLock :: MonadUnliftIO m => AgentClient -> m a -> m a
|
||||
withAgentLock AgentClient {lock} =
|
||||
|
@ -272,40 +404,40 @@ withAgentLock AgentClient {lock} =
|
|||
(void . atomically $ takeTMVar lock)
|
||||
(atomically $ putTMVar lock ())
|
||||
|
||||
withSMP_ :: forall a m. AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> m a) -> m a
|
||||
withSMP_ c srv action =
|
||||
(getSMPServerClient c srv >>= action) `catchError` logServerError
|
||||
withClient_ :: forall a m msg. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> (ProtocolClient msg -> m a) -> m a
|
||||
withClient_ c srv action = (getProtocolServerClient c srv >>= action) `catchError` logServerError
|
||||
where
|
||||
logServerError :: AgentErrorType -> m a
|
||||
logServerError e = do
|
||||
logServer "<--" c srv "" $ bshow e
|
||||
throwError e
|
||||
|
||||
withLogSMP_ :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> m a) -> m a
|
||||
withLogSMP_ c srv qId cmdStr action = do
|
||||
withLogClient_ :: (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> m a) -> m a
|
||||
withLogClient_ c srv qId cmdStr action = do
|
||||
logServer "-->" c srv qId cmdStr
|
||||
res <- withSMP_ c srv action
|
||||
res <- withClient_ c srv action
|
||||
logServer "<--" c srv qId "OK"
|
||||
return res
|
||||
|
||||
withSMP :: AgentMonad m => AgentClient -> SMPServer -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withSMP c srv action = withSMP_ c srv $ liftSMP . action
|
||||
withClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withClient c srv action = withClient_ c srv $ liftClient (clientProtocolError @msg) . action
|
||||
|
||||
withLogSMP :: AgentMonad m => AgentClient -> SMPServer -> QueueId -> ByteString -> (SMPClient -> ExceptT SMPClientError IO a) -> m a
|
||||
withLogSMP c srv qId cmdStr action = withLogSMP_ c srv qId cmdStr $ liftSMP . action
|
||||
withLogClient :: forall m msg a. (AgentMonad m, ProtocolServerClient msg) => AgentClient -> ProtoServer msg -> QueueId -> ByteString -> (ProtocolClient msg -> ExceptT ProtocolClientError IO a) -> m a
|
||||
withLogClient c srv qId cmdStr action = withLogClient_ c srv qId cmdStr $ liftClient (clientProtocolError @msg) . action
|
||||
|
||||
liftSMP :: AgentMonad m => ExceptT SMPClientError IO a -> m a
|
||||
liftSMP = liftError smpClientError
|
||||
liftClient :: AgentMonad m => (ErrorType -> AgentErrorType) -> ExceptT ProtocolClientError IO a -> m a
|
||||
liftClient = liftError . protocolClientError
|
||||
|
||||
smpClientError :: SMPClientError -> AgentErrorType
|
||||
smpClientError = \case
|
||||
SMPServerError e -> SMP e
|
||||
SMPResponseError e -> BROKER $ RESPONSE e
|
||||
SMPUnexpectedResponse -> BROKER UNEXPECTED
|
||||
SMPResponseTimeout -> BROKER TIMEOUT
|
||||
SMPNetworkError -> BROKER NETWORK
|
||||
SMPTransportError e -> BROKER $ TRANSPORT e
|
||||
e -> INTERNAL $ show e
|
||||
protocolClientError :: (ErrorType -> AgentErrorType) -> ProtocolClientError -> AgentErrorType
|
||||
protocolClientError protocolError_ = \case
|
||||
PCEProtocolError e -> protocolError_ e
|
||||
PCEResponseError e -> BROKER $ RESPONSE e
|
||||
PCEUnexpectedResponse _ -> BROKER UNEXPECTED
|
||||
PCEResponseTimeout -> BROKER TIMEOUT
|
||||
PCENetworkError -> BROKER NETWORK
|
||||
PCETransportError e -> BROKER $ TRANSPORT e
|
||||
e@PCESignatureError {} -> INTERNAL $ show e
|
||||
e@PCEIOError {} -> INTERNAL $ show e
|
||||
|
||||
newRcvQueue :: AgentMonad m => AgentClient -> SMPServer -> m (RcvQueue, SMPQueueUri)
|
||||
newRcvQueue c srv =
|
||||
|
@ -324,7 +456,7 @@ newRcvQueue_ a c srv = do
|
|||
(e2eDhKey, e2ePrivKey) <- liftIO C.generateKeyPair'
|
||||
logServer "-->" c srv "" "NEW"
|
||||
QIK {rcvId, sndId, rcvPublicDhKey} <-
|
||||
withSMP c srv $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey
|
||||
withClient c srv $ \smp -> createSMPQueue smp rcvPrivateKey recipientKey dhKey
|
||||
logServer "<--" c srv "" $ B.unwords ["IDS", logSecret rcvId, logSecret sndId]
|
||||
let rq =
|
||||
RcvQueue
|
||||
|
@ -335,27 +467,33 @@ newRcvQueue_ a c srv = do
|
|||
e2ePrivKey,
|
||||
e2eDhSecret = Nothing,
|
||||
sndId = Just sndId,
|
||||
status = New
|
||||
status = New,
|
||||
clientNtfCreds = Nothing
|
||||
}
|
||||
pure (rq, SMPQueueUri srv sndId SMP.smpClientVRange e2eDhKey)
|
||||
|
||||
subscribeQueue :: AgentMonad m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
subscribeQueue c rq@RcvQueue {server, rcvPrivateKey, rcvId} connId = do
|
||||
whenM (atomically . TM.member (server, rcvId) $ getMsgLocks c) . throwError $ CMD PROHIBITED
|
||||
atomically $ addPendingSubscription c rq connId
|
||||
withLogSMP c server rcvId "SUB" $ \smp -> do
|
||||
withLogClient c server rcvId "SUB" $ \smp -> do
|
||||
liftIO (runExceptT $ subscribeSMPQueue smp rcvPrivateKey rcvId) >>= \case
|
||||
Left e -> do
|
||||
atomically . when (e /= SMPNetworkError && e /= SMPResponseTimeout) $
|
||||
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
||||
removePendingSubscription c server connId
|
||||
throwError e
|
||||
Right _ -> addSubscription c rq connId
|
||||
Right _ -> do
|
||||
addSubscription c rq connId
|
||||
|
||||
addSubscription :: MonadUnliftIO m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
addSubscription :: MonadIO m => AgentClient -> RcvQueue -> ConnId -> m ()
|
||||
addSubscription c rq@RcvQueue {server} connId = atomically $ do
|
||||
TM.insert connId server $ subscrConns c
|
||||
addSubs_ (subscrSrvrs c) rq connId
|
||||
removePendingSubscription c server connId
|
||||
|
||||
hasActiveSubscription :: AgentClient -> ConnId -> STM Bool
|
||||
hasActiveSubscription c connId = TM.member connId (subscrConns c)
|
||||
|
||||
addPendingSubscription :: AgentClient -> RcvQueue -> ConnId -> STM ()
|
||||
addPendingSubscription = addSubs_ . pendingSubscrSrvrs
|
||||
|
||||
|
@ -377,12 +515,17 @@ removeSubs_ :: TMap SMPServer (TMap ConnId RcvQueue) -> SMPServer -> ConnId -> S
|
|||
removeSubs_ ss server connId =
|
||||
TM.lookup server ss >>= mapM_ (TM.delete connId)
|
||||
|
||||
logServer :: AgentMonad m => ByteString -> AgentClient -> SMPServer -> QueueId -> ByteString -> m ()
|
||||
getSubscriptions :: AgentClient -> STM (Set ConnId)
|
||||
getSubscriptions AgentClient {subscrConns} = do
|
||||
m <- readTVar subscrConns
|
||||
pure $ M.keysSet m
|
||||
|
||||
logServer :: MonadIO m => ByteString -> AgentClient -> ProtocolServer s -> QueueId -> ByteString -> m ()
|
||||
logServer dir AgentClient {clientId} srv qId cmdStr =
|
||||
logInfo . decodeUtf8 $ B.unwords ["A", "(" <> bshow clientId <> ")", dir, showServer srv, ":", logSecret qId, cmdStr]
|
||||
|
||||
showServer :: SMPServer -> ByteString
|
||||
showServer SMPServer {host, port} =
|
||||
showServer :: ProtocolServer s -> ByteString
|
||||
showServer ProtocolServer {host, port} =
|
||||
B.pack $ host <> if null port then "" else ':' : port
|
||||
|
||||
logSecret :: ByteString -> ByteString
|
||||
|
@ -390,51 +533,124 @@ logSecret bs = encode $ B.take 3 bs
|
|||
|
||||
sendConfirmation :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
|
||||
sendConfirmation c sq@SndQueue {server, sndId, sndPublicKey = Just sndPublicKey, e2ePubKey = e2ePubKey@Just {}} agentConfirmation =
|
||||
withLogSMP_ c server sndId "SEND <CONF>" $ \smp -> do
|
||||
withLogClient_ c server sndId "SEND <CONF>" $ \smp -> do
|
||||
let clientMsg = SMP.ClientMessage (SMP.PHConfirmation sndPublicKey) agentConfirmation
|
||||
msg <- agentCbEncrypt sq e2ePubKey $ smpEncode clientMsg
|
||||
liftSMP $ sendSMPMessage smp Nothing sndId msg
|
||||
liftClient SMP $ sendSMPMessage smp Nothing sndId (SMP.MsgFlags {notification = True}) msg
|
||||
sendConfirmation _ _ _ = throwError $ INTERNAL "sendConfirmation called without snd_queue public key(s) in the database"
|
||||
|
||||
sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
sendInvitation c (Compatible SMPQueueInfo {smpServer, senderId, dhPublicKey}) connReq connInfo =
|
||||
withLogSMP_ c smpServer senderId "SEND <INV>" $ \smp -> do
|
||||
sendInvitation :: forall m. AgentMonad m => AgentClient -> Compatible SMPQueueInfo -> Compatible Version -> ConnectionRequestUri 'CMInvitation -> ConnInfo -> m ()
|
||||
sendInvitation c (Compatible SMPQueueInfo {smpServer, senderId, dhPublicKey}) (Compatible agentVersion) connReq connInfo =
|
||||
withLogClient_ c smpServer senderId "SEND <INV>" $ \smp -> do
|
||||
msg <- mkInvitation
|
||||
liftSMP $ sendSMPMessage smp Nothing senderId msg
|
||||
liftClient SMP $ sendSMPMessage smp Nothing senderId MsgFlags {notification = True} msg
|
||||
where
|
||||
mkInvitation :: m ByteString
|
||||
-- this is only encrypted with per-queue E2E, not with double ratchet
|
||||
mkInvitation = do
|
||||
let agentEnvelope = AgentInvitation {agentVersion = smpAgentVersion, connReq, connInfo}
|
||||
let agentEnvelope = AgentInvitation {agentVersion, connReq, connInfo}
|
||||
agentCbEncryptOnce dhPublicKey . smpEncode $
|
||||
SMP.ClientMessage SMP.PHEmpty $ smpEncode agentEnvelope
|
||||
|
||||
getQueueMessage :: AgentMonad m => AgentClient -> RcvQueue -> m (Maybe SMPMsgMeta)
|
||||
getQueueMessage c rq@RcvQueue {server, rcvId, rcvPrivateKey} = do
|
||||
atomically createTakeGetLock
|
||||
(v, msg_) <- withLogClient c server rcvId "GET" $ \smp ->
|
||||
(thVersion smp,) <$> getSMPMessage smp rcvPrivateKey rcvId
|
||||
mapM (decryptMeta v) msg_
|
||||
where
|
||||
decryptMeta v msg@SMP.RcvMessage {msgId} = SMP.rcvMessageMeta msgId <$> decryptSMPMessage v rq msg
|
||||
createTakeGetLock = TM.alterF takeLock (server, rcvId) $ getMsgLocks c
|
||||
where
|
||||
takeLock l_ = do
|
||||
l <- maybe (newTMVar ()) pure l_
|
||||
takeTMVar l
|
||||
pure $ Just l
|
||||
|
||||
decryptSMPMessage :: AgentMonad m => Version -> RcvQueue -> SMP.RcvMessage -> m SMP.ClientRcvMsgBody
|
||||
decryptSMPMessage v rq SMP.RcvMessage {msgId, msgTs, msgFlags, msgBody = SMP.EncRcvMsgBody body}
|
||||
| v == 1 || v == 2 = SMP.ClientRcvMsgBody msgTs msgFlags <$> decrypt body
|
||||
| otherwise = liftEither . parse SMP.clientRcvMsgBodyP (AGENT A_MESSAGE) =<< decrypt body
|
||||
where
|
||||
decrypt = agentCbDecrypt (rcvDhSecret rq) (C.cbNonce msgId)
|
||||
|
||||
secureQueue :: AgentMonad m => AgentClient -> RcvQueue -> SndPublicVerifyKey -> m ()
|
||||
secureQueue c RcvQueue {server, rcvId, rcvPrivateKey} senderKey =
|
||||
withLogSMP c server rcvId "KEY <key>" $ \smp ->
|
||||
withLogClient c server rcvId "KEY <key>" $ \smp ->
|
||||
secureSMPQueue smp rcvPrivateKey rcvId senderKey
|
||||
|
||||
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
sendAck c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogSMP c server rcvId "ACK" $ \smp ->
|
||||
ackSMPMessage smp rcvPrivateKey rcvId
|
||||
enableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (NotifierId, RcvNtfPublicDhKey)
|
||||
enableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} notifierKey rcvNtfPublicDhKey =
|
||||
withLogClient c server rcvId "NKEY <nkey>" $ \smp ->
|
||||
enableSMPQueueNotifications smp rcvPrivateKey rcvId notifierKey rcvNtfPublicDhKey
|
||||
|
||||
disableQueueNotifications :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
disableQueueNotifications c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogClient c server rcvId "NDEL" $ \smp ->
|
||||
disableSMPQueueNotifications smp rcvPrivateKey rcvId
|
||||
|
||||
sendAck :: AgentMonad m => AgentClient -> RcvQueue -> MsgId -> m ()
|
||||
sendAck c rq@RcvQueue {server, rcvId, rcvPrivateKey} msgId = do
|
||||
withLogClient c server rcvId "ACK" $ \smp ->
|
||||
ackSMPMessage smp rcvPrivateKey rcvId msgId
|
||||
atomically $ releaseGetLock c rq
|
||||
|
||||
releaseGetLock :: AgentClient -> RcvQueue -> STM ()
|
||||
releaseGetLock c RcvQueue {server, rcvId} =
|
||||
TM.lookup (server, rcvId) (getMsgLocks c) >>= mapM_ (`tryPutTMVar` ())
|
||||
|
||||
suspendQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
suspendQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogSMP c server rcvId "OFF" $ \smp ->
|
||||
withLogClient c server rcvId "OFF" $ \smp ->
|
||||
suspendSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
deleteQueue :: AgentMonad m => AgentClient -> RcvQueue -> m ()
|
||||
deleteQueue c RcvQueue {server, rcvId, rcvPrivateKey} =
|
||||
withLogSMP c server rcvId "DEL" $ \smp ->
|
||||
withLogClient c server rcvId "DEL" $ \smp ->
|
||||
deleteSMPQueue smp rcvPrivateKey rcvId
|
||||
|
||||
sendAgentMessage :: forall m. AgentMonad m => AgentClient -> SndQueue -> ByteString -> m ()
|
||||
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} agentMsg =
|
||||
withLogSMP_ c server sndId "SEND <MSG>" $ \smp -> do
|
||||
sendAgentMessage :: forall m. AgentMonad m => AgentClient -> SndQueue -> MsgFlags -> ByteString -> m ()
|
||||
sendAgentMessage c sq@SndQueue {server, sndId, sndPrivateKey} msgFlags agentMsg =
|
||||
withLogClient_ c server sndId "SEND <MSG>" $ \smp -> do
|
||||
let clientMsg = SMP.ClientMessage SMP.PHEmpty agentMsg
|
||||
msg <- agentCbEncrypt sq Nothing $ smpEncode clientMsg
|
||||
liftSMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msg
|
||||
liftClient SMP $ sendSMPMessage smp (Just sndPrivateKey) sndId msgFlags msg
|
||||
|
||||
agentNtfRegisterToken :: AgentMonad m => AgentClient -> NtfToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> m (NtfTokenId, C.PublicKeyX25519)
|
||||
agentNtfRegisterToken c NtfToken {deviceToken, ntfServer, ntfPrivKey} ntfPubKey pubDhKey =
|
||||
withClient c ntfServer $ \ntf -> ntfRegisterToken ntf ntfPrivKey (NewNtfTkn deviceToken ntfPubKey pubDhKey)
|
||||
|
||||
agentNtfVerifyToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> NtfRegCode -> m ()
|
||||
agentNtfVerifyToken c tknId NtfToken {ntfServer, ntfPrivKey} code =
|
||||
withLogClient c ntfServer tknId "TVFY" $ \ntf -> ntfVerifyToken ntf ntfPrivKey tknId code
|
||||
|
||||
agentNtfCheckToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m NtfTknStatus
|
||||
agentNtfCheckToken c tknId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer tknId "TCHK" $ \ntf -> ntfCheckToken ntf ntfPrivKey tknId
|
||||
|
||||
agentNtfReplaceToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> DeviceToken -> m ()
|
||||
agentNtfReplaceToken c tknId NtfToken {ntfServer, ntfPrivKey} token =
|
||||
withLogClient c ntfServer tknId "TRPL" $ \ntf -> ntfReplaceToken ntf ntfPrivKey tknId token
|
||||
|
||||
agentNtfDeleteToken :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> m ()
|
||||
agentNtfDeleteToken c tknId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer tknId "TDEL" $ \ntf -> ntfDeleteToken ntf ntfPrivKey tknId
|
||||
|
||||
agentNtfEnableCron :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> Word16 -> m ()
|
||||
agentNtfEnableCron c tknId NtfToken {ntfServer, ntfPrivKey} interval =
|
||||
withLogClient c ntfServer tknId "TCRN" $ \ntf -> ntfEnableCron ntf ntfPrivKey tknId interval
|
||||
|
||||
agentNtfCreateSubscription :: AgentMonad m => AgentClient -> NtfTokenId -> NtfToken -> SMPQueueNtf -> NtfPrivateSignKey -> m NtfSubscriptionId
|
||||
agentNtfCreateSubscription c tknId NtfToken {ntfServer, ntfPrivKey} smpQueue nKey =
|
||||
withLogClient c ntfServer tknId "SNEW" $ \ntf -> ntfCreateSubscription ntf ntfPrivKey (NewNtfSub tknId smpQueue nKey)
|
||||
|
||||
agentNtfCheckSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m NtfSubStatus
|
||||
agentNtfCheckSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer subId "SCHK" $ \ntf -> ntfCheckSubscription ntf ntfPrivKey subId
|
||||
|
||||
agentNtfDeleteSubscription :: AgentMonad m => AgentClient -> NtfSubscriptionId -> NtfToken -> m ()
|
||||
agentNtfDeleteSubscription c subId NtfToken {ntfServer, ntfPrivKey} =
|
||||
withLogClient c ntfServer subId "SDEL" $ \ntf -> ntfDeleteSubscription ntf ntfPrivKey subId
|
||||
|
||||
agentCbEncrypt :: AgentMonad m => SndQueue -> Maybe C.PublicKeyX25519 -> ByteString -> m ByteString
|
||||
agentCbEncrypt SndQueue {e2eDhSecret} e2ePubKey msg = do
|
||||
|
@ -473,4 +689,88 @@ cryptoError = \case
|
|||
C.CryptoHeaderError _ -> AGENT A_ENCRYPTION
|
||||
C.AESDecryptError -> AGENT A_ENCRYPTION
|
||||
C.CBDecryptError -> AGENT A_ENCRYPTION
|
||||
C.CERatchetDuplicateMessage -> AGENT A_DUPLICATE
|
||||
e -> INTERNAL $ show e
|
||||
|
||||
endAgentOperation :: AgentClient -> AgentOperation -> STM ()
|
||||
endAgentOperation c op = endOperation c op $ case op of
|
||||
AONtfNetwork -> pure ()
|
||||
AORcvNetwork ->
|
||||
suspendOperation c AOMsgDelivery $
|
||||
suspendSendingAndDatabase c
|
||||
AOMsgDelivery ->
|
||||
suspendSendingAndDatabase c
|
||||
AOSndNetwork ->
|
||||
suspendOperation c AODatabase $
|
||||
notifySuspended c
|
||||
AODatabase ->
|
||||
notifySuspended c
|
||||
|
||||
suspendSendingAndDatabase :: AgentClient -> STM ()
|
||||
suspendSendingAndDatabase c =
|
||||
suspendOperation c AOSndNetwork $
|
||||
suspendOperation c AODatabase $
|
||||
notifySuspended c
|
||||
|
||||
suspendOperation :: AgentClient -> AgentOperation -> STM () -> STM ()
|
||||
suspendOperation c op endedAction = do
|
||||
n <- stateTVar (agentOpSel op c) $ \s -> (opsInProgress s, s {opSuspended = True})
|
||||
-- unsafeIOToSTM $ putStrLn $ "suspendOperation_ " <> show op <> " " <> show n
|
||||
when (n == 0) $ whenSuspending c endedAction
|
||||
|
||||
notifySuspended :: AgentClient -> STM ()
|
||||
notifySuspended c = do
|
||||
-- unsafeIOToSTM $ putStrLn "notifySuspended"
|
||||
writeTBQueue (subQ c) ("", "", SUSPENDED)
|
||||
writeTVar (agentState c) ASSuspended
|
||||
|
||||
endOperation :: AgentClient -> AgentOperation -> STM () -> STM ()
|
||||
endOperation c op endedAction = do
|
||||
(suspended, n) <- stateTVar (agentOpSel op c) $ \s ->
|
||||
let n = max 0 (opsInProgress s - 1)
|
||||
in ((opSuspended s, n), s {opsInProgress = n})
|
||||
-- unsafeIOToSTM $ putStrLn $ "endOperation: " <> show op <> " " <> show suspended <> " " <> show n
|
||||
when (suspended && n == 0) $ whenSuspending c endedAction
|
||||
|
||||
whenSuspending :: AgentClient -> STM () -> STM ()
|
||||
whenSuspending c = whenM ((== ASSuspending) <$> readTVar (agentState c))
|
||||
|
||||
beginAgentOperation :: AgentClient -> AgentOperation -> STM ()
|
||||
beginAgentOperation c op = do
|
||||
let opVar = agentOpSel op c
|
||||
s <- readTVar opVar
|
||||
-- unsafeIOToSTM $ putStrLn $ "beginOperation? " <> show op <> " " <> show (opsInProgress s)
|
||||
when (opSuspended s) retry
|
||||
-- unsafeIOToSTM $ putStrLn $ "beginOperation! " <> show op <> " " <> show (opsInProgress s + 1)
|
||||
writeTVar opVar $! s {opsInProgress = opsInProgress s + 1}
|
||||
|
||||
agentOperationBracket :: MonadUnliftIO m => AgentClient -> AgentOperation -> m a -> m a
|
||||
agentOperationBracket c op action =
|
||||
E.bracket
|
||||
(atomically $ beginAgentOperation c op)
|
||||
(\_ -> atomically $ endAgentOperation c op)
|
||||
(const action)
|
||||
|
||||
withStore' :: AgentMonad m => AgentClient -> (DB.Connection -> IO a) -> m a
|
||||
withStore' c action = withStore c $ fmap Right . action
|
||||
|
||||
withStore :: AgentMonad m => AgentClient -> (DB.Connection -> IO (Either StoreError a)) -> m a
|
||||
withStore c action = do
|
||||
st <- asks store
|
||||
liftEitherError storeError . agentOperationBracket c AODatabase $
|
||||
withTransaction st action `E.catch` handleInternal
|
||||
where
|
||||
handleInternal :: E.SomeException -> IO (Either StoreError a)
|
||||
handleInternal = pure . Left . SEInternal . bshow
|
||||
storeError :: StoreError -> AgentErrorType
|
||||
storeError = \case
|
||||
SEConnNotFound -> CONN NOT_FOUND
|
||||
SEConnDuplicate -> CONN DUPLICATE
|
||||
SEBadConnType CRcv -> CONN SIMPLEX
|
||||
SEBadConnType CSnd -> CONN SIMPLEX
|
||||
SEInvitationNotFound -> CMD PROHIBITED
|
||||
-- this error is never reported as store error,
|
||||
-- it is used to wrap agent operations when "transaction-like" store access is needed
|
||||
-- NOTE: network IO should NOT be used inside AgentStoreMonad
|
||||
SEAgentError e -> e
|
||||
e -> INTERNAL $ show e
|
||||
|
|
|
@ -1,89 +1,155 @@
|
|||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Env.SQLite
|
||||
( AgentConfig (..),
|
||||
( AgentMonad,
|
||||
AgentConfig (..),
|
||||
InitialAgentServers (..),
|
||||
defaultAgentConfig,
|
||||
defaultReconnectInterval,
|
||||
Env (..),
|
||||
newSMPAgentEnv,
|
||||
NtfSupervisor (..),
|
||||
NtfSupervisorCommand (..),
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random
|
||||
import Data.List.NonEmpty (NonEmpty)
|
||||
import Data.Time.Clock (NominalDiffTime, nominalDay)
|
||||
import Data.Word (Word16)
|
||||
import Network.Socket
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.Protocol (SMPServer)
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import Simplex.Messaging.Client
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Protocol (NtfServer)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (TLS, Transport (..))
|
||||
import Simplex.Messaging.Version
|
||||
import System.Random (StdGen, newStdGen)
|
||||
import UnliftIO (Async)
|
||||
import UnliftIO.STM
|
||||
|
||||
-- | Agent monad with MonadReader Env and MonadError AgentErrorType
|
||||
type AgentMonad m = (MonadUnliftIO m, MonadReader Env m, MonadError AgentErrorType m)
|
||||
|
||||
data InitialAgentServers = InitialAgentServers
|
||||
{ smp :: NonEmpty SMPServer,
|
||||
ntf :: [NtfServer]
|
||||
}
|
||||
|
||||
data AgentConfig = AgentConfig
|
||||
{ tcpPort :: ServiceName,
|
||||
initialSMPServers :: NonEmpty SMPServer,
|
||||
cmdSignAlg :: C.SignAlg,
|
||||
connIdBytes :: Int,
|
||||
tbqSize :: Natural,
|
||||
dbFile :: FilePath,
|
||||
dbPoolSize :: Int,
|
||||
yesToMigrations :: Bool,
|
||||
smpCfg :: SMPClientConfig,
|
||||
smpCfg :: ProtocolClientConfig,
|
||||
ntfCfg :: ProtocolClientConfig,
|
||||
reconnectInterval :: RetryInterval,
|
||||
helloTimeout :: NominalDiffTime,
|
||||
resubscriptionConcurrency :: Int,
|
||||
ntfCron :: Word16,
|
||||
ntfWorkerDelay :: Int,
|
||||
ntfSMPWorkerDelay :: Int,
|
||||
ntfSubCheckInterval :: NominalDiffTime,
|
||||
ntfMaxMessages :: Int,
|
||||
caCertificateFile :: FilePath,
|
||||
privateKeyFile :: FilePath,
|
||||
certificateFile :: FilePath
|
||||
certificateFile :: FilePath,
|
||||
smpAgentVersion :: Version,
|
||||
smpAgentVRange :: VersionRange
|
||||
}
|
||||
|
||||
defaultReconnectInterval :: RetryInterval
|
||||
defaultReconnectInterval =
|
||||
RetryInterval
|
||||
{ initialInterval = second,
|
||||
increaseAfter = 10 * second,
|
||||
maxInterval = 10 * second
|
||||
}
|
||||
where
|
||||
second = 1_000_000
|
||||
|
||||
defaultAgentConfig :: AgentConfig
|
||||
defaultAgentConfig =
|
||||
AgentConfig
|
||||
{ tcpPort = "5224",
|
||||
initialSMPServers = undefined, -- TODO move it elsewhere?
|
||||
cmdSignAlg = C.SignAlg C.SEd448,
|
||||
connIdBytes = 12,
|
||||
tbqSize = 64,
|
||||
dbFile = "smp-agent.db",
|
||||
dbPoolSize = 4,
|
||||
yesToMigrations = False,
|
||||
smpCfg = smpDefaultConfig,
|
||||
reconnectInterval =
|
||||
RetryInterval
|
||||
{ initialInterval = second,
|
||||
increaseAfter = 10 * second,
|
||||
maxInterval = 10 * second
|
||||
},
|
||||
smpCfg = defaultClientConfig {defaultTransport = ("5223", transport @TLS)},
|
||||
ntfCfg = defaultClientConfig {defaultTransport = ("443", transport @TLS)},
|
||||
reconnectInterval = defaultReconnectInterval,
|
||||
helloTimeout = 2 * nominalDay,
|
||||
resubscriptionConcurrency = 16,
|
||||
ntfCron = 20, -- minutes
|
||||
ntfWorkerDelay = 100000, -- microseconds
|
||||
ntfSMPWorkerDelay = 500000, -- microseconds
|
||||
ntfSubCheckInterval = nominalDay,
|
||||
ntfMaxMessages = 4,
|
||||
-- CA certificate private key is not needed for initialization
|
||||
-- ! we do not generate these
|
||||
caCertificateFile = "/etc/opt/simplex-agent/ca.crt",
|
||||
privateKeyFile = "/etc/opt/simplex-agent/agent.key",
|
||||
certificateFile = "/etc/opt/simplex-agent/agent.crt"
|
||||
certificateFile = "/etc/opt/simplex-agent/agent.crt",
|
||||
smpAgentVersion = currentSMPAgentVersion,
|
||||
smpAgentVRange = supportedSMPAgentVRange
|
||||
}
|
||||
where
|
||||
second = 1_000_000
|
||||
|
||||
data Env = Env
|
||||
{ config :: AgentConfig,
|
||||
store :: SQLiteStore,
|
||||
idsDrg :: TVar ChaChaDRG,
|
||||
clientCounter :: TVar Int,
|
||||
randomServer :: TVar StdGen
|
||||
randomServer :: TVar StdGen,
|
||||
ntfSupervisor :: NtfSupervisor
|
||||
}
|
||||
|
||||
newSMPAgentEnv :: (MonadUnliftIO m, MonadRandom m) => AgentConfig -> m Env
|
||||
newSMPAgentEnv config@AgentConfig {dbFile, dbPoolSize, yesToMigrations} = do
|
||||
newSMPAgentEnv config@AgentConfig {dbFile, yesToMigrations} = do
|
||||
idsDrg <- newTVarIO =<< drgNew
|
||||
store <- liftIO $ createSQLiteStore dbFile dbPoolSize Migrations.app yesToMigrations
|
||||
store <- liftIO $ createSQLiteStore dbFile Migrations.app yesToMigrations
|
||||
clientCounter <- newTVarIO 0
|
||||
randomServer <- newTVarIO =<< liftIO newStdGen
|
||||
return Env {config, store, idsDrg, clientCounter, randomServer}
|
||||
ntfSupervisor <- atomically . newNtfSubSupervisor $ tbqSize config
|
||||
return Env {config, store, idsDrg, clientCounter, randomServer, ntfSupervisor}
|
||||
|
||||
data NtfSupervisor = NtfSupervisor
|
||||
{ ntfTkn :: TVar (Maybe NtfToken),
|
||||
ntfSubQ :: TBQueue (ConnId, NtfSupervisorCommand),
|
||||
ntfWorkers :: TMap NtfServer (TMVar (), Async ()),
|
||||
ntfSMPWorkers :: TMap SMPServer (TMVar (), Async ())
|
||||
}
|
||||
|
||||
data NtfSupervisorCommand = NSCCreate | NSCDelete | NSCSmpDelete | NSCNtfWorker NtfServer | NSCNtfSMPWorker SMPServer
|
||||
deriving (Show)
|
||||
|
||||
newNtfSubSupervisor :: Natural -> STM NtfSupervisor
|
||||
newNtfSubSupervisor qSize = do
|
||||
ntfTkn <- newTVar Nothing
|
||||
ntfSubQ <- newTBQueue qSize
|
||||
ntfWorkers <- TM.empty
|
||||
ntfSMPWorkers <- TM.empty
|
||||
pure NtfSupervisor {ntfTkn, ntfSubQ, ntfWorkers, ntfSMPWorkers}
|
||||
|
|
|
@ -0,0 +1,357 @@
|
|||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module Simplex.Messaging.Agent.NtfSubSupervisor
|
||||
( runNtfSupervisor,
|
||||
nsUpdateToken,
|
||||
nsRemoveNtfToken,
|
||||
sendNtfSubCommand,
|
||||
closeNtfSupervisor,
|
||||
getNtfServer,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.Async (Async, uninterruptibleCancel)
|
||||
import Control.Concurrent.STM (stateTVar)
|
||||
import Control.Logger.Simple (logError, logInfo)
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Control.Monad.Reader
|
||||
import Data.Bifunctor (first)
|
||||
import Data.Fixed (Fixed (MkFixed), Pico)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Text (Text)
|
||||
import Data.Time (UTCTime, addUTCTime, diffUTCTime, getCurrentTime, nominalDiffTimeToSeconds)
|
||||
import Simplex.Messaging.Agent.Client
|
||||
import Simplex.Messaging.Agent.Env.SQLite
|
||||
import Simplex.Messaging.Agent.Protocol (AgentErrorType (..), BrokerErrorType (..), ConnId, NotificationsMode (..))
|
||||
import qualified Simplex.Messaging.Agent.Protocol as AP
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Store
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import Simplex.Messaging.Client.Agent ()
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol (NtfSubStatus (..), NtfTknStatus (..), SMPQueueNtf (..))
|
||||
import Simplex.Messaging.Notifications.Types
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (tshow, unlessM)
|
||||
import System.Random (randomR)
|
||||
import UnliftIO (async)
|
||||
import UnliftIO.Concurrent (forkIO, threadDelay)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
runNtfSupervisor :: forall m. (MonadUnliftIO m, MonadReader Env m) => AgentClient -> m ()
|
||||
runNtfSupervisor c = do
|
||||
ns <- asks ntfSupervisor
|
||||
forever $ do
|
||||
cmd@(connId, _) <- atomically . readTBQueue $ ntfSubQ ns
|
||||
handleError connId . agentOperationBracket c AONtfNetwork $
|
||||
runExceptT (processNtfSub c cmd) >>= \case
|
||||
Left e -> notifyErr connId e
|
||||
Right _ -> return ()
|
||||
where
|
||||
handleError :: ConnId -> m () -> m ()
|
||||
handleError connId = E.handle $ \(e :: E.SomeException) -> do
|
||||
logError $ "runNtfSupervisor error " <> tshow e
|
||||
notifyErr connId e
|
||||
notifyErr connId e = notifyInternalError c connId $ "runNtfSupervisor error " <> show e
|
||||
|
||||
processNtfSub :: forall m. AgentMonad m => AgentClient -> (ConnId, NtfSupervisorCommand) -> m ()
|
||||
processNtfSub c (connId, cmd) = do
|
||||
logInfo $ "processNtfSub - connId = " <> tshow connId <> " - cmd = " <> tshow cmd
|
||||
case cmd of
|
||||
NSCCreate -> do
|
||||
(a, RcvQueue {server = smpServer, clientNtfCreds}) <- withStore c $ \db -> runExceptT $ do
|
||||
a <- liftIO $ getNtfSubscription db connId
|
||||
q <- ExceptT $ getRcvQueue db connId
|
||||
pure (a, q)
|
||||
logInfo $ "processNtfSub, NSCCreate - a = " <> tshow a
|
||||
case a of
|
||||
Nothing -> do
|
||||
withNtfServer c $ \ntfServer -> do
|
||||
case clientNtfCreds of
|
||||
Just ClientNtfCreds {notifierId} -> do
|
||||
let newSub = newNtfSubscription connId smpServer (Just notifierId) ntfServer NASKey
|
||||
ts <- liftIO getCurrentTime
|
||||
withStore' c $ \db -> createNtfSubscription db newSub (NtfSubNTFAction NSACreate) ts
|
||||
addNtfNTFWorker ntfServer
|
||||
Nothing -> do
|
||||
let newSub = newNtfSubscription connId smpServer Nothing ntfServer NASNew
|
||||
ts <- liftIO getCurrentTime
|
||||
withStore' c $ \db -> createNtfSubscription db newSub (NtfSubSMPAction NSASmpKey) ts
|
||||
addNtfSMPWorker smpServer
|
||||
(Just (sub@NtfSubscription {ntfSubStatus, ntfServer = subNtfServer}, action_)) -> do
|
||||
case action_ of
|
||||
-- action was set to NULL after worker internal error
|
||||
Nothing -> resetSubscription
|
||||
Just (action, _)
|
||||
-- subscription was marked for deletion / is being deleted
|
||||
| isDeleteNtfSubAction action -> do
|
||||
if ntfSubStatus == NASNew || ntfSubStatus == NASOff || ntfSubStatus == NASDeleted
|
||||
then resetSubscription
|
||||
else withNtfServer c $ \ntfServer -> do
|
||||
ts <- liftIO getCurrentTime
|
||||
withStore' c $ \db ->
|
||||
supervisorUpdateNtfSubscription db sub {ntfServer} (NtfSubNTFAction NSACreate) ts
|
||||
addNtfNTFWorker ntfServer
|
||||
| otherwise -> case action of
|
||||
NtfSubNTFAction _ -> addNtfNTFWorker subNtfServer
|
||||
NtfSubSMPAction _ -> addNtfSMPWorker smpServer
|
||||
where
|
||||
resetSubscription :: m ()
|
||||
resetSubscription =
|
||||
withNtfServer c $ \ntfServer -> do
|
||||
ts <- liftIO getCurrentTime
|
||||
withStore' c $ \db ->
|
||||
supervisorUpdateNtfSubscription db sub {ntfQueueId = Nothing, ntfServer, ntfSubId = Nothing, ntfSubStatus = NASNew} (NtfSubSMPAction NSASmpKey) ts
|
||||
addNtfSMPWorker smpServer
|
||||
NSCDelete -> do
|
||||
sub_ <- withStore' c $ \db -> do
|
||||
ts <- liftIO getCurrentTime
|
||||
supervisorUpdateNtfSubAction db connId (NtfSubNTFAction NSADelete) ts
|
||||
getNtfSubscription db connId
|
||||
logInfo $ "processNtfSub, NSCDelete - sub_ = " <> tshow sub_
|
||||
case sub_ of
|
||||
(Just (NtfSubscription {ntfServer}, _)) -> addNtfNTFWorker ntfServer
|
||||
_ -> pure () -- err "NSCDelete - no subscription"
|
||||
NSCSmpDelete -> do
|
||||
withStore' c (`getRcvQueue` connId) >>= \case
|
||||
Right rq@RcvQueue {server = smpServer} -> do
|
||||
logInfo $ "processNtfSub, NSCSmpDelete - rq = " <> tshow rq
|
||||
ts <- liftIO getCurrentTime
|
||||
withStore' c $ \db -> supervisorUpdateNtfSubAction db connId (NtfSubSMPAction NSASmpDelete) ts
|
||||
addNtfSMPWorker smpServer
|
||||
_ -> notifyInternalError c connId "NSCSmpDelete - no rcv queue"
|
||||
NSCNtfWorker ntfServer ->
|
||||
addNtfNTFWorker ntfServer
|
||||
NSCNtfSMPWorker smpServer ->
|
||||
addNtfSMPWorker smpServer
|
||||
where
|
||||
addNtfNTFWorker = addWorker ntfWorkers runNtfWorker
|
||||
addNtfSMPWorker = addWorker ntfSMPWorkers runNtfSMPWorker
|
||||
addWorker ::
|
||||
(NtfSupervisor -> TMap (ProtocolServer s) (TMVar (), Async ())) ->
|
||||
(AgentClient -> ProtocolServer s -> TMVar () -> m ()) ->
|
||||
ProtocolServer s ->
|
||||
m ()
|
||||
addWorker wsSel runWorker srv = do
|
||||
ws <- asks $ wsSel . ntfSupervisor
|
||||
atomically (TM.lookup srv ws) >>= \case
|
||||
Nothing -> do
|
||||
doWork <- newTMVarIO ()
|
||||
worker <- async $ runWorker c srv doWork `E.finally` atomically (TM.delete srv ws)
|
||||
atomically $ TM.insert srv (doWork, worker) ws
|
||||
Just (doWork, _) ->
|
||||
void . atomically $ tryPutTMVar doWork ()
|
||||
|
||||
withNtfServer :: AgentMonad m => AgentClient -> (NtfServer -> m ()) -> m ()
|
||||
withNtfServer c action = getNtfServer c >>= mapM_ action
|
||||
|
||||
runNtfWorker :: forall m. AgentMonad m => AgentClient -> NtfServer -> TMVar () -> m ()
|
||||
runNtfWorker c srv doWork = do
|
||||
delay <- asks $ ntfWorkerDelay . config
|
||||
forever $ do
|
||||
void . atomically $ readTMVar doWork
|
||||
agentOperationBracket c AONtfNetwork runNtfOperation
|
||||
threadDelay delay
|
||||
where
|
||||
runNtfOperation :: m ()
|
||||
runNtfOperation = do
|
||||
nextSub_ <- withStore' c (`getNextNtfSubNTFAction` srv)
|
||||
logInfo $ "runNtfWorker, nextSub_ " <> tshow nextSub_
|
||||
case nextSub_ of
|
||||
Nothing -> noWorkToDo
|
||||
Just a@(NtfSubscription {connId}, _, _) -> do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop ->
|
||||
processAction a
|
||||
`catchError` retryOnError c "NtfWorker" loop (workerInternalError c connId . show)
|
||||
noWorkToDo = void . atomically $ tryTakeTMVar doWork
|
||||
processAction :: (NtfSubscription, NtfSubNTFAction, NtfActionTs) -> m ()
|
||||
processAction (sub@NtfSubscription {connId, smpServer, ntfSubId}, action, actionTs) = do
|
||||
ts <- liftIO getCurrentTime
|
||||
unlessM (rescheduleAction doWork ts actionTs) $
|
||||
case action of
|
||||
NSACreate ->
|
||||
getNtfToken >>= \case
|
||||
Just tkn@NtfToken {ntfTokenId = Just tknId, ntfTknStatus = NTActive, ntfMode = NMInstant} -> do
|
||||
RcvQueue {clientNtfCreds} <- withStore c (`getRcvQueue` connId)
|
||||
case clientNtfCreds of
|
||||
Just ClientNtfCreds {ntfPrivateKey, notifierId} -> do
|
||||
nSubId <- agentNtfCreateSubscription c tknId tkn (SMPQueueNtf smpServer notifierId) ntfPrivateKey
|
||||
-- TODO smaller retry until Active, less frequently (daily?) once Active
|
||||
let actionTs' = addUTCTime 30 ts
|
||||
withStore' c $ \db ->
|
||||
updateNtfSubscription db sub {ntfSubId = Just nSubId, ntfSubStatus = NASCreated NSNew} (NtfSubNTFAction NSACheck) actionTs'
|
||||
_ -> workerInternalError c connId "NSACreate - no notifier queue credentials"
|
||||
_ -> workerInternalError c connId "NSACreate - no active token"
|
||||
NSACheck ->
|
||||
getNtfToken >>= \case
|
||||
Just tkn ->
|
||||
case ntfSubId of
|
||||
Just nSubId ->
|
||||
agentNtfCheckSubscription c nSubId tkn >>= \case
|
||||
NSAuth -> do
|
||||
getNtfServer c >>= \case
|
||||
Just ntfServer -> do
|
||||
withStore' c $ \db ->
|
||||
updateNtfSubscription db sub {ntfServer, ntfQueueId = Nothing, ntfSubId = Nothing, ntfSubStatus = NASNew} (NtfSubSMPAction NSASmpKey) ts
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer)
|
||||
_ -> workerInternalError c connId "NSACheck - failed to reset subscription, notification server not configured"
|
||||
status -> updateSubNextCheck ts status
|
||||
Nothing -> workerInternalError c connId "NSACheck - no subscription ID"
|
||||
_ -> workerInternalError c connId "NSACheck - no active token"
|
||||
NSADelete -> case ntfSubId of
|
||||
Just nSubId ->
|
||||
(getNtfToken >>= \tkn -> forM_ tkn $ agentNtfDeleteSubscription c nSubId)
|
||||
`E.finally` carryOnWithDeletion
|
||||
Nothing -> carryOnWithDeletion
|
||||
where
|
||||
carryOnWithDeletion :: m ()
|
||||
carryOnWithDeletion = do
|
||||
withStore' c $ \db ->
|
||||
updateNtfSubscription db sub {ntfSubId = Nothing, ntfSubStatus = NASOff} (NtfSubSMPAction NSASmpDelete) ts
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ writeTBQueue (ntfSubQ ns) (connId, NSCNtfSMPWorker smpServer)
|
||||
where
|
||||
updateSubNextCheck ts toStatus = do
|
||||
checkInterval <- asks $ ntfSubCheckInterval . config
|
||||
let nextCheckTs = addUTCTime checkInterval ts
|
||||
updateSub (NASCreated toStatus) (NtfSubNTFAction NSACheck) nextCheckTs
|
||||
updateSub toStatus toAction actionTs' =
|
||||
withStore' c $ \db ->
|
||||
updateNtfSubscription db sub {ntfSubStatus = toStatus} toAction actionTs'
|
||||
|
||||
runNtfSMPWorker :: forall m. AgentMonad m => AgentClient -> SMPServer -> TMVar () -> m ()
|
||||
runNtfSMPWorker c srv doWork = do
|
||||
delay <- asks $ ntfSMPWorkerDelay . config
|
||||
forever $ do
|
||||
void . atomically $ readTMVar doWork
|
||||
agentOperationBracket c AONtfNetwork runNtfSMPOperation
|
||||
threadDelay delay
|
||||
where
|
||||
runNtfSMPOperation = do
|
||||
nextSub_ <- withStore' c (`getNextNtfSubSMPAction` srv)
|
||||
logInfo $ "runNtfSMPWorker, nextSub_ " <> tshow nextSub_
|
||||
case nextSub_ of
|
||||
Nothing -> noWorkToDo
|
||||
Just a@(NtfSubscription {connId}, _, _) -> do
|
||||
ri <- asks $ reconnectInterval . config
|
||||
withRetryInterval ri $ \loop ->
|
||||
processAction a
|
||||
`catchError` retryOnError c "NtfSMPWorker" loop (workerInternalError c connId . show)
|
||||
noWorkToDo = void . atomically $ tryTakeTMVar doWork
|
||||
processAction :: (NtfSubscription, NtfSubSMPAction, NtfActionTs) -> m ()
|
||||
processAction (sub@NtfSubscription {connId, ntfServer}, smpAction, actionTs) = do
|
||||
ts <- liftIO getCurrentTime
|
||||
unlessM (rescheduleAction doWork ts actionTs) $
|
||||
case smpAction of
|
||||
NSASmpKey ->
|
||||
getNtfToken >>= \case
|
||||
Just NtfToken {ntfTknStatus = NTActive, ntfMode = NMInstant} -> do
|
||||
rq <- withStore c (`getRcvQueue` connId)
|
||||
C.SignAlg a <- asks (cmdSignAlg . config)
|
||||
(ntfPublicKey, ntfPrivateKey) <- liftIO $ C.generateSignatureKeyPair a
|
||||
(rcvNtfPubDhKey, rcvNtfPrivDhKey) <- liftIO C.generateKeyPair'
|
||||
(notifierId, rcvNtfSrvPubDhKey) <- enableQueueNotifications c rq ntfPublicKey rcvNtfPubDhKey
|
||||
let rcvNtfDhSecret = C.dh' rcvNtfSrvPubDhKey rcvNtfPrivDhKey
|
||||
withStore' c $ \db -> do
|
||||
setRcvQueueNtfCreds db connId $ Just ClientNtfCreds {ntfPublicKey, ntfPrivateKey, notifierId, rcvNtfDhSecret}
|
||||
updateNtfSubscription db sub {ntfQueueId = Just notifierId, ntfSubStatus = NASKey} (NtfSubNTFAction NSACreate) ts
|
||||
ns <- asks ntfSupervisor
|
||||
atomically $ sendNtfSubCommand ns (connId, NSCNtfWorker ntfServer)
|
||||
_ -> workerInternalError c connId "NSASmpKey - no active token"
|
||||
NSASmpDelete -> do
|
||||
rq_ <- withStore' c $ \db -> do
|
||||
setRcvQueueNtfCreds db connId Nothing
|
||||
getRcvQueue db connId
|
||||
forM_ rq_ $ \rq -> disableQueueNotifications c rq
|
||||
withStore' c $ \db -> deleteNtfSubscription db connId
|
||||
|
||||
rescheduleAction :: AgentMonad m => TMVar () -> UTCTime -> UTCTime -> m Bool
|
||||
rescheduleAction doWork ts actionTs
|
||||
| actionTs <= ts = pure False
|
||||
| otherwise = do
|
||||
void . atomically $ tryTakeTMVar doWork
|
||||
void . forkIO $ do
|
||||
threadDelay $ diffInMicros actionTs ts
|
||||
void . atomically $ tryPutTMVar doWork ()
|
||||
pure True
|
||||
|
||||
fromPico :: Pico -> Integer
|
||||
fromPico (MkFixed i) = i
|
||||
|
||||
diffInMicros :: UTCTime -> UTCTime -> Int
|
||||
diffInMicros a b = (`div` 1000000) . fromInteger . fromPico . nominalDiffTimeToSeconds $ diffUTCTime a b
|
||||
|
||||
retryOnError :: AgentMonad m => AgentClient -> Text -> m () -> (AgentErrorType -> m ()) -> AgentErrorType -> m ()
|
||||
retryOnError c name loop done e = do
|
||||
logError $ name <> " error: " <> tshow e
|
||||
case e of
|
||||
BROKER NETWORK -> retryLoop
|
||||
BROKER TIMEOUT -> retryLoop
|
||||
_ -> done e
|
||||
where
|
||||
retryLoop = do
|
||||
atomically $ endAgentOperation c AONtfNetwork
|
||||
atomically $ beginAgentOperation c AONtfNetwork
|
||||
loop
|
||||
|
||||
workerInternalError :: AgentMonad m => AgentClient -> ConnId -> String -> m ()
|
||||
workerInternalError c connId internalErrStr = do
|
||||
withStore' c $ \db -> setNullNtfSubscriptionAction db connId
|
||||
notifyInternalError c connId internalErrStr
|
||||
|
||||
notifyInternalError :: (MonadUnliftIO m) => AgentClient -> ConnId -> String -> m ()
|
||||
notifyInternalError AgentClient {subQ} connId internalErrStr = atomically $ writeTBQueue subQ ("", connId, AP.ERR $ AP.INTERNAL internalErrStr)
|
||||
|
||||
getNtfToken :: AgentMonad m => m (Maybe NtfToken)
|
||||
getNtfToken = do
|
||||
tkn <- asks $ ntfTkn . ntfSupervisor
|
||||
readTVarIO tkn
|
||||
|
||||
nsUpdateToken :: NtfSupervisor -> NtfToken -> STM ()
|
||||
nsUpdateToken ns tkn = writeTVar (ntfTkn ns) $ Just tkn
|
||||
|
||||
nsRemoveNtfToken :: NtfSupervisor -> STM ()
|
||||
nsRemoveNtfToken ns = writeTVar (ntfTkn ns) Nothing
|
||||
|
||||
sendNtfSubCommand :: NtfSupervisor -> (ConnId, NtfSupervisorCommand) -> STM ()
|
||||
sendNtfSubCommand ns cmd =
|
||||
readTVar (ntfTkn ns)
|
||||
>>= mapM_
|
||||
( \NtfToken {ntfTknStatus, ntfMode} ->
|
||||
when (ntfTknStatus == NTActive && ntfMode == NMInstant) $
|
||||
writeTBQueue (ntfSubQ ns) cmd
|
||||
)
|
||||
|
||||
closeNtfSupervisor :: NtfSupervisor -> IO ()
|
||||
closeNtfSupervisor ns = do
|
||||
cancelNtfWorkers_ $ ntfWorkers ns
|
||||
cancelNtfWorkers_ $ ntfSMPWorkers ns
|
||||
|
||||
cancelNtfWorkers_ :: TMap (ProtocolServer s) (TMVar (), Async ()) -> IO ()
|
||||
cancelNtfWorkers_ wsVar = do
|
||||
ws <- atomically $ stateTVar wsVar (,M.empty)
|
||||
forM_ ws $ uninterruptibleCancel . snd
|
||||
|
||||
getNtfServer :: AgentMonad m => AgentClient -> m (Maybe NtfServer)
|
||||
getNtfServer c = do
|
||||
ntfServers <- readTVarIO $ ntfServers c
|
||||
case ntfServers of
|
||||
[] -> pure Nothing
|
||||
[srv] -> pure $ Just srv
|
||||
servers -> do
|
||||
gen <- asks randomServer
|
||||
atomically . stateTVar gen $
|
||||
first (Just . (servers !!)) . randomR (0, length servers - 1)
|
|
@ -7,6 +7,7 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
@ -31,8 +32,8 @@
|
|||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/agent-protocol.md
|
||||
module Simplex.Messaging.Agent.Protocol
|
||||
( -- * Protocol parameters
|
||||
smpAgentVersion,
|
||||
smpAgentVRange,
|
||||
currentSMPAgentVersion,
|
||||
supportedSMPAgentVRange,
|
||||
e2eEncConnInfoLength,
|
||||
e2eEncUserMsgLength,
|
||||
|
||||
|
@ -49,7 +50,8 @@ module Simplex.Messaging.Agent.Protocol
|
|||
AgentMessageType (..),
|
||||
APrivHeader (..),
|
||||
AMessage (..),
|
||||
SMPServer (..),
|
||||
SMPServer,
|
||||
pattern SMPServer,
|
||||
SrvLoc (..),
|
||||
SMPQueueUri (..),
|
||||
SMPQueueInfo (..),
|
||||
|
@ -80,6 +82,8 @@ module Simplex.Messaging.Agent.Protocol
|
|||
QueueStatus (..),
|
||||
ACorrId,
|
||||
AgentMsgId,
|
||||
NotificationsMode (..),
|
||||
NotificationInfo (..),
|
||||
|
||||
-- * Encode/decode
|
||||
serializeCommand,
|
||||
|
@ -108,7 +112,7 @@ import qualified Data.Attoparsec.ByteString.Char8 as A
|
|||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Composition ((.:))
|
||||
import Data.Composition ((.:), (.:.))
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.Kind (Type)
|
||||
|
@ -116,9 +120,12 @@ import qualified Data.List.NonEmpty as L
|
|||
import Data.Maybe (isJust)
|
||||
import Data.Text (Text)
|
||||
import Data.Time.Clock (UTCTime)
|
||||
import Data.Time.Clock.System (SystemTime)
|
||||
import Data.Time.ISO8601
|
||||
import Data.Type.Equality
|
||||
import Data.Typeable ()
|
||||
import Database.SQLite.Simple.FromField
|
||||
import Database.SQLite.Simple.ToField
|
||||
import GHC.Generics (Generic)
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Simplex.Messaging.Agent.QueryString
|
||||
|
@ -130,10 +137,13 @@ import Simplex.Messaging.Parsers
|
|||
import Simplex.Messaging.Protocol
|
||||
( ErrorType,
|
||||
MsgBody,
|
||||
MsgFlags,
|
||||
MsgId,
|
||||
SMPServer (..),
|
||||
NMsgMeta,
|
||||
SMPServer,
|
||||
SndPublicVerifyKey,
|
||||
SrvLoc (..),
|
||||
pattern SMPServer,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport (Transport (..), TransportError, serializeTransportError, transportErrorP)
|
||||
|
@ -143,11 +153,11 @@ import Test.QuickCheck (Arbitrary (..))
|
|||
import Text.Read
|
||||
import UnliftIO.Exception (Exception)
|
||||
|
||||
smpAgentVersion :: Version
|
||||
smpAgentVersion = 1
|
||||
currentSMPAgentVersion :: Version
|
||||
currentSMPAgentVersion = 2
|
||||
|
||||
smpAgentVRange :: VersionRange
|
||||
smpAgentVRange = mkVersionRange 1 smpAgentVersion
|
||||
supportedSMPAgentVRange :: VersionRange
|
||||
supportedSMPAgentVRange = mkVersionRange 1 currentSMPAgentVersion
|
||||
|
||||
-- it is shorter to allow all handshake headers,
|
||||
-- including E2E (double-ratchet) parameters and
|
||||
|
@ -207,23 +217,52 @@ data ACommand (p :: AParty) where
|
|||
CON :: ACommand Agent -- notification that connection is established
|
||||
SUB :: ACommand Client
|
||||
END :: ACommand Agent
|
||||
DOWN :: ACommand Agent
|
||||
UP :: ACommand Agent
|
||||
SEND :: MsgBody -> ACommand Client
|
||||
DOWN :: SMPServer -> [ConnId] -> ACommand Agent
|
||||
UP :: SMPServer -> [ConnId] -> ACommand Agent
|
||||
SEND :: MsgFlags -> MsgBody -> ACommand Client
|
||||
MID :: AgentMsgId -> ACommand Agent
|
||||
SENT :: AgentMsgId -> ACommand Agent
|
||||
MERR :: AgentMsgId -> AgentErrorType -> ACommand Agent
|
||||
MSG :: MsgMeta -> MsgBody -> ACommand Agent
|
||||
MSG :: MsgMeta -> MsgFlags -> MsgBody -> ACommand Agent
|
||||
ACK :: AgentMsgId -> ACommand Client
|
||||
OFF :: ACommand Client
|
||||
DEL :: ACommand Client
|
||||
OK :: ACommand Agent
|
||||
ERR :: AgentErrorType -> ACommand Agent
|
||||
SUSPENDED :: ACommand Agent
|
||||
|
||||
deriving instance Eq (ACommand p)
|
||||
|
||||
deriving instance Show (ACommand p)
|
||||
|
||||
data NotificationsMode = NMPeriodic | NMInstant
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance StrEncoding NotificationsMode where
|
||||
strEncode = \case
|
||||
NMPeriodic -> "PERIODIC"
|
||||
NMInstant -> "INSTANT"
|
||||
strP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"PERIODIC" -> pure NMPeriodic
|
||||
"INSTANT" -> pure NMInstant
|
||||
_ -> fail "bad NotificationsMode"
|
||||
|
||||
instance ToJSON NotificationsMode where
|
||||
toEncoding = strToJEncoding
|
||||
toJSON = strToJSON
|
||||
|
||||
instance ToField NotificationsMode where toField = toField . strEncode
|
||||
|
||||
instance FromField NotificationsMode where fromField = blobFieldDecoder $ parseAll strP
|
||||
|
||||
data NotificationInfo = NotificationInfo
|
||||
{ ntfConnId :: ConnId,
|
||||
ntfTs :: SystemTime,
|
||||
ntfMsgMeta :: Maybe NMsgMeta
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data ConnectionMode = CMInvitation | CMContact
|
||||
deriving (Eq, Show)
|
||||
|
||||
|
@ -284,7 +323,9 @@ data SMPConfirmation = SMPConfirmation
|
|||
-- | sender's DH public key for simple per-queue e2e encryption
|
||||
e2ePubKey :: C.PublicKeyX25519,
|
||||
-- | sender's information to be associated with the connection, e.g. sender's profile information
|
||||
connInfo :: ConnInfo
|
||||
connInfo :: ConnInfo,
|
||||
-- | optional reply queues included in confirmation (added in agent protocol v2)
|
||||
smpReplyQueues :: [SMPQueueInfo]
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
|
@ -330,31 +371,40 @@ instance Encoding AgentMsgEnvelope where
|
|||
|
||||
-- SMP agent message formats (after double ratchet decryption,
|
||||
-- or in case of AgentInvitation - in plain text body)
|
||||
data AgentMessage = AgentConnInfo ConnInfo | AgentMessage APrivHeader AMessage
|
||||
data AgentMessage
|
||||
= AgentConnInfo ConnInfo
|
||||
| -- AgentConnInfoReply is only used in duplexHandshake mode (v2), allowing to include reply queue(s) in the initial confirmation.
|
||||
-- It makes REPLY message unnecessary.
|
||||
AgentConnInfoReply (L.NonEmpty SMPQueueInfo) ConnInfo
|
||||
| AgentMessage APrivHeader AMessage
|
||||
deriving (Show)
|
||||
|
||||
instance Encoding AgentMessage where
|
||||
smpEncode = \case
|
||||
AgentConnInfo cInfo -> smpEncode ('I', Tail cInfo)
|
||||
AgentConnInfoReply smpQueues cInfo -> smpEncode ('D', smpQueues, Tail cInfo) -- 'D' stands for "duplex"
|
||||
AgentMessage hdr aMsg -> smpEncode ('M', hdr, aMsg)
|
||||
smpP =
|
||||
smpP >>= \case
|
||||
'I' -> AgentConnInfo . unTail <$> smpP
|
||||
'D' -> AgentConnInfoReply <$> smpP <*> (unTail <$> smpP)
|
||||
'M' -> AgentMessage <$> smpP <*> smpP
|
||||
_ -> fail "bad AgentMessage"
|
||||
|
||||
data AgentMessageType = AM_CONN_INFO | AM_HELLO_ | AM_REPLY_ | AM_A_MSG_
|
||||
data AgentMessageType = AM_CONN_INFO | AM_CONN_INFO_REPLY | AM_HELLO_ | AM_REPLY_ | AM_A_MSG_
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Encoding AgentMessageType where
|
||||
smpEncode = \case
|
||||
AM_CONN_INFO -> "C"
|
||||
AM_CONN_INFO_REPLY -> "D"
|
||||
AM_HELLO_ -> "H"
|
||||
AM_REPLY_ -> "R"
|
||||
AM_A_MSG_ -> "M"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'C' -> pure AM_CONN_INFO
|
||||
'D' -> pure AM_CONN_INFO_REPLY
|
||||
'H' -> pure AM_HELLO_
|
||||
'R' -> pure AM_REPLY_
|
||||
'M' -> pure AM_A_MSG_
|
||||
|
@ -363,8 +413,14 @@ instance Encoding AgentMessageType where
|
|||
agentMessageType :: AgentMessage -> AgentMessageType
|
||||
agentMessageType = \case
|
||||
AgentConnInfo _ -> AM_CONN_INFO
|
||||
AgentConnInfoReply {} -> AM_CONN_INFO_REPLY
|
||||
AgentMessage _ aMsg -> case aMsg of
|
||||
-- HELLO is used both in v1 and in v2, but differently.
|
||||
-- - in v1 (and, possibly, in v2 for simplex connections) can be sent multiple times,
|
||||
-- until the queue is secured - the OK response from the server instead of initial AUTH errors confirms it.
|
||||
-- - in v2 duplexHandshake it is sent only once, when it is known that the queue was secured.
|
||||
HELLO -> AM_HELLO_
|
||||
-- REPLY is only used in v1
|
||||
REPLY _ -> AM_REPLY_
|
||||
A_MSG _ -> AM_A_MSG_
|
||||
|
||||
|
@ -539,7 +595,6 @@ data SMPQueueUri = SMPQueueUri
|
|||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- TODO change SMP queue URI format to include version range and allow unknown parameters
|
||||
instance StrEncoding SMPQueueUri where
|
||||
-- v1 uses short SMP queue URI format
|
||||
strEncode SMPQueueUri {smpServer = srv, senderId = qId, clientVRange = _vr, dhPublicKey = k} =
|
||||
|
@ -684,6 +739,8 @@ data AgentErrorType
|
|||
CONN {connErr :: ConnectionErrorType}
|
||||
| -- | SMP protocol errors forwarded to agent clients
|
||||
SMP {smpErr :: ErrorType}
|
||||
| -- | NTF protocol errors forwarded to agent clients
|
||||
NTF {ntfErr :: ErrorType}
|
||||
| -- | SMP server errors
|
||||
BROKER {brokerErr :: BrokerErrorType}
|
||||
| -- | errors of other agents
|
||||
|
@ -761,6 +818,8 @@ data SMPAgentError
|
|||
A_VERSION
|
||||
| -- | cannot decrypt message
|
||||
A_ENCRYPTION
|
||||
| -- | duplicate message - this error is detected by ratchet decryption - this message will be ignored and not shown
|
||||
A_DUPLICATE
|
||||
deriving (Eq, Generic, Read, Show, Exception)
|
||||
|
||||
instance ToJSON SMPAgentError where
|
||||
|
@ -772,6 +831,7 @@ instance StrEncoding AgentErrorType where
|
|||
"CMD " *> (CMD <$> parseRead1)
|
||||
<|> "CONN " *> (CONN <$> parseRead1)
|
||||
<|> "SMP " *> (SMP <$> strP)
|
||||
<|> "NTF " *> (NTF <$> strP)
|
||||
<|> "BROKER RESPONSE " *> (BROKER . RESPONSE <$> strP)
|
||||
<|> "BROKER TRANSPORT " *> (BROKER . TRANSPORT <$> transportErrorP)
|
||||
<|> "BROKER " *> (BROKER <$> parseRead1)
|
||||
|
@ -781,6 +841,7 @@ instance StrEncoding AgentErrorType where
|
|||
CMD e -> "CMD " <> bshow e
|
||||
CONN e -> "CONN " <> bshow e
|
||||
SMP e -> "SMP " <> strEncode e
|
||||
NTF e -> "NTF " <> strEncode e
|
||||
BROKER (RESPONSE e) -> "BROKER RESPONSE " <> strEncode e
|
||||
BROKER (TRANSPORT e) -> "BROKER TRANSPORT " <> serializeTransportError e
|
||||
BROKER e -> "BROKER " <> bshow e
|
||||
|
@ -811,8 +872,8 @@ commandP =
|
|||
<|> "INFO " *> infoCmd
|
||||
<|> "SUB" $> ACmd SClient SUB
|
||||
<|> "END" $> ACmd SAgent END
|
||||
<|> "DOWN" $> ACmd SAgent DOWN
|
||||
<|> "UP" $> ACmd SAgent UP
|
||||
<|> "DOWN " *> downsResp
|
||||
<|> "UP " *> upsResp
|
||||
<|> "SEND " *> sendCmd
|
||||
<|> "MID " *> msgIdResp
|
||||
<|> "SENT " *> sentResp
|
||||
|
@ -834,12 +895,15 @@ commandP =
|
|||
acptCmd = ACmd SClient .: ACPT <$> A.takeTill (== ' ') <* A.space <*> A.takeByteString
|
||||
rjctCmd = ACmd SClient . RJCT <$> A.takeByteString
|
||||
infoCmd = ACmd SAgent . INFO <$> A.takeByteString
|
||||
sendCmd = ACmd SClient . SEND <$> A.takeByteString
|
||||
downsResp = ACmd SAgent .: DOWN <$> strP_ <*> connections
|
||||
upsResp = ACmd SAgent .: UP <$> strP_ <*> connections
|
||||
sendCmd = ACmd SClient .: SEND <$> smpP <* A.space <*> A.takeByteString
|
||||
msgIdResp = ACmd SAgent . MID <$> A.decimal
|
||||
sentResp = ACmd SAgent . SENT <$> A.decimal
|
||||
msgErrResp = ACmd SAgent .: MERR <$> A.decimal <* A.space <*> strP
|
||||
message = ACmd SAgent .: MSG <$> msgMetaP <* A.space <*> A.takeByteString
|
||||
message = ACmd SAgent .:. MSG <$> msgMetaP <* A.space <*> smpP <* A.space <*> A.takeByteString
|
||||
ackCmd = ACmd SClient . ACK <$> A.decimal
|
||||
connections = strP `A.sepBy'` A.char ','
|
||||
msgMetaP = do
|
||||
integrity <- strP
|
||||
recipient <- " R=" *> partyMeta A.decimal
|
||||
|
@ -866,22 +930,25 @@ serializeCommand = \case
|
|||
INFO cInfo -> "INFO " <> serializeBinary cInfo
|
||||
SUB -> "SUB"
|
||||
END -> "END"
|
||||
DOWN -> "DOWN"
|
||||
UP -> "UP"
|
||||
SEND msgBody -> "SEND " <> serializeBinary msgBody
|
||||
DOWN srv conns -> B.unwords ["DOWN", strEncode srv, connections conns]
|
||||
UP srv conns -> B.unwords ["UP", strEncode srv, connections conns]
|
||||
SEND msgFlags msgBody -> "SEND " <> smpEncode msgFlags <> " " <> serializeBinary msgBody
|
||||
MID mId -> "MID " <> bshow mId
|
||||
SENT mId -> "SENT " <> bshow mId
|
||||
MERR mId e -> B.unwords ["MERR", bshow mId, strEncode e]
|
||||
MSG msgMeta msgBody -> B.unwords ["MSG", serializeMsgMeta msgMeta, serializeBinary msgBody]
|
||||
MSG msgMeta msgFlags msgBody -> B.unwords ["MSG", serializeMsgMeta msgMeta, smpEncode msgFlags, serializeBinary msgBody]
|
||||
ACK mId -> "ACK " <> bshow mId
|
||||
OFF -> "OFF"
|
||||
DEL -> "DEL"
|
||||
CON -> "CON"
|
||||
ERR e -> "ERR " <> strEncode e
|
||||
OK -> "OK"
|
||||
SUSPENDED -> "SUSPENDED"
|
||||
where
|
||||
showTs :: UTCTime -> ByteString
|
||||
showTs = B.pack . formatISO8601Millis
|
||||
connections :: [ConnId] -> ByteString
|
||||
connections = B.intercalate "," . map strEncode
|
||||
serializeMsgMeta :: MsgMeta -> ByteString
|
||||
serializeMsgMeta MsgMeta {integrity, recipient = (rmId, rTs), broker = (bmId, bTs), sndMsgId} =
|
||||
B.unwords
|
||||
|
@ -933,6 +1000,8 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
|||
ACPT {} -> Right cmd
|
||||
-- ERROR response does not always have connId
|
||||
ERR _ -> Right cmd
|
||||
DOWN {} -> Right cmd
|
||||
UP {} -> Right cmd
|
||||
-- other responses must have connId
|
||||
_
|
||||
| B.null connId -> Left $ CMD NO_CONN
|
||||
|
@ -940,8 +1009,8 @@ tGet party h = liftIO (tGetRaw h) >>= tParseLoadBody
|
|||
|
||||
cmdWithMsgBody :: ACommand p -> m (Either AgentErrorType (ACommand p))
|
||||
cmdWithMsgBody = \case
|
||||
SEND body -> SEND <$$> getBody body
|
||||
MSG msgMeta body -> MSG msgMeta <$$> getBody body
|
||||
SEND msgFlags body -> SEND msgFlags <$$> getBody body
|
||||
MSG msgMeta msgFlags body -> MSG msgMeta msgFlags <$$> getBody body
|
||||
JOIN qUri cInfo -> JOIN qUri <$$> getBody cInfo
|
||||
CONF confId cInfo -> CONF confId <$$> getBody cInfo
|
||||
LET confId cInfo -> LET confId <$$> getBody cInfo
|
||||
|
|
|
@ -31,17 +31,17 @@ import UnliftIO.STM
|
|||
-- | Runs an SMP agent as a TCP service using passed configuration.
|
||||
--
|
||||
-- See a full agent executable here: https://github.com/simplex-chat/simplexmq/blob/master/apps/smp-agent/Main.hs
|
||||
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> m ()
|
||||
runSMPAgent t cfg = do
|
||||
runSMPAgent :: (MonadRandom m, MonadUnliftIO m) => ATransport -> AgentConfig -> InitialAgentServers -> m ()
|
||||
runSMPAgent t cfg initServers = do
|
||||
started <- newEmptyTMVarIO
|
||||
runSMPAgentBlocking t started cfg
|
||||
runSMPAgentBlocking t started cfg initServers
|
||||
|
||||
-- | Runs an SMP agent as a TCP service using passed configuration with signalling.
|
||||
--
|
||||
-- This function uses passed TMVar to signal when the server is ready to accept TCP requests (True)
|
||||
-- and when it is disconnected from the TCP socket once the server thread is killed (False).
|
||||
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> m ()
|
||||
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} = do
|
||||
runSMPAgentBlocking :: (MonadRandom m, MonadUnliftIO m) => ATransport -> TMVar Bool -> AgentConfig -> InitialAgentServers -> m ()
|
||||
runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertificateFile, certificateFile, privateKeyFile} initServers = do
|
||||
runReaderT (smpAgent t) =<< newSMPAgentEnv cfg
|
||||
where
|
||||
smpAgent :: forall c m'. (Transport c, MonadUnliftIO m', MonadReader Env m') => TProxy c -> m' ()
|
||||
|
@ -50,7 +50,7 @@ runSMPAgentBlocking (ATransport t) started cfg@AgentConfig {tcpPort, caCertifica
|
|||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
runTransportServer started tcpPort tlsServerParams $ \(h :: c) -> do
|
||||
liftIO . putLn h $ "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
c <- getAgentClient
|
||||
c <- getAgentClient initServers
|
||||
logConnection c True
|
||||
race_ (connectClient h c) (runAgentClient c)
|
||||
`E.finally` disconnectAgentClient c
|
||||
|
|
|
@ -9,9 +9,7 @@
|
|||
|
||||
module Simplex.Messaging.Agent.Store where
|
||||
|
||||
import Control.Concurrent.STM (TVar)
|
||||
import Control.Exception (Exception)
|
||||
import Crypto.Random (ChaChaDRG)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import Data.Kind (Type)
|
||||
|
@ -19,61 +17,21 @@ import Data.Time (UTCTime)
|
|||
import Data.Type.Equality
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet (RatchetX448, SkippedMsgDiff, SkippedMsgKeys)
|
||||
import Simplex.Messaging.Crypto.Ratchet (RatchetX448)
|
||||
import Simplex.Messaging.Protocol
|
||||
( MsgBody,
|
||||
MsgFlags,
|
||||
MsgId,
|
||||
NotifierId,
|
||||
NtfPrivateSignKey,
|
||||
NtfPublicVerifyKey,
|
||||
RcvDhSecret,
|
||||
RcvNtfDhSecret,
|
||||
RcvPrivateSignKey,
|
||||
SndPrivateSignKey,
|
||||
)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
|
||||
-- * Store management
|
||||
|
||||
-- | Store class type. Defines store access methods for implementations.
|
||||
class Monad m => MonadAgentStore s m where
|
||||
-- Queue and Connection management
|
||||
createRcvConn :: s -> TVar ChaChaDRG -> ConnData -> RcvQueue -> SConnectionMode c -> m ConnId
|
||||
createSndConn :: s -> TVar ChaChaDRG -> ConnData -> SndQueue -> m ConnId
|
||||
getConn :: s -> ConnId -> m SomeConn
|
||||
getRcvConn :: s -> SMPServer -> SMP.RecipientId -> m SomeConn
|
||||
deleteConn :: s -> ConnId -> m ()
|
||||
upgradeRcvConnToDuplex :: s -> ConnId -> SndQueue -> m ()
|
||||
upgradeSndConnToDuplex :: s -> ConnId -> RcvQueue -> m ()
|
||||
setRcvQueueStatus :: s -> RcvQueue -> QueueStatus -> m ()
|
||||
setRcvQueueConfirmedE2E :: s -> RcvQueue -> C.DhSecretX25519 -> m ()
|
||||
setSndQueueStatus :: s -> SndQueue -> QueueStatus -> m ()
|
||||
|
||||
-- Confirmations
|
||||
createConfirmation :: s -> TVar ChaChaDRG -> NewConfirmation -> m ConfirmationId
|
||||
acceptConfirmation :: s -> ConfirmationId -> ConnInfo -> m AcceptedConfirmation
|
||||
getAcceptedConfirmation :: s -> ConnId -> m AcceptedConfirmation
|
||||
removeConfirmations :: s -> ConnId -> m ()
|
||||
|
||||
-- Invitations - sent via Contact connections
|
||||
createInvitation :: s -> TVar ChaChaDRG -> NewInvitation -> m InvitationId
|
||||
getInvitation :: s -> InvitationId -> m Invitation
|
||||
acceptInvitation :: s -> InvitationId -> ConnInfo -> m ()
|
||||
deleteInvitation :: s -> ConnId -> InvitationId -> m ()
|
||||
|
||||
-- Msg management
|
||||
updateRcvIds :: s -> ConnId -> m (InternalId, InternalRcvId, PrevExternalSndId, PrevRcvMsgHash)
|
||||
createRcvMsg :: s -> ConnId -> RcvMsgData -> m ()
|
||||
updateSndIds :: s -> ConnId -> m (InternalId, InternalSndId, PrevSndMsgHash)
|
||||
createSndMsg :: s -> ConnId -> SndMsgData -> m ()
|
||||
getPendingMsgData :: s -> ConnId -> InternalId -> m (Maybe RcvQueue, (AgentMessageType, MsgBody, InternalTs))
|
||||
getPendingMsgs :: s -> ConnId -> m [InternalId]
|
||||
checkRcvMsg :: s -> ConnId -> InternalId -> m ()
|
||||
deleteMsg :: s -> ConnId -> InternalId -> m ()
|
||||
|
||||
-- Double ratchet persistence
|
||||
createRatchetX3dhKeys :: s -> ConnId -> C.PrivateKeyX448 -> C.PrivateKeyX448 -> m ()
|
||||
getRatchetX3dhKeys :: s -> ConnId -> m (C.PrivateKeyX448, C.PrivateKeyX448)
|
||||
createRatchet :: s -> ConnId -> RatchetX448 -> m ()
|
||||
getRatchet :: s -> ConnId -> m RatchetX448
|
||||
getSkippedMsgKeys :: s -> ConnId -> m SkippedMsgKeys
|
||||
updateRatchet :: s -> ConnId -> RatchetX448 -> SkippedMsgDiff -> m ()
|
||||
import Simplex.Messaging.Version
|
||||
|
||||
-- * Queue types
|
||||
|
||||
|
@ -93,7 +51,20 @@ data RcvQueue = RcvQueue
|
|||
-- | sender queue ID
|
||||
sndId :: Maybe SMP.SenderId,
|
||||
-- | queue status
|
||||
status :: QueueStatus
|
||||
status :: QueueStatus,
|
||||
-- | credentials used in context of notifications
|
||||
clientNtfCreds :: Maybe ClientNtfCreds
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
data ClientNtfCreds = ClientNtfCreds
|
||||
{ -- | key pair to be used by the notification server to sign transmissions
|
||||
ntfPublicKey :: NtfPublicVerifyKey,
|
||||
ntfPrivateKey :: NtfPrivateSignKey,
|
||||
-- | queue ID to be used by the notification server for NSUB command
|
||||
notifierId :: NotifierId,
|
||||
-- | shared DH secret used to encrypt/decrypt notification metadata (NMsgMeta) from server to recipient
|
||||
rcvNtfDhSecret :: RcvNtfDhSecret
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
|
@ -173,7 +144,11 @@ instance Eq SomeConn where
|
|||
|
||||
deriving instance Show SomeConn
|
||||
|
||||
newtype ConnData = ConnData {connId :: ConnId}
|
||||
data ConnData = ConnData
|
||||
{ connId :: ConnId,
|
||||
connAgentVersion :: Version,
|
||||
duplexHandshake :: Maybe Bool -- added in agent protocol v2
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- * Confirmation types
|
||||
|
@ -220,30 +195,42 @@ type PrevRcvMsgHash = MsgHash
|
|||
-- | Corresponds to `last_snd_msg_hash` in `connections` table
|
||||
type PrevSndMsgHash = MsgHash
|
||||
|
||||
-- * Message data containers - used on Msg creation to reduce number of parameters
|
||||
-- * Message data containers
|
||||
|
||||
data RcvMsgData = RcvMsgData
|
||||
{ msgMeta :: MsgMeta,
|
||||
msgType :: AgentMessageType,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: MsgBody,
|
||||
internalRcvId :: InternalRcvId,
|
||||
internalHash :: MsgHash,
|
||||
externalPrevSndHash :: MsgHash
|
||||
}
|
||||
|
||||
data RcvMsg = RcvMsg
|
||||
{ internalId :: InternalId,
|
||||
msgMeta :: MsgMeta,
|
||||
msgBody :: MsgBody,
|
||||
userAck :: Bool
|
||||
}
|
||||
|
||||
data SndMsgData = SndMsgData
|
||||
{ internalId :: InternalId,
|
||||
internalSndId :: InternalSndId,
|
||||
internalTs :: InternalTs,
|
||||
msgType :: AgentMessageType,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: MsgBody,
|
||||
internalHash :: MsgHash,
|
||||
prevMsgHash :: MsgHash
|
||||
}
|
||||
|
||||
data PendingMsg = PendingMsg
|
||||
{ connId :: ConnId,
|
||||
msgId :: InternalId
|
||||
data PendingMsgData = PendingMsgData
|
||||
{ msgId :: InternalId,
|
||||
msgType :: AgentMessageType,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: MsgBody,
|
||||
internalTs :: InternalTs
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
|
@ -311,4 +298,6 @@ data StoreError
|
|||
SEX3dhKeysNotFound
|
||||
| -- | Used in `getMsg` that is not implemented/used. TODO remove.
|
||||
SENotImplemented
|
||||
| -- | Used to wrap agent errors inside store operations to avoid race conditions
|
||||
SEAgentError AgentErrorType
|
||||
deriving (Eq, Show, Exception)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -26,6 +26,9 @@ import Database.SQLite.Simple.QQ (sql)
|
|||
import qualified Database.SQLite3 as SQLite3
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220101_initial
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220301_snd_queue_keys
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220322_notifications
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220608_v2
|
||||
import Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode
|
||||
|
||||
data Migration = Migration {name :: String, up :: Text}
|
||||
deriving (Show)
|
||||
|
@ -33,7 +36,10 @@ data Migration = Migration {name :: String, up :: Text}
|
|||
schemaMigrations :: [(String, Query)]
|
||||
schemaMigrations =
|
||||
[ ("20220101_initial", m20220101_initial),
|
||||
("20220301_snd_queue_keys", m20220301_snd_queue_keys)
|
||||
("20220301_snd_queue_keys", m20220301_snd_queue_keys),
|
||||
("20220322_notifications", m20220322_notifications),
|
||||
("20220607_v2", m20220608_v2),
|
||||
("m20220625_v2_ntf_mode", m20220625_v2_ntf_mode)
|
||||
]
|
||||
|
||||
-- | The list of migrations in ascending order by date
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220322_notifications where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
m20220322_notifications :: Query
|
||||
m20220322_notifications =
|
||||
[sql|
|
||||
CREATE TABLE ntf_servers (
|
||||
ntf_host TEXT NOT NULL,
|
||||
ntf_port TEXT NOT NULL,
|
||||
ntf_key_hash BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
PRIMARY KEY (ntf_host, ntf_port)
|
||||
) WITHOUT ROWID;
|
||||
|
||||
CREATE TABLE ntf_tokens (
|
||||
provider TEXT NOT NULL, -- apns
|
||||
device_token TEXT NOT NULL, -- ! this field is mislabeled and is actually saved as binary
|
||||
ntf_host TEXT NOT NULL,
|
||||
ntf_port TEXT NOT NULL,
|
||||
tkn_id BLOB, -- token ID assigned by notifications server
|
||||
tkn_pub_key BLOB NOT NULL, -- client's public key to verify token commands (used by server, for repeat registraions)
|
||||
tkn_priv_key BLOB NOT NULL, -- client's private key to sign token commands
|
||||
tkn_pub_dh_key BLOB NOT NULL, -- client's public DH key (for repeat registraions)
|
||||
tkn_priv_dh_key BLOB NOT NULL, -- client's private DH key (for repeat registraions)
|
||||
tkn_dh_secret BLOB, -- DH secret for e2e encryption of notifications
|
||||
tkn_status TEXT NOT NULL,
|
||||
tkn_action BLOB,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now')), -- this is to check token status periodically to know when it was last checked
|
||||
PRIMARY KEY (provider, device_token, ntf_host, ntf_port),
|
||||
FOREIGN KEY (ntf_host, ntf_port) REFERENCES ntf_servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
) WITHOUT ROWID;
|
||||
|]
|
|
@ -0,0 +1,50 @@
|
|||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220608_v2 where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
m20220608_v2 :: Query
|
||||
m20220608_v2 =
|
||||
[sql|
|
||||
ALTER TABLE messages ADD COLUMN msg_flags TEXT NULL;
|
||||
|
||||
ALTER TABLE conn_confirmations ADD COLUMN smp_reply_queues BLOB NULL;
|
||||
|
||||
ALTER TABLE connections ADD COLUMN duplex_handshake INTEGER NULL DEFAULT 0;
|
||||
|
||||
ALTER TABLE rcv_messages ADD COLUMN user_ack INTEGER NULL DEFAULT 0;
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN ntf_public_key BLOB;
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN ntf_private_key BLOB;
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN ntf_id BLOB;
|
||||
|
||||
ALTER TABLE rcv_queues ADD COLUMN rcv_ntf_dh_secret BLOB;
|
||||
|
||||
CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues (host, port, ntf_id);
|
||||
|
||||
CREATE TABLE ntf_subscriptions (
|
||||
conn_id BLOB NOT NULL,
|
||||
smp_host TEXT NULL,
|
||||
smp_port TEXT NULL,
|
||||
smp_ntf_id BLOB,
|
||||
ntf_host TEXT NOT NULL,
|
||||
ntf_port TEXT NOT NULL,
|
||||
ntf_sub_id BLOB,
|
||||
ntf_sub_status TEXT NOT NULL, -- see NtfAgentSubStatus
|
||||
ntf_sub_action TEXT, -- if there is an action required on this subscription: NtfSubNTFAction
|
||||
ntf_sub_smp_action TEXT, -- action with SMP server: NtfSubSMPAction; only one of this and ntf_sub_action can (should) be not null in same record
|
||||
ntf_sub_action_ts TEXT, -- the earliest time for the action, e.g. checks can be scheduled every X hours
|
||||
updated_by_supervisor INTEGER NOT NULL DEFAULT 0, -- to be checked on updates by workers to not overwrite supervisor command (state still should be updated)
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
PRIMARY KEY (conn_id),
|
||||
FOREIGN KEY (smp_host, smp_port) REFERENCES servers (host, port)
|
||||
ON DELETE SET NULL ON UPDATE CASCADE,
|
||||
FOREIGN KEY (ntf_host, ntf_port) REFERENCES ntf_servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
) WITHOUT ROWID;
|
||||
|]
|
|
@ -0,0 +1,14 @@
|
|||
{-# LANGUAGE QuasiQuotes #-}
|
||||
|
||||
module Simplex.Messaging.Agent.Store.SQLite.Migrations.M20220625_v2_ntf_mode where
|
||||
|
||||
import Database.SQLite.Simple (Query)
|
||||
import Database.SQLite.Simple.QQ (sql)
|
||||
|
||||
m20220625_v2_ntf_mode :: Query
|
||||
m20220625_v2_ntf_mode =
|
||||
[sql|
|
||||
ALTER TABLE ntf_tokens ADD COLUMN ntf_mode TEXT NULL;
|
||||
|
||||
DELETE FROM ntf_tokens;
|
||||
|]
|
|
@ -0,0 +1,194 @@
|
|||
CREATE TABLE migrations(
|
||||
name TEXT NOT NULL,
|
||||
ts TEXT NOT NULL,
|
||||
PRIMARY KEY(name)
|
||||
);
|
||||
CREATE TABLE servers(
|
||||
host TEXT NOT NULL,
|
||||
port TEXT NOT NULL,
|
||||
key_hash BLOB NOT NULL,
|
||||
PRIMARY KEY(host, port)
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE connections(
|
||||
conn_id BLOB NOT NULL PRIMARY KEY,
|
||||
conn_mode TEXT NOT NULL,
|
||||
last_internal_msg_id INTEGER NOT NULL DEFAULT 0,
|
||||
last_internal_rcv_msg_id INTEGER NOT NULL DEFAULT 0,
|
||||
last_internal_snd_msg_id INTEGER NOT NULL DEFAULT 0,
|
||||
last_external_snd_msg_id INTEGER NOT NULL DEFAULT 0,
|
||||
last_rcv_msg_hash BLOB NOT NULL DEFAULT x'',
|
||||
last_snd_msg_hash BLOB NOT NULL DEFAULT x'',
|
||||
smp_agent_version INTEGER NOT NULL DEFAULT 1
|
||||
,
|
||||
duplex_handshake INTEGER NULL DEFAULT 0
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE rcv_queues(
|
||||
host TEXT NOT NULL,
|
||||
port TEXT NOT NULL,
|
||||
rcv_id BLOB NOT NULL,
|
||||
conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
|
||||
rcv_private_key BLOB NOT NULL,
|
||||
rcv_dh_secret BLOB NOT NULL,
|
||||
e2e_priv_key BLOB NOT NULL,
|
||||
e2e_dh_secret BLOB,
|
||||
snd_id BLOB NOT NULL,
|
||||
snd_key BLOB,
|
||||
status TEXT NOT NULL,
|
||||
smp_server_version INTEGER NOT NULL DEFAULT 1,
|
||||
smp_client_version INTEGER,
|
||||
ntf_public_key BLOB,
|
||||
ntf_private_key BLOB,
|
||||
ntf_id BLOB,
|
||||
rcv_ntf_dh_secret BLOB,
|
||||
PRIMARY KEY(host, port, rcv_id),
|
||||
FOREIGN KEY(host, port) REFERENCES servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
UNIQUE(host, port, snd_id)
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE snd_queues(
|
||||
host TEXT NOT NULL,
|
||||
port TEXT NOT NULL,
|
||||
snd_id BLOB NOT NULL,
|
||||
conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
|
||||
snd_private_key BLOB NOT NULL,
|
||||
e2e_dh_secret BLOB NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
smp_server_version INTEGER NOT NULL DEFAULT 1,
|
||||
smp_client_version INTEGER NOT NULL DEFAULT 1,
|
||||
snd_public_key BLOB,
|
||||
e2e_pub_key BLOB,
|
||||
PRIMARY KEY(host, port, snd_id),
|
||||
FOREIGN KEY(host, port) REFERENCES servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE messages(
|
||||
conn_id BLOB NOT NULL REFERENCES connections(conn_id)
|
||||
ON DELETE CASCADE,
|
||||
internal_id INTEGER NOT NULL,
|
||||
internal_ts TEXT NOT NULL,
|
||||
internal_rcv_id INTEGER,
|
||||
internal_snd_id INTEGER,
|
||||
msg_type BLOB NOT NULL, --(H)ELLO,(R)EPLY,(D)ELETE. Should SMP confirmation be saved too?
|
||||
msg_body BLOB NOT NULL DEFAULT x'',
|
||||
msg_flags TEXT NULL,
|
||||
PRIMARY KEY(conn_id, internal_id),
|
||||
FOREIGN KEY(conn_id, internal_rcv_id) REFERENCES rcv_messages
|
||||
ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
|
||||
FOREIGN KEY(conn_id, internal_snd_id) REFERENCES snd_messages
|
||||
ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE rcv_messages(
|
||||
conn_id BLOB NOT NULL,
|
||||
internal_rcv_id INTEGER NOT NULL,
|
||||
internal_id INTEGER NOT NULL,
|
||||
external_snd_id INTEGER NOT NULL,
|
||||
broker_id BLOB NOT NULL,
|
||||
broker_ts TEXT NOT NULL,
|
||||
internal_hash BLOB NOT NULL,
|
||||
external_prev_snd_hash BLOB NOT NULL,
|
||||
integrity BLOB NOT NULL,
|
||||
user_ack INTEGER NULL DEFAULT 0,
|
||||
PRIMARY KEY(conn_id, internal_rcv_id),
|
||||
FOREIGN KEY(conn_id, internal_id) REFERENCES messages
|
||||
ON DELETE CASCADE
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE snd_messages(
|
||||
conn_id BLOB NOT NULL,
|
||||
internal_snd_id INTEGER NOT NULL,
|
||||
internal_id INTEGER NOT NULL,
|
||||
internal_hash BLOB NOT NULL,
|
||||
previous_msg_hash BLOB NOT NULL DEFAULT x'',
|
||||
PRIMARY KEY(conn_id, internal_snd_id),
|
||||
FOREIGN KEY(conn_id, internal_id) REFERENCES messages
|
||||
ON DELETE CASCADE
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE conn_confirmations(
|
||||
confirmation_id BLOB NOT NULL PRIMARY KEY,
|
||||
conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
|
||||
e2e_snd_pub_key BLOB NOT NULL, -- TODO per-queue key. Split?
|
||||
sender_key BLOB NOT NULL, -- TODO per-queue key. Split?
|
||||
ratchet_state BLOB NOT NULL,
|
||||
sender_conn_info BLOB NOT NULL,
|
||||
accepted INTEGER NOT NULL,
|
||||
own_conn_info BLOB,
|
||||
created_at TEXT NOT NULL DEFAULT(datetime('now'))
|
||||
,
|
||||
smp_reply_queues BLOB NULL
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE conn_invitations(
|
||||
invitation_id BLOB NOT NULL PRIMARY KEY,
|
||||
contact_conn_id BLOB NOT NULL REFERENCES connections ON DELETE CASCADE,
|
||||
cr_invitation BLOB NOT NULL,
|
||||
recipient_conn_info BLOB NOT NULL,
|
||||
accepted INTEGER NOT NULL DEFAULT 0,
|
||||
own_conn_info BLOB,
|
||||
created_at TEXT NOT NULL DEFAULT(datetime('now'))
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE ratchets(
|
||||
conn_id BLOB NOT NULL PRIMARY KEY REFERENCES connections
|
||||
ON DELETE CASCADE,
|
||||
-- x3dh keys are not saved on the sending side(the side accepting the connection)
|
||||
x3dh_priv_key_1 BLOB,
|
||||
x3dh_priv_key_2 BLOB,
|
||||
-- ratchet is initially empty on the receiving side(the side offering the connection)
|
||||
ratchet_state BLOB,
|
||||
e2e_version INTEGER NOT NULL DEFAULT 1
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE skipped_messages(
|
||||
skipped_message_id INTEGER PRIMARY KEY,
|
||||
conn_id BLOB NOT NULL REFERENCES ratchets
|
||||
ON DELETE CASCADE,
|
||||
header_key BLOB NOT NULL,
|
||||
msg_n INTEGER NOT NULL,
|
||||
msg_key BLOB NOT NULL
|
||||
);
|
||||
CREATE TABLE ntf_servers(
|
||||
ntf_host TEXT NOT NULL,
|
||||
ntf_port TEXT NOT NULL,
|
||||
ntf_key_hash BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT(datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT(datetime('now')),
|
||||
PRIMARY KEY(ntf_host, ntf_port)
|
||||
) WITHOUT ROWID;
|
||||
CREATE TABLE ntf_tokens(
|
||||
provider TEXT NOT NULL, -- apns
|
||||
device_token TEXT NOT NULL, -- ! this field is mislabeled and is actually saved as binary
|
||||
ntf_host TEXT NOT NULL,
|
||||
ntf_port TEXT NOT NULL,
|
||||
tkn_id BLOB, -- token ID assigned by notifications server
|
||||
tkn_pub_key BLOB NOT NULL, -- client's public key to verify token commands(used by server, for repeat registraions)
|
||||
tkn_priv_key BLOB NOT NULL, -- client's private key to sign token commands
|
||||
tkn_pub_dh_key BLOB NOT NULL, -- client's public DH key(for repeat registraions)
|
||||
tkn_priv_dh_key BLOB NOT NULL, -- client's private DH key(for repeat registraions)
|
||||
tkn_dh_secret BLOB, -- DH secret for e2e encryption of notifications
|
||||
tkn_status TEXT NOT NULL,
|
||||
tkn_action BLOB,
|
||||
created_at TEXT NOT NULL DEFAULT(datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT(datetime('now')),
|
||||
ntf_mode TEXT NULL, -- this is to check token status periodically to know when it was last checked
|
||||
PRIMARY KEY(provider, device_token, ntf_host, ntf_port),
|
||||
FOREIGN KEY(ntf_host, ntf_port) REFERENCES ntf_servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
) WITHOUT ROWID;
|
||||
CREATE UNIQUE INDEX idx_rcv_queues_ntf ON rcv_queues(host, port, ntf_id);
|
||||
CREATE TABLE ntf_subscriptions(
|
||||
conn_id BLOB NOT NULL,
|
||||
smp_host TEXT NULL,
|
||||
smp_port TEXT NULL,
|
||||
smp_ntf_id BLOB,
|
||||
ntf_host TEXT NOT NULL,
|
||||
ntf_port TEXT NOT NULL,
|
||||
ntf_sub_id BLOB,
|
||||
ntf_sub_status TEXT NOT NULL, -- see NtfAgentSubStatus
|
||||
ntf_sub_action TEXT, -- if there is an action required on this subscription: NtfSubNTFAction
|
||||
ntf_sub_smp_action TEXT, -- action with SMP server: NtfSubSMPAction; only one of this and ntf_sub_action can(should) be not null in same record
|
||||
ntf_sub_action_ts TEXT, -- the earliest time for the action, e.g. checks can be scheduled every X hours
|
||||
updated_by_supervisor INTEGER NOT NULL DEFAULT 0, -- to be checked on updates by workers to not overwrite supervisor command(state still should be updated)
|
||||
created_at TEXT NOT NULL DEFAULT(datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT(datetime('now')),
|
||||
PRIMARY KEY(conn_id),
|
||||
FOREIGN KEY(smp_host, smp_port) REFERENCES servers(host, port)
|
||||
ON DELETE SET NULL ON UPDATE CASCADE,
|
||||
FOREIGN KEY(ntf_host, ntf_port) REFERENCES ntf_servers
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
) WITHOUT ROWID;
|
|
@ -1,6 +1,7 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
@ -23,27 +24,30 @@
|
|||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
|
||||
module Simplex.Messaging.Client
|
||||
( -- * Connect (disconnect) client to (from) SMP server
|
||||
SMPClient (sessionId),
|
||||
getSMPClient,
|
||||
closeSMPClient,
|
||||
ProtocolClient (thVersion, sessionId),
|
||||
SMPClient,
|
||||
getProtocolClient,
|
||||
closeProtocolClient,
|
||||
|
||||
-- * SMP protocol command functions
|
||||
createSMPQueue,
|
||||
subscribeSMPQueue,
|
||||
getSMPMessage,
|
||||
subscribeSMPQueueNotifications,
|
||||
secureSMPQueue,
|
||||
enableSMPQueueNotifications,
|
||||
disableSMPQueueNotifications,
|
||||
sendSMPMessage,
|
||||
ackSMPMessage,
|
||||
suspendSMPQueue,
|
||||
deleteSMPQueue,
|
||||
sendSMPCommand,
|
||||
sendProtocolCommand,
|
||||
|
||||
-- * Supporting types and client configuration
|
||||
SMPClientError (..),
|
||||
SMPClientConfig (..),
|
||||
smpDefaultConfig,
|
||||
SMPServerTransmission,
|
||||
ProtocolClientError (..),
|
||||
ProtocolClientConfig (..),
|
||||
defaultClientConfig,
|
||||
ServerTransmission,
|
||||
)
|
||||
where
|
||||
|
||||
|
@ -52,7 +56,7 @@ import Control.Concurrent.Async
|
|||
import Control.Concurrent.STM
|
||||
import Control.Exception
|
||||
import Control.Monad
|
||||
import Control.Monad.Trans.Class
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
|
@ -60,7 +64,7 @@ import Data.Maybe (fromMaybe)
|
|||
import Network.Socket (ServiceName)
|
||||
import Numeric.Natural
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport
|
||||
|
@ -68,84 +72,90 @@ import Simplex.Messaging.Transport.Client (runTransportClient)
|
|||
import Simplex.Messaging.Transport.KeepAlive
|
||||
import Simplex.Messaging.Transport.WebSockets (WS)
|
||||
import Simplex.Messaging.Util (bshow, liftError, raceAny_)
|
||||
import Simplex.Messaging.Version
|
||||
import System.Timeout (timeout)
|
||||
|
||||
-- | 'SMPClient' is a handle used to send commands to a specific SMP server.
|
||||
--
|
||||
-- The only exported selector is blockSize that is negotiated
|
||||
-- with the server during the TCP transport handshake.
|
||||
--
|
||||
-- Use 'getSMPClient' to connect to an SMP server and create a client handle.
|
||||
data SMPClient = SMPClient
|
||||
data ProtocolClient msg = ProtocolClient
|
||||
{ action :: Async (),
|
||||
connected :: TVar Bool,
|
||||
sessionId :: SessionId,
|
||||
smpServer :: SMPServer,
|
||||
thVersion :: Version,
|
||||
protocolServer :: ProtoServer msg,
|
||||
tcpTimeout :: Int,
|
||||
clientCorrId :: TVar Natural,
|
||||
sentCommands :: TMap CorrId Request,
|
||||
sentCommands :: TMap CorrId (Request msg),
|
||||
sndQ :: TBQueue SentRawTransmission,
|
||||
rcvQ :: TBQueue (SignedTransmission BrokerMsg),
|
||||
msgQ :: TBQueue SMPServerTransmission
|
||||
rcvQ :: TBQueue (SignedTransmission msg),
|
||||
msgQ :: Maybe (TBQueue (ServerTransmission msg))
|
||||
}
|
||||
|
||||
-- | Type synonym for transmission from some SPM server queue.
|
||||
type SMPServerTransmission = (SMPServer, SessionId, RecipientId, BrokerMsg)
|
||||
type SMPClient = ProtocolClient SMP.BrokerMsg
|
||||
|
||||
-- | SMP client configuration.
|
||||
data SMPClientConfig = SMPClientConfig
|
||||
-- | Type synonym for transmission from some SPM server queue.
|
||||
type ServerTransmission msg = (ProtoServer msg, Version, SessionId, QueueId, msg)
|
||||
|
||||
-- | protocol client configuration.
|
||||
data ProtocolClientConfig = ProtocolClientConfig
|
||||
{ -- | size of TBQueue to use for server commands and responses
|
||||
qSize :: Natural,
|
||||
-- | default SMP server port if port is not specified in SMPServer
|
||||
-- | default server port if port is not specified in ProtocolServer
|
||||
defaultTransport :: (ServiceName, ATransport),
|
||||
-- | timeout of TCP commands (microseconds)
|
||||
tcpTimeout :: Int,
|
||||
-- | TCP keep-alive options, Nothing to skip enabling keep-alive
|
||||
tcpKeepAlive :: Maybe KeepAliveOpts,
|
||||
-- | period for SMP ping commands (microseconds)
|
||||
smpPing :: Int
|
||||
smpPing :: Int,
|
||||
-- | SMP client-server protocol version range
|
||||
smpServerVRange :: VersionRange
|
||||
}
|
||||
|
||||
-- | Default SMP client configuration.
|
||||
smpDefaultConfig :: SMPClientConfig
|
||||
smpDefaultConfig =
|
||||
SMPClientConfig
|
||||
-- | Default protocol client configuration.
|
||||
defaultClientConfig :: ProtocolClientConfig
|
||||
defaultClientConfig =
|
||||
ProtocolClientConfig
|
||||
{ qSize = 64,
|
||||
defaultTransport = ("5223", transport @TLS),
|
||||
defaultTransport = ("443", transport @TLS),
|
||||
tcpTimeout = 5_000_000,
|
||||
tcpKeepAlive = Just defaultKeepAliveOpts,
|
||||
smpPing = 600_000_000 -- 10min
|
||||
smpPing = 600_000_000, -- 10min
|
||||
smpServerVRange = supportedSMPServerVRange
|
||||
}
|
||||
|
||||
data Request = Request
|
||||
data Request msg = Request
|
||||
{ queueId :: QueueId,
|
||||
responseVar :: TMVar Response
|
||||
responseVar :: TMVar (Response msg)
|
||||
}
|
||||
|
||||
type Response = Either SMPClientError BrokerMsg
|
||||
type Response msg = Either ProtocolClientError msg
|
||||
|
||||
-- | Connects to 'SMPServer' using passed client configuration
|
||||
-- | Connects to 'ProtocolServer' using passed client configuration
|
||||
-- and queue for messages and notifications.
|
||||
--
|
||||
-- A single queue can be used for multiple 'SMPClient' instances,
|
||||
-- as 'SMPServerTransmission' includes server information.
|
||||
getSMPClient :: SMPServer -> SMPClientConfig -> TBQueue SMPServerTransmission -> IO () -> IO (Either SMPClientError SMPClient)
|
||||
getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smpPing} msgQ disconnected =
|
||||
atomically mkSMPClient >>= runClient useTransport
|
||||
getProtocolClient :: forall msg. Protocol msg => ProtoServer msg -> ProtocolClientConfig -> Maybe (TBQueue (ServerTransmission msg)) -> IO () -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
getProtocolClient protocolServer cfg@ProtocolClientConfig {qSize, tcpTimeout, tcpKeepAlive, smpPing, smpServerVRange} msgQ disconnected =
|
||||
(atomically mkProtocolClient >>= runClient useTransport)
|
||||
`catch` \(e :: IOException) -> pure . Left $ PCEIOError e
|
||||
where
|
||||
mkSMPClient :: STM SMPClient
|
||||
mkSMPClient = do
|
||||
mkProtocolClient :: STM (ProtocolClient msg)
|
||||
mkProtocolClient = do
|
||||
connected <- newTVar False
|
||||
clientCorrId <- newTVar 0
|
||||
sentCommands <- TM.empty
|
||||
sndQ <- newTBQueue qSize
|
||||
rcvQ <- newTBQueue qSize
|
||||
return
|
||||
SMPClient
|
||||
ProtocolClient
|
||||
{ action = undefined,
|
||||
sessionId = undefined,
|
||||
thVersion = undefined,
|
||||
connected,
|
||||
smpServer,
|
||||
protocolServer,
|
||||
tcpTimeout,
|
||||
clientCorrId,
|
||||
sentCommands,
|
||||
|
@ -154,50 +164,51 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smp
|
|||
msgQ
|
||||
}
|
||||
|
||||
runClient :: (ServiceName, ATransport) -> SMPClient -> IO (Either SMPClientError SMPClient)
|
||||
runClient :: (ServiceName, ATransport) -> ProtocolClient msg -> IO (Either ProtocolClientError (ProtocolClient msg))
|
||||
runClient (port', ATransport t) c = do
|
||||
thVar <- newEmptyTMVarIO
|
||||
action <-
|
||||
async $
|
||||
runTransportClient (host smpServer) port' (keyHash smpServer) tcpKeepAlive (client t c thVar)
|
||||
`finally` atomically (putTMVar thVar $ Left SMPNetworkError)
|
||||
runTransportClient (host protocolServer) port' (Just $ keyHash protocolServer) tcpKeepAlive (client t c thVar)
|
||||
`finally` atomically (putTMVar thVar $ Left PCENetworkError)
|
||||
th_ <- tcpTimeout `timeout` atomically (takeTMVar thVar)
|
||||
pure $ case th_ of
|
||||
Just (Right THandle {sessionId}) -> Right c {action, sessionId}
|
||||
Just (Right THandle {sessionId, thVersion}) -> Right c {action, sessionId, thVersion}
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left SMPNetworkError
|
||||
Nothing -> Left PCENetworkError
|
||||
|
||||
useTransport :: (ServiceName, ATransport)
|
||||
useTransport = case port smpServer of
|
||||
useTransport = case port protocolServer of
|
||||
"" -> defaultTransport cfg
|
||||
"80" -> ("80", transport @WS)
|
||||
p -> (p, transport @TLS)
|
||||
|
||||
client :: forall c. Transport c => TProxy c -> SMPClient -> TMVar (Either SMPClientError (THandle c)) -> c -> IO ()
|
||||
client :: forall c. Transport c => TProxy c -> ProtocolClient msg -> TMVar (Either ProtocolClientError (THandle c)) -> c -> IO ()
|
||||
client _ c thVar h =
|
||||
runExceptT (clientHandshake h $ keyHash smpServer) >>= \case
|
||||
Left e -> atomically . putTMVar thVar . Left $ SMPTransportError e
|
||||
Right th@THandle {sessionId} -> do
|
||||
runExceptT (protocolClientHandshake @msg h (keyHash protocolServer) smpServerVRange) >>= \case
|
||||
Left e -> atomically . putTMVar thVar . Left $ PCETransportError e
|
||||
Right th@THandle {sessionId, thVersion} -> do
|
||||
atomically $ do
|
||||
writeTVar (connected c) True
|
||||
putTMVar thVar $ Right th
|
||||
let c' = c {sessionId} :: SMPClient
|
||||
let c' = c {sessionId, thVersion} :: ProtocolClient msg
|
||||
-- TODO remove ping if 0 is passed (or Nothing?)
|
||||
raceAny_ [send c' th, process c', receive c' th, ping c']
|
||||
`finally` disconnected
|
||||
|
||||
send :: Transport c => SMPClient -> THandle c -> IO ()
|
||||
send SMPClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
|
||||
send :: Transport c => ProtocolClient msg -> THandle c -> IO ()
|
||||
send ProtocolClient {sndQ} h = forever $ atomically (readTBQueue sndQ) >>= tPut h
|
||||
|
||||
receive :: Transport c => SMPClient -> THandle c -> IO ()
|
||||
receive SMPClient {rcvQ} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
|
||||
receive :: Transport c => ProtocolClient msg -> THandle c -> IO ()
|
||||
receive ProtocolClient {rcvQ} h = forever $ tGet h >>= atomically . writeTBQueue rcvQ
|
||||
|
||||
ping :: SMPClient -> IO ()
|
||||
ping :: ProtocolClient msg -> IO ()
|
||||
ping c = forever $ do
|
||||
threadDelay smpPing
|
||||
runExceptT $ sendSMPCommand c Nothing "" PING
|
||||
runExceptT $ sendProtocolCommand c Nothing "" protocolPing
|
||||
|
||||
process :: SMPClient -> IO ()
|
||||
process SMPClient {sessionId, rcvQ, sentCommands} = forever $ do
|
||||
process :: ProtocolClient msg -> IO ()
|
||||
process c@ProtocolClient {rcvQ, sentCommands} = forever $ do
|
||||
(_, _, (corrId, qId, respOrErr)) <- atomically $ readTBQueue rcvQ
|
||||
if B.null $ bs corrId
|
||||
then sendMsg qId respOrErr
|
||||
|
@ -209,45 +220,48 @@ getSMPClient smpServer cfg@SMPClientConfig {qSize, tcpTimeout, tcpKeepAlive, smp
|
|||
putTMVar responseVar $
|
||||
if queueId == qId
|
||||
then case respOrErr of
|
||||
Left e -> Left $ SMPResponseError e
|
||||
Right (ERR e) -> Left $ SMPServerError e
|
||||
Right r -> Right r
|
||||
else Left SMPUnexpectedResponse
|
||||
Left e -> Left $ PCEResponseError e
|
||||
Right r -> case protocolError r of
|
||||
Just e -> Left $ PCEProtocolError e
|
||||
_ -> Right r
|
||||
else Left . PCEUnexpectedResponse $ bshow respOrErr
|
||||
where
|
||||
sendMsg :: QueueId -> Either ErrorType BrokerMsg -> IO ()
|
||||
sendMsg :: QueueId -> Either ErrorType msg -> IO ()
|
||||
sendMsg qId = \case
|
||||
Right cmd -> atomically $ writeTBQueue msgQ (smpServer, sessionId, qId, cmd)
|
||||
Right msg -> atomically $ mapM_ (`writeTBQueue` serverTransmission c qId msg) msgQ
|
||||
-- TODO send everything else to errQ and log in agent
|
||||
_ -> return ()
|
||||
|
||||
-- | Disconnects SMP client from the server and terminates client threads.
|
||||
closeSMPClient :: SMPClient -> IO ()
|
||||
closeSMPClient = uninterruptibleCancel . action
|
||||
-- | Disconnects client from the server and terminates client threads.
|
||||
closeProtocolClient :: ProtocolClient msg -> IO ()
|
||||
closeProtocolClient = uninterruptibleCancel . action
|
||||
|
||||
-- | SMP client error type.
|
||||
data SMPClientError
|
||||
data ProtocolClientError
|
||||
= -- | Correctly parsed SMP server ERR response.
|
||||
-- This error is forwarded to the agent client as `ERR SMP err`.
|
||||
SMPServerError ErrorType
|
||||
PCEProtocolError ErrorType
|
||||
| -- | Invalid server response that failed to parse.
|
||||
-- Forwarded to the agent client as `ERR BROKER RESPONSE`.
|
||||
SMPResponseError ErrorType
|
||||
PCEResponseError ErrorType
|
||||
| -- | Different response from what is expected to a certain SMP command,
|
||||
-- e.g. server should respond `IDS` or `ERR` to `NEW` command,
|
||||
-- other responses would result in this error.
|
||||
-- Forwarded to the agent client as `ERR BROKER UNEXPECTED`.
|
||||
SMPUnexpectedResponse
|
||||
PCEUnexpectedResponse ByteString
|
||||
| -- | Used for TCP connection and command response timeouts.
|
||||
-- Forwarded to the agent client as `ERR BROKER TIMEOUT`.
|
||||
SMPResponseTimeout
|
||||
PCEResponseTimeout
|
||||
| -- | Failure to establish TCP connection.
|
||||
-- Forwarded to the agent client as `ERR BROKER NETWORK`.
|
||||
SMPNetworkError
|
||||
PCENetworkError
|
||||
| -- | TCP transport handshake or some other transport error.
|
||||
-- Forwarded to the agent client as `ERR BROKER TRANSPORT e`.
|
||||
SMPTransportError TransportError
|
||||
PCETransportError TransportError
|
||||
| -- | Error when cryptographically "signing" the command.
|
||||
SMPSignatureError C.CryptoError
|
||||
PCESignatureError C.CryptoError
|
||||
| -- | IO Error
|
||||
PCEIOError IOException
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
-- | Create a new SMP queue.
|
||||
|
@ -258,92 +272,119 @@ createSMPQueue ::
|
|||
RcvPrivateSignKey ->
|
||||
RcvPublicVerifyKey ->
|
||||
RcvPublicDhKey ->
|
||||
ExceptT SMPClientError IO QueueIdsKeys
|
||||
ExceptT ProtocolClientError IO QueueIdsKeys
|
||||
createSMPQueue c rpKey rKey dhKey =
|
||||
sendSMPCommand c (Just rpKey) "" (NEW rKey dhKey) >>= \case
|
||||
IDS qik -> pure qik
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
-- | Subscribe to the SMP queue.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue
|
||||
subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT SMPClientError IO ()
|
||||
subscribeSMPQueue c@SMPClient {smpServer, sessionId, msgQ} rpKey rId =
|
||||
subscribeSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
|
||||
subscribeSMPQueue c rpKey rId =
|
||||
sendSMPCommand c (Just rpKey) rId SUB >>= \case
|
||||
OK -> return ()
|
||||
cmd@MSG {} ->
|
||||
lift . atomically $ writeTBQueue msgQ (smpServer, sessionId, rId, cmd)
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
cmd@MSG {} -> writeSMPMessage c rId cmd
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
writeSMPMessage :: SMPClient -> RecipientId -> BrokerMsg -> ExceptT ProtocolClientError IO ()
|
||||
writeSMPMessage c rId msg =
|
||||
liftIO . atomically $ mapM_ (`writeTBQueue` serverTransmission c rId msg) (msgQ c)
|
||||
|
||||
serverTransmission :: ProtocolClient msg -> RecipientId -> msg -> ServerTransmission msg
|
||||
serverTransmission ProtocolClient {protocolServer, thVersion, sessionId} entityId message =
|
||||
(protocolServer, thVersion, sessionId, entityId, message)
|
||||
|
||||
-- | Get message from SMP queue. The server returns ERR PROHIBITED if a client uses SUB and GET via the same transport connection for the same queue
|
||||
--
|
||||
-- https://github.covm/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#receive-a-message-from-the-queue
|
||||
getSMPMessage :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO (Maybe RcvMessage)
|
||||
getSMPMessage c rpKey rId =
|
||||
sendSMPCommand c (Just rpKey) rId GET >>= \case
|
||||
OK -> pure Nothing
|
||||
cmd@(MSG msg) -> do
|
||||
writeSMPMessage c rId cmd
|
||||
pure $ Just msg
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
-- | Subscribe to the SMP queue notifications.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#subscribe-to-queue-notifications
|
||||
subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT SMPClientError IO ()
|
||||
subscribeSMPQueueNotifications :: SMPClient -> NtfPrivateSignKey -> NotifierId -> ExceptT ProtocolClientError IO ()
|
||||
subscribeSMPQueueNotifications = okSMPCommand NSUB
|
||||
|
||||
-- | Secure the SMP queue by adding a sender public key.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#secure-queue-command
|
||||
secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT SMPClientError IO ()
|
||||
secureSMPQueue :: SMPClient -> RcvPrivateSignKey -> RecipientId -> SndPublicVerifyKey -> ExceptT ProtocolClientError IO ()
|
||||
secureSMPQueue c rpKey rId senderKey = okSMPCommand (KEY senderKey) c rpKey rId
|
||||
|
||||
-- | Enable notifications for the queue for push notifications server.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#enable-notifications-command
|
||||
enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> ExceptT SMPClientError IO NotifierId
|
||||
enableSMPQueueNotifications c rpKey rId notifierKey =
|
||||
sendSMPCommand c (Just rpKey) rId (NKEY notifierKey) >>= \case
|
||||
NID nId -> pure nId
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
enableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> ExceptT ProtocolClientError IO (NotifierId, RcvNtfPublicDhKey)
|
||||
enableSMPQueueNotifications c rpKey rId notifierKey rcvNtfPublicDhKey =
|
||||
sendSMPCommand c (Just rpKey) rId (NKEY notifierKey rcvNtfPublicDhKey) >>= \case
|
||||
NID nId rcvNtfSrvPublicDhKey -> pure (nId, rcvNtfSrvPublicDhKey)
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
-- | Disable notifications for the queue for push notifications server.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#disable-notifications-command
|
||||
disableSMPQueueNotifications :: SMPClient -> RcvPrivateSignKey -> RecipientId -> ExceptT ProtocolClientError IO ()
|
||||
disableSMPQueueNotifications = okSMPCommand NDEL
|
||||
|
||||
-- | Send SMP message.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#send-message
|
||||
sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgBody -> ExceptT SMPClientError IO ()
|
||||
sendSMPMessage c spKey sId msg =
|
||||
sendSMPCommand c spKey sId (SEND msg) >>= \case
|
||||
sendSMPMessage :: SMPClient -> Maybe SndPrivateSignKey -> SenderId -> MsgFlags -> MsgBody -> ExceptT ProtocolClientError IO ()
|
||||
sendSMPMessage c spKey sId flags msg =
|
||||
sendSMPCommand c spKey sId (SEND flags msg) >>= \case
|
||||
OK -> pure ()
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
-- | Acknowledge message delivery (server deletes the message).
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#acknowledge-message-delivery
|
||||
ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
ackSMPMessage c@SMPClient {smpServer, sessionId, msgQ} rpKey rId =
|
||||
sendSMPCommand c (Just rpKey) rId ACK >>= \case
|
||||
ackSMPMessage :: SMPClient -> RcvPrivateSignKey -> QueueId -> MsgId -> ExceptT ProtocolClientError IO ()
|
||||
ackSMPMessage c rpKey rId msgId =
|
||||
sendSMPCommand c (Just rpKey) rId (ACK msgId) >>= \case
|
||||
OK -> return ()
|
||||
cmd@MSG {} ->
|
||||
lift . atomically $ writeTBQueue msgQ (smpServer, sessionId, rId, cmd)
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
cmd@MSG {} -> writeSMPMessage c rId cmd
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
-- | Irreversibly suspend SMP queue.
|
||||
-- The existing messages from the queue will still be delivered.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#suspend-queue
|
||||
suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
suspendSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
|
||||
suspendSMPQueue = okSMPCommand OFF
|
||||
|
||||
-- | Irreversibly delete SMP queue and all messages in it.
|
||||
--
|
||||
-- https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#delete-queue
|
||||
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
deleteSMPQueue :: SMPClient -> RcvPrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
|
||||
deleteSMPQueue = okSMPCommand DEL
|
||||
|
||||
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT SMPClientError IO ()
|
||||
okSMPCommand :: PartyI p => Command p -> SMPClient -> C.APrivateSignKey -> QueueId -> ExceptT ProtocolClientError IO ()
|
||||
okSMPCommand cmd c pKey qId =
|
||||
sendSMPCommand c (Just pKey) qId cmd >>= \case
|
||||
OK -> return ()
|
||||
_ -> throwE SMPUnexpectedResponse
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
-- | Send SMP command
|
||||
-- TODO sign all requests (SEND of SMP confirmation would be signed with the same key that is passed to the recipient)
|
||||
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT SMPClientError IO BrokerMsg
|
||||
sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, sessionId, tcpTimeout} pKey qId cmd = do
|
||||
sendSMPCommand :: PartyI p => SMPClient -> Maybe C.APrivateSignKey -> QueueId -> Command p -> ExceptT ProtocolClientError IO BrokerMsg
|
||||
sendSMPCommand c pKey qId cmd = sendProtocolCommand c pKey qId (Cmd sParty cmd)
|
||||
|
||||
-- | Send Protocol command
|
||||
sendProtocolCommand :: forall msg. ProtocolEncoding (ProtoCommand msg) => ProtocolClient msg -> Maybe C.APrivateSignKey -> QueueId -> ProtoCommand msg -> ExceptT ProtocolClientError IO msg
|
||||
sendProtocolCommand ProtocolClient {sndQ, sentCommands, clientCorrId, sessionId, thVersion, tcpTimeout} pKey qId cmd = do
|
||||
corrId <- lift_ getNextCorrId
|
||||
t <- signTransmission $ encodeTransmission sessionId (corrId, qId, cmd)
|
||||
t <- signTransmission $ encodeTransmission thVersion sessionId (corrId, qId, cmd)
|
||||
ExceptT $ sendRecv corrId t
|
||||
where
|
||||
lift_ :: STM a -> ExceptT SMPClientError IO a
|
||||
lift_ :: STM a -> ExceptT ProtocolClientError IO a
|
||||
lift_ action = ExceptT $ Right <$> atomically action
|
||||
|
||||
getNextCorrId :: STM CorrId
|
||||
|
@ -351,20 +392,20 @@ sendSMPCommand SMPClient {sndQ, sentCommands, clientCorrId, sessionId, tcpTimeou
|
|||
i <- stateTVar clientCorrId $ \i -> (i, i + 1)
|
||||
pure . CorrId $ bshow i
|
||||
|
||||
signTransmission :: ByteString -> ExceptT SMPClientError IO SentRawTransmission
|
||||
signTransmission :: ByteString -> ExceptT ProtocolClientError IO SentRawTransmission
|
||||
signTransmission t = case pKey of
|
||||
Nothing -> return (Nothing, t)
|
||||
Just pk -> do
|
||||
sig <- liftError SMPSignatureError $ C.sign pk t
|
||||
sig <- liftError PCESignatureError $ C.sign pk t
|
||||
return (Just sig, t)
|
||||
|
||||
-- two separate "atomically" needed to avoid blocking
|
||||
sendRecv :: CorrId -> SentRawTransmission -> IO Response
|
||||
sendRecv :: CorrId -> SentRawTransmission -> IO (Response msg)
|
||||
sendRecv corrId t = atomically (send corrId t) >>= withTimeout . atomically . takeTMVar
|
||||
where
|
||||
withTimeout a = fromMaybe (Left SMPResponseTimeout) <$> timeout tcpTimeout a
|
||||
withTimeout a = fromMaybe (Left PCEResponseTimeout) <$> timeout tcpTimeout a
|
||||
|
||||
send :: CorrId -> SentRawTransmission -> STM (TMVar Response)
|
||||
send :: CorrId -> SentRawTransmission -> STM (TMVar (Response msg))
|
||||
send corrId t = do
|
||||
r <- newEmptyTMVar
|
||||
TM.insert corrId (Request qId r) sentCommands
|
||||
|
|
|
@ -0,0 +1,303 @@
|
|||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
module Simplex.Messaging.Client.Agent where
|
||||
|
||||
import Control.Concurrent (forkIO)
|
||||
import Control.Concurrent.Async (Async, uninterruptibleCancel)
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Set (Set)
|
||||
import Data.Text.Encoding
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Protocol (BrokerMsg, ProtocolServer (..), QueueId, SMPServer)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Util (catchAll_, tryE, unlessM, ($>>=))
|
||||
import System.Timeout (timeout)
|
||||
import UnliftIO (async, forConcurrently_)
|
||||
import UnliftIO.Exception (Exception)
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
||||
type SMPClientVar = TMVar (Either ProtocolClientError SMPClient)
|
||||
|
||||
data SMPClientAgentEvent
|
||||
= CAConnected SMPServer
|
||||
| CADisconnected SMPServer (Set SMPSub)
|
||||
| CAReconnected SMPServer
|
||||
| CAResubscribed SMPServer SMPSub
|
||||
| CASubError SMPServer SMPSub ProtocolClientError
|
||||
|
||||
data SMPSubParty = SPRecipient | SPNotifier
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
type SMPSub = (SMPSubParty, QueueId)
|
||||
|
||||
-- type SMPServerSub = (SMPServer, SMPSub)
|
||||
|
||||
data SMPClientAgentConfig = SMPClientAgentConfig
|
||||
{ smpCfg :: ProtocolClientConfig,
|
||||
reconnectInterval :: RetryInterval,
|
||||
msgQSize :: Natural,
|
||||
agentQSize :: Natural
|
||||
}
|
||||
|
||||
defaultSMPClientAgentConfig :: SMPClientAgentConfig
|
||||
defaultSMPClientAgentConfig =
|
||||
SMPClientAgentConfig
|
||||
{ smpCfg = defaultClientConfig {defaultTransport = ("5223", transport @TLS)},
|
||||
reconnectInterval =
|
||||
RetryInterval
|
||||
{ initialInterval = second,
|
||||
increaseAfter = 10 * second,
|
||||
maxInterval = 10 * second
|
||||
},
|
||||
msgQSize = 64,
|
||||
agentQSize = 64
|
||||
}
|
||||
where
|
||||
second = 1000000
|
||||
|
||||
data SMPClientAgent = SMPClientAgent
|
||||
{ agentCfg :: SMPClientAgentConfig,
|
||||
msgQ :: TBQueue (ServerTransmission BrokerMsg),
|
||||
agentQ :: TBQueue SMPClientAgentEvent,
|
||||
smpClients :: TMap SMPServer SMPClientVar,
|
||||
srvSubs :: TMap SMPServer (TMap SMPSub C.APrivateSignKey),
|
||||
pendingSrvSubs :: TMap SMPServer (TMap SMPSub C.APrivateSignKey),
|
||||
reconnections :: TVar [Async ()],
|
||||
asyncClients :: TVar [Async ()]
|
||||
}
|
||||
|
||||
newtype InternalException e = InternalException {unInternalException :: e}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Exception e => Exception (InternalException e)
|
||||
|
||||
instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where
|
||||
withRunInIO :: ((forall a. ExceptT e m a -> IO a) -> IO b) -> ExceptT e m b
|
||||
withRunInIO exceptToIO =
|
||||
withExceptT unInternalException . ExceptT . E.try $
|
||||
withRunInIO $ \run ->
|
||||
exceptToIO $ run . (either (E.throwIO . InternalException) return <=< runExceptT)
|
||||
|
||||
newSMPClientAgent :: SMPClientAgentConfig -> STM SMPClientAgent
|
||||
newSMPClientAgent agentCfg@SMPClientAgentConfig {msgQSize, agentQSize} = do
|
||||
msgQ <- newTBQueue msgQSize
|
||||
agentQ <- newTBQueue agentQSize
|
||||
smpClients <- TM.empty
|
||||
srvSubs <- TM.empty
|
||||
pendingSrvSubs <- TM.empty
|
||||
reconnections <- newTVar []
|
||||
asyncClients <- newTVar []
|
||||
pure SMPClientAgent {agentCfg, msgQ, agentQ, smpClients, srvSubs, pendingSrvSubs, reconnections, asyncClients}
|
||||
|
||||
getSMPServerClient' :: SMPClientAgent -> SMPServer -> ExceptT ProtocolClientError IO SMPClient
|
||||
getSMPServerClient' ca@SMPClientAgent {agentCfg, smpClients, msgQ} srv =
|
||||
atomically getClientVar >>= either newSMPClient waitForSMPClient
|
||||
where
|
||||
getClientVar :: STM (Either SMPClientVar SMPClientVar)
|
||||
getClientVar = maybe (Left <$> newClientVar) (pure . Right) =<< TM.lookup srv smpClients
|
||||
|
||||
newClientVar :: STM SMPClientVar
|
||||
newClientVar = do
|
||||
smpVar <- newEmptyTMVar
|
||||
TM.insert srv smpVar smpClients
|
||||
pure smpVar
|
||||
|
||||
waitForSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient
|
||||
waitForSMPClient smpVar = do
|
||||
let ProtocolClientConfig {tcpTimeout} = smpCfg agentCfg
|
||||
smpClient_ <- liftIO $ tcpTimeout `timeout` atomically (readTMVar smpVar)
|
||||
liftEither $ case smpClient_ of
|
||||
Just (Right smpClient) -> Right smpClient
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left PCEResponseTimeout
|
||||
|
||||
newSMPClient :: SMPClientVar -> ExceptT ProtocolClientError IO SMPClient
|
||||
newSMPClient smpVar = tryConnectClient pure tryConnectAsync
|
||||
where
|
||||
tryConnectClient :: (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO () -> ExceptT ProtocolClientError IO a
|
||||
tryConnectClient successAction retryAction =
|
||||
tryE connectClient >>= \r -> case r of
|
||||
Right smp -> do
|
||||
logInfo . decodeUtf8 $ "Agent connected to " <> showServer srv
|
||||
atomically $ putTMVar smpVar r
|
||||
successAction smp
|
||||
Left e -> do
|
||||
if e == PCENetworkError || e == PCEResponseTimeout
|
||||
then retryAction
|
||||
else atomically $ do
|
||||
putTMVar smpVar (Left e)
|
||||
TM.delete srv smpClients
|
||||
throwE e
|
||||
tryConnectAsync :: ExceptT ProtocolClientError IO ()
|
||||
tryConnectAsync = do
|
||||
a <- async connectAsync
|
||||
atomically $ modifyTVar' (asyncClients ca) (a :)
|
||||
connectAsync :: ExceptT ProtocolClientError IO ()
|
||||
connectAsync =
|
||||
withRetryInterval (reconnectInterval agentCfg) $ \loop ->
|
||||
void $ tryConnectClient (const reconnectClient) loop
|
||||
|
||||
connectClient :: ExceptT ProtocolClientError IO SMPClient
|
||||
connectClient = ExceptT $ getProtocolClient srv (smpCfg agentCfg) (Just msgQ) clientDisconnected
|
||||
|
||||
clientDisconnected :: IO ()
|
||||
clientDisconnected = do
|
||||
removeClientAndSubs >>= (`forM_` serverDown)
|
||||
logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv
|
||||
|
||||
removeClientAndSubs :: IO (Maybe (Map SMPSub C.APrivateSignKey))
|
||||
removeClientAndSubs = atomically $ do
|
||||
TM.delete srv smpClients
|
||||
TM.lookupDelete srv (srvSubs ca) >>= mapM updateSubs
|
||||
where
|
||||
updateSubs sVar = do
|
||||
ss <- readTVar sVar
|
||||
addPendingSubs sVar ss
|
||||
pure ss
|
||||
|
||||
addPendingSubs sVar ss = do
|
||||
let ps = pendingSrvSubs ca
|
||||
TM.lookup srv ps >>= \case
|
||||
Just v -> TM.union ss v
|
||||
_ -> TM.insert srv sVar ps
|
||||
|
||||
serverDown :: Map SMPSub C.APrivateSignKey -> IO ()
|
||||
serverDown ss = unless (M.null ss) . void . runExceptT $ do
|
||||
notify . CADisconnected srv $ M.keysSet ss
|
||||
reconnectServer
|
||||
|
||||
reconnectServer :: ExceptT ProtocolClientError IO ()
|
||||
reconnectServer = do
|
||||
a <- async tryReconnectClient
|
||||
atomically $ modifyTVar' (reconnections ca) (a :)
|
||||
|
||||
tryReconnectClient :: ExceptT ProtocolClientError IO ()
|
||||
tryReconnectClient = do
|
||||
withRetryInterval (reconnectInterval agentCfg) $ \loop ->
|
||||
reconnectClient `catchE` const loop
|
||||
|
||||
reconnectClient :: ExceptT ProtocolClientError IO ()
|
||||
reconnectClient = do
|
||||
withSMP ca srv $ \smp -> do
|
||||
notify $ CAReconnected srv
|
||||
cs <- atomically $ mapM readTVar =<< TM.lookup srv (pendingSrvSubs ca)
|
||||
forConcurrently_ (maybe [] M.assocs cs) $ \sub@(s, _) ->
|
||||
unlessM (atomically $ hasSub (srvSubs ca) srv s) $
|
||||
subscribe_ smp sub `catchE` handleError s
|
||||
where
|
||||
subscribe_ :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
|
||||
subscribe_ smp sub@(s, _) = do
|
||||
smpSubscribe smp sub
|
||||
atomically $ addSubscription ca srv sub
|
||||
notify $ CAResubscribed srv s
|
||||
|
||||
handleError :: SMPSub -> ProtocolClientError -> ExceptT ProtocolClientError IO ()
|
||||
handleError s = \case
|
||||
e@PCEResponseTimeout -> throwE e
|
||||
e@PCENetworkError -> throwE e
|
||||
e -> do
|
||||
notify $ CASubError srv s e
|
||||
atomically $ removePendingSubscription ca srv s
|
||||
|
||||
notify :: SMPClientAgentEvent -> ExceptT ProtocolClientError IO ()
|
||||
notify evt = atomically $ writeTBQueue (agentQ ca) evt
|
||||
|
||||
closeSMPClientAgent :: MonadUnliftIO m => SMPClientAgent -> m ()
|
||||
closeSMPClientAgent c = liftIO $ do
|
||||
closeSMPServerClients c
|
||||
cancelActions $ reconnections c
|
||||
cancelActions $ asyncClients c
|
||||
|
||||
closeSMPServerClients :: SMPClientAgent -> IO ()
|
||||
closeSMPServerClients c = readTVarIO (smpClients c) >>= mapM_ (forkIO . closeClient)
|
||||
where
|
||||
closeClient smpVar =
|
||||
atomically (readTMVar smpVar) >>= \case
|
||||
Right smp -> closeProtocolClient smp `catchAll_` pure ()
|
||||
_ -> pure ()
|
||||
|
||||
cancelActions :: Foldable f => TVar (f (Async ())) -> IO ()
|
||||
cancelActions as = readTVarIO as >>= mapM_ uninterruptibleCancel
|
||||
|
||||
withSMP :: SMPClientAgent -> SMPServer -> (SMPClient -> ExceptT ProtocolClientError IO a) -> ExceptT ProtocolClientError IO a
|
||||
withSMP ca srv action = (getSMPServerClient' ca srv >>= action) `catchE` logSMPError
|
||||
where
|
||||
logSMPError :: ProtocolClientError -> ExceptT ProtocolClientError IO a
|
||||
logSMPError e = do
|
||||
liftIO $ putStrLn $ "SMP error (" <> show srv <> "): " <> show e
|
||||
throwE e
|
||||
|
||||
subscribeQueue :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
|
||||
subscribeQueue ca srv sub = do
|
||||
atomically $ addPendingSubscription ca srv sub
|
||||
withSMP ca srv $ \smp -> subscribe_ smp `catchE` handleError
|
||||
where
|
||||
subscribe_ smp = do
|
||||
smpSubscribe smp sub
|
||||
atomically $ addSubscription ca srv sub
|
||||
|
||||
handleError e = do
|
||||
atomically . when (e /= PCENetworkError && e /= PCEResponseTimeout) $
|
||||
removePendingSubscription ca srv $ fst sub
|
||||
throwE e
|
||||
|
||||
showServer :: SMPServer -> ByteString
|
||||
showServer ProtocolServer {host, port} =
|
||||
B.pack $ host <> if null port then "" else ':' : port
|
||||
|
||||
smpSubscribe :: SMPClient -> (SMPSub, C.APrivateSignKey) -> ExceptT ProtocolClientError IO ()
|
||||
smpSubscribe smp ((party, queueId), privKey) = subscribe_ smp privKey queueId
|
||||
where
|
||||
subscribe_ = case party of
|
||||
SPRecipient -> subscribeSMPQueue
|
||||
SPNotifier -> subscribeSMPQueueNotifications
|
||||
|
||||
addSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> STM ()
|
||||
addSubscription ca srv sub = do
|
||||
addSub_ (srvSubs ca) srv sub
|
||||
removePendingSubscription ca srv $ fst sub
|
||||
|
||||
addPendingSubscription :: SMPClientAgent -> SMPServer -> (SMPSub, C.APrivateSignKey) -> STM ()
|
||||
addPendingSubscription = addSub_ . pendingSrvSubs
|
||||
|
||||
addSub_ :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> (SMPSub, C.APrivateSignKey) -> STM ()
|
||||
addSub_ subs srv (s, key) =
|
||||
TM.lookup srv subs >>= \case
|
||||
Just m -> TM.insert s key m
|
||||
_ -> TM.singleton s key >>= \v -> TM.insert srv v subs
|
||||
|
||||
removeSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM ()
|
||||
removeSubscription = removeSub_ . srvSubs
|
||||
|
||||
removePendingSubscription :: SMPClientAgent -> SMPServer -> SMPSub -> STM ()
|
||||
removePendingSubscription = removeSub_ . pendingSrvSubs
|
||||
|
||||
removeSub_ :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM ()
|
||||
removeSub_ subs srv s = TM.lookup srv subs >>= mapM_ (TM.delete s)
|
||||
|
||||
getSubKey :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM (Maybe C.APrivateSignKey)
|
||||
getSubKey subs srv s = TM.lookup srv subs $>>= TM.lookup s
|
||||
|
||||
hasSub :: TMap SMPServer (TMap SMPSub C.APrivateSignKey) -> SMPServer -> SMPSub -> STM Bool
|
||||
hasSub subs srv s = maybe (pure False) (TM.member s) =<< TM.lookup srv subs
|
|
@ -8,6 +8,7 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
|
@ -52,6 +53,7 @@ module Simplex.Messaging.Crypto
|
|||
CryptoPublicKey (..),
|
||||
CryptoPrivateKey (..),
|
||||
KeyPair,
|
||||
ASignatureKeyPair,
|
||||
DhSecret (..),
|
||||
DhSecretX25519,
|
||||
ADhSecret (..),
|
||||
|
@ -100,9 +102,14 @@ module Simplex.Messaging.Crypto
|
|||
-- * NaCl crypto_box
|
||||
CbNonce (unCbNonce),
|
||||
cbEncrypt,
|
||||
cbEncryptMaxLenBS,
|
||||
cbDecrypt,
|
||||
cbNonce,
|
||||
randomCbNonce,
|
||||
pseudoRandomCbNonce,
|
||||
|
||||
-- * pseudo-random bytes
|
||||
pseudoRandomBytes,
|
||||
|
||||
-- * SHA256 hash
|
||||
sha256Hash,
|
||||
|
@ -113,9 +120,17 @@ module Simplex.Messaging.Crypto
|
|||
|
||||
-- * Cryptography error type
|
||||
CryptoError (..),
|
||||
|
||||
-- * Limited size ByteStrings
|
||||
MaxLenBS,
|
||||
pattern MaxLenBS,
|
||||
maxLenBS,
|
||||
unsafeMaxLenBS,
|
||||
appendMaxLenBS,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Exception (Exception)
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans.Except
|
||||
|
@ -129,7 +144,7 @@ import qualified Crypto.PubKey.Curve25519 as X25519
|
|||
import qualified Crypto.PubKey.Curve448 as X448
|
||||
import qualified Crypto.PubKey.Ed25519 as Ed25519
|
||||
import qualified Crypto.PubKey.Ed448 as Ed448
|
||||
import Crypto.Random (getRandomBytes)
|
||||
import Crypto.Random (ChaChaDRG, getRandomBytes, randomBytesGenerate)
|
||||
import Data.ASN1.BinaryEncoding
|
||||
import Data.ASN1.Encoding
|
||||
import Data.ASN1.Types
|
||||
|
@ -147,11 +162,11 @@ import Data.Constraint (Dict (..))
|
|||
import Data.Kind (Constraint, Type)
|
||||
import Data.String
|
||||
import Data.Type.Equality
|
||||
import Data.Typeable (Typeable)
|
||||
import Data.Typeable (Proxy (Proxy), Typeable)
|
||||
import Data.X509
|
||||
import Database.SQLite.Simple.FromField (FromField (..))
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import GHC.TypeLits (ErrorMessage (..), TypeError)
|
||||
import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, natVal, type (+))
|
||||
import Network.Transport.Internal (decodeWord16, encodeWord16)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
|
@ -244,6 +259,12 @@ deriving instance Eq (PrivateKey a)
|
|||
|
||||
deriving instance Show (PrivateKey a)
|
||||
|
||||
instance StrEncoding (PrivateKey X25519) where
|
||||
strEncode = strEncode . encodePrivKey
|
||||
{-# INLINE strEncode #-}
|
||||
strDecode = decodePrivKey
|
||||
{-# INLINE strDecode #-}
|
||||
|
||||
data APrivateKey
|
||||
= forall a.
|
||||
AlgorithmI a =>
|
||||
|
@ -284,6 +305,18 @@ instance Eq APrivateSignKey where
|
|||
|
||||
deriving instance Show APrivateSignKey
|
||||
|
||||
instance Encoding APrivateSignKey where
|
||||
smpEncode = smpEncode . encodePrivKey
|
||||
{-# INLINE smpEncode #-}
|
||||
smpDecode = decodePrivKey
|
||||
{-# INLINE smpDecode #-}
|
||||
|
||||
instance StrEncoding APrivateSignKey where
|
||||
strEncode = strEncode . encodePrivKey
|
||||
{-# INLINE strEncode #-}
|
||||
strDecode = decodePrivKey
|
||||
{-# INLINE strDecode #-}
|
||||
|
||||
data APublicVerifyKey
|
||||
= forall a.
|
||||
(AlgorithmI a, SignatureAlgorithm a) =>
|
||||
|
@ -666,7 +699,9 @@ data CryptoError
|
|||
CERatchetHeader
|
||||
| -- | too many skipped messages
|
||||
CERatchetTooManySkipped
|
||||
| -- | duplicate message number (or, possibly, skipped message that failed to decrypt?)
|
||||
| -- | earlier message number (or, possibly, skipped message that failed to decrypt?)
|
||||
CERatchetEarlierMessage
|
||||
| -- | duplicate message number
|
||||
CERatchetDuplicateMessage
|
||||
deriving (Eq, Show, Exception)
|
||||
|
||||
|
@ -783,6 +818,41 @@ unPad padded
|
|||
(lenWrd, rest) = B.splitAt 2 padded
|
||||
len = fromIntegral $ decodeWord16 lenWrd
|
||||
|
||||
newtype MaxLenBS (i :: Nat) = MLBS {unMaxLenBS :: ByteString}
|
||||
|
||||
pattern MaxLenBS :: ByteString -> MaxLenBS i
|
||||
pattern MaxLenBS s <- MLBS s
|
||||
|
||||
{-# COMPLETE MaxLenBS #-}
|
||||
|
||||
instance KnownNat i => Encoding (MaxLenBS i) where
|
||||
smpEncode (MLBS s) = smpEncode s
|
||||
smpP = first show . maxLenBS <$?> smpP
|
||||
|
||||
instance KnownNat i => StrEncoding (MaxLenBS i) where
|
||||
strEncode (MLBS s) = strEncode s
|
||||
strP = first show . maxLenBS <$?> strP
|
||||
|
||||
maxLenBS :: forall i. KnownNat i => ByteString -> Either CryptoError (MaxLenBS i)
|
||||
maxLenBS s
|
||||
| B.length s > maxLength @i = Left CryptoLargeMsgError
|
||||
| otherwise = Right $ MLBS s
|
||||
|
||||
unsafeMaxLenBS :: forall i. KnownNat i => ByteString -> MaxLenBS i
|
||||
unsafeMaxLenBS = MLBS
|
||||
|
||||
padMaxLenBS :: forall i. KnownNat i => MaxLenBS i -> MaxLenBS (i + 2)
|
||||
padMaxLenBS (MLBS msg) = MLBS $ encodeWord16 (fromIntegral len) <> msg <> B.replicate padLen '#'
|
||||
where
|
||||
len = B.length msg
|
||||
padLen = maxLength @i - len
|
||||
|
||||
appendMaxLenBS :: (KnownNat i, KnownNat j) => MaxLenBS i -> MaxLenBS j -> MaxLenBS (i + j)
|
||||
appendMaxLenBS (MLBS s1) (MLBS s2) = MLBS $ s1 <> s2
|
||||
|
||||
maxLength :: forall i. KnownNat i => Int
|
||||
maxLength = fromIntegral (natVal $ Proxy @i)
|
||||
|
||||
initAEAD :: forall c. AES.BlockCipher c => Key -> IV -> ExceptT CryptoError IO (AES.AEAD c)
|
||||
initAEAD (Key aesKey) (IV ivBytes) = do
|
||||
iv <- makeIV @c ivBytes
|
||||
|
@ -838,12 +908,17 @@ dh' (PublicKeyX448 k) (PrivateKeyX448 pk _) = DhSecretX448 $ X448.dh k pk
|
|||
|
||||
-- | NaCl @crypto_box@ encrypt with a shared DH secret and 192-bit nonce.
|
||||
cbEncrypt :: DhSecret X25519 -> CbNonce -> ByteString -> Int -> Either CryptoError ByteString
|
||||
cbEncrypt secret (CbNonce nonce) msg paddedLen = cryptoBox <$> pad msg paddedLen
|
||||
cbEncrypt secret (CbNonce nonce) msg paddedLen = cryptoBox secret nonce <$> pad msg paddedLen
|
||||
|
||||
-- | NaCl @crypto_box@ encrypt with a shared DH secret and 192-bit nonce.
|
||||
cbEncryptMaxLenBS :: KnownNat i => DhSecret X25519 -> CbNonce -> MaxLenBS i -> ByteString
|
||||
cbEncryptMaxLenBS secret (CbNonce nonce) = cryptoBox secret nonce . unMaxLenBS . padMaxLenBS
|
||||
|
||||
cryptoBox :: DhSecret 'X25519 -> ByteString -> ByteString -> ByteString
|
||||
cryptoBox secret nonce s = BA.convert tag <> c
|
||||
where
|
||||
cryptoBox s = BA.convert tag <> c
|
||||
where
|
||||
(rs, c) = xSalsa20 secret nonce s
|
||||
tag = Poly1305.auth rs c
|
||||
(rs, c) = xSalsa20 secret nonce s
|
||||
tag = Poly1305.auth rs c
|
||||
|
||||
-- | NaCl @crypto_box@ decrypt with a shared DH secret and 192-bit nonce.
|
||||
cbDecrypt :: DhSecret X25519 -> CbNonce -> ByteString -> Either CryptoError ByteString
|
||||
|
@ -857,7 +932,15 @@ cbDecrypt secret (CbNonce nonce) packet
|
|||
tag = Poly1305.auth rs c
|
||||
|
||||
newtype CbNonce = CbNonce {unCbNonce :: ByteString}
|
||||
deriving (Show)
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance StrEncoding CbNonce where
|
||||
strEncode (CbNonce s) = strEncode s
|
||||
strP = cbNonce <$> strP
|
||||
|
||||
instance ToJSON CbNonce where
|
||||
toJSON = strToJSON
|
||||
toEncoding = strToJEncoding
|
||||
|
||||
cbNonce :: ByteString -> CbNonce
|
||||
cbNonce s
|
||||
|
@ -870,6 +953,16 @@ cbNonce s
|
|||
randomCbNonce :: IO CbNonce
|
||||
randomCbNonce = CbNonce <$> getRandomBytes 24
|
||||
|
||||
pseudoRandomCbNonce :: TVar ChaChaDRG -> STM CbNonce
|
||||
pseudoRandomCbNonce gVar = CbNonce <$> pseudoRandomBytes 24 gVar
|
||||
|
||||
pseudoRandomBytes :: Int -> TVar ChaChaDRG -> STM ByteString
|
||||
pseudoRandomBytes n gVar = do
|
||||
g <- readTVar gVar
|
||||
let (bytes, g') = randomBytesGenerate n g
|
||||
writeTVar gVar g'
|
||||
return bytes
|
||||
|
||||
instance Encoding CbNonce where
|
||||
smpEncode = unCbNonce
|
||||
smpP = CbNonce <$> A.take 24
|
||||
|
|
|
@ -416,7 +416,8 @@ rcDecrypt rc@Ratchet {rcRcv, rcAD = Str rcAD} rcMKSkipped msg' = do
|
|||
skipMessageKeys :: Word32 -> Ratchet a -> Either CryptoError (Ratchet a, SkippedMsgKeys)
|
||||
skipMessageKeys _ r@Ratchet {rcRcv = Nothing} = Right (r, M.empty)
|
||||
skipMessageKeys untilN r@Ratchet {rcRcv = Just rr@RcvRatchet {rcCKr, rcHKr}, rcNr}
|
||||
| rcNr > untilN = Left CERatchetDuplicateMessage
|
||||
| rcNr > untilN + 1 = Left CERatchetEarlierMessage
|
||||
| rcNr == untilN + 1 = Left CERatchetDuplicateMessage
|
||||
| rcNr + maxSkip < untilN = Left CERatchetTooManySkipped
|
||||
| rcNr == untilN = Right (r, M.empty)
|
||||
| otherwise =
|
||||
|
|
|
@ -47,66 +47,100 @@ class Encoding a where
|
|||
|
||||
instance Encoding Char where
|
||||
smpEncode = B.singleton
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = A.anyChar
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance Encoding Bool where
|
||||
smpEncode = \case
|
||||
True -> "T"
|
||||
False -> "F"
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP =
|
||||
smpP >>= \case
|
||||
'T' -> pure True
|
||||
'F' -> pure False
|
||||
_ -> fail "invalid Bool"
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance Encoding Word16 where
|
||||
smpEncode = encodeWord16
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = decodeWord16 <$> A.take 2
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance Encoding Word32 where
|
||||
smpEncode = encodeWord32
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = decodeWord32 <$> A.take 4
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance Encoding Int64 where
|
||||
smpEncode i = w32 (i `shiftR` 32) <> w32 i
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = do
|
||||
l <- w32P
|
||||
r <- w32P
|
||||
pure $ (l `shiftL` 32) .|. r
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
w32 :: Int64 -> ByteString
|
||||
w32 = smpEncode @Word32 . fromIntegral
|
||||
{-# INLINE w32 #-}
|
||||
|
||||
w32P :: Parser Int64
|
||||
w32P = fromIntegral <$> smpP @Word32
|
||||
{-# INLINE w32P #-}
|
||||
|
||||
-- ByteStrings are assumed no longer than 255 bytes
|
||||
instance Encoding ByteString where
|
||||
smpEncode s = B.cons (lenEncode $ B.length s) s
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = A.take =<< lenP
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
lenEncode :: Int -> Char
|
||||
lenEncode = w2c . fromIntegral
|
||||
{-# INLINE lenEncode #-}
|
||||
|
||||
lenP :: Parser Int
|
||||
lenP = fromIntegral . c2w <$> A.anyChar
|
||||
{-# INLINE lenP #-}
|
||||
|
||||
instance Encoding a => Encoding (Maybe a) where
|
||||
smpEncode s = maybe "0" (("1" <>) . smpEncode) s
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP =
|
||||
smpP >>= \case
|
||||
'0' -> pure Nothing
|
||||
'1' -> Just <$> smpP
|
||||
_ -> fail "invalid Maybe tag"
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
newtype Tail = Tail {unTail :: ByteString}
|
||||
|
||||
instance Encoding Tail where
|
||||
smpEncode = unTail
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = Tail <$> A.takeByteString
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
-- newtype for encoding/decoding ByteStrings over 255 bytes with 2-bytes length prefix
|
||||
newtype Large = Large {unLarge :: ByteString}
|
||||
|
||||
instance Encoding Large where
|
||||
smpEncode (Large s) = smpEncode @Word16 (fromIntegral $ B.length s) <> s
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = do
|
||||
len <- fromIntegral <$> smpP @Word16
|
||||
Large <$> A.take len
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance Encoding SystemTime where
|
||||
smpEncode = smpEncode . systemSeconds
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = MkSystemTime <$> smpP <*> pure 0
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
-- lists encode/parse as a sequence of items prefixed with list length (as 1 byte)
|
||||
smpEncodeList :: Encoding a => [a] -> ByteString
|
||||
|
@ -117,7 +151,9 @@ smpListP = (`A.count` smpP) =<< lenP
|
|||
|
||||
instance Encoding String where
|
||||
smpEncode = smpEncode . B.pack
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = B.unpack <$> smpP
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance Encoding a => Encoding (L.NonEmpty a) where
|
||||
smpEncode = smpEncodeList . L.toList
|
||||
|
@ -128,16 +164,42 @@ instance Encoding a => Encoding (L.NonEmpty a) where
|
|||
|
||||
instance (Encoding a, Encoding b) => Encoding (a, b) where
|
||||
smpEncode (a, b) = smpEncode a <> smpEncode b
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = (,) <$> smpP <*> smpP
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance (Encoding a, Encoding b, Encoding c) => Encoding (a, b, c) where
|
||||
smpEncode (a, b, c) = smpEncode a <> smpEncode b <> smpEncode c
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = (,,) <$> smpP <*> smpP <*> smpP
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance (Encoding a, Encoding b, Encoding c, Encoding d) => Encoding (a, b, c, d) where
|
||||
smpEncode (a, b, c, d) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = (,,,) <$> smpP <*> smpP <*> smpP <*> smpP
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance (Encoding a, Encoding b, Encoding c, Encoding d, Encoding e) => Encoding (a, b, c, d, e) where
|
||||
smpEncode (a, b, c, d, e) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d <> smpEncode e
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = (,,,,) <$> smpP <*> smpP <*> smpP <*> smpP <*> smpP
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance (Encoding a, Encoding b, Encoding c, Encoding d, Encoding e, Encoding f) => Encoding (a, b, c, d, e, f) where
|
||||
smpEncode (a, b, c, d, e, f) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d <> smpEncode e <> smpEncode f
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = (,,,,,) <$> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance (Encoding a, Encoding b, Encoding c, Encoding d, Encoding e, Encoding f, Encoding g) => Encoding (a, b, c, d, e, f, g) where
|
||||
smpEncode (a, b, c, d, e, f, g) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d <> smpEncode e <> smpEncode f <> smpEncode g
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = (,,,,,,) <$> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP
|
||||
{-# INLINE smpP #-}
|
||||
|
||||
instance (Encoding a, Encoding b, Encoding c, Encoding d, Encoding e, Encoding f, Encoding g, Encoding h) => Encoding (a, b, c, d, e, f, g, h) where
|
||||
smpEncode (a, b, c, d, e, f, g, h) = smpEncode a <> smpEncode b <> smpEncode c <> smpEncode d <> smpEncode e <> smpEncode f <> smpEncode g <> smpEncode h
|
||||
{-# INLINE smpEncode #-}
|
||||
smpP = (,,,,,,,) <$> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP <*> smpP
|
||||
{-# INLINE smpP #-}
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Encoding.String
|
||||
( StrEncoding (..),
|
||||
( TextEncoding (..),
|
||||
StrEncoding (..),
|
||||
Str (..),
|
||||
strP_,
|
||||
strToJSON,
|
||||
|
@ -25,12 +26,24 @@ import qualified Data.ByteString.Base64.URL as U
|
|||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isAlphaNum)
|
||||
import Data.Int (Int64)
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Text (Text)
|
||||
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
|
||||
import Data.Time.Clock (UTCTime)
|
||||
import Data.Time.Clock.System (SystemTime (..))
|
||||
import Data.Time.Format.ISO8601
|
||||
import Data.Word (Word16)
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Util ((<$?>))
|
||||
|
||||
class TextEncoding a where
|
||||
textEncode :: a -> Text
|
||||
textDecode :: Text -> Maybe a
|
||||
|
||||
-- | Serializing human-readable and (where possible) URI-friendly strings for SMP and SMP agent protocols
|
||||
class StrEncoding a where
|
||||
{-# MINIMAL strEncode, (strDecode | strP) #-}
|
||||
|
@ -70,11 +83,47 @@ instance FromJSON Str where
|
|||
|
||||
instance StrEncoding a => StrEncoding (Maybe a) where
|
||||
strEncode = maybe "" strEncode
|
||||
{-# INLINE strEncode #-}
|
||||
strP = optional strP
|
||||
{-# INLINE strP #-}
|
||||
|
||||
instance StrEncoding Word16 where
|
||||
strEncode = B.pack . show
|
||||
{-# INLINE strEncode #-}
|
||||
strP = A.decimal
|
||||
{-# INLINE strP #-}
|
||||
|
||||
instance StrEncoding Char where
|
||||
strEncode = smpEncode
|
||||
{-# INLINE strEncode #-}
|
||||
strP = strP
|
||||
{-# INLINE strP #-}
|
||||
|
||||
instance StrEncoding Bool where
|
||||
strEncode = smpEncode
|
||||
{-# INLINE strEncode #-}
|
||||
strP = smpP
|
||||
{-# INLINE strP #-}
|
||||
|
||||
instance StrEncoding Int where
|
||||
strEncode = B.pack . show
|
||||
{-# INLINE strEncode #-}
|
||||
strP = A.decimal
|
||||
{-# INLINE strP #-}
|
||||
|
||||
instance StrEncoding Int64 where
|
||||
strEncode = B.pack . show
|
||||
{-# INLINE strEncode #-}
|
||||
strP = A.decimal
|
||||
{-# INLINE strP #-}
|
||||
|
||||
instance StrEncoding SystemTime where
|
||||
strEncode = strEncode . systemSeconds
|
||||
strP = MkSystemTime <$> strP <*> pure 0
|
||||
|
||||
instance StrEncoding UTCTime where
|
||||
strEncode = B.pack . iso8601Show
|
||||
strP = maybe (Left "bad UTCTime") Right . iso8601ParseM . B.unpack <$?> A.takeTill (\c -> c == ' ' || c == '\n')
|
||||
|
||||
-- lists encode/parse as comma-separated strings
|
||||
strEncodeList :: StrEncoding a => [a] -> ByteString
|
||||
|
@ -88,24 +137,36 @@ instance StrEncoding a => StrEncoding (L.NonEmpty a) where
|
|||
strEncode = strEncodeList . L.toList
|
||||
strP = L.fromList <$> listItem `A.sepBy1'` A.char ','
|
||||
|
||||
instance (StrEncoding a, Ord a) => StrEncoding (Set a) where
|
||||
strEncode = strEncodeList . S.toList
|
||||
strP = S.fromList <$> listItem `A.sepBy'` A.char ','
|
||||
|
||||
listItem :: StrEncoding a => Parser a
|
||||
listItem = parseAll strP <$?> A.takeTill (== ',')
|
||||
listItem = parseAll strP <$?> A.takeTill (\c -> c == ',' || c == ' ' || c == '\n')
|
||||
|
||||
instance (StrEncoding a, StrEncoding b) => StrEncoding (a, b) where
|
||||
strEncode (a, b) = B.unwords [strEncode a, strEncode b]
|
||||
{-# INLINE strEncode #-}
|
||||
strP = (,) <$> strP_ <*> strP
|
||||
{-# INLINE strP #-}
|
||||
|
||||
instance (StrEncoding a, StrEncoding b, StrEncoding c) => StrEncoding (a, b, c) where
|
||||
strEncode (a, b, c) = B.unwords [strEncode a, strEncode b, strEncode c]
|
||||
{-# INLINE strEncode #-}
|
||||
strP = (,,) <$> strP_ <*> strP_ <*> strP
|
||||
{-# INLINE strP #-}
|
||||
|
||||
instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d) => StrEncoding (a, b, c, d) where
|
||||
strEncode (a, b, c, d) = B.unwords [strEncode a, strEncode b, strEncode c, strEncode d]
|
||||
{-# INLINE strEncode #-}
|
||||
strP = (,,,) <$> strP_ <*> strP_ <*> strP_ <*> strP
|
||||
{-# INLINE strP #-}
|
||||
|
||||
instance (StrEncoding a, StrEncoding b, StrEncoding c, StrEncoding d, StrEncoding e) => StrEncoding (a, b, c, d, e) where
|
||||
strEncode (a, b, c, d, e) = B.unwords [strEncode a, strEncode b, strEncode c, strEncode d, strEncode e]
|
||||
{-# INLINE strEncode #-}
|
||||
strP = (,,,,) <$> strP_ <*> strP_ <*> strP_ <*> strP_ <*> strP
|
||||
{-# INLINE strP #-}
|
||||
|
||||
strP_ :: StrEncoding a => Parser a
|
||||
strP_ = strP <* A.space
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Client where
|
||||
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.Word (Word16)
|
||||
import Simplex.Messaging.Client
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
|
||||
type NtfClient = ProtocolClient NtfResponse
|
||||
|
||||
ntfRegisterToken :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Token -> ExceptT ProtocolClientError IO (NtfTokenId, C.PublicKeyX25519)
|
||||
ntfRegisterToken c pKey newTkn =
|
||||
sendNtfCommand c (Just pKey) "" (TNEW newTkn) >>= \case
|
||||
NRTknId tknId dhKey -> pure (tknId, dhKey)
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
ntfVerifyToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> NtfRegCode -> ExceptT ProtocolClientError IO ()
|
||||
ntfVerifyToken c pKey tknId code = okNtfCommand (TVFY code) c pKey tknId
|
||||
|
||||
ntfCheckToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO NtfTknStatus
|
||||
ntfCheckToken c pKey tknId =
|
||||
sendNtfCommand c (Just pKey) tknId TCHK >>= \case
|
||||
NRTkn stat -> pure stat
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
ntfReplaceToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> DeviceToken -> ExceptT ProtocolClientError IO ()
|
||||
ntfReplaceToken c pKey tknId token = okNtfCommand (TRPL token) c pKey tknId
|
||||
|
||||
ntfDeleteToken :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> ExceptT ProtocolClientError IO ()
|
||||
ntfDeleteToken = okNtfCommand TDEL
|
||||
|
||||
ntfEnableCron :: NtfClient -> C.APrivateSignKey -> NtfTokenId -> Word16 -> ExceptT ProtocolClientError IO ()
|
||||
ntfEnableCron c pKey tknId int = okNtfCommand (TCRN int) c pKey tknId
|
||||
|
||||
ntfCreateSubscription :: NtfClient -> C.APrivateSignKey -> NewNtfEntity 'Subscription -> ExceptT ProtocolClientError IO NtfSubscriptionId
|
||||
ntfCreateSubscription c pKey newSub =
|
||||
sendNtfCommand c (Just pKey) "" (SNEW newSub) >>= \case
|
||||
NRSubId subId -> pure subId
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
ntfCheckSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO NtfSubStatus
|
||||
ntfCheckSubscription c pKey subId =
|
||||
sendNtfCommand c (Just pKey) subId SCHK >>= \case
|
||||
NRSub stat -> pure stat
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
||||
|
||||
ntfDeleteSubscription :: NtfClient -> C.APrivateSignKey -> NtfSubscriptionId -> ExceptT ProtocolClientError IO ()
|
||||
ntfDeleteSubscription = okNtfCommand SDEL
|
||||
|
||||
-- | Send notification server command
|
||||
sendNtfCommand :: NtfEntityI e => NtfClient -> Maybe C.APrivateSignKey -> NtfEntityId -> NtfCommand e -> ExceptT ProtocolClientError IO NtfResponse
|
||||
sendNtfCommand c pKey entId cmd = sendProtocolCommand c pKey entId (NtfCmd sNtfEntity cmd)
|
||||
|
||||
okNtfCommand :: NtfEntityI e => NtfCommand e -> NtfClient -> C.APrivateSignKey -> NtfEntityId -> ExceptT ProtocolClientError IO ()
|
||||
okNtfCommand cmd c pKey entId =
|
||||
sendNtfCommand c (Just pKey) entId cmd >>= \case
|
||||
NROk -> return ()
|
||||
r -> throwE . PCEUnexpectedResponse $ bshow r
|
|
@ -0,0 +1,502 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Protocol where
|
||||
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..), (.=))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Kind
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
|
||||
import Data.Type.Equality
|
||||
import Data.Word (Word16)
|
||||
import Database.SQLite.Simple.FromField (FromField (..))
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Transport (ntfClientHandshake)
|
||||
import Simplex.Messaging.Parsers (fromTextField_)
|
||||
import Simplex.Messaging.Protocol hiding (Command (..), CommandTag (..))
|
||||
import Simplex.Messaging.Util (eitherToMaybe, (<$?>))
|
||||
|
||||
data NtfEntity = Token | Subscription
|
||||
deriving (Show)
|
||||
|
||||
data SNtfEntity :: NtfEntity -> Type where
|
||||
SToken :: SNtfEntity 'Token
|
||||
SSubscription :: SNtfEntity 'Subscription
|
||||
|
||||
instance TestEquality SNtfEntity where
|
||||
testEquality SToken SToken = Just Refl
|
||||
testEquality SSubscription SSubscription = Just Refl
|
||||
testEquality _ _ = Nothing
|
||||
|
||||
deriving instance Show (SNtfEntity e)
|
||||
|
||||
class NtfEntityI (e :: NtfEntity) where sNtfEntity :: SNtfEntity e
|
||||
|
||||
instance NtfEntityI 'Token where sNtfEntity = SToken
|
||||
|
||||
instance NtfEntityI 'Subscription where sNtfEntity = SSubscription
|
||||
|
||||
data NtfCommandTag (e :: NtfEntity) where
|
||||
TNEW_ :: NtfCommandTag 'Token
|
||||
TVFY_ :: NtfCommandTag 'Token
|
||||
TCHK_ :: NtfCommandTag 'Token
|
||||
TRPL_ :: NtfCommandTag 'Token
|
||||
TDEL_ :: NtfCommandTag 'Token
|
||||
TCRN_ :: NtfCommandTag 'Token
|
||||
SNEW_ :: NtfCommandTag 'Subscription
|
||||
SCHK_ :: NtfCommandTag 'Subscription
|
||||
SDEL_ :: NtfCommandTag 'Subscription
|
||||
PING_ :: NtfCommandTag 'Subscription
|
||||
|
||||
deriving instance Show (NtfCommandTag e)
|
||||
|
||||
data NtfCmdTag = forall e. NtfEntityI e => NCT (SNtfEntity e) (NtfCommandTag e)
|
||||
|
||||
instance NtfEntityI e => Encoding (NtfCommandTag e) where
|
||||
smpEncode = \case
|
||||
TNEW_ -> "TNEW"
|
||||
TVFY_ -> "TVFY"
|
||||
TCHK_ -> "TCHK"
|
||||
TRPL_ -> "TRPL"
|
||||
TDEL_ -> "TDEL"
|
||||
TCRN_ -> "TCRN"
|
||||
SNEW_ -> "SNEW"
|
||||
SCHK_ -> "SCHK"
|
||||
SDEL_ -> "SDEL"
|
||||
PING_ -> "PING"
|
||||
smpP = messageTagP
|
||||
|
||||
instance Encoding NtfCmdTag where
|
||||
smpEncode (NCT _ t) = smpEncode t
|
||||
smpP = messageTagP
|
||||
|
||||
instance ProtocolMsgTag NtfCmdTag where
|
||||
decodeTag = \case
|
||||
"TNEW" -> Just $ NCT SToken TNEW_
|
||||
"TVFY" -> Just $ NCT SToken TVFY_
|
||||
"TCHK" -> Just $ NCT SToken TCHK_
|
||||
"TRPL" -> Just $ NCT SToken TRPL_
|
||||
"TDEL" -> Just $ NCT SToken TDEL_
|
||||
"TCRN" -> Just $ NCT SToken TCRN_
|
||||
"SNEW" -> Just $ NCT SSubscription SNEW_
|
||||
"SCHK" -> Just $ NCT SSubscription SCHK_
|
||||
"SDEL" -> Just $ NCT SSubscription SDEL_
|
||||
"PING" -> Just $ NCT SSubscription PING_
|
||||
_ -> Nothing
|
||||
|
||||
instance NtfEntityI e => ProtocolMsgTag (NtfCommandTag e) where
|
||||
decodeTag s = decodeTag s >>= (\(NCT _ t) -> checkEntity' t)
|
||||
|
||||
newtype NtfRegCode = NtfRegCode ByteString
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Encoding NtfRegCode where
|
||||
smpEncode (NtfRegCode code) = smpEncode code
|
||||
smpP = NtfRegCode <$> smpP
|
||||
|
||||
instance StrEncoding NtfRegCode where
|
||||
strEncode (NtfRegCode m) = strEncode m
|
||||
strDecode s = NtfRegCode <$> strDecode s
|
||||
strP = NtfRegCode <$> strP
|
||||
|
||||
instance FromJSON NtfRegCode where
|
||||
parseJSON = strParseJSON "NtfRegCode"
|
||||
|
||||
instance ToJSON NtfRegCode where
|
||||
toJSON = strToJSON
|
||||
toEncoding = strToJEncoding
|
||||
|
||||
data NewNtfEntity (e :: NtfEntity) where
|
||||
NewNtfTkn :: DeviceToken -> C.APublicVerifyKey -> C.PublicKeyX25519 -> NewNtfEntity 'Token
|
||||
NewNtfSub :: NtfTokenId -> SMPQueueNtf -> NtfPrivateSignKey -> NewNtfEntity 'Subscription
|
||||
|
||||
deriving instance Show (NewNtfEntity e)
|
||||
|
||||
data ANewNtfEntity = forall e. NtfEntityI e => ANE (SNtfEntity e) (NewNtfEntity e)
|
||||
|
||||
deriving instance Show ANewNtfEntity
|
||||
|
||||
instance NtfEntityI e => Encoding (NewNtfEntity e) where
|
||||
smpEncode = \case
|
||||
NewNtfTkn tkn verifyKey dhPubKey -> smpEncode ('T', tkn, verifyKey, dhPubKey)
|
||||
NewNtfSub tknId smpQueue notifierKey -> smpEncode ('S', tknId, smpQueue, notifierKey)
|
||||
smpP = (\(ANE _ c) -> checkEntity c) <$?> smpP
|
||||
|
||||
instance Encoding ANewNtfEntity where
|
||||
smpEncode (ANE _ e) = smpEncode e
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'T' -> ANE SToken <$> (NewNtfTkn <$> smpP <*> smpP <*> smpP)
|
||||
'S' -> ANE SSubscription <$> (NewNtfSub <$> smpP <*> smpP <*> smpP)
|
||||
_ -> fail "bad ANewNtfEntity"
|
||||
|
||||
instance Protocol NtfResponse where
|
||||
type ProtoCommand NtfResponse = NtfCmd
|
||||
type ProtoType NtfResponse = 'PNTF
|
||||
protocolClientHandshake = ntfClientHandshake
|
||||
protocolPing = NtfCmd SSubscription PING
|
||||
protocolError = \case
|
||||
NRErr e -> Just e
|
||||
_ -> Nothing
|
||||
|
||||
data NtfCommand (e :: NtfEntity) where
|
||||
-- | register new device token for notifications
|
||||
TNEW :: NewNtfEntity 'Token -> NtfCommand 'Token
|
||||
-- | verify token - uses e2e encrypted random string sent to the device via PN to confirm that the device has the token
|
||||
TVFY :: NtfRegCode -> NtfCommand 'Token
|
||||
-- | check token status
|
||||
TCHK :: NtfCommand 'Token
|
||||
-- | replace device token (while keeping all existing subscriptions)
|
||||
TRPL :: DeviceToken -> NtfCommand 'Token
|
||||
-- | delete token - all subscriptions will be removed and no more notifications will be sent
|
||||
TDEL :: NtfCommand 'Token
|
||||
-- | enable periodic background notification to fetch the new messages - interval is in minutes, minimum is 20, 0 to disable
|
||||
TCRN :: Word16 -> NtfCommand 'Token
|
||||
-- | create SMP subscription
|
||||
SNEW :: NewNtfEntity 'Subscription -> NtfCommand 'Subscription
|
||||
-- | check SMP subscription status (response is SUB)
|
||||
SCHK :: NtfCommand 'Subscription
|
||||
-- | delete SMP subscription
|
||||
SDEL :: NtfCommand 'Subscription
|
||||
-- | keep-alive command
|
||||
PING :: NtfCommand 'Subscription
|
||||
|
||||
deriving instance Show (NtfCommand e)
|
||||
|
||||
data NtfCmd = forall e. NtfEntityI e => NtfCmd (SNtfEntity e) (NtfCommand e)
|
||||
|
||||
deriving instance Show NtfCmd
|
||||
|
||||
instance NtfEntityI e => ProtocolEncoding (NtfCommand e) where
|
||||
type Tag (NtfCommand e) = NtfCommandTag e
|
||||
encodeProtocol _v = \case
|
||||
TNEW newTkn -> e (TNEW_, ' ', newTkn)
|
||||
TVFY code -> e (TVFY_, ' ', code)
|
||||
TCHK -> e TCHK_
|
||||
TRPL tkn -> e (TRPL_, ' ', tkn)
|
||||
TDEL -> e TDEL_
|
||||
TCRN int -> e (TCRN_, ' ', int)
|
||||
SNEW newSub -> e (SNEW_, ' ', newSub)
|
||||
SCHK -> e SCHK_
|
||||
SDEL -> e SDEL_
|
||||
PING -> e PING_
|
||||
where
|
||||
e :: Encoding a => a -> ByteString
|
||||
e = smpEncode
|
||||
|
||||
protocolP _v tag = (\(NtfCmd _ c) -> checkEntity c) <$?> protocolP _v (NCT (sNtfEntity @e) tag)
|
||||
|
||||
checkCredentials (sig, _, entityId, _) cmd = case cmd of
|
||||
-- TNEW and SNEW must have signature but NOT token/subscription IDs
|
||||
TNEW {} -> sigNoEntity
|
||||
SNEW {} -> sigNoEntity
|
||||
PING
|
||||
| isNothing sig && B.null entityId -> Right cmd
|
||||
| otherwise -> Left $ CMD HAS_AUTH
|
||||
-- other client commands must have both signature and entity ID
|
||||
_
|
||||
| isNothing sig || B.null entityId -> Left $ CMD NO_AUTH
|
||||
| otherwise -> Right cmd
|
||||
where
|
||||
sigNoEntity
|
||||
| isNothing sig = Left $ CMD NO_AUTH
|
||||
| not (B.null entityId) = Left $ CMD HAS_AUTH
|
||||
| otherwise = Right cmd
|
||||
|
||||
instance ProtocolEncoding NtfCmd where
|
||||
type Tag NtfCmd = NtfCmdTag
|
||||
encodeProtocol _v (NtfCmd _ c) = encodeProtocol _v c
|
||||
|
||||
protocolP _v = \case
|
||||
NCT SToken tag ->
|
||||
NtfCmd SToken <$> case tag of
|
||||
TNEW_ -> TNEW <$> _smpP
|
||||
TVFY_ -> TVFY <$> _smpP
|
||||
TCHK_ -> pure TCHK
|
||||
TRPL_ -> TRPL <$> _smpP
|
||||
TDEL_ -> pure TDEL
|
||||
TCRN_ -> TCRN <$> _smpP
|
||||
NCT SSubscription tag ->
|
||||
NtfCmd SSubscription <$> case tag of
|
||||
SNEW_ -> SNEW <$> _smpP
|
||||
SCHK_ -> pure SCHK
|
||||
SDEL_ -> pure SDEL
|
||||
PING_ -> pure PING
|
||||
|
||||
checkCredentials t (NtfCmd e c) = NtfCmd e <$> checkCredentials t c
|
||||
|
||||
data NtfResponseTag
|
||||
= NRTknId_
|
||||
| NRSubId_
|
||||
| NROk_
|
||||
| NRErr_
|
||||
| NRTkn_
|
||||
| NRSub_
|
||||
| NRPong_
|
||||
deriving (Show)
|
||||
|
||||
instance Encoding NtfResponseTag where
|
||||
smpEncode = \case
|
||||
NRTknId_ -> "IDTKN" -- it should be "TID", "SID"
|
||||
NRSubId_ -> "IDSUB"
|
||||
NROk_ -> "OK"
|
||||
NRErr_ -> "ERR"
|
||||
NRTkn_ -> "TKN"
|
||||
NRSub_ -> "SUB"
|
||||
NRPong_ -> "PONG"
|
||||
smpP = messageTagP
|
||||
|
||||
instance ProtocolMsgTag NtfResponseTag where
|
||||
decodeTag = \case
|
||||
"IDTKN" -> Just NRTknId_
|
||||
"IDSUB" -> Just NRSubId_
|
||||
"OK" -> Just NROk_
|
||||
"ERR" -> Just NRErr_
|
||||
"TKN" -> Just NRTkn_
|
||||
"SUB" -> Just NRSub_
|
||||
"PONG" -> Just NRPong_
|
||||
_ -> Nothing
|
||||
|
||||
data NtfResponse
|
||||
= NRTknId NtfEntityId C.PublicKeyX25519
|
||||
| NRSubId NtfEntityId
|
||||
| NROk
|
||||
| NRErr ErrorType
|
||||
| NRTkn NtfTknStatus
|
||||
| NRSub NtfSubStatus
|
||||
| NRPong
|
||||
deriving (Show)
|
||||
|
||||
instance ProtocolEncoding NtfResponse where
|
||||
type Tag NtfResponse = NtfResponseTag
|
||||
encodeProtocol _v = \case
|
||||
NRTknId entId dhKey -> e (NRTknId_, ' ', entId, dhKey)
|
||||
NRSubId entId -> e (NRSubId_, ' ', entId)
|
||||
NROk -> e NROk_
|
||||
NRErr err -> e (NRErr_, ' ', err)
|
||||
NRTkn stat -> e (NRTkn_, ' ', stat)
|
||||
NRSub stat -> e (NRSub_, ' ', stat)
|
||||
NRPong -> e NRPong_
|
||||
where
|
||||
e :: Encoding a => a -> ByteString
|
||||
e = smpEncode
|
||||
|
||||
protocolP _v = \case
|
||||
NRTknId_ -> NRTknId <$> _smpP <*> smpP
|
||||
NRSubId_ -> NRSubId <$> _smpP
|
||||
NROk_ -> pure NROk
|
||||
NRErr_ -> NRErr <$> _smpP
|
||||
NRTkn_ -> NRTkn <$> _smpP
|
||||
NRSub_ -> NRSub <$> _smpP
|
||||
NRPong_ -> pure NRPong
|
||||
|
||||
checkCredentials (_, _, entId, _) cmd = case cmd of
|
||||
-- IDTKN response must not have queue ID
|
||||
NRTknId {} -> noEntity
|
||||
-- IDSUB response must not have queue ID
|
||||
NRSubId {} -> noEntity
|
||||
-- ERR response does not always have entity ID
|
||||
NRErr _ -> Right cmd
|
||||
-- PONG response must not have queue ID
|
||||
NRPong -> noEntity
|
||||
-- other server responses must have entity ID
|
||||
_
|
||||
| B.null entId -> Left $ CMD NO_ENTITY
|
||||
| otherwise -> Right cmd
|
||||
where
|
||||
noEntity
|
||||
| B.null entId = Right cmd
|
||||
| otherwise = Left $ CMD HAS_AUTH
|
||||
|
||||
data SMPQueueNtf = SMPQueueNtf
|
||||
{ smpServer :: SMPServer,
|
||||
notifierId :: NotifierId
|
||||
}
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance Encoding SMPQueueNtf where
|
||||
smpEncode SMPQueueNtf {smpServer, notifierId} = smpEncode (smpServer, notifierId)
|
||||
smpP = do
|
||||
(smpServer, notifierId) <- smpP
|
||||
pure $ SMPQueueNtf {smpServer, notifierId}
|
||||
|
||||
instance StrEncoding SMPQueueNtf where
|
||||
strEncode SMPQueueNtf {smpServer, notifierId} = strEncode smpServer <> "/" <> strEncode notifierId
|
||||
strP = SMPQueueNtf <$> strP <* A.char '/' <*> strP
|
||||
|
||||
data PushProvider = PPApnsDev | PPApnsProd | PPApnsTest
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance Encoding PushProvider where
|
||||
smpEncode = \case
|
||||
PPApnsDev -> "AD"
|
||||
PPApnsProd -> "AP"
|
||||
PPApnsTest -> "AT"
|
||||
smpP =
|
||||
A.take 2 >>= \case
|
||||
"AD" -> pure PPApnsDev
|
||||
"AP" -> pure PPApnsProd
|
||||
"AT" -> pure PPApnsTest
|
||||
_ -> fail "bad PushProvider"
|
||||
|
||||
instance StrEncoding PushProvider where
|
||||
strEncode = \case
|
||||
PPApnsDev -> "apns_dev"
|
||||
PPApnsProd -> "apns_prod"
|
||||
PPApnsTest -> "apns_test"
|
||||
strP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"apns_dev" -> pure PPApnsDev
|
||||
"apns_prod" -> pure PPApnsProd
|
||||
"apns_test" -> pure PPApnsTest
|
||||
_ -> fail "bad PushProvider"
|
||||
|
||||
instance FromField PushProvider where fromField = fromTextField_ $ eitherToMaybe . strDecode . encodeUtf8
|
||||
|
||||
instance ToField PushProvider where toField = toField . decodeLatin1 . strEncode
|
||||
|
||||
data DeviceToken = DeviceToken PushProvider ByteString
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance Encoding DeviceToken where
|
||||
smpEncode (DeviceToken p t) = smpEncode (p, t)
|
||||
smpP = DeviceToken <$> smpP <*> smpP
|
||||
|
||||
instance StrEncoding DeviceToken where
|
||||
strEncode (DeviceToken p t) = strEncode p <> " " <> t
|
||||
strP = DeviceToken <$> strP <* A.space <*> hexStringP
|
||||
where
|
||||
hexStringP =
|
||||
A.takeWhile (\c -> (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) >>= \s ->
|
||||
if even (B.length s) then pure s else fail "odd number of hex characters"
|
||||
|
||||
instance ToJSON DeviceToken where
|
||||
toEncoding (DeviceToken pp t) = J.pairs $ "pushProvider" .= decodeLatin1 (strEncode pp) <> "token" .= decodeLatin1 t
|
||||
toJSON (DeviceToken pp t) = J.object ["pushProvider" .= decodeLatin1 (strEncode pp), "token" .= decodeLatin1 t]
|
||||
|
||||
type NtfEntityId = ByteString
|
||||
|
||||
type NtfSubscriptionId = NtfEntityId
|
||||
|
||||
type NtfTokenId = NtfEntityId
|
||||
|
||||
data NtfSubStatus
|
||||
= -- | state after SNEW
|
||||
NSNew
|
||||
| -- | pending connection/subscription to SMP server
|
||||
NSPending
|
||||
| -- | connected and subscribed to SMP server
|
||||
NSActive
|
||||
| -- | disconnected/unsubscribed from SMP server
|
||||
NSInactive
|
||||
| -- | END received
|
||||
NSEnd
|
||||
| -- | SMP AUTH error
|
||||
NSAuth
|
||||
| -- | SMP error other than AUTH
|
||||
NSErr ByteString
|
||||
deriving (Eq, Show)
|
||||
|
||||
ntfShouldSubscribe :: NtfSubStatus -> Bool
|
||||
ntfShouldSubscribe = \case
|
||||
NSNew -> True
|
||||
NSPending -> True
|
||||
NSActive -> True
|
||||
NSInactive -> True
|
||||
NSEnd -> False
|
||||
NSAuth -> False
|
||||
NSErr _ -> False
|
||||
|
||||
instance Encoding NtfSubStatus where
|
||||
smpEncode = \case
|
||||
NSNew -> "NEW"
|
||||
NSPending -> "PENDING" -- e.g. after SMP server disconnect/timeout while ntf server is retrying to connect
|
||||
NSActive -> "ACTIVE"
|
||||
NSInactive -> "INACTIVE"
|
||||
NSEnd -> "END"
|
||||
NSAuth -> "AUTH"
|
||||
NSErr err -> "ERR " <> err
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"NEW" -> pure NSNew
|
||||
"PENDING" -> pure NSPending
|
||||
"ACTIVE" -> pure NSActive
|
||||
"INACTIVE" -> pure NSInactive
|
||||
"END" -> pure NSEnd
|
||||
"AUTH" -> pure NSAuth
|
||||
"ERR" -> NSErr <$> (A.space *> A.takeByteString)
|
||||
_ -> fail "bad NtfSubStatus"
|
||||
|
||||
instance StrEncoding NtfSubStatus where
|
||||
strEncode = smpEncode
|
||||
strP = smpP
|
||||
|
||||
data NtfTknStatus
|
||||
= -- | Token created in DB
|
||||
NTNew
|
||||
| -- | state after registration (TNEW)
|
||||
NTRegistered
|
||||
| -- | if initial notification failed (push provider error) or verification failed
|
||||
NTInvalid
|
||||
| -- | Token confirmed via notification (accepted by push provider or verification code received by client)
|
||||
NTConfirmed
|
||||
| -- | after successful verification (TVFY)
|
||||
NTActive
|
||||
| -- | after it is no longer valid (push provider error)
|
||||
NTExpired
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Encoding NtfTknStatus where
|
||||
smpEncode = \case
|
||||
NTNew -> "NEW"
|
||||
NTRegistered -> "REGISTERED"
|
||||
NTInvalid -> "INVALID"
|
||||
NTConfirmed -> "CONFIRMED"
|
||||
NTActive -> "ACTIVE"
|
||||
NTExpired -> "EXPIRED"
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"NEW" -> pure NTNew
|
||||
"REGISTERED" -> pure NTRegistered
|
||||
"INVALID" -> pure NTInvalid
|
||||
"CONFIRMED" -> pure NTConfirmed
|
||||
"ACTIVE" -> pure NTActive
|
||||
"EXPIRED" -> pure NTExpired
|
||||
_ -> fail "bad NtfTknStatus"
|
||||
|
||||
instance StrEncoding NtfTknStatus where
|
||||
strEncode = smpEncode
|
||||
strP = smpP
|
||||
|
||||
instance FromField NtfTknStatus where fromField = fromTextField_ $ either (const Nothing) Just . smpDecode . encodeUtf8
|
||||
|
||||
instance ToField NtfTknStatus where toField = toField . decodeLatin1 . smpEncode
|
||||
|
||||
instance ToJSON NtfTknStatus where
|
||||
toEncoding = JE.text . decodeLatin1 . smpEncode
|
||||
toJSON = J.String . decodeLatin1 . smpEncode
|
||||
|
||||
checkEntity :: forall t e e'. (NtfEntityI e, NtfEntityI e') => t e' -> Either String (t e)
|
||||
checkEntity c = case testEquality (sNtfEntity @e) (sNtfEntity @e') of
|
||||
Just Refl -> Right c
|
||||
Nothing -> Left "bad command party"
|
||||
|
||||
checkEntity' :: forall t p p'. (NtfEntityI p, NtfEntityI p') => t p' -> Maybe (t p)
|
||||
checkEntity' c = case testEquality (sNtfEntity @p) (sNtfEntity @p') of
|
||||
Just Refl -> Just c
|
||||
_ -> Nothing
|
|
@ -0,0 +1,471 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server where
|
||||
|
||||
import Control.Concurrent.STM (stateTVar)
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift (MonadUnliftIO)
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random (MonadRandom)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Functor (($>))
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Text as T
|
||||
import Data.Time.Clock.System (getSystemTime)
|
||||
import Network.Socket (ServiceName)
|
||||
import Simplex.Messaging.Client (ProtocolClientError (..))
|
||||
import Simplex.Messaging.Client.Agent
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Env
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS (PNMessageData (..), PushNotification (..), PushProviderError (..))
|
||||
import Simplex.Messaging.Notifications.Server.Store
|
||||
import Simplex.Messaging.Notifications.Server.StoreLog
|
||||
import Simplex.Messaging.Notifications.Transport
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), ProtocolServer (host), SMPServer, SignedTransmission, Transmission, encodeTransmission, tGet, tPut)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport (..), THandle (..), TProxy, Transport (..))
|
||||
import Simplex.Messaging.Transport.Server (runTransportServer)
|
||||
import Simplex.Messaging.Util
|
||||
import System.Mem.Weak (deRefWeak)
|
||||
import UnliftIO (IOMode (..), async, uninterruptibleCancel)
|
||||
import UnliftIO.Concurrent (forkIO, killThread, mkWeakThreadId, threadDelay)
|
||||
import UnliftIO.Exception
|
||||
import UnliftIO.STM
|
||||
|
||||
runNtfServer :: (MonadRandom m, MonadUnliftIO m) => NtfServerConfig -> m ()
|
||||
runNtfServer cfg = do
|
||||
started <- newEmptyTMVarIO
|
||||
runNtfServerBlocking started cfg
|
||||
|
||||
runNtfServerBlocking :: (MonadRandom m, MonadUnliftIO m) => TMVar Bool -> NtfServerConfig -> m ()
|
||||
runNtfServerBlocking started cfg = runReaderT (ntfServer cfg started) =<< newNtfServerEnv cfg
|
||||
|
||||
ntfServer :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerConfig -> TMVar Bool -> m ()
|
||||
ntfServer NtfServerConfig {transports} started = do
|
||||
s <- asks subscriber
|
||||
ps <- asks pushServer
|
||||
subs <- readTVarIO =<< asks (subscriptions . store)
|
||||
void . forkIO $ resubscribe s subs
|
||||
raceAny_ (ntfSubscriber s : ntfPush ps : map runServer transports) `finally` stopServer
|
||||
where
|
||||
runServer :: (ServiceName, ATransport) -> m ()
|
||||
runServer (tcpPort, ATransport t) = do
|
||||
serverParams <- asks tlsServerParams
|
||||
runTransportServer started tcpPort serverParams (runClient t)
|
||||
|
||||
runClient :: Transport c => TProxy c -> c -> m ()
|
||||
runClient _ h = do
|
||||
kh <- asks serverIdentity
|
||||
liftIO (runExceptT $ ntfServerHandshake h kh supportedNTFServerVRange) >>= \case
|
||||
Right th -> runNtfClientTransport th
|
||||
Left _ -> pure ()
|
||||
|
||||
stopServer :: m ()
|
||||
stopServer = do
|
||||
withNtfLog closeStoreLog
|
||||
asks (smpSubscribers . subscriber) >>= readTVarIO >>= mapM_ (\SMPSubscriber {subThreadId} -> readTVarIO subThreadId >>= mapM_ (liftIO . deRefWeak >=> mapM_ killThread))
|
||||
|
||||
resubscribe :: (MonadUnliftIO m, MonadReader NtfEnv m) => NtfSubscriber -> Map NtfSubscriptionId NtfSubData -> m ()
|
||||
resubscribe NtfSubscriber {newSubQ} subs = do
|
||||
d <- asks $ resubscribeDelay . config
|
||||
forM_ subs $ \sub@NtfSubData {} ->
|
||||
whenM (ntfShouldSubscribe <$> readTVarIO (subStatus sub)) $ do
|
||||
atomically $ writeTBQueue newSubQ $ NtfSub sub
|
||||
threadDelay d
|
||||
liftIO $ logInfo "SMP connections resubscribed"
|
||||
|
||||
ntfSubscriber :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfSubscriber -> m ()
|
||||
ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = do
|
||||
raceAny_ [subscribe, receiveSMP, receiveAgent]
|
||||
where
|
||||
subscribe :: m ()
|
||||
subscribe =
|
||||
forever $
|
||||
atomically (readTBQueue newSubQ) >>= \case
|
||||
sub@(NtfSub NtfSubData {smpQueue = SMPQueueNtf {smpServer}}) -> do
|
||||
SMPSubscriber {newSubQ = subscriberSubQ} <- getSMPSubscriber smpServer
|
||||
atomically $ writeTQueue subscriberSubQ sub
|
||||
|
||||
getSMPSubscriber :: SMPServer -> m SMPSubscriber
|
||||
getSMPSubscriber smpServer =
|
||||
atomically (TM.lookup smpServer smpSubscribers) >>= maybe createSMPSubscriber pure
|
||||
where
|
||||
createSMPSubscriber = do
|
||||
sub@SMPSubscriber {subThreadId} <- atomically newSMPSubscriber
|
||||
atomically $ TM.insert smpServer sub smpSubscribers
|
||||
tId <- mkWeakThreadId =<< forkIO (runSMPSubscriber sub)
|
||||
atomically . writeTVar subThreadId $ Just tId
|
||||
pure sub
|
||||
|
||||
runSMPSubscriber :: SMPSubscriber -> m ()
|
||||
runSMPSubscriber SMPSubscriber {newSubQ = subscriberSubQ} =
|
||||
forever $
|
||||
atomically (peekTQueue subscriberSubQ)
|
||||
>>= \(NtfSub NtfSubData {smpQueue, notifierKey}) -> do
|
||||
updateSubStatus smpQueue NSPending
|
||||
let SMPQueueNtf {smpServer, notifierId} = smpQueue
|
||||
liftIO (runExceptT $ subscribeQueue ca smpServer ((SPNotifier, notifierId), notifierKey)) >>= \case
|
||||
Right _ -> do
|
||||
updateSubStatus smpQueue NSActive
|
||||
void . atomically $ readTQueue subscriberSubQ
|
||||
Left err -> do
|
||||
handleSubError smpQueue err
|
||||
case err of
|
||||
PCEResponseTimeout -> pure ()
|
||||
PCENetworkError -> pure ()
|
||||
_ -> void . atomically $ readTQueue subscriberSubQ
|
||||
|
||||
receiveSMP :: m ()
|
||||
receiveSMP = forever $ do
|
||||
(srv, _, _, ntfId, msg) <- atomically $ readTBQueue msgQ
|
||||
let smpQueue = SMPQueueNtf srv ntfId
|
||||
case msg of
|
||||
SMP.NMSG nmsgNonce encNMsgMeta -> do
|
||||
ntfTs <- liftIO getSystemTime
|
||||
st <- asks store
|
||||
NtfPushServer {pushQ} <- asks pushServer
|
||||
atomically $
|
||||
findNtfSubscriptionToken st smpQueue
|
||||
>>= mapM_ (\tkn -> writeTBQueue pushQ (tkn, PNMessage PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta}))
|
||||
SMP.END -> updateSubStatus smpQueue NSEnd
|
||||
_ -> pure ()
|
||||
pure ()
|
||||
|
||||
receiveAgent =
|
||||
forever $
|
||||
atomically (readTBQueue agentQ) >>= \case
|
||||
CAConnected _ -> pure ()
|
||||
CADisconnected srv subs -> do
|
||||
logInfo . T.pack $ "SMP server disconnected " <> host srv <> " (" <> show (length subs) <> ") subscriptions"
|
||||
forM_ subs $ \(_, ntfId) -> do
|
||||
let smpQueue = SMPQueueNtf srv ntfId
|
||||
updateSubStatus smpQueue NSInactive
|
||||
CAReconnected srv ->
|
||||
logInfo $ "SMP server reconnected " <> T.pack (host srv)
|
||||
CAResubscribed srv sub -> do
|
||||
let ntfId = snd sub
|
||||
smpQueue = SMPQueueNtf srv ntfId
|
||||
updateSubStatus smpQueue NSActive
|
||||
CASubError srv (_, ntfId) err -> do
|
||||
logError . T.pack $ "SMP subscription error on server " <> host srv <> ": " <> show err
|
||||
handleSubError (SMPQueueNtf srv ntfId) err
|
||||
|
||||
handleSubError :: SMPQueueNtf -> ProtocolClientError -> m ()
|
||||
handleSubError smpQueue = \case
|
||||
PCEProtocolError AUTH -> updateSubStatus smpQueue NSAuth
|
||||
PCEProtocolError e -> updateErr "SMP error " e
|
||||
PCEIOError e -> updateErr "IOError " e
|
||||
PCEResponseError e -> updateErr "ResponseError " e
|
||||
PCEUnexpectedResponse r -> updateErr "UnexpectedResponse " r
|
||||
PCETransportError e -> updateErr "TransportError " e
|
||||
PCESignatureError e -> updateErr "SignatureError " e
|
||||
PCEResponseTimeout -> pure ()
|
||||
PCENetworkError -> pure ()
|
||||
where
|
||||
updateErr :: Show e => ByteString -> e -> m ()
|
||||
updateErr errType e = updateSubStatus smpQueue . NSErr $ errType <> bshow e
|
||||
|
||||
updateSubStatus smpQueue status = do
|
||||
st <- asks store
|
||||
atomically (findNtfSubscription st smpQueue)
|
||||
>>= mapM_
|
||||
( \NtfSubData {ntfSubId, subStatus} -> do
|
||||
atomically $ writeTVar subStatus status
|
||||
withNtfLog $ \sl -> logSubscriptionStatus sl ntfSubId status
|
||||
)
|
||||
|
||||
ntfPush :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfPushServer -> m ()
|
||||
ntfPush s@NtfPushServer {pushQ} = forever $ do
|
||||
(tkn@NtfTknData {ntfTknId, token = DeviceToken pp _, tknStatus}, ntf) <- atomically (readTBQueue pushQ)
|
||||
liftIO $ logDebug $ "sending push notification to " <> T.pack (show pp)
|
||||
status <- readTVarIO tknStatus
|
||||
case (status, ntf) of
|
||||
(_, PNVerification _) -> do
|
||||
-- TODO check token status
|
||||
deliverNotification pp tkn ntf >>= \case
|
||||
Right _ -> do
|
||||
status_ <- atomically $ stateTVar tknStatus $ \status' -> if status' == NTActive then (Nothing, NTActive) else (Just NTConfirmed, NTConfirmed)
|
||||
forM_ status_ $ \status' -> withNtfLog $ \sl -> logTokenStatus sl ntfTknId status'
|
||||
_ -> pure ()
|
||||
(NTActive, PNCheckMessages) -> do
|
||||
void $ deliverNotification pp tkn ntf
|
||||
(NTActive, PNMessage {}) -> do
|
||||
void $ deliverNotification pp tkn ntf
|
||||
_ -> do
|
||||
liftIO $ logError "bad notification token status"
|
||||
where
|
||||
deliverNotification :: PushProvider -> NtfTknData -> PushNotification -> m (Either PushProviderError ())
|
||||
deliverNotification pp tkn@NtfTknData {ntfTknId, tknStatus} ntf = do
|
||||
deliver <- liftIO $ getPushClient s pp
|
||||
liftIO (runExceptT $ deliver tkn ntf) >>= \case
|
||||
Right _ -> pure $ Right ()
|
||||
Left e -> case e of
|
||||
PPConnection _ -> retryDeliver
|
||||
PPRetryLater -> retryDeliver
|
||||
-- TODO alert
|
||||
PPCryptoError _ -> err e
|
||||
PPResponseError _ _ -> err e
|
||||
PPTokenInvalid -> updateTknStatus NTInvalid >> err e
|
||||
PPPermanentError -> err e
|
||||
where
|
||||
retryDeliver :: m (Either PushProviderError ())
|
||||
retryDeliver = do
|
||||
deliver <- liftIO $ newPushClient s pp
|
||||
liftIO (runExceptT $ deliver tkn ntf) >>= either err (pure . Right)
|
||||
updateTknStatus :: NtfTknStatus -> m ()
|
||||
updateTknStatus status = do
|
||||
atomically $ writeTVar tknStatus status
|
||||
withNtfLog $ \sl -> logTokenStatus sl ntfTknId status
|
||||
err e = logError (T.pack $ "Push provider error (" <> show pp <> "): " <> show e) $> Left e
|
||||
|
||||
runNtfClientTransport :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> m ()
|
||||
runNtfClientTransport th@THandle {sessionId} = do
|
||||
qSize <- asks $ clientQSize . config
|
||||
ts <- liftIO getSystemTime
|
||||
c <- atomically $ newNtfServerClient qSize sessionId ts
|
||||
s <- asks subscriber
|
||||
ps <- asks pushServer
|
||||
expCfg <- asks $ inactiveClientExpiration . config
|
||||
raceAny_ ([send th c, client c s ps, receive th c] <> disconnectThread_ c expCfg)
|
||||
`finally` clientDisconnected c
|
||||
where
|
||||
disconnectThread_ c expCfg = maybe [] ((: []) . disconnectTransport th c activeAt) expCfg
|
||||
|
||||
clientDisconnected :: MonadUnliftIO m => NtfServerClient -> m ()
|
||||
clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connected False
|
||||
|
||||
receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m ()
|
||||
receive th NtfServerClient {rcvQ, sndQ, activeAt} = forever $ do
|
||||
t@(_, _, (corrId, entId, cmdOrError)) <- tGet th
|
||||
atomically . writeTVar activeAt =<< liftIO getSystemTime
|
||||
logDebug "received transmission"
|
||||
case cmdOrError of
|
||||
Left e -> write sndQ (corrId, entId, NRErr e)
|
||||
Right cmd ->
|
||||
verifyNtfTransmission t cmd >>= \case
|
||||
VRVerified req -> write rcvQ req
|
||||
VRFailed -> write sndQ (corrId, entId, NRErr AUTH)
|
||||
where
|
||||
write q t = atomically $ writeTBQueue q t
|
||||
|
||||
send :: (Transport c, MonadUnliftIO m) => THandle c -> NtfServerClient -> m ()
|
||||
send h@THandle {thVersion = v} NtfServerClient {sndQ, sessionId, activeAt} = forever $ do
|
||||
t <- atomically $ readTBQueue sndQ
|
||||
void . liftIO $ tPut h (Nothing, encodeTransmission v sessionId t)
|
||||
atomically . writeTVar activeAt =<< liftIO getSystemTime
|
||||
|
||||
-- instance Show a => Show (TVar a) where
|
||||
-- show x = unsafePerformIO $ show <$> readTVarIO x
|
||||
|
||||
data VerificationResult = VRVerified NtfRequest | VRFailed
|
||||
|
||||
verifyNtfTransmission ::
|
||||
forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => SignedTransmission NtfCmd -> NtfCmd -> m VerificationResult
|
||||
verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
|
||||
st <- asks store
|
||||
case cmd of
|
||||
NtfCmd SToken c@(TNEW tkn@(NewNtfTkn _ k _)) -> do
|
||||
r_ <- atomically $ getNtfTokenRegistration st tkn
|
||||
pure $
|
||||
if verifyCmdSignature sig_ signed k
|
||||
then case r_ of
|
||||
Just t@NtfTknData {tknVerifyKey}
|
||||
| k == tknVerifyKey -> verifiedTknCmd t c
|
||||
| otherwise -> VRFailed
|
||||
_ -> VRVerified (NtfReqNew corrId (ANE SToken tkn))
|
||||
else VRFailed
|
||||
NtfCmd SToken c -> do
|
||||
t_ <- atomically $ getNtfToken st entId
|
||||
verifyToken t_ (`verifiedTknCmd` c)
|
||||
NtfCmd SSubscription c@(SNEW sub@(NewNtfSub tknId smpQueue _)) -> do
|
||||
s_ <- atomically $ findNtfSubscription st smpQueue
|
||||
case s_ of
|
||||
Nothing -> do
|
||||
-- TODO move active token check here to differentiate error
|
||||
t_ <- atomically $ getActiveNtfToken st tknId
|
||||
verifyToken' t_ $ VRVerified (NtfReqNew corrId (ANE SSubscription sub))
|
||||
Just s@NtfSubData {tokenId = subTknId} ->
|
||||
if subTknId == tknId
|
||||
then do
|
||||
-- TODO move active token check here to differentiate error
|
||||
t_ <- atomically $ getActiveNtfToken st subTknId
|
||||
verifyToken' t_ $ verifiedSubCmd s c
|
||||
else pure $ maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
|
||||
NtfCmd SSubscription c -> do
|
||||
s_ <- atomically $ getNtfSubscription st entId
|
||||
case s_ of
|
||||
Just s@NtfSubData {tokenId = subTknId} -> do
|
||||
-- TODO move active token check here to differentiate error
|
||||
t_ <- atomically $ getActiveNtfToken st subTknId
|
||||
verifyToken' t_ $ verifiedSubCmd s c
|
||||
_ -> pure $ maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
|
||||
where
|
||||
verifiedTknCmd t c = VRVerified (NtfReqCmd SToken (NtfTkn t) (corrId, entId, c))
|
||||
verifiedSubCmd s c = VRVerified (NtfReqCmd SSubscription (NtfSub s) (corrId, entId, c))
|
||||
verifyToken :: Maybe NtfTknData -> (NtfTknData -> VerificationResult) -> m VerificationResult
|
||||
verifyToken t_ positiveVerificationResult =
|
||||
pure $ case t_ of
|
||||
Just t@NtfTknData {tknVerifyKey} ->
|
||||
if verifyCmdSignature sig_ signed tknVerifyKey
|
||||
then positiveVerificationResult t
|
||||
else VRFailed
|
||||
_ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
|
||||
verifyToken' :: Maybe NtfTknData -> VerificationResult -> m VerificationResult
|
||||
verifyToken' t_ = verifyToken t_ . const
|
||||
|
||||
client :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerClient -> NtfSubscriber -> NtfPushServer -> m ()
|
||||
client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPushServer {pushQ, intervalNotifiers} =
|
||||
forever $
|
||||
atomically (readTBQueue rcvQ)
|
||||
>>= processCommand
|
||||
>>= atomically . writeTBQueue sndQ
|
||||
where
|
||||
processCommand :: NtfRequest -> m (Transmission NtfResponse)
|
||||
processCommand = \case
|
||||
NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn _ _ dhPubKey)) -> do
|
||||
logDebug "TNEW - new token"
|
||||
st <- asks store
|
||||
ks@(srvDhPubKey, srvDhPrivKey) <- liftIO C.generateKeyPair'
|
||||
let dhSecret = C.dh' dhPubKey srvDhPrivKey
|
||||
tknId <- getId
|
||||
regCode <- getRegCode
|
||||
tkn <- atomically $ mkNtfTknData tknId newTkn ks dhSecret regCode
|
||||
atomically $ addNtfToken st tknId tkn
|
||||
atomically $ writeTBQueue pushQ (tkn, PNVerification regCode)
|
||||
withNtfLog (`logCreateToken` tkn)
|
||||
pure (corrId, "", NRTknId tknId srvDhPubKey)
|
||||
NtfReqCmd SToken (NtfTkn tkn@NtfTknData {ntfTknId, tknStatus, tknRegCode, tknDhSecret, tknDhKeys = (srvDhPubKey, srvDhPrivKey), tknCronInterval}) (corrId, tknId, cmd) -> do
|
||||
status <- readTVarIO tknStatus
|
||||
(corrId,tknId,) <$> case cmd of
|
||||
TNEW (NewNtfTkn _ _ dhPubKey) -> do
|
||||
logDebug "TNEW - registered token"
|
||||
let dhSecret = C.dh' dhPubKey srvDhPrivKey
|
||||
-- it is required that DH secret is the same, to avoid failed verifications if notification is delaying
|
||||
if tknDhSecret == dhSecret
|
||||
then do
|
||||
atomically $ writeTBQueue pushQ (tkn, PNVerification tknRegCode)
|
||||
pure $ NRTknId ntfTknId srvDhPubKey
|
||||
else pure $ NRErr AUTH
|
||||
TVFY code -- this allows repeated verification for cases when client connection dropped before server response
|
||||
| (status == NTRegistered || status == NTConfirmed || status == NTActive) && tknRegCode == code -> do
|
||||
logDebug "TVFY - token verified"
|
||||
st <- asks store
|
||||
atomically $ writeTVar tknStatus NTActive
|
||||
tIds <- atomically $ removeInactiveTokenRegistrations st tkn
|
||||
forM_ tIds cancelInvervalNotifications
|
||||
withNtfLog $ \s -> logTokenStatus s tknId NTActive
|
||||
pure NROk
|
||||
| otherwise -> do
|
||||
logDebug "TVFY - incorrect code or token status"
|
||||
pure $ NRErr AUTH
|
||||
TCHK -> do
|
||||
logDebug "TCHK"
|
||||
pure $ NRTkn status
|
||||
TRPL token' -> do
|
||||
logDebug "TRPL - replace token"
|
||||
st <- asks store
|
||||
regCode <- getRegCode
|
||||
atomically $ do
|
||||
removeTokenRegistration st tkn
|
||||
writeTVar tknStatus NTRegistered
|
||||
let tkn' = tkn {token = token', tknRegCode = regCode}
|
||||
addNtfToken st tknId tkn'
|
||||
writeTBQueue pushQ (tkn', PNVerification regCode)
|
||||
withNtfLog $ \s -> logUpdateToken s tknId token' regCode
|
||||
pure NROk
|
||||
TDEL -> do
|
||||
logDebug "TDEL"
|
||||
st <- asks store
|
||||
qs <- atomically $ deleteNtfToken st tknId
|
||||
forM_ qs $ \SMPQueueNtf {smpServer, notifierId} ->
|
||||
atomically $ removeSubscription ca smpServer (SPNotifier, notifierId)
|
||||
cancelInvervalNotifications tknId
|
||||
withNtfLog (`logDeleteToken` tknId)
|
||||
pure NROk
|
||||
TCRN 0 -> do
|
||||
logDebug "TCRN 0"
|
||||
atomically $ writeTVar tknCronInterval 0
|
||||
cancelInvervalNotifications tknId
|
||||
withNtfLog $ \s -> logTokenCron s tknId 0
|
||||
pure NROk
|
||||
TCRN int
|
||||
| int < 20 -> pure $ NRErr QUOTA
|
||||
| otherwise -> do
|
||||
logDebug "TCRN"
|
||||
atomically $ writeTVar tknCronInterval int
|
||||
atomically (TM.lookup tknId intervalNotifiers) >>= \case
|
||||
Nothing -> runIntervalNotifier int
|
||||
Just IntervalNotifier {interval, action} ->
|
||||
unless (interval == int) $ do
|
||||
uninterruptibleCancel action
|
||||
runIntervalNotifier int
|
||||
withNtfLog $ \s -> logTokenCron s tknId int
|
||||
pure NROk
|
||||
where
|
||||
runIntervalNotifier interval = do
|
||||
action <- async . intervalNotifier $ fromIntegral interval * 1000000 * 60
|
||||
let notifier = IntervalNotifier {action, token = tkn, interval}
|
||||
atomically $ TM.insert tknId notifier intervalNotifiers
|
||||
where
|
||||
intervalNotifier delay = forever $ do
|
||||
threadDelay delay
|
||||
atomically $ writeTBQueue pushQ (tkn, PNCheckMessages)
|
||||
NtfReqNew corrId (ANE SSubscription newSub) -> do
|
||||
logDebug "SNEW - new subscription"
|
||||
st <- asks store
|
||||
subId <- getId
|
||||
sub <- atomically $ mkNtfSubData subId newSub
|
||||
resp <-
|
||||
atomically (addNtfSubscription st subId sub) >>= \case
|
||||
Just _ -> atomically (writeTBQueue newSubQ $ NtfSub sub) $> NRSubId subId
|
||||
_ -> pure $ NRErr AUTH
|
||||
withNtfLog (`logCreateSubscription` sub)
|
||||
pure (corrId, "", resp)
|
||||
NtfReqCmd SSubscription (NtfSub NtfSubData {smpQueue = SMPQueueNtf {smpServer, notifierId}, notifierKey = registeredNKey, subStatus}) (corrId, subId, cmd) -> do
|
||||
status <- readTVarIO subStatus
|
||||
(corrId,subId,) <$> case cmd of
|
||||
SNEW (NewNtfSub _ _ notifierKey) -> do
|
||||
logDebug "SNEW - existing subscription"
|
||||
-- TODO retry if subscription failed, if pending or AUTH do nothing
|
||||
pure $
|
||||
if notifierKey == registeredNKey
|
||||
then NRSubId subId
|
||||
else NRErr AUTH
|
||||
SCHK -> do
|
||||
logDebug "SCHK"
|
||||
pure $ NRSub status
|
||||
SDEL -> do
|
||||
logDebug "SDEL"
|
||||
st <- asks store
|
||||
atomically $ deleteNtfSubscription st subId
|
||||
atomically $ removeSubscription ca smpServer (SPNotifier, notifierId)
|
||||
withNtfLog (`logDeleteSubscription` subId)
|
||||
pure NROk
|
||||
PING -> pure NRPong
|
||||
getId :: m NtfEntityId
|
||||
getId = getRandomBytes =<< asks (subIdBytes . config)
|
||||
getRegCode :: m NtfRegCode
|
||||
getRegCode = NtfRegCode <$> (getRandomBytes =<< asks (regCodeBytes . config))
|
||||
getRandomBytes :: Int -> m ByteString
|
||||
getRandomBytes n = do
|
||||
gVar <- asks idsDrg
|
||||
atomically (C.pseudoRandomBytes n gVar)
|
||||
cancelInvervalNotifications :: NtfTokenId -> m ()
|
||||
cancelInvervalNotifications tknId =
|
||||
atomically (TM.lookupDelete tknId intervalNotifiers)
|
||||
>>= mapM_ (uninterruptibleCancel . action)
|
||||
|
||||
withNtfLog :: (MonadUnliftIO m, MonadReader NtfEnv m) => (StoreLog 'WriteMode -> IO a) -> m ()
|
||||
withNtfLog action = liftIO . mapM_ action =<< asks storeLog
|
|
@ -0,0 +1,156 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.Env where
|
||||
|
||||
import Control.Concurrent (ThreadId)
|
||||
import Control.Concurrent.Async (Async)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Time.Clock.System (SystemTime)
|
||||
import Data.Word (Word16)
|
||||
import Data.X509.Validation (Fingerprint (..))
|
||||
import Network.Socket
|
||||
import qualified Network.TLS as T
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Client.Agent
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
import Simplex.Messaging.Notifications.Server.Store
|
||||
import Simplex.Messaging.Notifications.Server.StoreLog
|
||||
import Simplex.Messaging.Protocol (CorrId, SMPServer, Transmission)
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams)
|
||||
import System.IO (IOMode (..))
|
||||
import System.Mem.Weak (Weak)
|
||||
import UnliftIO.STM
|
||||
|
||||
data NtfServerConfig = NtfServerConfig
|
||||
{ transports :: [(ServiceName, ATransport)],
|
||||
subIdBytes :: Int,
|
||||
regCodeBytes :: Int,
|
||||
clientQSize :: Natural,
|
||||
subQSize :: Natural,
|
||||
pushQSize :: Natural,
|
||||
smpAgentCfg :: SMPClientAgentConfig,
|
||||
apnsConfig :: APNSPushClientConfig,
|
||||
inactiveClientExpiration :: Maybe ExpirationConfig,
|
||||
storeLogFile :: Maybe FilePath,
|
||||
resubscribeDelay :: Int, -- microseconds
|
||||
-- CA certificate private key is not needed for initialization
|
||||
caCertificateFile :: FilePath,
|
||||
privateKeyFile :: FilePath,
|
||||
certificateFile :: FilePath
|
||||
}
|
||||
|
||||
defaultInactiveClientExpiration :: ExpirationConfig
|
||||
defaultInactiveClientExpiration =
|
||||
ExpirationConfig
|
||||
{ ttl = 7200, -- 2 hours
|
||||
checkInterval = 3600 -- seconds, 1 hour
|
||||
}
|
||||
|
||||
data NtfEnv = NtfEnv
|
||||
{ config :: NtfServerConfig,
|
||||
subscriber :: NtfSubscriber,
|
||||
pushServer :: NtfPushServer,
|
||||
store :: NtfStore,
|
||||
storeLog :: Maybe (StoreLog 'WriteMode),
|
||||
idsDrg :: TVar ChaChaDRG,
|
||||
serverIdentity :: C.KeyHash,
|
||||
tlsServerParams :: T.ServerParams,
|
||||
serverIdentity :: C.KeyHash
|
||||
}
|
||||
|
||||
newNtfServerEnv :: (MonadUnliftIO m, MonadRandom m) => NtfServerConfig -> m NtfEnv
|
||||
newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, apnsConfig, storeLogFile, caCertificateFile, certificateFile, privateKeyFile} = do
|
||||
idsDrg <- newTVarIO =<< drgNew
|
||||
store <- atomically newNtfStore
|
||||
storeLog <- liftIO $ mapM (`readWriteNtfStore` store) storeLogFile
|
||||
subscriber <- atomically $ newNtfSubscriber subQSize smpAgentCfg
|
||||
pushServer <- atomically $ newNtfPushServer pushQSize apnsConfig
|
||||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile
|
||||
pure NtfEnv {config, subscriber, pushServer, store, storeLog, idsDrg, tlsServerParams, serverIdentity = C.KeyHash fp}
|
||||
|
||||
data NtfSubscriber = NtfSubscriber
|
||||
{ smpSubscribers :: TMap SMPServer SMPSubscriber,
|
||||
newSubQ :: TBQueue (NtfEntityRec 'Subscription),
|
||||
smpAgent :: SMPClientAgent
|
||||
}
|
||||
|
||||
newNtfSubscriber :: Natural -> SMPClientAgentConfig -> STM NtfSubscriber
|
||||
newNtfSubscriber qSize smpAgentCfg = do
|
||||
smpSubscribers <- TM.empty
|
||||
newSubQ <- newTBQueue qSize
|
||||
smpAgent <- newSMPClientAgent smpAgentCfg
|
||||
pure NtfSubscriber {smpSubscribers, newSubQ, smpAgent}
|
||||
|
||||
data SMPSubscriber = SMPSubscriber
|
||||
{ newSubQ :: TQueue (NtfEntityRec 'Subscription),
|
||||
subThreadId :: TVar (Maybe (Weak ThreadId))
|
||||
}
|
||||
|
||||
newSMPSubscriber :: STM SMPSubscriber
|
||||
newSMPSubscriber = do
|
||||
newSubQ <- newTQueue
|
||||
subThreadId <- newTVar Nothing
|
||||
pure SMPSubscriber {newSubQ, subThreadId}
|
||||
|
||||
data NtfPushServer = NtfPushServer
|
||||
{ pushQ :: TBQueue (NtfTknData, PushNotification),
|
||||
pushClients :: TMap PushProvider PushProviderClient,
|
||||
intervalNotifiers :: TMap NtfTokenId IntervalNotifier,
|
||||
apnsConfig :: APNSPushClientConfig
|
||||
}
|
||||
|
||||
data IntervalNotifier = IntervalNotifier
|
||||
{ action :: Async (),
|
||||
token :: NtfTknData,
|
||||
interval :: Word16
|
||||
}
|
||||
|
||||
newNtfPushServer :: Natural -> APNSPushClientConfig -> STM NtfPushServer
|
||||
newNtfPushServer qSize apnsConfig = do
|
||||
pushQ <- newTBQueue qSize
|
||||
pushClients <- TM.empty
|
||||
intervalNotifiers <- TM.empty
|
||||
pure NtfPushServer {pushQ, pushClients, intervalNotifiers, apnsConfig}
|
||||
|
||||
newPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient
|
||||
newPushClient NtfPushServer {apnsConfig, pushClients} pp = do
|
||||
c <- apnsPushProviderClient <$> createAPNSPushClient (apnsProviderHost pp) apnsConfig
|
||||
atomically $ TM.insert pp c pushClients
|
||||
pure c
|
||||
|
||||
getPushClient :: NtfPushServer -> PushProvider -> IO PushProviderClient
|
||||
getPushClient s@NtfPushServer {pushClients} pp =
|
||||
atomically (TM.lookup pp pushClients) >>= maybe (newPushClient s pp) pure
|
||||
|
||||
data NtfRequest
|
||||
= NtfReqNew CorrId ANewNtfEntity
|
||||
| forall e. NtfEntityI e => NtfReqCmd (SNtfEntity e) (NtfEntityRec e) (Transmission (NtfCommand e))
|
||||
|
||||
data NtfServerClient = NtfServerClient
|
||||
{ rcvQ :: TBQueue NtfRequest,
|
||||
sndQ :: TBQueue (Transmission NtfResponse),
|
||||
sessionId :: ByteString,
|
||||
connected :: TVar Bool,
|
||||
activeAt :: TVar SystemTime
|
||||
}
|
||||
|
||||
newNtfServerClient :: Natural -> ByteString -> SystemTime -> STM NtfServerClient
|
||||
newNtfServerClient qSize sessionId ts = do
|
||||
rcvQ <- newTBQueue qSize
|
||||
sndQ <- newTBQueue qSize
|
||||
connected <- newTVar True
|
||||
activeAt <- newTVar ts
|
||||
return NtfServerClient {rcvQ, sndQ, sessionId, connected, activeAt}
|
|
@ -0,0 +1 @@
|
|||
local.env
|
|
@ -0,0 +1,374 @@
|
|||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
|
||||
{-# HLINT ignore "Use newtype instead of data" #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.Push.APNS where
|
||||
|
||||
import Control.Exception (Exception)
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad.Except
|
||||
import Crypto.Hash.Algorithms (SHA256 (..))
|
||||
import qualified Crypto.PubKey.ECC.ECDSA as EC
|
||||
import qualified Crypto.PubKey.ECC.Types as ECT
|
||||
import Crypto.Random (ChaChaDRG, drgNew)
|
||||
import qualified Crypto.Store.PKCS8 as PK
|
||||
import Data.ASN1.BinaryEncoding (DER (..))
|
||||
import Data.ASN1.Encoding
|
||||
import Data.ASN1.Types
|
||||
import Data.Aeson (FromJSON, ToJSON, (.=))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Builder (lazyByteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Lazy.Char8 as LB
|
||||
import qualified Data.CaseInsensitive as CI
|
||||
import Data.Int (Int64)
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeUtf8With)
|
||||
import Data.Time.Clock.System
|
||||
import qualified Data.X509 as X
|
||||
import GHC.Generics (Generic)
|
||||
import Network.HTTP.Types (HeaderName, Status)
|
||||
import qualified Network.HTTP.Types as N
|
||||
import Network.HTTP2.Client (Request)
|
||||
import qualified Network.HTTP2.Client as H
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Store (NtfTknData (..))
|
||||
import Simplex.Messaging.Protocol (EncNMsgMeta)
|
||||
import Simplex.Messaging.Transport.HTTP2.Client
|
||||
import System.Environment (getEnv)
|
||||
import UnliftIO.STM
|
||||
|
||||
data JWTHeader = JWTHeader
|
||||
{ alg :: Text, -- key algorithm, ES256 for APNS
|
||||
kid :: Text -- key ID
|
||||
}
|
||||
deriving (Show, Generic)
|
||||
|
||||
instance ToJSON JWTHeader where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
|
||||
data JWTClaims = JWTClaims
|
||||
{ iss :: Text, -- issuer, team ID for APNS
|
||||
iat :: Int64 -- issue time, seconds from epoch
|
||||
}
|
||||
deriving (Show, Generic)
|
||||
|
||||
instance ToJSON JWTClaims where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
|
||||
data JWTToken = JWTToken JWTHeader JWTClaims
|
||||
deriving (Show)
|
||||
|
||||
mkJWTToken :: JWTHeader -> Text -> IO JWTToken
|
||||
mkJWTToken hdr iss = do
|
||||
iat <- systemSeconds <$> getSystemTime
|
||||
pure $ JWTToken hdr JWTClaims {iss, iat}
|
||||
|
||||
type SignedJWTToken = ByteString
|
||||
|
||||
signedJWTToken :: EC.PrivateKey -> JWTToken -> IO SignedJWTToken
|
||||
signedJWTToken pk (JWTToken hdr claims) = do
|
||||
let hc = jwtEncode hdr <> "." <> jwtEncode claims
|
||||
sig <- EC.sign pk SHA256 hc
|
||||
pure $ hc <> "." <> serialize sig
|
||||
where
|
||||
jwtEncode :: ToJSON a => a -> ByteString
|
||||
jwtEncode = U.encodeUnpadded . LB.toStrict . J.encode
|
||||
serialize sig = U.encodeUnpadded $ encodeASN1' DER [Start Sequence, IntVal (EC.sign_r sig), IntVal (EC.sign_s sig), End Sequence]
|
||||
|
||||
readECPrivateKey :: FilePath -> IO EC.PrivateKey
|
||||
readECPrivateKey f = do
|
||||
-- TODO this is specific to APNS key
|
||||
[PK.Unprotected (X.PrivKeyEC X.PrivKeyEC_Named {privkeyEC_name, privkeyEC_priv})] <- PK.readKeyFile f
|
||||
pure EC.PrivateKey {private_curve = ECT.getCurveByName privkeyEC_name, private_d = privkeyEC_priv}
|
||||
|
||||
data PushNotification
|
||||
= PNVerification NtfRegCode
|
||||
| PNMessage PNMessageData
|
||||
| PNAlert Text
|
||||
| PNCheckMessages
|
||||
deriving (Show)
|
||||
|
||||
data PNMessageData = PNMessageData
|
||||
{ smpQueue :: SMPQueueNtf,
|
||||
ntfTs :: SystemTime,
|
||||
nmsgNonce :: C.CbNonce,
|
||||
encNMsgMeta :: EncNMsgMeta
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
instance StrEncoding PNMessageData where
|
||||
strEncode PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta} =
|
||||
strEncode (smpQueue, ntfTs, nmsgNonce, encNMsgMeta)
|
||||
strP = do
|
||||
(smpQueue, ntfTs, nmsgNonce, encNMsgMeta) <- strP
|
||||
pure PNMessageData {smpQueue, ntfTs, nmsgNonce, encNMsgMeta}
|
||||
|
||||
data APNSNotification = APNSNotification {aps :: APNSNotificationBody, notificationData :: Maybe J.Value}
|
||||
deriving (Show, Generic)
|
||||
|
||||
instance ToJSON APNSNotification where
|
||||
toJSON = J.genericToJSON J.defaultOptions {J.omitNothingFields = True}
|
||||
toEncoding = J.genericToEncoding J.defaultOptions {J.omitNothingFields = True}
|
||||
|
||||
data APNSNotificationBody
|
||||
= APNSBackground {contentAvailable :: Int}
|
||||
| APNSMutableContent {mutableContent :: Int, alert :: APNSAlertBody, category :: Maybe Text}
|
||||
| APNSAlert {alert :: APNSAlertBody, badge :: Maybe Int, sound :: Maybe Text, category :: Maybe Text}
|
||||
deriving (Show, Generic)
|
||||
|
||||
apnsJSONOptions :: J.Options
|
||||
apnsJSONOptions = J.defaultOptions {J.omitNothingFields = True, J.sumEncoding = J.UntaggedValue, J.fieldLabelModifier = J.camelTo2 '-'}
|
||||
|
||||
instance ToJSON APNSNotificationBody where
|
||||
toJSON = J.genericToJSON apnsJSONOptions
|
||||
toEncoding = J.genericToEncoding apnsJSONOptions
|
||||
|
||||
type APNSNotificationData = Map Text Text
|
||||
|
||||
data APNSAlertBody = APNSAlertObject {title :: Text, subtitle :: Text, body :: Text} | APNSAlertText Text
|
||||
deriving (Show)
|
||||
|
||||
instance ToJSON APNSAlertBody where
|
||||
toEncoding = \case
|
||||
APNSAlertObject {title, subtitle, body} -> J.pairs $ "title" .= title <> "subtitle" .= subtitle <> "body" .= body
|
||||
APNSAlertText t -> JE.text t
|
||||
toJSON = \case
|
||||
APNSAlertObject {title, subtitle, body} -> J.object ["title" .= title, "subtitle" .= subtitle, "body" .= body]
|
||||
APNSAlertText t -> J.String t
|
||||
|
||||
-- APNS notification types
|
||||
--
|
||||
-- Visible alerts:
|
||||
-- {
|
||||
-- "aps" : {
|
||||
-- "alert" : {
|
||||
-- "title" : "Game Request",
|
||||
-- "subtitle" : "Five Card Draw",
|
||||
-- "body" : "Bob wants to play poker"
|
||||
-- },
|
||||
-- "badge" : 9,
|
||||
-- "sound" : "bingbong.aiff",
|
||||
-- "category" : "GAME_INVITATION"
|
||||
-- },
|
||||
-- "gameID" : "12345678"
|
||||
-- }
|
||||
--
|
||||
-- Simple text alert:
|
||||
-- {"aps":{"alert":"you have a new message"}}
|
||||
--
|
||||
-- Background notification to fetch content
|
||||
-- {"aps":{"content-available":1}}
|
||||
--
|
||||
-- Mutable content notification that must be shown but can be processed before before being shown (up to 30 sec)
|
||||
-- {
|
||||
-- "aps" : {
|
||||
-- "category" : "SECRET",
|
||||
-- "mutable-content" : 1,
|
||||
-- "alert" : {
|
||||
-- "title" : "Secret Message!",
|
||||
-- "body" : "(Encrypted)"
|
||||
-- },
|
||||
-- },
|
||||
-- "ENCRYPTED_DATA" : "Salted__·öîQÊ$UDì_¶Ù∞èΩ^¬%gq∞NÿÒQùw"
|
||||
-- }
|
||||
|
||||
data APNSPushClientConfig = APNSPushClientConfig
|
||||
{ tokenTTL :: Int64,
|
||||
authKeyFileEnv :: String,
|
||||
authKeyAlg :: Text,
|
||||
authKeyIdEnv :: String,
|
||||
paddedNtfLength :: Int,
|
||||
appName :: ByteString,
|
||||
appTeamId :: Text,
|
||||
apnsPort :: ServiceName,
|
||||
http2cfg :: HTTP2ClientConfig
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
apnsProviderHost :: PushProvider -> HostName
|
||||
apnsProviderHost = \case
|
||||
PPApnsTest -> "localhost"
|
||||
PPApnsDev -> "api.sandbox.push.apple.com"
|
||||
PPApnsProd -> "api.push.apple.com"
|
||||
|
||||
defaultAPNSPushClientConfig :: APNSPushClientConfig
|
||||
defaultAPNSPushClientConfig =
|
||||
APNSPushClientConfig
|
||||
{ tokenTTL = 1800, -- 30 minutes
|
||||
authKeyFileEnv = "APNS_KEY_FILE", -- the environment variables APNS_KEY_FILE and APNS_KEY_ID must be set, or the server would fail to start
|
||||
authKeyAlg = "ES256",
|
||||
authKeyIdEnv = "APNS_KEY_ID",
|
||||
paddedNtfLength = 512,
|
||||
appName = "chat.simplex.app",
|
||||
appTeamId = "5NN7GUYB6T",
|
||||
apnsPort = "443",
|
||||
http2cfg = defaultHTTP2ClientConfig
|
||||
}
|
||||
|
||||
data APNSPushClient = APNSPushClient
|
||||
{ https2Client :: TVar (Maybe HTTP2Client),
|
||||
privateKey :: EC.PrivateKey,
|
||||
jwtHeader :: JWTHeader,
|
||||
jwtToken :: TVar (JWTToken, SignedJWTToken),
|
||||
nonceDrg :: TVar ChaChaDRG,
|
||||
apnsHost :: HostName,
|
||||
apnsCfg :: APNSPushClientConfig
|
||||
}
|
||||
|
||||
createAPNSPushClient :: HostName -> APNSPushClientConfig -> IO APNSPushClient
|
||||
createAPNSPushClient apnsHost apnsCfg@APNSPushClientConfig {authKeyFileEnv, authKeyAlg, authKeyIdEnv, appTeamId} = do
|
||||
https2Client <- newTVarIO Nothing
|
||||
void $ connectHTTPS2 apnsHost apnsCfg https2Client
|
||||
privateKey <- readECPrivateKey =<< getEnv authKeyFileEnv
|
||||
authKeyId <- T.pack <$> getEnv authKeyIdEnv
|
||||
let jwtHeader = JWTHeader {alg = authKeyAlg, kid = authKeyId}
|
||||
jwtToken <- newTVarIO =<< mkApnsJWTToken appTeamId jwtHeader privateKey
|
||||
nonceDrg <- drgNew >>= newTVarIO
|
||||
pure APNSPushClient {https2Client, privateKey, jwtHeader, jwtToken, nonceDrg, apnsHost, apnsCfg}
|
||||
|
||||
getApnsJWTToken :: APNSPushClient -> IO SignedJWTToken
|
||||
getApnsJWTToken APNSPushClient {apnsCfg = APNSPushClientConfig {appTeamId, tokenTTL}, privateKey, jwtHeader, jwtToken} = do
|
||||
(jwt, signedJWT) <- readTVarIO jwtToken
|
||||
age <- jwtTokenAge jwt
|
||||
if age < tokenTTL
|
||||
then pure signedJWT
|
||||
else do
|
||||
t@(_, signedJWT') <- mkApnsJWTToken appTeamId jwtHeader privateKey
|
||||
atomically $ writeTVar jwtToken t
|
||||
pure signedJWT'
|
||||
where
|
||||
jwtTokenAge (JWTToken _ JWTClaims {iat}) = subtract iat . systemSeconds <$> getSystemTime
|
||||
|
||||
mkApnsJWTToken :: Text -> JWTHeader -> EC.PrivateKey -> IO (JWTToken, SignedJWTToken)
|
||||
mkApnsJWTToken appTeamId jwtHeader privateKey = do
|
||||
jwt <- mkJWTToken jwtHeader appTeamId
|
||||
signedJWT <- signedJWTToken privateKey jwt
|
||||
pure (jwt, signedJWT)
|
||||
|
||||
connectHTTPS2 :: HostName -> APNSPushClientConfig -> TVar (Maybe HTTP2Client) -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
connectHTTPS2 apnsHost APNSPushClientConfig {apnsPort, http2cfg} https2Client = do
|
||||
r <- getHTTP2Client apnsHost apnsPort http2cfg disconnected
|
||||
case r of
|
||||
Right client -> atomically . writeTVar https2Client $ Just client
|
||||
Left e -> putStrLn $ "Error connecting to APNS: " <> show e
|
||||
pure r
|
||||
where
|
||||
disconnected = atomically $ writeTVar https2Client Nothing
|
||||
|
||||
getApnsHTTP2Client :: APNSPushClient -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
getApnsHTTP2Client APNSPushClient {https2Client, apnsHost, apnsCfg} =
|
||||
readTVarIO https2Client >>= maybe (connectHTTPS2 apnsHost apnsCfg https2Client) (pure . Right)
|
||||
|
||||
disconnectApnsHTTP2Client :: APNSPushClient -> IO ()
|
||||
disconnectApnsHTTP2Client APNSPushClient {https2Client} =
|
||||
readTVarIO https2Client >>= mapM_ closeHTTP2Client >> atomically (writeTVar https2Client Nothing)
|
||||
|
||||
ntfCategoryCheckMessage :: Text
|
||||
ntfCategoryCheckMessage = "NTF_CAT_CHECK_MESSAGE"
|
||||
|
||||
apnsNotification :: NtfTknData -> C.CbNonce -> Int -> PushNotification -> Either C.CryptoError APNSNotification
|
||||
apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case
|
||||
PNVerification (NtfRegCode code) ->
|
||||
encrypt code $ \code' ->
|
||||
apn APNSBackground {contentAvailable = 1} . Just $ J.object ["nonce" .= nonce, "verification" .= code']
|
||||
PNMessage pnMessageData ->
|
||||
encrypt (strEncode pnMessageData) $ \ntfData ->
|
||||
apn apnMutableContent . Just $ J.object ["nonce" .= nonce, "message" .= ntfData]
|
||||
PNAlert text -> Right $ apn (apnAlert $ APNSAlertText text) Nothing
|
||||
PNCheckMessages -> Right $ apn APNSBackground {contentAvailable = 1} . Just $ J.object ["checkMessages" .= True]
|
||||
where
|
||||
encrypt :: ByteString -> (Text -> APNSNotification) -> Either C.CryptoError APNSNotification
|
||||
encrypt ntfData f = f . safeDecodeUtf8 . U.encode <$> C.cbEncrypt tknDhSecret nonce ntfData paddedLen
|
||||
apn aps notificationData = APNSNotification {aps, notificationData}
|
||||
apnMutableContent = APNSMutableContent {mutableContent = 1, alert = APNSAlertText "Encrypted message or another app event", category = Just ntfCategoryCheckMessage}
|
||||
apnAlert alert = APNSAlert {alert, badge = Nothing, sound = Nothing, category = Nothing}
|
||||
safeDecodeUtf8 = decodeUtf8With onError where onError _ _ = Just '?'
|
||||
|
||||
apnsRequest :: APNSPushClient -> ByteString -> APNSNotification -> IO Request
|
||||
apnsRequest c tkn ntf@APNSNotification {aps} = do
|
||||
signedJWT <- getApnsJWTToken c
|
||||
pure $ H.requestBuilder N.methodPost path (headers signedJWT) (lazyByteString $ J.encode ntf)
|
||||
where
|
||||
path = "/3/device/" <> tkn
|
||||
headers signedJWT =
|
||||
[ (hApnsTopic, appName $ apnsCfg (c :: APNSPushClient)),
|
||||
(hApnsPushType, pushType aps),
|
||||
(N.hAuthorization, "bearer " <> signedJWT)
|
||||
]
|
||||
<> [(hApnsPriority, "5") | isBackground aps]
|
||||
isBackground = \case
|
||||
APNSBackground {} -> True
|
||||
_ -> False
|
||||
pushType = \case
|
||||
APNSBackground {} -> "background"
|
||||
_ -> "alert"
|
||||
|
||||
data PushProviderError
|
||||
= PPConnection HTTP2ClientError
|
||||
| PPCryptoError C.CryptoError
|
||||
| PPResponseError (Maybe Status) Text
|
||||
| PPTokenInvalid
|
||||
| PPRetryLater
|
||||
| PPPermanentError
|
||||
deriving (Show, Exception)
|
||||
|
||||
type PushProviderClient = NtfTknData -> PushNotification -> ExceptT PushProviderError IO ()
|
||||
|
||||
-- this is not a newtype on purpose to have a correct JSON encoding as a record
|
||||
data APNSErrorResponse = APNSErrorResponse {reason :: Text}
|
||||
deriving (Generic, FromJSON)
|
||||
|
||||
apnsPushProviderClient :: APNSPushClient -> PushProviderClient
|
||||
apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {token = DeviceToken _ tknStr} pn = do
|
||||
http2 <- liftHTTPS2 $ getApnsHTTP2Client c
|
||||
nonce <- atomically $ C.pseudoRandomCbNonce nonceDrg
|
||||
apnsNtf <- liftEither $ first PPCryptoError $ apnsNotification tkn nonce (paddedNtfLength apnsCfg) pn
|
||||
req <- liftIO $ apnsRequest c tknStr apnsNtf
|
||||
HTTP2Response {response, respBody} <- liftHTTPS2 $ sendRequest http2 req
|
||||
let status = H.responseStatus response
|
||||
reason' = maybe "" reason $ J.decodeStrict' respBody
|
||||
logDebug $ "APNS response: " <> T.pack (show status) <> " " <> reason'
|
||||
result status reason'
|
||||
where
|
||||
result :: Maybe Status -> Text -> ExceptT PushProviderError IO ()
|
||||
result status reason'
|
||||
| status == Just N.ok200 = pure ()
|
||||
| status == Just N.badRequest400 =
|
||||
case reason' of
|
||||
"BadDeviceToken" -> throwError PPTokenInvalid
|
||||
"DeviceTokenNotForTopic" -> throwError PPTokenInvalid
|
||||
"TopicDisallowed" -> throwError PPPermanentError
|
||||
_ -> err status reason'
|
||||
| status == Just N.forbidden403 = case reason' of
|
||||
"ExpiredProviderToken" -> throwError PPPermanentError -- there should be no point retrying it as the token was refreshed
|
||||
"InvalidProviderToken" -> throwError PPPermanentError
|
||||
_ -> err status reason'
|
||||
| status == Just N.gone410 = throwError PPTokenInvalid
|
||||
| status == Just N.serviceUnavailable503 = liftIO (disconnectApnsHTTP2Client c) >> throwError PPRetryLater
|
||||
-- Just tooManyRequests429 -> TODO TooManyRequests - too many requests for the same token
|
||||
| otherwise = err status reason'
|
||||
err :: Maybe Status -> Text -> ExceptT PushProviderError IO ()
|
||||
err s r = throwError $ PPResponseError s r
|
||||
liftHTTPS2 a = ExceptT $ first PPConnection <$> a
|
||||
|
||||
hApnsTopic :: HeaderName
|
||||
hApnsTopic = CI.mk "apns-topic"
|
||||
|
||||
hApnsPushType :: HeaderName
|
||||
hApnsPushType = CI.mk "apns-push-type"
|
||||
|
||||
hApnsPriority :: HeaderName
|
||||
hApnsPriority = CI.mk "apns-priority"
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/sh
|
||||
|
||||
export TEAM_ID=5NN7GUYB6T
|
||||
# export APNS_KEY_FILE=""
|
||||
# export APNS_KEY_ID=""
|
||||
export TOPIC=chat.simplex.app
|
||||
# export DEVICE_TOKEN=
|
||||
export APNS_HOST_NAME=api.sandbox.push.apple.com
|
||||
|
||||
export JWT_ISSUE_TIME=$(date +%s)
|
||||
export JWT_HEADER=$(printf '{"alg":"ES256","kid":"%s"}' "${APNS_KEY_ID}" | openssl base64 -e -A | tr -- '+/' '-_' | tr -d =)
|
||||
export JWT_CLAIMS=$(printf '{"iss":"%s","iat":%d}' "${TEAM_ID}" "${JWT_ISSUE_TIME}" | openssl base64 -e -A | tr -- '+/' '-_' | tr -d =)
|
||||
export JWT_HEADER_CLAIMS="${JWT_HEADER}.${JWT_CLAIMS}"
|
||||
|
||||
export JWT_SIGNED_HEADER_CLAIMS=$(printf "${JWT_HEADER_CLAIMS}" | openssl dgst -binary -sha256 -sign "${APNS_KEY_FILE}" | openssl base64 -e -A | tr -- '+/' '-_' | tr -d =)
|
||||
export AUTHENTICATION_TOKEN="${JWT_HEADER}.${JWT_CLAIMS}.${JWT_SIGNED_HEADER_CLAIMS}"
|
||||
|
||||
# simple alert
|
||||
# curl -v --header "apns-topic: $TOPIC" --header "apns-push-type: alert" --header "authorization: bearer $AUTHENTICATION_TOKEN" --data '{"aps":{"alert":"you have a new message"},"data":{"test":"123"}}' --http2 https://${APNS_HOST_NAME}/3/device/${DEVICE_TOKEN}
|
||||
|
||||
# background notification
|
||||
# curl -v --header "apns-topic: $TOPIC" --header "apns-push-type: background" --header "apns-priority: 5" --header "authorization: bearer $AUTHENTICATION_TOKEN" --data '{"aps":{"content-available":1}}' --http2 https://${APNS_HOST_NAME}/3/device/${DEVICE_TOKEN}
|
||||
|
||||
# mutable-content notification
|
||||
# NTF_CAT_CHECK_MESSAGE category will not show alert if the app is in foreground
|
||||
curl -v --header "apns-topic: $TOPIC" --header "apns-push-type: alert" --header "authorization: bearer $AUTHENTICATION_TOKEN" --data '{"aps":{"category": "NTF_CAT_CHECK_MESSAGE__SECRET", "mutable-content": 1, "alert":"received encrypted message"}, "data": {"test":"123"}}' --http2 https://${APNS_HOST_NAME}/3/device/${DEVICE_TOKEN}
|
|
@ -0,0 +1,200 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.Store where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Functor (($>))
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (catMaybes)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Word (Word16)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Protocol (NtfPrivateSignKey)
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (whenM, ($>>=))
|
||||
|
||||
data NtfStore = NtfStore
|
||||
{ tokens :: TMap NtfTokenId NtfTknData,
|
||||
-- multiple registrations exist to protect from malicious registrations if token is compromised
|
||||
tokenRegistrations :: TMap DeviceToken (TMap ByteString NtfTokenId),
|
||||
subscriptions :: TMap NtfSubscriptionId NtfSubData,
|
||||
tokenSubscriptions :: TMap NtfTokenId (TVar (Set NtfSubscriptionId)),
|
||||
subscriptionLookup :: TMap SMPQueueNtf NtfSubscriptionId
|
||||
}
|
||||
|
||||
newNtfStore :: STM NtfStore
|
||||
newNtfStore = do
|
||||
tokens <- TM.empty
|
||||
tokenRegistrations <- TM.empty
|
||||
subscriptions <- TM.empty
|
||||
tokenSubscriptions <- TM.empty
|
||||
subscriptionLookup <- TM.empty
|
||||
pure NtfStore {tokens, tokenRegistrations, subscriptions, tokenSubscriptions, subscriptionLookup}
|
||||
|
||||
data NtfTknData = NtfTknData
|
||||
{ ntfTknId :: NtfTokenId,
|
||||
token :: DeviceToken,
|
||||
tknStatus :: TVar NtfTknStatus,
|
||||
tknVerifyKey :: C.APublicVerifyKey,
|
||||
tknDhKeys :: C.KeyPair 'C.X25519,
|
||||
tknDhSecret :: C.DhSecretX25519,
|
||||
tknRegCode :: NtfRegCode,
|
||||
tknCronInterval :: TVar Word16
|
||||
}
|
||||
|
||||
mkNtfTknData :: NtfTokenId -> NewNtfEntity 'Token -> C.KeyPair 'C.X25519 -> C.DhSecretX25519 -> NtfRegCode -> STM NtfTknData
|
||||
mkNtfTknData ntfTknId (NewNtfTkn token tknVerifyKey _) tknDhKeys tknDhSecret tknRegCode = do
|
||||
tknStatus <- newTVar NTRegistered
|
||||
tknCronInterval <- newTVar 0
|
||||
pure NtfTknData {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval}
|
||||
|
||||
data NtfSubData = NtfSubData
|
||||
{ ntfSubId :: NtfSubscriptionId,
|
||||
smpQueue :: SMPQueueNtf,
|
||||
notifierKey :: NtfPrivateSignKey,
|
||||
tokenId :: NtfTokenId,
|
||||
subStatus :: TVar NtfSubStatus
|
||||
}
|
||||
|
||||
data NtfEntityRec (e :: NtfEntity) where
|
||||
NtfTkn :: NtfTknData -> NtfEntityRec 'Token
|
||||
NtfSub :: NtfSubData -> NtfEntityRec 'Subscription
|
||||
|
||||
getNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe NtfTknData)
|
||||
getNtfToken st tknId = TM.lookup tknId (tokens st)
|
||||
|
||||
addNtfToken :: NtfStore -> NtfTokenId -> NtfTknData -> STM ()
|
||||
addNtfToken st tknId tkn@NtfTknData {token, tknVerifyKey} = do
|
||||
TM.insert tknId tkn $ tokens st
|
||||
TM.lookup token regs >>= \case
|
||||
Just tIds -> TM.insert regKey tknId tIds
|
||||
_ -> do
|
||||
tIds <- TM.singleton regKey tknId
|
||||
TM.insert token tIds regs
|
||||
where
|
||||
regs = tokenRegistrations st
|
||||
regKey = C.toPubKey C.pubKeyBytes tknVerifyKey
|
||||
|
||||
getNtfTokenRegistration :: NtfStore -> NewNtfEntity 'Token -> STM (Maybe NtfTknData)
|
||||
getNtfTokenRegistration st (NewNtfTkn token tknVerifyKey _) =
|
||||
TM.lookup token (tokenRegistrations st)
|
||||
$>>= TM.lookup regKey
|
||||
$>>= (`TM.lookup` tokens st)
|
||||
where
|
||||
regKey = C.toPubKey C.pubKeyBytes tknVerifyKey
|
||||
|
||||
removeInactiveTokenRegistrations :: NtfStore -> NtfTknData -> STM [NtfTokenId]
|
||||
removeInactiveTokenRegistrations st NtfTknData {ntfTknId = tId, token} =
|
||||
TM.lookup token (tokenRegistrations st)
|
||||
>>= maybe (pure []) removeRegs
|
||||
where
|
||||
removeRegs :: TMap ByteString NtfTokenId -> STM [NtfTokenId]
|
||||
removeRegs tknRegs = do
|
||||
tIds <- filter ((/= tId) . snd) . M.assocs <$> readTVar tknRegs
|
||||
forM_ tIds $ \(regKey, tId') -> do
|
||||
TM.delete regKey tknRegs
|
||||
TM.delete tId' $ tokens st
|
||||
pure $ map snd tIds
|
||||
|
||||
removeTokenRegistration :: NtfStore -> NtfTknData -> STM ()
|
||||
removeTokenRegistration st NtfTknData {ntfTknId = tId, token, tknVerifyKey} =
|
||||
TM.lookup token (tokenRegistrations st) >>= mapM_ removeReg
|
||||
where
|
||||
removeReg regs =
|
||||
TM.lookup k regs
|
||||
>>= mapM_ (\tId' -> when (tId == tId') $ TM.delete k regs)
|
||||
k = C.toPubKey C.pubKeyBytes tknVerifyKey
|
||||
|
||||
deleteNtfToken :: NtfStore -> NtfTokenId -> STM [SMPQueueNtf]
|
||||
deleteNtfToken st tknId = do
|
||||
TM.lookupDelete tknId (tokens st)
|
||||
>>= mapM_
|
||||
( \NtfTknData {token, tknVerifyKey} ->
|
||||
TM.lookup token regs
|
||||
>>= mapM_
|
||||
( \tIds -> do
|
||||
TM.delete (regKey tknVerifyKey) tIds
|
||||
whenM (TM.null tIds) $ TM.delete token regs
|
||||
)
|
||||
)
|
||||
|
||||
qs <-
|
||||
TM.lookupDelete tknId (tokenSubscriptions st)
|
||||
>>= mapM
|
||||
( readTVar
|
||||
>=> mapM
|
||||
( \subId -> do
|
||||
TM.lookupDelete subId (subscriptions st)
|
||||
>>= mapM
|
||||
( \NtfSubData {smpQueue} ->
|
||||
TM.delete smpQueue (subscriptionLookup st) $> smpQueue
|
||||
)
|
||||
)
|
||||
. S.toList
|
||||
)
|
||||
pure $ maybe [] catMaybes qs
|
||||
where
|
||||
regs = tokenRegistrations st
|
||||
regKey = C.toPubKey C.pubKeyBytes
|
||||
|
||||
getNtfSubscription :: NtfStore -> NtfSubscriptionId -> STM (Maybe NtfSubData)
|
||||
getNtfSubscription st subId =
|
||||
TM.lookup subId (subscriptions st)
|
||||
|
||||
findNtfSubscription :: NtfStore -> SMPQueueNtf -> STM (Maybe NtfSubData)
|
||||
findNtfSubscription st smpQueue = do
|
||||
TM.lookup smpQueue (subscriptionLookup st)
|
||||
$>>= \subId -> TM.lookup subId (subscriptions st)
|
||||
|
||||
findNtfSubscriptionToken :: NtfStore -> SMPQueueNtf -> STM (Maybe NtfTknData)
|
||||
findNtfSubscriptionToken st smpQueue = do
|
||||
findNtfSubscription st smpQueue
|
||||
$>>= \NtfSubData {tokenId} -> getActiveNtfToken st tokenId
|
||||
|
||||
getActiveNtfToken :: NtfStore -> NtfTokenId -> STM (Maybe NtfTknData)
|
||||
getActiveNtfToken st tknId =
|
||||
getNtfToken st tknId $>>= \tkn@NtfTknData {tknStatus} -> do
|
||||
tStatus <- readTVar tknStatus
|
||||
pure $ if tStatus == NTActive then Just tkn else Nothing
|
||||
|
||||
mkNtfSubData :: NtfSubscriptionId -> NewNtfEntity 'Subscription -> STM NtfSubData
|
||||
mkNtfSubData ntfSubId (NewNtfSub tokenId smpQueue notifierKey) = do
|
||||
subStatus <- newTVar NSNew
|
||||
pure NtfSubData {ntfSubId, smpQueue, tokenId, subStatus, notifierKey}
|
||||
|
||||
addNtfSubscription :: NtfStore -> NtfSubscriptionId -> NtfSubData -> STM (Maybe ())
|
||||
addNtfSubscription st subId sub@NtfSubData {smpQueue, tokenId} =
|
||||
TM.lookup tokenId (tokenSubscriptions st) >>= maybe newTokenSub pure >>= insertSub
|
||||
where
|
||||
newTokenSub = do
|
||||
ts <- newTVar S.empty
|
||||
TM.insert tokenId ts $ tokenSubscriptions st
|
||||
pure ts
|
||||
insertSub ts = do
|
||||
modifyTVar' ts $ S.insert subId
|
||||
TM.insert subId sub $ subscriptions st
|
||||
TM.insert smpQueue subId (subscriptionLookup st)
|
||||
-- return Nothing if subscription existed before
|
||||
pure $ Just ()
|
||||
|
||||
deleteNtfSubscription :: NtfStore -> NtfSubscriptionId -> STM ()
|
||||
deleteNtfSubscription st subId = do
|
||||
TM.lookupDelete subId (subscriptions st)
|
||||
>>= mapM_
|
||||
( \NtfSubData {smpQueue, tokenId} -> do
|
||||
TM.delete smpQueue $ subscriptionLookup st
|
||||
ts_ <- TM.lookup tokenId (tokenSubscriptions st)
|
||||
forM_ ts_ $ \ts -> modifyTVar' ts $ S.delete subId
|
||||
)
|
|
@ -0,0 +1,227 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Server.StoreLog
|
||||
( StoreLog,
|
||||
NtfStoreLogRecord (..),
|
||||
readWriteNtfStore,
|
||||
logCreateToken,
|
||||
logTokenStatus,
|
||||
logUpdateToken,
|
||||
logTokenCron,
|
||||
logDeleteToken,
|
||||
logCreateSubscription,
|
||||
logSubscriptionStatus,
|
||||
logDeleteSubscription,
|
||||
closeStoreLog,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad (void)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Word (Word16)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Store
|
||||
import Simplex.Messaging.Protocol (NtfPrivateSignKey)
|
||||
import Simplex.Messaging.Server.StoreLog
|
||||
import Simplex.Messaging.Util (whenM)
|
||||
import System.Directory (doesFileExist, renameFile)
|
||||
import System.IO
|
||||
|
||||
data NtfStoreLogRecord
|
||||
= CreateToken NtfTknRec
|
||||
| TokenStatus NtfTokenId NtfTknStatus
|
||||
| UpdateToken NtfTokenId DeviceToken NtfRegCode
|
||||
| TokenCron NtfTokenId Word16
|
||||
| DeleteToken NtfTokenId
|
||||
| CreateSubscription NtfSubRec
|
||||
| SubscriptionStatus NtfSubscriptionId NtfSubStatus
|
||||
| DeleteSubscription NtfSubscriptionId
|
||||
|
||||
data NtfTknRec = NtfTknRec
|
||||
{ ntfTknId :: NtfTokenId,
|
||||
token :: DeviceToken,
|
||||
tknStatus :: NtfTknStatus,
|
||||
tknVerifyKey :: C.APublicVerifyKey,
|
||||
tknDhKeys :: C.KeyPair 'C.X25519,
|
||||
tknDhSecret :: C.DhSecretX25519,
|
||||
tknRegCode :: NtfRegCode,
|
||||
tknCronInterval :: Word16
|
||||
}
|
||||
|
||||
mkTknData :: NtfTknRec -> STM NtfTknData
|
||||
mkTknData NtfTknRec {ntfTknId, token, tknStatus = status, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval = cronInt} = do
|
||||
tknStatus <- newTVar status
|
||||
tknCronInterval <- newTVar cronInt
|
||||
pure NtfTknData {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval}
|
||||
|
||||
mkTknRec :: NtfTknData -> STM NtfTknRec
|
||||
mkTknRec NtfTknData {ntfTknId, token, tknStatus = status, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval = cronInt} = do
|
||||
tknStatus <- readTVar status
|
||||
tknCronInterval <- readTVar cronInt
|
||||
pure NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval}
|
||||
|
||||
data NtfSubRec = NtfSubRec
|
||||
{ ntfSubId :: NtfSubscriptionId,
|
||||
smpQueue :: SMPQueueNtf,
|
||||
notifierKey :: NtfPrivateSignKey,
|
||||
tokenId :: NtfTokenId,
|
||||
subStatus :: NtfSubStatus
|
||||
}
|
||||
|
||||
mkSubData :: NtfSubRec -> STM NtfSubData
|
||||
mkSubData NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus = status} = do
|
||||
subStatus <- newTVar status
|
||||
pure NtfSubData {ntfSubId, smpQueue, notifierKey, tokenId, subStatus}
|
||||
|
||||
mkSubRec :: NtfSubData -> STM NtfSubRec
|
||||
mkSubRec NtfSubData {ntfSubId, smpQueue, notifierKey, tokenId, subStatus = status} = do
|
||||
subStatus <- readTVar status
|
||||
pure NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus}
|
||||
|
||||
instance StrEncoding NtfStoreLogRecord where
|
||||
strEncode = \case
|
||||
CreateToken tknRec -> strEncode (Str "TCREATE", tknRec)
|
||||
TokenStatus tknId tknStatus -> strEncode (Str "TSTATUS", tknId, tknStatus)
|
||||
UpdateToken tknId token regCode -> strEncode (Str "TUPDATE", tknId, token, regCode)
|
||||
TokenCron tknId cronInt -> strEncode (Str "TCRON", tknId, cronInt)
|
||||
DeleteToken tknId -> strEncode (Str "TDELETE", tknId)
|
||||
CreateSubscription subRec -> strEncode (Str "SCREATE", subRec)
|
||||
SubscriptionStatus subId subStatus -> strEncode (Str "SSTATUS", subId, subStatus)
|
||||
DeleteSubscription subId -> strEncode (Str "SDELETE", subId)
|
||||
strP =
|
||||
A.choice
|
||||
[ "TCREATE " *> (CreateToken <$> strP),
|
||||
"TSTATUS " *> (TokenStatus <$> strP_ <*> strP),
|
||||
"TUPDATE " *> (UpdateToken <$> strP_ <*> strP_ <*> strP),
|
||||
"TCRON " *> (TokenCron <$> strP_ <*> strP),
|
||||
"TDELETE " *> (DeleteToken <$> strP),
|
||||
"SCREATE " *> (CreateSubscription <$> strP),
|
||||
"SSTATUS " *> (SubscriptionStatus <$> strP_ <*> strP),
|
||||
"SDELETE " *> (DeleteSubscription <$> strP)
|
||||
]
|
||||
|
||||
instance StrEncoding NtfTknRec where
|
||||
strEncode NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval} =
|
||||
B.unwords
|
||||
[ "tknId=" <> strEncode ntfTknId,
|
||||
"token=" <> strEncode token,
|
||||
"tokenStatus=" <> strEncode tknStatus,
|
||||
"verifyKey=" <> strEncode tknVerifyKey,
|
||||
"dhKeys=" <> strEncode tknDhKeys,
|
||||
"dhSecret=" <> strEncode tknDhSecret,
|
||||
"regCode=" <> strEncode tknRegCode,
|
||||
"cron=" <> strEncode tknCronInterval
|
||||
]
|
||||
strP = do
|
||||
ntfTknId <- "tknId=" *> strP_
|
||||
token <- "token=" *> strP_
|
||||
tknStatus <- "tokenStatus=" *> strP_
|
||||
tknVerifyKey <- "verifyKey=" *> strP_
|
||||
tknDhKeys <- "dhKeys=" *> strP_
|
||||
tknDhSecret <- "dhSecret=" *> strP_
|
||||
tknRegCode <- "regCode=" *> strP_
|
||||
tknCronInterval <- "cron=" *> strP
|
||||
pure NtfTknRec {ntfTknId, token, tknStatus, tknVerifyKey, tknDhKeys, tknDhSecret, tknRegCode, tknCronInterval}
|
||||
|
||||
instance StrEncoding NtfSubRec where
|
||||
strEncode NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus} =
|
||||
B.unwords
|
||||
[ "subId=" <> strEncode ntfSubId,
|
||||
"smpQueue=" <> strEncode smpQueue,
|
||||
"notifierKey=" <> strEncode notifierKey,
|
||||
"tknId=" <> strEncode tokenId,
|
||||
"subStatus=" <> strEncode subStatus
|
||||
]
|
||||
strP = do
|
||||
ntfSubId <- "subId=" *> strP_
|
||||
smpQueue <- "smpQueue=" *> strP_
|
||||
notifierKey <- "notifierKey=" *> strP_
|
||||
tokenId <- "tknId=" *> strP_
|
||||
subStatus <- "subStatus=" *> strP
|
||||
pure NtfSubRec {ntfSubId, smpQueue, notifierKey, tokenId, subStatus}
|
||||
|
||||
logNtfStoreRecord :: StoreLog 'WriteMode -> NtfStoreLogRecord -> IO ()
|
||||
logNtfStoreRecord = writeStoreLogRecord
|
||||
|
||||
logCreateToken :: StoreLog 'WriteMode -> NtfTknData -> IO ()
|
||||
logCreateToken s tkn = logNtfStoreRecord s . CreateToken =<< atomically (mkTknRec tkn)
|
||||
|
||||
logTokenStatus :: StoreLog 'WriteMode -> NtfTokenId -> NtfTknStatus -> IO ()
|
||||
logTokenStatus s tknId tknStatus = logNtfStoreRecord s $ TokenStatus tknId tknStatus
|
||||
|
||||
logUpdateToken :: StoreLog 'WriteMode -> NtfTokenId -> DeviceToken -> NtfRegCode -> IO ()
|
||||
logUpdateToken s tknId token regCode = logNtfStoreRecord s $ UpdateToken tknId token regCode
|
||||
|
||||
logTokenCron :: StoreLog 'WriteMode -> NtfTokenId -> Word16 -> IO ()
|
||||
logTokenCron s tknId cronInt = logNtfStoreRecord s $ TokenCron tknId cronInt
|
||||
|
||||
logDeleteToken :: StoreLog 'WriteMode -> NtfTokenId -> IO ()
|
||||
logDeleteToken s tknId = logNtfStoreRecord s $ DeleteToken tknId
|
||||
|
||||
logCreateSubscription :: StoreLog 'WriteMode -> NtfSubData -> IO ()
|
||||
logCreateSubscription s sub = logNtfStoreRecord s . CreateSubscription =<< atomically (mkSubRec sub)
|
||||
|
||||
logSubscriptionStatus :: StoreLog 'WriteMode -> NtfSubscriptionId -> NtfSubStatus -> IO ()
|
||||
logSubscriptionStatus s subId subStatus = logNtfStoreRecord s $ SubscriptionStatus subId subStatus
|
||||
|
||||
logDeleteSubscription :: StoreLog 'WriteMode -> NtfSubscriptionId -> IO ()
|
||||
logDeleteSubscription s subId = logNtfStoreRecord s $ DeleteSubscription subId
|
||||
|
||||
readWriteNtfStore :: FilePath -> NtfStore -> IO (StoreLog 'WriteMode)
|
||||
readWriteNtfStore f st = do
|
||||
whenM (doesFileExist f) $ do
|
||||
readNtfStore f st
|
||||
renameFile f $ f <> ".bak"
|
||||
s <- openWriteStoreLog f
|
||||
writeNtfStore s st
|
||||
pure s
|
||||
|
||||
readNtfStore :: FilePath -> NtfStore -> IO ()
|
||||
readNtfStore f st = mapM_ addNtfLogRecord . B.lines =<< B.readFile f
|
||||
where
|
||||
addNtfLogRecord s = case strDecode s of
|
||||
Left e -> B.putStrLn $ "Log parsing error (" <> B.pack e <> "): " <> B.take 100 s
|
||||
Right lr -> atomically $ case lr of
|
||||
CreateToken r@NtfTknRec {ntfTknId} -> do
|
||||
tkn <- mkTknData r
|
||||
addNtfToken st ntfTknId tkn
|
||||
TokenStatus tknId status ->
|
||||
getNtfToken st tknId
|
||||
>>= mapM_ (\NtfTknData {tknStatus} -> writeTVar tknStatus status)
|
||||
UpdateToken tknId token' tknRegCode ->
|
||||
getNtfToken st tknId
|
||||
>>= mapM_
|
||||
( \tkn@NtfTknData {tknStatus} -> do
|
||||
removeTokenRegistration st tkn
|
||||
writeTVar tknStatus NTRegistered
|
||||
addNtfToken st tknId tkn {token = token', tknRegCode}
|
||||
)
|
||||
TokenCron tknId cronInt ->
|
||||
getNtfToken st tknId
|
||||
>>= mapM_ (\NtfTknData {tknCronInterval} -> writeTVar tknCronInterval cronInt)
|
||||
DeleteToken tknId ->
|
||||
void $ deleteNtfToken st tknId
|
||||
CreateSubscription r@NtfSubRec {ntfSubId} -> do
|
||||
sub <- mkSubData r
|
||||
void $ addNtfSubscription st ntfSubId sub
|
||||
SubscriptionStatus subId status ->
|
||||
getNtfSubscription st subId
|
||||
>>= mapM_ (\NtfSubData {subStatus} -> writeTVar subStatus status)
|
||||
DeleteSubscription subId ->
|
||||
deleteNtfSubscription st subId
|
||||
|
||||
writeNtfStore :: StoreLog 'WriteMode -> NtfStore -> IO ()
|
||||
writeNtfStore s NtfStore {tokens, subscriptions} = do
|
||||
atomically (readTVar tokens >>= mapM mkTknRec)
|
||||
>>= mapM_ (writeStoreLogRecord s . CreateToken)
|
||||
atomically (readTVar subscriptions >>= mapM mkSubRec)
|
||||
>>= mapM_ (writeStoreLogRecord s . CreateSubscription)
|
|
@ -0,0 +1,72 @@
|
|||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Transport where
|
||||
|
||||
import Control.Monad.Except
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Version
|
||||
|
||||
ntfBlockSize :: Int
|
||||
ntfBlockSize = 512
|
||||
|
||||
supportedNTFServerVRange :: VersionRange
|
||||
supportedNTFServerVRange = mkVersionRange 1 1
|
||||
|
||||
data NtfServerHandshake = NtfServerHandshake
|
||||
{ ntfVersionRange :: VersionRange,
|
||||
sessionId :: SessionId
|
||||
}
|
||||
|
||||
data NtfClientHandshake = NtfClientHandshake
|
||||
{ -- | agreed SMP notifications server protocol version
|
||||
ntfVersion :: Version,
|
||||
-- | server identity - CA certificate fingerprint
|
||||
keyHash :: C.KeyHash
|
||||
}
|
||||
|
||||
instance Encoding NtfServerHandshake where
|
||||
smpEncode NtfServerHandshake {ntfVersionRange, sessionId} =
|
||||
smpEncode (ntfVersionRange, sessionId)
|
||||
smpP = do
|
||||
(ntfVersionRange, sessionId) <- smpP
|
||||
pure NtfServerHandshake {ntfVersionRange, sessionId}
|
||||
|
||||
instance Encoding NtfClientHandshake where
|
||||
smpEncode NtfClientHandshake {ntfVersion, keyHash} = smpEncode (ntfVersion, keyHash)
|
||||
smpP = do
|
||||
(ntfVersion, keyHash) <- smpP
|
||||
pure NtfClientHandshake {ntfVersion, keyHash}
|
||||
|
||||
-- | Notifcations server transport handshake.
|
||||
ntfServerHandshake :: forall c. Transport c => c -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
ntfServerHandshake c kh ntfVRange = do
|
||||
let th@THandle {sessionId} = ntfTHandle c
|
||||
sendHandshake th $ NtfServerHandshake {sessionId, ntfVersionRange = ntfVRange}
|
||||
getHandshake th >>= \case
|
||||
NtfClientHandshake {ntfVersion, keyHash}
|
||||
| keyHash /= kh ->
|
||||
throwError $ TEHandshake IDENTITY
|
||||
| ntfVersion `isCompatible` ntfVRange -> do
|
||||
pure (th :: THandle c) {thVersion = ntfVersion}
|
||||
| otherwise -> throwError $ TEHandshake VERSION
|
||||
|
||||
-- | Notifcations server client transport handshake.
|
||||
ntfClientHandshake :: forall c. Transport c => c -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
ntfClientHandshake c keyHash ntfVRange = do
|
||||
let th@THandle {sessionId} = ntfTHandle c
|
||||
NtfServerHandshake {sessionId = sessId, ntfVersionRange} <- getHandshake th
|
||||
if sessionId /= sessId
|
||||
then throwError TEBadSession
|
||||
else case ntfVersionRange `compatibleVersion` ntfVRange of
|
||||
Just (Compatible ntfVersion) -> do
|
||||
sendHandshake th $ NtfClientHandshake {ntfVersion, keyHash}
|
||||
pure (th :: THandle c) {thVersion = ntfVersion}
|
||||
Nothing -> throwError $ TEHandshake VERSION
|
||||
|
||||
ntfTHandle :: Transport c => c -> THandle c
|
||||
ntfTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = ntfBlockSize, thVersion = 0}
|
|
@ -0,0 +1,194 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
|
||||
module Simplex.Messaging.Notifications.Types where
|
||||
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import Data.Text.Encoding (decodeLatin1, encodeUtf8)
|
||||
import Data.Time (UTCTime)
|
||||
import Database.SQLite.Simple.FromField (FromField (..))
|
||||
import Database.SQLite.Simple.ToField (ToField (..))
|
||||
import Simplex.Messaging.Agent.Protocol (ConnId, NotificationsMode (..))
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Parsers (blobFieldDecoder, fromTextField_)
|
||||
import Simplex.Messaging.Protocol (NotifierId, NtfServer, SMPServer)
|
||||
|
||||
data NtfTknAction
|
||||
= NTARegister
|
||||
| NTAVerify NtfRegCode -- code to verify token
|
||||
| NTACheck
|
||||
| NTADelete
|
||||
deriving (Show)
|
||||
|
||||
instance Encoding NtfTknAction where
|
||||
smpEncode = \case
|
||||
NTARegister -> "R"
|
||||
NTAVerify code -> smpEncode ('V', code)
|
||||
NTACheck -> "C"
|
||||
NTADelete -> "D"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'R' -> pure NTARegister
|
||||
'V' -> NTAVerify <$> smpP
|
||||
'C' -> pure NTACheck
|
||||
'D' -> pure NTADelete
|
||||
_ -> fail "bad NtfTknAction"
|
||||
|
||||
instance FromField NtfTknAction where fromField = blobFieldDecoder smpDecode
|
||||
|
||||
instance ToField NtfTknAction where toField = toField . smpEncode
|
||||
|
||||
data NtfToken = NtfToken
|
||||
{ deviceToken :: DeviceToken,
|
||||
ntfServer :: NtfServer,
|
||||
ntfTokenId :: Maybe NtfTokenId,
|
||||
-- | key used by the ntf server to verify transmissions
|
||||
ntfPubKey :: C.APublicVerifyKey,
|
||||
-- | key used by the ntf client to sign transmissions
|
||||
ntfPrivKey :: C.APrivateSignKey,
|
||||
-- | client's DH keys (to repeat registration if necessary)
|
||||
ntfDhKeys :: C.KeyPair 'C.X25519,
|
||||
-- | shared DH secret used to encrypt/decrypt notifications e2e
|
||||
ntfDhSecret :: Maybe C.DhSecretX25519,
|
||||
-- | token status
|
||||
ntfTknStatus :: NtfTknStatus,
|
||||
-- | pending token action and the earliest time
|
||||
ntfTknAction :: Maybe NtfTknAction,
|
||||
ntfMode :: NotificationsMode
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
newNtfToken :: DeviceToken -> NtfServer -> C.ASignatureKeyPair -> C.KeyPair 'C.X25519 -> NotificationsMode -> NtfToken
|
||||
newNtfToken deviceToken ntfServer (ntfPubKey, ntfPrivKey) ntfDhKeys ntfMode =
|
||||
NtfToken
|
||||
{ deviceToken,
|
||||
ntfServer,
|
||||
ntfTokenId = Nothing,
|
||||
ntfPubKey,
|
||||
ntfPrivKey,
|
||||
ntfDhKeys,
|
||||
ntfDhSecret = Nothing,
|
||||
ntfTknStatus = NTNew,
|
||||
ntfTknAction = Just NTARegister,
|
||||
ntfMode
|
||||
}
|
||||
|
||||
data NtfSubAction = NtfSubNTFAction NtfSubNTFAction | NtfSubSMPAction NtfSubSMPAction
|
||||
deriving (Show)
|
||||
|
||||
isDeleteNtfSubAction :: NtfSubAction -> Bool
|
||||
isDeleteNtfSubAction = \case
|
||||
NtfSubNTFAction a -> case a of
|
||||
NSACreate -> False
|
||||
NSACheck -> False
|
||||
NSADelete -> True
|
||||
NtfSubSMPAction a -> case a of
|
||||
NSASmpKey -> False
|
||||
NSASmpDelete -> True
|
||||
|
||||
type NtfActionTs = UTCTime
|
||||
|
||||
data NtfSubNTFAction
|
||||
= NSACreate
|
||||
| NSACheck
|
||||
| NSADelete
|
||||
deriving (Show)
|
||||
|
||||
instance Encoding NtfSubNTFAction where
|
||||
smpEncode = \case
|
||||
NSACreate -> "N"
|
||||
NSACheck -> "C"
|
||||
NSADelete -> "D"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'N' -> pure NSACreate
|
||||
'C' -> pure NSACheck
|
||||
'D' -> pure NSADelete
|
||||
_ -> fail "bad NtfSubNTFAction"
|
||||
|
||||
instance FromField NtfSubNTFAction where fromField = blobFieldDecoder smpDecode
|
||||
|
||||
instance ToField NtfSubNTFAction where toField = toField . smpEncode
|
||||
|
||||
data NtfSubSMPAction
|
||||
= NSASmpKey
|
||||
| NSASmpDelete
|
||||
deriving (Show)
|
||||
|
||||
instance Encoding NtfSubSMPAction where
|
||||
smpEncode = \case
|
||||
NSASmpKey -> "K"
|
||||
NSASmpDelete -> "D"
|
||||
smpP =
|
||||
A.anyChar >>= \case
|
||||
'K' -> pure NSASmpKey
|
||||
'D' -> pure NSASmpDelete
|
||||
_ -> fail "bad NtfSubSMPAction"
|
||||
|
||||
instance FromField NtfSubSMPAction where fromField = blobFieldDecoder smpDecode
|
||||
|
||||
instance ToField NtfSubSMPAction where toField = toField . smpEncode
|
||||
|
||||
data NtfAgentSubStatus
|
||||
= -- | subscription started
|
||||
NASNew
|
||||
| -- | state after NKEY - notifier ID is assigned to queue on SMP server
|
||||
NASKey
|
||||
| -- | state after SNEW - subscription created on notification server
|
||||
NASCreated NtfSubStatus
|
||||
| -- | state after SDEL (subscription is deleted on notification server)
|
||||
NASOff
|
||||
| -- | state after NDEL (notifier credentials are deleted on SMP server)
|
||||
-- Can only exist transiently - if subscription record was updated by notification supervisor mid worker operation,
|
||||
-- and hence got updated instead of being fully deleted in the database post operation by worker
|
||||
NASDeleted
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Encoding NtfAgentSubStatus where
|
||||
smpEncode = \case
|
||||
NASNew -> "NEW"
|
||||
NASKey -> "KEY"
|
||||
NASCreated status -> "CREATED " <> smpEncode status
|
||||
NASOff -> "OFF"
|
||||
NASDeleted -> "DELETED"
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"NEW" -> pure NASNew
|
||||
"KEY" -> pure NASKey
|
||||
"CREATED" -> do
|
||||
_ <- A.space
|
||||
NASCreated <$> smpP
|
||||
"OFF" -> pure NASOff
|
||||
"DELETED" -> pure NASDeleted
|
||||
_ -> fail "bad NtfAgentSubStatus"
|
||||
|
||||
instance FromField NtfAgentSubStatus where fromField = fromTextField_ $ either (const Nothing) Just . smpDecode . encodeUtf8
|
||||
|
||||
instance ToField NtfAgentSubStatus where toField = toField . decodeLatin1 . smpEncode
|
||||
|
||||
data NtfSubscription = NtfSubscription
|
||||
{ connId :: ConnId,
|
||||
smpServer :: SMPServer,
|
||||
ntfQueueId :: Maybe NotifierId,
|
||||
ntfServer :: NtfServer,
|
||||
ntfSubId :: Maybe NtfSubscriptionId,
|
||||
ntfSubStatus :: NtfAgentSubStatus
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
newNtfSubscription :: ConnId -> SMPServer -> Maybe NotifierId -> NtfServer -> NtfAgentSubStatus -> NtfSubscription
|
||||
newNtfSubscription connId smpServer ntfQueueId ntfServer ntfSubStatus =
|
||||
NtfSubscription
|
||||
{ connId,
|
||||
smpServer,
|
||||
ntfQueueId,
|
||||
ntfServer,
|
||||
ntfSubId = Nothing,
|
||||
ntfSubStatus
|
||||
}
|
|
@ -13,6 +13,8 @@ import Data.ByteString.Base64
|
|||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Char (isAlphaNum, toLower)
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import Data.Time.Clock (UTCTime)
|
||||
import Data.Time.ISO8601 (parseISO8601)
|
||||
import Data.Typeable (Typeable)
|
||||
|
@ -81,6 +83,14 @@ blobFieldDecoder dec = \case
|
|||
Left e -> returnError ConversionFailed f ("couldn't parse field: " ++ e)
|
||||
f -> returnError ConversionFailed f "expecting SQLBlob column type"
|
||||
|
||||
fromTextField_ :: (Typeable a) => (Text -> Maybe a) -> Field -> Ok a
|
||||
fromTextField_ fromText = \case
|
||||
f@(Field (SQLText t) _) ->
|
||||
case fromText t of
|
||||
Just x -> Ok x
|
||||
_ -> returnError ConversionFailed f ("invalid text: " <> T.unpack t)
|
||||
f -> returnError ConversionFailed f "expecting SQLText column type"
|
||||
|
||||
fstToLower :: String -> String
|
||||
fstToLower "" = ""
|
||||
fstToLower (h : t) = toLower h : t
|
||||
|
|
|
@ -1,23 +1,30 @@
|
|||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeFamilyDependencies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
{-# OPTIONS_GHC -fno-warn-unticked-promoted-constructors #-}
|
||||
|
||||
{-# HLINT ignore "Use newtype instead of data" #-}
|
||||
|
||||
-- |
|
||||
-- Module : Simplex.Messaging.Protocol
|
||||
-- Module : Simplex.Messaging.ProtocolEncoding
|
||||
-- Copyright : (c) simplex.chat
|
||||
-- License : AGPL-3
|
||||
--
|
||||
|
@ -37,7 +44,7 @@ module Simplex.Messaging.Protocol
|
|||
e2eEncMessageLength,
|
||||
|
||||
-- * SMP protocol types
|
||||
Protocol,
|
||||
ProtocolEncoding (..),
|
||||
Command (..),
|
||||
Party (..),
|
||||
Cmd (..),
|
||||
|
@ -55,7 +62,14 @@ module Simplex.Messaging.Protocol
|
|||
PubHeader (..),
|
||||
ClientMessage (..),
|
||||
PrivHeader (..),
|
||||
SMPServer (..),
|
||||
Protocol (..),
|
||||
ProtocolType (..),
|
||||
ProtocolServer (..),
|
||||
ProtoServer,
|
||||
SMPServer,
|
||||
pattern SMPServer,
|
||||
NtfServer,
|
||||
pattern NtfServer,
|
||||
SrvLoc (..),
|
||||
CorrId (..),
|
||||
QueueId,
|
||||
|
@ -70,13 +84,32 @@ module Simplex.Messaging.Protocol
|
|||
SndPublicVerifyKey,
|
||||
NtfPrivateSignKey,
|
||||
NtfPublicVerifyKey,
|
||||
RcvNtfPublicDhKey,
|
||||
RcvNtfDhSecret,
|
||||
Message (..),
|
||||
RcvMessage (..),
|
||||
MsgId,
|
||||
MsgBody,
|
||||
MaxMessageLen,
|
||||
MaxRcvMessageLen,
|
||||
EncRcvMsgBody (..),
|
||||
RcvMsgBody (..),
|
||||
ClientRcvMsgBody (..),
|
||||
EncNMsgMeta,
|
||||
SMPMsgMeta (..),
|
||||
NMsgMeta (..),
|
||||
MsgFlags (..),
|
||||
rcvMessageMeta,
|
||||
noMsgFlags,
|
||||
|
||||
-- * Parse and serialize
|
||||
ProtocolMsgTag (..),
|
||||
messageTagP,
|
||||
encodeTransmission,
|
||||
transmissionP,
|
||||
encodeProtocol,
|
||||
_smpP,
|
||||
encodeRcvMsgBody,
|
||||
clientRcvMsgBodyP,
|
||||
|
||||
-- * TCP transport functions
|
||||
tPut,
|
||||
|
@ -99,16 +132,17 @@ import qualified Data.ByteString.Char8 as B
|
|||
import Data.Kind
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.String
|
||||
import Data.Time.Clock.System (SystemTime)
|
||||
import Data.Time.Clock.System (SystemTime (..))
|
||||
import Data.Type.Equality
|
||||
import GHC.Generics (Generic)
|
||||
import GHC.TypeLits (type (+))
|
||||
import Generic.Random (genericArbitraryU)
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers
|
||||
import Simplex.Messaging.Transport (SessionId, THandle (..), Transport, TransportError (..), tGetBlock, tPutBlock)
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Util (bshow, (<$?>))
|
||||
import Simplex.Messaging.Version
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
|
@ -122,6 +156,11 @@ smpClientVRange = mkVersionRange 1 smpClientVersion
|
|||
maxMessageLength :: Int
|
||||
maxMessageLength = 16088
|
||||
|
||||
type MaxMessageLen = 16088
|
||||
|
||||
-- 16 extra bytes: 8 for timestamp and 8 for flags (7 flags and the space, only 1 flag is currently used)
|
||||
type MaxRcvMessageLen = MaxMessageLen + 16 -- 16104, the padded size is 16106
|
||||
|
||||
-- it is shorter to allow per-queue e2e encryption DH key in the "public" header
|
||||
e2eEncConfirmationLength :: Int
|
||||
e2eEncConfirmationLength = 15936
|
||||
|
@ -161,7 +200,7 @@ data Cmd = forall p. PartyI p => Cmd (SParty p) (Command p)
|
|||
deriving instance Show Cmd
|
||||
|
||||
-- | Parsed SMP transmission without signature, size and session ID.
|
||||
type Transmission c = (CorrId, QueueId, c)
|
||||
type Transmission c = (CorrId, EntityId, c)
|
||||
|
||||
-- | signed parsed transmission, with original raw bytes and parsing error.
|
||||
type SignedTransmission c = (Maybe C.ASignature, Signed, Transmission (Either ErrorType c))
|
||||
|
@ -174,9 +213,10 @@ data RawTransmission = RawTransmission
|
|||
signed :: ByteString,
|
||||
sessId :: SessionId,
|
||||
corrId :: ByteString,
|
||||
queueId :: ByteString,
|
||||
entityId :: ByteString,
|
||||
command :: ByteString
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
-- | unparsed sent SMP transmission with signature, without session ID.
|
||||
type SignedRawTransmission = (Maybe C.ASignature, SessionId, ByteString, ByteString)
|
||||
|
@ -194,7 +234,9 @@ type SenderId = QueueId
|
|||
type NotifierId = QueueId
|
||||
|
||||
-- | SMP queue ID on the server.
|
||||
type QueueId = ByteString
|
||||
type QueueId = EntityId
|
||||
|
||||
type EntityId = ByteString
|
||||
|
||||
-- | Parameterized type for SMP protocol commands from all clients.
|
||||
data Command (p :: Party) where
|
||||
|
@ -202,12 +244,18 @@ data Command (p :: Party) where
|
|||
NEW :: RcvPublicVerifyKey -> RcvPublicDhKey -> Command Recipient
|
||||
SUB :: Command Recipient
|
||||
KEY :: SndPublicVerifyKey -> Command Recipient
|
||||
NKEY :: NtfPublicVerifyKey -> Command Recipient
|
||||
ACK :: Command Recipient
|
||||
NKEY :: NtfPublicVerifyKey -> RcvNtfPublicDhKey -> Command Recipient
|
||||
NDEL :: Command Recipient
|
||||
GET :: Command Recipient
|
||||
-- ACK v1 has to be supported for encoding/decoding
|
||||
-- ACK :: Command Recipient
|
||||
ACK :: MsgId -> Command Recipient
|
||||
OFF :: Command Recipient
|
||||
DEL :: Command Recipient
|
||||
-- SMP sender commands
|
||||
SEND :: MsgBody -> Command Sender
|
||||
-- SEND v1 has to be supported for encoding/decoding
|
||||
-- SEND :: MsgBody -> Command Sender
|
||||
SEND :: MsgFlags -> MsgBody -> Command Sender
|
||||
PING :: Command Sender
|
||||
-- SMP notification subscriber commands
|
||||
NSUB :: Command Notifier
|
||||
|
@ -219,15 +267,139 @@ deriving instance Eq (Command p)
|
|||
data BrokerMsg where
|
||||
-- SMP broker messages (responses, client messages, notifications)
|
||||
IDS :: QueueIdsKeys -> BrokerMsg
|
||||
MSG :: MsgId -> SystemTime -> MsgBody -> BrokerMsg
|
||||
NID :: NotifierId -> BrokerMsg
|
||||
NMSG :: BrokerMsg
|
||||
-- MSG v1/2 has to be supported for encoding/decoding
|
||||
-- v1: MSG :: MsgId -> SystemTime -> MsgBody -> BrokerMsg
|
||||
-- v2: MsgId -> SystemTime -> MsgFlags -> MsgBody -> BrokerMsg
|
||||
MSG :: RcvMessage -> BrokerMsg
|
||||
NID :: NotifierId -> RcvNtfPublicDhKey -> BrokerMsg
|
||||
NMSG :: C.CbNonce -> EncNMsgMeta -> BrokerMsg
|
||||
END :: BrokerMsg
|
||||
OK :: BrokerMsg
|
||||
ERR :: ErrorType -> BrokerMsg
|
||||
PONG :: BrokerMsg
|
||||
deriving (Eq, Show)
|
||||
|
||||
data RcvMessage = RcvMessage
|
||||
{ msgId :: MsgId,
|
||||
msgTs :: SystemTime,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: EncRcvMsgBody -- e2e encrypted, with extra encryption for recipient
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
-- | received message without server/recipient encryption
|
||||
data Message = Message
|
||||
{ msgId :: MsgId,
|
||||
msgTs :: SystemTime,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: C.MaxLenBS MaxMessageLen
|
||||
}
|
||||
|
||||
instance StrEncoding RcvMessage where
|
||||
strEncode RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} =
|
||||
B.unwords
|
||||
[ strEncode msgId,
|
||||
strEncode msgTs,
|
||||
"flags=" <> strEncode msgFlags,
|
||||
strEncode body
|
||||
]
|
||||
strP = do
|
||||
msgId <- strP_
|
||||
msgTs <- strP_
|
||||
msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags
|
||||
msgBody <- EncRcvMsgBody <$> strP
|
||||
pure RcvMessage {msgId, msgTs, msgFlags, msgBody}
|
||||
|
||||
newtype EncRcvMsgBody = EncRcvMsgBody ByteString
|
||||
deriving (Eq, Show)
|
||||
|
||||
data RcvMsgBody = RcvMsgBody
|
||||
{ msgTs :: SystemTime,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: C.MaxLenBS MaxMessageLen
|
||||
}
|
||||
|
||||
encodeRcvMsgBody :: RcvMsgBody -> C.MaxLenBS MaxRcvMessageLen
|
||||
encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody} =
|
||||
let rcvMeta :: C.MaxLenBS 16 = C.unsafeMaxLenBS $ smpEncode (msgTs, msgFlags, ' ')
|
||||
in C.appendMaxLenBS rcvMeta msgBody
|
||||
|
||||
data ClientRcvMsgBody = ClientRcvMsgBody
|
||||
{ msgTs :: SystemTime,
|
||||
msgFlags :: MsgFlags,
|
||||
msgBody :: ByteString
|
||||
}
|
||||
|
||||
clientRcvMsgBodyP :: Parser ClientRcvMsgBody
|
||||
clientRcvMsgBodyP = do
|
||||
msgTs <- smpP
|
||||
msgFlags <- smpP
|
||||
Tail msgBody <- _smpP
|
||||
pure ClientRcvMsgBody {msgTs, msgFlags, msgBody}
|
||||
|
||||
instance StrEncoding Message where
|
||||
strEncode Message {msgId, msgTs, msgFlags, msgBody} =
|
||||
B.unwords
|
||||
[ strEncode msgId,
|
||||
strEncode msgTs,
|
||||
"flags=" <> strEncode msgFlags,
|
||||
strEncode msgBody
|
||||
]
|
||||
strP = do
|
||||
msgId <- strP_
|
||||
msgTs <- strP_
|
||||
msgFlags <- ("flags=" *> strP_) <|> pure noMsgFlags
|
||||
msgBody <- strP
|
||||
pure Message {msgId, msgTs, msgFlags, msgBody}
|
||||
|
||||
type EncNMsgMeta = ByteString
|
||||
|
||||
data SMPMsgMeta = SMPMsgMeta
|
||||
{ msgId :: MsgId,
|
||||
msgTs :: SystemTime,
|
||||
msgFlags :: MsgFlags
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
rcvMessageMeta :: MsgId -> ClientRcvMsgBody -> SMPMsgMeta
|
||||
rcvMessageMeta msgId ClientRcvMsgBody {msgTs, msgFlags} = SMPMsgMeta {msgId, msgTs, msgFlags}
|
||||
|
||||
data NMsgMeta = NMsgMeta
|
||||
{ msgId :: MsgId,
|
||||
msgTs :: SystemTime
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
instance Encoding NMsgMeta where
|
||||
smpEncode NMsgMeta {msgId, msgTs} =
|
||||
smpEncode (msgId, msgTs)
|
||||
smpP = do
|
||||
-- Tail here is to allow extension in the future clients/servers
|
||||
(msgId, msgTs, Tail _) <- smpP
|
||||
pure NMsgMeta {msgId, msgTs}
|
||||
|
||||
-- it must be data for correct JSON encoding
|
||||
data MsgFlags = MsgFlags {notification :: Bool}
|
||||
deriving (Eq, Show, Generic)
|
||||
|
||||
instance ToJSON MsgFlags where toEncoding = J.genericToEncoding J.defaultOptions
|
||||
|
||||
-- this encoding should not become bigger than 7 bytes (currently it is 1 byte)
|
||||
instance Encoding MsgFlags where
|
||||
smpEncode MsgFlags {notification} = smpEncode notification
|
||||
smpP = do
|
||||
notification <- smpP <* A.takeTill (== ' ')
|
||||
pure MsgFlags {notification}
|
||||
|
||||
instance StrEncoding MsgFlags where
|
||||
strEncode = smpEncode
|
||||
{-# INLINE strEncode #-}
|
||||
strP = smpP
|
||||
{-# INLINE strP #-}
|
||||
|
||||
noMsgFlags :: MsgFlags
|
||||
noMsgFlags = MsgFlags {notification = False}
|
||||
|
||||
-- * SMP command tags
|
||||
|
||||
data CommandTag (p :: Party) where
|
||||
|
@ -235,6 +407,8 @@ data CommandTag (p :: Party) where
|
|||
SUB_ :: CommandTag Recipient
|
||||
KEY_ :: CommandTag Recipient
|
||||
NKEY_ :: CommandTag Recipient
|
||||
NDEL_ :: CommandTag Recipient
|
||||
GET_ :: CommandTag Recipient
|
||||
ACK_ :: CommandTag Recipient
|
||||
OFF_ :: CommandTag Recipient
|
||||
DEL_ :: CommandTag Recipient
|
||||
|
@ -264,7 +438,7 @@ class ProtocolMsgTag t where
|
|||
|
||||
messageTagP :: ProtocolMsgTag t => Parser t
|
||||
messageTagP =
|
||||
maybe (fail "bad command") pure . decodeTag
|
||||
maybe (fail "bad message") pure . decodeTag
|
||||
=<< (A.takeTill (== ' ') <* optional A.space)
|
||||
|
||||
instance PartyI p => Encoding (CommandTag p) where
|
||||
|
@ -273,6 +447,8 @@ instance PartyI p => Encoding (CommandTag p) where
|
|||
SUB_ -> "SUB"
|
||||
KEY_ -> "KEY"
|
||||
NKEY_ -> "NKEY"
|
||||
NDEL_ -> "NDEL"
|
||||
GET_ -> "GET"
|
||||
ACK_ -> "ACK"
|
||||
OFF_ -> "OFF"
|
||||
DEL_ -> "DEL"
|
||||
|
@ -287,6 +463,8 @@ instance ProtocolMsgTag CmdTag where
|
|||
"SUB" -> Just $ CT SRecipient SUB_
|
||||
"KEY" -> Just $ CT SRecipient KEY_
|
||||
"NKEY" -> Just $ CT SRecipient NKEY_
|
||||
"NDEL" -> Just $ CT SRecipient NDEL_
|
||||
"GET" -> Just $ CT SRecipient GET_
|
||||
"ACK" -> Just $ CT SRecipient ACK_
|
||||
"OFF" -> Just $ CT SRecipient OFF_
|
||||
"DEL" -> Just $ CT SRecipient DEL_
|
||||
|
@ -372,34 +550,109 @@ instance Encoding ClientMessage where
|
|||
smpEncode (ClientMessage h msg) = smpEncode h <> msg
|
||||
smpP = ClientMessage <$> smpP <*> A.takeByteString
|
||||
|
||||
-- | SMP server location and transport key digest (hash).
|
||||
data SMPServer = SMPServer
|
||||
{ host :: HostName,
|
||||
type SMPServer = ProtocolServer 'PSMP
|
||||
|
||||
pattern SMPServer :: HostName -> ServiceName -> C.KeyHash -> ProtocolServer 'PSMP
|
||||
pattern SMPServer host port keyHash = ProtocolServer SPSMP host port keyHash
|
||||
|
||||
{-# COMPLETE SMPServer #-}
|
||||
|
||||
type NtfServer = ProtocolServer 'PNTF
|
||||
|
||||
pattern NtfServer :: HostName -> ServiceName -> C.KeyHash -> ProtocolServer 'PNTF
|
||||
pattern NtfServer host port keyHash = ProtocolServer SPNTF host port keyHash
|
||||
|
||||
{-# COMPLETE NtfServer #-}
|
||||
|
||||
data ProtocolType = PSMP | PNTF
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance StrEncoding ProtocolType where
|
||||
strEncode = \case
|
||||
PSMP -> "smp"
|
||||
PNTF -> "ntf"
|
||||
strP =
|
||||
A.takeTill (== ':') >>= \case
|
||||
"smp" -> pure PSMP
|
||||
"ntf" -> pure PNTF
|
||||
_ -> fail "bad ProtocolType"
|
||||
|
||||
data SProtocolType (p :: ProtocolType) where
|
||||
SPSMP :: SProtocolType 'PSMP
|
||||
SPNTF :: SProtocolType 'PNTF
|
||||
|
||||
deriving instance Eq (SProtocolType p)
|
||||
|
||||
deriving instance Ord (SProtocolType p)
|
||||
|
||||
deriving instance Show (SProtocolType p)
|
||||
|
||||
data AProtocolType = forall p. ProtocolTypeI p => AProtocolType (SProtocolType p)
|
||||
|
||||
instance TestEquality SProtocolType where
|
||||
testEquality SPSMP SPSMP = Just Refl
|
||||
testEquality SPNTF SPNTF = Just Refl
|
||||
testEquality _ _ = Nothing
|
||||
|
||||
protocolType :: SProtocolType p -> ProtocolType
|
||||
protocolType = \case
|
||||
SPSMP -> PSMP
|
||||
SPNTF -> PNTF
|
||||
|
||||
aProtocolType :: ProtocolType -> AProtocolType
|
||||
aProtocolType = \case
|
||||
PSMP -> AProtocolType SPSMP
|
||||
PNTF -> AProtocolType SPNTF
|
||||
|
||||
instance ProtocolTypeI p => StrEncoding (SProtocolType p) where
|
||||
strEncode = strEncode . protocolType
|
||||
strP = (\(AProtocolType p) -> checkProtocolType p) <$?> strP
|
||||
|
||||
instance StrEncoding AProtocolType where
|
||||
strEncode (AProtocolType p) = strEncode p
|
||||
strP = aProtocolType <$> strP
|
||||
|
||||
checkProtocolType :: forall t p p'. (ProtocolTypeI p, ProtocolTypeI p') => t p' -> Either String (t p)
|
||||
checkProtocolType p = case testEquality (protocolTypeI @p) (protocolTypeI @p') of
|
||||
Just Refl -> Right p
|
||||
Nothing -> Left "bad ProtocolType"
|
||||
|
||||
class ProtocolTypeI (p :: ProtocolType) where
|
||||
protocolTypeI :: SProtocolType p
|
||||
|
||||
instance ProtocolTypeI 'PSMP where protocolTypeI = SPSMP
|
||||
|
||||
instance ProtocolTypeI 'PNTF where protocolTypeI = SPNTF
|
||||
|
||||
-- | server location and transport key digest (hash).
|
||||
data ProtocolServer p = ProtocolServer
|
||||
{ scheme :: SProtocolType p,
|
||||
host :: HostName,
|
||||
port :: ServiceName,
|
||||
keyHash :: C.KeyHash
|
||||
}
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
instance IsString SMPServer where
|
||||
instance ProtocolTypeI p => IsString (ProtocolServer p) where
|
||||
fromString = parseString strDecode
|
||||
|
||||
instance Encoding SMPServer where
|
||||
smpEncode SMPServer {host, port, keyHash} =
|
||||
instance ProtocolTypeI p => Encoding (ProtocolServer p) where
|
||||
smpEncode ProtocolServer {host, port, keyHash} =
|
||||
smpEncode (host, port, keyHash)
|
||||
smpP = do
|
||||
(host, port, keyHash) <- smpP
|
||||
pure SMPServer {host, port, keyHash}
|
||||
pure ProtocolServer {scheme = protocolTypeI @p, host, port, keyHash}
|
||||
|
||||
instance StrEncoding SMPServer where
|
||||
strEncode SMPServer {host, port, keyHash} =
|
||||
"smp://" <> strEncode keyHash <> "@" <> strEncode (SrvLoc host port)
|
||||
instance ProtocolTypeI p => StrEncoding (ProtocolServer p) where
|
||||
strEncode ProtocolServer {scheme, host, port, keyHash} =
|
||||
strEncode scheme <> "://" <> strEncode keyHash <> "@" <> strEncode (SrvLoc host port)
|
||||
strP = do
|
||||
_ <- "smp://"
|
||||
scheme <- strP <* "://"
|
||||
keyHash <- strP <* A.char '@'
|
||||
SrvLoc host port <- strP
|
||||
pure SMPServer {host, port, keyHash}
|
||||
pure ProtocolServer {scheme, host, port, keyHash}
|
||||
|
||||
instance ToJSON SMPServer where
|
||||
instance ProtocolTypeI p => ToJSON (ProtocolServer p) where
|
||||
toJSON = strToJSON
|
||||
toEncoding = strToJEncoding
|
||||
|
||||
|
@ -458,12 +711,18 @@ type SndPrivateSignKey = C.APrivateSignKey
|
|||
-- | Sender's public key used by SMP server to verify authorization of SMP commands.
|
||||
type SndPublicVerifyKey = C.APublicVerifyKey
|
||||
|
||||
-- | Private key used by push notifications server to authorize (sign) LSTN command.
|
||||
-- | Private key used by push notifications server to authorize (sign) NSUB command.
|
||||
type NtfPrivateSignKey = C.APrivateSignKey
|
||||
|
||||
-- | Public key used by SMP server to verify authorization of LSTN command sent by push notifications server.
|
||||
-- | Public key used by SMP server to verify authorization of NSUB command sent by push notifications server.
|
||||
type NtfPublicVerifyKey = C.APublicVerifyKey
|
||||
|
||||
-- | Public key used for DH exchange to encrypt notification metadata from server to recipient
|
||||
type RcvNtfPublicDhKey = C.PublicKeyX25519
|
||||
|
||||
-- | DH Secret used to encrypt notification metadata from server to recipient
|
||||
type RcvNtfDhSecret = C.DhSecretX25519
|
||||
|
||||
-- | SMP message server ID.
|
||||
type MsgId = ByteString
|
||||
|
||||
|
@ -484,7 +743,7 @@ data ErrorType
|
|||
QUOTA
|
||||
| -- | ACK command is sent without message to be acknowledged
|
||||
NO_MSG
|
||||
| -- | sent message is too large (> maxMessageLength = 16078 bytes)
|
||||
| -- | sent message is too large (> maxMessageLength = 16088 bytes)
|
||||
LARGE_MSG
|
||||
| -- | internal server error
|
||||
INTERNAL
|
||||
|
@ -508,12 +767,14 @@ data CommandError
|
|||
UNKNOWN
|
||||
| -- | error parsing command
|
||||
SYNTAX
|
||||
| -- | command is not allowed (SUB/GET cannot be used with the same queue in the same TCP connection)
|
||||
PROHIBITED
|
||||
| -- | transmission has no required credentials (signature or queue ID)
|
||||
NO_AUTH
|
||||
| -- | transmission has credentials that are not allowed for this command
|
||||
HAS_AUTH
|
||||
| -- | transmission has no required queue ID
|
||||
NO_QUEUE
|
||||
| -- | transmission has no required entity ID (e.g. SMP queue)
|
||||
NO_ENTITY
|
||||
deriving (Eq, Generic, Read, Show)
|
||||
|
||||
instance ToJSON CommandError where
|
||||
|
@ -534,34 +795,58 @@ transmissionP = do
|
|||
trn signature signed = do
|
||||
sessId <- smpP
|
||||
corrId <- smpP
|
||||
queueId <- smpP
|
||||
entityId <- smpP
|
||||
command <- A.takeByteString
|
||||
pure RawTransmission {signature, signed, sessId, corrId, queueId, command}
|
||||
pure RawTransmission {signature, signed, sessId, corrId, entityId, command}
|
||||
|
||||
class Protocol msg where
|
||||
class (ProtocolEncoding msg, ProtocolEncoding (ProtoCommand msg), Show msg) => Protocol msg where
|
||||
type ProtoCommand msg = cmd | cmd -> msg
|
||||
type ProtoType msg = (sch :: ProtocolType) | sch -> msg
|
||||
protocolClientHandshake :: forall c. Transport c => c -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
protocolPing :: ProtoCommand msg
|
||||
protocolError :: msg -> Maybe ErrorType
|
||||
|
||||
type ProtoServer msg = ProtocolServer (ProtoType msg)
|
||||
|
||||
instance Protocol BrokerMsg where
|
||||
type ProtoCommand BrokerMsg = Cmd
|
||||
type ProtoType BrokerMsg = 'PSMP
|
||||
protocolClientHandshake = smpClientHandshake
|
||||
protocolPing = Cmd SSender PING
|
||||
protocolError = \case
|
||||
ERR e -> Just e
|
||||
_ -> Nothing
|
||||
|
||||
class ProtocolMsgTag (Tag msg) => ProtocolEncoding msg where
|
||||
type Tag msg
|
||||
encodeProtocol :: msg -> ByteString
|
||||
protocolP :: Tag msg -> Parser msg
|
||||
encodeProtocol :: Version -> msg -> ByteString
|
||||
protocolP :: Version -> Tag msg -> Parser msg
|
||||
checkCredentials :: SignedRawTransmission -> msg -> Either ErrorType msg
|
||||
|
||||
instance PartyI p => Protocol (Command p) where
|
||||
instance PartyI p => ProtocolEncoding (Command p) where
|
||||
type Tag (Command p) = CommandTag p
|
||||
encodeProtocol = \case
|
||||
encodeProtocol v = \case
|
||||
NEW rKey dhKey -> e (NEW_, ' ', rKey, dhKey)
|
||||
SUB -> e SUB_
|
||||
KEY k -> e (KEY_, ' ', k)
|
||||
NKEY k -> e (NKEY_, ' ', k)
|
||||
ACK -> e ACK_
|
||||
NKEY k dhKey -> e (NKEY_, ' ', k, dhKey)
|
||||
NDEL -> e NDEL_
|
||||
GET -> e GET_
|
||||
ACK msgId
|
||||
| v == 1 -> e ACK_
|
||||
| otherwise -> e (ACK_, ' ', msgId)
|
||||
OFF -> e OFF_
|
||||
DEL -> e DEL_
|
||||
SEND msg -> e (SEND_, ' ', Tail msg)
|
||||
SEND flags msg
|
||||
| v == 1 -> e (SEND_, ' ', Tail msg)
|
||||
| otherwise -> e (SEND_, ' ', flags, ' ', Tail msg)
|
||||
PING -> e PING_
|
||||
NSUB -> e NSUB_
|
||||
where
|
||||
e :: Encoding a => a -> ByteString
|
||||
e = smpEncode
|
||||
|
||||
protocolP tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP (CT (sParty @p) tag)
|
||||
protocolP v tag = (\(Cmd _ c) -> checkParty c) <$?> protocolP v (CT (sParty @p) tag)
|
||||
|
||||
checkCredentials (sig, _, queueId, _) cmd = case cmd of
|
||||
-- NEW must have signature but NOT queue ID
|
||||
|
@ -570,8 +855,8 @@ instance PartyI p => Protocol (Command p) where
|
|||
| not (B.null queueId) -> Left $ CMD HAS_AUTH
|
||||
| otherwise -> Right cmd
|
||||
-- SEND must have queue ID, signature is not always required
|
||||
SEND _
|
||||
| B.null queueId -> Left $ CMD NO_QUEUE
|
||||
SEND {}
|
||||
| B.null queueId -> Left $ CMD NO_ENTITY
|
||||
| otherwise -> Right cmd
|
||||
-- PING must not have queue ID or signature
|
||||
PING
|
||||
|
@ -582,35 +867,44 @@ instance PartyI p => Protocol (Command p) where
|
|||
| isNothing sig || B.null queueId -> Left $ CMD NO_AUTH
|
||||
| otherwise -> Right cmd
|
||||
|
||||
instance Protocol Cmd where
|
||||
instance ProtocolEncoding Cmd where
|
||||
type Tag Cmd = CmdTag
|
||||
encodeProtocol (Cmd _ c) = encodeProtocol c
|
||||
encodeProtocol v (Cmd _ c) = encodeProtocol v c
|
||||
|
||||
protocolP = \case
|
||||
protocolP v = \case
|
||||
CT SRecipient tag ->
|
||||
Cmd SRecipient <$> case tag of
|
||||
NEW_ -> NEW <$> _smpP <*> smpP
|
||||
SUB_ -> pure SUB
|
||||
KEY_ -> KEY <$> _smpP
|
||||
NKEY_ -> NKEY <$> _smpP
|
||||
ACK_ -> pure ACK
|
||||
NKEY_ -> NKEY <$> _smpP <*> smpP
|
||||
NDEL_ -> pure NDEL
|
||||
GET_ -> pure GET
|
||||
ACK_
|
||||
| v == 1 -> pure $ ACK ""
|
||||
| otherwise -> ACK <$> _smpP
|
||||
OFF_ -> pure OFF
|
||||
DEL_ -> pure DEL
|
||||
CT SSender tag ->
|
||||
Cmd SSender <$> case tag of
|
||||
SEND_ -> SEND . unTail <$> _smpP
|
||||
SEND_
|
||||
| v == 1 -> SEND <$> pure noMsgFlags <*> (unTail <$> _smpP)
|
||||
| otherwise -> SEND <$> _smpP <*> (unTail <$> _smpP)
|
||||
PING_ -> pure PING
|
||||
CT SNotifier NSUB_ -> pure $ Cmd SNotifier NSUB
|
||||
|
||||
checkCredentials t (Cmd p c) = Cmd p <$> checkCredentials t c
|
||||
|
||||
instance Protocol BrokerMsg where
|
||||
instance ProtocolEncoding BrokerMsg where
|
||||
type Tag BrokerMsg = BrokerMsgTag
|
||||
encodeProtocol = \case
|
||||
encodeProtocol v = \case
|
||||
IDS (QIK rcvId sndId srvDh) -> e (IDS_, ' ', rcvId, sndId, srvDh)
|
||||
MSG msgId ts msgBody -> e (MSG_, ' ', msgId, ts, Tail msgBody)
|
||||
NID nId -> e (NID_, ' ', nId)
|
||||
NMSG -> e NMSG_
|
||||
MSG RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body}
|
||||
| v == 1 -> e (MSG_, ' ', msgId, msgTs, Tail body)
|
||||
| v == 2 -> e (MSG_, ' ', msgId, msgTs, msgFlags, ' ', Tail body)
|
||||
| otherwise -> e (MSG_, ' ', msgId, Tail body)
|
||||
NID nId srvNtfDh -> e (NID_, ' ', nId, srvNtfDh)
|
||||
NMSG nmsgNonce encNMsgMeta -> e (NMSG_, ' ', nmsgNonce, encNMsgMeta)
|
||||
END -> e END_
|
||||
OK -> e OK_
|
||||
ERR err -> e (ERR_, ' ', err)
|
||||
|
@ -619,18 +913,25 @@ instance Protocol BrokerMsg where
|
|||
e :: Encoding a => a -> ByteString
|
||||
e = smpEncode
|
||||
|
||||
protocolP = \case
|
||||
MSG_ -> MSG <$> _smpP <*> smpP <*> (unTail <$> smpP)
|
||||
protocolP v = \case
|
||||
MSG_ -> do
|
||||
msgId <- _smpP
|
||||
MSG <$> case v of
|
||||
1 -> RcvMessage msgId <$> smpP <*> pure noMsgFlags <*> bodyP
|
||||
2 -> RcvMessage msgId <$> smpP <*> smpP <*> (A.space *> bodyP)
|
||||
_ -> RcvMessage msgId (MkSystemTime 0 0) noMsgFlags <$> bodyP
|
||||
where
|
||||
bodyP = EncRcvMsgBody . unTail <$> smpP
|
||||
IDS_ -> IDS <$> (QIK <$> _smpP <*> smpP <*> smpP)
|
||||
NID_ -> NID <$> _smpP
|
||||
NMSG_ -> pure NMSG
|
||||
NID_ -> NID <$> _smpP <*> smpP
|
||||
NMSG_ -> NMSG <$> _smpP <*> smpP
|
||||
END_ -> pure END
|
||||
OK_ -> pure OK
|
||||
ERR_ -> ERR <$> _smpP
|
||||
PONG_ -> pure PONG
|
||||
|
||||
checkCredentials (_, _, queueId, _) cmd = case cmd of
|
||||
-- IDS response must not have queue ID
|
||||
-- IDS response should not have queue ID
|
||||
IDS _ -> Right cmd
|
||||
-- ERR response does not always have queue ID
|
||||
ERR _ -> Right cmd
|
||||
|
@ -640,18 +941,18 @@ instance Protocol BrokerMsg where
|
|||
| otherwise -> Left $ CMD HAS_AUTH
|
||||
-- other broker responses must have queue ID
|
||||
_
|
||||
| B.null queueId -> Left $ CMD NO_QUEUE
|
||||
| B.null queueId -> Left $ CMD NO_ENTITY
|
||||
| otherwise -> Right cmd
|
||||
|
||||
_smpP :: Encoding a => Parser a
|
||||
_smpP = A.space *> smpP
|
||||
|
||||
-- | Parse SMP protocol commands and broker messages
|
||||
parseProtocol :: (Protocol msg, ProtocolMsgTag (Tag msg)) => ByteString -> Either ErrorType msg
|
||||
parseProtocol s =
|
||||
parseProtocol :: ProtocolEncoding msg => Version -> ByteString -> Either ErrorType msg
|
||||
parseProtocol v s =
|
||||
let (tag, params) = B.break (== ' ') s
|
||||
in case decodeTag tag of
|
||||
Just cmd -> parse (protocolP cmd) (CMD SYNTAX) params
|
||||
Just cmd -> parse (protocolP v cmd) (CMD SYNTAX) params
|
||||
Nothing -> Left $ CMD UNKNOWN
|
||||
|
||||
checkParty :: forall t p p'. (PartyI p, PartyI p') => t p' -> Either String (t p)
|
||||
|
@ -693,43 +994,42 @@ instance Encoding CommandError where
|
|||
smpEncode e = case e of
|
||||
UNKNOWN -> "UNKNOWN"
|
||||
SYNTAX -> "SYNTAX"
|
||||
PROHIBITED -> "PROHIBITED"
|
||||
NO_AUTH -> "NO_AUTH"
|
||||
HAS_AUTH -> "HAS_AUTH"
|
||||
NO_QUEUE -> "NO_QUEUE"
|
||||
NO_ENTITY -> "NO_ENTITY"
|
||||
smpP =
|
||||
A.takeTill (== ' ') >>= \case
|
||||
"UNKNOWN" -> pure UNKNOWN
|
||||
"SYNTAX" -> pure SYNTAX
|
||||
"PROHIBITED" -> pure PROHIBITED
|
||||
"NO_AUTH" -> pure NO_AUTH
|
||||
"HAS_AUTH" -> pure HAS_AUTH
|
||||
"NO_QUEUE" -> pure NO_QUEUE
|
||||
"NO_ENTITY" -> pure NO_ENTITY
|
||||
"NO_QUEUE" -> pure NO_ENTITY
|
||||
_ -> fail "bad command error type"
|
||||
|
||||
-- | Send signed SMP transmission to TCP transport.
|
||||
tPut :: Transport c => THandle c -> SentRawTransmission -> IO (Either TransportError ())
|
||||
tPut th (sig, t) = tPutBlock th $ smpEncode (C.signatureBytes sig) <> t
|
||||
|
||||
encodeTransmission :: Protocol c => ByteString -> Transmission c -> ByteString
|
||||
encodeTransmission sessionId (CorrId corrId, queueId, command) =
|
||||
smpEncode (sessionId, corrId, queueId) <> encodeProtocol command
|
||||
encodeTransmission :: ProtocolEncoding c => Version -> ByteString -> Transmission c -> ByteString
|
||||
encodeTransmission v sessionId (CorrId corrId, queueId, command) =
|
||||
smpEncode (sessionId, corrId, queueId) <> encodeProtocol v command
|
||||
|
||||
-- | Receive and parse transmission from the TCP transport (ignoring any trailing padding).
|
||||
tGetParse :: Transport c => THandle c -> IO (Either TransportError RawTransmission)
|
||||
tGetParse th = (parse transmissionP TEBadBlock =<<) <$> tGetBlock th
|
||||
|
||||
-- | Receive client and server transmissions (determined by `cmd` type).
|
||||
tGet ::
|
||||
forall cmd c m.
|
||||
(Protocol cmd, ProtocolMsgTag (Tag cmd), Transport c, MonadIO m) =>
|
||||
THandle c ->
|
||||
m (SignedTransmission cmd)
|
||||
tGet th@THandle {sessionId} = liftIO (tGetParse th) >>= decodeParseValidate
|
||||
tGet :: forall cmd c m. (ProtocolEncoding cmd, Transport c, MonadIO m) => THandle c -> m (SignedTransmission cmd)
|
||||
tGet th@THandle {sessionId, thVersion = v} = liftIO (tGetParse th) >>= decodeParseValidate
|
||||
where
|
||||
decodeParseValidate :: Either TransportError RawTransmission -> m (SignedTransmission cmd)
|
||||
decodeParseValidate = \case
|
||||
Right RawTransmission {signature, signed, sessId, corrId, queueId, command}
|
||||
Right RawTransmission {signature, signed, sessId, corrId, entityId, command}
|
||||
| sessId == sessionId ->
|
||||
let decodedTransmission = (,corrId,queueId,command) <$> C.decodeSignature signature
|
||||
let decodedTransmission = (,corrId,entityId,command) <$> C.decodeSignature signature
|
||||
in either (const $ tError corrId) (tParseValidate signed) decodedTransmission
|
||||
| otherwise -> pure (Nothing, "", (CorrId corrId, "", Left SESSION))
|
||||
Left _ -> tError ""
|
||||
|
@ -738,6 +1038,6 @@ tGet th@THandle {sessionId} = liftIO (tGetParse th) >>= decodeParseValidate
|
|||
tError corrId = pure (Nothing, "", (CorrId corrId, "", Left BLOCK))
|
||||
|
||||
tParseValidate :: ByteString -> SignedRawTransmission -> m (SignedTransmission cmd)
|
||||
tParseValidate signed t@(sig, corrId, queueId, command) = do
|
||||
let cmd = parseProtocol command >>= checkCredentials t
|
||||
pure (sig, signed, (CorrId corrId, queueId, cmd))
|
||||
tParseValidate signed t@(sig, corrId, entityId, command) = do
|
||||
let cmd = parseProtocol v command >>= checkCredentials t
|
||||
pure (sig, signed, (CorrId corrId, entityId, cmd))
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
@ -23,36 +25,62 @@
|
|||
-- and optional append only log of SMP queue records.
|
||||
--
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md
|
||||
module Simplex.Messaging.Server (runSMPServer, runSMPServerBlocking) where
|
||||
module Simplex.Messaging.Server
|
||||
( runSMPServer,
|
||||
runSMPServerBlocking,
|
||||
disconnectTransport,
|
||||
verifyCmdSignature,
|
||||
dummyVerifyCmd,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Logger.Simple
|
||||
import Control.Monad
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Reader
|
||||
import Crypto.Random
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (fromRight)
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.List (intercalate)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (isNothing)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeLatin1)
|
||||
import Data.Time.Calendar.Month.Compat (pattern MonthDay)
|
||||
import Data.Time.Calendar.OrdinalDate (mondayStartWeek)
|
||||
import Data.Time.Clock (UTCTime (..), diffTimeToPicoseconds, getCurrentTime)
|
||||
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
||||
import Data.Time.Format.ISO8601 (iso8601Show)
|
||||
import Data.Type.Equality
|
||||
import GHC.TypeLits (KnownNat)
|
||||
import Network.Socket (ServiceName)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding (Encoding (smpEncode))
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.Env.STM
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Server.MsgStore
|
||||
import Simplex.Messaging.Server.MsgStore.STM (MsgQueue)
|
||||
import Simplex.Messaging.Server.QueueStore
|
||||
import Simplex.Messaging.Server.QueueStore.STM (QueueStore)
|
||||
import Simplex.Messaging.Server.Stats
|
||||
import Simplex.Messaging.Server.StoreLog
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Server
|
||||
import Simplex.Messaging.Util
|
||||
import System.Exit (exitFailure)
|
||||
import System.Mem.Weak (deRefWeak)
|
||||
import UnliftIO.Concurrent
|
||||
import UnliftIO.Directory (doesFileExist, renameFile)
|
||||
import UnliftIO.Exception
|
||||
import UnliftIO.IO
|
||||
import UnliftIO.STM
|
||||
|
@ -76,12 +104,14 @@ smpServer :: forall m. (MonadUnliftIO m, MonadReader Env m) => TMVar Bool -> m (
|
|||
smpServer started = do
|
||||
s <- asks server
|
||||
cfg@ServerConfig {transports} <- asks config
|
||||
restoreServerStats
|
||||
restoreServerMessages
|
||||
raceAny_
|
||||
( serverThread s subscribedQ subscribers subscriptions cancelSub :
|
||||
serverThread s ntfSubscribedQ notifiers ntfSubscriptions (\_ -> pure ()) :
|
||||
map runServer transports <> expireMessagesThread_ cfg
|
||||
map runServer transports <> expireMessagesThread_ cfg <> serverStatsThread_ cfg
|
||||
)
|
||||
`finally` withLog closeStoreLog
|
||||
`finally` (withLog closeStoreLog >> saveServerMessages >> saveServerStats)
|
||||
where
|
||||
runServer :: (ServiceName, ATransport) -> m ()
|
||||
runServer (tcpPort, ATransport t) = do
|
||||
|
@ -98,7 +128,7 @@ smpServer started = do
|
|||
m ()
|
||||
serverThread s subQ subs clientSubs unsub = forever $ do
|
||||
atomically updateSubscribers
|
||||
>>= fmap join . mapM endPreviousSubscriptions
|
||||
$>>= endPreviousSubscriptions
|
||||
>>= mapM_ unsub
|
||||
where
|
||||
updateSubscribers :: STM (Maybe (QueueId, Client))
|
||||
|
@ -110,8 +140,7 @@ smpServer started = do
|
|||
else do
|
||||
yes <- readTVar $ connected c'
|
||||
pure $ if yes then Just (qId, c') else Nothing
|
||||
TM.lookupInsert qId clnt (subs s)
|
||||
>>= fmap join . mapM clientToBeNotified
|
||||
TM.lookupInsert qId clnt (subs s) $>>= clientToBeNotified
|
||||
endPreviousSubscriptions :: (QueueId, Client) -> m (Maybe s)
|
||||
endPreviousSubscriptions (qId, c) = do
|
||||
void . forkIO . atomically $
|
||||
|
@ -119,43 +148,81 @@ smpServer started = do
|
|||
atomically $ TM.lookupDelete qId (clientSubs c)
|
||||
|
||||
expireMessagesThread_ :: ServerConfig -> [m ()]
|
||||
expireMessagesThread_ ServerConfig {messageTTL, expireMessagesInterval} =
|
||||
case (messageTTL, expireMessagesInterval) of
|
||||
(Just ttl, Just int) -> [expireMessages ttl int]
|
||||
_ -> []
|
||||
expireMessagesThread_ ServerConfig {messageExpiration = Just msgExp} = [expireMessages msgExp]
|
||||
expireMessagesThread_ _ = []
|
||||
|
||||
expireMessages :: Int64 -> Int -> m ()
|
||||
expireMessages ttl interval = do
|
||||
expireMessages :: ExpirationConfig -> m ()
|
||||
expireMessages expCfg = do
|
||||
ms <- asks msgStore
|
||||
quota <- asks $ msgQueueQuota . config
|
||||
let interval = checkInterval expCfg * 1000000
|
||||
forever $ do
|
||||
threadDelay interval
|
||||
old <- subtract ttl . systemSeconds <$> liftIO getSystemTime
|
||||
old <- liftIO $ expireBeforeEpoch expCfg
|
||||
rIds <- M.keysSet <$> readTVarIO ms
|
||||
forM_ rIds $ \rId ->
|
||||
atomically (getMsgQueue ms rId quota)
|
||||
>>= atomically . (`deleteExpiredMsgs` old)
|
||||
|
||||
serverStatsThread_ :: ServerConfig -> [m ()]
|
||||
serverStatsThread_ ServerConfig {logStatsInterval = Just interval, logStatsStartTime} =
|
||||
[logServerStats logStatsStartTime interval]
|
||||
serverStatsThread_ _ = []
|
||||
|
||||
logServerStats :: Int -> Int -> m ()
|
||||
logServerStats startAt logInterval = do
|
||||
initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime
|
||||
logInfo "fromTime,qCreated,qSecured,qDeleted,msgSent,msgRecv,dayMsgQueues,weekMsgQueues,monthMsgQueues"
|
||||
threadDelay $ 1000000 * (initialDelay + if initialDelay < 0 then 86400 else 0)
|
||||
ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, dayMsgQueues, weekMsgQueues, monthMsgQueues} <- asks serverStats
|
||||
let interval = 1000000 * logInterval
|
||||
forever $ do
|
||||
ts <- liftIO getCurrentTime
|
||||
fromTime' <- atomically $ swapTVar fromTime ts
|
||||
qCreated' <- atomically $ swapTVar qCreated 0
|
||||
qSecured' <- atomically $ swapTVar qSecured 0
|
||||
qDeleted' <- atomically $ swapTVar qDeleted 0
|
||||
msgSent' <- atomically $ swapTVar msgSent 0
|
||||
msgRecv' <- atomically $ swapTVar msgRecv 0
|
||||
let day = utctDay ts
|
||||
(_, wDay) = mondayStartWeek day
|
||||
MonthDay _ mDay = day
|
||||
(dayMsgQueues', weekMsgQueues', monthMsgQueues') <-
|
||||
atomically $ (,,) <$> periodCount 1 dayMsgQueues <*> periodCount wDay weekMsgQueues <*> periodCount mDay monthMsgQueues
|
||||
logInfo . T.pack $ intercalate "," [iso8601Show fromTime', show qCreated', show qSecured', show qDeleted', show msgSent', show msgRecv', show dayMsgQueues', weekMsgQueues', monthMsgQueues']
|
||||
threadDelay interval
|
||||
where
|
||||
periodCount :: Int -> TVar (Set RecipientId) -> STM String
|
||||
periodCount 1 pVar = show . S.size <$> swapTVar pVar S.empty
|
||||
periodCount _ _ = pure ""
|
||||
|
||||
runClient :: Transport c => TProxy c -> c -> m ()
|
||||
runClient _ h = do
|
||||
kh <- asks serverIdentity
|
||||
liftIO (runExceptT $ serverHandshake h kh) >>= \case
|
||||
smpVRange <- asks $ smpServerVRange . config
|
||||
liftIO (runExceptT $ smpServerHandshake h kh smpVRange) >>= \case
|
||||
Right th -> runClientTransport th
|
||||
Left _ -> pure ()
|
||||
|
||||
runClientTransport :: (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> m ()
|
||||
runClientTransport th@THandle {sessionId} = do
|
||||
runClientTransport th@THandle {thVersion, sessionId} = do
|
||||
q <- asks $ tbqSize . config
|
||||
c <- atomically $ newClient q sessionId
|
||||
ts <- liftIO getSystemTime
|
||||
c <- atomically $ newClient q thVersion sessionId ts
|
||||
s <- asks server
|
||||
raceAny_ [send th c, client c s, receive th c]
|
||||
expCfg <- asks $ inactiveClientExpiration . config
|
||||
raceAny_ ([send th c, client c s, receive th c] <> disconnectThread_ c expCfg)
|
||||
`finally` clientDisconnected c
|
||||
where
|
||||
disconnectThread_ c (Just expCfg) = [disconnectTransport th c activeAt expCfg]
|
||||
disconnectThread_ _ _ = []
|
||||
|
||||
clientDisconnected :: (MonadUnliftIO m, MonadReader Env m) => Client -> m ()
|
||||
clientDisconnected c@Client {subscriptions, connected} = do
|
||||
atomically $ writeTVar connected False
|
||||
subs <- readTVarIO subscriptions
|
||||
mapM_ cancelSub subs
|
||||
atomically $ writeTVar subscriptions M.empty
|
||||
cs <- asks $ subscribers . server
|
||||
atomically . mapM_ (\rId -> TM.update deleteCurrentClient rId cs) $ M.keys subs
|
||||
where
|
||||
|
@ -167,14 +234,16 @@ clientDisconnected c@Client {subscriptions, connected} = do
|
|||
sameClientSession :: Client -> Client -> Bool
|
||||
sameClientSession Client {sessionId} Client {sessionId = s'} = sessionId == s'
|
||||
|
||||
cancelSub :: MonadUnliftIO m => Sub -> m ()
|
||||
cancelSub = \case
|
||||
Sub {subThread = SubThread t} -> killThread t
|
||||
_ -> return ()
|
||||
cancelSub :: MonadUnliftIO m => TVar Sub -> m ()
|
||||
cancelSub sub =
|
||||
readTVarIO sub >>= \case
|
||||
Sub {subThread = SubThread t} -> liftIO $ deRefWeak t >>= mapM_ killThread
|
||||
_ -> return ()
|
||||
|
||||
receive :: (Transport c, MonadUnliftIO m, MonadReader Env m) => THandle c -> Client -> m ()
|
||||
receive th Client {rcvQ, sndQ} = forever $ do
|
||||
receive th Client {rcvQ, sndQ, activeAt} = forever $ do
|
||||
(sig, signed, (corrId, queueId, cmdOrError)) <- tGet th
|
||||
atomically . writeTVar activeAt =<< liftIO getSystemTime
|
||||
case cmdOrError of
|
||||
Left e -> write sndQ (corrId, queueId, ERR e)
|
||||
Right cmd -> do
|
||||
|
@ -186,36 +255,50 @@ receive th Client {rcvQ, sndQ} = forever $ do
|
|||
write q t = atomically $ writeTBQueue q t
|
||||
|
||||
send :: (Transport c, MonadUnliftIO m) => THandle c -> Client -> m ()
|
||||
send h Client {sndQ, sessionId} = forever $ do
|
||||
send h@THandle {thVersion = v} Client {sndQ, sessionId, activeAt} = forever $ do
|
||||
t <- atomically $ readTBQueue sndQ
|
||||
liftIO $ tPut h (Nothing, encodeTransmission sessionId t)
|
||||
-- TODO the line below can return Left, but we ignore it and do not disconnect the client
|
||||
void . liftIO $ tPut h (Nothing, encodeTransmission v sessionId t)
|
||||
atomically . writeTVar activeAt =<< liftIO getSystemTime
|
||||
|
||||
disconnectTransport :: (Transport c, MonadUnliftIO m) => THandle c -> client -> (client -> TVar SystemTime) -> ExpirationConfig -> m ()
|
||||
disconnectTransport THandle {connection} c activeAt expCfg = do
|
||||
let interval = checkInterval expCfg * 1000000
|
||||
forever . liftIO $ do
|
||||
threadDelay interval
|
||||
old <- expireBeforeEpoch expCfg
|
||||
ts <- readTVarIO $ activeAt c
|
||||
when (systemSeconds ts < old) $ closeConnection connection
|
||||
|
||||
verifyTransmission ::
|
||||
forall m. (MonadUnliftIO m, MonadReader Env m) => Maybe C.ASignature -> ByteString -> QueueId -> Cmd -> m Bool
|
||||
verifyTransmission sig_ signed queueId cmd = do
|
||||
case cmd of
|
||||
Cmd SRecipient (NEW k _) -> pure $ verifySignature k
|
||||
Cmd SRecipient _ -> verifyCmd SRecipient $ verifySignature . recipientKey
|
||||
Cmd SSender (SEND _) -> verifyCmd SSender $ verifyMaybe . senderKey
|
||||
Cmd SRecipient (NEW k _) -> pure $ verifyCmdSignature sig_ signed k
|
||||
Cmd SRecipient _ -> verifyCmd SRecipient $ verifyCmdSignature sig_ signed . recipientKey
|
||||
Cmd SSender SEND {} -> verifyCmd SSender $ verifyMaybe . senderKey
|
||||
Cmd SSender PING -> pure True
|
||||
Cmd SNotifier NSUB -> verifyCmd SNotifier $ verifyMaybe . fmap snd . notifier
|
||||
Cmd SNotifier NSUB -> verifyCmd SNotifier $ verifyMaybe . fmap notifierKey . notifier
|
||||
where
|
||||
verifyCmd :: SParty p -> (QueueRec -> Bool) -> m Bool
|
||||
verifyCmd party f = do
|
||||
st <- asks queueStore
|
||||
q <- atomically $ getQueue st party queueId
|
||||
pure $ either (const $ maybe False dummyVerify sig_ `seq` False) f q
|
||||
pure $ either (const $ maybe False (dummyVerifyCmd signed) sig_ `seq` False) f q
|
||||
verifyMaybe :: Maybe C.APublicVerifyKey -> Bool
|
||||
verifyMaybe = maybe (isNothing sig_) verifySignature
|
||||
verifySignature :: C.APublicVerifyKey -> Bool
|
||||
verifySignature key = maybe False (verify key) sig_
|
||||
verifyMaybe = maybe (isNothing sig_) $ verifyCmdSignature sig_ signed
|
||||
|
||||
verifyCmdSignature :: Maybe C.ASignature -> ByteString -> C.APublicVerifyKey -> Bool
|
||||
verifyCmdSignature sig_ signed key = maybe False (verify key) sig_
|
||||
where
|
||||
verify :: C.APublicVerifyKey -> C.ASignature -> Bool
|
||||
verify (C.APublicVerifyKey a k) sig@(C.ASignature a' s) =
|
||||
case (testEquality a a', C.signatureSize k == C.signatureSize s) of
|
||||
(Just Refl, True) -> C.verify' k s signed
|
||||
_ -> dummyVerify sig `seq` False
|
||||
dummyVerify :: C.ASignature -> Bool
|
||||
dummyVerify (C.ASignature _ s) = C.verify' (dummyPublicKey s) s signed
|
||||
_ -> dummyVerifyCmd signed sig `seq` False
|
||||
|
||||
dummyVerifyCmd :: ByteString -> C.ASignature -> Bool
|
||||
dummyVerifyCmd signed (C.ASignature _ s) = C.verify' (dummyPublicKey s) s signed
|
||||
|
||||
-- These dummy keys are used with `dummyVerify` function to mitigate timing attacks
|
||||
-- by having the same time of the response whether a queue exists or nor, for all valid key/signature sizes
|
||||
|
@ -231,7 +314,7 @@ dummyKeyEd448 :: C.PublicKey 'C.Ed448
|
|||
dummyKeyEd448 = "MEMwBQYDK2VxAzoA6ibQc9XpkSLtwrf7PLvp81qW/etiumckVFImCMRdftcG/XopbOSaq9qyLhrgJWKOLyNrQPNVvpMA"
|
||||
|
||||
client :: forall m. (MonadUnliftIO m, MonadReader Env m) => Client -> Server -> m ()
|
||||
client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscribedQ, ntfSubscribedQ, notifiers} =
|
||||
client clnt@Client {thVersion, subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscribedQ, ntfSubscribedQ, notifiers} =
|
||||
forever $
|
||||
atomically (readTBQueue rcvQ)
|
||||
>>= processCommand
|
||||
|
@ -243,7 +326,7 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
|||
case cmd of
|
||||
Cmd SSender command ->
|
||||
case command of
|
||||
SEND msgBody -> sendMessage st msgBody
|
||||
SEND flags msgBody -> sendMessage st flags msgBody
|
||||
PING -> pure (corrId, "", PONG)
|
||||
Cmd SNotifier NSUB -> subscribeNotifications
|
||||
Cmd SRecipient command ->
|
||||
|
@ -253,10 +336,12 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
|||
(asks $ allowNewQueues . config)
|
||||
(createQueue st rKey dhKey)
|
||||
(pure (corrId, queueId, ERR AUTH))
|
||||
SUB -> subscribeQueue queueId
|
||||
ACK -> acknowledgeMsg
|
||||
SUB -> subscribeQueue st queueId
|
||||
GET -> getMessage st
|
||||
ACK msgId -> acknowledgeMsg st msgId
|
||||
KEY sKey -> secureQueue_ st sKey
|
||||
NKEY nKey -> addQueueNotifier_ st nKey
|
||||
NKEY nKey dhKey -> addQueueNotifier_ st nKey dhKey
|
||||
NDEL -> deleteQueueNotifier_ st
|
||||
OFF -> suspendQueue_ st
|
||||
DEL -> delQueueAndMsgs st
|
||||
where
|
||||
|
@ -288,7 +373,9 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
|||
Left e -> pure $ ERR e
|
||||
Right _ -> do
|
||||
withLog (`logCreateById` rId)
|
||||
subscribeQueue rId $> IDS (qik ids)
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (qCreated stats) (+ 1)
|
||||
subscribeQueue st rId $> IDS (qik ids)
|
||||
|
||||
logCreateById :: StoreLog 'WriteMode -> RecipientId -> IO ()
|
||||
logCreateById s rId =
|
||||
|
@ -304,60 +391,147 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
|||
secureQueue_ :: QueueStore -> SndPublicVerifyKey -> m (Transmission BrokerMsg)
|
||||
secureQueue_ st sKey = do
|
||||
withLog $ \s -> logSecureQueue s queueId sKey
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (qSecured stats) (+ 1)
|
||||
atomically $ (corrId,queueId,) . either ERR (const OK) <$> secureQueue st queueId sKey
|
||||
|
||||
addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> m (Transmission BrokerMsg)
|
||||
addQueueNotifier_ st nKey = (corrId,queueId,) <$> addNotifierRetry 3
|
||||
addQueueNotifier_ :: QueueStore -> NtfPublicVerifyKey -> RcvNtfPublicDhKey -> m (Transmission BrokerMsg)
|
||||
addQueueNotifier_ st notifierKey dhKey = do
|
||||
(rcvPublicDhKey, privDhKey) <- liftIO C.generateKeyPair'
|
||||
let rcvNtfDhSecret = C.dh' dhKey privDhKey
|
||||
(corrId,queueId,) <$> addNotifierRetry 3 rcvPublicDhKey rcvNtfDhSecret
|
||||
where
|
||||
addNotifierRetry :: Int -> m BrokerMsg
|
||||
addNotifierRetry 0 = pure $ ERR INTERNAL
|
||||
addNotifierRetry n = do
|
||||
nId <- randomId =<< asks (queueIdBytes . config)
|
||||
atomically (addQueueNotifier st queueId nId nKey) >>= \case
|
||||
Left DUPLICATE_ -> addNotifierRetry $ n - 1
|
||||
addNotifierRetry :: Int -> RcvNtfPublicDhKey -> RcvNtfDhSecret -> m BrokerMsg
|
||||
addNotifierRetry 0 _ _ = pure $ ERR INTERNAL
|
||||
addNotifierRetry n rcvPublicDhKey rcvNtfDhSecret = do
|
||||
notifierId <- randomId =<< asks (queueIdBytes . config)
|
||||
let ntfCreds = NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}
|
||||
atomically (addQueueNotifier st queueId ntfCreds) >>= \case
|
||||
Left DUPLICATE_ -> addNotifierRetry (n - 1) rcvPublicDhKey rcvNtfDhSecret
|
||||
Left e -> pure $ ERR e
|
||||
Right _ -> do
|
||||
withLog $ \s -> logAddNotifier s queueId nId nKey
|
||||
pure $ NID nId
|
||||
withLog $ \s -> logAddNotifier s queueId ntfCreds
|
||||
pure $ NID notifierId rcvPublicDhKey
|
||||
|
||||
deleteQueueNotifier_ :: QueueStore -> m (Transmission BrokerMsg)
|
||||
deleteQueueNotifier_ st = do
|
||||
withLog (`logDeleteNotifier` queueId)
|
||||
okResp <$> atomically (deleteQueueNotifier st queueId)
|
||||
|
||||
suspendQueue_ :: QueueStore -> m (Transmission BrokerMsg)
|
||||
suspendQueue_ st = do
|
||||
withLog (`logDeleteQueue` queueId)
|
||||
okResp <$> atomically (suspendQueue st queueId)
|
||||
|
||||
subscribeQueue :: RecipientId -> m (Transmission BrokerMsg)
|
||||
subscribeQueue rId =
|
||||
atomically (getSubscription rId) >>= deliverMessage tryPeekMsg rId
|
||||
|
||||
getSubscription :: RecipientId -> STM Sub
|
||||
getSubscription rId = do
|
||||
TM.lookup rId subscriptions >>= \case
|
||||
Just s -> tryTakeTMVar (delivered s) $> s
|
||||
Nothing -> do
|
||||
subscribeQueue :: QueueStore -> RecipientId -> m (Transmission BrokerMsg)
|
||||
subscribeQueue st rId =
|
||||
atomically (TM.lookup rId subscriptions) >>= \case
|
||||
Nothing ->
|
||||
atomically newSub >>= deliver
|
||||
Just sub ->
|
||||
readTVarIO sub >>= \case
|
||||
Sub {subThread = ProhibitSub} ->
|
||||
-- cannot use SUB in the same connection where GET was used
|
||||
pure (corrId, rId, ERR $ CMD PROHIBITED)
|
||||
s ->
|
||||
atomically (tryTakeTMVar $ delivered s) >> deliver sub
|
||||
where
|
||||
newSub :: STM (TVar Sub)
|
||||
newSub = do
|
||||
writeTBQueue subscribedQ (rId, clnt)
|
||||
s <- newSubscription
|
||||
TM.insert rId s subscriptions
|
||||
return s
|
||||
sub <- newTVar =<< newSubscription NoSub
|
||||
TM.insert rId sub subscriptions
|
||||
pure sub
|
||||
deliver :: TVar Sub -> m (Transmission BrokerMsg)
|
||||
deliver sub = do
|
||||
q <- getStoreMsgQueue rId
|
||||
msg_ <- atomically $ tryPeekMsg q
|
||||
deliverMessage st rId sub q msg_
|
||||
|
||||
getMessage :: QueueStore -> m (Transmission BrokerMsg)
|
||||
getMessage st =
|
||||
atomically (TM.lookup queueId subscriptions) >>= \case
|
||||
Nothing ->
|
||||
atomically newSub >>= getMessage_
|
||||
Just sub ->
|
||||
readTVarIO sub >>= \case
|
||||
s@Sub {subThread = ProhibitSub} ->
|
||||
atomically (tryTakeTMVar $ delivered s)
|
||||
>> getMessage_ s
|
||||
-- cannot use GET in the same connection where there is an active subscription
|
||||
_ -> pure (corrId, queueId, ERR $ CMD PROHIBITED)
|
||||
where
|
||||
newSub :: STM Sub
|
||||
newSub = do
|
||||
s <- newSubscription ProhibitSub
|
||||
sub <- newTVar s
|
||||
TM.insert queueId sub subscriptions
|
||||
pure s
|
||||
getMessage_ :: Sub -> m (Transmission BrokerMsg)
|
||||
getMessage_ s = withRcvQueue st queueId $ \qr -> do
|
||||
q <- getStoreMsgQueue queueId
|
||||
atomically $
|
||||
tryPeekMsg q >>= \case
|
||||
Just msg ->
|
||||
let encMsg = encryptMsg qr msg
|
||||
in setDelivered s msg $> (corrId, queueId, MSG encMsg)
|
||||
_ -> pure (corrId, queueId, OK)
|
||||
|
||||
withRcvQueue :: QueueStore -> RecipientId -> (QueueRec -> m (Transmission BrokerMsg)) -> m (Transmission BrokerMsg)
|
||||
withRcvQueue st rId action =
|
||||
atomically (getQueue st SRecipient rId) >>= \case
|
||||
Left e -> pure (corrId, rId, ERR e)
|
||||
Right qr -> action qr
|
||||
|
||||
subscribeNotifications :: m (Transmission BrokerMsg)
|
||||
subscribeNotifications = atomically $ do
|
||||
whenM (isNothing <$> TM.lookup queueId ntfSubscriptions) $ do
|
||||
unlessM (TM.member queueId ntfSubscriptions) $ do
|
||||
writeTBQueue ntfSubscribedQ (queueId, clnt)
|
||||
TM.insert queueId () ntfSubscriptions
|
||||
pure ok
|
||||
|
||||
acknowledgeMsg :: m (Transmission BrokerMsg)
|
||||
acknowledgeMsg =
|
||||
atomically (withSub queueId $ \s -> const s <$$> tryTakeTMVar (delivered s))
|
||||
>>= \case
|
||||
Just (Just s) -> deliverMessage tryDelPeekMsg queueId s
|
||||
_ -> return $ err NO_MSG
|
||||
acknowledgeMsg :: QueueStore -> MsgId -> m (Transmission BrokerMsg)
|
||||
acknowledgeMsg st msgId = do
|
||||
atomically (TM.lookup queueId subscriptions) >>= \case
|
||||
Nothing -> pure $ err NO_MSG
|
||||
Just sub ->
|
||||
atomically (getDelivered sub) >>= \case
|
||||
Just s -> do
|
||||
q <- getStoreMsgQueue queueId
|
||||
case s of
|
||||
Sub {subThread = ProhibitSub} -> do
|
||||
msgDeleted <- atomically $ tryDelMsg q msgId
|
||||
when msgDeleted updateStats
|
||||
pure ok
|
||||
_ -> do
|
||||
(msgDeleted, msg_) <- atomically $ tryDelPeekMsg q msgId
|
||||
when msgDeleted updateStats
|
||||
deliverMessage st queueId sub q msg_
|
||||
_ -> pure $ err NO_MSG
|
||||
where
|
||||
getDelivered :: TVar Sub -> STM (Maybe Sub)
|
||||
getDelivered sub = do
|
||||
s@Sub {delivered} <- readTVar sub
|
||||
tryTakeTMVar delivered $>>= \msgId' ->
|
||||
if msgId == msgId' || B.null msgId
|
||||
then pure $ Just s
|
||||
else putTMVar delivered msgId' $> Nothing
|
||||
updateStats :: m ()
|
||||
updateStats = do
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (msgRecv stats) (+ 1)
|
||||
atomically $ updateActiveQueues stats queueId
|
||||
|
||||
withSub :: RecipientId -> (Sub -> STM a) -> STM (Maybe a)
|
||||
withSub rId f = mapM f =<< TM.lookup rId subscriptions
|
||||
updateActiveQueues :: ServerStats -> RecipientId -> STM ()
|
||||
updateActiveQueues stats qId = do
|
||||
updatePeriod dayMsgQueues
|
||||
updatePeriod weekMsgQueues
|
||||
updatePeriod monthMsgQueues
|
||||
where
|
||||
updatePeriod pSel = modifyTVar (pSel stats) (S.insert qId)
|
||||
|
||||
sendMessage :: QueueStore -> MsgBody -> m (Transmission BrokerMsg)
|
||||
sendMessage st msgBody
|
||||
sendMessage :: QueueStore -> MsgFlags -> MsgBody -> m (Transmission BrokerMsg)
|
||||
sendMessage st msgFlags msgBody
|
||||
| B.length msgBody > maxMessageLength = pure $ err LARGE_MSG
|
||||
| otherwise = do
|
||||
qr <- atomically $ getQueue st SSender queueId
|
||||
|
@ -367,76 +541,102 @@ client clnt@Client {subscriptions, ntfSubscriptions, rcvQ, sndQ} Server {subscri
|
|||
storeMessage qr = case status qr of
|
||||
QueueOff -> return $ err AUTH
|
||||
QueueActive ->
|
||||
mkMessage >>= \case
|
||||
mapM mkMessage (C.maxLenBS msgBody) >>= \case
|
||||
Left _ -> pure $ err LARGE_MSG
|
||||
Right msg -> do
|
||||
ms <- asks msgStore
|
||||
ServerConfig {messageTTL, msgQueueQuota} <- asks config
|
||||
old <- forM messageTTL $ \ttl -> subtract ttl . systemSeconds <$> liftIO getSystemTime
|
||||
atomically $ do
|
||||
ServerConfig {messageExpiration, msgQueueQuota} <- asks config
|
||||
old <- liftIO $ mapM expireBeforeEpoch messageExpiration
|
||||
ntfNonceDrg <- asks idsDrg
|
||||
resp@(_, _, sent) <- atomically $ do
|
||||
q <- getMsgQueue ms (recipientId qr) msgQueueQuota
|
||||
mapM_ (deleteExpiredMsgs q) old
|
||||
ifM (isFull q) (pure $ err QUOTA) $ do
|
||||
trySendNotification
|
||||
when (notification msgFlags) $ trySendNotification msg ntfNonceDrg
|
||||
writeMsg q msg
|
||||
pure ok
|
||||
when (sent == OK) $ do
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (msgSent stats) (+ 1)
|
||||
atomically $ updateActiveQueues stats $ recipientId qr
|
||||
pure resp
|
||||
where
|
||||
mkMessage :: m (Either C.CryptoError Message)
|
||||
mkMessage = do
|
||||
mkMessage :: C.MaxLenBS MaxMessageLen -> m Message
|
||||
mkMessage body = do
|
||||
msgId <- randomId =<< asks (msgIdBytes . config)
|
||||
ts <- liftIO getSystemTime
|
||||
let c = C.cbEncrypt (rcvDhSecret qr) (C.cbNonce msgId) msgBody (maxMessageLength + 2)
|
||||
pure $ Message msgId ts <$> c
|
||||
msgTs <- liftIO getSystemTime
|
||||
pure $ Message msgId msgTs msgFlags body
|
||||
|
||||
trySendNotification :: STM ()
|
||||
trySendNotification =
|
||||
forM_ (notifier qr) $ \(nId, _) ->
|
||||
mapM_ (writeNtf nId) =<< TM.lookup nId notifiers
|
||||
trySendNotification :: Message -> TVar ChaChaDRG -> STM ()
|
||||
trySendNotification msg ntfNonceDrg =
|
||||
forM_ (notifier qr) $ \NtfCreds {notifierId, rcvNtfDhSecret} ->
|
||||
mapM_ (writeNtf notifierId msg rcvNtfDhSecret ntfNonceDrg) =<< TM.lookup notifierId notifiers
|
||||
|
||||
writeNtf :: NotifierId -> Client -> STM ()
|
||||
writeNtf nId Client {sndQ = q} =
|
||||
unlessM (isFullTBQueue sndQ) $
|
||||
writeTBQueue q (CorrId "", nId, NMSG)
|
||||
writeNtf :: NotifierId -> Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> Client -> STM ()
|
||||
writeNtf nId msg rcvNtfDhSecret ntfNonceDrg Client {sndQ = q} =
|
||||
unlessM (isFullTBQueue sndQ) $ do
|
||||
(nmsgNonce, encNMsgMeta) <- mkMessageNotification msg rcvNtfDhSecret ntfNonceDrg
|
||||
writeTBQueue q (CorrId "", nId, NMSG nmsgNonce encNMsgMeta)
|
||||
|
||||
deliverMessage :: (MsgQueue -> STM (Maybe Message)) -> RecipientId -> Sub -> m (Transmission BrokerMsg)
|
||||
deliverMessage tryPeek rId = \case
|
||||
Sub {subThread = NoSub} -> do
|
||||
ms <- asks msgStore
|
||||
quota <- asks $ msgQueueQuota . config
|
||||
q <- atomically $ getMsgQueue ms rId quota
|
||||
atomically (tryPeek q) >>= \case
|
||||
Nothing -> forkSub q $> ok
|
||||
Just msg -> atomically setDelivered $> (corrId, rId, msgCmd msg)
|
||||
_ -> pure ok
|
||||
mkMessageNotification :: Message -> RcvNtfDhSecret -> TVar ChaChaDRG -> STM (C.CbNonce, EncNMsgMeta)
|
||||
mkMessageNotification Message {msgId, msgTs} rcvNtfDhSecret ntfNonceDrg = do
|
||||
cbNonce <- C.pseudoRandomCbNonce ntfNonceDrg
|
||||
let msgMeta = NMsgMeta {msgId, msgTs}
|
||||
encNMsgMeta = C.cbEncrypt rcvNtfDhSecret cbNonce (smpEncode msgMeta) 128
|
||||
pure . (cbNonce,) $ fromRight "" encNMsgMeta
|
||||
|
||||
deliverMessage :: QueueStore -> RecipientId -> TVar Sub -> MsgQueue -> Maybe Message -> m (Transmission BrokerMsg)
|
||||
deliverMessage st rId sub q msg_ = withRcvQueue st rId $ \qr -> do
|
||||
readTVarIO sub >>= \case
|
||||
s@Sub {subThread = NoSub} ->
|
||||
case msg_ of
|
||||
Just msg ->
|
||||
let encMsg = encryptMsg qr msg
|
||||
in atomically (setDelivered s msg) $> (corrId, rId, MSG encMsg)
|
||||
_ -> forkSub qr $> ok
|
||||
_ -> pure ok
|
||||
where
|
||||
forkSub :: MsgQueue -> m ()
|
||||
forkSub q = do
|
||||
atomically . setSub $ \s -> s {subThread = SubPending}
|
||||
t <- forkIO $ subscriber q
|
||||
atomically . setSub $ \case
|
||||
forkSub :: QueueRec -> m ()
|
||||
forkSub qr = do
|
||||
atomically . modifyTVar sub $ \s -> s {subThread = SubPending}
|
||||
t <- mkWeakThreadId =<< forkIO subscriber
|
||||
atomically . modifyTVar sub $ \case
|
||||
s@Sub {subThread = SubPending} -> s {subThread = SubThread t}
|
||||
s -> s
|
||||
where
|
||||
subscriber = atomically $ do
|
||||
msg <- peekMsg q
|
||||
let encMsg = encryptMsg qr msg
|
||||
writeTBQueue sndQ (CorrId "", rId, MSG encMsg)
|
||||
s <- readTVar sub
|
||||
void $ setDelivered s msg
|
||||
writeTVar sub s {subThread = NoSub}
|
||||
|
||||
subscriber :: MsgQueue -> m ()
|
||||
subscriber q = atomically $ do
|
||||
msg <- peekMsg q
|
||||
writeTBQueue sndQ (CorrId "", rId, msgCmd msg)
|
||||
setSub (\s -> s {subThread = NoSub})
|
||||
void setDelivered
|
||||
encryptMsg :: QueueRec -> Message -> RcvMessage
|
||||
encryptMsg qr Message {msgId, msgTs, msgFlags, msgBody}
|
||||
| thVersion == 1 || thVersion == 2 = encrypt msgBody
|
||||
| otherwise = encrypt $ encodeRcvMsgBody RcvMsgBody {msgTs, msgFlags, msgBody}
|
||||
where
|
||||
encrypt :: KnownNat i => C.MaxLenBS i -> RcvMessage
|
||||
encrypt body =
|
||||
let encBody = EncRcvMsgBody $ C.cbEncryptMaxLenBS (rcvDhSecret qr) (C.cbNonce msgId) body
|
||||
in RcvMessage msgId msgTs msgFlags encBody
|
||||
|
||||
setSub :: (Sub -> Sub) -> STM ()
|
||||
setSub f = TM.adjust f rId subscriptions
|
||||
setDelivered :: Sub -> Message -> STM Bool
|
||||
setDelivered s Message {msgId} = tryPutTMVar (delivered s) msgId
|
||||
|
||||
setDelivered :: STM (Maybe Bool)
|
||||
setDelivered = withSub rId $ \s -> tryPutTMVar (delivered s) ()
|
||||
|
||||
msgCmd :: Message -> BrokerMsg
|
||||
msgCmd Message {msgId, ts, msgBody} = MSG msgId ts msgBody
|
||||
getStoreMsgQueue :: RecipientId -> m MsgQueue
|
||||
getStoreMsgQueue rId = do
|
||||
ms <- asks msgStore
|
||||
quota <- asks $ msgQueueQuota . config
|
||||
atomically $ getMsgQueue ms rId quota
|
||||
|
||||
delQueueAndMsgs :: QueueStore -> m (Transmission BrokerMsg)
|
||||
delQueueAndMsgs st = do
|
||||
withLog (`logDeleteQueue` queueId)
|
||||
ms <- asks msgStore
|
||||
stats <- asks serverStats
|
||||
atomically $ modifyTVar (qDeleted stats) (+ 1)
|
||||
atomically $
|
||||
deleteQueue st queueId >>= \case
|
||||
Left e -> pure $ err e
|
||||
|
@ -459,11 +659,78 @@ withLog action = do
|
|||
randomId :: (MonadUnliftIO m, MonadReader Env m) => Int -> m ByteString
|
||||
randomId n = do
|
||||
gVar <- asks idsDrg
|
||||
atomically (randomBytes n gVar)
|
||||
atomically (C.pseudoRandomBytes n gVar)
|
||||
|
||||
randomBytes :: Int -> TVar ChaChaDRG -> STM ByteString
|
||||
randomBytes n gVar = do
|
||||
g <- readTVar gVar
|
||||
let (bytes, g') = randomBytesGenerate n g
|
||||
writeTVar gVar g'
|
||||
return bytes
|
||||
saveServerMessages :: (MonadUnliftIO m, MonadReader Env m) => m ()
|
||||
saveServerMessages = asks (storeMsgsFile . config) >>= mapM_ saveMessages
|
||||
where
|
||||
saveMessages f = do
|
||||
logInfo $ "saving messages to file " <> T.pack f
|
||||
ms <- asks msgStore
|
||||
liftIO . withFile f WriteMode $ \h ->
|
||||
readTVarIO ms >>= mapM_ (saveQueueMsgs ms h) . M.keys
|
||||
logInfo "messages saved"
|
||||
where
|
||||
saveQueueMsgs ms h rId =
|
||||
atomically (flushMsgQueue ms rId)
|
||||
>>= mapM_ (B.hPutStrLn h . strEncode . MLRv3 rId)
|
||||
|
||||
restoreServerMessages :: forall m. (MonadUnliftIO m, MonadReader Env m) => m ()
|
||||
restoreServerMessages = asks (storeMsgsFile . config) >>= mapM_ restoreMessages
|
||||
where
|
||||
restoreMessages f = whenM (doesFileExist f) $ do
|
||||
logInfo $ "restoring messages from file " <> T.pack f
|
||||
st <- asks queueStore
|
||||
ms <- asks msgStore
|
||||
quota <- asks $ msgQueueQuota . config
|
||||
runExceptT (liftIO (B.readFile f) >>= mapM_ (restoreMsg st ms quota) . B.lines) >>= \case
|
||||
Left e -> do
|
||||
logError . T.pack $ "error restoring messages: " <> e
|
||||
liftIO exitFailure
|
||||
_ -> do
|
||||
renameFile f $ f <> ".bak"
|
||||
logInfo "messages restored"
|
||||
where
|
||||
restoreMsg st ms quota s = do
|
||||
r <- liftEither . first (msgErr "parsing") $ strDecode s
|
||||
case r of
|
||||
MLRv3 rId msg -> addToMsgQueue rId msg
|
||||
MLRv1 rId encMsg -> do
|
||||
qr <- liftEitherError (msgErr "queue unknown") . atomically $ getQueue st SRecipient rId
|
||||
msg' <- updateMsgV1toV3 qr encMsg
|
||||
addToMsgQueue rId msg'
|
||||
where
|
||||
addToMsgQueue rId msg = do
|
||||
full <- atomically $ do
|
||||
q <- getMsgQueue ms rId quota
|
||||
ifM (isFull q) (pure True) (writeMsg q msg $> False)
|
||||
when full . logError . decodeLatin1 $ "message queue " <> strEncode rId <> " is full, message not restored: " <> strEncode (msgId (msg :: Message))
|
||||
updateMsgV1toV3 QueueRec {rcvDhSecret} RcvMessage {msgId, msgTs, msgFlags, msgBody = EncRcvMsgBody body} = do
|
||||
let nonce = C.cbNonce msgId
|
||||
msgBody <- liftEither . first (msgErr "v1 message decryption") $ C.maxLenBS =<< C.cbDecrypt rcvDhSecret nonce body
|
||||
pure Message {msgId, msgTs, msgFlags, msgBody}
|
||||
msgErr :: Show e => String -> e -> String
|
||||
msgErr op e = op <> " error (" <> show e <> "): " <> B.unpack (B.take 100 s)
|
||||
|
||||
saveServerStats :: (MonadUnliftIO m, MonadReader Env m) => m ()
|
||||
saveServerStats =
|
||||
asks (serverStatsFile . config)
|
||||
>>= mapM_ (\f -> asks serverStats >>= atomically . getServerStatsData >>= liftIO . saveStats f)
|
||||
where
|
||||
saveStats f stats = do
|
||||
logInfo $ "saving server stats to file " <> T.pack f
|
||||
B.writeFile f $ strEncode stats
|
||||
logInfo "server stats saved"
|
||||
|
||||
restoreServerStats :: (MonadUnliftIO m, MonadReader Env m) => m ()
|
||||
restoreServerStats = asks (serverStatsFile . config) >>= mapM_ restoreStats
|
||||
where
|
||||
restoreStats f = whenM (doesFileExist f) $ do
|
||||
logInfo $ "restoring server stats from file " <> T.pack f
|
||||
liftIO (strDecode <$> B.readFile f) >>= \case
|
||||
Right d -> do
|
||||
s <- asks serverStats
|
||||
atomically $ setServerStatsData s d
|
||||
renameFile f $ f <> ".bak"
|
||||
logInfo "server stats restored"
|
||||
Left e -> logInfo $ "error restoring server stats: " <> T.pack e
|
||||
|
|
|
@ -0,0 +1,284 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
module Simplex.Messaging.Server.CLI where
|
||||
|
||||
import Control.Monad
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Either (fromRight)
|
||||
import Data.Ini (Ini, lookupValue, readIniFile)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import qualified Data.Text as T
|
||||
import Data.X509.Validation (Fingerprint (..))
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import Options.Applicative
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Transport (ATransport (..), TLS, Transport (..))
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint)
|
||||
import Simplex.Messaging.Transport.WebSockets (WS)
|
||||
import System.Directory (createDirectoryIfMissing, doesDirectoryExist, doesFileExist, removeDirectoryRecursive)
|
||||
import System.Exit (exitFailure)
|
||||
import System.FilePath (combine)
|
||||
import System.IO (BufferMode (..), IOMode (..), hFlush, hGetLine, hSetBuffering, stderr, stdout, withFile)
|
||||
import System.Process (readCreateProcess, shell)
|
||||
import Text.Read (readMaybe)
|
||||
|
||||
data ServerCLIConfig cfg = ServerCLIConfig
|
||||
{ cfgDir :: FilePath,
|
||||
logDir :: FilePath,
|
||||
iniFile :: FilePath,
|
||||
storeLogFile :: FilePath,
|
||||
caKeyFile :: FilePath,
|
||||
caCrtFile :: FilePath,
|
||||
serverKeyFile :: FilePath,
|
||||
serverCrtFile :: FilePath,
|
||||
fingerprintFile :: FilePath,
|
||||
defaultServerPort :: ServiceName,
|
||||
executableName :: String,
|
||||
serverVersion :: String,
|
||||
mkIniFile :: Bool -> ServiceName -> String,
|
||||
mkServerConfig :: Maybe FilePath -> [(ServiceName, ATransport)] -> Ini -> cfg
|
||||
}
|
||||
|
||||
protocolServerCLI :: ServerCLIConfig cfg -> (cfg -> IO ()) -> IO ()
|
||||
protocolServerCLI cliCfg@ServerCLIConfig {iniFile, executableName} server =
|
||||
getCliCommand cliCfg >>= \case
|
||||
Init opts ->
|
||||
doesFileExist iniFile >>= \case
|
||||
True -> exitError $ "Error: server is already initialized (" <> iniFile <> " exists).\nRun `" <> executableName <> " start`."
|
||||
_ -> initializeServer cliCfg opts
|
||||
Start ->
|
||||
doesFileExist iniFile >>= \case
|
||||
True -> readIniFile iniFile >>= either exitError (runServer cliCfg server)
|
||||
_ -> exitError $ "Error: server is not initialized (" <> iniFile <> " does not exist).\nRun `" <> executableName <> " init`."
|
||||
Delete -> do
|
||||
confirmOrExit "WARNING: deleting the server will make all queues inaccessible, because the server identity (certificate fingerprint) will change.\nTHIS CANNOT BE UNDONE!"
|
||||
cleanup cliCfg
|
||||
putStrLn "Deleted configuration and log files"
|
||||
|
||||
exitError :: String -> IO ()
|
||||
exitError msg = putStrLn msg >> exitFailure
|
||||
|
||||
confirmOrExit :: String -> IO ()
|
||||
confirmOrExit s = do
|
||||
putStrLn s
|
||||
putStr "Continue (Y/n): "
|
||||
hFlush stdout
|
||||
ok <- getLine
|
||||
when (ok /= "Y") exitFailure
|
||||
|
||||
data CliCommand
|
||||
= Init InitOptions
|
||||
| Start
|
||||
| Delete
|
||||
|
||||
data InitOptions = InitOptions
|
||||
{ enableStoreLog :: Bool,
|
||||
signAlgorithm :: SignAlgorithm,
|
||||
ip :: HostName,
|
||||
fqdn :: Maybe HostName
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data SignAlgorithm = ED448 | ED25519
|
||||
deriving (Read, Show)
|
||||
|
||||
getCliCommand :: ServerCLIConfig cfg -> IO CliCommand
|
||||
getCliCommand cliCfg =
|
||||
customExecParser
|
||||
(prefs showHelpOnEmpty)
|
||||
( info
|
||||
(helper <*> versionOption <*> cliCommandP cliCfg)
|
||||
(header version <> fullDesc)
|
||||
)
|
||||
where
|
||||
versionOption = infoOption version (long "version" <> short 'v' <> help "Show version")
|
||||
version = serverVersion cliCfg
|
||||
|
||||
cliCommandP :: ServerCLIConfig cfg -> Parser CliCommand
|
||||
cliCommandP ServerCLIConfig {cfgDir, logDir, iniFile} =
|
||||
hsubparser
|
||||
( command "init" (info initP (progDesc $ "Initialize server - creates " <> cfgDir <> " and " <> logDir <> " directories and configuration files"))
|
||||
<> command "start" (info (pure Start) (progDesc $ "Start server (configuration: " <> iniFile <> ")"))
|
||||
<> command "delete" (info (pure Delete) (progDesc "Delete configuration and log files"))
|
||||
)
|
||||
where
|
||||
initP :: Parser CliCommand
|
||||
initP =
|
||||
Init
|
||||
<$> ( InitOptions
|
||||
<$> switch
|
||||
( long "store-log"
|
||||
<> short 'l'
|
||||
<> help "Enable store log for persistence"
|
||||
)
|
||||
<*> option
|
||||
(maybeReader readMaybe)
|
||||
( long "sign-algorithm"
|
||||
<> short 'a'
|
||||
<> help "Signature algorithm used for TLS certificates: ED25519, ED448"
|
||||
<> value ED448
|
||||
<> showDefault
|
||||
<> metavar "ALG"
|
||||
)
|
||||
<*> strOption
|
||||
( long "ip"
|
||||
<> help
|
||||
"Server IP address, used as Common Name for TLS online certificate if FQDN is not supplied"
|
||||
<> value "127.0.0.1"
|
||||
<> showDefault
|
||||
<> metavar "IP"
|
||||
)
|
||||
<*> (optional . strOption)
|
||||
( long "fqdn"
|
||||
<> short 'n'
|
||||
<> help "Server FQDN used as Common Name for TLS online certificate"
|
||||
<> showDefault
|
||||
<> metavar "FQDN"
|
||||
)
|
||||
)
|
||||
|
||||
initializeServer :: ServerCLIConfig cfg -> InitOptions -> IO ()
|
||||
initializeServer cliCfg InitOptions {enableStoreLog, signAlgorithm, ip, fqdn} = do
|
||||
cleanup cliCfg
|
||||
createDirectoryIfMissing True cfgDir
|
||||
createDirectoryIfMissing True logDir
|
||||
createX509
|
||||
fp <- saveFingerprint
|
||||
writeFile iniFile $ mkIniFile enableStoreLog defaultServerPort
|
||||
putStrLn $ "Server initialized, you can modify configuration in " <> iniFile <> ".\nRun `" <> executableName <> " start` to start server."
|
||||
printServiceInfo cliCfg fp
|
||||
warnCAPrivateKeyFile
|
||||
where
|
||||
ServerCLIConfig {cfgDir, logDir, iniFile, executableName, caKeyFile, caCrtFile, serverKeyFile, serverCrtFile, fingerprintFile, defaultServerPort, mkIniFile} = cliCfg
|
||||
createX509 = do
|
||||
createOpensslCaConf
|
||||
createOpensslServerConf
|
||||
-- CA certificate (identity/offline)
|
||||
run $ "openssl genpkey -algorithm " <> show signAlgorithm <> " -out " <> caKeyFile
|
||||
run $ "openssl req -new -x509 -days 999999 -config " <> opensslCaConfFile <> " -extensions v3 -key " <> caKeyFile <> " -out " <> caCrtFile
|
||||
-- server certificate (online)
|
||||
run $ "openssl genpkey -algorithm " <> show signAlgorithm <> " -out " <> serverKeyFile
|
||||
run $ "openssl req -new -config " <> opensslServerConfFile <> " -reqexts v3 -key " <> serverKeyFile <> " -out " <> serverCsrFile
|
||||
run $ "openssl x509 -req -days 999999 -extfile " <> opensslServerConfFile <> " -extensions v3 -in " <> serverCsrFile <> " -CA " <> caCrtFile <> " -CAkey " <> caKeyFile <> " -CAcreateserial -out " <> serverCrtFile
|
||||
where
|
||||
run cmd = void $ readCreateProcess (shell cmd) ""
|
||||
opensslCaConfFile = combine cfgDir "openssl_ca.conf"
|
||||
opensslServerConfFile = combine cfgDir "openssl_server.conf"
|
||||
serverCsrFile = combine cfgDir "server.csr"
|
||||
createOpensslCaConf =
|
||||
writeFile
|
||||
opensslCaConfFile
|
||||
"[req]\n\
|
||||
\distinguished_name = req_distinguished_name\n\
|
||||
\prompt = no\n\n\
|
||||
\[req_distinguished_name]\n\
|
||||
\CN = SMP server CA\n\
|
||||
\O = SimpleX\n\n\
|
||||
\[v3]\n\
|
||||
\subjectKeyIdentifier = hash\n\
|
||||
\authorityKeyIdentifier = keyid:always\n\
|
||||
\basicConstraints = critical,CA:true\n"
|
||||
-- TODO revise https://www.rfc-editor.org/rfc/rfc5280#section-4.2.1.3, https://www.rfc-editor.org/rfc/rfc3279#section-2.3.5
|
||||
-- IP and FQDN can't both be used as server address interchangeably even if IP is added
|
||||
-- as Subject Alternative Name, unless the following validation hook is disabled:
|
||||
-- https://hackage.haskell.org/package/x509-validation-1.6.10/docs/src/Data-X509-Validation.html#validateCertificateName
|
||||
createOpensslServerConf =
|
||||
writeFile
|
||||
opensslServerConfFile
|
||||
( "[req]\n\
|
||||
\distinguished_name = req_distinguished_name\n\
|
||||
\prompt = no\n\n\
|
||||
\[req_distinguished_name]\n"
|
||||
<> ("CN = " <> cn <> "\n\n")
|
||||
<> "[v3]\n\
|
||||
\basicConstraints = CA:FALSE\n\
|
||||
\keyUsage = digitalSignature, nonRepudiation, keyAgreement\n\
|
||||
\extendedKeyUsage = serverAuth\n"
|
||||
)
|
||||
where
|
||||
cn = fromMaybe ip fqdn
|
||||
|
||||
saveFingerprint = do
|
||||
Fingerprint fp <- loadFingerprint caCrtFile
|
||||
withFile fingerprintFile WriteMode (`B.hPutStrLn` strEncode fp)
|
||||
pure fp
|
||||
|
||||
warnCAPrivateKeyFile =
|
||||
putStrLn $
|
||||
"----------\n\
|
||||
\You should store CA private key securely and delete it from the server.\n\
|
||||
\If server TLS credential is compromised this key can be used to sign a new one, \
|
||||
\keeping the same server identity and established connections.\n\
|
||||
\CA private key location:\n"
|
||||
<> caKeyFile
|
||||
<> "\n----------"
|
||||
|
||||
data IniOptions = IniOptions
|
||||
{ enableStoreLog :: Bool,
|
||||
port :: ServiceName,
|
||||
enableWebsockets :: Bool
|
||||
}
|
||||
|
||||
mkIniOptions :: Ini -> IniOptions
|
||||
mkIniOptions ini =
|
||||
IniOptions
|
||||
{ enableStoreLog = (== "on") $ strictIni "STORE_LOG" "enable" ini,
|
||||
port = T.unpack $ strictIni "TRANSPORT" "port" ini,
|
||||
enableWebsockets = (== "on") $ strictIni "TRANSPORT" "websockets" ini
|
||||
}
|
||||
|
||||
strictIni :: String -> String -> Ini -> T.Text
|
||||
strictIni section key ini =
|
||||
fromRight (error ("no key " <> key <> " in section " <> section)) $
|
||||
lookupValue (T.pack section) (T.pack key) ini
|
||||
|
||||
readStrictIni :: Read a => String -> String -> Ini -> a
|
||||
readStrictIni section key = read . T.unpack . strictIni section key
|
||||
|
||||
runServer :: ServerCLIConfig cfg -> (cfg -> IO ()) -> Ini -> IO ()
|
||||
runServer cliCfg server ini = do
|
||||
hSetBuffering stdout LineBuffering
|
||||
hSetBuffering stderr LineBuffering
|
||||
fp <- checkSavedFingerprint
|
||||
printServiceInfo cliCfg fp
|
||||
let IniOptions {enableStoreLog, port, enableWebsockets} = mkIniOptions ini
|
||||
transports = (port, transport @TLS) : [("80", transport @WS) | enableWebsockets]
|
||||
logFile = if enableStoreLog then Just storeLogFile else Nothing
|
||||
cfg = mkServerConfig logFile transports ini
|
||||
printServerConfig logFile transports
|
||||
server cfg
|
||||
where
|
||||
ServerCLIConfig {storeLogFile, caCrtFile, fingerprintFile, mkServerConfig} = cliCfg
|
||||
checkSavedFingerprint = do
|
||||
savedFingerprint <- withFile fingerprintFile ReadMode hGetLine
|
||||
Fingerprint fp <- loadFingerprint caCrtFile
|
||||
when (B.pack savedFingerprint /= strEncode fp) $
|
||||
exitError "Stored fingerprint is invalid."
|
||||
pure fp
|
||||
|
||||
printServerConfig logFile transports = do
|
||||
putStrLn $ case logFile of
|
||||
Just f -> "Store log: " <> f
|
||||
_ -> "Store log disabled."
|
||||
forM_ transports $ \(p, ATransport t) ->
|
||||
putStrLn $ "Listening on port " <> p <> " (" <> transportName t <> ")..."
|
||||
|
||||
cleanup :: ServerCLIConfig cfg -> IO ()
|
||||
cleanup ServerCLIConfig {cfgDir, logDir} = do
|
||||
deleteDirIfExists cfgDir
|
||||
deleteDirIfExists logDir
|
||||
where
|
||||
deleteDirIfExists path = doesDirectoryExist path >>= (`when` removeDirectoryRecursive path)
|
||||
|
||||
printServiceInfo :: ServerCLIConfig cfg -> ByteString -> IO ()
|
||||
printServiceInfo ServerCLIConfig {serverVersion} fpStr = do
|
||||
putStrLn serverVersion
|
||||
B.putStrLn $ "Fingerprint: " <> strEncode fpStr
|
|
@ -9,24 +9,29 @@ import Control.Concurrent (ThreadId)
|
|||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Int (Int64)
|
||||
import Data.Map.Strict (Map)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Time.Clock (getCurrentTime)
|
||||
import Data.Time.Clock.System (SystemTime)
|
||||
import Data.X509.Validation (Fingerprint (..))
|
||||
import Network.Socket (ServiceName)
|
||||
import qualified Network.TLS as T
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Crypto (KeyHash (..))
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Server.MsgStore.STM
|
||||
import Simplex.Messaging.Server.QueueStore (QueueRec (..))
|
||||
import Simplex.Messaging.Server.QueueStore (NtfCreds (..), QueueRec (..))
|
||||
import Simplex.Messaging.Server.QueueStore.STM
|
||||
import Simplex.Messaging.Server.Stats
|
||||
import Simplex.Messaging.Server.StoreLog
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport (ATransport)
|
||||
import Simplex.Messaging.Transport.Server (loadFingerprint, loadTLSServerParams)
|
||||
import Simplex.Messaging.Version
|
||||
import System.IO (IOMode (..))
|
||||
import System.Mem.Weak (Weak)
|
||||
import UnliftIO.STM
|
||||
|
||||
data ServerConfig = ServerConfig
|
||||
|
@ -36,19 +41,44 @@ data ServerConfig = ServerConfig
|
|||
msgQueueQuota :: Natural,
|
||||
queueIdBytes :: Int,
|
||||
msgIdBytes :: Int,
|
||||
storeLog :: Maybe (StoreLog 'ReadMode),
|
||||
storeLogFile :: Maybe FilePath,
|
||||
storeMsgsFile :: Maybe FilePath,
|
||||
-- | set to False to prohibit creating new queues
|
||||
allowNewQueues :: Bool,
|
||||
-- | time after which the messages can be removed from the queues, seconds
|
||||
messageTTL :: Maybe Int64,
|
||||
-- | interval to periodically remove expired messages (when no messages are sent to the queue), microseconds
|
||||
expireMessagesInterval :: Maybe Int,
|
||||
-- CA certificate private key is not needed for initialization
|
||||
-- | time after which the messages can be removed from the queues and check interval, seconds
|
||||
messageExpiration :: Maybe ExpirationConfig,
|
||||
-- | time after which the socket with inactive client can be disconnected (without any messages or commands, incl. PING),
|
||||
-- and check interval, seconds
|
||||
inactiveClientExpiration :: Maybe ExpirationConfig,
|
||||
-- | log SMP server usage statistics, only aggregates are logged, seconds
|
||||
logStatsInterval :: Maybe Int,
|
||||
-- | time of the day when the stats are logged first, to log at consistent times,
|
||||
-- irrespective of when the server is started (seconds from 00:00 UTC)
|
||||
logStatsStartTime :: Int,
|
||||
-- | file to save and restore stats
|
||||
serverStatsFile :: Maybe FilePath,
|
||||
-- | CA certificate private key is not needed for initialization
|
||||
caCertificateFile :: FilePath,
|
||||
privateKeyFile :: FilePath,
|
||||
certificateFile :: FilePath
|
||||
certificateFile :: FilePath,
|
||||
-- | SMP client-server protocol version range
|
||||
smpServerVRange :: VersionRange
|
||||
}
|
||||
|
||||
defaultMessageExpiration :: ExpirationConfig
|
||||
defaultMessageExpiration =
|
||||
ExpirationConfig
|
||||
{ ttl = 30 * 86400, -- seconds, 30 days
|
||||
checkInterval = 43200 -- seconds, 12 hours
|
||||
}
|
||||
|
||||
defaultInactiveClientExpiration :: ExpirationConfig
|
||||
defaultInactiveClientExpiration =
|
||||
ExpirationConfig
|
||||
{ ttl = 86400, -- seconds, 24 hours
|
||||
checkInterval = 43200 -- seconds, 12 hours
|
||||
}
|
||||
|
||||
data Env = Env
|
||||
{ config :: ServerConfig,
|
||||
server :: Server,
|
||||
|
@ -57,7 +87,8 @@ data Env = Env
|
|||
msgStore :: STMMsgStore,
|
||||
idsDrg :: TVar ChaChaDRG,
|
||||
storeLog :: Maybe (StoreLog 'WriteMode),
|
||||
tlsServerParams :: T.ServerParams
|
||||
tlsServerParams :: T.ServerParams,
|
||||
serverStats :: ServerStats
|
||||
}
|
||||
|
||||
data Server = Server
|
||||
|
@ -68,19 +99,21 @@ data Server = Server
|
|||
}
|
||||
|
||||
data Client = Client
|
||||
{ subscriptions :: TMap RecipientId Sub,
|
||||
{ subscriptions :: TMap RecipientId (TVar Sub),
|
||||
ntfSubscriptions :: TMap NotifierId (),
|
||||
rcvQ :: TBQueue (Transmission Cmd),
|
||||
sndQ :: TBQueue (Transmission BrokerMsg),
|
||||
thVersion :: Version,
|
||||
sessionId :: ByteString,
|
||||
connected :: TVar Bool
|
||||
connected :: TVar Bool,
|
||||
activeAt :: TVar SystemTime
|
||||
}
|
||||
|
||||
data SubscriptionThread = NoSub | SubPending | SubThread ThreadId
|
||||
data SubscriptionThread = NoSub | SubPending | SubThread (Weak ThreadId) | ProhibitSub
|
||||
|
||||
data Sub = Sub
|
||||
{ subThread :: SubscriptionThread,
|
||||
delivered :: TMVar ()
|
||||
delivered :: TMVar MsgId
|
||||
}
|
||||
|
||||
newServer :: Natural -> STM Server
|
||||
|
@ -91,31 +124,34 @@ newServer qSize = do
|
|||
notifiers <- TM.empty
|
||||
return Server {subscribedQ, subscribers, ntfSubscribedQ, notifiers}
|
||||
|
||||
newClient :: Natural -> ByteString -> STM Client
|
||||
newClient qSize sessionId = do
|
||||
newClient :: Natural -> Version -> ByteString -> SystemTime -> STM Client
|
||||
newClient qSize thVersion sessionId ts = do
|
||||
subscriptions <- TM.empty
|
||||
ntfSubscriptions <- TM.empty
|
||||
rcvQ <- newTBQueue qSize
|
||||
sndQ <- newTBQueue qSize
|
||||
connected <- newTVar True
|
||||
return Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, sessionId, connected}
|
||||
activeAt <- newTVar ts
|
||||
return Client {subscriptions, ntfSubscriptions, rcvQ, sndQ, thVersion, sessionId, connected, activeAt}
|
||||
|
||||
newSubscription :: STM Sub
|
||||
newSubscription = do
|
||||
newSubscription :: SubscriptionThread -> STM Sub
|
||||
newSubscription subThread = do
|
||||
delivered <- newEmptyTMVar
|
||||
return Sub {subThread = NoSub, delivered}
|
||||
return Sub {subThread, delivered}
|
||||
|
||||
newEnv :: forall m. (MonadUnliftIO m, MonadRandom m) => ServerConfig -> m Env
|
||||
newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile} = do
|
||||
newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile, storeLogFile} = do
|
||||
server <- atomically $ newServer (serverTbqSize config)
|
||||
queueStore <- atomically newQueueStore
|
||||
msgStore <- atomically newMsgStore
|
||||
idsDrg <- drgNew >>= newTVarIO
|
||||
s' <- restoreQueues queueStore `mapM` storeLog (config :: ServerConfig)
|
||||
storeLog <- liftIO $ openReadStoreLog `mapM` storeLogFile
|
||||
s' <- restoreQueues queueStore `mapM` storeLog
|
||||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile
|
||||
let serverIdentity = KeyHash fp
|
||||
return Env {config, server, serverIdentity, queueStore, msgStore, idsDrg, storeLog = s', tlsServerParams}
|
||||
serverStats <- atomically . newServerStats =<< liftIO getCurrentTime
|
||||
return Env {config, server, serverIdentity, queueStore, msgStore, idsDrg, storeLog = s', tlsServerParams, serverStats}
|
||||
where
|
||||
restoreQueues :: QueueStore -> StoreLog 'ReadMode -> m (StoreLog 'WriteMode)
|
||||
restoreQueues QueueStore {queues, senders, notifiers} s = do
|
||||
|
@ -130,4 +166,4 @@ newEnv config@ServerConfig {caCertificateFile, certificateFile, privateKeyFile}
|
|||
addNotifier :: QueueRec -> Map NotifierId RecipientId -> Map NotifierId RecipientId
|
||||
addNotifier q = case notifier q of
|
||||
Nothing -> id
|
||||
Just (nId, _) -> M.insert nId (recipientId q)
|
||||
Just NtfCreds {notifierId} -> M.insert notifierId (recipientId q)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Simplex.Messaging.Server.Expiration where
|
||||
|
||||
import Control.Monad.IO.Class
|
||||
import Data.Int (Int64)
|
||||
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
||||
|
||||
data ExpirationConfig = ExpirationConfig
|
||||
{ -- time after which the entity can be expired, seconds
|
||||
ttl :: Int64,
|
||||
-- interval to check expiration, seconds
|
||||
checkInterval :: Int
|
||||
}
|
||||
|
||||
expireBeforeEpoch :: ExpirationConfig -> IO Int64
|
||||
expireBeforeEpoch ExpirationConfig {ttl} = subtract ttl . systemSeconds <$> liftIO getSystemTime
|
|
@ -1,26 +1,33 @@
|
|||
{-# LANGUAGE FunctionalDependencies #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Server.MsgStore where
|
||||
|
||||
import Control.Applicative ((<|>))
|
||||
import Data.Int (Int64)
|
||||
import Data.Time.Clock.System (SystemTime)
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Protocol (MsgBody, MsgId, RecipientId)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (Message (..), MsgId, RcvMessage (..), RecipientId)
|
||||
|
||||
data Message = Message
|
||||
{ msgId :: MsgId,
|
||||
ts :: SystemTime,
|
||||
msgBody :: MsgBody
|
||||
}
|
||||
data MsgLogRecord = MLRv3 RecipientId Message | MLRv1 RecipientId RcvMessage
|
||||
|
||||
instance StrEncoding MsgLogRecord where
|
||||
strEncode = \case
|
||||
MLRv3 rId msg -> strEncode (Str "v3", rId, msg)
|
||||
MLRv1 rId msg -> strEncode (rId, msg)
|
||||
strP = "v3 " *> (MLRv3 <$> strP_ <*> strP) <|> MLRv1 <$> strP_ <*> strP
|
||||
|
||||
class MonadMsgStore s q m | s -> q where
|
||||
getMsgQueue :: s -> RecipientId -> Natural -> m q
|
||||
delMsgQueue :: s -> RecipientId -> m ()
|
||||
flushMsgQueue :: s -> RecipientId -> m [Message]
|
||||
|
||||
class MonadMsgQueue q m where
|
||||
isFull :: q -> m Bool
|
||||
writeMsg :: q -> Message -> m () -- non blocking
|
||||
tryPeekMsg :: q -> m (Maybe Message) -- non blocking
|
||||
peekMsg :: q -> m Message -- blocking
|
||||
tryDelPeekMsg :: q -> m (Maybe Message) -- atomic delete (== read) last and peek next message, if available
|
||||
tryDelMsg :: q -> MsgId -> m Bool -- non blocking
|
||||
tryDelPeekMsg :: q -> MsgId -> m (Bool, Maybe Message) -- atomic delete (== read) last and peek next message, if available
|
||||
deleteExpiredMsgs :: q -> Int64 -> m ()
|
||||
|
|
|
@ -2,16 +2,21 @@
|
|||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
|
||||
module Simplex.Messaging.Server.MsgStore.STM where
|
||||
|
||||
import Control.Concurrent.STM.TBQueue (flushTBQueue)
|
||||
import Control.Monad (when)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Functor (($>))
|
||||
import Data.Int (Int64)
|
||||
import Data.Time.Clock.System (SystemTime (systemSeconds))
|
||||
import Numeric.Natural
|
||||
import Simplex.Messaging.Protocol (RecipientId)
|
||||
import Simplex.Messaging.Protocol (Message (..), MsgId, RecipientId)
|
||||
import Simplex.Messaging.Server.MsgStore
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
|
@ -36,6 +41,9 @@ instance MonadMsgStore STMMsgStore MsgQueue STM where
|
|||
delMsgQueue :: STMMsgStore -> RecipientId -> STM ()
|
||||
delMsgQueue st rId = TM.delete rId st
|
||||
|
||||
flushMsgQueue :: STMMsgStore -> RecipientId -> STM [Message]
|
||||
flushMsgQueue st rId = TM.lookup rId st >>= maybe (pure []) (flushTBQueue . msgQueue)
|
||||
|
||||
instance MonadMsgQueue MsgQueue STM where
|
||||
isFull :: MsgQueue -> STM Bool
|
||||
isFull = isFullTBQueue . msgQueue
|
||||
|
@ -49,14 +57,27 @@ instance MonadMsgQueue MsgQueue STM where
|
|||
peekMsg :: MsgQueue -> STM Message
|
||||
peekMsg = peekTBQueue . msgQueue
|
||||
|
||||
tryDelMsg :: MsgQueue -> MsgId -> STM Bool
|
||||
tryDelMsg (MsgQueue q) msgId' =
|
||||
tryPeekTBQueue q >>= \case
|
||||
Just Message {msgId}
|
||||
| msgId == msgId' || B.null msgId' -> tryReadTBQueue q $> True
|
||||
| otherwise -> pure False
|
||||
_ -> pure False
|
||||
|
||||
-- atomic delete (== read) last and peek next message if available
|
||||
tryDelPeekMsg :: MsgQueue -> STM (Maybe Message)
|
||||
tryDelPeekMsg (MsgQueue q) = tryReadTBQueue q >> tryPeekTBQueue q
|
||||
tryDelPeekMsg :: MsgQueue -> MsgId -> STM (Bool, Maybe Message)
|
||||
tryDelPeekMsg (MsgQueue q) msgId' =
|
||||
tryPeekTBQueue q >>= \case
|
||||
msg_@(Just Message {msgId})
|
||||
| msgId == msgId' || B.null msgId' -> (True,) <$> (tryReadTBQueue q >> tryPeekTBQueue q)
|
||||
| otherwise -> pure (False, msg_)
|
||||
_ -> pure (False, Nothing)
|
||||
|
||||
deleteExpiredMsgs :: MsgQueue -> Int64 -> STM ()
|
||||
deleteExpiredMsgs (MsgQueue q) old = loop
|
||||
where
|
||||
loop = tryPeekTBQueue q >>= mapM_ delOldMsg
|
||||
delOldMsg Message {ts} =
|
||||
when (systemSeconds ts < old) $
|
||||
delOldMsg Message {msgTs} =
|
||||
when (systemSeconds msgTs < old) $
|
||||
tryReadTBQueue q >> loop
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
||||
module Simplex.Messaging.Server.QueueStore where
|
||||
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol
|
||||
|
||||
data QueueRec = QueueRec
|
||||
|
@ -12,17 +14,31 @@ data QueueRec = QueueRec
|
|||
rcvDhSecret :: RcvDhSecret,
|
||||
senderId :: SenderId,
|
||||
senderKey :: Maybe SndPublicVerifyKey,
|
||||
notifier :: Maybe (NotifierId, NtfPublicVerifyKey),
|
||||
notifier :: Maybe NtfCreds,
|
||||
status :: QueueStatus
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
data NtfCreds = NtfCreds
|
||||
{ notifierId :: NotifierId,
|
||||
notifierKey :: NtfPublicVerifyKey,
|
||||
rcvNtfDhSecret :: RcvNtfDhSecret
|
||||
}
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance StrEncoding NtfCreds where
|
||||
strEncode NtfCreds {notifierId, notifierKey, rcvNtfDhSecret} = strEncode (notifierId, notifierKey, rcvNtfDhSecret)
|
||||
strP = do
|
||||
(notifierId, notifierKey, rcvNtfDhSecret) <- strP
|
||||
pure NtfCreds {notifierId, notifierKey, rcvNtfDhSecret}
|
||||
|
||||
data QueueStatus = QueueActive | QueueOff deriving (Eq, Show)
|
||||
|
||||
class MonadQueueStore s m where
|
||||
addQueue :: s -> QueueRec -> m (Either ErrorType ())
|
||||
getQueue :: s -> SParty p -> QueueId -> m (Either ErrorType QueueRec)
|
||||
secureQueue :: s -> RecipientId -> SndPublicVerifyKey -> m (Either ErrorType QueueRec)
|
||||
addQueueNotifier :: s -> RecipientId -> NotifierId -> NtfPublicVerifyKey -> m (Either ErrorType QueueRec)
|
||||
addQueueNotifier :: s -> RecipientId -> NtfCreds -> m (Either ErrorType QueueRec)
|
||||
deleteQueueNotifier :: s -> RecipientId -> m (Either ErrorType ())
|
||||
suspendQueue :: s -> RecipientId -> m (Either ErrorType ())
|
||||
deleteQueue :: s -> RecipientId -> m (Either ErrorType ())
|
||||
|
|
|
@ -18,7 +18,7 @@ import Simplex.Messaging.Protocol
|
|||
import Simplex.Messaging.Server.QueueStore
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Util (ifM)
|
||||
import Simplex.Messaging.Util (ifM, ($>>=))
|
||||
import UnliftIO.STM
|
||||
|
||||
data QueueStore = QueueStore
|
||||
|
@ -51,9 +51,8 @@ instance MonadQueueStore QueueStore STM where
|
|||
where
|
||||
getVar = case party of
|
||||
SRecipient -> TM.lookup qId queues
|
||||
SSender -> TM.lookup qId senders >>= get
|
||||
SNotifier -> TM.lookup qId notifiers >>= get
|
||||
get = fmap join . mapM (`TM.lookup` queues)
|
||||
SSender -> TM.lookup qId senders $>>= (`TM.lookup` queues)
|
||||
SNotifier -> TM.lookup qId notifiers $>>= (`TM.lookup` queues)
|
||||
|
||||
secureQueue :: QueueStore -> RecipientId -> SndPublicVerifyKey -> STM (Either ErrorType QueueRec)
|
||||
secureQueue QueueStore {queues} rId sKey =
|
||||
|
@ -62,16 +61,23 @@ instance MonadQueueStore QueueStore STM where
|
|||
Just _ -> pure Nothing
|
||||
_ -> writeTVar qVar q {senderKey = Just sKey} $> Just q
|
||||
|
||||
addQueueNotifier :: QueueStore -> RecipientId -> NotifierId -> NtfPublicVerifyKey -> STM (Either ErrorType QueueRec)
|
||||
addQueueNotifier QueueStore {queues, notifiers} rId nId nKey = do
|
||||
addQueueNotifier :: QueueStore -> RecipientId -> NtfCreds -> STM (Either ErrorType QueueRec)
|
||||
addQueueNotifier QueueStore {queues, notifiers} rId ntfCreds@NtfCreds {notifierId = nId} = do
|
||||
ifM (TM.member nId notifiers) (pure $ Left DUPLICATE_) $
|
||||
withQueue rId queues $ \qVar ->
|
||||
readTVar qVar >>= \q -> case notifier q of
|
||||
Just _ -> pure Nothing
|
||||
_ -> do
|
||||
writeTVar qVar q {notifier = Just (nId, nKey)}
|
||||
TM.insert nId rId notifiers
|
||||
pure $ Just q
|
||||
withQueue rId queues $ \qVar -> do
|
||||
q <- readTVar qVar
|
||||
forM_ (notifier q) $ (`TM.delete` notifiers) . notifierId
|
||||
writeTVar qVar q {notifier = Just ntfCreds}
|
||||
TM.insert nId rId notifiers
|
||||
pure $ Just q
|
||||
|
||||
deleteQueueNotifier :: QueueStore -> RecipientId -> STM (Either ErrorType ())
|
||||
deleteQueueNotifier QueueStore {queues, notifiers} rId =
|
||||
withQueue rId queues $ \qVar -> do
|
||||
q <- readTVar qVar
|
||||
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
|
||||
writeTVar qVar q {notifier = Nothing}
|
||||
pure $ Just ()
|
||||
|
||||
suspendQueue :: QueueStore -> RecipientId -> STM (Either ErrorType ())
|
||||
suspendQueue QueueStore {queues} rId =
|
||||
|
@ -83,7 +89,7 @@ instance MonadQueueStore QueueStore STM where
|
|||
Just qVar ->
|
||||
readTVar qVar >>= \q -> do
|
||||
TM.delete (senderId q) senders
|
||||
forM_ (notifier q) $ \(nId, _) -> TM.delete nId notifiers
|
||||
forM_ (notifier q) $ \NtfCreds {notifierId} -> TM.delete notifierId notifiers
|
||||
pure $ Right ()
|
||||
_ -> pure $ Left AUTH
|
||||
|
||||
|
@ -91,4 +97,4 @@ toResult :: Maybe a -> Either ErrorType a
|
|||
toResult = maybe (Left AUTH) Right
|
||||
|
||||
withQueue :: RecipientId -> TMap RecipientId (TVar QueueRec) -> (TVar QueueRec -> STM (Maybe a)) -> STM (Either ErrorType a)
|
||||
withQueue rId queues f = toResult <$> (TM.lookup rId queues >>= fmap join . mapM f)
|
||||
withQueue rId queues f = toResult <$> TM.lookup rId queues $>>= f
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Server.Stats where
|
||||
|
||||
import Control.Applicative (optional)
|
||||
import qualified Data.Attoparsec.ByteString.Char8 as A
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Data.Time.Clock (UTCTime)
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (RecipientId)
|
||||
import UnliftIO.STM
|
||||
|
||||
data ServerStats = ServerStats
|
||||
{ fromTime :: TVar UTCTime,
|
||||
qCreated :: TVar Int,
|
||||
qSecured :: TVar Int,
|
||||
qDeleted :: TVar Int,
|
||||
msgSent :: TVar Int,
|
||||
msgRecv :: TVar Int,
|
||||
dayMsgQueues :: TVar (Set RecipientId),
|
||||
weekMsgQueues :: TVar (Set RecipientId),
|
||||
monthMsgQueues :: TVar (Set RecipientId)
|
||||
}
|
||||
|
||||
data ServerStatsData = ServerStatsData
|
||||
{ _fromTime :: UTCTime,
|
||||
_qCreated :: Int,
|
||||
_qSecured :: Int,
|
||||
_qDeleted :: Int,
|
||||
_msgSent :: Int,
|
||||
_msgRecv :: Int,
|
||||
_dayMsgQueues :: Set RecipientId,
|
||||
_weekMsgQueues :: Set RecipientId,
|
||||
_monthMsgQueues :: Set RecipientId
|
||||
}
|
||||
|
||||
newServerStats :: UTCTime -> STM ServerStats
|
||||
newServerStats ts = do
|
||||
fromTime <- newTVar ts
|
||||
qCreated <- newTVar 0
|
||||
qSecured <- newTVar 0
|
||||
qDeleted <- newTVar 0
|
||||
msgSent <- newTVar 0
|
||||
msgRecv <- newTVar 0
|
||||
dayMsgQueues <- newTVar S.empty
|
||||
weekMsgQueues <- newTVar S.empty
|
||||
monthMsgQueues <- newTVar S.empty
|
||||
pure ServerStats {fromTime, qCreated, qSecured, qDeleted, msgSent, msgRecv, dayMsgQueues, weekMsgQueues, monthMsgQueues}
|
||||
|
||||
getServerStatsData :: ServerStats -> STM ServerStatsData
|
||||
getServerStatsData s = do
|
||||
_fromTime <- readTVar $ fromTime s
|
||||
_qCreated <- readTVar $ qCreated s
|
||||
_qSecured <- readTVar $ qSecured s
|
||||
_qDeleted <- readTVar $ qDeleted s
|
||||
_msgSent <- readTVar $ msgSent s
|
||||
_msgRecv <- readTVar $ msgRecv s
|
||||
_dayMsgQueues <- readTVar $ dayMsgQueues s
|
||||
_weekMsgQueues <- readTVar $ weekMsgQueues s
|
||||
_monthMsgQueues <- readTVar $ monthMsgQueues s
|
||||
pure ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _dayMsgQueues, _weekMsgQueues, _monthMsgQueues}
|
||||
|
||||
setServerStatsData :: ServerStats -> ServerStatsData -> STM ()
|
||||
setServerStatsData s d = do
|
||||
writeTVar (fromTime s) (_fromTime d)
|
||||
writeTVar (qCreated s) (_qCreated d)
|
||||
writeTVar (qSecured s) (_qSecured d)
|
||||
writeTVar (qDeleted s) (_qDeleted d)
|
||||
writeTVar (msgSent s) (_msgSent d)
|
||||
writeTVar (msgRecv s) (_msgRecv d)
|
||||
writeTVar (dayMsgQueues s) (_dayMsgQueues d)
|
||||
writeTVar (weekMsgQueues s) (_weekMsgQueues d)
|
||||
writeTVar (monthMsgQueues s) (_monthMsgQueues d)
|
||||
|
||||
instance StrEncoding ServerStatsData where
|
||||
strEncode ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _dayMsgQueues, _weekMsgQueues, _monthMsgQueues} =
|
||||
B.unlines
|
||||
[ "fromTime=" <> strEncode _fromTime,
|
||||
"qCreated=" <> strEncode _qCreated,
|
||||
"qSecured=" <> strEncode _qSecured,
|
||||
"qDeleted=" <> strEncode _qDeleted,
|
||||
"msgSent=" <> strEncode _msgSent,
|
||||
"msgRecv=" <> strEncode _msgRecv,
|
||||
"dayMsgQueues=" <> strEncode _dayMsgQueues,
|
||||
"weekMsgQueues=" <> strEncode _weekMsgQueues,
|
||||
"monthMsgQueues=" <> strEncode _monthMsgQueues
|
||||
]
|
||||
strP = do
|
||||
_fromTime <- "fromTime=" *> strP <* A.endOfLine
|
||||
_qCreated <- "qCreated=" *> strP <* A.endOfLine
|
||||
_qSecured <- "qSecured=" *> strP <* A.endOfLine
|
||||
_qDeleted <- "qDeleted=" *> strP <* A.endOfLine
|
||||
_msgSent <- "msgSent=" *> strP <* A.endOfLine
|
||||
_msgRecv <- "msgRecv=" *> strP <* A.endOfLine
|
||||
_dayMsgQueues <- "dayMsgQueues=" *> strP <* A.endOfLine
|
||||
_weekMsgQueues <- "weekMsgQueues=" *> strP <* A.endOfLine
|
||||
_monthMsgQueues <- "monthMsgQueues=" *> strP <* optional A.endOfLine
|
||||
pure ServerStatsData {_fromTime, _qCreated, _qSecured, _qDeleted, _msgSent, _msgRecv, _dayMsgQueues, _weekMsgQueues, _monthMsgQueues}
|
|
@ -13,10 +13,12 @@ module Simplex.Messaging.Server.StoreLog
|
|||
openReadStoreLog,
|
||||
storeLogFilePath,
|
||||
closeStoreLog,
|
||||
writeStoreLogRecord,
|
||||
logCreateQueue,
|
||||
logSecureQueue,
|
||||
logAddNotifier,
|
||||
logDeleteQueue,
|
||||
logDeleteNotifier,
|
||||
readWriteStoreLog,
|
||||
)
|
||||
where
|
||||
|
@ -34,7 +36,7 @@ import Data.Map.Strict (Map)
|
|||
import qualified Data.Map.Strict as M
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.QueueStore (QueueRec (..), QueueStatus (..))
|
||||
import Simplex.Messaging.Server.QueueStore (NtfCreds (..), QueueRec (..), QueueStatus (..))
|
||||
import Simplex.Messaging.Transport (trimCR)
|
||||
import System.Directory (doesFileExist)
|
||||
import System.IO
|
||||
|
@ -48,8 +50,9 @@ data StoreLog (a :: IOMode) where
|
|||
data StoreLogRecord
|
||||
= CreateQueue QueueRec
|
||||
| SecureQueue QueueId SndPublicVerifyKey
|
||||
| AddNotifier QueueId NotifierId NtfPublicVerifyKey
|
||||
| AddNotifier QueueId NtfCreds
|
||||
| DeleteQueue QueueId
|
||||
| DeleteNotifier QueueId
|
||||
|
||||
instance StrEncoding QueueRec where
|
||||
strEncode QueueRec {recipientId, recipientKey, rcvDhSecret, senderId, senderKey, notifier} =
|
||||
|
@ -62,7 +65,7 @@ instance StrEncoding QueueRec where
|
|||
]
|
||||
<> maybe "" notifierStr notifier
|
||||
where
|
||||
notifierStr (nId, nKey) = " nid=" <> strEncode nId <> " nk=" <> strEncode nKey
|
||||
notifierStr ntfCreds = " notifier=" <> strEncode ntfCreds
|
||||
|
||||
strP = do
|
||||
recipientId <- "rid=" *> strP_
|
||||
|
@ -70,24 +73,29 @@ instance StrEncoding QueueRec where
|
|||
rcvDhSecret <- "rdh=" *> strP_
|
||||
senderId <- "sid=" *> strP_
|
||||
senderKey <- "sk=" *> strP
|
||||
notifier <- optional $ (,) <$> (" nid=" *> strP_) <*> ("nk=" *> strP)
|
||||
notifier <- optional $ " notifier=" *> strP
|
||||
pure QueueRec {recipientId, recipientKey, rcvDhSecret, senderId, senderKey, notifier, status = QueueActive}
|
||||
|
||||
instance StrEncoding StoreLogRecord where
|
||||
strEncode = \case
|
||||
CreateQueue q -> strEncode (Str "CREATE", q)
|
||||
SecureQueue rId sKey -> strEncode (Str "SECURE", rId, sKey)
|
||||
AddNotifier rId nId nKey -> strEncode (Str "NOTIFIER", rId, nId, nKey)
|
||||
AddNotifier rId ntfCreds -> strEncode (Str "NOTIFIER", rId, ntfCreds)
|
||||
DeleteQueue rId -> strEncode (Str "DELETE", rId)
|
||||
DeleteNotifier rId -> strEncode (Str "NDELETE", rId)
|
||||
|
||||
strP =
|
||||
"CREATE " *> (CreateQueue <$> strP)
|
||||
<|> "SECURE " *> (SecureQueue <$> strP_ <*> strP)
|
||||
<|> "NOTIFIER " *> (AddNotifier <$> strP_ <*> strP_ <*> strP)
|
||||
<|> "NOTIFIER " *> (AddNotifier <$> strP_ <*> strP)
|
||||
<|> "DELETE " *> (DeleteQueue <$> strP)
|
||||
<|> "NDELETE " *> (DeleteNotifier <$> strP)
|
||||
|
||||
openWriteStoreLog :: FilePath -> IO (StoreLog 'WriteMode)
|
||||
openWriteStoreLog f = WriteStoreLog f <$> openFile f WriteMode
|
||||
openWriteStoreLog f = do
|
||||
h <- openFile f WriteMode
|
||||
hSetBuffering h LineBuffering
|
||||
pure $ WriteStoreLog f h
|
||||
|
||||
openReadStoreLog :: FilePath -> IO (StoreLog 'ReadMode)
|
||||
openReadStoreLog f = do
|
||||
|
@ -104,7 +112,7 @@ closeStoreLog = \case
|
|||
WriteStoreLog _ h -> hClose h
|
||||
ReadStoreLog _ h -> hClose h
|
||||
|
||||
writeStoreLogRecord :: StoreLog 'WriteMode -> StoreLogRecord -> IO ()
|
||||
writeStoreLogRecord :: StrEncoding r => StoreLog 'WriteMode -> r -> IO ()
|
||||
writeStoreLogRecord (WriteStoreLog _ h) r = do
|
||||
B.hPutStrLn h $ strEncode r
|
||||
hFlush h
|
||||
|
@ -115,12 +123,15 @@ logCreateQueue s = writeStoreLogRecord s . CreateQueue
|
|||
logSecureQueue :: StoreLog 'WriteMode -> QueueId -> SndPublicVerifyKey -> IO ()
|
||||
logSecureQueue s qId sKey = writeStoreLogRecord s $ SecureQueue qId sKey
|
||||
|
||||
logAddNotifier :: StoreLog 'WriteMode -> QueueId -> NotifierId -> NtfPublicVerifyKey -> IO ()
|
||||
logAddNotifier s qId nId nKey = writeStoreLogRecord s $ AddNotifier qId nId nKey
|
||||
logAddNotifier :: StoreLog 'WriteMode -> QueueId -> NtfCreds -> IO ()
|
||||
logAddNotifier s qId ntfCreds = writeStoreLogRecord s $ AddNotifier qId ntfCreds
|
||||
|
||||
logDeleteQueue :: StoreLog 'WriteMode -> QueueId -> IO ()
|
||||
logDeleteQueue s = writeStoreLogRecord s . DeleteQueue
|
||||
|
||||
logDeleteNotifier :: StoreLog 'WriteMode -> QueueId -> IO ()
|
||||
logDeleteNotifier s = writeStoreLogRecord s . DeleteNotifier
|
||||
|
||||
readWriteStoreLog :: StoreLog 'ReadMode -> IO (Map RecipientId QueueRec, StoreLog 'WriteMode)
|
||||
readWriteStoreLog s@(ReadStoreLog f _) = do
|
||||
qs <- readQueues s
|
||||
|
@ -149,7 +160,8 @@ readQueues (ReadStoreLog _ h) = LB.hGetContents h >>= returnResult . procStoreLo
|
|||
procLogRecord m = \case
|
||||
CreateQueue q -> M.insert (recipientId q) q m
|
||||
SecureQueue qId sKey -> M.adjust (\q -> q {senderKey = Just sKey}) qId m
|
||||
AddNotifier qId nId nKey -> M.adjust (\q -> q {notifier = Just (nId, nKey)}) qId m
|
||||
AddNotifier qId ntfCreds -> M.adjust (\q -> q {notifier = Just ntfCreds}) qId m
|
||||
DeleteQueue qId -> M.delete qId m
|
||||
DeleteNotifier qId -> M.adjust (\q -> q {notifier = Nothing}) qId m
|
||||
printError :: LogParsingError -> IO ()
|
||||
printError (e, s) = B.putStrLn $ "Error parsing log: " <> B.pack e <> " - " <> s
|
||||
|
|
|
@ -2,6 +2,7 @@ module Simplex.Messaging.TMap
|
|||
( TMap,
|
||||
empty,
|
||||
singleton,
|
||||
Simplex.Messaging.TMap.null,
|
||||
Simplex.Messaging.TMap.lookup,
|
||||
member,
|
||||
insert,
|
||||
|
@ -11,6 +12,7 @@ module Simplex.Messaging.TMap
|
|||
adjust,
|
||||
update,
|
||||
alter,
|
||||
alterF,
|
||||
union,
|
||||
)
|
||||
where
|
||||
|
@ -29,6 +31,10 @@ singleton :: k -> a -> STM (TMap k a)
|
|||
singleton k v = newTVar $ M.singleton k v
|
||||
{-# INLINE singleton #-}
|
||||
|
||||
null :: TMap k a -> STM Bool
|
||||
null m = M.null <$> readTVar m
|
||||
{-# INLINE null #-}
|
||||
|
||||
lookup :: Ord k => k -> TMap k a -> STM (Maybe a)
|
||||
lookup k m = M.lookup k <$> readTVar m
|
||||
{-# INLINE lookup #-}
|
||||
|
@ -65,6 +71,12 @@ alter :: Ord k => (Maybe a -> Maybe a) -> k -> TMap k a -> STM ()
|
|||
alter f k m = modifyTVar' m $ M.alter f k
|
||||
{-# INLINE alter #-}
|
||||
|
||||
alterF :: Ord k => (Maybe a -> STM (Maybe a)) -> k -> TMap k a -> STM ()
|
||||
alterF f k m = do
|
||||
mv <- M.alterF f k =<< readTVar m
|
||||
writeTVar m $! mv
|
||||
{-# INLINE alterF #-}
|
||||
|
||||
union :: Ord k => Map k a -> TMap k a -> STM ()
|
||||
union m' m = modifyTVar' m $ M.union m'
|
||||
{-# INLINE union #-}
|
||||
|
|
|
@ -26,8 +26,7 @@
|
|||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
|
||||
module Simplex.Messaging.Transport
|
||||
( -- * SMP transport parameters
|
||||
smpBlockSize,
|
||||
supportedSMPVersions,
|
||||
supportedSMPServerVRange,
|
||||
simplexMQVersion,
|
||||
|
||||
-- * Transport connection class
|
||||
|
@ -47,12 +46,15 @@ module Simplex.Messaging.Transport
|
|||
-- * SMP transport
|
||||
THandle (..),
|
||||
TransportError (..),
|
||||
serverHandshake,
|
||||
clientHandshake,
|
||||
HandshakeError (..),
|
||||
smpServerHandshake,
|
||||
smpClientHandshake,
|
||||
tPutBlock,
|
||||
tGetBlock,
|
||||
serializeTransportError,
|
||||
transportErrorP,
|
||||
sendHandshake,
|
||||
getHandshake,
|
||||
|
||||
-- * Trim trailing CR
|
||||
trimCR,
|
||||
|
@ -81,7 +83,7 @@ import qualified Network.TLS.Extra as TE
|
|||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Parsers (dropPrefix, parse, parseRead1, sumTypeJSON)
|
||||
import Simplex.Messaging.Util (bshow)
|
||||
import Simplex.Messaging.Util (bshow, catchAll, catchAll_)
|
||||
import Simplex.Messaging.Version
|
||||
import Test.QuickCheck (Arbitrary (..))
|
||||
import UnliftIO.Exception (Exception)
|
||||
|
@ -93,11 +95,11 @@ import UnliftIO.STM
|
|||
smpBlockSize :: Int
|
||||
smpBlockSize = 16384
|
||||
|
||||
supportedSMPVersions :: VersionRange
|
||||
supportedSMPVersions = mkVersionRange 1 1
|
||||
supportedSMPServerVRange :: VersionRange
|
||||
supportedSMPServerVRange = mkVersionRange 1 3
|
||||
|
||||
simplexMQVersion :: String
|
||||
simplexMQVersion = "1.1.0"
|
||||
simplexMQVersion = "3.0.0"
|
||||
|
||||
-- * Transport connection class
|
||||
|
||||
|
@ -155,7 +157,7 @@ connectTLS :: T.TLSParams p => p -> Socket -> IO T.Context
|
|||
connectTLS params sock =
|
||||
E.bracketOnError (T.contextNew sock params) closeTLS $ \ctx -> do
|
||||
T.handshake ctx
|
||||
`E.catch` \(e :: E.SomeException) -> putStrLn ("exception: " <> show e) >> E.throwIO e
|
||||
`catchAll` \e -> putStrLn ("exception: " <> show e) >> E.throwIO e
|
||||
pure ctx
|
||||
|
||||
getTLS :: TransportPeer -> T.Context -> IO TLS
|
||||
|
@ -176,8 +178,9 @@ withTlsUnique peer cxt f =
|
|||
|
||||
closeTLS :: T.Context -> IO ()
|
||||
closeTLS ctx =
|
||||
(T.bye ctx >> T.contextClose ctx) -- sometimes socket was closed before 'TLS.bye'
|
||||
`E.catch` (\(_ :: E.SomeException) -> pure ()) -- so we catch the 'Broken pipe' error here
|
||||
T.bye ctx -- sometimes socket was closed before 'TLS.bye' so we catch the 'Broken pipe' error here
|
||||
`E.finally` T.contextClose ctx
|
||||
`catchAll_` pure ()
|
||||
|
||||
supportedParameters :: T.Supported
|
||||
supportedParameters =
|
||||
|
@ -214,10 +217,11 @@ instance Transport TLS where
|
|||
readChunks :: ByteString -> IO ByteString
|
||||
readChunks b
|
||||
| B.length b >= n = pure b
|
||||
| otherwise = readChunks . (b <>) =<< T.recvData tlsContext `E.catch` handleEOF
|
||||
handleEOF = \case
|
||||
T.Error_EOF -> E.throwIO TEBadBlock
|
||||
e -> E.throwIO e
|
||||
| otherwise =
|
||||
T.recvData tlsContext >>= \case
|
||||
-- https://hackage.haskell.org/package/tls-1.6.0/docs/Network-TLS.html#v:recvData
|
||||
"" -> ioe_EOF
|
||||
s -> readChunks $ b <> s
|
||||
|
||||
cPut :: TLS -> ByteString -> IO ()
|
||||
cPut tls = T.sendData (tlsContext tls) . BL.fromStrict
|
||||
|
@ -252,8 +256,9 @@ trimCR s = if B.last s == '\r' then B.init s else s
|
|||
data THandle c = THandle
|
||||
{ connection :: c,
|
||||
sessionId :: SessionId,
|
||||
-- | agreed SMP server protocol version
|
||||
smpVersion :: Version
|
||||
blockSize :: Int,
|
||||
-- | agreed server protocol version
|
||||
thVersion :: Version
|
||||
}
|
||||
|
||||
-- | TLS-unique channel binding
|
||||
|
@ -336,45 +341,45 @@ serializeTransportError = \case
|
|||
|
||||
-- | Pad and send block to SMP transport.
|
||||
tPutBlock :: Transport c => THandle c -> ByteString -> IO (Either TransportError ())
|
||||
tPutBlock THandle {connection = c} block =
|
||||
tPutBlock THandle {connection = c, blockSize} block =
|
||||
bimapM (const $ pure TELargeMsg) (cPut c) $
|
||||
C.pad block smpBlockSize
|
||||
C.pad block blockSize
|
||||
|
||||
-- | Receive block from SMP transport.
|
||||
tGetBlock :: Transport c => THandle c -> IO (Either TransportError ByteString)
|
||||
tGetBlock THandle {connection = c} =
|
||||
cGet c smpBlockSize >>= \case
|
||||
tGetBlock THandle {connection = c, blockSize} =
|
||||
cGet c blockSize >>= \case
|
||||
"" -> ioe_EOF
|
||||
msg -> pure . first (const TELargeMsg) $ C.unPad msg
|
||||
|
||||
-- | Server SMP transport handshake.
|
||||
--
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
|
||||
serverHandshake :: forall c. Transport c => c -> C.KeyHash -> ExceptT TransportError IO (THandle c)
|
||||
serverHandshake c kh = do
|
||||
let th@THandle {sessionId} = tHandle c
|
||||
sendHandshake th $ ServerHandshake {sessionId, smpVersionRange = supportedSMPVersions}
|
||||
smpServerHandshake :: forall c. Transport c => c -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
smpServerHandshake c kh smpVRange = do
|
||||
let th@THandle {sessionId} = smpTHandle c
|
||||
sendHandshake th $ ServerHandshake {sessionId, smpVersionRange = smpVRange}
|
||||
getHandshake th >>= \case
|
||||
ClientHandshake {smpVersion, keyHash}
|
||||
| keyHash /= kh ->
|
||||
throwE $ TEHandshake IDENTITY
|
||||
| smpVersion `isCompatible` supportedSMPVersions -> do
|
||||
pure (th :: THandle c) {smpVersion}
|
||||
| smpVersion `isCompatible` smpVRange -> do
|
||||
pure (th :: THandle c) {thVersion = smpVersion}
|
||||
| otherwise -> throwE $ TEHandshake VERSION
|
||||
|
||||
-- | Client SMP transport handshake.
|
||||
--
|
||||
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
|
||||
clientHandshake :: forall c. Transport c => c -> C.KeyHash -> ExceptT TransportError IO (THandle c)
|
||||
clientHandshake c keyHash = do
|
||||
let th@THandle {sessionId} = tHandle c
|
||||
smpClientHandshake :: forall c. Transport c => c -> C.KeyHash -> VersionRange -> ExceptT TransportError IO (THandle c)
|
||||
smpClientHandshake c keyHash smpVRange = do
|
||||
let th@THandle {sessionId} = smpTHandle c
|
||||
ServerHandshake {sessionId = sessId, smpVersionRange} <- getHandshake th
|
||||
if sessionId /= sessId
|
||||
then throwE TEBadSession
|
||||
else case smpVersionRange `compatibleVersion` supportedSMPVersions of
|
||||
else case smpVersionRange `compatibleVersion` smpVRange of
|
||||
Just (Compatible smpVersion) -> do
|
||||
sendHandshake th $ ClientHandshake {smpVersion, keyHash}
|
||||
pure (th :: THandle c) {smpVersion}
|
||||
pure (th :: THandle c) {thVersion = smpVersion}
|
||||
Nothing -> throwE $ TEHandshake VERSION
|
||||
|
||||
sendHandshake :: (Transport c, Encoding smp) => THandle c -> smp -> ExceptT TransportError IO ()
|
||||
|
@ -383,6 +388,5 @@ sendHandshake th = ExceptT . tPutBlock th . smpEncode
|
|||
getHandshake :: (Transport c, Encoding smp) => THandle c -> ExceptT TransportError IO smp
|
||||
getHandshake th = ExceptT $ (parse smpP (TEHandshake PARSE) =<<) <$> tGetBlock th
|
||||
|
||||
tHandle :: Transport c => c -> THandle c
|
||||
tHandle c =
|
||||
THandle {connection = c, sessionId = tlsUnique c, smpVersion = 0}
|
||||
smpTHandle :: Transport c => c -> THandle c
|
||||
smpTHandle c = THandle {connection = c, sessionId = tlsUnique c, blockSize = smpBlockSize, thVersion = 0}
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
|
||||
module Simplex.Messaging.Transport.Client
|
||||
( runTransportClient,
|
||||
clientHandshake,
|
||||
runTLSTransportClient,
|
||||
smpClientHandshake,
|
||||
)
|
||||
where
|
||||
|
||||
|
@ -26,9 +27,12 @@ import UnliftIO.Exception (IOException)
|
|||
import qualified UnliftIO.Exception as E
|
||||
|
||||
-- | Connect to passed TCP host:port and pass handle to the client.
|
||||
runTransportClient :: Transport c => MonadUnliftIO m => HostName -> ServiceName -> C.KeyHash -> Maybe KeepAliveOpts -> (c -> m a) -> m a
|
||||
runTransportClient host port keyHash keepAliveOpts client = do
|
||||
let clientParams = mkTLSClientParams host port keyHash
|
||||
runTransportClient :: (Transport c, MonadUnliftIO m) => HostName -> ServiceName -> Maybe C.KeyHash -> Maybe KeepAliveOpts -> (c -> m a) -> m a
|
||||
runTransportClient = runTLSTransportClient supportedParameters Nothing
|
||||
|
||||
runTLSTransportClient :: (Transport c, MonadUnliftIO m) => T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> Maybe KeepAliveOpts -> (c -> m a) -> m a
|
||||
runTLSTransportClient tlsParams caStore_ host port keyHash keepAliveOpts client = do
|
||||
let clientParams = mkTLSClientParams tlsParams caStore_ host port keyHash
|
||||
c <- liftIO $ startTCPClient host port clientParams keepAliveOpts
|
||||
client c `E.finally` liftIO (closeConnection c)
|
||||
|
||||
|
@ -56,13 +60,15 @@ startTCPClient host port clientParams keepAliveOpts = withSocketsDo $ resolve >>
|
|||
ctx <- connectTLS clientParams sock
|
||||
getClientConnection ctx
|
||||
|
||||
mkTLSClientParams :: HostName -> ServiceName -> C.KeyHash -> T.ClientParams
|
||||
mkTLSClientParams host port keyHash = do
|
||||
-- readCertificateStore :: FilePath -> IO (Maybe CertificateStore)
|
||||
|
||||
mkTLSClientParams :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe C.KeyHash -> T.ClientParams
|
||||
mkTLSClientParams supported caStore_ host port keyHash_ = do
|
||||
let p = B.pack port
|
||||
(T.defaultParamsClient host p)
|
||||
{ T.clientShared = def,
|
||||
T.clientHooks = def {T.onServerCertificate = \_ _ _ -> validateCertificateChain keyHash host p},
|
||||
T.clientSupported = supportedParameters
|
||||
{ T.clientShared = maybe def (\caStore -> def {T.sharedCAStore = caStore}) caStore_,
|
||||
T.clientHooks = maybe def (\keyHash -> def {T.onServerCertificate = \_ _ _ -> validateCertificateChain keyHash host p}) keyHash_,
|
||||
T.clientSupported = supported
|
||||
}
|
||||
|
||||
validateCertificateChain :: C.KeyHash -> HostName -> ByteString -> X.CertificateChain -> IO [XV.FailedReason]
|
||||
|
@ -77,7 +83,7 @@ validateCertificateChain (C.KeyHash kh) host port cc@(X.CertificateChain sc@[_,
|
|||
x509validate = XV.validate X.HashSHA256 hooks checks certStore cache serviceID cc
|
||||
where
|
||||
hooks = XV.defaultHooks
|
||||
checks = XV.defaultChecks
|
||||
checks = XV.defaultChecks {XV.checkFQHN = False}
|
||||
certStore = XS.makeCertificateStore sc
|
||||
cache = XV.exceptionValidationCache [] -- we manually check fingerprint only of the identity certificate (ca.crt)
|
||||
serviceID = (host, port)
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
module Simplex.Messaging.Transport.HTTP2 where
|
||||
|
||||
import qualified Control.Exception as E
|
||||
import Data.Default (def)
|
||||
import Foreign (mallocBytes)
|
||||
import Network.HPACK (BufferSize)
|
||||
import Network.HTTP2.Client (Config (..), defaultPositionReadMaker, freeSimpleConfig)
|
||||
import qualified Network.TLS as T
|
||||
import qualified Network.TLS.Extra as TE
|
||||
import Simplex.Messaging.Transport (TLS, Transport (cGet, cPut))
|
||||
import qualified System.TimeManager as TI
|
||||
|
||||
withTlsConfig :: TLS -> BufferSize -> (Config -> IO ()) -> IO ()
|
||||
withTlsConfig c sz = E.bracket (allocTlsConfig c sz) freeSimpleConfig
|
||||
|
||||
allocTlsConfig :: TLS -> BufferSize -> IO Config
|
||||
allocTlsConfig c sz = do
|
||||
buf <- mallocBytes sz
|
||||
tm <- TI.initialize $ 30 * 1000000
|
||||
pure
|
||||
Config
|
||||
{ confWriteBuffer = buf,
|
||||
confBufferSize = sz,
|
||||
confSendAll = cPut c,
|
||||
confReadN = cGet c,
|
||||
confPositionReadMaker = defaultPositionReadMaker,
|
||||
confTimeoutManager = tm
|
||||
}
|
||||
|
||||
http2TLSParams :: T.Supported
|
||||
http2TLSParams =
|
||||
def
|
||||
{ T.supportedVersions = [T.TLS13, T.TLS12],
|
||||
T.supportedCiphers = TE.ciphersuite_strong_det,
|
||||
T.supportedSecureRenegotiation = False
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Transport.HTTP2.Client where
|
||||
|
||||
import Control.Concurrent.Async
|
||||
import Control.Exception (IOException)
|
||||
import qualified Control.Exception as E
|
||||
import Control.Monad.Except
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Maybe (isNothing)
|
||||
import qualified Data.X509.CertificateStore as XS
|
||||
import Network.HPACK (HeaderTable)
|
||||
import Network.HTTP2.Client (ClientConfig (..), Request, Response)
|
||||
import qualified Network.HTTP2.Client as H
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import qualified Network.TLS as T
|
||||
import Numeric.Natural (Natural)
|
||||
import Simplex.Messaging.Transport.Client (runTLSTransportClient)
|
||||
import Simplex.Messaging.Transport.HTTP2 (http2TLSParams, withTlsConfig)
|
||||
import Simplex.Messaging.Transport.KeepAlive (KeepAliveOpts)
|
||||
import UnliftIO.STM
|
||||
import UnliftIO.Timeout
|
||||
|
||||
data HTTP2Client = HTTP2Client
|
||||
{ action :: Async (),
|
||||
connected :: TVar Bool,
|
||||
host :: HostName,
|
||||
port :: ServiceName,
|
||||
config :: HTTP2ClientConfig,
|
||||
reqQ :: TBQueue (Request, TMVar HTTP2Response)
|
||||
}
|
||||
|
||||
data HTTP2Response = HTTP2Response
|
||||
{ response :: Response,
|
||||
respBody :: ByteString,
|
||||
respTrailers :: Maybe HeaderTable
|
||||
}
|
||||
|
||||
data HTTP2ClientConfig = HTTP2ClientConfig
|
||||
{ qSize :: Natural,
|
||||
connTimeout :: Int,
|
||||
tcpKeepAlive :: Maybe KeepAliveOpts,
|
||||
caStoreFile :: FilePath,
|
||||
suportedTLSParams :: T.Supported
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
defaultHTTP2ClientConfig :: HTTP2ClientConfig
|
||||
defaultHTTP2ClientConfig =
|
||||
HTTP2ClientConfig
|
||||
{ qSize = 64,
|
||||
connTimeout = 10000000,
|
||||
tcpKeepAlive = Nothing,
|
||||
caStoreFile = "/etc/ssl/cert.pem",
|
||||
suportedTLSParams = http2TLSParams
|
||||
}
|
||||
|
||||
data HTTP2ClientError = HCResponseTimeout | HCNetworkError | HCNetworkError1 | HCIOError IOException
|
||||
deriving (Show)
|
||||
|
||||
getHTTP2Client :: HostName -> ServiceName -> HTTP2ClientConfig -> IO () -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
getHTTP2Client host port config@HTTP2ClientConfig {tcpKeepAlive, connTimeout, caStoreFile, suportedTLSParams} disconnected =
|
||||
(atomically mkHTTPS2Client >>= runClient)
|
||||
`E.catch` \(e :: IOException) -> pure . Left $ HCIOError e
|
||||
where
|
||||
mkHTTPS2Client :: STM HTTP2Client
|
||||
mkHTTPS2Client = do
|
||||
connected <- newTVar False
|
||||
reqQ <- newTBQueue $ qSize config
|
||||
pure HTTP2Client {action = undefined, connected, host, port, config, reqQ}
|
||||
|
||||
runClient :: HTTP2Client -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
runClient c = do
|
||||
cVar <- newEmptyTMVarIO
|
||||
caStore <- XS.readCertificateStore caStoreFile
|
||||
when (isNothing caStore) . putStrLn $ "Error loading CertificateStore from " <> caStoreFile
|
||||
action <-
|
||||
async $
|
||||
runHTTP2Client suportedTLSParams caStore host port tcpKeepAlive (client c cVar)
|
||||
`E.finally` atomically (putTMVar cVar $ Left HCNetworkError)
|
||||
conn_ <- connTimeout `timeout` atomically (takeTMVar cVar)
|
||||
pure $ case conn_ of
|
||||
Just (Right ()) -> Right c {action}
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left HCNetworkError1
|
||||
|
||||
client :: HTTP2Client -> TMVar (Either HTTP2ClientError ()) -> (Request -> (Response -> IO ()) -> IO ()) -> IO ()
|
||||
client c cVar sendReq = do
|
||||
atomically $ do
|
||||
writeTVar (connected c) True
|
||||
putTMVar cVar $ Right ()
|
||||
process c sendReq `E.finally` disconnected
|
||||
|
||||
process :: HTTP2Client -> (Request -> (Response -> IO ()) -> IO ()) -> IO ()
|
||||
process HTTP2Client {reqQ} sendReq = forever $ do
|
||||
(req, respVar) <- atomically $ readTBQueue reqQ
|
||||
sendReq req $ \r -> do
|
||||
let writeResp respBody respTrailers = atomically $ putTMVar respVar HTTP2Response {response = r, respBody, respTrailers}
|
||||
respBody <- getResponseBody r ""
|
||||
respTrailers <- H.getResponseTrailers r
|
||||
writeResp respBody respTrailers
|
||||
|
||||
getResponseBody :: Response -> ByteString -> IO ByteString
|
||||
getResponseBody r s =
|
||||
H.getResponseBodyChunk r >>= \chunk ->
|
||||
if B.null chunk then pure s else getResponseBody r $ s <> chunk
|
||||
|
||||
-- | Disconnects client from the server and terminates client threads.
|
||||
closeHTTP2Client :: HTTP2Client -> IO ()
|
||||
-- TODO disconnect
|
||||
closeHTTP2Client = uninterruptibleCancel . action
|
||||
|
||||
sendRequest :: HTTP2Client -> Request -> IO (Either HTTP2ClientError HTTP2Response)
|
||||
sendRequest HTTP2Client {reqQ, config} req = do
|
||||
resp <- newEmptyTMVarIO
|
||||
atomically $ writeTBQueue reqQ (req, resp)
|
||||
maybe (Left HCResponseTimeout) Right <$> (connTimeout config `timeout` atomically (takeTMVar resp))
|
||||
|
||||
runHTTP2Client :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe KeepAliveOpts -> ((Request -> (Response -> IO ()) -> IO ()) -> IO ()) -> IO ()
|
||||
runHTTP2Client tlsParams caStore host port keepAliveOpts client =
|
||||
runTLSTransportClient tlsParams caStore host port Nothing keepAliveOpts $ \c ->
|
||||
withTlsConfig c 16384 (`run` client)
|
||||
where
|
||||
run = H.run $ ClientConfig "https" (B.pack host) 20
|
|
@ -0,0 +1,70 @@
|
|||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Transport.HTTP2.Server where
|
||||
|
||||
import Control.Concurrent.Async (Async, async, uninterruptibleCancel)
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Network.HPACK (HeaderTable)
|
||||
import Network.HTTP2.Server (Aux, PushPromise, Request, Response)
|
||||
import qualified Network.HTTP2.Server as H
|
||||
import Network.Socket
|
||||
import qualified Network.TLS as T
|
||||
import Numeric.Natural (Natural)
|
||||
import Simplex.Messaging.Transport.HTTP2 (withTlsConfig)
|
||||
import Simplex.Messaging.Transport.Server (loadSupportedTLSServerParams, runTransportServer)
|
||||
|
||||
type HTTP2ServerFunc = (Request -> (Response -> IO ()) -> IO ())
|
||||
|
||||
data HTTP2ServerConfig = HTTP2ServerConfig
|
||||
{ qSize :: Natural,
|
||||
http2Port :: ServiceName,
|
||||
serverSupported :: T.Supported,
|
||||
caCertificateFile :: FilePath,
|
||||
privateKeyFile :: FilePath,
|
||||
certificateFile :: FilePath
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data HTTP2Request = HTTP2Request
|
||||
{ request :: Request,
|
||||
reqBody :: ByteString,
|
||||
reqTrailers :: Maybe HeaderTable,
|
||||
sendResponse :: Response -> IO ()
|
||||
}
|
||||
|
||||
data HTTP2Server = HTTP2Server
|
||||
{ action :: Async (),
|
||||
reqQ :: TBQueue HTTP2Request
|
||||
}
|
||||
|
||||
getHTTP2Server :: HTTP2ServerConfig -> IO HTTP2Server
|
||||
getHTTP2Server HTTP2ServerConfig {qSize, http2Port, serverSupported, caCertificateFile, certificateFile, privateKeyFile} = do
|
||||
tlsServerParams <- loadSupportedTLSServerParams serverSupported caCertificateFile certificateFile privateKeyFile
|
||||
started <- newEmptyTMVarIO
|
||||
reqQ <- newTBQueueIO qSize
|
||||
action <- async $
|
||||
runHTTP2Server started http2Port tlsServerParams $ \r sendResponse -> do
|
||||
reqBody <- getRequestBody r ""
|
||||
reqTrailers <- H.getRequestTrailers r
|
||||
atomically $ writeTBQueue reqQ HTTP2Request {request = r, reqBody, reqTrailers, sendResponse}
|
||||
void . atomically $ takeTMVar started
|
||||
pure HTTP2Server {action, reqQ}
|
||||
where
|
||||
getRequestBody :: Request -> ByteString -> IO ByteString
|
||||
getRequestBody r s =
|
||||
H.getRequestBodyChunk r >>= \chunk ->
|
||||
if B.null chunk then pure s else getRequestBody r $ s <> chunk
|
||||
|
||||
closeHTTP2Server :: HTTP2Server -> IO ()
|
||||
closeHTTP2Server = uninterruptibleCancel . action
|
||||
|
||||
runHTTP2Server :: TMVar Bool -> ServiceName -> T.ServerParams -> HTTP2ServerFunc -> IO ()
|
||||
runHTTP2Server started port serverParams http2Server =
|
||||
runTransportServer started port serverParams $ \c -> withTlsConfig c 16384 (`H.run` server)
|
||||
where
|
||||
server :: Request -> Aux -> (Response -> [PushPromise] -> IO ()) -> IO ()
|
||||
server req _aux sendResp = http2Server req (`sendResp` [])
|
|
@ -12,6 +12,7 @@ data KeepAliveOpts = KeepAliveOpts
|
|||
keepIntvl :: Int,
|
||||
keepCnt :: Int
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
defaultKeepAliveOpts :: KeepAliveOpts
|
||||
defaultKeepAliveOpts =
|
||||
|
|
|
@ -5,25 +5,29 @@
|
|||
module Simplex.Messaging.Transport.Server
|
||||
( runTransportServer,
|
||||
runTCPServer,
|
||||
loadSupportedTLSServerParams,
|
||||
loadTLSServerParams,
|
||||
loadFingerprint,
|
||||
serverHandshake,
|
||||
smpServerHandshake,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent.STM (stateTVar)
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import qualified Crypto.Store.X509 as SX
|
||||
import Data.Default (def)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import qualified Data.X509 as X
|
||||
import Data.X509.Validation (Fingerprint (..))
|
||||
import qualified Data.X509.Validation as XV
|
||||
import Network.Socket
|
||||
import qualified Network.TLS as T
|
||||
import Simplex.Messaging.TMap (TMap)
|
||||
import qualified Simplex.Messaging.TMap as TM
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Util (catchAll_)
|
||||
import System.Exit (exitFailure)
|
||||
import System.Mem.Weak (Weak, deRefWeak)
|
||||
import UnliftIO.Concurrent
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
|
@ -34,37 +38,30 @@ import UnliftIO.STM
|
|||
runTransportServer :: forall c m. (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> T.ServerParams -> (c -> m ()) -> m ()
|
||||
runTransportServer started port serverParams server = do
|
||||
u <- askUnliftIO
|
||||
liftIO $ do
|
||||
clients <- newTVarIO S.empty
|
||||
liftIO . runTCPServer started port $ \conn ->
|
||||
E.bracket
|
||||
(startTCPServer started port)
|
||||
(closeServer started clients)
|
||||
$ \sock -> forever $ do
|
||||
(connSock, _) <- accept sock
|
||||
tid <- forkIO $ connectClient u connSock `E.catch` \(_ :: E.SomeException) -> pure ()
|
||||
atomically . modifyTVar' clients $ S.insert tid
|
||||
where
|
||||
connectClient :: UnliftIO m -> Socket -> IO ()
|
||||
connectClient u connSock =
|
||||
E.bracket
|
||||
(connectTLS serverParams connSock >>= getServerConnection)
|
||||
closeConnection
|
||||
(unliftIO u . server)
|
||||
(connectTLS serverParams conn >>= getServerConnection)
|
||||
closeConnection
|
||||
(unliftIO u . server)
|
||||
|
||||
-- | Run TCP server without TLS
|
||||
runTCPServer :: TMVar Bool -> ServiceName -> (Socket -> IO ()) -> IO ()
|
||||
runTCPServer started port server = do
|
||||
clients <- newTVarIO S.empty
|
||||
clients <- atomically TM.empty
|
||||
clientId <- newTVarIO 0
|
||||
E.bracket
|
||||
(startTCPServer started port)
|
||||
(closeServer started clients)
|
||||
$ \sock -> forever $ do
|
||||
(connSock, _) <- accept sock
|
||||
tid <- forkIO $ server connSock `E.catch` \(_ :: E.SomeException) -> pure ()
|
||||
atomically . modifyTVar' clients $ S.insert tid
|
||||
$ \sock -> forever . E.bracketOnError (accept sock) (close . fst) $ \(conn, _peer) -> do
|
||||
-- catchAll_ is needed here in case the connection was closed earlier
|
||||
cId <- atomically $ stateTVar clientId $ \cId -> let cId' = cId + 1 in (cId', cId')
|
||||
let closeConn _ = atomically (TM.delete cId clients) >> gracefulClose conn 5000 `catchAll_` pure ()
|
||||
tId <- mkWeakThreadId =<< server conn `forkFinally` closeConn
|
||||
atomically $ TM.insert cId tId clients
|
||||
|
||||
closeServer :: TMVar Bool -> TVar (Set ThreadId) -> Socket -> IO ()
|
||||
closeServer :: TMVar Bool -> TMap Int (Weak ThreadId) -> Socket -> IO ()
|
||||
closeServer started clients sock = do
|
||||
readTVarIO clients >>= mapM_ killThread
|
||||
readTVarIO clients >>= mapM_ (deRefWeak >=> mapM_ killThread)
|
||||
close sock
|
||||
void . atomically $ tryPutTMVar started False
|
||||
|
||||
|
@ -84,7 +81,10 @@ startTCPServer started port = withSocketsDo $ resolve >>= open >>= setStarted
|
|||
setStarted sock = atomically (tryPutTMVar started True) >> pure sock
|
||||
|
||||
loadTLSServerParams :: FilePath -> FilePath -> FilePath -> IO T.ServerParams
|
||||
loadTLSServerParams caCertificateFile certificateFile privateKeyFile =
|
||||
loadTLSServerParams = loadSupportedTLSServerParams supportedParameters
|
||||
|
||||
loadSupportedTLSServerParams :: T.Supported -> FilePath -> FilePath -> FilePath -> IO T.ServerParams
|
||||
loadSupportedTLSServerParams serverSupported caCertificateFile certificateFile privateKeyFile =
|
||||
fromCredential <$> loadServerCredential
|
||||
where
|
||||
loadServerCredential :: IO T.Credential
|
||||
|
@ -98,7 +98,7 @@ loadTLSServerParams caCertificateFile certificateFile privateKeyFile =
|
|||
{ T.serverWantClientCert = False,
|
||||
T.serverShared = def {T.sharedCredentials = T.Credentials [credential]},
|
||||
T.serverHooks = def,
|
||||
T.serverSupported = supportedParameters
|
||||
T.serverSupported = serverSupported
|
||||
}
|
||||
|
||||
loadFingerprint :: FilePath -> IO Fingerprint
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Util where
|
||||
|
||||
import qualified Control.Exception as E
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.IO.Unlift
|
||||
import Control.Monad.Trans.Except
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as T
|
||||
import UnliftIO.Async
|
||||
|
||||
raceAny_ :: MonadUnliftIO m => [m a] -> m ()
|
||||
|
@ -30,6 +34,10 @@ bshow :: Show a => a -> ByteString
|
|||
bshow = B.pack . show
|
||||
{-# INLINE bshow #-}
|
||||
|
||||
tshow :: Show a => a -> Text
|
||||
tshow = T.pack . show
|
||||
{-# INLINE tshow #-}
|
||||
|
||||
maybeWord :: (a -> ByteString) -> Maybe a -> ByteString
|
||||
maybeWord f = maybe "" $ B.cons ' ' . f
|
||||
{-# INLINE maybeWord #-}
|
||||
|
@ -54,6 +62,10 @@ tryE :: Monad m => ExceptT e m a -> ExceptT e m (Either e a)
|
|||
tryE m = (Right <$> m) `catchE` (pure . Left)
|
||||
{-# INLINE tryE #-}
|
||||
|
||||
liftE :: (e -> e') -> ExceptT e IO a -> ExceptT e' IO a
|
||||
liftE f a = ExceptT $ first f <$> runExceptT a
|
||||
{-# INLINE liftE #-}
|
||||
|
||||
ifM :: Monad m => m Bool -> m a -> m a -> m a
|
||||
ifM ba t f = ba >>= \b -> if b then t else f
|
||||
{-# INLINE ifM #-}
|
||||
|
@ -65,3 +77,18 @@ whenM b a = ifM b a $ pure ()
|
|||
unlessM :: Monad m => m Bool -> m () -> m ()
|
||||
unlessM b = ifM b $ pure ()
|
||||
{-# INLINE unlessM #-}
|
||||
|
||||
($>>=) :: (Monad m, Monad f, Traversable f) => m (f a) -> (a -> m (f b)) -> m (f b)
|
||||
f $>>= g = f >>= fmap join . mapM g
|
||||
|
||||
catchAll :: IO a -> (E.SomeException -> IO a) -> IO a
|
||||
catchAll = E.catch
|
||||
{-# INLINE catchAll #-}
|
||||
|
||||
catchAll_ :: IO a -> IO a -> IO a
|
||||
catchAll_ a = catchAll a . const
|
||||
{-# INLINE catchAll_ #-}
|
||||
|
||||
eitherToMaybe :: Either a b -> Maybe b
|
||||
eitherToMaybe = either (const Nothing) Just
|
||||
{-# INLINE eitherToMaybe #-}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#
|
||||
# resolver: ./custom-snapshot.yaml
|
||||
# resolver: https://example.com/snapshots/2018-01-01.yaml
|
||||
resolver: lts-18.21
|
||||
resolver: lts-18.28
|
||||
|
||||
# User packages to be built.
|
||||
# Various formats can be used as shown in the example below.
|
||||
|
@ -39,7 +39,7 @@ extra-deps:
|
|||
- network-3.1.2.7@sha256:e3d78b13db9512aeb106e44a334ab42b7aa48d26c097299084084cb8be5c5568,4888
|
||||
- simple-logger-0.1.0@sha256:be8ede4bd251a9cac776533bae7fb643369ebd826eb948a9a18df1a8dd252ff8,1079
|
||||
- tls-1.5.7@sha256:1cc30253a9696b65a9cafc0317fbf09f7dcea15e3a145ed6c9c0e28c632fa23a,6991
|
||||
# below dependancies are to update Aeson to 2.0.3
|
||||
# below dependancies are to update Aeson to 2.0.3
|
||||
- OneTuple-0.3.1@sha256:a848c096c9d29e82ffdd30a9998aa2931cbccb3a1bc137539d80f6174d31603e,2262
|
||||
- attoparsec-0.14.4@sha256:79584bdada8b730cb5138fca8c35c76fbef75fc1d1e01e6b1d815a5ee9843191,5810
|
||||
- hashable-1.4.0.2@sha256:0cddd0229d1aac305ea0404409c0bbfab81f075817bd74b8b2929eff58333e55,5005
|
||||
|
@ -53,7 +53,6 @@ extra-deps:
|
|||
# commit: f6cc753611f80af300401cfae63846e9d7c40d9e
|
||||
# subdirs:
|
||||
# - core
|
||||
|
||||
# Override default flag values for local packages and extra-deps
|
||||
# flags: {}
|
||||
|
||||
|
|
|
@ -12,14 +12,16 @@ module AgentTests (agentTests) where
|
|||
import AgentTests.ConnectionRequestTests
|
||||
import AgentTests.DoubleRatchetTests (doubleRatchetTests)
|
||||
import AgentTests.FunctionalAPITests (functionalAPITests)
|
||||
import AgentTests.NotificationTests (notificationTests)
|
||||
import AgentTests.SQLiteTests (storeTests)
|
||||
import AgentTests.SchemaDump (schemaDumpTest)
|
||||
import Control.Concurrent
|
||||
import Control.Monad (forM_)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Network.HTTP.Types (urlEncode)
|
||||
import SMPAgentClient
|
||||
import SMPClient (testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn)
|
||||
import SMPClient (testKeyHash, testPort, testPort2, testStoreLogFile, withSmpServer, withSmpServerStoreLogOn)
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import qualified Simplex.Messaging.Agent.Protocol as A
|
||||
import Simplex.Messaging.Encoding.String
|
||||
|
@ -35,7 +37,9 @@ agentTests (ATransport t) = do
|
|||
describe "Connection request" connectionRequestTests
|
||||
describe "Double ratchet tests" doubleRatchetTests
|
||||
describe "Functional API" $ functionalAPITests (ATransport t)
|
||||
describe "Notification tests" $ notificationTests (ATransport t)
|
||||
describe "SQLite store" storeTests
|
||||
describe "SQLite schema dump" schemaDumpTest
|
||||
describe "SMP agent protocol syntax" $ syntaxTests t
|
||||
describe "Establishing duplex connection" $ do
|
||||
it "should connect via one server and one agent" $
|
||||
|
@ -115,7 +119,7 @@ h #:# err = tryGet `shouldReturn` ()
|
|||
_ -> return ()
|
||||
|
||||
pattern Msg :: MsgBody -> ACommand 'Agent
|
||||
pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} msgBody
|
||||
pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody
|
||||
|
||||
testDuplexConnection :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testDuplexConnection _ alice bob = do
|
||||
|
@ -128,25 +132,25 @@ testDuplexConnection _ alice bob = do
|
|||
bob <# ("", "alice", CON)
|
||||
alice <# ("", "bob", CON)
|
||||
-- message IDs 1 to 3 get assigned to control messages, so first MSG is assigned ID 4
|
||||
alice #: ("3", "bob", "SEND :hello") #> ("3", "bob", MID 5)
|
||||
alice <# ("", "bob", SENT 5)
|
||||
alice #: ("3", "bob", "SEND F :hello") #> ("3", "bob", MID 4)
|
||||
alice <# ("", "bob", SENT 4)
|
||||
bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False
|
||||
bob #: ("12", "alice", "ACK 5") #> ("12", "alice", OK)
|
||||
alice #: ("4", "bob", "SEND :how are you?") #> ("4", "bob", MID 6)
|
||||
alice <# ("", "bob", SENT 6)
|
||||
bob #: ("12", "alice", "ACK 4") #> ("12", "alice", OK)
|
||||
alice #: ("4", "bob", "SEND F :how are you?") #> ("4", "bob", MID 5)
|
||||
alice <# ("", "bob", SENT 5)
|
||||
bob <#= \case ("", "alice", Msg "how are you?") -> True; _ -> False
|
||||
bob #: ("13", "alice", "ACK 6") #> ("13", "alice", OK)
|
||||
bob #: ("14", "alice", "SEND 9\nhello too") #> ("14", "alice", MID 7)
|
||||
bob <# ("", "alice", SENT 7)
|
||||
bob #: ("13", "alice", "ACK 5") #> ("13", "alice", OK)
|
||||
bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", MID 6)
|
||||
bob <# ("", "alice", SENT 6)
|
||||
alice <#= \case ("", "bob", Msg "hello too") -> True; _ -> False
|
||||
alice #: ("3a", "bob", "ACK 7") #> ("3a", "bob", OK)
|
||||
bob #: ("15", "alice", "SEND 9\nmessage 1") #> ("15", "alice", MID 8)
|
||||
bob <# ("", "alice", SENT 8)
|
||||
alice #: ("3a", "bob", "ACK 6") #> ("3a", "bob", OK)
|
||||
bob #: ("15", "alice", "SEND F 9\nmessage 1") #> ("15", "alice", MID 7)
|
||||
bob <# ("", "alice", SENT 7)
|
||||
alice <#= \case ("", "bob", Msg "message 1") -> True; _ -> False
|
||||
alice #: ("4a", "bob", "ACK 8") #> ("4a", "bob", OK)
|
||||
alice #: ("4a", "bob", "ACK 7") #> ("4a", "bob", OK)
|
||||
alice #: ("5", "bob", "OFF") #> ("5", "bob", OK)
|
||||
bob #: ("17", "alice", "SEND 9\nmessage 3") #> ("17", "alice", MID 9)
|
||||
bob <# ("", "alice", MERR 9 (SMP AUTH))
|
||||
bob #: ("17", "alice", "SEND F 9\nmessage 3") #> ("17", "alice", MID 8)
|
||||
bob <# ("", "alice", MERR 8 (SMP AUTH))
|
||||
alice #: ("6", "bob", "DEL") #> ("6", "bob", OK)
|
||||
alice #:# "nothing else should be delivered to alice"
|
||||
|
||||
|
@ -161,25 +165,25 @@ testDuplexConnRandomIds _ alice bob = do
|
|||
bob <# ("", aliceConn, INFO "alice's connInfo")
|
||||
bob <# ("", aliceConn, CON)
|
||||
alice <# ("", bobConn, CON)
|
||||
alice #: ("2", bobConn, "SEND :hello") #> ("2", bobConn, MID 5)
|
||||
alice <# ("", bobConn, SENT 5)
|
||||
alice #: ("2", bobConn, "SEND F :hello") #> ("2", bobConn, MID 4)
|
||||
alice <# ("", bobConn, SENT 4)
|
||||
bob <#= \case ("", c, Msg "hello") -> c == aliceConn; _ -> False
|
||||
bob #: ("12", aliceConn, "ACK 5") #> ("12", aliceConn, OK)
|
||||
alice #: ("3", bobConn, "SEND :how are you?") #> ("3", bobConn, MID 6)
|
||||
alice <# ("", bobConn, SENT 6)
|
||||
bob #: ("12", aliceConn, "ACK 4") #> ("12", aliceConn, OK)
|
||||
alice #: ("3", bobConn, "SEND F :how are you?") #> ("3", bobConn, MID 5)
|
||||
alice <# ("", bobConn, SENT 5)
|
||||
bob <#= \case ("", c, Msg "how are you?") -> c == aliceConn; _ -> False
|
||||
bob #: ("13", aliceConn, "ACK 6") #> ("13", aliceConn, OK)
|
||||
bob #: ("14", aliceConn, "SEND 9\nhello too") #> ("14", aliceConn, MID 7)
|
||||
bob <# ("", aliceConn, SENT 7)
|
||||
bob #: ("13", aliceConn, "ACK 5") #> ("13", aliceConn, OK)
|
||||
bob #: ("14", aliceConn, "SEND F 9\nhello too") #> ("14", aliceConn, MID 6)
|
||||
bob <# ("", aliceConn, SENT 6)
|
||||
alice <#= \case ("", c, Msg "hello too") -> c == bobConn; _ -> False
|
||||
alice #: ("3a", bobConn, "ACK 7") #> ("3a", bobConn, OK)
|
||||
bob #: ("15", aliceConn, "SEND 9\nmessage 1") #> ("15", aliceConn, MID 8)
|
||||
bob <# ("", aliceConn, SENT 8)
|
||||
alice #: ("3a", bobConn, "ACK 6") #> ("3a", bobConn, OK)
|
||||
bob #: ("15", aliceConn, "SEND F 9\nmessage 1") #> ("15", aliceConn, MID 7)
|
||||
bob <# ("", aliceConn, SENT 7)
|
||||
alice <#= \case ("", c, Msg "message 1") -> c == bobConn; _ -> False
|
||||
alice #: ("4a", bobConn, "ACK 8") #> ("4a", bobConn, OK)
|
||||
alice #: ("4a", bobConn, "ACK 7") #> ("4a", bobConn, OK)
|
||||
alice #: ("5", bobConn, "OFF") #> ("5", bobConn, OK)
|
||||
bob #: ("17", aliceConn, "SEND 9\nmessage 3") #> ("17", aliceConn, MID 9)
|
||||
bob <# ("", aliceConn, MERR 9 (SMP AUTH))
|
||||
bob #: ("17", aliceConn, "SEND F 9\nmessage 3") #> ("17", aliceConn, MID 8)
|
||||
bob <# ("", aliceConn, MERR 8 (SMP AUTH))
|
||||
alice #: ("6", bobConn, "DEL") #> ("6", bobConn, OK)
|
||||
alice #:# "nothing else should be delivered to alice"
|
||||
|
||||
|
@ -196,10 +200,10 @@ testContactConnection _ alice bob tom = do
|
|||
alice <# ("", "bob", INFO "bob's connInfo 2")
|
||||
alice <# ("", "bob", CON)
|
||||
bob <# ("", "alice", CON)
|
||||
alice #: ("3", "bob", "SEND :hi") #> ("3", "bob", MID 5)
|
||||
alice <# ("", "bob", SENT 5)
|
||||
alice #: ("3", "bob", "SEND F :hi") #> ("3", "bob", MID 4)
|
||||
alice <# ("", "bob", SENT 4)
|
||||
bob <#= \case ("", "alice", Msg "hi") -> True; _ -> False
|
||||
bob #: ("13", "alice", "ACK 5") #> ("13", "alice", OK)
|
||||
bob #: ("13", "alice", "ACK 4") #> ("13", "alice", OK)
|
||||
|
||||
tom #: ("21", "alice", "JOIN " <> cReq' <> " 14\ntom's connInfo") #> ("21", "alice", OK)
|
||||
("", "alice_contact", Right (REQ aInvId' "tom's connInfo")) <- (alice <#:)
|
||||
|
@ -209,10 +213,10 @@ testContactConnection _ alice bob tom = do
|
|||
alice <# ("", "tom", INFO "tom's connInfo 2")
|
||||
alice <# ("", "tom", CON)
|
||||
tom <# ("", "alice", CON)
|
||||
alice #: ("5", "tom", "SEND :hi there") #> ("5", "tom", MID 5)
|
||||
alice <# ("", "tom", SENT 5)
|
||||
alice #: ("5", "tom", "SEND F :hi there") #> ("5", "tom", MID 4)
|
||||
alice <# ("", "tom", SENT 4)
|
||||
tom <#= \case ("", "alice", Msg "hi there") -> True; _ -> False
|
||||
tom #: ("23", "alice", "ACK 5") #> ("23", "alice", OK)
|
||||
tom #: ("23", "alice", "ACK 4") #> ("23", "alice", OK)
|
||||
|
||||
testContactConnRandomIds :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testContactConnRandomIds _ alice bob = do
|
||||
|
@ -232,10 +236,10 @@ testContactConnRandomIds _ alice bob = do
|
|||
alice <# ("", bobConn, CON)
|
||||
bob <# ("", aliceConn, CON)
|
||||
|
||||
alice #: ("3", bobConn, "SEND :hi") #> ("3", bobConn, MID 5)
|
||||
alice <# ("", bobConn, SENT 5)
|
||||
alice #: ("3", bobConn, "SEND F :hi") #> ("3", bobConn, MID 4)
|
||||
alice <# ("", bobConn, SENT 4)
|
||||
bob <#= \case ("", c, Msg "hi") -> c == aliceConn; _ -> False
|
||||
bob #: ("13", aliceConn, "ACK 5") #> ("13", aliceConn, OK)
|
||||
bob #: ("13", aliceConn, "ACK 4") #> ("13", aliceConn, OK)
|
||||
|
||||
testRejectContactRequest :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testRejectContactRequest _ alice bob = do
|
||||
|
@ -252,20 +256,20 @@ testRejectContactRequest _ alice bob = do
|
|||
testSubscription :: Transport c => TProxy c -> c -> c -> c -> IO ()
|
||||
testSubscription _ alice1 alice2 bob = do
|
||||
(alice1, "alice") `connect` (bob, "bob")
|
||||
bob #: ("12", "alice", "SEND 5\nhello") #> ("12", "alice", MID 5)
|
||||
bob <# ("", "alice", SENT 5)
|
||||
bob #: ("12", "alice", "SEND F 5\nhello") #> ("12", "alice", MID 4)
|
||||
bob <# ("", "alice", SENT 4)
|
||||
alice1 <#= \case ("", "bob", Msg "hello") -> True; _ -> False
|
||||
alice1 #: ("1", "bob", "ACK 5") #> ("1", "bob", OK)
|
||||
bob #: ("13", "alice", "SEND 11\nhello again") #> ("13", "alice", MID 6)
|
||||
bob <# ("", "alice", SENT 6)
|
||||
alice1 #: ("1", "bob", "ACK 4") #> ("1", "bob", OK)
|
||||
bob #: ("13", "alice", "SEND F 11\nhello again") #> ("13", "alice", MID 5)
|
||||
bob <# ("", "alice", SENT 5)
|
||||
alice1 <#= \case ("", "bob", Msg "hello again") -> True; _ -> False
|
||||
alice1 #: ("2", "bob", "ACK 6") #> ("2", "bob", OK)
|
||||
alice1 #: ("2", "bob", "ACK 5") #> ("2", "bob", OK)
|
||||
alice2 #: ("21", "bob", "SUB") #> ("21", "bob", OK)
|
||||
alice1 <# ("", "bob", END)
|
||||
bob #: ("14", "alice", "SEND 2\nhi") #> ("14", "alice", MID 7)
|
||||
bob <# ("", "alice", SENT 7)
|
||||
bob #: ("14", "alice", "SEND F 2\nhi") #> ("14", "alice", MID 6)
|
||||
bob <# ("", "alice", SENT 6)
|
||||
alice2 <#= \case ("", "bob", Msg "hi") -> True; _ -> False
|
||||
alice2 #: ("22", "bob", "ACK 7") #> ("22", "bob", OK)
|
||||
alice2 #: ("22", "bob", "ACK 6") #> ("22", "bob", OK)
|
||||
alice1 #:# "nothing else should be delivered to alice1"
|
||||
|
||||
testSubscrNotification :: Transport c => TProxy c -> (ThreadId, ThreadId) -> c -> IO ()
|
||||
|
@ -273,7 +277,7 @@ testSubscrNotification t (server, _) client = do
|
|||
client #: ("1", "conn1", "NEW INV") =#> \case ("1", "conn1", INV {}) -> True; _ -> False
|
||||
client #:# "nothing should be delivered to client before the server is killed"
|
||||
killThread server
|
||||
client <# ("", "conn1", DOWN)
|
||||
client <# ("", "", DOWN testSMPServer ["conn1"])
|
||||
withSmpServer (ATransport t) $
|
||||
client <# ("", "conn1", ERR (SMP AUTH)) -- this new server does not have the queue
|
||||
|
||||
|
@ -281,22 +285,23 @@ testMsgDeliveryServerRestart :: Transport c => TProxy c -> c -> c -> IO ()
|
|||
testMsgDeliveryServerRestart t alice bob = do
|
||||
withServer $ do
|
||||
connect (alice, "alice") (bob, "bob")
|
||||
bob #: ("1", "alice", "SEND 2\nhi") #> ("1", "alice", MID 5)
|
||||
bob <# ("", "alice", SENT 5)
|
||||
bob #: ("1", "alice", "SEND F 2\nhi") #> ("1", "alice", MID 4)
|
||||
bob <# ("", "alice", SENT 4)
|
||||
alice <#= \case ("", "bob", Msg "hi") -> True; _ -> False
|
||||
alice #: ("11", "bob", "ACK 5") #> ("11", "bob", OK)
|
||||
alice #: ("11", "bob", "ACK 4") #> ("11", "bob", OK)
|
||||
alice #:# "nothing else delivered before the server is killed"
|
||||
|
||||
alice <# ("", "bob", DOWN)
|
||||
bob #: ("2", "alice", "SEND 11\nhello again") #> ("2", "alice", MID 6)
|
||||
let server = (SMPServer "localhost" testPort2 testKeyHash)
|
||||
alice <# ("", "", DOWN server ["bob"])
|
||||
bob #: ("2", "alice", "SEND F 11\nhello again") #> ("2", "alice", MID 5)
|
||||
bob #:# "nothing else delivered before the server is restarted"
|
||||
alice #:# "nothing else delivered before the server is restarted"
|
||||
|
||||
withServer $ do
|
||||
bob <# ("", "alice", SENT 6)
|
||||
alice <# ("", "bob", UP)
|
||||
bob <# ("", "alice", SENT 5)
|
||||
alice <# ("", "", UP server ["bob"])
|
||||
alice <#= \case ("", "bob", Msg "hello again") -> True; _ -> False
|
||||
alice #: ("12", "bob", "ACK 6") #> ("12", "bob", OK)
|
||||
alice #: ("12", "bob", "ACK 5") #> ("12", "bob", OK)
|
||||
|
||||
removeFile testStoreLogFile
|
||||
where
|
||||
|
@ -309,9 +314,9 @@ testServerConnectionAfterError t _ = do
|
|||
withServer $ do
|
||||
connect (bob, "bob") (alice, "alice")
|
||||
|
||||
bob <# ("", "alice", DOWN)
|
||||
alice <# ("", "bob", DOWN)
|
||||
alice #: ("1", "bob", "SEND 5\nhello") #> ("1", "bob", MID 5)
|
||||
bob <# ("", "", DOWN server ["alice"])
|
||||
alice <# ("", "", DOWN server ["bob"])
|
||||
alice #: ("1", "bob", "SEND F 5\nhello") #> ("1", "bob", MID 4)
|
||||
alice #:# "nothing else delivered before the server is restarted"
|
||||
bob #:# "nothing else delivered before the server is restarted"
|
||||
|
||||
|
@ -320,19 +325,20 @@ testServerConnectionAfterError t _ = do
|
|||
bob #: ("1", "alice", "SUB") #> ("1", "alice", ERR (BROKER NETWORK))
|
||||
alice #: ("1", "bob", "SUB") #> ("1", "bob", ERR (BROKER NETWORK))
|
||||
withServer $ do
|
||||
alice <#= \case ("", "bob", cmd) -> cmd == UP || cmd == SENT 5; _ -> False
|
||||
alice <#= \case ("", "bob", cmd) -> cmd == UP || cmd == SENT 5; _ -> False
|
||||
bob <# ("", "alice", UP)
|
||||
alice <#= \case ("", "bob", SENT 4) -> True; ("", "", UP s ["bob"]) -> s == server; _ -> False
|
||||
alice <#= \case ("", "bob", SENT 4) -> True; ("", "", UP s ["bob"]) -> s == server; _ -> False
|
||||
bob <# ("", "", UP server ["alice"])
|
||||
bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False
|
||||
bob #: ("2", "alice", "ACK 5") #> ("2", "alice", OK)
|
||||
alice #: ("1", "bob", "SEND 11\nhello again") #> ("1", "bob", MID 6)
|
||||
alice <# ("", "bob", SENT 6)
|
||||
bob #: ("2", "alice", "ACK 4") #> ("2", "alice", OK)
|
||||
alice #: ("1", "bob", "SEND F 11\nhello again") #> ("1", "bob", MID 5)
|
||||
alice <# ("", "bob", SENT 5)
|
||||
bob <#= \case ("", "alice", Msg "hello again") -> True; _ -> False
|
||||
|
||||
removeFile testStoreLogFile
|
||||
removeFile testDB
|
||||
removeFile testDB2
|
||||
where
|
||||
server = SMPServer "localhost" testPort2 testKeyHash
|
||||
withServer test' = withSmpServerStoreLogOn (ATransport t) testPort2 (const test') `shouldReturn` ()
|
||||
withAgent1 = withAgent agentTestPort testDB
|
||||
withAgent2 = withAgent agentTestPort2 testDB2
|
||||
|
@ -341,17 +347,18 @@ testServerConnectionAfterError t _ = do
|
|||
|
||||
testMsgDeliveryAgentRestart :: Transport c => TProxy c -> c -> IO ()
|
||||
testMsgDeliveryAgentRestart t bob = do
|
||||
let server = SMPServer "localhost" testPort2 testKeyHash
|
||||
withAgent $ \alice -> do
|
||||
withServer $ do
|
||||
connect (bob, "bob") (alice, "alice")
|
||||
alice #: ("1", "bob", "SEND 5\nhello") #> ("1", "bob", MID 5)
|
||||
alice <# ("", "bob", SENT 5)
|
||||
alice #: ("1", "bob", "SEND F 5\nhello") #> ("1", "bob", MID 4)
|
||||
alice <# ("", "bob", SENT 4)
|
||||
bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False
|
||||
bob #: ("11", "alice", "ACK 5") #> ("11", "alice", OK)
|
||||
bob #: ("11", "alice", "ACK 4") #> ("11", "alice", OK)
|
||||
bob #:# "nothing else delivered before the server is down"
|
||||
|
||||
bob <# ("", "alice", DOWN)
|
||||
alice #: ("2", "bob", "SEND 11\nhello again") #> ("2", "bob", MID 6)
|
||||
bob <# ("", "", DOWN server ["alice"])
|
||||
alice #: ("2", "bob", "SEND F 11\nhello again") #> ("2", "bob", MID 5)
|
||||
alice #:# "nothing else delivered before the server is restarted"
|
||||
bob #:# "nothing else delivered before the server is restarted"
|
||||
|
||||
|
@ -361,11 +368,11 @@ testMsgDeliveryAgentRestart t bob = do
|
|||
alice <#= \case
|
||||
(corrId, "bob", cmd) ->
|
||||
(corrId == "3" && cmd == OK)
|
||||
|| (corrId == "" && cmd == SENT 6)
|
||||
|| (corrId == "" && cmd == SENT 5)
|
||||
_ -> False
|
||||
bob <# ("", "alice", UP)
|
||||
bob <# ("", "", UP server ["alice"])
|
||||
bob <#= \case ("", "alice", Msg "hello again") -> True; _ -> False
|
||||
bob #: ("12", "alice", "ACK 6") #> ("12", "alice", OK)
|
||||
bob #: ("12", "alice", "ACK 5") #> ("12", "alice", OK)
|
||||
|
||||
removeFile testStoreLogFile
|
||||
removeFile testDB
|
||||
|
@ -389,15 +396,15 @@ testConcurrentMsgDelivery _ alice bob = do
|
|||
|
||||
-- the first connection should not be blocked by the second one
|
||||
sendMessage (alice, "alice") (bob, "bob") "hello"
|
||||
-- alice #: ("2", "bob", "SEND :hello") #> ("2", "bob", MID 1)
|
||||
-- alice #: ("2", "bob", "SEND F :hello") #> ("2", "bob", MID 1)
|
||||
-- alice <# ("", "bob", SENT 1)
|
||||
-- bob <#= \case ("", "alice", Msg "hello") -> True; _ -> False
|
||||
-- bob #: ("12", "alice", "ACK 1") #> ("12", "alice", OK)
|
||||
bob #: ("14", "alice", "SEND 9\nhello too") #> ("14", "alice", MID 6)
|
||||
bob <# ("", "alice", SENT 6)
|
||||
bob #: ("14", "alice", "SEND F 9\nhello too") #> ("14", "alice", MID 5)
|
||||
bob <# ("", "alice", SENT 5)
|
||||
-- if delivery is blocked it won't go further
|
||||
alice <#= \case ("", "bob", Msg "hello too") -> True; _ -> False
|
||||
alice #: ("3", "bob", "ACK 6") #> ("3", "bob", OK)
|
||||
alice #: ("3", "bob", "ACK 5") #> ("3", "bob", OK)
|
||||
|
||||
testMsgDeliveryQuotaExceeded :: Transport c => TProxy c -> c -> c -> IO ()
|
||||
testMsgDeliveryQuotaExceeded _ alice bob = do
|
||||
|
@ -406,13 +413,13 @@ testMsgDeliveryQuotaExceeded _ alice bob = do
|
|||
forM_ [1 .. 4 :: Int] $ \i -> do
|
||||
let corrId = bshow i
|
||||
msg = "message " <> bshow i
|
||||
(_, "bob", Right (MID mId)) <- alice #: (corrId, "bob", "SEND :" <> msg)
|
||||
(_, "bob", Right (MID mId)) <- alice #: (corrId, "bob", "SEND F :" <> msg)
|
||||
alice <#= \case ("", "bob", SENT m) -> m == mId; _ -> False
|
||||
(_, "bob", Right (MID _)) <- alice #: ("5", "bob", "SEND :over quota")
|
||||
(_, "bob", Right (MID _)) <- alice #: ("5", "bob", "SEND F :over quota")
|
||||
|
||||
alice #: ("1", "bob2", "SEND :hello") #> ("1", "bob2", MID 5)
|
||||
alice #: ("1", "bob2", "SEND F :hello") #> ("1", "bob2", MID 4)
|
||||
-- if delivery is blocked it won't go further
|
||||
alice <# ("", "bob2", SENT 5)
|
||||
alice <# ("", "bob2", SENT 4)
|
||||
|
||||
connect :: forall c. Transport c => (c, ByteString) -> (c, ByteString) -> IO ()
|
||||
connect (h1, name1) (h2, name2) = do
|
||||
|
@ -427,10 +434,10 @@ connect (h1, name1) (h2, name2) = do
|
|||
|
||||
sendMessage :: Transport c => (c, ConnId) -> (c, ConnId) -> ByteString -> IO ()
|
||||
sendMessage (h1, name1) (h2, name2) msg = do
|
||||
("m1", name2', Right (MID mId)) <- h1 #: ("m1", name2, "SEND :" <> msg)
|
||||
("m1", name2', Right (MID mId)) <- h1 #: ("m1", name2, "SEND F :" <> msg)
|
||||
name2' `shouldBe` name2
|
||||
h1 <#= \case ("", n, SENT m) -> n == name2 && m == mId; _ -> False
|
||||
("", name1', Right (MSG MsgMeta {recipient = (msgId', _)} msg')) <- (h2 <#:)
|
||||
("", name1', Right (MSG MsgMeta {recipient = (msgId', _)} _ msg')) <- (h2 <#:)
|
||||
name1' `shouldBe` name1
|
||||
msg' `shouldBe` msg
|
||||
h2 #: ("m2", name1, "ACK " <> bshow msgId') =#> \case ("m2", n, OK) -> n == name1; _ -> False
|
||||
|
|
|
@ -11,7 +11,7 @@ import Simplex.Messaging.Agent.Protocol
|
|||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Crypto.Ratchet
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Protocol (smpClientVRange)
|
||||
import Simplex.Messaging.Protocol (ProtocolServer (..), smpClientVRange)
|
||||
import Simplex.Messaging.Version
|
||||
import Test.Hspec
|
||||
|
||||
|
@ -19,12 +19,7 @@ uri :: String
|
|||
uri = "smp.simplex.im"
|
||||
|
||||
srv :: SMPServer
|
||||
srv =
|
||||
SMPServer
|
||||
{ host = "smp.simplex.im",
|
||||
port = "5223",
|
||||
keyHash = C.KeyHash "\215m\248\251"
|
||||
}
|
||||
srv = SMPServer "smp.simplex.im" "5223" (C.KeyHash "\215m\248\251")
|
||||
|
||||
queue :: SMPQueueUri
|
||||
queue =
|
||||
|
@ -48,7 +43,7 @@ connReqData :: ConnReqUriData
|
|||
connReqData =
|
||||
ConnReqUriData
|
||||
{ crScheme = simplexChat,
|
||||
crAgentVRange = smpAgentVRange,
|
||||
crAgentVRange = mkVersionRange 1 1,
|
||||
crSmpQueues = [queue]
|
||||
}
|
||||
|
||||
|
@ -70,7 +65,7 @@ connectionRequest12 :: AConnectionRequestUri
|
|||
connectionRequest12 =
|
||||
ACR SCMInvitation $
|
||||
CRInvitationUri
|
||||
connReqData {crAgentVRange = mkVersionRange 1 2, crSmpQueues = [queue, queue]}
|
||||
connReqData {crAgentVRange = supportedSMPAgentVRange, crSmpQueues = [queue, queue]}
|
||||
testE2ERatchetParams13
|
||||
|
||||
connectionRequestTests :: Spec
|
||||
|
|
|
@ -1,24 +1,40 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
|
||||
|
||||
module AgentTests.FunctionalAPITests (functionalAPITests) where
|
||||
module AgentTests.FunctionalAPITests
|
||||
( functionalAPITests,
|
||||
makeConnection,
|
||||
get,
|
||||
(##>),
|
||||
(=##>),
|
||||
pattern Msg,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent (threadDelay)
|
||||
import Control.Monad.Except (ExceptT, runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Data.Time.Clock.System (SystemTime (..), getSystemTime)
|
||||
import SMPAgentClient
|
||||
import SMPClient (testPort, withSmpServer, withSmpServerStoreLogOn)
|
||||
import SMPClient (cfg, testPort, withSmpServer, withSmpServerConfigOn, withSmpServerStoreLogOn, withSmpServerStoreMsgLogOn)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Client (ProtocolClientConfig (..))
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Server.Env.STM (ServerConfig (..))
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Transport (ATransport (..))
|
||||
import System.Timeout
|
||||
import Simplex.Messaging.Version
|
||||
import Test.Hspec
|
||||
import UnliftIO.STM
|
||||
import UnliftIO
|
||||
|
||||
(##>) :: MonadIO m => m (ATransmission 'Agent) -> ATransmission 'Agent -> m ()
|
||||
a ##> t = a >>= \t' -> liftIO (t' `shouldBe` t)
|
||||
|
@ -30,13 +46,36 @@ get :: MonadIO m => AgentClient -> m (ATransmission 'Agent)
|
|||
get c = atomically (readTBQueue $ subQ c)
|
||||
|
||||
pattern Msg :: MsgBody -> ACommand 'Agent
|
||||
pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} msgBody
|
||||
pattern Msg msgBody <- MSG MsgMeta {integrity = MsgOk} _ msgBody
|
||||
|
||||
smpCfgV1 :: ProtocolClientConfig
|
||||
smpCfgV1 = (smpCfg agentCfg) {smpServerVRange = mkVersionRange 1 1}
|
||||
|
||||
agentCfgV1 :: AgentConfig
|
||||
agentCfgV1 = agentCfg {smpAgentVersion = 1, smpAgentVRange = mkVersionRange 1 1, smpCfg = smpCfgV1}
|
||||
|
||||
functionalAPITests :: ATransport -> Spec
|
||||
functionalAPITests t = do
|
||||
describe "Establishing duplex connection" $
|
||||
it "should connect via one server using SMP agent clients" $
|
||||
withSmpServer t testAgentClient
|
||||
describe "Duplex connection between agent versions 1 and 2" $ do
|
||||
it "should connect agent v1 to v1" $
|
||||
withSmpServer t testAgentClientV1toV1
|
||||
it "should connect agent v1 to v2" $
|
||||
withSmpServer t testAgentClientV1toV2
|
||||
it "should connect agent v2 to v1" $
|
||||
withSmpServer t testAgentClientV2toV1
|
||||
describe "Establish duplex connection via contact address" $
|
||||
it "should connect via one server using SMP agent clients" $
|
||||
withSmpServer t testAgentClientContact
|
||||
describe "Duplex connection via contact address between agent versions 1 and 2" $ do
|
||||
it "should connect agent v1 to v1" $
|
||||
withSmpServer t testAgentClientContactV1toV1
|
||||
it "should connect agent v1 to v2" $
|
||||
withSmpServer t testAgentClientContactV1toV2
|
||||
it "should connect agent v2 to v1" $
|
||||
withSmpServer t testAgentClientContactV2toV1
|
||||
describe "Establishing connection asynchronously" $ do
|
||||
it "should connect with initiating client going offline" $
|
||||
withSmpServer t testAsyncInitiatingOffline
|
||||
|
@ -48,11 +87,72 @@ functionalAPITests t = do
|
|||
testAsyncServerOffline t
|
||||
it "should notify after HELLO timeout" $
|
||||
withSmpServer t testAsyncHelloTimeout
|
||||
describe "Duplicate message delivery" $
|
||||
it "should deliver messages to the user once, even if repeat delivery is made by the server (no ACK)" $
|
||||
testDuplicateMessage t
|
||||
describe "Inactive client disconnection" $ do
|
||||
it "should disconnect clients if it was inactive longer than TTL" $
|
||||
testInactiveClientDisconnected t
|
||||
it "should NOT disconnect active clients" $
|
||||
testActiveClientNotDisconnected t
|
||||
describe "Suspending agent" $ do
|
||||
it "should update client when agent is suspended" $
|
||||
withSmpServer t testSuspendingAgent
|
||||
it "should complete sending messages when agent is suspended" $
|
||||
testSuspendingAgentCompleteSending t
|
||||
it "should suspend agent on timeout, even if pending messages not sent" $
|
||||
testSuspendingAgentTimeout t
|
||||
|
||||
testAgentClient :: IO ()
|
||||
testAgentClient = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2}
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
runAgentClientTest alice bob 3
|
||||
|
||||
testAgentClientV1toV1 :: IO ()
|
||||
testAgentClientV1toV1 = do
|
||||
alice <- getSMPAgentClient agentCfgV1 initAgentServers
|
||||
bob <- getSMPAgentClient agentCfgV1 {dbFile = testDB2} initAgentServers
|
||||
runAgentClientTest alice bob 4
|
||||
|
||||
testAgentClientV1toV2 :: IO ()
|
||||
testAgentClientV1toV2 = do
|
||||
alice <- getSMPAgentClient agentCfgV1 initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
runAgentClientTest alice bob 4
|
||||
|
||||
testAgentClientV2toV1 :: IO ()
|
||||
testAgentClientV2toV1 = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfgV1 {dbFile = testDB2} initAgentServers
|
||||
runAgentClientTest alice bob 4
|
||||
|
||||
testAgentClientContact :: IO ()
|
||||
testAgentClientContact = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
runAgentClientContactTest alice bob 3
|
||||
|
||||
testAgentClientContactV1toV1 :: IO ()
|
||||
testAgentClientContactV1toV1 = do
|
||||
alice <- getSMPAgentClient agentCfgV1 initAgentServers
|
||||
bob <- getSMPAgentClient agentCfgV1 {dbFile = testDB2} initAgentServers
|
||||
runAgentClientContactTest alice bob 4
|
||||
|
||||
testAgentClientContactV1toV2 :: IO ()
|
||||
testAgentClientContactV1toV2 = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfgV1 {dbFile = testDB2} initAgentServers
|
||||
runAgentClientContactTest alice bob 4
|
||||
|
||||
testAgentClientContactV2toV1 :: IO ()
|
||||
testAgentClientContactV2toV1 = do
|
||||
alice <- getSMPAgentClient agentCfgV1 initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
runAgentClientContactTest alice bob 4
|
||||
|
||||
runAgentClientTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
|
||||
runAgentClientTest alice bob baseId = do
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, qInfo) <- createConnection alice SCMInvitation
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
|
@ -61,47 +161,87 @@ testAgentClient = do
|
|||
get alice ##> ("", bobId, CON)
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
get bob ##> ("", aliceId, CON)
|
||||
-- message IDs 1 to 4 get assigned to control messages, so first MSG is assigned ID 5
|
||||
5 <- sendMessage alice bobId "hello"
|
||||
get alice ##> ("", bobId, SENT 5)
|
||||
6 <- sendMessage alice bobId "how are you?"
|
||||
get alice ##> ("", bobId, SENT 6)
|
||||
-- message IDs 1 to 3 (or 1 to 4 in v1) get assigned to control messages, so first MSG is assigned ID 4
|
||||
1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello"
|
||||
get alice ##> ("", bobId, SENT $ baseId + 1)
|
||||
2 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?"
|
||||
get alice ##> ("", bobId, SENT $ baseId + 2)
|
||||
get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False
|
||||
ackMessage bob aliceId 5
|
||||
ackMessage bob aliceId $ baseId + 1
|
||||
get bob =##> \case ("", c, Msg "how are you?") -> c == aliceId; _ -> False
|
||||
ackMessage bob aliceId 6
|
||||
7 <- sendMessage bob aliceId "hello too"
|
||||
get bob ##> ("", aliceId, SENT 7)
|
||||
8 <- sendMessage bob aliceId "message 1"
|
||||
get bob ##> ("", aliceId, SENT 8)
|
||||
ackMessage bob aliceId $ baseId + 2
|
||||
3 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 3)
|
||||
4 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 1"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 4)
|
||||
get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId 7
|
||||
ackMessage alice bobId $ baseId + 3
|
||||
get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId 8
|
||||
ackMessage alice bobId $ baseId + 4
|
||||
suspendConnection alice bobId
|
||||
9 <- sendMessage bob aliceId "message 2"
|
||||
get bob ##> ("", aliceId, MERR 9 (SMP AUTH))
|
||||
5 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2"
|
||||
get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH))
|
||||
deleteConnection alice bobId
|
||||
liftIO $ noMessages alice "nothing else should be delivered to alice"
|
||||
pure ()
|
||||
where
|
||||
noMessages :: AgentClient -> String -> Expectation
|
||||
noMessages c err = tryGet `shouldReturn` ()
|
||||
where
|
||||
tryGet =
|
||||
10000 `timeout` get c >>= \case
|
||||
Just _ -> error err
|
||||
_ -> return ()
|
||||
msgId = subtract baseId
|
||||
|
||||
runAgentClientContactTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
|
||||
runAgentClientContactTest alice bob baseId = do
|
||||
Right () <- runExceptT $ do
|
||||
(_, qInfo) <- createConnection alice SCMContact
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
("", _, REQ invId "bob's connInfo") <- get alice
|
||||
bobId <- acceptContact alice invId "alice's connInfo"
|
||||
("", _, CONF confId "alice's connInfo") <- get bob
|
||||
allowConnection bob aliceId confId "bob's connInfo"
|
||||
get alice ##> ("", bobId, INFO "bob's connInfo")
|
||||
get alice ##> ("", bobId, CON)
|
||||
get bob ##> ("", aliceId, CON)
|
||||
-- message IDs 1 to 3 (or 1 to 4 in v1) get assigned to control messages, so first MSG is assigned ID 4
|
||||
1 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "hello"
|
||||
get alice ##> ("", bobId, SENT $ baseId + 1)
|
||||
2 <- msgId <$> sendMessage alice bobId SMP.noMsgFlags "how are you?"
|
||||
get alice ##> ("", bobId, SENT $ baseId + 2)
|
||||
get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False
|
||||
ackMessage bob aliceId $ baseId + 1
|
||||
get bob =##> \case ("", c, Msg "how are you?") -> c == aliceId; _ -> False
|
||||
ackMessage bob aliceId $ baseId + 2
|
||||
3 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "hello too"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 3)
|
||||
4 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 1"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 4)
|
||||
get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId $ baseId + 3
|
||||
get alice =##> \case ("", c, Msg "message 1") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId $ baseId + 4
|
||||
suspendConnection alice bobId
|
||||
5 <- msgId <$> sendMessage bob aliceId SMP.noMsgFlags "message 2"
|
||||
get bob ##> ("", aliceId, MERR (baseId + 5) (SMP AUTH))
|
||||
deleteConnection alice bobId
|
||||
liftIO $ noMessages alice "nothing else should be delivered to alice"
|
||||
pure ()
|
||||
where
|
||||
msgId = subtract baseId
|
||||
|
||||
noMessages :: AgentClient -> String -> Expectation
|
||||
noMessages c err = tryGet `shouldReturn` ()
|
||||
where
|
||||
tryGet =
|
||||
10000 `timeout` get c >>= \case
|
||||
Just _ -> error err
|
||||
_ -> return ()
|
||||
|
||||
testAsyncInitiatingOffline :: IO ()
|
||||
testAsyncInitiatingOffline = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2}
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, cReq) <- createConnection alice SCMInvitation
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob cReq "bob's connInfo"
|
||||
alice' <- liftIO $ getSMPAgentClient cfg
|
||||
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
|
||||
subscribeConnection alice' bobId
|
||||
("", _, CONF confId "bob's connInfo") <- get alice'
|
||||
allowConnection alice' bobId confId "alice's connInfo"
|
||||
|
@ -113,15 +253,15 @@ testAsyncInitiatingOffline = do
|
|||
|
||||
testAsyncJoiningOfflineBeforeActivation :: IO ()
|
||||
testAsyncJoiningOfflineBeforeActivation = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2}
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, qInfo) <- createConnection alice SCMInvitation
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
disconnectAgentClient bob
|
||||
("", _, CONF confId "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2}
|
||||
bob' <- liftIO $ getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
subscribeConnection bob' aliceId
|
||||
get alice ##> ("", bobId, CON)
|
||||
get bob' ##> ("", aliceId, INFO "alice's connInfo")
|
||||
|
@ -131,18 +271,18 @@ testAsyncJoiningOfflineBeforeActivation = do
|
|||
|
||||
testAsyncBothOffline :: IO ()
|
||||
testAsyncBothOffline = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2}
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, cReq) <- createConnection alice SCMInvitation
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob cReq "bob's connInfo"
|
||||
disconnectAgentClient bob
|
||||
alice' <- liftIO $ getSMPAgentClient cfg
|
||||
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
|
||||
subscribeConnection alice' bobId
|
||||
("", _, CONF confId "bob's connInfo") <- get alice'
|
||||
allowConnection alice' bobId confId "alice's connInfo"
|
||||
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2}
|
||||
bob' <- liftIO $ getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
subscribeConnection bob' aliceId
|
||||
get alice' ##> ("", bobId, CON)
|
||||
get bob' ##> ("", aliceId, INFO "alice's connInfo")
|
||||
|
@ -152,19 +292,22 @@ testAsyncBothOffline = do
|
|||
|
||||
testAsyncServerOffline :: ATransport -> IO ()
|
||||
testAsyncServerOffline t = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2}
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
-- create connection and shutdown the server
|
||||
Right (bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ ->
|
||||
runExceptT $ createConnection alice SCMInvitation
|
||||
-- connection fails
|
||||
Left (BROKER NETWORK) <- runExceptT $ joinConnection bob cReq "bob's connInfo"
|
||||
("", bobId1, DOWN) <- get alice
|
||||
bobId1 `shouldBe` bobId
|
||||
("", "", DOWN srv conns) <- get alice
|
||||
srv `shouldBe` testSMPServer
|
||||
conns `shouldBe` [bobId]
|
||||
-- connection succeeds after server start
|
||||
Right () <- withSmpServerStoreLogOn t testPort $ \_ -> runExceptT $ do
|
||||
("", bobId2, UP) <- get alice
|
||||
liftIO $ bobId2 `shouldBe` bobId
|
||||
("", "", UP srv1 conns1) <- get alice
|
||||
liftIO $ do
|
||||
srv1 `shouldBe` testSMPServer
|
||||
conns1 `shouldBe` [bobId]
|
||||
aliceId <- joinConnection bob cReq "bob's connInfo"
|
||||
("", _, CONF confId "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
|
@ -176,8 +319,9 @@ testAsyncServerOffline t = do
|
|||
|
||||
testAsyncHelloTimeout :: IO ()
|
||||
testAsyncHelloTimeout = do
|
||||
alice <- getSMPAgentClient cfg
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2, helloTimeout = 1}
|
||||
-- this test would only work if any of the agent is v1, there is no HELLO timeout in v2
|
||||
alice <- getSMPAgentClient agentCfgV1 initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2, helloTimeout = 1} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(_, cReq) <- createConnection alice SCMInvitation
|
||||
disconnectAgentClient alice
|
||||
|
@ -185,13 +329,187 @@ testAsyncHelloTimeout = do
|
|||
get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED)
|
||||
pure ()
|
||||
|
||||
testDuplicateMessage :: ATransport -> IO ()
|
||||
testDuplicateMessage t = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
(aliceId, bobId, bob1) <- withSmpServerStoreMsgLogOn t testPort $ \_ -> do
|
||||
Right (aliceId, bobId) <- runExceptT $ makeConnection alice bob
|
||||
Right () <- runExceptT $ do
|
||||
4 <- sendMessage alice bobId SMP.noMsgFlags "hello"
|
||||
get alice ##> ("", bobId, SENT 4)
|
||||
get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False
|
||||
disconnectAgentClient bob
|
||||
|
||||
-- if the agent user did not send ACK, the message will be delivered again
|
||||
bob1 <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
subscribeConnection bob1 aliceId
|
||||
get bob1 =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False
|
||||
ackMessage bob1 aliceId 4
|
||||
5 <- sendMessage alice bobId SMP.noMsgFlags "hello 2"
|
||||
get alice ##> ("", bobId, SENT 5)
|
||||
get bob1 =##> \case ("", c, Msg "hello 2") -> c == aliceId; _ -> False
|
||||
|
||||
pure (aliceId, bobId, bob1)
|
||||
|
||||
get alice =##> \case ("", "", DOWN _ [c]) -> c == bobId; _ -> False
|
||||
get bob1 =##> \case ("", "", DOWN _ [c]) -> c == aliceId; _ -> False
|
||||
-- commenting two lines below and uncommenting further two lines would also pass,
|
||||
-- it is the scenario tested above, when the message was not acknowledged by the user
|
||||
threadDelay 200000
|
||||
Left (BROKER TIMEOUT) <- runExceptT $ ackMessage bob1 aliceId 5
|
||||
|
||||
disconnectAgentClient alice
|
||||
disconnectAgentClient bob1
|
||||
|
||||
alice2 <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob2 <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
|
||||
withSmpServerStoreMsgLogOn t testPort $ \_ -> do
|
||||
Right () <- runExceptT $ do
|
||||
subscribeConnection bob2 aliceId
|
||||
subscribeConnection alice2 bobId
|
||||
-- get bob2 =##> \case ("", c, Msg "hello 2") -> c == aliceId; _ -> False
|
||||
-- ackMessage bob2 aliceId 5
|
||||
-- message 2 is not delivered again, even though it was delivered to the agent
|
||||
6 <- sendMessage alice2 bobId SMP.noMsgFlags "hello 3"
|
||||
get alice2 ##> ("", bobId, SENT 6)
|
||||
get bob2 =##> \case ("", c, Msg "hello 3") -> c == aliceId; _ -> False
|
||||
pure ()
|
||||
|
||||
makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId)
|
||||
makeConnection alice bob = do
|
||||
(bobId, qInfo) <- createConnection alice SCMInvitation
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
("", _, CONF confId "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get alice ##> ("", bobId, CON)
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
get bob ##> ("", aliceId, CON)
|
||||
pure (aliceId, bobId)
|
||||
|
||||
testInactiveClientDisconnected :: ATransport -> IO ()
|
||||
testInactiveClientDisconnected t = do
|
||||
let cfg' = cfg {inactiveClientExpiration = Just ExpirationConfig {ttl = 1, checkInterval = 1}}
|
||||
withSmpServerConfigOn t cfg' testPort $ \_ -> do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(connId, _cReq) <- createConnection alice SCMInvitation
|
||||
get alice ##> ("", "", DOWN testSMPServer [connId])
|
||||
pure ()
|
||||
|
||||
testActiveClientNotDisconnected :: ATransport -> IO ()
|
||||
testActiveClientNotDisconnected t = do
|
||||
let cfg' = cfg {inactiveClientExpiration = Just ExpirationConfig {ttl = 1, checkInterval = 1}}
|
||||
withSmpServerConfigOn t cfg' testPort $ \_ -> do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
ts <- getSystemTime
|
||||
Right () <- runExceptT $ do
|
||||
(connId, _cReq) <- createConnection alice SCMInvitation
|
||||
keepSubscribing alice connId ts
|
||||
pure ()
|
||||
where
|
||||
keepSubscribing :: AgentClient -> ConnId -> SystemTime -> ExceptT AgentErrorType IO ()
|
||||
keepSubscribing alice connId ts = do
|
||||
ts' <- liftIO $ getSystemTime
|
||||
if milliseconds ts' - milliseconds ts < 2200
|
||||
then do
|
||||
-- keep sending SUB for 2.2 seconds
|
||||
liftIO $ threadDelay 200000
|
||||
subscribeConnection alice connId
|
||||
keepSubscribing alice connId ts
|
||||
else do
|
||||
-- check that nothing is sent from agent
|
||||
Nothing <- 800000 `timeout` get alice
|
||||
liftIO $ threadDelay 1200000
|
||||
-- and after 2 sec of inactivity DOWN is sent
|
||||
get alice ##> ("", "", DOWN testSMPServer [connId])
|
||||
milliseconds ts = systemSeconds ts * 1000 + fromIntegral (systemNanoseconds ts `div` 1000000)
|
||||
|
||||
testSuspendingAgent :: IO ()
|
||||
testSuspendingAgent = do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers
|
||||
b <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
4 <- sendMessage a bId SMP.noMsgFlags "hello"
|
||||
get a ##> ("", bId, SENT 4)
|
||||
get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False
|
||||
ackMessage b aId 4
|
||||
suspendAgent b 1000000
|
||||
get b ##> ("", "", SUSPENDED)
|
||||
5 <- sendMessage a bId SMP.noMsgFlags "hello 2"
|
||||
get a ##> ("", bId, SENT 5)
|
||||
Nothing <- 100000 `timeout` get b
|
||||
activateAgent b
|
||||
get b =##> \case ("", c, Msg "hello 2") -> c == aId; _ -> False
|
||||
pure ()
|
||||
|
||||
testSuspendingAgentCompleteSending :: ATransport -> IO ()
|
||||
testSuspendingAgentCompleteSending t = do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers
|
||||
b <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right (aId, bId) <- withSmpServerStoreLogOn t testPort $ \_ -> runExceptT $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
4 <- sendMessage a bId SMP.noMsgFlags "hello"
|
||||
get a ##> ("", bId, SENT 4)
|
||||
get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False
|
||||
ackMessage b aId 4
|
||||
pure (aId, bId)
|
||||
|
||||
Right () <- runExceptT $ do
|
||||
("", "", DOWN {}) <- get a
|
||||
("", "", DOWN {}) <- get b
|
||||
5 <- sendMessage b aId SMP.noMsgFlags "hello too"
|
||||
6 <- sendMessage b aId SMP.noMsgFlags "how are you?"
|
||||
liftIO $ threadDelay 100000
|
||||
suspendAgent b 5000000
|
||||
|
||||
Right () <- withSmpServerStoreLogOn t testPort $ \_ -> runExceptT $ do
|
||||
get b =##> \case ("", c, SENT 5) -> c == aId; ("", "", UP {}) -> True; _ -> False
|
||||
get b =##> \case ("", c, SENT 5) -> c == aId; ("", "", UP {}) -> True; _ -> False
|
||||
get b =##> \case ("", c, SENT 6) -> c == aId; ("", "", UP {}) -> True; _ -> False
|
||||
("", "", SUSPENDED) <- get b
|
||||
|
||||
("", "", UP {}) <- get a
|
||||
get a =##> \case ("", c, Msg "hello too") -> c == bId; _ -> False
|
||||
ackMessage a bId 5
|
||||
get a =##> \case ("", c, Msg "how are you?") -> c == bId; _ -> False
|
||||
ackMessage a bId 6
|
||||
|
||||
pure ()
|
||||
|
||||
testSuspendingAgentTimeout :: ATransport -> IO ()
|
||||
testSuspendingAgentTimeout t = do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers
|
||||
b <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right (aId, _) <- withSmpServer t . runExceptT $ do
|
||||
(aId, bId) <- makeConnection a b
|
||||
4 <- sendMessage a bId SMP.noMsgFlags "hello"
|
||||
get a ##> ("", bId, SENT 4)
|
||||
get b =##> \case ("", c, Msg "hello") -> c == aId; _ -> False
|
||||
ackMessage b aId 4
|
||||
pure (aId, bId)
|
||||
|
||||
Right () <- runExceptT $ do
|
||||
("", "", DOWN {}) <- get a
|
||||
("", "", DOWN {}) <- get b
|
||||
5 <- sendMessage b aId SMP.noMsgFlags "hello too"
|
||||
6 <- sendMessage b aId SMP.noMsgFlags "how are you?"
|
||||
suspendAgent b 100000
|
||||
("", "", SUSPENDED) <- get b
|
||||
pure ()
|
||||
|
||||
pure ()
|
||||
|
||||
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
|
||||
exchangeGreetings alice bobId bob aliceId = do
|
||||
5 <- sendMessage alice bobId "hello"
|
||||
get alice ##> ("", bobId, SENT 5)
|
||||
4 <- sendMessage alice bobId SMP.noMsgFlags "hello"
|
||||
get alice ##> ("", bobId, SENT 4)
|
||||
get bob =##> \case ("", c, Msg "hello") -> c == aliceId; _ -> False
|
||||
ackMessage bob aliceId 5
|
||||
6 <- sendMessage bob aliceId "hello too"
|
||||
get bob ##> ("", aliceId, SENT 6)
|
||||
ackMessage bob aliceId 4
|
||||
5 <- sendMessage bob aliceId SMP.noMsgFlags "hello too"
|
||||
get bob ##> ("", aliceId, SENT 5)
|
||||
get alice =##> \case ("", c, Msg "hello too") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId 6
|
||||
ackMessage alice bobId 5
|
||||
|
|
|
@ -0,0 +1,515 @@
|
|||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
|
||||
|
||||
module AgentTests.NotificationTests where
|
||||
|
||||
-- import Control.Logger.Simple (LogConfig (..), LogLevel (..), setLogLevel, withGlobalLogging)
|
||||
|
||||
import AgentTests.FunctionalAPITests (get, makeConnection, (##>), (=##>), pattern Msg)
|
||||
import Control.Concurrent (killThread, threadDelay)
|
||||
import Control.Monad.Except
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Types as JT
|
||||
import Data.Bifunctor (bimap, first)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import NtfClient
|
||||
import SMPAgentClient (agentCfg, initAgentServers, testDB, testDB2)
|
||||
import SMPClient (testPort, withSmpServer, withSmpServerStoreLogOn)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Client (AgentClient)
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
import Simplex.Messaging.Notifications.Types (NtfToken (..))
|
||||
import Simplex.Messaging.Protocol (ErrorType (AUTH), MsgFlags (MsgFlags), SMPMsgMeta (..))
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import Simplex.Messaging.Transport (ATransport)
|
||||
import Simplex.Messaging.Util (tryE)
|
||||
import System.Directory (doesFileExist, removeFile)
|
||||
import Test.Hspec
|
||||
import UnliftIO
|
||||
|
||||
removeFileIfExists :: FilePath -> IO ()
|
||||
removeFileIfExists filePath = do
|
||||
fileExists <- doesFileExist filePath
|
||||
when fileExists $ removeFile filePath
|
||||
|
||||
notificationTests :: ATransport -> Spec
|
||||
notificationTests t =
|
||||
after_ (removeFile testDB >> removeFileIfExists testDB2) $ do
|
||||
describe "Managing notification tokens" $ do
|
||||
it "should register and verify notification token" $
|
||||
withAPNSMockServer $ \apns ->
|
||||
withNtfServer t $ testNotificationToken apns
|
||||
it "should allow repeated registration with the same credentials" $ \_ ->
|
||||
withAPNSMockServer $ \apns ->
|
||||
withNtfServer t $ testNtfTokenRepeatRegistration apns
|
||||
it "should allow the second registration with different credentials and delete the first after verification" $ \_ ->
|
||||
withAPNSMockServer $ \apns ->
|
||||
withNtfServer t $ testNtfTokenSecondRegistration apns
|
||||
it "should re-register token when notification server is restarted" $ \_ ->
|
||||
withAPNSMockServer $ \apns ->
|
||||
testNtfTokenServerRestart t apns
|
||||
describe "Managing notification subscriptions" $ do
|
||||
it "should create notification subscription for existing connection" $ \_ ->
|
||||
withSmpServer t $
|
||||
withAPNSMockServer $ \apns ->
|
||||
withNtfServer t $ testNotificationSubscriptionExistingConnection apns
|
||||
it "should create notification subscription for new connection" $ \_ ->
|
||||
withSmpServer t $
|
||||
withAPNSMockServer $ \apns ->
|
||||
withNtfServer t $ testNotificationSubscriptionNewConnection apns
|
||||
it "should change notifications mode" $ \_ ->
|
||||
withSmpServer t $
|
||||
withAPNSMockServer $ \apns ->
|
||||
withNtfServer t $ testChangeNotificationsMode apns
|
||||
it "should change token" $ \_ ->
|
||||
withSmpServer t $
|
||||
withAPNSMockServer $ \apns ->
|
||||
withNtfServer t $ testChangeToken apns
|
||||
describe "Notifications server store log" $
|
||||
it "should save and restore tokens and subscriptions" $ \_ ->
|
||||
withSmpServer t $
|
||||
withAPNSMockServer $ \apns ->
|
||||
testNotificationsStoreLog t apns
|
||||
describe "Notifications after SMP server restart" $
|
||||
it "should resume subscriptions after SMP server is restarted" $ \_ ->
|
||||
withAPNSMockServer $ \apns ->
|
||||
withNtfServer t $ testNotificationsSMPRestart t apns
|
||||
|
||||
testNotificationToken :: APNSMockServer -> IO ()
|
||||
testNotificationToken APNSMockServer {apnsQ} = do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
let tkn = DeviceToken PPApnsTest "abcd"
|
||||
NTRegistered <- registerNtfToken a tkn NMPeriodic
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
verification <- ntfData .-> "verification"
|
||||
nonce <- C.cbNonce <$> ntfData .-> "nonce"
|
||||
liftIO $ sendApnsResponse APNSRespOk
|
||||
verifyNtfToken a tkn nonce verification
|
||||
NTActive <- checkNtfToken a tkn
|
||||
deleteNtfToken a tkn
|
||||
-- agent deleted this token
|
||||
Left (CMD PROHIBITED) <- tryE $ checkNtfToken a tkn
|
||||
pure ()
|
||||
pure ()
|
||||
|
||||
(.->) :: J.Value -> J.Key -> ExceptT AgentErrorType IO ByteString
|
||||
v .-> key = do
|
||||
J.Object o <- pure v
|
||||
liftEither . bimap INTERNAL (U.decodeLenient . encodeUtf8) $ JT.parseEither (J..: key) o
|
||||
|
||||
-- logCfg :: LogConfig
|
||||
-- logCfg = LogConfig {lc_file = Nothing, lc_stderr = True}
|
||||
|
||||
testNtfTokenRepeatRegistration :: APNSMockServer -> IO ()
|
||||
testNtfTokenRepeatRegistration APNSMockServer {apnsQ} = do
|
||||
-- setLogLevel LogError -- LogDebug
|
||||
-- withGlobalLogging logCfg $ do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
let tkn = DeviceToken PPApnsTest "abcd"
|
||||
NTRegistered <- registerNtfToken a tkn NMPeriodic
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
verification <- ntfData .-> "verification"
|
||||
nonce <- C.cbNonce <$> ntfData .-> "nonce"
|
||||
liftIO $ sendApnsResponse APNSRespOk
|
||||
NTRegistered <- registerNtfToken a tkn NMPeriodic
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
_ <- ntfData' .-> "verification"
|
||||
_ <- C.cbNonce <$> ntfData' .-> "nonce"
|
||||
liftIO $ sendApnsResponse' APNSRespOk
|
||||
-- can still use the first verification code, it is the same after decryption
|
||||
verifyNtfToken a tkn nonce verification
|
||||
NTActive <- checkNtfToken a tkn
|
||||
pure ()
|
||||
pure ()
|
||||
|
||||
testNtfTokenSecondRegistration :: APNSMockServer -> IO ()
|
||||
testNtfTokenSecondRegistration APNSMockServer {apnsQ} = do
|
||||
-- setLogLevel LogError -- LogDebug
|
||||
-- withGlobalLogging logCfg $ do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers
|
||||
a' <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
let tkn = DeviceToken PPApnsTest "abcd"
|
||||
NTRegistered <- registerNtfToken a tkn NMPeriodic
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
verification <- ntfData .-> "verification"
|
||||
nonce <- C.cbNonce <$> ntfData .-> "nonce"
|
||||
liftIO $ sendApnsResponse APNSRespOk
|
||||
verifyNtfToken a tkn nonce verification
|
||||
|
||||
NTRegistered <- registerNtfToken a' tkn NMPeriodic
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
verification' <- ntfData' .-> "verification"
|
||||
nonce' <- C.cbNonce <$> ntfData' .-> "nonce"
|
||||
liftIO $ sendApnsResponse' APNSRespOk
|
||||
|
||||
-- at this point the first token is still active
|
||||
NTActive <- checkNtfToken a tkn
|
||||
-- and the second is not yet verified
|
||||
liftIO $ threadDelay 50000
|
||||
NTConfirmed <- checkNtfToken a' tkn
|
||||
-- now the second token registration is verified
|
||||
verifyNtfToken a' tkn nonce' verification'
|
||||
-- the first registration is removed
|
||||
Left (NTF AUTH) <- tryE $ checkNtfToken a tkn
|
||||
-- and the second is active
|
||||
NTActive <- checkNtfToken a' tkn
|
||||
pure ()
|
||||
pure ()
|
||||
|
||||
testNtfTokenServerRestart :: ATransport -> APNSMockServer -> IO ()
|
||||
testNtfTokenServerRestart t APNSMockServer {apnsQ} = do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers
|
||||
let tkn = DeviceToken PPApnsTest "abcd"
|
||||
Right ntfData <- withNtfServer t . runExceptT $ do
|
||||
NTRegistered <- registerNtfToken a tkn NMPeriodic
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
liftIO $ sendApnsResponse APNSRespOk
|
||||
pure ntfData
|
||||
-- the new agent is created as otherwise when running the tests in CI the old agent was keeping the connection to the server
|
||||
threadDelay 1000000
|
||||
disconnectAgentClient a
|
||||
a' <- getSMPAgentClient agentCfg initAgentServers
|
||||
-- server stopped before token is verified, so now the attempt to verify it will return AUTH error but re-register token,
|
||||
-- so that repeat verification happens without restarting the clients, when notification arrives
|
||||
Right () <- withNtfServer t . runExceptT $ do
|
||||
verification <- ntfData .-> "verification"
|
||||
nonce <- C.cbNonce <$> ntfData .-> "nonce"
|
||||
Left (NTF AUTH) <- tryE $ verifyNtfToken a' tkn nonce verification
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
verification' <- ntfData' .-> "verification"
|
||||
nonce' <- C.cbNonce <$> ntfData' .-> "nonce"
|
||||
liftIO $ sendApnsResponse' APNSRespOk
|
||||
verifyNtfToken a' tkn nonce' verification'
|
||||
NTActive <- checkNtfToken a' tkn
|
||||
pure ()
|
||||
pure ()
|
||||
|
||||
testNotificationSubscriptionExistingConnection :: APNSMockServer -> IO ()
|
||||
testNotificationSubscriptionExistingConnection APNSMockServer {apnsQ} = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right (bobId, aliceId, nonce, message) <- runExceptT $ do
|
||||
-- establish connection
|
||||
(bobId, qInfo) <- createConnection alice SCMInvitation
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
("", _, CONF confId "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
get alice ##> ("", bobId, CON)
|
||||
get bob ##> ("", aliceId, CON)
|
||||
-- register notification token
|
||||
let tkn = DeviceToken PPApnsTest "abcd"
|
||||
NTRegistered <- registerNtfToken alice tkn NMInstant
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
verification <- ntfData .-> "verification"
|
||||
vNonce <- C.cbNonce <$> ntfData .-> "nonce"
|
||||
liftIO $ sendApnsResponse APNSRespOk
|
||||
verifyNtfToken alice tkn vNonce verification
|
||||
NTActive <- checkNtfToken alice tkn
|
||||
-- send message
|
||||
liftIO $ threadDelay 50000
|
||||
1 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 1)
|
||||
-- notification
|
||||
(nonce, message) <- messageNotification apnsQ
|
||||
pure (bobId, aliceId, nonce, message)
|
||||
|
||||
-- alice client already has subscription for the connection
|
||||
Left (CMD PROHIBITED) <- runExceptT $ getNotificationMessage alice nonce message
|
||||
|
||||
-- aliceNtf client doesn't have subscription and is allowed to get notification message
|
||||
aliceNtf <- getSMPAgentClient agentCfg initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(_, [SMPMsgMeta {msgFlags = MsgFlags True}]) <- getNotificationMessage aliceNtf nonce message
|
||||
pure ()
|
||||
disconnectAgentClient aliceNtf
|
||||
|
||||
Right () <- runExceptT $ do
|
||||
get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId $ baseId + 1
|
||||
-- delete notification subscription
|
||||
deleteNtfSub alice bobId
|
||||
-- send message
|
||||
2 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello again"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 2)
|
||||
-- no notifications should follow
|
||||
noNotification apnsQ
|
||||
pure ()
|
||||
where
|
||||
baseId = 3
|
||||
msgId = subtract baseId
|
||||
|
||||
testNotificationSubscriptionNewConnection :: APNSMockServer -> IO ()
|
||||
testNotificationSubscriptionNewConnection APNSMockServer {apnsQ} = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
-- alice registers notification token
|
||||
_ <- registerTestToken alice "abcd" NMInstant apnsQ
|
||||
-- bob registers notification token
|
||||
_ <- registerTestToken bob "bcde" NMInstant apnsQ
|
||||
-- establish connection
|
||||
liftIO $ threadDelay 50000
|
||||
(bobId, qInfo) <- createConnection alice SCMInvitation
|
||||
liftIO $ threadDelay 500000
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
liftIO $ print 0
|
||||
void $ messageNotification apnsQ
|
||||
("", _, CONF confId "bob's connInfo") <- get alice
|
||||
liftIO $ threadDelay 500000
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
liftIO $ print 1
|
||||
void $ messageNotification apnsQ
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
liftIO $ print 2
|
||||
void $ messageNotification apnsQ
|
||||
get alice ##> ("", bobId, CON)
|
||||
liftIO $ print 3
|
||||
void $ messageNotification apnsQ
|
||||
get bob ##> ("", aliceId, CON)
|
||||
-- bob sends message
|
||||
1 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 1)
|
||||
liftIO $ print 4
|
||||
void $ messageNotification apnsQ
|
||||
get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId $ baseId + 1
|
||||
-- alice sends message
|
||||
2 <- msgId <$> sendMessage alice bobId (SMP.MsgFlags True) "hey there"
|
||||
get alice ##> ("", bobId, SENT $ baseId + 2)
|
||||
liftIO $ print 5
|
||||
void $ messageNotification apnsQ
|
||||
get bob =##> \case ("", c, Msg "hey there") -> c == aliceId; _ -> False
|
||||
ackMessage bob aliceId $ baseId + 2
|
||||
-- no unexpected notifications should follow
|
||||
noNotification apnsQ
|
||||
pure ()
|
||||
where
|
||||
baseId = 3
|
||||
msgId = subtract baseId
|
||||
|
||||
registerTestToken :: AgentClient -> ByteString -> NotificationsMode -> TBQueue APNSMockRequest -> ExceptT AgentErrorType IO DeviceToken
|
||||
registerTestToken a token mode apnsQ = do
|
||||
let tkn = DeviceToken PPApnsTest token
|
||||
NTRegistered <- registerNtfToken a tkn mode
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData'}, sendApnsResponse = sendApnsResponse'} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
verification' <- ntfData' .-> "verification"
|
||||
nonce' <- C.cbNonce <$> ntfData' .-> "nonce"
|
||||
liftIO $ sendApnsResponse' APNSRespOk
|
||||
verifyNtfToken a tkn nonce' verification'
|
||||
NTActive <- checkNtfToken a tkn
|
||||
pure tkn
|
||||
|
||||
testChangeNotificationsMode :: APNSMockServer -> IO ()
|
||||
testChangeNotificationsMode APNSMockServer {apnsQ} = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
-- establish connection
|
||||
(bobId, qInfo) <- createConnection alice SCMInvitation
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
("", _, CONF confId "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
get alice ##> ("", bobId, CON)
|
||||
get bob ##> ("", aliceId, CON)
|
||||
-- register notification token, set mode to NMInstant
|
||||
tkn <- registerTestToken alice "abcd" NMInstant apnsQ
|
||||
-- send message, receive notification
|
||||
liftIO $ threadDelay 500000
|
||||
1 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 1)
|
||||
void $ messageNotification apnsQ
|
||||
get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId $ baseId + 1
|
||||
-- set mode to NMPeriodic
|
||||
NTActive <- registerNtfToken alice tkn NMPeriodic
|
||||
-- send message, no notification
|
||||
liftIO $ threadDelay 500000
|
||||
2 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello again"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 2)
|
||||
noNotification apnsQ
|
||||
get alice =##> \case ("", c, Msg "hello again") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId $ baseId + 2
|
||||
-- set mode to NMInstant
|
||||
NTActive <- registerNtfToken alice tkn NMInstant
|
||||
-- send message, receive notification
|
||||
liftIO $ threadDelay 500000
|
||||
3 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello there"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 3)
|
||||
void $ messageNotification apnsQ
|
||||
get alice =##> \case ("", c, Msg "hello there") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId $ baseId + 3
|
||||
-- turn off notifications
|
||||
deleteNtfToken alice tkn
|
||||
-- send message, no notification
|
||||
liftIO $ threadDelay 500000
|
||||
4 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "why hello there"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 4)
|
||||
noNotification apnsQ
|
||||
get alice =##> \case ("", c, Msg "why hello there") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId $ baseId + 4
|
||||
-- turn on notifications, set mode to NMInstant
|
||||
void $ registerTestToken alice "abcd" NMInstant apnsQ
|
||||
-- send message, receive notification
|
||||
liftIO $ threadDelay 500000
|
||||
5 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hey"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 5)
|
||||
void $ messageNotification apnsQ
|
||||
get alice =##> \case ("", c, Msg "hey") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId $ baseId + 5
|
||||
-- no notifications should follow
|
||||
noNotification apnsQ
|
||||
pure ()
|
||||
where
|
||||
baseId = 3
|
||||
msgId = subtract baseId
|
||||
|
||||
testChangeToken :: APNSMockServer -> IO ()
|
||||
testChangeToken APNSMockServer {apnsQ} = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right (aliceId, bobId) <- runExceptT $ do
|
||||
-- establish connection
|
||||
(bobId, qInfo) <- createConnection alice SCMInvitation
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
("", _, CONF confId "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
get bob ##> ("", aliceId, INFO "alice's connInfo")
|
||||
get alice ##> ("", bobId, CON)
|
||||
get bob ##> ("", aliceId, CON)
|
||||
-- register notification token, set mode to NMInstant
|
||||
void $ registerTestToken alice "abcd" NMInstant apnsQ
|
||||
-- send message, receive notification
|
||||
liftIO $ threadDelay 500000
|
||||
1 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 1)
|
||||
void $ messageNotification apnsQ
|
||||
get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId $ baseId + 1
|
||||
pure (aliceId, bobId)
|
||||
disconnectAgentClient alice
|
||||
|
||||
alice1 <- getSMPAgentClient agentCfg initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
subscribeConnection alice1 bobId
|
||||
-- change notification token
|
||||
void $ registerTestToken alice1 "bcde" NMInstant apnsQ
|
||||
-- send message, receive notification
|
||||
liftIO $ threadDelay 500000
|
||||
2 <- msgId <$> sendMessage bob aliceId (SMP.MsgFlags True) "hello there"
|
||||
get bob ##> ("", aliceId, SENT $ baseId + 2)
|
||||
void $ messageNotification apnsQ
|
||||
get alice1 =##> \case ("", c, Msg "hello there") -> c == bobId; _ -> False
|
||||
ackMessage alice1 bobId $ baseId + 2
|
||||
-- no notifications should follow
|
||||
noNotification apnsQ
|
||||
pure ()
|
||||
where
|
||||
baseId = 3
|
||||
msgId = subtract baseId
|
||||
|
||||
testNotificationsStoreLog :: ATransport -> APNSMockServer -> IO ()
|
||||
testNotificationsStoreLog t APNSMockServer {apnsQ} = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right (aliceId, bobId) <- withNtfServerStoreLog t $ \threadId -> runExceptT $ do
|
||||
(aliceId, bobId) <- makeConnection alice bob
|
||||
_ <- registerTestToken alice "abcd" NMInstant apnsQ
|
||||
liftIO $ threadDelay 250000
|
||||
4 <- sendMessage bob aliceId (SMP.MsgFlags True) "hello"
|
||||
get bob ##> ("", aliceId, SENT 4)
|
||||
void $ messageNotification apnsQ
|
||||
get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId 4
|
||||
liftIO $ killThread threadId
|
||||
pure (aliceId, bobId)
|
||||
|
||||
liftIO $ threadDelay 250000
|
||||
|
||||
Right () <- withNtfServerStoreLog t $ \threadId -> runExceptT $ do
|
||||
liftIO $ threadDelay 250000
|
||||
5 <- sendMessage bob aliceId (SMP.MsgFlags True) "hello again"
|
||||
get bob ##> ("", aliceId, SENT 5)
|
||||
void $ messageNotification apnsQ
|
||||
get alice =##> \case ("", c, Msg "hello again") -> c == bobId; _ -> False
|
||||
liftIO $ killThread threadId
|
||||
pure ()
|
||||
|
||||
testNotificationsSMPRestart :: ATransport -> APNSMockServer -> IO ()
|
||||
testNotificationsSMPRestart t APNSMockServer {apnsQ} = do
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right (aliceId, bobId) <- withSmpServerStoreLogOn t testPort $ \threadId -> runExceptT $ do
|
||||
(aliceId, bobId) <- makeConnection alice bob
|
||||
_ <- registerTestToken alice "abcd" NMInstant apnsQ
|
||||
liftIO $ threadDelay 250000
|
||||
4 <- sendMessage bob aliceId (SMP.MsgFlags True) "hello"
|
||||
get bob ##> ("", aliceId, SENT 4)
|
||||
void $ messageNotification apnsQ
|
||||
get alice =##> \case ("", c, Msg "hello") -> c == bobId; _ -> False
|
||||
ackMessage alice bobId 4
|
||||
liftIO $ killThread threadId
|
||||
pure (aliceId, bobId)
|
||||
|
||||
Right () <- runExceptT $ do
|
||||
get alice =##> \case ("", "", DOWN _ [c]) -> c == bobId; _ -> False
|
||||
get bob =##> \case ("", "", DOWN _ [c]) -> c == aliceId; _ -> False
|
||||
|
||||
Right () <- withSmpServerStoreLogOn t testPort $ \threadId -> runExceptT $ do
|
||||
get alice =##> \case ("", "", UP _ [c]) -> c == bobId; _ -> False
|
||||
get bob =##> \case ("", "", UP _ [c]) -> c == aliceId; _ -> False
|
||||
liftIO $ threadDelay 1000000
|
||||
5 <- sendMessage bob aliceId (SMP.MsgFlags True) "hello again"
|
||||
get bob ##> ("", aliceId, SENT 5)
|
||||
_ <- messageNotificationData alice apnsQ
|
||||
get alice =##> \case ("", c, Msg "hello again") -> c == bobId; _ -> False
|
||||
liftIO $ killThread threadId
|
||||
pure ()
|
||||
|
||||
messageNotification :: TBQueue APNSMockRequest -> ExceptT AgentErrorType IO (C.CbNonce, ByteString)
|
||||
messageNotification apnsQ = do
|
||||
1000000 `timeout` atomically (readTBQueue apnsQ) >>= \case
|
||||
Nothing -> error "no notification"
|
||||
Just APNSMockRequest {notification = APNSNotification {aps = APNSMutableContent {}, notificationData = Just ntfData}, sendApnsResponse} -> do
|
||||
nonce <- C.cbNonce <$> ntfData .-> "nonce"
|
||||
message <- ntfData .-> "message"
|
||||
liftIO $ sendApnsResponse APNSRespOk
|
||||
pure (nonce, message)
|
||||
_ -> error "bad notification"
|
||||
|
||||
messageNotificationData :: AgentClient -> TBQueue APNSMockRequest -> ExceptT AgentErrorType IO PNMessageData
|
||||
messageNotificationData c apnsQ = do
|
||||
(nonce, message) <- messageNotification apnsQ
|
||||
NtfToken {ntfDhSecret = Just dhSecret} <- getNtfTokenData c
|
||||
Right pnMsgData <- liftEither . first INTERNAL $ Right . strDecode =<< first show (C.cbDecrypt dhSecret nonce message)
|
||||
pure pnMsgData
|
||||
|
||||
noNotification :: TBQueue APNSMockRequest -> ExceptT AgentErrorType IO ()
|
||||
noNotification apnsQ = do
|
||||
500000 `timeout` atomically (readTBQueue apnsQ) >>= \case
|
||||
Nothing -> pure ()
|
||||
_ -> error "unexpected notification"
|
|
@ -11,7 +11,6 @@ module AgentTests.SQLiteTests (storeTests) where
|
|||
import Control.Concurrent.Async (concurrently_)
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad (replicateM_)
|
||||
import Control.Monad.Except (ExceptT, runExceptT)
|
||||
import Crypto.Random (drgNew)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.Text as T
|
||||
|
@ -27,6 +26,7 @@ import Simplex.Messaging.Agent.Store
|
|||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import qualified Simplex.Messaging.Protocol as SMP
|
||||
import System.Random
|
||||
import Test.Hspec
|
||||
import UnliftIO.Directory (removeFile)
|
||||
|
@ -43,7 +43,7 @@ withStore2 = before connect2 . after (removeStore . fst)
|
|||
connect2 :: IO (SQLiteStore, SQLiteStore)
|
||||
connect2 = do
|
||||
s1 <- createStore
|
||||
s2 <- connectSQLiteStore (dbFilePath s1) 4
|
||||
s2 <- connectSQLiteStore (dbFilePath s1)
|
||||
pure (s1, s2)
|
||||
|
||||
createStore :: IO SQLiteStore
|
||||
|
@ -51,21 +51,15 @@ createStore = do
|
|||
-- Randomize DB file name to avoid SQLite IO errors supposedly caused by asynchronous
|
||||
-- IO operations on multiple similarly named files; error seems to be environment specific
|
||||
r <- randomIO :: IO Word32
|
||||
createSQLiteStore (testDB <> show r) 4 Migrations.app True
|
||||
createSQLiteStore (testDB <> show r) Migrations.app True
|
||||
|
||||
removeStore :: SQLiteStore -> IO ()
|
||||
removeStore store = do
|
||||
close store
|
||||
removeFile $ dbFilePath store
|
||||
removeStore db = do
|
||||
close db
|
||||
removeFile $ dbFilePath db
|
||||
where
|
||||
close :: SQLiteStore -> IO ()
|
||||
close st = mapM_ DB.close =<< atomically (flushTBQueue $ dbConnPool st)
|
||||
|
||||
returnsResult :: (Eq a, Eq e, Show a, Show e) => ExceptT e IO a -> a -> Expectation
|
||||
action `returnsResult` r = runExceptT action `shouldReturn` Right r
|
||||
|
||||
throwsError :: (Eq a, Eq e, Show a, Show e) => ExceptT e IO a -> e -> Expectation
|
||||
action `throwsError` e = runExceptT action `shouldReturn` Left e
|
||||
close st = mapM_ DB.close =<< atomically (tryTakeTMVar $ dbConnection st)
|
||||
|
||||
-- TODO add null port tests
|
||||
storeTests :: Spec
|
||||
|
@ -73,10 +67,10 @@ storeTests = do
|
|||
withStore2 $ do
|
||||
describe "stress test" testConcurrentWrites
|
||||
withStore $ do
|
||||
describe "store setup" $ do
|
||||
describe "db setup" $ do
|
||||
testCompiledThreadsafe
|
||||
testForeignKeysEnabled
|
||||
describe "store methods" $ do
|
||||
describe "db methods" $ do
|
||||
describe "Queue and Connection management" $ do
|
||||
describe "createRcvConn" $ do
|
||||
testCreateRcvConn
|
||||
|
@ -111,28 +105,29 @@ testConcurrentWrites :: SpecWith (SQLiteStore, SQLiteStore)
|
|||
testConcurrentWrites =
|
||||
it "should complete multiple concurrent write transactions w/t sqlite busy errors" $ \(s1, s2) -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn s1 g cData1 rcvQueue1 SCMInvitation
|
||||
_ <- withTransaction s1 $ \db ->
|
||||
createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
let ConnData {connId} = cData1
|
||||
concurrently_ (runTest s1 connId) (runTest s2 connId)
|
||||
where
|
||||
runTest :: SQLiteStore -> ConnId -> IO (Either StoreError ())
|
||||
runTest store connId = runExceptT . replicateM_ 100 $ do
|
||||
(internalId, internalRcvId, _, _) <- updateRcvIds store connId
|
||||
runTest :: SQLiteStore -> ConnId -> IO ()
|
||||
runTest st connId = replicateM_ 100 . withTransaction st $ \db -> do
|
||||
(internalId, internalRcvId, _, _) <- updateRcvIds db connId
|
||||
let rcvMsgData = mkRcvMsgData internalId internalRcvId 0 "0" "hash_dummy"
|
||||
createRcvMsg store connId rcvMsgData
|
||||
createRcvMsg db connId rcvMsgData
|
||||
|
||||
testCompiledThreadsafe :: SpecWith SQLiteStore
|
||||
testCompiledThreadsafe =
|
||||
it "compiled sqlite library should be threadsafe" . withStoreConnection $ \db -> do
|
||||
it "compiled sqlite library should be threadsafe" . withStoreTransaction $ \db -> do
|
||||
compileOptions <- DB.query_ db "pragma COMPILE_OPTIONS;" :: IO [[T.Text]]
|
||||
compileOptions `shouldNotContain` [["THREADSAFE=0"]]
|
||||
|
||||
withStoreConnection :: (DB.Connection -> IO a) -> SQLiteStore -> IO a
|
||||
withStoreConnection = flip withConnection
|
||||
withStoreTransaction :: (DB.Connection -> IO a) -> SQLiteStore -> IO a
|
||||
withStoreTransaction = flip withTransaction
|
||||
|
||||
testForeignKeysEnabled :: SpecWith SQLiteStore
|
||||
testForeignKeysEnabled =
|
||||
it "foreign keys should be enabled" . withStoreConnection $ \db -> do
|
||||
it "foreign keys should be enabled" . withStoreTransaction $ \db -> do
|
||||
let inconsistentQuery =
|
||||
[sql|
|
||||
INSERT INTO snd_queues
|
||||
|
@ -144,7 +139,7 @@ testForeignKeysEnabled =
|
|||
`shouldThrow` (\e -> DB.sqlError e == DB.ErrorConstraint)
|
||||
|
||||
cData1 :: ConnData
|
||||
cData1 = ConnData {connId = "conn1"}
|
||||
cData1 = ConnData {connId = "conn1", connAgentVersion = 1, duplexHandshake = Nothing}
|
||||
|
||||
testPrivateSignKey :: C.APrivateSignKey
|
||||
testPrivateSignKey = C.APrivateSignKey C.SEd25519 "MC4CAQAwBQYDK2VwBCIEIDfEfevydXXfKajz3sRkcQ7RPvfWUPoq6pu1TYHV1DEe"
|
||||
|
@ -165,7 +160,8 @@ rcvQueue1 =
|
|||
e2ePrivKey = testPrivDhKey,
|
||||
e2eDhSecret = Nothing,
|
||||
sndId = Just "2345",
|
||||
status = New
|
||||
status = New,
|
||||
clientNtfCreds = Nothing
|
||||
}
|
||||
|
||||
sndQueue1 :: SndQueue
|
||||
|
@ -182,125 +178,125 @@ sndQueue1 =
|
|||
|
||||
testCreateRcvConn :: SpecWith SQLiteStore
|
||||
testCreateRcvConn =
|
||||
it "should create RcvConnection and add SndQueue" $ \store -> do
|
||||
it "should create RcvConnection and add SndQueue" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
createRcvConn store g cData1 rcvQueue1 SCMInvitation
|
||||
`returnsResult` "conn1"
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
|
||||
upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
|
||||
createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
`shouldReturn` Right "conn1"
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rcvQueue1))
|
||||
upgradeRcvConnToDuplex db "conn1" sndQueue1
|
||||
`shouldReturn` Right ()
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
|
||||
|
||||
testCreateRcvConnRandomId :: SpecWith SQLiteStore
|
||||
testCreateRcvConnRandomId =
|
||||
it "should create RcvConnection and add SndQueue with random ID" $ \store -> do
|
||||
it "should create RcvConnection and add SndQueue with random ID" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
Right connId <- runExceptT $ createRcvConn store g cData1 {connId = ""} rcvQueue1 SCMInvitation
|
||||
getConn store connId
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 {connId} rcvQueue1)
|
||||
upgradeRcvConnToDuplex store connId sndQueue1
|
||||
`returnsResult` ()
|
||||
getConn store connId
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1)
|
||||
Right connId <- createRcvConn db g cData1 {connId = ""} rcvQueue1 SCMInvitation
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 {connId} rcvQueue1))
|
||||
upgradeRcvConnToDuplex db connId sndQueue1
|
||||
`shouldReturn` Right ()
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1))
|
||||
|
||||
testCreateRcvConnDuplicate :: SpecWith SQLiteStore
|
||||
testCreateRcvConnDuplicate =
|
||||
it "should throw error on attempt to create duplicate RcvConnection" $ \store -> do
|
||||
it "should throw error on attempt to create duplicate RcvConnection" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 SCMInvitation
|
||||
createRcvConn store g cData1 rcvQueue1 SCMInvitation
|
||||
`throwsError` SEConnDuplicate
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
`shouldReturn` Left (SEConnDuplicate)
|
||||
|
||||
testCreateSndConn :: SpecWith SQLiteStore
|
||||
testCreateSndConn =
|
||||
it "should create SndConnection and add RcvQueue" $ \store -> do
|
||||
it "should create SndConnection and add RcvQueue" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
createSndConn store g cData1 sndQueue1
|
||||
`returnsResult` "conn1"
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1)
|
||||
upgradeSndConnToDuplex store "conn1" rcvQueue1
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
|
||||
createSndConn db g cData1 sndQueue1
|
||||
`shouldReturn` Right "conn1"
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sndQueue1))
|
||||
upgradeSndConnToDuplex db "conn1" rcvQueue1
|
||||
`shouldReturn` Right ()
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
|
||||
|
||||
testCreateSndConnRandomID :: SpecWith SQLiteStore
|
||||
testCreateSndConnRandomID =
|
||||
it "should create SndConnection and add RcvQueue with random ID" $ \store -> do
|
||||
it "should create SndConnection and add RcvQueue with random ID" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
Right connId <- runExceptT $ createSndConn store g cData1 {connId = ""} sndQueue1
|
||||
getConn store connId
|
||||
`returnsResult` SomeConn SCSnd (SndConnection cData1 {connId} sndQueue1)
|
||||
upgradeSndConnToDuplex store connId rcvQueue1
|
||||
`returnsResult` ()
|
||||
getConn store connId
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1)
|
||||
Right connId <- createSndConn db g cData1 {connId = ""} sndQueue1
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 {connId} sndQueue1))
|
||||
upgradeSndConnToDuplex db connId rcvQueue1
|
||||
`shouldReturn` Right ()
|
||||
getConn db connId
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 {connId} rcvQueue1 sndQueue1))
|
||||
|
||||
testCreateSndConnDuplicate :: SpecWith SQLiteStore
|
||||
testCreateSndConnDuplicate =
|
||||
it "should throw error on attempt to create duplicate SndConnection" $ \store -> do
|
||||
it "should throw error on attempt to create duplicate SndConnection" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
|
||||
createSndConn store g cData1 sndQueue1
|
||||
`throwsError` SEConnDuplicate
|
||||
_ <- createSndConn db g cData1 sndQueue1
|
||||
createSndConn db g cData1 sndQueue1
|
||||
`shouldReturn` Left (SEConnDuplicate)
|
||||
|
||||
testGetRcvConn :: SpecWith SQLiteStore
|
||||
testGetRcvConn =
|
||||
it "should get connection using rcv queue id and server" $ \store -> do
|
||||
it "should get connection using rcv queue id and server" . withStoreTransaction $ \db -> do
|
||||
let smpServer = SMPServer "smp.simplex.im" "5223" testKeyHash
|
||||
let recipientId = "1234"
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 SCMInvitation
|
||||
getRcvConn store smpServer recipientId
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
getRcvConn db smpServer recipientId
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rcvQueue1))
|
||||
|
||||
testDeleteRcvConn :: SpecWith SQLiteStore
|
||||
testDeleteRcvConn =
|
||||
it "should create RcvConnection and delete it" $ \store -> do
|
||||
it "should create RcvConnection and delete it" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 SCMInvitation
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
|
||||
deleteConn store "conn1"
|
||||
`returnsResult` ()
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rcvQueue1))
|
||||
deleteConn db "conn1"
|
||||
`shouldReturn` ()
|
||||
-- TODO check queues are deleted as well
|
||||
getConn store "conn1"
|
||||
`throwsError` SEConnNotFound
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Left (SEConnNotFound)
|
||||
|
||||
testDeleteSndConn :: SpecWith SQLiteStore
|
||||
testDeleteSndConn =
|
||||
it "should create SndConnection and delete it" $ \store -> do
|
||||
it "should create SndConnection and delete it" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1)
|
||||
deleteConn store "conn1"
|
||||
`returnsResult` ()
|
||||
_ <- createSndConn db g cData1 sndQueue1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sndQueue1))
|
||||
deleteConn db "conn1"
|
||||
`shouldReturn` ()
|
||||
-- TODO check queues are deleted as well
|
||||
getConn store "conn1"
|
||||
`throwsError` SEConnNotFound
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Left (SEConnNotFound)
|
||||
|
||||
testDeleteDuplexConn :: SpecWith SQLiteStore
|
||||
testDeleteDuplexConn =
|
||||
it "should create DuplexConnection and delete it" $ \store -> do
|
||||
it "should create DuplexConnection and delete it" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 SCMInvitation
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
|
||||
deleteConn store "conn1"
|
||||
`returnsResult` ()
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
_ <- upgradeRcvConnToDuplex db "conn1" sndQueue1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
|
||||
deleteConn db "conn1"
|
||||
`shouldReturn` ()
|
||||
-- TODO check queues are deleted as well
|
||||
getConn store "conn1"
|
||||
`throwsError` SEConnNotFound
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Left (SEConnNotFound)
|
||||
|
||||
testUpgradeRcvConnToDuplex :: SpecWith SQLiteStore
|
||||
testUpgradeRcvConnToDuplex =
|
||||
it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" $ \store -> do
|
||||
it "should throw error on attempt to add SndQueue to SndConnection or DuplexConnection" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
|
||||
_ <- createSndConn db g cData1 sndQueue1
|
||||
let anotherSndQueue =
|
||||
SndQueue
|
||||
{ server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
|
@ -311,17 +307,17 @@ testUpgradeRcvConnToDuplex =
|
|||
e2eDhSecret = testDhSecret,
|
||||
status = New
|
||||
}
|
||||
upgradeRcvConnToDuplex store "conn1" anotherSndQueue
|
||||
`throwsError` SEBadConnType CSnd
|
||||
_ <- runExceptT $ upgradeSndConnToDuplex store "conn1" rcvQueue1
|
||||
upgradeRcvConnToDuplex store "conn1" anotherSndQueue
|
||||
`throwsError` SEBadConnType CDuplex
|
||||
upgradeRcvConnToDuplex db "conn1" anotherSndQueue
|
||||
`shouldReturn` Left (SEBadConnType CSnd)
|
||||
_ <- upgradeSndConnToDuplex db "conn1" rcvQueue1
|
||||
upgradeRcvConnToDuplex db "conn1" anotherSndQueue
|
||||
`shouldReturn` Left (SEBadConnType CDuplex)
|
||||
|
||||
testUpgradeSndConnToDuplex :: SpecWith SQLiteStore
|
||||
testUpgradeSndConnToDuplex =
|
||||
it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" $ \store -> do
|
||||
it "should throw error on attempt to add RcvQueue to RcvConnection or DuplexConnection" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 SCMInvitation
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
let anotherRcvQueue =
|
||||
RcvQueue
|
||||
{ server = SMPServer "smp.simplex.im" "5223" testKeyHash,
|
||||
|
@ -331,54 +327,55 @@ testUpgradeSndConnToDuplex =
|
|||
e2ePrivKey = testPrivDhKey,
|
||||
e2eDhSecret = Nothing,
|
||||
sndId = Just "4567",
|
||||
status = New
|
||||
status = New,
|
||||
clientNtfCreds = Nothing
|
||||
}
|
||||
upgradeSndConnToDuplex store "conn1" anotherRcvQueue
|
||||
`throwsError` SEBadConnType CRcv
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
upgradeSndConnToDuplex store "conn1" anotherRcvQueue
|
||||
`throwsError` SEBadConnType CDuplex
|
||||
upgradeSndConnToDuplex db "conn1" anotherRcvQueue
|
||||
`shouldReturn` Left (SEBadConnType CRcv)
|
||||
_ <- upgradeRcvConnToDuplex db "conn1" sndQueue1
|
||||
upgradeSndConnToDuplex db "conn1" anotherRcvQueue
|
||||
`shouldReturn` Left (SEBadConnType CDuplex)
|
||||
|
||||
testSetRcvQueueStatus :: SpecWith SQLiteStore
|
||||
testSetRcvQueueStatus =
|
||||
it "should update status of RcvQueue" $ \store -> do
|
||||
it "should update status of RcvQueue" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 SCMInvitation
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1)
|
||||
setRcvQueueStatus store rcvQueue1 Confirmed
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCRcv (RcvConnection cData1 rcvQueue1 {status = Confirmed})
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rcvQueue1))
|
||||
setRcvQueueStatus db rcvQueue1 Confirmed
|
||||
`shouldReturn` ()
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCRcv (RcvConnection cData1 rcvQueue1 {status = Confirmed}))
|
||||
|
||||
testSetSndQueueStatus :: SpecWith SQLiteStore
|
||||
testSetSndQueueStatus =
|
||||
it "should update status of SndQueue" $ \store -> do
|
||||
it "should update status of SndQueue" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1)
|
||||
setSndQueueStatus store sndQueue1 Confirmed
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCSnd (SndConnection cData1 sndQueue1 {status = Confirmed})
|
||||
_ <- createSndConn db g cData1 sndQueue1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sndQueue1))
|
||||
setSndQueueStatus db sndQueue1 Confirmed
|
||||
`shouldReturn` ()
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCSnd (SndConnection cData1 sndQueue1 {status = Confirmed}))
|
||||
|
||||
testSetQueueStatusDuplex :: SpecWith SQLiteStore
|
||||
testSetQueueStatusDuplex =
|
||||
it "should update statuses of RcvQueue and SndQueue in DuplexConnection" $ \store -> do
|
||||
it "should update statuses of RcvQueue and SndQueue in DuplexConnection" . withStoreTransaction $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 SCMInvitation
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1)
|
||||
setRcvQueueStatus store rcvQueue1 Secured
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1)
|
||||
setSndQueueStatus store sndQueue1 Confirmed
|
||||
`returnsResult` ()
|
||||
getConn store "conn1"
|
||||
`returnsResult` SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed})
|
||||
_ <- createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
_ <- upgradeRcvConnToDuplex db "conn1" sndQueue1
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 sndQueue1))
|
||||
setRcvQueueStatus db rcvQueue1 Secured
|
||||
`shouldReturn` ()
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1))
|
||||
setSndQueueStatus db sndQueue1 Confirmed
|
||||
`shouldReturn` ()
|
||||
getConn db "conn1"
|
||||
`shouldReturn` Right (SomeConn SCDuplex (DuplexConnection cData1 rcvQueue1 {status = Secured} sndQueue1 {status = Confirmed}))
|
||||
|
||||
hw :: ByteString
|
||||
hw = encodeUtf8 "Hello world!"
|
||||
|
@ -398,27 +395,30 @@ mkRcvMsgData internalId internalRcvId externalSndId brokerId internalHash =
|
|||
broker = (brokerId, ts)
|
||||
},
|
||||
msgType = AM_A_MSG_,
|
||||
msgFlags = SMP.noMsgFlags,
|
||||
msgBody = hw,
|
||||
internalHash,
|
||||
externalPrevSndHash = "hash_from_sender"
|
||||
}
|
||||
|
||||
testCreateRcvMsg_ :: SQLiteStore -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvMsgData -> Expectation
|
||||
testCreateRcvMsg_ st expectedPrevSndId expectedPrevHash connId rcvMsgData@RcvMsgData {..} = do
|
||||
testCreateRcvMsg_ :: DB.Connection -> PrevExternalSndId -> PrevRcvMsgHash -> ConnId -> RcvMsgData -> Expectation
|
||||
testCreateRcvMsg_ db expectedPrevSndId expectedPrevHash connId rcvMsgData@RcvMsgData {..} = do
|
||||
let MsgMeta {recipient = (internalId, _)} = msgMeta
|
||||
updateRcvIds st connId
|
||||
`returnsResult` (InternalId internalId, internalRcvId, expectedPrevSndId, expectedPrevHash)
|
||||
createRcvMsg st connId rcvMsgData
|
||||
`returnsResult` ()
|
||||
updateRcvIds db connId
|
||||
`shouldReturn` (InternalId internalId, internalRcvId, expectedPrevSndId, expectedPrevHash)
|
||||
createRcvMsg db connId rcvMsgData
|
||||
`shouldReturn` ()
|
||||
|
||||
testCreateRcvMsg :: SpecWith SQLiteStore
|
||||
testCreateRcvMsg =
|
||||
it "should reserve internal ids and create a RcvMsg" $ \st -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
let ConnData {connId} = cData1
|
||||
_ <- runExceptT $ createRcvConn st g cData1 rcvQueue1 SCMInvitation
|
||||
testCreateRcvMsg_ st 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
|
||||
testCreateRcvMsg_ st 1 "hash_dummy" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
|
||||
_ <- withTransaction st $ \db -> do
|
||||
createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
withTransaction st $ \db -> do
|
||||
testCreateRcvMsg_ db 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "hash_dummy"
|
||||
testCreateRcvMsg_ db 1 "hash_dummy" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "new_hash_dummy"
|
||||
|
||||
mkSndMsgData :: InternalId -> InternalSndId -> MsgHash -> SndMsgData
|
||||
mkSndMsgData internalId internalSndId internalHash =
|
||||
|
@ -427,37 +427,42 @@ mkSndMsgData internalId internalSndId internalHash =
|
|||
internalSndId,
|
||||
internalTs = ts,
|
||||
msgType = AM_A_MSG_,
|
||||
msgFlags = SMP.noMsgFlags,
|
||||
msgBody = hw,
|
||||
internalHash,
|
||||
prevMsgHash = internalHash
|
||||
}
|
||||
|
||||
testCreateSndMsg_ :: SQLiteStore -> PrevSndMsgHash -> ConnId -> SndMsgData -> Expectation
|
||||
testCreateSndMsg_ store expectedPrevHash connId sndMsgData@SndMsgData {..} = do
|
||||
updateSndIds store connId
|
||||
`returnsResult` (internalId, internalSndId, expectedPrevHash)
|
||||
createSndMsg store connId sndMsgData
|
||||
`returnsResult` ()
|
||||
testCreateSndMsg_ :: DB.Connection -> PrevSndMsgHash -> ConnId -> SndMsgData -> Expectation
|
||||
testCreateSndMsg_ db expectedPrevHash connId sndMsgData@SndMsgData {..} = do
|
||||
updateSndIds db connId
|
||||
`shouldReturn` (internalId, internalSndId, expectedPrevHash)
|
||||
createSndMsg db connId sndMsgData
|
||||
`shouldReturn` ()
|
||||
|
||||
testCreateSndMsg :: SpecWith SQLiteStore
|
||||
testCreateSndMsg =
|
||||
it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \store -> do
|
||||
it "should create a SndMsg and return InternalId and PrevSndMsgHash" $ \st -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
let ConnData {connId} = cData1
|
||||
_ <- runExceptT $ createSndConn store g cData1 sndQueue1
|
||||
testCreateSndMsg_ store "" connId $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy"
|
||||
testCreateSndMsg_ store "hash_dummy" connId $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy"
|
||||
_ <- withTransaction st $ \db -> do
|
||||
createSndConn db g cData1 sndQueue1
|
||||
withTransaction st $ \db -> do
|
||||
testCreateSndMsg_ db "" connId $ mkSndMsgData (InternalId 1) (InternalSndId 1) "hash_dummy"
|
||||
testCreateSndMsg_ db "hash_dummy" connId $ mkSndMsgData (InternalId 2) (InternalSndId 2) "new_hash_dummy"
|
||||
|
||||
testCreateRcvAndSndMsgs :: SpecWith SQLiteStore
|
||||
testCreateRcvAndSndMsgs =
|
||||
it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \store -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
it "should create multiple RcvMsg and SndMsg, correctly ordering internal Ids and returning previous state" $ \st -> do
|
||||
let ConnData {connId} = cData1
|
||||
_ <- runExceptT $ createRcvConn store g cData1 rcvQueue1 SCMInvitation
|
||||
_ <- runExceptT $ upgradeRcvConnToDuplex store "conn1" sndQueue1
|
||||
testCreateRcvMsg_ store 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
|
||||
testCreateRcvMsg_ store 1 "rcv_hash_1" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
|
||||
testCreateSndMsg_ store "" connId $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1"
|
||||
testCreateRcvMsg_ store 2 "rcv_hash_2" connId $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
|
||||
testCreateSndMsg_ store "snd_hash_1" connId $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2"
|
||||
testCreateSndMsg_ store "snd_hash_2" connId $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"
|
||||
_ <- withTransaction st $ \db -> do
|
||||
g <- newTVarIO =<< drgNew
|
||||
createRcvConn db g cData1 rcvQueue1 SCMInvitation
|
||||
withTransaction st $ \db -> do
|
||||
_ <- upgradeRcvConnToDuplex db "conn1" sndQueue1
|
||||
testCreateRcvMsg_ db 0 "" connId $ mkRcvMsgData (InternalId 1) (InternalRcvId 1) 1 "1" "rcv_hash_1"
|
||||
testCreateRcvMsg_ db 1 "rcv_hash_1" connId $ mkRcvMsgData (InternalId 2) (InternalRcvId 2) 2 "2" "rcv_hash_2"
|
||||
testCreateSndMsg_ db "" connId $ mkSndMsgData (InternalId 3) (InternalSndId 1) "snd_hash_1"
|
||||
testCreateRcvMsg_ db 2 "rcv_hash_2" connId $ mkRcvMsgData (InternalId 4) (InternalRcvId 3) 3 "3" "rcv_hash_3"
|
||||
testCreateSndMsg_ db "snd_hash_1" connId $ mkSndMsgData (InternalId 5) (InternalSndId 2) "snd_hash_2"
|
||||
testCreateSndMsg_ db "snd_hash_2" connId $ mkSndMsgData (InternalId 6) (InternalSndId 3) "snd_hash_3"
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module AgentTests.SchemaDump where
|
||||
|
||||
import Control.Monad (void)
|
||||
import Simplex.Messaging.Agent.Store.SQLite
|
||||
import qualified Simplex.Messaging.Agent.Store.SQLite.Migrations as Migrations
|
||||
import System.Process (readCreateProcess, shell)
|
||||
import Test.Hspec
|
||||
|
||||
testDB :: FilePath
|
||||
testDB = "tests/tmp/test_agent_schema.db"
|
||||
|
||||
schema :: FilePath
|
||||
schema = "src/Simplex/Messaging/Agent/Store/SQLite/Migrations/agent_schema.sql"
|
||||
|
||||
schemaDumpTest :: Spec
|
||||
schemaDumpTest =
|
||||
it "verify and overwrite schema dump" testVerifySchemaDump
|
||||
|
||||
testVerifySchemaDump :: IO ()
|
||||
testVerifySchemaDump = do
|
||||
void $ createSQLiteStore testDB Migrations.app False
|
||||
void $ readCreateProcess (shell $ "touch " <> schema) ""
|
||||
savedSchema <- readFile schema
|
||||
savedSchema `seq` pure ()
|
||||
void $ readCreateProcess (shell $ "sqlite3 " <> testDB <> " '.schema --indent' > " <> schema) ""
|
||||
currentSchema <- readFile schema
|
||||
savedSchema `shouldBe` currentSchema
|
|
@ -0,0 +1,219 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
module NtfClient where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.Except (runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..), (.:))
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Types as JT
|
||||
import Data.ByteString.Builder (lazyByteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Text (Text)
|
||||
import GHC.Generics (Generic)
|
||||
import Network.HTTP.Types (Status)
|
||||
import qualified Network.HTTP.Types as N
|
||||
import qualified Network.HTTP2.Server as H
|
||||
import Network.Socket
|
||||
import Simplex.Messaging.Client.Agent (defaultSMPClientAgentConfig)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Server (runNtfServerBlocking)
|
||||
import Simplex.Messaging.Notifications.Server.Env
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
import Simplex.Messaging.Notifications.Transport
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client
|
||||
import Simplex.Messaging.Transport.HTTP2 (http2TLSParams)
|
||||
import Simplex.Messaging.Transport.HTTP2.Client
|
||||
import Simplex.Messaging.Transport.HTTP2.Server
|
||||
import Simplex.Messaging.Transport.KeepAlive
|
||||
import Test.Hspec
|
||||
import UnliftIO.Async
|
||||
import UnliftIO.Concurrent
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM
|
||||
import UnliftIO.Timeout (timeout)
|
||||
|
||||
testHost :: HostName
|
||||
testHost = "localhost"
|
||||
|
||||
ntfTestPort :: ServiceName
|
||||
ntfTestPort = "6001"
|
||||
|
||||
apnsTestPort :: ServiceName
|
||||
apnsTestPort = "6010"
|
||||
|
||||
testKeyHash :: C.KeyHash
|
||||
testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
|
||||
|
||||
ntfTestStoreLogFile :: FilePath
|
||||
ntfTestStoreLogFile = "tests/tmp/ntf-server-store.log"
|
||||
|
||||
testNtfClient :: (Transport c, MonadUnliftIO m) => (THandle c -> m a) -> m a
|
||||
testNtfClient client =
|
||||
runTransportClient testHost ntfTestPort (Just testKeyHash) (Just defaultKeepAliveOpts) $ \h ->
|
||||
liftIO (runExceptT $ ntfClientHandshake h testKeyHash supportedNTFServerVRange) >>= \case
|
||||
Right th -> client th
|
||||
Left e -> error $ show e
|
||||
|
||||
ntfServerCfg :: NtfServerConfig
|
||||
ntfServerCfg =
|
||||
NtfServerConfig
|
||||
{ transports = undefined,
|
||||
subIdBytes = 24,
|
||||
regCodeBytes = 32,
|
||||
clientQSize = 1,
|
||||
subQSize = 1,
|
||||
pushQSize = 1,
|
||||
smpAgentCfg = defaultSMPClientAgentConfig,
|
||||
apnsConfig =
|
||||
defaultAPNSPushClientConfig
|
||||
{ apnsPort = apnsTestPort,
|
||||
http2cfg = defaultHTTP2ClientConfig {caStoreFile = "tests/fixtures/ca.crt"}
|
||||
},
|
||||
inactiveClientExpiration = Just defaultInactiveClientExpiration,
|
||||
storeLogFile = Nothing,
|
||||
resubscribeDelay = 1000,
|
||||
-- CA certificate private key is not needed for initialization
|
||||
caCertificateFile = "tests/fixtures/ca.crt",
|
||||
privateKeyFile = "tests/fixtures/server.key",
|
||||
certificateFile = "tests/fixtures/server.crt"
|
||||
}
|
||||
|
||||
withNtfServerStoreLog :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ThreadId -> m a) -> m a
|
||||
withNtfServerStoreLog t = withNtfServerCfg t ntfServerCfg {storeLogFile = Just ntfTestStoreLogFile}
|
||||
|
||||
withNtfServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> (ThreadId -> m a) -> m a
|
||||
withNtfServerThreadOn t port' = withNtfServerCfg t ntfServerCfg {transports = [(port', t)]}
|
||||
|
||||
withNtfServerCfg :: (MonadUnliftIO m, MonadRandom m) => ATransport -> NtfServerConfig -> (ThreadId -> m a) -> m a
|
||||
withNtfServerCfg t cfg =
|
||||
serverBracket
|
||||
(\started -> runNtfServerBlocking started cfg {transports = [(ntfTestPort, t)]})
|
||||
(pure ())
|
||||
|
||||
serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a
|
||||
serverBracket process afterProcess f = do
|
||||
started <- newEmptyTMVarIO
|
||||
E.bracket
|
||||
(forkIOWithUnmask ($ process started))
|
||||
(\t -> killThread t >> afterProcess >> waitFor started "stop")
|
||||
(\t -> waitFor started "start" >> f t)
|
||||
where
|
||||
waitFor started s =
|
||||
5_000_000 `timeout` atomically (takeTMVar started) >>= \case
|
||||
Nothing -> error $ "server did not " <> s
|
||||
_ -> pure ()
|
||||
|
||||
withNtfServerOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> m a -> m a
|
||||
withNtfServerOn t port' = withNtfServerThreadOn t port' . const
|
||||
|
||||
withNtfServer :: (MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a
|
||||
withNtfServer t = withNtfServerOn t ntfTestPort
|
||||
|
||||
runNtfTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => (THandle c -> m a) -> m a
|
||||
runNtfTest test = withNtfServer (transport @c) $ testNtfClient test
|
||||
|
||||
ntfServerTest ::
|
||||
forall c smp.
|
||||
(Transport c, Encoding smp) =>
|
||||
TProxy c ->
|
||||
(Maybe C.ASignature, ByteString, ByteString, smp) ->
|
||||
IO (Maybe C.ASignature, ByteString, ByteString, BrokerMsg)
|
||||
ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h
|
||||
where
|
||||
tPut' h (sig, corrId, queueId, smp) = do
|
||||
let t' = smpEncode (sessionId (h :: THandle c), corrId, queueId, smp)
|
||||
Right () <- tPut h (sig, t')
|
||||
pure ()
|
||||
tGet' h = do
|
||||
(Nothing, _, (CorrId corrId, qId, Right cmd)) <- tGet h
|
||||
pure (Nothing, corrId, qId, cmd)
|
||||
|
||||
ntfTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
|
||||
ntfTest _ test' = runNtfTest test' `shouldReturn` ()
|
||||
|
||||
data APNSMockRequest = APNSMockRequest
|
||||
{ notification :: APNSNotification,
|
||||
sendApnsResponse :: APNSMockResponse -> IO ()
|
||||
}
|
||||
|
||||
data APNSMockResponse = APNSRespOk | APNSRespError Status Text
|
||||
|
||||
data APNSMockServer = APNSMockServer
|
||||
{ action :: Async (),
|
||||
apnsQ :: TBQueue APNSMockRequest,
|
||||
http2Server :: HTTP2Server
|
||||
}
|
||||
|
||||
apnsMockServerConfig :: HTTP2ServerConfig
|
||||
apnsMockServerConfig =
|
||||
HTTP2ServerConfig
|
||||
{ qSize = 1,
|
||||
http2Port = apnsTestPort,
|
||||
serverSupported = http2TLSParams,
|
||||
caCertificateFile = "tests/fixtures/ca.crt",
|
||||
privateKeyFile = "tests/fixtures/server.key",
|
||||
certificateFile = "tests/fixtures/server.crt"
|
||||
}
|
||||
|
||||
withAPNSMockServer :: (APNSMockServer -> IO ()) -> IO ()
|
||||
withAPNSMockServer = E.bracket (getAPNSMockServer apnsMockServerConfig) closeAPNSMockServer
|
||||
|
||||
deriving instance Generic APNSAlertBody
|
||||
|
||||
instance FromJSON APNSAlertBody where
|
||||
parseJSON (J.Object v) = do
|
||||
title <- v .: "title"
|
||||
subtitle <- v .: "subtitle"
|
||||
body <- v .: "body"
|
||||
pure APNSAlertObject {title, subtitle, body}
|
||||
parseJSON (J.String v) = pure $ APNSAlertText v
|
||||
parseJSON invalid = JT.prependFailure "parsing Coord failed, " (JT.typeMismatch "Object" invalid)
|
||||
|
||||
instance FromJSON APNSNotificationBody where parseJSON = J.genericParseJSON apnsJSONOptions {J.rejectUnknownFields = True}
|
||||
|
||||
deriving instance FromJSON APNSNotification
|
||||
|
||||
deriving instance ToJSON APNSErrorResponse
|
||||
|
||||
getAPNSMockServer :: HTTP2ServerConfig -> IO APNSMockServer
|
||||
getAPNSMockServer config@HTTP2ServerConfig {qSize} = do
|
||||
http2Server <- getHTTP2Server config
|
||||
apnsQ <- newTBQueueIO qSize
|
||||
action <- async $ runAPNSMockServer apnsQ http2Server
|
||||
pure APNSMockServer {action, apnsQ, http2Server}
|
||||
where
|
||||
runAPNSMockServer apnsQ HTTP2Server {reqQ} = forever $ do
|
||||
HTTP2Request {reqBody, sendResponse} <- atomically $ readTBQueue reqQ
|
||||
let sendApnsResponse = \case
|
||||
APNSRespOk -> sendResponse $ H.responseNoBody N.ok200 []
|
||||
APNSRespError status reason ->
|
||||
sendResponse . H.responseBuilder status [] . lazyByteString $ J.encode APNSErrorResponse {reason}
|
||||
case J.decodeStrict' reqBody of
|
||||
Just notification ->
|
||||
atomically $ writeTBQueue apnsQ APNSMockRequest {notification, sendApnsResponse}
|
||||
_ -> do
|
||||
putStrLn $ "runAPNSMockServer J.decodeStrict' error, reqBody: " <> show reqBody
|
||||
sendApnsResponse $ APNSRespError N.badRequest400 "bad_request_body"
|
||||
|
||||
closeAPNSMockServer :: APNSMockServer -> IO ()
|
||||
closeAPNSMockServer APNSMockServer {action, http2Server} = do
|
||||
closeHTTP2Server http2Server
|
||||
uninterruptibleCancel action
|
|
@ -0,0 +1,165 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
|
||||
|
||||
module NtfServerTests where
|
||||
|
||||
import Control.Concurrent (threadDelay)
|
||||
import Control.Monad.Except (runExceptT)
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Types as JT
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import NtfClient
|
||||
import SMPClient as SMP
|
||||
import ServerTests
|
||||
( createAndSecureQueue,
|
||||
sampleDhPubKey,
|
||||
samplePubKey,
|
||||
sampleSig,
|
||||
signSendRecv,
|
||||
(#==),
|
||||
_SEND',
|
||||
pattern Resp,
|
||||
)
|
||||
import qualified Simplex.Messaging.Agent.Protocol as AP
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
import qualified Simplex.Messaging.Notifications.Server.Push.APNS as APNS
|
||||
import Simplex.Messaging.Parsers (parse, parseAll)
|
||||
import Simplex.Messaging.Protocol hiding (notification)
|
||||
import Simplex.Messaging.Transport
|
||||
import Test.Hspec
|
||||
import UnliftIO.STM
|
||||
|
||||
ntfServerTests :: ATransport -> Spec
|
||||
ntfServerTests t = do
|
||||
describe "Notifications server protocol syntax" $ ntfSyntaxTests t
|
||||
describe "Notification subscriptions" $ testNotificationSubscription t
|
||||
|
||||
ntfSyntaxTests :: ATransport -> Spec
|
||||
ntfSyntaxTests (ATransport t) = do
|
||||
it "unknown command" $ ("", "abcd", "1234", ('H', 'E', 'L', 'L', 'O')) >#> ("", "abcd", "1234", ERR $ CMD UNKNOWN)
|
||||
describe "NEW" $ do
|
||||
it "no parameters" $ (sampleSig, "bcda", "", TNEW_) >#> ("", "bcda", "", ERR $ CMD SYNTAX)
|
||||
it "many parameters" $ (sampleSig, "cdab", "", (TNEW_, (' ', '\x01', 'A'), ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "cdab", "", ERR $ CMD SYNTAX)
|
||||
it "no signature" $ ("", "dabc", "", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "dabc", "", ERR $ CMD NO_AUTH)
|
||||
it "token ID" $ (sampleSig, "abcd", "12345678", (TNEW_, ' ', ('T', 'A', 'T', "abcd" :: ByteString), samplePubKey, sampleDhPubKey)) >#> ("", "abcd", "12345678", ERR $ CMD HAS_AUTH)
|
||||
where
|
||||
(>#>) ::
|
||||
Encoding smp =>
|
||||
(Maybe C.ASignature, ByteString, ByteString, smp) ->
|
||||
(Maybe C.ASignature, ByteString, ByteString, BrokerMsg) ->
|
||||
Expectation
|
||||
command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response
|
||||
|
||||
pattern RespNtf :: CorrId -> QueueId -> NtfResponse -> SignedTransmission NtfResponse
|
||||
pattern RespNtf corrId queueId command <- (_, _, (corrId, queueId, Right command))
|
||||
|
||||
sendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse)
|
||||
sendRecvNtf h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut h (sgn, t)
|
||||
tGet h
|
||||
|
||||
signSendRecvNtf :: forall c e. (Transport c, NtfEntityI e) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, NtfCommand e) -> IO (SignedTransmission NtfResponse)
|
||||
signSendRecvNtf h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right sig <- runExceptT $ C.sign pk t
|
||||
Right () <- tPut h (Just sig, t)
|
||||
tGet h
|
||||
|
||||
(.->) :: J.Value -> J.Key -> Either String ByteString
|
||||
v .-> key =
|
||||
let J.Object o = v
|
||||
in U.decodeLenient . encodeUtf8 <$> JT.parseEither (J..: key) o
|
||||
|
||||
testNotificationSubscription :: ATransport -> Spec
|
||||
testNotificationSubscription (ATransport t) =
|
||||
it "should create notification subscription and notify when message is received" $ do
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
(nPub, nKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
(tknPub, tknKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
(dhPub, dhPriv :: C.PrivateKeyX25519) <- C.generateKeyPair'
|
||||
let tkn = DeviceToken PPApnsTest "abcd"
|
||||
withAPNSMockServer $ \APNSMockServer {apnsQ} ->
|
||||
smpTest2 t $ \rh sh ->
|
||||
ntfTest t $ \nh -> do
|
||||
-- create queue
|
||||
(sId, rId, rKey, rcvDhSecret) <- createAndSecureQueue rh sPub
|
||||
-- register and verify token
|
||||
RespNtf "1" "" (NRTknId tId ntfDh) <- signSendRecvNtf nh tknKey ("1", "", TNEW $ NewNtfTkn tkn tknPub dhPub)
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse = send} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
send APNSRespOk
|
||||
let dhSecret = C.dh' ntfDh dhPriv
|
||||
Right verification = ntfData .-> "verification"
|
||||
Right nonce = C.cbNonce <$> ntfData .-> "nonce"
|
||||
Right code = NtfRegCode <$> C.cbDecrypt dhSecret nonce verification
|
||||
RespNtf "2" _ NROk <- signSendRecvNtf nh tknKey ("2", tId, TVFY code)
|
||||
RespNtf "2a" _ (NRTkn NTActive) <- signSendRecvNtf nh tknKey ("2a", tId, TCHK)
|
||||
-- enable queue notifications
|
||||
(rcvNtfPubDhKey, rcvNtfPrivDhKey) <- C.generateKeyPair'
|
||||
Resp "3" _ (NID nId rcvNtfSrvPubDhKey) <- signSendRecv rh rKey ("3", rId, NKEY nPub rcvNtfPubDhKey)
|
||||
let srv = SMPServer SMP.testHost SMP.testPort SMP.testKeyHash
|
||||
q = SMPQueueNtf srv nId
|
||||
rcvNtfDhSecret = C.dh' rcvNtfSrvPubDhKey rcvNtfPrivDhKey
|
||||
RespNtf "4" _ (NRSubId _subId) <- signSendRecvNtf nh tknKey ("4", "", SNEW $ NewNtfSub tId q nKey)
|
||||
-- send message
|
||||
threadDelay 50000
|
||||
Resp "5" _ OK <- signSendRecv sh sKey ("5", sId, _SEND' "hello")
|
||||
-- receive notification
|
||||
APNSMockRequest {notification, sendApnsResponse = send'} <- atomically $ readTBQueue apnsQ
|
||||
let APNSNotification {aps = APNSMutableContent {}, notificationData = Just ntfData'} = notification
|
||||
Right nonce' = C.cbNonce <$> ntfData' .-> "nonce"
|
||||
Right message = ntfData' .-> "message"
|
||||
Right ntfDataDecrypted = C.cbDecrypt dhSecret nonce' message
|
||||
Right APNS.PNMessageData {smpQueue = SMPQueueNtf {smpServer, notifierId}, nmsgNonce, encNMsgMeta} =
|
||||
parse strP (AP.INTERNAL "error parsing PNMessageData") ntfDataDecrypted
|
||||
Right nMsgMeta = C.cbDecrypt rcvNtfDhSecret nmsgNonce encNMsgMeta
|
||||
Right NMsgMeta {msgId, msgTs} = parse smpP (AP.INTERNAL "error parsing NMsgMeta") nMsgMeta
|
||||
smpServer `shouldBe` srv
|
||||
notifierId `shouldBe` nId
|
||||
send' APNSRespOk
|
||||
-- receive message
|
||||
Resp "" _ (MSG RcvMessage {msgId = mId1, msgBody = EncRcvMsgBody body}) <- tGet rh
|
||||
Right ClientRcvMsgBody {msgTs = mTs, msgBody} <- pure $ parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt rcvDhSecret (C.cbNonce mId1) body)
|
||||
mId1 `shouldBe` msgId
|
||||
mTs `shouldBe` msgTs
|
||||
(msgBody, "hello") #== "delivered from queue"
|
||||
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, ACK mId1)
|
||||
pure ()
|
||||
-- replace token
|
||||
let tkn' = DeviceToken PPApnsTest "efgh"
|
||||
RespNtf "7" tId' NROk <- signSendRecvNtf nh tknKey ("7", tId, TRPL tkn')
|
||||
tId `shouldBe` tId'
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData2}, sendApnsResponse = send2} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
send2 APNSRespOk
|
||||
let Right verification2 = ntfData2 .-> "verification"
|
||||
Right nonce2 = C.cbNonce <$> ntfData2 .-> "nonce"
|
||||
Right code2 = NtfRegCode <$> C.cbDecrypt dhSecret nonce2 verification2
|
||||
RespNtf "8" _ NROk <- signSendRecvNtf nh tknKey ("8", tId, TVFY code2)
|
||||
RespNtf "8a" _ (NRTkn NTActive) <- signSendRecvNtf nh tknKey ("8a", tId, TCHK)
|
||||
-- send message
|
||||
Resp "9" _ OK <- signSendRecv sh sKey ("9", sId, _SEND' "hello 2")
|
||||
APNSMockRequest {notification = notification3, sendApnsResponse = send3} <- atomically $ readTBQueue apnsQ
|
||||
let APNSNotification {aps = APNSMutableContent {}, notificationData = Just ntfData3} = notification3
|
||||
Right nonce3 = C.cbNonce <$> ntfData3 .-> "nonce"
|
||||
Right message3 = ntfData3 .-> "message"
|
||||
Right ntfDataDecrypted3 = C.cbDecrypt dhSecret nonce3 message3
|
||||
Right APNS.PNMessageData {smpQueue = SMPQueueNtf {smpServer = smpServer3, notifierId = notifierId3}} =
|
||||
parse strP (AP.INTERNAL "error parsing PNMessageData") ntfDataDecrypted3
|
||||
smpServer3 `shouldBe` srv
|
||||
notifierId3 `shouldBe` nId
|
||||
send3 APNSRespOk
|
|
@ -11,6 +11,7 @@ import Crypto.Random
|
|||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import NtfClient (ntfTestPort)
|
||||
import SMPClient
|
||||
( serverBracket,
|
||||
testKeyHash,
|
||||
|
@ -24,7 +25,7 @@ import Simplex.Messaging.Agent.Env.SQLite
|
|||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Agent.RetryInterval
|
||||
import Simplex.Messaging.Agent.Server (runSMPAgentBlocking)
|
||||
import Simplex.Messaging.Client (SMPClientConfig (..), smpDefaultConfig)
|
||||
import Simplex.Messaging.Client (ProtocolClientConfig (..), defaultClientConfig)
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client
|
||||
import Simplex.Messaging.Transport.KeepAlive
|
||||
|
@ -154,20 +155,36 @@ smpAgentTest1_1_1 test' =
|
|||
_test [h] = test' h
|
||||
_test _ = error "expected 1 handle"
|
||||
|
||||
cfg :: AgentConfig
|
||||
cfg =
|
||||
testSMPServer :: SMPServer
|
||||
testSMPServer = "smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001"
|
||||
|
||||
initAgentServers :: InitialAgentServers
|
||||
initAgentServers =
|
||||
InitialAgentServers
|
||||
{ smp = L.fromList [testSMPServer],
|
||||
ntf = ["ntf://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"]
|
||||
}
|
||||
|
||||
agentCfg :: AgentConfig
|
||||
agentCfg =
|
||||
defaultAgentConfig
|
||||
{ tcpPort = agentTestPort,
|
||||
initialSMPServers = L.fromList ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:5001"],
|
||||
tbqSize = 1,
|
||||
dbFile = testDB,
|
||||
smpCfg =
|
||||
smpDefaultConfig
|
||||
defaultClientConfig
|
||||
{ qSize = 1,
|
||||
defaultTransport = (testPort, transport @TLS),
|
||||
tcpTimeout = 500_000
|
||||
},
|
||||
reconnectInterval = (reconnectInterval defaultAgentConfig) {initialInterval = 50_000},
|
||||
ntfCfg =
|
||||
defaultClientConfig
|
||||
{ qSize = 1,
|
||||
defaultTransport = (ntfTestPort, transport @TLS)
|
||||
},
|
||||
reconnectInterval = defaultReconnectInterval {initialInterval = 50_000},
|
||||
ntfWorkerDelay = 1000,
|
||||
ntfSMPWorkerDelay = 1000,
|
||||
caCertificateFile = "tests/fixtures/ca.crt",
|
||||
privateKeyFile = "tests/fixtures/server.key",
|
||||
certificateFile = "tests/fixtures/server.crt"
|
||||
|
@ -175,9 +192,10 @@ cfg =
|
|||
|
||||
withSmpAgentThreadOn_ :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m () -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess =
|
||||
let cfg' = cfg {tcpPort = port', dbFile = db', initialSMPServers = L.fromList [SMPServer "localhost" smpPort' testKeyHash]}
|
||||
let cfg' = agentCfg {tcpPort = port', dbFile = db'}
|
||||
initServers' = initAgentServers {smp = L.fromList [SMPServer "localhost" smpPort' testKeyHash]}
|
||||
in serverBracket
|
||||
(\started -> runSMPAgentBlocking t started cfg')
|
||||
(\started -> runSMPAgentBlocking t started cfg' initServers')
|
||||
afterProcess
|
||||
|
||||
withSmpAgentThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> (ThreadId -> m a) -> m a
|
||||
|
@ -191,7 +209,7 @@ withSmpAgent t = withSmpAgentOn t (agentTestPort, testPort, testDB)
|
|||
|
||||
testSMPAgentClientOn :: (Transport c, MonadUnliftIO m) => ServiceName -> (c -> m a) -> m a
|
||||
testSMPAgentClientOn port' client = do
|
||||
runTransportClient agentTestHost port' testKeyHash (Just defaultKeepAliveOpts) $ \h -> do
|
||||
runTransportClient agentTestHost port' (Just testKeyHash) (Just defaultKeepAliveOpts) $ \h -> do
|
||||
line <- liftIO $ getLn h
|
||||
if line == "Welcome to SMP agent v" <> B.pack simplexMQVersion
|
||||
then client h
|
||||
|
|
|
@ -19,10 +19,10 @@ import Simplex.Messaging.Encoding
|
|||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server (runSMPServerBlocking)
|
||||
import Simplex.Messaging.Server.Env.STM
|
||||
import Simplex.Messaging.Server.StoreLog (openReadStoreLog)
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client
|
||||
import Simplex.Messaging.Transport.KeepAlive
|
||||
import Simplex.Messaging.Version
|
||||
import Test.Hspec
|
||||
import UnliftIO.Concurrent
|
||||
import qualified UnliftIO.Exception as E
|
||||
|
@ -44,13 +44,22 @@ testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
|
|||
testStoreLogFile :: FilePath
|
||||
testStoreLogFile = "tests/tmp/smp-server-store.log"
|
||||
|
||||
testStoreMsgsFile :: FilePath
|
||||
testStoreMsgsFile = "tests/tmp/smp-server-messages.log"
|
||||
|
||||
testServerStatsFile :: FilePath
|
||||
testServerStatsFile = "tests/tmp/smp-server-stats.log"
|
||||
|
||||
testSMPClient :: (Transport c, MonadUnliftIO m) => (THandle c -> m a) -> m a
|
||||
testSMPClient client =
|
||||
runTransportClient testHost testPort testKeyHash (Just defaultKeepAliveOpts) $ \h ->
|
||||
liftIO (runExceptT $ clientHandshake h testKeyHash) >>= \case
|
||||
runTransportClient testHost testPort (Just testKeyHash) (Just defaultKeepAliveOpts) $ \h ->
|
||||
liftIO (runExceptT $ smpClientHandshake h testKeyHash supportedSMPServerVRange) >>= \case
|
||||
Right th -> client th
|
||||
Left e -> error $ show e
|
||||
|
||||
cfgV2 :: ServerConfig
|
||||
cfgV2 = cfg {smpServerVRange = mkVersionRange 1 2}
|
||||
|
||||
cfg :: ServerConfig
|
||||
cfg =
|
||||
ServerConfig
|
||||
|
@ -60,19 +69,28 @@ cfg =
|
|||
msgQueueQuota = 4,
|
||||
queueIdBytes = 24,
|
||||
msgIdBytes = 24,
|
||||
storeLog = Nothing,
|
||||
storeLogFile = Nothing,
|
||||
storeMsgsFile = Nothing,
|
||||
allowNewQueues = True,
|
||||
messageTTL = Just $ 7 * 86400, -- seconds, 7 days
|
||||
expireMessagesInterval = Just 21600_000000, -- microseconds, 6 hours
|
||||
messageExpiration = Just defaultMessageExpiration,
|
||||
inactiveClientExpiration = Just defaultInactiveClientExpiration,
|
||||
logStatsInterval = Nothing,
|
||||
logStatsStartTime = 0,
|
||||
serverStatsFile = Nothing,
|
||||
caCertificateFile = "tests/fixtures/ca.crt",
|
||||
privateKeyFile = "tests/fixtures/server.key",
|
||||
certificateFile = "tests/fixtures/server.crt"
|
||||
certificateFile = "tests/fixtures/server.crt",
|
||||
smpServerVRange = supportedSMPServerVRange
|
||||
}
|
||||
|
||||
withSmpServerStoreMsgLogOnV2 :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> (ThreadId -> m a) -> m a
|
||||
withSmpServerStoreMsgLogOnV2 t = withSmpServerConfigOn t cfgV2 {storeLogFile = Just testStoreLogFile, storeMsgsFile = Just testStoreMsgsFile}
|
||||
|
||||
withSmpServerStoreMsgLogOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> (ThreadId -> m a) -> m a
|
||||
withSmpServerStoreMsgLogOn t = withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile, storeMsgsFile = Just testStoreMsgsFile, serverStatsFile = Just testServerStatsFile}
|
||||
|
||||
withSmpServerStoreLogOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> (ThreadId -> m a) -> m a
|
||||
withSmpServerStoreLogOn t port' client = do
|
||||
s <- liftIO $ openReadStoreLog testStoreLogFile
|
||||
withSmpServerConfigOn t cfg {storeLog = Just s} port' client
|
||||
withSmpServerStoreLogOn t = withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile, serverStatsFile = Just testServerStatsFile}
|
||||
|
||||
withSmpServerConfigOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServerConfig -> ServiceName -> (ThreadId -> m a) -> m a
|
||||
withSmpServerConfigOn t cfg' port' =
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
|
@ -13,6 +14,7 @@ import Control.Concurrent (ThreadId, killThread, threadDelay)
|
|||
import Control.Concurrent.STM
|
||||
import Control.Exception (SomeException, try)
|
||||
import Control.Monad.Except (forM, forM_, runExceptT)
|
||||
import Data.Bifunctor (first)
|
||||
import Data.ByteString.Base64
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
|
@ -20,8 +22,10 @@ import SMPClient
|
|||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Encoding.String
|
||||
import Simplex.Messaging.Parsers (parseAll)
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Server.Env.STM (ServerConfig (..))
|
||||
import Simplex.Messaging.Server.Expiration
|
||||
import Simplex.Messaging.Transport
|
||||
import System.Directory (removeFile)
|
||||
import System.TimeIt (timeItT)
|
||||
|
@ -33,14 +37,19 @@ serverTests :: ATransport -> Spec
|
|||
serverTests t@(ATransport t') = do
|
||||
describe "SMP syntax" $ syntaxTests t
|
||||
describe "SMP queues" $ do
|
||||
describe "NEW and KEY commands, SEND messages" $ testCreateSecure t
|
||||
describe "NEW and KEY commands, SEND messages (v2)" $ testCreateSecureV2 t'
|
||||
describe "NEW and KEY commands, SEND messages (v3)" $ testCreateSecure t
|
||||
describe "NEW, OFF and DEL commands, SEND messages" $ testCreateDelete t
|
||||
describe "Stress test" $ stressTest t
|
||||
describe "allowNewQueues setting" $ testAllowNewQueues t'
|
||||
describe "SMP messages" $ do
|
||||
describe "duplex communication over 2 SMP connections" $ testDuplex t
|
||||
describe "switch subscription to another TCP connection" $ testSwitchSub t
|
||||
describe "GET command" $ testGetCommand t'
|
||||
describe "GET & SUB commands" $ testGetSubCommands t'
|
||||
describe "Store log" $ testWithStoreLog t
|
||||
describe "Restore messages" $ testRestoreMessages t
|
||||
describe "Restore messages (v2)" $ testRestoreMessagesV2 t
|
||||
describe "Timing of AUTH error" $ testTiming t
|
||||
describe "Message notifications" $ testMessageNotifications t
|
||||
describe "Message expiration" $ do
|
||||
|
@ -54,15 +63,18 @@ pattern Resp corrId queueId command <- (_, _, (corrId, queueId, Right command))
|
|||
pattern Ids :: RecipientId -> SenderId -> RcvPublicDhKey -> BrokerMsg
|
||||
pattern Ids rId sId srvDh <- IDS (QIK rId sId srvDh)
|
||||
|
||||
pattern Msg :: MsgId -> MsgBody -> BrokerMsg
|
||||
pattern Msg msgId body <- MSG RcvMessage {msgId, msgBody = EncRcvMsgBody body}
|
||||
|
||||
sendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> (Maybe C.ASignature, ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg)
|
||||
sendRecv h@THandle {sessionId} (sgn, corrId, qId, cmd) = do
|
||||
let t = encodeTransmission sessionId (CorrId corrId, qId, cmd)
|
||||
sendRecv h@THandle {thVersion, sessionId} (sgn, corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right () <- tPut h (sgn, t)
|
||||
tGet h
|
||||
|
||||
signSendRecv :: forall c p. (Transport c, PartyI p) => THandle c -> C.APrivateSignKey -> (ByteString, ByteString, Command p) -> IO (SignedTransmission BrokerMsg)
|
||||
signSendRecv h@THandle {sessionId} pk (corrId, qId, cmd) = do
|
||||
let t = encodeTransmission sessionId (CorrId corrId, qId, cmd)
|
||||
signSendRecv h@THandle {thVersion, sessionId} pk (corrId, qId, cmd) = do
|
||||
let t = encodeTransmission thVersion sessionId (CorrId corrId, qId, cmd)
|
||||
Right sig <- runExceptT $ C.sign pk t
|
||||
Right () <- tPut h (Just sig, t)
|
||||
tGet h
|
||||
|
@ -70,31 +82,45 @@ signSendRecv h@THandle {sessionId} pk (corrId, qId, cmd) = do
|
|||
(#==) :: (HasCallStack, Eq a, Show a) => (a, a) -> String -> Assertion
|
||||
(actual, expected) #== message = assertEqual message expected actual
|
||||
|
||||
testCreateSecure :: ATransport -> Spec
|
||||
testCreateSecure (ATransport t) =
|
||||
_SEND :: MsgBody -> Command 'Sender
|
||||
_SEND = SEND noMsgFlags
|
||||
|
||||
_SEND' :: MsgBody -> Command 'Sender
|
||||
_SEND' = SEND MsgFlags {notification = True}
|
||||
|
||||
decryptMsgV2 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either C.CryptoError ByteString
|
||||
decryptMsgV2 dhShared = C.cbDecrypt dhShared . C.cbNonce
|
||||
|
||||
decryptMsgV3 :: C.DhSecret 'C.X25519 -> ByteString -> ByteString -> Either String MsgBody
|
||||
decryptMsgV3 dhShared nonce body = do
|
||||
ClientRcvMsgBody {msgBody} <- parseAll clientRcvMsgBodyP =<< first show (C.cbDecrypt dhShared (C.cbNonce nonce) body)
|
||||
pure msgBody
|
||||
|
||||
testCreateSecureV2 :: forall c. Transport c => TProxy c -> Spec
|
||||
testCreateSecureV2 _ =
|
||||
it "should create (NEW) and secure (KEY) queue" $
|
||||
smpTest t $ \h -> do
|
||||
withSmpServerConfigOn (transport @c) cfgV2 testPort $ \_ -> testSMPClient @c $ \h -> do
|
||||
(rPub, rKey) <- C.generateSignatureKeyPair C.SEd448
|
||||
(dhPub, dhPriv :: C.PrivateKeyX25519) <- C.generateKeyPair'
|
||||
Resp "abcd" rId1 (Ids rId sId srvDh) <- signSendRecv h rKey ("abcd", "", NEW rPub dhPub)
|
||||
let dec nonce = C.cbDecrypt (C.dh' srvDh dhPriv) (C.cbNonce nonce)
|
||||
let dec = decryptMsgV2 $ C.dh' srvDh dhPriv
|
||||
(rId1, "") #== "creates queue"
|
||||
|
||||
Resp "bcda" sId1 ok1 <- sendRecv h ("", "bcda", sId, SEND "hello")
|
||||
Resp "bcda" sId1 ok1 <- sendRecv h ("", "bcda", sId, _SEND "hello")
|
||||
(ok1, OK) #== "accepts unsigned SEND"
|
||||
(sId1, sId) #== "same queue ID in response 1"
|
||||
|
||||
Resp "" _ (MSG mId1 _ msg1) <- tGet h
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet h
|
||||
(dec mId1 msg1, Right "hello") #== "delivers message"
|
||||
|
||||
Resp "cdab" _ ok4 <- signSendRecv h rKey ("cdab", rId, ACK)
|
||||
Resp "cdab" _ ok4 <- signSendRecv h rKey ("cdab", rId, ACK mId1)
|
||||
(ok4, OK) #== "replies OK when message acknowledged if no more messages"
|
||||
|
||||
Resp "dabc" _ err6 <- signSendRecv h rKey ("dabc", rId, ACK)
|
||||
Resp "dabc" _ err6 <- signSendRecv h rKey ("dabc", rId, ACK mId1)
|
||||
(err6, ERR NO_MSG) #== "replies ERR when message acknowledged without messages"
|
||||
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd448
|
||||
Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, SEND "hello")
|
||||
Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, _SEND "hello")
|
||||
(err1, ERR AUTH) #== "rejects signed SEND"
|
||||
(sId2, sId) #== "same queue ID in response 2"
|
||||
|
||||
|
@ -111,18 +137,89 @@ testCreateSecure (ATransport t) =
|
|||
Resp "abcd" _ err4 <- signSendRecv h rKey ("abcd", rId, KEY sPub)
|
||||
(err4, ERR AUTH) #== "rejects KEY if already secured"
|
||||
|
||||
Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, SEND "hello again")
|
||||
Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, _SEND "hello again")
|
||||
(ok3, OK) #== "accepts signed SEND"
|
||||
|
||||
Resp "" _ (MSG mId2 _ msg2) <- tGet h
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet h
|
||||
(dec mId2 msg2, Right "hello again") #== "delivers message 2"
|
||||
|
||||
Resp "cdab" _ ok5 <- signSendRecv h rKey ("cdab", rId, ACK)
|
||||
Resp "cdab" _ ok5 <- signSendRecv h rKey ("cdab", rId, ACK mId2)
|
||||
(ok5, OK) #== "replies OK when message acknowledged 2"
|
||||
|
||||
Resp "dabc" _ err5 <- sendRecv h ("", "dabc", sId, SEND "hello")
|
||||
Resp "dabc" _ err5 <- sendRecv h ("", "dabc", sId, _SEND "hello")
|
||||
(err5, ERR AUTH) #== "rejects unsigned SEND"
|
||||
|
||||
let maxAllowedMessage = B.replicate maxMessageLength '-'
|
||||
Resp "bcda" _ OK <- signSendRecv h sKey ("bcda", sId, _SEND maxAllowedMessage)
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet h
|
||||
(dec mId3 msg3, Right maxAllowedMessage) #== "delivers message of max size"
|
||||
|
||||
let biggerMessage = B.replicate (maxMessageLength + 1) '-'
|
||||
Resp "bcda" _ (ERR LARGE_MSG) <- signSendRecv h sKey ("bcda", sId, _SEND biggerMessage)
|
||||
pure ()
|
||||
|
||||
testCreateSecure :: ATransport -> Spec
|
||||
testCreateSecure (ATransport t) =
|
||||
it "should create (NEW) and secure (KEY) queue" $
|
||||
smpTest t $ \h -> do
|
||||
(rPub, rKey) <- C.generateSignatureKeyPair C.SEd448
|
||||
(dhPub, dhPriv :: C.PrivateKeyX25519) <- C.generateKeyPair'
|
||||
Resp "abcd" rId1 (Ids rId sId srvDh) <- signSendRecv h rKey ("abcd", "", NEW rPub dhPub)
|
||||
let dec = decryptMsgV3 $ C.dh' srvDh dhPriv
|
||||
(rId1, "") #== "creates queue"
|
||||
|
||||
Resp "bcda" sId1 ok1 <- sendRecv h ("", "bcda", sId, _SEND "hello")
|
||||
(ok1, OK) #== "accepts unsigned SEND"
|
||||
(sId1, sId) #== "same queue ID in response 1"
|
||||
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet h
|
||||
(dec mId1 msg1, Right "hello") #== "delivers message"
|
||||
|
||||
Resp "cdab" _ ok4 <- signSendRecv h rKey ("cdab", rId, ACK mId1)
|
||||
(ok4, OK) #== "replies OK when message acknowledged if no more messages"
|
||||
|
||||
Resp "dabc" _ err6 <- signSendRecv h rKey ("dabc", rId, ACK mId1)
|
||||
(err6, ERR NO_MSG) #== "replies ERR when message acknowledged without messages"
|
||||
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd448
|
||||
Resp "abcd" sId2 err1 <- signSendRecv h sKey ("abcd", sId, _SEND "hello")
|
||||
(err1, ERR AUTH) #== "rejects signed SEND"
|
||||
(sId2, sId) #== "same queue ID in response 2"
|
||||
|
||||
Resp "bcda" _ err2 <- sendRecv h (sampleSig, "bcda", rId, KEY sPub)
|
||||
(err2, ERR AUTH) #== "rejects KEY with wrong signature"
|
||||
|
||||
Resp "cdab" _ err3 <- signSendRecv h rKey ("cdab", sId, KEY sPub)
|
||||
(err3, ERR AUTH) #== "rejects KEY with sender's ID"
|
||||
|
||||
Resp "dabc" rId2 ok2 <- signSendRecv h rKey ("dabc", rId, KEY sPub)
|
||||
(ok2, OK) #== "secures queue"
|
||||
(rId2, rId) #== "same queue ID in response 3"
|
||||
|
||||
Resp "abcd" _ err4 <- signSendRecv h rKey ("abcd", rId, KEY sPub)
|
||||
(err4, ERR AUTH) #== "rejects KEY if already secured"
|
||||
|
||||
Resp "bcda" _ ok3 <- signSendRecv h sKey ("bcda", sId, _SEND "hello again")
|
||||
(ok3, OK) #== "accepts signed SEND"
|
||||
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet h
|
||||
(dec mId2 msg2, Right "hello again") #== "delivers message 2"
|
||||
|
||||
Resp "cdab" _ ok5 <- signSendRecv h rKey ("cdab", rId, ACK mId2)
|
||||
(ok5, OK) #== "replies OK when message acknowledged 2"
|
||||
|
||||
Resp "dabc" _ err5 <- sendRecv h ("", "dabc", sId, _SEND "hello")
|
||||
(err5, ERR AUTH) #== "rejects unsigned SEND"
|
||||
|
||||
let maxAllowedMessage = B.replicate maxMessageLength '-'
|
||||
Resp "bcda" _ OK <- signSendRecv h sKey ("bcda", sId, _SEND maxAllowedMessage)
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet h
|
||||
(dec mId3 msg3, Right maxAllowedMessage) #== "delivers message of max size"
|
||||
|
||||
let biggerMessage = B.replicate (maxMessageLength + 1) '-'
|
||||
Resp "bcda" _ (ERR LARGE_MSG) <- signSendRecv h sKey ("bcda", sId, _SEND biggerMessage)
|
||||
pure ()
|
||||
|
||||
testCreateDelete :: ATransport -> Spec
|
||||
testCreateDelete (ATransport t) =
|
||||
it "should create (NEW), suspend (OFF) and delete (DEL) queue" $
|
||||
|
@ -130,20 +227,20 @@ testCreateDelete (ATransport t) =
|
|||
(rPub, rKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
(dhPub, dhPriv :: C.PrivateKeyX25519) <- C.generateKeyPair'
|
||||
Resp "abcd" rId1 (Ids rId sId srvDh) <- signSendRecv rh rKey ("abcd", "", NEW rPub dhPub)
|
||||
let dec nonce = C.cbDecrypt (C.dh' srvDh dhPriv) (C.cbNonce nonce)
|
||||
let dec = decryptMsgV3 $ C.dh' srvDh dhPriv
|
||||
(rId1, "") #== "creates queue"
|
||||
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
Resp "bcda" _ ok1 <- signSendRecv rh rKey ("bcda", rId, KEY sPub)
|
||||
(ok1, OK) #== "secures queue"
|
||||
|
||||
Resp "cdab" _ ok2 <- signSendRecv sh sKey ("cdab", sId, SEND "hello")
|
||||
Resp "cdab" _ ok2 <- signSendRecv sh sKey ("cdab", sId, _SEND "hello")
|
||||
(ok2, OK) #== "accepts signed SEND"
|
||||
|
||||
Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, SEND "hello 2")
|
||||
Resp "dabc" _ ok7 <- signSendRecv sh sKey ("dabc", sId, _SEND "hello 2")
|
||||
(ok7, OK) #== "accepts signed SEND 2 - this message is not delivered because the first is not ACKed"
|
||||
|
||||
Resp "" _ (MSG mId1 _ msg1) <- tGet rh
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet rh
|
||||
(dec mId1 msg1, Right "hello") #== "delivers message"
|
||||
|
||||
Resp "abcd" _ err1 <- sendRecv rh (sampleSig, "abcd", rId, OFF)
|
||||
|
@ -156,16 +253,16 @@ testCreateDelete (ATransport t) =
|
|||
(ok3, OK) #== "suspends queue"
|
||||
(rId2, rId) #== "same queue ID in response 2"
|
||||
|
||||
Resp "dabc" _ err3 <- signSendRecv sh sKey ("dabc", sId, SEND "hello")
|
||||
Resp "dabc" _ err3 <- signSendRecv sh sKey ("dabc", sId, _SEND "hello")
|
||||
(err3, ERR AUTH) #== "rejects signed SEND"
|
||||
|
||||
Resp "abcd" _ err4 <- sendRecv sh ("", "abcd", sId, SEND "hello")
|
||||
Resp "abcd" _ err4 <- sendRecv sh ("", "abcd", sId, _SEND "hello")
|
||||
(err4, ERR AUTH) #== "reject unsigned SEND too"
|
||||
|
||||
Resp "bcda" _ ok4 <- signSendRecv rh rKey ("bcda", rId, OFF)
|
||||
(ok4, OK) #== "accepts OFF when suspended"
|
||||
|
||||
Resp "cdab" _ (MSG mId2 _ msg2) <- signSendRecv rh rKey ("cdab", rId, SUB)
|
||||
Resp "cdab" _ (Msg mId2 msg2) <- signSendRecv rh rKey ("cdab", rId, SUB)
|
||||
(dec mId2 msg2, Right "hello") #== "accepts SUB when suspended and delivers the message again (because was not ACKed)"
|
||||
|
||||
Resp "dabc" _ err5 <- sendRecv rh (sampleSig, "dabc", rId, DEL)
|
||||
|
@ -178,13 +275,13 @@ testCreateDelete (ATransport t) =
|
|||
(ok6, OK) #== "deletes queue"
|
||||
(rId3, rId) #== "same queue ID in response 3"
|
||||
|
||||
Resp "cdab" _ err7 <- signSendRecv sh sKey ("cdab", sId, SEND "hello")
|
||||
Resp "cdab" _ err7 <- signSendRecv sh sKey ("cdab", sId, _SEND "hello")
|
||||
(err7, ERR AUTH) #== "rejects signed SEND when deleted"
|
||||
|
||||
Resp "dabc" _ err8 <- sendRecv sh ("", "dabc", sId, SEND "hello")
|
||||
Resp "dabc" _ err8 <- sendRecv sh ("", "dabc", sId, _SEND "hello")
|
||||
(err8, ERR AUTH) #== "rejects unsigned SEND too when deleted"
|
||||
|
||||
Resp "abcd" _ err11 <- signSendRecv rh rKey ("abcd", rId, ACK)
|
||||
Resp "abcd" _ err11 <- signSendRecv rh rKey ("abcd", rId, ACK "")
|
||||
(err11, ERR AUTH) #== "rejects ACK when conn deleted - the second message is deleted"
|
||||
|
||||
Resp "bcda" _ err9 <- signSendRecv rh rKey ("bcda", rId, OFF)
|
||||
|
@ -227,15 +324,15 @@ testDuplex (ATransport t) =
|
|||
(arPub, arKey) <- C.generateSignatureKeyPair C.SEd448
|
||||
(aDhPub, aDhPriv :: C.PrivateKeyX25519) <- C.generateKeyPair'
|
||||
Resp "abcd" _ (Ids aRcv aSnd aSrvDh) <- signSendRecv alice arKey ("abcd", "", NEW arPub aDhPub)
|
||||
let aDec nonce = C.cbDecrypt (C.dh' aSrvDh aDhPriv) (C.cbNonce nonce)
|
||||
let aDec = decryptMsgV3 $ C.dh' aSrvDh aDhPriv
|
||||
-- aSnd ID is passed to Bob out-of-band
|
||||
|
||||
(bsPub, bsKey) <- C.generateSignatureKeyPair C.SEd448
|
||||
Resp "bcda" _ OK <- sendRecv bob ("", "bcda", aSnd, SEND $ "key " <> strEncode bsPub)
|
||||
Resp "bcda" _ OK <- sendRecv bob ("", "bcda", aSnd, _SEND $ "key " <> strEncode bsPub)
|
||||
-- "key ..." is ad-hoc, not a part of SMP protocol
|
||||
|
||||
Resp "" _ (MSG mId1 _ msg1) <- tGet alice
|
||||
Resp "cdab" _ OK <- signSendRecv alice arKey ("cdab", aRcv, ACK)
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet alice
|
||||
Resp "cdab" _ OK <- signSendRecv alice arKey ("cdab", aRcv, ACK mId1)
|
||||
Right ["key", bobKey] <- pure $ B.words <$> aDec mId1 msg1
|
||||
(bobKey, strEncode bsPub) #== "key received from Bob"
|
||||
Resp "dabc" _ OK <- signSendRecv alice arKey ("dabc", aRcv, KEY bsPub)
|
||||
|
@ -243,35 +340,35 @@ testDuplex (ATransport t) =
|
|||
(brPub, brKey) <- C.generateSignatureKeyPair C.SEd448
|
||||
(bDhPub, bDhPriv :: C.PrivateKeyX25519) <- C.generateKeyPair'
|
||||
Resp "abcd" _ (Ids bRcv bSnd bSrvDh) <- signSendRecv bob brKey ("abcd", "", NEW brPub bDhPub)
|
||||
let bDec nonce = C.cbDecrypt (C.dh' bSrvDh bDhPriv) (C.cbNonce nonce)
|
||||
Resp "bcda" _ OK <- signSendRecv bob bsKey ("bcda", aSnd, SEND $ "reply_id " <> encode bSnd)
|
||||
let bDec = decryptMsgV3 $ C.dh' bSrvDh bDhPriv
|
||||
Resp "bcda" _ OK <- signSendRecv bob bsKey ("bcda", aSnd, _SEND $ "reply_id " <> encode bSnd)
|
||||
-- "reply_id ..." is ad-hoc, not a part of SMP protocol
|
||||
|
||||
Resp "" _ (MSG mId2 _ msg2) <- tGet alice
|
||||
Resp "cdab" _ OK <- signSendRecv alice arKey ("cdab", aRcv, ACK)
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet alice
|
||||
Resp "cdab" _ OK <- signSendRecv alice arKey ("cdab", aRcv, ACK mId2)
|
||||
Right ["reply_id", bId] <- pure $ B.words <$> aDec mId2 msg2
|
||||
(bId, encode bSnd) #== "reply queue ID received from Bob"
|
||||
|
||||
(asPub, asKey) <- C.generateSignatureKeyPair C.SEd448
|
||||
Resp "dabc" _ OK <- sendRecv alice ("", "dabc", bSnd, SEND $ "key " <> strEncode asPub)
|
||||
Resp "dabc" _ OK <- sendRecv alice ("", "dabc", bSnd, _SEND $ "key " <> strEncode asPub)
|
||||
-- "key ..." is ad-hoc, not a part of SMP protocol
|
||||
|
||||
Resp "" _ (MSG mId3 _ msg3) <- tGet bob
|
||||
Resp "abcd" _ OK <- signSendRecv bob brKey ("abcd", bRcv, ACK)
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet bob
|
||||
Resp "abcd" _ OK <- signSendRecv bob brKey ("abcd", bRcv, ACK mId3)
|
||||
Right ["key", aliceKey] <- pure $ B.words <$> bDec mId3 msg3
|
||||
(aliceKey, strEncode asPub) #== "key received from Alice"
|
||||
Resp "bcda" _ OK <- signSendRecv bob brKey ("bcda", bRcv, KEY asPub)
|
||||
|
||||
Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, SEND "hi alice")
|
||||
Resp "cdab" _ OK <- signSendRecv bob bsKey ("cdab", aSnd, _SEND "hi alice")
|
||||
|
||||
Resp "" _ (MSG mId4 _ msg4) <- tGet alice
|
||||
Resp "dabc" _ OK <- signSendRecv alice arKey ("dabc", aRcv, ACK)
|
||||
Resp "" _ (Msg mId4 msg4) <- tGet alice
|
||||
Resp "dabc" _ OK <- signSendRecv alice arKey ("dabc", aRcv, ACK mId4)
|
||||
(aDec mId4 msg4, Right "hi alice") #== "message received from Bob"
|
||||
|
||||
Resp "abcd" _ OK <- signSendRecv alice asKey ("abcd", bSnd, SEND "how are you bob")
|
||||
Resp "abcd" _ OK <- signSendRecv alice asKey ("abcd", bSnd, _SEND "how are you bob")
|
||||
|
||||
Resp "" _ (MSG mId5 _ msg5) <- tGet bob
|
||||
Resp "bcda" _ OK <- signSendRecv bob brKey ("bcda", bRcv, ACK)
|
||||
Resp "" _ (Msg mId5 msg5) <- tGet bob
|
||||
Resp "bcda" _ OK <- signSendRecv bob brKey ("bcda", bRcv, ACK mId5)
|
||||
(bDec mId5 msg5, Right "how are you bob") #== "message received from alice"
|
||||
|
||||
testSwitchSub :: ATransport -> Spec
|
||||
|
@ -281,39 +378,106 @@ testSwitchSub (ATransport t) =
|
|||
(rPub, rKey) <- C.generateSignatureKeyPair C.SEd448
|
||||
(dhPub, dhPriv :: C.PrivateKeyX25519) <- C.generateKeyPair'
|
||||
Resp "abcd" _ (Ids rId sId srvDh) <- signSendRecv rh1 rKey ("abcd", "", NEW rPub dhPub)
|
||||
let dec nonce = C.cbDecrypt (C.dh' srvDh dhPriv) (C.cbNonce nonce)
|
||||
Resp "bcda" _ ok1 <- sendRecv sh ("", "bcda", sId, SEND "test1")
|
||||
let dec = decryptMsgV3 $ C.dh' srvDh dhPriv
|
||||
Resp "bcda" _ ok1 <- sendRecv sh ("", "bcda", sId, _SEND "test1")
|
||||
(ok1, OK) #== "sent test message 1"
|
||||
Resp "cdab" _ ok2 <- sendRecv sh ("", "cdab", sId, SEND "test2, no ACK")
|
||||
Resp "cdab" _ ok2 <- sendRecv sh ("", "cdab", sId, _SEND "test2, no ACK")
|
||||
(ok2, OK) #== "sent test message 2"
|
||||
|
||||
Resp "" _ (MSG mId1 _ msg1) <- tGet rh1
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet rh1
|
||||
(dec mId1 msg1, Right "test1") #== "test message 1 delivered to the 1st TCP connection"
|
||||
Resp "abcd" _ (MSG mId2 _ msg2) <- signSendRecv rh1 rKey ("abcd", rId, ACK)
|
||||
Resp "abcd" _ (Msg mId2 msg2) <- signSendRecv rh1 rKey ("abcd", rId, ACK mId1)
|
||||
(dec mId2 msg2, Right "test2, no ACK") #== "test message 2 delivered, no ACK"
|
||||
|
||||
Resp "bcda" _ (MSG mId2' _ msg2') <- signSendRecv rh2 rKey ("bcda", rId, SUB)
|
||||
Resp "bcda" _ (Msg mId2' msg2') <- signSendRecv rh2 rKey ("bcda", rId, SUB)
|
||||
(dec mId2' msg2', Right "test2, no ACK") #== "same simplex queue via another TCP connection, tes2 delivered again (no ACK in 1st queue)"
|
||||
Resp "cdab" _ OK <- signSendRecv rh2 rKey ("cdab", rId, ACK)
|
||||
Resp "cdab" _ OK <- signSendRecv rh2 rKey ("cdab", rId, ACK mId2')
|
||||
|
||||
Resp "" _ end <- tGet rh1
|
||||
(end, END) #== "unsubscribed the 1st TCP connection"
|
||||
|
||||
Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, SEND "test3")
|
||||
Resp "dabc" _ OK <- sendRecv sh ("", "dabc", sId, _SEND "test3")
|
||||
|
||||
Resp "" _ (MSG mId3 _ msg3) <- tGet rh2
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet rh2
|
||||
(dec mId3 msg3, Right "test3") #== "delivered to the 2nd TCP connection"
|
||||
|
||||
Resp "abcd" _ err <- signSendRecv rh1 rKey ("abcd", rId, ACK)
|
||||
Resp "abcd" _ err <- signSendRecv rh1 rKey ("abcd", rId, ACK mId3)
|
||||
(err, ERR NO_MSG) #== "rejects ACK from the 1st TCP connection"
|
||||
|
||||
Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, ACK)
|
||||
Resp "bcda" _ ok3 <- signSendRecv rh2 rKey ("bcda", rId, ACK mId3)
|
||||
(ok3, OK) #== "accepts ACK from the 2nd TCP connection"
|
||||
|
||||
1000 `timeout` tGet @BrokerMsg rh1 >>= \case
|
||||
Nothing -> return ()
|
||||
Just _ -> error "nothing else is delivered to the 1st TCP connection"
|
||||
|
||||
testGetCommand :: forall c. Transport c => TProxy c -> Spec
|
||||
testGetCommand t =
|
||||
it "should retrieve messages from the queue using GET command" $ do
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
smpTest t $ \sh -> do
|
||||
queue <- newEmptyTMVarIO
|
||||
testSMPClient @c $ \rh ->
|
||||
atomically . putTMVar queue =<< createAndSecureQueue rh sPub
|
||||
testSMPClient @c $ \rh -> do
|
||||
(sId, rId, rKey, dhShared) <- atomically $ takeTMVar queue
|
||||
let dec = decryptMsgV3 dhShared
|
||||
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello")
|
||||
Resp "2" _ (Msg mId1 msg1) <- signSendRecv rh rKey ("2", rId, GET)
|
||||
(dec mId1 msg1, Right "hello") #== "retrieved from queue"
|
||||
Resp "3" _ OK <- signSendRecv rh rKey ("3", rId, ACK mId1)
|
||||
Resp "4" _ OK <- signSendRecv rh rKey ("4", rId, GET)
|
||||
pure ()
|
||||
|
||||
testGetSubCommands :: forall c. Transport c => TProxy c -> Spec
|
||||
testGetSubCommands t =
|
||||
it "should retrieve messages with GET and receive with SUB, only one ACK would work" $ do
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
smpTest3 t $ \rh1 rh2 sh -> do
|
||||
(sId, rId, rKey, dhShared) <- createAndSecureQueue rh1 sPub
|
||||
let dec = decryptMsgV3 dhShared
|
||||
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello 1")
|
||||
Resp "1a" _ OK <- signSendRecv sh sKey ("1a", sId, _SEND "hello 2")
|
||||
Resp "1b" _ OK <- signSendRecv sh sKey ("1b", sId, _SEND "hello 3")
|
||||
Resp "1c" _ OK <- signSendRecv sh sKey ("1c", sId, _SEND "hello 4")
|
||||
-- both get the same if not ACK'd
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet rh1
|
||||
Resp "2" _ (Msg mId1' msg1') <- signSendRecv rh2 rKey ("2", rId, GET)
|
||||
(dec mId1 msg1, Right "hello 1") #== "received from queue via SUB"
|
||||
(dec mId1' msg1', Right "hello 1") #== "retrieved from queue with GET"
|
||||
mId1 `shouldBe` mId1'
|
||||
msg1 `shouldBe` msg1'
|
||||
-- subscriber cannot GET, getter cannot SUB
|
||||
Resp "3" _ (ERR (CMD PROHIBITED)) <- signSendRecv rh1 rKey ("3", rId, GET)
|
||||
Resp "3a" _ (ERR (CMD PROHIBITED)) <- signSendRecv rh2 rKey ("3a", rId, SUB)
|
||||
-- ACK for SUB delivers the next message
|
||||
Resp "4" _ (Msg mId2 msg2) <- signSendRecv rh1 rKey ("4", rId, ACK mId1)
|
||||
(dec mId2 msg2, Right "hello 2") #== "received from queue via SUB"
|
||||
-- bad msgId returns error
|
||||
Resp "5" _ (ERR NO_MSG) <- signSendRecv rh2 rKey ("5", rId, ACK "1234")
|
||||
-- already ACK'd by subscriber, but still returns OK when msgId matches
|
||||
Resp "5a" _ OK <- signSendRecv rh2 rKey ("5a", rId, ACK mId1)
|
||||
-- msg2 is not lost - even if subscriber does not ACK it, it is delivered to getter
|
||||
Resp "6" _ (Msg mId2' msg2') <- signSendRecv rh2 rKey ("6", rId, GET)
|
||||
(dec mId2' msg2', Right "hello 2") #== "retrieved from queue with GET"
|
||||
mId2 `shouldBe` mId2'
|
||||
msg2 `shouldBe` msg2'
|
||||
-- getter ACK returns OK, even though there is the next message
|
||||
Resp "7" _ OK <- signSendRecv rh2 rKey ("7", rId, ACK mId2')
|
||||
Resp "8" _ (Msg mId3 msg3) <- signSendRecv rh2 rKey ("8", rId, GET)
|
||||
(dec mId3 msg3, Right "hello 3") #== "retrieved from queue with GET"
|
||||
-- subscriber ACK does not lose message
|
||||
Resp "9" _ (Msg mId3' msg3') <- signSendRecv rh1 rKey ("9", rId, ACK mId2')
|
||||
(dec mId3' msg3', Right "hello 3") #== "retrieved from queue with GET"
|
||||
mId3 `shouldBe` mId3'
|
||||
msg3 `shouldBe` msg3'
|
||||
Resp "10" _ (Msg mId4 msg4) <- signSendRecv rh1 rKey ("10", rId, ACK mId3)
|
||||
(dec mId4 msg4, Right "hello 4") #== "retrieved from queue with GET"
|
||||
Resp "11" _ OK <- signSendRecv rh1 rKey ("11", rId, ACK mId4)
|
||||
-- no more messages for getter too
|
||||
Resp "12" _ OK <- signSendRecv rh2 rKey ("12", rId, GET)
|
||||
pure ()
|
||||
|
||||
testWithStoreLog :: ATransport -> Spec
|
||||
testWithStoreLog at@(ATransport t) =
|
||||
it "should store simplex queues to log and restore them after server restart" $ do
|
||||
|
@ -329,7 +493,8 @@ testWithStoreLog at@(ATransport t) =
|
|||
|
||||
withSmpServerStoreLogOn at testPort . runTest t $ \h -> runClient t $ \h1 -> do
|
||||
(sId1, rId1, rKey1, dhShared) <- createAndSecureQueue h sPub1
|
||||
Resp "abcd" _ (NID nId) <- signSendRecv h rKey1 ("abcd", rId1, NKEY nPub)
|
||||
(rcvNtfPubDhKey, _) <- C.generateKeyPair'
|
||||
Resp "abcd" _ (NID nId _) <- signSendRecv h rKey1 ("abcd", rId1, NKEY nPub rcvNtfPubDhKey)
|
||||
atomically $ do
|
||||
writeTVar recipientId1 rId1
|
||||
writeTVar recipientKey1 $ Just rKey1
|
||||
|
@ -337,26 +502,26 @@ testWithStoreLog at@(ATransport t) =
|
|||
writeTVar senderId1 sId1
|
||||
writeTVar notifierId nId
|
||||
Resp "dabc" _ OK <- signSendRecv h1 nKey ("dabc", nId, NSUB)
|
||||
Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, SEND "hello")
|
||||
Resp "" _ (MSG mId1 _ msg1) <- tGet h
|
||||
(C.cbDecrypt dhShared (C.cbNonce mId1) msg1, Right "hello") #== "delivered from queue 1"
|
||||
Resp "" _ NMSG <- tGet h1
|
||||
Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello")
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet h
|
||||
(decryptMsgV3 dhShared mId1 msg1, Right "hello") #== "delivered from queue 1"
|
||||
Resp "" _ (NMSG _ _) <- tGet h1
|
||||
|
||||
(sId2, rId2, rKey2, dhShared2) <- createAndSecureQueue h sPub2
|
||||
atomically $ writeTVar senderId2 sId2
|
||||
Resp "cdab" _ OK <- signSendRecv h sKey2 ("cdab", sId2, SEND "hello too")
|
||||
Resp "" _ (MSG mId2 _ msg2) <- tGet h
|
||||
(C.cbDecrypt dhShared2 (C.cbNonce mId2) msg2, Right "hello too") #== "delivered from queue 2"
|
||||
Resp "cdab" _ OK <- signSendRecv h sKey2 ("cdab", sId2, _SEND "hello too")
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet h
|
||||
(decryptMsgV3 dhShared2 mId2 msg2, Right "hello too") #== "delivered from queue 2"
|
||||
|
||||
Resp "dabc" _ OK <- signSendRecv h rKey2 ("dabc", rId2, DEL)
|
||||
pure ()
|
||||
|
||||
logSize `shouldReturn` 6
|
||||
logSize testStoreLogFile `shouldReturn` 6
|
||||
|
||||
withSmpServerThreadOn at testPort . runTest t $ \h -> do
|
||||
sId1 <- readTVarIO senderId1
|
||||
-- fails if store log is disabled
|
||||
Resp "bcda" _ (ERR AUTH) <- signSendRecv h sKey1 ("bcda", sId1, SEND "hello")
|
||||
Resp "bcda" _ (ERR AUTH) <- signSendRecv h sKey1 ("bcda", sId1, _SEND "hello")
|
||||
pure ()
|
||||
|
||||
withSmpServerStoreLogOn at testPort . runTest t $ \h -> runClient t $ \h1 -> do
|
||||
|
@ -367,16 +532,16 @@ testWithStoreLog at@(ATransport t) =
|
|||
sId1 <- readTVarIO senderId1
|
||||
nId <- readTVarIO notifierId
|
||||
Resp "dabc" _ OK <- signSendRecv h1 nKey ("dabc", nId, NSUB)
|
||||
Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, SEND "hello")
|
||||
Resp "cdab" _ (MSG mId3 _ msg3) <- signSendRecv h rKey1 ("cdab", rId1, SUB)
|
||||
(C.cbDecrypt dh1 (C.cbNonce mId3) msg3, Right "hello") #== "delivered from restored queue"
|
||||
Resp "" _ NMSG <- tGet h1
|
||||
Resp "bcda" _ OK <- signSendRecv h sKey1 ("bcda", sId1, _SEND' "hello")
|
||||
Resp "cdab" _ (Msg mId3 msg3) <- signSendRecv h rKey1 ("cdab", rId1, SUB)
|
||||
(decryptMsgV3 dh1 mId3 msg3, Right "hello") #== "delivered from restored queue"
|
||||
Resp "" _ (NMSG _ _) <- tGet h1
|
||||
-- this queue is removed - not restored
|
||||
sId2 <- readTVarIO senderId2
|
||||
Resp "cdab" _ (ERR AUTH) <- signSendRecv h sKey2 ("cdab", sId2, SEND "hello too")
|
||||
Resp "cdab" _ (ERR AUTH) <- signSendRecv h sKey2 ("cdab", sId2, _SEND "hello too")
|
||||
pure ()
|
||||
|
||||
logSize `shouldReturn` 1
|
||||
logSize testStoreLogFile `shouldReturn` 1
|
||||
removeFile testStoreLogFile
|
||||
where
|
||||
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
|
||||
|
@ -387,11 +552,149 @@ testWithStoreLog at@(ATransport t) =
|
|||
runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
|
||||
runClient _ test' = testSMPClient test' `shouldReturn` ()
|
||||
|
||||
logSize :: IO Int
|
||||
logSize =
|
||||
try (length . B.lines <$> B.readFile testStoreLogFile) >>= \case
|
||||
Right l -> pure l
|
||||
Left (_ :: SomeException) -> logSize
|
||||
logSize :: FilePath -> IO Int
|
||||
logSize f =
|
||||
try (length . B.lines <$> B.readFile f) >>= \case
|
||||
Right l -> pure l
|
||||
Left (_ :: SomeException) -> logSize f
|
||||
|
||||
testRestoreMessages :: ATransport -> Spec
|
||||
testRestoreMessages at@(ATransport t) =
|
||||
it "should store messages on exit and restore on start" $ do
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
recipientId <- newTVarIO ""
|
||||
recipientKey <- newTVarIO Nothing
|
||||
dhShared <- newTVarIO Nothing
|
||||
senderId <- newTVarIO ""
|
||||
|
||||
withSmpServerStoreMsgLogOn at testPort . runTest t $ \h -> do
|
||||
runClient t $ \h1 -> do
|
||||
(sId, rId, rKey, dh) <- createAndSecureQueue h1 sPub
|
||||
atomically $ do
|
||||
writeTVar recipientId rId
|
||||
writeTVar recipientKey $ Just rKey
|
||||
writeTVar dhShared $ Just dh
|
||||
writeTVar senderId sId
|
||||
Resp "1" _ OK <- signSendRecv h sKey ("1", sId, _SEND "hello")
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet h1
|
||||
Resp "1a" _ OK <- signSendRecv h1 rKey ("1a", rId, ACK mId1)
|
||||
(decryptMsgV3 dh mId1 msg1, Right "hello") #== "message delivered"
|
||||
-- messages below are delivered after server restart
|
||||
sId <- readTVarIO senderId
|
||||
Resp "2" _ OK <- signSendRecv h sKey ("2", sId, _SEND "hello 2")
|
||||
Resp "3" _ OK <- signSendRecv h sKey ("3", sId, _SEND "hello 3")
|
||||
Resp "4" _ OK <- signSendRecv h sKey ("4", sId, _SEND "hello 4")
|
||||
pure ()
|
||||
|
||||
logSize testStoreLogFile `shouldReturn` 2
|
||||
logSize testStoreMsgsFile `shouldReturn` 3
|
||||
|
||||
withSmpServerStoreMsgLogOn at testPort . runTest t $ \h -> do
|
||||
rId <- readTVarIO recipientId
|
||||
Just rKey <- readTVarIO recipientKey
|
||||
Just dh <- readTVarIO dhShared
|
||||
let dec = decryptMsgV3 dh
|
||||
Resp "2" _ (Msg mId2 msg2) <- signSendRecv h rKey ("2", rId, SUB)
|
||||
(dec mId2 msg2, Right "hello 2") #== "restored message delivered"
|
||||
Resp "3" _ (Msg mId3 msg3) <- signSendRecv h rKey ("3", rId, ACK mId2)
|
||||
(dec mId3 msg3, Right "hello 3") #== "restored message delivered"
|
||||
Resp "4" _ (Msg mId4 msg4) <- signSendRecv h rKey ("4", rId, ACK mId3)
|
||||
(dec mId4 msg4, Right "hello 4") #== "restored message delivered"
|
||||
|
||||
logSize testStoreLogFile `shouldReturn` 1
|
||||
-- the last message is not removed because it was not ACK'd
|
||||
logSize testStoreMsgsFile `shouldReturn` 1
|
||||
|
||||
withSmpServerStoreMsgLogOn at testPort . runTest t $ \h -> do
|
||||
rId <- readTVarIO recipientId
|
||||
Just rKey <- readTVarIO recipientKey
|
||||
Just dh <- readTVarIO dhShared
|
||||
Resp "4" _ (Msg mId4 msg4) <- signSendRecv h rKey ("4", rId, SUB)
|
||||
Resp "5" _ OK <- signSendRecv h rKey ("5", rId, ACK mId4)
|
||||
(decryptMsgV3 dh mId4 msg4, Right "hello 4") #== "restored message delivered"
|
||||
|
||||
logSize testStoreLogFile `shouldReturn` 1
|
||||
logSize testStoreMsgsFile `shouldReturn` 0
|
||||
|
||||
removeFile testStoreLogFile
|
||||
removeFile testStoreMsgsFile
|
||||
where
|
||||
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
|
||||
runTest _ test' server = do
|
||||
testSMPClient test' `shouldReturn` ()
|
||||
killThread server
|
||||
|
||||
runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
|
||||
runClient _ test' = testSMPClient test' `shouldReturn` ()
|
||||
|
||||
testRestoreMessagesV2 :: ATransport -> Spec
|
||||
testRestoreMessagesV2 at@(ATransport t) =
|
||||
it "should store messages on exit and restore on start" $ do
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
recipientId <- newTVarIO ""
|
||||
recipientKey <- newTVarIO Nothing
|
||||
dhShared <- newTVarIO Nothing
|
||||
senderId <- newTVarIO ""
|
||||
|
||||
withSmpServerStoreMsgLogOnV2 at testPort . runTest t $ \h -> do
|
||||
runClient t $ \h1 -> do
|
||||
(sId, rId, rKey, dh) <- createAndSecureQueue h1 sPub
|
||||
atomically $ do
|
||||
writeTVar recipientId rId
|
||||
writeTVar recipientKey $ Just rKey
|
||||
writeTVar dhShared $ Just dh
|
||||
writeTVar senderId sId
|
||||
Resp "1" _ OK <- signSendRecv h sKey ("1", sId, _SEND "hello")
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet h1
|
||||
Resp "1a" _ OK <- signSendRecv h1 rKey ("1a", rId, ACK mId1)
|
||||
(decryptMsgV2 dh mId1 msg1, Right "hello") #== "message delivered"
|
||||
-- messages below are delivered after server restart
|
||||
sId <- readTVarIO senderId
|
||||
Resp "2" _ OK <- signSendRecv h sKey ("2", sId, _SEND "hello 2")
|
||||
Resp "3" _ OK <- signSendRecv h sKey ("3", sId, _SEND "hello 3")
|
||||
Resp "4" _ OK <- signSendRecv h sKey ("4", sId, _SEND "hello 4")
|
||||
pure ()
|
||||
|
||||
logSize testStoreLogFile `shouldReturn` 2
|
||||
logSize testStoreMsgsFile `shouldReturn` 3
|
||||
|
||||
withSmpServerStoreMsgLogOnV2 at testPort . runTest t $ \h -> do
|
||||
rId <- readTVarIO recipientId
|
||||
Just rKey <- readTVarIO recipientKey
|
||||
Just dh <- readTVarIO dhShared
|
||||
let dec = decryptMsgV2 dh
|
||||
Resp "2" _ (Msg mId2 msg2) <- signSendRecv h rKey ("2", rId, SUB)
|
||||
(dec mId2 msg2, Right "hello 2") #== "restored message delivered"
|
||||
Resp "3" _ (Msg mId3 msg3) <- signSendRecv h rKey ("3", rId, ACK mId2)
|
||||
(dec mId3 msg3, Right "hello 3") #== "restored message delivered"
|
||||
Resp "4" _ (Msg mId4 msg4) <- signSendRecv h rKey ("4", rId, ACK mId3)
|
||||
(dec mId4 msg4, Right "hello 4") #== "restored message delivered"
|
||||
|
||||
logSize testStoreLogFile `shouldReturn` 1
|
||||
-- the last message is not removed because it was not ACK'd
|
||||
logSize testStoreMsgsFile `shouldReturn` 1
|
||||
|
||||
withSmpServerStoreMsgLogOnV2 at testPort . runTest t $ \h -> do
|
||||
rId <- readTVarIO recipientId
|
||||
Just rKey <- readTVarIO recipientKey
|
||||
Just dh <- readTVarIO dhShared
|
||||
Resp "4" _ (Msg mId4 msg4) <- signSendRecv h rKey ("4", rId, SUB)
|
||||
Resp "5" _ OK <- signSendRecv h rKey ("5", rId, ACK mId4)
|
||||
(decryptMsgV2 dh mId4 msg4, Right "hello 4") #== "restored message delivered"
|
||||
|
||||
logSize testStoreLogFile `shouldReturn` 1
|
||||
logSize testStoreMsgsFile `shouldReturn` 0
|
||||
|
||||
removeFile testStoreLogFile
|
||||
removeFile testStoreMsgsFile
|
||||
where
|
||||
runTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> ThreadId -> Expectation
|
||||
runTest _ test' server = do
|
||||
testSMPClient test' `shouldReturn` ()
|
||||
killThread server
|
||||
|
||||
runClient :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
|
||||
runClient _ test' = testSMPClient test' `shouldReturn` ()
|
||||
|
||||
createAndSecureQueue :: Transport c => THandle c -> SndPublicVerifyKey -> IO (SenderId, RecipientId, RcvPrivateSignKey, RcvDhSecret)
|
||||
createAndSecureQueue h sPub = do
|
||||
|
@ -422,7 +725,7 @@ testTiming (ATransport t) =
|
|||
(rPub, rKey) <- generateKeys goodKeySize
|
||||
(dhPub, dhPriv :: C.PrivateKeyX25519) <- C.generateKeyPair'
|
||||
Resp "abcd" "" (Ids rId sId srvDh) <- signSendRecv rh rKey ("abcd", "", NEW rPub dhPub)
|
||||
let dec nonce = C.cbDecrypt (C.dh' srvDh dhPriv) (C.cbNonce nonce)
|
||||
let dec = decryptMsgV3 $ C.dh' srvDh dhPriv
|
||||
Resp "cdab" _ OK <- signSendRecv rh rKey ("cdab", rId, SUB)
|
||||
|
||||
(_, badKey) <- generateKeys badKeySize
|
||||
|
@ -431,11 +734,11 @@ testTiming (ATransport t) =
|
|||
(sPub, sKey) <- generateKeys goodKeySize
|
||||
Resp "dabc" _ OK <- signSendRecv rh rKey ("dabc", rId, KEY sPub)
|
||||
|
||||
Resp "bcda" _ OK <- signSendRecv sh sKey ("bcda", sId, SEND "hello")
|
||||
Resp "" _ (MSG mId _ msg) <- tGet rh
|
||||
Resp "bcda" _ OK <- signSendRecv sh sKey ("bcda", sId, _SEND "hello")
|
||||
Resp "" _ (Msg mId msg) <- tGet rh
|
||||
(dec mId msg, Right "hello") #== "delivered from queue"
|
||||
|
||||
runTimingTest sh badKey sId $ SEND "hello"
|
||||
runTimingTest sh badKey sId $ _SEND "hello"
|
||||
where
|
||||
generateKeys = \case
|
||||
32 -> C.generateSignatureKeyPair C.SEd25519
|
||||
|
@ -464,37 +767,49 @@ testMessageNotifications (ATransport t) =
|
|||
(nPub, nKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
smpTest4 t $ \rh sh nh1 nh2 -> do
|
||||
(sId, rId, rKey, dhShared) <- createAndSecureQueue rh sPub
|
||||
let dec nonce = C.cbDecrypt dhShared (C.cbNonce nonce)
|
||||
Resp "1" _ (NID nId) <- signSendRecv rh rKey ("1", rId, NKEY nPub)
|
||||
let dec = decryptMsgV3 dhShared
|
||||
(rcvNtfPubDhKey, _) <- C.generateKeyPair'
|
||||
Resp "1" _ (NID nId' _) <- signSendRecv rh rKey ("1", rId, NKEY nPub rcvNtfPubDhKey)
|
||||
Resp "1a" _ (NID nId _) <- signSendRecv rh rKey ("1a", rId, NKEY nPub rcvNtfPubDhKey)
|
||||
nId' `shouldNotBe` nId
|
||||
Resp "2" _ OK <- signSendRecv nh1 nKey ("2", nId, NSUB)
|
||||
Resp "3" _ OK <- signSendRecv sh sKey ("3", sId, SEND "hello")
|
||||
Resp "" _ (MSG mId1 _ msg1) <- tGet rh
|
||||
Resp "3" _ OK <- signSendRecv sh sKey ("3", sId, _SEND' "hello")
|
||||
Resp "" _ (Msg mId1 msg1) <- tGet rh
|
||||
(dec mId1 msg1, Right "hello") #== "delivered from queue"
|
||||
Resp "3a" _ OK <- signSendRecv rh rKey ("3a", rId, ACK)
|
||||
Resp "" _ NMSG <- tGet nh1
|
||||
Resp "3a" _ OK <- signSendRecv rh rKey ("3a", rId, ACK mId1)
|
||||
Resp "" _ (NMSG _ _) <- tGet nh1
|
||||
Resp "4" _ OK <- signSendRecv nh2 nKey ("4", nId, NSUB)
|
||||
Resp "" _ END <- tGet nh1
|
||||
Resp "5" _ OK <- signSendRecv sh sKey ("5", sId, SEND "hello again")
|
||||
Resp "" _ (MSG mId2 _ msg2) <- tGet rh
|
||||
Resp "5" _ OK <- signSendRecv sh sKey ("5", sId, _SEND' "hello again")
|
||||
Resp "" _ (Msg mId2 msg2) <- tGet rh
|
||||
Resp "5a" _ OK <- signSendRecv rh rKey ("5a", rId, ACK mId2)
|
||||
(dec mId2 msg2, Right "hello again") #== "delivered from queue again"
|
||||
Resp "" _ NMSG <- tGet nh2
|
||||
Resp "" _ (NMSG _ _) <- tGet nh2
|
||||
1000 `timeout` tGet @BrokerMsg nh1 >>= \case
|
||||
Nothing -> return ()
|
||||
Nothing -> pure ()
|
||||
Just _ -> error "nothing else should be delivered to the 1st notifier's TCP connection"
|
||||
Resp "6" _ OK <- signSendRecv rh rKey ("6", rId, NDEL)
|
||||
Resp "7" _ OK <- signSendRecv sh sKey ("7", sId, _SEND' "hello there")
|
||||
Resp "" _ (Msg mId3 msg3) <- tGet rh
|
||||
(dec mId3 msg3, Right "hello there") #== "delivered from queue again"
|
||||
1000 `timeout` tGet @BrokerMsg nh2 >>= \case
|
||||
Nothing -> pure ()
|
||||
Just _ -> error "nothing else should be delivered to the 2nd notifier's TCP connection"
|
||||
|
||||
testMsgExpireOnSend :: forall c. Transport c => TProxy c -> Spec
|
||||
testMsgExpireOnSend t =
|
||||
it "should expire messages that are not received before messageTTL on SEND" $ do
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
withSmpServerConfigOn (ATransport t) cfg {messageTTL = Just 1} testPort $ \_ ->
|
||||
let cfg' = cfg {messageExpiration = Just ExpirationConfig {ttl = 1, checkInterval = 10000}}
|
||||
withSmpServerConfigOn (ATransport t) cfg' testPort $ \_ ->
|
||||
testSMPClient @c $ \sh -> do
|
||||
(sId, rId, rKey, dhShared) <- testSMPClient @c $ \rh -> createAndSecureQueue rh sPub
|
||||
let dec nonce = C.cbDecrypt dhShared (C.cbNonce nonce)
|
||||
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, SEND "hello (should expire)")
|
||||
let dec = decryptMsgV3 dhShared
|
||||
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello (should expire)")
|
||||
threadDelay 2500000
|
||||
Resp "2" _ OK <- signSendRecv sh sKey ("2", sId, SEND "hello (should NOT expire)")
|
||||
Resp "2" _ OK <- signSendRecv sh sKey ("2", sId, _SEND "hello (should NOT expire)")
|
||||
testSMPClient @c $ \rh -> do
|
||||
Resp "3" _ (MSG mId _ msg) <- signSendRecv rh rKey ("3", rId, SUB)
|
||||
Resp "3" _ (Msg mId msg) <- signSendRecv rh rKey ("3", rId, SUB)
|
||||
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
|
||||
1000 `timeout` tGet @BrokerMsg rh >>= \case
|
||||
Nothing -> return ()
|
||||
|
@ -504,10 +819,11 @@ testMsgExpireOnInterval :: forall c. Transport c => TProxy c -> Spec
|
|||
testMsgExpireOnInterval t =
|
||||
it "should expire messages that are not received before messageTTL after expiry interval" $ do
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
withSmpServerConfigOn (ATransport t) cfg {messageTTL = Just 1, expireMessagesInterval = Just 1000000} testPort $ \_ ->
|
||||
let cfg' = cfg {messageExpiration = Just ExpirationConfig {ttl = 1, checkInterval = 1}}
|
||||
withSmpServerConfigOn (ATransport t) cfg' testPort $ \_ ->
|
||||
testSMPClient @c $ \sh -> do
|
||||
(sId, rId, rKey, _) <- testSMPClient @c $ \rh -> createAndSecureQueue rh sPub
|
||||
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, SEND "hello (should expire)")
|
||||
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello (should expire)")
|
||||
threadDelay 2500000
|
||||
testSMPClient @c $ \rh -> do
|
||||
Resp "2" _ OK <- signSendRecv rh rKey ("2", rId, SUB)
|
||||
|
@ -517,16 +833,17 @@ testMsgExpireOnInterval t =
|
|||
|
||||
testMsgNOTExpireOnInterval :: forall c. Transport c => TProxy c -> Spec
|
||||
testMsgNOTExpireOnInterval t =
|
||||
it "should NOT expire messages that are not received before messageTTL if expiry interval is not set" $ do
|
||||
it "should NOT expire messages that are not received before messageTTL if expiry interval is large" $ do
|
||||
(sPub, sKey) <- C.generateSignatureKeyPair C.SEd25519
|
||||
withSmpServerConfigOn (ATransport t) cfg {messageTTL = Just 1, expireMessagesInterval = Nothing} testPort $ \_ ->
|
||||
let cfg' = cfg {messageExpiration = Just ExpirationConfig {ttl = 1, checkInterval = 10000}}
|
||||
withSmpServerConfigOn (ATransport t) cfg' testPort $ \_ ->
|
||||
testSMPClient @c $ \sh -> do
|
||||
(sId, rId, rKey, dhShared) <- testSMPClient @c $ \rh -> createAndSecureQueue rh sPub
|
||||
let dec nonce = C.cbDecrypt dhShared (C.cbNonce nonce)
|
||||
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, SEND "hello (should NOT expire)")
|
||||
let dec = decryptMsgV3 dhShared
|
||||
Resp "1" _ OK <- signSendRecv sh sKey ("1", sId, _SEND "hello (should NOT expire)")
|
||||
threadDelay 2500000
|
||||
testSMPClient @c $ \rh -> do
|
||||
Resp "2" _ (MSG mId _ msg) <- signSendRecv rh rKey ("2", rId, SUB)
|
||||
Resp "2" _ (Msg mId msg) <- signSendRecv rh rKey ("2", rId, SUB)
|
||||
(dec mId msg, Right "hello (should NOT expire)") #== "delivered"
|
||||
1000 `timeout` tGet @BrokerMsg rh >>= \case
|
||||
Nothing -> return ()
|
||||
|
@ -556,13 +873,17 @@ syntaxTests (ATransport t) = do
|
|||
it "no signature" $ ("", "abcd", "12345678", (KEY_, ' ', samplePubKey)) >#> ("", "abcd", "12345678", ERR $ CMD NO_AUTH)
|
||||
it "no queue ID" $ (sampleSig, "bcda", "", (KEY_, ' ', samplePubKey)) >#> ("", "bcda", "", ERR $ CMD NO_AUTH)
|
||||
noParamsSyntaxTest "SUB" SUB_
|
||||
noParamsSyntaxTest "ACK" ACK_
|
||||
noParamsSyntaxTest "OFF" OFF_
|
||||
noParamsSyntaxTest "DEL" DEL_
|
||||
describe "SEND" $ do
|
||||
it "valid syntax" $ (sampleSig, "cdab", "12345678", (SEND_, ' ', "hello" :: ByteString)) >#> ("", "cdab", "12345678", ERR AUTH)
|
||||
it "valid syntax" $ (sampleSig, "cdab", "12345678", (SEND_, ' ', noMsgFlags, ' ', "hello" :: ByteString)) >#> ("", "cdab", "12345678", ERR AUTH)
|
||||
it "no parameters" $ (sampleSig, "abcd", "12345678", SEND_) >#> ("", "abcd", "12345678", ERR $ CMD SYNTAX)
|
||||
it "no queue ID" $ (sampleSig, "bcda", "", (SEND_, ' ', "hello" :: ByteString)) >#> ("", "bcda", "", ERR $ CMD NO_QUEUE)
|
||||
it "no queue ID" $ (sampleSig, "bcda", "", (SEND_, ' ', noMsgFlags, ' ', "hello" :: ByteString)) >#> ("", "bcda", "", ERR $ CMD NO_ENTITY)
|
||||
describe "ACK" $ do
|
||||
it "valid syntax" $ (sampleSig, "cdab", "12345678", (ACK_, ' ', "1234" :: ByteString)) >#> ("", "cdab", "12345678", ERR AUTH)
|
||||
it "no parameters" $ (sampleSig, "abcd", "12345678", ACK_) >#> ("", "abcd", "12345678", ERR $ CMD SYNTAX)
|
||||
it "no queue ID" $ (sampleSig, "bcda", "", (ACK_, ' ', "1234" :: ByteString)) >#> ("", "bcda", "", ERR $ CMD NO_AUTH)
|
||||
it "no signature" $ ("", "cdab", "12345678", (ACK_, ' ', "1234" :: ByteString)) >#> ("", "cdab", "12345678", ERR $ CMD NO_AUTH)
|
||||
describe "PING" $ do
|
||||
it "valid syntax" $ ("", "abcd", "", PING_) >#> ("", "abcd", "", PONG)
|
||||
describe "broker response not allowed" $ do
|
||||
|
|
|
@ -4,15 +4,19 @@ import AgentTests (agentTests)
|
|||
import CoreTests.EncodingTests
|
||||
import CoreTests.ProtocolErrorTests
|
||||
import CoreTests.VersionRangeTests
|
||||
import NtfServerTests (ntfServerTests)
|
||||
import ServerTests
|
||||
import Simplex.Messaging.Transport (TLS, Transport (..))
|
||||
import Simplex.Messaging.Transport.WebSockets (WS)
|
||||
import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive)
|
||||
import System.Environment (setEnv)
|
||||
import Test.Hspec
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
createDirectoryIfMissing False "tests/tmp"
|
||||
setEnv "APNS_KEY_ID" "H82WD9K9AQ"
|
||||
setEnv "APNS_KEY_FILE" "./tests/fixtures/AuthKey_H82WD9K9AQ.p8"
|
||||
hspec $ do
|
||||
describe "Core tests" $ do
|
||||
describe "Encoding tests" encodingTests
|
||||
|
@ -20,5 +24,6 @@ main = do
|
|||
describe "Version range" versionRangeTests
|
||||
describe "SMP server via TLS" $ serverTests (transport @TLS)
|
||||
describe "SMP server via WebSockets" $ serverTests (transport @WS)
|
||||
describe "Notifications server" $ ntfServerTests (transport @TLS)
|
||||
describe "SMP client agent" $ agentTests (transport @TLS)
|
||||
removeDirectoryRecursive "tests/tmp"
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
-----BEGIN PRIVATE KEY-----
|
||||
MIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgWuPap5jF6eioxuHM
|
||||
XWZWUK78LdcxkTnMXWg2GqyXuBugCgYIKoZIzj0DAQehRANCAAQn64CvAIbEEzvM
|
||||
KwYjlOxVD5SxlgP1ZcYvVM/+VHLFu0aCkG7ueICTi3qWyqoB5hjjuAqwtc3EK0q0
|
||||
yupyM7Yx
|
||||
-----END PRIVATE KEY-----
|
Reference in New Issue