refactor: split off generic model REST API logic

also, a model's marshmallow schema class is now within the model class
This commit is contained in:
dogeystamp 2023-04-01 17:56:25 -04:00
parent a53fba5f4b
commit 68f56dd1a1
Signed by: dogeystamp
GPG Key ID: 7225FE3592EFFA38
6 changed files with 192 additions and 151 deletions

View File

@ -17,20 +17,28 @@ class Permissions(IntFlag):
ADMIN = 1 << 5
def patch(orig, diff):
"""Patch the dictionary orig recursively with the dictionary diff."""
class PermissionField(fields.Field):
"""Field that serializes a Permissions bitmask to an array of strings."""
# if we get to a leaf node, just replace it
if not isinstance(orig, dict) or not isinstance(diff, dict):
return diff
def _serialize(self, value, attr, obj, **kwargs):
mask = Bitmask()
mask.AllFlags = Permissions
mask += value
return [flag.name for flag in mask]
# deep copy
new = {k: v for k, v in orig.items()}
def _deserialize(self, value, attr, data, **kwargs):
mask = Bitmask()
mask.AllFlags = Permissions
for key, value in diff.items():
new[key] = patch(orig.get(key, {}), diff[key])
flags = value
return new
try:
for flag in flags:
mask.add(Permissions[flag])
except KeyError as e:
raise ValidationError("Invalid permission.") from e
return mask
class User(db.Model):
@ -86,7 +94,6 @@ class User(db.Model):
}
return jwt.encode(payload, app.config.get("SECRET_KEY"), algorithm="HS256")
def read_token(token):
"""Read a JWT and validate it.
@ -109,39 +116,17 @@ class User(db.Model):
return data, user
class PermissionField(fields.Field):
"""Field that serializes a Permissions bitmask to an array of strings."""
def _serialize(self, value, attr, obj, **kwargs):
mask = Bitmask()
mask.AllFlags = Permissions
mask += value
return [flag.name for flag in mask]
def _deserialize(self, value, attr, data, **kwargs):
mask = Bitmask()
mask.AllFlags = Permissions
flags = value
try:
for flag in flags:
mask.add(Permissions[flag])
except KeyError as e:
raise ValidationError("Invalid permission.") from e
return mask
class UserSchema(ma.SQLAlchemySchema):
def get_schema(self):
class Schema(ma.SQLAlchemySchema):
class Meta:
model = User
model = self
username = ma.auto_field()
register_date = ma.auto_field()
permissions = PermissionField(data_key="permissions")
return Schema()
class BlacklistToken(db.Model):
"""Token that has been revoked (but has not expired yet.)
@ -176,38 +161,3 @@ class BlacklistToken(db.Model):
if entry.expires < datetime.datetime.utcnow():
db.session.delete(entry)
return True
def auth_required(f):
"""Decorator to require authentication.
Passes an argument 'user' to the function, with a User object corresponding
to the authenticated session.
"""
@wraps(f)
def decorator(*args, **kwargs):
token = None
auth_header = request.headers.get("Authorization")
if auth_header:
try:
token = auth_header.split(" ")[1]
except IndexError:
resp = {"status": "fail", "message": "Malformed Authorization header."}
return jsonify(resp), 401
if not token:
return jsonify({"status": "fail", "message": "Missing auth token"}), 401
try:
data, user = User.read_token(token)
except jwt.ExpiredSignatureError:
# if it's expired we don't want it lingering in the db
BlacklistToken.check_blacklist(token)
return jsonify({"status": "fail", "message": "Token has expired."}), 401
except jwt.InvalidTokenError:
return jsonify({"status": "fail", "message": "Invalid auth token."}), 401
return f(user, *args, **kwargs)
return decorator

View File

