diff --git a/pyproject.toml b/pyproject.toml index 0801600..5f2dc91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "spiderweb-framework" -version = "0.10.0" +version = "0.11.0" description = "A small web framework, just big enough for a spider." authors = ["Joe Kaufeld "] readme = "README.md" diff --git a/spiderweb/constants.py b/spiderweb/constants.py index ab3aed6..cf8734d 100644 --- a/spiderweb/constants.py +++ b/spiderweb/constants.py @@ -2,7 +2,7 @@ from peewee import DatabaseProxy DEFAULT_ALLOWED_METHODS = ["GET"] DEFAULT_ENCODING = "UTF-8" -__version__ = "0.10.0" +__version__ = "0.11.0" # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" diff --git a/spiderweb/exceptions.py b/spiderweb/exceptions.py index c826d40..bdba675 100644 --- a/spiderweb/exceptions.py +++ b/spiderweb/exceptions.py @@ -5,7 +5,7 @@ class SpiderwebException(Exception): msg = self.args[0] if len(self.args) > 0 else "" if msg: return f"{name}() - {msg}" - return f"{self.__class__.__name__}()" + return f"{name}()" class SpiderwebNetworkException(SpiderwebException): diff --git a/spiderweb/main.py b/spiderweb/main.py index 8727932..b0333dc 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -32,7 +32,7 @@ from spiderweb.routes import RoutesMixin from spiderweb.secrets import FernetMixin from spiderweb.utils import get_http_status_by_code -file_logger = logging.getLogger(__name__) +console_logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -57,7 +57,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi session_cookie_same_site="lax", session_cookie_path="/", log=None, - **kwargs + **kwargs, ): self._routes = {} self.routes = routes @@ -69,7 +69,8 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi self.append_slash = append_slash self.templates_dirs = templates_dirs self.staticfiles_dirs = staticfiles_dirs - self.middleware = middleware if middleware else [] + self._middleware: list[str] = middleware if middleware else [] + self.middleware: list[Callable] = [] self.secret_key = secret_key if secret_key else self.generate_key() self.extra_data = kwargs @@ -84,7 +85,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi self.DEFAULT_ENCODING = DEFAULT_ENCODING self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS - self.log: logging.Logger = log if log else file_logger + self.log: logging.Logger = log if log else console_logger # for using .start() and .stop() self._thread: Optional[Thread] = None @@ -108,7 +109,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi self.add_error_routes() if self.templates_dirs: - self.template_loader = Environment(loader=FileSystemLoader(self.templates_dirs)) + self.template_loader = Environment( + loader=FileSystemLoader(self.templates_dirs) + ) else: self.template_loader = None self.string_loader = Environment(loader=BaseLoader()) @@ -117,12 +120,16 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi for static_dir in self.staticfiles_dirs: static_dir = pathlib.Path(static_dir) if not pathlib.Path(self.BASE_DIR / static_dir).exists(): - log.error( + self.log.error( f"Static files directory '{str(static_dir)}' does not exist." ) raise ConfigError self.add_route(r"/static/", send_file) # noqa: F405 + # finally, run the startup checks to verify everything is correct and happy. + self.log.info("Run startup checks...") + self.run_middleware_checks() + def fire_response(self, start_response, request: Request, resp: HttpResponse): try: status = get_http_status_by_code(resp.status_code) @@ -190,7 +197,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi def prepare_and_fire_response(self, start_response, request, resp) -> list[bytes]: try: if isinstance(resp, dict): - return self.fire_response(start_response, request, JsonResponse(data=resp)) + return self.fire_response( + start_response, request, JsonResponse(data=resp) + ) if isinstance(resp, TemplateResponse): resp.set_template_loader(self.template_loader) resp.set_string_loader(self.string_loader) diff --git a/spiderweb/middleware/__init__.py b/spiderweb/middleware/__init__.py index c43b43b..3ffeb8e 100644 --- a/spiderweb/middleware/__init__.py +++ b/spiderweb/middleware/__init__.py @@ -12,19 +12,26 @@ from ..utils import import_by_string class MiddlewareMixin: """Cannot be called on its own. Requires context of SpiderwebRouter.""" + _middleware: list[str] middleware: list[ClassVar] fire_response: Callable def init_middleware(self): - if self.middleware: + if self._middleware: middleware_by_reference = [] - for m in self.middleware: + for m in self._middleware: try: middleware_by_reference.append(import_by_string(m)(server=self)) except ImportError: raise ConfigError(f"Middleware '{m}' not found.") self.middleware = middleware_by_reference + def run_middleware_checks(self): + for middleware in self.middleware: + if hasattr(middleware, "checks"): + for check in middleware.checks: + check(server=self).check() + def process_request_middleware(self, request: Request) -> None | bool: for middleware in self.middleware: try: diff --git a/spiderweb/middleware/base.py b/spiderweb/middleware/base.py index 0cf3b00..9a68e3e 100644 --- a/spiderweb/middleware/base.py +++ b/spiderweb/middleware/base.py @@ -30,4 +30,4 @@ class SpiderwebMiddleware: pass def on_error(self, request: Request, e: Exception) -> HttpResponse | None: - pass \ No newline at end of file + pass diff --git a/spiderweb/middleware/csrf.py b/spiderweb/middleware/csrf.py index b2dd922..5a128e7 100644 --- a/spiderweb/middleware/csrf.py +++ b/spiderweb/middleware/csrf.py @@ -1,25 +1,42 @@ from datetime import datetime, timedelta -from spiderweb.exceptions import CSRFError +from spiderweb.exceptions import CSRFError, ConfigError from spiderweb.middleware import SpiderwebMiddleware from spiderweb.request import Request from spiderweb.response import HttpResponse +from spiderweb.server_checks import ServerCheck + + +class SessionCheck(ServerCheck): + + SESSION_MIDDLEWARE_NOT_FOUND = ( + "Session middleware is not enabled. It must be listed above" + "CSRFMiddleware in the middleware list." + ) + SESSION_MIDDLEWARE_BELOW_CSRF = ( + "SessionMiddleware is enabled, but it must be listed above" + "CSRFMiddleware in the middleware list." + ) + + def check(self): + + if ( + "spiderweb.middleware.sessions.SessionMiddleware" + not in self.server._middleware + ): + raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND) + + if self.server._middleware.index( + "spiderweb.middleware.sessions.SessionMiddleware" + ) > self.server._middleware.index( + "spiderweb.middleware.csrf.CSRFMiddleware" + ): + raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF) class CSRFMiddleware(SpiderwebMiddleware): - """ - tl;dr: this is a naive implementation going off just what I could think of - at the time. It is very vulnerable to CSRF Forgery and should be updated. - Eventually I'll probably just pull everything out of Django and use their - implementation, as it's written by people who know a lot more about these - things than I do, but in the meantime, this is still here until I get - around to making it more solid. - - todo: fix - - https://en.wikipedia.org/wiki/Cross-site_request_forgery - """ + checks = [SessionCheck] CSRF_EXPIRY = 60 * 60 # 1 hour @@ -33,14 +50,14 @@ class CSRFMiddleware(SpiderwebMiddleware): or request.GET.get("csrf_token") or request.POST.get("csrf_token") ) - if self.is_csrf_valid(csrf_token): + if self.is_csrf_valid(request, csrf_token): return None else: raise CSRFError() return None def process_response(self, request: Request, response: HttpResponse) -> None: - token = self.get_csrf_token() + token = self.get_csrf_token(request) # do we need it in both places? response.headers["X-CSRF-TOKEN"] = token response.context |= { @@ -48,17 +65,22 @@ class CSRFMiddleware(SpiderwebMiddleware): "raw_csrf_token": token, # in case they want to format it themselves } - def get_csrf_token(self): - return self.server.encrypt(str(datetime.now().isoformat())).decode( - self.server.DEFAULT_ENCODING - ) + def get_csrf_token(self, request): + # the session key should be here because we've processed the session first + session_key = request._session["id"] + return self.server.encrypt( + f"{str(datetime.now().isoformat())}::{session_key}" + ).decode(self.server.DEFAULT_ENCODING) - def is_csrf_valid(self, key): + def is_csrf_valid(self, request, key): try: decoded = self.server.decrypt(key) + timestamp, session_key = decoded.split("::") + if session_key != request._session["id"]: + return False if datetime.now() - timedelta( seconds=self.CSRF_EXPIRY - ) > datetime.fromisoformat(decoded): + ) > datetime.fromisoformat(timestamp): return False return True except Exception: diff --git a/spiderweb/routes.py b/spiderweb/routes.py index 5b4421b..4b26448 100644 --- a/spiderweb/routes.py +++ b/spiderweb/routes.py @@ -30,7 +30,7 @@ class RoutesMixin: # ones that start with underscores are the compiled versions, non-underscores # are the user-supplied versions _routes: dict - routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None, + routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,) _error_routes: dict error_routes: dict[int, Callable] append_slash: bool diff --git a/spiderweb/server_checks.py b/spiderweb/server_checks.py new file mode 100644 index 0000000..7ec14fe --- /dev/null +++ b/spiderweb/server_checks.py @@ -0,0 +1,18 @@ +class ServerCheck: + """ + Single server check base class. + + During startup, each middleware can request checks to be run against the + current state of the server. These are usually used to verify that things + are configured correctly, but can also be used for database setup or other + similar things. + + To fail a check, raise any error that makes sense to raise. This will halt + startup so the error can be fixed. + """ + + def __init__(self, server): + self.server = server + + def check(self): + pass diff --git a/spiderweb/tests/__init__.py b/spiderweb/tests/__init__.py index 187efdc..cec2bf9 100644 --- a/spiderweb/tests/__init__.py +++ b/spiderweb/tests/__init__.py @@ -1 +1,4 @@ -from spiderweb.tests.middleware import ExplodingResponseMiddleware, ExplodingRequestMiddleware \ No newline at end of file +from spiderweb.tests.middleware import ( + ExplodingResponseMiddleware, + ExplodingRequestMiddleware, +) diff --git a/spiderweb/tests/helpers.py b/spiderweb/tests/helpers.py index 7948c1a..10e88ee 100644 --- a/spiderweb/tests/helpers.py +++ b/spiderweb/tests/helpers.py @@ -25,4 +25,4 @@ def setup(): SpiderwebRouter(db=SqliteDatabase("spiderweb-tests.db")), environ, StartResponse(), - ) \ No newline at end of file + ) diff --git a/spiderweb/tests/middleware.py b/spiderweb/tests/middleware.py index c2e32c4..6c555ff 100644 --- a/spiderweb/tests/middleware.py +++ b/spiderweb/tests/middleware.py @@ -7,5 +7,7 @@ class ExplodingRequestMiddleware(SpiderwebMiddleware): class ExplodingResponseMiddleware(SpiderwebMiddleware): - def process_response(self, request: Request, response: HttpResponse) -> HttpResponse | None: - raise UnusedMiddleware("Unfinished!") \ No newline at end of file + def process_response( + self, request: Request, response: HttpResponse + ) -> HttpResponse | None: + raise UnusedMiddleware("Unfinished!") diff --git a/spiderweb/tests/test_middleware.py b/spiderweb/tests/test_middleware.py index 0aff358..e727b07 100644 --- a/spiderweb/tests/test_middleware.py +++ b/spiderweb/tests/test_middleware.py @@ -1,11 +1,16 @@ +from io import BytesIO, BufferedReader from datetime import timedelta +import pytest from peewee import SqliteDatabase -from spiderweb import SpiderwebRouter, HttpResponse +from spiderweb import SpiderwebRouter, HttpResponse, ConfigError from spiderweb.constants import DEFAULT_ENCODING from spiderweb.middleware.sessions import Session +from spiderweb.middleware import csrf from spiderweb.tests.helpers import setup +from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf + # app = SpiderwebRouter( # middleware=[ @@ -18,19 +23,20 @@ from spiderweb.tests.helpers import setup # ], # ) + def index(request): if "value" in request.SESSION: - request.SESSION['value'] += 1 + request.SESSION["value"] += 1 else: - request.SESSION['value'] = 0 - return HttpResponse(body=str(request.SESSION['value'])) + request.SESSION["value"] = 0 + return HttpResponse(body=str(request.SESSION["value"])) def test_session_middleware(): _, environ, start_response = setup() app = SpiderwebRouter( middleware=["spiderweb.middleware.sessions.SessionMiddleware"], - db=SqliteDatabase("spiderweb-tests.db") + db=SqliteDatabase("spiderweb-tests.db"), ) app.add_route("/", index) @@ -46,11 +52,12 @@ def test_session_middleware(): assert app(environ, start_response) == [bytes(str(1), DEFAULT_ENCODING)] assert app(environ, start_response) == [bytes(str(2), DEFAULT_ENCODING)] + def test_expired_session(): _, environ, start_response = setup() app = SpiderwebRouter( middleware=["spiderweb.middleware.sessions.SessionMiddleware"], - db=SqliteDatabase("spiderweb-tests.db") + db=SqliteDatabase("spiderweb-tests.db"), ) app.add_route("/", index) @@ -80,9 +87,170 @@ def test_exploding_middleware(): "spiderweb.tests.middleware.ExplodingRequestMiddleware", "spiderweb.tests.middleware.ExplodingResponseMiddleware", ], - db=SqliteDatabase("spiderweb-tests.db") + db=SqliteDatabase("spiderweb-tests.db"), ) app.add_route("/", index) assert app(environ, start_response) == [bytes(str(0), DEFAULT_ENCODING)] + # make sure it kicked out the middleware and isn't just ignoring it + assert len(app.middleware) == 0 + + +def test_csrf_middleware_without_session_middleware(): + _, environ, start_response = setup() + with pytest.raises(ConfigError) as e: + SpiderwebRouter( + middleware=["spiderweb.middleware.csrf.CSRFMiddleware"], + db=SqliteDatabase("spiderweb-tests.db"), + ) + + assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND + + +def test_csrf_middleware_above_session_middleware(): + _, environ, start_response = setup() + with pytest.raises(ConfigError) as e: + SpiderwebRouter( + middleware=[ + "spiderweb.middleware.csrf.CSRFMiddleware", + "spiderweb.middleware.sessions.SessionMiddleware", + ], + db=SqliteDatabase("spiderweb-tests.db"), + ) + + assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF + + +def test_csrf_middleware(): + _, environ, start_response = setup() + app = SpiderwebRouter( + middleware=[ + "spiderweb.middleware.sessions.SessionMiddleware", + "spiderweb.middleware.csrf.CSRFMiddleware", + ], + db=SqliteDatabase("spiderweb-tests.db"), + ) + + app.add_route("/", form_view_with_csrf, ["GET", "POST"]) + + environ["HTTP_USER_AGENT"] = "hi" + environ["REMOTE_ADDR"] = "1.1.1.1" + + resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING) + + assert "") def index(request, test_input: str): return HttpResponse(test_input) @@ -52,19 +58,19 @@ def test_duplicate_route(): app, environ, start_response = setup() @app.route("/") - def index(request): - ... + def index(request): ... with pytest.raises(ConfigError): + @app.route("/") - def index(request): - ... + def index(request): ... def test_url_with_double_underscore(): app, environ, start_response = setup() with pytest.raises(ConfigError): + @app.route("/") def index(request, test_input: str): return HttpResponse(test_input) @@ -99,4 +105,4 @@ def test_float_converter(number): return HttpResponse(test_input) environ["PATH_INFO"] = f"/{number}" - assert app(environ, start_response) == [bytes(str(number), DEFAULT_ENCODING)] \ No newline at end of file + assert app(environ, start_response) == [bytes(str(number), DEFAULT_ENCODING)] diff --git a/spiderweb/tests/views_for_tests.py b/spiderweb/tests/views_for_tests.py new file mode 100644 index 0000000..cc4dac1 --- /dev/null +++ b/spiderweb/tests/views_for_tests.py @@ -0,0 +1,40 @@ +from spiderweb.decorators import csrf_exempt +from spiderweb.response import JsonResponse, TemplateResponse + + +EXAMPLE_HTML_FORM = """ +
+ + +
+""" + +EXAMPLE_HTML_FORM_WITH_CSRF = """ +
+ + + {{ csrf_token }} +
+""" + + +def form_view_without_csrf(request): + if request.method == "POST": + return JsonResponse(data=request.POST) + else: + return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM) + + +@csrf_exempt +def form_csrf_exempt(request): + if request.method == "POST": + return JsonResponse(data=request.POST) + else: + return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF) + + +def form_view_with_csrf(request): + if request.method == "POST": + return JsonResponse(data=request.POST) + else: + return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)