From ee6c7d4fa5c238048e72aa5206993193a4a9039a Mon Sep 17 00:00:00 2001 From: dogeystamp Date: Tue, 28 Mar 2023 21:48:09 -0400 Subject: [PATCH] /users/: added PATCH endpoint --- sachet/server/models.py | 16 ++++++++ sachet/server/users/views.py | 38 +++++++++++++++++- tests/test_models.py | 74 ++++++++++++++++++++++++++++++++++++ tests/test_userinfo.py | 49 +++++++++++++++++++++++- 4 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 tests/test_models.py diff --git a/sachet/server/models.py b/sachet/server/models.py index ca5927b..b5a495d 100644 --- a/sachet/server/models.py +++ b/sachet/server/models.py @@ -17,6 +17,22 @@ class Permissions(IntFlag): ADMIN = 1<<5 +def patch(orig, diff): + """Patch the dictionary orig recursively with the dictionary diff.""" + + # if we get to a leaf node, just replace it + if not isinstance(orig, dict) or not isinstance(diff, dict): + return diff + + # deep copy + new = {k:v for k, v in orig.items()} + + for key, value in diff.items(): + new[key] = patch(orig.get(key, {}), diff[key]) + + return new + + class User(db.Model): __tablename__ = "users" diff --git a/sachet/server/users/views.py b/sachet/server/users/views.py index 5113b56..48b218a 100644 --- a/sachet/server/users/views.py +++ b/sachet/server/users/views.py @@ -1,8 +1,9 @@ import jwt from flask import Blueprint, request, jsonify from flask.views import MethodView -from sachet.server.models import auth_required, read_token, Permissions, User, UserSchema, BlacklistToken +from sachet.server.models import auth_required, read_token, patch, Permissions, User, UserSchema, BlacklistToken from sachet.server import bcrypt, db +from marshmallow import ValidationError user_schema = UserSchema() @@ -118,8 +119,41 @@ class UserAPI(MethodView): return jsonify(user_schema.dump(info_user)) + @auth_required + def patch(user, self, username): + patch_user = User.query.filter_by(username=username).first() + + if Permissions.ADMIN not in user.permissions: + resp = { + "status": "fail", + "message": "You are not authorized to access this page." + } + return jsonify(resp), 403 + + patch_json = request.get_json() + orig_json = user_schema.dump(patch_user) + + new_json = patch(orig_json, patch_json) + + try: + deserialized = user_schema.load(new_json) + except ValidationError as e: + resp = { + "status": "fail", + "message": f"Invalid patch: {str(e)}" + } + return jsonify(resp), 400 + + for k, v in deserialized.items(): + setattr(patch_user, k, v) + + resp = { + "status": "success", + } + return jsonify(resp), 200 + users_blueprint.add_url_rule( "/users/", view_func=UserAPI.as_view("user_api"), - methods=['GET'] + methods=['GET', 'PATCH'] ) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..3686d4a --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,74 @@ +from sachet.server.models import patch + +def test_patch(): + """Tests sachet/server/models.py's patch() method for dicts.""" + + assert patch( + dict(), + dict() + ) == dict() + + assert patch( + dict(key="value"), + dict() + ) == dict(key="value") + + assert patch( + dict(key="value"), + dict(key="newvalue") + ) == dict(key="newvalue") + + assert patch( + dict(key="value"), + dict(key="newvalue") + ) == dict(key="newvalue") + + assert patch( + dict(key="value"), + dict(key2="other_value") + ) == dict( + key="value", + key2="other_value" + ) + + assert patch( + dict( + nest = dict( + key="value", + key2="other_value" + ) + ), + dict( + top_key="newvalue", + nest = dict( + key2 = "new_other_value" + ) + ) + ) == dict( + top_key="newvalue", + nest = dict( + key="value", + key2="new_other_value" + ) + ) + + assert patch( + dict( + nest = dict( + key="value", + list=[1, 2, 3, 4, 5] + ) + ), + dict( + top_key="newvalue", + nest = dict( + list = [3, 1, 4, 1, 5] + ) + ) + ) == dict( + top_key="newvalue", + nest = dict( + key="value", + list=[3, 1, 4, 1, 5] + ) + ) diff --git a/tests/test_userinfo.py b/tests/test_userinfo.py index f50f894..effaca0 100644 --- a/tests/test_userinfo.py +++ b/tests/test_userinfo.py @@ -1,6 +1,8 @@ import pytest +from bitmask import Bitmask +from sachet.server.models import Permissions -def test_userinfo(client, tokens, validate_info): +def test_get(client, tokens, validate_info): """Test accessing the user information endpoint as a normal user.""" # access user info endpoint @@ -45,3 +47,48 @@ def test_userinfo_admin(client, tokens, validate_info): ) assert resp.status_code == 200 validate_info("jeff", resp.get_json()) + +def test_patch(client, users, tokens, validate_info): + """Test modifying user information as an administrator.""" + + # try with regular user to make sure it doesn't work + resp = client.patch( + "/users/jeff", + json = { "permissions": ["ADMIN"] }, + headers={ + "Authorization": f"bearer {tokens['jeff']}" + } + ) + assert resp.status_code == 403 + + # test malformed patch + resp = client.patch( + "/users/jeff", + json = "hurr durr", + headers={ + "Authorization": f"bearer {tokens['administrator']}" + } + ) + assert resp.status_code == 400 + + resp = client.patch( + "/users/jeff", + json = { "permissions": ["ADMIN"] }, + headers={ + "Authorization": f"bearer {tokens['administrator']}" + } + ) + assert resp.status_code == 200 + + # modify the expected values + users["jeff"]["permissions"] = Bitmask(Permissions.ADMIN) + + # request new info + resp = client.get( + "/users/jeff", + headers={ + "Authorization": f"bearer {tokens['jeff']}" + } + ) + assert resp.status_code == 200 + validate_info("jeff", resp.get_json())