{-# LANGUAGE ScopedTypeVariables #-}
module System.IO.Streams.SSL
( connect
, withConnection
, sslToStreams
) where
import qualified Control.Exception as E
import Control.Monad (void)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as S
import Network.Socket (HostName, PortNumber)
import qualified Network.Socket as N
import OpenSSL.Session (SSL, SSLContext)
import qualified OpenSSL.Session as SSL
import System.IO.Streams (InputStream, OutputStream)
import qualified System.IO.Streams as Streams
bUFSIZ :: Int
bUFSIZ :: Int
bUFSIZ = Int
32752
sslToStreams :: SSL
-> IO (InputStream ByteString, OutputStream ByteString)
sslToStreams :: SSL -> IO (InputStream ByteString, OutputStream ByteString)
sslToStreams SSL
ssl = do
InputStream ByteString
is <- IO (Maybe ByteString) -> IO (InputStream ByteString)
forall a. IO (Maybe a) -> IO (InputStream a)
Streams.makeInputStream IO (Maybe ByteString)
input
OutputStream ByteString
os <- (Maybe ByteString -> IO ()) -> IO (OutputStream ByteString)
forall a. (Maybe a -> IO ()) -> IO (OutputStream a)
Streams.makeOutputStream Maybe ByteString -> IO ()
output
(InputStream ByteString, OutputStream ByteString)
-> IO (InputStream ByteString, OutputStream ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((InputStream ByteString, OutputStream ByteString)
-> IO (InputStream ByteString, OutputStream ByteString))
-> (InputStream ByteString, OutputStream ByteString)
-> IO (InputStream ByteString, OutputStream ByteString)
forall a b. (a -> b) -> a -> b
$! (InputStream ByteString
is, OutputStream ByteString
os)
where
input :: IO (Maybe ByteString)
input = do
ByteString
s <- SSL -> Int -> IO ByteString
SSL.read SSL
ssl Int
bUFSIZ
Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$! if ByteString -> Bool
S.null ByteString
s then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
s
output :: Maybe ByteString -> IO ()
output Maybe ByteString
Nothing = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> IO ()) -> () -> IO ()
forall a b. (a -> b) -> a -> b
$! ()
output (Just ByteString
s) = SSL -> ByteString -> IO ()
SSL.write SSL
ssl ByteString
s
connect :: SSLContext
-> HostName
-> PortNumber
-> IO (InputStream ByteString, OutputStream ByteString, SSL)
connect :: SSLContext
-> HostName
-> PortNumber
-> IO (InputStream ByteString, OutputStream ByteString, SSL)
connect SSLContext
ctx HostName
host PortNumber
port = do
(AddrInfo
addrInfo:[AddrInfo]
_) <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
N.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just (HostName -> Maybe HostName) -> HostName -> Maybe HostName
forall a b. (a -> b) -> a -> b
$ PortNumber -> HostName
forall a. Show a => a -> HostName
show PortNumber
port)
let family :: Family
family = AddrInfo -> Family
N.addrFamily AddrInfo
addrInfo
let socketType :: SocketType
socketType = AddrInfo -> SocketType
N.addrSocketType AddrInfo
addrInfo
let protocol :: ProtocolNumber
protocol = AddrInfo -> ProtocolNumber
N.addrProtocol AddrInfo
addrInfo
let address :: SockAddr
address = AddrInfo -> SockAddr
N.addrAddress AddrInfo
addrInfo
IO Socket
-> (Socket -> IO ())
-> (Socket
-> IO (InputStream ByteString, OutputStream ByteString, SSL))
-> IO (InputStream ByteString, OutputStream ByteString, SSL)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
family SocketType
socketType ProtocolNumber
protocol)
Socket -> IO ()
N.close
(\Socket
sock -> do Socket -> SockAddr -> IO ()
N.connect Socket
sock SockAddr
address
SSL
ssl <- SSLContext -> Socket -> IO SSL
SSL.connection SSLContext
ctx Socket
sock
SSL -> IO ()
SSL.connect SSL
ssl
(InputStream ByteString
is, OutputStream ByteString
os) <- SSL -> IO (InputStream ByteString, OutputStream ByteString)
sslToStreams SSL
ssl
(InputStream ByteString, OutputStream ByteString, SSL)
-> IO (InputStream ByteString, OutputStream ByteString, SSL)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((InputStream ByteString, OutputStream ByteString, SSL)
-> IO (InputStream ByteString, OutputStream ByteString, SSL))
-> (InputStream ByteString, OutputStream ByteString, SSL)
-> IO (InputStream ByteString, OutputStream ByteString, SSL)
forall a b. (a -> b) -> a -> b
$! (InputStream ByteString
is, OutputStream ByteString
os, SSL
ssl)
)
where
hints :: AddrInfo
hints = AddrInfo
N.defaultHints {
addrFlags :: [AddrInfoFlag]
N.addrFlags = [AddrInfoFlag
N.AI_NUMERICSERV]
, addrSocketType :: SocketType
N.addrSocketType = SocketType
N.Stream
}
withConnection ::
SSLContext
-> HostName
-> PortNumber
-> (InputStream ByteString -> OutputStream ByteString -> SSL -> IO a)
-> IO a
withConnection :: forall a.
SSLContext
-> HostName
-> PortNumber
-> (InputStream ByteString
-> OutputStream ByteString -> SSL -> IO a)
-> IO a
withConnection SSLContext
ctx HostName
host PortNumber
port InputStream ByteString -> OutputStream ByteString -> SSL -> IO a
action = do
(AddrInfo
addrInfo:[AddrInfo]
_) <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
N.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just (HostName -> Maybe HostName) -> HostName -> Maybe HostName
forall a b. (a -> b) -> a -> b
$ PortNumber -> HostName
forall a. Show a => a -> HostName
show PortNumber
port)
IO (InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> ((InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> IO ())
-> ((InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (AddrInfo
-> IO
(InputStream ByteString, OutputStream ByteString, SSL, Socket)
connectTo AddrInfo
addrInfo) (InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> IO ()
forall {a} {a}. (a, OutputStream a, SSL, Socket) -> IO ()
cleanup (InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> IO a
forall {d}.
(InputStream ByteString, OutputStream ByteString, SSL, d) -> IO a
go
where
go :: (InputStream ByteString, OutputStream ByteString, SSL, d) -> IO a
go (InputStream ByteString
is, OutputStream ByteString
os, SSL
ssl, d
_) = InputStream ByteString -> OutputStream ByteString -> SSL -> IO a
action InputStream ByteString
is OutputStream ByteString
os SSL
ssl
connectTo :: AddrInfo
-> IO
(InputStream ByteString, OutputStream ByteString, SSL, Socket)
connectTo AddrInfo
addrInfo = do
let family :: Family
family = AddrInfo -> Family
N.addrFamily AddrInfo
addrInfo
let socketType :: SocketType
socketType = AddrInfo -> SocketType
N.addrSocketType AddrInfo
addrInfo
let protocol :: ProtocolNumber
protocol = AddrInfo -> ProtocolNumber
N.addrProtocol AddrInfo
addrInfo
let address :: SockAddr
address = AddrInfo -> SockAddr
N.addrAddress AddrInfo
addrInfo
IO Socket
-> (Socket -> IO ())
-> (Socket
-> IO
(InputStream ByteString, OutputStream ByteString, SSL, Socket))
-> IO
(InputStream ByteString, OutputStream ByteString, SSL, Socket)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
family SocketType
socketType ProtocolNumber
protocol)
Socket -> IO ()
N.close
(\Socket
sock -> do Socket -> SockAddr -> IO ()
N.connect Socket
sock SockAddr
address
SSL
ssl <- SSLContext -> Socket -> IO SSL
SSL.connection SSLContext
ctx Socket
sock
SSL -> IO ()
SSL.connect SSL
ssl
(InputStream ByteString
is, OutputStream ByteString
os) <- SSL -> IO (InputStream ByteString, OutputStream ByteString)
sslToStreams SSL
ssl
(InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> IO
(InputStream ByteString, OutputStream ByteString, SSL, Socket)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> IO
(InputStream ByteString, OutputStream ByteString, SSL, Socket))
-> (InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> IO
(InputStream ByteString, OutputStream ByteString, SSL, Socket)
forall a b. (a -> b) -> a -> b
$! (InputStream ByteString
is, OutputStream ByteString
os, SSL
ssl, Socket
sock))
cleanup :: (a, OutputStream a, SSL, Socket) -> IO ()
cleanup (a
_, OutputStream a
os, SSL
ssl, Socket
sock) = IO () -> IO ()
forall a. IO a -> IO a
E.mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
IO () -> IO ()
forall {a}. IO a -> IO ()
eatException (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$! Maybe a -> OutputStream a -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
Streams.write Maybe a
forall a. Maybe a
Nothing OutputStream a
os
IO () -> IO ()
forall {a}. IO a -> IO ()
eatException (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$! SSL -> ShutdownType -> IO ()
SSL.shutdown SSL
ssl (ShutdownType -> IO ()) -> ShutdownType -> IO ()
forall a b. (a -> b) -> a -> b
$! ShutdownType
SSL.Unidirectional
IO () -> IO ()
forall {a}. IO a -> IO ()
eatException (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$! Socket -> IO ()
N.close Socket
sock
hints :: AddrInfo
hints = AddrInfo
N.defaultHints {
addrFlags :: [AddrInfoFlag]
N.addrFlags = [AddrInfoFlag
N.AI_NUMERICSERV]
, addrSocketType :: SocketType
N.addrSocketType = SocketType
N.Stream
}
eatException :: IO a -> IO ()
eatException IO a
m = IO a -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void IO a
m IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` (\(SomeException
_::E.SomeException) -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> IO ()) -> () -> IO ()
forall a b. (a -> b) -> a -> b
$! ())