diff --git a/ordr2/resources/base.py b/ordr2/resources/base.py index 75f488c..224c6e5 100644 --- a/ordr2/resources/base.py +++ b/ordr2/resources/base.py @@ -1,5 +1,9 @@ +from collections import namedtuple from pyramid.security import DENY_ALL +from sqlalchemy import asc, desc +from sqlalchemy.inspection import inspect + class BaseResource(object): @@ -8,18 +12,201 @@ class BaseResource(object): request = None model = None - nodes = {} + nodes = dict() nav_highlight = None - def __init__(self, name, parent): + def __init__(self, name, parent, sql_model_instance=None): self.__name__ = name self.__parent__ = parent self.request = parent.request + self.model = sql_model_instance + + # call to super().__init_() needed to set up PaginationMixin + super().__init__() def __acl__(self): return [ DENY_ALL ] + def __getitem__(self, key): klass = self.nodes[key] return klass(key, self) + + + @classmethod + def from_sqla(cls, sql_model_instance, parent): + ''' initializes a resource from an SQLalchemy object ''' + primary_keys = inspect(sql_model_instance).identity + if primary_keys is None: + raise ValueError('Cannot init resource for primary key: None') + elif len(primary_keys) != 1: + raise ValueError('Cannot init resource for composite primary key') + primary_key = str(primary_keys[0]) + return cls(primary_key, parent, sql_model_instance) + + +class Pagination(object): + + default_items = 25 + default_window_size = 7 + + def __init__(self, current, count, items=None, window_size=None): + current = self._ensure_int(current, 1) + count = self._ensure_int(count, 0) + items = self._ensure_int(items, self.default_items) + window_size = self._ensure_int(window_size, self.default_window_size) + + self.count = count + self.items = self.items_per_page = items + + pages = (count - 1) // items + 1 + self.first = 1 + self.last = max(self.first, pages) + self.current = self._is_valid(current, default=self.first) + self.previous = self._is_valid(self.current - 1) + self.next = self._is_valid(self.current + 1) + + # window calculations + # example: lets assume the current page is 10 and window size is 7 + # self.window = [7, 8, 9, 10, 11, 12, 13] + half_window = window_size // 2 + start = self.current - half_window + end = self.current + half_window + calculated_window = range(start, end + 1) + self.window = [p for p in calculated_window if self._is_valid(p)] + + def _is_valid(self, page, default=None): + ''' checks if the given page is valid, returns default if not ''' + if self.count and self.first <= page <= self.last: + return page + return default + + def _ensure_int(self, value, default): + try: + return int(value) + except Exception: + return default + + +SortParameter = namedtuple('SortParameter', 'text field direction func') + + +class PaginationResourceMixin(object): + + sql_model_class = None + child_resource_class = None + default_sorting = None + default_items_per_page = 25 + + pages = None + sorting = None + filters = {} + + query_key_current_page = 'p' + query_key_items_per_page = 'n' + query_key_sorting = 'o' + + _base_query = None + + def __init__(self): + # first we need to remove non-filter parameters from GET + params = dict(self.request.GET) + page = params.pop(self.query_key_current_page, 1) + items = params.pop( + self.query_key_items_per_page, + self.default_items_per_page + ) + sort = params.pop(self.query_key_sorting, self.default_sorting) + + # we can now setup a base query with applied filters + self._base_query = self.prepare_filtered_query( + self.request.dbsession, + params + ) + + # with this base query, the pagination can be calculated: + count = self._base_query.count() + self.pages = Pagination(page, count, items) + + # and we should check that we can sort results later + self.sorting = self.parse_sort_parameters(sort) + if self.sorting is None: + msg = 'Error in default sorting {}'.format(self.default_sorting) + raise ValueError(msg) + + + def prepare_filtered_query(self, dbsession, filter_params): + ''' setup the base filtered query + + An example: + def prepare_filtered_query(self, dbsession, filter_params): + query = dbsession.query(self.sql_model_class) + by_username = filter_params.get('username', None) + if by_username is not None: + query = query.filter_by(username=by_username) + # don't forget to remember the filter + self.filters['username'] = by_username + return query + ''' + msg = 'Query setup must be implemented in child class' + raise NotImplementedError(msg) + + + def prepare_sorted_query(self, query, sorting): + ''' setup the base filtered query + + An example: + def prepare_sorted_query(self, query, sorting): + model_field = getattr(self.sql_model_class, sorting.field) + sort_func = sorting.func(model_field) + return query.order_by(sort_func) + ''' + msg = 'Query setup must be implemented in child class' + raise NotImplementedError(msg) + + + def parse_sort_parameters(self, sort_param): + sort_functions = { 'asc': asc, 'desc': desc} + try: + sort_param = sort_param.lower() + field, direction = sort_param.split('.', 1) + func = sort_functions[direction] + return SortParameter(sort_param, field, direction, func) + except (AttributeError, IndexError, KeyError, ValueError): + return None + + + def items(self): + ''' returns the items of the current page as resources''' + if not self.pages.count: + return + + offset = (self.pages.current - 1) * self.pages.items_per_page + + query = self.prepare_sorted_query(self._base_query, self.sorting) + query = query.offset(offset).limit(self.pages.items_per_page) + return [ + self.child_resource_class.from_sqla(item, self) + for item + in query.all() + ] + + + def query_params(self, *args, **kwargs): + params = { + self.query_key_current_page: self.pages.current, + self.query_key_items_per_page: self.pages.items, + self.query_key_sorting: self.sorting.text + } + params.update(self.filters) + params.update(args) + params.update(kwargs) + filtered = {k: v for k, v in params.items() if v is not None} + return filtered + + + def url(self, *args, **kwargs): + params = self.query_params(*args, **kwargs) + return self.request.resource_url(self, query=params) +