Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions libs/cassandra-util/src/Cassandra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import Cassandra.Exec as C
( BatchM,
Client,
ClientState,
GeneralPaginationState (..),
MonadClient,
Page (..),
PageWithState (..),
Expand All @@ -74,6 +75,8 @@ import Cassandra.Exec as C
paginate,
paginateC,
paginateWithState,
paginationStateCassandra,
paginationStatePostgres,
params,
paramsP,
paramsPagingState,
Expand Down
35 changes: 26 additions & 9 deletions libs/cassandra-util/src/Cassandra/Exec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ module Cassandra.Exec
x5,
x1,
paginateC,
GeneralPaginationState (..),
paginationStateCassandra,
paginationStatePostgres,
PageWithState (..),
paginateWithState,
paginateWithStateC,
Expand Down Expand Up @@ -97,23 +100,37 @@ paginateC q p r = go =<< lift (retry r (paginate q p))
when (hasMore page) $
go =<< lift (retry r (liftClient (nextPage page)))

data PageWithState a = PageWithState
{ pwsResults :: [a],
pwsState :: Maybe Protocol.PagingState
data GeneralPaginationState a
= PaginationStateCassandra Protocol.PagingState
| PaginationStatePostgres a

paginationStateCassandra :: GeneralPaginationState pgState -> Maybe Protocol.PagingState
paginationStateCassandra = \case
PaginationStateCassandra state -> Just state
PaginationStatePostgres {} -> Nothing

paginationStatePostgres :: GeneralPaginationState pgState -> Maybe pgState
paginationStatePostgres = \case
PaginationStatePostgres pgState -> Just pgState
PaginationStateCassandra {} -> Nothing

data PageWithState state res = PageWithState
{ pwsResults :: [res],
pwsState :: Maybe (GeneralPaginationState state)
}
deriving (Functor)

-- | Like 'paginate' but exposes the paging state. This paging state can be
-- serialised and sent to consumers of the API. The state is not good for long
-- term storage as the bytestring format may change when the schema of a table
-- changes or when cassandra is upgraded.
paginateWithState :: (MonadClient m, Tuple a, Tuple b, RunQ q) => q R a b -> QueryParams a -> m (PageWithState b)
paginateWithState :: (MonadClient m, Tuple a, Tuple b, RunQ q) => q R a b -> QueryParams a -> m (PageWithState x b)
paginateWithState q p = do
let p' = p {Protocol.pageSize = Protocol.pageSize p <|> Just 10000}
r <- runQ q p'
getResult r >>= \case
Protocol.RowsResult m b ->
pure $ PageWithState b (pagingState m)
pure $ PageWithState b (PaginationStateCassandra <$> pagingState m)
_ -> throwM $ UnexpectedResponse (hrHost r) (hrResponse r)

-- | Like 'paginateWithState' but returns a conduit instead of one page.
Expand All @@ -128,20 +145,20 @@ paginateWithState q p = do
-- where
-- getUsers state = paginateWithState getUsersQuery (paramsPagingState Quorum () 10000 state)
-- @
paginateWithStateC :: forall m a. (Monad m) => (Maybe Protocol.PagingState -> m (PageWithState a)) -> ConduitT () [a] m ()
paginateWithStateC :: forall m res state. (Monad m) => (Maybe (GeneralPaginationState state) -> m (PageWithState state res)) -> ConduitT () [res] m ()
paginateWithStateC getPage = do
go =<< lift (getPage Nothing)
where
go :: PageWithState a -> ConduitT () [a] m ()
go :: PageWithState state res -> ConduitT () [res] m ()
go page = do
unless (null page.pwsResults) $
yield (page.pwsResults)
when (pwsHasMore page) $
go =<< lift (getPage page.pwsState)
go =<< lift (getPage $ page.pwsState)

paramsPagingState :: Consistency -> a -> Int32 -> Maybe Protocol.PagingState -> QueryParams a
paramsPagingState c p n state = QueryParams c False p (Just n) state Nothing Nothing
{-# INLINE paramsPagingState #-}

pwsHasMore :: PageWithState a -> Bool
pwsHasMore :: PageWithState a b -> Bool
pwsHasMore = isJust . pwsState
2 changes: 1 addition & 1 deletion libs/polysemy-wire-zoo/src/Wire/Sem/Paging/Cassandra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ data CassandraPaging

type instance E.PagingState CassandraPaging a = PagingState

type instance E.Page CassandraPaging a = PageWithState a
type instance E.Page CassandraPaging a = PageWithState Void a

type instance E.PagingBounds CassandraPaging TeamId = Range 1 100 Int32

Expand Down
7 changes: 7 additions & 0 deletions libs/wire-api/src/Wire/API/Asset.hs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ import Imports
import Servant
import URI.ByteString
import Wire.API.Error
import Wire.API.PostgresMarshall
import Wire.API.Routes.MultiVerb
import Wire.Arbitrary (Arbitrary (..), GenericUniform (..))

Expand Down Expand Up @@ -200,6 +201,12 @@ instance C.Cql AssetKey where
fromCql (C.CqlText txt) = runParser parser . T.encodeUtf8 $ txt
fromCql _ = Left "AssetKey: Text expected"

instance PostgresMarshall Text AssetKey where
postgresMarshall = assetKeyToText

instance PostgresUnmarshall Text AssetKey where
postgresUnmarshall = mapLeft (\e -> "failed to parse AssetKey: " <> T.pack e) . runParser parser . T.encodeUtf8

--------------------------------------------------------------------------------
-- AssetToken

Expand Down
17 changes: 17 additions & 0 deletions libs/wire-api/src/Wire/API/Locale.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import Data.Time.Format
import Data.Time.LocalTime (TimeZone (..), utc)
import Imports
import Test.QuickCheck
import Wire.API.PostgresMarshall
import Wire.API.User.Orphans ()
import Wire.Arbitrary

Expand Down Expand Up @@ -181,6 +182,14 @@ instance C.Cql Language where
Nothing -> Left "Language: ISO 639-1 expected."
fromCql _ = Left "Language: ASCII expected"

instance PostgresMarshall Text Language where
postgresMarshall = lan2Text

instance PostgresUnmarshall Text Language where
postgresUnmarshall =
mapLeft (\e -> "failed to parse Language: " <> Text.pack e)
. parseOnly languageParser

languageParser :: Parser Language
languageParser = codeParser "language" $ fmap Language . checkAndConvert isLower

Expand All @@ -206,6 +215,14 @@ instance C.Cql Country where
Nothing -> Left "Country: ISO 3166-1-alpha2 expected."
fromCql _ = Left "Country: ASCII expected"

instance PostgresMarshall Text Country where
postgresMarshall = con2Text

instance PostgresUnmarshall Text Country where
postgresUnmarshall =
mapLeft (\e -> "failed to parse Country: " <> Text.pack e)
. parseOnly countryParser

countryParser :: Parser Country
countryParser = codeParser "country" $ fmap Country . checkAndConvert isUpper

Expand Down
16 changes: 11 additions & 5 deletions libs/wire-api/src/Wire/API/Password.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ import Data.ByteString.Lazy (fromStrict, toStrict)
import Data.Misc
import Data.OpenApi qualified as S
import Data.Schema
import Data.Text qualified as Text
import Data.Text.Encoding qualified as Text
import Imports
import OpenSSL.Random (randBytes)
import Wire.API.Password.Argon2id
import Wire.API.Password.Scrypt
import Wire.API.PostgresMarshall

-- | A derived, stretched password that can be safely stored.
data Password
Expand All @@ -56,11 +58,15 @@ instance Cql Password where
fromCql (CqlBlob lbs) = parsePassword . Text.decodeUtf8 . toStrict $ lbs
fromCql _ = Left "password: expected blob"

toCql pw = CqlBlob . fromStrict $ Text.encodeUtf8 encoded
where
encoded = case pw of
Argon2Password argon2pw -> encodeArgon2HashedPassword argon2pw
ScryptPassword scryptpw -> encodeScryptPassword scryptpw
toCql = CqlBlob . fromStrict . Text.encodeUtf8 . postgresMarshall

instance PostgresMarshall Text Password where
postgresMarshall = \case
Argon2Password argon2pw -> encodeArgon2HashedPassword argon2pw
ScryptPassword scryptpw -> encodeScryptPassword scryptpw

instance PostgresUnmarshall Text Password where
postgresUnmarshall = mapLeft Text.pack . parsePassword

-------------------------------------------------------------------------------

Expand Down
29 changes: 29 additions & 0 deletions libs/wire-api/src/Wire/API/PostgresMarshall.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
module Wire.API.PostgresMarshall
( PostgresMarshall (..),
PostgresUnmarshall (..),
StoreAsJSON (..),
lmapPG,
rmapPG,
dimapPG,
Expand All @@ -29,12 +30,15 @@ import Data.Bifunctor (first)
import Data.ByteString qualified as BS
import Data.ByteString.Conversion qualified as BSC
import Data.Domain
import Data.Handle
import Data.Id
import Data.Json.Util (UTCTimeMillis (fromUTCTimeMillis), toUTCTimeMillis)
import Data.Misc
import Data.Profunctor
import Data.Set qualified as Set
import Data.Text qualified as Text
import Data.Text.Encoding qualified as Text
import Data.Time (UTCTime)
import Data.UUID
import Data.Vector (Vector)
import Data.Vector qualified as V
Expand Down Expand Up @@ -518,6 +522,12 @@ instance PostgresMarshall Int64 Milliseconds where
instance PostgresMarshall Text Domain where
postgresMarshall = domainText

instance PostgresMarshall Text Handle where
postgresMarshall = fromHandle

instance PostgresMarshall UTCTime UTCTimeMillis where
postgresMarshall = fromUTCTimeMillis

instance (PostgresMarshall a b) => PostgresMarshall (Maybe a) (Maybe b) where
postgresMarshall = fmap postgresMarshall

Expand Down Expand Up @@ -855,6 +865,12 @@ instance (PostgresUnmarshall a b, Ord b) => PostgresUnmarshall (Vector a) (Set b
instance PostgresUnmarshall Int64 Milliseconds where
postgresUnmarshall = Right . int64ToMs

instance PostgresUnmarshall Text Handle where
postgresUnmarshall = mapLeft Text.pack . parseHandleEither

instance PostgresUnmarshall UTCTime UTCTimeMillis where
postgresUnmarshall = Right . toUTCTimeMillis

---

lmapPG :: (PostgresMarshall db domain, Profunctor p) => p db x -> p domain x
Expand All @@ -868,3 +884,16 @@ dimapPG ::
Statement dbIn dbOut ->
Statement domainIn domainOut
dimapPG = refineResult postgresUnmarshall . lmapPG

---

newtype StoreAsJSON a = StoreAsJSON a

instance (ToJSON a) => PostgresMarshall Value (StoreAsJSON a) where
postgresMarshall (StoreAsJSON a) = toJSON a

instance (FromJSON a) => PostgresUnmarshall Value (StoreAsJSON a) where
postgresUnmarshall v =
case fromJSON v of
Error e -> Left $ Text.pack e
Success a -> Right $ StoreAsJSON a
9 changes: 7 additions & 2 deletions libs/wire-api/src/Wire/API/Team/Member.hs
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,13 @@ instance ToSchema TeamMembersPage where

type TeamMembersPagingState = MultiTablePagingState TeamMembersPagingName TeamMembersTable

teamMemberPagingState :: PageWithState TeamMember -> TeamMembersPagingState
teamMemberPagingState p = MultiTablePagingState TeamMembersTable (LBS.toStrict . C.unPagingState <$> pwsState p)
teamMemberPagingState :: PageWithState Void TeamMember -> TeamMembersPagingState
teamMemberPagingState p =
MultiTablePagingState
TeamMembersTable
( LBS.toStrict . C.unPagingState
<$> (C.paginationStateCassandra =<< p.pwsState)
)

instance ToParamSchema TeamMembersPagingState where
toParamSchema _ = toParamSchema (Proxy @Text)
Expand Down
41 changes: 28 additions & 13 deletions libs/wire-api/src/Wire/API/User.hs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ import Data.Schema
import Data.Schema qualified as Schema
import Data.Set qualified as Set
import Data.Text qualified as T
import Data.Text qualified as Text
import Data.Text.Ascii
import Data.Text.Encoding qualified as T
import Data.Text.Encoding.Error
Expand All @@ -205,6 +206,7 @@ import Wire.API.Error.Brig
import Wire.API.Error.Brig qualified as E
import Wire.API.Locale
import Wire.API.Password
import Wire.API.PostgresMarshall
import Wire.API.Provider.Service (ServiceRef)
import Wire.API.Routes.MultiVerb
import Wire.API.Team
Expand Down Expand Up @@ -1838,21 +1840,28 @@ instance Schema.ToSchema AccountStatus where
instance C.Cql AccountStatus where
ctype = C.Tagged C.IntColumn

toCql Active = C.CqlInt 0
toCql Suspended = C.CqlInt 1
toCql Deleted = C.CqlInt 2
toCql Ephemeral = C.CqlInt 3
toCql PendingInvitation = C.CqlInt 4

fromCql (C.CqlInt i) = case i of
0 -> pure Active
1 -> pure Suspended
2 -> pure Deleted
3 -> pure Ephemeral
4 -> pure PendingInvitation
n -> Left $ "unexpected account status: " ++ show n
toCql = C.CqlInt . postgresMarshall

fromCql (C.CqlInt i) = mapLeft Text.unpack $ postgresUnmarshall i
fromCql _ = Left "account status: int expected"

instance PostgresMarshall Int32 AccountStatus where
postgresMarshall = \case
Active -> 0
Suspended -> 1
Deleted -> 2
Ephemeral -> 3
PendingInvitation -> 4

instance PostgresUnmarshall Int32 AccountStatus where
postgresUnmarshall = \case
0 -> Right Active
1 -> Right Suspended
2 -> Right Deleted
3 -> Right Ephemeral
4 -> Right PendingInvitation
n -> Left $ "unexpected account status: " <> Text.show n

data AccountStatusResp = AccountStatusResp {fromAccountStatusResp :: AccountStatus}
deriving (Eq, Show, Generic)
deriving (Arbitrary) via (GenericUniform AccountStatusResp)
Expand Down Expand Up @@ -1992,6 +2001,12 @@ instance C.Cql (Imports.Set BaseProtocolTag) where
fromCql (C.CqlInt bits) = pure $ protocolSetFromBits (fromIntegral bits)
fromCql _ = Left "Protocol set: Int expected"

instance PostgresMarshall Int32 (Imports.Set BaseProtocolTag) where
postgresMarshall = fromIntegral . protocolSetBits

instance PostgresUnmarshall Int32 (Imports.Set BaseProtocolTag) where
postgresUnmarshall = Right . protocolSetFromBits . fromIntegral

baseProtocolMask :: BaseProtocolTag -> Word32
baseProtocolMask BaseProtocolProteusTag = 1
baseProtocolMask BaseProtocolMLSTag = 2
Expand Down
9 changes: 9 additions & 0 deletions libs/wire-api/src/Wire/API/User/EmailAddress.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import Servant.API qualified as S
import Test.QuickCheck
import Text.Email.Parser
import Text.Email.Validate
import Wire.API.PostgresMarshall

--------------------------------------------------------------------------------
-- Email
Expand Down Expand Up @@ -103,6 +104,14 @@ instance C.Cql EmailAddress where

toCql = C.toCql . fromEmail

instance PostgresMarshall Text EmailAddress where
postgresMarshall = fromEmail

instance PostgresUnmarshall Text EmailAddress where
postgresUnmarshall t = case emailAddressText t of
Just e -> Right e
Nothing -> Left "postgresUnmarshall: Invalid email"

fromEmail :: EmailAddress -> Text
fromEmail = decodeUtf8 . toByteString

Expand Down
2 changes: 2 additions & 0 deletions libs/wire-api/src/Wire/API/User/Identity.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ import Text.Email.Parser
import URI.ByteString qualified as URI
import URI.ByteString.QQ (uri)
import Web.Scim.Schema.User.Email ()
import Wire.API.PostgresMarshall
import Wire.API.User.EmailAddress
import Wire.API.User.Phone
import Wire.API.User.Profile (fromName, mkName)
Expand Down Expand Up @@ -150,6 +151,7 @@ data UserSSOId
| UserScimExternalId Text
deriving stock (Eq, Show, Generic)
deriving (Arbitrary) via (GenericUniform UserSSOId)
deriving (PostgresMarshall A.Value, PostgresUnmarshall A.Value) via (StoreAsJSON UserSSOId)

isUserSSOId :: UserSSOId -> Bool
isUserSSOId (UserSSOId _) = True
Expand Down
Loading