linted everything with black

This commit is contained in:
dogeystamp 2023-03-30 20:20:09 -04:00
parent f01d1d0e54
commit 27f6703318
Signed by: dogeystamp
GPG Key ID: 7225FE3592EFFA38
13 changed files with 216 additions and 298 deletions

View File

@ -34,3 +34,11 @@ Tests are available with the following command:
``` ```
pytest --cov --cov-report term-missing pytest --cov --cov-report term-missing
``` ```
### linting
Please use the linter before submitting code.
```
black .
```

View File

@ -1,6 +1,7 @@
attrs==22.2.0 attrs==22.2.0
bcrypt==4.0.1 bcrypt==4.0.1
bitmask @ git+https://github.com/dogeystamp/bitmask@8524113fcdc22a570bda77d440374f5f269fdb79 bitmask @ git+https://github.com/dogeystamp/bitmask@8524113fcdc22a570bda77d440374f5f269fdb79
black==23.3.0
click==8.1.3 click==8.1.3
coverage==7.2.1 coverage==7.2.1
exceptiongroup==1.1.0 exceptiongroup==1.1.0
@ -18,7 +19,10 @@ Jinja2==3.1.2
MarkupSafe==2.1.2 MarkupSafe==2.1.2
marshmallow==3.19.0 marshmallow==3.19.0
marshmallow-sqlalchemy==0.29.0 marshmallow-sqlalchemy==0.29.0
mypy-extensions==1.0.0
packaging==23.0 packaging==23.0
pathspec==0.11.1
platformdirs==3.2.0
pluggy==1.0.0 pluggy==1.0.0
PyJWT==2.6.0 PyJWT==2.6.0
pytest==7.2.2 pytest==7.2.2

View File

@ -24,6 +24,7 @@ ma = Marshmallow()
import sachet.server.commands import sachet.server.commands
from sachet.server.users.views import users_blueprint from sachet.server.users.views import users_blueprint
app.register_blueprint(users_blueprint) app.register_blueprint(users_blueprint)
with app.app_context(): with app.app_context():

View File

@ -8,29 +8,45 @@ from bitmask import Bitmask
db_cli = AppGroup("db") db_cli = AppGroup("db")
@db_cli.command("create") @db_cli.command("create")
def create_db(): def create_db():
"""Create all db tables.""" """Create all db tables."""
db.create_all() db.create_all()
@db_cli.command("drop") @db_cli.command("drop")
@click.option('--yes', is_flag=True, expose_value=False, prompt="Are you sure you want to drop all tables?") @click.option(
"--yes",
is_flag=True,
expose_value=False,
prompt="Are you sure you want to drop all tables?",
)
def drop_db(): def drop_db():
"""Drop all db tables.""" """Drop all db tables."""
db.drop_all() db.drop_all()
app.cli.add_command(db_cli) app.cli.add_command(db_cli)
user_cli = AppGroup("user") user_cli = AppGroup("user")
@user_cli.command("create") @user_cli.command("create")
@click.option("--admin", default=False, prompt="Set this user as administrator?", help="Set this user an administrator.") @click.option(
"--admin",
default=False,
prompt="Set this user as administrator?",
help="Set this user an administrator.",
)
@click.option("--username", prompt="Username", help="Sets the username.") @click.option("--username", prompt="Username", help="Sets the username.")
@click.option("--password", @click.option(
prompt="Password", "--password",
hide_input=True, prompt="Password",
help="Sets the user's password (for security, avoid setting this from the command line).") hide_input=True,
help="Sets the user's password (for security, avoid setting this from the command line).",
)
def create_user(admin, username, password): def create_user(admin, username, password):
"""Create a user directly in the database.""" """Create a user directly in the database."""
perms = Bitmask() perms = Bitmask()
@ -38,10 +54,17 @@ def create_user(admin, username, password):
perms.add(Permissions.ADMIN) perms.add(Permissions.ADMIN)
manage.create_user(perms, username, password) manage.create_user(perms, username, password)
@user_cli.command("delete") @user_cli.command("delete")
@click.argument("username") @click.argument("username")
@click.option('--yes', is_flag=True, expose_value=False, prompt=f"Are you sure you want to delete this user?") @click.option(
"--yes",
is_flag=True,
expose_value=False,
prompt=f"Are you sure you want to delete this user?",
)
def delete_user(username): def delete_user(username):
manage.delete_user_by_username(username) manage.delete_user_by_username(username)
app.cli.add_command(user_cli) app.cli.add_command(user_cli)

View File

