Compare commits

..

2 Commits

8 changed files with 183 additions and 33 deletions

View File

@ -9,12 +9,13 @@ from .config import DevelopmentConfig, ProductionConfig, TestingConfig, overlay_
app = Flask(__name__) app = Flask(__name__)
CORS(app) CORS(app)
if os.getenv("RUN_ENV") == "test": with app.app_context():
if os.getenv("RUN_ENV") == "test":
overlay_config(TestingConfig, "./config-testing.yml") overlay_config(TestingConfig, "./config-testing.yml")
elif app.config["DEBUG"]: elif app.config["DEBUG"]:
overlay_config(DevelopmentConfig) overlay_config(DevelopmentConfig)
app.logger.warning("Running in DEVELOPMENT MODE; do NOT use this in production!") app.logger.warning("Running in DEVELOPMENT MODE; do NOT use this in production!")
else: else:
overlay_config(ProductionConfig) overlay_config(ProductionConfig)
bcrypt = Bcrypt(app) bcrypt = Bcrypt(app)
@ -27,9 +28,12 @@ storage = None
from sachet.storage import FileSystem from sachet.storage import FileSystem
if _storage_method == "filesystem":
with app.app_context():
db.create_all()
if _storage_method == "filesystem":
storage = FileSystem() storage = FileSystem()
else: else:
raise ValueError(f"{_storage_method} is not a valid storage method.") raise ValueError(f"{_storage_method} is not a valid storage method.")
import sachet.server.commands import sachet.server.commands
@ -45,6 +49,3 @@ app.register_blueprint(admin_blueprint)
from sachet.server.files.views import files_blueprint from sachet.server.files.views import files_blueprint
app.register_blueprint(files_blueprint) app.register_blueprint(files_blueprint)
with app.app_context():
db.create_all()

View File

@ -1,4 +1,5 @@
from os import getenv, path from os import getenv, path
from flask import current_app
import yaml import yaml
sqlalchemy_base = "sqlite:///sachet" sqlalchemy_base = "sqlite:///sachet"
@ -51,9 +52,7 @@ def overlay_config(base, config_file=None):
if config["SECRET_KEY"] == "" or config["SECRET_KEY"] is None: if config["SECRET_KEY"] == "" or config["SECRET_KEY"] is None:
raise ValueError("Please set secret_key within the configuration.") raise ValueError("Please set secret_key within the configuration.")
from sachet.server import app current_app.config.from_object(base)
app.config.from_object(base)
for k, v in config.items(): for k, v in config.items():
app.config[k] = v current_app.config[k] = v

View File

