diff --git a/sachet/server/__init__.py b/sachet/server/__init__.py index 8f411b7..0c55c42 100644 --- a/sachet/server/__init__.py +++ b/sachet/server/__init__.py @@ -9,13 +9,14 @@ from .config import DevelopmentConfig, ProductionConfig, TestingConfig, overlay_ app = Flask(__name__) CORS(app) -if os.getenv("RUN_ENV") == "test": - overlay_config(TestingConfig, "./config-testing.yml") -elif app.config["DEBUG"]: - overlay_config(DevelopmentConfig) - app.logger.warning("Running in DEVELOPMENT MODE; do NOT use this in production!") -else: - overlay_config(ProductionConfig) +with app.app_context(): + if os.getenv("RUN_ENV") == "test": + overlay_config(TestingConfig, "./config-testing.yml") + elif app.config["DEBUG"]: + overlay_config(DevelopmentConfig) + app.logger.warning("Running in DEVELOPMENT MODE; do NOT use this in production!") + else: + overlay_config(ProductionConfig) bcrypt = Bcrypt(app) db = SQLAlchemy(app) @@ -27,10 +28,13 @@ storage = None from sachet.storage import FileSystem -if _storage_method == "filesystem": - storage = FileSystem() -else: - raise ValueError(f"{_storage_method} is not a valid storage method.") + +with app.app_context(): + db.create_all() + if _storage_method == "filesystem": + storage = FileSystem() + else: + raise ValueError(f"{_storage_method} is not a valid storage method.") import sachet.server.commands @@ -45,6 +49,3 @@ app.register_blueprint(admin_blueprint) from sachet.server.files.views import files_blueprint app.register_blueprint(files_blueprint) - -with app.app_context(): - db.create_all() diff --git a/sachet/server/config.py b/sachet/server/config.py index e5f3a8e..6875717 100644 --- a/sachet/server/config.py +++ b/sachet/server/config.py @@ -1,4 +1,5 @@ from os import getenv, path +from flask import current_app import yaml sqlalchemy_base = "sqlite:///sachet" @@ -51,9 +52,7 @@ def overlay_config(base, config_file=None): if config["SECRET_KEY"] == "" or config["SECRET_KEY"] is None: raise ValueError("Please set secret_key within the configuration.") - from sachet.server import app - - app.config.from_object(base) + current_app.config.from_object(base) for k, v in config.items(): - app.config[k] = v + current_app.config[k] = v diff --git a/sachet/server/models.py b/sachet/server/models.py index 1f0cd6d..291de02 100644 --- a/sachet/server/models.py +++ b/sachet/server/models.py @@ -1,10 +1,10 @@ -from sachet.server import app, db, ma, bcrypt, storage +from sachet.server import db, ma, bcrypt, storage import datetime import jwt from enum import IntFlag from bitmask import Bitmask from marshmallow import fields, ValidationError -from flask import request, jsonify, url_for +from flask import request, jsonify, url_for, current_app from sqlalchemy_utils import UUIDType import uuid @@ -87,7 +87,7 @@ class User(db.Model): self.permissions = permissions self.password = bcrypt.generate_password_hash( - password, app.config.get("BCRYPT_LOG_ROUNDS") + password, current_app.config.get("BCRYPT_LOG_ROUNDS") ).decode() self.username = username self.register_date = datetime.datetime.now() @@ -100,7 +100,7 @@ class User(db.Model): "sub": self.username, "jti": jti, } - return jwt.encode(payload, app.config.get("SECRET_KEY"), algorithm="HS256") + return jwt.encode(payload, current_app.config.get("SECRET_KEY"), algorithm="HS256") def read_token(token): """Read a JWT and validate it. @@ -111,7 +111,7 @@ class User(db.Model): data = jwt.decode( token, - app.config["SECRET_KEY"], + current_app.config["SECRET_KEY"], algorithms=["HS256"], ) @@ -153,7 +153,7 @@ class BlacklistToken(db.Model): data = jwt.decode( token, - app.config["SECRET_KEY"], + current_app.config["SECRET_KEY"], algorithms=["HS256"], ) self.expires = datetime.datetime.fromtimestamp(data["exp"]) diff --git a/sachet/storage/filesystem.py b/sachet/storage/filesystem.py index 613f8a5..b0710f7 100644 --- a/sachet/storage/filesystem.py +++ b/sachet/storage/filesystem.py @@ -1,4 +1,5 @@ from sachet.storage import Storage +from flask import current_app from pathlib import Path from werkzeug.utils import secure_filename import json @@ -6,21 +7,18 @@ import json class FileSystem(Storage): def __init__(self): - # prevent circular import when inspecting this file outside of Flask - from sachet.server import app - - config_path = Path(app.config["SACHET_FILE_DIR"]) + config_path = Path(current_app.config["SACHET_FILE_DIR"]) if config_path.is_absolute(): self._directory = config_path else: - self._directory = Path(app.instance_path) / config_path + self._directory = Path(current_app.instance_path) / config_path self._files_directory = self._directory / Path("files") self._files_directory.mkdir(mode=0o700, exist_ok=True, parents=True) if not self._directory.is_dir(): - raise OSError(f"'{app.config['SACHET_FILE_DIR']}' is not a directory.") + raise OSError(f"'{current_app.config['SACHET_FILE_DIR']}' is not a directory.") def _get_path(self, name): name = secure_filename(name) diff --git a/tests/test_cli.py b/tests/test_cli.py index 65243ac..e442838 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,6 @@ import pytest from sachet.server.commands import create_db, drop_db, create_user, delete_user -from sachet.server import app, db +from sachet.server import db from sqlalchemy import inspect from sachet.server.models import User