🐛 fix importing middleware

This commit is contained in:
Joe Kaufeld 2024-09-16 18:15:54 -04:00
parent 24109014af
commit 1d8559f766
3 changed files with 13 additions and 10 deletions

View File

@ -2,9 +2,6 @@ from typing import Callable, ClassVar
import sys import sys
from .base import SpiderwebMiddleware as SpiderwebMiddleware from .base import SpiderwebMiddleware as SpiderwebMiddleware
from .cors import CorsMiddleware as CorsMiddleware
from .csrf import CSRFMiddleware as CSRFMiddleware
from .sessions import SessionMiddleware as SessionMiddleware
from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors
from ..request import Request from ..request import Request
from ..response import HttpResponse from ..response import HttpResponse

View File

@ -4,7 +4,7 @@ from datetime import timedelta
import pytest import pytest
from peewee import SqliteDatabase from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse, StartupErrors from spiderweb import SpiderwebRouter, HttpResponse, StartupErrors, ConfigError
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.cors import ( from spiderweb.middleware.cors import (
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_ALLOW_ORIGIN,
@ -94,6 +94,13 @@ def test_exploding_middleware():
assert len(app.middleware) == 0 assert len(app.middleware) == 0
def test_invalid_middleware():
with pytest.raises(ConfigError) as e:
SpiderwebRouter(middleware=["nonexistent.middleware"])
assert e.value.args[0] == "Middleware 'nonexistent.middleware' not found."
def test_csrf_middleware_without_session_middleware(): def test_csrf_middleware_without_session_middleware():
with pytest.raises(StartupErrors) as e: with pytest.raises(StartupErrors) as e:
SpiderwebRouter( SpiderwebRouter(

View File

@ -1,3 +1,4 @@
import importlib
import json import json
import re import re
import secrets import secrets
@ -13,12 +14,10 @@ VALID_CHARS = string.ascii_letters + string.digits
def import_by_string(name): def import_by_string(name):
# https://stackoverflow.com/a/547867 mod_name, klass_name = name.rsplit(".", 1)
components = name.split(".") module = importlib.import_module(mod_name)
mod = __import__(components[0]) klass = getattr(module, klass_name)
for comp in components[1:]: return klass
mod = getattr(mod, comp)
return mod
def is_safe_path(path: str) -> bool: def is_safe_path(path: str) -> bool: