Compare commits

..

No commits in common. "15a94b9879e34eb50cdab635e1cf03ea96a45c2c" and "4c4bd153be2de7febeb0930651a147447296f081" have entirely different histories.

12 changed files with 50 additions and 371 deletions

View File

@ -1,22 +1,5 @@
# spiderweb # spiderweb
<p align="center">
<img
src="https://img.shields.io/pypi/v/spiderweb-framework.svg?style=for-the-badge"
alt="PyPI release version for Spiderweb"
/>
<a href="https://gitmoji.dev">
<img
src="https://img.shields.io/badge/gitmoji-%20😜%20😍-FFDD67.svg?style=for-the-badge"
alt="Gitmoji"
/>
</a>
<img
src="https://img.shields.io/badge/code%20style-black-000000.svg?style=for-the-badge"
alt="Code style: Black"
/>
</p>
As a professional web developer focusing on arcane uses of Django for arcane purposes, it occurred to me a little while ago that I didn't actually know how a web framework _worked_. As a professional web developer focusing on arcane uses of Django for arcane purposes, it occurred to me a little while ago that I didn't actually know how a web framework _worked_.
So I built one. So I built one.

View File

@ -11,6 +11,9 @@ app = SpiderwebRouter(
) )
``` ```
> [!DANGER]
> The CSRFMiddleware is incomplete at best and dangerous at worst. I am not a security expert, and my implementation is [very susceptible to the thing it is meant to prevent](https://en.wikipedia.org/wiki/Cross-site_request_forgery). While this is an big issue (and moderately hilarious), the middleware is still provided to you in its unfinished state. Be aware.
Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more. Cross-site request forgery, put simply, is a method for attackers to make legitimate-looking requests in your name to a service or system that you've previously authenticated to. Ways that we can protect against this involve aggressively expiring session cookies, special IDs for forms that are keyed to a specific user, and more.
> [!TIP] > [!TIP]

View File

@ -15,7 +15,6 @@ from spiderweb.response import (
app = SpiderwebRouter( app = SpiderwebRouter(
templates_dirs=["templates"], templates_dirs=["templates"],
middleware=[ middleware=[
"spiderweb.middleware.cors.CorsMiddleware",
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
"example_middleware.TestMiddleware", "example_middleware.TestMiddleware",

View File

@ -8,20 +8,3 @@ __version__ = "0.12.0"
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
DATABASE_PROXY = DatabaseProxy() DATABASE_PROXY = DatabaseProxy()
DEFAULT_CORS_ALLOW_METHODS = (
"DELETE",
"GET",
"OPTIONS",
"PATCH",
"POST",
"PUT",
)
DEFAULT_CORS_ALLOW_HEADERS = (
"accept",
"authorization",
"content-type",
"user-agent",
"x-csrftoken",
"x-requested-with",
)

View File

@ -86,7 +86,3 @@ class UnusedMiddleware(SpiderwebException):
class NoResponseError(SpiderwebException): class NoResponseError(SpiderwebException):
pass pass
class StartupErrors(ExceptionGroup):
pass

View File

@ -1,22 +1,16 @@
import inspect import inspect
import logging import logging
import pathlib import pathlib
import re
import traceback import traceback
import urllib.parse as urlparse import urllib.parse as urlparse
from logging import Logger
from threading import Thread from threading import Thread
from typing import Optional, Callable, Sequence, LiteralString, Literal from typing import Optional, Callable
from wsgiref.simple_server import WSGIServer from wsgiref.simple_server import WSGIServer
from jinja2 import BaseLoader, Environment, FileSystemLoader from jinja2 import BaseLoader, Environment, FileSystemLoader
from peewee import Database, SqliteDatabase from peewee import Database, SqliteDatabase
from spiderweb.middleware import MiddlewareMixin from spiderweb.middleware import MiddlewareMixin
from spiderweb.constants import (
DEFAULT_CORS_ALLOW_METHODS,
DEFAULT_CORS_ALLOW_HEADERS,
)
from spiderweb.constants import ( from spiderweb.constants import (
DATABASE_PROXY, DATABASE_PROXY,
DEFAULT_ENCODING, DEFAULT_ENCODING,
@ -36,7 +30,7 @@ from spiderweb.request import Request
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
from spiderweb.routes import RoutesMixin from spiderweb.routes import RoutesMixin
from spiderweb.secrets import FernetMixin from spiderweb.secrets import FernetMixin
from spiderweb.utils import get_http_status_by_code, convert_url_to_regex from spiderweb.utils import get_http_status_by_code
console_logger = logging.getLogger(__name__) console_logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -48,32 +42,24 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
*, *,
addr: str = None, addr: str = None,
port: int = None, port: int = None,
allowed_hosts: Sequence[str | re.Pattern] = None, allowed_hosts=None,
cors_allowed_origins: Sequence[str] = None, cors_allowed_origins=None,
cors_allowed_origins_regexes: Sequence[str] = None, cors_allow_all_origins=False,
cors_allow_all_origins: bool = False,
cors_urls_regex: str | re.Pattern[str] = r"^.*$",
cors_allow_methods: Sequence[str] = None,
cors_allow_headers: Sequence[str] = None,
cors_expose_headers: Sequence[str] = None,
cors_preflight_max_age: int = 86400,
cors_allow_credentials: bool = False,
csrf_trusted_origins: Sequence[str] = None,
db: Optional[Database] = None, db: Optional[Database] = None,
templates_dirs: Sequence[str] = None, templates_dirs: list[str] = None,
middleware: Sequence[str] = None, middleware: list[str] = None,
append_slash: bool = False, append_slash: bool = False,
staticfiles_dirs: Sequence[str] = None, staticfiles_dirs: list[str] = None,
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] = None, routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
error_routes: dict[int, Callable] = None, error_routes: dict[int, Callable] = None,
secret_key: str = None, secret_key: str = None,
session_max_age: int = 60 * 60 * 24 * 14, # 2 weeks session_max_age=60 * 60 * 24 * 14, # 2 weeks
session_cookie_name: str = "swsession", session_cookie_name="swsession",
session_cookie_secure: bool = False, # should be true if serving over HTTPS session_cookie_secure=False, # should be true if serving over HTTPS
session_cookie_http_only: bool = True, session_cookie_http_only=True,
session_cookie_same_site: Literal["strict", "lax", "none"] = "lax", session_cookie_same_site="lax",
session_cookie_path: str = "/", session_cookie_path="/",
log: Logger = None, log=None,
**kwargs, **kwargs,
): ):
self._routes = {} self._routes = {}
@ -89,23 +75,10 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self._middleware: list[str] = middleware or [] self._middleware: list[str] = middleware or []
self.middleware: list[Callable] = [] self.middleware: list[Callable] = []
self.secret_key = secret_key if secret_key else self.generate_key() self.secret_key = secret_key if secret_key else self.generate_key()
self._allowed_hosts = allowed_hosts or ["*"] self.allowed_hosts = allowed_hosts or ["*"]
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
self.cors_allowed_origins = cors_allowed_origins or [] self.cors_allowed_origins = cors_allowed_origins or []
self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or []
self.cors_allow_all_origins = cors_allow_all_origins self.cors_allow_all_origins = cors_allow_all_origins
self.cors_urls_regex = cors_urls_regex
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
self.cors_allow_headers = cors_allow_headers or DEFAULT_CORS_ALLOW_HEADERS
self.cors_expose_headers = cors_expose_headers or []
self.cors_preflight_max_age = cors_preflight_max_age
self.cors_allow_credentials = cors_allow_credentials
self._csrf_trusted_origins = csrf_trusted_origins or []
self.csrf_trusted_origins = [
convert_url_to_regex(i) for i in self._csrf_trusted_origins
]
self.extra_data = kwargs self.extra_data = kwargs
@ -181,6 +154,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
for v in varies: for v in varies:
headers.append(("Vary", v)) headers.append(("Vary", v))
start_response(status, headers) start_response(status, headers)
rendered_output = resp.render() rendered_output = resp.render()
@ -257,15 +231,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
start_response, request, self.get_error_route(500)(request) start_response, request, self.get_error_route(500)(request)
) )
def check_valid_host(self, request) -> bool:
host = request.headers.get("http_host")
if not host:
return False
for option in self.allowed_hosts:
if re.match(option, host):
return True
return False
def __call__(self, environ, start_response, *args, **kwargs): def __call__(self, environ, start_response, *args, **kwargs):
"""Entry point for WSGI apps.""" """Entry point for WSGI apps."""
request = self.get_request(environ) request = self.get_request(environ)
@ -282,9 +247,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
# replace the potentially valid handler with the error route # replace the potentially valid handler with the error route
handler = self.get_error_route(405) handler = self.get_error_route(405)
if not self.check_valid_host(request):
handler = self.get_error_route(403)
if request.is_form_request(): if request.is_form_request():
form_data = urlparse.parse_qs(request.content) form_data = urlparse.parse_qs(request.content)
for key, value in form_data.items(): for key, value in form_data.items():

View File

@ -1,11 +1,9 @@
from typing import Callable, ClassVar from typing import Callable, ClassVar
import sys
from .base import SpiderwebMiddleware as SpiderwebMiddleware from .base import SpiderwebMiddleware as SpiderwebMiddleware
from .cors import CorsMiddleware as CorsMiddleware
from .csrf import CSRFMiddleware as CSRFMiddleware from .csrf import CSRFMiddleware as CSRFMiddleware
from .sessions import SessionMiddleware as SessionMiddleware from .sessions import SessionMiddleware as SessionMiddleware
from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors from ..exceptions import ConfigError, UnusedMiddleware
from ..request import Request from ..request import Request
from ..response import HttpResponse from ..response import HttpResponse
from ..utils import import_by_string from ..utils import import_by_string
@ -29,19 +27,10 @@ class MiddlewareMixin:
self.middleware = middleware_by_reference self.middleware = middleware_by_reference
def run_middleware_checks(self): def run_middleware_checks(self):
errors = []
for middleware in self.middleware: for middleware in self.middleware:
if hasattr(middleware, "checks"): if hasattr(middleware, "checks"):
for check in middleware.checks: for check in middleware.checks:
if issue := check(server=self).check(): check(server=self).check()
errors.append(issue)
if errors:
# just show the messages
sys.tracebacklimit = 0
raise StartupErrors(
"Problems were identified during startup — cannot continue.", errors
)
def process_request_middleware(self, request: Request) -> None | bool: def process_request_middleware(self, request: Request) -> None | bool:
for middleware in self.middleware: for middleware in self.middleware:

View File

@ -1,137 +1 @@
import re # https://gist.github.com/FND/204ba41bf6ae485965ef
from urllib.parse import urlsplit, SplitResult
from spiderweb.request import Request
from spiderweb.response import HttpResponse
from spiderweb.middleware import SpiderwebMiddleware
ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
class CorsMiddleware(SpiderwebMiddleware):
# heavily 'based' on https://github.com/adamchainz/django-cors-headers,
# which is provided under the MIT license. This is essentially a direct
# port, since django-cors-headers is battle-tested code that has been
# around for a long time and it works well. Shoutouts to Otto, Adam, and
# crew for helping make this a complete non-issue in Django for a very long
# time.
def is_enabled(self, request: Request):
return bool(re.match(self.server.cors_urls_regex, request.path))
def add_response_headers(self, request: Request, response: HttpResponse):
enabled = getattr(request, "_cors_enabled", None)
if enabled is None:
enabled = self.is_enabled(request)
if not enabled:
return response
if "vary" in response.headers:
response.headers["vary"].append("origin")
else:
response.headers["vary"] = ["origin"]
origin = request.headers.get("origin")
if not origin:
return response
try:
url = urlsplit(origin)
except ValueError:
return response
if (
not self.server.cors_allow_all_origins
and not self.origin_found_in_allow_lists(origin, url)
):
return response
if (
self.server.cors_allow_all_origins
and not self.server.cors_allow_credentials
):
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
else:
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
if self.server.cors_allow_credentials:
response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
if len(self.server.cors_expose_headers):
response.headers[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
self.server.cors_expose_headers
)
if request.method == "OPTIONS":
response.headers[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(
self.server.cors_allow_headers
)
response.headers[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(
self.server.cors_allow_methods
)
if self.server.cors_preflight_max_age:
response.headers[ACCESS_CONTROL_MAX_AGE] = str(
self.server.cors_preflight_max_age
)
if (
self.server.cors_allow_private_network
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
):
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
return response
def origin_found_in_allow_lists(self, origin: str, url: SplitResult) -> bool:
return (
(origin == "null" and origin in self.server.cors_allowed_origins)
or self._url_in_allowlist(url)
or self.regex_domain_match(origin)
)
def _url_in_allowlist(self, url: SplitResult) -> bool:
origins = [urlsplit(o) for o in self.server.cors_allowed_origins]
return any(
origin.scheme == url.scheme and origin.netloc == url.netloc
for origin in origins
)
def regex_domain_match(self, origin: str) -> bool:
return any(
re.match(domain_pattern, origin)
for domain_pattern in self.server.cors_allowed_origin_regexes
)
def process_request(self, request: Request) -> HttpResponse | None:
# Identify and handle a preflight request
# origin = request.META.get("HTTP_ORIGIN")
request._cors_enabled = self.is_enabled(request)
if (
request._cors_enabled
and request.method == "OPTIONS"
and "access-control-request-method" in request.headers
):
# this should be 204, but according to mozilla, not all browsers
# parse that correctly. See [204] comment below.
resp = HttpResponse(
"",
status_code=200,
headers={"content-type": "text/plain", "content-length": 0},
)
self.add_response_headers(request, resp)
return resp
def process_response(
self, request: Request, response: HttpResponse
) -> None:
self.add_response_headers(request, response)
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code

View File

@ -1,7 +1,4 @@
import re
from re import Pattern
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional
from spiderweb.exceptions import CSRFError, ConfigError from spiderweb.exceptions import CSRFError, ConfigError
from spiderweb.middleware import SpiderwebMiddleware from spiderweb.middleware import SpiderwebMiddleware
@ -10,89 +7,53 @@ from spiderweb.response import HttpResponse
from spiderweb.server_checks import ServerCheck from spiderweb.server_checks import ServerCheck
class CheckForSessionMiddleware(ServerCheck): class SessionCheck(ServerCheck):
SESSION_MIDDLEWARE_NOT_FOUND = ( SESSION_MIDDLEWARE_NOT_FOUND = (
"Session middleware is not enabled. It must be listed above" "Session middleware is not enabled. It must be listed above"
"CSRFMiddleware in the middleware list." "CSRFMiddleware in the middleware list."
) )
def check(self) -> Optional[Exception]:
if (
"spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware
):
return ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
class VerifyCorrectMiddlewarePlacement(ServerCheck):
SESSION_MIDDLEWARE_BELOW_CSRF = ( SESSION_MIDDLEWARE_BELOW_CSRF = (
"SessionMiddleware is enabled, but it must be listed above" "SessionMiddleware is enabled, but it must be listed above"
"CSRFMiddleware in the middleware list." "CSRFMiddleware in the middleware list."
) )
def check(self) -> Optional[Exception]: def check(self):
if ( if (
"spiderweb.middleware.sessions.SessionMiddleware" "spiderweb.middleware.sessions.SessionMiddleware"
not in self.server._middleware not in self.server._middleware
): ):
# this is handled by CheckForSessionMiddleware raise ConfigError(self.SESSION_MIDDLEWARE_NOT_FOUND)
return
if self.server._middleware.index( if self.server._middleware.index(
"spiderweb.middleware.sessions.SessionMiddleware" "spiderweb.middleware.sessions.SessionMiddleware"
) > self.server._middleware.index("spiderweb.middleware.csrf.CSRFMiddleware"): ) > self.server._middleware.index(
return ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF) "spiderweb.middleware.csrf.CSRFMiddleware"
):
raise ConfigError(self.SESSION_MIDDLEWARE_BELOW_CSRF)
class VerifyCorrectFormatForTrustedOrigins(ServerCheck):
CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR = (
"The csrf_trusted_origins setting must be a list of strings."
)
def check(self) -> Optional[Exception]:
if not isinstance(self.server.csrf_trusted_origins, list):
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
for item in self.server.csrf_trusted_origins:
if not isinstance(item, Pattern):
# It's a pattern here because we've already manipulated it
# by the time this check runs
return ConfigError(self.CSRF_TRUSTED_ORIGINS_IS_LIST_OF_STR)
class CSRFMiddleware(SpiderwebMiddleware): class CSRFMiddleware(SpiderwebMiddleware):
checks = [ checks = [SessionCheck]
CheckForSessionMiddleware,
VerifyCorrectMiddlewarePlacement,
VerifyCorrectFormatForTrustedOrigins,
]
CSRF_EXPIRY = 60 * 60 # 1 hour CSRF_EXPIRY = 60 * 60 # 1 hour
def process_request(self, request: Request) -> HttpResponse | None: def process_request(self, request: Request) -> HttpResponse | None:
if request.method == "POST": if request.method == "POST":
trusted_origin = False
if hasattr(request.handler, "csrf_exempt"): if hasattr(request.handler, "csrf_exempt"):
if request.handler.csrf_exempt is True: if request.handler.csrf_exempt is True:
return return
if origin := request.headers.get("http_origin"):
for re_origin in self.server.csrf_trusted_origins:
if re.match(re_origin, origin):
trusted_origin = True
csrf_token = ( csrf_token = (
request.headers.get("X-CSRF-TOKEN") request.headers.get("X-CSRF-TOKEN")
or request.GET.get("csrf_token") or request.GET.get("csrf_token")
or request.POST.get("csrf_token") or request.POST.get("csrf_token")
) )
if self.is_csrf_valid(request, csrf_token):
if not trusted_origin: return None
if self.is_csrf_valid(request, csrf_token): else:
return None raise CSRFError()
else:
raise CSRFError()
return None return None
def process_response(self, request: Request, response: HttpResponse) -> None: def process_response(self, request: Request, response: HttpResponse) -> None:

View File

@ -1,5 +1,5 @@
import re import re
from typing import Callable, Any, Optional, Sequence from typing import Callable, Any, Optional
from spiderweb.constants import DEFAULT_ALLOWED_METHODS from spiderweb.constants import DEFAULT_ALLOWED_METHODS
from spiderweb.converters import * # noqa: F403 from spiderweb.converters import * # noqa: F403
@ -30,7 +30,7 @@ class RoutesMixin:
# ones that start with underscores are the compiled versions, non-underscores # ones that start with underscores are the compiled versions, non-underscores
# are the user-supplied versions # are the user-supplied versions
_routes: dict _routes: dict
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,)
_error_routes: dict _error_routes: dict
error_routes: dict[int, Callable] error_routes: dict[int, Callable]
append_slash: bool append_slash: bool

View File

@ -4,16 +4,12 @@ from datetime import timedelta
import pytest import pytest
from peewee import SqliteDatabase from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors from spiderweb import SpiderwebRouter, HttpResponse, ConfigError
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.sessions import Session from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf from spiderweb.middleware import csrf
from spiderweb.tests.helpers import setup from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import ( from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf
form_view_with_csrf,
form_csrf_exempt,
form_view_without_csrf,
)
# app = SpiderwebRouter( # app = SpiderwebRouter(
@ -103,21 +99,18 @@ def test_exploding_middleware():
def test_csrf_middleware_without_session_middleware(): def test_csrf_middleware_without_session_middleware():
_, environ, start_response = setup() _, environ, start_response = setup()
with pytest.raises(StartupErrors) as e: with pytest.raises(ConfigError) as e:
SpiderwebRouter( SpiderwebRouter(
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"], middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"), db=SqliteDatabase("spiderweb-tests.db"),
) )
exceptiongroup = e.value.args[1]
assert ( assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND
exceptiongroup[0].args[0]
== csrf.CheckForSessionMiddleware.SESSION_MIDDLEWARE_NOT_FOUND
)
def test_csrf_middleware_above_session_middleware(): def test_csrf_middleware_above_session_middleware():
_, environ, start_response = setup() _, environ, start_response = setup()
with pytest.raises(StartupErrors) as e: with pytest.raises(ConfigError) as e:
SpiderwebRouter( SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
@ -125,11 +118,8 @@ def test_csrf_middleware_above_session_middleware():
], ],
db=SqliteDatabase("spiderweb-tests.db"), db=SqliteDatabase("spiderweb-tests.db"),
) )
exceptiongroup = e.value.args[1]
assert ( assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF
exceptiongroup[0].args[0]
== csrf.VerifyCorrectMiddlewarePlacement.SESSION_MIDDLEWARE_BELOW_CSRF
)
def test_csrf_middleware(): def test_csrf_middleware():
@ -221,7 +211,6 @@ def test_csrf_expired_token():
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}" f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
) )
environ["REQUEST_METHOD"] = "POST" environ["REQUEST_METHOD"] = "POST"
environ["HTTP_ORIGIN"] = "example.com"
environ["HTTP_X_CSRF_TOKEN"] = token environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata) environ["CONTENT_LENGTH"] = len(formdata)
@ -265,44 +254,3 @@ def test_csrf_exempt():
environ["PATH_INFO"] = "/2" environ["PATH_INFO"] = "/2"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING) resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp2 assert "CSRF token is invalid" in resp2
def test_csrf_trusted_origins():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
csrf_trusted_origins=[
"example.com",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", form_view_without_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["REQUEST_METHOD"] = "POST"
formdata = "name=bob"
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "notvalid.com"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "example.com"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert resp2 == '{"name": "bob"}'

View File

@ -1,5 +1,4 @@
import json import json
import re
import secrets import secrets
import string import string
from http import HTTPStatus from http import HTTPStatus
@ -77,13 +76,5 @@ class Headers(dict):
def get(self, key, default=None): def get(self, key, default=None):
return super().get(key.lower(), default) return super().get(key.lower(), default)
def setdefault(self, key, default=None): def setdefault(self, key, default = None):
return super().setdefault(key.lower(), default) return super().setdefault(key.lower(), default)
def convert_url_to_regex(url: str | re.Pattern) -> re.Pattern:
if isinstance(url, re.Pattern):
return url
url = url.replace(".", "\\.")
url = url.replace("*", ".+")
return re.compile(url)