/users: added endpoint for creating, listing users
This commit is contained in:
parent
71683662ea
commit
cbf1310e1b
@ -91,6 +91,7 @@ class User(db.Model):
|
||||
password, current_app.config.get("BCRYPT_LOG_ROUNDS")
|
||||
).decode()
|
||||
self.username = username
|
||||
self.url = url_for("users_blueprint.user_list_api", username=self.username)
|
||||
self.register_date = datetime.datetime.now()
|
||||
|
||||
def encode_token(self, jti=None):
|
||||
@ -133,7 +134,8 @@ class User(db.Model):
|
||||
model = self
|
||||
|
||||
username = ma.auto_field()
|
||||
register_date = ma.auto_field()
|
||||
register_date = ma.auto_field(dump_only=True)
|
||||
password = ma.auto_field(load_only=True, required=False)
|
||||
permissions = PermissionField()
|
||||
|
||||
return Schema()
|
||||
|
@ -6,7 +6,7 @@ from sachet.server.models import (
|
||||
User,
|
||||
BlacklistToken,
|
||||
)
|
||||
from sachet.server.views_common import ModelAPI, auth_required
|
||||
from sachet.server.views_common import ModelAPI, ModelListAPI, auth_required
|
||||
from sachet.server import bcrypt, db
|
||||
|
||||
users_blueprint = Blueprint("users_blueprint", __name__)
|
||||
@ -142,3 +142,21 @@ users_blueprint.add_url_rule(
|
||||
view_func=UserAPI.as_view("user_api"),
|
||||
methods=["GET", "PATCH", "PUT"],
|
||||
)
|
||||
|
||||
|
||||
class UserListAPI(ModelListAPI):
|
||||
@auth_required(required_permissions=(Permissions.ADMIN,))
|
||||
def post(self, auth_user=None):
|
||||
data = request.get_json()
|
||||
return super().post(User, data)
|
||||
|
||||
@auth_required(required_permissions=(Permissions.ADMIN,))
|
||||
def get(self, auth_user=None):
|
||||
return super().get(User)
|
||||
|
||||
|
||||
users_blueprint.add_url_rule(
|
||||
"/users",
|
||||
view_func=UserListAPI.as_view("user_list_api"),
|
||||
methods=["GET", "POST"],
|
||||
)
|
||||
|
@ -145,10 +145,12 @@ def validate_info(users):
|
||||
]
|
||||
|
||||
def _validate(user, info):
|
||||
info = User.get_schema(User).load(info)
|
||||
schema = User.get_schema(User)
|
||||
|
||||
dumped = schema.dump(users[user])
|
||||
|
||||
for k in verify_fields:
|
||||
assert users[user][k] == info[k]
|
||||
assert dumped[k] == info[k]
|
||||
|
||||
return _validate
|
||||
|
||||
|
@ -77,6 +77,89 @@ def test_files(client, users, auth):
|
||||
paginate(forwards=False, page=end_page)
|
||||
|
||||
|
||||
def test_users(client, users, auth):
|
||||
"""Test /users endpoint."""
|
||||
|
||||
# suck it for me violating DRY this is a test
|
||||
|
||||
# create multiple users
|
||||
total_users = set()
|
||||
# this is the amount of shares we'll create
|
||||
user_count = 20
|
||||
|
||||
for i in range(user_count):
|
||||
resp = client.post(
|
||||
"/users",
|
||||
headers=auth("administrator"),
|
||||
json={"username": f"user{i}", "permissions": [], "password": "123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.get_json()
|
||||
total_users.add(f"user{i}")
|
||||
|
||||
# add on the existing amount of users
|
||||
user_count += len(users)
|
||||
for user in users.keys():
|
||||
total_users.add(user)
|
||||
|
||||
# we'll paginate through all users and ensure that:
|
||||
# - we see all the users
|
||||
# - no users are seen twice
|
||||
|
||||
def paginate(forwards=True, page=1):
|
||||
"""Goes through the pages.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
forwards : bool, optional
|
||||
Set direction to paginate in.
|
||||
page : int
|
||||
Page to start at.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Page we ended at.
|
||||
"""
|
||||
seen = set()
|
||||
|
||||
per_page = 9
|
||||
while page is not None:
|
||||
resp = client.get(
|
||||
"/users",
|
||||
headers=auth("administrator"),
|
||||
json=dict(page=page, per_page=per_page),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
data = resp.get_json().get("data")
|
||||
assert len(data) == per_page or len(data) == user_count % per_page
|
||||
|
||||
for user in data:
|
||||
username = user.get("username")
|
||||
assert username in total_users
|
||||
assert username not in seen
|
||||
seen.add(username)
|
||||
|
||||
if forwards:
|
||||
new_page = resp.get_json().get("next")
|
||||
assert new_page == page + 1 or new_page is None
|
||||
else:
|
||||
new_page = resp.get_json().get("prev")
|
||||
assert new_page == page - 1 or new_page is None
|
||||
if new_page is not None:
|
||||
page = new_page
|
||||
else:
|
||||
break
|
||||
|
||||
assert seen == total_users
|
||||
|
||||
return page
|
||||
|
||||
end_page = paginate(forwards=True)
|
||||
paginate(forwards=False, page=end_page)
|
||||
|
||||
|
||||
def test_invalid(client, auth):
|
||||
"""Test invalid requests to pagination."""
|
||||
|
||||
|
@ -84,11 +84,13 @@ def test_put(client, users, auth, validate_info):
|
||||
|
||||
new_data = {k: v for k, v in users["jeff"].items()}
|
||||
new_data["permissions"] = Bitmask(Permissions.ADMIN)
|
||||
new_data["register_date"] = datetime(2022, 2, 2, 0, 0, 0)
|
||||
|
||||
json_data = user_schema.dump(new_data)
|
||||
json_data.update(dict(password="123"))
|
||||
|
||||
resp = client.put(
|
||||
"/users/jeff",
|
||||
json=user_schema.dump(new_data),
|
||||
json=json_data,
|
||||
headers=auth("administrator"),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
Loading…
Reference in New Issue
Block a user