implemented chunked upload

This commit is contained in:
dogeystamp 2023-05-07 21:08:45 -04:00
parent ec41382f9d
commit 7aa60d8aed
Signed by: dogeystamp
GPG Key ID: 7225FE3592EFFA38
6 changed files with 301 additions and 115 deletions

View File

@ -2,7 +2,7 @@ import uuid
import io import io
from flask import Blueprint, request, jsonify, send_file from flask import Blueprint, request, jsonify, send_file
from flask.views import MethodView from flask.views import MethodView
from sachet.server.models import Share, Permissions from sachet.server.models import Share, Permissions, Upload, Chunk
from sachet.server.views_common import ModelAPI, ModelListAPI, auth_required from sachet.server.views_common import ModelAPI, ModelListAPI, auth_required
from sachet.server import storage, db from sachet.server import storage, db
@ -68,7 +68,60 @@ files_blueprint.add_url_rule(
) )
class FileContentAPI(ModelAPI): class FileContentAPI(MethodView):
def recv_upload(self, share):
"""Receive chunked uploads.
share : Share
Share we are uploading to.
"""
chunk_file = request.files.get("upload")
if not chunk_file:
return (
jsonify(dict(status="fail", message="Missing chunk data in request.")),
400,
)
chunk_data = chunk_file.read()
try:
dz_uuid = request.form["dzuuid"]
dz_chunk_index = int(request.form["dzchunkindex"])
dz_total_chunks = int(request.form["dztotalchunks"])
except KeyError as err:
return (
jsonify(
dict(
status="fail", message=f"Missing data for chunking; {err}"
)
),
400,
)
except ValueError as err:
return (
jsonify(
dict(
status="fail", message=f"{err}"
)
),
400,
)
chunk = Chunk(dz_chunk_index, dz_uuid, dz_total_chunks, share, chunk_data)
db.session.add(chunk)
db.session.commit()
upload = chunk.upload
upload.recv_chunks = upload.recv_chunks + 1
if upload.recv_chunks >= upload.total_chunks:
upload.complete()
if upload.completed:
share.initialized = True
db.session.commit()
return jsonify(dict(status="success", message="Upload completed.")), 201
else:
return jsonify(dict(status="success", message="Chunk uploaded.")), 200
@auth_required(required_permissions=(Permissions.CREATE,), allow_anonymous=True) @auth_required(required_permissions=(Permissions.CREATE,), allow_anonymous=True)
def post(self, share_id, auth_user=None): def post(self, share_id, auth_user=None):
share = Share.query.filter_by(share_id=uuid.UUID(share_id)).first() share = Share.query.filter_by(share_id=uuid.UUID(share_id)).first()
@ -100,20 +153,7 @@ class FileContentAPI(ModelAPI):
423, 423,
) )
upload = request.files["upload"] return self.recv_upload(share)
data = upload.read()
file = share.get_handle()
with file.open(mode="wb") as f:
f.write(data)
share.initialized = True
db.session.commit()
return (
jsonify({"status": "success", "message": "Share has been initialized."}),
201,
)
@auth_required(required_permissions=(Permissions.MODIFY,), allow_anonymous=True) @auth_required(required_permissions=(Permissions.MODIFY,), allow_anonymous=True)
def put(self, share_id, auth_user=None): def put(self, share_id, auth_user=None):
@ -150,17 +190,7 @@ class FileContentAPI(ModelAPI):
423, 423,
) )
upload = request.files["upload"] return self.recv_upload(share)
data = upload.read()
file = share.get_handle()
with file.open(mode="wb") as f:
f.write(data)
return (
jsonify({"status": "success", "message": "Share has been modified."}),
200,
)
@auth_required(required_permissions=(Permissions.READ,), allow_anonymous=True) @auth_required(required_permissions=(Permissions.READ,), allow_anonymous=True)
def get(self, share_id, auth_user=None): def get(self, share_id, auth_user=None):

View File

