From 9330918009daf01e1ba32f041bda797c25483e2a Mon Sep 17 00:00:00 2001 From: Joe Kaufeld Date: Sun, 1 Sep 2024 21:05:24 -0400 Subject: [PATCH] :lock: fix issues with CSRF middleware --- docs/middleware/csrf.md | 3 -- spiderweb/main.py | 21 +++++++++-- spiderweb/middleware/csrf.py | 67 ++++++++++++++++++++++++++++-------- 3 files changed, 72 insertions(+), 19 deletions(-) diff --git a/docs/middleware/csrf.md b/docs/middleware/csrf.md index 8458dd3..b1f9c7d 100644 --- a/docs/middleware/csrf.md +++ b/docs/middleware/csrf.md @@ -11,9 +11,6 @@ app = SpiderwebRouter( ) ``` -> [!DANGER] -> The CSRFMiddleware is incomplete at best and dangerous at worst. I am not a security expert, and my implementation is [very susceptible to the thing it is meant to prevent](https://en.wikipedia.org/wiki/Cross-site_request_forgery). While this is an big issue (and moderately hilarious), the middleware is still provided to you in its unfinished state. Be aware. - Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more. > [!TIP] diff --git a/spiderweb/main.py b/spiderweb/main.py index ae1ad69..2f5dfc9 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -45,6 +45,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi allowed_hosts=None, cors_allowed_origins=None, cors_allow_all_origins=False, + csrf_trusted_origins: Sequence[str] = None, db: Optional[Database] = None, templates_dirs: list[str] = None, middleware: list[str] = None, @@ -75,10 +76,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi self._middleware: list[str] = middleware or [] self.middleware: list[Callable] = [] self.secret_key = secret_key if secret_key else self.generate_key() - self.allowed_hosts = allowed_hosts or ["*"] + self._allowed_hosts = allowed_hosts or ["*"] + self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts] self.cors_allowed_origins = cors_allowed_origins or [] self.cors_allow_all_origins = cors_allow_all_origins + self._csrf_trusted_origins = csrf_trusted_origins or [] + self.csrf_trusted_origins = [ + convert_url_to_regex(i) for i in self._csrf_trusted_origins + ] self.extra_data = kwargs @@ -154,7 +160,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi for v in varies: headers.append(("Vary", v)) - start_response(status, headers) rendered_output = resp.render() @@ -231,6 +236,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi start_response, request, self.get_error_route(500)(request) ) + def check_valid_host(self, request) -> bool: + host = request.headers.get("http_host") + if not host: + return False + for option in self.allowed_hosts: + if re.match(option, host): + return True + return False + def __call__(self, environ, start_response, *args, **kwargs): """Entry point for WSGI apps.""" request = self.get_request(environ) @@ -247,6 +261,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi # replace the potentially valid handler with the error route handler = self.get_error_route(405) + if not self.check_valid_host(request): + handler = self.get_error_route(403) + if request.is_form_request(): form_data = urlparse.parse_qs(request.content) for key, value in form_data.items(): diff --git a/spiderweb/middleware/csrf.py b/spiderweb/middleware/csrf.py index 5a128e7..3a0ffa9 100644 --- a/spiderweb/middleware/csrf.py +++ b/spiderweb/middleware/csrf.py @@ -1,4 +1,7 @@ +import re +from re import Pattern from datetime import datetime, timedelta +from typing import Optional from spiderweb.exceptions import CSRFError, ConfigError from spiderweb.middleware import SpiderwebMiddleware @@ -7,53 +10,89 @@ from spiderweb.response import HttpResponse from spiderweb.server_checks import ServerCheck -class SessionCheck(ServerCheck): - +class CheckForSessionMiddleware(ServerCheck): SESSION_MIDDLEWARE_NOT_FOUND = ( "Session middleware is not enabled. It must be listed above" "CSRFMiddleware in the middleware list." ) + + def check(self) -> Optional[Exception]: + if ( + "spiderweb.middleware.sessions.SessionMiddleware" + not in self.server._middleware + ): + return ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND) + + +class VerifyCorrectMiddlewarePlacement(ServerCheck): SESSION_MIDDLEWARE_BELOW_CSRF = ( "SessionMiddleware is enabled, but it must be listed above" "CSRFMiddleware in the middleware list." ) - def check(self): - + def check(self) -> Optional[Exception]: if ( "spiderweb.middleware.sessions.SessionMiddleware" not in self.server._middleware ): - raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND) + # this is handled by CheckForSessionMiddleware + return if self.server._middleware.index( "spiderweb.middleware.sessions.SessionMiddleware" - ) > self.server._middleware.index( - "spiderweb.middleware.csrf.CSRFMiddleware" - ): - raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF) + ) > self.server._middleware.index("spiderweb.middleware.csrf.CSRFMiddleware"): + return ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF) + + +class VerifyCorrectFormatForTrustedOrigins(ServerCheck): + CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR = ( + "The csrf_trusted_origins setting must be a list of strings." + ) + + def check(self) -> Optional[Exception]: + if not isinstance(self.server.csrf_trusted_origins, list): + return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR) + + for item in self.server.csrf_trusted_origins: + if not isinstance(item, Pattern): + # It's a pattern here because we've already manipulated it + # by the time this check runs + return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR) class CSRFMiddleware(SpiderwebMiddleware): - checks = [SessionCheck] + checks = [ + CheckForSessionMiddleware, + VerifyCorrectMiddlewarePlacement, + VerifyCorrectFormatForTrustedOrigins, + ] CSRF_EXPIRY = 60 * 60 # 1 hour def process_request(self, request: Request) -> HttpResponse | None: if request.method == "POST": + trusted_origin = False if hasattr(request.handler, "csrf_exempt"): if request.handler.csrf_exempt is True: return + if origin := request.headers.get("http_origin"): + + for re_origin in self.server.csrf_trusted_origins: + if re.match(re_origin, origin): + trusted_origin = True + csrf_token = ( request.headers.get("X-CSRF-TOKEN") or request.GET.get("csrf_token") or request.POST.get("csrf_token") ) - if self.is_csrf_valid(request, csrf_token): - return None - else: - raise CSRFError() + + if not trusted_origin: + if self.is_csrf_valid(request, csrf_token): + return None + else: + raise CSRFError() return None def process_response(self, request: Request, response: HttpResponse) -> None: