Db.hs 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. {-# OPTIONS -funbox-strict-fields #-}
  2. {-# LANGUAGE OverloadedStrings #-}
  3. module TFB.Db (
  4. Pool
  5. , mkPool
  6. , Config(..)
  7. , queryWorldById
  8. , queryWorldByIds
  9. , updateWorlds
  10. , queryFortunes
  11. , Error
  12. ) where
  13. import qualified TFB.Types as Types
  14. import qualified Data.Either as Either
  15. import Control.Monad (forM, forM_)
  16. import qualified Data.Pool as Pool
  17. import Data.ByteString (ByteString)
  18. import qualified Data.ByteString.Char8 as BSC
  19. import qualified Database.MySQL.Base as MySQL
  20. import qualified System.IO.Streams as Streams
  21. import Data.Text (Text)
  22. import qualified Data.Text as Text
  23. -------------------------------------------------------------------------------
  24. -- * Database
  25. data Config
  26. = Config
  27. { configHost :: String
  28. , configName :: ByteString
  29. , configUser :: ByteString
  30. , configPass :: ByteString
  31. , configStripes :: Int
  32. , configPoolSize :: Int
  33. }
  34. instance Show Config where
  35. show c
  36. = "Config {"
  37. <> " configHost = " <> configHost c
  38. <> ", configName = " <> BSC.unpack (configName c)
  39. <> ", configUser = " <> BSC.unpack (configUser c)
  40. <> ", configPass = REDACTED"
  41. <> ", configStripes = " <> show (configStripes c)
  42. <> ", configPoolSize = " <> show (configPoolSize c)
  43. <> " }"
  44. type Connection = MySQL.MySQLConn
  45. type Pool = Pool.Pool Connection
  46. type Error = Text
  47. type DbRow = [MySQL.MySQLValue]
  48. connect :: Config -> IO Connection
  49. connect c = MySQL.connect myc
  50. where
  51. myc = MySQL.defaultConnectInfoMB4
  52. { MySQL.ciHost = configHost c
  53. , MySQL.ciDatabase = configName c
  54. , MySQL.ciUser = configUser c
  55. , MySQL.ciPassword = configPass c
  56. }
  57. close :: Connection -> IO ()
  58. close = MySQL.close
  59. mkPool :: Config -> IO Pool
  60. mkPool c = Pool.createPool (connect c) close (configStripes c) 0.5 (configPoolSize c)
  61. {-# SPECIALIZE intValEnc :: Int -> MySQL.MySQLValue #-}
  62. {-# SPECIALIZE intValEnc :: Types.QId -> MySQL.MySQLValue #-}
  63. intValEnc :: Integral a => a -> MySQL.MySQLValue
  64. intValEnc = MySQL.MySQLInt16U . fromIntegral
  65. intValDec :: MySQL.MySQLValue -> Either Text Int
  66. intValDec (MySQL.MySQLInt8U i) = pure . fromIntegral $ i
  67. intValDec (MySQL.MySQLInt8 i) = pure . fromIntegral $ i
  68. intValDec (MySQL.MySQLInt16U i) = pure . fromIntegral $ i
  69. intValDec (MySQL.MySQLInt16 i) = pure . fromIntegral $ i
  70. intValDec (MySQL.MySQLInt32U i) = pure . fromIntegral $ i
  71. intValDec (MySQL.MySQLInt32 i) = pure . fromIntegral $ i
  72. intValDec (MySQL.MySQLInt64U i) = pure . fromIntegral $ i
  73. intValDec (MySQL.MySQLInt64 i) = pure . fromIntegral $ i
  74. intValDec x = Left $ "Expected MySQLInt*, received" <> (Text.pack $ show x)
  75. textValDec :: MySQL.MySQLValue -> Either Text Text
  76. textValDec (MySQL.MySQLText t) = pure t
  77. textValDec x = Left $ "Expected Text, received" <> (Text.pack $ show x)
  78. -------------------------------------------------------------------------------
  79. -- * World
  80. decodeWorld :: DbRow -> Either Error Types.World
  81. decodeWorld [] = Left "MarshalError: Expected 2 columns for World, found 0"
  82. decodeWorld (_:[]) = Left "MarshalError: Expected 2 columns for World, found 1"
  83. decodeWorld (c1:c2:_) = Types.World <$> intValDec c1 <*> intValDec c2
  84. queryWorldById :: Pool -> Types.QId -> IO (Either Error Types.World)
  85. queryWorldById dbPool wId = Pool.withResource dbPool $ \conn -> do
  86. (_, rowsS) <- MySQL.query conn s [intValEnc wId]
  87. rows <- Streams.toList rowsS
  88. let eWorlds = fmap decodeWorld rows
  89. let (err, oks) = Either.partitionEithers eWorlds
  90. return $ case err of
  91. [] -> case oks of
  92. [] -> Left "World not found!"
  93. ws -> pure $ head ws
  94. _ -> Left . mconcat $ err
  95. where
  96. s = "SELECT * FROM World WHERE id = ?"
  97. queryWorldByIds :: Pool -> [Types.QId] -> IO (Either Error [Types.World])
  98. queryWorldByIds _ [] = pure . pure $ mempty
  99. queryWorldByIds dbPool wIds = Pool.withResource dbPool $ \conn -> do
  100. sId <- MySQL.prepareStmt conn "SELECT * FROM World WHERE id = ?"
  101. res <- forM wIds $ \wId -> do
  102. (_, rowsS) <- MySQL.queryStmt conn sId [intValEnc wId]
  103. rows <- Streams.toList rowsS
  104. return . fmap decodeWorld $ rows
  105. MySQL.closeStmt conn sId
  106. let (errs, ws) = Either.partitionEithers . mconcat $ res
  107. return $ case errs of
  108. [] -> pure ws
  109. _ -> Left . mconcat $ errs
  110. updateWorlds :: Pool -> [(Types.World, Int)] -> IO (Either Error [Types.World])
  111. updateWorlds _ [] = pure . pure $ mempty
  112. updateWorlds dbPool wsUpdates = Pool.withResource dbPool $ \conn -> do
  113. let ws = fmap updateW wsUpdates
  114. sId <- MySQL.prepareStmt conn "UPDATE World SET randomNumber = ? WHERE id = ?"
  115. forM_ wsUpdates $ \(w, wNum) ->
  116. MySQL.executeStmt conn sId [intValEnc wNum, intValEnc $ Types.wId w]
  117. MySQL.closeStmt conn sId
  118. return . pure $ ws
  119. where
  120. updateW (w,wNum) = w { Types.wRandomNumber = wNum }
  121. -------------------------------------------------------------------------------
  122. -- * Fortunes
  123. decodeFortune :: DbRow -> Either Error Types.Fortune
  124. decodeFortune [] = Left "MarshalError: Expected 2 columns for Fortune, found 0"
  125. decodeFortune (_:[]) = Left "MarshalError: Expected 2 columns for Fortune, found 1"
  126. decodeFortune (c1:c2:_) = Types.Fortune <$> intValDec c1 <*> textValDec c2
  127. queryFortunes :: Pool -> IO (Either Error [Types.Fortune])
  128. queryFortunes dbPool = Pool.withResource dbPool $ \conn -> do
  129. (_, rowsS) <- MySQL.query_ conn "SELECT * FROM Fortune"
  130. rows <- Streams.toList rowsS
  131. let eFortunes = fmap decodeFortune rows
  132. let (err, oks) = Either.partitionEithers eFortunes
  133. return $ case err of
  134. [] -> pure oks
  135. _ -> Left $ head err