@ -281,3 +281,129 @@ class Share(db.Model):
def get_handle(self): def get_handle(self):
return storage.get_file(str(self.share_id)) return storage.get_file(str(self.share_id))
class Upload(db.Model):
"""Upload instance for a given file.
Parameters
----------
upload_id : str
ID associated to this upload.
total_chunks: int
Total amount of chunks in this upload.
share_id : uuid.UUID
Assigns this upload to the given share id.
Attributes
----------
upload_id : str
ID associated to this upload.
total_chunks : int
Total amount of chunks in this upload.
recv_chunks : int
Amount of chunks received in this upload.
completed : bool
Whether the file has been fully uploaded.
share : Share
The share this upload is for.
chunks : list of Chunk
Chunks composing this upload.
create_date : DateTime
Time this upload was started.
"""
__tablename__ = "uploads"
upload_id = db.Column(db.String, primary_key=True)
share_id = db.Column(UUIDType(), db.ForeignKey("shares.share_id"))
share = db.relationship("Share", backref=db.backref("upload"))
create_date = db.Column(db.DateTime, nullable=False)
total_chunks = db.Column(db.Integer, nullable=False)
recv_chunks = db.Column(db.Integer, nullable=False, default=0)
completed = db.Column(db.Boolean, nullable=False, default=False)
chunks = db.relationship(
"Chunk", backref=db.backref("upload"), order_by="Chunk.chunk_id"
)
def __init__(self, upload_id, total_chunks, share_id):
self.share = Share.query.filter_by(share_id=share_id).first()
if self.share is None:
raise KeyError(f"Share '{self.share_id}' could not be found.")
self.upload_id = upload_id
self.total_chunks = total_chunks
self.create_date = datetime.datetime.now()
def complete(self):
"""Merge chunks, save the file, then clean up."""
tmp_file = storage.get_file(f"{self.share.share_id}_{self.upload_id}")
with tmp_file.open(mode="ab") as tmp_f:
for chunk in self.chunks:
chunk_file = storage.get_file(chunk.filename)
with chunk_file.open(mode="rb") as chunk_f:
data = chunk_f.read()
tmp_f.write(data)
# replace the old file
old_file = self.share.get_handle()
old_file.delete()
tmp_file.rename(str(self.share.share_id))
self.completed = True
class Chunk(db.Model):
"""Single chunk within an upload.
Parameters
----------
index : int
Index of this chunk within an upload.
upload_id : str
ID of the upload this chunk is associated to.
total_chunks : int
Total amount of chunks within this upload.
share : Share
Assigns this chunk to the given share.
data : bytes
Raw chunk data.
Attributes
----------
chunk_id : int
ID unique for all chunks (not just in a single upload.)
create_date : DateTime
Time this chunk was received.
index : int
Index of this chunk within an upload.
upload : Upload
Upload this chunk is associated to.
filename : str
Filename the data is stored in.
"""
__tablename__ = "chunks"
chunk_id = db.Column(db.Integer, primary_key=True, autoincrement=True)
create_date = db.Column(db.DateTime, nullable=False)
index = db.Column(db.Integer, nullable=False)
upload_id = db.Column(db.String, db.ForeignKey("uploads.upload_id"))
filename = db.Column(db.String, nullable=False)
def __init__(self, index, upload_id, total_chunks, share, data):
self.upload = Upload.query.filter_by(upload_id=upload_id).first()
if self.upload is None:
self.upload = Upload(upload_id, total_chunks, share.share_id)
self.upload.recv_chunks = 0
db.session.add(self.upload)
self.create_date = datetime.datetime.now()
self.index = index
self.filename = f"{share.share_id}_{self.upload_id}_{self.index}"
file = storage.get_file(self.filename)
with file.open(mode="wb") as f:
f.write(data)

View File

@ -43,7 +43,7 @@ class FileSystem(Storage):
self._path = self._storage._get_path(name) self._path = self._storage._get_path(name)
self._path.touch() self._path.touch()
def delete(self, name): def delete(self):
self._path.unlink() self._path.unlink()
def open(self, mode="r"): def open(self, mode="r"):

View File

