🔒 fix issues with CSRF middleware

This commit is contained in:
Joe Kaufeld 2024-09-01 21:05:24 -04:00
parent 572675b076
commit 9330918009
3 changed files with 72 additions and 19 deletions

View File

@ -11,9 +11,6 @@ app = SpiderwebRouter(
) )
``` ```
> [!DANGER]
> The CSRFMiddleware is incomplete at best and dangerous at worst. I am not a security expert, and my implementation is [very susceptible to the thing it is meant to prevent](https://en.wikipedia.org/wiki/Cross-site_request_forgery). While this is an big issue (and moderately hilarious), the middleware is still provided to you in its unfinished state. Be aware.
Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more. Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more.
> [!TIP] > [!TIP]

View File

@ -45,6 +45,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
allowed_hosts=None, allowed_hosts=None,
cors_allowed_origins=None, cors_allowed_origins=None,
cors_allow_all_origins=False, cors_allow_all_origins=False,
csrf_trusted_origins: Sequence[str] = None,
db: Optional[Database] = None, db: Optional[Database] = None,
templates_dirs: list[str] = None, templates_dirs: list[str] = None,
middleware: list[str] = None, middleware: list[str] = None,
@ -75,10 +76,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self._middleware: list[str] = middleware or [] self._middleware: list[str] = middleware or []
self.middleware: list[Callable] = [] self.middleware: list[Callable] = []
self.secret_key = secret_key if secret_key else self.generate_key() self.secret_key = secret_key if secret_key else self.generate_key()
self.allowed_hosts = allowed_hosts or ["*"] self._allowed_hosts = allowed_hosts or ["*"]
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
self.cors_allowed_origins = cors_allowed_origins or [] self.cors_allowed_origins = cors_allowed_origins or []
self.cors_allow_all_origins = cors_allow_all_origins self.cors_allow_all_origins = cors_allow_all_origins
self._csrf_trusted_origins = csrf_trusted_origins or []
self.csrf_trusted_origins = [
convert_url_to_regex(i) for i in self._csrf_trusted_origins
]
self.extra_data = kwargs self.extra_data = kwargs
@ -154,7 +160,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
for v in varies: for v in varies:
headers.append(("Vary", v)) headers.append(("Vary", v))
start_response(status, headers) start_response(status, headers)
rendered_output = resp.render() rendered_output = resp.render()
@ -231,6 +236,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
start_response, request, self.get_error_route(500)(request) start_response, request, self.get_error_route(500)(request)
) )
def check_valid_host(self, request) -> bool:
host = request.headers.get("http_host")
if not host:
return False
for option in self.allowed_hosts:
if re.match(option, host):
return True
return False
def __call__(self, environ, start_response, *args, **kwargs): def __call__(self, environ, start_response, *args, **kwargs):
"""Entry point for WSGI apps.""" """Entry point for WSGI apps."""
request = self.get_request(environ) request = self.get_request(environ)
@ -247,6 +261,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
# replace the potentially valid handler with the error route # replace the potentially valid handler with the error route
handler = self.get_error_route(405) handler = self.get_error_route(405)
if not self.check_valid_host(request):
handler = self.get_error_route(403)
if request.is_form_request(): if request.is_form_request():
form_data = urlparse.parse_qs(request.content) form_data = urlparse.parse_qs(request.content)
for key, value in form_data.items(): for key, value in form_data.items():

View File

@ -1,4 +1,7 @@
import re
from re import Pattern
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional
from spiderweb.exceptions import CSRFError, ConfigError from spiderweb.exceptions import CSRFError, ConfigError
from spiderweb.middleware import SpiderwebMiddleware from spiderweb.middleware import SpiderwebMiddleware
@ -7,53 +10,89 @@ from spiderweb.response import HttpResponse
from spiderweb.server_checks import ServerCheck from spiderweb.server_checks import ServerCheck
class SessionCheck(ServerCheck): class CheckForSessionMiddleware(ServerCheck):
SESSION_MIDDLEWARE_NOT_FOUND = ( SESSION_MIDDLEWARE_NOT_FOUND = (
"Session middleware is not enabled. It must be listed above" "Session middleware is not enabled. It must be listed above"
"CSRFMiddleware in the middleware list." "CSRFMiddleware in the middleware list."
) )
def check(self) -> Optional[Exception]:
if (
"spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware
):
return ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
class VerifyCorrectMiddlewarePlacement(ServerCheck):
SESSION_MIDDLEWARE_BELOW_CSRF = ( SESSION_MIDDLEWARE_BELOW_CSRF = (
"SessionMiddleware is enabled, but it must be listed above" "SessionMiddleware is enabled, but it must be listed above"
"CSRFMiddleware in the middleware list." "CSRFMiddleware in the middleware list."
) )
def check(self): def check(self) -> Optional[Exception]:
if ( if (
"spiderweb.middleware.sessions.SessionMiddleware" "spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware not in self.server._middleware
): ):
raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND) # this is handled by CheckForSessionMiddleware
return
if self.server._middleware.index( if self.server._middleware.index(
"spiderweb.middleware.sessions.SessionMiddleware" "spiderweb.middleware.sessions.SessionMiddleware"
) > self.server._middleware.index( ) > self.server._middleware.index("spiderweb.middleware.csrf.CSRFMiddleware"):
"spiderweb.middleware.csrf.CSRFMiddleware" return ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
):
raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
class VerifyCorrectFormatForTrustedOrigins(ServerCheck):
CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR = (
"The csrf_trusted_origins setting must be a list of strings."
)
def check(self) -> Optional[Exception]:
if not isinstance(self.server.csrf_trusted_origins, list):
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
for item in self.server.csrf_trusted_origins:
if not isinstance(item, Pattern):
# It's a pattern here because we've already manipulated it
# by the time this check runs
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
class CSRFMiddleware(SpiderwebMiddleware): class CSRFMiddleware(SpiderwebMiddleware):
checks = [SessionCheck] checks = [
CheckForSessionMiddleware,
VerifyCorrectMiddlewarePlacement,
VerifyCorrectFormatForTrustedOrigins,
]
CSRF_EXPIRY = 60 * 60 # 1 hour CSRF_EXPIRY = 60 * 60 # 1 hour
def process_request(self, request: Request) -> HttpResponse | None: def process_request(self, request: Request) -> HttpResponse | None:
if request.method == "POST": if request.method == "POST":
trusted_origin = False
if hasattr(request.handler, "csrf_exempt"): if hasattr(request.handler, "csrf_exempt"):
if request.handler.csrf_exempt is True: if request.handler.csrf_exempt is True:
return return
if origin := request.headers.get("http_origin"):
for re_origin in self.server.csrf_trusted_origins:
if re.match(re_origin, origin):
trusted_origin = True
csrf_token = ( csrf_token = (
request.headers.get("X-CSRF-TOKEN") request.headers.get("X-CSRF-TOKEN")
or request.GET.get("csrf_token") or request.GET.get("csrf_token")
or request.POST.get("csrf_token") or request.POST.get("csrf_token")
) )
if self.is_csrf_valid(request, csrf_token):
return None if not trusted_origin:
else: if self.is_csrf_valid(request, csrf_token):
raise CSRFError() return None
else:
raise CSRFError()
return None return None
def process_response(self, request: Request, response: HttpResponse) -> None: def process_response(self, request: Request, response: HttpResponse) -> None: