refactor: split off generic model REST API logic
also, a model's marshmallow schema class is now within the model class
This commit is contained in:
parent
a53fba5f4b
commit
68f56dd1a1
@ -17,20 +17,28 @@ class Permissions(IntFlag):
|
|||||||
ADMIN = 1 << 5
|
ADMIN = 1 << 5
|
||||||
|
|
||||||
|
|
||||||
def patch(orig, diff):
|
class PermissionField(fields.Field):
|
||||||
"""Patch the dictionary orig recursively with the dictionary diff."""
|
"""Field that serializes a Permissions bitmask to an array of strings."""
|
||||||
|
|
||||||
# if we get to a leaf node, just replace it
|
def _serialize(self, value, attr, obj, **kwargs):
|
||||||
if not isinstance(orig, dict) or not isinstance(diff, dict):
|
mask = Bitmask()
|
||||||
return diff
|
mask.AllFlags = Permissions
|
||||||
|
mask += value
|
||||||
|
return [flag.name for flag in mask]
|
||||||
|
|
||||||
# deep copy
|
def _deserialize(self, value, attr, data, **kwargs):
|
||||||
new = {k: v for k, v in orig.items()}
|
mask = Bitmask()
|
||||||
|
mask.AllFlags = Permissions
|
||||||
|
|
||||||
for key, value in diff.items():
|
flags = value
|
||||||
new[key] = patch(orig.get(key, {}), diff[key])
|
|
||||||
|
|
||||||
return new
|
try:
|
||||||
|
for flag in flags:
|
||||||
|
mask.add(Permissions[flag])
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValidationError("Invalid permission.") from e
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
class User(db.Model):
|
class User(db.Model):
|
||||||
@ -86,7 +94,6 @@ class User(db.Model):
|
|||||||
}
|
}
|
||||||
return jwt.encode(payload, app.config.get("SECRET_KEY"), algorithm="HS256")
|
return jwt.encode(payload, app.config.get("SECRET_KEY"), algorithm="HS256")
|
||||||
|
|
||||||
|
|
||||||
def read_token(token):
|
def read_token(token):
|
||||||
"""Read a JWT and validate it.
|
"""Read a JWT and validate it.
|
||||||
|
|
||||||
@ -109,38 +116,16 @@ class User(db.Model):
|
|||||||
|
|
||||||
return data, user
|
return data, user
|
||||||
|
|
||||||
|
def get_schema(self):
|
||||||
|
class Schema(ma.SQLAlchemySchema):
|
||||||
|
class Meta:
|
||||||
|
model = self
|
||||||
|
|
||||||
class PermissionField(fields.Field):
|
username = ma.auto_field()
|
||||||
"""Field that serializes a Permissions bitmask to an array of strings."""
|
register_date = ma.auto_field()
|
||||||
|
permissions = PermissionField(data_key="permissions")
|
||||||
|
|
||||||
def _serialize(self, value, attr, obj, **kwargs):
|
return Schema()
|
||||||
mask = Bitmask()
|
|
||||||
mask.AllFlags = Permissions
|
|
||||||
mask += value
|
|
||||||
return [flag.name for flag in mask]
|
|
||||||
|
|
||||||
def _deserialize(self, value, attr, data, **kwargs):
|
|
||||||
mask = Bitmask()
|
|
||||||
mask.AllFlags = Permissions
|
|
||||||
|
|
||||||
flags = value
|
|
||||||
|
|
||||||
try:
|
|
||||||
for flag in flags:
|
|
||||||
mask.add(Permissions[flag])
|
|
||||||
except KeyError as e:
|
|
||||||
raise ValidationError("Invalid permission.") from e
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
class UserSchema(ma.SQLAlchemySchema):
|
|
||||||
class Meta:
|
|
||||||
model = User
|
|
||||||
|
|
||||||
username = ma.auto_field()
|
|
||||||
register_date = ma.auto_field()
|
|
||||||
permissions = PermissionField(data_key="permissions")
|
|
||||||
|
|
||||||
|
|
||||||
class BlacklistToken(db.Model):
|
class BlacklistToken(db.Model):
|
||||||
@ -176,38 +161,3 @@ class BlacklistToken(db.Model):
|
|||||||
if entry.expires < datetime.datetime.utcnow():
|
if entry.expires < datetime.datetime.utcnow():
|
||||||
db.session.delete(entry)
|
db.session.delete(entry)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def auth_required(f):
|
|
||||||
"""Decorator to require authentication.
|
|
||||||
|
|
||||||
Passes an argument 'user' to the function, with a User object corresponding
|
|
||||||
to the authenticated session.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@wraps(f)
|
|
||||||
def decorator(*args, **kwargs):
|
|
||||||
token = None
|
|
||||||
auth_header = request.headers.get("Authorization")
|
|
||||||
if auth_header:
|
|
||||||
try:
|
|
||||||
token = auth_header.split(" ")[1]
|
|
||||||
except IndexError:
|
|
||||||
resp = {"status": "fail", "message": "Malformed Authorization header."}
|
|
||||||
return jsonify(resp), 401
|
|
||||||
|
|
||||||
if not token:
|
|
||||||
return jsonify({"status": "fail", "message": "Missing auth token"}), 401
|
|
||||||
|
|
||||||
try:
|
|
||||||
data, user = User.read_token(token)
|
|
||||||
except jwt.ExpiredSignatureError:
|
|
||||||
# if it's expired we don't want it lingering in the db
|
|
||||||
BlacklistToken.check_blacklist(token)
|
|
||||||
return jsonify({"status": "fail", "message": "Token has expired."}), 401
|
|
||||||
except jwt.InvalidTokenError:
|
|
||||||
return jsonify({"status": "fail", "message": "Invalid auth token."}), 401
|
|
||||||
|
|
||||||
return f(user, *args, **kwargs)
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
@ -2,17 +2,12 @@ import jwt
|
|||||||
from flask import Blueprint, request, jsonify
|
from flask import Blueprint, request, jsonify
|
||||||
from flask.views import MethodView
|
from flask.views import MethodView
|
||||||
from sachet.server.models import (
|
from sachet.server.models import (
|
||||||
auth_required,
|
|
||||||
patch,
|
|
||||||
Permissions,
|
Permissions,
|
||||||
User,
|
User,
|
||||||
UserSchema,
|
|
||||||
BlacklistToken,
|
BlacklistToken,
|
||||||
)
|
)
|
||||||
|
from sachet.server.views_common import ModelAPI, auth_required
|
||||||
from sachet.server import bcrypt, db
|
from sachet.server import bcrypt, db
|
||||||
from marshmallow import ValidationError
|
|
||||||
|
|
||||||
user_schema = UserSchema()
|
|
||||||
|
|
||||||
users_blueprint = Blueprint("users_blueprint", __name__)
|
users_blueprint = Blueprint("users_blueprint", __name__)
|
||||||
|
|
||||||
@ -51,7 +46,7 @@ class LogoutAPI(MethodView):
|
|||||||
"""Endpoint to revoke a user's token."""
|
"""Endpoint to revoke a user's token."""
|
||||||
|
|
||||||
@auth_required
|
@auth_required
|
||||||
def post(user, self):
|
def post(self, auth_user=None):
|
||||||
post_data = request.get_json()
|
post_data = request.get_json()
|
||||||
token = post_data.get("token")
|
token = post_data.get("token")
|
||||||
if not token:
|
if not token:
|
||||||
@ -71,7 +66,7 @@ class LogoutAPI(MethodView):
|
|||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
return jsonify({"status": "fail", "message": "Invalid auth token."}), 400
|
return jsonify({"status": "fail", "message": "Invalid auth token."}), 400
|
||||||
|
|
||||||
if user == token_user or Permissions.ADMIN in user.permissions:
|
if auth_user == token_user or Permissions.ADMIN in auth_user.permissions:
|
||||||
entry = BlacklistToken(token=token)
|
entry = BlacklistToken(token=token)
|
||||||
db.session.add(entry)
|
db.session.add(entry)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -97,12 +92,12 @@ class ExtendAPI(MethodView):
|
|||||||
"""Endpoint to take a token and get a new one with a later expiry date."""
|
"""Endpoint to take a token and get a new one with a later expiry date."""
|
||||||
|
|
||||||
@auth_required
|
@auth_required
|
||||||
def post(user, self):
|
def post(self, auth_user=None):
|
||||||
token = user.encode_token(jti="renew")
|
token = auth_user.encode_token(jti="renew")
|
||||||
resp = {
|
resp = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": "Renewed token.",
|
"message": "Renewed token.",
|
||||||
"username": user.username,
|
"username": auth_user.username,
|
||||||
"auth_token": token,
|
"auth_token": token,
|
||||||
}
|
}
|
||||||
return jsonify(resp), 200
|
return jsonify(resp), 200
|
||||||
@ -113,14 +108,15 @@ users_blueprint.add_url_rule(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UserAPI(MethodView):
|
class UserAPI(ModelAPI):
|
||||||
"""User information API"""
|
"""User information API"""
|
||||||
|
|
||||||
@auth_required
|
@auth_required
|
||||||
def get(user, self, username):
|
def get(self, username, auth_user=None):
|
||||||
info_user = User.query.filter_by(username=username).first()
|
info_user = User.query.filter_by(username=username).first()
|
||||||
|
# only allow user to query themselves, but admin can query anyone
|
||||||
if (not info_user) or (
|
if (not info_user) or (
|
||||||
info_user != user and Permissions.ADMIN not in user.permissions
|
info_user != auth_user and Permissions.ADMIN not in auth_user.permissions
|
||||||
):
|
):
|
||||||
resp = {
|
resp = {
|
||||||
"status": "fail",
|
"status": "fail",
|
||||||
@ -128,64 +124,17 @@ class UserAPI(MethodView):
|
|||||||
}
|
}
|
||||||
return jsonify(resp), 403
|
return jsonify(resp), 403
|
||||||
|
|
||||||
return jsonify(user_schema.dump(info_user))
|
return super().get(info_user)
|
||||||
|
|
||||||
@auth_required
|
@auth_required(require_admin=True)
|
||||||
def patch(user, self, username):
|
def patch(self, username, auth_user=None):
|
||||||
patch_user = User.query.filter_by(username=username).first()
|
patch_user = User.query.filter_by(username=username).first()
|
||||||
|
return super().patch(patch_user)
|
||||||
|
|
||||||
if not patch_user or Permissions.ADMIN not in user.permissions:
|
@auth_required(require_admin=True)
|
||||||
resp = {
|
def put(self, username, auth_user=None):
|
||||||
"status": "fail",
|
|
||||||
"message": "You are not authorized to access this page.",
|
|
||||||
}
|
|
||||||
return jsonify(resp), 403
|
|
||||||
|
|
||||||
patch_json = request.get_json()
|
|
||||||
orig_json = user_schema.dump(patch_user)
|
|
||||||
|
|
||||||
new_json = patch(orig_json, patch_json)
|
|
||||||
|
|
||||||
try:
|
|
||||||
deserialized = user_schema.load(new_json)
|
|
||||||
except ValidationError as e:
|
|
||||||
resp = {"status": "fail", "message": f"Invalid patch: {str(e)}"}
|
|
||||||
return jsonify(resp), 400
|
|
||||||
|
|
||||||
for k, v in deserialized.items():
|
|
||||||
setattr(patch_user, k, v)
|
|
||||||
|
|
||||||
resp = {
|
|
||||||
"status": "success",
|
|
||||||
}
|
|
||||||
return jsonify(resp), 200
|
|
||||||
|
|
||||||
@auth_required
|
|
||||||
def put(user, self, username):
|
|
||||||
put_user = User.query.filter_by(username=username).first()
|
put_user = User.query.filter_by(username=username).first()
|
||||||
|
return super().put(put_user)
|
||||||
if not put_user or Permissions.ADMIN not in user.permissions:
|
|
||||||
resp = {
|
|
||||||
"status": "fail",
|
|
||||||
"message": "You are not authorized to access this page.",
|
|
||||||
}
|
|
||||||
return jsonify(resp), 403
|
|
||||||
|
|
||||||
new_json = request.get_json()
|
|
||||||
|
|
||||||
try:
|
|
||||||
deserialized = user_schema.load(new_json)
|
|
||||||
except ValidationError as e:
|
|
||||||
resp = {"status": "fail", "message": f"Invalid data: {str(e)}"}
|
|
||||||
return jsonify(resp), 400
|
|
||||||
|
|
||||||
for k, v in deserialized.items():
|
|
||||||
setattr(put_user, k, v)
|
|
||||||
|
|
||||||
resp = {
|
|
||||||
"status": "success",
|
|
||||||
}
|
|
||||||
return jsonify(resp), 200
|
|
||||||
|
|
||||||
|
|
||||||
users_blueprint.add_url_rule(
|
users_blueprint.add_url_rule(
|
||||||
|
144
sachet/server/views_common.py
Normal file
144
sachet/server/views_common.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
from flask import request, jsonify
|
||||||
|
from flask.views import MethodView
|
||||||
|
from sachet.server.models import Permissions, User, BlacklistToken
|
||||||
|
from functools import wraps
|
||||||
|
from marshmallow import ValidationError
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
|
||||||
|
def auth_required(func=None, *, require_admin=False):
|
||||||
|
"""Decorator to require authentication.
|
||||||
|
|
||||||
|
Passes an argument 'user' to the function, with a User object corresponding
|
||||||
|
to the authenticated session.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# see https://stackoverflow.com/questions/3888158/making-decorators-with-optional-arguments
|
||||||
|
def _decorate(f):
|
||||||
|
@wraps(f)
|
||||||
|
def decorator(*args, **kwargs):
|
||||||
|
token = None
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if auth_header:
|
||||||
|
try:
|
||||||
|
token = auth_header.split(" ")[1]
|
||||||
|
except IndexError:
|
||||||
|
resp = {
|
||||||
|
"status": "fail",
|
||||||
|
"message": "Malformed Authorization header.",
|
||||||
|
}
|
||||||
|
return jsonify(resp), 401
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
return jsonify({"status": "fail", "message": "Missing auth token"}), 401
|
||||||
|
|
||||||
|
try:
|
||||||
|
data, user = User.read_token(token)
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
# if it's expired we don't want it lingering in the db
|
||||||
|
BlacklistToken.check_blacklist(token)
|
||||||
|
return jsonify({"status": "fail", "message": "Token has expired."}), 401
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
|
return (
|
||||||
|
jsonify({"status": "fail", "message": "Invalid auth token."}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
if require_admin and Permissions.ADMIN not in user.permissions:
|
||||||
|
return (
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"status": "fail",
|
||||||
|
"message": "Administrator permission is required to see this page.",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
403,
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs["auth_user"] = user
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
if func:
|
||||||
|
return _decorate(func)
|
||||||
|
|
||||||
|
return _decorate
|
||||||
|
|
||||||
|
|
||||||
|
def patch(orig, diff):
|
||||||
|
"""Patch the dictionary orig recursively with the dictionary diff."""
|
||||||
|
|
||||||
|
# if we get to a leaf node, just replace it
|
||||||
|
if not isinstance(orig, dict) or not isinstance(diff, dict):
|
||||||
|
return diff
|
||||||
|
|
||||||
|
# deep copy
|
||||||
|
new = {k: v for k, v in orig.items()}
|
||||||
|
|
||||||
|
for key, value in diff.items():
|
||||||
|
new[key] = patch(orig.get(key, {}), diff[key])
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
class ModelAPI(MethodView):
|
||||||
|
"""Generic REST API for interacting with models."""
|
||||||
|
|
||||||
|
def get(self, model):
|
||||||
|
return jsonify(model.get_schema().dump(model))
|
||||||
|
|
||||||
|
def patch(self, model):
|
||||||
|
model_schema = model.get_schema()
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
resp = {
|
||||||
|
"status": "fail",
|
||||||
|
"message": "This resource does not exist.",
|
||||||
|
}
|
||||||
|
return jsonify(resp), 404
|
||||||
|
|
||||||
|
patch_json = request.get_json()
|
||||||
|
orig_json = model_schema.dump(model)
|
||||||
|
|
||||||
|
new_json = patch(orig_json, patch_json)
|
||||||
|
|
||||||
|
try:
|
||||||
|
deserialized = model_schema.load(new_json)
|
||||||
|
except ValidationError as e:
|
||||||
|
resp = {"status": "fail", "message": f"Invalid patch: {str(e)}"}
|
||||||
|
return jsonify(resp), 400
|
||||||
|
|
||||||
|
for k, v in deserialized.items():
|
||||||
|
setattr(model, k, v)
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"status": "success",
|
||||||
|
}
|
||||||
|
return jsonify(resp), 200
|
||||||
|
|
||||||
|
def put(self, model):
|
||||||
|
if not model:
|
||||||
|
resp = {
|
||||||
|
"status": "fail",
|
||||||
|
"message": "This resource does not exist.",
|
||||||
|
}
|
||||||
|
return jsonify(resp), 404
|
||||||
|
|
||||||
|
model_schema = model.get_schema()
|
||||||
|
|
||||||
|
new_json = request.get_json()
|
||||||
|
|
||||||
|
try:
|
||||||
|
deserialized = model_schema.load(new_json)
|
||||||
|
except ValidationError as e:
|
||||||
|
resp = {"status": "fail", "message": f"Invalid data: {str(e)}"}
|
||||||
|
return jsonify(resp), 400
|
||||||
|
|
||||||
|
for k, v in deserialized.items():
|
||||||
|
setattr(model, k, v)
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"status": "success",
|
||||||
|
}
|
||||||
|
return jsonify(resp), 200
|
@ -3,11 +3,9 @@ import yaml
|
|||||||
from sachet.server.users import manage
|
from sachet.server.users import manage
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
from sachet.server import app, db
|
from sachet.server import app, db
|
||||||
from sachet.server.models import Permissions, UserSchema
|
from sachet.server.models import Permissions, User
|
||||||
from bitmask import Bitmask
|
from bitmask import Bitmask
|
||||||
|
|
||||||
user_schema = UserSchema()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
@ -59,7 +57,7 @@ def validate_info(users):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def _validate(user, info):
|
def _validate(user, info):
|
||||||
info = user_schema.load(info)
|
info = User.get_schema(User).load(info)
|
||||||
|
|
||||||
for k in verify_fields:
|
for k in verify_fields:
|
||||||
assert users[user][k] == info[k]
|
assert users[user][k] == info[k]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from sachet.server.models import patch
|
from sachet.server.views_common import patch
|
||||||
|
|
||||||
|
|
||||||
def test_patch():
|
def test_patch():
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from bitmask import Bitmask
|
from bitmask import Bitmask
|
||||||
from sachet.server.models import Permissions, UserSchema
|
from sachet.server.models import Permissions, User
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
user_schema = UserSchema()
|
user_schema = User.get_schema(User)
|
||||||
|
|
||||||
|
|
||||||
def test_get(client, tokens, validate_info):
|
def test_get(client, tokens, validate_info):
|
||||||
|
Loading…
Reference in New Issue
Block a user