@ -1,8 +1,12 @@
import pytest import pytest
import uuid
from math import ceil
from sachet.server.users import manage from sachet.server.users import manage
from click.testing import CliRunner from click.testing import CliRunner
from sachet.server import app, db, storage from sachet.server import app, db, storage
from sachet.server.models import Permissions, User from sachet.server.models import Permissions, User
from werkzeug.datastructures import FileStorage
from io import BytesIO
from bitmask import Bitmask from bitmask import Bitmask
from pathlib import Path from pathlib import Path
import random import random
@ -200,3 +204,59 @@ def auth(tokens):
return ret return ret
return auth_headers return auth_headers
@pytest.fixture
def upload(client):
"""Perform chunked upload of some data.
Parameters
----------
url : str
URL to upload to.
data : BytesIO
Stream of data to upload.
You can use BytesIO(data) to convert raw bytes to a stream.
headers : dict, optional
Headers to upload with.
chunk_size : int, optional
Size of chunks in bytes.
method : function
Method like client.post or client.put to use.
"""
def upload(url, data, headers={}, chunk_size=int(2e6), method=client.post):
data_size = len(data.getbuffer())
buf = data.getbuffer()
upload_uuid = uuid.uuid4()
total_chunks = int(ceil(data_size / chunk_size))
resp = None
for chunk_idx in range(total_chunks):
start = chunk_size * chunk_idx
end = min(chunk_size * (chunk_idx + 1), data_size)
resp = method(
url,
headers=headers,
data={
"upload": FileStorage(
stream=BytesIO(buf[start:end]), filename="upload"
),
"dzuuid": str(upload_uuid),
"dzchunkindex": chunk_idx,
"dztotalchunks": total_chunks,
},
content_type="multipart/form-data",
)
if not resp.status_code == 200 or resp.status_code == 201:
break
return resp
return upload

View File

@ -6,7 +6,7 @@ import uuid
"""Test anonymous authentication to endpoints.""" """Test anonymous authentication to endpoints."""
def test_files(client, auth, rand): def test_files(client, auth, rand, upload):
# set create perm for anon users # set create perm for anon users
resp = client.patch( resp = client.patch(
"/admin/settings", "/admin/settings",
@ -28,10 +28,9 @@ def test_files(client, auth, rand):
upload_data = rand.randbytes(4000) upload_data = rand.randbytes(4000)
# upload file to share # upload file to share
resp = client.post( resp = upload(
url + "/content", url + "/content",
data={"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")}, BytesIO(upload_data),
content_type="multipart/form-data",
) )
assert resp.status_code == 201 assert resp.status_code == 201
@ -60,12 +59,12 @@ def test_files(client, auth, rand):
# modify share # modify share
upload_data = rand.randbytes(4000) upload_data = rand.randbytes(4000)
resp = client.put( resp = upload(
url + "/content", url + "/content",
data={"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")}, BytesIO(upload_data),
content_type="multipart/form-data", method=client.put,
) )
assert resp.status_code == 200 assert resp.status_code == 201
resp = client.patch( resp = client.patch(
url, url,
json={"file_name": "new_bin.bin"}, json={"file_name": "new_bin.bin"},
@ -130,7 +129,7 @@ def test_files(client, auth, rand):
assert resp.status_code == 404 assert resp.status_code == 404
def test_files_invalid(client, auth, rand): def test_files_invalid(client, auth, rand, upload):
# set create perm for anon users # set create perm for anon users
resp = client.patch( resp = client.patch(
"/admin/settings", "/admin/settings",
@ -151,10 +150,9 @@ def test_files_invalid(client, auth, rand):
data = resp.get_json() data = resp.get_json()
url = data.get("url") url = data.get("url")
upload_data = rand.randbytes(4000) upload_data = rand.randbytes(4000)
resp = client.post( resp = upload(
url + "/content", url + "/content",
data={"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")}, BytesIO(upload_data)
content_type="multipart/form-data",
) )
assert resp.status_code == 201 assert resp.status_code == 201
@ -167,27 +165,25 @@ def test_files_invalid(client, auth, rand):
assert resp.status_code == 200 assert resp.status_code == 200
# test initializing a share without perms # test initializing a share without perms
resp = client.post( resp = upload(
uninit_url + "/content", url + "/content",
data={"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")}, BytesIO(upload_data)
content_type="multipart/form-data",
) )
assert resp.status_code == 401 assert resp.status_code == 401
# test reading a share without perms # test reading a share without perms
resp = client.get(url + "/content") resp = client.get(url + "/content")
# test modifying an uninitialized share without perms # test modifying an uninitialized share without perms
resp = client.put( resp = upload(
uninit_url + "/content", uninit_url + "/content",
data={"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")}, BytesIO(upload_data),
content_type="multipart/form-data", method=client.put
) )
assert resp.status_code == 401 assert resp.status_code == 401
assert resp.status_code == 401
# test modifying a share without perms # test modifying a share without perms
resp = client.put( resp = upload(
url + "/content", url + "/content",
data={"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")}, BytesIO(upload_data),
content_type="multipart/form-data", method=client.put
) )
assert resp.status_code == 401 assert resp.status_code == 401

