add tests for cors and get coverage to 89%

This commit is contained in:
Joe Kaufeld 2024-09-02 17:34:50 -04:00
parent 5cf9dff13a
commit f9225848a6
11 changed files with 516 additions and 39 deletions

View File

@ -25,6 +25,7 @@ app = SpiderwebRouter(
], ],
staticfiles_dirs=["static_files"], staticfiles_dirs=["static_files"],
append_slash=False, # default append_slash=False, # default
cors_allow_all_origins=True,
) )

View File

@ -1,8 +1,8 @@
from peewee import DatabaseProxy from peewee import DatabaseProxy
DEFAULT_ALLOWED_METHODS = ["GET"] DEFAULT_ALLOWED_METHODS = ["POST", "GET", "PUT", "PATCH", "DELETE"]
DEFAULT_ENCODING = "UTF-8" DEFAULT_ENCODING = "UTF-8"
__version__ = "0.12.0" __version__ = "1.0.0"
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"

View File

@ -50,7 +50,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
port: int = None, port: int = None,
allowed_hosts: Sequence[str | re.Pattern] = None, allowed_hosts: Sequence[str | re.Pattern] = None,
cors_allowed_origins: Sequence[str] = None, cors_allowed_origins: Sequence[str] = None,
cors_allowed_origins_regexes: Sequence[str] = None, cors_allowed_origin_regexes: Sequence[str] = None,
cors_allow_all_origins: bool = False, cors_allow_all_origins: bool = False,
cors_urls_regex: str | re.Pattern[str] = r"^.*$", cors_urls_regex: str | re.Pattern[str] = r"^.*$",
cors_allow_methods: Sequence[str] = None, cors_allow_methods: Sequence[str] = None,
@ -94,7 +94,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts] self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
self.cors_allowed_origins = cors_allowed_origins or [] self.cors_allowed_origins = cors_allowed_origins or []
self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or [] self.cors_allowed_origin_regexes = cors_allowed_origin_regexes or []
self.cors_allow_all_origins = cors_allow_all_origins self.cors_allow_all_origins = cors_allow_all_origins
self.cors_urls_regex = cors_urls_regex self.cors_urls_regex = cors_urls_regex
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
@ -171,17 +171,19 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
status = get_http_status_by_code(resp.status_code) status = get_http_status_by_code(resp.status_code)
cookies = [] cookies = []
varies = [] varies = []
resp.headers = {k.replace("_", "-"): v for k, v in resp.headers.items()}
if "set-cookie" in resp.headers: if "set-cookie" in resp.headers:
cookies = resp.headers["set-cookie"] cookies = resp.headers["set-cookie"]
del resp.headers["set-cookie"] del resp.headers["set-cookie"]
if "vary" in resp.headers: if "vary" in resp.headers:
varies = resp.headers["vary"] varies = resp.headers["vary"]
del resp.headers["vary"] del resp.headers["vary"]
resp.headers = {k: str(v) for k, v in resp.headers.items()}
headers = list(resp.headers.items()) headers = list(resp.headers.items())
for c in cookies: for c in cookies:
headers.append(("Set-Cookie", c)) headers.append(("set-cookie", str(c)))
for v in varies: for v in varies:
headers.append(("Vary", v)) headers.append(("vary", str(v)))
start_response(status, headers) start_response(status, headers)
@ -271,7 +273,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
def __call__(self, environ, start_response, *args, **kwargs): def __call__(self, environ, start_response, *args, **kwargs):
"""Entry point for WSGI apps.""" """Entry point for WSGI apps."""
request = self.get_request(environ) request = self.get_request(environ)
try: try:
handler, additional_args, allowed_methods = self.get_route(request.path) handler, additional_args, allowed_methods = self.get_route(request.path)
except NotFound: except NotFound:

View File

@ -38,7 +38,7 @@ class MiddlewareMixin:
if errors: if errors:
# just show the messages # just show the messages
sys.tracebacklimit = 0 sys.tracebacklimit = 1
raise StartupErrors( raise StartupErrors(
"Problems were identified during startup — cannot continue.", errors "Problems were identified during startup — cannot continue.", errors
) )

View File

