spiderweb/spiderweb/tests/test_middleware.py
2024-10-30 00:32:02 -04:00

907 lines
32 KiB
Python

from io import BytesIO, BufferedReader
from datetime import timedelta
import pytest
from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse, StartupErrors, ConfigError
from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.cors import (
ACCESS_CONTROL_ALLOW_ORIGIN,
ACCESS_CONTROL_ALLOW_HEADERS,
ACCESS_CONTROL_ALLOW_METHODS,
ACCESS_CONTROL_EXPOSE_HEADERS,
ACCESS_CONTROL_ALLOW_CREDENTIALS,
ACCESS_CONTROL_MAX_AGE,
ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK,
)
from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf
from spiderweb.tests.helpers import setup
from spiderweb.tests.views_for_tests import (
form_view_with_csrf,
form_csrf_exempt,
form_view_without_csrf,
text_view,
unauthorized_view,
file_view,
)
from spiderweb.middleware.gzip import (
CheckValidGzipMinimumLength,
CheckValidGzipCompressionLevel,
)
def index(request):
if "value" in request.SESSION:
request.SESSION["value"] += 1
else:
request.SESSION["value"] = 0
return HttpResponse(body=str(request.SESSION["value"]))
def test_session_middleware():
app, environ, start_response = setup(
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
)
app.add_route("/", index)
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
assert app(environ, start_response) == [bytes(str(0), DEFAULT_ENCODING)]
session_key = Session.select().first().session_key
environ["HTTP_COOKIE"] = f"swsession={session_key}"
assert app(environ, start_response) == [bytes(str(1), DEFAULT_ENCODING)]
assert app(environ, start_response) == [bytes(str(2), DEFAULT_ENCODING)]
def test_expired_session():
app, environ, start_response = setup(
middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
)
app.add_route("/", index)
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
assert app(environ, start_response) == [bytes(str(0), DEFAULT_ENCODING)]
session = Session.select().first()
session.created_at = session.created_at - timedelta(seconds=app.session_max_age)
session.save()
environ["HTTP_COOKIE"] = f"swsession={session.session_key}"
# it shouldn't increment because we get a new session
assert app(environ, start_response) == [bytes(str(0), DEFAULT_ENCODING)]
session2 = list(Session.select())[-1]
assert session2.session_key != session.session_key
def test_exploding_middleware():
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.ExplodingRequestMiddleware",
"spiderweb.tests.middleware.ExplodingResponseMiddleware",
],
)
app.add_route("/", index)
assert app(environ, start_response) == [bytes(str(0), DEFAULT_ENCODING)]
# make sure it kicked out the middleware and isn't just ignoring it
assert len(app.middleware) == 0
def test_invalid_middleware():
with pytest.raises(ConfigError) as e:
SpiderwebRouter(middleware=["nonexistent.middleware"])
assert e.value.args[0] == "Middleware 'nonexistent.middleware' not found."
def test_csrf_middleware_without_session_middleware():
with pytest.raises(StartupErrors) as e:
SpiderwebRouter(
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"),
)
exceptiongroup = e.value.args[1]
assert (
exceptiongroup[0].args[0]
== csrf.CheckForSessionMiddleware.SESSION_MIDDLEWARE_NOT_FOUND
)
def test_csrf_middleware_above_session_middleware():
with pytest.raises(StartupErrors) as e:
app, environ, start_response = setup(
middleware=[
"spiderweb.middleware.csrf.CSRFMiddleware",
"spiderweb.middleware.sessions.SessionMiddleware",
],
)
exceptiongroup = e.value.args[1]
assert (
exceptiongroup[0].args[0]
== csrf.VerifyCorrectMiddlewarePlacement.SESSION_MIDDLEWARE_BELOW_CSRF
)
def test_csrf_middleware():
app, environ, start_response = setup(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
)
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "<form" in resp
assert '<input type="hidden" name="csrf_token"' in resp
token = resp.split('value="')[1].split('"')[0]
formdata = f"name=bob&csrf_token={token}"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["HTTP_COOKIE"] = (
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
)
environ["REQUEST_METHOD"] = "POST"
environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata)
# setup form data
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "bob" in resp2
# test that it raises a CSRF error on wrong token
formdata = "name=bob&csrf_token=badtoken"
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_X_CSRF_TOKEN"] = None
resp3 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp3
# test that the wrong session also raises a CSRF error
token = app.decrypt(token).split("::")[0]
token = app.encrypt(f"{token}::badsession").decode(DEFAULT_ENCODING)
formdata = f"name=bob&csrf_token={token}"
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp4 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp4
def test_csrf_expired_token():
app, environ, start_response = setup(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
)
app.middleware[1].CSRF_EXPIRY = -1
app.add_route("/", form_view_with_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
token = resp.split('value="')[1].split('"')[0]
formdata = f"name=bob&csrf_token={token}"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["HTTP_COOKIE"] = (
f"swsession={[i for i in Session.select().dicts()][-1]['session_key']}"
)
environ["REQUEST_METHOD"] = "POST"
environ["HTTP_ORIGIN"] = "example.com"
environ["HTTP_X_CSRF_TOKEN"] = token
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp
def test_csrf_exempt():
app, environ, start_response = setup(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
)
app.add_route("/", form_csrf_exempt, ["GET", "POST"])
app.add_route("/2", form_view_without_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["REQUEST_METHOD"] = "POST"
formdata = "name=bob"
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "bob" in resp
environ["PATH_INFO"] = "/2"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp2
def test_csrf_trusted_origins():
app, environ, start_response = setup(
middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware",
],
csrf_trusted_origins=[
"example.com",
],
)
app.add_route("/", form_view_without_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "1.1.1.1"
environ["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
environ["REQUEST_METHOD"] = "POST"
formdata = "name=bob"
environ["CONTENT_LENGTH"] = len(formdata)
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "notvalid.com"
resp = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert "CSRF token is invalid" in resp
b_handle = BytesIO()
b_handle.write(formdata.encode(DEFAULT_ENCODING))
b_handle.seek(0)
environ["wsgi.input"] = BufferedReader(b_handle)
environ["HTTP_ORIGIN"] = "example.com"
resp2 = app(environ, start_response)[0].decode(DEFAULT_ENCODING)
assert resp2 == '{"name": "bob"}'
def test_post_process_middleware():
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.PostProcessingMiddleware",
],
)
app.add_route("/", text_view)
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "/"
environ["REQUEST_METHOD"] = "GET"
assert app(environ, start_response) == [bytes("Hi! Moo!", DEFAULT_ENCODING)]
def test_post_process_header_manip():
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.PostProcessingWithHeaderManipulation",
],
)
app.add_route("/", text_view)
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "/"
environ["REQUEST_METHOD"] = "GET"
assert app(environ, start_response) == [bytes("Hi!", DEFAULT_ENCODING)]
assert start_response.get_headers()["x-moo"] == "true"
def test_unused_post_process_middleware():
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.ExplodingPostProcessingMiddleware",
],
)
app.add_route("/", text_view)
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "/"
environ["REQUEST_METHOD"] = "GET"
assert app(environ, start_response) == [bytes("Hi!", DEFAULT_ENCODING)]
# make sure it kicked out the middleware and isn't just ignoring it
assert len(app.middleware) == 0
class TestGzipMiddleware:
middleware = {"middleware": ["spiderweb.middleware.gzip.GzipMiddleware"]}
def test_not_enabled_on_small_response(self):
app, environ, start_response = setup(
**self.middleware,
gzip_minimum_response_length=500,
)
app.add_route("/", text_view)
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "/"
environ["REQUEST_METHOD"] = "GET"
assert app(environ, start_response) == [bytes("Hi!", DEFAULT_ENCODING)]
assert "Content-Encoding" not in start_response.get_headers()
def test_changing_minimum_response_length(self):
app, environ, start_response = setup(
**self.middleware,
gzip_minimum_response_length=1,
)
app.add_route("/", text_view)
environ["HTTP_ACCEPT_ENCODING"] = "gzip"
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "/"
environ["REQUEST_METHOD"] = "GET"
assert str(app(environ, start_response)[0]).startswith("b'\\x1f\\x8b\\x08")
assert "content-encoding" in start_response.get_headers()
def test_not_enabled_on_error_response(self):
app, environ, start_response = setup(
**self.middleware,
gzip_minimum_response_length=1,
)
app.add_route("/", unauthorized_view)
environ["HTTP_ACCEPT_ENCODING"] = "gzip"
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "/"
environ["REQUEST_METHOD"] = "GET"
assert app(environ, start_response) == [bytes("Unauthorized", DEFAULT_ENCODING)]
assert "content-encoding" not in start_response.get_headers()
def test_not_enabled_on_bytes_response(self):
app, environ, start_response = setup(
**self.middleware,
gzip_minimum_response_length=1,
)
# send a file that's already in bytes form
app.add_route("/", file_view)
environ["HTTP_ACCEPT_ENCODING"] = "gzip"
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "/"
environ["REQUEST_METHOD"] = "GET"
assert app(environ, start_response) == [bytes("hi", DEFAULT_ENCODING)]
assert "content-encoding" not in start_response.get_headers()
def test_invalid_response_length(self):
class FakeServer:
gzip_minimum_response_length = "asdf"
with pytest.raises(ConfigError) as e:
CheckValidGzipMinimumLength(server=FakeServer).check()
assert (
e.value.args[0] == CheckValidGzipMinimumLength.INVALID_GZIP_MINIMUM_LENGTH
)
def test_negative_response_length(self):
class FakeServer:
gzip_minimum_response_length = -1
with pytest.raises(ConfigError) as e:
CheckValidGzipMinimumLength(server=FakeServer).check()
assert (
e.value.args[0] == CheckValidGzipMinimumLength.INVALID_GZIP_MINIMUM_LENGTH
)
def test_bad_compression_level(self):
class FakeServer:
gzip_compression_level = "asdf"
with pytest.raises(ConfigError) as e:
CheckValidGzipCompressionLevel(server=FakeServer).check()
assert (
e.value.args[0]
== CheckValidGzipCompressionLevel.INVALID_GZIP_COMPRESSION_LEVEL
)
class TestCorsMiddleware:
# adapted from:
# https://github.com/adamchainz/django-cors-headers/blob/main/tests/test_middleware.py
# to make sure I didn't miss anything
middleware = {"middleware": ["spiderweb.middleware.cors.CorsMiddleware"]}
def test_get_no_origin(self):
app, environ, start_response = setup(
**self.middleware, cors_allow_all_origins=True
)
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_origin_vary_by_default(self):
app, environ, start_response = setup(
**self.middleware, cors_allow_all_origins=True
)
app(environ, start_response)
assert start_response.get_headers()["vary"] == "origin"
def test_get_invalid_origin(self):
app, environ, start_response = setup(
**self.middleware, cors_allow_all_origins=True
)
environ["HTTP_ORIGIN"] = "https://example.com]"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_not_in_allowed_origins(self):
app, environ, start_response = setup(
**self.middleware, cors_allowed_origins=["https://example.com"]
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_not_in_allowed_origins_due_to_wrong_scheme(self):
app, environ, start_response = setup(
**self.middleware, cors_allowed_origins=["http://example.org"]
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_in_allowed_origins(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com", "https://example.org"],
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.org"
)
def test_null_in_allowed_origins(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com", "null"],
)
environ["HTTP_ORIGIN"] = "null"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "null"
def test_file_in_allowed_origins(self):
"""
'file://' should be allowed as an origin since Chrome on Android
mistakenly sends it
"""
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com", "file://"],
)
environ["HTTP_ORIGIN"] = "file://"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN] == "file://"
def test_get_expose_headers(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
cors_expose_headers=["accept", "content-type"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert (
start_response.get_headers()[ACCESS_CONTROL_EXPOSE_HEADERS]
== "accept, content-type"
)
def test_get_dont_expose_headers(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_EXPOSE_HEADERS not in start_response.get_headers()
def test_get_allow_credentials(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
cors_allow_credentials=True,
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert start_response.get_headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS] == "true"
def test_get_allow_credentials_bad_origin(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
cors_allow_credentials=True,
)
environ["HTTP_ORIGIN"] = "https://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in start_response.get_headers()
def test_get_allow_credentials_disabled(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in start_response.get_headers()
def test_allow_private_network_added_if_enabled_and_requested(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allow_all_origins=True,
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response)
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] == "true"
)
def test_allow_private_network_not_added_if_enabled_and_not_requested(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allow_all_origins=True,
)
environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
def test_allow_private_network_not_added_if_enabled_and_no_cors_origin(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=True,
cors_allowed_origins=["http://example.com"],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.org"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
def test_allow_private_network_not_added_if_disabled_and_requested(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_private_network=False,
cors_allow_all_origins=True,
)
environ["HTTP_ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK"] = "true"
environ["HTTP_ORIGIN"] = "http://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK not in start_response.get_headers()
def test_options_allowed_origin(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_headers=["content-type"],
cors_allow_methods=["GET", "OPTIONS"],
cors_preflight_max_age=1002,
cors_allow_all_origins=True,
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
headers = start_response.get_headers()
assert start_response.status == "200 OK"
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
assert headers[ACCESS_CONTROL_MAX_AGE] == "1002"
def test_options_no_max_age(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_headers=["content-type"],
cors_allow_methods=["GET", "OPTIONS"],
cors_preflight_max_age=0,
cors_allow_all_origins=True,
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
headers = start_response.get_headers()
assert headers[ACCESS_CONTROL_ALLOW_HEADERS] == "content-type"
assert headers[ACCESS_CONTROL_ALLOW_METHODS] == "GET, OPTIONS"
assert ACCESS_CONTROL_MAX_AGE not in headers
def test_options_allowed_origins_with_port(self):
app, environ, start_response = setup(
**self.middleware, cors_allowed_origins=["https://localhost:9000"]
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://localhost:9000"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://localhost:9000"
)
def test_options_adds_origin_when_domain_found_in_allowed_regexes(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origin_regexes=[r"^https://\w+\.example\.com$"],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://foo.example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://foo.example.com"
)
def test_options_adds_origin_when_domain_found_in_allowed_regexes_second(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origin_regexes=[
r"^https://\w+\.example\.org$",
r"^https://\w+\.example\.com$",
],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://foo.example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://foo.example.com"
)
def test_options_doesnt_add_origin_when_domain_not_found_in_allowed_regexes(self):
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origin_regexes=[r"^https://\w+\.example\.org$"],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://foo.example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_options_empty_request_method(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_all_origins=True,
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = ""
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "200 OK"
def test_options_no_headers(self):
app, environ, start_response = setup(
**self.middleware, cors_allow_all_origins=True, routes=[("/", text_view)]
)
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "405 Method Not Allowed"
def test_allow_all_origins_get(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/", text_view)],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert start_response.status == "200 OK"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
assert start_response.get_headers()["vary"] == "origin"
def test_allow_all_origins_options(self):
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/", text_view)],
)
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["HTTP_ORIGIN"] = "https://example.com"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "200 OK"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
assert start_response.get_headers()["vary"] == "origin"
def test_non_200_headers_still_set(self):
"""
It's not clear whether the header should still be set for non-HTTP200
when not a preflight request. However, this is the existing behavior for
django-cors-middleware, and Spiderweb should mirror it.
"""
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/unauthorized", unauthorized_view)],
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["PATH_INFO"] = "/unauthorized"
app(environ, start_response)
assert start_response.status == "401 Unauthorized"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
def test_auth_view_options(self):
"""
Ensure HTTP200 and header still set, for preflight requests to views requiring
authentication. See: https://github.com/adamchainz/django-cors-headers/issues/3
"""
app, environ, start_response = setup(
**self.middleware,
cors_allow_credentials=True,
cors_allow_all_origins=True,
routes=[("/unauthorized", unauthorized_view)],
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["PATH_INFO"] = "/unauthorized"
environ["REQUEST_METHOD"] = "OPTIONS"
app(environ, start_response)
assert start_response.status == "200 OK"
assert (
start_response.get_headers()[ACCESS_CONTROL_ALLOW_ORIGIN]
== "https://example.com"
)
assert start_response.get_headers()["content-length"] == "0"
def test_get_short_circuit(self):
"""
Test a scenario when a middleware that returns a response is run before
the `CorsMiddleware`. In this case
`CorsMiddleware.process_response()` should ignore the request.
"""
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.InterruptingMiddleware",
"spiderweb.middleware.cors.CorsMiddleware",
],
cors_allow_credentials=True,
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_short_circuit_should_be_ignored(self):
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.InterruptingMiddleware",
"spiderweb.middleware.cors.CorsMiddleware",
],
cors_urls_regex=r"^/foo/$",
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_get_regex_matches(self):
app, environ, start_response = setup(
**self.middleware,
cors_urls_regex=r"^/foo$",
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["PATH_INFO"] = "/foo"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN in start_response.get_headers()
def test_get_regex_doesnt_match(self):
app, environ, start_response = setup(
**self.middleware,
cors_urls_regex=r"^/not-foo/$",
cors_allowed_origins=["https://example.com"],
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["HTTP_ACCESS_CONTROL_REQUEST_METHOD"] = "GET"
environ["PATH_INFO"] = "/foo"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN not in start_response.get_headers()
def test_works_if_view_deletes_cors_enabled(self):
"""
Just in case something crazy happens in the view or other middleware,
check that get_response doesn't fall over if `_cors_enabled` is removed
"""
def yeet(request):
del request._cors_enabled
return HttpResponse("hahaha")
app, environ, start_response = setup(
**self.middleware,
cors_allowed_origins=["https://example.com"],
routes=[("/yeet", yeet)],
)
environ["HTTP_ORIGIN"] = "https://example.com"
environ["PATH_INFO"] = "/yeet"
environ["REQUEST_METHOD"] = "GET"
app(environ, start_response)
assert ACCESS_CONTROL_ALLOW_ORIGIN in start_response.get_headers()