/files: make tests more robust

This commit is contained in:
dogeystamp 2023-04-27 21:16:10 -04:00
parent f97cfbbe33
commit 30899847fb
Signed by: dogeystamp
GPG Key ID: 7225FE3592EFFA38
5 changed files with 48 additions and 19 deletions

View File

@ -1,6 +1,6 @@
from flask import Blueprint, request, jsonify
from flask.views import MethodView
from sachet.server.models import ServerSettings, Permissions
from sachet.server.models import ServerSettings, get_settings, Permissions
from sachet.server import db
from sachet.server.views_common import auth_required, ModelAPI
@ -9,28 +9,19 @@ admin_blueprint = Blueprint("admin_blueprint", __name__)
class ServerSettingsAPI(ModelAPI):
def get_settings(self):
rows = ServerSettings.query.all()
if len(rows) == 0:
settings = ServerSettings()
db.session.add(settings)
db.session.commit()
return settings
return rows[-1]
@auth_required(required_permissions=(Permissions.ADMIN,))
def get(self, auth_user=None):
settings = self.get_settings()
settings = get_settings()
return super().get(settings)
@auth_required(required_permissions=(Permissions.ADMIN,))
def patch(self, auth_user=None):
settings = self.get_settings()
settings = get_settings()
return super().patch(settings)
@auth_required(required_permissions=(Permissions.ADMIN,))
def put(self, auth_user=None):
settings = self.get_settings()
settings = get_settings()
return super().put(settings)

View File

@ -192,6 +192,17 @@ class ServerSettings(db.Model):
return Schema()
def get_settings():
"""Return server settings, and create them if they don't exist."""
rows = ServerSettings.query.all()
if len(rows) == 0:
settings = ServerSettings()
db.session.add(settings)
db.session.commit()
return settings
return rows[-1]
class Share(db.Model):
"""Share for a single file.

View File

@ -1,6 +1,6 @@
from flask import request, jsonify
from flask.views import MethodView
from sachet.server.models import Permissions, User, BlacklistToken, ServerSettings
from sachet.server.models import Permissions, User, BlacklistToken, get_settings
from sachet.server import db
from functools import wraps
from marshmallow import ValidationError
@ -40,7 +40,7 @@ def auth_required(func=None, *, required_permissions=(), allow_anonymous=False):
if not token:
if allow_anonymous:
server_settings = ServerSettings.query.first()
server_settings = get_settings()
if (
Bitmask(AllFlags=Permissions, *required_permissions)
not in server_settings.default_permissions
@ -52,7 +52,7 @@ def auth_required(func=None, *, required_permissions=(), allow_anonymous=False):
"message": "Missing permissions to access this page.",
}
),
403,
401,
)
kwargs["auth_user"] = None
return f(*args, **kwargs)
@ -101,7 +101,6 @@ def auth_required(func=None, *, required_permissions=(), allow_anonymous=False):
def patch(orig, diff):
"""Patch the dictionary orig recursively with the dictionary diff."""
# if we get to a leaf node, just replace it
if not isinstance(orig, dict) or not isinstance(diff, dict):
return diff
@ -267,8 +266,14 @@ class ModelListAPI(MethodView):
Number of next page (if this is not the last).
"""
json_data = request.get_json()
per_page = int(json_data.get("per_page", 15))
page = int(json_data.get("page", 1))
try:
per_page = int(json_data.get("per_page", 15))
page = int(json_data.get("page", 1))
except ValueError as e:
return jsonify(dict(
status="fail",
message=str(e),
)), 400
page_data = ModelClass.query.paginate(page=page, per_page=per_page)
data = [model.get_schema().dump(model) for model in page_data]

View File

@ -70,6 +70,12 @@ def test_files(client, auth, rand):
url,
json={"file_name": "new_bin.bin"},
)
# set read perm for anon users
resp = client.patch(
"/admin/settings",
headers=auth("administrator"),
json={"default_permissions": ["READ"]},
)
assert resp.status_code == 200
resp = client.get(
url + "/content",
@ -110,6 +116,13 @@ def test_files(client, auth, rand):
)
assert resp.status_code == 200
# set delete perm for anon users
resp = client.patch(
"/admin/settings",
headers=auth("administrator"),
json={"default_permissions": ["READ"]},
)
assert resp.status_code == 200
# file shouldn't exist anymore
resp = client.get(
url + "/content",

View File

@ -75,3 +75,12 @@ def test_files(client, users, auth):
end_page = paginate(forwards=True)
paginate(forwards=False, page=end_page)
def test_invalid(client, auth):
"""Test invalid requests to pagination."""
resp = client.get(
"/files", headers=auth("jeff"), json=dict(page="one", per_page="two")
)
assert resp.status_code == 400