@ -25,12 +25,9 @@ class VerifyValidCorsSetting(ServerCheck):
) )
def check(self): def check(self):
# - `cors_allowed_origins`
# - `cors_allowed_origin_regexes`
# - `cors_allow_all_origins`
if ( if (
not self.server.cors_allowed_origins not self.server.cors_allowed_origins
and not self.server.cors.allowed_origin_regexes and not self.server.cors_allowed_origin_regexes
and not self.server.cors_allow_all_origins and not self.server.cors_allow_all_origins
): ):
return ConfigError(self.INVALID_BASE_CONFIG) return ConfigError(self.INVALID_BASE_CONFIG)
@ -52,7 +49,6 @@ class CorsMiddleware(SpiderwebMiddleware):
enabled = getattr(request, "_cors_enabled", None) enabled = getattr(request, "_cors_enabled", None)
if enabled is None: if enabled is None:
enabled = self.is_enabled(request) enabled = self.is_enabled(request)
if not enabled: if not enabled:
return response return response
@ -61,7 +57,7 @@ class CorsMiddleware(SpiderwebMiddleware):
else: else:
response.headers["vary"] = ["origin"] response.headers["vary"] = ["origin"]
origin = request.headers.get("origin") origin = request.headers.get("http_origin")
if not origin: if not origin:
return response return response
@ -103,10 +99,9 @@ class CorsMiddleware(SpiderwebMiddleware):
response.headers[ACCESS_CONTROL_MAX_AGE] = str( response.headers[ACCESS_CONTROL_MAX_AGE] = str(
self.server.cors_preflight_max_age self.server.cors_preflight_max_age
) )
if ( if (
self.server.cors_allow_private_network self.server.cors_allow_private_network
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true" and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK.replace("-", "_")) == "true"
): ):
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true" response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
@ -134,8 +129,8 @@ class CorsMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse | None: def process_request(self, request: Request) -> HttpResponse | None:
# Identify and handle a preflight request # Identify and handle a preflight request
# origin = request.META.get("HTTP_ORIGIN")
request._cors_enabled = self.is_enabled(request) request._cors_enabled = self.is_enabled(request)
request.META["cors_ran"] = True
if ( if (
request._cors_enabled request._cors_enabled
and request.method == "OPTIONS" and request.method == "OPTIONS"
@ -152,6 +147,11 @@ class CorsMiddleware(SpiderwebMiddleware):
return resp return resp
def process_response(self, request: Request, response: HttpResponse) -> None: def process_response(self, request: Request, response: HttpResponse) -> None:
if not request.META.get("cors_ran"):
# something happened and process_request didn't run. Abort early.
# We're not relying on request._cors_enabled because it's more
# visible and the view may have destroyed it accidentally.
return
self.add_response_headers(request, response) self.add_response_headers(request, response)

View File

@ -90,13 +90,12 @@ class CSRFMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse | None: def process_request(self, request: Request) -> HttpResponse | None:
if request.method == "POST": if request.method == "POST":
if hasattr(request.handler, "csrf_exempt"): if hasattr(request.handler, "csrf_exempt"):
if request.handler.csrf_exempt is True: if request.handler.csrf_exempt is True:
return return
csrf_token = ( csrf_token = (
request.headers.get("X-CSRF-TOKEN") request.headers.get("x-csrf-token")
or request.GET.get("csrf_token") or request.GET.get("csrf_token")
or request.POST.get("csrf_token") or request.POST.get("csrf_token")
) )
@ -111,7 +110,7 @@ class CSRFMiddleware(SpiderwebMiddleware):
def process_response(self, request: Request, response: HttpResponse) -> None: def process_response(self, request: Request, response: HttpResponse) -> None:
token = self.get_csrf_token(request) token = self.get_csrf_token(request)
# do we need it in both places? # do we need it in both places?
response.headers["X-CSRF-TOKEN"] = token response.headers["x-csrf-token"] = token
response.context |= { response.context |= {
"csrf_token": f"""<input type="hidden" name="csrf_token" value="{token}">""", "csrf_token": f"""<input type="hidden" name="csrf_token" value="{token}">""",
"raw_csrf_token": token, # in case they want to format it themselves "raw_csrf_token": token, # in case they want to format it themselves

View File

@ -15,7 +15,7 @@ class StartResponse:
self.headers = headers self.headers = headers
def get_headers(self): def get_headers(self):
return {h[0]: h[1] for h in self.headers} return {h[0]: h[1] for h in self.headers} if self.headers else {}
def setup(**kwargs): def setup(**kwargs):

View File

@ -11,3 +11,8 @@ class ExplodingResponseMiddleware(SpiderwebMiddleware):
self, request: Request, response: HttpResponse self, request: Request, response: HttpResponse
) -> HttpResponse | None: ) -> HttpResponse | None:
raise UnusedMiddleware("Unfinished!") raise UnusedMiddleware("Unfinished!")
class InterruptingMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse:
return HttpResponse("Moo!")

View File

@ -6,28 +6,25 @@ from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse, StartupErrors from spiderweb import SpiderwebRouter, HttpResponse, StartupErrors
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.cors import (
ACCESS_CONTROL_ALLOW_ORIGIN,
ACCESS_CONTROL_ALLOW_HEADERS,
ACCESS_CONTROL_ALLOW_METHODS,
ACCESS_CONTROL_EXPOSE_HEADERS,
ACCESS_CONTROL_ALLOW_CREDENTIALS,
ACCESS_CONTROL_MAX_AGE,
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK,
)
from spiderweb.middleware.sessions import Session from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf from spiderweb.middleware import csrf
from spiderweb.tests.helpers import setup from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import ( from spiderweb.tests.views_for_tests import (
form_view_with_csrf, form_view_with_csrf,
form_csrf_exempt, form_csrf_exempt,
form_view_without_csrf, form_view_without_csrf, text_view, unauthorized_view,
) )
# app = SpiderwebRouter(
# middleware=[
# "spiderweb.middleware.sessions.SessionMiddleware",
# "spiderweb.middleware.csrf.CSRFMiddleware",
# "example_middleware.TestMiddleware",
# "example_middleware.RedirectMiddleware",
# "spiderweb.middleware.pydantic.PydanticMiddleware",
# "example_middleware.ExplodingMiddleware",
# ],
# )
def index(request): def index(request):
if "value" in request.SESSION: if "value" in request.SESSION:
request.SESSION["value"] += 1 request.SESSION["value"] += 1
@ -169,8 +166,8 @@ def test_csrf_middleware():
b_handle = BytesIO() b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING)) b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0) b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle) environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_X_CSRF_TOKEN"] = None
resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING) resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp3 assert "CSRF token is invalid" in resp3
@ -290,3 +287,448 @@ def test_csrf_trusted_origins():
environ["HTTP_ORIGIN"] = "example.com" environ["HTTP_ORIGIN"] = "example.com"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING) resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert resp2 == '{"name": "bob"}' assert resp2 == '{"name": "bob"}'
class TestCorsMiddleware:
# adapted from:
# https://github.com/adamchainz/django-cors-headers/blob/main/tests/test_middleware.py
# to make sure I didn't miss anything
middleware = {"middleware": ["spiderweb.middleware.cors.CorsMiddleware"]}
def test_get_no_origin(self):
app, environ, start_response = setup(
**self.middleware, cors_allow_all_origins=True
)
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_origin_vary_by_default(self):
app, environ, start_response = setup(
**self.middleware, cors_allow_all_origins=True
)
app(environ, start_response)
assert start_response.get_headers()["vary"] == "origin"
def test_get_invalid_origin(self):
app, environ, start_response = setup(
**self.middleware, cors_allow_all_origins=True
)
environ["HTTP_ORIGIN"] = "https://example.com]"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_not_in_allowed_origins(self):
app, environ, start_response = setup(
**self.middleware, cors_allowed_origins=["https://example.com"]
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_not_in_allowed_origins_due_to_wrong_scheme(self):
app, environ, start_response = setup(
**self.middleware, cors_allowed_origins=["http://example.org"]
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_in_allowed_origins(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com", "https://example.org"],
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.org"
)
def test_null_in_allowed_origins(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com", "null"],
)
environ["HTTP_ORIGIN"] = "null"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "null"
def test_file_in_allowed_origins(self):
"""
'file://' should be allowed as an origin since Chrome on Android
mistakenly sends it
"""
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com", "file://"],
)
environ["HTTP_ORIGIN"] = "file://"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "file://"
def test_get_expose_headers(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
cors_expose_headers=["accept", "content-type"]
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_EXPOSE_HEADERS] == "accept, content-type"
def test_get_dont_expose_headers(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_EXPOSE_HEADERS not in start_response.get_headers()
def test_get_allow_credentials(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
cors_allow_credentials=True,
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS] == "true"
def test_get_allow_credentials_bad_origin(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
cors_allow_credentials=True,
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in start_response.get_headers()
def test_get_allow_credentials_disabled(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in start_response.get_headers()
def test_allow_private_network_added_if_enabled_and_requested(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allow_all_origins=True
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] == "true"
def test_allow_private_network_not_added_if_enabled_and_not_requested(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allow_all_origins=True
)
environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
def test_allow_private_network_not_added_if_enabled_and_no_cors_origin(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allowed_origins=["http://example.com"]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
def test_allow_private_network_not_added_if_disabled_and_requested(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=False,
cors_allow_all_origins=True
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
def test_options_allowed_origin(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_headers=["content-type"],
cors_allow_methods=["GET", "OPTIONS"],
cors_preflight_max_age=1002,
cors_allow_all_origins=True
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
headers = start_response.get_headers()
assert start_response.status == '200 OK'
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
assert headers[ACCESS_CONTROL_MAX_AGE] == "1002"
def test_options_no_max_age(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_headers=["content-type"],
cors_allow_methods=["GET", "OPTIONS"],
cors_preflight_max_age=0,
cors_allow_all_origins=True
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
headers = start_response.get_headers()
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
assert ACCESS_CONTROL_MAX_AGE not in headers
def test_options_allowed_origins_with_port(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://localhost:9000"]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://localhost:9000"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://localhost:9000"
def test_options_adds_origin_when_domain_found_in_allowed_regexes(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origin_regexes=[r"^https://\w+\.example\.com$"]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://foo.example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com"
def test_options_adds_origin_when_domain_found_in_allowed_regexes_second(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origin_regexes=[
r"^https://\w+\.example\.org$",
r"^https://\w+\.example\.com$",
],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://foo.example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com"
def test_options_doesnt_add_origin_when_domain_not_found_in_allowed_regexes(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origin_regexes=[r"^https://\w+\.example\.org$"],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://foo.example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_options_empty_request_method(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = ""
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "200 OK"
def test_options_no_headers(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
routes=[
("/", text_view)
]
)
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "405 Method Not Allowed"
def test_allow_all_origins_get(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/", text_view)]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert start_response.get_headers()["vary"] == "origin"
def test_allow_all_origins_options(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/", text_view)]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert start_response.get_headers()["vary"] == "origin"
def test_non_200_headers_still_set(self):
"""
It's not clear whether the header should still be set for non-HTTP200
when not a preflight request. However, this is the existing behavior for
django-cors-middleware, and Spiderweb should mirror it.
"""
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/unauthorized", unauthorized_view)]
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["PATH_INFO"] = "/unauthorized"
app(environ, start_response)
assert start_response.status == "401 Unauthorized"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
def test_auth_view_options(self):
"""
Ensure HTTP200 and header still set, for preflight requests to views requiring
authentication. See: https://github.com/adamchainz/django-cors-headers/issues/3
"""
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/unauthorized", unauthorized_view)]
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["PATH_INFO"] = "/unauthorized"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert start_response.get_headers()["content-length"] == "0"
def test_get_short_circuit(self):
"""
Test a scenario when a middleware that returns a response is run before
the `CorsMiddleware`. In this case
`CorsMiddleware.process_response()` should ignore the request.
"""
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.InterruptingMiddleware",
"spiderweb.middleware.cors.CorsMiddleware",
],
cors_allow_credentials=True,
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_short_circuit_should_be_ignored(self):
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.InterruptingMiddleware",
"spiderweb.middleware.cors.CorsMiddleware",
],
cors_urls_regex=r"^/foo/$",
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_regex_matches(self):
app, environ, start_response = setup(
**self.middleware,
cors_urls_regex=r"^/foo$",
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["PATH_INFO"] = "/foo"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN in start_response.get_headers()
def test_get_regex_doesnt_match(self):
app, environ, start_response = setup(
**self.middleware,
cors_urls_regex=r"^/not-foo/$",
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["PATH_INFO"] = "/foo"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_works_if_view_deletes_cors_enabled(self):
"""
Just in case something crazy happens in the view or other middleware,
check that get_response doesn't fall over if `_cors_enabled` is removed
"""
def yeet(request):
del request._cors_enabled
return HttpResponse("hahaha")
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
routes=[('/yeet', yeet)]
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["PATH_INFO"] = "/yeet"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN in start_response.get_headers()

View File

@ -1,3 +1,4 @@
from spiderweb import HttpResponse
from spiderweb.decorators import csrf_exempt from spiderweb.decorators import csrf_exempt
from spiderweb.response import JsonResponse, TemplateResponse from spiderweb.response import JsonResponse, TemplateResponse
@ -38,3 +39,11 @@ def form_view_with_csrf(request):
return JsonResponse(data=request.POST) return JsonResponse(data=request.POST)
else: else:
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF) return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)
def text_view(request):
return HttpResponse("Hi!")
def unauthorized_view(request):
return HttpResponse("Unauthorized", status_code=401)

View File

@ -67,15 +67,35 @@ def is_jsonable(data: str) -> bool:
class Headers(dict): class Headers(dict):
# special dict that forces lowercase for all keys # special dict that forces lowercase and snake_case for all keys
def __getitem__(self, key): def __getitem__(self, key):
return super().__getitem__(key.lower()) key = key.replace("-", "_")
try:
regular = super().__getitem__(key.lower())
except KeyError:
regular = None
try:
http_version = super().__getitem__(f"http_{key.lower()}")
except KeyError:
http_version = None
return regular or http_version
def __contains__(self, item):
item = item.lower().replace("-", "_")
regular = super().__contains__(item)
http = super().__contains__(f"http_{item}")
return regular or http
def __setitem__(self, key, value): def __setitem__(self, key, value):
return super().__setitem__(key.lower(), value) return super().__setitem__(key.lower().replace("-", "_"), value)
def get(self, key, default=None): def get(self, key, default=None):
return super().get(key.lower(), default) key = key.replace("-", "_")
regular = super().get(key.lower(), default)
http_version = super().get(f"http_{key.lower()}", default)
return regular or http_version
def setdefault(self, key, default=None): def setdefault(self, key, default=None):
return super().setdefault(key.lower(), default) return super().setdefault(key.lower(), default)