from passlib.hash import argon2 from pyramid.view import view_config, forbidden_view_config from pyramid.security import forget, remember from pyramid.authorization import ACLAuthorizationPolicy from pyramid.authentication import AuthTktAuthenticationPolicy from pyramid.httpexceptions import HTTPFound from . import Root AUTHENTICATED_USER_ID = "authenticated" class MyAuthenticationPolicy(AuthTktAuthenticationPolicy): def authenticated_userid(self, request): user = request.user if user is not None: return AUTHENTICATED_USER_ID def get_user(request): return request.unauthenticated_userid @forbidden_view_config( renderer="superx_budget:pyramid/templates/login.jinja2", ) def forbidden_view(request): return {"error": False} @view_config( context=Root, name="login", request_method="POST", permission="login", renderer="superx_budget:pyramid/templates/login.jinja2", ) def login(request): if request.check_password(): headers = remember(request, AUTHENTICATED_USER_ID, max_age=3600) return HTTPFound("/", headers=headers) return {"error": True} @view_config( context=Root, name="logout", permission="login", ) def logout(request): headers = forget(request) return HTTPFound("/", headers=headers) def includeme(config): settings = config.get_settings() authn_policy = MyAuthenticationPolicy( settings["auth.secret"], hashalg="sha512", ) config.set_authentication_policy(authn_policy) config.set_authorization_policy(ACLAuthorizationPolicy()) hashes = [hash for hash in settings["pwd.db"].splitlines() if hash] def check_password(request): password = request.POST.get("password", "") return any(argon2.verify(password, hash) for hash in hashes) config.add_request_method(check_password, "check_password") config.add_request_method(get_user, "user", reify=True)