diff --git a/sachet/server/models.py b/sachet/server/models.py index 1e4738a..1ea0f3e 100644 --- a/sachet/server/models.py +++ b/sachet/server/models.py @@ -91,6 +91,7 @@ class User(db.Model): password, current_app.config.get("BCRYPT_LOG_ROUNDS") ).decode() self.username = username + self.url = url_for("users_blueprint.user_list_api", username=self.username) self.register_date = datetime.datetime.now() def encode_token(self, jti=None): @@ -133,7 +134,8 @@ class User(db.Model): model = self username = ma.auto_field() - register_date = ma.auto_field() + register_date = ma.auto_field(dump_only=True) + password = ma.auto_field(load_only=True, required=False) permissions = PermissionField() return Schema() diff --git a/sachet/server/users/views.py b/sachet/server/users/views.py index 019e449..2a802e3 100644 --- a/sachet/server/users/views.py +++ b/sachet/server/users/views.py @@ -6,7 +6,7 @@ from sachet.server.models import ( User, BlacklistToken, ) -from sachet.server.views_common import ModelAPI, auth_required +from sachet.server.views_common import ModelAPI, ModelListAPI, auth_required from sachet.server import bcrypt, db users_blueprint = Blueprint("users_blueprint", __name__) @@ -142,3 +142,21 @@ users_blueprint.add_url_rule( view_func=UserAPI.as_view("user_api"), methods=["GET", "PATCH", "PUT"], ) + + +class UserListAPI(ModelListAPI): + @auth_required(required_permissions=(Permissions.ADMIN,)) + def post(self, auth_user=None): + data = request.get_json() + return super().post(User, data) + + @auth_required(required_permissions=(Permissions.ADMIN,)) + def get(self, auth_user=None): + return super().get(User) + + +users_blueprint.add_url_rule( + "/users", + view_func=UserListAPI.as_view("user_list_api"), + methods=["GET", "POST"], +) diff --git a/tests/conftest.py b/tests/conftest.py index bddd6a0..a0e33c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -145,10 +145,12 @@ def validate_info(users): ] def _validate(user, info): - info = User.get_schema(User).load(info) + schema = User.get_schema(User) + + dumped = schema.dump(users[user]) for k in verify_fields: - assert users[user][k] == info[k] + assert dumped[k] == info[k] return _validate diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 49917fd..68833bc 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -77,6 +77,89 @@ def test_files(client, users, auth): paginate(forwards=False, page=end_page) +def test_users(client, users, auth): + """Test /users endpoint.""" + + # suck it for me violating DRY this is a test + + # create multiple users + total_users = set() + # this is the amount of shares we'll create + user_count = 20 + + for i in range(user_count): + resp = client.post( + "/users", + headers=auth("administrator"), + json={"username": f"user{i}", "permissions": [], "password": "123"}, + ) + assert resp.status_code == 201 + data = resp.get_json() + total_users.add(f"user{i}") + + # add on the existing amount of users + user_count += len(users) + for user in users.keys(): + total_users.add(user) + + # we'll paginate through all users and ensure that: + # - we see all the users + # - no users are seen twice + + def paginate(forwards=True, page=1): + """Goes through the pages. + + Parameters + ---------- + forwards : bool, optional + Set direction to paginate in. + page : int + Page to start at. + + Returns + ------- + int + Page we ended at. + """ + seen = set() + + per_page = 9 + while page is not None: + resp = client.get( + "/users", + headers=auth("administrator"), + json=dict(page=page, per_page=per_page), + ) + assert resp.status_code == 200 + + data = resp.get_json().get("data") + assert len(data) == per_page or len(data) == user_count % per_page + + for user in data: + username = user.get("username") + assert username in total_users + assert username not in seen + seen.add(username) + + if forwards: + new_page = resp.get_json().get("next") + assert new_page == page + 1 or new_page is None + else: + new_page = resp.get_json().get("prev") + assert new_page == page - 1 or new_page is None + if new_page is not None: + page = new_page + else: + break + + assert seen == total_users + + return page + + end_page = paginate(forwards=True) + paginate(forwards=False, page=end_page) + + def test_invalid(client, auth): """Test invalid requests to pagination.""" diff --git a/tests/test_userinfo.py b/tests/test_userinfo.py index 4c3a11c..d903b99 100644 --- a/tests/test_userinfo.py +++ b/tests/test_userinfo.py @@ -84,11 +84,13 @@ def test_put(client, users, auth, validate_info): new_data = {k: v for k, v in users["jeff"].items()} new_data["permissions"] = Bitmask(Permissions.ADMIN) - new_data["register_date"] = datetime(2022, 2, 2, 0, 0, 0) + + json_data = user_schema.dump(new_data) + json_data.update(dict(password="123")) resp = client.put( "/users/jeff", - json=user_schema.dump(new_data), + json=json_data, headers=auth("administrator"), ) assert resp.status_code == 200