Compare commits

..

3 Commits

Author SHA1 Message Date
0d1fba1aad 🔖 v1.0!!! 2024-09-02 17:35:06 -04:00
f9225848a6 add tests for cors and get coverage to 89% 2024-09-02 17:34:50 -04:00
5cf9dff13a 🎨 reformat tests to remove some duplicate lines 2024-09-02 10:50:07 -04:00
13 changed files with 542 additions and 83 deletions

View File

@ -25,6 +25,7 @@ app = SpiderwebRouter(
], ],
staticfiles_dirs=["static_files"], staticfiles_dirs=["static_files"],
append_slash=False, # default append_slash=False, # default
cors_allow_all_origins=True,
) )

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "spiderweb-framework" name = "spiderweb-framework"
version = "0.12.0" version = "1.0.0"
description = "A small web framework, just big enough for a spider." description = "A small web framework, just big enough for a spider."
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"] authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
readme = "README.md" readme = "README.md"

View File

@ -1,8 +1,8 @@
from peewee import DatabaseProxy from peewee import DatabaseProxy
DEFAULT_ALLOWED_METHODS = ["GET"] DEFAULT_ALLOWED_METHODS = ["POST", "GET", "PUT", "PATCH", "DELETE"]
DEFAULT_ENCODING = "UTF-8" DEFAULT_ENCODING = "UTF-8"
__version__ = "0.12.0" __version__ = "1.0.0"
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"

View File

@ -50,7 +50,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
port: int = None, port: int = None,
allowed_hosts: Sequence[str | re.Pattern] = None, allowed_hosts: Sequence[str | re.Pattern] = None,
cors_allowed_origins: Sequence[str] = None, cors_allowed_origins: Sequence[str] = None,
cors_allowed_origins_regexes: Sequence[str] = None, cors_allowed_origin_regexes: Sequence[str] = None,
cors_allow_all_origins: bool = False, cors_allow_all_origins: bool = False,
cors_urls_regex: str | re.Pattern[str] = r"^.*$", cors_urls_regex: str | re.Pattern[str] = r"^.*$",
cors_allow_methods: Sequence[str] = None, cors_allow_methods: Sequence[str] = None,
@ -94,7 +94,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts] self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
self.cors_allowed_origins = cors_allowed_origins or [] self.cors_allowed_origins = cors_allowed_origins or []
self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or [] self.cors_allowed_origin_regexes = cors_allowed_origin_regexes or []
self.cors_allow_all_origins = cors_allow_all_origins self.cors_allow_all_origins = cors_allow_all_origins
self.cors_urls_regex = cors_urls_regex self.cors_urls_regex = cors_urls_regex
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
@ -171,17 +171,19 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
status = get_http_status_by_code(resp.status_code) status = get_http_status_by_code(resp.status_code)
cookies = [] cookies = []
varies = [] varies = []
resp.headers = {k.replace("_", "-"): v for k, v in resp.headers.items()}
if "set-cookie" in resp.headers: if "set-cookie" in resp.headers:
cookies = resp.headers["set-cookie"] cookies = resp.headers["set-cookie"]
del resp.headers["set-cookie"] del resp.headers["set-cookie"]
if "vary" in resp.headers: if "vary" in resp.headers:
varies = resp.headers["vary"] varies = resp.headers["vary"]
del resp.headers["vary"] del resp.headers["vary"]
resp.headers = {k: str(v) for k, v in resp.headers.items()}
headers = list(resp.headers.items()) headers = list(resp.headers.items())
for c in cookies: for c in cookies:
headers.append(("Set-Cookie", c)) headers.append(("set-cookie", str(c)))
for v in varies: for v in varies:
headers.append(("Vary", v)) headers.append(("vary", str(v)))
start_response(status, headers) start_response(status, headers)
@ -271,7 +273,6 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
def __call__(self, environ, start_response, *args, **kwargs): def __call__(self, environ, start_response, *args, **kwargs):
"""Entry point for WSGI apps.""" """Entry point for WSGI apps."""
request = self.get_request(environ) request = self.get_request(environ)
try: try:
handler, additional_args, allowed_methods = self.get_route(request.path) handler, additional_args, allowed_methods = self.get_route(request.path)
except NotFound: except NotFound:

View File

