diff --git a/requirements.txt b/requirements.txt index 79eb26c..16cc574 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,8 @@ iniconfig==2.0.0 itsdangerous==2.1.2 Jinja2==3.1.2 MarkupSafe==2.1.2 +marshmallow==3.19.0 +marshmallow-sqlalchemy==0.29.0 packaging==23.0 pluggy==1.0.0 PyJWT==2.6.0 diff --git a/sachet/server/__init__.py b/sachet/server/__init__.py index ace6535..e323c0c 100644 --- a/sachet/server/__init__.py +++ b/sachet/server/__init__.py @@ -2,6 +2,7 @@ import os from flask import Flask from flask_cors import CORS from flask_sqlalchemy import SQLAlchemy +from flask_marshmallow import Marshmallow from flask_bcrypt import Bcrypt from .config import DevelopmentConfig, ProductionConfig, TestingConfig, overlay_config @@ -18,6 +19,7 @@ else: bcrypt = Bcrypt(app) db = SQLAlchemy(app) +ma = Marshmallow() import sachet.server.commands diff --git a/sachet/server/models.py b/sachet/server/models.py index 532739b..5ede7a9 100644 --- a/sachet/server/models.py +++ b/sachet/server/models.py @@ -1,4 +1,4 @@ -from sachet.server import app, db, bcrypt +from sachet.server import app, db, ma, bcrypt from flask import request, jsonify from functools import wraps import datetime @@ -35,6 +35,15 @@ class User(db.Model): ) +class UserSchema(ma.SQLAlchemySchema): + class Meta: + model = User + + username = ma.auto_field() + register_date = ma.auto_field() + admin = ma.auto_field() + + class BlacklistToken(db.Model): """Token that has been revoked (but has not expired yet.) diff --git a/sachet/server/users/views.py b/sachet/server/users/views.py index 28e676d..808ea17 100644 --- a/sachet/server/users/views.py +++ b/sachet/server/users/views.py @@ -1,9 +1,11 @@ import jwt from flask import Blueprint, request, jsonify from flask.views import MethodView -from sachet.server.models import auth_required, read_token, User, BlacklistToken +from sachet.server.models import auth_required, read_token, User, UserSchema, BlacklistToken from sachet.server import bcrypt, db +user_schema = UserSchema() + users_blueprint = Blueprint("users_blueprint", __name__) class LoginAPI(MethodView): @@ -114,10 +116,7 @@ class UserAPI(MethodView): } return jsonify(resp), 403 - return jsonify({ - "username": info_user.username, - "admin": info_user.admin, - }) + return jsonify(user_schema.dump(info_user)) users_blueprint.add_url_rule( "/users/", diff --git a/tests/conftest.py b/tests/conftest.py index 6f1e692..a218ece 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,9 +57,15 @@ def users(client): @pytest.fixture def validate_info(users): """Given a dictionary, validate the information against a given user's info.""" + + verify_fields = [ + "username", + "admin", + ] + def _validate(user, info): - for k, v in info.items(): - assert users[user][k] == v + for k in verify_fields: + assert users[user][k] == info[k] return _validate