1use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{UserEmailAuthentication, UserRegistration, UserRegistrationPassword};
11use mas_storage::{Clock, user::UserRegistrationRepository};
12use rand::RngCore;
13use sqlx::PgConnection;
14use ulid::Ulid;
15use url::Url;
16use uuid::Uuid;
17
18use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt as _};
19
20pub struct PgUserRegistrationRepository<'c> {
23 conn: &'c mut PgConnection,
24}
25
26impl<'c> PgUserRegistrationRepository<'c> {
27 pub fn new(conn: &'c mut PgConnection) -> Self {
30 Self { conn }
31 }
32}
33
34struct UserRegistrationLookup {
35 user_registration_id: Uuid,
36 ip_address: Option<IpAddr>,
37 user_agent: Option<String>,
38 post_auth_action: Option<serde_json::Value>,
39 username: String,
40 display_name: Option<String>,
41 terms_url: Option<String>,
42 email_authentication_id: Option<Uuid>,
43 hashed_password: Option<String>,
44 hashed_password_version: Option<i32>,
45 created_at: DateTime<Utc>,
46 completed_at: Option<DateTime<Utc>>,
47}
48
49impl TryFrom<UserRegistrationLookup> for UserRegistration {
50 type Error = DatabaseInconsistencyError;
51
52 fn try_from(value: UserRegistrationLookup) -> Result<Self, Self::Error> {
53 let id = Ulid::from(value.user_registration_id);
54
55 let password = match (value.hashed_password, value.hashed_password_version) {
56 (Some(hashed_password), Some(version)) => {
57 let version = version.try_into().map_err(|e| {
58 DatabaseInconsistencyError::on("user_registrations")
59 .column("hashed_password_version")
60 .row(id)
61 .source(e)
62 })?;
63
64 Some(UserRegistrationPassword {
65 hashed_password,
66 version,
67 })
68 }
69 (None, None) => None,
70 _ => {
71 return Err(DatabaseInconsistencyError::on("user_registrations")
72 .column("hashed_password")
73 .row(id));
74 }
75 };
76
77 let terms_url = value
78 .terms_url
79 .map(|u| u.parse())
80 .transpose()
81 .map_err(|e| {
82 DatabaseInconsistencyError::on("user_registrations")
83 .column("terms_url")
84 .row(id)
85 .source(e)
86 })?;
87
88 Ok(UserRegistration {
89 id,
90 ip_address: value.ip_address,
91 user_agent: value.user_agent,
92 post_auth_action: value.post_auth_action,
93 username: value.username,
94 display_name: value.display_name,
95 terms_url,
96 email_authentication_id: value.email_authentication_id.map(Ulid::from),
97 password,
98 created_at: value.created_at,
99 completed_at: value.completed_at,
100 })
101 }
102}
103
104#[async_trait]
105impl UserRegistrationRepository for PgUserRegistrationRepository<'_> {
106 type Error = DatabaseError;
107
108 #[tracing::instrument(
109 name = "db.user_registration.lookup",
110 skip_all,
111 fields(
112 db.query.text,
113 user_registration.id = %id,
114 ),
115 err,
116 )]
117 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserRegistration>, Self::Error> {
118 let res = sqlx::query_as!(
119 UserRegistrationLookup,
120 r#"
121 SELECT user_registration_id
122 , ip_address as "ip_address: IpAddr"
123 , user_agent
124 , post_auth_action
125 , username
126 , display_name
127 , terms_url
128 , email_authentication_id
129 , hashed_password
130 , hashed_password_version
131 , created_at
132 , completed_at
133 FROM user_registrations
134 WHERE user_registration_id = $1
135 "#,
136 Uuid::from(id),
137 )
138 .traced()
139 .fetch_optional(&mut *self.conn)
140 .await?;
141
142 let Some(res) = res else { return Ok(None) };
143
144 Ok(Some(res.try_into()?))
145 }
146
147 #[tracing::instrument(
148 name = "db.user_registration.add",
149 skip_all,
150 fields(
151 db.query.text,
152 user_registration.id,
153 ),
154 err,
155 )]
156 async fn add(
157 &mut self,
158 rng: &mut (dyn RngCore + Send),
159 clock: &dyn Clock,
160 username: String,
161 ip_address: Option<IpAddr>,
162 user_agent: Option<String>,
163 post_auth_action: Option<serde_json::Value>,
164 ) -> Result<UserRegistration, Self::Error> {
165 let created_at = clock.now();
166 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
167 tracing::Span::current().record("user_registration.id", tracing::field::display(id));
168
169 sqlx::query!(
170 r#"
171 INSERT INTO user_registrations
172 ( user_registration_id
173 , ip_address
174 , user_agent
175 , post_auth_action
176 , username
177 , created_at
178 )
179 VALUES ($1, $2, $3, $4, $5, $6)
180 "#,
181 Uuid::from(id),
182 ip_address as Option<IpAddr>,
183 user_agent.as_deref(),
184 post_auth_action,
185 username,
186 created_at,
187 )
188 .traced()
189 .execute(&mut *self.conn)
190 .await?;
191
192 Ok(UserRegistration {
193 id,
194 ip_address,
195 user_agent,
196 post_auth_action,
197 created_at,
198 completed_at: None,
199 username,
200 display_name: None,
201 terms_url: None,
202 email_authentication_id: None,
203 password: None,
204 })
205 }
206
207 #[tracing::instrument(
208 name = "db.user_registration.set_display_name",
209 skip_all,
210 fields(
211 db.query.text,
212 user_registration.id = %user_registration.id,
213 user_registration.display_name = display_name,
214 ),
215 err,
216 )]
217 async fn set_display_name(
218 &mut self,
219 mut user_registration: UserRegistration,
220 display_name: String,
221 ) -> Result<UserRegistration, Self::Error> {
222 let res = sqlx::query!(
223 r#"
224 UPDATE user_registrations
225 SET display_name = $2
226 WHERE user_registration_id = $1 AND completed_at IS NULL
227 "#,
228 Uuid::from(user_registration.id),
229 display_name,
230 )
231 .traced()
232 .execute(&mut *self.conn)
233 .await?;
234
235 DatabaseError::ensure_affected_rows(&res, 1)?;
236
237 user_registration.display_name = Some(display_name);
238
239 Ok(user_registration)
240 }
241
242 #[tracing::instrument(
243 name = "db.user_registration.set_terms_url",
244 skip_all,
245 fields(
246 db.query.text,
247 user_registration.id = %user_registration.id,
248 user_registration.terms_url = %terms_url,
249 ),
250 err,
251 )]
252 async fn set_terms_url(
253 &mut self,
254 mut user_registration: UserRegistration,
255 terms_url: Url,
256 ) -> Result<UserRegistration, Self::Error> {
257 let res = sqlx::query!(
258 r#"
259 UPDATE user_registrations
260 SET terms_url = $2
261 WHERE user_registration_id = $1 AND completed_at IS NULL
262 "#,
263 Uuid::from(user_registration.id),
264 terms_url.as_str(),
265 )
266 .traced()
267 .execute(&mut *self.conn)
268 .await?;
269
270 DatabaseError::ensure_affected_rows(&res, 1)?;
271
272 user_registration.terms_url = Some(terms_url);
273
274 Ok(user_registration)
275 }
276
277 #[tracing::instrument(
278 name = "db.user_registration.set_email_authentication",
279 skip_all,
280 fields(
281 db.query.text,
282 %user_registration.id,
283 %user_email_authentication.id,
284 %user_email_authentication.email,
285 ),
286 err,
287 )]
288 async fn set_email_authentication(
289 &mut self,
290 mut user_registration: UserRegistration,
291 user_email_authentication: &UserEmailAuthentication,
292 ) -> Result<UserRegistration, Self::Error> {
293 let res = sqlx::query!(
294 r#"
295 UPDATE user_registrations
296 SET email_authentication_id = $2
297 WHERE user_registration_id = $1 AND completed_at IS NULL
298 "#,
299 Uuid::from(user_registration.id),
300 Uuid::from(user_email_authentication.id),
301 )
302 .traced()
303 .execute(&mut *self.conn)
304 .await?;
305
306 DatabaseError::ensure_affected_rows(&res, 1)?;
307
308 user_registration.email_authentication_id = Some(user_email_authentication.id);
309
310 Ok(user_registration)
311 }
312
313 #[tracing::instrument(
314 name = "db.user_registration.set_password",
315 skip_all,
316 fields(
317 db.query.text,
318 user_registration.id = %user_registration.id,
319 user_registration.hashed_password = hashed_password,
320 user_registration.hashed_password_version = version,
321 ),
322 err,
323 )]
324 async fn set_password(
325 &mut self,
326 mut user_registration: UserRegistration,
327 hashed_password: String,
328 version: u16,
329 ) -> Result<UserRegistration, Self::Error> {
330 let res = sqlx::query!(
331 r#"
332 UPDATE user_registrations
333 SET hashed_password = $2, hashed_password_version = $3
334 WHERE user_registration_id = $1 AND completed_at IS NULL
335 "#,
336 Uuid::from(user_registration.id),
337 hashed_password,
338 i32::from(version),
339 )
340 .traced()
341 .execute(&mut *self.conn)
342 .await?;
343
344 DatabaseError::ensure_affected_rows(&res, 1)?;
345
346 user_registration.password = Some(UserRegistrationPassword {
347 hashed_password,
348 version,
349 });
350
351 Ok(user_registration)
352 }
353
354 #[tracing::instrument(
355 name = "db.user_registration.complete",
356 skip_all,
357 fields(
358 db.query.text,
359 user_registration.id = %user_registration.id,
360 ),
361 err,
362 )]
363 async fn complete(
364 &mut self,
365 clock: &dyn Clock,
366 mut user_registration: UserRegistration,
367 ) -> Result<UserRegistration, Self::Error> {
368 let completed_at = clock.now();
369 let res = sqlx::query!(
370 r#"
371 UPDATE user_registrations
372 SET completed_at = $2
373 WHERE user_registration_id = $1 AND completed_at IS NULL
374 "#,
375 Uuid::from(user_registration.id),
376 completed_at,
377 )
378 .traced()
379 .execute(&mut *self.conn)
380 .await?;
381
382 DatabaseError::ensure_affected_rows(&res, 1)?;
383
384 user_registration.completed_at = Some(completed_at);
385
386 Ok(user_registration)
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use std::net::{IpAddr, Ipv4Addr};
393
394 use mas_data_model::UserRegistrationPassword;
395 use mas_storage::{Clock, clock::MockClock};
396 use rand::SeedableRng;
397 use rand_chacha::ChaChaRng;
398 use sqlx::PgPool;
399
400 use crate::PgRepository;
401
402 #[sqlx::test(migrator = "crate::MIGRATOR")]
403 async fn test_create_lookup_complete(pool: PgPool) {
404 let mut rng = ChaChaRng::seed_from_u64(42);
405 let clock = MockClock::default();
406
407 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
408
409 let registration = repo
410 .user_registration()
411 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
412 .await
413 .unwrap();
414
415 assert_eq!(registration.created_at, clock.now());
416 assert_eq!(registration.completed_at, None);
417 assert_eq!(registration.username, "alice");
418 assert_eq!(registration.display_name, None);
419 assert_eq!(registration.terms_url, None);
420 assert_eq!(registration.email_authentication_id, None);
421 assert_eq!(registration.password, None);
422 assert_eq!(registration.user_agent, None);
423 assert_eq!(registration.ip_address, None);
424 assert_eq!(registration.post_auth_action, None);
425
426 let lookup = repo
427 .user_registration()
428 .lookup(registration.id)
429 .await
430 .unwrap()
431 .unwrap();
432
433 assert_eq!(lookup.id, registration.id);
434 assert_eq!(lookup.created_at, registration.created_at);
435 assert_eq!(lookup.completed_at, registration.completed_at);
436 assert_eq!(lookup.username, registration.username);
437 assert_eq!(lookup.display_name, registration.display_name);
438 assert_eq!(lookup.terms_url, registration.terms_url);
439 assert_eq!(
440 lookup.email_authentication_id,
441 registration.email_authentication_id
442 );
443 assert_eq!(lookup.password, registration.password);
444 assert_eq!(lookup.user_agent, registration.user_agent);
445 assert_eq!(lookup.ip_address, registration.ip_address);
446 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
447
448 let registration = repo
450 .user_registration()
451 .complete(&clock, registration)
452 .await
453 .unwrap();
454 assert_eq!(registration.completed_at, Some(clock.now()));
455
456 let lookup = repo
458 .user_registration()
459 .lookup(registration.id)
460 .await
461 .unwrap()
462 .unwrap();
463 assert_eq!(lookup.completed_at, registration.completed_at);
464
465 let res = repo
467 .user_registration()
468 .complete(&clock, registration)
469 .await;
470 assert!(res.is_err());
471 }
472
473 #[sqlx::test(migrator = "crate::MIGRATOR")]
474 async fn test_create_useragent_ipaddress(pool: PgPool) {
475 let mut rng = ChaChaRng::seed_from_u64(42);
476 let clock = MockClock::default();
477
478 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
479
480 let registration = repo
481 .user_registration()
482 .add(
483 &mut rng,
484 &clock,
485 "alice".to_owned(),
486 Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
487 Some("Mozilla/5.0".to_owned()),
488 Some(serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})),
489 )
490 .await
491 .unwrap();
492
493 assert_eq!(registration.user_agent, Some("Mozilla/5.0".to_owned()));
494 assert_eq!(
495 registration.ip_address,
496 Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
497 );
498 assert_eq!(
499 registration.post_auth_action,
500 Some(
501 serde_json::json!({"action": "continue_compat_sso_login", "id": "01FSHN9AG0MKGTBNZ16RDR3PVY"})
502 )
503 );
504
505 let lookup = repo
506 .user_registration()
507 .lookup(registration.id)
508 .await
509 .unwrap()
510 .unwrap();
511
512 assert_eq!(lookup.user_agent, registration.user_agent);
513 assert_eq!(lookup.ip_address, registration.ip_address);
514 assert_eq!(lookup.post_auth_action, registration.post_auth_action);
515 }
516
517 #[sqlx::test(migrator = "crate::MIGRATOR")]
518 async fn test_set_display_name(pool: PgPool) {
519 let mut rng = ChaChaRng::seed_from_u64(42);
520 let clock = MockClock::default();
521
522 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
523
524 let registration = repo
525 .user_registration()
526 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
527 .await
528 .unwrap();
529
530 assert_eq!(registration.display_name, None);
531
532 let registration = repo
533 .user_registration()
534 .set_display_name(registration, "Alice".to_owned())
535 .await
536 .unwrap();
537
538 assert_eq!(registration.display_name, Some("Alice".to_owned()));
539
540 let lookup = repo
541 .user_registration()
542 .lookup(registration.id)
543 .await
544 .unwrap()
545 .unwrap();
546
547 assert_eq!(lookup.display_name, registration.display_name);
548
549 let registration = repo
551 .user_registration()
552 .set_display_name(registration, "Bob".to_owned())
553 .await
554 .unwrap();
555
556 assert_eq!(registration.display_name, Some("Bob".to_owned()));
557
558 let lookup = repo
559 .user_registration()
560 .lookup(registration.id)
561 .await
562 .unwrap()
563 .unwrap();
564
565 assert_eq!(lookup.display_name, registration.display_name);
566
567 let registration = repo
569 .user_registration()
570 .complete(&clock, registration)
571 .await
572 .unwrap();
573
574 let res = repo
575 .user_registration()
576 .set_display_name(registration, "Charlie".to_owned())
577 .await;
578 assert!(res.is_err());
579 }
580
581 #[sqlx::test(migrator = "crate::MIGRATOR")]
582 async fn test_set_terms_url(pool: PgPool) {
583 let mut rng = ChaChaRng::seed_from_u64(42);
584 let clock = MockClock::default();
585
586 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
587
588 let registration = repo
589 .user_registration()
590 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
591 .await
592 .unwrap();
593
594 assert_eq!(registration.terms_url, None);
595
596 let registration = repo
597 .user_registration()
598 .set_terms_url(registration, "https://example.com/terms".parse().unwrap())
599 .await
600 .unwrap();
601
602 assert_eq!(
603 registration.terms_url,
604 Some("https://example.com/terms".parse().unwrap())
605 );
606
607 let lookup = repo
608 .user_registration()
609 .lookup(registration.id)
610 .await
611 .unwrap()
612 .unwrap();
613
614 assert_eq!(lookup.terms_url, registration.terms_url);
615
616 let registration = repo
618 .user_registration()
619 .set_terms_url(registration, "https://example.com/terms2".parse().unwrap())
620 .await
621 .unwrap();
622
623 assert_eq!(
624 registration.terms_url,
625 Some("https://example.com/terms2".parse().unwrap())
626 );
627
628 let lookup = repo
629 .user_registration()
630 .lookup(registration.id)
631 .await
632 .unwrap()
633 .unwrap();
634
635 assert_eq!(lookup.terms_url, registration.terms_url);
636
637 let registration = repo
639 .user_registration()
640 .complete(&clock, registration)
641 .await
642 .unwrap();
643
644 let res = repo
645 .user_registration()
646 .set_terms_url(registration, "https://example.com/terms3".parse().unwrap())
647 .await;
648 assert!(res.is_err());
649 }
650
651 #[sqlx::test(migrator = "crate::MIGRATOR")]
652 async fn test_set_email_authentication(pool: PgPool) {
653 let mut rng = ChaChaRng::seed_from_u64(42);
654 let clock = MockClock::default();
655
656 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
657
658 let registration = repo
659 .user_registration()
660 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
661 .await
662 .unwrap();
663
664 assert_eq!(registration.email_authentication_id, None);
665
666 let authentication = repo
667 .user_email()
668 .add_authentication_for_registration(
669 &mut rng,
670 &clock,
671 "alice@example.com".to_owned(),
672 ®istration,
673 )
674 .await
675 .unwrap();
676
677 let registration = repo
678 .user_registration()
679 .set_email_authentication(registration, &authentication)
680 .await
681 .unwrap();
682
683 assert_eq!(
684 registration.email_authentication_id,
685 Some(authentication.id)
686 );
687
688 let lookup = repo
689 .user_registration()
690 .lookup(registration.id)
691 .await
692 .unwrap()
693 .unwrap();
694
695 assert_eq!(
696 lookup.email_authentication_id,
697 registration.email_authentication_id
698 );
699
700 let registration = repo
702 .user_registration()
703 .set_email_authentication(registration, &authentication)
704 .await
705 .unwrap();
706
707 assert_eq!(
708 registration.email_authentication_id,
709 Some(authentication.id)
710 );
711
712 let lookup = repo
713 .user_registration()
714 .lookup(registration.id)
715 .await
716 .unwrap()
717 .unwrap();
718
719 assert_eq!(
720 lookup.email_authentication_id,
721 registration.email_authentication_id
722 );
723
724 let registration = repo
726 .user_registration()
727 .complete(&clock, registration)
728 .await
729 .unwrap();
730
731 let res = repo
732 .user_registration()
733 .set_email_authentication(registration, &authentication)
734 .await;
735 assert!(res.is_err());
736 }
737
738 #[sqlx::test(migrator = "crate::MIGRATOR")]
739 async fn test_set_password(pool: PgPool) {
740 let mut rng = ChaChaRng::seed_from_u64(42);
741 let clock = MockClock::default();
742
743 let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
744
745 let registration = repo
746 .user_registration()
747 .add(&mut rng, &clock, "alice".to_owned(), None, None, None)
748 .await
749 .unwrap();
750
751 assert_eq!(registration.password, None);
752
753 let registration = repo
754 .user_registration()
755 .set_password(registration, "fakehashedpassword".to_owned(), 1)
756 .await
757 .unwrap();
758
759 assert_eq!(
760 registration.password,
761 Some(UserRegistrationPassword {
762 hashed_password: "fakehashedpassword".to_owned(),
763 version: 1,
764 })
765 );
766
767 let lookup = repo
768 .user_registration()
769 .lookup(registration.id)
770 .await
771 .unwrap()
772 .unwrap();
773
774 assert_eq!(lookup.password, registration.password);
775
776 let registration = repo
778 .user_registration()
779 .set_password(registration, "fakehashedpassword2".to_owned(), 2)
780 .await
781 .unwrap();
782
783 assert_eq!(
784 registration.password,
785 Some(UserRegistrationPassword {
786 hashed_password: "fakehashedpassword2".to_owned(),
787 version: 2,
788 })
789 );
790
791 let lookup = repo
792 .user_registration()
793 .lookup(registration.id)
794 .await
795 .unwrap()
796 .unwrap();
797
798 assert_eq!(lookup.password, registration.password);
799
800 let registration = repo
802 .user_registration()
803 .complete(&clock, registration)
804 .await
805 .unwrap();
806
807 let res = repo
808 .user_registration()
809 .set_password(registration, "fakehashedpassword3".to_owned(), 3)
810 .await;
811 assert!(res.is_err());
812 }
813}