diff --git a/sachet/server/models.py b/sachet/server/models.py index c251b77..e42e1ad 100644 --- a/sachet/server/models.py +++ b/sachet/server/models.py @@ -33,50 +33,97 @@ class User(db.Model): algorithm="HS256" ) -def _token_decorator(require_admin, f, *args, **kwargs): - """Generic function for checking tokens. - require_admin: require user to be administrator to authenticate +class BlacklistToken(db.Model): + """Token that has been revoked (but has not expired yet.) + + This is needed to perform functionality like logging out. """ - token = None - auth_header = request.headers.get("Authorization") - if auth_header: - try: - token = auth_header.split(" ")[1] - except IndexError: - resp = { - "status": "fail", - "message": "Malformed Authorization header." - } - return jsonify(resp), 401 - if not token: - return jsonify({"status": "fail", "message": "Missing auth token"}), 401 - try: - data = jwt.decode(token, app.config["SECRET_KEY"], algorithms=["HS256"]) - user = User.query.filter_by(username=data.get("sub")).first() - except: - return jsonify({"status": "fail", "message": "Invalid auth token."}), 401 + __tablename__ = "blacklist_tokens" + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + token = db.Column(db.String(500), unique=True, nullable=False) + expires = db.Column(db.DateTime, nullable=False) + def __init__(self, token): + self.token = token + + data = jwt.decode( + token, + app.config["SECRET_KEY"], + algorithms=["HS256"], + ) + self.expires = datetime.datetime.fromtimestamp(data["exp"]) + + @staticmethod + def check_blacklist(token): + """Returns if a token is blacklisted.""" + entry = BlacklistToken.query.filter_by(token=token).first() + + if not entry: + return False + else: + if entry.expires < datetime.datetime.utcnow(): + db.session.delete(entry) + return True + + +def read_token(token): + """Read a JWT and validate it. + + Returns a tuple: dictionary of the JWT's data, and the corresponding user + if available. + """ + + data = jwt.decode( + token, + app.config["SECRET_KEY"], + algorithms=["HS256"], + ) + + if BlacklistToken.check_blacklist(token): + raise jwt.ExpiredSignatureError("Token revoked.") + + user = User.query.filter_by(username=data.get("sub")).first() if not user: - return jsonify({"status": "fail", "message": "Invalid auth token."}), 401 + raise jwt.InvalidTokenError("No user corresponds to this token.") - if require_admin and not user.admin: - return jsonify({"status": "fail", "message": "You are not authorized to view this page."}), 403 + return data, user - return f(user, *args, **kwargs) -def token_required(f): - """Decorator to require authentication.""" +def auth_required(f): + """Decorator to require authentication. + + Passes an argument 'user' to the function, with a User object corresponding + to the authenticated session. + """ @wraps(f) def decorator(*args, **kwargs): - return _token_decorator(False, f, *args, **kwargs) - return decorator + token = None + auth_header = request.headers.get("Authorization") + if auth_header: + try: + token = auth_header.split(" ")[1] + except IndexError: + resp = { + "status": "fail", + "message": "Malformed Authorization header." + } + return jsonify(resp), 401 -def admin_required(f): - """Decorator to require authentication and admin privileges.""" + if not token: + return jsonify({"status": "fail", "message": "Missing auth token"}), 401 + + try: + data, user = read_token(token) + except jwt.ExpiredSignatureError: + # if it's expired we don't want it lingering in the db + BlacklistToken.check_blacklist(token) + return jsonify({"status": "fail", "message": "Token has expired."}), 401 + except jwt.InvalidTokenError: + return jsonify({"status": "fail", "message": "Invalid auth token."}), 401 + + return f(user, *args, **kwargs) - @wraps(f) - def decorator(*args, **kwargs): - return _token_decorator(True, f, *args, **kwargs) return decorator diff --git a/sachet/server/users/views.py b/sachet/server/users/views.py index 2ca886b..db6936d 100644 --- a/sachet/server/users/views.py +++ b/sachet/server/users/views.py @@ -1,7 +1,8 @@ +import jwt from flask import Blueprint, request, jsonify from flask.views import MethodView -from sachet.server.models import token_required, admin_required, User -from sachet.server import bcrypt +from sachet.server.models import auth_required, read_token, User, BlacklistToken +from sachet.server import bcrypt, db users_blueprint = Blueprint("users_blueprint", __name__) @@ -42,9 +43,46 @@ users_blueprint.add_url_rule( ) +class LogoutAPI(MethodView): + """Endpoint to revoke a user's token.""" + + @auth_required + def post(user, self): + post_data = request.get_json() + token = post_data.get("token") + if not token: + return jsonify({"status": "fail", "message": "Specify a token to revoke."}), 400 + + res = BlacklistToken.check_blacklist(token) + if res: + return jsonify({"status": "fail", "message": "Token already revoked."}), 400 + + try: + data, token_user = read_token(token) + except jwt.ExpiredSignatureError: + return jsonify({"status": "fail", "message": "Token already expired."}), 400 + except jwt.InvalidTokenError: + return jsonify({"status": "fail", "message": "Invalid auth token."}), 400 + + if user == token_user or user.admin == True: + entry = BlacklistToken(token=token) + db.session.add(entry) + db.session.commit() + return jsonify({"status": "success", "message": "Token revoked."}), 200 + else: + return jsonify({"status": "fail", "message": "You are not allowed to revoke this token."}), 403 + + +users_blueprint.add_url_rule( + "/users/logout", + view_func=LogoutAPI.as_view("logout_api"), + methods=['POST'] +) + + class UserAPI(MethodView): """User information API""" - @token_required + @auth_required def get(user, self, username): info_user = User.query.filter_by(username=username).first() if (not info_user) or (info_user != user and not user.admin): diff --git a/tests/test_auth.py b/tests/test_auth.py index 8cf1564..6bc7f89 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -84,3 +84,104 @@ def test_login(client, users): assert resp_json.get("username") == "jeff" token = resp_json.get("auth_token") assert token is not None and token != "" + + +def test_logout(client, tokens, validate_info): + """Test logging out.""" + + # unauthenticated + resp = client.post("/users/logout", json={ + "token": tokens["jeff"] + }, + ) + assert resp.status_code == 401 + + # missing token + resp = client.post("/users/logout", json={ + + }, + headers={ + "Authorization": f"bearer {tokens['jeff']}" + } + ) + assert resp.status_code == 400 + + # invalid token + resp = client.post("/users/logout", json={ + "token": "not.real.jwt" + }, + headers={ + "Authorization": f"bearer {tokens['jeff']}" + } + ) + assert resp.status_code == 400 + + # wrong user's token + resp = client.post("/users/logout", json={ + "token": tokens["administrator"] + }, + headers={ + "Authorization": f"bearer {tokens['jeff']}" + } + ) + assert resp.status_code == 403 + + # check that we can access this endpoint before logging out + resp = client.get("/users/jeff", + headers={ + "Authorization": f"bearer {tokens['jeff']}" + } + ) + assert resp.status_code == 200 + validate_info("jeff", resp.get_json()) + + # valid logout + resp = client.post("/users/logout", json={ + "token": tokens["jeff"] + }, + headers={ + "Authorization": f"bearer {tokens['jeff']}" + } + ) + assert resp.status_code == 200 + + # check that the logout worked + + resp = client.get("/users/jeff", + headers={ + "Authorization": f"bearer {tokens['jeff']}" + } + ) + assert resp.status_code == 401 + +def test_admin_revoke(client, tokens, validate_info): + """Test that an admin can revoke any token from other users.""" + + resp = client.post("/users/logout", json={ + "token": tokens["jeff"] + }, + headers={ + "Authorization": f"bearer {tokens['administrator']}" + } + ) + assert resp.status_code == 200 + + # check that the logout worked + + resp = client.get("/users/jeff", + headers={ + "Authorization": f"bearer {tokens['jeff']}" + } + ) + assert resp.status_code == 401 + + # try revoking twice + + resp = client.post("/users/logout", json={ + "token": tokens["jeff"] + }, + headers={ + "Authorization": f"bearer {tokens['administrator']}" + } + ) + assert resp.status_code == 400