diff --git a/spiderweb/middleware/__init__.py b/spiderweb/middleware/__init__.py index 2d2b0be..3bcfce2 100644 --- a/spiderweb/middleware/__init__.py +++ b/spiderweb/middleware/__init__.py @@ -2,9 +2,6 @@ from typing import Callable, ClassVar import sys from .base import SpiderwebMiddleware as SpiderwebMiddleware -from .cors import CorsMiddleware as CorsMiddleware -from .csrf import CSRFMiddleware as CSRFMiddleware -from .sessions import SessionMiddleware as SessionMiddleware from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors from ..request import Request from ..response import HttpResponse diff --git a/spiderweb/tests/test_middleware.py b/spiderweb/tests/test_middleware.py index ac558d6..b42fa61 100644 --- a/spiderweb/tests/test_middleware.py +++ b/spiderweb/tests/test_middleware.py @@ -4,7 +4,7 @@ from datetime import timedelta import pytest from peewee import SqliteDatabase -from spiderweb import SpiderwebRouter, HttpResponse, StartupErrors +from spiderweb import SpiderwebRouter, HttpResponse, StartupErrors, ConfigError from spiderweb.constants import DEFAULT_ENCODING from spiderweb.middleware.cors import ( ACCESS_CONTROL_ALLOW_ORIGIN, @@ -94,6 +94,13 @@ def test_exploding_middleware(): assert len(app.middleware) == 0 +def test_invalid_middleware(): + with pytest.raises(ConfigError) as e: + SpiderwebRouter(middleware=["nonexistent.middleware"]) + + assert e.value.args[0] == "Middleware 'nonexistent.middleware' not found." + + def test_csrf_middleware_without_session_middleware(): with pytest.raises(StartupErrors) as e: SpiderwebRouter( diff --git a/spiderweb/utils.py b/spiderweb/utils.py index 635c910..15f8225 100644 --- a/spiderweb/utils.py +++ b/spiderweb/utils.py @@ -1,3 +1,4 @@ +import importlib import json import re import secrets @@ -13,12 +14,10 @@ VALID_CHARS = string.ascii_letters + string.digits def import_by_string(name): - # https://stackoverflow.com/a/547867 - components = name.split(".") - mod = __import__(components[0]) - for comp in components[1:]: - mod = getattr(mod, comp) - return mod + mod_name, klass_name = name.rsplit(".", 1) + module = importlib.import_module(mod_name) + klass = getattr(module, klass_name) + return klass def is_safe_path(path: str) -> bool: