{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Network.TLS.Handshake.Server.ClientHello13 (
processClientHello13,
) where
import qualified Data.ByteString as B
import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Common13
import Network.TLS.Handshake.Signature
import Network.TLS.Handshake.State
import Network.TLS.IO.Encode
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Parameters
import Network.TLS.Session
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Types
processClientHello13
:: ServerParams
-> Context
-> ClientHello
-> IO
( Maybe KeyShareEntry
, (Cipher, Hash, Bool)
, (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
)
processClientHello13 :: ServerParams
-> Context
-> ClientHello
-> IO
(Maybe KeyShareEntry, (Cipher, Hash, Bool),
(SecretPair EarlySecret, [ExtensionRaw], Bool, Bool))
processClientHello13 ServerParams
sparams Context
ctx ch :: ClientHello
ch@CH{[CompressionID]
[CipherId]
[ExtensionRaw]
Version
ClientRandom
Session
chVersion :: Version
chRandom :: ClientRandom
chSession :: Session
chCiphers :: [CipherId]
chComps :: [CompressionID]
chExtensions :: [ExtensionRaw]
chCiphers :: ClientHello -> [CipherId]
chComps :: ClientHello -> [CompressionID]
chExtensions :: ClientHello -> [ExtensionRaw]
chRandom :: ClientHello -> ClientRandom
chSession :: ClientHello -> Session
chVersion :: ClientHello -> Version
..} = do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
((ExtensionRaw -> Bool) -> [ExtensionRaw] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\(ExtensionRaw ExtensionID
eid ByteString
_) -> ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_PreSharedKey) ([ExtensionRaw] -> Bool) -> [ExtensionRaw] -> Bool
forall a b. (a -> b) -> a -> b
$ [ExtensionRaw] -> [ExtensionRaw]
forall a. HasCallStack => [a] -> [a]
init [ExtensionRaw]
chExtensions)
(IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore
(TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"extension pre_shared_key must be last" AlertDescription
IllegalParameter
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Cipher] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Cipher]
ciphersFilteredVersion) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol String
"no cipher in common with the TLS 1.3 client" AlertDescription
HandshakeFailure
let usedCipher :: Cipher
usedCipher = ServerHooks -> Version -> [Cipher] -> Cipher
onCipherChoosing (ServerParams -> ServerHooks
serverHooks ServerParams
sparams) Version
TLS13 [Cipher]
ciphersFilteredVersion
usedHash :: Hash
usedHash = Cipher -> Hash
cipherHash Cipher
usedCipher
rtt0 :: Bool
rtt0 =
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> Bool
-> (EarlyDataIndication -> Bool)
-> Bool
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
ExtensionID
EID_EarlyData
MessageType
MsgTClientHello
[ExtensionRaw]
chExtensions
Bool
False
(\(EarlyDataIndication Maybe Word32
_) -> Bool
True)
if Bool
rtt0
then
Context -> Established -> IO ()
setEstablished Context
ctx (Int -> Established
EarlyDataNotAllowed Int
3)
else
Context -> Established -> IO ()
setEstablished Context
ctx Established
NotEstablished
let require :: IO a
require =
TLSError -> IO a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO a) -> TLSError -> IO a
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol
String
"key exchange not implemented, expected key_share extension"
AlertDescription
MissingExtension
extract :: KeyShare -> IO [KeyShareEntry]
extract (KeyShareClientHello [KeyShareEntry]
kses) = [KeyShareEntry] -> IO [KeyShareEntry]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [KeyShareEntry]
kses
extract KeyShare
_ = IO [KeyShareEntry]
forall {a}. IO a
require
keyShares <-
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> IO [KeyShareEntry]
-> (KeyShare -> IO [KeyShareEntry])
-> IO [KeyShareEntry]
forall a b.
Extension a =>
ExtensionID
-> MessageType -> [ExtensionRaw] -> IO b -> (a -> IO b) -> IO b
lookupAndDecodeAndDo ExtensionID
EID_KeyShare MessageType
MsgTClientHello [ExtensionRaw]
chExtensions IO [KeyShareEntry]
forall {a}. IO a
require KeyShare -> IO [KeyShareEntry]
extract
mshare <- findKeyShare keyShares serverGroups
let triple = (Cipher
usedCipher, Hash
usedHash, Bool
rtt0)
pskEarlySecret <- pskAndEarlySecret sparams ctx triple ch
clientHello <- fromJust <$> usingHState ctx getClientHello
void $ updateTranscriptHash12 ctx $ ClientHello clientHello
return (mshare, triple, pskEarlySecret)
where
ciphersFilteredVersion :: [Cipher]
ciphersFilteredVersion = [CipherId] -> [Cipher] -> [Cipher]
intersectCiphers [CipherId]
chCiphers [Cipher]
serverCiphers
serverCiphers :: [Cipher]
serverCiphers =
(Cipher -> Bool) -> [Cipher] -> [Cipher]
forall a. (a -> Bool) -> [a] -> [a]
filter
(Version -> Cipher -> Bool
cipherAllowedForVersion Version
TLS13)
(Supported -> [Cipher]
supportedCiphers (Supported -> [Cipher]) -> Supported -> [Cipher]
forall a b. (a -> b) -> a -> b
$ ServerParams -> Supported
serverSupported ServerParams
sparams)
serverGroups :: [Group]
serverGroups = Supported -> [Group]
supportedGroups (Context -> Supported
ctxSupported Context
ctx)
findKeyShare :: [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare :: [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare [KeyShareEntry]
ks [Group]
ggs = [Group] -> IO (Maybe KeyShareEntry)
forall {m :: * -> *}.
MonadIO m =>
[Group] -> m (Maybe KeyShareEntry)
go [Group]
ggs
where
go :: [Group] -> m (Maybe KeyShareEntry)
go [] = Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KeyShareEntry
forall a. Maybe a
Nothing
go (Group
g : [Group]
gs) = case (KeyShareEntry -> Bool) -> [KeyShareEntry] -> [KeyShareEntry]
forall a. (a -> Bool) -> [a] -> [a]
filter (Group -> KeyShareEntry -> Bool
grpEq Group
g) [KeyShareEntry]
ks of
[] -> [Group] -> m (Maybe KeyShareEntry)
go [Group]
gs
[KeyShareEntry
k] -> do
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (KeyShareEntry -> Bool
checkKeyShareKeyLength KeyShareEntry
k) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
TLSError -> m ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m ()) -> TLSError -> m ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol String
"broken key_share" AlertDescription
IllegalParameter
Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KeyShareEntry -> m (Maybe KeyShareEntry))
-> Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ KeyShareEntry -> Maybe KeyShareEntry
forall a. a -> Maybe a
Just KeyShareEntry
k
[KeyShareEntry]
_ -> TLSError -> m (Maybe KeyShareEntry)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m (Maybe KeyShareEntry))
-> TLSError -> m (Maybe KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"duplicated key_share" AlertDescription
IllegalParameter
grpEq :: Group -> KeyShareEntry -> Bool
grpEq Group
g KeyShareEntry
ent = Group
g Group -> Group -> Bool
forall a. Eq a => a -> a -> Bool
== KeyShareEntry -> Group
keyShareEntryGroup KeyShareEntry
ent
pskAndEarlySecret
:: ServerParams
-> Context
-> (Cipher, Hash, Bool)
-> ClientHello
-> IO (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
pskAndEarlySecret :: ServerParams
-> Context
-> (Cipher, Hash, Bool)
-> ClientHello
-> IO (SecretPair EarlySecret, [ExtensionRaw], Bool, Bool)
pskAndEarlySecret ServerParams
sparams Context
ctx (Cipher
usedCipher, Hash
usedHash, Bool
rtt0) CH{[CompressionID]
[CipherId]
[ExtensionRaw]
Version
ClientRandom
Session
chCiphers :: ClientHello -> [CipherId]
chComps :: ClientHello -> [CompressionID]
chExtensions :: ClientHello -> [ExtensionRaw]
chRandom :: ClientHello -> ClientRandom
chSession :: ClientHello -> Session
chVersion :: ClientHello -> Version
chVersion :: Version
chRandom :: ClientRandom
chSession :: Session
chCiphers :: [CipherId]
chComps :: [CompressionID]
chExtensions :: [ExtensionRaw]
..} = do
(psk, binderInfo, is0RTTvalid) <- IO (ByteString, Maybe (ByteString, Int, Int), Bool)
choosePSK
earlyKey <- calculateEarlySecret ctx choice (Left psk)
let earlySecret = SecretPair EarlySecret -> BaseSecret EarlySecret
forall a. SecretPair a -> BaseSecret a
pairBase SecretPair EarlySecret
earlyKey
authenticated = Maybe (ByteString, Int, Int) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (ByteString, Int, Int)
binderInfo
preSharedKeyExt <- checkBinder earlySecret binderInfo
return (earlyKey, preSharedKeyExt, authenticated, is0RTTvalid)
where
choice :: CipherChoice
choice = Version -> Cipher -> CipherChoice
makeCipherChoice Version
TLS13 Cipher
usedCipher
choosePSK :: IO (ByteString, Maybe (ByteString, Int, Int), Bool)
choosePSK =
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
-> (PreSharedKey
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool))
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a b.
Extension a =>
ExtensionID
-> MessageType -> [ExtensionRaw] -> IO b -> (a -> IO b) -> IO b
lookupAndDecodeAndDo
ExtensionID
EID_PreSharedKey
MessageType
MsgTClientHello
[ExtensionRaw]
chExtensions
((ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False))
PreSharedKey -> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
selectPSK
selectPSK :: PreSharedKey -> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
selectPSK (PreSharedKeyClientHello (PskIdentity ByteString
identity Word32
obfAge : [PskIdentity]
_) bnds :: [ByteString]
bnds@(ByteString
bnd : [ByteString]
_)) = do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([PskKexMode] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PskKexMode]
dhModes) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol String
"no psk_key_exchange_modes extension" AlertDescription
MissingExtension
if PskKexMode
PSK_DHE_KE PskKexMode -> [PskKexMode] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PskKexMode]
dhModes
then do
let len :: Int
len = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((ByteString -> Int) -> [ByteString] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (\ByteString
x -> ByteString -> Int
B.length ByteString
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [ByteString]
bnds) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2
mgr :: SessionManager
mgr = Shared -> SessionManager
sharedSessionManager (Shared -> SessionManager) -> Shared -> SessionManager
forall a b. (a -> b) -> a -> b
$ ServerParams -> Shared
serverShared ServerParams
sparams
msdata <-
if Bool
rtt0
then SessionManager -> ByteString -> IO (Maybe SessionData)
sessionResumeOnlyOnce SessionManager
mgr ByteString
identity
else SessionManager -> ByteString -> IO (Maybe SessionData)
sessionResume SessionManager
mgr ByteString
identity
case msdata of
Just SessionData
sdata -> do
let tinfo :: TLS13TicketInfo
tinfo = Maybe TLS13TicketInfo -> TLS13TicketInfo
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe TLS13TicketInfo -> TLS13TicketInfo)
-> Maybe TLS13TicketInfo -> TLS13TicketInfo
forall a b. (a -> b) -> a -> b
$ SessionData -> Maybe TLS13TicketInfo
sessionTicketInfo SessionData
sdata
psk :: ByteString
psk = SessionData -> ByteString
sessionSecret SessionData
sdata
isFresh <- TLS13TicketInfo -> Word32 -> IO Bool
checkFreshness TLS13TicketInfo
tinfo Word32
obfAge
(isPSKvalid, is0RTTvalid) <- checkSessionEquality sdata
if isPSKvalid && isFresh
then return (psk, Just (bnd, 0 :: Int, len), is0RTTvalid)
else
return (zero, Nothing, False)
Maybe SessionData
_ -> (ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False)
else (ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False)
selectPSK PreSharedKey
_ = (ByteString, Maybe (ByteString, Int, Int), Bool)
-> IO (ByteString, Maybe (ByteString, Int, Int), Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
zero, Maybe (ByteString, Int, Int)
forall a. Maybe a
Nothing, Bool
False)
checkBinder :: BaseSecret EarlySecret
-> Maybe (ByteString, a, Int) -> m [ExtensionRaw]
checkBinder BaseSecret EarlySecret
_ Maybe (ByteString, a, Int)
Nothing = [ExtensionRaw] -> m [ExtensionRaw]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return []
checkBinder BaseSecret EarlySecret
earlySecret (Just (ByteString
binder, a
n, Int
tlen)) = do
ch <- Maybe ClientHello -> ClientHello
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe ClientHello -> ClientHello)
-> m (Maybe ClientHello) -> m ClientHello
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> HandshakeM (Maybe ClientHello) -> m (Maybe ClientHello)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe ClientHello)
getClientHello
let ech = Handshake -> ByteString
encodeHandshake (Handshake -> ByteString) -> Handshake -> ByteString
forall a b. (a -> b) -> a -> b
$ ClientHello -> Handshake
ClientHello ClientHello
ch
binder' = BaseSecret EarlySecret -> Hash -> Int -> ByteString -> ByteString
makePSKBinder BaseSecret EarlySecret
earlySecret Hash
usedHash Int
tlen ByteString
ech
unless (binder == binder') $
decryptError "PSK binder validation failed"
return [toExtensionRaw $ PreSharedKeyServerHello $ fromIntegral n]
checkSessionEquality :: SessionData -> IO (Bool, Bool)
checkSessionEquality SessionData
sdata = do
msni <- Context -> TLSSt (Maybe String) -> IO (Maybe String)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe String)
getClientSNI
let isSameSNI = SessionData -> Maybe String
sessionClientSNI SessionData
sdata Maybe String -> Maybe String -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe String
msni
isSameCipher = SessionData -> CipherID
sessionCipher SessionData
sdata CipherID -> CipherID -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher -> CipherID
cipherID Cipher
usedCipher
ciphers = Supported -> [Cipher]
supportedCiphers (Supported -> [Cipher]) -> Supported -> [Cipher]
forall a b. (a -> b) -> a -> b
$ ServerParams -> Supported
serverSupported ServerParams
sparams
scid = SessionData -> CipherID
sessionCipher SessionData
sdata
isSameKDF = case CipherID -> [Cipher] -> Maybe Cipher
findCipher CipherID
scid [Cipher]
ciphers of
Maybe Cipher
Nothing -> Bool
False
Just Cipher
c -> Cipher -> Hash
cipherHash Cipher
c Hash -> Hash -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher -> Hash
cipherHash Cipher
usedCipher
isSameVersion = Version
TLS13 Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== SessionData -> Version
sessionVersion SessionData
sdata
isPSKvalid = Bool
isSameKDF Bool -> Bool -> Bool
&& Bool
isSameSNI
is0RTTvalid = Bool
isSameVersion Bool -> Bool -> Bool
&& Bool
isSameCipher
return (isPSKvalid, is0RTTvalid)
dhModes :: [PskKexMode]
dhModes =
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> [PskKexMode]
-> (PskKeyExchangeModes -> [PskKexMode])
-> [PskKexMode]
forall e a.
Extension e =>
ExtensionID -> MessageType -> [ExtensionRaw] -> a -> (e -> a) -> a
lookupAndDecode
ExtensionID
EID_PskKeyExchangeModes
MessageType
MsgTClientHello
[ExtensionRaw]
chExtensions
[]
(\(PskKeyExchangeModes [PskKexMode]
ms) -> [PskKexMode]
ms)
hashSize :: Int
hashSize = Hash -> Int
hashDigestSize Hash
usedHash
zero :: ByteString
zero = Int -> CompressionID -> ByteString
B.replicate Int
hashSize CompressionID
0