CORS! #1
@ -15,6 +15,7 @@ from spiderweb.response import (
|
|||||||
app = SpiderwebRouter(
|
app = SpiderwebRouter(
|
||||||
templates_dirs=["templates"],
|
templates_dirs=["templates"],
|
||||||
middleware=[
|
middleware=[
|
||||||
|
"spiderweb.middleware.cors.CorsMiddleware",
|
||||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
"example_middleware.TestMiddleware",
|
"example_middleware.TestMiddleware",
|
||||||
|
@ -8,3 +8,20 @@ __version__ = "0.12.0"
|
|||||||
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
|
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
|
||||||
|
|
||||||
DATABASE_PROXY = DatabaseProxy()
|
DATABASE_PROXY = DatabaseProxy()
|
||||||
|
|
||||||
|
DEFAULT_CORS_ALLOW_METHODS = (
|
||||||
|
"DELETE",
|
||||||
|
"GET",
|
||||||
|
"OPTIONS",
|
||||||
|
"PATCH",
|
||||||
|
"POST",
|
||||||
|
"PUT",
|
||||||
|
)
|
||||||
|
DEFAULT_CORS_ALLOW_HEADERS = (
|
||||||
|
"accept",
|
||||||
|
"authorization",
|
||||||
|
"content-type",
|
||||||
|
"user-agent",
|
||||||
|
"x-csrftoken",
|
||||||
|
"x-requested-with",
|
||||||
|
)
|
||||||
|
@ -86,3 +86,7 @@ class UnusedMiddleware(SpiderwebException):
|
|||||||
|
|
||||||
class NoResponseError(SpiderwebException):
|
class NoResponseError(SpiderwebException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StartupErrors(ExceptionGroup):
|
||||||
|
pass
|
||||||
|
@ -1,16 +1,22 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
import urllib.parse as urlparse
|
import urllib.parse as urlparse
|
||||||
|
from logging import Logger
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Optional, Callable
|
from typing import Optional, Callable, Sequence, LiteralString, Literal
|
||||||
from wsgiref.simple_server import WSGIServer
|
from wsgiref.simple_server import WSGIServer
|
||||||
|
|
||||||
from jinja2 import BaseLoader, Environment, FileSystemLoader
|
from jinja2 import BaseLoader, Environment, FileSystemLoader
|
||||||
from peewee import Database, SqliteDatabase
|
from peewee import Database, SqliteDatabase
|
||||||
|
|
||||||
from spiderweb.middleware import MiddlewareMixin
|
from spiderweb.middleware import MiddlewareMixin
|
||||||
|
from spiderweb.constants import (
|
||||||
|
DEFAULT_CORS_ALLOW_METHODS,
|
||||||
|
DEFAULT_CORS_ALLOW_HEADERS,
|
||||||
|
)
|
||||||
from spiderweb.constants import (
|
from spiderweb.constants import (
|
||||||
DATABASE_PROXY,
|
DATABASE_PROXY,
|
||||||
DEFAULT_ENCODING,
|
DEFAULT_ENCODING,
|
||||||
@ -30,7 +36,7 @@ from spiderweb.request import Request
|
|||||||
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
|
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
|
||||||
from spiderweb.routes import RoutesMixin
|
from spiderweb.routes import RoutesMixin
|
||||||
from spiderweb.secrets import FernetMixin
|
from spiderweb.secrets import FernetMixin
|
||||||
from spiderweb.utils import get_http_status_by_code
|
from spiderweb.utils import get_http_status_by_code, convert_url_to_regex
|
||||||
|
|
||||||
console_logger = logging.getLogger(__name__)
|
console_logger = logging.getLogger(__name__)
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -42,25 +48,32 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
*,
|
*,
|
||||||
addr: str = None,
|
addr: str = None,
|
||||||
port: int = None,
|
port: int = None,
|
||||||
allowed_hosts=None,
|
allowed_hosts: Sequence[str | re.Pattern] = None,
|
||||||
cors_allowed_origins=None,
|
cors_allowed_origins: Sequence[str] = None,
|
||||||
cors_allow_all_origins=False,
|
cors_allowed_origins_regexes: Sequence[str] = None,
|
||||||
|
cors_allow_all_origins: bool = False,
|
||||||
|
cors_urls_regex: str | re.Pattern[str] = r"^.*$",
|
||||||
|
cors_allow_methods: Sequence[str] = None,
|
||||||
|
cors_allow_headers: Sequence[str] = None,
|
||||||
|
cors_expose_headers: Sequence[str] = None,
|
||||||
|
cors_preflight_max_age: int = 86400,
|
||||||
|
cors_allow_credentials: bool = False,
|
||||||
csrf_trusted_origins: Sequence[str] = None,
|
csrf_trusted_origins: Sequence[str] = None,
|
||||||
db: Optional[Database] = None,
|
db: Optional[Database] = None,
|
||||||
templates_dirs: list[str] = None,
|
templates_dirs: Sequence[str] = None,
|
||||||
middleware: list[str] = None,
|
middleware: Sequence[str] = None,
|
||||||
append_slash: bool = False,
|
append_slash: bool = False,
|
||||||
staticfiles_dirs: list[str] = None,
|
staticfiles_dirs: Sequence[str] = None,
|
||||||
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
|
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]] = None,
|
||||||
error_routes: dict[int, Callable] = None,
|
error_routes: dict[int, Callable] = None,
|
||||||
secret_key: str = None,
|
secret_key: str = None,
|
||||||
session_max_age=60 * 60 * 24 * 14, # 2 weeks
|
session_max_age: int = 60 * 60 * 24 * 14, # 2 weeks
|
||||||
session_cookie_name="swsession",
|
session_cookie_name: str = "swsession",
|
||||||
session_cookie_secure=False, # should be true if serving over HTTPS
|
session_cookie_secure: bool = False, # should be true if serving over HTTPS
|
||||||
session_cookie_http_only=True,
|
session_cookie_http_only: bool = True,
|
||||||
session_cookie_same_site="lax",
|
session_cookie_same_site: Literal["strict", "lax", "none"] = "lax",
|
||||||
session_cookie_path="/",
|
session_cookie_path: str = "/",
|
||||||
log=None,
|
log: Logger = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._routes = {}
|
self._routes = {}
|
||||||
@ -80,7 +93,15 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
|
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
|
||||||
|
|
||||||
self.cors_allowed_origins = cors_allowed_origins or []
|
self.cors_allowed_origins = cors_allowed_origins or []
|
||||||
|
self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or []
|
||||||
self.cors_allow_all_origins = cors_allow_all_origins
|
self.cors_allow_all_origins = cors_allow_all_origins
|
||||||
|
self.cors_urls_regex = cors_urls_regex
|
||||||
|
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
|
||||||
|
self.cors_allow_headers = cors_allow_headers or DEFAULT_CORS_ALLOW_HEADERS
|
||||||
|
self.cors_expose_headers = cors_expose_headers or []
|
||||||
|
self.cors_preflight_max_age = cors_preflight_max_age
|
||||||
|
self.cors_allow_credentials = cors_allow_credentials
|
||||||
|
|
||||||
self._csrf_trusted_origins = csrf_trusted_origins or []
|
self._csrf_trusted_origins = csrf_trusted_origins or []
|
||||||
self.csrf_trusted_origins = [
|
self.csrf_trusted_origins = [
|
||||||
convert_url_to_regex(i) for i in self._csrf_trusted_origins
|
convert_url_to_regex(i) for i in self._csrf_trusted_origins
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from typing import Callable, ClassVar
|
from typing import Callable, ClassVar
|
||||||
|
import sys
|
||||||
|
|
||||||
from .base import SpiderwebMiddleware as SpiderwebMiddleware
|
from .base import SpiderwebMiddleware as SpiderwebMiddleware
|
||||||
|
from .cors import CorsMiddleware as CorsMiddleware
|
||||||
from .csrf import CSRFMiddleware as CSRFMiddleware
|
from .csrf import CSRFMiddleware as CSRFMiddleware
|
||||||
from .sessions import SessionMiddleware as SessionMiddleware
|
from .sessions import SessionMiddleware as SessionMiddleware
|
||||||
from ..exceptions import ConfigError, UnusedMiddleware
|
from ..exceptions import ConfigError, UnusedMiddleware, StartupErrors
|
||||||
from ..request import Request
|
from ..request import Request
|
||||||
from ..response import HttpResponse
|
from ..response import HttpResponse
|
||||||
from ..utils import import_by_string
|
from ..utils import import_by_string
|
||||||
@ -27,10 +29,19 @@ class MiddlewareMixin:
|
|||||||
self.middleware = middleware_by_reference
|
self.middleware = middleware_by_reference
|
||||||
|
|
||||||
def run_middleware_checks(self):
|
def run_middleware_checks(self):
|
||||||
|
errors = []
|
||||||
for middleware in self.middleware:
|
for middleware in self.middleware:
|
||||||
if hasattr(middleware, "checks"):
|
if hasattr(middleware, "checks"):
|
||||||
for check in middleware.checks:
|
for check in middleware.checks:
|
||||||
check(server=self).check()
|
if issue := check(server=self).check():
|
||||||
|
errors.append(issue)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
# just show the messages
|
||||||
|
sys.tracebacklimit = 0
|
||||||
|
raise StartupErrors(
|
||||||
|
"Problems were identified during startup — cannot continue.", errors
|
||||||
|
)
|
||||||
|
|
||||||
def process_request_middleware(self, request: Request) -> None | bool:
|
def process_request_middleware(self, request: Request) -> None | bool:
|
||||||
for middleware in self.middleware:
|
for middleware in self.middleware:
|
||||||
|
@ -1 +1,137 @@
|
|||||||
# https://gist.github.com/FND/204ba41bf6ae485965ef
|
import re
|
||||||
|
from urllib.parse import urlsplit, SplitResult
|
||||||
|
|
||||||
|
from spiderweb.request import Request
|
||||||
|
from spiderweb.response import HttpResponse
|
||||||
|
from spiderweb.middleware import SpiderwebMiddleware
|
||||||
|
|
||||||
|
ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
|
||||||
|
ACCESS_CONTROL_EXPOSE_HEADERS = "access-control-expose-headers"
|
||||||
|
ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
|
||||||
|
ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
|
||||||
|
ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
|
||||||
|
ACCESS_CONTROL_MAX_AGE = "access-control-max-age"
|
||||||
|
ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = "access-control-request-private-network"
|
||||||
|
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = "access-control-allow-private-network"
|
||||||
|
|
||||||
|
|
||||||
|
class CorsMiddleware(SpiderwebMiddleware):
|
||||||
|
# heavily 'based' on https://github.com/adamchainz/django-cors-headers,
|
||||||
|
# which is provided under the MIT license. This is essentially a direct
|
||||||
|
# port, since django-cors-headers is battle-tested code that has been
|
||||||
|
# around for a long time and it works well. Shoutouts to Otto, Adam, and
|
||||||
|
# crew for helping make this a complete non-issue in Django for a very long
|
||||||
|
# time.
|
||||||
|
|
||||||
|
def is_enabled(self, request: Request):
|
||||||
|
return bool(re.match(self.server.cors_urls_regex, request.path))
|
||||||
|
|
||||||
|
def add_response_headers(self, request: Request, response: HttpResponse):
|
||||||
|
enabled = getattr(request, "_cors_enabled", None)
|
||||||
|
if enabled is None:
|
||||||
|
enabled = self.is_enabled(request)
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
return response
|
||||||
|
|
||||||
|
if "vary" in response.headers:
|
||||||
|
response.headers["vary"].append("origin")
|
||||||
|
else:
|
||||||
|
response.headers["vary"] = ["origin"]
|
||||||
|
|
||||||
|
origin = request.headers.get("origin")
|
||||||
|
if not origin:
|
||||||
|
return response
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = urlsplit(origin)
|
||||||
|
except ValueError:
|
||||||
|
return response
|
||||||
|
|
||||||
|
if (
|
||||||
|
not self.server.cors_allow_all_origins
|
||||||
|
and not self.origin_found_in_allow_lists(origin, url)
|
||||||
|
):
|
||||||
|
return response
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.server.cors_allow_all_origins
|
||||||
|
and not self.server.cors_allow_credentials
|
||||||
|
):
|
||||||
|
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
|
||||||
|
else:
|
||||||
|
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
|
||||||
|
|
||||||
|
if self.server.cors_allow_credentials:
|
||||||
|
response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
|
||||||
|
|
||||||
|
if len(self.server.cors_expose_headers):
|
||||||
|
response.headers[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
|
||||||
|
self.server.cors_expose_headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
response.headers[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(
|
||||||
|
self.server.cors_allow_headers
|
||||||
|
)
|
||||||
|
response.headers[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(
|
||||||
|
self.server.cors_allow_methods
|
||||||
|
)
|
||||||
|
if self.server.cors_preflight_max_age:
|
||||||
|
response.headers[ACCESS_CONTROL_MAX_AGE] = str(
|
||||||
|
self.server.cors_preflight_max_age
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.server.cors_allow_private_network
|
||||||
|
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
|
||||||
|
):
|
||||||
|
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def origin_found_in_allow_lists(self, origin: str, url: SplitResult) -> bool:
|
||||||
|
return (
|
||||||
|
(origin == "null" and origin in self.server.cors_allowed_origins)
|
||||||
|
or self._url_in_allowlist(url)
|
||||||
|
or self.regex_domain_match(origin)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _url_in_allowlist(self, url: SplitResult) -> bool:
|
||||||
|
origins = [urlsplit(o) for o in self.server.cors_allowed_origins]
|
||||||
|
return any(
|
||||||
|
origin.scheme == url.scheme and origin.netloc == url.netloc
|
||||||
|
for origin in origins
|
||||||
|
)
|
||||||
|
|
||||||
|
def regex_domain_match(self, origin: str) -> bool:
|
||||||
|
return any(
|
||||||
|
re.match(domain_pattern, origin)
|
||||||
|
for domain_pattern in self.server.cors_allowed_origin_regexes
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_request(self, request: Request) -> HttpResponse | None:
|
||||||
|
# Identify and handle a preflight request
|
||||||
|
# origin = request.META.get("HTTP_ORIGIN")
|
||||||
|
request._cors_enabled = self.is_enabled(request)
|
||||||
|
if (
|
||||||
|
request._cors_enabled
|
||||||
|
and request.method == "OPTIONS"
|
||||||
|
and "access-control-request-method" in request.headers
|
||||||
|
):
|
||||||
|
# this should be 204, but according to mozilla, not all browsers
|
||||||
|
# parse that correctly. See [204] comment below.
|
||||||
|
resp = HttpResponse(
|
||||||
|
"",
|
||||||
|
status_code=200,
|
||||||
|
headers={"content-type": "text/plain", "content-length": 0},
|
||||||
|
)
|
||||||
|
self.add_response_headers(request, resp)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self, request: Request, response: HttpResponse
|
||||||
|
) -> None:
|
||||||
|
self.add_response_headers(request, response)
|
||||||
|
|
||||||
|
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Callable, Any, Optional
|
from typing import Callable, Any, Optional, Sequence
|
||||||
|
|
||||||
from spiderweb.constants import DEFAULT_ALLOWED_METHODS
|
from spiderweb.constants import DEFAULT_ALLOWED_METHODS
|
||||||
from spiderweb.converters import * # noqa: F403
|
from spiderweb.converters import * # noqa: F403
|
||||||
@ -30,7 +30,7 @@ class RoutesMixin:
|
|||||||
# ones that start with underscores are the compiled versions, non-underscores
|
# ones that start with underscores are the compiled versions, non-underscores
|
||||||
# are the user-supplied versions
|
# are the user-supplied versions
|
||||||
_routes: dict
|
_routes: dict
|
||||||
routes: list[tuple[str, Callable] | tuple[str, Callable, dict]] = (None,)
|
routes: Sequence[tuple[str, Callable] | tuple[str, Callable, dict]]
|
||||||
_error_routes: dict
|
_error_routes: dict
|
||||||
error_routes: dict[int, Callable]
|
error_routes: dict[int, Callable]
|
||||||
append_slash: bool
|
append_slash: bool
|
||||||
|
@ -4,12 +4,16 @@ from datetime import timedelta
|
|||||||
import pytest
|
import pytest
|
||||||
from peewee import SqliteDatabase
|
from peewee import SqliteDatabase
|
||||||
|
|
||||||
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError
|
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors
|
||||||
from spiderweb.constants import DEFAULT_ENCODING
|
from spiderweb.constants import DEFAULT_ENCODING
|
||||||
from spiderweb.middleware.sessions import Session
|
from spiderweb.middleware.sessions import Session
|
||||||
from spiderweb.middleware import csrf
|
from spiderweb.middleware import csrf
|
||||||
from spiderweb.tests.helpers import setup
|
from spiderweb.tests.helpers import setup
|
||||||
from spiderweb.tests.views_for_tests import form_view_with_csrf, form_csrf_exempt, form_view_without_csrf
|
from spiderweb.tests.views_for_tests import (
|
||||||
|
form_view_with_csrf,
|
||||||
|
form_csrf_exempt,
|
||||||
|
form_view_without_csrf,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# app = SpiderwebRouter(
|
# app = SpiderwebRouter(
|
||||||
@ -99,18 +103,21 @@ def test_exploding_middleware():
|
|||||||
|
|
||||||
def test_csrf_middleware_without_session_middleware():
|
def test_csrf_middleware_without_session_middleware():
|
||||||
_, environ, start_response = setup()
|
_, environ, start_response = setup()
|
||||||
with pytest.raises(ConfigError) as e:
|
with pytest.raises(StartupErrors) as e:
|
||||||
SpiderwebRouter(
|
SpiderwebRouter(
|
||||||
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
|
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
|
||||||
db=SqliteDatabase("spiderweb-tests.db"),
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
exceptiongroup = e.value.args[1]
|
||||||
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_NOT_FOUND
|
assert (
|
||||||
|
exceptiongroup[0].args[0]
|
||||||
|
== csrf.CheckForSessionMiddleware.SESSION_MIDDLEWARE_NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_csrf_middleware_above_session_middleware():
|
def test_csrf_middleware_above_session_middleware():
|
||||||
_, environ, start_response = setup()
|
_, environ, start_response = setup()
|
||||||
with pytest.raises(ConfigError) as e:
|
with pytest.raises(StartupErrors) as e:
|
||||||
SpiderwebRouter(
|
SpiderwebRouter(
|
||||||
middleware=[
|
middleware=[
|
||||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
@ -118,8 +125,11 @@ def test_csrf_middleware_above_session_middleware():
|
|||||||
],
|
],
|
||||||
db=SqliteDatabase("spiderweb-tests.db"),
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
exceptiongroup = e.value.args[1]
|
||||||
assert e.value.args[0] == csrf.SessionCheck.SESSION_MIDDLEWARE_BELOW_CSRF
|
assert (
|
||||||
|
exceptiongroup[0].args[0]
|
||||||
|
== csrf.VerifyCorrectMiddlewarePlacement.SESSION_MIDDLEWARE_BELOW_CSRF
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_csrf_middleware():
|
def test_csrf_middleware():
|
||||||
@ -211,6 +221,7 @@ def test_csrf_expired_token():
|
|||||||
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
|
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
|
||||||
)
|
)
|
||||||
environ["REQUEST_METHOD"] = "POST"
|
environ["REQUEST_METHOD"] = "POST"
|
||||||
|
environ["HTTP_ORIGIN"] = "example.com"
|
||||||
environ["HTTP_X_CSRF_TOKEN"] = token
|
environ["HTTP_X_CSRF_TOKEN"] = token
|
||||||
environ["CONTENT_LENGTH"] = len(formdata)
|
environ["CONTENT_LENGTH"] = len(formdata)
|
||||||
|
|
||||||
@ -254,3 +265,44 @@ def test_csrf_exempt():
|
|||||||
environ["PATH_INFO"] = "/2"
|
environ["PATH_INFO"] = "/2"
|
||||||
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
assert "CSRF token is invalid" in resp2
|
assert "CSRF token is invalid" in resp2
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_trusted_origins():
|
||||||
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(
|
||||||
|
middleware=[
|
||||||
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
|
],
|
||||||
|
csrf_trusted_origins=[
|
||||||
|
"example.com",
|
||||||
|
],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_route("/", form_view_without_csrf, ["GET", "POST"])
|
||||||
|
|
||||||
|
environ["HTTP_USER_AGENT"] = "hi"
|
||||||
|
environ["REMOTE_ADDR"] = "1.1.1.1"
|
||||||
|
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
|
||||||
|
environ["REQUEST_METHOD"] = "POST"
|
||||||
|
|
||||||
|
formdata = "name=bob"
|
||||||
|
environ["CONTENT_LENGTH"] = len(formdata)
|
||||||
|
b_handle = BytesIO()
|
||||||
|
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||||
|
b_handle.seek(0)
|
||||||
|
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||||
|
|
||||||
|
environ["HTTP_ORIGIN"] = "notvalid.com"
|
||||||
|
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
|
assert "CSRF token is invalid" in resp
|
||||||
|
|
||||||
|
b_handle = BytesIO()
|
||||||
|
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||||
|
b_handle.seek(0)
|
||||||
|
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||||
|
|
||||||
|
environ["HTTP_ORIGIN"] = "example.com"
|
||||||
|
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
|
assert resp2 == '{"name": "bob"}'
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@ -78,3 +79,11 @@ class Headers(dict):
|
|||||||
|
|
||||||
def setdefault(self, key, default=None):
|
def setdefault(self, key, default=None):
|
||||||
return super().setdefault(key.lower(), default)
|
return super().setdefault(key.lower(), default)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_url_to_regex(url: str | re.Pattern) -> re.Pattern:
|
||||||
|
if isinstance(url, re.Pattern):
|
||||||
|
return url
|
||||||
|
url = url.replace(".", "\\.")
|
||||||
|
url = url.replace("*", ".+")
|
||||||
|
return re.compile(url)
|
||||||
|
Loading…
Reference in New Issue
Block a user