/users: added endpoint for creating, listing users

This commit is contained in:
dogeystamp 2023-05-18 22:06:16 -04:00
parent 71683662ea
commit cbf1310e1b
Signed by: dogeystamp
GPG Key ID: 7225FE3592EFFA38
5 changed files with 113 additions and 6 deletions

View File

@ -91,6 +91,7 @@ class User(db.Model):
password, current_app.config.get("BCRYPT_LOG_ROUNDS") password, current_app.config.get("BCRYPT_LOG_ROUNDS")
).decode() ).decode()
self.username = username self.username = username
self.url = url_for("users_blueprint.user_list_api", username=self.username)
self.register_date = datetime.datetime.now() self.register_date = datetime.datetime.now()
def encode_token(self, jti=None): def encode_token(self, jti=None):
@ -133,7 +134,8 @@ class User(db.Model):
model = self model = self
username = ma.auto_field() username = ma.auto_field()
register_date = ma.auto_field() register_date = ma.auto_field(dump_only=True)
password = ma.auto_field(load_only=True, required=False)
permissions = PermissionField() permissions = PermissionField()
return Schema() return Schema()

View File

@ -6,7 +6,7 @@ from sachet.server.models import (
User, User,
BlacklistToken, BlacklistToken,
) )
from sachet.server.views_common import ModelAPI, auth_required from sachet.server.views_common import ModelAPI, ModelListAPI, auth_required
from sachet.server import bcrypt, db from sachet.server import bcrypt, db
users_blueprint = Blueprint("users_blueprint", __name__) users_blueprint = Blueprint("users_blueprint", __name__)
@ -142,3 +142,21 @@ users_blueprint.add_url_rule(
view_func=UserAPI.as_view("user_api"), view_func=UserAPI.as_view("user_api"),
methods=["GET", "PATCH", "PUT"], methods=["GET", "PATCH", "PUT"],
) )
class UserListAPI(ModelListAPI):
@auth_required(required_permissions=(Permissions.ADMIN,))
def post(self, auth_user=None):
data = request.get_json()
return super().post(User, data)
@auth_required(required_permissions=(Permissions.ADMIN,))
def get(self, auth_user=None):
return super().get(User)
users_blueprint.add_url_rule(
"/users",
view_func=UserListAPI.as_view("user_list_api"),
methods=["GET", "POST"],
)

View File

@ -145,10 +145,12 @@ def validate_info(users):
] ]
def _validate(user, info): def _validate(user, info):
info = User.get_schema(User).load(info) schema = User.get_schema(User)
dumped = schema.dump(users[user])
for k in verify_fields: for k in verify_fields:
assert users[user][k] == info[k] assert dumped[k] == info[k]
return _validate return _validate

View File

@ -77,6 +77,89 @@ def test_files(client, users, auth):
paginate(forwards=False, page=end_page) paginate(forwards=False, page=end_page)
def test_users(client, users, auth):
"""Test /users endpoint."""
# suck it for me violating DRY this is a test
# create multiple users
total_users = set()
# this is the amount of shares we'll create
user_count = 20
for i in range(user_count):
resp = client.post(
"/users",
headers=auth("administrator"),
json={"username": f"user{i}", "permissions": [], "password": "123"},
)
assert resp.status_code == 201
data = resp.get_json()
total_users.add(f"user{i}")
# add on the existing amount of users
user_count += len(users)
for user in users.keys():
total_users.add(user)
# we'll paginate through all users and ensure that:
# - we see all the users
# - no users are seen twice
def paginate(forwards=True, page=1):
"""Goes through the pages.
Parameters
----------
forwards : bool, optional
Set direction to paginate in.
page : int
Page to start at.
Returns
-------
int
Page we ended at.
"""
seen = set()
per_page = 9
while page is not None:
resp = client.get(
"/users",
headers=auth("administrator"),
json=dict(page=page, per_page=per_page),
)
assert resp.status_code == 200
data = resp.get_json().get("data")
assert len(data) == per_page or len(data) == user_count % per_page
for user in data:
username = user.get("username")
assert username in total_users
assert username not in seen
seen.add(username)
if forwards:
new_page = resp.get_json().get("next")
assert new_page == page + 1 or new_page is None
else:
new_page = resp.get_json().get("prev")
assert new_page == page - 1 or new_page is None
if new_page is not None:
page = new_page
else:
break
assert seen == total_users
return page
end_page = paginate(forwards=True)
paginate(forwards=False, page=end_page)
def test_invalid(client, auth): def test_invalid(client, auth):
"""Test invalid requests to pagination.""" """Test invalid requests to pagination."""

View File

@ -84,11 +84,13 @@ def test_put(client, users, auth, validate_info):
new_data = {k: v for k, v in users["jeff"].items()} new_data = {k: v for k, v in users["jeff"].items()}
new_data["permissions"] = Bitmask(Permissions.ADMIN) new_data["permissions"] = Bitmask(Permissions.ADMIN)
new_data["register_date"] = datetime(2022, 2, 2, 0, 0, 0)
json_data = user_schema.dump(new_data)
json_data.update(dict(password="123"))
resp = client.put( resp = client.put(
"/users/jeff", "/users/jeff",
json=user_schema.dump(new_data), json=json_data,
headers=auth("administrator"), headers=auth("administrator"),
) )
assert resp.status_code == 200 assert resp.status_code == 200