1use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11 UserEmailAuthentication, UserRegistration, UserRegistrationPassword, UserRegistrationToken,
12};
13use mas_storage::{Clock, user::UserRegistrationRepository};
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use url::Url;
18use uuid::Uuid;
19
20use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
21
22pub struct PgUserRegistrationRepository<'c> {
25 conn: &'c mut PgConnection,
26}
27
28impl<'c> PgUserRegistrationRepository<'c> {
29 pub fn new(conn: &'c mut PgConnection) -> Self {
32 Self { conn }
33 }
34}
35
36struct UserRegistrationLookup {
37 user_registration_id: Uuid,
38 ip_address: Option<IpAddr>,
39 user_agent: Option<String>,
40 post_auth_action: Option<serde_json::Value>,
41 username: String,
42 display_name: Option<String>,
43 terms_url: Option<String>,
44 email_authentication_id: Option<Uuid>,
45 user_registration_token_id: Option<Uuid>,
46 hashed_password: Option<String>,
47 hashed_password_version: Option<i32>,
48 created_at: DateTime<Utc>,
49 completed_at: Option<DateTime<Utc>>,
50}
51
52impl TryFrom<UserRegistrationLookup> for UserRegistration {
53 type Error = DatabaseInconsistencyError;
54
55 fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
56 let id = Ulid::from(value.user_registration_id);
57
58 let password = match (value.hashed_password, value.hashed_password_version) {
59 (Some(hashed_password), Some(version)) => {
60 let version = version.try_into().map_err(|e| {
61 DatabaseInconsistencyError::on("user_registrations")
62 .column("hashed_password_version")
63 .row(id)
64 .source(e)
65 })?;
66
67 Some(UserRegistrationPassword {
68 hashed_password,
69 version,
70 })
71 }
72 (None, None) => None,
73 _ => {
74 return Err(DatabaseInconsistencyError::on("user_registrations")
75 .column("hashed_password")
76 .row(id));
77 }
78 };
79
80 let terms_url = value
81 .terms_url
82 .map(|u| u.parse())
83 .transpose()
84 .map_err(|e| {
85 DatabaseInconsistencyError::on("user_registrations")
86 .column("terms_url")
87 .row(id)
88 .source(e)
89 })?;
90
91 Ok(UserRegistration {
92 id,
93 ip_address: value.ip_address,
94 user_agent: value.user_agent,
95 post_auth_action: value.post_auth_action,
96 username: value.username,
97 display_name: value.display_name,
98 terms_url,
99 email_authentication_id: value.email_authentication_id.map(Ulid::from),
100 user_registration_token_id: value.user_registration_token_id.map(Ulid::from),
101 password,
102 created_at: value.created_at,
103 completed_at: value.completed_at,
104 })
105 }
106}
107
108#[async_trait]
109impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
110 type Error = DatabaseError;
111
112 #[tracing::instrument(
113 name = "db.user_registration.lookup",
114 skip_all,
115 fields(
116 db.query.text,
117 user_registration.id = %id,
118 ),
119 err,
120 )]
121 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
122 let res = sqlx::query_as!(
123 UserRegistrationLookup,
124 r#"
125 SELECT user_registration_id
126 , ip_address as "ip_address: IpAddr"
127 , user_agent
128 , post_auth_action
129 , username
130 , display_name
131 , terms_url
132 , email_authentication_id
133 , user_registration_token_id
134 , hashed_password
135 , hashed_password_version
136 , created_at
137 , completed_at
138 FROM user_registrations
139 WHERE user_registration_id = $1
140 "#,
141 Uuid::from(id),
142 )
143 .traced()
144 .fetch_optional(&mut *self.conn)
145 .await?;
146
147 let Some(res) = res else { return Ok(None) };
148
149 Ok(Some(res.try_into()?))
150 }
151
152 #[tracing::instrument(
153 name = "db.user_registration.add",
154 skip_all,
155 fields(
156 db.query.text,
157 user_registration.id,
158 ),
159 err,
160 )]
161 async fn add(
162 &mut self,
163 rng: &mut (dyn RngCore + Send),
164 clock: &dyn Clock,
165 username: String,
166 ip_address: Option<IpAddr>,
167 user_agent: Option<String>,
168 post_auth_action: Option<serde_json::Value>,
169 ) -> Result<UserRegistration, Self::Error> {
170 let created_at = clock.now();
171 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
172 tracing::Span::current().record("user_registration.id", tracing::field::display(id));
173
174 sqlx::query!(
175 r#"
176 INSERT INTO user_registrations
177 ( user_registration_id
178 , ip_address
179 , user_agent
180 , post_auth_action
181 , username
182 , created_at
183 )
184 VALUES ($1, $2, $3, $4, $5, $6)
185 "#,
186 Uuid::from(id),
187 ip_address as Option<IpAddr>,
188 user_agent.as_deref(),
189 post_auth_action,
190 username,
191 created_at,
192 )
193 .traced()
194 .execute(&mut *self.conn)
195 .await?;
196
197 Ok(UserRegistration {
198 id,
199 ip_address,
200 user_agent,
201 post_auth_action,
202 created_at,
203 completed_at: None,
204 username,
205 display_name: None,
206 terms_url: None,
207 email_authentication_id: None,
208 user_registration_token_id: None,
209 password: None,
210 })
211 }
212
213 #[tracing::instrument(
214 name = "db.user_registration.set_display_name",
215 skip_all,
216 fields(
217 db.query.text,
218 user_registration.id = %user_registration.id,
219 user_registration.display_name = display_name,
220 ),
221 err,
222 )]
223 async fn set_display_name(
224 &mut self,
225 mut user_registration: UserRegistration,
226 display_name: String,
227 ) -> Result<UserRegistration, Self::Error> {
228 let res = sqlx::query!(
229 r#"
230 UPDATE user_registrations
231 SET display_name = $2
232 WHERE user_registration_id = $1 AND completed_at IS NULL
233 "#,
234 Uuid::from(user_registration.id),
235 display_name,
236 )
237 .traced()
238 .execute(&mut *self.conn)
239 .await?;
240
241 DatabaseError::ensure_affected_rows(&res, 1)?;
242
243 user_registration.display_name = Some(display_name);
244
245 Ok(user_registration)
246 }
247
248 #[tracing::instrument(
249 name = "db.user_registration.set_terms_url",
250 skip_all,
251 fields(
252 db.query.text,
253 user_registration.id = %user_registration.id,
254 user_registration.terms_url = %terms_url,
255 ),
256 err,
257 )]
258 async fn set_terms_url(
259 &mut self,
260 mut user_registration: UserRegistration,
261 terms_url: Url,
262 ) -> Result<UserRegistration, Self::Error> {
263 let res = sqlx::query!(
264 r#"
265 UPDATE user_registrations
266 SET terms_url = $2
267 WHERE user_registration_id = $1 AND completed_at IS NULL
268 "#,
269 Uuid::from(user_registration.id),
270 terms_url.as_str(),
271 )
272 .traced()
273 .execute(&mut *self.conn)
274 .await?;
275
276 DatabaseError::ensure_affected_rows(&res, 1)?;
277
278 user_registration.terms_url = Some(terms_url);
279
280 Ok(user_registration)
281 }
282
283 #[tracing::instrument(
284 name = "db.user_registration.set_email_authentication",
285 skip_all,
286 fields(
287 db.query.text,
288 %user_registration.id,
289 %user_email_authentication.id,
290 %user_email_authentication.email,
291 ),
292 err,
293 )]
294 async fn set_email_authentication(
295 &mut self,
296 mut user_registration: UserRegistration,
297 user_email_authentication: &UserEmailAuthentication,
298 ) -> Result<UserRegistration, Self::Error> {
299 let res = sqlx::query!(
300 r#"
301 UPDATE user_registrations
302 SET email_authentication_id = $2
303 WHERE user_registration_id = $1 AND completed_at IS NULL
304 "#,
305 Uuid::from(user_registration.id),
306 Uuid::from(user_email_authentication.id),
307 )
308 .traced()
309 .execute(&mut *self.conn)
310 .await?;
311
312 DatabaseError::ensure_affected_rows(&res, 1)?;
313
314 user_registration.email_authentication_id = Some(user_email_authentication.id);
315
316 Ok(user_registration)
317 }
318
319 #[tracing::instrument(
320 name = "db.user_registration.set_password",
321 skip_all,
322 fields(
323 db.query.text,
324 user_registration.id = %user_registration.id,
325 user_registration.hashed_password = hashed_password,
326 user_registration.hashed_password_version = version,
327 ),
328 err,
329 )]
330 async fn set_password(
331 &mut self,
332 mut user_registration: UserRegistration,
333 hashed_password: String,
334 version: u16,
335 ) -> Result<UserRegistration, Self::Error> {
336 let res = sqlx::query!(
337 r#"
338 UPDATE user_registrations
339 SET hashed_password = $2, hashed_password_version = $3
340 WHERE user_registration_id = $1 AND completed_at IS NULL
341 "#,
342 Uuid::from(user_registration.id),
343 hashed_password,
344 i32::from(version),
345 )
346 .traced()
347 .execute(&mut *self.conn)
348 .await?;
349
350 DatabaseError::ensure_affected_rows(&res, 1)?;
351
352 user_registration.password = Some(UserRegistrationPassword {
353 hashed_password,
354 version,
355 });
356
357 Ok(user_registration)
358 }
359
360 #[tracing::instrument(
361 name = "db.user_registration.set_registration_token",
362 skip_all,
363 fields(
364 db.query.text,
365 %user_registration.id,
366 %user_registration_token.id,
367 ),
368 err,
369 )]
370 async fn set_registration_token(
371 &mut self,
372 mut user_registration: UserRegistration,
373 user_registration_token: &UserRegistrationToken,
374 ) -> Result<UserRegistration, Self::Error> {
375 let res = sqlx::query!(
376 r#"
377 UPDATE user_registrations
378 SET user_registration_token_id = $2
379 WHERE user_registration_id = $1 AND completed_at IS NULL
380 "#,
381 Uuid::from(user_registration.id),
382 Uuid::from(user_registration_token.id),
383 )
384 .traced()
385 .execute(&mut *self.conn)
386 .await?;
387
388 DatabaseError::ensure_affected_rows(&res, 1)?;
389
390 user_registration.user_registration_token_id = Some(user_registration_token.id);
391
392 Ok(user_registration)
393 }
394
395 #[tracing::instrument(
396 name = "db.user_registration.complete",
397 skip_all,
398 fields(
399 db.query.text,
400 user_registration.id = %user_registration.id,
401 ),
402 err,
403 )]
404 async fn complete(
405 &mut self,
406 clock: &dyn Clock,
407 mut user_registration: UserRegistration,
408 ) -> Result<UserRegistration, Self::Error> {
409 let completed_at = clock.now();
410 let res = sqlx::query!(
411 r#"
412 UPDATE user_registrations
413 SET completed_at = $2
414 WHERE user_registration_id = $1 AND completed_at IS NULL
415 "#,
416 Uuid::from(user_registration.id),
417 completed_at,
418 )
419 .traced()
420 .execute(&mut *self.conn)
421 .await?;
422
423 DatabaseError::ensure_affected_rows(&res, 1)?;
424
425 user_registration.completed_at = Some(completed_at);
426
427 Ok(user_registration)
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use std::net::{IpAddr, Ipv4Addr};
434
435 use mas_data_model::UserRegistrationPassword;
436 use mas_storage::{Clock, clock::MockClock};
437 use rand::SeedableRng;
438 use rand_chacha::ChaChaRng;
439 use sqlx::PgPool;
440
441 use crate::PgRepository;
442
443 #[sqlx::test(migrator = "crate::MIGRATOR")]
444 async fn test_create_lookup_complete(pool: PgPool) {
445 let mut rng = ChaChaRng::seed_from_u64(42);
446 let clock = MockClock::default();
447
448 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
449
450 let registration = repo
451 .user_registration()
452 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
453 .await
454 .unwrap();
455
456 assert_eq!(registration.created_at, clock.now());
457 assert_eq!(registration.completed_at, None);
458 assert_eq!(registration.username, "alice");
459 assert_eq!(registration.display_name, None);
460 assert_eq!(registration.terms_url, None);
461 assert_eq!(registration.email_authentication_id, None);
462 assert_eq!(registration.password, None);
463 assert_eq!(registration.user_agent, None);
464 assert_eq!(registration.ip_address, None);
465 assert_eq!(registration.post_auth_action, None);
466
467 let lookup = repo
468 .user_registration()
469 .lookup(registration.id)
470 .await
471 .unwrap()
472 .unwrap();
473
474 assert_eq!(lookup.id, registration.id);
475 assert_eq!(lookup.created_at, registration.created_at);
476 assert_eq!(lookup.completed_at, registration.completed_at);
477 assert_eq!(lookup.username, registration.username);
478 assert_eq!(lookup.display_name, registration.display_name);
479 assert_eq!(lookup.terms_url, registration.terms_url);
480 assert_eq!(
481 lookup.email_authentication_id,
482 registration.email_authentication_id
483 );
484 assert_eq!(lookup.password, registration.password);
485 assert_eq!(lookup.user_agent, registration.user_agent);
486 assert_eq!(lookup.ip_address, registration.ip_address);
487 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
488
489 let registration = repo
491 .user_registration()
492 .complete(&clock, registration)
493 .await
494 .unwrap();
495 assert_eq!(registration.completed_at, Some(clock.now()));
496
497 let lookup = repo
499 .user_registration()
500 .lookup(registration.id)
501 .await
502 .unwrap()
503 .unwrap();
504 assert_eq!(lookup.completed_at, registration.completed_at);
505
506 let res = repo
508 .user_registration()
509 .complete(&clock, registration)
510 .await;
511 assert!(res.is_err());
512 }
513
514 #[sqlx::test(migrator = "crate::MIGRATOR")]
515 async fn test_create_useragent_ipaddress(pool: PgPool) {
516 let mut rng = ChaChaRng::seed_from_u64(42);
517 let clock = MockClock::default();
518
519 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
520
521 let registration = repo
522 .user_registration()
523 .add(
524 &mut rng,
525 &clock,
526 "alice".to_owned(),
527 Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
528 Some("Mozilla/5.0".to_owned()),
529 Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
530 )
531 .await
532 .unwrap();
533
534 assert_eq!(registration.user_agent, Some("Mozilla/5.0".to_owned()));
535 assert_eq!(
536 registration.ip_address,
537 Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
538 );
539 assert_eq!(
540 registration.post_auth_action,
541 Some(
542 serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
543 )
544 );
545
546 let lookup = repo
547 .user_registration()
548 .lookup(registration.id)
549 .await
550 .unwrap()
551 .unwrap();
552
553 assert_eq!(lookup.user_agent, registration.user_agent);
554 assert_eq!(lookup.ip_address, registration.ip_address);
555 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
556 }
557
558 #[sqlx::test(migrator = "crate::MIGRATOR")]
559 async fn test_set_display_name(pool: PgPool) {
560 let mut rng = ChaChaRng::seed_from_u64(42);
561 let clock = MockClock::default();
562
563 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
564
565 let registration = repo
566 .user_registration()
567 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
568 .await
569 .unwrap();
570
571 assert_eq!(registration.display_name, None);
572
573 let registration = repo
574 .user_registration()
575 .set_display_name(registration, "Alice".to_owned())
576 .await
577 .unwrap();
578
579 assert_eq!(registration.display_name, Some("Alice".to_owned()));
580
581 let lookup = repo
582 .user_registration()
583 .lookup(registration.id)
584 .await
585 .unwrap()
586 .unwrap();
587
588 assert_eq!(lookup.display_name, registration.display_name);
589
590 let registration = repo
592 .user_registration()
593 .set_display_name(registration, "Bob".to_owned())
594 .await
595 .unwrap();
596
597 assert_eq!(registration.display_name, Some("Bob".to_owned()));
598
599 let lookup = repo
600 .user_registration()
601 .lookup(registration.id)
602 .await
603 .unwrap()
604 .unwrap();
605
606 assert_eq!(lookup.display_name, registration.display_name);
607
608 let registration = repo
610 .user_registration()
611 .complete(&clock, registration)
612 .await
613 .unwrap();
614
615 let res = repo
616 .user_registration()
617 .set_display_name(registration, "Charlie".to_owned())
618 .await;
619 assert!(res.is_err());
620 }
621
622 #[sqlx::test(migrator = "crate::MIGRATOR")]
623 async fn test_set_terms_url(pool: PgPool) {
624 let mut rng = ChaChaRng::seed_from_u64(42);
625 let clock = MockClock::default();
626
627 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
628
629 let registration = repo
630 .user_registration()
631 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
632 .await
633 .unwrap();
634
635 assert_eq!(registration.terms_url, None);
636
637 let registration = repo
638 .user_registration()
639 .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
640 .await
641 .unwrap();
642
643 assert_eq!(
644 registration.terms_url,
645 Some("https://example.com/terms".parse().unwrap())
646 );
647
648 let lookup = repo
649 .user_registration()
650 .lookup(registration.id)
651 .await
652 .unwrap()
653 .unwrap();
654
655 assert_eq!(lookup.terms_url, registration.terms_url);
656
657 let registration = repo
659 .user_registration()
660 .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
661 .await
662 .unwrap();
663
664 assert_eq!(
665 registration.terms_url,
666 Some("https://example.com/terms2".parse().unwrap())
667 );
668
669 let lookup = repo
670 .user_registration()
671 .lookup(registration.id)
672 .await
673 .unwrap()
674 .unwrap();
675
676 assert_eq!(lookup.terms_url, registration.terms_url);
677
678 let registration = repo
680 .user_registration()
681 .complete(&clock, registration)
682 .await
683 .unwrap();
684
685 let res = repo
686 .user_registration()
687 .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
688 .await;
689 assert!(res.is_err());
690 }
691
692 #[sqlx::test(migrator = "crate::MIGRATOR")]
693 async fn test_set_email_authentication(pool: PgPool) {
694 let mut rng = ChaChaRng::seed_from_u64(42);
695 let clock = MockClock::default();
696
697 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
698
699 let registration = repo
700 .user_registration()
701 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
702 .await
703 .unwrap();
704
705 assert_eq!(registration.email_authentication_id, None);
706
707 let authentication = repo
708 .user_email()
709 .add_authentication_for_registration(
710 &mut rng,
711 &clock,
712 "alice@example.com".to_owned(),
713 ®istration,
714 )
715 .await
716 .unwrap();
717
718 let registration = repo
719 .user_registration()
720 .set_email_authentication(registration, &authentication)
721 .await
722 .unwrap();
723
724 assert_eq!(
725 registration.email_authentication_id,
726 Some(authentication.id)
727 );
728
729 let lookup = repo
730 .user_registration()
731 .lookup(registration.id)
732 .await
733 .unwrap()
734 .unwrap();
735
736 assert_eq!(
737 lookup.email_authentication_id,
738 registration.email_authentication_id
739 );
740
741 let registration = repo
743 .user_registration()
744 .set_email_authentication(registration, &authentication)
745 .await
746 .unwrap();
747
748 assert_eq!(
749 registration.email_authentication_id,
750 Some(authentication.id)
751 );
752
753 let lookup = repo
754 .user_registration()
755 .lookup(registration.id)
756 .await
757 .unwrap()
758 .unwrap();
759
760 assert_eq!(
761 lookup.email_authentication_id,
762 registration.email_authentication_id
763 );
764
765 let registration = repo
767 .user_registration()
768 .complete(&clock, registration)
769 .await
770 .unwrap();
771
772 let res = repo
773 .user_registration()
774 .set_email_authentication(registration, &authentication)
775 .await;
776 assert!(res.is_err());
777 }
778
779 #[sqlx::test(migrator = "crate::MIGRATOR")]
780 async fn test_set_password(pool: PgPool) {
781 let mut rng = ChaChaRng::seed_from_u64(42);
782 let clock = MockClock::default();
783
784 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
785
786 let registration = repo
787 .user_registration()
788 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
789 .await
790 .unwrap();
791
792 assert_eq!(registration.password, None);
793
794 let registration = repo
795 .user_registration()
796 .set_password(registration, "fakehashedpassword".to_owned(), 1)
797 .await
798 .unwrap();
799
800 assert_eq!(
801 registration.password,
802 Some(UserRegistrationPassword {
803 hashed_password: "fakehashedpassword".to_owned(),
804 version: 1,
805 })
806 );
807
808 let lookup = repo
809 .user_registration()
810 .lookup(registration.id)
811 .await
812 .unwrap()
813 .unwrap();
814
815 assert_eq!(lookup.password, registration.password);
816
817 let registration = repo
819 .user_registration()
820 .set_password(registration, "fakehashedpassword2".to_owned(), 2)
821 .await
822 .unwrap();
823
824 assert_eq!(
825 registration.password,
826 Some(UserRegistrationPassword {
827 hashed_password: "fakehashedpassword2".to_owned(),
828 version: 2,
829 })
830 );
831
832 let lookup = repo
833 .user_registration()
834 .lookup(registration.id)
835 .await
836 .unwrap()
837 .unwrap();
838
839 assert_eq!(lookup.password, registration.password);
840
841 let registration = repo
843 .user_registration()
844 .complete(&clock, registration)
845 .await
846 .unwrap();
847
848 let res = repo
849 .user_registration()
850 .set_password(registration, "fakehashedpassword3".to_owned(), 3)
851 .await;
852 assert!(res.is_err());
853 }
854}