diff --git a/sachet/server/files/views.py b/sachet/server/files/views.py index 05d4ca9..76bce5b 100644 --- a/sachet/server/files/views.py +++ b/sachet/server/files/views.py @@ -43,8 +43,7 @@ files_blueprint.add_url_rule( class FileCreationAPI(ModelAPI): @auth_required(required_permissions=(Permissions.CREATE,)) def post(self, auth_user=None): - # silent means it will return None if there is no JSON - data = request.get_json(silent=True) or {} + data = request.get_json() data["owner_name"] = auth_user.username return super().post(Share, data) @@ -157,7 +156,7 @@ class FileContentAPI(ModelAPI): with file.open(mode="rb") as f: data = f.read() - return send_file(io.BytesIO(data), download_name=str(share.share_id)) + return send_file(io.BytesIO(data), download_name=share.file_name) files_blueprint.add_url_rule( diff --git a/sachet/server/models.py b/sachet/server/models.py index c4acc9f..7083814 100644 --- a/sachet/server/models.py +++ b/sachet/server/models.py @@ -211,6 +211,8 @@ class Share(db.Model): the real data is uploaded. create_date : DateTime Time the share was created (not initialized.) + file_name : str + File name to download as. Methods ------- @@ -230,12 +232,18 @@ class Share(db.Model): create_date = db.Column(db.DateTime, nullable=False) - def __init__(self, owner_name): + file_name = db.Column(db.String, nullable=False) + + def __init__(self, owner_name, file_name=None): self.owner = User.query.filter_by(username=owner_name).first() self.owner_name = self.owner.username self.share_id = uuid.uuid4() self.url = url_for("files_blueprint.files_api", share_id=self.share_id) self.create_date = datetime.datetime.now() + if file_name: + self.file_name = file_name + else: + self.file_name = str(self.share_id) def get_schema(self): class Schema(ma.SQLAlchemySchema): @@ -244,6 +252,7 @@ class Share(db.Model): share_id = ma.auto_field(dump_only=True) owner_name = ma.auto_field() + file_name = ma.auto_field() initialized = ma.auto_field(dump_only=True) return Schema() diff --git a/tests/test_files.py b/tests/test_files.py index 25ccedb..c257109 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -12,7 +12,9 @@ import uuid class TestSuite: def test_sharing(self, client, users, auth, rand): # create share - resp = client.post("/files", headers=auth("jeff")) + resp = client.post( + "/files", headers=auth("jeff"), json={"file_name": "content.bin"} + ) assert resp.status_code == 201 data = resp.get_json() @@ -40,6 +42,7 @@ class TestSuite: headers=auth("jeff"), ) assert resp.data == upload_data + assert resp.headers["Content-Disposition"] == "inline; filename=content.bin" # test deletion resp = client.delete( @@ -85,7 +88,9 @@ class TestSuite: assert resp.status_code == 403 # valid share creation to move on to testing content endpoint - resp = client.post("/files", headers=auth("jeff")) + resp = client.post( + "/files", headers=auth("jeff"), json={"file_name": "content.bin"} + ) assert resp.status_code == 201 data = resp.get_json() url = data.get("url")