From a0e163d8bdcb814350633574e5537ab3bf731f15 Mon Sep 17 00:00:00 2001 From: dogeystamp Date: Wed, 26 Apr 2023 19:51:09 -0400 Subject: [PATCH] /files: add GET with pagination --- sachet/server/files/views.py | 6 ++- sachet/server/models.py | 2 + sachet/server/views_common.py | 37 ++++++++++++++++ tests/conftest.py | 3 +- tests/test_pagination.py | 80 +++++++++++++++++++++++++++++++++++ 5 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 tests/test_pagination.py diff --git a/sachet/server/files/views.py b/sachet/server/files/views.py index bf6e4b7..6a44bb3 100644 --- a/sachet/server/files/views.py +++ b/sachet/server/files/views.py @@ -45,11 +45,15 @@ class FilesAPI(ModelListAPI): data["owner_name"] = auth_user.username return super().post(Share, data) + @auth_required(required_permissions=(Permissions.LIST,)) + def get(self, auth_user=None): + return super().get(Share) + files_blueprint.add_url_rule( "/files", view_func=FilesAPI.as_view("files_api"), - methods=["POST"], + methods=["POST", "GET"], ) diff --git a/sachet/server/models.py b/sachet/server/models.py index ede8fcd..1c6c6c8 100644 --- a/sachet/server/models.py +++ b/sachet/server/models.py @@ -213,6 +213,8 @@ class Share(db.Model): Time the share was created (not initialized.) file_name : str File name to download as. + url : str + URL linking to this object. Methods ------- diff --git a/sachet/server/views_common.py b/sachet/server/views_common.py index 48a13da..9de0c82 100644 --- a/sachet/server/views_common.py +++ b/sachet/server/views_common.py @@ -219,3 +219,40 @@ class ModelListAPI(MethodView): db.session.commit() return jsonify({"status": "success", "url": model.url}), 201 + + def get(self, ModelClass): + """List a given range of instances. + + Parameters + ---------- + ModelClass + Model class to query. + + JSON Parameters + --------------- + per_page : int + Amount of entries to return in one query. + page : int + Incrementing this reads the next `per_page` entries. + + Returns + ------- + data : list of dict + All requested entries. + prev : int or None + Number of previous page (if this is not the first). + next : int or None + Number of next page (if this is not the last). + """ + json_data = request.get_json() + per_page = int(json_data.get("per_page", 15)) + page = int(json_data.get("page", 1)) + + page_data = ModelClass.query.paginate(page=page, per_page=per_page) + data = [model.get_schema().dump(model) for model in page_data] + + return jsonify(dict( + data=data, + prev=page_data.prev_num, + next=page_data.next_num, + )) diff --git a/tests/conftest.py b/tests/conftest.py index 31f8b66..151c67b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,7 +57,7 @@ def flask_app_bare(): @pytest.fixture def users(client): - """Creates all the test users. + """Create all the test users. Returns a dictionary with all the info for each user. """ @@ -69,6 +69,7 @@ def users(client): Permissions.READ, Permissions.DELETE, Permissions.MODIFY, + Permissions.LIST, ), ), dave=dict( diff --git a/tests/test_pagination.py b/tests/test_pagination.py new file mode 100644 index 0000000..0eb28e2 --- /dev/null +++ b/tests/test_pagination.py @@ -0,0 +1,80 @@ +import pytest +from math import ceil +import json + +"""Test ability to paginate endpoint responses.""" + + +def test_files(client, users, auth): + """Test /files endpoint.""" + + # create multiple shares + shares = set() + share_count = 20 + + for i in range(share_count): + resp = client.post( + "/files", headers=auth("jeff"), json={"file_name": "content.bin"} + ) + assert resp.status_code == 201 + data = resp.get_json() + share_id = data.get("url").split("/")[-1] + shares.add(share_id) + + # we'll paginate through all shares and ensure that: + # - we see all the shares + # - no shares are seen twice + + def paginate(forwards=True, page=1): + """Goes through the pages. + + Parameters + ---------- + forwards : bool, optional + Set direction to paginate in. + page : int + Page to start at. + + Returns + ------- + int + Page we ended at. + """ + seen = set() + + per_page = 9 + while page is not None: + resp = client.get( + "/files", headers=auth("jeff"), json=dict( + page=page, + per_page=per_page + ) + ) + assert resp.status_code == 200 + + data = resp.get_json().get("data") + assert len(data) == per_page or len(data) == share_count % per_page + + for share in data: + share_id = share.get("share_id") + assert share_id in shares + assert share_id not in seen + seen.add(share_id) + + if forwards: + new_page = resp.get_json().get("next") + assert new_page == page + 1 or new_page is None + else: + new_page = resp.get_json().get("prev") + assert new_page == page - 1 or new_page is None + if new_page is not None: + page = new_page + else: + break + + assert seen == shares + + return page + + end_page = paginate(forwards=True) + paginate(forwards=False, page=end_page)