{-# LANGUAGE OverloadedStrings #-}

module Network.Socket.BufferPool.Recv (
    receive,
    makeRecvN,
) where

import qualified Data.ByteString as BS
import Data.ByteString.Internal (ByteString (..), unsafeCreate)
import Data.IORef
import Network.Socket (Socket, recvBuf)

import Network.Socket.BufferPool.Buffer
import Network.Socket.BufferPool.Types

----------------------------------------------------------------

-- | The receiving function with a buffer pool.
--   The buffer pool is automatically managed.
receive :: Socket -> BufferPool -> Recv
receive :: Socket -> BufferPool -> Recv
receive Socket
sock BufferPool
pool = BufferPool -> (Buffer -> Int -> IO Int) -> Recv
withBufferPool BufferPool
pool ((Buffer -> Int -> IO Int) -> Recv)
-> (Buffer -> Int -> IO Int) -> Recv
forall a b. (a -> b) -> a -> b
$ \Buffer
ptr Int
size -> Socket -> Buffer -> Int -> IO Int
recvBuf Socket
sock Buffer
ptr Int
size

----------------------------------------------------------------

-- | This function returns a receiving function
--   based on two receiving functions.
--   The returned function receives exactly N bytes.
--   The first argument is an initial received data.
--   After consuming the initial data, the two functions is used.
--   When N is less than equal to 4096, the buffer pool is used.
--   Otherwise, a new buffer is allocated.
--   In this case, the global lock is taken.
--
-- >>> :seti -XOverloadedStrings
-- >>> tryRecvN "a" 3 =<< _iorefRecv ["bcd"]
-- ("abc","d")
-- >>> tryRecvN "a" 3 =<< _iorefRecv ["bc"]
-- ("abc","")
-- >>> tryRecvN "a" 3 =<< _iorefRecv ["b"]
-- ("ab","")
makeRecvN :: ByteString -> Recv -> IO RecvN
makeRecvN :: ByteString -> Recv -> IO RecvN
makeRecvN ByteString
bs0 Recv
recv = do
    ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
    return $ recvN ref recv

-- | The receiving function which receives exactly N bytes
--   (the fourth argument).
recvN :: IORef ByteString -> Recv -> RecvN
recvN :: IORef ByteString -> Recv -> RecvN
recvN IORef ByteString
ref Recv
recv Int
size = do
    cached <- IORef ByteString -> Recv
forall a. IORef a -> IO a
readIORef IORef ByteString
ref
    (bs, leftover) <- tryRecvN cached size recv
    writeIORef ref leftover
    return bs

----------------------------------------------------------------

tryRecvN :: ByteString -> Int -> IO ByteString -> IO (ByteString, ByteString)
tryRecvN :: ByteString -> Int -> Recv -> IO (ByteString, ByteString)
tryRecvN ByteString
init0 Int
siz0 Recv
recv
    | Int
siz0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
len0 = (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString, ByteString) -> IO (ByteString, ByteString))
-> (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
siz0 ByteString
init0
    | Bool
otherwise = ([ByteString] -> [ByteString])
-> Int -> IO (ByteString, ByteString)
go (ByteString
init0 ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:) (Int
siz0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len0)
  where
    len0 :: Int
len0 = ByteString -> Int
BS.length ByteString
init0
    go :: ([ByteString] -> [ByteString])
-> Int -> IO (ByteString, ByteString)
go [ByteString] -> [ByteString]
build Int
left = do
        bs <- Recv
recv
        let len = ByteString -> Int
BS.length ByteString
bs
        if len == 0
            then do
                let cs = Int -> [ByteString] -> ByteString
concatN (Int
siz0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
left) ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
build []
                return (cs, "")
            else
                if len >= left
                    then do
                        let (consume, leftover) = BS.splitAt left bs
                            ret = Int -> [ByteString] -> ByteString
concatN Int
siz0 ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
build [ByteString
consume]
                        return (ret, leftover)
                    else do
                        let build' = [ByteString] -> [ByteString]
build ([ByteString] -> [ByteString])
-> ([ByteString] -> [ByteString]) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString
bs ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:)
                            left' = Int
left Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
                        go build' left'

concatN :: Int -> [ByteString] -> ByteString
-- Just because it's logical
concatN :: Int -> [ByteString] -> ByteString
concatN Int
_ [] = ByteString
""
-- To avoid a copy if there's only one ByteString
concatN Int
_ [ByteString
bs] = ByteString
bs
concatN Int
total [ByteString]
bss0 =
    Int -> (Buffer -> IO ()) -> ByteString
unsafeCreate Int
total ((Buffer -> IO ()) -> ByteString)
-> (Buffer -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Buffer
ptr -> [ByteString] -> Buffer -> IO ()
goCopy [ByteString]
bss0 Buffer
ptr
  where
    goCopy :: [ByteString] -> Buffer -> IO ()
goCopy [] Buffer
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    goCopy (ByteString
bs : [ByteString]
bss) Buffer
ptr = do
        ptr' <- Buffer -> ByteString -> IO Buffer
copy Buffer
ptr ByteString
bs
        goCopy bss ptr'

-- | doctest only. Elements in the argument must not be empty.
_iorefRecv :: [ByteString] -> IO (IO ByteString)
_iorefRecv :: [ByteString] -> IO Recv
_iorefRecv [ByteString]
ini = do
    ref <- [ByteString] -> IO (IORef [ByteString])
forall a. a -> IO (IORef a)
newIORef [ByteString]
ini
    return $ recv ref
  where
    recv :: IORef [b] -> IO b
recv IORef [b]
ref = do
        xxs <- IORef [b] -> IO [b]
forall a. IORef a -> IO a
readIORef IORef [b]
ref
        case xxs of
            [] -> do
                IORef [b] -> [b] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [b]
ref ([b] -> IO ()) -> [b] -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> [b]
forall a. HasCallStack => String -> a
error String
"closed"
                b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
""
            b
x : [b]
xs -> do
                IORef [b] -> [b] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [b]
ref [b]
xs
                b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
x