You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
241 lines
7.3 KiB
241 lines
7.3 KiB
""" Classes for acessing a datastore """ |
|
|
|
import abc |
|
from datetime import datetime |
|
|
|
from sqlalchemy import func |
|
from sqlalchemy.orm.exc import NoResultFound |
|
|
|
from . import models |
|
|
|
|
|
class RepoItemNotFound(StopIteration): |
|
"""repo error for a not found item""" |
|
|
|
pass |
|
|
|
|
|
class AbstractOrderRepository(abc.ABC): |
|
"""Abstract base class for a datastore""" |
|
|
|
def __init__(self, session): |
|
self.session = session |
|
|
|
@abc.abstractmethod |
|
def add_order(self, order): |
|
"""add an order to the datastore""" |
|
|
|
@abc.abstractmethod |
|
def get_order(self, reference): |
|
"""get an order from the datastore by primary key""" |
|
|
|
@abc.abstractmethod |
|
def delete_order(self, order): |
|
"""remove an order from the datastore""" |
|
|
|
@abc.abstractmethod |
|
def list_consumable_candidates(self, limit_date, statuses): |
|
"""list orders that might be consumables""" |
|
|
|
@abc.abstractmethod |
|
def add_user(self, user): |
|
"""add a user to the datastore""" |
|
|
|
@abc.abstractmethod |
|
def delete_user(self, user): |
|
"""removes a user from the datastore""" |
|
|
|
@abc.abstractmethod |
|
def get_user(self, reference): |
|
"""get a user from the datastore by primary key""" |
|
|
|
@abc.abstractmethod |
|
def get_user_by_username(self, reference): |
|
"""get a user from the datastore by username""" |
|
|
|
@abc.abstractmethod |
|
def get_user_by_email(self, reference): |
|
"""get a user from the datastore by email""" |
|
|
|
@abc.abstractmethod |
|
def search_vendor(self, reference): |
|
"""search for a vendor by its canonical name""" |
|
|
|
@abc.abstractmethod |
|
def get_vendor_aggregates(self, reference): |
|
"""list a all canonical names of vendors""" |
|
|
|
@abc.abstractmethod |
|
def update_vendors(self, old_vendor, new_name, new_terms): |
|
"""update autocorrect values of vendors""" |
|
|
|
@abc.abstractmethod |
|
def add_reset_token(self, token): |
|
"""add an password reset token""" |
|
|
|
@abc.abstractmethod |
|
def get_reset_token(self, reference): |
|
"""add an password reset token""" |
|
|
|
@abc.abstractmethod |
|
def delete_reset_token(self, token): |
|
"""deletes a password reset token""" |
|
|
|
@abc.abstractmethod |
|
def clear_stale_reset_tokens(self): |
|
"""removes invalid reset tokens""" |
|
|
|
|
|
class SqlAlchemyRepository(AbstractOrderRepository): |
|
"""Repository implementation for SQLAlchemy""" |
|
|
|
def _add_item_to_db(self, item): |
|
"""add any item to the database""" |
|
self.session.add(item) |
|
self.session.flush() |
|
|
|
def _delete_item_from_db(self, item): |
|
"""add any item to the database""" |
|
self.session.delete(item) |
|
self.session.flush() |
|
|
|
def add_order(self, order): |
|
"""add an order to the database""" |
|
self._add_item_to_db(order) |
|
|
|
def delete_order(self, order): |
|
"""remove an order from the datastore""" |
|
for log_entry in order.log: |
|
self.session.delete(log_entry) |
|
self._delete_item_from_db(order) |
|
|
|
def get_order(self, reference): |
|
"""get an order from the database by primary key""" |
|
try: |
|
return ( |
|
self.session.query(models.OrderItem) |
|
.filter_by(id=reference) |
|
.one() |
|
) |
|
except NoResultFound as exc: |
|
raise RepoItemNotFound from exc |
|
|
|
def list_consumable_candidates(self, limit_date, statuses): |
|
"""list orders that might be consumables""" |
|
return ( |
|
self.session.query(models.OrderItem) |
|
.filter(models.OrderItem.created_on > limit_date) |
|
.filter(models.OrderItem.status.in_(statuses)) |
|
.order_by(models.OrderItem.created_on.desc()) |
|
.all() |
|
) |
|
|
|
def add_user(self, user): |
|
"""add a user to the database""" |
|
self._add_item_to_db(user) |
|
|
|
def delete_user(self, user): |
|
"""removes a user from the datastore""" |
|
self._delete_item_from_db(user) |
|
|
|
def get_user(self, reference): |
|
"""get a user from the database by primary key""" |
|
try: |
|
return ( |
|
self.session.query(models.User).filter_by(id=reference).one() |
|
) |
|
except NoResultFound as exc: |
|
raise RepoItemNotFound from exc |
|
|
|
def get_user_by_username(self, reference): |
|
"""get a user from the database by username""" |
|
try: |
|
return ( |
|
self.session.query(models.User) |
|
.filter_by(username=reference) |
|
.one() |
|
) |
|
except NoResultFound as exc: |
|
raise RepoItemNotFound from exc |
|
|
|
def get_user_by_email(self, reference): |
|
"""get a user from the database by email""" |
|
try: |
|
return ( |
|
self.session.query(models.User) |
|
.filter(func.lower(models.User.email) == func.lower(reference)) |
|
.one() |
|
) |
|
except NoResultFound as exc: |
|
raise RepoItemNotFound from exc |
|
|
|
def count_new_users(self): |
|
"""count the number of new users that need approval""" |
|
return ( |
|
self.session.query(models.User) |
|
.filter(models.User.role == models.UserRole.NEW) |
|
.count() |
|
) |
|
|
|
def search_vendor(self, reference): |
|
"""search for a vendor by its canonical name""" |
|
return ( |
|
self.session.query(models.Vendor) |
|
.filter_by(term=reference) |
|
.one_or_none() |
|
) |
|
|
|
def get_vendor_aggregates(self, reference): |
|
"""list a all canonical names of vendors""" |
|
vendors = ( |
|
self.session.query(models.Vendor).filter_by(name=reference).all() |
|
) |
|
if not vendors: |
|
raise RepoItemNotFound |
|
terms = sorted((v.term for v in vendors), key=lambda x: x.lower()) |
|
return models.VendorAggregate(vendors[0].name, terms) |
|
|
|
def update_vendors(self, old_vendor, new_name, new_terms): |
|
"""update autocorrect values of vendors""" |
|
# remove old vendor autocorrect values |
|
self.session.query(models.Vendor).filter( |
|
models.Vendor.name == old_vendor.name |
|
).delete(synchronize_session="fetch") |
|
|
|
# delete all terms to be added later |
|
self.session.query(models.Vendor).filter( |
|
models.Vendor.term.in_(new_terms) |
|
).delete(synchronize_session="fetch") |
|
|
|
# add the new definition: |
|
for new_term in new_terms: |
|
new_vendor = models.Vendor(new_term, new_name) |
|
self._add_item_to_db(new_vendor) |
|
|
|
def add_reset_token(self, token): |
|
"""add an password reset token""" |
|
self._add_item_to_db(token) |
|
|
|
def delete_reset_token(self, token): |
|
"""deletes a password reset token""" |
|
self._delete_item_from_db(token) |
|
|
|
def get_reset_token(self, reference): |
|
"""get a passowrd reset token from the database""" |
|
try: |
|
return ( |
|
self.session.query(models.PasswordResetToken) |
|
.filter( |
|
func.lower(models.PasswordResetToken.token) |
|
== func.lower(reference) |
|
) |
|
.one() |
|
) |
|
except NoResultFound as exc: |
|
raise RepoItemNotFound from exc |
|
|
|
def clear_stale_reset_tokens(self): |
|
"""removes invalid reset tokens""" |
|
self.session.query(models.PasswordResetToken).filter( |
|
models.PasswordResetToken.valid_until < datetime.utcnow() |
|
).delete()
|
|
|