Compare commits
No commits in common. "0d1fba1aad68a7b64ffed6485d033bd5c2f081fc" and "8cdc6eef4438296c076eff94e962bd92ce73c794" have entirely different histories.
0d1fba1aad
...
8cdc6eef44
@ -25,7 +25,6 @@ app = SpiderwebRouter(
|
|||||||
],
|
],
|
||||||
staticfiles_dirs=["static_files"],
|
staticfiles_dirs=["static_files"],
|
||||||
append_slash=False, # default
|
append_slash=False, # default
|
||||||
cors_allow_all_origins=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "spiderweb-framework"
|
name = "spiderweb-framework"
|
||||||
version = "1.0.0"
|
version = "0.12.0"
|
||||||
description = "A small web framework, just big enough for a spider."
|
description = "A small web framework, just big enough for a spider."
|
||||||
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
|
authors = ["Joe Kaufeld <opensource@joekaufeld.com>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from peewee import DatabaseProxy
|
from peewee import DatabaseProxy
|
||||||
|
|
||||||
DEFAULT_ALLOWED_METHODS = ["POST", "GET", "PUT", "PATCH", "DELETE"]
|
DEFAULT_ALLOWED_METHODS = ["GET"]
|
||||||
DEFAULT_ENCODING = "UTF-8"
|
DEFAULT_ENCODING = "UTF-8"
|
||||||
__version__ = "1.0.0"
|
__version__ = "0.12.0"
|
||||||
|
|
||||||
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
|
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
|
||||||
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
|
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
|
||||||
|
@ -50,7 +50,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
port: int = None,
|
port: int = None,
|
||||||
allowed_hosts: Sequence[str | re.Pattern] = None,
|
allowed_hosts: Sequence[str | re.Pattern] = None,
|
||||||
cors_allowed_origins: Sequence[str] = None,
|
cors_allowed_origins: Sequence[str] = None,
|
||||||
cors_allowed_origin_regexes: Sequence[str] = None,
|
cors_allowed_origins_regexes: Sequence[str] = None,
|
||||||
cors_allow_all_origins: bool = False,
|
cors_allow_all_origins: bool = False,
|
||||||
cors_urls_regex: str | re.Pattern[str] = r"^.*$",
|
cors_urls_regex: str | re.Pattern[str] = r"^.*$",
|
||||||
cors_allow_methods: Sequence[str] = None,
|
cors_allow_methods: Sequence[str] = None,
|
||||||
@ -94,7 +94,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
|
self.allowed_hosts = [convert_url_to_regex(i) for i in self._allowed_hosts]
|
||||||
|
|
||||||
self.cors_allowed_origins = cors_allowed_origins or []
|
self.cors_allowed_origins = cors_allowed_origins or []
|
||||||
self.cors_allowed_origin_regexes = cors_allowed_origin_regexes or []
|
self.cors_allowed_origins_regexes = cors_allowed_origins_regexes or []
|
||||||
self.cors_allow_all_origins = cors_allow_all_origins
|
self.cors_allow_all_origins = cors_allow_all_origins
|
||||||
self.cors_urls_regex = cors_urls_regex
|
self.cors_urls_regex = cors_urls_regex
|
||||||
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
|
self.cors_allow_methods = cors_allow_methods or DEFAULT_CORS_ALLOW_METHODS
|
||||||
@ -171,19 +171,17 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
status = get_http_status_by_code(resp.status_code)
|
status = get_http_status_by_code(resp.status_code)
|
||||||
cookies = []
|
cookies = []
|
||||||
varies = []
|
varies = []
|
||||||
resp.headers = {k.replace("_", "-"): v for k, v in resp.headers.items()}
|
|
||||||
if "set-cookie" in resp.headers:
|
if "set-cookie" in resp.headers:
|
||||||
cookies = resp.headers["set-cookie"]
|
cookies = resp.headers["set-cookie"]
|
||||||
del resp.headers["set-cookie"]
|
del resp.headers["set-cookie"]
|
||||||
if "vary" in resp.headers:
|
if "vary" in resp.headers:
|
||||||
varies = resp.headers["vary"]
|
varies = resp.headers["vary"]
|
||||||
del resp.headers["vary"]
|
del resp.headers["vary"]
|
||||||
resp.headers = {k: str(v) for k, v in resp.headers.items()}
|
|
||||||
headers = list(resp.headers.items())
|
headers = list(resp.headers.items())
|
||||||
for c in cookies:
|
for c in cookies:
|
||||||
headers.append(("set-cookie", str(c)))
|
headers.append(("Set-Cookie", c))
|
||||||
for v in varies:
|
for v in varies:
|
||||||
headers.append(("vary", str(v)))
|
headers.append(("Vary", v))
|
||||||
|
|
||||||
start_response(status, headers)
|
start_response(status, headers)
|
||||||
|
|
||||||
@ -273,6 +271,7 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
|||||||
def __call__(self, environ, start_response, *args, **kwargs):
|
def __call__(self, environ, start_response, *args, **kwargs):
|
||||||
"""Entry point for WSGI apps."""
|
"""Entry point for WSGI apps."""
|
||||||
request = self.get_request(environ)
|
request = self.get_request(environ)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
handler, additional_args, allowed_methods = self.get_route(request.path)
|
handler, additional_args, allowed_methods = self.get_route(request.path)
|
||||||
except NotFound:
|
except NotFound:
|
||||||
|
@ -38,7 +38,7 @@ class MiddlewareMixin:
|
|||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
# just show the messages
|
# just show the messages
|
||||||
sys.tracebacklimit = 1
|
sys.tracebacklimit = 0
|
||||||
raise StartupErrors(
|
raise StartupErrors(
|
||||||
"Problems were identified during startup — cannot continue.", errors
|
"Problems were identified during startup — cannot continue.", errors
|
||||||
)
|
)
|
||||||
|
@ -23,11 +23,13 @@ class VerifyValidCorsSetting(ServerCheck):
|
|||||||
" `cors_allowed_origins`, `cors_allowed_origin_regexes`, or"
|
" `cors_allowed_origins`, `cors_allowed_origin_regexes`, or"
|
||||||
" `cors_allow_all_origins`.",
|
" `cors_allow_all_origins`.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
|
# - `cors_allowed_origins`
|
||||||
|
# - `cors_allowed_origin_regexes`
|
||||||
|
# - `cors_allow_all_origins`
|
||||||
if (
|
if (
|
||||||
not self.server.cors_allowed_origins
|
not self.server.cors_allowed_origins
|
||||||
and not self.server.cors_allowed_origin_regexes
|
and not self.server.cors.allowed_origin_regexes
|
||||||
and not self.server.cors_allow_all_origins
|
and not self.server.cors_allow_all_origins
|
||||||
):
|
):
|
||||||
return ConfigError(self.INVALID_BASE_CONFIG)
|
return ConfigError(self.INVALID_BASE_CONFIG)
|
||||||
@ -49,6 +51,7 @@ class CorsMiddleware(SpiderwebMiddleware):
|
|||||||
enabled = getattr(request, "_cors_enabled", None)
|
enabled = getattr(request, "_cors_enabled", None)
|
||||||
if enabled is None:
|
if enabled is None:
|
||||||
enabled = self.is_enabled(request)
|
enabled = self.is_enabled(request)
|
||||||
|
|
||||||
if not enabled:
|
if not enabled:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -57,7 +60,7 @@ class CorsMiddleware(SpiderwebMiddleware):
|
|||||||
else:
|
else:
|
||||||
response.headers["vary"] = ["origin"]
|
response.headers["vary"] = ["origin"]
|
||||||
|
|
||||||
origin = request.headers.get("http_origin")
|
origin = request.headers.get("origin")
|
||||||
if not origin:
|
if not origin:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -99,9 +102,10 @@ class CorsMiddleware(SpiderwebMiddleware):
|
|||||||
response.headers[ACCESS_CONTROL_MAX_AGE] = str(
|
response.headers[ACCESS_CONTROL_MAX_AGE] = str(
|
||||||
self.server.cors_preflight_max_age
|
self.server.cors_preflight_max_age
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.server.cors_allow_private_network
|
self.server.cors_allow_private_network
|
||||||
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK.replace("-", "_")) == "true"
|
and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true"
|
||||||
):
|
):
|
||||||
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
|
response.headers[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true"
|
||||||
|
|
||||||
@ -129,8 +133,8 @@ class CorsMiddleware(SpiderwebMiddleware):
|
|||||||
|
|
||||||
def process_request(self, request: Request) -> HttpResponse | None:
|
def process_request(self, request: Request) -> HttpResponse | None:
|
||||||
# Identify and handle a preflight request
|
# Identify and handle a preflight request
|
||||||
|
# origin = request.META.get("HTTP_ORIGIN")
|
||||||
request._cors_enabled = self.is_enabled(request)
|
request._cors_enabled = self.is_enabled(request)
|
||||||
request.META["cors_ran"] = True
|
|
||||||
if (
|
if (
|
||||||
request._cors_enabled
|
request._cors_enabled
|
||||||
and request.method == "OPTIONS"
|
and request.method == "OPTIONS"
|
||||||
@ -146,13 +150,9 @@ class CorsMiddleware(SpiderwebMiddleware):
|
|||||||
self.add_response_headers(request, resp)
|
self.add_response_headers(request, resp)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def process_response(self, request: Request, response: HttpResponse) -> None:
|
def process_response(
|
||||||
if not request.META.get("cors_ran"):
|
self, request: Request, response: HttpResponse
|
||||||
# something happened and process_request didn't run. Abort early.
|
) -> None:
|
||||||
# We're not relying on request._cors_enabled because it's more
|
|
||||||
# visible and the view may have destroyed it accidentally.
|
|
||||||
return
|
|
||||||
self.add_response_headers(request, response)
|
self.add_response_headers(request, response)
|
||||||
|
|
||||||
|
|
||||||
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||||
|
@ -72,9 +72,7 @@ class CSRFMiddleware(SpiderwebMiddleware):
|
|||||||
|
|
||||||
def is_trusted_origin(self, request) -> bool:
|
def is_trusted_origin(self, request) -> bool:
|
||||||
origin = request.headers.get("http_origin")
|
origin = request.headers.get("http_origin")
|
||||||
referrer = request.headers.get("http_referer") or request.headers.get(
|
referrer = request.headers.get("http_referer") or request.headers.get("http_referrer")
|
||||||
"http_referrer"
|
|
||||||
)
|
|
||||||
host = request.headers.get("http_host")
|
host = request.headers.get("http_host")
|
||||||
|
|
||||||
if not origin and not (host == referrer):
|
if not origin and not (host == referrer):
|
||||||
@ -90,12 +88,13 @@ class CSRFMiddleware(SpiderwebMiddleware):
|
|||||||
|
|
||||||
def process_request(self, request: Request) -> HttpResponse | None:
|
def process_request(self, request: Request) -> HttpResponse | None:
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
|
|
||||||
if hasattr(request.handler, "csrf_exempt"):
|
if hasattr(request.handler, "csrf_exempt"):
|
||||||
if request.handler.csrf_exempt is True:
|
if request.handler.csrf_exempt is True:
|
||||||
return
|
return
|
||||||
|
|
||||||
csrf_token = (
|
csrf_token = (
|
||||||
request.headers.get("x-csrf-token")
|
request.headers.get("X-CSRF-TOKEN")
|
||||||
or request.GET.get("csrf_token")
|
or request.GET.get("csrf_token")
|
||||||
or request.POST.get("csrf_token")
|
or request.POST.get("csrf_token")
|
||||||
)
|
)
|
||||||
@ -110,7 +109,7 @@ class CSRFMiddleware(SpiderwebMiddleware):
|
|||||||
def process_response(self, request: Request, response: HttpResponse) -> None:
|
def process_response(self, request: Request, response: HttpResponse) -> None:
|
||||||
token = self.get_csrf_token(request)
|
token = self.get_csrf_token(request)
|
||||||
# do we need it in both places?
|
# do we need it in both places?
|
||||||
response.headers["x-csrf-token"] = token
|
response.headers["X-CSRF-TOKEN"] = token
|
||||||
response.context |= {
|
response.context |= {
|
||||||
"csrf_token": f"""<input type="hidden" name="csrf_token" value="{token}">""",
|
"csrf_token": f"""<input type="hidden" name="csrf_token" value="{token}">""",
|
||||||
"raw_csrf_token": token, # in case they want to format it themselves
|
"raw_csrf_token": token, # in case they want to format it themselves
|
||||||
|
@ -15,16 +15,14 @@ class StartResponse:
|
|||||||
self.headers = headers
|
self.headers = headers
|
||||||
|
|
||||||
def get_headers(self):
|
def get_headers(self):
|
||||||
return {h[0]: h[1] for h in self.headers} if self.headers else {}
|
return {h[0]: h[1] for h in self.headers}
|
||||||
|
|
||||||
|
|
||||||
def setup(**kwargs):
|
def setup():
|
||||||
environ = {}
|
environ = {}
|
||||||
setup_testing_defaults(environ)
|
setup_testing_defaults(environ)
|
||||||
if "db" not in kwargs:
|
|
||||||
kwargs["db"] = SqliteDatabase("spiderweb-tests.db")
|
|
||||||
return (
|
return (
|
||||||
SpiderwebRouter(**kwargs),
|
SpiderwebRouter(db=SqliteDatabase("spiderweb-tests.db")),
|
||||||
environ,
|
environ,
|
||||||
StartResponse(),
|
StartResponse(),
|
||||||
)
|
)
|
||||||
|
@ -11,8 +11,3 @@ class ExplodingResponseMiddleware(SpiderwebMiddleware):
|
|||||||
self, request: Request, response: HttpResponse
|
self, request: Request, response: HttpResponse
|
||||||
) -> HttpResponse | None:
|
) -> HttpResponse | None:
|
||||||
raise UnusedMiddleware("Unfinished!")
|
raise UnusedMiddleware("Unfinished!")
|
||||||
|
|
||||||
|
|
||||||
class InterruptingMiddleware(SpiderwebMiddleware):
|
|
||||||
def process_request(self, request: Request) -> HttpResponse:
|
|
||||||
return HttpResponse("Moo!")
|
|
@ -4,27 +4,30 @@ from datetime import timedelta
|
|||||||
import pytest
|
import pytest
|
||||||
from peewee import SqliteDatabase
|
from peewee import SqliteDatabase
|
||||||
|
|
||||||
from spiderweb import SpiderwebRouter, HttpResponse, StartupErrors
|
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors
|
||||||
from spiderweb.constants import DEFAULT_ENCODING
|
from spiderweb.constants import DEFAULT_ENCODING
|
||||||
from spiderweb.middleware.cors import (
|
|
||||||
ACCESS_CONTROL_ALLOW_ORIGIN,
|
|
||||||
ACCESS_CONTROL_ALLOW_HEADERS,
|
|
||||||
ACCESS_CONTROL_ALLOW_METHODS,
|
|
||||||
ACCESS_CONTROL_EXPOSE_HEADERS,
|
|
||||||
ACCESS_CONTROL_ALLOW_CREDENTIALS,
|
|
||||||
ACCESS_CONTROL_MAX_AGE,
|
|
||||||
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK,
|
|
||||||
)
|
|
||||||
from spiderweb.middleware.sessions import Session
|
from spiderweb.middleware.sessions import Session
|
||||||
from spiderweb.middleware import csrf
|
from spiderweb.middleware import csrf
|
||||||
from spiderweb.tests.helpers import setup
|
from spiderweb.tests.helpers import setup
|
||||||
from spiderweb.tests.views_for_tests import (
|
from spiderweb.tests.views_for_tests import (
|
||||||
form_view_with_csrf,
|
form_view_with_csrf,
|
||||||
form_csrf_exempt,
|
form_csrf_exempt,
|
||||||
form_view_without_csrf, text_view, unauthorized_view,
|
form_view_without_csrf,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# app = SpiderwebRouter(
|
||||||
|
# middleware=[
|
||||||
|
# "spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
|
# "spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
|
# "example_middleware.TestMiddleware",
|
||||||
|
# "example_middleware.RedirectMiddleware",
|
||||||
|
# "spiderweb.middleware.pydantic.PydanticMiddleware",
|
||||||
|
# "example_middleware.ExplodingMiddleware",
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
def index(request):
|
def index(request):
|
||||||
if "value" in request.SESSION:
|
if "value" in request.SESSION:
|
||||||
request.SESSION["value"] += 1
|
request.SESSION["value"] += 1
|
||||||
@ -34,8 +37,10 @@ def index(request):
|
|||||||
|
|
||||||
|
|
||||||
def test_session_middleware():
|
def test_session_middleware():
|
||||||
app, environ, start_response = setup(
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(
|
||||||
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
|
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_route("/", index)
|
app.add_route("/", index)
|
||||||
@ -53,8 +58,10 @@ def test_session_middleware():
|
|||||||
|
|
||||||
|
|
||||||
def test_expired_session():
|
def test_expired_session():
|
||||||
app, environ, start_response = setup(
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(
|
||||||
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
|
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_route("/", index)
|
app.add_route("/", index)
|
||||||
@ -78,11 +85,13 @@ def test_expired_session():
|
|||||||
|
|
||||||
|
|
||||||
def test_exploding_middleware():
|
def test_exploding_middleware():
|
||||||
app, environ, start_response = setup(
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(
|
||||||
middleware=[
|
middleware=[
|
||||||
"spiderweb.tests.middleware.ExplodingRequestMiddleware",
|
"spiderweb.tests.middleware.ExplodingRequestMiddleware",
|
||||||
"spiderweb.tests.middleware.ExplodingResponseMiddleware",
|
"spiderweb.tests.middleware.ExplodingResponseMiddleware",
|
||||||
],
|
],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_route("/", index)
|
app.add_route("/", index)
|
||||||
@ -93,6 +102,7 @@ def test_exploding_middleware():
|
|||||||
|
|
||||||
|
|
||||||
def test_csrf_middleware_without_session_middleware():
|
def test_csrf_middleware_without_session_middleware():
|
||||||
|
_, environ, start_response = setup()
|
||||||
with pytest.raises(StartupErrors) as e:
|
with pytest.raises(StartupErrors) as e:
|
||||||
SpiderwebRouter(
|
SpiderwebRouter(
|
||||||
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
|
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
|
||||||
@ -106,14 +116,15 @@ def test_csrf_middleware_without_session_middleware():
|
|||||||
|
|
||||||
|
|
||||||
def test_csrf_middleware_above_session_middleware():
|
def test_csrf_middleware_above_session_middleware():
|
||||||
|
_, environ, start_response = setup()
|
||||||
with pytest.raises(StartupErrors) as e:
|
with pytest.raises(StartupErrors) as e:
|
||||||
app, environ, start_response = setup(
|
SpiderwebRouter(
|
||||||
middleware=[
|
middleware=[
|
||||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
],
|
],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
|
||||||
exceptiongroup = e.value.args[1]
|
exceptiongroup = e.value.args[1]
|
||||||
assert (
|
assert (
|
||||||
exceptiongroup[0].args[0]
|
exceptiongroup[0].args[0]
|
||||||
@ -122,11 +133,13 @@ def test_csrf_middleware_above_session_middleware():
|
|||||||
|
|
||||||
|
|
||||||
def test_csrf_middleware():
|
def test_csrf_middleware():
|
||||||
app, environ, start_response = setup(
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(
|
||||||
middleware=[
|
middleware=[
|
||||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
],
|
],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
|
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
|
||||||
@ -166,8 +179,8 @@ def test_csrf_middleware():
|
|||||||
b_handle = BytesIO()
|
b_handle = BytesIO()
|
||||||
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
b_handle.write(formdata.encode(DEFAULT_ENCODING))
|
||||||
b_handle.seek(0)
|
b_handle.seek(0)
|
||||||
|
|
||||||
environ["wsgi.input"] = BufferedReader(b_handle)
|
environ["wsgi.input"] = BufferedReader(b_handle)
|
||||||
environ["HTTP_X_CSRF_TOKEN"] = None
|
|
||||||
resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
assert "CSRF token is invalid" in resp3
|
assert "CSRF token is invalid" in resp3
|
||||||
|
|
||||||
@ -185,13 +198,14 @@ def test_csrf_middleware():
|
|||||||
|
|
||||||
|
|
||||||
def test_csrf_expired_token():
|
def test_csrf_expired_token():
|
||||||
app, environ, start_response = setup(
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(
|
||||||
middleware=[
|
middleware=[
|
||||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
],
|
],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.middleware[1].CSRF_EXPIRY = -1
|
app.middleware[1].CSRF_EXPIRY = -1
|
||||||
|
|
||||||
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
|
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
|
||||||
@ -221,11 +235,13 @@ def test_csrf_expired_token():
|
|||||||
|
|
||||||
|
|
||||||
def test_csrf_exempt():
|
def test_csrf_exempt():
|
||||||
app, environ, start_response = setup(
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(
|
||||||
middleware=[
|
middleware=[
|
||||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
],
|
],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_route("/", form_csrf_exempt, ["GET", "POST"])
|
app.add_route("/", form_csrf_exempt, ["GET", "POST"])
|
||||||
@ -252,7 +268,8 @@ def test_csrf_exempt():
|
|||||||
|
|
||||||
|
|
||||||
def test_csrf_trusted_origins():
|
def test_csrf_trusted_origins():
|
||||||
app, environ, start_response = setup(
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(
|
||||||
middleware=[
|
middleware=[
|
||||||
"spiderweb.middleware.sessions.SessionMiddleware",
|
"spiderweb.middleware.sessions.SessionMiddleware",
|
||||||
"spiderweb.middleware.csrf.CSRFMiddleware",
|
"spiderweb.middleware.csrf.CSRFMiddleware",
|
||||||
@ -260,7 +277,9 @@ def test_csrf_trusted_origins():
|
|||||||
csrf_trusted_origins=[
|
csrf_trusted_origins=[
|
||||||
"example.com",
|
"example.com",
|
||||||
],
|
],
|
||||||
|
db=SqliteDatabase("spiderweb-tests.db"),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_route("/", form_view_without_csrf, ["GET", "POST"])
|
app.add_route("/", form_view_without_csrf, ["GET", "POST"])
|
||||||
|
|
||||||
environ["HTTP_USER_AGENT"] = "hi"
|
environ["HTTP_USER_AGENT"] = "hi"
|
||||||
@ -287,448 +306,3 @@ def test_csrf_trusted_origins():
|
|||||||
environ["HTTP_ORIGIN"] = "example.com"
|
environ["HTTP_ORIGIN"] = "example.com"
|
||||||
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
|
||||||
assert resp2 == '{"name": "bob"}'
|
assert resp2 == '{"name": "bob"}'
|
||||||
|
|
||||||
|
|
||||||
class TestCorsMiddleware:
|
|
||||||
# adapted from:
|
|
||||||
# https://github.com/adamchainz/django-cors-headers/blob/main/tests/test_middleware.py
|
|
||||||
# to make sure I didn't miss anything
|
|
||||||
middleware = {"middleware": ["spiderweb.middleware.cors.CorsMiddleware"]}
|
|
||||||
|
|
||||||
def test_get_no_origin(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware, cors_allow_all_origins=True
|
|
||||||
)
|
|
||||||
app(environ, start_response)
|
|
||||||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_get_origin_vary_by_default(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware, cors_allow_all_origins=True
|
|
||||||
)
|
|
||||||
app(environ, start_response)
|
|
||||||
assert start_response.get_headers()["vary"] == "origin"
|
|
||||||
|
|
||||||
def test_get_invalid_origin(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware, cors_allow_all_origins=True
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com]"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_get_not_in_allowed_origins(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware, cors_allowed_origins=["https://example.com"]
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.org"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_get_not_in_allowed_origins_due_to_wrong_scheme(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware, cors_allowed_origins=["http://example.org"]
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.org"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_get_in_allowed_origins(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allowed_origins=["https://example.com", "https://example.org"],
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.org"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert (
|
|
||||||
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
|
|
||||||
== "https://example.org"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_null_in_allowed_origins(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allowed_origins=["https://example.com", "null"],
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "null"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "null"
|
|
||||||
|
|
||||||
def test_file_in_allowed_origins(self):
|
|
||||||
"""
|
|
||||||
'file://' should be allowed as an origin since Chrome on Android
|
|
||||||
mistakenly sends it
|
|
||||||
"""
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allowed_origins=["https://example.com", "file://"],
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "file://"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "file://"
|
|
||||||
|
|
||||||
def test_get_expose_headers(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_all_origins=True,
|
|
||||||
cors_expose_headers=["accept", "content-type"]
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_EXPOSE_HEADERS] == "accept, content-type"
|
|
||||||
|
|
||||||
def test_get_dont_expose_headers(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_all_origins=True,
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert ACCESS_CONTROL_EXPOSE_HEADERS not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_get_allow_credentials(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allowed_origins=["https://example.com"],
|
|
||||||
cors_allow_credentials=True,
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS] == "true"
|
|
||||||
|
|
||||||
def test_get_allow_credentials_bad_origin(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allowed_origins=["https://example.com"],
|
|
||||||
cors_allow_credentials=True,
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.org"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_get_allow_credentials_disabled(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allowed_origins=["https://example.com"],
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_allow_private_network_added_if_enabled_and_requested(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_private_network=True,
|
|
||||||
cors_allow_all_origins=True
|
|
||||||
)
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
|
|
||||||
environ["HTTP_ORIGIN"] = "http://example.com"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] == "true"
|
|
||||||
|
|
||||||
def test_allow_private_network_not_added_if_enabled_and_not_requested(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_private_network=True,
|
|
||||||
cors_allow_all_origins=True
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "http://example.com"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_allow_private_network_not_added_if_enabled_and_no_cors_origin(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_private_network=True,
|
|
||||||
cors_allowed_origins=["http://example.com"]
|
|
||||||
)
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
|
|
||||||
environ["HTTP_ORIGIN"] = "http://example.org"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
|
|
||||||
|
|
||||||
|
|
||||||
def test_allow_private_network_not_added_if_disabled_and_requested(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_private_network=False,
|
|
||||||
cors_allow_all_origins=True
|
|
||||||
)
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
|
|
||||||
environ["HTTP_ORIGIN"] = "http://example.com"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_options_allowed_origin(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_headers=["content-type"],
|
|
||||||
cors_allow_methods=["GET", "OPTIONS"],
|
|
||||||
cors_preflight_max_age=1002,
|
|
||||||
cors_allow_all_origins=True
|
|
||||||
)
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
environ["REQUEST_METHOD"] = "OPTIONS"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
headers = start_response.get_headers()
|
|
||||||
|
|
||||||
assert start_response.status == '200 OK'
|
|
||||||
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
|
|
||||||
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
|
|
||||||
assert headers[ACCESS_CONTROL_MAX_AGE] == "1002"
|
|
||||||
|
|
||||||
|
|
||||||
def test_options_no_max_age(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_headers=["content-type"],
|
|
||||||
cors_allow_methods=["GET", "OPTIONS"],
|
|
||||||
cors_preflight_max_age=0,
|
|
||||||
cors_allow_all_origins=True
|
|
||||||
)
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
environ["REQUEST_METHOD"] = "OPTIONS"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
|
|
||||||
headers = start_response.get_headers()
|
|
||||||
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
|
|
||||||
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
|
|
||||||
assert ACCESS_CONTROL_MAX_AGE not in headers
|
|
||||||
|
|
||||||
def test_options_allowed_origins_with_port(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allowed_origins=["https://localhost:9000"]
|
|
||||||
)
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
|
|
||||||
environ["HTTP_ORIGIN"] = "https://localhost:9000"
|
|
||||||
environ["REQUEST_METHOD"] = "OPTIONS"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://localhost:9000"
|
|
||||||
|
|
||||||
def test_options_adds_origin_when_domain_found_in_allowed_regexes(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allowed_origin_regexes=[r"^https://\w+\.example\.com$"]
|
|
||||||
)
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
|
|
||||||
environ["HTTP_ORIGIN"] = "https://foo.example.com"
|
|
||||||
environ["REQUEST_METHOD"] = "OPTIONS"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com"
|
|
||||||
|
|
||||||
def test_options_adds_origin_when_domain_found_in_allowed_regexes_second(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allowed_origin_regexes=[
|
|
||||||
r"^https://\w+\.example\.org$",
|
|
||||||
r"^https://\w+\.example\.com$",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
|
|
||||||
environ["HTTP_ORIGIN"] = "https://foo.example.com"
|
|
||||||
environ["REQUEST_METHOD"] = "OPTIONS"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://foo.example.com"
|
|
||||||
|
|
||||||
def test_options_doesnt_add_origin_when_domain_not_found_in_allowed_regexes(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allowed_origin_regexes=[r"^https://\w+\.example\.org$"],
|
|
||||||
)
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
|
|
||||||
environ["HTTP_ORIGIN"] = "https://foo.example.com"
|
|
||||||
environ["REQUEST_METHOD"] = "OPTIONS"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_options_empty_request_method(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_all_origins=True,
|
|
||||||
)
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = ""
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
environ["REQUEST_METHOD"] = "OPTIONS"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert start_response.status == "200 OK"
|
|
||||||
|
|
||||||
|
|
||||||
def test_options_no_headers(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_all_origins=True,
|
|
||||||
routes=[
|
|
||||||
("/", text_view)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
environ["REQUEST_METHOD"] = "OPTIONS"
|
|
||||||
app(environ, start_response)
|
|
||||||
assert start_response.status == "405 Method Not Allowed"
|
|
||||||
|
|
||||||
def test_allow_all_origins_get(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_credentials=True,
|
|
||||||
cors_allow_all_origins=True,
|
|
||||||
routes=[("/", text_view)]
|
|
||||||
)
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
environ["REQUEST_METHOD"] = "GET"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert start_response.status == "200 OK"
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
|
|
||||||
assert start_response.get_headers()["vary"] == "origin"
|
|
||||||
|
|
||||||
def test_allow_all_origins_options(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_credentials=True,
|
|
||||||
cors_allow_all_origins=True,
|
|
||||||
routes=[("/", text_view)]
|
|
||||||
)
|
|
||||||
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
environ["REQUEST_METHOD"] = "OPTIONS"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert start_response.status == "200 OK"
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
|
|
||||||
assert start_response.get_headers()["vary"] == "origin"
|
|
||||||
|
|
||||||
def test_non_200_headers_still_set(self):
|
|
||||||
"""
|
|
||||||
It's not clear whether the header should still be set for non-HTTP200
|
|
||||||
when not a preflight request. However, this is the existing behavior for
|
|
||||||
django-cors-middleware, and Spiderweb should mirror it.
|
|
||||||
"""
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_credentials=True,
|
|
||||||
cors_allow_all_origins=True,
|
|
||||||
routes=[("/unauthorized", unauthorized_view)]
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
environ["PATH_INFO"] = "/unauthorized"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
|
|
||||||
assert start_response.status == "401 Unauthorized"
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
|
|
||||||
|
|
||||||
def test_auth_view_options(self):
|
|
||||||
"""
|
|
||||||
Ensure HTTP200 and header still set, for preflight requests to views requiring
|
|
||||||
authentication. See: https://github.com/adamchainz/django-cors-headers/issues/3
|
|
||||||
"""
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allow_credentials=True,
|
|
||||||
cors_allow_all_origins=True,
|
|
||||||
routes=[("/unauthorized", unauthorized_view)]
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
|
|
||||||
environ["PATH_INFO"] = "/unauthorized"
|
|
||||||
environ["REQUEST_METHOD"] = "OPTIONS"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert start_response.status == "200 OK"
|
|
||||||
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
|
|
||||||
assert start_response.get_headers()["content-length"] == "0"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_short_circuit(self):
|
|
||||||
"""
|
|
||||||
Test a scenario when a middleware that returns a response is run before
|
|
||||||
the `CorsMiddleware`. In this case
|
|
||||||
`CorsMiddleware.process_response()` should ignore the request.
|
|
||||||
"""
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
middleware=[
|
|
||||||
"spiderweb.tests.middleware.InterruptingMiddleware",
|
|
||||||
"spiderweb.middleware.cors.CorsMiddleware",
|
|
||||||
],
|
|
||||||
cors_allow_credentials=True,
|
|
||||||
cors_allowed_origins=["https://example.com"],
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_get_short_circuit_should_be_ignored(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
middleware=[
|
|
||||||
"spiderweb.tests.middleware.InterruptingMiddleware",
|
|
||||||
"spiderweb.middleware.cors.CorsMiddleware",
|
|
||||||
],
|
|
||||||
cors_urls_regex=r"^/foo/$",
|
|
||||||
cors_allowed_origins=["https://example.com"],
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_get_regex_matches(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_urls_regex=r"^/foo$",
|
|
||||||
cors_allowed_origins=["https://example.com"],
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
|
|
||||||
environ["PATH_INFO"] = "/foo"
|
|
||||||
environ["REQUEST_METHOD"] = "GET"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert ACCESS_CONTROL_ALLOW_ORIGIN in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_get_regex_doesnt_match(self):
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_urls_regex=r"^/not-foo/$",
|
|
||||||
cors_allowed_origins=["https://example.com"],
|
|
||||||
)
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
|
|
||||||
environ["PATH_INFO"] = "/foo"
|
|
||||||
environ["REQUEST_METHOD"] = "GET"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
|
|
||||||
|
|
||||||
def test_works_if_view_deletes_cors_enabled(self):
|
|
||||||
"""
|
|
||||||
Just in case something crazy happens in the view or other middleware,
|
|
||||||
check that get_response doesn't fall over if `_cors_enabled` is removed
|
|
||||||
"""
|
|
||||||
def yeet(request):
|
|
||||||
del request._cors_enabled
|
|
||||||
return HttpResponse("hahaha")
|
|
||||||
|
|
||||||
app, environ, start_response = setup(
|
|
||||||
**self.middleware,
|
|
||||||
cors_allowed_origins=["https://example.com"],
|
|
||||||
routes=[('/yeet', yeet)]
|
|
||||||
)
|
|
||||||
|
|
||||||
environ["HTTP_ORIGIN"] = "https://example.com"
|
|
||||||
environ["PATH_INFO"] = "/yeet"
|
|
||||||
environ["REQUEST_METHOD"] = "GET"
|
|
||||||
app(environ, start_response)
|
|
||||||
|
|
||||||
assert ACCESS_CONTROL_ALLOW_ORIGIN in start_response.get_headers()
|
|
||||||
|
@ -75,13 +75,15 @@ def test_redirect_response():
|
|||||||
|
|
||||||
|
|
||||||
def test_add_route_at_server_start():
|
def test_add_route_at_server_start():
|
||||||
|
app, environ, start_response = setup()
|
||||||
|
|
||||||
def index(request):
|
def index(request):
|
||||||
return RedirectResponse(location="/redirected")
|
return RedirectResponse(location="/redirected")
|
||||||
|
|
||||||
def view2(request):
|
def view2(request):
|
||||||
return HttpResponse("View 2")
|
return HttpResponse("View 2")
|
||||||
|
|
||||||
app, environ, start_response = setup(
|
app = SpiderwebRouter(
|
||||||
routes=[
|
routes=[
|
||||||
("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}),
|
("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}),
|
||||||
("/view2", view2),
|
("/view2", view2),
|
||||||
@ -93,7 +95,8 @@ def test_add_route_at_server_start():
|
|||||||
|
|
||||||
|
|
||||||
def test_redirect_on_append_slash():
|
def test_redirect_on_append_slash():
|
||||||
app, environ, start_response = setup(append_slash=True)
|
_, environ, start_response = setup()
|
||||||
|
app = SpiderwebRouter(append_slash=True)
|
||||||
|
|
||||||
@app.route("/hello")
|
@app.route("/hello")
|
||||||
def index(request):
|
def index(request):
|
||||||
@ -106,7 +109,9 @@ def test_redirect_on_append_slash():
|
|||||||
|
|
||||||
@given(st.text())
|
@given(st.text())
|
||||||
def test_template_response_with_template(text):
|
def test_template_response_with_template(text):
|
||||||
app, environ, start_response = setup(templates_dirs=["spiderweb/tests"])
|
_, environ, start_response = setup()
|
||||||
|
|
||||||
|
app = SpiderwebRouter(templates_dirs=["spiderweb/tests"])
|
||||||
|
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
def index(request):
|
def index(request):
|
||||||
@ -169,10 +174,11 @@ def test_duplicate_error_view():
|
|||||||
|
|
||||||
|
|
||||||
def test_missing_view_with_custom_404_alt():
|
def test_missing_view_with_custom_404_alt():
|
||||||
|
_, environ, start_response = setup()
|
||||||
|
|
||||||
def custom_404(request):
|
def custom_404(request):
|
||||||
return HttpResponse("Custom 404 2")
|
return HttpResponse("Custom 404 2")
|
||||||
|
|
||||||
app, environ, start_response = setup(error_routes={404: custom_404})
|
app = SpiderwebRouter(error_routes={404: custom_404})
|
||||||
|
|
||||||
assert app(environ, start_response) == [b"Custom 404 2"]
|
assert app(environ, start_response) == [b"Custom 404 2"]
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from spiderweb import HttpResponse
|
|
||||||
from spiderweb.decorators import csrf_exempt
|
from spiderweb.decorators import csrf_exempt
|
||||||
from spiderweb.response import JsonResponse, TemplateResponse
|
from spiderweb.response import JsonResponse, TemplateResponse
|
||||||
|
|
||||||
@ -39,11 +38,3 @@ def form_view_with_csrf(request):
|
|||||||
return JsonResponse(data=request.POST)
|
return JsonResponse(data=request.POST)
|
||||||
else:
|
else:
|
||||||
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)
|
return TemplateResponse(request, template_string=EXAMPLE_HTML_FORM_WITH_CSRF)
|
||||||
|
|
||||||
|
|
||||||
def text_view(request):
|
|
||||||
return HttpResponse("Hi!")
|
|
||||||
|
|
||||||
|
|
||||||
def unauthorized_view(request):
|
|
||||||
return HttpResponse("Unauthorized", status_code=401)
|
|
@ -67,35 +67,15 @@ def is_jsonable(data: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class Headers(dict):
|
class Headers(dict):
|
||||||
# special dict that forces lowercase and snake_case for all keys
|
# special dict that forces lowercase for all keys
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
key = key.replace("-", "_")
|
return super().__getitem__(key.lower())
|
||||||
try:
|
|
||||||
regular = super().__getitem__(key.lower())
|
|
||||||
except KeyError:
|
|
||||||
regular = None
|
|
||||||
try:
|
|
||||||
http_version = super().__getitem__(f"http_{key.lower()}")
|
|
||||||
except KeyError:
|
|
||||||
http_version = None
|
|
||||||
return regular or http_version
|
|
||||||
|
|
||||||
def __contains__(self, item):
|
|
||||||
item = item.lower().replace("-", "_")
|
|
||||||
|
|
||||||
regular = super().__contains__(item)
|
|
||||||
http = super().__contains__(f"http_{item}")
|
|
||||||
|
|
||||||
return regular or http
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
return super().__setitem__(key.lower().replace("-", "_"), value)
|
return super().__setitem__(key.lower(), value)
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
key = key.replace("-", "_")
|
return super().get(key.lower(), default)
|
||||||
regular = super().get(key.lower(), default)
|
|
||||||
http_version = super().get(f"http_{key.lower()}", default)
|
|
||||||
return regular or http_version
|
|
||||||
|
|
||||||
def setdefault(self, key, default=None):
|
def setdefault(self, key, default=None):
|
||||||
return super().setdefault(key.lower(), default)
|
return super().setdefault(key.lower(), default)
|
||||||
|
Loading…
Reference in New Issue
Block a user