diff --git a/ordr3/models.py b/ordr3/models.py index e57d036..0e462c0 100644 --- a/ordr3/models.py +++ b/ordr3/models.py @@ -161,6 +161,9 @@ class Vendor(Model): self.term = term self.name = name + def __repr__(self): + return f"" + class User(Model): diff --git a/ordr3/repo.py b/ordr3/repo.py index 5257b7c..5ac49a5 100644 --- a/ordr3/repo.py +++ b/ordr3/repo.py @@ -65,6 +65,10 @@ class AbstractOrderRepository(abc.ABC): def get_vendor_aggregates(self, reference): """ list a all canonical names of vendors """ + @abc.abstractmethod + def update_vendors(self, old_vendor, new_name, new_terms): + """ update autocorrect values of vendors """ + @abc.abstractmethod def add_reset_token(self, token): """ add an password reset token """ @@ -191,6 +195,23 @@ class SqlAlchemyRepository(AbstractOrderRepository): terms = sorted((v.term for v in vendors), key=lambda x: x.lower()) return models.VendorAggregate(vendors[0].name, terms) + def update_vendors(self, old_vendor, new_name, new_terms): + """ update autocorrect values of vendors """ + # remove old vendor autocorrect values + self.session.query(models.Vendor).filter( + models.Vendor.name == old_vendor.name + ).delete(synchronize_session="fetch") + + # delete all terms to be added later + self.session.query(models.Vendor).filter( + models.Vendor.term.in_(new_terms) + ).delete(synchronize_session="fetch") + + # add the new definition: + for new_term in new_terms: + new_vendor = models.Vendor(new_term, new_name) + self._add_item_to_db(new_vendor) + def add_reset_token(self, token): """ add an password reset token """ self._add_item_to_db(token) diff --git a/ordr3/resources.py b/ordr3/resources.py index 5f422ee..24e7c41 100644 --- a/ordr3/resources.py +++ b/ordr3/resources.py @@ -118,7 +118,6 @@ class Vendor(BaseResource): @classmethod def from_model(cls, model, parent): """ initializes a resource from an model object """ - print(model) return cls(model.name, parent, model) diff --git a/ordr3/services.py b/ordr3/services.py index 59247ce..0951688 100644 --- a/ordr3/services.py +++ b/ordr3/services.py @@ -98,6 +98,12 @@ def _vendor_with_common_domain(vendor): return vendor +def canonical_vendor_name(vendor): + cleaned = " ".join(vendor.strip().split()) + tmp = _vendor_from_url(cleaned) + return _vendor_with_common_domain(tmp) + + CheckVendorResult = namedtuple("CheckVendorResult", ["name", "found"]) @@ -105,7 +111,7 @@ def check_vendor_name(repo, to_check): # remove unused whitespace cleaned = " ".join(to_check.strip().split()) tmp = _vendor_from_url(cleaned) - canonical_name = _vendor_with_common_domain(tmp) + canonical_name = canonical_vendor_name(tmp) vendor = repo.search_vendor(canonical_name.lower()) diff --git a/ordr3/templates/vendors/edit.jinja2 b/ordr3/templates/vendors/edit.jinja2 new file mode 100644 index 0000000..65f0ef4 --- /dev/null +++ b/ordr3/templates/vendors/edit.jinja2 @@ -0,0 +1,37 @@ +{% extends "ordr3:templates/layout_full.jinja2" %} + +{% block subtitle %} Edit Vendor Autocorrect for "{{ context.model.name }}" {% endblock subtitle %} + + +{% block sidebar %}{% endblock sidebar %} + + +{% block content %} + +
+

Edit Autorcorrect for {{ context.model.name }}

+
+
+ + +
Must not be empty
+
+
+ + + Add one term per line +
Must not be empty
+
+ +

+ + + +

