CORS! #1

Merged
jkaufeld merged 9 commits from origins into main 2024-09-02 00:39:35 -04:00
9 changed files with 282 additions and 31 deletions
Showing only changes of commit 15a94b9879 - Show all commits

View File

@ -15,6 +15,7 @@ from spiderweb.response import (
app = SpiderwebRouter(
templates_dirs=["templates"],
middleware=[
"spiderweb.middleware.cors.CorsMiddleware",
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
"example_middleware.TestMiddleware",

View File

@ -8,3 +8,20 @@ __version__ = "0.12.0"
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
DATABASE_PROXY = DatabaseProxy()
DEFAULT_CORS_ALLOW_METHODS = (
"DELETE",
"GET",
"OPTIONS",
"PATCH",
"POST",
"PUT",
)
DEFAULT_CORS_ALLOW_HEADERS = (
"accept",
"authorization",
"content-type",
"user-agent",
"x-csrftoken",
"x-requested-with",
)

View File

@ -86,3 +86,7 @@ class UnusedMiddleware(SpiderwebException):
class NoResponseError(SpiderwebException):
pass
class StartupErrors(ExceptionGroup):
pass

View File

@ -1,16 +1,22 @@
import inspect
import logging
import pathlib
import re
import traceback
import urllib.parse as urlparse
from logging import Logger
from threading import Thread
from typing import Optional, Callable
from typing import Optional, Callable, Sequence, LiteralString, Literal
from wsgiref.simple_server import WSGIServer
from jinja2 import BaseLoader, Environment, FileSystemLoader
from peewee import Database, SqliteDatabase
from spiderweb.middleware import MiddlewareMixin
from spiderweb.constants import (
DEFAULT_CORS_ALLOW_METHODS,
DEFAULT_CORS_ALLOW_HEADERS,
)
from spiderweb.constants import (
DATABASE_PROXY,
DEFAULT_ENCODING,
@ -30,7 +36,7 @@ from spiderweb.request import Request
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
from spiderweb.routes import RoutesMixin
from spiderweb.secrets import FernetMixin
from spiderweb.utils import get_http_status_by_code
from spiderweb.utils import get_http_status_by_code, convert_url_to_regex
console_logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@ -42,25 +48,32 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
*,
addr: str = None,
port: int = None,
allowed_hosts=None,
cors_allowed_origins=None,
cors_allow_all_origins=False,
allowed_hosts: Sequence[str | re.Pattern] = None,
cors_allowed_origins: Sequence[str] = None,
cors_allowed_origins_regexes: Sequence[str] = None,
cors_allow_all_origins: bool = False,
cors_urls_regex: str | re.Pattern[str] = r"^.*$",
cors_allow_methods: Sequence[str] = None,
cors_allow_headers: Sequence[str] = None,
cors_expose_headers: Sequence[str] = None,
cors_preflight_max_age: int = 86400,
cors_allow_credentials: bool = False,
csrf_trusted_origins: Sequence[str] = None,
db: Optional[Database] = None,
templates_dirs: list[str] = None,
middleware: list[str] = None,
templates_dirs: Sequence[str] = None,
middleware: Sequence[str] = None,
append_slash: bool = False,
staticfiles_dirs: list[str] = None,
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
staticfiles_dirs: Sequence[str] = None,
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
error_routes: dict[int, Callable] = None,
secret_key: str = None,
session_max_age=60 * 60 * 24 * 14, # 2 weeks
session_cookie_name="swsession",
session_cookie_secure=False, # should be true if serving over HTTPS
session_cookie_http_only=True,
session_cookie_same_site="lax",
session_cookie_path="/",
log=None,
session_max_age: int = 60 * 60 * 24 * 14, # 2 weeks
session_cookie_name: str = "swsession",
session_cookie_secure: bool = False, # should be true if serving over HTTPS
session_cookie_http_only: bool = True,
session_cookie_same_site: Literal["strict", "lax", "none"] = "lax",
session_cookie_path: str = "/",
log: Logger = None,
**kwargs,
):
self._routes = {}
@ -80,7 +93,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
self.cors_allowed_origins = cors_allowed_origins or []
self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or []
self.cors_allow_all_origins = cors_allow_all_origins
self.cors_urls_regex = cors_urls_regex
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
self.cors_allow_headers = cors_allow_headers or DEFAULT_CORS_ALLOW_HEADERS
self.cors_expose_headers = cors_expose_headers or []
self.cors_preflight_max_age = cors_preflight_max_age
self.cors_allow_credentials = cors_allow_credentials
self._csrf_trusted_origins = csrf_trusted_origins or []
self.csrf_trusted_origins = [
convert_url_to_regex(i) for i in self._csrf_trusted_origins

View File

@ -1,9 +1,11 @@
from typing import Callable, ClassVar
import sys
from .base import SpiderwebMiddleware as SpiderwebMiddleware
from .cors import CorsMiddleware as CorsMiddleware
from .csrf import CSRFMiddleware as CSRFMiddleware
from .sessions import SessionMiddleware as SessionMiddleware
from ..exceptions import ConfigError, UnusedMiddleware
from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors
from ..request import Request
from ..response import HttpResponse
from ..utils import import_by_string
@ -27,10 +29,19 @@ class MiddlewareMixin:
self.middleware = middleware_by_reference
def run_middleware_checks(self):
errors = []
for middleware in self.middleware:
if hasattr(middleware, "checks"):
for check in middleware.checks:
check(server=self).check()
if issue := check(server=self).check():
errors.append(issue)
if errors:
# just show the messages
sys.tracebacklimit = 0
raise StartupErrors(
"Problems were identified during startup — cannot continue.", errors
)
def process_request_middleware(self, request: Request) -> None | bool:
for middleware in self.middleware:

View File

@ -1 +1,137 @@
# https://gist.github.com/FND/204ba41bf6ae485965ef
import re
from urllib.parse import urlsplit, SplitResult
from spiderweb.request import Request
from spiderweb.response import HttpResponse
from spiderweb.middleware import SpiderwebMiddleware
ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
class CorsMiddleware(SpiderwebMiddleware):
# heavily 'based' on https://github.com/adamchainz/django-cors-headers,
# which is provided under the MIT license. This is essentially a direct
# port, since django-cors-headers is battle-tested code that has been
# around for a long time and it works well. Shoutouts to Otto, Adam, and
# crew for helping make this a complete non-issue in Django for a very long
# time.
def is_enabled(self, request: Request):
return bool(re.match(self.server.cors_urls_regex, request.path))
def add_response_headers(self, request: Request, response: HttpResponse):
enabled = getattr(request, "_cors_enabled", None)
if enabled is None:
enabled = self.is_enabled(request)
if not enabled:
return response
if "vary" in response.headers:
response.headers["vary"].append("origin")
else:
response.headers["vary"] = ["origin"]
origin = request.headers.get("origin")
if not origin:
return response
try:
url = urlsplit(origin)
except ValueError:
return response
if (
not self.server.cors_allow_all_origins
and not self.origin_found_in_allow_lists(origin, url)
):
return response
if (
self.server.cors_allow_all_origins
and not self.server.cors_allow_credentials
):
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
else:
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
if self.server.cors_allow_credentials:
response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
if len(self.server.cors_expose_headers):
response.headers[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
self.server.cors_expose_headers
)
if request.method == "OPTIONS":
response.headers[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(
self.server.cors_allow_headers
)
response.headers[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(
self.server.cors_allow_methods
)
if self.server.cors_preflight_max_age:
response.headers[ACCESS_CONTROL_MAX_AGE] = str(
self.server.cors_preflight_max_age
)
if (
self.server.cors_allow_private_network
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
):
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
return response
def origin_found_in_allow_lists(self, origin: str, url: SplitResult) -> bool:
return (
(origin == "null" and origin in self.server.cors_allowed_origins)
or self._url_in_allowlist(url)
or self.regex_domain_match(origin)
)
def _url_in_allowlist(self, url: SplitResult) -> bool:
origins = [urlsplit(o) for o in self.server.cors_allowed_origins]
return any(
origin.scheme == url.scheme and origin.netloc == url.netloc
for origin in origins
)
def regex_domain_match(self, origin: str) -> bool:
return any(
re.match(domain_pattern, origin)
for domain_pattern in self.server.cors_allowed_origin_regexes
)
def process_request(self, request: Request) -> HttpResponse | None:
# Identify and handle a preflight request
# origin = request.META.get("HTTP_ORIGIN")
request._cors_enabled = self.is_enabled(request)
if (
request._cors_enabled
and request.method == "OPTIONS"
and "access-control-request-method" in request.headers
):
# this should be 204, but according to mozilla, not all browsers
# parse that correctly. See [204] comment below.
resp = HttpResponse(
"",
status_code=200,
headers={"content-type": "text/plain", "content-length": 0},
)
self.add_response_headers(request, resp)
return resp
def process_response(
self, request: Request, response: HttpResponse
) -> None:
self.add_response_headers(request, response)
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code

View File

@ -1,5 +1,5 @@
import re
from typing import Callable, Any, Optional
from typing import Callable, Any, Optional, Sequence
from spiderweb.constants import DEFAULT_ALLOWED_METHODS
from spiderweb.converters import * # noqa: F403
@ -30,7 +30,7 @@ class RoutesMixin:
# ones that start with underscores are the compiled versions, non-underscores
# are the user-supplied versions
_routes: dict
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,)
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]]
_error_routes: dict
error_routes: dict[int, Callable]
append_slash: bool

View File

@ -4,12 +4,16 @@ from datetime import timedelta
import pytest
from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors
from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf
from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf
from spiderweb.tests.views_for_tests import (
form_view_with_csrf,
form_csrf_exempt,
form_view_without_csrf,
)
# app = SpiderwebRouter(
@ -99,18 +103,21 @@ def test_exploding_middleware():
def test_csrf_middleware_without_session_middleware():
_, environ, start_response = setup()
with pytest.raises(ConfigError) as e:
with pytest.raises(StartupErrors) as e:
SpiderwebRouter(
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"),
)
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND
exceptiongroup = e.value.args[1]
assert (
exceptiongroup[0].args[0]
== csrf.CheckForSessionMiddleware.SESSION_MIDDLEWARE_NOT_FOUND
)
def test_csrf_middleware_above_session_middleware():
_, environ, start_response = setup()
with pytest.raises(ConfigError) as e:
with pytest.raises(StartupErrors) as e:
SpiderwebRouter(
middleware=[
"spiderweb.middleware.csrf.CSRFMiddleware",
@ -118,8 +125,11 @@ def test_csrf_middleware_above_session_middleware():
],
db=SqliteDatabase("spiderweb-tests.db"),
)
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF
exceptiongroup = e.value.args[1]
assert (
exceptiongroup[0].args[0]
== csrf.VerifyCorrectMiddlewarePlacement.SESSION_MIDDLEWARE_BELOW_CSRF
)
def test_csrf_middleware():
@ -211,6 +221,7 @@ def test_csrf_expired_token():
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
)
environ["REQUEST_METHOD"] = "POST"
environ["HTTP_ORIGIN"] = "example.com"
environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata)
@ -254,3 +265,44 @@ def test_csrf_exempt():
environ["PATH_INFO"] = "/2"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp2
def test_csrf_trusted_origins():
_, environ, start_response = setup()
app = SpiderwebRouter(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
csrf_trusted_origins=[
"example.com",
],
db=SqliteDatabase("spiderweb-tests.db"),
)
app.add_route("/", form_view_without_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["REQUEST_METHOD"] = "POST"
formdata = "name=bob"
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "notvalid.com"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "example.com"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert resp2 == '{"name": "bob"}'

View File

@ -1,4 +1,5 @@
import json
import re
import secrets
import string
from http import HTTPStatus
@ -78,3 +79,11 @@ class Headers(dict):
def setdefault(self, key, default=None):
return super().setdefault(key.lower(), default)
def convert_url_to_regex(url: str | re.Pattern) -> re.Pattern:
if isinstance(url, re.Pattern):
return url
url = url.replace(".", "\\.")
url = url.replace("*", ".+")
return re.compile(url)