From 30899847fb9fcb797d9913aa8a38e0872ffb5bf1 Mon Sep 17 00:00:00 2001 From: dogeystamp Date: Thu, 27 Apr 2023 21:16:10 -0400 Subject: [PATCH] /files: make tests more robust --- sachet/server/admin/views.py | 17 ++++------------- sachet/server/models.py | 11 +++++++++++ sachet/server/views_common.py | 17 +++++++++++------ tests/test_anonymous.py | 13 +++++++++++++ tests/test_pagination.py | 9 +++++++++ 5 files changed, 48 insertions(+), 19 deletions(-) diff --git a/sachet/server/admin/views.py b/sachet/server/admin/views.py index ba839f6..f70dd5a 100644 --- a/sachet/server/admin/views.py +++ b/sachet/server/admin/views.py @@ -1,6 +1,6 @@ from flask import Blueprint, request, jsonify from flask.views import MethodView -from sachet.server.models import ServerSettings, Permissions +from sachet.server.models import ServerSettings, get_settings, Permissions from sachet.server import db from sachet.server.views_common import auth_required, ModelAPI @@ -9,28 +9,19 @@ admin_blueprint = Blueprint("admin_blueprint", __name__) class ServerSettingsAPI(ModelAPI): - def get_settings(self): - rows = ServerSettings.query.all() - if len(rows) == 0: - settings = ServerSettings() - db.session.add(settings) - db.session.commit() - return settings - return rows[-1] - @auth_required(required_permissions=(Permissions.ADMIN,)) def get(self, auth_user=None): - settings = self.get_settings() + settings = get_settings() return super().get(settings) @auth_required(required_permissions=(Permissions.ADMIN,)) def patch(self, auth_user=None): - settings = self.get_settings() + settings = get_settings() return super().patch(settings) @auth_required(required_permissions=(Permissions.ADMIN,)) def put(self, auth_user=None): - settings = self.get_settings() + settings = get_settings() return super().put(settings) diff --git a/sachet/server/models.py b/sachet/server/models.py index a270a52..a097793 100644 --- a/sachet/server/models.py +++ b/sachet/server/models.py @@ -192,6 +192,17 @@ class ServerSettings(db.Model): return Schema() +def get_settings(): + """Return server settings, and create them if they don't exist.""" + rows = ServerSettings.query.all() + if len(rows) == 0: + settings = ServerSettings() + db.session.add(settings) + db.session.commit() + return settings + return rows[-1] + + class Share(db.Model): """Share for a single file. diff --git a/sachet/server/views_common.py b/sachet/server/views_common.py index cbad803..e87ac22 100644 --- a/sachet/server/views_common.py +++ b/sachet/server/views_common.py @@ -1,6 +1,6 @@ from flask import request, jsonify from flask.views import MethodView -from sachet.server.models import Permissions, User, BlacklistToken, ServerSettings +from sachet.server.models import Permissions, User, BlacklistToken, get_settings from sachet.server import db from functools import wraps from marshmallow import ValidationError @@ -40,7 +40,7 @@ def auth_required(func=None, *, required_permissions=(), allow_anonymous=False): if not token: if allow_anonymous: - server_settings = ServerSettings.query.first() + server_settings = get_settings() if ( Bitmask(AllFlags=Permissions, *required_permissions) not in server_settings.default_permissions @@ -52,7 +52,7 @@ def auth_required(func=None, *, required_permissions=(), allow_anonymous=False): "message": "Missing permissions to access this page.", } ), - 403, + 401, ) kwargs["auth_user"] = None return f(*args, **kwargs) @@ -101,7 +101,6 @@ def auth_required(func=None, *, required_permissions=(), allow_anonymous=False): def patch(orig, diff): """Patch the dictionary orig recursively with the dictionary diff.""" - # if we get to a leaf node, just replace it if not isinstance(orig, dict) or not isinstance(diff, dict): return diff @@ -267,8 +266,14 @@ class ModelListAPI(MethodView): Number of next page (if this is not the last). """ json_data = request.get_json() - per_page = int(json_data.get("per_page", 15)) - page = int(json_data.get("page", 1)) + try: + per_page = int(json_data.get("per_page", 15)) + page = int(json_data.get("page", 1)) + except ValueError as e: + return jsonify(dict( + status="fail", + message=str(e), + )), 400 page_data = ModelClass.query.paginate(page=page, per_page=per_page) data = [model.get_schema().dump(model) for model in page_data] diff --git a/tests/test_anonymous.py b/tests/test_anonymous.py index 70ccc00..89f3a2e 100644 --- a/tests/test_anonymous.py +++ b/tests/test_anonymous.py @@ -70,6 +70,12 @@ def test_files(client, auth, rand): url, json={"file_name": "new_bin.bin"}, ) + # set read perm for anon users + resp = client.patch( + "/admin/settings", + headers=auth("administrator"), + json={"default_permissions": ["READ"]}, + ) assert resp.status_code == 200 resp = client.get( url + "/content", @@ -110,6 +116,13 @@ def test_files(client, auth, rand): ) assert resp.status_code == 200 + # set delete perm for anon users + resp = client.patch( + "/admin/settings", + headers=auth("administrator"), + json={"default_permissions": ["READ"]}, + ) + assert resp.status_code == 200 # file shouldn't exist anymore resp = client.get( url + "/content", diff --git a/tests/test_pagination.py b/tests/test_pagination.py index c37bf02..49917fd 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -75,3 +75,12 @@ def test_files(client, users, auth): end_page = paginate(forwards=True) paginate(forwards=False, page=end_page) + + +def test_invalid(client, auth): + """Test invalid requests to pagination.""" + + resp = client.get( + "/files", headers=auth("jeff"), json=dict(page="one", per_page="two") + ) + assert resp.status_code == 400