View File

@ -10,7 +10,7 @@ import uuid
# this might be redundant because test_storage tests the backends already # this might be redundant because test_storage tests the backends already
@pytest.mark.parametrize("client", [{"SACHET_STORAGE": "filesystem"}], indirect=True) @pytest.mark.parametrize("client", [{"SACHET_STORAGE": "filesystem"}], indirect=True)
class TestSuite: class TestSuite:
def test_sharing(self, client, users, auth, rand): def test_sharing(self, client, users, auth, rand, upload):
# create share # create share
resp = client.post( resp = client.post(
"/files", headers=auth("jeff"), json={"file_name": "content.bin"} "/files", headers=auth("jeff"), json={"file_name": "content.bin"}
@ -25,14 +25,11 @@ class TestSuite:
upload_data = rand.randbytes(4000) upload_data = rand.randbytes(4000)
# upload file to share resp = upload(
resp = client.post(
url + "/content", url + "/content",
BytesIO(upload_data),
headers=auth("jeff"), headers=auth("jeff"),
data={ chunk_size=1230,
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
) )
assert resp.status_code == 201 assert resp.status_code == 201
@ -58,7 +55,7 @@ class TestSuite:
) )
assert resp.status_code == 404 assert resp.status_code == 404
def test_modification(self, client, users, auth, rand): def test_modification(self, client, users, auth, rand, upload):
# create share # create share
resp = client.post( resp = client.post(
"/files", headers=auth("jeff"), json={"file_name": "content.bin"} "/files", headers=auth("jeff"), json={"file_name": "content.bin"}
@ -70,14 +67,12 @@ class TestSuite:
new_data = rand.randbytes(4000) new_data = rand.randbytes(4000)
# upload file to share # upload file to share
resp = client.post( resp = upload(
url + "/content", url + "/content",
BytesIO(upload_data),
headers=auth("jeff"), headers=auth("jeff"),
data={
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
) )
assert resp.status_code == 201
# modify metadata # modify metadata
resp = client.patch( resp = client.patch(
@ -88,12 +83,13 @@ class TestSuite:
assert resp.status_code == 200 assert resp.status_code == 200
# modify file contents # modify file contents
resp = client.put( resp = upload(
url + "/content", url + "/content",
BytesIO(new_data),
headers=auth("jeff"), headers=auth("jeff"),
data={"upload": FileStorage(stream=BytesIO(new_data), filename="upload")}, method=client.put,
) )
assert resp.status_code == 200 assert resp.status_code == 201
# read file # read file
resp = client.get( resp = client.get(
@ -103,7 +99,7 @@ class TestSuite:
assert resp.data == new_data assert resp.data == new_data
assert "filename=new_bin.bin" in resp.headers["Content-Disposition"].split("; ") assert "filename=new_bin.bin" in resp.headers["Content-Disposition"].split("; ")
def test_invalid(self, client, users, auth, rand): def test_invalid(self, client, users, auth, rand, upload):
"""Test invalid requests.""" """Test invalid requests."""
upload_data = rand.randbytes(4000) upload_data = rand.randbytes(4000)
@ -141,66 +137,51 @@ class TestSuite:
url = data.get("url") url = data.get("url")
# test invalid methods # test invalid methods
resp = client.put( resp = upload(
url + "/content", url + "/content",
BytesIO(upload_data),
headers=auth("jeff"), headers=auth("jeff"),
data={ method=client.put,
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
) )
assert resp.status_code == 423 assert resp.status_code == 423
resp = client.patch( resp = upload(
url + "/content", url + "/content",
BytesIO(upload_data),
headers=auth("jeff"), headers=auth("jeff"),
data={ method=client.patch,
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
) )
assert resp.status_code == 405 assert resp.status_code == 405
# test other user being unable to upload to this share # test other user being unable to upload to this share
resp = client.post( resp = upload(
url + "/content", url + "/content",
BytesIO(upload_data),
headers=auth("dave"), headers=auth("dave"),
data={
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
) )
assert resp.status_code == 403 assert resp.status_code == 403
# upload file to share (properly) # upload file to share (properly)
resp = client.post( resp = upload(
url + "/content", url + "/content",
BytesIO(upload_data),
headers=auth("jeff"), headers=auth("jeff"),
data={
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
) )
assert resp.status_code == 201 assert resp.status_code == 201
# test other user being unable to modify this share # test other user being unable to modify this share
resp = client.put( resp = upload(
url + "/content", url + "/content",
BytesIO(upload_data),
headers=auth("dave"), headers=auth("dave"),
data={ method=client.put,
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
) )
assert resp.status_code == 403 assert resp.status_code == 403
# test not allowing re-upload # test not allowing re-upload
resp = client.post( resp = upload(
url + "/content", url + "/content",
BytesIO(upload_data),
headers=auth("jeff"), headers=auth("jeff"),
data={
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
) )
assert resp.status_code == 423 assert resp.status_code == 423
@ -210,7 +191,7 @@ class TestSuite:
resp = client.get(url + "/content", headers=auth("no_read_user")) resp = client.get(url + "/content", headers=auth("no_read_user"))
assert resp.status_code == 403 assert resp.status_code == 403
def test_locking(self, client, users, auth, rand): def test_locking(self, client, users, auth, rand, upload):
# upload share # upload share
resp = client.post( resp = client.post(
"/files", headers=auth("jeff"), json={"file_name": "content.bin"} "/files", headers=auth("jeff"), json={"file_name": "content.bin"}
@ -218,13 +199,10 @@ class TestSuite:
data = resp.get_json() data = resp.get_json()
url = data.get("url") url = data.get("url")
upload_data = rand.randbytes(4000) upload_data = rand.randbytes(4000)
resp = client.post( resp = upload(
url + "/content", url + "/content",
BytesIO(upload_data),
headers=auth("jeff"), headers=auth("jeff"),
data={
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
) )
assert resp.status_code == 201 assert resp.status_code == 201
@ -236,13 +214,11 @@ class TestSuite:
assert resp.status_code == 200 assert resp.status_code == 200
# attempt to modify share # attempt to modify share
resp = client.put( resp = upload(
url + "/content", url + "/content",
BytesIO(upload_data),
headers=auth("jeff"), headers=auth("jeff"),
data={ method=client.put,
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
) )
assert resp.status_code == 423 assert resp.status_code == 423
@ -261,15 +237,13 @@ class TestSuite:
assert resp.status_code == 200 assert resp.status_code == 200
# attempt to modify share # attempt to modify share
resp = client.put( resp = upload(
url + "/content", url + "/content",
BytesIO(upload_data),
headers=auth("jeff"), headers=auth("jeff"),
data={ method=client.put,
"upload": FileStorage(stream=BytesIO(upload_data), filename="upload")
},
content_type="multipart/form-data",
) )
assert resp.status_code == 200 assert resp.status_code == 201
# attempt to delete share # attempt to delete share
resp = client.delete( resp = client.delete(