CORS middleware!

This commit is contained in:
Joe Kaufeld 2024-09-01 21:05:43 -04:00
parent 9330918009
commit 15a94b9879
9 changed files with 282 additions and 31 deletions

View File

@ -15,6 +15,7 @@ from spiderweb.response import (
app = SpiderwebRouter( app = SpiderwebRouter(
templates_dirs=["templates"], templates_dirs=["templates"],
middleware=[ middleware=[
"spiderweb.middleware.cors.CorsMiddleware",
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
"example_middleware.TestMiddleware", "example_middleware.TestMiddleware",

View File

@ -8,3 +8,20 @@ __version__ = "0.12.0"
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
DATABASE_PROXY = DatabaseProxy() DATABASE_PROXY = DatabaseProxy()
DEFAULT_CORS_ALLOW_METHODS = (
"DELETE",
"GET",
"OPTIONS",
"PATCH",
"POST",
"PUT",
)
DEFAULT_CORS_ALLOW_HEADERS = (
"accept",
"authorization",
"content-type",
"user-agent",
"x-csrftoken",
"x-requested-with",
)

View File

@ -86,3 +86,7 @@ class UnusedMiddleware(SpiderwebException):
class NoResponseError(SpiderwebException): class NoResponseError(SpiderwebException):
pass pass
class StartupErrors(ExceptionGroup):
pass

View File

@ -1,16 +1,22 @@
import inspect import inspect
import logging import logging
import pathlib import pathlib
import re
import traceback import traceback
import urllib.parse as urlparse import urllib.parse as urlparse
from logging import Logger
from threading import Thread from threading import Thread
from typing import Optional, Callable from typing import Optional, Callable, Sequence, LiteralString, Literal
from wsgiref.simple_server import WSGIServer from wsgiref.simple_server import WSGIServer
from jinja2 import BaseLoader, Environment, FileSystemLoader from jinja2 import BaseLoader, Environment, FileSystemLoader
from peewee import Database, SqliteDatabase from peewee import Database, SqliteDatabase
from spiderweb.middleware import MiddlewareMixin from spiderweb.middleware import MiddlewareMixin
from spiderweb.constants import (
DEFAULT_CORS_ALLOW_METHODS,
DEFAULT_CORS_ALLOW_HEADERS,
)
from spiderweb.constants import ( from spiderweb.constants import (
DATABASE_PROXY, DATABASE_PROXY,
DEFAULT_ENCODING, DEFAULT_ENCODING,
@ -30,7 +36,7 @@ from spiderweb.request import Request
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
from spiderweb.routes import RoutesMixin from spiderweb.routes import RoutesMixin
from spiderweb.secrets import FernetMixin from spiderweb.secrets import FernetMixin
from spiderweb.utils import get_http_status_by_code from spiderweb.utils import get_http_status_by_code, convert_url_to_regex
console_logger = logging.getLogger(__name__) console_logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -42,25 +48,32 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
*, *,
addr: str = None, addr: str = None,
port: int = None, port: int = None,
allowed_hosts=None, allowed_hosts: Sequence[str | re.Pattern] = None,
cors_allowed_origins=None, cors_allowed_origins: Sequence[str] = None,
cors_allow_all_origins=False, cors_allowed_origins_regexes: Sequence[str] = None,
cors_allow_all_origins: bool = False,
cors_urls_regex: str | re.Pattern[str] = r"^.*$",
cors_allow_methods: Sequence[str] = None,
cors_allow_headers: Sequence[str] = None,
cors_expose_headers: Sequence[str] = None,
cors_preflight_max_age: int = 86400,
cors_allow_credentials: bool = False,
csrf_trusted_origins: Sequence[str] = None, csrf_trusted_origins: Sequence[str] = None,
db: Optional[Database] = None, db: Optional[Database] = None,
templates_dirs: list[str] = None, templates_dirs: Sequence[str] = None,
middleware: list[str] = None, middleware: Sequence[str] = None,
append_slash: bool = False, append_slash: bool = False,
staticfiles_dirs: list[str] = None, staticfiles_dirs: Sequence[str] = None,
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None, routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
error_routes: dict[int, Callable] = None, error_routes: dict[int, Callable] = None,
secret_key: str = None, secret_key: str = None,
session_max_age=60 * 60 * 24 * 14, # 2 weeks session_max_age: int = 60 * 60 * 24 * 14, # 2 weeks
session_cookie_name="swsession", session_cookie_name: str = "swsession",
session_cookie_secure=False, # should be true if serving over HTTPS session_cookie_secure: bool = False, # should be true if serving over HTTPS
session_cookie_http_only=True, session_cookie_http_only: bool = True,
session_cookie_same_site="lax", session_cookie_same_site: Literal["strict", "lax", "none"] = "lax",
session_cookie_path="/", session_cookie_path: str = "/",
log=None, log: Logger = None,
**kwargs, **kwargs,
): ):
self._routes = {} self._routes = {}
@ -80,7 +93,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts] self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
self.cors_allowed_origins = cors_allowed_origins or [] self.cors_allowed_origins = cors_allowed_origins or []
self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or []
self.cors_allow_all_origins = cors_allow_all_origins self.cors_allow_all_origins = cors_allow_all_origins
self.cors_urls_regex = cors_urls_regex
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
self.cors_allow_headers = cors_allow_headers or DEFAULT_CORS_ALLOW_HEADERS
self.cors_expose_headers = cors_expose_headers or []
self.cors_preflight_max_age = cors_preflight_max_age
self.cors_allow_credentials = cors_allow_credentials
self._csrf_trusted_origins = csrf_trusted_origins or [] self._csrf_trusted_origins = csrf_trusted_origins or []
self.csrf_trusted_origins = [ self.csrf_trusted_origins = [
convert_url_to_regex(i) for i in self._csrf_trusted_origins convert_url_to_regex(i) for i in self._csrf_trusted_origins

View File

@ -1,9 +1,11 @@
from typing import Callable, ClassVar from typing import Callable, ClassVar
import sys
from .base import SpiderwebMiddleware as SpiderwebMiddleware from .base import SpiderwebMiddleware as SpiderwebMiddleware
from .cors import CorsMiddleware as CorsMiddleware
from .csrf import CSRFMiddleware as CSRFMiddleware from .csrf import CSRFMiddleware as CSRFMiddleware
from .sessions import SessionMiddleware as SessionMiddleware from .sessions import SessionMiddleware as SessionMiddleware
from ..exceptions import ConfigError, UnusedMiddleware from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors
from ..request import Request from ..request import Request
from ..response import HttpResponse from ..response import HttpResponse
from ..utils import import_by_string from ..utils import import_by_string
@ -27,10 +29,19 @@ class MiddlewareMixin:
self.middleware = middleware_by_reference self.middleware = middleware_by_reference
def run_middleware_checks(self): def run_middleware_checks(self):
errors = []
for middleware in self.middleware: for middleware in self.middleware:
if hasattr(middleware, "checks"): if hasattr(middleware, "checks"):
for check in middleware.checks: for check in middleware.checks:
check(server=self).check() if issue := check(server=self).check():
errors.append(issue)
if errors:
# just show the messages
sys.tracebacklimit = 0
raise StartupErrors(
"Problems were identified during startup — cannot continue.", errors
)
def process_request_middleware(self, request: Request) -> None | bool: def process_request_middleware(self, request: Request) -> None | bool:
for middleware in self.middleware: for middleware in self.middleware:

View File

@ -1 +1,137 @@
# https://gist.github.com/FND/204ba41bf6ae485965ef import re
from urllib.parse import urlsplit, SplitResult
from spiderweb.request import Request
from spiderweb.response import HttpResponse
from spiderweb.middleware import SpiderwebMiddleware
ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
class CorsMiddleware(SpiderwebMiddleware):
# heavily 'based' on https://github.com/adamchainz/django-cors-headers,
# which is provided under the MIT license. This is essentially a direct
# port, since django-cors-headers is battle-tested code that has been
# around for a long time and it works well. Shoutouts to Otto, Adam, and
# crew for helping make this a complete non-issue in Django for a very long
# time.
def is_enabled(self, request: Request):
return bool(re.match(self.server.cors_urls_regex, request.path))
def add_response_headers(self, request: Request, response: HttpResponse):
enabled = getattr(request, "_cors_enabled", None)
if enabled is None:
enabled = self.is_enabled(request)
if not enabled:
return response
if "vary" in response.headers:
response.headers["vary"].append("origin")
else:
response.headers["vary"] = ["origin"]
origin = request.headers.get("origin")
if not origin:
return response
try:
url = urlsplit(origin)
except ValueError:
return response
if (
not self.server.cors_allow_all_origins
and not self.origin_found_in_allow_lists(origin, url)
):
return response
if (
self.server.cors_allow_all_origins
and not self.server.cors_allow_credentials
):
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
else:
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
if self.server.cors_allow_credentials:
response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
if len(self.server.cors_expose_headers):
response.headers[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
self.server.cors_expose_headers
)
if request.method == "OPTIONS":
response.headers[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(
self.server.cors_allow_headers
)
response.headers[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(
self.server.cors_allow_methods
)
if self.server.cors_preflight_max_age:
response.headers[ACCESS_CONTROL_MAX_AGE] = str(
self.server.cors_preflight_max_age
)
if (
self.server.cors_allow_private_network
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
):
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
return response
def origin_found_in_allow_lists(self, origin: str, url: SplitResult) -> bool:
return (
(origin == "null" and origin in self.server.cors_allowed_origins)
or self._url_in_allowlist(url)
or self.regex_domain_match(origin)
)
def _url_in_allowlist(self, url: SplitResult) -> bool:
origins = [urlsplit(o) for o in self.server.cors_allowed_origins]
return any(
origin.scheme == url.scheme and origin.netloc == url.netloc
for origin in origins
)
def regex_domain_match(self, origin: str) -> bool:
return any(
re.match(domain_pattern, origin)
for domain_pattern in self.server.cors_allowed_origin_regexes
)
def process_request(self, request: Request) -> HttpResponse | None:
# Identify and handle a preflight request
# origin = request.META.get("HTTP_ORIGIN")
request._cors_enabled = self.is_enabled(request)
if (
request._cors_enabled
and request.method == "OPTIONS"
and "access-control-request-method" in request.headers
):
# this should be 204, but according to mozilla, not all browsers
# parse that correctly. See [204] comment below.
resp = HttpResponse(
"",
status_code=200,
headers={"content-type": "text/plain", "content-length": 0},
)
self.add_response_headers(request, resp)
return resp
def process_response(
self, request: Request, response: HttpResponse
) -> None:
self.add_response_headers(request, response)
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code

View File

@ -1,5 +1,5 @@
import re import re
from typing import Callable, Any, Optional from typing import Callable, Any, Optional, Sequence
from spiderweb.constants import DEFAULT_ALLOWED_METHODS from spiderweb.constants import DEFAULT_ALLOWED_METHODS
from spiderweb.converters import * # noqa: F403 from spiderweb.converters import * # noqa: F403
@ -30,7 +30,7 @@ class RoutesMixin:
# ones that start with underscores are the compiled versions, non-underscores # ones that start with underscores are the compiled versions, non-underscores
# are the user-supplied versions # are the user-supplied versions
_routes: dict _routes: dict
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,) routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]]
_error_routes: dict _error_routes: dict
error_routes: dict[int, Callable] error_routes: dict[int, Callable]
append_slash: bool append_slash: bool

View File

@ -4,12 +4,16 @@ from datetime import timedelta
import pytest import pytest
from peewee import SqliteDatabase from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.sessions import Session from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf from spiderweb.middleware import csrf
from spiderweb.tests.helpers import setup from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf from spiderweb.tests.views_for_tests import (
form_view_with_csrf,
form_csrf_exempt,
form_view_without_csrf,
)
# app = SpiderwebRouter( # app = SpiderwebRouter(
@ -99,18 +103,21 @@ def test_exploding_middleware():
def test_csrf_middleware_without_session_middleware(): def test_csrf_middleware_without_session_middleware():
_, environ, start_response = setup() _, environ, start_response = setup()
with pytest.raises(ConfigError) as e: with pytest.raises(StartupErrors) as e:
SpiderwebRouter( SpiderwebRouter(
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"], middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"), db=SqliteDatabase("spiderweb-tests.db"),
) )
exceptiongroup = e.value.args[1]
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND assert (
exceptiongroup[0].args[0]
== csrf.CheckForSessionMiddleware.SESSION_MIDDLEWARE_NOT_FOUND
)
def test_csrf_middleware_above_session_middleware(): def test_csrf_middleware_above_session_middleware():
_, environ, start_response = setup() _, environ, start_response = setup()
with pytest.raises(ConfigError) as e: with pytest.raises(StartupErrors) as e:
SpiderwebRouter( SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
@ -118,8 +125,11 @@ def test_csrf_middleware_above_session_middleware():
], ],
db=SqliteDatabase("spiderweb-tests.db"), db=SqliteDatabase("spiderweb-tests.db"),
) )
exceptiongroup = e.value.args[1]
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF assert (
exceptiongroup[0].args[0]
== csrf.VerifyCorrectMiddlewarePlacement.SESSION_MIDDLEWARE_BELOW_CSRF
)
def test_csrf_middleware(): def test_csrf_middleware():
@ -211,6 +221,7 @@ def test_csrf_expired_token():
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}" f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
) )
environ["REQUEST_METHOD"] = "POST" environ["REQUEST_METHOD"] = "POST"
environ["HTTP_ORIGIN"] = "example.com"
environ["HTTP_X_CSRF_TOKEN"] = token environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata) environ["CONTENT_LENGTH"] = len(formdata)
@ -254,3 +265,44 @@ def test_csrf_exempt():
environ["PATH_INFO"] = "/2" environ["PATH_INFO"] = "/2"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING) resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp2 assert "CSRF token is invalid" in resp2
def test_csrf_trusted_origins():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
csrf_trusted_origins=[
"example.com",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", form_view_without_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["REQUEST_METHOD"] = "POST"
formdata = "name=bob"
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "notvalid.com"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "example.com"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert resp2 == '{"name": "bob"}'

View File

@ -1,4 +1,5 @@
import json import json
import re
import secrets import secrets
import string import string
from http import HTTPStatus from http import HTTPStatus
@ -78,3 +79,11 @@ class Headers(dict):
def setdefault(self, key, default=None): def setdefault(self, key, default=None):
return super().setdefault(key.lower(), default) return super().setdefault(key.lower(), default)
def convert_url_to_regex(url: str | re.Pattern) -> re.Pattern:
if isinstance(url, re.Pattern):
return url
url = url.replace(".", "\\.")
url = url.replace("*", ".+")
return re.compile(url)