1use std::ops::{Deref, DerefMut};
8
9use async_trait::async_trait;
10use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
11use mas_storage::{
12 BoxRepository, BoxRepositoryFactory, MapErr, Repository, RepositoryAccess, RepositoryError,
13 RepositoryFactory, RepositoryTransaction,
14 app_session::AppSessionRepository,
15 compat::{
16 CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
17 CompatSsoLoginRepository,
18 },
19 oauth2::{
20 OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
21 OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
22 },
23 policy_data::PolicyDataRepository,
24 queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
25 upstream_oauth2::{
26 UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
27 UpstreamOAuthSessionRepository,
28 },
29 user::{
30 BrowserSessionRepository, UserEmailRepository, UserPasswordRepository,
31 UserRecoveryRepository, UserRegistrationRepository, UserRegistrationTokenRepository,
32 UserRepository, UserTermsRepository,
33 },
34};
35use sqlx::{PgConnection, PgPool, Postgres, Transaction};
36use tracing::Instrument;
37
38use crate::{
39 DatabaseError,
40 app_session::PgAppSessionRepository,
41 compat::{
42 PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
43 PgCompatSsoLoginRepository,
44 },
45 oauth2::{
46 PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
47 PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository,
48 PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
49 },
50 policy_data::PgPolicyDataRepository,
51 queue::{
52 job::PgQueueJobRepository, schedule::PgQueueScheduleRepository,
53 worker::PgQueueWorkerRepository,
54 },
55 telemetry::DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM,
56 upstream_oauth2::{
57 PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
58 PgUpstreamOAuthSessionRepository,
59 },
60 user::{
61 PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository,
62 PgUserRecoveryRepository, PgUserRegistrationRepository, PgUserRegistrationTokenRepository,
63 PgUserRepository, PgUserTermsRepository,
64 },
65};
66
67#[derive(Clone)]
70pub struct PgRepositoryFactory {
71 pool: PgPool,
72}
73
74impl PgRepositoryFactory {
75 #[must_use]
77 pub fn new(pool: PgPool) -> Self {
78 Self { pool }
79 }
80
81 #[must_use]
83 pub fn boxed(self) -> BoxRepositoryFactory {
84 Box::new(self)
85 }
86
87 #[must_use]
89 pub fn pool(&self) -> PgPool {
90 self.pool.clone()
91 }
92}
93
94#[async_trait]
95impl RepositoryFactory for PgRepositoryFactory {
96 async fn create(&self) -> Result<BoxRepository, RepositoryError> {
97 let start = std::time::Instant::now();
98 let repo = PgRepository::from_pool(&self.pool)
99 .await
100 .map_err(RepositoryError::from_error)?
101 .boxed();
102
103 let duration = start.elapsed();
105 let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
106 DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM.record(duration_ms, &[]);
107
108 Ok(repo)
109 }
110}
111
112pub struct PgRepository<C = Transaction<'static, Postgres>> {
115 conn: C,
116}
117
118impl PgRepository {
119 pub async fn from_pool(pool: &PgPool) -> Result<Self, DatabaseError> {
126 let txn = pool.begin().await?;
127 Ok(Self::from_conn(txn))
128 }
129
130 pub fn boxed(self) -> BoxRepository {
132 Box::new(MapErr::new(self, RepositoryError::from_error))
133 }
134}
135
136impl<C> PgRepository<C> {
137 pub fn from_conn(conn: C) -> Self {
140 PgRepository { conn }
141 }
142
143 pub fn into_inner(self) -> C {
145 self.conn
146 }
147}
148
149impl<C> AsRef<C> for PgRepository<C> {
150 fn as_ref(&self) -> &C {
151 &self.conn
152 }
153}
154
155impl<C> AsMut<C> for PgRepository<C> {
156 fn as_mut(&mut self) -> &mut C {
157 &mut self.conn
158 }
159}
160
161impl<C> Deref for PgRepository<C> {
162 type Target = C;
163
164 fn deref(&self) -> &Self::Target {
165 &self.conn
166 }
167}
168
169impl<C> DerefMut for PgRepository<C> {
170 fn deref_mut(&mut self) -> &mut Self::Target {
171 &mut self.conn
172 }
173}
174
175impl Repository<DatabaseError> for PgRepository {}
176
177impl RepositoryTransaction for PgRepository {
178 type Error = DatabaseError;
179
180 fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
181 let span = tracing::info_span!("db.save");
182 self.conn
183 .commit()
184 .map_err(DatabaseError::from)
185 .instrument(span)
186 .boxed()
187 }
188
189 fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
190 let span = tracing::info_span!("db.cancel");
191 self.conn
192 .rollback()
193 .map_err(DatabaseError::from)
194 .instrument(span)
195 .boxed()
196 }
197}
198
199impl<C> RepositoryAccess for PgRepository<C>
200where
201 C: AsMut<PgConnection> + Send,
202{
203 type Error = DatabaseError;
204
205 fn upstream_oauth_link<'c>(
206 &'c mut self,
207 ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
208 Box::new(PgUpstreamOAuthLinkRepository::new(self.conn.as_mut()))
209 }
210
211 fn upstream_oauth_provider<'c>(
212 &'c mut self,
213 ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
214 Box::new(PgUpstreamOAuthProviderRepository::new(self.conn.as_mut()))
215 }
216
217 fn upstream_oauth_session<'c>(
218 &'c mut self,
219 ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
220 Box::new(PgUpstreamOAuthSessionRepository::new(self.conn.as_mut()))
221 }
222
223 fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
224 Box::new(PgUserRepository::new(self.conn.as_mut()))
225 }
226
227 fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
228 Box::new(PgUserEmailRepository::new(self.conn.as_mut()))
229 }
230
231 fn user_password<'c>(
232 &'c mut self,
233 ) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
234 Box::new(PgUserPasswordRepository::new(self.conn.as_mut()))
235 }
236
237 fn user_recovery<'c>(
238 &'c mut self,
239 ) -> Box<dyn UserRecoveryRepository<Error = Self::Error> + 'c> {
240 Box::new(PgUserRecoveryRepository::new(self.conn.as_mut()))
241 }
242
243 fn user_terms<'c>(&'c mut self) -> Box<dyn UserTermsRepository<Error = Self::Error> + 'c> {
244 Box::new(PgUserTermsRepository::new(self.conn.as_mut()))
245 }
246
247 fn user_registration<'c>(
248 &'c mut self,
249 ) -> Box<dyn UserRegistrationRepository<Error = Self::Error> + 'c> {
250 Box::new(PgUserRegistrationRepository::new(self.conn.as_mut()))
251 }
252
253 fn user_registration_token<'c>(
254 &'c mut self,
255 ) -> Box<dyn UserRegistrationTokenRepository<Error = Self::Error> + 'c> {
256 Box::new(PgUserRegistrationTokenRepository::new(self.conn.as_mut()))
257 }
258
259 fn browser_session<'c>(
260 &'c mut self,
261 ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
262 Box::new(PgBrowserSessionRepository::new(self.conn.as_mut()))
263 }
264
265 fn app_session<'c>(&'c mut self) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
266 Box::new(PgAppSessionRepository::new(self.conn.as_mut()))
267 }
268
269 fn oauth2_client<'c>(
270 &'c mut self,
271 ) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
272 Box::new(PgOAuth2ClientRepository::new(self.conn.as_mut()))
273 }
274
275 fn oauth2_authorization_grant<'c>(
276 &'c mut self,
277 ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
278 Box::new(PgOAuth2AuthorizationGrantRepository::new(
279 self.conn.as_mut(),
280 ))
281 }
282
283 fn oauth2_session<'c>(
284 &'c mut self,
285 ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
286 Box::new(PgOAuth2SessionRepository::new(self.conn.as_mut()))
287 }
288
289 fn oauth2_access_token<'c>(
290 &'c mut self,
291 ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
292 Box::new(PgOAuth2AccessTokenRepository::new(self.conn.as_mut()))
293 }
294
295 fn oauth2_refresh_token<'c>(
296 &'c mut self,
297 ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
298 Box::new(PgOAuth2RefreshTokenRepository::new(self.conn.as_mut()))
299 }
300
301 fn oauth2_device_code_grant<'c>(
302 &'c mut self,
303 ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
304 Box::new(PgOAuth2DeviceCodeGrantRepository::new(self.conn.as_mut()))
305 }
306
307 fn compat_session<'c>(
308 &'c mut self,
309 ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
310 Box::new(PgCompatSessionRepository::new(self.conn.as_mut()))
311 }
312
313 fn compat_sso_login<'c>(
314 &'c mut self,
315 ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
316 Box::new(PgCompatSsoLoginRepository::new(self.conn.as_mut()))
317 }
318
319 fn compat_access_token<'c>(
320 &'c mut self,
321 ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
322 Box::new(PgCompatAccessTokenRepository::new(self.conn.as_mut()))
323 }
324
325 fn compat_refresh_token<'c>(
326 &'c mut self,
327 ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
328 Box::new(PgCompatRefreshTokenRepository::new(self.conn.as_mut()))
329 }
330
331 fn queue_worker<'c>(&'c mut self) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c> {
332 Box::new(PgQueueWorkerRepository::new(self.conn.as_mut()))
333 }
334
335 fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c> {
336 Box::new(PgQueueJobRepository::new(self.conn.as_mut()))
337 }
338
339 fn queue_schedule<'c>(
340 &'c mut self,
341 ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
342 Box::new(PgQueueScheduleRepository::new(self.conn.as_mut()))
343 }
344
345 fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
346 Box::new(PgPolicyDataRepository::new(self.conn.as_mut()))
347 }
348}