mas_storage_pg/user/
registration.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    UserEmailAuthentication, UserRegistration, UserRegistrationPassword, UserRegistrationToken,
12};
13use mas_storage::{Clock, user::UserRegistrationRepository};
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use url::Url;
18use uuid::Uuid;
19
20use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
21
22/// An implementation of [`UserRegistrationRepository`] for a PostgreSQL
23/// connection
24pub struct PgUserRegistrationRepository<'c> {
25    conn: &'c mut PgConnection,
26}
27
28impl<'c> PgUserRegistrationRepository<'c> {
29    /// Create a new [`PgUserRegistrationRepository`] from an active PostgreSQL
30    /// connection
31    pub fn new(conn: &'c mut PgConnection) -> Self {
32        Self { conn }
33    }
34}
35
36struct UserRegistrationLookup {
37    user_registration_id: Uuid,
38    ip_address: Option<IpAddr>,
39    user_agent: Option<String>,
40    post_auth_action: Option<serde_json::Value>,
41    username: String,
42    display_name: Option<String>,
43    terms_url: Option<String>,
44    email_authentication_id: Option<Uuid>,
45    user_registration_token_id: Option<Uuid>,
46    hashed_password: Option<String>,
47    hashed_password_version: Option<i32>,
48    created_at: DateTime<Utc>,
49    completed_at: Option<DateTime<Utc>>,
50}
51
52impl TryFrom<UserRegistrationLookup> for UserRegistration {
53    type Error = DatabaseInconsistencyError;
54
55    fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
56        let id = Ulid::from(value.user_registration_id);
57
58        let password = match (value.hashed_password, value.hashed_password_version) {
59            (Some(hashed_password), Some(version)) => {
60                let version = version.try_into().map_err(|e| {
61                    DatabaseInconsistencyError::on("user_registrations")
62                        .column("hashed_password_version")
63                        .row(id)
64                        .source(e)
65                })?;
66
67                Some(UserRegistrationPassword {
68                    hashed_password,
69                    version,
70                })
71            }
72            (None, None) => None,
73            _ => {
74                return Err(DatabaseInconsistencyError::on("user_registrations")
75                    .column("hashed_password")
76                    .row(id));
77            }
78        };
79
80        let terms_url = value
81            .terms_url
82            .map(|u| u.parse())
83            .transpose()
84            .map_err(|e| {
85                DatabaseInconsistencyError::on("user_registrations")
86                    .column("terms_url")
87                    .row(id)
88                    .source(e)
89            })?;
90
91        Ok(UserRegistration {
92            id,
93            ip_address: value.ip_address,
94            user_agent: value.user_agent,
95            post_auth_action: value.post_auth_action,
96            username: value.username,
97            display_name: value.display_name,
98            terms_url,
99            email_authentication_id: value.email_authentication_id.map(Ulid::from),
100            user_registration_token_id: value.user_registration_token_id.map(Ulid::from),
101            password,
102            created_at: value.created_at,
103            completed_at: value.completed_at,
104        })
105    }
106}
107
108#[async_trait]
109impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
110    type Error = DatabaseError;
111
112    #[tracing::instrument(
113        name = "db.user_registration.lookup",
114        skip_all,
115        fields(
116            db.query.text,
117            user_registration.id = %id,
118        ),
119        err,
120    )]
121    async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
122        let res = sqlx::query_as!(
123            UserRegistrationLookup,
124            r#"
125                SELECT user_registration_id
126                     , ip_address as "ip_address: IpAddr"
127                     , user_agent
128                     , post_auth_action
129                     , username
130                     , display_name
131                     , terms_url
132                     , email_authentication_id
133                     , user_registration_token_id
134                     , hashed_password
135                     , hashed_password_version
136                     , created_at
137                     , completed_at
138                FROM user_registrations
139                WHERE user_registration_id = $1
140            "#,
141            Uuid::from(id),
142        )
143        .traced()
144        .fetch_optional(&mut *self.conn)
145        .await?;
146
147        let Some(res) = res else { return Ok(None) };
148
149        Ok(Some(res.try_into()?))
150    }
151
152    #[tracing::instrument(
153        name = "db.user_registration.add",
154        skip_all,
155        fields(
156            db.query.text,
157            user_registration.id,
158        ),
159        err,
160    )]
161    async fn add(
162        &mut self,
163        rng: &mut (dyn RngCore + Send),
164        clock: &dyn Clock,
165        username: String,
166        ip_address: Option<IpAddr>,
167        user_agent: Option<String>,
168        post_auth_action: Option<serde_json::Value>,
169    ) -> Result<UserRegistration, Self::Error> {
170        let created_at = clock.now();
171        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
172        tracing::Span::current().record("user_registration.id", tracing::field::display(id));
173
174        sqlx::query!(
175            r#"
176                INSERT INTO user_registrations
177                  ( user_registration_id
178                  , ip_address
179                  , user_agent
180                  , post_auth_action
181                  , username
182                  , created_at
183                  )
184                VALUES ($1, $2, $3, $4, $5, $6)
185            "#,
186            Uuid::from(id),
187            ip_address as Option<IpAddr>,
188            user_agent.as_deref(),
189            post_auth_action,
190            username,
191            created_at,
192        )
193        .traced()
194        .execute(&mut *self.conn)
195        .await?;
196
197        Ok(UserRegistration {
198            id,
199            ip_address,
200            user_agent,
201            post_auth_action,
202            created_at,
203            completed_at: None,
204            username,
205            display_name: None,
206            terms_url: None,
207            email_authentication_id: None,
208            user_registration_token_id: None,
209            password: None,
210        })
211    }
212
213    #[tracing::instrument(
214        name = "db.user_registration.set_display_name",
215        skip_all,
216        fields(
217            db.query.text,
218            user_registration.id = %user_registration.id,
219            user_registration.display_name = display_name,
220        ),
221        err,
222    )]
223    async fn set_display_name(
224        &mut self,
225        mut user_registration: UserRegistration,
226        display_name: String,
227    ) -> Result<UserRegistration, Self::Error> {
228        let res = sqlx::query!(
229            r#"
230                UPDATE user_registrations
231                SET display_name = $2
232                WHERE user_registration_id = $1 AND completed_at IS NULL
233            "#,
234            Uuid::from(user_registration.id),
235            display_name,
236        )
237        .traced()
238        .execute(&mut *self.conn)
239        .await?;
240
241        DatabaseError::ensure_affected_rows(&res, 1)?;
242
243        user_registration.display_name = Some(display_name);
244
245        Ok(user_registration)
246    }
247
248    #[tracing::instrument(
249        name = "db.user_registration.set_terms_url",
250        skip_all,
251        fields(
252            db.query.text,
253            user_registration.id = %user_registration.id,
254            user_registration.terms_url = %terms_url,
255        ),
256        err,
257    )]
258    async fn set_terms_url(
259        &mut self,
260        mut user_registration: UserRegistration,
261        terms_url: Url,
262    ) -> Result<UserRegistration, Self::Error> {
263        let res = sqlx::query!(
264            r#"
265                UPDATE user_registrations
266                SET terms_url = $2
267                WHERE user_registration_id = $1 AND completed_at IS NULL
268            "#,
269            Uuid::from(user_registration.id),
270            terms_url.as_str(),
271        )
272        .traced()
273        .execute(&mut *self.conn)
274        .await?;
275
276        DatabaseError::ensure_affected_rows(&res, 1)?;
277
278        user_registration.terms_url = Some(terms_url);
279
280        Ok(user_registration)
281    }
282
283    #[tracing::instrument(
284        name = "db.user_registration.set_email_authentication",
285        skip_all,
286        fields(
287            db.query.text,
288            %user_registration.id,
289            %user_email_authentication.id,
290            %user_email_authentication.email,
291        ),
292        err,
293    )]
294    async fn set_email_authentication(
295        &mut self,
296        mut user_registration: UserRegistration,
297        user_email_authentication: &UserEmailAuthentication,
298    ) -> Result<UserRegistration, Self::Error> {
299        let res = sqlx::query!(
300            r#"
301                UPDATE user_registrations
302                SET email_authentication_id = $2
303                WHERE user_registration_id = $1 AND completed_at IS NULL
304            "#,
305            Uuid::from(user_registration.id),
306            Uuid::from(user_email_authentication.id),
307        )
308        .traced()
309        .execute(&mut *self.conn)
310        .await?;
311
312        DatabaseError::ensure_affected_rows(&res, 1)?;
313
314        user_registration.email_authentication_id = Some(user_email_authentication.id);
315
316        Ok(user_registration)
317    }
318
319    #[tracing::instrument(
320        name = "db.user_registration.set_password",
321        skip_all,
322        fields(
323            db.query.text,
324            user_registration.id = %user_registration.id,
325            user_registration.hashed_password = hashed_password,
326            user_registration.hashed_password_version = version,
327        ),
328        err,
329    )]
330    async fn set_password(
331        &mut self,
332        mut user_registration: UserRegistration,
333        hashed_password: String,
334        version: u16,
335    ) -> Result<UserRegistration, Self::Error> {
336        let res = sqlx::query!(
337            r#"
338                UPDATE user_registrations
339                SET hashed_password = $2, hashed_password_version = $3
340                WHERE user_registration_id = $1 AND completed_at IS NULL
341            "#,
342            Uuid::from(user_registration.id),
343            hashed_password,
344            i32::from(version),
345        )
346        .traced()
347        .execute(&mut *self.conn)
348        .await?;
349
350        DatabaseError::ensure_affected_rows(&res, 1)?;
351
352        user_registration.password = Some(UserRegistrationPassword {
353            hashed_password,
354            version,
355        });
356
357        Ok(user_registration)
358    }
359
360    #[tracing::instrument(
361        name = "db.user_registration.set_registration_token",
362        skip_all,
363        fields(
364            db.query.text,
365            %user_registration.id,
366            %user_registration_token.id,
367        ),
368        err,
369    )]
370    async fn set_registration_token(
371        &mut self,
372        mut user_registration: UserRegistration,
373        user_registration_token: &UserRegistrationToken,
374    ) -> Result<UserRegistration, Self::Error> {
375        let res = sqlx::query!(
376            r#"
377                UPDATE user_registrations
378                SET user_registration_token_id = $2
379                WHERE user_registration_id = $1 AND completed_at IS NULL
380            "#,
381            Uuid::from(user_registration.id),
382            Uuid::from(user_registration_token.id),
383        )
384        .traced()
385        .execute(&mut *self.conn)
386        .await?;
387
388        DatabaseError::ensure_affected_rows(&res, 1)?;
389
390        user_registration.user_registration_token_id = Some(user_registration_token.id);
391
392        Ok(user_registration)
393    }
394
395    #[tracing::instrument(
396        name = "db.user_registration.complete",
397        skip_all,
398        fields(
399            db.query.text,
400            user_registration.id = %user_registration.id,
401        ),
402        err,
403    )]
404    async fn complete(
405        &mut self,
406        clock: &dyn Clock,
407        mut user_registration: UserRegistration,
408    ) -> Result<UserRegistration, Self::Error> {
409        let completed_at = clock.now();
410        let res = sqlx::query!(
411            r#"
412                UPDATE user_registrations
413                SET completed_at = $2
414                WHERE user_registration_id = $1 AND completed_at IS NULL
415            "#,
416            Uuid::from(user_registration.id),
417            completed_at,
418        )
419        .traced()
420        .execute(&mut *self.conn)
421        .await?;
422
423        DatabaseError::ensure_affected_rows(&res, 1)?;
424
425        user_registration.completed_at = Some(completed_at);
426
427        Ok(user_registration)
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use std::net::{IpAddr, Ipv4Addr};
434
435    use mas_data_model::UserRegistrationPassword;
436    use mas_storage::{Clock, clock::MockClock};
437    use rand::SeedableRng;
438    use rand_chacha::ChaChaRng;
439    use sqlx::PgPool;
440
441    use crate::PgRepository;
442
443    #[sqlx::test(migrator = "crate::MIGRATOR")]
444    async fn test_create_lookup_complete(pool: PgPool) {
445        let mut rng = ChaChaRng::seed_from_u64(42);
446        let clock = MockClock::default();
447
448        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
449
450        let registration = repo
451            .user_registration()
452            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
453            .await
454            .unwrap();
455
456        assert_eq!(registration.created_at, clock.now());
457        assert_eq!(registration.completed_at, None);
458        assert_eq!(registration.username, "alice");
459        assert_eq!(registration.display_name, None);
460        assert_eq!(registration.terms_url, None);
461        assert_eq!(registration.email_authentication_id, None);
462        assert_eq!(registration.password, None);
463        assert_eq!(registration.user_agent, None);
464        assert_eq!(registration.ip_address, None);
465        assert_eq!(registration.post_auth_action, None);
466
467        let lookup = repo
468            .user_registration()
469            .lookup(registration.id)
470            .await
471            .unwrap()
472            .unwrap();
473
474        assert_eq!(lookup.id, registration.id);
475        assert_eq!(lookup.created_at, registration.created_at);
476        assert_eq!(lookup.completed_at, registration.completed_at);
477        assert_eq!(lookup.username, registration.username);
478        assert_eq!(lookup.display_name, registration.display_name);
479        assert_eq!(lookup.terms_url, registration.terms_url);
480        assert_eq!(
481            lookup.email_authentication_id,
482            registration.email_authentication_id
483        );
484        assert_eq!(lookup.password, registration.password);
485        assert_eq!(lookup.user_agent, registration.user_agent);
486        assert_eq!(lookup.ip_address, registration.ip_address);
487        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
488
489        // Mark the registration as completed
490        let registration = repo
491            .user_registration()
492            .complete(&clock, registration)
493            .await
494            .unwrap();
495        assert_eq!(registration.completed_at, Some(clock.now()));
496
497        // Lookup the registration again
498        let lookup = repo
499            .user_registration()
500            .lookup(registration.id)
501            .await
502            .unwrap()
503            .unwrap();
504        assert_eq!(lookup.completed_at, registration.completed_at);
505
506        // Do it again, it should fail
507        let res = repo
508            .user_registration()
509            .complete(&clock, registration)
510            .await;
511        assert!(res.is_err());
512    }
513
514    #[sqlx::test(migrator = "crate::MIGRATOR")]
515    async fn test_create_useragent_ipaddress(pool: PgPool) {
516        let mut rng = ChaChaRng::seed_from_u64(42);
517        let clock = MockClock::default();
518
519        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
520
521        let registration = repo
522            .user_registration()
523            .add(
524                &mut rng,
525                &clock,
526                "alice".to_owned(),
527                Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
528                Some("Mozilla/5.0".to_owned()),
529                Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
530            )
531            .await
532            .unwrap();
533
534        assert_eq!(registration.user_agent, Some("Mozilla/5.0".to_owned()));
535        assert_eq!(
536            registration.ip_address,
537            Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
538        );
539        assert_eq!(
540            registration.post_auth_action,
541            Some(
542                serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
543            )
544        );
545
546        let lookup = repo
547            .user_registration()
548            .lookup(registration.id)
549            .await
550            .unwrap()
551            .unwrap();
552
553        assert_eq!(lookup.user_agent, registration.user_agent);
554        assert_eq!(lookup.ip_address, registration.ip_address);
555        assert_eq!(lookup.post_auth_action, registration.post_auth_action);
556    }
557
558    #[sqlx::test(migrator = "crate::MIGRATOR")]
559    async fn test_set_display_name(pool: PgPool) {
560        let mut rng = ChaChaRng::seed_from_u64(42);
561        let clock = MockClock::default();
562
563        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
564
565        let registration = repo
566            .user_registration()
567            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
568            .await
569            .unwrap();
570
571        assert_eq!(registration.display_name, None);
572
573        let registration = repo
574            .user_registration()
575            .set_display_name(registration, "Alice".to_owned())
576            .await
577            .unwrap();
578
579        assert_eq!(registration.display_name, Some("Alice".to_owned()));
580
581        let lookup = repo
582            .user_registration()
583            .lookup(registration.id)
584            .await
585            .unwrap()
586            .unwrap();
587
588        assert_eq!(lookup.display_name, registration.display_name);
589
590        // Setting it again should work
591        let registration = repo
592            .user_registration()
593            .set_display_name(registration, "Bob".to_owned())
594            .await
595            .unwrap();
596
597        assert_eq!(registration.display_name, Some("Bob".to_owned()));
598
599        let lookup = repo
600            .user_registration()
601            .lookup(registration.id)
602            .await
603            .unwrap()
604            .unwrap();
605
606        assert_eq!(lookup.display_name, registration.display_name);
607
608        // Can't set it once completed
609        let registration = repo
610            .user_registration()
611            .complete(&clock, registration)
612            .await
613            .unwrap();
614
615        let res = repo
616            .user_registration()
617            .set_display_name(registration, "Charlie".to_owned())
618            .await;
619        assert!(res.is_err());
620    }
621
622    #[sqlx::test(migrator = "crate::MIGRATOR")]
623    async fn test_set_terms_url(pool: PgPool) {
624        let mut rng = ChaChaRng::seed_from_u64(42);
625        let clock = MockClock::default();
626
627        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
628
629        let registration = repo
630            .user_registration()
631            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
632            .await
633            .unwrap();
634
635        assert_eq!(registration.terms_url, None);
636
637        let registration = repo
638            .user_registration()
639            .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
640            .await
641            .unwrap();
642
643        assert_eq!(
644            registration.terms_url,
645            Some("https://example.com/terms".parse().unwrap())
646        );
647
648        let lookup = repo
649            .user_registration()
650            .lookup(registration.id)
651            .await
652            .unwrap()
653            .unwrap();
654
655        assert_eq!(lookup.terms_url, registration.terms_url);
656
657        // Setting it again should work
658        let registration = repo
659            .user_registration()
660            .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
661            .await
662            .unwrap();
663
664        assert_eq!(
665            registration.terms_url,
666            Some("https://example.com/terms2".parse().unwrap())
667        );
668
669        let lookup = repo
670            .user_registration()
671            .lookup(registration.id)
672            .await
673            .unwrap()
674            .unwrap();
675
676        assert_eq!(lookup.terms_url, registration.terms_url);
677
678        // Can't set it once completed
679        let registration = repo
680            .user_registration()
681            .complete(&clock, registration)
682            .await
683            .unwrap();
684
685        let res = repo
686            .user_registration()
687            .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
688            .await;
689        assert!(res.is_err());
690    }
691
692    #[sqlx::test(migrator = "crate::MIGRATOR")]
693    async fn test_set_email_authentication(pool: PgPool) {
694        let mut rng = ChaChaRng::seed_from_u64(42);
695        let clock = MockClock::default();
696
697        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
698
699        let registration = repo
700            .user_registration()
701            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
702            .await
703            .unwrap();
704
705        assert_eq!(registration.email_authentication_id, None);
706
707        let authentication = repo
708            .user_email()
709            .add_authentication_for_registration(
710                &mut rng,
711                &clock,
712                "alice@example.com".to_owned(),
713                &registration,
714            )
715            .await
716            .unwrap();
717
718        let registration = repo
719            .user_registration()
720            .set_email_authentication(registration, &authentication)
721            .await
722            .unwrap();
723
724        assert_eq!(
725            registration.email_authentication_id,
726            Some(authentication.id)
727        );
728
729        let lookup = repo
730            .user_registration()
731            .lookup(registration.id)
732            .await
733            .unwrap()
734            .unwrap();
735
736        assert_eq!(
737            lookup.email_authentication_id,
738            registration.email_authentication_id
739        );
740
741        // Setting it again should work
742        let registration = repo
743            .user_registration()
744            .set_email_authentication(registration, &authentication)
745            .await
746            .unwrap();
747
748        assert_eq!(
749            registration.email_authentication_id,
750            Some(authentication.id)
751        );
752
753        let lookup = repo
754            .user_registration()
755            .lookup(registration.id)
756            .await
757            .unwrap()
758            .unwrap();
759
760        assert_eq!(
761            lookup.email_authentication_id,
762            registration.email_authentication_id
763        );
764
765        // Can't set it once completed
766        let registration = repo
767            .user_registration()
768            .complete(&clock, registration)
769            .await
770            .unwrap();
771
772        let res = repo
773            .user_registration()
774            .set_email_authentication(registration, &authentication)
775            .await;
776        assert!(res.is_err());
777    }
778
779    #[sqlx::test(migrator = "crate::MIGRATOR")]
780    async fn test_set_password(pool: PgPool) {
781        let mut rng = ChaChaRng::seed_from_u64(42);
782        let clock = MockClock::default();
783
784        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
785
786        let registration = repo
787            .user_registration()
788            .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
789            .await
790            .unwrap();
791
792        assert_eq!(registration.password, None);
793
794        let registration = repo
795            .user_registration()
796            .set_password(registration, "fakehashedpassword".to_owned(), 1)
797            .await
798            .unwrap();
799
800        assert_eq!(
801            registration.password,
802            Some(UserRegistrationPassword {
803                hashed_password: "fakehashedpassword".to_owned(),
804                version: 1,
805            })
806        );
807
808        let lookup = repo
809            .user_registration()
810            .lookup(registration.id)
811            .await
812            .unwrap()
813            .unwrap();
814
815        assert_eq!(lookup.password, registration.password);
816
817        // Setting it again should work
818        let registration = repo
819            .user_registration()
820            .set_password(registration, "fakehashedpassword2".to_owned(), 2)
821            .await
822            .unwrap();
823
824        assert_eq!(
825            registration.password,
826            Some(UserRegistrationPassword {
827                hashed_password: "fakehashedpassword2".to_owned(),
828                version: 2,
829            })
830        );
831
832        let lookup = repo
833            .user_registration()
834            .lookup(registration.id)
835            .await
836            .unwrap()
837            .unwrap();
838
839        assert_eq!(lookup.password, registration.password);
840
841        // Can't set it once completed
842        let registration = repo
843            .user_registration()
844            .complete(&clock, registration)
845            .await
846            .unwrap();
847
848        let res = repo
849            .user_registration()
850            .set_password(registration, "fakehashedpassword3".to_owned(), 3)
851            .await;
852        assert!(res.is_err());
853    }
854}