✨ CORS middleware!
This commit is contained in:
parent
9330918009
commit
15a94b9879
example.py
spiderweb
@ -15,6 +15,7 @@ from spiderweb.response import (
|
||||
app = SpiderwebRouter(
|
||||
templates_dirs=["templates"],
|
||||
middleware=[
|
||||
"spiderweb.middleware.cors.CorsMiddleware",
|
||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||
"example_middleware.TestMiddleware",
|
||||
|
@ -8,3 +8,20 @@ __version__ = "0.12.0"
|
||||
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
|
||||
|
||||
DATABASE_PROXY = DatabaseProxy()
|
||||
|
||||
DEFAULT_CORS_ALLOW_METHODS = (
|
||||
"DELETE",
|
||||
"GET",
|
||||
"OPTIONS",
|
||||
"PATCH",
|
||||
"POST",
|
||||
"PUT",
|
||||
)
|
||||
DEFAULT_CORS_ALLOW_HEADERS = (
|
||||
"accept",
|
||||
"authorization",
|
||||
"content-type",
|
||||
"user-agent",
|
||||
"x-csrftoken",
|
||||
"x-requested-with",
|
||||
)
|
||||
|
@ -86,3 +86,7 @@ class UnusedMiddleware(SpiderwebException):
|
||||
|
||||
class NoResponseError(SpiderwebException):
|
||||
pass
|
||||
|
||||
|
||||
class StartupErrors(ExceptionGroup):
|
||||
pass
|
||||
|
@ -1,16 +1,22 @@
|
||||
import inspect
|
||||
import logging
|
||||
import pathlib
|
||||
import re
|
||||
import traceback
|
||||
import urllib.parse as urlparse
|
||||
from logging import Logger
|
||||
from threading import Thread
|
||||
from typing import Optional, Callable
|
||||
from typing import Optional, Callable, Sequence, LiteralString, Literal
|
||||
from wsgiref.simple_server import WSGIServer
|
||||
|
||||
from jinja2 import BaseLoader, Environment, FileSystemLoader
|
||||
from peewee import Database, SqliteDatabase
|
||||
|
||||
from spiderweb.middleware import MiddlewareMixin
|
||||
from spiderweb.constants import (
|
||||
DEFAULT_CORS_ALLOW_METHODS,
|
||||
DEFAULT_CORS_ALLOW_HEADERS,
|
||||
)
|
||||
from spiderweb.constants import (
|
||||
DATABASE_PROXY,
|
||||
DEFAULT_ENCODING,
|
||||
@ -30,7 +36,7 @@ from spiderweb.request import Request
|
||||
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
|
||||
from spiderweb.routes import RoutesMixin
|
||||
from spiderweb.secrets import FernetMixin
|
||||
from spiderweb.utils import get_http_status_by_code
|
||||
from spiderweb.utils import get_http_status_by_code, convert_url_to_regex
|
||||
|
||||
console_logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -42,25 +48,32 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
*,
|
||||
addr: str = None,
|
||||
port: int = None,
|
||||
allowed_hosts=None,
|
||||
cors_allowed_origins=None,
|
||||
cors_allow_all_origins=False,
|
||||
allowed_hosts: Sequence[str | re.Pattern] = None,
|
||||
cors_allowed_origins: Sequence[str] = None,
|
||||
cors_allowed_origins_regexes: Sequence[str] = None,
|
||||
cors_allow_all_origins: bool = False,
|
||||
cors_urls_regex: str | re.Pattern[str] = r"^.*$",
|
||||
cors_allow_methods: Sequence[str] = None,
|
||||
cors_allow_headers: Sequence[str] = None,
|
||||
cors_expose_headers: Sequence[str] = None,
|
||||
cors_preflight_max_age: int = 86400,
|
||||
cors_allow_credentials: bool = False,
|
||||
csrf_trusted_origins: Sequence[str] = None,
|
||||
db: Optional[Database] = None,
|
||||
templates_dirs: list[str] = None,
|
||||
middleware: list[str] = None,
|
||||
templates_dirs: Sequence[str] = None,
|
||||
middleware: Sequence[str] = None,
|
||||
append_slash: bool = False,
|
||||
staticfiles_dirs: list[str] = None,
|
||||
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
|
||||
staticfiles_dirs: Sequence[str] = None,
|
||||
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
|
||||
error_routes: dict[int, Callable] = None,
|
||||
secret_key: str = None,
|
||||
session_max_age=60 * 60 * 24 * 14, # 2 weeks
|
||||
session_cookie_name="swsession",
|
||||
session_cookie_secure=False, # should be true if serving over HTTPS
|
||||
session_cookie_http_only=True,
|
||||
session_cookie_same_site="lax",
|
||||
session_cookie_path="/",
|
||||
log=None,
|
||||
session_max_age: int = 60 * 60 * 24 * 14, # 2 weeks
|
||||
session_cookie_name: str = "swsession",
|
||||
session_cookie_secure: bool = False, # should be true if serving over HTTPS
|
||||
session_cookie_http_only: bool = True,
|
||||
session_cookie_same_site: Literal["strict", "lax", "none"] = "lax",
|
||||
session_cookie_path: str = "/",
|
||||
log: Logger = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._routes = {}
|
||||
@ -80,7 +93,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
|
||||
|
||||
self.cors_allowed_origins = cors_allowed_origins or []
|
||||
self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or []
|
||||
self.cors_allow_all_origins = cors_allow_all_origins
|
||||
self.cors_urls_regex = cors_urls_regex
|
||||
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
|
||||
self.cors_allow_headers = cors_allow_headers or DEFAULT_CORS_ALLOW_HEADERS
|
||||
self.cors_expose_headers = cors_expose_headers or []
|
||||
self.cors_preflight_max_age = cors_preflight_max_age
|
||||
self.cors_allow_credentials = cors_allow_credentials
|
||||
|
||||
self._csrf_trusted_origins = csrf_trusted_origins or []
|
||||
self.csrf_trusted_origins = [
|
||||
convert_url_to_regex(i) for i in self._csrf_trusted_origins
|
||||
|
@ -1,9 +1,11 @@
|
||||
from typing import Callable, ClassVar
|
||||
import sys
|
||||
|
||||
from .base import SpiderwebMiddleware as SpiderwebMiddleware
|
||||
from .cors import CorsMiddleware as CorsMiddleware
|
||||
from .csrf import CSRFMiddleware as CSRFMiddleware
|
||||
from .sessions import SessionMiddleware as SessionMiddleware
|
||||
from ..exceptions import ConfigError, UnusedMiddleware
|
||||
from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors
|
||||
from ..request import Request
|
||||
from ..response import HttpResponse
|
||||
from ..utils import import_by_string
|
||||
@ -27,10 +29,19 @@ class MiddlewareMixin:
|
||||
self.middleware = middleware_by_reference
|
||||
|
||||
def run_middleware_checks(self):
|
||||
errors = []
|
||||
for middleware in self.middleware:
|
||||
if hasattr(middleware, "checks"):
|
||||
for check in middleware.checks:
|
||||
check(server=self).check()
|
||||
if issue := check(server=self).check():
|
||||
errors.append(issue)
|
||||
|
||||
if errors:
|
||||
# just show the messages
|
||||
sys.tracebacklimit = 0
|
||||
raise StartupErrors(
|
||||
"Problems were identified during startup — cannot continue.", errors
|
||||
)
|
||||
|
||||
def process_request_middleware(self, request: Request) -> None | bool:
|
||||
for middleware in self.middleware:
|
||||
|
@ -1 +1,137 @@
|
||||
# https://gist.github.com/FND/204ba41bf6ae485965ef
|
||||
import re
|
||||
from urllib.parse import urlsplit, SplitResult
|
||||
|
||||
from spiderweb.request import Request
|
||||
from spiderweb.response import HttpResponse
|
||||
from spiderweb.middleware import SpiderwebMiddleware
|
||||
|
||||
ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
|
||||
ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
|
||||
ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
|
||||
ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
|
||||
ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
|
||||
ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
|
||||
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
|
||||
|
||||
|
||||
class CorsMiddleware(SpiderwebMiddleware):
|
||||
# heavily 'based' on https://github.com/adamchainz/django-cors-headers,
|
||||
# which is provided under the MIT license. This is essentially a direct
|
||||
# port, since django-cors-headers is battle-tested code that has been
|
||||
# around for a long time and it works well. Shoutouts to Otto, Adam, and
|
||||
# crew for helping make this a complete non-issue in Django for a very long
|
||||
# time.
|
||||
|
||||
def is_enabled(self, request: Request):
|
||||
return bool(re.match(self.server.cors_urls_regex, request.path))
|
||||
|
||||
def add_response_headers(self, request: Request, response: HttpResponse):
|
||||
enabled = getattr(request, "_cors_enabled", None)
|
||||
if enabled is None:
|
||||
enabled = self.is_enabled(request)
|
||||
|
||||
if not enabled:
|
||||
return response
|
||||
|
||||
if "vary" in response.headers:
|
||||
response.headers["vary"].append("origin")
|
||||
else:
|
||||
response.headers["vary"] = ["origin"]
|
||||
|
||||
origin = request.headers.get("origin")
|
||||
if not origin:
|
||||
return response
|
||||
|
||||
try:
|
||||
url = urlsplit(origin)
|
||||
except ValueError:
|
||||
return response
|
||||
|
||||
if (
|
||||
not self.server.cors_allow_all_origins
|
||||
and not self.origin_found_in_allow_lists(origin, url)
|
||||
):
|
||||
return response
|
||||
|
||||
if (
|
||||
self.server.cors_allow_all_origins
|
||||
and not self.server.cors_allow_credentials
|
||||
):
|
||||
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
|
||||
else:
|
||||
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
|
||||
|
||||
if self.server.cors_allow_credentials:
|
||||
response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
|
||||
|
||||
if len(self.server.cors_expose_headers):
|
||||
response.headers[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
|
||||
self.server.cors_expose_headers
|
||||
)
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
response.headers[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(
|
||||
self.server.cors_allow_headers
|
||||
)
|
||||
response.headers[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(
|
||||
self.server.cors_allow_methods
|
||||
)
|
||||
if self.server.cors_preflight_max_age:
|
||||
response.headers[ACCESS_CONTROL_MAX_AGE] = str(
|
||||
self.server.cors_preflight_max_age
|
||||
)
|
||||
|
||||
if (
|
||||
self.server.cors_allow_private_network
|
||||
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
|
||||
):
|
||||
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
|
||||
|
||||
return response
|
||||
|
||||
def origin_found_in_allow_lists(self, origin: str, url: SplitResult) -> bool:
|
||||
return (
|
||||
(origin == "null" and origin in self.server.cors_allowed_origins)
|
||||
or self._url_in_allowlist(url)
|
||||
or self.regex_domain_match(origin)
|
||||
)
|
||||
|
||||
def _url_in_allowlist(self, url: SplitResult) -> bool:
|
||||
origins = [urlsplit(o) for o in self.server.cors_allowed_origins]
|
||||
return any(
|
||||
origin.scheme == url.scheme and origin.netloc == url.netloc
|
||||
for origin in origins
|
||||
)
|
||||
|
||||
def regex_domain_match(self, origin: str) -> bool:
|
||||
return any(
|
||||
re.match(domain_pattern, origin)
|
||||
for domain_pattern in self.server.cors_allowed_origin_regexes
|
||||
)
|
||||
|
||||
def process_request(self, request: Request) -> HttpResponse | None:
|
||||
# Identify and handle a preflight request
|
||||
# origin = request.META.get("HTTP_ORIGIN")
|
||||
request._cors_enabled = self.is_enabled(request)
|
||||
if (
|
||||
request._cors_enabled
|
||||
and request.method == "OPTIONS"
|
||||
and "access-control-request-method" in request.headers
|
||||
):
|
||||
# this should be 204, but according to mozilla, not all browsers
|
||||
# parse that correctly. See [204] comment below.
|
||||
resp = HttpResponse(
|
||||
"",
|
||||
status_code=200,
|
||||
headers={"content-type": "text/plain", "content-length": 0},
|
||||
)
|
||||
self.add_response_headers(request, resp)
|
||||
return resp
|
||||
|
||||
def process_response(
|
||||
self, request: Request, response: HttpResponse
|
||||
) -> None:
|
||||
self.add_response_headers(request, response)
|
||||
|
||||
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
|
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import Callable, Any, Optional
|
||||
from typing import Callable, Any, Optional, Sequence
|
||||
|
||||
from spiderweb.constants import DEFAULT_ALLOWED_METHODS
|
||||
from spiderweb.converters import * # noqa: F403
|
||||
@ -30,7 +30,7 @@ class RoutesMixin:
|
||||
# ones that start with underscores are the compiled versions, non-underscores
|
||||
# are the user-supplied versions
|
||||
_routes: dict
|
||||
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,)
|
||||
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]]
|
||||
_error_routes: dict
|
||||
error_routes: dict[int, Callable]
|
||||
append_slash: bool
|
||||
|
@ -4,12 +4,16 @@ from datetime import timedelta
|
||||
import pytest
|
||||
from peewee import SqliteDatabase
|
||||
|
||||
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError
|
||||
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors
|
||||
from spiderweb.constants import DEFAULT_ENCODING
|
||||
from spiderweb.middleware.sessions import Session
|
||||
from spiderweb.middleware import csrf
|
||||
from spiderweb.tests.helpers import setup
|
||||
from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf
|
||||
from spiderweb.tests.views_for_tests import (
|
||||
form_view_with_csrf,
|
||||
form_csrf_exempt,
|
||||
form_view_without_csrf,
|
||||
)
|
||||
|
||||
|
||||
# app = SpiderwebRouter(
|
||||
@ -99,18 +103,21 @@ def test_exploding_middleware():
|
||||
|
||||
def test_csrf_middleware_without_session_middleware():
|
||||
_, environ, start_response = setup()
|
||||
with pytest.raises(ConfigError) as e:
|
||||
with pytest.raises(StartupErrors) as e:
|
||||
SpiderwebRouter(
|
||||
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
|
||||
db=SqliteDatabase("spiderweb-tests.db"),
|
||||
)
|
||||
|
||||
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND
|
||||
exceptiongroup = e.value.args[1]
|
||||
assert (
|
||||
exceptiongroup[0].args[0]
|
||||
== csrf.CheckForSessionMiddleware.SESSION_MIDDLEWARE_NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
def test_csrf_middleware_above_session_middleware():
|
||||
_, environ, start_response = setup()
|
||||
with pytest.raises(ConfigError) as e:
|
||||
with pytest.raises(StartupErrors) as e:
|
||||
SpiderwebRouter(
|
||||
middleware=[
|
||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||
@ -118,8 +125,11 @@ def test_csrf_middleware_above_session_middleware():
|
||||
],
|
||||
db=SqliteDatabase("spiderweb-tests.db"),
|
||||
)
|
||||
|
||||
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF
|
||||
exceptiongroup = e.value.args[1]
|
||||
assert (
|
||||
exceptiongroup[0].args[0]
|
||||
== csrf.VerifyCorrectMiddlewarePlacement.SESSION_MIDDLEWARE_BELOW_CSRF
|
||||
)
|
||||
|
||||
|
||||
def test_csrf_middleware():
|
||||
@ -211,6 +221,7 @@ def test_csrf_expired_token():
|
||||
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
|
||||
)
|
||||
environ["REQUEST_METHOD"] = "POST"
|
||||
environ["HTTP_ORIGIN"] = "example.com"
|
||||
environ["HTTP_X_CSRF_TOKEN"] = token
|
||||
environ["CONTENT_LENGTH"] = len(formdata)
|
||||
|
||||
@ -254,3 +265,44 @@ def test_csrf_exempt():
|
||||
environ["PATH_INFO"] = "/2"
|
||||
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||
assert "CSRF token is invalid" in resp2
|
||||
|
||||
|
||||
def test_csrf_trusted_origins():
|
||||
_, environ, start_response = setup()
|
||||
app = SpiderwebRouter(
|
||||
middleware=[
|
||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||
],
|
||||
csrf_trusted_origins=[
|
||||
"example.com",
|
||||
],
|
||||
db=SqliteDatabase("spiderweb-tests.db"),
|
||||
)
|
||||
|
||||
app.add_route("/", form_view_without_csrf, ["GET", "POST"])
|
||||
|
||||
environ["HTTP_USER_AGENT"] = "hi"
|
||||
environ["REMOTE_ADDR"] = "1.1.1.1"
|
||||
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
|
||||
environ["REQUEST_METHOD"] = "POST"
|
||||
|
||||
formdata = "name=bob"
|
||||
environ["CONTENT_LENGTH"] = len(formdata)
|
||||
b_handle = BytesIO()
|
||||
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||
b_handle.seek(0)
|
||||
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||
|
||||
environ["HTTP_ORIGIN"] = "notvalid.com"
|
||||
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||
assert "CSRF token is invalid" in resp
|
||||
|
||||
b_handle = BytesIO()
|
||||
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||
b_handle.seek(0)
|
||||
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||
|
||||
environ["HTTP_ORIGIN"] = "example.com"
|
||||
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||
assert resp2 == '{"name": "bob"}'
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
from http import HTTPStatus
|
||||
@ -76,5 +77,13 @@ class Headers(dict):
|
||||
def get(self, key, default=None):
|
||||
return super().get(key.lower(), default)
|
||||
|
||||
def setdefault(self, key, default = None):
|
||||
return super().setdefault(key.lower(), default)
|
||||
def setdefault(self, key, default=None):
|
||||
return super().setdefault(key.lower(), default)
|
||||
|
||||
|
||||
def convert_url_to_regex(url: str | re.Pattern) -> re.Pattern:
|
||||
if isinstance(url, re.Pattern):
|
||||
return url
|
||||
url = url.replace(".", "\\.")
|
||||
url = url.replace("*", ".+")
|
||||
return re.compile(url)
|
||||
|
Loading…
Reference in New Issue
Block a user