{-# LANGUAGE ScopedTypeVariables #-}

-- | This module provides convenience functions for interfacing @io-streams@
-- with @HsOpenSSL@. It is intended to be imported @qualified@, e.g.:
--
-- @
-- import qualified "OpenSSL" as SSL
-- import qualified "OpenSSL.Session" as SSL
-- import qualified "System.IO.Streams.SSL" as SSLStreams
--
-- \ example :: IO ('InputStream' 'ByteString', 'OutputStream' 'ByteString')
-- example = SSL.'SSL.withOpenSSL' $ do
--     ctx <- SSL.'SSL.context'
--     SSL.'SSL.contextSetDefaultCiphers' ctx
--
-- \     \-\- Note: the location of the system certificates is system-dependent,
--     \-\- on Linux systems this is usually \"\/etc\/ssl\/certs\". This
--     \-\- step is optional if you choose to disable certificate verification
--     \-\- (not recommended!).
--     SSL.'SSL.contextSetCADirectory' ctx \"\/etc\/ssl\/certs\"
--     SSL.'SSL.contextSetVerificationMode' ctx $
--         SSL.'SSL.VerifyPeer' True True Nothing
--     SSLStreams.'connect' ctx "foo.com" 4444
-- @
--

module System.IO.Streams.SSL
  ( connect
  , withConnection
  , sslToStreams
  ) where

import qualified Control.Exception     as E
import           Control.Monad         (void)
import           Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as S
import           Network.Socket        (HostName, PortNumber)
import qualified Network.Socket        as N
import           OpenSSL.Session       (SSL, SSLContext)
import qualified OpenSSL.Session       as SSL
import           System.IO.Streams     (InputStream, OutputStream)
import qualified System.IO.Streams     as Streams


------------------------------------------------------------------------------
bUFSIZ :: Int
bUFSIZ :: Int
bUFSIZ = Int
32752


------------------------------------------------------------------------------
-- | Given an existing HsOpenSSL 'SSL' connection, produces an 'InputStream' \/
-- 'OutputStream' pair.
sslToStreams :: SSL             -- ^ SSL connection object
             -> IO (InputStream ByteString, OutputStream ByteString)
sslToStreams :: SSL -> IO (InputStream ByteString, OutputStream ByteString)
sslToStreams SSL
ssl = do
    InputStream ByteString
is <- IO (Maybe ByteString) -> IO (InputStream ByteString)
forall a. IO (Maybe a) -> IO (InputStream a)
Streams.makeInputStream IO (Maybe ByteString)
input
    OutputStream ByteString
os <- (Maybe ByteString -> IO ()) -> IO (OutputStream ByteString)
forall a. (Maybe a -> IO ()) -> IO (OutputStream a)
Streams.makeOutputStream Maybe ByteString -> IO ()
output
    (InputStream ByteString, OutputStream ByteString)
-> IO (InputStream ByteString, OutputStream ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((InputStream ByteString, OutputStream ByteString)
 -> IO (InputStream ByteString, OutputStream ByteString))
-> (InputStream ByteString, OutputStream ByteString)
-> IO (InputStream ByteString, OutputStream ByteString)
forall a b. (a -> b) -> a -> b
$! (InputStream ByteString
is, OutputStream ByteString
os)

  where
    input :: IO (Maybe ByteString)
input = do
        ByteString
s <- SSL -> Int -> IO ByteString
SSL.read SSL
ssl Int
bUFSIZ
        Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$! if ByteString -> Bool
S.null ByteString
s then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
s

    output :: Maybe ByteString -> IO ()
output Maybe ByteString
Nothing  = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> IO ()) -> () -> IO ()
forall a b. (a -> b) -> a -> b
$! ()
    output (Just ByteString
s) = SSL -> ByteString -> IO ()
SSL.write SSL
ssl ByteString
s


------------------------------------------------------------------------------
-- | Convenience function for initiating an SSL connection to the given
-- @('HostName', 'PortNumber')@ combination.
--
-- Note that sending an end-of-file to the returned 'OutputStream' will not
-- close the underlying SSL connection; to do that, call:
--
-- @
-- SSL.'SSL.shutdown' ssl SSL.'SSL.Unidirectional'
-- maybe (return ()) 'N.close' $ SSL.'SSL.sslSocket' ssl
-- @
--
-- on the returned 'SSL' object.
connect :: SSLContext           -- ^ SSL context. See the @HsOpenSSL@
                                -- documentation for information on creating
                                -- this.
        -> HostName             -- ^ hostname to connect to
        -> PortNumber           -- ^ port number to connect to
        -> IO (InputStream ByteString, OutputStream ByteString, SSL)
connect :: SSLContext
-> HostName
-> PortNumber
-> IO (InputStream ByteString, OutputStream ByteString, SSL)
connect SSLContext
ctx HostName
host PortNumber
port = do
    -- Partial function here OK, network will throw an exception rather than
    -- return the empty list here.
    (AddrInfo
addrInfo:[AddrInfo]
_) <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
N.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just (HostName -> Maybe HostName) -> HostName -> Maybe HostName
forall a b. (a -> b) -> a -> b
$ PortNumber -> HostName
forall a. Show a => a -> HostName
show PortNumber
port)

    let family :: Family
