diff --git a/example.py b/example.py index f763eca..f53b123 100644 --- a/example.py +++ b/example.py @@ -25,6 +25,7 @@ app = SpiderwebRouter( ], staticfiles_dirs=["static_files"], append_slash=False, # default + cors_allow_all_origins=True, ) diff --git a/spiderweb/constants.py b/spiderweb/constants.py index cb46532..ec9c8de 100644 --- a/spiderweb/constants.py +++ b/spiderweb/constants.py @@ -1,8 +1,8 @@ from peewee import DatabaseProxy -DEFAULT_ALLOWED_METHODS = ["GET"] +DEFAULT_ALLOWED_METHODS = ["POST", "GET", "PUT", "PATCH", "DELETE"] DEFAULT_ENCODING = "UTF-8" -__version__ = "0.12.0" +__version__ = "1.0.0" # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" diff --git a/spiderweb/main.py b/spiderweb/main.py index e019cb2..4817f0c 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -50,7 +50,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi port: int = None, allowed_hosts: Sequence[str | re.Pattern] = None, cors_allowed_origins: Sequence[str] = None, - cors_allowed_origins_regexes: Sequence[str] = None, + cors_allowed_origin_regexes: Sequence[str] = None, cors_allow_all_origins: bool = False, cors_urls_regex: str | re.Pattern[str] = r"^.*$", cors_allow_methods: Sequence[str] = None, @@ -94,7 +94,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts] self.cors_allowed_origins = cors_allowed_origins or [] - self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or [] + self.cors_allowed_origin_regexes = cors_allowed_origin_regexes or [] self.cors_allow_all_origins = cors_allow_all_origins self.cors_urls_regex = cors_urls_regex self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS @@ -171,17 +171,19 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi status = get_http_status_by_code(resp.status_code) cookies = [] varies = [] + resp.headers = {k.replace("_", "-"): v for k, v in resp.headers.items()} if "set-cookie" in resp.headers: cookies = resp.headers["set-cookie"] del resp.headers["set-cookie"] if "vary" in resp.headers: varies = resp.headers["vary"] del resp.headers["vary"] + resp.headers = {k: str(v) for k, v in resp.headers.items()} headers = list(resp.headers.items()) for c in cookies: - headers.append(("Set-Cookie", c)) + headers.append(("set-cookie", str(c))) for v in varies: - headers.append(("Vary", v)) + headers.append(("vary", str(v))) start_response(status, headers) @@ -271,7 +273,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi def __call__(self, environ, start_response, *args, **kwargs): """Entry point for WSGI apps.""" request = self.get_request(environ) - try: handler, additional_args, allowed_methods = self.get_route(request.path) except NotFound: diff --git a/spiderweb/middleware/__init__.py b/spiderweb/middleware/__init__.py index 265f2a5..2d2b0be 100644 --- a/spiderweb/middleware/__init__.py +++ b/spiderweb/middleware/__init__.py @@ -38,7 +38,7 @@ class MiddlewareMixin: if errors: # just show the messages - sys.tracebacklimit = 0 + sys.tracebacklimit = 1 raise StartupErrors( "Problems were identified during startup — cannot continue.", errors ) diff --git a/spiderweb/middleware/cors.py b/spiderweb/middleware/cors.py index 313ee62..a7e8672 100644 --- a/spiderweb/middleware/cors.py +++ b/spiderweb/middleware/cors.py @@ -25,12 +25,9 @@ class VerifyValidCorsSetting(ServerCheck): ) def check(self): - # - `cors_allowed_origins` - # - `cors_allowed_origin_regexes` - # - `cors_allow_all_origins` if ( not self.server.cors_allowed_origins - and not self.server.cors.allowed_origin_regexes + and not self.server.cors_allowed_origin_regexes and not self.server.cors_allow_all_origins ): return ConfigError(self.INVALID_BASE_CONFIG) @@ -52,7 +49,6 @@ class CorsMiddleware(SpiderwebMiddleware): enabled = getattr(request, "_cors_enabled", None) if enabled is None: enabled = self.is_enabled(request) - if not enabled: return response @@ -61,7 +57,7 @@ class CorsMiddleware(SpiderwebMiddleware): else: response.headers["vary"] = ["origin"] - origin = request.headers.get("origin") + origin = request.headers.get("http_origin") if not origin: return response @@ -103,10 +99,9 @@ class CorsMiddleware(SpiderwebMiddleware): response.headers[ACCESS_CONTROL_MAX_AGE] = str( self.server.cors_preflight_max_age ) - if ( self.server.cors_allow_private_network - and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true" + and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK.replace("-", "_")) == "true" ): response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true" @@ -134,8 +129,8 @@ class CorsMiddleware(SpiderwebMiddleware): def process_request(self, request: Request) -> HttpResponse | None: # Identify and handle a preflight request - # origin = request.META.get("HTTP_ORIGIN") request._cors_enabled = self.is_enabled(request) + request.META["cors_ran"] = True if ( request._cors_enabled and request.method == "OPTIONS" @@ -152,6 +147,11 @@ class CorsMiddleware(SpiderwebMiddleware): return resp def process_response(self, request: Request, response: HttpResponse) -> None: + if not request.META.get("cors_ran"): + # something happened and process_request didn't run. Abort early. + # We're not relying on request._cors_enabled because it's more + # visible and the view may have destroyed it accidentally. + return self.add_response_headers(request, response) diff --git a/spiderweb/middleware/csrf.py b/spiderweb/middleware/csrf.py index aad8c70..d692a04 100644 --- a/spiderweb/middleware/csrf.py +++ b/spiderweb/middleware/csrf.py @@ -90,13 +90,12 @@ class CSRFMiddleware(SpiderwebMiddleware): def process_request(self, request: Request) -> HttpResponse | None: if request.method == "POST": - if hasattr(request.handler, "csrf_exempt"): if request.handler.csrf_exempt is True: return csrf_token = ( - request.headers.get("X-CSRF-TOKEN") + request.headers.get("x-csrf-token") or request.GET.get("csrf_token") or request.POST.get("csrf_token") ) @@ -111,7 +110,7 @@ class CSRFMiddleware(SpiderwebMiddleware): def process_response(self, request: Request, response: HttpResponse) -> None: token = self.get_csrf_token(request) # do we need it in both places? - response.headers["X-CSRF-TOKEN"] = token + response.headers["x-csrf-token"] = token response.context |= { "csrf_token": f"""""", "raw_csrf_token": token, # in case they want to format it themselves diff --git a/spiderweb/tests/helpers.py b/spiderweb/tests/helpers.py index 89dd7e4..34f72bb 100644 --- a/spiderweb/tests/helpers.py +++ b/spiderweb/tests/helpers.py @@ -15,7 +15,7 @@ class StartResponse: self.headers = headers def get_headers(self): - return {h[0]: h[1] for h in self.headers} + return {h[0]: h[1] for h in self.headers} if self.headers else {} def setup(**kwargs): diff --git a/spiderweb/tests/middleware.py b/spiderweb/tests/middleware.py index 6c555ff..1c8f9da 100644 --- a/spiderweb/tests/middleware.py +++ b/spiderweb/tests/middleware.py @@ -11,3 +11,8 @@ class ExplodingResponseMiddleware(SpiderwebMiddleware): self, request: Request, response: HttpResponse ) -> HttpResponse | None: raise UnusedMiddleware("Unfinished!") + + +class InterruptingMiddleware(SpiderwebMiddleware): + def process_request(self, request: Request) -> HttpResponse: + return HttpResponse("Moo!") \ No newline at end of file diff --git a/spiderweb/tests/test_middleware.py b/spiderweb/tests/test_middleware.py index 2625ff0..ea4b517 100644 --- a/spiderweb/tests/test_middleware.py +++ b/spiderweb/tests/test_middleware.py @@ -6,28 +6,25 @@ from peewee import SqliteDatabase from spiderweb import SpiderwebRouter, HttpResponse, StartupErrors from spiderweb.constants import DEFAULT_ENCODING +from spiderweb.middleware.cors import ( + ACCESS_CONTROL_ALLOW_ORIGIN, + ACCESS_CONTROL_ALLOW_HEADERS, + ACCESS_CONTROL_ALLOW_METHODS, + ACCESS_CONTROL_EXPOSE_HEADERS, + ACCESS_CONTROL_ALLOW_CREDENTIALS, + ACCESS_CONTROL_MAX_AGE, + ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK, +) from spiderweb.middleware.sessions import Session from spiderweb.middleware import csrf from spiderweb.tests.helpers import setup from spiderweb.tests.views_for_tests import ( form_view_with_csrf, form_csrf_exempt, - form_view_without_csrf, + form_view_without_csrf, text_view, unauthorized_view, ) -# app = SpiderwebRouter( -# middleware=[ -# "spiderweb.middleware.sessions.SessionMiddleware", -# "spiderweb.middleware.csrf.CSRFMiddleware", -# "example_middleware.TestMiddleware", -# "example_middleware.RedirectMiddleware", -# "spiderweb.middleware.pydantic.PydanticMiddleware", -# "example_middleware.ExplodingMiddleware", -# ], -# ) - - def index(request): if "value" in request.SESSION: request.SESSION["value"] += 1 @@ -169,8 +166,8 @@ def test_csrf_middleware(): b_handle = BytesIO() b_handle.write(formdata.encode(DEFAULT_ENCODING)) b_handle.seek(0) - environ["wsgi.input"] = BufferedReader(b_handle) + environ["HTTP_X_CSRF_TOKEN"] = None resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING) assert "CSRF token is invalid" in resp3 @@ -290,3 +287,448 @@ def test_csrf_trusted_origins(): environ["HTTP_ORIGIN"] = "example.com" resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING) assert resp2 == '{"name": "bob"}' + + +class TestCorsMiddleware: + # adapted from: + # https://github.com/adamchainz/django-cors-headers/blob/main/tests/test_middleware.py + # to make sure I didn't miss anything + middleware = {"middleware": ["spiderweb.middleware.cors.CorsMiddleware"]} + + def test_get_no_origin(self): + app, environ, start_response = setup( + **self.middleware, cors_allow_all_origins=True + ) + app(environ, start_response) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers() + + def test_get_origin_vary_by_default(self): + app, environ, start_response = setup( + **self.middleware, cors_allow_all_origins=True + ) + app(environ, start_response) + assert start_response.get_headers()["vary"] == "origin" + + def test_get_invalid_origin(self): + app, environ, start_response = setup( + **self.middleware, cors_allow_all_origins=True + ) + environ["HTTP_ORIGIN"] = "https://example.com]" + app(environ, start_response) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers() + + def test_get_not_in_allowed_origins(self): + app, environ, start_response = setup( + **self.middleware, cors_allowed_origins=["https://example.com"] + ) + environ["HTTP_ORIGIN"] = "https://example.org" + app(environ, start_response) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers() + + def test_get_not_in_allowed_origins_due_to_wrong_scheme(self): + app, environ, start_response = setup( + **self.middleware, cors_allowed_origins=["http://example.org"] + ) + environ["HTTP_ORIGIN"] = "https://example.org" + app(environ, start_response) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers() + + def test_get_in_allowed_origins(self): + app, environ, start_response = setup( + **self.middleware, + cors_allowed_origins=["https://example.com", "https://example.org"], + ) + environ["HTTP_ORIGIN"] = "https://example.org" + app(environ, start_response) + assert ( + start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] + == "https://example.org" + ) + + def test_null_in_allowed_origins(self): + app, environ, start_response = setup( + **self.middleware, + cors_allowed_origins=["https://example.com", "null"], + ) + environ["HTTP_ORIGIN"] = "null" + app(environ, start_response) + assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "null" + + def test_file_in_allowed_origins(self): + """ + 'file://' should be allowed as an origin since Chrome on Android + mistakenly sends it + """ + app, environ, start_response = setup( + **self.middleware, + cors_allowed_origins=["https://example.com", "file://"], + ) + environ["HTTP_ORIGIN"] = "file://" + app(environ, start_response) + assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "file://" + + def test_get_expose_headers(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_all_origins=True, + cors_expose_headers=["accept", "content-type"] + ) + environ["HTTP_ORIGIN"] = "https://example.com" + app(environ, start_response) + assert start_response.get_headers()[ACCESS_CONTROL_EXPOSE_HEADERS] == "accept, content-type" + + def test_get_dont_expose_headers(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_all_origins=True, + ) + environ["HTTP_ORIGIN"] = "https://example.com" + app(environ, start_response) + assert ACCESS_CONTROL_EXPOSE_HEADERS not in start_response.get_headers() + + def test_get_allow_credentials(self): + app, environ, start_response = setup( + **self.middleware, + cors_allowed_origins=["https://example.com"], + cors_allow_credentials=True, + ) + environ["HTTP_ORIGIN"] = "https://example.com" + app(environ, start_response) + assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS] == "true" + + def test_get_allow_credentials_bad_origin(self): + app, environ, start_response = setup( + **self.middleware, + cors_allowed_origins=["https://example.com"], + cors_allow_credentials=True, + ) + environ["HTTP_ORIGIN"] = "https://example.org" + app(environ, start_response) + assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in start_response.get_headers() + + def test_get_allow_credentials_disabled(self): + app, environ, start_response = setup( + **self.middleware, + cors_allowed_origins=["https://example.com"], + ) + environ["HTTP_ORIGIN"] = "https://example.com" + app(environ, start_response) + assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in start_response.get_headers() + + def test_allow_private_network_added_if_enabled_and_requested(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_private_network=True, + cors_allow_all_origins=True + ) + environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true" + environ["HTTP_ORIGIN"] = "http://example.com" + app(environ, start_response) + assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] == "true" + + def test_allow_private_network_not_added_if_enabled_and_not_requested(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_private_network=True, + cors_allow_all_origins=True + ) + environ["HTTP_ORIGIN"] = "http://example.com" + app(environ, start_response) + assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers() + + def test_allow_private_network_not_added_if_enabled_and_no_cors_origin(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_private_network=True, + cors_allowed_origins=["http://example.com"] + ) + environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true" + environ["HTTP_ORIGIN"] = "http://example.org" + app(environ, start_response) + assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers() + + + def test_allow_private_network_not_added_if_disabled_and_requested(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_private_network=False, + cors_allow_all_origins=True + ) + environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true" + environ["HTTP_ORIGIN"] = "http://example.com" + app(environ, start_response) + assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers() + + def test_options_allowed_origin(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_headers=["content-type"], + cors_allow_methods=["GET", "OPTIONS"], + cors_preflight_max_age=1002, + cors_allow_all_origins=True + ) + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" + environ["HTTP_ORIGIN"] = "https://example.com" + environ["REQUEST_METHOD"] = "OPTIONS" + app(environ, start_response) + + headers = start_response.get_headers() + + assert start_response.status == '200 OK' + assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type" + assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS" + assert headers[ACCESS_CONTROL_MAX_AGE] == "1002" + + + def test_options_no_max_age(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_headers=["content-type"], + cors_allow_methods=["GET", "OPTIONS"], + cors_preflight_max_age=0, + cors_allow_all_origins=True + ) + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" + environ["HTTP_ORIGIN"] = "https://example.com" + environ["REQUEST_METHOD"] = "OPTIONS" + app(environ, start_response) + + + headers = start_response.get_headers() + assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type" + assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS" + assert ACCESS_CONTROL_MAX_AGE not in headers + + def test_options_allowed_origins_with_port(self): + app, environ, start_response = setup( + **self.middleware, + cors_allowed_origins=["https://localhost:9000"] + ) + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" + environ["HTTP_ORIGIN"] = "https://localhost:9000" + environ["REQUEST_METHOD"] = "OPTIONS" + app(environ, start_response) + + assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://localhost:9000" + + def test_options_adds_origin_when_domain_found_in_allowed_regexes(self): + app, environ, start_response = setup( + **self.middleware, + cors_allowed_origin_regexes=[r"^https://\w+\.example\.com$"] + ) + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" + environ["HTTP_ORIGIN"] = "https://foo.example.com" + environ["REQUEST_METHOD"] = "OPTIONS" + app(environ, start_response) + + assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com" + + def test_options_adds_origin_when_domain_found_in_allowed_regexes_second(self): + app, environ, start_response = setup( + **self.middleware, + cors_allowed_origin_regexes=[ + r"^https://\w+\.example\.org$", + r"^https://\w+\.example\.com$", + ], + ) + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" + environ["HTTP_ORIGIN"] = "https://foo.example.com" + environ["REQUEST_METHOD"] = "OPTIONS" + app(environ, start_response) + + assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com" + + def test_options_doesnt_add_origin_when_domain_not_found_in_allowed_regexes(self): + app, environ, start_response = setup( + **self.middleware, + cors_allowed_origin_regexes=[r"^https://\w+\.example\.org$"], + ) + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" + environ["HTTP_ORIGIN"] = "https://foo.example.com" + environ["REQUEST_METHOD"] = "OPTIONS" + app(environ, start_response) + + assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers() + + def test_options_empty_request_method(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_all_origins=True, + ) + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "" + environ["HTTP_ORIGIN"] = "https://example.com" + environ["REQUEST_METHOD"] = "OPTIONS" + app(environ, start_response) + + assert start_response.status == "200 OK" + + + def test_options_no_headers(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_all_origins=True, + routes=[ + ("/", text_view) + ] + ) + environ["REQUEST_METHOD"] = "OPTIONS" + app(environ, start_response) + assert start_response.status == "405 Method Not Allowed" + + def test_allow_all_origins_get(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_credentials=True, + cors_allow_all_origins=True, + routes=[("/", text_view)] + ) + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" + environ["HTTP_ORIGIN"] = "https://example.com" + environ["REQUEST_METHOD"] = "GET" + app(environ, start_response) + + assert start_response.status == "200 OK" + assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert start_response.get_headers()["vary"] == "origin" + + def test_allow_all_origins_options(self): + app, environ, start_response = setup( + **self.middleware, + cors_allow_credentials=True, + cors_allow_all_origins=True, + routes=[("/", text_view)] + ) + + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" + environ["HTTP_ORIGIN"] = "https://example.com" + environ["REQUEST_METHOD"] = "OPTIONS" + app(environ, start_response) + + assert start_response.status == "200 OK" + assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert start_response.get_headers()["vary"] == "origin" + + def test_non_200_headers_still_set(self): + """ + It's not clear whether the header should still be set for non-HTTP200 + when not a preflight request. However, this is the existing behavior for + django-cors-middleware, and Spiderweb should mirror it. + """ + app, environ, start_response = setup( + **self.middleware, + cors_allow_credentials=True, + cors_allow_all_origins=True, + routes=[("/unauthorized", unauthorized_view)] + ) + environ["HTTP_ORIGIN"] = "https://example.com" + environ["PATH_INFO"] = "/unauthorized" + app(environ, start_response) + + + assert start_response.status == "401 Unauthorized" + assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + + def test_auth_view_options(self): + """ + Ensure HTTP200 and header still set, for preflight requests to views requiring + authentication. See: https://github.com/adamchainz/django-cors-headers/issues/3 + """ + app, environ, start_response = setup( + **self.middleware, + cors_allow_credentials=True, + cors_allow_all_origins=True, + routes=[("/unauthorized", unauthorized_view)] + ) + environ["HTTP_ORIGIN"] = "https://example.com" + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" + environ["PATH_INFO"] = "/unauthorized" + environ["REQUEST_METHOD"] = "OPTIONS" + app(environ, start_response) + + assert start_response.status == "200 OK" + assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert start_response.get_headers()["content-length"] == "0" + + + def test_get_short_circuit(self): + """ + Test a scenario when a middleware that returns a response is run before + the `CorsMiddleware`. In this case + `CorsMiddleware.process_response()` should ignore the request. + """ + app, environ, start_response = setup( + middleware=[ + "spiderweb.tests.middleware.InterruptingMiddleware", + "spiderweb.middleware.cors.CorsMiddleware", + ], + cors_allow_credentials=True, + cors_allowed_origins=["https://example.com"], + ) + environ["HTTP_ORIGIN"] = "https://example.com" + app(environ, start_response) + + assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers() + + def test_get_short_circuit_should_be_ignored(self): + app, environ, start_response = setup( + middleware=[ + "spiderweb.tests.middleware.InterruptingMiddleware", + "spiderweb.middleware.cors.CorsMiddleware", + ], + cors_urls_regex=r"^/foo/$", + cors_allowed_origins=["https://example.com"], + ) + environ["HTTP_ORIGIN"] = "https://example.com" + app(environ, start_response) + + assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers() + + def test_get_regex_matches(self): + app, environ, start_response = setup( + **self.middleware, + cors_urls_regex=r"^/foo$", + cors_allowed_origins=["https://example.com"], + ) + environ["HTTP_ORIGIN"] = "https://example.com" + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" + environ["PATH_INFO"] = "/foo" + environ["REQUEST_METHOD"] = "GET" + app(environ, start_response) + + assert ACCESS_CONTROL_ALLOW_ORIGIN in start_response.get_headers() + + def test_get_regex_doesnt_match(self): + app, environ, start_response = setup( + **self.middleware, + cors_urls_regex=r"^/not-foo/$", + cors_allowed_origins=["https://example.com"], + ) + environ["HTTP_ORIGIN"] = "https://example.com" + environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET" + environ["PATH_INFO"] = "/foo" + environ["REQUEST_METHOD"] = "GET" + app(environ, start_response) + + assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers() + + def test_works_if_view_deletes_cors_enabled(self): + """ + Just in case something crazy happens in the view or other middleware, + check that get_response doesn't fall over if `_cors_enabled` is removed + """ + def yeet(request): + del request._cors_enabled + return HttpResponse("hahaha") + + app, environ, start_response = setup( + **self.middleware, + cors_allowed_origins=["https://example.com"], + routes=[('/yeet', yeet)] + ) + + environ["HTTP_ORIGIN"] = "https://example.com" + environ["PATH_INFO"] = "/yeet" + environ["REQUEST_METHOD"] = "GET" + app(environ, start_response) + + assert ACCESS_CONTROL_ALLOW_ORIGIN in start_response.get_headers() diff --git a/spiderweb/tests/views_for_tests.py b/spiderweb/tests/views_for_tests.py index cc4dac1..62e21b1 100644 --- a/spiderweb/tests/views_for_tests.py +++ b/spiderweb/tests/views_for_tests.py @@ -1,3 +1,4 @@ +from spiderweb import HttpResponse from spiderweb.decorators import csrf_exempt from spiderweb.response import JsonResponse, TemplateResponse @@ -38,3 +39,11 @@ def form_view_with_csrf(request): return JsonResponse(data=request.POST) else: return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF) + + +def text_view(request): + return HttpResponse("Hi!") + + +def unauthorized_view(request): + return HttpResponse("Unauthorized", status_code=401) \ No newline at end of file diff --git a/spiderweb/utils.py b/spiderweb/utils.py index e00bcb7..635c910 100644 --- a/spiderweb/utils.py +++ b/spiderweb/utils.py @@ -67,15 +67,35 @@ def is_jsonable(data: str) -> bool: class Headers(dict): - # special dict that forces lowercase for all keys + # special dict that forces lowercase and snake_case for all keys def __getitem__(self, key): - return super().__getitem__(key.lower()) + key = key.replace("-", "_") + try: + regular = super().__getitem__(key.lower()) + except KeyError: + regular = None + try: + http_version = super().__getitem__(f"http_{key.lower()}") + except KeyError: + http_version = None + return regular or http_version + + def __contains__(self, item): + item = item.lower().replace("-", "_") + + regular = super().__contains__(item) + http = super().__contains__(f"http_{item}") + + return regular or http def __setitem__(self, key, value): - return super().__setitem__(key.lower(), value) + return super().__setitem__(key.lower().replace("-", "_"), value) def get(self, key, default=None): - return super().get(key.lower(), default) + key = key.replace("-", "_") + regular = super().get(key.lower(), default) + http_version = super().get(f"http_{key.lower()}", default) + return regular or http_version def setdefault(self, key, default=None): return super().setdefault(key.lower(), default)