refactor: split off generic model REST API logic

also, a model's marshmallow schema class is now within the model class
This commit is contained in:
dogeystamp 2023-04-01 17:56:25 -04:00
parent a53fba5f4b
commit 68f56dd1a1
Signed by: dogeystamp
GPG Key ID: 7225FE3592EFFA38
6 changed files with 192 additions and 151 deletions

View File

@ -17,20 +17,28 @@ class Permissions(IntFlag):
ADMIN = 1 << 5 ADMIN = 1 << 5
def patch(orig, diff): class PermissionField(fields.Field):
"""Patch the dictionary orig recursively with the dictionary diff.""" """Field that serializes a Permissions bitmask to an array of strings."""
# if we get to a leaf node, just replace it def _serialize(self, value, attr, obj, **kwargs):
if not isinstance(orig, dict) or not isinstance(diff, dict): mask = Bitmask()
return diff mask.AllFlags = Permissions
mask += value
return [flag.name for flag in mask]
# deep copy def _deserialize(self, value, attr, data, **kwargs):
new = {k: v for k, v in orig.items()} mask = Bitmask()
mask.AllFlags = Permissions
for key, value in diff.items(): flags = value
new[key] = patch(orig.get(key, {}), diff[key])
return new try:
for flag in flags:
mask.add(Permissions[flag])
except KeyError as e:
raise ValidationError("Invalid permission.") from e
return mask
class User(db.Model): class User(db.Model):
@ -86,7 +94,6 @@ class User(db.Model):
} }
return jwt.encode(payload, app.config.get("SECRET_KEY"), algorithm="HS256") return jwt.encode(payload, app.config.get("SECRET_KEY"), algorithm="HS256")
def read_token(token): def read_token(token):
"""Read a JWT and validate it. """Read a JWT and validate it.
@ -109,38 +116,16 @@ class User(db.Model):
return data, user return data, user
def get_schema(self):
class Schema(ma.SQLAlchemySchema):
class Meta:
model = self
class PermissionField(fields.Field): username = ma.auto_field()
"""Field that serializes a Permissions bitmask to an array of strings.""" register_date = ma.auto_field()
permissions = PermissionField(data_key="permissions")
def _serialize(self, value, attr, obj, **kwargs): return Schema()
mask = Bitmask()
mask.AllFlags = Permissions
mask += value
return [flag.name for flag in mask]
def _deserialize(self, value, attr, data, **kwargs):
mask = Bitmask()
mask.AllFlags = Permissions
flags = value
try:
for flag in flags:
mask.add(Permissions[flag])
except KeyError as e:
raise ValidationError("Invalid permission.") from e
return mask
class UserSchema(ma.SQLAlchemySchema):
class Meta:
model = User
username = ma.auto_field()
register_date = ma.auto_field()
permissions = PermissionField(data_key="permissions")
class BlacklistToken(db.Model): class BlacklistToken(db.Model):
@ -176,38 +161,3 @@ class BlacklistToken(db.Model):
if entry.expires < datetime.datetime.utcnow(): if entry.expires < datetime.datetime.utcnow():
db.session.delete(entry) db.session.delete(entry)
return True return True
def auth_required(f):
"""Decorator to require authentication.
Passes an argument 'user' to the function, with a User object corresponding
to the authenticated session.
"""
@wraps(f)
def decorator(*args, **kwargs):
token = None
auth_header = request.headers.get("Authorization")
if auth_header:
try:
token = auth_header.split(" ")[1]
except IndexError:
resp = {"status": "fail", "message": "Malformed Authorization header."}
return jsonify(resp), 401
if not token:
return jsonify({"status": "fail", "message": "Missing auth token"}), 401
try:
data, user = User.read_token(token)
except jwt.ExpiredSignatureError:
# if it's expired we don't want it lingering in the db
BlacklistToken.check_blacklist(token)
return jsonify({"status": "fail", "message": "Token has expired."}), 401
except jwt.InvalidTokenError:
return jsonify({"status": "fail", "message": "Invalid auth token."}), 401
return f(user, *args, **kwargs)
return decorator

View File