family     = AddrInfo -> Family
N.addrFamily AddrInfo
addrInfo
    let socketType :: SocketType
socketType = AddrInfo -> SocketType
N.addrSocketType AddrInfo
addrInfo
    let protocol :: ProtocolNumber
protocol   = AddrInfo -> ProtocolNumber
N.addrProtocol AddrInfo
addrInfo
    let address :: SockAddr
address    = AddrInfo -> SockAddr
N.addrAddress AddrInfo
addrInfo

    IO Socket
-> (Socket -> IO ())
-> (Socket
    -> IO (InputStream ByteString, OutputStream ByteString, SSL))
-> IO (InputStream ByteString, OutputStream ByteString, SSL)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
family SocketType
socketType ProtocolNumber
protocol)
                     Socket -> IO ()
N.close
                     (\Socket
sock -> do Socket -> SockAddr -> IO ()
N.connect Socket
sock SockAddr
address
                                  SSL
ssl <- SSLContext -> Socket -> IO SSL
SSL.connection SSLContext
ctx Socket
sock
                                  SSL -> IO ()
SSL.connect SSL
ssl
                                  (InputStream ByteString
is, OutputStream ByteString
os) <- SSL -> IO (InputStream ByteString, OutputStream ByteString)
sslToStreams SSL
ssl
                                  (InputStream ByteString, OutputStream ByteString, SSL)
-> IO (InputStream ByteString, OutputStream ByteString, SSL)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((InputStream ByteString, OutputStream ByteString, SSL)
 -> IO (InputStream ByteString, OutputStream ByteString, SSL))
-> (InputStream ByteString, OutputStream ByteString, SSL)
-> IO (InputStream ByteString, OutputStream ByteString, SSL)
forall a b. (a -> b) -> a -> b
$! (InputStream ByteString
is, OutputStream ByteString
os, SSL
ssl)
                     )

  where
    hints :: AddrInfo
hints = AddrInfo
N.defaultHints {
              addrFlags :: [AddrInfoFlag]
N.addrFlags      = [AddrInfoFlag
N.AI_NUMERICSERV]
            , addrSocketType :: SocketType
N.addrSocketType = SocketType
N.Stream
            }


------------------------------------------------------------------------------
-- | Convenience function for initiating an SSL connection to the given
-- @('HostName', 'PortNumber')@ combination. The socket and SSL connection are
-- closed and deleted after the user handler runs.
--
-- /Since: 1.2.0.0./
withConnection ::
     SSLContext           -- ^ SSL context. See the @HsOpenSSL@
                          -- documentation for information on creating
                          -- this.
  -> HostName             -- ^ hostname to connect to
  -> PortNumber           -- ^ port number to connect to
  -> (InputStream ByteString -> OutputStream ByteString -> SSL -> IO a)
          -- ^ Action to run with the new connection
  -> IO a
withConnection :: forall a.
SSLContext
-> HostName
-> PortNumber
-> (InputStream ByteString
    -> OutputStream ByteString -> SSL -> IO a)