@ -3,22 +3,27 @@ import yaml
sqlalchemy_base = "sqlite:///sachet" sqlalchemy_base = "sqlite:///sachet"
class BaseConfig: class BaseConfig:
SQLALCHEMY_DATABASE_URI = sqlalchemy_base + ".db" SQLALCHEMY_DATABASE_URI = sqlalchemy_base + ".db"
BCRYPT_LOG_ROUNDS = 13 BCRYPT_LOG_ROUNDS = 13
SQLALCHEMY_TRACK_MODIFICATIONS = False SQLALCHEMY_TRACK_MODIFICATIONS = False
class TestingConfig(BaseConfig): class TestingConfig(BaseConfig):
SQLALCHEMY_DATABASE_URI = sqlalchemy_base + "_test" + ".db" SQLALCHEMY_DATABASE_URI = sqlalchemy_base + "_test" + ".db"
BCRYPT_LOG_ROUNDS = 4 BCRYPT_LOG_ROUNDS = 4
class DevelopmentConfig(BaseConfig): class DevelopmentConfig(BaseConfig):
SQLALCHEMY_DATABASE_URI = sqlalchemy_base + "_dev" + ".db" SQLALCHEMY_DATABASE_URI = sqlalchemy_base + "_dev" + ".db"
BCRYPT_LOG_ROUNDS = 4 BCRYPT_LOG_ROUNDS = 4
class ProductionConfig(BaseConfig): class ProductionConfig(BaseConfig):
pass pass
def overlay_config(base, config_file=None): def overlay_config(base, config_file=None):
"""Reading from a YAML file, this overrides configuration options from the bases above.""" """Reading from a YAML file, this overrides configuration options from the bases above."""
config_locations = [config_file, "/etc/sachet/config.yml", "./config.yml"] config_locations = [config_file, "/etc/sachet/config.yml", "./config.yml"]
@ -33,7 +38,9 @@ def overlay_config(base, config_file=None):
break break
if config_path == "": if config_path == "":
raise FileNotFoundError("Please create a configuration: copy config.yml.example to config.yml.") raise FileNotFoundError(
"Please create a configuration: copy config.yml.example to config.yml."
)
config = yaml.safe_load(open(config_path)) config = yaml.safe_load(open(config_path))
@ -41,6 +48,7 @@ def overlay_config(base, config_file=None):
raise ValueError("Please set secret_key within the configuration.") raise ValueError("Please set secret_key within the configuration.")
from sachet.server import app from sachet.server import app
app.config.from_object(base) app.config.from_object(base)
for k, v in config.items(): for k, v in config.items():

View File

@ -10,11 +10,11 @@ from enum import IntFlag
class Permissions(IntFlag): class Permissions(IntFlag):
CREATE = 1 CREATE = 1
MODIFY = 1<<1 MODIFY = 1 << 1
DELETE = 1<<2 DELETE = 1 << 2
LOCK = 1<<3 LOCK = 1 << 3
LIST = 1<<4 LIST = 1 << 4
ADMIN = 1<<5 ADMIN = 1 << 5
def patch(orig, diff): def patch(orig, diff):
@ -25,13 +25,13 @@ def patch(orig, diff):
return diff return diff
# deep copy # deep copy
new = {k:v for k, v in orig.items()} new = {k: v for k, v in orig.items()}
for key, value in diff.items(): for key, value in diff.items():
new[key] = patch(orig.get(key, {}), diff[key]) new[key] = patch(orig.get(key, {}), diff[key])
return new return new
class User(db.Model): class User(db.Model):
__tablename__ = "users" __tablename__ = "users"
@ -48,7 +48,7 @@ class User(db.Model):
Bitmask listing all permissions. Bitmask listing all permissions.
See the Permissions class for all possible permissions. See the Permissions class for all possible permissions.
Also, see https://github.com/dogeystamp/bitmask for information on how Also, see https://github.com/dogeystamp/bitmask for information on how
to use this field. to use this field.
""" """
@ -66,31 +66,25 @@ class User(db.Model):
self.permissions_number = mask.value self.permissions_number = mask.value
db.session.commit() db.session.commit()
def __init__(self, username, password, permissions): def __init__(self, username, password, permissions):
permissions.AllFlags = Permissions permissions.AllFlags = Permissions
self.permissions = permissions self.permissions = permissions
self.password = bcrypt.generate_password_hash( self.password = bcrypt.generate_password_hash(
password, app.config.get("BCRYPT_LOG_ROUNDS") password, app.config.get("BCRYPT_LOG_ROUNDS")
).decode() ).decode()
self.username = username self.username = username
self.register_date = datetime.datetime.now() self.register_date = datetime.datetime.now()
def encode_token(self, jti=None): def encode_token(self, jti=None):
"""Generates an authentication token""" """Generates an authentication token"""
payload = { payload = {
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=7), "exp": datetime.datetime.utcnow() + datetime.timedelta(days=7),
"iat": datetime.datetime.utcnow(), "iat": datetime.datetime.utcnow(),
"sub": self.username, "sub": self.username,
"jti": jti "jti": jti,
} }
return jwt.encode( return jwt.encode(payload, app.config.get("SECRET_KEY"), algorithm="HS256")
payload,
app.config.get("SECRET_KEY"),
algorithm="HS256"
)
class PermissionField(fields.Field): class PermissionField(fields.Field):
@ -133,9 +127,9 @@ class BlacklistToken(db.Model):
""" """
__tablename__ = "blacklist_tokens" __tablename__ = "blacklist_tokens"
id = db.Column(db.Integer, primary_key=True, autoincrement=True) id = db.Column(db.Integer, primary_key=True, autoincrement=True)
token = db.Column(db.String(500), unique=True, nullable=False) token = db.Column(db.String(500), unique=True, nullable=False)
expires = db.Column(db.DateTime, nullable=False) expires = db.Column(db.DateTime, nullable=False)
def __init__(self, token): def __init__(self, token):
@ -190,6 +184,7 @@ def auth_required(f):
Passes an argument 'user' to the function, with a User object corresponding Passes an argument 'user' to the function, with a User object corresponding
to the authenticated session. to the authenticated session.
""" """
@wraps(f) @wraps(f)
def decorator(*args, **kwargs): def decorator(*args, **kwargs):
token = None token = None
@ -198,10 +193,7 @@ def auth_required(f):
try: try:
token = auth_header.split(" ")[1] token = auth_header.split(" ")[1]
except IndexError: except IndexError:
resp = { resp = {"status": "fail", "message": "Malformed Authorization header."}
"status": "fail",
"message": "Malformed Authorization header."
}
return jsonify(resp), 401 return jsonify(resp), 401
if not token: if not token:

