test notification token with agent and notifications server (#353)
* test notification token with agent and notifications server * notification server test with APNS mock * set environment variables in the test * use base64url encoding in encrypted notification data
This commit is contained in:
parent
9d8a9c4fe4
commit
17888f89a9
|
@ -74,7 +74,9 @@ library
|
|||
Simplex.Messaging.TMap
|
||||
Simplex.Messaging.Transport
|
||||
Simplex.Messaging.Transport.Client
|
||||
Simplex.Messaging.Transport.Client.HTTP2
|
||||
Simplex.Messaging.Transport.HTTP2
|
||||
Simplex.Messaging.Transport.HTTP2.Client
|
||||
Simplex.Messaging.Transport.HTTP2.Server
|
||||
Simplex.Messaging.Transport.KeepAlive
|
||||
Simplex.Messaging.Transport.Server
|
||||
Simplex.Messaging.Transport.WebSockets
|
||||
|
@ -325,6 +327,7 @@ test-suite smp-server-test
|
|||
AgentTests.ConnectionRequestTests
|
||||
AgentTests.DoubleRatchetTests
|
||||
AgentTests.FunctionalAPITests
|
||||
AgentTests.NotificationTests
|
||||
AgentTests.SQLiteTests
|
||||
CoreTests.EncodingTests
|
||||
CoreTests.ProtocolErrorTests
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
module Simplex.Messaging.Notifications.Server.Env where
|
||||
|
||||
import Control.Concurrent.Async (Async)
|
||||
import Control.Monad (void)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
|
@ -61,7 +62,7 @@ newNtfServerEnv config@NtfServerConfig {subQSize, pushQSize, smpAgentCfg, apnsCo
|
|||
subscriber <- atomically $ newNtfSubscriber subQSize smpAgentCfg
|
||||
pushServer <- atomically $ newNtfPushServer pushQSize apnsConfig
|
||||
-- TODO not creating APNS client on start to pass CI test, has to be replaced with mock APNS server
|
||||
-- void . liftIO $ newPushClient pushServer PPApns
|
||||
void . liftIO $ newPushClient pushServer PPApns
|
||||
tlsServerParams <- liftIO $ loadTLSServerParams caCertificateFile certificateFile privateKeyFile
|
||||
Fingerprint fp <- liftIO $ loadFingerprint caCertificateFile
|
||||
pure NtfEnv {config, subscriber, pushServer, store, idsDrg, tlsServerParams, serverIdentity = C.KeyHash fp}
|
||||
|
|
|
@ -24,7 +24,6 @@ import Data.Aeson (FromJSON, ToJSON, (.=))
|
|||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Encoding as JE
|
||||
import Data.Bifunctor (first)
|
||||
import qualified Data.ByteString.Base64 as B64
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Builder (lazyByteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
|
@ -37,8 +36,8 @@ import qualified Data.Text as T
|
|||
import Data.Text.Encoding (decodeUtf8With)
|
||||
import Data.Time.Clock.System
|
||||
import qualified Data.X509 as X
|
||||
import GHC.Generics
|
||||
import Network.HTTP.Types (HeaderName, Status, hAuthorization, methodPost)
|
||||
import GHC.Generics (Generic)
|
||||
import Network.HTTP.Types (HeaderName, Status)
|
||||
import qualified Network.HTTP.Types as N
|
||||
import Network.HTTP2.Client (Request)
|
||||
import qualified Network.HTTP2.Client as H
|
||||
|
@ -48,7 +47,7 @@ import Simplex.Messaging.Encoding.String
|
|||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Subscriptions (NtfTknData (..))
|
||||
import Simplex.Messaging.Protocol (NotifierId, SMPServer)
|
||||
import Simplex.Messaging.Transport.Client.HTTP2
|
||||
import Simplex.Messaging.Transport.HTTP2.Client
|
||||
import System.Environment (getEnv)
|
||||
import UnliftIO.STM
|
||||
|
||||
|
@ -177,9 +176,9 @@ data APNSPushClientConfig = APNSPushClientConfig
|
|||
paddedNtfLength :: Int,
|
||||
appName :: ByteString,
|
||||
appTeamId :: Text,
|
||||
apnHost :: HostName,
|
||||
apnPort :: ServiceName,
|
||||
https2cfg :: HTTP2SClientConfig
|
||||
apnsHost :: HostName,
|
||||
apnsPort :: ServiceName,
|
||||
http2cfg :: HTTP2ClientConfig
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
|
@ -193,13 +192,13 @@ defaultAPNSPushClientConfig =
|
|||
paddedNtfLength = 256,
|
||||
appName = "chat.simplex.app",
|
||||
appTeamId = "5NN7GUYB6T",
|
||||
apnHost = "api.sandbox.push.apple.com",
|
||||
apnPort = "443",
|
||||
https2cfg = defaultHTTP2SClientConfig
|
||||
apnsHost = "api.sandbox.push.apple.com",
|
||||
apnsPort = "443",
|
||||
http2cfg = defaultHTTP2ClientConfig
|
||||
}
|
||||
|
||||
data APNSPushClient = APNSPushClient
|
||||
{ https2Client :: TVar (Maybe HTTPS2Client),
|
||||
{ https2Client :: TVar (Maybe HTTP2Client),
|
||||
privateKey :: EC.PrivateKey,
|
||||
jwtHeader :: JWTHeader,
|
||||
jwtToken :: TVar (JWTToken, SignedJWTToken),
|
||||
|
@ -213,7 +212,6 @@ createAPNSPushClient apnsCfg@APNSPushClientConfig {authKeyFileEnv, authKeyAlg, a
|
|||
void $ connectHTTPS2 apnsCfg https2Client
|
||||
privateKey <- readECPrivateKey =<< getEnv authKeyFileEnv
|
||||
authKeyId <- T.pack <$> getEnv authKeyIdEnv
|
||||
putStrLn $ authKeyIdEnv <> "=" <> T.unpack authKeyId
|
||||
let jwtHeader = JWTHeader {alg = authKeyAlg, kid = authKeyId}
|
||||
jwtToken <- newTVarIO =<< mkApnsJWTToken appTeamId jwtHeader privateKey
|
||||
nonceDrg <- drgNew >>= newTVarIO
|
||||
|
@ -238,9 +236,9 @@ mkApnsJWTToken appTeamId jwtHeader privateKey = do
|
|||
signedJWT <- signedJWTToken privateKey jwt
|
||||
pure (jwt, signedJWT)
|
||||
|
||||
connectHTTPS2 :: APNSPushClientConfig -> TVar (Maybe HTTPS2Client) -> IO (Either HTTPS2ClientError HTTPS2Client)
|
||||
connectHTTPS2 APNSPushClientConfig {apnHost, apnPort, https2cfg} https2Client = do
|
||||
r <- getHTTPS2Client apnHost apnPort https2cfg disconnected
|
||||
connectHTTPS2 :: APNSPushClientConfig -> TVar (Maybe HTTP2Client) -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
connectHTTPS2 APNSPushClientConfig {apnsHost, apnsPort, http2cfg} https2Client = do
|
||||
r <- getHTTP2Client apnsHost apnsPort http2cfg disconnected
|
||||
case r of
|
||||
Right client -> atomically . writeTVar https2Client $ Just client
|
||||
Left e -> putStrLn $ "Error connecting to APNS: " <> show e
|
||||
|
@ -248,13 +246,13 @@ connectHTTPS2 APNSPushClientConfig {apnHost, apnPort, https2cfg} https2Client =
|
|||
where
|
||||
disconnected = atomically $ writeTVar https2Client Nothing
|
||||
|
||||
getApnsHTTP2Client :: APNSPushClient -> IO (Either HTTPS2ClientError HTTPS2Client)
|
||||
getApnsHTTP2Client :: APNSPushClient -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
getApnsHTTP2Client APNSPushClient {https2Client, apnsCfg} =
|
||||
readTVarIO https2Client >>= maybe (connectHTTPS2 apnsCfg https2Client) (pure . Right)
|
||||
|
||||
disconnectApnsHTTP2Client :: APNSPushClient -> IO ()
|
||||
disconnectApnsHTTP2Client APNSPushClient {https2Client} =
|
||||
readTVarIO https2Client >>= mapM_ closeHTTPS2Client >> atomically (writeTVar https2Client Nothing)
|
||||
readTVarIO https2Client >>= mapM_ closeHTTP2Client >> atomically (writeTVar https2Client Nothing)
|
||||
|
||||
apnsNotification :: NtfTknData -> C.CbNonce -> Int -> PushNotification -> Either C.CryptoError APNSNotification
|
||||
apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case
|
||||
|
@ -268,7 +266,7 @@ apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case
|
|||
PNCheckMessages -> Right $ apn APNSBackground {contentAvailable = 1} . Just $ J.object ["checkMessages" .= True]
|
||||
where
|
||||
encrypt :: ByteString -> (Text -> APNSNotification) -> Either C.CryptoError APNSNotification
|
||||
encrypt ntfData f = f . safeDecodeUtf8 . B64.encode <$> C.cbEncrypt tknDhSecret nonce ntfData paddedLen
|
||||
encrypt ntfData f = f . safeDecodeUtf8 . U.encode <$> C.cbEncrypt tknDhSecret nonce ntfData paddedLen
|
||||
apn aps notificationData = APNSNotification {aps, notificationData}
|
||||
apnMutableContent = APNSMutableContent {mutableContent = 1, alert = APNSAlertText "Encrypted message or some other app event", category = Nothing}
|
||||
apnAlert alert = APNSAlert {alert, badge = Nothing, sound = Nothing, category = Nothing}
|
||||
|
@ -277,13 +275,13 @@ apnsNotification NtfTknData {tknDhSecret} nonce paddedLen = \case
|
|||
apnsRequest :: APNSPushClient -> ByteString -> APNSNotification -> IO Request
|
||||
apnsRequest c tkn ntf@APNSNotification {aps} = do
|
||||
signedJWT <- getApnsJWTToken c
|
||||
pure $ H.requestBuilder methodPost path (headers signedJWT) (lazyByteString $ J.encode ntf)
|
||||
pure $ H.requestBuilder N.methodPost path (headers signedJWT) (lazyByteString $ J.encode ntf)
|
||||
where
|
||||
path = "/3/device/" <> tkn
|
||||
headers signedJWT =
|
||||
[ (hApnsTopic, appName $ apnsCfg (c :: APNSPushClient)),
|
||||
(hApnsPushType, pushType aps),
|
||||
(hAuthorization, "bearer " <> signedJWT)
|
||||
(N.hAuthorization, "bearer " <> signedJWT)
|
||||
]
|
||||
<> [(hApnsPriority, "5") | isBackground aps]
|
||||
isBackground = \case
|
||||
|
@ -294,7 +292,7 @@ apnsRequest c tkn ntf@APNSNotification {aps} = do
|
|||
_ -> "alert"
|
||||
|
||||
data PushProviderError
|
||||
= PPConnection HTTPS2ClientError
|
||||
= PPConnection HTTP2ClientError
|
||||
| PPCryptoError C.CryptoError
|
||||
| PPResponseError (Maybe Status) Text
|
||||
| PPTokenInvalid
|
||||
|
@ -305,7 +303,7 @@ data PushProviderError
|
|||
type PushProviderClient = NtfTknData -> PushNotification -> ExceptT PushProviderError IO ()
|
||||
|
||||
-- this is not a newtype on purpose to have a correct JSON encoding as a record
|
||||
data APNSErrorReponse = APNSErrorReponse {reason :: Text}
|
||||
data APNSErrorResponse = APNSErrorResponse {reason :: Text}
|
||||
deriving (Generic, FromJSON)
|
||||
|
||||
apnsPushProviderClient :: APNSPushClient -> PushProviderClient
|
||||
|
@ -316,7 +314,7 @@ apnsPushProviderClient c@APNSPushClient {nonceDrg, apnsCfg} tkn@NtfTknData {toke
|
|||
req <- liftIO $ apnsRequest c tknStr apnsNtf
|
||||
HTTP2Response {response, respBody} <- liftHTTPS2 $ sendRequest http2 req
|
||||
let status = H.responseStatus response
|
||||
reason' = maybe "?" reason $ J.decodeStrict' respBody
|
||||
reason' = maybe "" reason $ J.decodeStrict' respBody
|
||||
logDebug $ "APNS response: " <> T.pack (show status) <> " " <> reason'
|
||||
result status reason'
|
||||
where
|
||||
|
|
|
@ -1,159 +0,0 @@
|
|||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Transport.Client.HTTP2 where
|
||||
|
||||
import Control.Concurrent.Async
|
||||
import Control.Exception (IOException, catch, finally)
|
||||
import qualified Control.Exception as E
|
||||
import Control.Logger.Simple (logDebug)
|
||||
import Control.Monad.Except
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Default (def)
|
||||
import Data.Maybe (isNothing)
|
||||
import qualified Data.Text as T
|
||||
import qualified Data.X509.CertificateStore as XS
|
||||
import Foreign (mallocBytes)
|
||||
import Network.HPACK (BufferSize, HeaderTable)
|
||||
import Network.HTTP2.Client (ClientConfig (..), Config (..), Request, Response)
|
||||
import qualified Network.HTTP2.Client as H
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import qualified Network.TLS as T
|
||||
import qualified Network.TLS.Extra as TE
|
||||
import Numeric.Natural (Natural)
|
||||
import Simplex.Messaging.Transport (TLS, Transport (cGet, cPut))
|
||||
import Simplex.Messaging.Transport.Client (runTLSTransportClient)
|
||||
import Simplex.Messaging.Transport.KeepAlive (KeepAliveOpts)
|
||||
import qualified System.TimeManager as TI
|
||||
import UnliftIO.STM
|
||||
import UnliftIO.Timeout
|
||||
|
||||
data HTTPS2Client = HTTPS2Client
|
||||
{ action :: Async (),
|
||||
connected :: TVar Bool,
|
||||
host :: HostName,
|
||||
port :: ServiceName,
|
||||
config :: HTTP2SClientConfig,
|
||||
reqQ :: TBQueue (Request, TMVar HTTP2Response)
|
||||
}
|
||||
|
||||
data HTTP2Response = HTTP2Response
|
||||
{ response :: Response,
|
||||
respBody :: ByteString,
|
||||
respTrailers :: Maybe HeaderTable
|
||||
}
|
||||
|
||||
data HTTP2SClientConfig = HTTP2SClientConfig
|
||||
{ qSize :: Natural,
|
||||
connTimeout :: Int,
|
||||
tcpKeepAlive :: Maybe KeepAliveOpts,
|
||||
caStoreFile :: FilePath,
|
||||
suportedTLSParams :: T.Supported
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
defaultHTTP2SClientConfig :: HTTP2SClientConfig
|
||||
defaultHTTP2SClientConfig =
|
||||
HTTP2SClientConfig
|
||||
{ qSize = 64,
|
||||
connTimeout = 10000000,
|
||||
tcpKeepAlive = Nothing,
|
||||
caStoreFile = "/etc/ssl/cert.pem",
|
||||
suportedTLSParams =
|
||||
def
|
||||
{ T.supportedVersions = [T.TLS13, T.TLS12],
|
||||
T.supportedCiphers = TE.ciphersuite_strong_det,
|
||||
T.supportedSecureRenegotiation = False
|
||||
}
|
||||
}
|
||||
|
||||
data HTTPS2ClientError = HCResponseTimeout | HCNetworkError | HCIOError IOException
|
||||
deriving (Show)
|
||||
|
||||
getHTTPS2Client :: HostName -> ServiceName -> HTTP2SClientConfig -> IO () -> IO (Either HTTPS2ClientError HTTPS2Client)
|
||||
getHTTPS2Client host port config@HTTP2SClientConfig {tcpKeepAlive, connTimeout, caStoreFile, suportedTLSParams} disconnected =
|
||||
(atomically mkHTTPS2Client >>= runClient)
|
||||
`catch` \(e :: IOException) -> pure . Left $ HCIOError e
|
||||
where
|
||||
mkHTTPS2Client :: STM HTTPS2Client
|
||||
mkHTTPS2Client = do
|
||||
connected <- newTVar False
|
||||
reqQ <- newTBQueue $ qSize config
|
||||
pure HTTPS2Client {action = undefined, connected, host, port, config, reqQ}
|
||||
|
||||
runClient :: HTTPS2Client -> IO (Either HTTPS2ClientError HTTPS2Client)
|
||||
runClient c = do
|
||||
cVar <- newEmptyTMVarIO
|
||||
caStore <- XS.readCertificateStore caStoreFile
|
||||
when (isNothing caStore) . putStrLn $ "Error loading CertificateStore from " <> caStoreFile
|
||||
action <-
|
||||
async $
|
||||
runHTTPS2Client suportedTLSParams caStore host port tcpKeepAlive (client c cVar)
|
||||
`finally` atomically (putTMVar cVar $ Left HCNetworkError)
|
||||
conn_ <- connTimeout `timeout` atomically (takeTMVar cVar)
|
||||
pure $ case conn_ of
|
||||
Just (Right ()) -> Right c {action}
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left HCNetworkError
|
||||
|
||||
client :: HTTPS2Client -> TMVar (Either HTTPS2ClientError ()) -> (Request -> (Response -> IO ()) -> IO ()) -> IO ()
|
||||
client c cVar sendReq = do
|
||||
atomically $ do
|
||||
writeTVar (connected c) True
|
||||
putTMVar cVar $ Right ()
|
||||
process c sendReq `finally` disconnected
|
||||
|
||||
process :: HTTPS2Client -> (Request -> (Response -> IO ()) -> IO ()) -> IO ()
|
||||
process HTTPS2Client {reqQ} sendReq = forever $ do
|
||||
(req, respVar) <- atomically $ readTBQueue reqQ
|
||||
sendReq req $ \r -> do
|
||||
let writeResp respBody respTrailers = atomically $ putTMVar respVar HTTP2Response {response = r, respBody, respTrailers}
|
||||
respBody <- getResponseBody r ""
|
||||
respTrailers <- H.getResponseTrailers r
|
||||
writeResp respBody respTrailers
|
||||
|
||||
getResponseBody :: Response -> ByteString -> IO ByteString
|
||||
getResponseBody r s =
|
||||
H.getResponseBodyChunk r >>= \chunk ->
|
||||
if B.null chunk then pure s else getResponseBody r $ s <> chunk
|
||||
|
||||
-- | Disconnects client from the server and terminates client threads.
|
||||
closeHTTPS2Client :: HTTPS2Client -> IO ()
|
||||
-- TODO disconnect
|
||||
closeHTTPS2Client = uninterruptibleCancel . action
|
||||
|
||||
sendRequest :: HTTPS2Client -> Request -> IO (Either HTTPS2ClientError HTTP2Response)
|
||||
sendRequest HTTPS2Client {reqQ, config} req = do
|
||||
resp <- newEmptyTMVarIO
|
||||
atomically $ writeTBQueue reqQ (req, resp)
|
||||
maybe (Left HCResponseTimeout) Right <$> (connTimeout config `timeout` atomically (takeTMVar resp))
|
||||
|
||||
runHTTPS2Client :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe KeepAliveOpts -> ((Request -> (Response -> IO ()) -> IO ()) -> IO ()) -> IO ()
|
||||
runHTTPS2Client tlsParams caStore host port keepAliveOpts client =
|
||||
runTLSTransportClient tlsParams caStore host port Nothing keepAliveOpts https2Client
|
||||
where
|
||||
cfg = ClientConfig "https" (B.pack host) 20
|
||||
https2Client :: TLS -> IO ()
|
||||
https2Client c =
|
||||
E.bracket
|
||||
(allocTlsConfig c 16384)
|
||||
H.freeSimpleConfig
|
||||
(\conf -> H.run cfg conf client)
|
||||
|
||||
allocTlsConfig :: TLS -> BufferSize -> IO Config
|
||||
allocTlsConfig c sz = do
|
||||
buf <- mallocBytes sz
|
||||
tm <- TI.initialize $ 30 * 1000000
|
||||
pure
|
||||
Config
|
||||
{ confWriteBuffer = buf,
|
||||
confBufferSize = sz,
|
||||
confSendAll = cPut c,
|
||||
confReadN = cGet c,
|
||||
confPositionReadMaker = H.defaultPositionReadMaker,
|
||||
confTimeoutManager = tm
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
module Simplex.Messaging.Transport.HTTP2 where
|
||||
|
||||
import qualified Control.Exception as E
|
||||
import Data.Default (def)
|
||||
import Foreign (mallocBytes)
|
||||
import Network.HPACK (BufferSize)
|
||||
import Network.HTTP2.Client (Config (..), defaultPositionReadMaker, freeSimpleConfig)
|
||||
import qualified Network.TLS as T
|
||||
import qualified Network.TLS.Extra as TE
|
||||
import Simplex.Messaging.Transport (TLS, Transport (cGet, cPut))
|
||||
import qualified System.TimeManager as TI
|
||||
|
||||
withTlsConfig :: TLS -> BufferSize -> (Config -> IO ()) -> IO ()
|
||||
withTlsConfig c sz = E.bracket (allocTlsConfig c sz) freeSimpleConfig
|
||||
|
||||
allocTlsConfig :: TLS -> BufferSize -> IO Config
|
||||
allocTlsConfig c sz = do
|
||||
buf <- mallocBytes sz
|
||||
tm <- TI.initialize $ 30 * 1000000
|
||||
pure
|
||||
Config
|
||||
{ confWriteBuffer = buf,
|
||||
confBufferSize = sz,
|
||||
confSendAll = cPut c,
|
||||
confReadN = cGet c,
|
||||
confPositionReadMaker = defaultPositionReadMaker,
|
||||
confTimeoutManager = tm
|
||||
}
|
||||
|
||||
http2TLSParams :: T.Supported
|
||||
http2TLSParams =
|
||||
def
|
||||
{ T.supportedVersions = [T.TLS13, T.TLS12],
|
||||
T.supportedCiphers = TE.ciphersuite_strong_det,
|
||||
T.supportedSecureRenegotiation = False
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Simplex.Messaging.Transport.HTTP2.Client where
|
||||
|
||||
import Control.Concurrent.Async
|
||||
import Control.Exception (IOException)
|
||||
import qualified Control.Exception as E
|
||||
import Control.Monad.Except
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Data.Maybe (isNothing)
|
||||
import qualified Data.X509.CertificateStore as XS
|
||||
import Network.HPACK (HeaderTable)
|
||||
import Network.HTTP2.Client (ClientConfig (..), Request, Response)
|
||||
import qualified Network.HTTP2.Client as H
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import qualified Network.TLS as T
|
||||
import Numeric.Natural (Natural)
|
||||
import Simplex.Messaging.Transport.Client (runTLSTransportClient)
|
||||
import Simplex.Messaging.Transport.HTTP2 (http2TLSParams, withTlsConfig)
|
||||
import Simplex.Messaging.Transport.KeepAlive (KeepAliveOpts)
|
||||
import UnliftIO.STM
|
||||
import UnliftIO.Timeout
|
||||
|
||||
data HTTP2Client = HTTP2Client
|
||||
{ action :: Async (),
|
||||
connected :: TVar Bool,
|
||||
host :: HostName,
|
||||
port :: ServiceName,
|
||||
config :: HTTP2ClientConfig,
|
||||
reqQ :: TBQueue (Request, TMVar HTTP2Response)
|
||||
}
|
||||
|
||||
data HTTP2Response = HTTP2Response
|
||||
{ response :: Response,
|
||||
respBody :: ByteString,
|
||||
respTrailers :: Maybe HeaderTable
|
||||
}
|
||||
|
||||
data HTTP2ClientConfig = HTTP2ClientConfig
|
||||
{ qSize :: Natural,
|
||||
connTimeout :: Int,
|
||||
tcpKeepAlive :: Maybe KeepAliveOpts,
|
||||
caStoreFile :: FilePath,
|
||||
suportedTLSParams :: T.Supported
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
defaultHTTP2ClientConfig :: HTTP2ClientConfig
|
||||
defaultHTTP2ClientConfig =
|
||||
HTTP2ClientConfig
|
||||
{ qSize = 64,
|
||||
connTimeout = 10000000,
|
||||
tcpKeepAlive = Nothing,
|
||||
caStoreFile = "/etc/ssl/cert.pem",
|
||||
suportedTLSParams = http2TLSParams
|
||||
}
|
||||
|
||||
data HTTP2ClientError = HCResponseTimeout | HCNetworkError | HCNetworkError1 | HCIOError IOException
|
||||
deriving (Show)
|
||||
|
||||
getHTTP2Client :: HostName -> ServiceName -> HTTP2ClientConfig -> IO () -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
getHTTP2Client host port config@HTTP2ClientConfig {tcpKeepAlive, connTimeout, caStoreFile, suportedTLSParams} disconnected =
|
||||
(atomically mkHTTPS2Client >>= runClient)
|
||||
`E.catch` \(e :: IOException) -> pure . Left $ HCIOError e
|
||||
where
|
||||
mkHTTPS2Client :: STM HTTP2Client
|
||||
mkHTTPS2Client = do
|
||||
connected <- newTVar False
|
||||
reqQ <- newTBQueue $ qSize config
|
||||
pure HTTP2Client {action = undefined, connected, host, port, config, reqQ}
|
||||
|
||||
runClient :: HTTP2Client -> IO (Either HTTP2ClientError HTTP2Client)
|
||||
runClient c = do
|
||||
cVar <- newEmptyTMVarIO
|
||||
caStore <- XS.readCertificateStore caStoreFile
|
||||
when (isNothing caStore) . putStrLn $ "Error loading CertificateStore from " <> caStoreFile
|
||||
action <-
|
||||
async $
|
||||
runHTTP2Client suportedTLSParams caStore host port tcpKeepAlive (client c cVar)
|
||||
`E.finally` atomically (putTMVar cVar $ Left HCNetworkError)
|
||||
conn_ <- connTimeout `timeout` atomically (takeTMVar cVar)
|
||||
pure $ case conn_ of
|
||||
Just (Right ()) -> Right c {action}
|
||||
Just (Left e) -> Left e
|
||||
Nothing -> Left HCNetworkError1
|
||||
|
||||
client :: HTTP2Client -> TMVar (Either HTTP2ClientError ()) -> (Request -> (Response -> IO ()) -> IO ()) -> IO ()
|
||||
client c cVar sendReq = do
|
||||
atomically $ do
|
||||
writeTVar (connected c) True
|
||||
putTMVar cVar $ Right ()
|
||||
process c sendReq `E.finally` disconnected
|
||||
|
||||
process :: HTTP2Client -> (Request -> (Response -> IO ()) -> IO ()) -> IO ()
|
||||
process HTTP2Client {reqQ} sendReq = forever $ do
|
||||
(req, respVar) <- atomically $ readTBQueue reqQ
|
||||
sendReq req $ \r -> do
|
||||
let writeResp respBody respTrailers = atomically $ putTMVar respVar HTTP2Response {response = r, respBody, respTrailers}
|
||||
respBody <- getResponseBody r ""
|
||||
respTrailers <- H.getResponseTrailers r
|
||||
writeResp respBody respTrailers
|
||||
|
||||
getResponseBody :: Response -> ByteString -> IO ByteString
|
||||
getResponseBody r s =
|
||||
H.getResponseBodyChunk r >>= \chunk ->
|
||||
if B.null chunk then pure s else getResponseBody r $ s <> chunk
|
||||
|
||||
-- | Disconnects client from the server and terminates client threads.
|
||||
closeHTTP2Client :: HTTP2Client -> IO ()
|
||||
-- TODO disconnect
|
||||
closeHTTP2Client = uninterruptibleCancel . action
|
||||
|
||||
sendRequest :: HTTP2Client -> Request -> IO (Either HTTP2ClientError HTTP2Response)
|
||||
sendRequest HTTP2Client {reqQ, config} req = do
|
||||
resp <- newEmptyTMVarIO
|
||||
atomically $ writeTBQueue reqQ (req, resp)
|
||||
maybe (Left HCResponseTimeout) Right <$> (connTimeout config `timeout` atomically (takeTMVar resp))
|
||||
|
||||
runHTTP2Client :: T.Supported -> Maybe XS.CertificateStore -> HostName -> ServiceName -> Maybe KeepAliveOpts -> ((Request -> (Response -> IO ()) -> IO ()) -> IO ()) -> IO ()
|
||||
runHTTP2Client tlsParams caStore host port keepAliveOpts client =
|
||||
runTLSTransportClient tlsParams caStore host port Nothing keepAliveOpts $ \c ->
|
||||
withTlsConfig c 16384 (`run` client)
|
||||
where
|
||||
run = H.run $ ClientConfig "https" (B.pack host) 20
|
|
@ -0,0 +1,70 @@
|
|||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Simplex.Messaging.Transport.HTTP2.Server where
|
||||
|
||||
import Control.Concurrent.Async (Async, async, uninterruptibleCancel)
|
||||
import Control.Concurrent.STM
|
||||
import Control.Monad
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString.Char8 as B
|
||||
import Network.HPACK (HeaderTable)
|
||||
import Network.HTTP2.Server (Aux, PushPromise, Request, Response)
|
||||
import qualified Network.HTTP2.Server as H
|
||||
import Network.Socket
|
||||
import qualified Network.TLS as T
|
||||
import Numeric.Natural (Natural)
|
||||
import Simplex.Messaging.Transport.HTTP2 (withTlsConfig)
|
||||
import Simplex.Messaging.Transport.Server (loadSupportedTLSServerParams, runTransportServer)
|
||||
|
||||
type HTTP2ServerFunc = (Request -> (Response -> IO ()) -> IO ())
|
||||
|
||||
data HTTP2ServerConfig = HTTP2ServerConfig
|
||||
{ qSize :: Natural,
|
||||
http2Port :: ServiceName,
|
||||
serverSupported :: T.Supported,
|
||||
caCertificateFile :: FilePath,
|
||||
privateKeyFile :: FilePath,
|
||||
certificateFile :: FilePath
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
data HTTP2Request = HTTP2Request
|
||||
{ request :: Request,
|
||||
reqBody :: ByteString,
|
||||
reqTrailers :: Maybe HeaderTable,
|
||||
sendResponse :: Response -> IO ()
|
||||
}
|
||||
|
||||
data HTTP2Server = HTTP2Server
|
||||
{ action :: Async (),
|
||||
reqQ :: TBQueue HTTP2Request
|
||||
}
|
||||
|
||||
getHTTP2Server :: HTTP2ServerConfig -> IO HTTP2Server
|
||||
getHTTP2Server HTTP2ServerConfig {qSize, http2Port, serverSupported, caCertificateFile, certificateFile, privateKeyFile} = do
|
||||
tlsServerParams <- loadSupportedTLSServerParams serverSupported caCertificateFile certificateFile privateKeyFile
|
||||
started <- newEmptyTMVarIO
|
||||
reqQ <- newTBQueueIO qSize
|
||||
action <- async $
|
||||
runHTTP2Server started http2Port tlsServerParams $ \r sendResponse -> do
|
||||
reqBody <- getRequestBody r ""
|
||||
reqTrailers <- H.getRequestTrailers r
|
||||
atomically $ writeTBQueue reqQ HTTP2Request {request = r, reqBody, reqTrailers, sendResponse}
|
||||
void . atomically $ takeTMVar started
|
||||
pure HTTP2Server {action, reqQ}
|
||||
where
|
||||
getRequestBody :: Request -> ByteString -> IO ByteString
|
||||
getRequestBody r s =
|
||||
H.getRequestBodyChunk r >>= \chunk ->
|
||||
if B.null chunk then pure s else getRequestBody r $ s <> chunk
|
||||
|
||||
closeHTTP2Server :: HTTP2Server -> IO ()
|
||||
closeHTTP2Server = uninterruptibleCancel . action
|
||||
|
||||
runHTTP2Server :: TMVar Bool -> ServiceName -> T.ServerParams -> HTTP2ServerFunc -> IO ()
|
||||
runHTTP2Server started port serverParams http2Server =
|
||||
runTransportServer started port serverParams $ \c -> withTlsConfig c 16384 (`H.run` server)
|
||||
where
|
||||
server :: Request -> Aux -> (Response -> [PushPromise] -> IO ()) -> IO ()
|
||||
server req _aux sendResp = http2Server req (`sendResp` [])
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
module Simplex.Messaging.Transport.Server
|
||||
( runTransportServer,
|
||||
loadSupportedTLSServerParams,
|
||||
loadTLSServerParams,
|
||||
loadFingerprint,
|
||||
smpServerHandshake,
|
||||
|
@ -71,7 +72,10 @@ startTCPServer started port = withSocketsDo $ resolve >>= open >>= setStarted
|
|||
setStarted sock = atomically (tryPutTMVar started True) >> pure sock
|
||||
|
||||
loadTLSServerParams :: FilePath -> FilePath -> FilePath -> IO T.ServerParams
|
||||
loadTLSServerParams caCertificateFile certificateFile privateKeyFile =
|
||||
loadTLSServerParams = loadSupportedTLSServerParams supportedParameters
|
||||
|
||||
loadSupportedTLSServerParams :: T.Supported -> FilePath -> FilePath -> FilePath -> IO T.ServerParams
|
||||
loadSupportedTLSServerParams serverSupported caCertificateFile certificateFile privateKeyFile =
|
||||
fromCredential <$> loadServerCredential
|
||||
where
|
||||
loadServerCredential :: IO T.Credential
|
||||
|
@ -85,7 +89,7 @@ loadTLSServerParams caCertificateFile certificateFile privateKeyFile =
|
|||
{ T.serverWantClientCert = False,
|
||||
T.serverShared = def {T.sharedCredentials = T.Credentials [credential]},
|
||||
T.serverHooks = def,
|
||||
T.serverSupported = supportedParameters
|
||||
T.serverSupported = serverSupported
|
||||
}
|
||||
|
||||
loadFingerprint :: FilePath -> IO Fingerprint
|
||||
|
|
|
@ -9,13 +9,11 @@ module AgentTests.FunctionalAPITests (functionalAPITests) where
|
|||
|
||||
import Control.Monad.Except (ExceptT, runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import NtfClient (withNtfServer)
|
||||
import SMPAgentClient
|
||||
import SMPClient (testPort, withSmpServer, withSmpServerStoreLogOn)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..))
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import Simplex.Messaging.Notifications.Protocol (DeviceToken (..), PushProvider (..))
|
||||
import Simplex.Messaging.Protocol (ErrorType (..), MsgBody)
|
||||
import Simplex.Messaging.Transport (ATransport (..))
|
||||
import System.Timeout
|
||||
|
@ -50,14 +48,11 @@ functionalAPITests t = do
|
|||
testAsyncServerOffline t
|
||||
it "should notify after HELLO timeout" $
|
||||
withSmpServer t testAsyncHelloTimeout
|
||||
describe "Notification server" $ do
|
||||
it "should register device token" $
|
||||
withNtfServer t testNotificationToken
|
||||
|
||||
testAgentClient :: IO ()
|
||||
testAgentClient = do
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, qInfo) <- createConnection alice SCMInvitation
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
|
@ -100,13 +95,13 @@ testAgentClient = do
|
|||
|
||||
testAsyncInitiatingOffline :: IO ()
|
||||
testAsyncInitiatingOffline = do
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, cReq) <- createConnection alice SCMInvitation
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob cReq "bob's connInfo"
|
||||
alice' <- liftIO $ getSMPAgentClient cfg initAgentServers
|
||||
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
|
||||
subscribeConnection alice' bobId
|
||||
("", _, CONF confId "bob's connInfo") <- get alice'
|
||||
allowConnection alice' bobId confId "alice's connInfo"
|
||||
|
@ -118,15 +113,15 @@ testAsyncInitiatingOffline = do
|
|||
|
||||
testAsyncJoiningOfflineBeforeActivation :: IO ()
|
||||
testAsyncJoiningOfflineBeforeActivation = do
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, qInfo) <- createConnection alice SCMInvitation
|
||||
aliceId <- joinConnection bob qInfo "bob's connInfo"
|
||||
disconnectAgentClient bob
|
||||
("", _, CONF confId "bob's connInfo") <- get alice
|
||||
allowConnection alice bobId confId "alice's connInfo"
|
||||
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
bob' <- liftIO $ getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
subscribeConnection bob' aliceId
|
||||
get alice ##> ("", bobId, CON)
|
||||
get bob' ##> ("", aliceId, INFO "alice's connInfo")
|
||||
|
@ -136,18 +131,18 @@ testAsyncJoiningOfflineBeforeActivation = do
|
|||
|
||||
testAsyncBothOffline :: IO ()
|
||||
testAsyncBothOffline = do
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(bobId, cReq) <- createConnection alice SCMInvitation
|
||||
disconnectAgentClient alice
|
||||
aliceId <- joinConnection bob cReq "bob's connInfo"
|
||||
disconnectAgentClient bob
|
||||
alice' <- liftIO $ getSMPAgentClient cfg initAgentServers
|
||||
alice' <- liftIO $ getSMPAgentClient agentCfg initAgentServers
|
||||
subscribeConnection alice' bobId
|
||||
("", _, CONF confId "bob's connInfo") <- get alice'
|
||||
allowConnection alice' bobId confId "alice's connInfo"
|
||||
bob' <- liftIO $ getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
bob' <- liftIO $ getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
subscribeConnection bob' aliceId
|
||||
get alice' ##> ("", bobId, CON)
|
||||
get bob' ##> ("", aliceId, INFO "alice's connInfo")
|
||||
|
@ -157,8 +152,8 @@ testAsyncBothOffline = do
|
|||
|
||||
testAsyncServerOffline :: ATransport -> IO ()
|
||||
testAsyncServerOffline t = do
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2} initAgentServers
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2} initAgentServers
|
||||
-- create connection and shutdown the server
|
||||
Right (bobId, cReq) <- withSmpServerStoreLogOn t testPort $ \_ ->
|
||||
runExceptT $ createConnection alice SCMInvitation
|
||||
|
@ -181,8 +176,8 @@ testAsyncServerOffline t = do
|
|||
|
||||
testAsyncHelloTimeout :: IO ()
|
||||
testAsyncHelloTimeout = do
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
bob <- getSMPAgentClient cfg {dbFile = testDB2, helloTimeout = 1} initAgentServers
|
||||
alice <- getSMPAgentClient agentCfg initAgentServers
|
||||
bob <- getSMPAgentClient agentCfg {dbFile = testDB2, helloTimeout = 1} initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
(_, cReq) <- createConnection alice SCMInvitation
|
||||
disconnectAgentClient alice
|
||||
|
@ -190,13 +185,6 @@ testAsyncHelloTimeout = do
|
|||
get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED)
|
||||
pure ()
|
||||
|
||||
testNotificationToken :: IO ()
|
||||
testNotificationToken = do
|
||||
alice <- getSMPAgentClient cfg initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
registerNtfToken alice $ DeviceToken PPApns "abcd"
|
||||
pure ()
|
||||
|
||||
exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
|
||||
exchangeGreetings alice bobId bob aliceId = do
|
||||
5 <- sendMessage alice bobId "hello"
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module AgentTests.NotificationTests where
|
||||
|
||||
import Control.Monad.Except
|
||||
import qualified Data.Aeson as J
|
||||
import qualified Data.Aeson.Types as JT
|
||||
import Data.Bifunctor (bimap)
|
||||
import qualified Data.ByteString.Base64.URL as U
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Text.Encoding (encodeUtf8)
|
||||
import NtfClient
|
||||
import SMPAgentClient (agentCfg, initAgentServers)
|
||||
import Simplex.Messaging.Agent
|
||||
import Simplex.Messaging.Agent.Protocol
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Notifications.Protocol
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
import Simplex.Messaging.Transport (ATransport)
|
||||
import Test.Hspec
|
||||
import UnliftIO.STM
|
||||
|
||||
notificationTests :: ATransport -> Spec
|
||||
notificationTests t = do
|
||||
describe "Managing notification tokens" $
|
||||
it "should register and verify notification token" $
|
||||
withAPNSMockServer $ \apns -> withNtfServer t $ testNotificationToken apns
|
||||
|
||||
testNotificationToken :: APNSMockServer -> IO ()
|
||||
testNotificationToken APNSMockServer {apnsQ} = do
|
||||
a <- getSMPAgentClient agentCfg initAgentServers
|
||||
Right () <- runExceptT $ do
|
||||
let tkn = DeviceToken PPApns "abcd"
|
||||
registerNtfToken a tkn
|
||||
APNSMockRequest {notification = APNSNotification {aps = APNSBackground _, notificationData = Just ntfData}, sendApnsResponse} <-
|
||||
atomically $ readTBQueue apnsQ
|
||||
verification <- ntfData .-> "verification"
|
||||
nonce <- C.cbNonce <$> ntfData .-> "nonce"
|
||||
liftIO $ sendApnsResponse APNSRespOk
|
||||
verifyNtfToken a tkn verification nonce
|
||||
enableNtfCron a tkn 30
|
||||
pure ()
|
||||
pure ()
|
||||
where
|
||||
(.->) :: J.Value -> J.Key -> ExceptT AgentErrorType IO ByteString
|
||||
v .-> key = do
|
||||
J.Object o <- pure v
|
||||
liftEither . bimap INTERNAL (U.decodeLenient . encodeUtf8) $ JT.parseEither (J..: key) o
|
|
@ -1,53 +1,74 @@
|
|||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE NumericUnderscores #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
{-# OPTIONS_GHC -fno-warn-orphans #-}
|
||||
|
||||
module NtfClient where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.Except (runExceptT)
|
||||
import Control.Monad.IO.Unlift
|
||||
import Crypto.Random
|
||||
import Data.Aeson (FromJSON (..), ToJSON (..))
|
||||
import qualified Data.Aeson as J
|
||||
import Data.ByteString.Builder (lazyByteString)
|
||||
import Data.ByteString.Char8 (ByteString)
|
||||
import Data.Text (Text)
|
||||
import GHC.Generics (Generic)
|
||||
import Network.HTTP.Types (Status)
|
||||
import qualified Network.HTTP.Types as N
|
||||
import qualified Network.HTTP2.Server as H
|
||||
import Network.Socket
|
||||
import Simplex.Messaging.Client.Agent (defaultSMPClientAgentConfig)
|
||||
import qualified Simplex.Messaging.Crypto as C
|
||||
import Simplex.Messaging.Encoding
|
||||
import Simplex.Messaging.Notifications.Server (runNtfServerBlocking)
|
||||
import Simplex.Messaging.Notifications.Server.Env
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS (defaultAPNSPushClientConfig)
|
||||
import Simplex.Messaging.Notifications.Server.Push.APNS
|
||||
import Simplex.Messaging.Notifications.Transport
|
||||
import Simplex.Messaging.Protocol
|
||||
import Simplex.Messaging.Transport
|
||||
import Simplex.Messaging.Transport.Client
|
||||
import Simplex.Messaging.Transport.HTTP2 (http2TLSParams)
|
||||
import Simplex.Messaging.Transport.HTTP2.Client
|
||||
import Simplex.Messaging.Transport.HTTP2.Server
|
||||
import Simplex.Messaging.Transport.KeepAlive
|
||||
import UnliftIO.Async
|
||||
import UnliftIO.Concurrent
|
||||
import qualified UnliftIO.Exception as E
|
||||
import UnliftIO.STM (TMVar, atomically, newEmptyTMVarIO, takeTMVar)
|
||||
import UnliftIO.STM
|
||||
import UnliftIO.Timeout (timeout)
|
||||
|
||||
testHost :: HostName
|
||||
testHost = "localhost"
|
||||
|
||||
testPort :: ServiceName
|
||||
testPort = "6001"
|
||||
ntfTestPort :: ServiceName
|
||||
ntfTestPort = "6001"
|
||||
|
||||
apnsTestPort :: ServiceName
|
||||
apnsTestPort = "6010"
|
||||
|
||||
testKeyHash :: C.KeyHash
|
||||
testKeyHash = "LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI="
|
||||
|
||||
testNtfClient :: (Transport c, MonadUnliftIO m) => (THandle c -> m a) -> m a
|
||||
testNtfClient client =
|
||||
runTransportClient testHost testPort (Just testKeyHash) (Just defaultKeepAliveOpts) $ \h ->
|
||||
runTransportClient testHost ntfTestPort (Just testKeyHash) (Just defaultKeepAliveOpts) $ \h ->
|
||||
liftIO (runExceptT $ ntfClientHandshake h testKeyHash) >>= \case
|
||||
Right th -> client th
|
||||
Left e -> error $ show e
|
||||
|
||||
cfg :: NtfServerConfig
|
||||
cfg =
|
||||
ntfServerCfg :: NtfServerConfig
|
||||
ntfServerCfg =
|
||||
NtfServerConfig
|
||||
{ transports = undefined,
|
||||
subIdBytes = 24,
|
||||
|
@ -56,7 +77,12 @@ cfg =
|
|||
subQSize = 1,
|
||||
pushQSize = 1,
|
||||
smpAgentCfg = defaultSMPClientAgentConfig,
|
||||
apnsConfig = defaultAPNSPushClientConfig,
|
||||
apnsConfig =
|
||||
defaultAPNSPushClientConfig
|
||||
{ apnsHost = "localhost",
|
||||
apnsPort = apnsTestPort,
|
||||
http2cfg = defaultHTTP2ClientConfig {caStoreFile = "tests/fixtures/ca.crt"}
|
||||
},
|
||||
-- CA certificate private key is not needed for initialization
|
||||
caCertificateFile = "tests/fixtures/ca.crt",
|
||||
privateKeyFile = "tests/fixtures/server.key",
|
||||
|
@ -66,7 +92,7 @@ cfg =
|
|||
withNtfServerThreadOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName -> (ThreadId -> m a) -> m a
|
||||
withNtfServerThreadOn t port' =
|
||||
serverBracket
|
||||
(\started -> runNtfServerBlocking started cfg {transports = [(port', t)]})
|
||||
(\started -> runNtfServerBlocking started ntfServerCfg {transports = [(port', t)]})
|
||||
(pure ())
|
||||
|
||||
serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a
|
||||
|
@ -86,7 +112,7 @@ withNtfServerOn :: (MonadUnliftIO m, MonadRandom m) => ATransport -> ServiceName
|
|||
withNtfServerOn t port' = withNtfServerThreadOn t port' . const
|
||||
|
||||
withNtfServer :: (MonadUnliftIO m, MonadRandom m) => ATransport -> m a -> m a
|
||||
withNtfServer t = withNtfServerOn t testPort
|
||||
withNtfServer t = withNtfServerOn t ntfTestPort
|
||||
|
||||
runNtfTest :: forall c m a. (Transport c, MonadUnliftIO m, MonadRandom m) => (THandle c -> m a) -> m a
|
||||
runNtfTest test = withNtfServer (transport @c) $ testNtfClient test
|
||||
|
@ -105,4 +131,63 @@ ntfServerTest _ t = runNtfTest $ \h -> tPut' h t >> tGet' h
|
|||
pure ()
|
||||
tGet' h = do
|
||||
(Nothing, _, (CorrId corrId, qId, Right cmd)) <- tGet h
|
||||
pure (Nothing, corrId, qId, cmd)
|
||||
pure (Nothing, corrId, qId, cmd)
|
||||
|
||||
data APNSMockRequest = APNSMockRequest
|
||||
{ notification :: APNSNotification,
|
||||
sendApnsResponse :: APNSMockResponse -> IO ()
|
||||
}
|
||||
|
||||
data APNSMockResponse = APNSRespOk | APNSRespError Status Text
|
||||
|
||||
data APNSMockServer = APNSMockServer
|
||||
{ action :: Async (),
|
||||
apnsQ :: TBQueue APNSMockRequest,
|
||||
http2Server :: HTTP2Server
|
||||
}
|
||||
|
||||
apnsMockServerConfig :: HTTP2ServerConfig
|
||||
apnsMockServerConfig =
|
||||
HTTP2ServerConfig
|
||||
{ qSize = 1,
|
||||
http2Port = apnsTestPort,
|
||||
serverSupported = http2TLSParams,
|
||||
caCertificateFile = "tests/fixtures/ca.crt",
|
||||
privateKeyFile = "tests/fixtures/server.key",
|
||||
certificateFile = "tests/fixtures/server.crt"
|
||||
}
|
||||
|
||||
withAPNSMockServer :: (APNSMockServer -> IO ()) -> IO ()
|
||||
withAPNSMockServer = E.bracket (getAPNSMockServer apnsMockServerConfig) closeAPNSMockServer
|
||||
|
||||
deriving instance Generic APNSAlertBody
|
||||
|
||||
deriving instance FromJSON APNSAlertBody
|
||||
|
||||
instance FromJSON APNSNotificationBody where parseJSON = J.genericParseJSON apnsJSONOptions
|
||||
|
||||
deriving instance FromJSON APNSNotification
|
||||
|
||||
deriving instance ToJSON APNSErrorResponse
|
||||
|
||||
getAPNSMockServer :: HTTP2ServerConfig -> IO APNSMockServer
|
||||
getAPNSMockServer config@HTTP2ServerConfig {qSize} = do
|
||||
http2Server <- getHTTP2Server config
|
||||
apnsQ <- newTBQueueIO qSize
|
||||
action <- async $ runAPNSMockServer apnsQ http2Server
|
||||
pure APNSMockServer {action, apnsQ, http2Server}
|
||||
where
|
||||
runAPNSMockServer apnsQ HTTP2Server {reqQ} = forever $ do
|
||||
HTTP2Request {reqBody, sendResponse} <- atomically $ readTBQueue reqQ
|
||||
let sendApnsResponse = \case
|
||||
APNSRespOk -> sendResponse $ H.responseNoBody N.ok200 []
|
||||
APNSRespError status reason ->
|
||||
sendResponse . H.responseBuilder status [] . lazyByteString $ J.encode APNSErrorResponse {reason}
|
||||
case J.decodeStrict' reqBody of
|
||||
Just notification -> atomically $ writeTBQueue apnsQ APNSMockRequest {notification, sendApnsResponse}
|
||||
_ -> sendApnsResponse $ APNSRespError N.badRequest400 "bad_request_body"
|
||||
|
||||
closeAPNSMockServer :: APNSMockServer -> IO ()
|
||||
closeAPNSMockServer APNSMockServer {action, http2Server} = do
|
||||
closeHTTP2Server http2Server
|
||||
uninterruptibleCancel action
|
||||
|
|
|
@ -33,4 +33,4 @@ ntfSyntaxTests (ATransport t) = do
|
|||
(Maybe C.ASignature, ByteString, ByteString, smp) ->
|
||||
(Maybe C.ASignature, ByteString, ByteString, BrokerMsg) ->
|
||||
Expectation
|
||||
command >#> response = ntfServerTest t command `shouldReturn` response
|
||||
command >#> response = withAPNSMockServer $ \_ -> ntfServerTest t command `shouldReturn` response
|
||||
|
|
|
@ -11,6 +11,7 @@ import Crypto.Random
|
|||
import qualified Data.ByteString.Char8 as B
|
||||
import qualified Data.List.NonEmpty as L
|
||||
import Network.Socket (HostName, ServiceName)
|
||||
import NtfClient (ntfTestPort)
|
||||
import SMPClient
|
||||
( serverBracket,
|
||||
testKeyHash,
|
||||
|
@ -161,8 +162,8 @@ initAgentServers =
|
|||
ntf = ["smp://LcJUMfVhwD8yxjAiSaDzzGF3-kLG4Uh0Fl_ZIjrRwjI=@localhost:6001"]
|
||||
}
|
||||
|
||||
cfg :: AgentConfig
|
||||
cfg =
|
||||
agentCfg :: AgentConfig
|
||||
agentCfg =
|
||||
defaultAgentConfig
|
||||
{ tcpPort = agentTestPort,
|
||||
tbqSize = 1,
|
||||
|
@ -173,6 +174,11 @@ cfg =
|
|||
defaultTransport = (testPort, transport @TLS),
|
||||
tcpTimeout = 500_000
|
||||
},
|
||||
ntfCfg =
|
||||
defaultClientConfig
|
||||
{ qSize = 1,
|
||||
defaultTransport = (ntfTestPort, transport @TLS)
|
||||
},
|
||||
reconnectInterval = defaultReconnectInterval {initialInterval = 50_000},
|
||||
caCertificateFile = "tests/fixtures/ca.crt",
|
||||
privateKeyFile = "tests/fixtures/server.key",
|
||||
|
@ -181,7 +187,7 @@ cfg =
|
|||
|
||||
withSmpAgentThreadOn_ :: (MonadUnliftIO m, MonadRandom m) => ATransport -> (ServiceName, ServiceName, String) -> m () -> (ThreadId -> m a) -> m a
|
||||
withSmpAgentThreadOn_ t (port', smpPort', db') afterProcess =
|
||||
let cfg' = cfg {tcpPort = port', dbFile = db'}
|
||||
let cfg' = agentCfg {tcpPort = port', dbFile = db'}
|
||||
initServers' = initAgentServers {smp = L.fromList [SMPServer "localhost" smpPort' testKeyHash]}
|
||||
in serverBracket
|
||||
(\started -> runSMPAgentBlocking t started cfg' initServers')
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
import AgentTests (agentTests)
|
||||
import AgentTests.NotificationTests (notificationTests)
|
||||
import CoreTests.EncodingTests
|
||||
import CoreTests.ProtocolErrorTests
|
||||
import CoreTests.VersionRangeTests
|
||||
|
@ -9,11 +10,14 @@ import ServerTests
|
|||
import Simplex.Messaging.Transport (TLS, Transport (..))
|
||||
import Simplex.Messaging.Transport.WebSockets (WS)
|
||||
import System.Directory (createDirectoryIfMissing, removeDirectoryRecursive)
|
||||
import System.Environment (setEnv)
|
||||
import Test.Hspec
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
createDirectoryIfMissing False "tests/tmp"
|
||||
setEnv "APNS_KEY_ID" "H82WD9K9AQ"
|
||||
setEnv "APNS_KEY_FILE" "./tests/fixtures/AuthKey_H82WD9K9AQ.p8"
|
||||
hspec $ do
|
||||
describe "Core tests" $ do
|
||||
describe "Encoding tests" encodingTests
|
||||
|
@ -21,6 +25,8 @@ main = do
|
|||
describe "Version range" versionRangeTests
|
||||
describe "SMP server via TLS" $ serverTests (transport @TLS)
|
||||
describe "SMP server via WebSockets" $ serverTests (transport @WS)
|
||||
describe "Ntf server via TLS" $ ntfServerTests (transport @TLS)
|
||||
describe "Notifications server" $ do
|
||||
ntfServerTests (transport @TLS)
|
||||
notificationTests (transport @TLS)
|
||||
describe "SMP client agent" $ agentTests (transport @TLS)
|
||||
removeDirectoryRecursive "tests/tmp"
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
-----BEGIN PRIVATE KEY-----
|
||||
MIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgWuPap5jF6eioxuHM
|
||||
XWZWUK78LdcxkTnMXWg2GqyXuBugCgYIKoZIzj0DAQehRANCAAQn64CvAIbEEzvM
|
||||
KwYjlOxVD5SxlgP1ZcYvVM/+VHLFu0aCkG7ueICTi3qWyqoB5hjjuAqwtc3EK0q0
|
||||
yupyM7Yx
|
||||
-----END PRIVATE KEY-----
|
Reference in New Issue