1use async_trait::async_trait;
11use mas_data_model::User;
12use mas_storage::{
13 Clock,
14 user::{UserFilter, UserRepository},
15};
16use rand::RngCore;
17use sea_query::{Expr, PostgresQueryBuilder, Query};
18use sea_query_binder::SqlxBinder;
19use sqlx::PgConnection;
20use ulid::Ulid;
21use uuid::Uuid;
22
23use crate::{
24 DatabaseError,
25 filter::{Filter, StatementExt},
26 iden::Users,
27 pagination::QueryBuilderExt,
28 tracing::ExecuteExt,
29};
30
31mod email;
32mod password;
33mod recovery;
34mod registration;
35mod registration_token;
36mod session;
37mod terms;
38
39#[cfg(test)]
40mod tests;
41
42pub use self::{
43 email::PgUserEmailRepository, password::PgUserPasswordRepository,
44 recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
45 registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
46 terms::PgUserTermsRepository,
47};
48
49pub struct PgUserRepository<'c> {
51 conn: &'c mut PgConnection,
52}
53
54impl<'c> PgUserRepository<'c> {
55 pub fn new(conn: &'c mut PgConnection) -> Self {
57 Self { conn }
58 }
59}
60
61mod priv_ {
62 #![allow(missing_docs)]
65
66 use chrono::{DateTime, Utc};
67 use sea_query::enum_def;
68 use uuid::Uuid;
69
70 #[derive(Debug, Clone, sqlx::FromRow)]
71 #[enum_def]
72 pub(super) struct UserLookup {
73 pub(super) user_id: Uuid,
74 pub(super) username: String,
75 pub(super) created_at: DateTime<Utc>,
76 pub(super) locked_at: Option<DateTime<Utc>>,
77 pub(super) deactivated_at: Option<DateTime<Utc>>,
78 pub(super) can_request_admin: bool,
79 }
80}
81
82use priv_::{UserLookup, UserLookupIden};
83
84impl From<UserLookup> for User {
85 fn from(value: UserLookup) -> Self {
86 let id = value.user_id.into();
87 Self {
88 id,
89 username: value.username,
90 sub: id.to_string(),
91 created_at: value.created_at,
92 locked_at: value.locked_at,
93 deactivated_at: value.deactivated_at,
94 can_request_admin: value.can_request_admin,
95 }
96 }
97}
98
99impl Filter for UserFilter<'_> {
100 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
101 sea_query::Condition::all()
102 .add_option(self.state().map(|state| {
103 match state {
104 mas_storage::user::UserState::Deactivated => {
105 Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
106 }
107 mas_storage::user::UserState::Locked => {
108 Expr::col((Users::Table, Users::LockedAt)).is_not_null()
109 }
110 mas_storage::user::UserState::Active => {
111 Expr::col((Users::Table, Users::LockedAt))
112 .is_null()
113 .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
114 }
115 }
116 }))
117 .add_option(self.can_request_admin().map(|can_request_admin| {
118 Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
119 }))
120 }
121}
122
123#[async_trait]
124impl UserRepository for PgUserRepository<'_> {
125 type Error = DatabaseError;
126
127 #[tracing::instrument(
128 name = "db.user.lookup",
129 skip_all,
130 fields(
131 db.query.text,
132 user.id = %id,
133 ),
134 err,
135 )]
136 async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
137 let res = sqlx::query_as!(
138 UserLookup,
139 r#"
140 SELECT user_id
141 , username
142 , created_at
143 , locked_at
144 , deactivated_at
145 , can_request_admin
146 FROM users
147 WHERE user_id = $1
148 "#,
149 Uuid::from(id),
150 )
151 .traced()
152 .fetch_optional(&mut *self.conn)
153 .await?;
154
155 let Some(res) = res else { return Ok(None) };
156
157 Ok(Some(res.into()))
158 }
159
160 #[tracing::instrument(
161 name = "db.user.find_by_username",
162 skip_all,
163 fields(
164 db.query.text,
165 user.username = username,
166 ),
167 err,
168 )]
169 async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
170 let res = sqlx::query_as!(
174 UserLookup,
175 r#"
176 SELECT user_id
177 , username
178 , created_at
179 , locked_at
180 , deactivated_at
181 , can_request_admin
182 FROM users
183 WHERE LOWER(username) = LOWER($1)
184 "#,
185 username,
186 )
187 .traced()
188 .fetch_all(&mut *self.conn)
189 .await?;
190
191 match &res[..] {
192 [user] => Ok(Some(user.clone().into())),
194 [] => Ok(None),
196 list => {
197 if let Some(user) = list.iter().find(|user| user.username == username) {
200 Ok(Some(user.clone().into()))
201 } else {
202 Ok(None)
204 }
205 }
206 }
207 }
208
209 #[tracing::instrument(
210 name = "db.user.add",
211 skip_all,
212 fields(
213 db.query.text,
214 user.username = username,
215 user.id,
216 ),
217 err,
218 )]
219 async fn add(
220 &mut self,
221 rng: &mut (dyn RngCore + Send),
222 clock: &dyn Clock,
223 username: String,
224 ) -> Result<User, Self::Error> {
225 let created_at = clock.now();
226 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
227 tracing::Span::current().record("user.id", tracing::field::display(id));
228
229 let res = sqlx::query!(
230 r#"
231 INSERT INTO users (user_id, username, created_at)
232 VALUES ($1, $2, $3)
233 ON CONFLICT (username) DO NOTHING
234 "#,
235 Uuid::from(id),
236 username,
237 created_at,
238 )
239 .traced()
240 .execute(&mut *self.conn)
241 .await?;
242
243 DatabaseError::ensure_affected_rows(&res, 1)?;
246
247 Ok(User {
248 id,
249 username,
250 sub: id.to_string(),
251 created_at,
252 locked_at: None,
253 deactivated_at: None,
254 can_request_admin: false,
255 })
256 }
257
258 #[tracing::instrument(
259 name = "db.user.exists",
260 skip_all,
261 fields(
262 db.query.text,
263 user.username = username,
264 ),
265 err,
266 )]
267 async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
268 let exists = sqlx::query_scalar!(
269 r#"
270 SELECT EXISTS(
271 SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
272 ) AS "exists!"
273 "#,
274 username
275 )
276 .traced()
277 .fetch_one(&mut *self.conn)
278 .await?;
279
280 Ok(exists)
281 }
282
283 #[tracing::instrument(
284 name = "db.user.lock",
285 skip_all,
286 fields(
287 db.query.text,
288 %user.id,
289 ),
290 err,
291 )]
292 async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
293 if user.locked_at.is_some() {
294 return Ok(user);
295 }
296
297 let locked_at = clock.now();
298 let res = sqlx::query!(
299 r#"
300 UPDATE users
301 SET locked_at = $1
302 WHERE user_id = $2
303 "#,
304 locked_at,
305 Uuid::from(user.id),
306 )
307 .traced()
308 .execute(&mut *self.conn)
309 .await?;
310
311 DatabaseError::ensure_affected_rows(&res, 1)?;
312
313 user.locked_at = Some(locked_at);
314
315 Ok(user)
316 }
317
318 #[tracing::instrument(
319 name = "db.user.unlock",
320 skip_all,
321 fields(
322 db.query.text,
323 %user.id,
324 ),
325 err,
326 )]
327 async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
328 if user.locked_at.is_none() {
329 return Ok(user);
330 }
331
332 let res = sqlx::query!(
333 r#"
334 UPDATE users
335 SET locked_at = NULL
336 WHERE user_id = $1
337 "#,
338 Uuid::from(user.id),
339 )
340 .traced()
341 .execute(&mut *self.conn)
342 .await?;
343
344 DatabaseError::ensure_affected_rows(&res, 1)?;
345
346 user.locked_at = None;
347
348 Ok(user)
349 }
350
351 #[tracing::instrument(
352 name = "db.user.deactivate",
353 skip_all,
354 fields(
355 db.query.text,
356 %user.id,
357 ),
358 err,
359 )]
360 async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
361 if user.deactivated_at.is_some() {
362 return Ok(user);
363 }
364
365 let deactivated_at = clock.now();
366 let res = sqlx::query!(
367 r#"
368 UPDATE users
369 SET deactivated_at = $2
370 WHERE user_id = $1
371 AND deactivated_at IS NULL
372 "#,
373 Uuid::from(user.id),
374 deactivated_at,
375 )
376 .traced()
377 .execute(&mut *self.conn)
378 .await?;
379
380 DatabaseError::ensure_affected_rows(&res, 1)?;
381
382 user.deactivated_at = Some(user.created_at);
383
384 Ok(user)
385 }
386
387 #[tracing::instrument(
388 name = "db.user.set_can_request_admin",
389 skip_all,
390 fields(
391 db.query.text,
392 %user.id,
393 user.can_request_admin = can_request_admin,
394 ),
395 err,
396 )]
397 async fn set_can_request_admin(
398 &mut self,
399 mut user: User,
400 can_request_admin: bool,
401 ) -> Result<User, Self::Error> {
402 let res = sqlx::query!(
403 r#"
404 UPDATE users
405 SET can_request_admin = $2
406 WHERE user_id = $1
407 "#,
408 Uuid::from(user.id),
409 can_request_admin,
410 )
411 .traced()
412 .execute(&mut *self.conn)
413 .await?;
414
415 DatabaseError::ensure_affected_rows(&res, 1)?;
416
417 user.can_request_admin = can_request_admin;
418
419 Ok(user)
420 }
421
422 #[tracing::instrument(
423 name = "db.user.list",
424 skip_all,
425 fields(
426 db.query.text,
427 ),
428 err,
429 )]
430 async fn list(
431 &mut self,
432 filter: UserFilter<'_>,
433 pagination: mas_storage::Pagination,
434 ) -> Result<mas_storage::Page<User>, Self::Error> {
435 let (sql, arguments) = Query::select()
436 .expr_as(
437 Expr::col((Users::Table, Users::UserId)),
438 UserLookupIden::UserId,
439 )
440 .expr_as(
441 Expr::col((Users::Table, Users::Username)),
442 UserLookupIden::Username,
443 )
444 .expr_as(
445 Expr::col((Users::Table, Users::CreatedAt)),
446 UserLookupIden::CreatedAt,
447 )
448 .expr_as(
449 Expr::col((Users::Table, Users::LockedAt)),
450 UserLookupIden::LockedAt,
451 )
452 .expr_as(
453 Expr::col((Users::Table, Users::DeactivatedAt)),
454 UserLookupIden::DeactivatedAt,
455 )
456 .expr_as(
457 Expr::col((Users::Table, Users::CanRequestAdmin)),
458 UserLookupIden::CanRequestAdmin,
459 )
460 .from(Users::Table)
461 .apply_filter(filter)
462 .generate_pagination((Users::Table, Users::UserId), pagination)
463 .build_sqlx(PostgresQueryBuilder);
464
465 let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
466 .traced()
467 .fetch_all(&mut *self.conn)
468 .await?;
469
470 let page = pagination.process(edges).map(User::from);
471
472 Ok(page)
473 }
474
475 #[tracing::instrument(
476 name = "db.user.count",
477 skip_all,
478 fields(
479 db.query.text,
480 ),
481 err,
482 )]
483 async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
484 let (sql, arguments) = Query::select()
485 .expr(Expr::col((Users::Table, Users::UserId)).count())
486 .from(Users::Table)
487 .apply_filter(filter)
488 .build_sqlx(PostgresQueryBuilder);
489
490 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
491 .traced()
492 .fetch_one(&mut *self.conn)
493 .await?;
494
495 count
496 .try_into()
497 .map_err(DatabaseError::to_invalid_operation)
498 }
499
500 #[tracing::instrument(
501 name = "db.user.acquire_lock_for_sync",
502 skip_all,
503 fields(
504 db.query.text,
505 user.id = %user.id,
506 ),
507 err,
508 )]
509 async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
510 let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
518
519 sqlx::query!(
522 r#"
523 SELECT pg_advisory_xact_lock($1)
524 "#,
525 lock_id,
526 )
527 .traced()
528 .execute(&mut *self.conn)
529 .await?;
530
531 Ok(())
532 }
533}