@ -2,17 +2,12 @@ import jwt
from flask import Blueprint, request, jsonify
from flask.views import MethodView
from sachet.server.models import (
auth_required,
patch,
Permissions,
User,
UserSchema,
BlacklistToken,
)
from sachet.server.views_common import ModelAPI, auth_required
from sachet.server import bcrypt, db
from marshmallow import ValidationError
user_schema = UserSchema()
users_blueprint = Blueprint("users_blueprint", __name__)
@ -51,7 +46,7 @@ class LogoutAPI(MethodView):
"""Endpoint to revoke a user's token."""
@auth_required
def post(user, self):
def post(self, auth_user=None):
post_data = request.get_json()
token = post_data.get("token")
if not token:
@ -71,7 +66,7 @@ class LogoutAPI(MethodView):
except jwt.InvalidTokenError:
return jsonify({"status": "fail", "message": "Invalid auth token."}), 400
if user == token_user or Permissions.ADMIN in user.permissions:
if auth_user == token_user or Permissions.ADMIN in auth_user.permissions:
entry = BlacklistToken(token=token)
db.session.add(entry)
db.session.commit()
@ -97,12 +92,12 @@ class ExtendAPI(MethodView):
"""Endpoint to take a token and get a new one with a later expiry date."""
@auth_required
def post(user, self):
token = user.encode_token(jti="renew")
def post(self, auth_user=None):
token = auth_user.encode_token(jti="renew")
resp = {
"status": "success",
"message": "Renewed token.",
"username": user.username,
"username": auth_user.username,
"auth_token": token,
}
return jsonify(resp), 200
@ -113,14 +108,15 @@ users_blueprint.add_url_rule(
)
class UserAPI(MethodView):
class UserAPI(ModelAPI):
"""User information API"""
@auth_required
def get(user, self, username):
def get(self, username, auth_user=None):
info_user = User.query.filter_by(username=username).first()
# only allow user to query themselves, but admin can query anyone
if (not info_user) or (
info_user != user and Permissions.ADMIN not in user.permissions
info_user != auth_user and Permissions.ADMIN not in auth_user.permissions
):
resp = {
"status": "fail",
@ -128,64 +124,17 @@ class UserAPI(MethodView):
}
return jsonify(resp), 403
return jsonify(user_schema.dump(info_user))
return super().get(info_user)
@auth_required
def patch(user, self, username):
@auth_required(require_admin=True)
def patch(self, username, auth_user=None):
patch_user = User.query.filter_by(username=username).first()
return super().patch(patch_user)
if not patch_user or Permissions.ADMIN not in user.permissions:
resp = {
"status": "fail",
"message": "You are not authorized to access this page.",
}
return jsonify(resp), 403
patch_json = request.get_json()
orig_json = user_schema.dump(patch_user)
new_json = patch(orig_json, patch_json)
try:
deserialized = user_schema.load(new_json)
except ValidationError as e:
resp = {"status": "fail", "message": f"Invalid patch: {str(e)}"}
return jsonify(resp), 400
for k, v in deserialized.items():
setattr(patch_user, k, v)
resp = {
"status": "success",
}
return jsonify(resp), 200
@auth_required
def put(user, self, username):
@auth_required(require_admin=True)
def put(self, username, auth_user=None):
put_user = User.query.filter_by(username=username).first()
if not put_user or Permissions.ADMIN not in user.permissions:
resp = {
"status": "fail",
"message": "You are not authorized to access this page.",
}
return jsonify(resp), 403
new_json = request.get_json()
try:
deserialized = user_schema.load(new_json)
except ValidationError as e:
resp = {"status": "fail", "message": f"Invalid data: {str(e)}"}
return jsonify(resp), 400
for k, v in deserialized.items():
setattr(put_user, k, v)
resp = {
"status": "success",
}
return jsonify(resp), 200
return super().put(put_user)
users_blueprint.add_url_rule(

View File

@ -0,0 +1,144 @@
from flask import request, jsonify
from flask.views import MethodView
from sachet.server.models import Permissions, User, BlacklistToken
from functools import wraps
from marshmallow import ValidationError
import jwt
def auth_required(func=None, *, require_admin=False):
"""Decorator to require authentication.
Passes an argument 'user' to the function, with a User object corresponding
to the authenticated session.
"""
# see https://stackoverflow.com/questions/3888158/making-decorators-with-optional-arguments
def _decorate(f):
@wraps(f)
def decorator(*args, **kwargs):
token = None
auth_header = request.headers.get("Authorization")
if auth_header:
try:
token = auth_header.split(" ")[1]
except IndexError:
resp = {
"status": "fail",
"message": "Malformed Authorization header.",
}
return jsonify(resp), 401
if not token:
return jsonify({"status": "fail", "message": "Missing auth token"}), 401
try:
data, user = User.read_token(token)
except jwt.ExpiredSignatureError:
# if it's expired we don't want it lingering in the db
BlacklistToken.check_blacklist(token)
return jsonify({"status": "fail", "message": "Token has expired."}), 401
except jwt.InvalidTokenError:
return (
jsonify({"status": "fail", "message": "Invalid auth token."}),
401,
)
if require_admin and Permissions.ADMIN not in user.permissions:
return (
jsonify(
{
"status": "fail",
"message": "Administrator permission is required to see this page.",
}
),
403,
)
kwargs["auth_user"] = user
return f(*args, **kwargs)
return decorator
if func:
return _decorate(func)
return _decorate
def patch(orig, diff):
"""Patch the dictionary orig recursively with the dictionary diff."""
# if we get to a leaf node, just replace it
if not isinstance(orig, dict) or not isinstance(diff, dict):
return diff
# deep copy
new = {k: v for k, v in orig.items()}
for key, value in diff.items():
new[key] = patch(orig.get(key, {}), diff[key])
return new
class ModelAPI(MethodView):
"""Generic REST API for interacting with models."""
def get(self, model):
return jsonify(model.get_schema().dump(model))
def patch(self, model):
model_schema = model.get_schema()
if not model:
resp = {
"status": "fail",
"message": "This resource does not exist.",
}
return jsonify(resp), 404
patch_json = request.get_json()
orig_json = model_schema.dump(model)
new_json = patch(orig_json, patch_json)
try:
deserialized = model_schema.load(new_json)
except ValidationError as e:
resp = {"status": "fail", "message": f"Invalid patch: {str(e)}"}
return jsonify(resp), 400
for k, v in deserialized.items():
setattr(model, k, v)
resp = {
"status": "success",
}
return jsonify(resp), 200
def put(self, model):
if not model:
resp = {
"status": "fail",
"message": "This resource does not exist.",
}
return jsonify(resp), 404
model_schema = model.get_schema()
new_json = request.get_json()
try:
deserialized = model_schema.load(new_json)
except ValidationError as e:
resp = {"status": "fail", "message": f"Invalid data: {str(e)}"}
return jsonify(resp), 400
for k, v in deserialized.items():
setattr(model, k, v)
resp = {
"status": "success",
}
return jsonify(resp), 200

View File

@ -3,11 +3,9 @@ import yaml
from sachet.server.users import manage
from click.testing import CliRunner
from sachet.server import app, db
from sachet.server.models import Permissions, UserSchema
from sachet.server.models import Permissions, User
from bitmask import Bitmask
user_schema = UserSchema()
@pytest.fixture
def client():
@ -59,7 +57,7 @@ def validate_info(users):
]
def _validate(user, info):
info = user_schema.load(info)
info = User.get_schema(User).load(info)
for k in verify_fields:
assert users[user][k] == info[k]

View File

@ -1,4 +1,4 @@
from sachet.server.models import patch
from sachet.server.views_common import patch
def test_patch():

View File

@ -1,9 +1,9 @@
import pytest
from bitmask import Bitmask
from sachet.server.models import Permissions, UserSchema
from sachet.server.models import Permissions, User
from datetime import datetime
user_schema = UserSchema()
user_schema = User.get_schema(User)
def test_get(client, tokens, validate_info):