@ -38,7 +38,7 @@ class MiddlewareMixin:
if errors: if errors:
# just show the messages # just show the messages
sys.tracebacklimit = 0 sys.tracebacklimit = 1
raise StartupErrors( raise StartupErrors(
"Problems were identified during startup — cannot continue.", errors "Problems were identified during startup — cannot continue.", errors
) )

View File

@ -23,13 +23,11 @@ class VerifyValidCorsSetting(ServerCheck):
" `cors_allowed_origins`, `cors_allowed_origin_regexes`, or" " `cors_allowed_origins`, `cors_allowed_origin_regexes`, or"
" `cors_allow_all_origins`.", " `cors_allow_all_origins`.",
) )
def check(self): def check(self):
# - `cors_allowed_origins`
# - `cors_allowed_origin_regexes`
# - `cors_allow_all_origins`
if ( if (
not self.server.cors_allowed_origins not self.server.cors_allowed_origins
and not self.server.cors.allowed_origin_regexes and not self.server.cors_allowed_origin_regexes
and not self.server.cors_allow_all_origins and not self.server.cors_allow_all_origins
): ):
return ConfigError(self.INVALID_BASE_CONFIG) return ConfigError(self.INVALID_BASE_CONFIG)
@ -51,7 +49,6 @@ class CorsMiddleware(SpiderwebMiddleware):
enabled = getattr(request, "_cors_enabled", None) enabled = getattr(request, "_cors_enabled", None)
if enabled is None: if enabled is None:
enabled = self.is_enabled(request) enabled = self.is_enabled(request)
if not enabled: if not enabled:
return response return response
@ -60,7 +57,7 @@ class CorsMiddleware(SpiderwebMiddleware):
else: else:
response.headers["vary"] = ["origin"] response.headers["vary"] = ["origin"]
origin = request.headers.get("origin") origin = request.headers.get("http_origin")
if not origin: if not origin:
return response return response
@ -102,10 +99,9 @@ class CorsMiddleware(SpiderwebMiddleware):
response.headers[ACCESS_CONTROL_MAX_AGE] = str( response.headers[ACCESS_CONTROL_MAX_AGE] = str(
self.server.cors_preflight_max_age self.server.cors_preflight_max_age
) )
if ( if (
self.server.cors_allow_private_network self.server.cors_allow_private_network
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true" and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK.replace("-", "_")) == "true"
): ):
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true" response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
@ -133,8 +129,8 @@ class CorsMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse | None: def process_request(self, request: Request) -> HttpResponse | None:
# Identify and handle a preflight request # Identify and handle a preflight request
# origin = request.META.get("HTTP_ORIGIN")
request._cors_enabled = self.is_enabled(request) request._cors_enabled = self.is_enabled(request)
request.META["cors_ran"] = True
if ( if (
request._cors_enabled request._cors_enabled
and request.method == "OPTIONS" and request.method == "OPTIONS"
@ -150,9 +146,13 @@ class CorsMiddleware(SpiderwebMiddleware):
self.add_response_headers(request, resp) self.add_response_headers(request, resp)
return resp return resp
def process_response( def process_response(self, request: Request, response: HttpResponse) -> None:
self, request: Request, response: HttpResponse if not request.META.get("cors_ran"):
) -> None: # something happened and process_request didn't run. Abort early.
# We're not relying on request._cors_enabled because it's more
# visible and the view may have destroyed it accidentally.
return
self.add_response_headers(request, response) self.add_response_headers(request, response)
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code # [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code

View File

@ -72,7 +72,9 @@ class CSRFMiddleware(SpiderwebMiddleware):
def is_trusted_origin(self, request) -> bool: def is_trusted_origin(self, request) -> bool:
origin = request.headers.get("http_origin") origin = request.headers.get("http_origin")
referrer = request.headers.get("http_referer") or request.headers.get("http_referrer") referrer = request.headers.get("http_referer") or request.headers.get(
"http_referrer"
)
host = request.headers.get("http_host") host = request.headers.get("http_host")
if not origin and not (host == referrer): if not origin and not (host == referrer):
@ -88,13 +90,12 @@ class CSRFMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse | None: def process_request(self, request: Request) -> HttpResponse | None:
if request.method == "POST": if request.method == "POST":
if hasattr(request.handler, "csrf_exempt"): if hasattr(request.handler, "csrf_exempt"):
if request.handler.csrf_exempt is True: if request.handler.csrf_exempt is True:
return return
csrf_token = ( csrf_token = (
request.headers.get("X-CSRF-TOKEN") request.headers.get("x-csrf-token")
or request.GET.get("csrf_token") or request.GET.get("csrf_token")
or request.POST.get("csrf_token") or request.POST.get("csrf_token")
) )
@ -109,7 +110,7 @@ class CSRFMiddleware(SpiderwebMiddleware):
def process_response(self, request: Request, response: HttpResponse) -> None: def process_response(self, request: Request, response: HttpResponse) -> None:
token = self.get_csrf_token(request) token = self.get_csrf_token(request)
# do we need it in both places? # do we need it in both places?
response.headers["X-CSRF-TOKEN"] = token response.headers["x-csrf-token"] = token
response.context |= { response.context |= {
"csrf_token": f"""<input type="hidden" name="csrf_token" value="{token}">""", "csrf_token": f"""<input type="hidden" name="csrf_token" value="{token}">""",
"raw_csrf_token": token, # in case they want to format it themselves "raw_csrf_token": token, # in case they want to format it themselves

View File

@ -15,14 +15,16 @@ class StartResponse:
self.headers = headers self.headers = headers
def get_headers(self): def get_headers(self):
return {h[0]: h[1] for h in self.headers} return {h[0]: h[1] for h in self.headers} if self.headers else {}
def setup(): def setup(**kwargs):
environ = {} environ = {}
setup_testing_defaults(environ) setup_testing_defaults(environ)
if "db" not in kwargs:
kwargs["db"] = SqliteDatabase("spiderweb-tests.db")
return ( return (
SpiderwebRouter(db=SqliteDatabase("spiderweb-tests.db")), SpiderwebRouter(**kwargs),
environ, environ,
StartResponse(), StartResponse(),
) )

View File

@ -11,3 +11,8 @@ class ExplodingResponseMiddleware(SpiderwebMiddleware):
self, request: Request, response: HttpResponse self, request: Request, response: HttpResponse
) -> HttpResponse | None: ) -> HttpResponse | None:
raise UnusedMiddleware("Unfinished!") raise UnusedMiddleware("Unfinished!")
class InterruptingMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse:
return HttpResponse("Moo!")

View File

@ -4,30 +4,27 @@ from datetime import timedelta
import pytest import pytest
from peewee import SqliteDatabase from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors from spiderweb import SpiderwebRouter, HttpResponse, StartupErrors
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.cors import (
ACCESS_CONTROL_ALLOW_ORIGIN,
ACCESS_CONTROL_ALLOW_HEADERS,
ACCESS_CONTROL_ALLOW_METHODS,
ACCESS_CONTROL_EXPOSE_HEADERS,
ACCESS_CONTROL_ALLOW_CREDENTIALS,
ACCESS_CONTROL_MAX_AGE,
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK,
)
from spiderweb.middleware.sessions import Session from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf from spiderweb.middleware import csrf
from spiderweb.tests.helpers import setup from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import ( from spiderweb.tests.views_for_tests import (
form_view_with_csrf, form_view_with_csrf,
form_csrf_exempt, form_csrf_exempt,
form_view_without_csrf, form_view_without_csrf, text_view, unauthorized_view,
) )
# app = SpiderwebRouter(
# middleware=[
# "spiderweb.middleware.sessions.SessionMiddleware",
# "spiderweb.middleware.csrf.CSRFMiddleware",
# "example_middleware.TestMiddleware",
# "example_middleware.RedirectMiddleware",
# "spiderweb.middleware.pydantic.PydanticMiddleware",
# "example_middleware.ExplodingMiddleware",
# ],
# )
def index(request): def index(request):
if "value" in request.SESSION: if "value" in request.SESSION:
request.SESSION["value"] += 1 request.SESSION["value"] += 1
@ -37,10 +34,8 @@ def index(request):
def test_session_middleware(): def test_session_middleware():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=["spiderweb.middleware.sessions.SessionMiddleware"], middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", index) app.add_route("/", index)
@ -58,10 +53,8 @@ def test_session_middleware():
def test_expired_session(): def test_expired_session():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=["spiderweb.middleware.sessions.SessionMiddleware"], middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", index) app.add_route("/", index)
@ -85,13 +78,11 @@ def test_expired_session():
def test_exploding_middleware(): def test_exploding_middleware():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.tests.middleware.ExplodingRequestMiddleware", "spiderweb.tests.middleware.ExplodingRequestMiddleware",
"spiderweb.tests.middleware.ExplodingResponseMiddleware", "spiderweb.tests.middleware.ExplodingResponseMiddleware",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", index) app.add_route("/", index)
@ -102,7 +93,6 @@ def test_exploding_middleware():
def test_csrf_middleware_without_session_middleware(): def test_csrf_middleware_without_session_middleware():
_, environ, start_response = setup()
with pytest.raises(StartupErrors) as e: with pytest.raises(StartupErrors) as e:
SpiderwebRouter( SpiderwebRouter(
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"], middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
@ -116,15 +106,14 @@ def test_csrf_middleware_without_session_middleware():
def test_csrf_middleware_above_session_middleware(): def test_csrf_middleware_above_session_middleware():
_, environ, start_response = setup()
with pytest.raises(StartupErrors) as e: with pytest.raises(StartupErrors) as e:
SpiderwebRouter( app, environ, start_response = setup(
middleware=[ middleware=[
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
exceptiongroup = e.value.args[1] exceptiongroup = e.value.args[1]
assert ( assert (
exceptiongroup[0].args[0] exceptiongroup[0].args[0]
@ -133,13 +122,11 @@ def test_csrf_middleware_above_session_middleware():
def test_csrf_middleware(): def test_csrf_middleware():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", form_view_with_csrf, ["GET", "POST"]) app.add_route("/", form_view_with_csrf, ["GET", "POST"])
@ -179,8 +166,8 @@ def test_csrf_middleware():
b_handle = BytesIO() b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING)) b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0) b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle) environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_X_CSRF_TOKEN"] = None
resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING) resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp3 assert "CSRF token is invalid" in resp3
@ -198,14 +185,13 @@ def test_csrf_middleware():
def test_csrf_expired_token(): def test_csrf_expired_token():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.middleware[1].CSRF_EXPIRY = -1 app.middleware[1].CSRF_EXPIRY = -1
app.add_route("/", form_view_with_csrf, ["GET", "POST"]) app.add_route("/", form_view_with_csrf, ["GET", "POST"])
@ -235,13 +221,11 @@ def test_csrf_expired_token():
def test_csrf_exempt(): def test_csrf_exempt():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", form_csrf_exempt, ["GET", "POST"]) app.add_route("/", form_csrf_exempt, ["GET", "POST"])
@ -268,8 +252,7 @@ def test_csrf_exempt():
def test_csrf_trusted_origins(): def test_csrf_trusted_origins():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
@ -277,9 +260,7 @@ def test_csrf_trusted_origins():
csrf_trusted_origins=[ csrf_trusted_origins=[
"example.com", "example.com",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", form_view_without_csrf, ["GET", "POST"]) app.add_route("/", form_view_without_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi" environ["HTTP_USER_AGENT"] = "hi"
@ -306,3 +287,448 @@ def test_csrf_trusted_origins():
environ["HTTP_ORIGIN"] = "example.com" environ["HTTP_ORIGIN"] = "example.com"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING) resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert resp2 == '{"name": "bob"}' assert resp2 == '{"name": "bob"}'
class TestCorsMiddleware:
# adapted from:
# https://github.com/adamchainz/django-cors-headers/blob/main/tests/test_middleware.py
# to make sure I didn't miss anything
middleware = {"middleware": ["spiderweb.middleware.cors.CorsMiddleware"]}
def test_get_no_origin(self):
app, environ, start_response = setup(
**self.middleware, cors_allow_all_origins=True
)
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_origin_vary_by_default(self):
app, environ, start_response = setup(
**self.middleware, cors_allow_all_origins=True
)
app(environ, start_response)
assert start_response.get_headers()["vary"] == "origin"
def test_get_invalid_origin(self):
app, environ, start_response = setup(
**self.middleware, cors_allow_all_origins=True
)
environ["HTTP_ORIGIN"] = "https://example.com]"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_not_in_allowed_origins(self):
app, environ, start_response = setup(
**self.middleware, cors_allowed_origins=["https://example.com"]
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_not_in_allowed_origins_due_to_wrong_scheme(self):
app, environ, start_response = setup(
**self.middleware, cors_allowed_origins=["http://example.org"]
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_in_allowed_origins(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com", "https://example.org"],
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.org"
)
def test_null_in_allowed_origins(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com", "null"],
)
environ["HTTP_ORIGIN"] = "null"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "null"
def test_file_in_allowed_origins(self):
"""
'file://' should be allowed as an origin since Chrome on Android
mistakenly sends it
"""
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com", "file://"],
)
environ["HTTP_ORIGIN"] = "file://"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "file://"
def test_get_expose_headers(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
cors_expose_headers=["accept", "content-type"]
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_EXPOSE_HEADERS] == "accept, content-type"
def test_get_dont_expose_headers(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_EXPOSE_HEADERS not in start_response.get_headers()
def test_get_allow_credentials(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
cors_allow_credentials=True,
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS] == "true"
def test_get_allow_credentials_bad_origin(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
cors_allow_credentials=True,
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in start_response.get_headers()
def test_get_allow_credentials_disabled(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in start_response.get_headers()
def test_allow_private_network_added_if_enabled_and_requested(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allow_all_origins=True
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] == "true"
def test_allow_private_network_not_added_if_enabled_and_not_requested(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allow_all_origins=True
)
environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
def test_allow_private_network_not_added_if_enabled_and_no_cors_origin(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allowed_origins=["http://example.com"]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
def test_allow_private_network_not_added_if_disabled_and_requested(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=False,
cors_allow_all_origins=True
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
def test_options_allowed_origin(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_headers=["content-type"],
cors_allow_methods=["GET", "OPTIONS"],
cors_preflight_max_age=1002,
cors_allow_all_origins=True
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
headers = start_response.get_headers()
assert start_response.status == '200 OK'
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
assert headers[ACCESS_CONTROL_MAX_AGE] == "1002"
def test_options_no_max_age(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_headers=["content-type"],
cors_allow_methods=["GET", "OPTIONS"],
cors_preflight_max_age=0,
cors_allow_all_origins=True
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
headers = start_response.get_headers()
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
assert ACCESS_CONTROL_MAX_AGE not in headers
def test_options_allowed_origins_with_port(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://localhost:9000"]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://localhost:9000"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://localhost:9000"
def test_options_adds_origin_when_domain_found_in_allowed_regexes(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origin_regexes=[r"^https://\w+\.example\.com$"]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://foo.example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com"
def test_options_adds_origin_when_domain_found_in_allowed_regexes_second(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origin_regexes=[
r"^https://\w+\.example\.org$",
r"^https://\w+\.example\.com$",
],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://foo.example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com"
def test_options_doesnt_add_origin_when_domain_not_found_in_allowed_regexes(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origin_regexes=[r"^https://\w+\.example\.org$"],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://foo.example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_options_empty_request_method(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = ""
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "200 OK"
def test_options_no_headers(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
routes=[
("/", text_view)
]
)
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "405 Method Not Allowed"
def test_allow_all_origins_get(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/", text_view)]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert start_response.get_headers()["vary"] == "origin"
def test_allow_all_origins_options(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/", text_view)]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert start_response.get_headers()["vary"] == "origin"
def test_non_200_headers_still_set(self):
"""
It's not clear whether the header should still be set for non-HTTP200
when not a preflight request. However, this is the existing behavior for
django-cors-middleware, and Spiderweb should mirror it.
"""
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/unauthorized", unauthorized_view)]
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["PATH_INFO"] = "/unauthorized"
app(environ, start_response)
assert start_response.status == "401 Unauthorized"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
def test_auth_view_options(self):
"""
Ensure HTTP200 and header still set, for preflight requests to views requiring
authentication. See: https://github.com/adamchainz/django-cors-headers/issues/3
"""
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/unauthorized", unauthorized_view)]
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["PATH_INFO"] = "/unauthorized"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "200 OK"
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
assert start_response.get_headers()["content-length"] == "0"
def test_get_short_circuit(self):
"""
Test a scenario when a middleware that returns a response is run before
the `CorsMiddleware`. In this case
`CorsMiddleware.process_response()` should ignore the request.
"""
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.InterruptingMiddleware",
"spiderweb.middleware.cors.CorsMiddleware",
],
cors_allow_credentials=True,
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_short_circuit_should_be_ignored(self):
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.InterruptingMiddleware",
"spiderweb.middleware.cors.CorsMiddleware",
],
cors_urls_regex=r"^/foo/$",
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_regex_matches(self):
app, environ, start_response = setup(
**self.middleware,
cors_urls_regex=r"^/foo$",
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["PATH_INFO"] = "/foo"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN in start_response.get_headers()
def test_get_regex_doesnt_match(self):
app, environ, start_response = setup(
**self.middleware,
cors_urls_regex=r"^/not-foo/$",
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["PATH_INFO"] = "/foo"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_works_if_view_deletes_cors_enabled(self):
"""
Just in case something crazy happens in the view or other middleware,
check that get_response doesn't fall over if `_cors_enabled` is removed
"""
def yeet(request):
del request._cors_enabled
return HttpResponse("hahaha")
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
routes=[('/yeet', yeet)]
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["PATH_INFO"] = "/yeet"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN in start_response.get_headers()

View File

@ -75,15 +75,13 @@ def test_redirect_response():
def test_add_route_at_server_start(): def test_add_route_at_server_start():
app, environ, start_response = setup()
def index(request): def index(request):
return RedirectResponse(location="/redirected") return RedirectResponse(location="/redirected")
def view2(request): def view2(request):
return HttpResponse("View 2") return HttpResponse("View 2")
app = SpiderwebRouter( app, environ, start_response = setup(
routes=[ routes=[
("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}), ("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}),
("/view2", view2), ("/view2", view2),
@ -95,8 +93,7 @@ def test_add_route_at_server_start():
def test_redirect_on_append_slash(): def test_redirect_on_append_slash():
_, environ, start_response = setup() app, environ, start_response = setup(append_slash=True)
app = SpiderwebRouter(append_slash=True)
@app.route("/hello") @app.route("/hello")
def index(request): def index(request):
@ -109,9 +106,7 @@ def test_redirect_on_append_slash():
@given(st.text()) @given(st.text())
def test_template_response_with_template(text): def test_template_response_with_template(text):
_, environ, start_response = setup() app, environ, start_response = setup(templates_dirs=["spiderweb/tests"])
app = SpiderwebRouter(templates_dirs=["spiderweb/tests"])
@app.route("/") @app.route("/")
def index(request): def index(request):
@ -174,11 +169,10 @@ def test_duplicate_error_view():
def test_missing_view_with_custom_404_alt(): def test_missing_view_with_custom_404_alt():
_, environ, start_response = setup()
def custom_404(request): def custom_404(request):
return HttpResponse("Custom 404 2") return HttpResponse("Custom 404 2")
app = SpiderwebRouter(error_routes={404: custom_404}) app, environ, start_response = setup(error_routes={404: custom_404})
assert app(environ, start_response) == [b"Custom 404 2"] assert app(environ, start_response) == [b"Custom 404 2"]

View File

@ -1,3 +1,4 @@
from spiderweb import HttpResponse
from spiderweb.decorators import csrf_exempt from spiderweb.decorators import csrf_exempt
from spiderweb.response import JsonResponse, TemplateResponse from spiderweb.response import JsonResponse, TemplateResponse
@ -38,3 +39,11 @@ def form_view_with_csrf(request):
return JsonResponse(data=request.POST) return JsonResponse(data=request.POST)
else: else:
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF) return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)
def text_view(request):
return HttpResponse("Hi!")
def unauthorized_view(request):
return HttpResponse("Unauthorized", status_code=401)

View File

@ -67,15 +67,35 @@ def is_jsonable(data: str) -> bool:
class Headers(dict): class Headers(dict):
# special dict that forces lowercase for all keys # special dict that forces lowercase and snake_case for all keys
def __getitem__(self, key): def __getitem__(self, key):
return super().__getitem__(key.lower()) key = key.replace("-", "_")
try:
regular = super().__getitem__(key.lower())
except KeyError:
regular = None
try:
http_version = super().__getitem__(f"http_{key.lower()}")
except KeyError:
http_version = None
return regular or http_version
def __contains__(self, item):
item = item.lower().replace("-", "_")
regular = super().__contains__(item)
http = super().__contains__(f"http_{item}")
return regular or http
def __setitem__(self, key, value): def __setitem__(self, key, value):
return super().__setitem__(key.lower(), value) return super().__setitem__(key.lower().replace("-", "_"), value)
def get(self, key, default=None): def get(self, key, default=None):
return super().get(key.lower(), default) key = key.replace("-", "_")
regular = super().get(key.lower(), default)
http_version = super().get(f"http_{key.lower()}", default)
return regular or http_version
def setdefault(self, key, default=None): def setdefault(self, key, default=None):
return super().setdefault(key.lower(), default) return super().setdefault(key.lower(), default)