CORS! #1
@ -11,9 +11,6 @@ app = SpiderwebRouter(
|
||||
)
|
||||
```
|
||||
|
||||
> [!DANGER]
|
||||
> The CSRFMiddleware is incomplete at best and dangerous at worst. I am not a security expert, and my implementation is [very susceptible to the thing it is meant to prevent](https://en.wikipedia.org/wiki/Cross-site_request_forgery). While this is an big issue (and moderately hilarious), the middleware is still provided to you in its unfinished state. Be aware.
|
||||
|
||||
Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more.
|
||||
|
||||
> [!TIP]
|
||||
|
@ -45,6 +45,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
allowed_hosts=None,
|
||||
cors_allowed_origins=None,
|
||||
cors_allow_all_origins=False,
|
||||
csrf_trusted_origins: Sequence[str] = None,
|
||||
db: Optional[Database] = None,
|
||||
templates_dirs: list[str] = None,
|
||||
middleware: list[str] = None,
|
||||
@ -75,10 +76,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
self._middleware: list[str] = middleware or []
|
||||
self.middleware: list[Callable] = []
|
||||
self.secret_key = secret_key if secret_key else self.generate_key()
|
||||
self.allowed_hosts = allowed_hosts or ["*"]
|
||||
self._allowed_hosts = allowed_hosts or ["*"]
|
||||
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
|
||||
|
||||
self.cors_allowed_origins = cors_allowed_origins or []
|
||||
self.cors_allow_all_origins = cors_allow_all_origins
|
||||
self._csrf_trusted_origins = csrf_trusted_origins or []
|
||||
self.csrf_trusted_origins = [
|
||||
convert_url_to_regex(i) for i in self._csrf_trusted_origins
|
||||
]
|
||||
|
||||
self.extra_data = kwargs
|
||||
|
||||
@ -154,7 +160,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
for v in varies:
|
||||
headers.append(("Vary", v))
|
||||
|
||||
|
||||
start_response(status, headers)
|
||||
|
||||
rendered_output = resp.render()
|
||||
@ -231,6 +236,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
start_response, request, self.get_error_route(500)(request)
|
||||
)
|
||||
|
||||
def check_valid_host(self, request) -> bool:
|
||||
host = request.headers.get("http_host")
|
||||
if not host:
|
||||
return False
|
||||
for option in self.allowed_hosts:
|
||||
if re.match(option, host):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __call__(self, environ, start_response, *args, **kwargs):
|
||||
"""Entry point for WSGI apps."""
|
||||
request = self.get_request(environ)
|
||||
@ -247,6 +261,9 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
# replace the potentially valid handler with the error route
|
||||
handler = self.get_error_route(405)
|
||||
|
||||
if not self.check_valid_host(request):
|
||||
handler = self.get_error_route(403)
|
||||
|
||||
if request.is_form_request():
|
||||
form_data = urlparse.parse_qs(request.content)
|
||||
for key, value in form_data.items():
|
||||
|
@ -1,4 +1,7 @@
|
||||
import re
|
||||
from re import Pattern
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from spiderweb.exceptions import CSRFError, ConfigError
|
||||
from spiderweb.middleware import SpiderwebMiddleware
|
||||
@ -7,49 +10,85 @@ from spiderweb.response import HttpResponse
|
||||
from spiderweb.server_checks import ServerCheck
|
||||
|
||||
|
||||
class SessionCheck(ServerCheck):
|
||||
|
||||
class CheckForSessionMiddleware(ServerCheck):
|
||||
SESSION_MIDDLEWARE_NOT_FOUND = (
|
||||
"Session middleware is not enabled. It must be listed above"
|
||||
"CSRFMiddleware in the middleware list."
|
||||
)
|
||||
|
||||
def check(self) -> Optional[Exception]:
|
||||
if (
|
||||
"spiderweb.middleware.sessions.SessionMiddleware"
|
||||
not in self.server._middleware
|
||||
):
|
||||
return ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
|
||||
|
||||
|
||||
class VerifyCorrectMiddlewarePlacement(ServerCheck):
|
||||
SESSION_MIDDLEWARE_BELOW_CSRF = (
|
||||
"SessionMiddleware is enabled, but it must be listed above"
|
||||
"CSRFMiddleware in the middleware list."
|
||||
)
|
||||
|
||||
def check(self):
|
||||
|
||||
def check(self) -> Optional[Exception]:
|
||||
if (
|
||||
"spiderweb.middleware.sessions.SessionMiddleware"
|
||||
not in self.server._middleware
|
||||
):
|
||||
raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
|
||||
# this is handled by CheckForSessionMiddleware
|
||||
return
|
||||
|
||||
if self.server._middleware.index(
|
||||
"spiderweb.middleware.sessions.SessionMiddleware"
|
||||
) > self.server._middleware.index(
|
||||
"spiderweb.middleware.csrf.CSRFMiddleware"
|
||||
):
|
||||
raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
|
||||
) > self.server._middleware.index("spiderweb.middleware.csrf.CSRFMiddleware"):
|
||||
return ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
|
||||
|
||||
|
||||
class VerifyCorrectFormatForTrustedOrigins(ServerCheck):
|
||||
CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR = (
|
||||
"The csrf_trusted_origins setting must be a list of strings."
|
||||
)
|
||||
|
||||
def check(self) -> Optional[Exception]:
|
||||
if not isinstance(self.server.csrf_trusted_origins, list):
|
||||
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
|
||||
|
||||
for item in self.server.csrf_trusted_origins:
|
||||
if not isinstance(item, Pattern):
|
||||
# It's a pattern here because we've already manipulated it
|
||||
# by the time this check runs
|
||||
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
|
||||
|
||||
|
||||
class CSRFMiddleware(SpiderwebMiddleware):
|
||||
|
||||
checks = [SessionCheck]
|
||||
checks = [
|
||||
CheckForSessionMiddleware,
|
||||
VerifyCorrectMiddlewarePlacement,
|
||||
VerifyCorrectFormatForTrustedOrigins,
|
||||
]
|
||||
|
||||
CSRF_EXPIRY = 60 * 60 # 1 hour
|
||||
|
||||
def process_request(self, request: Request) -> HttpResponse | None:
|
||||
if request.method == "POST":
|
||||
trusted_origin = False
|
||||
if hasattr(request.handler, "csrf_exempt"):
|
||||
if request.handler.csrf_exempt is True:
|
||||
return
|
||||
if origin := request.headers.get("http_origin"):
|
||||
|
||||
for re_origin in self.server.csrf_trusted_origins:
|
||||
if re.match(re_origin, origin):
|
||||
trusted_origin = True
|
||||
|
||||
csrf_token = (
|
||||
request.headers.get("X-CSRF-TOKEN")
|
||||
or request.GET.get("csrf_token")
|
||||
or request.POST.get("csrf_token")
|
||||
)
|
||||
|
||||
if not trusted_origin:
|
||||
if self.is_csrf_valid(request, csrf_token):
|
||||
return None
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user