@ -18,11 +18,15 @@ class FilesMetadataAPI(ModelAPI):
@auth_required(required_permissions=(Permissions.MODIFY,), allow_anonymous=True) @auth_required(required_permissions=(Permissions.MODIFY,), allow_anonymous=True)
def patch(self, share_id, auth_user=None): def patch(self, share_id, auth_user=None):
share = Share.query.filter_by(share_id=share_id).first() share = Share.query.filter_by(share_id=share_id).first()
if share.locked:
return jsonify({"status": "fail", "message": "This share is locked."}), 423
return super().patch(share) return super().patch(share)
@auth_required(required_permissions=(Permissions.MODIFY,), allow_anonymous=True) @auth_required(required_permissions=(Permissions.MODIFY,), allow_anonymous=True)
def put(self, share_id, auth_user=None): def put(self, share_id, auth_user=None):
share = Share.query.filter_by(share_id=share_id).first() share = Share.query.filter_by(share_id=share_id).first()
if share.locked:
return jsonify({"status": "fail", "message": "This share is locked."}), 423
return super().put(share) return super().put(share)
@auth_required(required_permissions=(Permissions.DELETE,), allow_anonymous=True) @auth_required(required_permissions=(Permissions.DELETE,), allow_anonymous=True)
@ -32,6 +36,8 @@ class FilesMetadataAPI(ModelAPI):
except ValueError: except ValueError:
return jsonify(dict(status="fail", message=f"Invalid ID: '{share_id}'.")) return jsonify(dict(status="fail", message=f"Invalid ID: '{share_id}'."))
share = Share.query.filter_by(share_id=share_id).first() share = Share.query.filter_by(share_id=share_id).first()
if share.locked:
return jsonify({"status": "fail", "message": "This share is locked."}), 423
return super().delete(share) return super().delete(share)
@ -117,6 +123,11 @@ class FileContentAPI(ModelAPI):
jsonify({"status": "fail", "message": "This share does not exist."}) jsonify({"status": "fail", "message": "This share does not exist."})
), 404 ), 404
if share.locked:
return (
jsonify({"status": "fail", "message": "This share is locked."})
), 423
if auth_user != share.owner: if auth_user != share.owner:
return ( return (
jsonify( jsonify(
@ -182,3 +193,47 @@ files_blueprint.add_url_rule(
view_func=FileContentAPI.as_view("files_content_api"), view_func=FileContentAPI.as_view("files_content_api"),
methods=["POST", "PUT", "GET"], methods=["POST", "PUT", "GET"],
) )
class FileLockAPI(ModelAPI):
@auth_required(required_permissions=(Permissions.LOCK,), allow_anonymous=True)
def post(self, share_id, auth_user=None):
share = Share.query.filter_by(share_id=uuid.UUID(share_id)).first()
if not share:
return (
jsonify({"status": "fail", "message": "This share does not exist."})
), 404
share.locked = True
db.session.commit()
return jsonify({"status": "success", "message": "Share has been locked."})
files_blueprint.add_url_rule(
"/files/<share_id>/lock",
view_func=FileLockAPI.as_view("files_lock_api"),
methods=["POST"],
)
class FileUnlockAPI(ModelAPI):
@auth_required(required_permissions=(Permissions.LOCK,), allow_anonymous=True)
def post(self, share_id, auth_user=None):
share = Share.query.filter_by(share_id=uuid.UUID(share_id)).first()
if not share:
return (
jsonify({"status": "fail", "message": "This share does not exist."})
), 404
share.locked = False
db.session.commit()
return jsonify({"status": "success", "message": "Share has been unlocked."})
files_blueprint.add_url_rule(
"/files/<share_id>/unlock",
view_func=FileUnlockAPI.as_view("files_unlock_api"),
methods=["POST"],
)

View File

@ -1,10 +1,10 @@
from sachet.server import app, db, ma, bcrypt, storage from sachet.server import db, ma, bcrypt, storage
import datetime import datetime
import jwt import jwt
from enum import IntFlag from enum import IntFlag
from bitmask import Bitmask from bitmask import Bitmask
from marshmallow import fields, ValidationError from marshmallow import fields, ValidationError
from flask import request, jsonify, url_for from flask import request, jsonify, url_for, current_app
from sqlalchemy_utils import UUIDType from sqlalchemy_utils import UUIDType
import uuid import uuid
@ -87,7 +87,7 @@ class User(db.Model):
self.permissions = permissions self.permissions = permissions
self.password = bcrypt.generate_password_hash( self.password = bcrypt.generate_password_hash(
password, app.config.get("BCRYPT_LOG_ROUNDS") password, current_app.config.get("BCRYPT_LOG_ROUNDS")
).decode() ).decode()
self.username = username self.username = username
self.register_date = datetime.datetime.now() self.register_date = datetime.datetime.now()
@ -100,7 +100,7 @@ class User(db.Model):
"sub": self.username, "sub": self.username,
"jti": jti, "jti": jti,
} }
return jwt.encode(payload, app.config.get("SECRET_KEY"), algorithm="HS256") return jwt.encode(payload, current_app.config.get("SECRET_KEY"), algorithm="HS256")
def read_token(token): def read_token(token):
"""Read a JWT and validate it. """Read a JWT and validate it.
@ -111,7 +111,7 @@ class User(db.Model):
data = jwt.decode( data = jwt.decode(
token, token,
app.config["SECRET_KEY"], current_app.config["SECRET_KEY"],
algorithms=["HS256"], algorithms=["HS256"],
) )
@ -153,7 +153,7 @@ class BlacklistToken(db.Model):
data = jwt.decode( data = jwt.decode(
token, token,
app.config["SECRET_KEY"], current_app.config["SECRET_KEY"],
algorithms=["HS256"], algorithms=["HS256"],
) )
self.expires = datetime.datetime.fromtimestamp(data["exp"]) self.expires = datetime.datetime.fromtimestamp(data["exp"])
@ -220,6 +220,8 @@ class Share(db.Model):
initialized : bool initialized : bool
Since only the metadata is uploaded first, this switches to True when Since only the metadata is uploaded first, this switches to True when
the real data is uploaded. the real data is uploaded.
locked : bool
Locks modification and deletion of this share.
create_date : DateTime create_date : DateTime
Time the share was created (not initialized.) Time the share was created (not initialized.)
file_name : str file_name : str
@ -242,12 +244,13 @@ class Share(db.Model):
owner = db.relationship("User", backref=db.backref("owner")) owner = db.relationship("User", backref=db.backref("owner"))
initialized = db.Column(db.Boolean, nullable=False, default=False) initialized = db.Column(db.Boolean, nullable=False, default=False)
locked = db.Column(db.Boolean, nullable=False, default=False)
create_date = db.Column(db.DateTime, nullable=False) create_date = db.Column(db.DateTime, nullable=False)
file_name = db.Column(db.String, nullable=False) file_name = db.Column(db.String, nullable=False)
def __init__(self, owner_name=None, file_name=None): def __init__(self, owner_name=None, file_name=None, locked=False):
self.owner = User.query.filter_by(username=owner_name).first() self.owner = User.query.filter_by(username=owner_name).first()
if self.owner: if self.owner:
self.owner_name = self.owner.username self.owner_name = self.owner.username
@ -259,6 +262,8 @@ class Share(db.Model):
else: else:
self.file_name = str(self.share_id) self.file_name = str(self.share_id)
self.locked = locked
def get_schema(self): def get_schema(self):
class Schema(ma.SQLAlchemySchema): class Schema(ma.SQLAlchemySchema):
class Meta: class Meta:
@ -268,6 +273,7 @@ class Share(db.Model):
owner_name = ma.auto_field() owner_name = ma.auto_field()
file_name = ma.auto_field() file_name = ma.auto_field()
initialized = ma.auto_field(dump_only=True) initialized = ma.auto_field(dump_only=True)
locked = ma.auto_field(dump_only=True)
return Schema() return Schema()

View File

@ -1,4 +1,5 @@
from sachet.storage import Storage from sachet.storage import Storage
from flask import current_app
from pathlib import Path from pathlib import Path
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
import json import json
@ -6,21 +7,18 @@ import json
class FileSystem(Storage): class FileSystem(Storage):
def __init__(self): def __init__(self):
# prevent circular import when inspecting this file outside of Flask config_path = Path(current_app.config["SACHET_FILE_DIR"])
from sachet.server import app
config_path = Path(app.config["SACHET_FILE_DIR"])
if config_path.is_absolute(): if config_path.is_absolute():
self._directory = config_path self._directory = config_path
else: else:
self._directory = Path(app.instance_path) / config_path self._directory = Path(current_app.instance_path) / config_path
self._files_directory = self._directory / Path("files") self._files_directory = self._directory / Path("files")
self._files_directory.mkdir(mode=0o700, exist_ok=True, parents=True) self._files_directory.mkdir(mode=0o700, exist_ok=True, parents=True)
if not self._directory.is_dir(): if not self._directory.is_dir():
raise OSError(f"'{app.config['SACHET_FILE_DIR']}' is not a directory.") raise OSError(f"'{current_app.config['SACHET_FILE_DIR']}' is not a directory.")
def _get_path(self, name): def _get_path(self, name):
name = secure_filename(name) name = secure_filename(name)

View File

@ -70,6 +70,7 @@ def users(client):
Permissions.DELETE, Permissions.DELETE,
Permissions.MODIFY, Permissions.MODIFY,
Permissions.LIST, Permissions.LIST,
Permissions.LOCK,
), ),
), ),
dave=dict( dave=dict(
@ -110,6 +111,16 @@ def users(client):
Permissions.READ, Permissions.READ,
), ),
), ),
no_lock_user=dict(
password="password",
permissions=Bitmask(
Permissions.CREATE,
Permissions.MODIFY,
Permissions.DELETE,
Permissions.ADMIN,
Permissions.READ,
),
),
administrator=dict(password="4321", permissions=Bitmask(Permissions.ADMIN)), administrator=dict(password="4321", permissions=Bitmask(Permissions.ADMIN)),
) )

View File

@ -1,6 +1,6 @@
import pytest import pytest
from sachet.server.commands import create_db, drop_db, create_user, delete_user from sachet.server.commands import create_db, drop_db, create_user, delete_user
from sachet.server import app, db from sachet.server import db
from sqlalchemy import inspect from sqlalchemy import inspect
from sachet.server.models import User from sachet.server.models import User

View File

@ -193,7 +193,6 @@ class TestSuite:
) )
assert resp.status_code == 403 assert resp.status_code == 403
# test not allowing re-upload # test not allowing re-upload
resp = client.post( resp = client.post(
url + "/content", url + "/content",
@ -210,3 +209,84 @@ class TestSuite:
assert resp.status_code == 403 assert resp.status_code == 403
resp = client.get(url + "/content", headers=auth("no_read_user")) resp = client.get(url + "/content", headers=auth("no_read_user"))
assert resp.status_code == 403 assert resp.status_code == 403
def test_locking(self, client, users, auth, rand):
# upload share
resp = client.post(
"/files", headers=auth("jeff"), json={"file_name": "content.bin"}
)
data = resp.get_json()
url = data.get("url")
upload_data = rand.randbytes(4000)
resp = client.post(
url + "/content",
headers=auth("jeff"),
data={
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
)
assert resp.status_code == 201
# lock share
resp = client.post(
url + "/lock",
headers=auth("jeff"),
)
assert resp.status_code == 200
# attempt to modify share
resp = client.put(
url + "/content",
headers=auth("jeff"),
data={
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
)
assert resp.status_code == 423
# attempt to delete share
resp = client.delete(
url,
headers=auth("jeff"),
)
assert resp.status_code == 423
# unlock share
resp = client.post(
url + "/unlock",
headers=auth("jeff"),
)
assert resp.status_code == 200
# attempt to modify share
resp = client.put(
url + "/content",
headers=auth("jeff"),
data={
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
)
assert resp.status_code == 200
# attempt to delete share
resp = client.delete(
url,
headers=auth("jeff"),
)
assert resp.status_code == 200
# attempt to lock/unlock without perms
resp = client.post(
url + "/lock",
headers=auth("no_lock_user"),
)
assert resp.status_code == 403
resp = client.post(
url + "/unlock",
headers=auth("no_lock_user"),
)
assert resp.status_code == 403