refactor: split off generic model REST API logic
also, a model's marshmallow schema class is now within the model class
This commit is contained in:
parent
a53fba5f4b
commit
68f56dd1a1
@ -17,20 +17,28 @@ class Permissions(IntFlag):
|
||||
ADMIN = 1 << 5
|
||||
|
||||
|
||||
def patch(orig, diff):
|
||||
"""Patch the dictionary orig recursively with the dictionary diff."""
|
||||
class PermissionField(fields.Field):
|
||||
"""Field that serializes a Permissions bitmask to an array of strings."""
|
||||
|
||||
# if we get to a leaf node, just replace it
|
||||
if not isinstance(orig, dict) or not isinstance(diff, dict):
|
||||
return diff
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
mask = Bitmask()
|
||||
mask.AllFlags = Permissions
|
||||
mask += value
|
||||
return [flag.name for flag in mask]
|
||||
|
||||
# deep copy
|
||||
new = {k: v for k, v in orig.items()}
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
mask = Bitmask()
|
||||
mask.AllFlags = Permissions
|
||||
|
||||
for key, value in diff.items():
|
||||
new[key] = patch(orig.get(key, {}), diff[key])
|
||||
flags = value
|
||||
|
||||
return new
|
||||
try:
|
||||
for flag in flags:
|
||||
mask.add(Permissions[flag])
|
||||
except KeyError as e:
|
||||
raise ValidationError("Invalid permission.") from e
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
class User(db.Model):
|
||||
@ -86,7 +94,6 @@ class User(db.Model):
|
||||
}
|
||||
return jwt.encode(payload, app.config.get("SECRET_KEY"), algorithm="HS256")
|
||||
|
||||
|
||||
def read_token(token):
|
||||
"""Read a JWT and validate it.
|
||||
|
||||
@ -109,38 +116,16 @@ class User(db.Model):
|
||||
|
||||
return data, user
|
||||
|
||||
def get_schema(self):
|
||||
class Schema(ma.SQLAlchemySchema):
|
||||
class Meta:
|
||||
model = self
|
||||
|
||||
class PermissionField(fields.Field):
|
||||
"""Field that serializes a Permissions bitmask to an array of strings."""
|
||||
username = ma.auto_field()
|
||||
register_date = ma.auto_field()
|
||||
permissions = PermissionField(data_key="permissions")
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
mask = Bitmask()
|
||||
mask.AllFlags = Permissions
|
||||
mask += value
|
||||
return [flag.name for flag in mask]
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
mask = Bitmask()
|
||||
mask.AllFlags = Permissions
|
||||
|
||||
flags = value
|
||||
|
||||
try:
|
||||
for flag in flags:
|
||||
mask.add(Permissions[flag])
|
||||
except KeyError as e:
|
||||
raise ValidationError("Invalid permission.") from e
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
class UserSchema(ma.SQLAlchemySchema):
|
||||
class Meta:
|
||||
model = User
|
||||
|
||||
username = ma.auto_field()
|
||||
register_date = ma.auto_field()
|
||||
permissions = PermissionField(data_key="permissions")
|
||||
return Schema()
|
||||
|
||||
|
||||
class BlacklistToken(db.Model):
|
||||
@ -176,38 +161,3 @@ class BlacklistToken(db.Model):
|
||||
if entry.expires < datetime.datetime.utcnow():
|
||||
db.session.delete(entry)
|
||||
return True
|
||||
|
||||
|
||||
def auth_required(f):
|
||||
"""Decorator to require authentication.
|
||||
|
||||
Passes an argument 'user' to the function, with a User object corresponding
|
||||
to the authenticated session.
|
||||
"""
|
||||
|
||||
@wraps(f)
|
||||
def decorator(*args, **kwargs):
|
||||
token = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header:
|
||||
try:
|
||||
token = auth_header.split(" ")[1]
|
||||
except IndexError:
|
||||
resp = {"status": "fail", "message": "Malformed Authorization header."}
|
||||
return jsonify(resp), 401
|
||||
|
||||
if not token:
|
||||
return jsonify({"status": "fail", "message": "Missing auth token"}), 401
|
||||
|
||||
try:
|
||||
data, user = User.read_token(token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
# if it's expired we don't want it lingering in the db
|
||||
BlacklistToken.check_blacklist(token)
|
||||
return jsonify({"status": "fail", "message": "Token has expired."}), 401
|
||||
except jwt.InvalidTokenError:
|
||||
return jsonify({"status": "fail", "message": "Invalid auth token."}), 401
|
||||
|
||||
return f(user, *args, **kwargs)
|
||||
|
||||
return decorator
|
||||
|
@ -2,17 +2,12 @@ import jwt
|
||||
from flask import Blueprint, request, jsonify
|
||||
from flask.views import MethodView
|
||||
from sachet.server.models import (
|
||||
auth_required,
|
||||
patch,
|
||||
Permissions,
|
||||
User,
|
||||
UserSchema,
|
||||
BlacklistToken,
|
||||
)
|
||||
from sachet.server.views_common import ModelAPI, auth_required
|
||||
from sachet.server import bcrypt, db
|
||||
from marshmallow import ValidationError
|
||||
|
||||
user_schema = UserSchema()
|
||||
|
||||
users_blueprint = Blueprint("users_blueprint", __name__)
|
||||
|
||||
@ -51,7 +46,7 @@ class LogoutAPI(MethodView):
|
||||
"""Endpoint to revoke a user's token."""
|
||||
|
||||
@auth_required
|
||||
def post(user, self):
|
||||
def post(self, auth_user=None):
|
||||
post_data = request.get_json()
|
||||
token = post_data.get("token")
|
||||
if not token:
|
||||
@ -71,7 +66,7 @@ class LogoutAPI(MethodView):
|
||||
except jwt.InvalidTokenError:
|
||||
return jsonify({"status": "fail", "message": "Invalid auth token."}), 400
|
||||
|
||||
if user == token_user or Permissions.ADMIN in user.permissions:
|
||||
if auth_user == token_user or Permissions.ADMIN in auth_user.permissions:
|
||||
entry = BlacklistToken(token=token)
|
||||
db.session.add(entry)
|
||||
db.session.commit()
|
||||
@ -97,12 +92,12 @@ class ExtendAPI(MethodView):
|
||||
"""Endpoint to take a token and get a new one with a later expiry date."""
|
||||
|
||||
@auth_required
|
||||
def post(user, self):
|
||||
token = user.encode_token(jti="renew")
|
||||
def post(self, auth_user=None):
|
||||
token = auth_user.encode_token(jti="renew")
|
||||
resp = {
|
||||
"status": "success",
|
||||
"message": "Renewed token.",
|
||||
"username": user.username,
|
||||
"username": auth_user.username,
|
||||
"auth_token": token,
|
||||
}
|
||||
return jsonify(resp), 200
|
||||
@ -113,14 +108,15 @@ users_blueprint.add_url_rule(
|
||||
)
|
||||
|
||||
|
||||
class UserAPI(MethodView):
|
||||
class UserAPI(ModelAPI):
|
||||
"""User information API"""
|
||||
|
||||
@auth_required
|
||||
def get(user, self, username):
|
||||
def get(self, username, auth_user=None):
|
||||
info_user = User.query.filter_by(username=username).first()
|
||||
# only allow user to query themselves, but admin can query anyone
|
||||
if (not info_user) or (
|
||||
info_user != user and Permissions.ADMIN not in user.permissions
|
||||
info_user != auth_user and Permissions.ADMIN not in auth_user.permissions
|
||||
):
|
||||
resp = {
|
||||
"status": "fail",
|
||||
@ -128,64 +124,17 @@ class UserAPI(MethodView):
|
||||
}
|
||||
return jsonify(resp), 403
|
||||
|
||||
return jsonify(user_schema.dump(info_user))
|
||||
return super().get(info_user)
|
||||
|
||||
@auth_required
|
||||
def patch(user, self, username):
|
||||
@auth_required(require_admin=True)
|
||||
def patch(self, username, auth_user=None):
|
||||
patch_user = User.query.filter_by(username=username).first()
|
||||
return super().patch(patch_user)
|
||||
|
||||
if not patch_user or Permissions.ADMIN not in user.permissions:
|
||||
resp = {
|
||||
"status": "fail",
|
||||
"message": "You are not authorized to access this page.",
|
||||
}
|
||||
return jsonify(resp), 403
|
||||
|
||||
patch_json = request.get_json()
|
||||
orig_json = user_schema.dump(patch_user)
|
||||
|
||||
new_json = patch(orig_json, patch_json)
|
||||
|
||||
try:
|
||||
deserialized = user_schema.load(new_json)
|
||||
except ValidationError as e:
|
||||
resp = {"status": "fail", "message": f"Invalid patch: {str(e)}"}
|
||||
return jsonify(resp), 400
|
||||
|
||||
for k, v in deserialized.items():
|
||||
setattr(patch_user, k, v)
|
||||
|
||||
resp = {
|
||||
"status": "success",
|
||||
}
|
||||
return jsonify(resp), 200
|
||||
|
||||
@auth_required
|
||||
def put(user, self, username):
|
||||
@auth_required(require_admin=True)
|
||||
def put(self, username, auth_user=None):
|
||||
put_user = User.query.filter_by(username=username).first()
|
||||
|
||||
if not put_user or Permissions.ADMIN not in user.permissions:
|
||||
resp = {
|
||||
"status": "fail",
|
||||
"message": "You are not authorized to access this page.",
|
||||
}
|
||||
return jsonify(resp), 403
|
||||
|
||||
new_json = request.get_json()
|
||||
|
||||
try:
|
||||
deserialized = user_schema.load(new_json)
|
||||
except ValidationError as e:
|
||||
resp = {"status": "fail", "message": f"Invalid data: {str(e)}"}
|
||||
return jsonify(resp), 400
|
||||
|
||||
for k, v in deserialized.items():
|
||||
setattr(put_user, k, v)
|
||||
|
||||
resp = {
|
||||
"status": "success",
|
||||
}
|
||||
return jsonify(resp), 200
|
||||
return super().put(put_user)
|
||||
|
||||
|
||||
users_blueprint.add_url_rule(
|
||||
|
144
sachet/server/views_common.py
Normal file
144
sachet/server/views_common.py
Normal file
@ -0,0 +1,144 @@
|
||||
from flask import request, jsonify
|
||||
from flask.views import MethodView
|
||||
from sachet.server.models import Permissions, User, BlacklistToken
|
||||
from functools import wraps
|
||||
from marshmallow import ValidationError
|
||||
import jwt
|
||||
|
||||
|
||||
def auth_required(func=None, *, require_admin=False):
|
||||
"""Decorator to require authentication.
|
||||
|
||||
Passes an argument 'user' to the function, with a User object corresponding
|
||||
to the authenticated session.
|
||||
"""
|
||||
|
||||
# see https://stackoverflow.com/questions/3888158/making-decorators-with-optional-arguments
|
||||
def _decorate(f):
|
||||
@wraps(f)
|
||||
def decorator(*args, **kwargs):
|
||||
token = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header:
|
||||
try:
|
||||
token = auth_header.split(" ")[1]
|
||||
except IndexError:
|
||||
resp = {
|
||||
"status": "fail",
|
||||
"message": "Malformed Authorization header.",
|
||||
}
|
||||
return jsonify(resp), 401
|
||||
|
||||
if not token:
|
||||
return jsonify({"status": "fail", "message": "Missing auth token"}), 401
|
||||
|
||||
try:
|
||||
data, user = User.read_token(token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
# if it's expired we don't want it lingering in the db
|
||||
BlacklistToken.check_blacklist(token)
|
||||
return jsonify({"status": "fail", "message": "Token has expired."}), 401
|
||||
except jwt.InvalidTokenError:
|
||||
return (
|
||||
jsonify({"status": "fail", "message": "Invalid auth token."}),
|
||||
401,
|
||||
)
|
||||
|
||||
if require_admin and Permissions.ADMIN not in user.permissions:
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"status": "fail",
|
||||
"message": "Administrator permission is required to see this page.",
|
||||
}
|
||||
),
|
||||
403,
|
||||
)
|
||||
|
||||
kwargs["auth_user"] = user
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorator
|
||||
|
||||
if func:
|
||||
return _decorate(func)
|
||||
|
||||
return _decorate
|
||||
|
||||
|
||||
def patch(orig, diff):
|
||||
"""Patch the dictionary orig recursively with the dictionary diff."""
|
||||
|
||||
# if we get to a leaf node, just replace it
|
||||
if not isinstance(orig, dict) or not isinstance(diff, dict):
|
||||
return diff
|
||||
|
||||
# deep copy
|
||||
new = {k: v for k, v in orig.items()}
|
||||
|
||||
for key, value in diff.items():
|
||||
new[key] = patch(orig.get(key, {}), diff[key])
|
||||
|
||||
return new
|
||||
|
||||
|
||||
class ModelAPI(MethodView):
|
||||
"""Generic REST API for interacting with models."""
|
||||
|
||||
def get(self, model):
|
||||
return jsonify(model.get_schema().dump(model))
|
||||
|
||||
def patch(self, model):
|
||||
model_schema = model.get_schema()
|
||||
|
||||
if not model:
|
||||
resp = {
|
||||
"status": "fail",
|
||||
"message": "This resource does not exist.",
|
||||
}
|
||||
return jsonify(resp), 404
|
||||
|
||||
patch_json = request.get_json()
|
||||
orig_json = model_schema.dump(model)
|
||||
|
||||
new_json = patch(orig_json, patch_json)
|
||||
|
||||
try:
|
||||
deserialized = model_schema.load(new_json)
|
||||
except ValidationError as e:
|
||||
resp = {"status": "fail", "message": f"Invalid patch: {str(e)}"}
|
||||
return jsonify(resp), 400
|
||||
|
||||
for k, v in deserialized.items():
|
||||
setattr(model, k, v)
|
||||
|
||||
resp = {
|
||||
"status": "success",
|
||||
}
|
||||
return jsonify(resp), 200
|
||||
|
||||
def put(self, model):
|
||||
if not model:
|
||||
resp = {
|
||||
"status": "fail",
|
||||
"message": "This resource does not exist.",
|
||||
}
|
||||
return jsonify(resp), 404
|
||||
|
||||
model_schema = model.get_schema()
|
||||
|
||||
new_json = request.get_json()
|
||||
|
||||
try:
|
||||
deserialized = model_schema.load(new_json)
|
||||
except ValidationError as e:
|
||||
resp = {"status": "fail", "message": f"Invalid data: {str(e)}"}
|
||||
return jsonify(resp), 400
|
||||
|
||||
for k, v in deserialized.items():
|
||||
setattr(model, k, v)
|
||||
|
||||
resp = {
|
||||
"status": "success",
|
||||
}
|
||||
return jsonify(resp), 200
|
@ -3,11 +3,9 @@ import yaml
|
||||
from sachet.server.users import manage
|
||||
from click.testing import CliRunner
|
||||
from sachet.server import app, db
|
||||
from sachet.server.models import Permissions, UserSchema
|
||||
from sachet.server.models import Permissions, User
|
||||
from bitmask import Bitmask
|
||||
|
||||
user_schema = UserSchema()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
@ -59,7 +57,7 @@ def validate_info(users):
|
||||
]
|
||||
|
||||
def _validate(user, info):
|
||||
info = user_schema.load(info)
|
||||
info = User.get_schema(User).load(info)
|
||||
|
||||
for k in verify_fields:
|
||||
assert users[user][k] == info[k]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from sachet.server.models import patch
|
||||
from sachet.server.views_common import patch
|
||||
|
||||
|
||||
def test_patch():
|
||||
|
@ -1,9 +1,9 @@
|
||||
import pytest
|
||||
from bitmask import Bitmask
|
||||
from sachet.server.models import Permissions, UserSchema
|
||||
from sachet.server.models import Permissions, User
|
||||
from datetime import datetime
|
||||
|
||||
user_schema = UserSchema()
|
||||
user_schema = User.get_schema(User)
|
||||
|
||||
|
||||
def test_get(client, tokens, validate_info):
|
||||
|
Loading…
Reference in New Issue
Block a user