diff --git a/sachet/server/files/views.py b/sachet/server/files/views.py index a3c09f4..4beedfd 100644 --- a/sachet/server/files/views.py +++ b/sachet/server/files/views.py @@ -208,49 +208,15 @@ class FileContentAPI(MethodView): file = share.get_handle() - range_header = request.headers.get("Range") - if not range_header: - resp = make_response( - send_file(file.open(mode="rb"), download_name=share.file_name) - ) - resp.headers["Accept-Ranges"] = "bytes" - return resp - - # handle partial file request - vals = range_header.strip().split("=") - if len(vals) != 2: - return ( - jsonify( - dict( - status="fail", message=f"Invalid range header '{range_header}'." - ) - ), - 400, - ) - - try: - start, end = vals[1].split("-") - start = int(start) - end = int(end) or file.size - 1 - except ValueError as err: - return ( - jsonify(dict(status="fail", message=str(err))), - 400, - ) - - content_length = end - start + 1 - - with file.open(mode="rb") as f: - f.seek(start) + with file.open("rb") as f: resp = make_response( send_file( - io.BytesIO(f.read(content_length)), download_name=share.file_name + io.BytesIO(f.read()), + download_name=share.file_name, + conditional=True, ) ) - resp.headers["Content-Range"] = f"bytes {start}-{end}/{file.size}" - resp.headers["Accept-Ranges"] = "bytes" - resp.headers["Content-Length"] = content_length - return resp, 206 + return resp files_blueprint.add_url_rule( diff --git a/tests/test_files.py b/tests/test_files.py index 9c8ec6e..30d15e9 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -268,3 +268,58 @@ class TestSuite: headers=auth("no_lock_user"), ) assert resp.status_code == 403 + + def test_partial(self, client, users, auth, rand, upload): + # create share + resp = client.post( + "/files", headers=auth("jeff"), json={"file_name": "content.bin"} + ) + assert resp.status_code == 201 + + data = resp.get_json() + url = data.get("url") + + upload_data = b"1234567890" * 400 + + resp = upload( + url + "/content", + BytesIO(upload_data), + headers=auth("jeff"), + chunk_size=1230, + ) + assert resp.status_code == 201 + + # test the following ranges + ranges = [ + [0, 1], + [1, 1], + [2, 300], + [300, 30], + [3, 4], + [30, 3999], + [4000, 4000], + [3999, 39999], + [40000, 0], + [48000, 9], + [-1, 0], + [-2, 3], + [0, 4000], + [0, ""], + ] + + for r in ranges: + resp = client.get( + url + "/content", + headers=auth("jeff", data={"Range": f"bytes={r[0]}-{r[1]}"}), + ) + if r[1] == "": + r[1] = len(upload_data) + # apparently if you specify an endpoint past the end + # it just truncates the response to the end + if r[0] < 0 or r[0] >= 4000: + assert resp.status_code == 416 + elif r[0] > r[1]: + assert resp.status_code == 416 + else: + assert resp.status_code == 206 + assert resp.data == upload_data[r[0] : r[1] + 1]