spiderweb/spiderweb/middleware/csrf.py

140 lines
4.9 KiB
Python

import re
from re import Pattern
from datetime import datetime, timedelta
from typing import Optional
from spiderweb.exceptions import CSRFError, ConfigError
from spiderweb.middleware import SpiderwebMiddleware
from spiderweb.request import Request
from spiderweb.response import HttpResponse
from spiderweb.server_checks import ServerCheck
class CheckForSessionMiddleware(ServerCheck):
SESSION_MIDDLEWARE_NOT_FOUND = (
"Session middleware is not enabled. It must be listed above"
"CSRFMiddleware in the middleware list. Add"
" 'spiderweb.middleware.sessions.SessionMiddleware' to your"
" `middleware` list."
)
def check(self) -> Optional[Exception]:
if (
"spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware
):
return ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
class VerifyCorrectMiddlewarePlacement(ServerCheck):
SESSION_MIDDLEWARE_BELOW_CSRF = (
"Session middleware is enabled, but must be listed above"
" CSRFMiddleware in the middleware list."
)
def check(self) -> Optional[Exception]:
if (
"spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware
):
# this is handled by CheckForSessionMiddleware
return
if self.server._middleware.index(
"spiderweb.middleware.sessions.SessionMiddleware"
) > self.server._middleware.index("spiderweb.middleware.csrf.CSRFMiddleware"):
return ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
class VerifyCorrectFormatForTrustedOrigins(ServerCheck):
CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR = (
"The csrf_trusted_origins setting must be a list of strings."
)
def check(self) -> Optional[Exception]:
if not isinstance(self.server.csrf_trusted_origins, list):
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
for item in self.server.csrf_trusted_origins:
if not isinstance(item, Pattern):
# It's a pattern here because we've already manipulated it
# by the time this check runs
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
class CSRFMiddleware(SpiderwebMiddleware):
checks = [
CheckForSessionMiddleware,
VerifyCorrectMiddlewarePlacement,
VerifyCorrectFormatForTrustedOrigins,
]
CSRF_EXPIRY = 60 * 60 # 1 hour
def is_trusted_origin(self, request) -> bool:
origin = request.headers.get("http_origin")
referrer = request.headers.get("http_referer") or request.headers.get(
"http_referrer"
)
host = request.headers.get("http_host")
if not origin and not (host == referrer):
return False
if not origin and (host == referrer):
origin = host
for re_origin in self.server.csrf_trusted_origins:
if re.match(re_origin, origin):
return True
return False
def process_request(self, request: Request) -> HttpResponse | None:
if request.method == "POST":
if hasattr(request.handler, "csrf_exempt"):
if request.handler.csrf_exempt is True:
return
csrf_token = (
request.headers.get("x-csrf-token")
or request.GET.get("csrf_token")
or request.POST.get("csrf_token")
)
if not self.is_trusted_origin(request):
if self.is_csrf_valid(request, csrf_token):
return None
else:
raise CSRFError()
return None
def process_response(self, request: Request, response: HttpResponse) -> None:
token = self.get_csrf_token(request)
# do we need it in both places?
response.headers["x-csrf-token"] = token
response.context |= {
"csrf_token": f"""<input type="hidden" name="csrf_token" value="{token}">""",
"raw_csrf_token": token, # in case they want to format it themselves
}
def get_csrf_token(self, request):
# the session key should be here because we've processed the session first
session_key = request._session["id"]
return self.server.encrypt(
f"{str(datetime.now().isoformat())}::{session_key}"
).decode(self.server.DEFAULT_ENCODING)
def is_csrf_valid(self, request, key):
try:
decoded = self.server.decrypt(key)
timestamp, session_key = decoded.split("::")
if session_key != request._session["id"]:
return False
if datetime.now() - timedelta(
seconds=self.CSRF_EXPIRY
) > datetime.fromisoformat(timestamp):
return False
return True
except Exception:
return False