View File

@ -1,6 +1,7 @@
from sachet.server import app, db from sachet.server import app, db
from sachet.server.models import User from sachet.server.models import User
def create_user(permissions, username, password): def create_user(permissions, username, password):
# to reduce confusion with API endpoints # to reduce confusion with API endpoints
forbidden = {"login", "logout", "extend"} forbidden = {"login", "logout", "extend"}
@ -10,16 +11,13 @@ def create_user(permissions, username, password):
user = User.query.filter_by(username=username).first() user = User.query.filter_by(username=username).first()
if not user: if not user:
user = User( user = User(username=username, password=password, permissions=permissions)
username=username,
password=password,
permissions=permissions
)
db.session.add(user) db.session.add(user)
db.session.commit() db.session.commit()
else: else:
raise KeyError(f"User '{username}' already exists.") raise KeyError(f"User '{username}' already exists.")
def delete_user_by_username(username): def delete_user_by_username(username):
user = User.query.filter_by(username=username).first() user = User.query.filter_by(username=username).first()

View File

@ -1,7 +1,15 @@
import jwt import jwt
from flask import Blueprint, request, jsonify from flask import Blueprint, request, jsonify
from flask.views import MethodView from flask.views import MethodView
from sachet.server.models import auth_required, read_token, patch, Permissions, User, UserSchema, BlacklistToken from sachet.server.models import (
auth_required,
read_token,
patch,
Permissions,
User,
UserSchema,
BlacklistToken,
)
from sachet.server import bcrypt, db from sachet.server import bcrypt, db
from marshmallow import ValidationError from marshmallow import ValidationError
@ -9,26 +17,22 @@ user_schema = UserSchema()
users_blueprint = Blueprint("users_blueprint", __name__) users_blueprint = Blueprint("users_blueprint", __name__)
class LoginAPI(MethodView): class LoginAPI(MethodView):
def post(self): def post(self):
post_data = request.get_json() post_data = request.get_json()
user = User.query.filter_by(username=post_data.get("username")).first() user = User.query.filter_by(username=post_data.get("username")).first()
if not user: if not user:
resp = { resp = {"status": "fail", "message": "Invalid credentials."}
"status": "fail",
"message": "Invalid credentials."
}
return jsonify(resp), 401 return jsonify(resp), 401
if bcrypt.check_password_hash( if bcrypt.check_password_hash(user.password, post_data.get("password", "")):
user.password, post_data.get("password", "")
):
token = user.encode_token() token = user.encode_token()
resp = { resp = {
"status": "success", "status": "success",
"message": "Logged in.", "message": "Logged in.",
"username": user.username, "username": user.username,
"auth_token": token "auth_token": token,
} }
return jsonify(resp), 200 return jsonify(resp), 200
else: else:
@ -40,9 +44,7 @@ class LoginAPI(MethodView):
users_blueprint.add_url_rule( users_blueprint.add_url_rule(
"/users/login", "/users/login", view_func=LoginAPI.as_view("login_api"), methods=["POST"]
view_func=LoginAPI.as_view("login_api"),
methods=['POST']
) )
@ -54,7 +56,10 @@ class LogoutAPI(MethodView):
post_data = request.get_json() post_data = request.get_json()
token = post_data.get("token") token = post_data.get("token")
if not token: if not token:
return jsonify({"status": "fail", "message": "Specify a token to revoke."}), 400 return (
jsonify({"status": "fail", "message": "Specify a token to revoke."}),
400,
)
res = BlacklistToken.check_blacklist(token) res = BlacklistToken.check_blacklist(token)
if res: if res:
@ -73,13 +78,19 @@ class LogoutAPI(MethodView):
db.session.commit() db.session.commit()
return jsonify({"status": "success", "message": "Token revoked."}), 200 return jsonify({"status": "success", "message": "Token revoked."}), 200
else: else:
return jsonify({"status": "fail", "message": "You are not allowed to revoke this token."}), 403 return (
jsonify(
{
"status": "fail",
"message": "You are not allowed to revoke this token.",
}
),
403,
)
users_blueprint.add_url_rule( users_blueprint.add_url_rule(
"/users/logout", "/users/logout", view_func=LogoutAPI.as_view("logout_api"), methods=["POST"]
view_func=LogoutAPI.as_view("logout_api"),
methods=['POST']
) )
@ -93,27 +104,28 @@ class ExtendAPI(MethodView):
"status": "success", "status": "success",
"message": "Renewed token.", "message": "Renewed token.",
"username": user.username, "username": user.username,
"auth_token": token "auth_token": token,
} }
return jsonify(resp), 200 return jsonify(resp), 200
users_blueprint.add_url_rule( users_blueprint.add_url_rule(
"/users/extend", "/users/extend", view_func=ExtendAPI.as_view("extend_api"), methods=["POST"]
view_func=ExtendAPI.as_view("extend_api"),
methods=['POST']
) )
class UserAPI(MethodView): class UserAPI(MethodView):
"""User information API""" """User information API"""
@auth_required @auth_required
def get(user, self, username): def get(user, self, username):
info_user = User.query.filter_by(username=username).first() info_user = User.query.filter_by(username=username).first()
if (not info_user) or (info_user != user and Permissions.ADMIN not in user.permissions): if (not info_user) or (
info_user != user and Permissions.ADMIN not in user.permissions
):
resp = { resp = {
"status": "fail", "status": "fail",
"message": "You are not authorized to view this page." "message": "You are not authorized to view this page.",
} }
return jsonify(resp), 403 return jsonify(resp), 403
@ -126,22 +138,19 @@ class UserAPI(MethodView):
if not patch_user or Permissions.ADMIN not in user.permissions: if not patch_user or Permissions.ADMIN not in user.permissions:
resp = { resp = {
"status": "fail", "status": "fail",
"message": "You are not authorized to access this page." "message": "You are not authorized to access this page.",
} }
return jsonify(resp), 403 return jsonify(resp), 403
patch_json = request.get_json() patch_json = request.get_json()
orig_json = user_schema.dump(patch_user) orig_json = user_schema.dump(patch_user)
new_json = patch(orig_json, patch_json) new_json = patch(orig_json, patch_json)
try: try:
deserialized = user_schema.load(new_json) deserialized = user_schema.load(new_json)
except ValidationError as e: except ValidationError as e:
resp = { resp = {"status": "fail", "message": f"Invalid patch: {str(e)}"}
"status": "fail",
"message": f"Invalid patch: {str(e)}"
}
return jsonify(resp), 400 return jsonify(resp), 400
for k, v in deserialized.items(): for k, v in deserialized.items():
@ -159,7 +168,7 @@ class UserAPI(MethodView):
if not put_user or Permissions.ADMIN not in user.permissions: if not put_user or Permissions.ADMIN not in user.permissions:
resp = { resp = {
"status": "fail", "status": "fail",
"message": "You are not authorized to access this page." "message": "You are not authorized to access this page.",
} }
return jsonify(resp), 403 return jsonify(resp), 403
@ -168,10 +177,7 @@ class UserAPI(MethodView):
try: try:
deserialized = user_schema.load(new_json) deserialized = user_schema.load(new_json)
except ValidationError as e: except ValidationError as e:
resp = { resp = {"status": "fail", "message": f"Invalid data: {str(e)}"}
"status": "fail",
"message": f"Invalid data: {str(e)}"
}
return jsonify(resp), 400 return jsonify(resp), 400
for k, v in deserialized.items(): for k, v in deserialized.items():
@ -182,8 +188,9 @@ class UserAPI(MethodView):
} }
return jsonify(resp), 200 return jsonify(resp), 200
users_blueprint.add_url_rule( users_blueprint.add_url_rule(
"/users/<username>", "/users/<username>",
view_func=UserAPI.as_view("user_api"), view_func=UserAPI.as_view("user_api"),
methods=['GET', 'PATCH', 'PUT'] methods=["GET", "PATCH", "PUT"],
) )

View File

@ -8,6 +8,7 @@ from bitmask import Bitmask
user_schema = UserSchema() user_schema = UserSchema()
@pytest.fixture @pytest.fixture
def client(): def client():
"""Flask application with DB already set up and ready.""" """Flask application with DB already set up and ready."""
@ -19,7 +20,7 @@ def client():
yield client yield client
db.session.remove() db.session.remove()
db.drop_all() db.drop_all()
@pytest.fixture @pytest.fixture
def flask_app_bare(): def flask_app_bare():
@ -37,23 +38,13 @@ def users(client):
Returns a dictionary with all the info for each user. Returns a dictionary with all the info for each user.
""" """
userinfo = dict( userinfo = dict(
jeff = dict( jeff=dict(password="1234", permissions=Bitmask()),
password = "1234", administrator=dict(password="4321", permissions=Bitmask(Permissions.ADMIN)),
permissions = Bitmask() )
),
administrator = dict(
password = "4321",
permissions = Bitmask(Permissions.ADMIN)
),
)
for user, info in userinfo.items(): for user, info in userinfo.items():
info["username"] = user info["username"] = user
manage.create_user( manage.create_user(info["permissions"], info["username"], info["password"])
info["permissions"],
info["username"],
info["password"]
)
return userinfo return userinfo
@ -86,10 +77,10 @@ def tokens(client, users):
toks = {} toks = {}
for user, creds in users.items(): for user, creds in users.items():
resp = client.post("/users/login", json={ resp = client.post(
"username": creds["username"], "/users/login",
"password": creds["password"] json={"username": creds["username"], "password": creds["password"]},
}) )
resp_json = resp.get_json() resp_json = resp.get_json()
token = resp_json.get("auth_token") token = resp_json.get("auth_token")
assert token is not None and token != "" assert token is not None and token != ""

