{-# LANGUAGE CPP #-}
{-# LANGUAGE ForeignFunctionInterface #-}

module Network.Wai.Handler.SCGI (
    run,
    runSendfile,
) where

import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import Data.ByteString.Lazy.Internal (defaultChunkSize)
import qualified Data.ByteString.Unsafe as S
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.Maybe (fromMaybe, listToMaybe)
import Foreign.C (CChar, CInt (..))
import Foreign.Marshal.Alloc (free, mallocBytes)
import Foreign.Ptr (Ptr, castPtr, nullPtr)
import Network.Wai (Application)
import Network.Wai.Handler.CGI (requestBodyFunc, runGeneric)

run :: Application -> IO ()
run :: Application -> IO ()
run Application
app = Maybe ByteString -> Application -> IO ()
runOne Maybe ByteString
forall a. Maybe a
Nothing Application
app IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Application -> IO ()
run Application
app

runSendfile :: ByteString -> Application -> IO ()
runSendfile :: ByteString -> Application -> IO ()
runSendfile ByteString
sf Application
app = Maybe ByteString -> Application -> IO ()
runOne (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
sf) Application
app IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Application -> IO ()
runSendfile ByteString
sf Application
app

runOne :: Maybe ByteString -> Application -> IO ()
runOne :: Maybe ByteString -> Application -> IO ()
runOne Maybe ByteString
sf Application
app = do
    socket <- CInt -> Ptr (ZonkAny 0) -> Ptr (ZonkAny 0) -> IO CInt
forall a. CInt -> Ptr a -> Ptr a -> IO CInt
c'accept CInt
0 Ptr (ZonkAny 0)
forall a. Ptr a
nullPtr Ptr (ZonkAny 0)
forall a. Ptr a
nullPtr
    headersBS <- readNetstring socket
    let headers = [ByteString] -> [(String, String)]
parseHeaders ([ByteString] -> [(String, String)])
-> [ByteString] -> [(String, String)]
forall a b. (a -> b) -> a -> b
$ Word8 -> ByteString -> [ByteString]
S.split Word8
0 ByteString
headersBS
    let conLen =
            Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ do
                (_, conLenS) <- [(String, String)] -> Maybe (String, String)
forall a. [a] -> Maybe a
listToMaybe [(String, String)]
headers
                (i, _) <- listToMaybe $ reads conLenS
                pure i
    conLenI <- newIORef conLen
    runGeneric
        headers
        (requestBodyFunc $ input socket conLenI)
        (write socket)
        sf
        app
    drain socket conLenI
    _ <- c'close socket
    return ()

write :: CInt -> S.ByteString -> IO ()
write :: CInt -> ByteString -> IO ()
write CInt
socket ByteString
bs = ByteString -> (CStringLen -> IO ()) -> IO ()
forall a. ByteString -> (CStringLen -> IO a) -> IO a
S.unsafeUseAsCStringLen ByteString
bs ((CStringLen -> IO ()) -> IO ()) -> (CStringLen -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
s, Int
l) -> do
    _ <- CInt -> Ptr CChar -> CInt -> IO CInt
c'write CInt
socket Ptr CChar
s (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
l)
    return ()

input :: CInt -> IORef Int -> Int -> IO (Maybe S.ByteString)
input :: CInt -> IORef Int -> Int -> IO (Maybe ByteString)
input CInt
socket IORef Int
ilen Int
rlen = do
    len <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
ilen
    case len of
        Int
0 -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
        Int
_ -> do
            bs <-
                CInt -> Int -> IO ByteString
readByteString CInt
socket (Int -> IO ByteString) -> Int -> IO ByteString
forall a b. (a -> b) -> a -> b
$
                    [Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [Int
defaultChunkSize, Int
len, Int
rlen]
            writeIORef ilen $ len - S.length bs
            return $ Just bs

drain :: CInt -> IORef Int -> IO () -- FIXME do it in chunks
drain :: CInt -> IORef Int -> IO ()
drain CInt
socket IORef Int
ilen = do
    len <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
ilen
    _ <- readByteString socket len
    return ()

parseHeaders :: [S.ByteString] -> [(String, String)]
parseHeaders :: [ByteString] -> [(String, String)]
parseHeaders (ByteString
x : ByteString
y : [ByteString]
z) = (ByteString -> String
S8.unpack ByteString
x, ByteString -> String
S8.unpack ByteString
y) (String, String) -> [(String, String)] -> [(String, String)]
forall a. a -> [a] -> [a]
: [ByteString] -> [(String, String)]
parseHeaders [ByteString]
z
parseHeaders [ByteString]
_ = []

readNetstring :: CInt -> IO S.ByteString
readNetstring :: CInt -> IO ByteString
readNetstring CInt
socket = do
    len <- Int -> IO Int
readLen Int
0
    bs <- readByteString socket len
    _ <- readByteString socket 1 -- the comma
    return bs
  where
    readLen :: Int -> IO Int
readLen Int
l = do
        bs <- CInt -> Int -> IO ByteString
readByteString CInt
socket Int
1
        case S8.unpack bs of
            [Char
':'] -> Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
l
            [Char
c] -> Int -> IO Int
readLen (Int -> IO Int) -> Int -> IO Int
forall a b. (a -> b) -> a -> b
$ Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
'0')
            String
_ -> String -> IO Int
forall a. HasCallStack => String -> a
error String
"Network.Wai.Handler.SCGI.readNetstring: should never happen"

readByteString :: CInt -> Int -> IO S.ByteString
readByteString :: CInt -> Int -> IO ByteString
readByteString CInt
socket Int
len = do
    buf <- Int -> IO (Ptr CChar)
forall a. Int -> IO (Ptr a)
mallocBytes Int
len
    _ <- c'read socket buf $ fromIntegral len
    S.unsafePackCStringFinalizer (castPtr buf) len $ free buf

foreign import ccall unsafe "accept"
    c'accept :: CInt -> Ptr a -> Ptr a -> IO CInt

#if WINDOWS
foreign import ccall unsafe "_close"
    c'close :: CInt -> IO CInt

foreign import ccall unsafe "_write"
    c'write :: CInt -> Ptr CChar -> CInt -> IO CInt

foreign import ccall unsafe "_read"
    c'read :: CInt -> Ptr CChar -> CInt -> IO CInt
#else
foreign import ccall unsafe "close"
    c'close :: CInt -> IO CInt

foreign import ccall unsafe "write"
    c'write :: CInt -> Ptr CChar -> CInt -> IO CInt

foreign import ccall unsafe "read"
    c'read :: CInt -> Ptr CChar -> CInt -> IO CInt
#endif