diff --git a/example.py b/example.py index c2543fe..f763eca 100644 --- a/example.py +++ b/example.py @@ -15,6 +15,7 @@ from spiderweb.response import ( app = SpiderwebRouter( templates_dirs=["templates"], middleware=[ + "spiderweb.middleware.cors.CorsMiddleware", "spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware", "example_middleware.TestMiddleware", diff --git a/spiderweb/constants.py b/spiderweb/constants.py index 9ebc2ad..cb46532 100644 --- a/spiderweb/constants.py +++ b/spiderweb/constants.py @@ -8,3 +8,20 @@ __version__ = "0.12.0" REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" DATABASE_PROXY = DatabaseProxy() + +DEFAULT_CORS_ALLOW_METHODS = ( + "DELETE", + "GET", + "OPTIONS", + "PATCH", + "POST", + "PUT", +) +DEFAULT_CORS_ALLOW_HEADERS = ( + "accept", + "authorization", + "content-type", + "user-agent", + "x-csrftoken", + "x-requested-with", +) diff --git a/spiderweb/exceptions.py b/spiderweb/exceptions.py index bdba675..f784c23 100644 --- a/spiderweb/exceptions.py +++ b/spiderweb/exceptions.py @@ -86,3 +86,7 @@ class UnusedMiddleware(SpiderwebException): class NoResponseError(SpiderwebException): pass + + +class StartupErrors(ExceptionGroup): + pass diff --git a/spiderweb/main.py b/spiderweb/main.py index 2f5dfc9..eb0e44d 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -1,16 +1,22 @@ import inspect import logging import pathlib +import re import traceback import urllib.parse as urlparse +from logging import Logger from threading import Thread -from typing import Optional, Callable +from typing import Optional, Callable, Sequence, LiteralString, Literal from wsgiref.simple_server import WSGIServer from jinja2 import BaseLoader, Environment, FileSystemLoader from peewee import Database, SqliteDatabase from spiderweb.middleware import MiddlewareMixin +from spiderweb.constants import ( + DEFAULT_CORS_ALLOW_METHODS, + DEFAULT_CORS_ALLOW_HEADERS, +) from spiderweb.constants import ( DATABASE_PROXY, DEFAULT_ENCODING, @@ -30,7 +36,7 @@ from spiderweb.request import Request from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse from spiderweb.routes import RoutesMixin from spiderweb.secrets import FernetMixin -from spiderweb.utils import get_http_status_by_code +from spiderweb.utils import get_http_status_by_code, convert_url_to_regex console_logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -42,25 +48,32 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi *, addr: str = None, port: int = None, - allowed_hosts=None, - cors_allowed_origins=None, - cors_allow_all_origins=False, + allowed_hosts: Sequence[str | re.Pattern] = None, + cors_allowed_origins: Sequence[str] = None, + cors_allowed_origins_regexes: Sequence[str] = None, + cors_allow_all_origins: bool = False, + cors_urls_regex: str | re.Pattern[str] = r"^.*$", + cors_allow_methods: Sequence[str] = None, + cors_allow_headers: Sequence[str] = None, + cors_expose_headers: Sequence[str] = None, + cors_preflight_max_age: int = 86400, + cors_allow_credentials: bool = False, csrf_trusted_origins: Sequence[str] = None, db: Optional[Database] = None, - templates_dirs: list[str] = None, - middleware: list[str] = None, + templates_dirs: Sequence[str] = None, + middleware: Sequence[str] = None, append_slash: bool = False, - staticfiles_dirs: list[str] = None, - routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None, + staticfiles_dirs: Sequence[str] = None, + routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] = None, error_routes: dict[int, Callable] = None, secret_key: str = None, - session_max_age=60 * 60 * 24 * 14, # 2 weeks - session_cookie_name="swsession", - session_cookie_secure=False, # should be true if serving over HTTPS - session_cookie_http_only=True, - session_cookie_same_site="lax", - session_cookie_path="/", - log=None, + session_max_age: int = 60 * 60 * 24 * 14, # 2 weeks + session_cookie_name: str = "swsession", + session_cookie_secure: bool = False, # should be true if serving over HTTPS + session_cookie_http_only: bool = True, + session_cookie_same_site: Literal["strict", "lax", "none"] = "lax", + session_cookie_path: str = "/", + log: Logger = None, **kwargs, ): self._routes = {} @@ -80,7 +93,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts] self.cors_allowed_origins = cors_allowed_origins or [] + self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or [] self.cors_allow_all_origins = cors_allow_all_origins + self.cors_urls_regex = cors_urls_regex + self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS + self.cors_allow_headers = cors_allow_headers or DEFAULT_CORS_ALLOW_HEADERS + self.cors_expose_headers = cors_expose_headers or [] + self.cors_preflight_max_age = cors_preflight_max_age + self.cors_allow_credentials = cors_allow_credentials + self._csrf_trusted_origins = csrf_trusted_origins or [] self.csrf_trusted_origins = [ convert_url_to_regex(i) for i in self._csrf_trusted_origins diff --git a/spiderweb/middleware/__init__.py b/spiderweb/middleware/__init__.py index 3ffeb8e..265f2a5 100644 --- a/spiderweb/middleware/__init__.py +++ b/spiderweb/middleware/__init__.py @@ -1,9 +1,11 @@ from typing import Callable, ClassVar +import sys from .base import SpiderwebMiddleware as SpiderwebMiddleware +from .cors import CorsMiddleware as CorsMiddleware from .csrf import CSRFMiddleware as CSRFMiddleware from .sessions import SessionMiddleware as SessionMiddleware -from ..exceptions import ConfigError, UnusedMiddleware +from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors from ..request import Request from ..response import HttpResponse from ..utils import import_by_string @@ -27,10 +29,19 @@ class MiddlewareMixin: self.middleware = middleware_by_reference def run_middleware_checks(self): + errors = [] for middleware in self.middleware: if hasattr(middleware, "checks"): for check in middleware.checks: - check(server=self).check() + if issue := check(server=self).check(): + errors.append(issue) + + if errors: + # just show the messages + sys.tracebacklimit = 0 + raise StartupErrors( + "Problems were identified during startup — cannot continue.", errors + ) def process_request_middleware(self, request: Request) -> None | bool: for middleware in self.middleware: diff --git a/spiderweb/middleware/cors.py b/spiderweb/middleware/cors.py index 37de52b..9a1bcc1 100644 --- a/spiderweb/middleware/cors.py +++ b/spiderweb/middleware/cors.py @@ -1 +1,137 @@ -# https://gist.github.com/FND/204ba41bf6ae485965ef +import re +from urllib.parse import urlsplit, SplitResult + +from spiderweb.request import Request +from spiderweb.response import HttpResponse +from spiderweb.middleware import SpiderwebMiddleware + +ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin" +ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers" +ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials" +ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers" +ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods" +ACCESS_CONTROL_MAX_AGE = "access-control-max-age" +ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network" +ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network" + + +class CorsMiddleware(SpiderwebMiddleware): + # heavily 'based' on https://github.com/adamchainz/django-cors-headers, + # which is provided under the MIT license. This is essentially a direct + # port, since django-cors-headers is battle-tested code that has been + # around for a long time and it works well. Shoutouts to Otto, Adam, and + # crew for helping make this a complete non-issue in Django for a very long + # time. + + def is_enabled(self, request: Request): + return bool(re.match(self.server.cors_urls_regex, request.path)) + + def add_response_headers(self, request: Request, response: HttpResponse): + enabled = getattr(request, "_cors_enabled", None) + if enabled is None: + enabled = self.is_enabled(request) + + if not enabled: + return response + + if "vary" in response.headers: + response.headers["vary"].append("origin") + else: + response.headers["vary"] = ["origin"] + + origin = request.headers.get("origin") + if not origin: + return response + + try: + url = urlsplit(origin) + except ValueError: + return response + + if ( + not self.server.cors_allow_all_origins + and not self.origin_found_in_allow_lists(origin, url) + ): + return response + + if ( + self.server.cors_allow_all_origins + and not self.server.cors_allow_credentials + ): + response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = "*" + else: + response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin + + if self.server.cors_allow_credentials: + response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true" + + if len(self.server.cors_expose_headers): + response.headers[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join( + self.server.cors_expose_headers + ) + + if request.method == "OPTIONS": + response.headers[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join( + self.server.cors_allow_headers + ) + response.headers[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join( + self.server.cors_allow_methods + ) + if self.server.cors_preflight_max_age: + response.headers[ACCESS_CONTROL_MAX_AGE] = str( + self.server.cors_preflight_max_age + ) + + if ( + self.server.cors_allow_private_network + and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true" + ): + response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true" + + return response + + def origin_found_in_allow_lists(self, origin: str, url: SplitResult) -> bool: + return ( + (origin == "null" and origin in self.server.cors_allowed_origins) + or self._url_in_allowlist(url) + or self.regex_domain_match(origin) + ) + + def _url_in_allowlist(self, url: SplitResult) -> bool: + origins = [urlsplit(o) for o in self.server.cors_allowed_origins] + return any( + origin.scheme == url.scheme and origin.netloc == url.netloc + for origin in origins + ) + + def regex_domain_match(self, origin: str) -> bool: + return any( + re.match(domain_pattern, origin) + for domain_pattern in self.server.cors_allowed_origin_regexes + ) + + def process_request(self, request: Request) -> HttpResponse | None: + # Identify and handle a preflight request + # origin = request.META.get("HTTP_ORIGIN") + request._cors_enabled = self.is_enabled(request) + if ( + request._cors_enabled + and request.method == "OPTIONS" + and "access-control-request-method" in request.headers + ): + # this should be 204, but according to mozilla, not all browsers + # parse that correctly. See [204] comment below. + resp = HttpResponse( + "", + status_code=200, + headers={"content-type": "text/plain", "content-length": 0}, + ) + self.add_response_headers(request, resp) + return resp + + def process_response( + self, request: Request, response: HttpResponse + ) -> None: + self.add_response_headers(request, response) + +# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code diff --git a/spiderweb/routes.py b/spiderweb/routes.py index 4b26448..3622d8b 100644 --- a/spiderweb/routes.py +++ b/spiderweb/routes.py @@ -1,5 +1,5 @@ import re -from typing import Callable, Any, Optional +from typing import Callable, Any, Optional, Sequence from spiderweb.constants import DEFAULT_ALLOWED_METHODS from spiderweb.converters import * # noqa: F403 @@ -30,7 +30,7 @@ class RoutesMixin: # ones that start with underscores are the compiled versions, non-underscores # are the user-supplied versions _routes: dict - routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,) + routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] _error_routes: dict error_routes: dict[int, Callable] append_slash: bool diff --git a/spiderweb/tests/test_middleware.py b/spiderweb/tests/test_middleware.py index e727b07..f785875 100644 --- a/spiderweb/tests/test_middleware.py +++ b/spiderweb/tests/test_middleware.py @@ -4,12 +4,16 @@ from datetime import timedelta import pytest from peewee import SqliteDatabase -from spiderweb import SpiderwebRouter, HttpResponse, ConfigError +from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors from spiderweb.constants import DEFAULT_ENCODING from spiderweb.middleware.sessions import Session from spiderweb.middleware import csrf from spiderweb.tests.helpers import setup -from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf +from spiderweb.tests.views_for_tests import ( + form_view_with_csrf, + form_csrf_exempt, + form_view_without_csrf, +) # app = SpiderwebRouter( @@ -99,18 +103,21 @@ def test_exploding_middleware(): def test_csrf_middleware_without_session_middleware(): _, environ, start_response = setup() - with pytest.raises(ConfigError) as e: + with pytest.raises(StartupErrors) as e: SpiderwebRouter( middleware=["spiderweb.middleware.csrf.CSRFMiddleware"], db=SqliteDatabase("spiderweb-tests.db"), ) - - assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND + exceptiongroup = e.value.args[1] + assert ( + exceptiongroup[0].args[0] + == csrf.CheckForSessionMiddleware.SESSION_MIDDLEWARE_NOT_FOUND + ) def test_csrf_middleware_above_session_middleware(): _, environ, start_response = setup() - with pytest.raises(ConfigError) as e: + with pytest.raises(StartupErrors) as e: SpiderwebRouter( middleware=[ "spiderweb.middleware.csrf.CSRFMiddleware", @@ -118,8 +125,11 @@ def test_csrf_middleware_above_session_middleware(): ], db=SqliteDatabase("spiderweb-tests.db"), ) - - assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF + exceptiongroup = e.value.args[1] + assert ( + exceptiongroup[0].args[0] + == csrf.VerifyCorrectMiddlewarePlacement.SESSION_MIDDLEWARE_BELOW_CSRF + ) def test_csrf_middleware(): @@ -211,6 +221,7 @@ def test_csrf_expired_token(): f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}" ) environ["REQUEST_METHOD"] = "POST" + environ["HTTP_ORIGIN"] = "example.com" environ["HTTP_X_CSRF_TOKEN"] = token environ["CONTENT_LENGTH"] = len(formdata) @@ -254,3 +265,44 @@ def test_csrf_exempt(): environ["PATH_INFO"] = "/2" resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING) assert "CSRF token is invalid" in resp2 + + +def test_csrf_trusted_origins(): + _, environ, start_response = setup() + app = SpiderwebRouter( + middleware=[ + "spiderweb.middleware.sessions.SessionMiddleware", + "spiderweb.middleware.csrf.CSRFMiddleware", + ], + csrf_trusted_origins=[ + "example.com", + ], + db=SqliteDatabase("spiderweb-tests.db"), + ) + + app.add_route("/", form_view_without_csrf, ["GET", "POST"]) + + environ["HTTP_USER_AGENT"] = "hi" + environ["REMOTE_ADDR"] = "1.1.1.1" + environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded" + environ["REQUEST_METHOD"] = "POST" + + formdata = "name=bob" + environ["CONTENT_LENGTH"] = len(formdata) + b_handle = BytesIO() + b_handle.write(formdata.encode(DEFAULT_ENCODING)) + b_handle.seek(0) + environ["wsgi.input"] = BufferedReader(b_handle) + + environ["HTTP_ORIGIN"] = "notvalid.com" + resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING) + assert "CSRF token is invalid" in resp + + b_handle = BytesIO() + b_handle.write(formdata.encode(DEFAULT_ENCODING)) + b_handle.seek(0) + environ["wsgi.input"] = BufferedReader(b_handle) + + environ["HTTP_ORIGIN"] = "example.com" + resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING) + assert resp2 == '{"name": "bob"}' diff --git a/spiderweb/utils.py b/spiderweb/utils.py index d24ef04..e00bcb7 100644 --- a/spiderweb/utils.py +++ b/spiderweb/utils.py @@ -1,4 +1,5 @@ import json +import re import secrets import string from http import HTTPStatus @@ -76,5 +77,13 @@ class Headers(dict): def get(self, key, default=None): return super().get(key.lower(), default) - def setdefault(self, key, default = None): - return super().setdefault(key.lower(), default) \ No newline at end of file + def setdefault(self, key, default=None): + return super().setdefault(key.lower(), default) + + +def convert_url_to_regex(url: str | re.Pattern) -> re.Pattern: + if isinstance(url, re.Pattern): + return url + url = url.replace(".", "\\.") + url = url.replace("*", ".+") + return re.compile(url)