-> IO a
withConnection SSLContext
ctx HostName
host PortNumber
port InputStream ByteString -> OutputStream ByteString -> SSL -> IO a
action = do
    (AddrInfo
addrInfo:[AddrInfo]
_) <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
N.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just (HostName -> Maybe HostName) -> HostName -> Maybe HostName
forall a b. (a -> b) -> a -> b
$ PortNumber -> HostName
forall a. Show a => a -> HostName
show PortNumber
port)
    IO (InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> ((InputStream ByteString, OutputStream ByteString, SSL, Socket)
    -> IO ())
-> ((InputStream ByteString, OutputStream ByteString, SSL, Socket)
    -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (AddrInfo
-> IO
     (InputStream ByteString, OutputStream ByteString, SSL, Socket)
connectTo AddrInfo
addrInfo) (InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> IO ()
forall {a} {a}. (a, OutputStream a, SSL, Socket) -> IO ()
cleanup (InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> IO a
forall {d}.
(InputStream ByteString, OutputStream ByteString, SSL, d) -> IO a
go

  where
    go :: (InputStream ByteString, OutputStream ByteString, SSL, d) -> IO a
go (InputStream ByteString
is, OutputStream ByteString
os, SSL
ssl, d
_) = InputStream ByteString -> OutputStream ByteString -> SSL -> IO a
action InputStream ByteString
is OutputStream ByteString
os SSL
ssl

    connectTo :: AddrInfo
-> IO
     (InputStream ByteString, OutputStream ByteString, SSL, Socket)
connectTo AddrInfo
addrInfo = do
        let family :: Family
family     = AddrInfo -> Family
N.addrFamily AddrInfo
addrInfo
        let socketType :: SocketType
socketType = AddrInfo -> SocketType
N.addrSocketType AddrInfo
addrInfo
        let protocol :: ProtocolNumber
protocol   = AddrInfo -> ProtocolNumber
N.addrProtocol AddrInfo
addrInfo
        let address :: SockAddr
address    = AddrInfo -> SockAddr
N.addrAddress AddrInfo
addrInfo
        IO Socket
-> (Socket -> IO ())
-> (Socket
    -> IO
         (InputStream ByteString, OutputStream ByteString, SSL, Socket))
-> IO
     (InputStream ByteString, OutputStream ByteString, SSL, Socket)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket Family
family SocketType
socketType ProtocolNumber
protocol)
                         Socket -> IO ()
N.close
                         (\Socket
sock -> do Socket -> SockAddr -> IO ()
N.connect Socket
sock SockAddr
address
                                      SSL
ssl <- SSLContext -> Socket -> IO SSL
SSL.connection SSLContext
ctx Socket
sock
                                      SSL -> IO ()
SSL.connect SSL
ssl
                                      (InputStream ByteString
is, OutputStream ByteString
os) <- SSL -> IO (InputStream ByteString, OutputStream ByteString)
sslToStreams SSL
ssl
                                      (InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> IO
     (InputStream ByteString, OutputStream ByteString, SSL, Socket)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((InputStream ByteString, OutputStream ByteString, SSL, Socket)
 -> IO
      (InputStream ByteString, OutputStream ByteString, SSL, Socket))
-> (InputStream ByteString, OutputStream ByteString, SSL, Socket)
-> IO
     (InputStream ByteString, OutputStream ByteString, SSL, Socket)
forall a b. (a -> b) -> a -> b
$! (InputStream ByteString
is, OutputStream ByteString
os, SSL
ssl, Socket
sock))

    cleanup :: (a, OutputStream a, SSL, Socket) -> IO ()
cleanup (a
_, OutputStream a
os, SSL
ssl, Socket
sock) = IO () -> IO ()
forall a. IO a -> IO a
E.mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        IO () -> IO ()
forall {a}. IO a -> IO ()
eatException (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$! Maybe a -> OutputStream a -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
Streams.write Maybe a
forall a. Maybe a
Nothing OutputStream a
os
        IO () -> IO ()
forall {a}. IO a -> IO ()
eatException (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$! SSL -> ShutdownType -> IO ()
SSL.shutdown SSL
ssl (ShutdownType -> IO ()) -> ShutdownType -> IO ()
forall a b. (a -> b) -> a -> b
$! ShutdownType
SSL.Unidirectional
        IO () -> IO ()
forall {a}. IO a -> IO ()
eatException (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$! Socket -> IO ()
N.close Socket
sock

    hints :: AddrInfo
hints = AddrInfo
N.defaultHints {
              addrFlags :: [AddrInfoFlag]
N.addrFlags      = [AddrInfoFlag
N.AI_NUMERICSERV]
            , addrSocketType :: SocketType
N.addrSocketType = SocketType
N.Stream
            }

    eatException :: IO a -> IO ()
eatException IO a
m = IO a -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void IO a
m IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` (\(SomeException
_::E.SomeException) -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> IO ()) -> () -> IO ()
forall a b. (a -> b) -> a -> b
$! ())