diff options
Diffstat (limited to 'app/app_account.py')
-rw-r--r-- | app/app_account.py | 39 |
1 files changed, 27 insertions, 12 deletions
diff --git a/app/app_account.py b/app/app_account.py index 3f4869d..e3d6433 100644 --- a/app/app_account.py +++ b/app/app_account.py | |||
@@ -3,6 +3,7 @@ | |||
3 | # Licence: EUPL-1.2 | 3 | # Licence: EUPL-1.2 |
4 | 4 | ||
5 | from fastapi import APIRouter, Depends, Request, Form, status | 5 | from fastapi import APIRouter, Depends, Request, Form, status |
6 | from fastapi.responses import RedirectResponse, HTMLResponse | ||
6 | 7 | ||
7 | from passlib.context import CryptContext | 8 | from passlib.context import CryptContext |
8 | import re | 9 | import re |
@@ -10,8 +11,9 @@ import re | |||
10 | from embrace.exceptions import IntegrityError | 11 | from embrace.exceptions import IntegrityError |
11 | from psycopg2.errors import UniqueViolation | 12 | from psycopg2.errors import UniqueViolation |
12 | 13 | ||
13 | from app_sessions import UserSession | 14 | from app_sessions import UserSession, FlashMessageQueue |
14 | from app_database import db_transaction | 15 | from app_database import db_transaction |
16 | from app_templating import TemplateRenderer | ||
15 | 17 | ||
16 | 18 | ||
17 | # Password hashing context. | 19 | # Password hashing context. |
@@ -20,42 +22,51 @@ password_ctx = CryptContext(schemes=['bcrypt'], deprecated='auto') | |||
20 | 22 | ||
21 | username_pattern = re.compile(r'^[a-zA-Z0-9-_]{4,16}$') | 23 | username_pattern = re.compile(r'^[a-zA-Z0-9-_]{4,16}$') |
22 | 24 | ||
25 | to_homepage = RedirectResponse('/', status_code=status.HTTP_303_SEE_OTHER) | ||
26 | to_wallet = RedirectResponse('/wallet', status_code=status.HTTP_303_SEE_OTHER) | ||
27 | |||
23 | router = APIRouter() | 28 | router = APIRouter() |
24 | 29 | ||
25 | 30 | ||
26 | @router.get('/') | 31 | @router.get('/', response_class=HTMLResponse) |
27 | def homepage( | 32 | def homepage( |
28 | session: UserSession=Depends(UserSession), | 33 | session: UserSession=Depends(UserSession), |
34 | render: TemplateRenderer=Depends(TemplateRenderer), | ||
29 | ): | 35 | ): |
30 | if session.is_logged_in(): | 36 | if session.is_logged_in(): |
31 | return 'Welcome!' | 37 | return to_wallet |
32 | 38 | ||
33 | return 'Homepage here.' | 39 | return render('homepage.html.jinja') |
34 | 40 | ||
35 | 41 | ||
36 | @router.post('/account/register') | 42 | @router.post('/account/register') |
37 | def account_register( | 43 | def account_register( |
38 | session: UserSession=Depends(UserSession), | 44 | session: UserSession=Depends(UserSession), |
45 | messages: FlashMessageQueue=Depends(FlashMessageQueue), | ||
39 | username: str=Form(...), | 46 | username: str=Form(...), |
40 | password: str=Form(...), | 47 | password: str=Form(...), |
41 | ): | 48 | ): |
42 | try: | 49 | try: |
43 | if username_pattern.match(username) is None: | 50 | if username_pattern.match(username) is None: |
44 | return 'error: Invalid username format.' | 51 | messages.add('error', 'Invalid username format.') |
52 | return to_homepage | ||
45 | 53 | ||
46 | if not 4 <= len(password) <= 32: | 54 | if not 4 <= len(password) <= 32: |
47 | return 'error: Invalid password length.' | 55 | messages.add('error', 'Invalid password length.') |
56 | return to_homepage | ||
48 | 57 | ||
49 | hash = password_ctx.hash(password) | 58 | hash = password_ctx.hash(password) |
50 | with db_transaction() as tx: | 59 | with db_transaction() as tx: |
51 | user = tx.create_account(username=username, password_hash=hash) | 60 | user = tx.create_account(username=username, password_hash=hash) |
52 | 61 | ||
53 | session.login(user.id) | 62 | session.login(user.id) |
54 | return 'Account succesfully created. Welcome!' | 63 | messages.add('success', 'Account succesfully created. Welcome!') |
64 | return to_wallet | ||
55 | 65 | ||
56 | except IntegrityError as exception: | 66 | except IntegrityError as exception: |
57 | if isinstance(exception.__cause__, UniqueViolation): | 67 | if isinstance(exception.__cause__, UniqueViolation): |
58 | return 'error: This username is already taken.' | 68 | messages.add('error', 'This username is already taken.') |
69 | return to_homepage | ||
59 | else: | 70 | else: |
60 | raise exception | 71 | raise exception |
61 | 72 | ||
@@ -63,6 +74,7 @@ def account_register( | |||
63 | @router.post('/account/login') | 74 | @router.post('/account/login') |
64 | def session_login( | 75 | def session_login( |
65 | session: UserSession=Depends(UserSession), | 76 | session: UserSession=Depends(UserSession), |
77 | messages: FlashMessageQueue=Depends(FlashMessageQueue), | ||
66 | username: str=Form(...), | 78 | username: str=Form(...), |
67 | password: str=Form(...), | 79 | password: str=Form(...), |
68 | ): | 80 | ): |
@@ -71,17 +83,20 @@ def session_login( | |||
71 | 83 | ||
72 | if user is not None and password_ctx.verify(password, user.password_hash): | 84 | if user is not None and password_ctx.verify(password, user.password_hash): |
73 | session.login(user.id) | 85 | session.login(user.id) |
74 | return 'Welcome back!' | 86 | messages.add('info', 'Welcome back!') |
87 | return to_wallet | ||
75 | else: | 88 | else: |
76 | return 'error: Invalid credentials.' | 89 | messages.add('error', 'Invalid credentials.') |
90 | return to_homepage | ||
77 | 91 | ||
78 | 92 | ||
79 | @router.post('/account/logout') | 93 | @router.post('/account/logout') |
80 | def session_logout( | 94 | def session_logout( |
81 | session: UserSession=Depends(UserSession), | 95 | session: UserSession=Depends(UserSession), |
96 | messages: FlashMessageQueue=Depends(FlashMessageQueue), | ||
82 | ): | 97 | ): |
83 | if session.is_logged_in(): | 98 | if session.is_logged_in(): |
84 | session.logout() | 99 | session.logout() |
85 | return 'You have been successfully logged out.' | 100 | messages.add('info', 'You have been successfully logged out.') |
86 | 101 | ||
87 | return 'Nothing to do' | 102 | return to_homepage |