@ -2,17 +2,12 @@ import jwt
from flask import Blueprint, request, jsonify from flask import Blueprint, request, jsonify
from flask.views import MethodView from flask.views import MethodView
from sachet.server.models import ( from sachet.server.models import (
auth_required,
patch,
Permissions, Permissions,
User, User,
UserSchema,
BlacklistToken, BlacklistToken,
) )
from sachet.server.views_common import ModelAPI, auth_required
from sachet.server import bcrypt, db from sachet.server import bcrypt, db
from marshmallow import ValidationError
user_schema = UserSchema()
users_blueprint = Blueprint("users_blueprint", __name__) users_blueprint = Blueprint("users_blueprint", __name__)
@ -51,7 +46,7 @@ class LogoutAPI(MethodView):
"""Endpoint to revoke a user's token.""" """Endpoint to revoke a user's token."""
@auth_required @auth_required
def post(user, self): def post(self, auth_user=None):
post_data = request.get_json() post_data = request.get_json()
token = post_data.get("token") token = post_data.get("token")
if not token: if not token:
@ -71,7 +66,7 @@ class LogoutAPI(MethodView):
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
return jsonify({"status": "fail", "message": "Invalid auth token."}), 400 return jsonify({"status": "fail", "message": "Invalid auth token."}), 400
if user == token_user or Permissions.ADMIN in user.permissions: if auth_user == token_user or Permissions.ADMIN in auth_user.permissions:
entry = BlacklistToken(token=token) entry = BlacklistToken(token=token)
db.session.add(entry) db.session.add(entry)
db.session.commit() db.session.commit()
@ -97,12 +92,12 @@ class ExtendAPI(MethodView):
"""Endpoint to take a token and get a new one with a later expiry date.""" """Endpoint to take a token and get a new one with a later expiry date."""
@auth_required @auth_required
def post(user, self): def post(self, auth_user=None):
token = user.encode_token(jti="renew") token = auth_user.encode_token(jti="renew")
resp = { resp = {
"status": "success", "status": "success",
"message": "Renewed token.", "message": "Renewed token.",
"username": user.username, "username": auth_user.username,
"auth_token": token, "auth_token": token,
} }
return jsonify(resp), 200 return jsonify(resp), 200
@ -113,14 +108,15 @@ users_blueprint.add_url_rule(
) )
class UserAPI(MethodView): class UserAPI(ModelAPI):
"""User information API""" """User information API"""
@auth_required @auth_required
def get(user, self, username): def get(self, username, auth_user=None):
info_user = User.query.filter_by(username=username).first() info_user = User.query.filter_by(username=username).first()
# only allow user to query themselves, but admin can query anyone
if (not info_user) or ( if (not info_user) or (
info_user != user and Permissions.ADMIN not in user.permissions info_user != auth_user and Permissions.ADMIN not in auth_user.permissions
): ):
resp = { resp = {
"status": "fail", "status": "fail",
@ -128,64 +124,17 @@ class UserAPI(MethodView):
} }
return jsonify(resp), 403 return jsonify(resp), 403
return jsonify(user_schema.dump(info_user)) return super().get(info_user)
@auth_required @auth_required(require_admin=True)
def patch(user, self, username): def patch(self, username, auth_user=None):
patch_user = User.query.filter_by(username=username).first() patch_user = User.query.filter_by(username=username).first()
return super().patch(patch_user)
if not patch_user or Permissions.ADMIN not in user.permissions: @auth_required(require_admin=True)
resp = { def put(self, username, auth_user=None):
"status": "fail",
"message": "You are not authorized to access this page.",
}
return jsonify(resp), 403
patch_json = request.get_json()
orig_json = user_schema.dump(patch_user)
new_json = patch(orig_json, patch_json)
try:
deserialized = user_schema.load(new_json)
except ValidationError as e:
resp = {"status": "fail", "message": f"Invalid patch: {str(e)}"}
return jsonify(resp), 400
for k, v in deserialized.items():
setattr(patch_user, k, v)
resp = {
"status": "success",
}
return jsonify(resp), 200
@auth_required
def put(user, self, username):
put_user = User.query.filter_by(username=username).first() put_user = User.query.filter_by(username=username).first()
return super().put(put_user)
if not put_user or Permissions.ADMIN not in user.permissions:
resp = {
"status": "fail",
"message": "You are not authorized to access this page.",
}
return jsonify(resp), 403
new_json = request.get_json()
try:
deserialized = user_schema.load(new_json)
except ValidationError as e:
resp = {"status": "fail", "message": f"Invalid data: {str(e)}"}
return jsonify(resp), 400
for k, v in deserialized.items():
setattr(put_user, k, v)
resp = {
"status": "success",
}
return jsonify(resp), 200
users_blueprint.add_url_rule( users_blueprint.add_url_rule(

View File

@ -0,0 +1,144 @@
from flask import request, jsonify
from flask.views import MethodView
from sachet.server.models import Permissions, User, BlacklistToken
from functools import wraps
from marshmallow import ValidationError
import jwt
def auth_required(func=None, *, require_admin=False):
"""Decorator to require authentication.
Passes an argument 'user' to the function, with a User object corresponding
to the authenticated session.
"""
# see https://stackoverflow.com/questions/3888158/making-decorators-with-optional-arguments
def _decorate(f):
@wraps(f)
def decorator(*args, **kwargs):
token = None
auth_header = request.headers.get("Authorization")
if auth_header:
try:
token = auth_header.split(" ")[1]
except IndexError:
resp = {
"status": "fail",
"message": "Malformed Authorization header.",
}
return jsonify(resp), 401
if not token:
return jsonify({"status": "fail", "message": "Missing auth token"}), 401
try:
data, user = User.read_token(token)
except jwt.ExpiredSignatureError:
# if it's expired we don't want it lingering in the db
BlacklistToken.check_blacklist(token)
return jsonify({"status": "fail", "message": "Token has expired."}), 401
except jwt.InvalidTokenError:
return (
jsonify({"status": "fail", "message": "Invalid auth token."}),
401,
)
if require_admin and Permissions.ADMIN not in user.permissions:
return (
jsonify(
{
"status": "fail",
"message": "Administrator permission is required to see this page.",
}
),
403,
)
kwargs["auth_user"] = user
return f(*args, **kwargs)
return decorator
if func:
return _decorate(func)
return _decorate
def patch(orig, diff):
"""Patch the dictionary orig recursively with the dictionary diff."""
# if we get to a leaf node, just replace it
if not isinstance(orig, dict) or not isinstance(diff, dict):
return diff
# deep copy
new = {k: v for k, v in orig.items()}
for key, value in diff.items():
new[key] = patch(orig.get(key, {}), diff[key])
return new
class ModelAPI(MethodView):
"""Generic REST API for interacting with models."""
def get(self, model):
return jsonify(model.get_schema().dump(model))
def patch(self, model):
model_schema = model.get_schema()
if not model:
resp = {
"status": "fail",
"message": "This resource does not exist.",
}
return jsonify(resp), 404
patch_json = request.get_json()
orig_json = model_schema.dump(model)
new_json = patch(orig_json, patch_json)
try:
deserialized = model_schema.load(new_json)
except ValidationError as e:
resp = {"status": "fail", "message": f"Invalid patch: {str(e)}"}
return jsonify(resp), 400
for k, v in deserialized.items():
setattr(model, k, v)
resp = {
"status": "success",
}
return jsonify(resp), 200
def put(self, model):
if not model:
resp = {
"status": "fail",
"message": "This resource does not exist.",
}
return jsonify(resp), 404
model_schema = model.get_schema()
new_json = request.get_json()
try:
deserialized = model_schema.load(new_json)
except ValidationError as e:
resp = {"status": "fail", "message": f"Invalid data: {str(e)}"}
return jsonify(resp), 400
for k, v in deserialized.items():
setattr(model, k, v)
resp = {
"status": "success",
}
return jsonify(resp), 200

View File

@ -3,11 +3,9 @@ import yaml
from sachet.server.users import manage from sachet.server.users import manage
from click.testing import CliRunner from click.testing import CliRunner
from sachet.server import app, db from sachet.server import app, db
from sachet.server.models import Permissions, UserSchema from sachet.server.models import Permissions, User
from bitmask import Bitmask from bitmask import Bitmask
user_schema = UserSchema()
@pytest.fixture @pytest.fixture
def client(): def client():
@ -59,7 +57,7 @@ def validate_info(users):
] ]
def _validate(user, info): def _validate(user, info):
info = user_schema.load(info) info = User.get_schema(User).load(info)
for k in verify_fields: for k in verify_fields:
assert users[user][k] == info[k] assert users[user][k] == info[k]

View File

@ -1,4 +1,4 @@
from sachet.server.models import patch from sachet.server.views_common import patch
def test_patch(): def test_patch():

View File

@ -1,9 +1,9 @@
import pytest import pytest
from bitmask import Bitmask from bitmask import Bitmask
from sachet.server.models import Permissions, UserSchema from sachet.server.models import Permissions, User
from datetime import datetime from datetime import datetime
user_schema = UserSchema() user_schema = User.get_schema(User)
def test_get(client, tokens, validate_info): def test_get(client, tokens, validate_info):