mas_storage_pg/
repository.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::ops::{Deref, DerefMut};
8
9use async_trait::async_trait;
10use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
11use mas_storage::{
12    BoxRepository, BoxRepositoryFactory, MapErr, Repository, RepositoryAccess, RepositoryError,
13    RepositoryFactory, RepositoryTransaction,
14    app_session::AppSessionRepository,
15    compat::{
16        CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
17        CompatSsoLoginRepository,
18    },
19    oauth2::{
20        OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
21        OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
22    },
23    policy_data::PolicyDataRepository,
24    queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
25    upstream_oauth2::{
26        UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
27        UpstreamOAuthSessionRepository,
28    },
29    user::{
30        BrowserSessionRepository, UserEmailRepository, UserPasswordRepository,
31        UserRecoveryRepository, UserRegistrationRepository, UserRegistrationTokenRepository,
32        UserRepository, UserTermsRepository,
33    },
34};
35use sqlx::{PgConnection, PgPool, Postgres, Transaction};
36use tracing::Instrument;
37
38use crate::{
39    DatabaseError,
40    app_session::PgAppSessionRepository,
41    compat::{
42        PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
43        PgCompatSsoLoginRepository,
44    },
45    oauth2::{
46        PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
47        PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository,
48        PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
49    },
50    policy_data::PgPolicyDataRepository,
51    queue::{
52        job::PgQueueJobRepository, schedule::PgQueueScheduleRepository,
53        worker::PgQueueWorkerRepository,
54    },
55    telemetry::DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM,
56    upstream_oauth2::{
57        PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
58        PgUpstreamOAuthSessionRepository,
59    },
60    user::{
61        PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository,
62        PgUserRecoveryRepository, PgUserRegistrationRepository, PgUserRegistrationTokenRepository,
63        PgUserRepository, PgUserTermsRepository,
64    },
65};
66
67/// An implementation of the [`RepositoryFactory`] trait backed by a PostgreSQL
68/// connection pool.
69#[derive(Clone)]
70pub struct PgRepositoryFactory {
71    pool: PgPool,
72}
73
74impl PgRepositoryFactory {
75    /// Create a new [`PgRepositoryFactory`] from a PostgreSQL connection pool.
76    #[must_use]
77    pub fn new(pool: PgPool) -> Self {
78        Self { pool }
79    }
80
81    /// Box the factory
82    #[must_use]
83    pub fn boxed(self) -> BoxRepositoryFactory {
84        Box::new(self)
85    }
86
87    /// Get the underlying PostgreSQL connection pool
88    #[must_use]
89    pub fn pool(&self) -> PgPool {
90        self.pool.clone()
91    }
92}
93
94#[async_trait]
95impl RepositoryFactory for PgRepositoryFactory {
96    async fn create(&self) -> Result<BoxRepository, RepositoryError> {
97        let start = std::time::Instant::now();
98        let repo = PgRepository::from_pool(&self.pool)
99            .await
100            .map_err(RepositoryError::from_error)?
101            .boxed();
102
103        // Measure the time it took to create the connection
104        let duration = start.elapsed();
105        let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
106        DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM.record(duration_ms, &[]);
107
108        Ok(repo)
109    }
110}
111
112/// An implementation of the [`Repository`] trait backed by a PostgreSQL
113/// transaction.
114pub struct PgRepository<C = Transaction<'static, Postgres>> {
115    conn: C,
116}
117
118impl PgRepository {
119    /// Create a new [`PgRepository`] from a PostgreSQL connection pool,
120    /// starting a transaction.
121    ///
122    /// # Errors
123    ///
124    /// Returns a [`DatabaseError`] if the transaction could not be started.
125    pub async fn from_pool(pool: &PgPool) -> Result<Self, DatabaseError> {
126        let txn = pool.begin().await?;
127        Ok(Self::from_conn(txn))
128    }
129
130    /// Transform the repository into a type-erased [`BoxRepository`]
131    pub fn boxed(self) -> BoxRepository {
132        Box::new(MapErr::new(self, RepositoryError::from_error))
133    }
134}
135
136impl<C> PgRepository<C> {
137    /// Create a new [`PgRepository`] from an existing PostgreSQL connection
138    /// with a transaction
139    pub fn from_conn(conn: C) -> Self {
140        PgRepository { conn }
141    }
142
143    /// Consume this [`PgRepository`], returning the underlying connection.
144    pub fn into_inner(self) -> C {
145        self.conn
146    }
147}
148
149impl<C> AsRef<C> for PgRepository<C> {
150    fn as_ref(&self) -> &C {
151        &self.conn
152    }
153}
154
155impl<C> AsMut<C> for PgRepository<C> {
156    fn as_mut(&mut self) -> &mut C {
157        &mut self.conn
158    }
159}
160
161impl<C> Deref for PgRepository<C> {
162    type Target = C;
163
164    fn deref(&self) -> &Self::Target {
165        &self.conn
166    }
167}
168
169impl<C> DerefMut for PgRepository<C> {
170    fn deref_mut(&mut self) -> &mut Self::Target {
171        &mut self.conn
172    }
173}
174
175impl Repository<DatabaseError> for PgRepository {}
176
177impl RepositoryTransaction for PgRepository {
178    type Error = DatabaseError;
179
180    fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
181        let span = tracing::info_span!("db.save");
182        self.conn
183            .commit()
184            .map_err(DatabaseError::from)
185            .instrument(span)
186            .boxed()
187    }
188
189    fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
190        let span = tracing::info_span!("db.cancel");
191        self.conn
192            .rollback()
193            .map_err(DatabaseError::from)
194            .instrument(span)
195            .boxed()
196    }
197}
198
199impl<C> RepositoryAccess for PgRepository<C>
200where
201    C: AsMut<PgConnection> + Send,
202{
203    type Error = DatabaseError;
204
205    fn upstream_oauth_link<'c>(
206        &'c mut self,
207    ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
208        Box::new(PgUpstreamOAuthLinkRepository::new(self.conn.as_mut()))
209    }
210
211    fn upstream_oauth_provider<'c>(
212        &'c mut self,
213    ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
214        Box::new(PgUpstreamOAuthProviderRepository::new(self.conn.as_mut()))
215    }
216
217    fn upstream_oauth_session<'c>(
218        &'c mut self,
219    ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
220        Box::new(PgUpstreamOAuthSessionRepository::new(self.conn.as_mut()))
221    }
222
223    fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
224        Box::new(PgUserRepository::new(self.conn.as_mut()))
225    }
226
227    fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
228        Box::new(PgUserEmailRepository::new(self.conn.as_mut()))
229    }
230
231    fn user_password<'c>(
232        &'c mut self,
233    ) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
234        Box::new(PgUserPasswordRepository::new(self.conn.as_mut()))
235    }
236
237    fn user_recovery<'c>(
238        &'c mut self,
239    ) -> Box<dyn UserRecoveryRepository<Error = Self::Error> + 'c> {
240        Box::new(PgUserRecoveryRepository::new(self.conn.as_mut()))
241    }
242
243    fn user_terms<'c>(&'c mut self) -> Box<dyn UserTermsRepository<Error = Self::Error> + 'c> {
244        Box::new(PgUserTermsRepository::new(self.conn.as_mut()))
245    }
246
247    fn user_registration<'c>(
248        &'c mut self,
249    ) -> Box<dyn UserRegistrationRepository<Error = Self::Error> + 'c> {
250        Box::new(PgUserRegistrationRepository::new(self.conn.as_mut()))
251    }
252
253    fn user_registration_token<'c>(
254        &'c mut self,
255    ) -> Box<dyn UserRegistrationTokenRepository<Error = Self::Error> + 'c> {
256        Box::new(PgUserRegistrationTokenRepository::new(self.conn.as_mut()))
257    }
258
259    fn browser_session<'c>(
260        &'c mut self,
261    ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
262        Box::new(PgBrowserSessionRepository::new(self.conn.as_mut()))
263    }
264
265    fn app_session<'c>(&'c mut self) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
266        Box::new(PgAppSessionRepository::new(self.conn.as_mut()))
267    }
268
269    fn oauth2_client<'c>(
270        &'c mut self,
271    ) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
272        Box::new(PgOAuth2ClientRepository::new(self.conn.as_mut()))
273    }
274
275    fn oauth2_authorization_grant<'c>(
276        &'c mut self,
277    ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
278        Box::new(PgOAuth2AuthorizationGrantRepository::new(
279            self.conn.as_mut(),
280        ))
281    }
282
283    fn oauth2_session<'c>(
284        &'c mut self,
285    ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
286        Box::new(PgOAuth2SessionRepository::new(self.conn.as_mut()))
287    }
288
289    fn oauth2_access_token<'c>(
290        &'c mut self,
291    ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
292        Box::new(PgOAuth2AccessTokenRepository::new(self.conn.as_mut()))
293    }
294
295    fn oauth2_refresh_token<'c>(
296        &'c mut self,
297    ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
298        Box::new(PgOAuth2RefreshTokenRepository::new(self.conn.as_mut()))
299    }
300
301    fn oauth2_device_code_grant<'c>(
302        &'c mut self,
303    ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
304        Box::new(PgOAuth2DeviceCodeGrantRepository::new(self.conn.as_mut()))
305    }
306
307    fn compat_session<'c>(
308        &'c mut self,
309    ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
310        Box::new(PgCompatSessionRepository::new(self.conn.as_mut()))
311    }
312
313    fn compat_sso_login<'c>(
314        &'c mut self,
315    ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
316        Box::new(PgCompatSsoLoginRepository::new(self.conn.as_mut()))
317    }
318
319    fn compat_access_token<'c>(
320        &'c mut self,
321    ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
322        Box::new(PgCompatAccessTokenRepository::new(self.conn.as_mut()))
323    }
324
325    fn compat_refresh_token<'c>(
326        &'c mut self,
327    ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
328        Box::new(PgCompatRefreshTokenRepository::new(self.conn.as_mut()))
329    }
330
331    fn queue_worker<'c>(&'c mut self) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c> {
332        Box::new(PgQueueWorkerRepository::new(self.conn.as_mut()))
333    }
334
335    fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c> {
336        Box::new(PgQueueJobRepository::new(self.conn.as_mut()))
337    }
338
339    fn queue_schedule<'c>(
340        &'c mut self,
341    ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
342        Box::new(PgQueueScheduleRepository::new(self.conn.as_mut()))
343    }
344
345    fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
346        Box::new(PgPolicyDataRepository::new(self.conn.as_mut()))
347    }
348}