/files: make tests more robust

This commit is contained in:
dogeystamp 2023-04-27 21:16:10 -04:00
parent f97cfbbe33
commit 30899847fb
Signed by: dogeystamp
GPG Key ID: 7225FE3592EFFA38
5 changed files with 48 additions and 19 deletions

View File

@ -1,6 +1,6 @@
from flask import Blueprint, request, jsonify from flask import Blueprint, request, jsonify
from flask.views import MethodView from flask.views import MethodView
from sachet.server.models import ServerSettings, Permissions from sachet.server.models import ServerSettings, get_settings, Permissions
from sachet.server import db from sachet.server import db
from sachet.server.views_common import auth_required, ModelAPI from sachet.server.views_common import auth_required, ModelAPI
@ -9,28 +9,19 @@ admin_blueprint = Blueprint("admin_blueprint", __name__)
class ServerSettingsAPI(ModelAPI): class ServerSettingsAPI(ModelAPI):
def get_settings(self):
rows = ServerSettings.query.all()
if len(rows) == 0:
settings = ServerSettings()
db.session.add(settings)
db.session.commit()
return settings
return rows[-1]
@auth_required(required_permissions=(Permissions.ADMIN,)) @auth_required(required_permissions=(Permissions.ADMIN,))
def get(self, auth_user=None): def get(self, auth_user=None):
settings = self.get_settings() settings = get_settings()
return super().get(settings) return super().get(settings)
@auth_required(required_permissions=(Permissions.ADMIN,)) @auth_required(required_permissions=(Permissions.ADMIN,))
def patch(self, auth_user=None): def patch(self, auth_user=None):
settings = self.get_settings() settings = get_settings()
return super().patch(settings) return super().patch(settings)
@auth_required(required_permissions=(Permissions.ADMIN,)) @auth_required(required_permissions=(Permissions.ADMIN,))
def put(self, auth_user=None): def put(self, auth_user=None):
settings = self.get_settings() settings = get_settings()
return super().put(settings) return super().put(settings)

View File

@ -192,6 +192,17 @@ class ServerSettings(db.Model):
return Schema() return Schema()
def get_settings():
"""Return server settings, and create them if they don't exist."""
rows = ServerSettings.query.all()
if len(rows) == 0:
settings = ServerSettings()
db.session.add(settings)
db.session.commit()
return settings
return rows[-1]
class Share(db.Model): class Share(db.Model):
"""Share for a single file. """Share for a single file.

View File

@ -1,6 +1,6 @@
from flask import request, jsonify from flask import request, jsonify
from flask.views import MethodView from flask.views import MethodView
from sachet.server.models import Permissions, User, BlacklistToken, ServerSettings from sachet.server.models import Permissions, User, BlacklistToken, get_settings
from sachet.server import db from sachet.server import db
from functools import wraps from functools import wraps
from marshmallow import ValidationError from marshmallow import ValidationError
@ -40,7 +40,7 @@ def auth_required(func=None, *, required_permissions=(), allow_anonymous=False):
if not token: if not token:
if allow_anonymous: if allow_anonymous:
server_settings = ServerSettings.query.first() server_settings = get_settings()
if ( if (
Bitmask(AllFlags=Permissions, *required_permissions) Bitmask(AllFlags=Permissions, *required_permissions)
not in server_settings.default_permissions not in server_settings.default_permissions
@ -52,7 +52,7 @@ def auth_required(func=None, *, required_permissions=(), allow_anonymous=False):
"message": "Missing permissions to access this page.", "message": "Missing permissions to access this page.",
} }
), ),
403, 401,
) )
kwargs["auth_user"] = None kwargs["auth_user"] = None
return f(*args, **kwargs) return f(*args, **kwargs)
@ -101,7 +101,6 @@ def auth_required(func=None, *, required_permissions=(), allow_anonymous=False):
def patch(orig, diff): def patch(orig, diff):
"""Patch the dictionary orig recursively with the dictionary diff.""" """Patch the dictionary orig recursively with the dictionary diff."""
# if we get to a leaf node, just replace it # if we get to a leaf node, just replace it
if not isinstance(orig, dict) or not isinstance(diff, dict): if not isinstance(orig, dict) or not isinstance(diff, dict):
return diff return diff
@ -267,8 +266,14 @@ class ModelListAPI(MethodView):
Number of next page (if this is not the last). Number of next page (if this is not the last).
""" """
json_data = request.get_json() json_data = request.get_json()
per_page = int(json_data.get("per_page", 15)) try:
page = int(json_data.get("page", 1)) per_page = int(json_data.get("per_page", 15))
page = int(json_data.get("page", 1))
except ValueError as e:
return jsonify(dict(
status="fail",
message=str(e),
)), 400
page_data = ModelClass.query.paginate(page=page, per_page=per_page) page_data = ModelClass.query.paginate(page=page, per_page=per_page)
data = [model.get_schema().dump(model) for model in page_data] data = [model.get_schema().dump(model) for model in page_data]

View File

@ -70,6 +70,12 @@ def test_files(client, auth, rand):
url, url,
json={"file_name": "new_bin.bin"}, json={"file_name": "new_bin.bin"},
) )
# set read perm for anon users
resp = client.patch(
"/admin/settings",
headers=auth("administrator"),
json={"default_permissions": ["READ"]},
)
assert resp.status_code == 200 assert resp.status_code == 200
resp = client.get( resp = client.get(
url + "/content", url + "/content",
@ -110,6 +116,13 @@ def test_files(client, auth, rand):
) )
assert resp.status_code == 200 assert resp.status_code == 200
# set delete perm for anon users
resp = client.patch(
"/admin/settings",
headers=auth("administrator"),
json={"default_permissions": ["READ"]},
)
assert resp.status_code == 200
# file shouldn't exist anymore # file shouldn't exist anymore
resp = client.get( resp = client.get(
url + "/content", url + "/content",

View File

@ -75,3 +75,12 @@ def test_files(client, users, auth):
end_page = paginate(forwards=True) end_page = paginate(forwards=True)
paginate(forwards=False, page=end_page) paginate(forwards=False, page=end_page)
def test_invalid(client, auth):
"""Test invalid requests to pagination."""
resp = client.get(
"/files", headers=auth("jeff"), json=dict(page="one", per_page="two")
)
assert resp.status_code == 400