View File

@ -3,81 +3,64 @@ import jwt
from sachet.server import db from sachet.server import db
from sachet.server.users import manage from sachet.server.users import manage
def test_reserved_users(client): def test_reserved_users(client):
"""Test that the server prevents reserved endpoints from being registered as usernames.""" """Test that the server prevents reserved endpoints from being registered as usernames."""
for user in ["login", "logout", "extend"]: for user in ["login", "logout", "extend"]:
with pytest.raises(KeyError): with pytest.raises(KeyError):
manage.create_user(False, user, "") manage.create_user(False, user, "")
def test_unauth_perms(client): def test_unauth_perms(client):
"""Test endpoints to see if they allow unauthenticated users.""" """Test endpoints to see if they allow unauthenticated users."""
resp = client.get("/users/jeff") resp = client.get("/users/jeff")
assert resp.status_code == 401 assert resp.status_code == 401
def test_malformed_authorization(client): def test_malformed_authorization(client):
"""Test attempting authorization incorrectly.""" """Test attempting authorization incorrectly."""
# incorrect token # incorrect token
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
resp = client.get( resp = client.get("/users/jeff", headers={"Authorization": f"bearer {token}"})
"/users/jeff",
headers={
"Authorization": f"bearer {token}"
}
)
assert resp.status_code == 401 assert resp.status_code == 401
# token for incorrect user (but properly signed) # token for incorrect user (but properly signed)
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.nZ86hUWPdG43W6HVSGFy6DJnDVOZhx8a73LhQ3gIxY8" token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.nZ86hUWPdG43W6HVSGFy6DJnDVOZhx8a73LhQ3gIxY8"
resp = client.get( resp = client.get("/users/jeff", headers={"Authorization": f"bearer {token}"})
"/users/jeff",
headers={
"Authorization": f"bearer {token}"
}
)
assert resp.status_code == 401 assert resp.status_code == 401
# invalid token # invalid token
token = "not a.real JWT.token" token = "not a.real JWT.token"
resp = client.get( resp = client.get("/users/jeff", headers={"Authorization": f"bearer {token}"})
"/users/jeff",
headers={
"Authorization": f"bearer {token}"
}
)
assert resp.status_code == 401 assert resp.status_code == 401
# missing token # missing token
resp = client.get( resp = client.get("/users/jeff", headers={"Authorization": "bearer"})
"/users/jeff",
headers={
"Authorization": "bearer"
}
)
assert resp.status_code == 401 assert resp.status_code == 401
def test_login(client, users): def test_login(client, users):
"""Test logging in.""" """Test logging in."""
# wrong password # wrong password
resp = client.post("/users/login", json={ resp = client.post(
"username": "jeff", "/users/login",
"password": users["jeff"]["password"] + "garbage" json={"username": "jeff", "password": users["jeff"]["password"] + "garbage"},
}) )
assert resp.status_code == 401 assert resp.status_code == 401
# wrong user # wrong user
resp = client.post("/users/login", json={ resp = client.post(
"username": "jeffery", "/users/login",
"password": users["jeff"]["password"] + "garbage" json={"username": "jeffery", "password": users["jeff"]["password"] + "garbage"},
}) )
assert resp.status_code == 401 assert resp.status_code == 401
# logging in correctly # logging in correctly
resp = client.post("/users/login", json={ resp = client.post(
"username": "jeff", "/users/login", json={"username": "jeff", "password": users["jeff"]["password"]}
"password": users["jeff"]["password"] )
})
assert resp.status_code == 200 assert resp.status_code == 200
resp_json = resp.get_json() resp_json = resp.get_json()
assert resp_json.get("status") == "success" assert resp_json.get("status") == "success"
@ -90,10 +73,8 @@ def test_extend(client, tokens, validate_info):
"""Test extending the token lifespan (get a new one with later expiry).""" """Test extending the token lifespan (get a new one with later expiry)."""
# obtain new token # obtain new token
resp = client.post("/users/extend", resp = client.post(
headers={ "/users/extend", headers={"Authorization": f"Bearer {tokens['jeff']}"}
"Authorization": f"Bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 200 assert resp.status_code == 200
resp_json = resp.get_json() resp_json = resp.get_json()
@ -103,20 +84,14 @@ def test_extend(client, tokens, validate_info):
# revoke old token # revoke old token
resp = client.post("/users/logout", json={ resp = client.post(
"token": tokens["jeff"] "/users/logout",
}, json={"token": tokens["jeff"]},
headers={ headers={"Authorization": f"bearer {tokens['jeff']}"},
"Authorization": f"bearer {tokens['jeff']}"
}
) )
# log in with the new token # log in with the new token
resp = client.get("/users/jeff", resp = client.get("/users/jeff", headers={"Authorization": f"Bearer {new_token}"})
headers={
"Authorization": f"Bearer {new_token}"
}
)
assert resp.status_code == 200 assert resp.status_code == 200
resp_json = resp.get_json() resp_json = resp.get_json()
validate_info("jeff", resp_json) validate_info("jeff", resp_json)
@ -126,98 +101,79 @@ def test_logout(client, tokens, validate_info):
"""Test logging out.""" """Test logging out."""
# unauthenticated # unauthenticated
resp = client.post("/users/logout", json={ resp = client.post(
"token": tokens["jeff"] "/users/logout",
}, json={"token": tokens["jeff"]},
) )
assert resp.status_code == 401 assert resp.status_code == 401
# missing token # missing token
resp = client.post("/users/logout", json={ resp = client.post(
"/users/logout", json={}, headers={"Authorization": f"bearer {tokens['jeff']}"}
},
headers={
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 400 assert resp.status_code == 400
# invalid token # invalid token
resp = client.post("/users/logout", json={ resp = client.post(
"token": "not.real.jwt" "/users/logout",
}, json={"token": "not.real.jwt"},
headers={ headers={"Authorization": f"bearer {tokens['jeff']}"},
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 400 assert resp.status_code == 400
# wrong user's token # wrong user's token
resp = client.post("/users/logout", json={ resp = client.post(
"token": tokens["administrator"] "/users/logout",
}, json={"token": tokens["administrator"]},
headers={ headers={"Authorization": f"bearer {tokens['jeff']}"},
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 403 assert resp.status_code == 403
# check that we can access this endpoint before logging out # check that we can access this endpoint before logging out
resp = client.get("/users/jeff", resp = client.get(
headers={ "/users/jeff", headers={"Authorization": f"bearer {tokens['jeff']}"}
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 200 assert resp.status_code == 200
validate_info("jeff", resp.get_json()) validate_info("jeff", resp.get_json())
# valid logout # valid logout
resp = client.post("/users/logout", json={ resp = client.post(
"token": tokens["jeff"] "/users/logout",
}, json={"token": tokens["jeff"]},
headers={ headers={"Authorization": f"bearer {tokens['jeff']}"},
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 200 assert resp.status_code == 200
# check that the logout worked # check that the logout worked
resp = client.get("/users/jeff", resp = client.get(
headers={ "/users/jeff", headers={"Authorization": f"bearer {tokens['jeff']}"}
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 401 assert resp.status_code == 401
def test_admin_revoke(client, tokens, validate_info): def test_admin_revoke(client, tokens, validate_info):
"""Test that an admin can revoke any token from other users.""" """Test that an admin can revoke any token from other users."""
resp = client.post("/users/logout", json={ resp = client.post(
"token": tokens["jeff"] "/users/logout",
}, json={"token": tokens["jeff"]},
headers={ headers={"Authorization": f"bearer {tokens['administrator']}"},
"Authorization": f"bearer {tokens['administrator']}"
}
) )
assert resp.status_code == 200 assert resp.status_code == 200
# check that the logout worked # check that the logout worked
resp = client.get("/users/jeff", resp = client.get(
headers={ "/users/jeff", headers={"Authorization": f"bearer {tokens['jeff']}"}
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 401 assert resp.status_code == 401
# try revoking twice # try revoking twice
resp = client.post("/users/logout", json={ resp = client.post(
"token": tokens["jeff"] "/users/logout",
}, json={"token": tokens["jeff"]},
headers={ headers={"Authorization": f"bearer {tokens['administrator']}"},
"Authorization": f"bearer {tokens['administrator']}"
}
) )
assert resp.status_code == 400 assert resp.status_code == 400

View File

@ -5,6 +5,7 @@ from sqlalchemy import inspect
from sachet.server.models import User from sachet.server.models import User
def test_db(flask_app_bare, cli): def test_db(flask_app_bare, cli):
"""Test the CLI's ability to create and drop the DB.""" """Test the CLI's ability to create and drop the DB."""
# make tables # make tables

View File

@ -1,74 +1,27 @@
from sachet.server.models import patch from sachet.server.models import patch
def test_patch(): def test_patch():
"""Tests sachet/server/models.py's patch() method for dicts.""" """Tests sachet/server/models.py's patch() method for dicts."""
assert patch( assert patch(dict(), dict()) == dict()
dict(),
dict()
) == dict()
assert patch( assert patch(dict(key="value"), dict()) == dict(key="value")
dict(key="value"),
dict()
) == dict(key="value")
assert patch( assert patch(dict(key="value"), dict(key="newvalue")) == dict(key="newvalue")
dict(key="value"),
dict(key="newvalue")
) == dict(key="newvalue")
assert patch( assert patch(dict(key="value"), dict(key="newvalue")) == dict(key="newvalue")
dict(key="value"),
dict(key="newvalue")
) == dict(key="newvalue")
assert patch( assert patch(dict(key="value"), dict(key2="other_value")) == dict(
dict(key="value"), key="value", key2="other_value"
dict(key2="other_value")
) == dict(
key="value",
key2="other_value"
) )
assert patch( assert patch(
dict( dict(nest=dict(key="value", key2="other_value")),
nest = dict( dict(top_key="newvalue", nest=dict(key2="new_other_value")),
key="value", ) == dict(top_key="newvalue", nest=dict(key="value", key2="new_other_value"))
key2="other_value"
)
),
dict(
top_key="newvalue",
nest = dict(
key2 = "new_other_value"
)
)
) == dict(
top_key="newvalue",
nest = dict(
key="value",
key2="new_other_value"
)
)
assert patch( assert patch(
dict( dict(nest=dict(key="value", list=[1, 2, 3, 4, 5])),
nest = dict( dict(top_key="newvalue", nest=dict(list=[3, 1, 4, 1, 5])),
key="value", ) == dict(top_key="newvalue", nest=dict(key="value", list=[3, 1, 4, 1, 5]))
list=[1, 2, 3, 4, 5]
)
),
dict(
top_key="newvalue",
nest = dict(
list = [3, 1, 4, 1, 5]
)
)
) == dict(
top_key="newvalue",
nest = dict(
key="value",
list=[3, 1, 4, 1, 5]
)
)

View File

@ -5,81 +5,66 @@ from datetime import datetime
user_schema = UserSchema() user_schema = UserSchema()
def test_get(client, tokens, validate_info): def test_get(client, tokens, validate_info):
"""Test accessing the user information endpoint as a normal user.""" """Test accessing the user information endpoint as a normal user."""
# access user info endpoint # access user info endpoint
resp = client.get( resp = client.get(
"/users/jeff", "/users/jeff", headers={"Authorization": f"bearer {tokens['jeff']}"}
headers={
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 200 assert resp.status_code == 200
validate_info("jeff", resp.get_json()) validate_info("jeff", resp.get_json())
# access other user's info endpoint # access other user's info endpoint
resp = client.get( resp = client.get(
"/users/administrator", "/users/administrator", headers={"Authorization": f"bearer {tokens['jeff']}"}
headers={
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 403 assert resp.status_code == 403
def test_userinfo_admin(client, tokens, validate_info): def test_userinfo_admin(client, tokens, validate_info):
"""Test accessing other user's information as an admin.""" """Test accessing other user's information as an admin."""
# first test that admin can access its own info # first test that admin can access its own info
resp = client.get( resp = client.get(
"/users/administrator", "/users/administrator",
headers={ headers={"Authorization": f"bearer {tokens['administrator']}"},
"Authorization": f"bearer {tokens['administrator']}"
}
) )
assert resp.status_code == 200 assert resp.status_code == 200
validate_info("administrator", resp.get_json()) validate_info("administrator", resp.get_json())
# now test accessing other user's info # now test accessing other user's info
resp = client.get( resp = client.get(
"/users/jeff", "/users/jeff", headers={"Authorization": f"bearer {tokens['administrator']}"}
headers={
"Authorization": f"bearer {tokens['administrator']}"
}
) )
assert resp.status_code == 200 assert resp.status_code == 200
validate_info("jeff", resp.get_json()) validate_info("jeff", resp.get_json())
def test_patch(client, users, tokens, validate_info): def test_patch(client, users, tokens, validate_info):
"""Test modifying user information as an administrator.""" """Test modifying user information as an administrator."""
# try with regular user to make sure it doesn't work # try with regular user to make sure it doesn't work
resp = client.patch( resp = client.patch(
"/users/jeff", "/users/jeff",
json = { "permissions": ["ADMIN"] }, json={"permissions": ["ADMIN"]},
headers={ headers={"Authorization": f"bearer {tokens['jeff']}"},
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 403 assert resp.status_code == 403
# test malformed patch # test malformed patch
resp = client.patch( resp = client.patch(
"/users/jeff", "/users/jeff",
json = "hurr durr", json="hurr durr",
headers={ headers={"Authorization": f"bearer {tokens['administrator']}"},
"Authorization": f"bearer {tokens['administrator']}"
}
) )
assert resp.status_code == 400 assert resp.status_code == 400
resp = client.patch( resp = client.patch(
"/users/jeff", "/users/jeff",
json = { "permissions": ["ADMIN"] }, json={"permissions": ["ADMIN"]},
headers={ headers={"Authorization": f"bearer {tokens['administrator']}"},
"Authorization": f"bearer {tokens['administrator']}"
}
) )
assert resp.status_code == 200 assert resp.status_code == 200
@ -88,37 +73,31 @@ def test_patch(client, users, tokens, validate_info):
# request new info # request new info
resp = client.get( resp = client.get(
"/users/jeff", "/users/jeff", headers={"Authorization": f"bearer {tokens['jeff']}"}
headers={
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 200 assert resp.status_code == 200
validate_info("jeff", resp.get_json()) validate_info("jeff", resp.get_json())
def test_put(client, users, tokens, validate_info): def test_put(client, users, tokens, validate_info):
"""Test replacing user information as an administrator.""" """Test replacing user information as an administrator."""
# try with regular user to make sure it doesn't work # try with regular user to make sure it doesn't work
resp = client.patch( resp = client.patch(
"/users/jeff", "/users/jeff",
json = dict(), json=dict(),
headers={ headers={"Authorization": f"bearer {tokens['jeff']}"},
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 403 assert resp.status_code == 403
new_data = {k:v for k, v in users["jeff"].items()} new_data = {k: v for k, v in users["jeff"].items()}
new_data["permissions"] = Bitmask(Permissions.ADMIN) new_data["permissions"] = Bitmask(Permissions.ADMIN)
new_data["register_date"] = datetime(2022,2,2,0,0,0) new_data["register_date"] = datetime(2022, 2, 2, 0, 0, 0)
resp = client.put( resp = client.put(
"/users/jeff", "/users/jeff",
json = user_schema.dump(new_data), json=user_schema.dump(new_data),
headers={ headers={"Authorization": f"bearer {tokens['administrator']}"},
"Authorization": f"bearer {tokens['administrator']}"
}
) )
assert resp.status_code == 200 assert resp.status_code == 200
@ -127,10 +106,7 @@ def test_put(client, users, tokens, validate_info):
# request new info # request new info
resp = client.get( resp = client.get(
"/users/jeff", "/users/jeff", headers={"Authorization": f"bearer {tokens['jeff']}"}
headers={
"Authorization": f"bearer {tokens['jeff']}"
}
) )
assert resp.status_code == 200 assert resp.status_code == 200
validate_info("jeff", resp.get_json()) validate_info("jeff", resp.get_json())