module Network.Wai.Middleware.RealIp (
realIp,
realIpHeader,
realIpTrusted,
defaultTrusted,
ipInRange,
) where
import qualified Data.ByteString.Char8 as B8 (split, unpack)
import qualified Data.IP as IP
import Data.Maybe (fromMaybe, listToMaybe, mapMaybe)
import Network.HTTP.Types (HeaderName, RequestHeaders)
import Network.Wai (Middleware, remoteHost, requestHeaders)
import Text.Read (readMaybe)
realIp :: Middleware
realIp :: Middleware
realIp = HeaderName -> Middleware
realIpHeader HeaderName
"X-Forwarded-For"
realIpHeader :: HeaderName -> Middleware
HeaderName
header =
HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header ((IP -> Bool) -> Middleware) -> (IP -> Bool) -> Middleware
forall a b. (a -> b) -> a -> b
$ \IP
ip -> (IPRange -> Bool) -> [IPRange] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (IP -> IPRange -> Bool
ipInRange IP
ip) [IPRange]
defaultTrusted
realIpTrusted :: HeaderName -> (IP.IP -> Bool) -> Middleware
realIpTrusted :: HeaderName -> (IP -> Bool) -> Middleware
realIpTrusted HeaderName
header IP -> Bool
isTrusted Application
app Request
req = Application
app Request
req'
where
req' :: Request
req' = Request -> Maybe Request -> Request
forall a. a -> Maybe a -> a
fromMaybe Request
req (Maybe Request -> Request) -> Maybe Request -> Request
forall a b. (a -> b) -> a -> b
$ do
(ip, port) <- SockAddr -> Maybe (IP, PortNumber)
IP.fromSockAddr (Request -> SockAddr
remoteHost Request
req)
ip' <-
if isTrusted ip
then findRealIp (requestHeaders req) header isTrusted
else Nothing
Just $ req{remoteHost = IP.toSockAddr (ip', port)}
defaultTrusted :: [IP.IPRange]
defaultTrusted :: [IPRange]
defaultTrusted =
[ IPRange
"127.0.0.0/8"
, IPRange
"10.0.0.0/8"
, IPRange
"172.16.0.0/12"
, IPRange
"192.168.0.0/16"
, IPRange
"::1/128"
, IPRange
"fc00::/7"
]
ipInRange :: IP.IP -> IP.IPRange -> Bool
ipInRange :: IP -> IPRange -> Bool
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv4Range AddrRange IPv4
r) = IPv4
ip IPv4 -> AddrRange IPv4 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv4
r
ipInRange (IP.IPv6 IPv6
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv6
ip IPv6 -> AddrRange IPv6 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange (IP.IPv4 IPv4
ip) (IP.IPv6Range AddrRange IPv6
r) = IPv4 -> IPv6
IP.ipv4ToIPv6 IPv4
ip IPv6 -> AddrRange IPv6 -> Bool
forall a. Addr a => a -> AddrRange a -> Bool
`IP.isMatchedTo` AddrRange IPv6
r
ipInRange IP
_ IPRange
_ = Bool
False
findRealIp :: RequestHeaders -> HeaderName -> (IP.IP -> Bool) -> Maybe IP.IP
findRealIp :: RequestHeaders -> HeaderName -> (IP -> Bool) -> Maybe IP
findRealIp RequestHeaders
reqHeaders HeaderName
header IP -> Bool
isTrusted =
case ([IP]
nonTrusted, [IP]
ips) of
([], [IP]
xs) -> [IP] -> Maybe IP
forall a. [a] -> Maybe a
listToMaybe [IP]
xs
([IP]
xs, [IP]
_) -> [IP] -> Maybe IP
forall a. [a] -> Maybe a
listToMaybe ([IP] -> Maybe IP) -> [IP] -> Maybe IP
forall a b. (a -> b) -> a -> b
$ [IP] -> [IP]
forall a. [a] -> [a]
reverse [IP]
xs
where
headerVals :: [ByteString]
headerVals = [ByteString
v | (HeaderName
k, ByteString
v) <- RequestHeaders
reqHeaders, HeaderName
k HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
header]
ips :: [IP]
ips = (ByteString -> [IP]) -> [ByteString] -> [IP]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((ByteString -> Maybe IP) -> [ByteString] -> [IP]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (String -> Maybe IP
forall a. Read a => String -> Maybe a
readMaybe (String -> Maybe IP)
-> (ByteString -> String) -> ByteString -> Maybe IP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
B8.unpack) ([ByteString] -> [IP])
-> (ByteString -> [ByteString]) -> ByteString -> [IP]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ByteString -> [ByteString]
B8.split Char
',') [ByteString]
headerVals
nonTrusted :: [IP]
nonTrusted = (IP -> Bool) -> [IP] -> [IP]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (IP -> Bool) -> IP -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IP -> Bool
isTrusted) [IP]
ips