diff --git a/sachet/server/models.py b/sachet/server/models.py index bc5f1bc..0ebf381 100644 --- a/sachet/server/models.py +++ b/sachet/server/models.py @@ -17,20 +17,28 @@ class Permissions(IntFlag): ADMIN = 1 << 5 -def patch(orig, diff): - """Patch the dictionary orig recursively with the dictionary diff.""" +class PermissionField(fields.Field): + """Field that serializes a Permissions bitmask to an array of strings.""" - # if we get to a leaf node, just replace it - if not isinstance(orig, dict) or not isinstance(diff, dict): - return diff + def _serialize(self, value, attr, obj, **kwargs): + mask = Bitmask() + mask.AllFlags = Permissions + mask += value + return [flag.name for flag in mask] - # deep copy - new = {k: v for k, v in orig.items()} + def _deserialize(self, value, attr, data, **kwargs): + mask = Bitmask() + mask.AllFlags = Permissions - for key, value in diff.items(): - new[key] = patch(orig.get(key, {}), diff[key]) + flags = value - return new + try: + for flag in flags: + mask.add(Permissions[flag]) + except KeyError as e: + raise ValidationError("Invalid permission.") from e + + return mask class User(db.Model): @@ -86,7 +94,6 @@ class User(db.Model): } return jwt.encode(payload, app.config.get("SECRET_KEY"), algorithm="HS256") - def read_token(token): """Read a JWT and validate it. @@ -109,38 +116,16 @@ class User(db.Model): return data, user + def get_schema(self): + class Schema(ma.SQLAlchemySchema): + class Meta: + model = self -class PermissionField(fields.Field): - """Field that serializes a Permissions bitmask to an array of strings.""" + username = ma.auto_field() + register_date = ma.auto_field() + permissions = PermissionField(data_key="permissions") - def _serialize(self, value, attr, obj, **kwargs): - mask = Bitmask() - mask.AllFlags = Permissions - mask += value - return [flag.name for flag in mask] - - def _deserialize(self, value, attr, data, **kwargs): - mask = Bitmask() - mask.AllFlags = Permissions - - flags = value - - try: - for flag in flags: - mask.add(Permissions[flag]) - except KeyError as e: - raise ValidationError("Invalid permission.") from e - - return mask - - -class UserSchema(ma.SQLAlchemySchema): - class Meta: - model = User - - username = ma.auto_field() - register_date = ma.auto_field() - permissions = PermissionField(data_key="permissions") + return Schema() class BlacklistToken(db.Model): @@ -176,38 +161,3 @@ class BlacklistToken(db.Model): if entry.expires < datetime.datetime.utcnow(): db.session.delete(entry) return True - - -def auth_required(f): - """Decorator to require authentication. - - Passes an argument 'user' to the function, with a User object corresponding - to the authenticated session. - """ - - @wraps(f) - def decorator(*args, **kwargs): - token = None - auth_header = request.headers.get("Authorization") - if auth_header: - try: - token = auth_header.split(" ")[1] - except IndexError: - resp = {"status": "fail", "message": "Malformed Authorization header."} - return jsonify(resp), 401 - - if not token: - return jsonify({"status": "fail", "message": "Missing auth token"}), 401 - - try: - data, user = User.read_token(token) - except jwt.ExpiredSignatureError: - # if it's expired we don't want it lingering in the db - BlacklistToken.check_blacklist(token) - return jsonify({"status": "fail", "message": "Token has expired."}), 401 - except jwt.InvalidTokenError: - return jsonify({"status": "fail", "message": "Invalid auth token."}), 401 - - return f(user, *args, **kwargs) - - return decorator diff --git a/sachet/server/users/views.py b/sachet/server/users/views.py index 170ffe2..4ab1609 100644 --- a/sachet/server/users/views.py +++ b/sachet/server/users/views.py @@ -2,17 +2,12 @@ import jwt from flask import Blueprint, request, jsonify from flask.views import MethodView from sachet.server.models import ( - auth_required, - patch, Permissions, User, - UserSchema, BlacklistToken, ) +from sachet.server.views_common import ModelAPI, auth_required from sachet.server import bcrypt, db -from marshmallow import ValidationError - -user_schema = UserSchema() users_blueprint = Blueprint("users_blueprint", __name__) @@ -51,7 +46,7 @@ class LogoutAPI(MethodView): """Endpoint to revoke a user's token.""" @auth_required - def post(user, self): + def post(self, auth_user=None): post_data = request.get_json() token = post_data.get("token") if not token: @@ -71,7 +66,7 @@ class LogoutAPI(MethodView): except jwt.InvalidTokenError: return jsonify({"status": "fail", "message": "Invalid auth token."}), 400 - if user == token_user or Permissions.ADMIN in user.permissions: + if auth_user == token_user or Permissions.ADMIN in auth_user.permissions: entry = BlacklistToken(token=token) db.session.add(entry) db.session.commit() @@ -97,12 +92,12 @@ class ExtendAPI(MethodView): """Endpoint to take a token and get a new one with a later expiry date.""" @auth_required - def post(user, self): - token = user.encode_token(jti="renew") + def post(self, auth_user=None): + token = auth_user.encode_token(jti="renew") resp = { "status": "success", "message": "Renewed token.", - "username": user.username, + "username": auth_user.username, "auth_token": token, } return jsonify(resp), 200 @@ -113,14 +108,15 @@ users_blueprint.add_url_rule( ) -class UserAPI(MethodView): +class UserAPI(ModelAPI): """User information API""" @auth_required - def get(user, self, username): + def get(self, username, auth_user=None): info_user = User.query.filter_by(username=username).first() + # only allow user to query themselves, but admin can query anyone if (not info_user) or ( - info_user != user and Permissions.ADMIN not in user.permissions + info_user != auth_user and Permissions.ADMIN not in auth_user.permissions ): resp = { "status": "fail", @@ -128,64 +124,17 @@ class UserAPI(MethodView): } return jsonify(resp), 403 - return jsonify(user_schema.dump(info_user)) + return super().get(info_user) - @auth_required - def patch(user, self, username): + @auth_required(require_admin=True) + def patch(self, username, auth_user=None): patch_user = User.query.filter_by(username=username).first() + return super().patch(patch_user) - if not patch_user or Permissions.ADMIN not in user.permissions: - resp = { - "status": "fail", - "message": "You are not authorized to access this page.", - } - return jsonify(resp), 403 - - patch_json = request.get_json() - orig_json = user_schema.dump(patch_user) - - new_json = patch(orig_json, patch_json) - - try: - deserialized = user_schema.load(new_json) - except ValidationError as e: - resp = {"status": "fail", "message": f"Invalid patch: {str(e)}"} - return jsonify(resp), 400 - - for k, v in deserialized.items(): - setattr(patch_user, k, v) - - resp = { - "status": "success", - } - return jsonify(resp), 200 - - @auth_required - def put(user, self, username): + @auth_required(require_admin=True) + def put(self, username, auth_user=None): put_user = User.query.filter_by(username=username).first() - - if not put_user or Permissions.ADMIN not in user.permissions: - resp = { - "status": "fail", - "message": "You are not authorized to access this page.", - } - return jsonify(resp), 403 - - new_json = request.get_json() - - try: - deserialized = user_schema.load(new_json) - except ValidationError as e: - resp = {"status": "fail", "message": f"Invalid data: {str(e)}"} - return jsonify(resp), 400 - - for k, v in deserialized.items(): - setattr(put_user, k, v) - - resp = { - "status": "success", - } - return jsonify(resp), 200 + return super().put(put_user) users_blueprint.add_url_rule( diff --git a/sachet/server/views_common.py b/sachet/server/views_common.py new file mode 100644 index 0000000..60142a5 --- /dev/null +++ b/sachet/server/views_common.py @@ -0,0 +1,144 @@ +from flask import request, jsonify +from flask.views import MethodView +from sachet.server.models import Permissions, User, BlacklistToken +from functools import wraps +from marshmallow import ValidationError +import jwt + + +def auth_required(func=None, *, require_admin=False): + """Decorator to require authentication. + + Passes an argument 'user' to the function, with a User object corresponding + to the authenticated session. + """ + + # see https://stackoverflow.com/questions/3888158/making-decorators-with-optional-arguments + def _decorate(f): + @wraps(f) + def decorator(*args, **kwargs): + token = None + auth_header = request.headers.get("Authorization") + if auth_header: + try: + token = auth_header.split(" ")[1] + except IndexError: + resp = { + "status": "fail", + "message": "Malformed Authorization header.", + } + return jsonify(resp), 401 + + if not token: + return jsonify({"status": "fail", "message": "Missing auth token"}), 401 + + try: + data, user = User.read_token(token) + except jwt.ExpiredSignatureError: + # if it's expired we don't want it lingering in the db + BlacklistToken.check_blacklist(token) + return jsonify({"status": "fail", "message": "Token has expired."}), 401 + except jwt.InvalidTokenError: + return ( + jsonify({"status": "fail", "message": "Invalid auth token."}), + 401, + ) + + if require_admin and Permissions.ADMIN not in user.permissions: + return ( + jsonify( + { + "status": "fail", + "message": "Administrator permission is required to see this page.", + } + ), + 403, + ) + + kwargs["auth_user"] = user + return f(*args, **kwargs) + + return decorator + + if func: + return _decorate(func) + + return _decorate + + +def patch(orig, diff): + """Patch the dictionary orig recursively with the dictionary diff.""" + + # if we get to a leaf node, just replace it + if not isinstance(orig, dict) or not isinstance(diff, dict): + return diff + + # deep copy + new = {k: v for k, v in orig.items()} + + for key, value in diff.items(): + new[key] = patch(orig.get(key, {}), diff[key]) + + return new + + +class ModelAPI(MethodView): + """Generic REST API for interacting with models.""" + + def get(self, model): + return jsonify(model.get_schema().dump(model)) + + def patch(self, model): + model_schema = model.get_schema() + + if not model: + resp = { + "status": "fail", + "message": "This resource does not exist.", + } + return jsonify(resp), 404 + + patch_json = request.get_json() + orig_json = model_schema.dump(model) + + new_json = patch(orig_json, patch_json) + + try: + deserialized = model_schema.load(new_json) + except ValidationError as e: + resp = {"status": "fail", "message": f"Invalid patch: {str(e)}"} + return jsonify(resp), 400 + + for k, v in deserialized.items(): + setattr(model, k, v) + + resp = { + "status": "success", + } + return jsonify(resp), 200 + + def put(self, model): + if not model: + resp = { + "status": "fail", + "message": "This resource does not exist.", + } + return jsonify(resp), 404 + + model_schema = model.get_schema() + + new_json = request.get_json() + + try: + deserialized = model_schema.load(new_json) + except ValidationError as e: + resp = {"status": "fail", "message": f"Invalid data: {str(e)}"} + return jsonify(resp), 400 + + for k, v in deserialized.items(): + setattr(model, k, v) + + resp = { + "status": "success", + } + return jsonify(resp), 200 diff --git a/tests/conftest.py b/tests/conftest.py index 8175a4f..d68e41c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,11 +3,9 @@ import yaml from sachet.server.users import manage from click.testing import CliRunner from sachet.server import app, db -from sachet.server.models import Permissions, UserSchema +from sachet.server.models import Permissions, User from bitmask import Bitmask -user_schema = UserSchema() - @pytest.fixture def client(): @@ -59,7 +57,7 @@ def validate_info(users): ] def _validate(user, info): - info = user_schema.load(info) + info = User.get_schema(User).load(info) for k in verify_fields: assert users[user][k] == info[k] diff --git a/tests/test_models.py b/tests/test_models.py index 9847fa6..7eeb09f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,4 @@ -from sachet.server.models import patch +from sachet.server.views_common import patch def test_patch(): diff --git a/tests/test_userinfo.py b/tests/test_userinfo.py index bd9a9a8..fa86089 100644 --- a/tests/test_userinfo.py +++ b/tests/test_userinfo.py @@ -1,9 +1,9 @@ import pytest from bitmask import Bitmask -from sachet.server.models import Permissions, UserSchema +from sachet.server.models import Permissions, User from datetime import datetime -user_schema = UserSchema() +user_schema = User.get_schema(User) def test_get(client, tokens, validate_info):