diff --git a/ordr3/adapters.py b/ordr3/adapters.py index 05c9366..3591240 100644 --- a/ordr3/adapters.py +++ b/ordr3/adapters.py @@ -81,6 +81,14 @@ user_table = Table( Column("role", Enum(models.UserRole)), ) +reset_token_table = Table( + "reset_tokens", + metadata, + Column("token", Text, primary_key=True), + Column("user_id", Integer, nullable=False), + Column("valid_unitl", DateTime, nullable=False), +) + def start_mappers(): """ maps data base tables to model objects """ @@ -96,6 +104,7 @@ def start_mappers(): mapper(models.LogEntry, log_table) mapper(models.Vendor, vendor_table) mapper(models.User, user_table) + mapper(models.PasswordResetToken, reset_token_table) def get_engine(settings, prefix="sqlalchemy."): diff --git a/ordr3/models.py b/ordr3/models.py index 21b33d3..120e3ea 100644 --- a/ordr3/models.py +++ b/ordr3/models.py @@ -1,5 +1,5 @@ import enum -from datetime import datetime +from datetime import datetime, timedelta @enum.unique @@ -202,3 +202,30 @@ class User(Model): def __repr__(self): return f"" + + +class PasswordResetToken(Model): + + token = None + user_id = None + valid_until = None + + def __init__(self, token, user_id, valid_until=None): + self.token = token + self.user_id = user_id + defaul_valid_until = datetime.utcnow() + timedelta(hours=1) + self.valid_until = valid_until or defaul_valid_until + + @property + def is_valid(self): + try: + return datetime.utcnow() < self.valid_until + except TypeError: + return False + + def __repr__(self): + date_str = self.valid_until.strftime("%Y-%m-%d %H:%M:%S") + return ( + f"" + ) diff --git a/ordr3/repo.py b/ordr3/repo.py index 2c8fd96..d603df6 100644 --- a/ordr3/repo.py +++ b/ordr3/repo.py @@ -1,6 +1,7 @@ """ Classes for acessing a datastore """ import abc +from datetime import datetime from sqlalchemy import func from sqlalchemy.orm.exc import NoResultFound @@ -55,6 +56,22 @@ class AbstractOrderRepository(abc.ABC): def search_vendor(self, reference): """ search for a vendor by its canonical name """ + @abc.abstractmethod + def add_reset_token(self, token): + """ add an password reset token """ + + @abc.abstractmethod + def get_reset_token(self, reference): + """ add an password reset token """ + + @abc.abstractmethod + def delete_reset_token(self, token): + """ deletes a password reset token """ + + @abc.abstractmethod + def clean_stale_reset_tokens(self): + """ removes invalid reset tokens """ + class SqlAlchemyRepository(AbstractOrderRepository): """ Repository implementation for SQLAlchemy """ @@ -64,6 +81,11 @@ class SqlAlchemyRepository(AbstractOrderRepository): self.session.add(item) self.session.flush() + def _delete_item_from_db(self, item): + """ add any item to the database """ + self.session.delete(item) + self.session.flush() + def add_order(self, order): """ add an order to the database """ self._add_item_to_db(order) @@ -140,3 +162,31 @@ class SqlAlchemyRepository(AbstractOrderRepository): if vendor is None: return None return vendor.name + + def add_reset_token(self, token): + """ add an password reset token """ + self._add_item_to_db(token) + + def delete_reset_token(self, token): + """ deletes a password reset token """ + self._delete_item_from_db(token) + + def get_reset_token(self, reference): + """ get a passowrd reset token from the database""" + try: + return ( + self.session.query(models.PassworResetToken) + .filter( + func.lower(models.PassworResetToken.token) + == func.lower(reference) + ) + .one() + ) + except NoResultFound as exc: + raise RepoItemNotFound from exc + + def clean_stale_reset_tokens(self): + """ removes invalid reset tokens """ + self.session.delete(models.PassworResetToken).filter( + models.PassworResetToken.valid_unit < datetime.utcnow() + ) diff --git a/ordr3/services.py b/ordr3/services.py index c75a9ae..b1365ea 100644 --- a/ordr3/services.py +++ b/ordr3/services.py @@ -149,3 +149,13 @@ def _check_have_i_been_pwned(password_hash, event_queue): event_queue.emit(MSG_PWNED_PASSWORD) return True return False + + +def get_user_from_reset_token(repo, identifier): + try: + token = repo.get_reset_token(identifier) + if token.is_valid: + return repo.get_user(token.user_id) + except StopIteration: + pass + return None diff --git a/tests/test_models.py b/tests/test_models.py index c4d7d8a..a2a16bb 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -163,3 +163,65 @@ def test_user_role_rincipal(): from ordr3.models import UserRole assert UserRole.INACTIVE.principal == "role:inactive" + + +def test_password_reset_token_init(): + from ordr3.models import PasswordResetToken + + token = PasswordResetToken(*list("ABC")) + + assert token.token == "A" + assert token.user_id == "B" + assert token.valid_until == "C" + + +def test_password_reset_token_init_auto_validity(): + from ordr3.models import PasswordResetToken + from datetime import datetime, timedelta + + token = PasswordResetToken(*list("AB")) + + assert token.token == "A" + assert token.user_id == "B" + assert token.valid_until - datetime.utcnow() <= timedelta(hours=1) + + +def test_password_reset_token_is_valid_ok(): + from ordr3.models import PasswordResetToken + + token = PasswordResetToken(*list("ABC")) + + assert not token.is_valid + + +def test_password_reset_token_is_valid_not_ok(): + from ordr3.models import PasswordResetToken + from datetime import datetime, timedelta + + invalid_time = datetime.utcnow() - timedelta(hours=1) + token = PasswordResetToken(*list("AB"), invalid_time) + + assert not token.is_valid + + +def test_password_reset_token_is_valid_invalid_time_format(): + from ordr3.models import PasswordResetToken + + token = PasswordResetToken(*list("ABC")) + + assert not token.is_valid + + +def test_password_reset_token_repr(): + from ordr3.models import PasswordResetToken + from datetime import datetime + + valid_until = datetime(2020, 4, 2, 10, 11, 12) + token = PasswordResetToken(*list("AB"), valid_until) + + result = repr(token) + + assert result == ( + "" + ) diff --git a/tests/test_services.py b/tests/test_services.py index 2270b58..909c608 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -12,6 +12,7 @@ class FakeOrderRepository(AbstractOrderRepository): self._orders = set() self._users = set() self._vendors = {"sa": "Sigma Aldrich"} + self._tokens = set() def add_order(self, order): """ add an order to the datastore """ @@ -49,6 +50,23 @@ class FakeOrderRepository(AbstractOrderRepository): """ search for a vendor by a canonical search term """ return self._vendors.get(reference, None) + def add_reset_token(self, token): + """ add an password reset token """ + self._tokens.add(token) + + def delete_reset_token(self, token): + """ deletes a password reset token """ + self._tokens.remove(token) + + def get_reset_token(self, reference): + """ add an password reset token """ + return next(t for t in self._tokens if t.token == reference) + + def clean_stale_reset_tokens(self): + """ removes invalid reset tokens """ + now = datetime.utcnow() + self._tokens = {t for t in self._tokens if t.valid_until > now} + class FakePasslibContext: def __init__(self, needs_update): @@ -323,3 +341,65 @@ def test_set_new_password_to_short_and_breached(monkeypatch): assert get_passlib_context().verify("1", user.password) assert len(queue) == 1 # only one item in que due to monkeypatch assert queue[0].text.startswith("Your password is quite short") + + +def test_get_user_from_reset_token_ok(): + from ordr3 import services + from ordr3.models import PasswordResetToken, User + + repo = FakeOrderRepository(None) + user = User(*list("ABCDEFG")) + repo.add_user(user) + token = PasswordResetToken("identifier", "A") + repo.add_reset_token(token) + + result = services.get_user_from_reset_token(repo, "identifier") + + assert result == user + + +def test_get_user_from_reset_token_wrong_token(): + from ordr3 import services + from ordr3.models import PasswordResetToken, User + + repo = FakeOrderRepository(None) + user = User(*list("ABCDEFG")) + repo.add_user(user) + token = PasswordResetToken("identifier", "A") + repo.add_reset_token(token) + + result = services.get_user_from_reset_token(repo, "wrong identifier") + + assert result is None + + +def test_get_user_from_reset_token_invalid_token(): + from ordr3 import services + from ordr3.models import PasswordResetToken, User + from datetime import datetime, timedelta + + repo = FakeOrderRepository(None) + user = User(*list("ABCDEFG")) + repo.add_user(user) + valid_until = datetime.now() - timedelta(hours=2) + token = PasswordResetToken("identifier", "A", valid_until) + repo.add_reset_token(token) + + result = services.get_user_from_reset_token(repo, "identifier") + + assert result is None + + +def test_get_user_from_reset_token_unknown_user(): + from ordr3 import services + from ordr3.models import PasswordResetToken, User + + repo = FakeOrderRepository(None) + user = User(*list("ABCDEFG")) + repo.add_user(user) + token = PasswordResetToken("identifier", "B") + repo.add_reset_token(token) + + result = services.get_user_from_reset_token(repo, "identifier") + + assert result is None