+
+
+ + +
+ +{% endblock content %} diff --git a/ordr3/templates/vendors/list.jinja2 b/ordr3/templates/vendors/list.jinja2 index 2f0c8c1..8b7bd8e 100644 --- a/ordr3/templates/vendors/list.jinja2 +++ b/ordr3/templates/vendors/list.jinja2 @@ -3,9 +3,7 @@ {% block subtitle %} Manage Vendor Autocorrect {% endblock subtitle %} -{% block sidebar %} - -{% endblock sidebar %} +{% block sidebar %}{% endblock sidebar %} {% block content %} diff --git a/ordr3/views/vendors.py b/ordr3/views/vendors.py index 2a044b0..add52ab 100644 --- a/ordr3/views/vendors.py +++ b/ordr3/views/vendors.py @@ -1,13 +1,9 @@ -# import deform from sqlalchemy import func - -# from pyramid.csrf import get_csrf_token +from pyramid.csrf import get_csrf_token from pyramid.view import view_config +from pyramid.httpexceptions import HTTPFound -# from .. import events, models, services, resources -from .. import models - -# from pyramid.httpexceptions import HTTPFound +from .. import events, models, services @view_config( @@ -26,3 +22,48 @@ def vendor_list(context, request): ) return {"vendors": vendors} + + +@view_config( + context="ordr3:resources.Vendor", + permission="edit", + request_method="GET", + renderer="ordr3:templates/vendors/edit.jinja2", +) +def vendor_edit_form(context, request): + return { + "form_error": False, + "csrf_token": get_csrf_token(request), + } + + +@view_config( + context="ordr3:resources.Vendor", + permission="edit", + request_method="POST", + renderer="ordr3:templates/vendors/edit.jinja2", +) +def vendor_edit(context, request): + if "change" not in request.POST: + return HTTPFound(request.resource_url(context.__parent__)) + + vendor = request.POST.get("vendor", "").strip() + terms = request.POST.get("terms", "").strip() + + if vendor and terms: + terms = set(terms.lower().splitlines()) + canonical_name = services.canonical_vendor_name(vendor) + terms.add(canonical_name.lower()) + request.repo.update_vendors(context.model, vendor, terms) + + request.emit( + events.FlashMessage.info( + f"The autocorrect for {vendor} was updated." + ) + ) + return HTTPFound(request.resource_url(context.__parent__)) + + return { + "form_error": True, + "csrf_token": get_csrf_token(request), + } diff --git a/tests/test_models.py b/tests/test_models.py index 655a5fe..9d135a2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -112,6 +112,14 @@ def test_vendor_init(): assert vendor.name == "B" +def test_vendor_repr(): + from ordr3.models import Vendor + + vendor = Vendor("A", "B") + + assert repr(vendor) == "" + + def test_user_init(): from ordr3.models import User diff --git a/tests/test_repo.py b/tests/test_repo.py index 984e218..e17a981 100644 --- a/tests/test_repo.py +++ b/tests/test_repo.py @@ -89,6 +89,18 @@ def example_tokens(): ] +@pytest.fixture() +def example_vendors(): + from ordr3.models import Vendor + + return [ + Vendor("sb", "Sigma Aldrich"), + Vendor("sa", "Sigma Aldrich"), + Vendor("vw", "VWR"), + Vendor("me", "Merck"), + ] + + def test_sql_repo_add_order(session, example_orders): from ordr3.repo import SqlAlchemyRepository from ordr3.models import OrderItem @@ -283,17 +295,12 @@ def test_sql_search_vendor(session, example_users): assert repo.search_vendor("unknown") is None -def test_sql_get_vendor_aggregates(session): +def test_sql_get_vendor_aggregates(session, example_vendors): from ordr3.repo import SqlAlchemyRepository - from ordr3.models import Vendor, VendorAggregate + from ordr3.models import VendorAggregate repo = SqlAlchemyRepository(session) - entry_1 = Vendor("sb", "Sigma Aldrich") - entry_2 = Vendor("sa", "Sigma Aldrich") - entry_3 = Vendor("vwr", "VWR") - session.add(entry_1) - session.add(entry_2) - session.add(entry_3) + session.add_all(example_vendors) session.flush() result = repo.get_vendor_aggregates("Sigma Aldrich") @@ -312,6 +319,28 @@ def test_sql_get_vendor_aggregates_raises_error(session): repo.get_vendor_aggregates("Sigma Aldrich") +def test_sql_update_vendors(session, example_vendors): + from ordr3.repo import SqlAlchemyRepository + from ordr3.models import Vendor, VendorAggregate + + repo = SqlAlchemyRepository(session) + session.add_all(example_vendors) + session.flush() + + old_vendor = VendorAggregate("Sigma Aldrich", None) + + repo.update_vendors(old_vendor, "ACME", {"sa", "me", "sx"}) + + result = session.query(Vendor).order_by(Vendor.term).all() + + assert result == [ + Vendor("me", "ACME"), + Vendor("sa", "ACME"), + Vendor("sx", "ACME"), + Vendor("vw", "VWR"), + ] + + def test_sql_repo_add_reset_token(session, example_tokens): from ordr3.repo import SqlAlchemyRepository from ordr3.models import PasswordResetToken diff --git a/tests/test_services.py b/tests/test_services.py index 8c46bff..d8d01a9 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta import pytest from ordr3.repo import AbstractOrderRepository -from ordr3.models import VendorAggregate +from ordr3.models import Vendor, VendorAggregate class FakeOrderRepository(AbstractOrderRepository): @@ -67,6 +67,16 @@ class FakeOrderRepository(AbstractOrderRepository): first_vendor = next(iter(vendors)) return VendorAggregate(first_vendor.name, terms) + def update_vendors(self, old_vendor, new_name, new_terms): + """ update autocorrect values of vendors """ + vendors_to_delete = { + v for v in self._vendors if v.name == old_vendor.name + } + terms_to_delete = {v for v in self._vendors if v.term in new_terms} + self._vendors = self._vendors - vendors_to_delete - terms_to_delete + for new_term in new_terms: + self._vendors.add(Vendor(new_term, new_name)) + def add_reset_token(self, token): """ add an password reset token """ self._tokens.add(token)