🎨 reformat tests to remove some duplicate lines

This commit is contained in:
Joe Kaufeld 2024-09-02 10:50:07 -04:00
parent 8cdc6eef44
commit 5cf9dff13a
5 changed files with 25 additions and 43 deletions

View File

@ -23,6 +23,7 @@ class VerifyValidCorsSetting(ServerCheck):
" `cors_allowed_origins`, `cors_allowed_origin_regexes`, or" " `cors_allowed_origins`, `cors_allowed_origin_regexes`, or"
" `cors_allow_all_origins`.", " `cors_allow_all_origins`.",
) )
def check(self): def check(self):
# - `cors_allowed_origins` # - `cors_allowed_origins`
# - `cors_allowed_origin_regexes` # - `cors_allowed_origin_regexes`
@ -150,9 +151,8 @@ class CorsMiddleware(SpiderwebMiddleware):
self.add_response_headers(request, resp) self.add_response_headers(request, resp)
return resp return resp
def process_response( def process_response(self, request: Request, response: HttpResponse) -> None:
self, request: Request, response: HttpResponse
) -> None:
self.add_response_headers(request, response) self.add_response_headers(request, response)
# [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code # [204]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code

View File

@ -72,7 +72,9 @@ class CSRFMiddleware(SpiderwebMiddleware):
def is_trusted_origin(self, request) -> bool: def is_trusted_origin(self, request) -> bool:
origin = request.headers.get("http_origin") origin = request.headers.get("http_origin")
referrer = request.headers.get("http_referer") or request.headers.get("http_referrer") referrer = request.headers.get("http_referer") or request.headers.get(
"http_referrer"
)
host = request.headers.get("http_host") host = request.headers.get("http_host")
if not origin and not (host == referrer): if not origin and not (host == referrer):

View File

@ -18,11 +18,13 @@ class StartResponse:
return {h[0]: h[1] for h in self.headers} return {h[0]: h[1] for h in self.headers}
def setup(): def setup(**kwargs):
environ = {} environ = {}
setup_testing_defaults(environ) setup_testing_defaults(environ)
if "db" not in kwargs:
kwargs["db"] = SqliteDatabase("spiderweb-tests.db")
return ( return (
SpiderwebRouter(db=SqliteDatabase("spiderweb-tests.db")), SpiderwebRouter(**kwargs),
environ, environ,
StartResponse(), StartResponse(),
) )

View File

@ -4,7 +4,7 @@ from datetime import timedelta
import pytest import pytest
from peewee import SqliteDatabase from peewee import SqliteDatabase
from spiderweb import SpiderwebRouter, HttpResponse, ConfigError, StartupErrors from spiderweb import SpiderwebRouter, HttpResponse, StartupErrors
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.middleware.sessions import Session from spiderweb.middleware.sessions import Session
from spiderweb.middleware import csrf from spiderweb.middleware import csrf
@ -37,10 +37,8 @@ def index(request):
def test_session_middleware(): def test_session_middleware():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=["spiderweb.middleware.sessions.SessionMiddleware"], middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", index) app.add_route("/", index)
@ -58,10 +56,8 @@ def test_session_middleware():
def test_expired_session(): def test_expired_session():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=["spiderweb.middleware.sessions.SessionMiddleware"], middleware=["spiderweb.middleware.sessions.SessionMiddleware"],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", index) app.add_route("/", index)
@ -85,13 +81,11 @@ def test_expired_session():
def test_exploding_middleware(): def test_exploding_middleware():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.tests.middleware.ExplodingRequestMiddleware", "spiderweb.tests.middleware.ExplodingRequestMiddleware",
"spiderweb.tests.middleware.ExplodingResponseMiddleware", "spiderweb.tests.middleware.ExplodingResponseMiddleware",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", index) app.add_route("/", index)
@ -102,7 +96,6 @@ def test_exploding_middleware():
def test_csrf_middleware_without_session_middleware(): def test_csrf_middleware_without_session_middleware():
_, environ, start_response = setup()
with pytest.raises(StartupErrors) as e: with pytest.raises(StartupErrors) as e:
SpiderwebRouter( SpiderwebRouter(
middleware=["spiderweb.middleware.csrf.CSRFMiddleware"], middleware=["spiderweb.middleware.csrf.CSRFMiddleware"],
@ -116,15 +109,14 @@ def test_csrf_middleware_without_session_middleware():
def test_csrf_middleware_above_session_middleware(): def test_csrf_middleware_above_session_middleware():
_, environ, start_response = setup()
with pytest.raises(StartupErrors) as e: with pytest.raises(StartupErrors) as e:
SpiderwebRouter( app, environ, start_response = setup(
middleware=[ middleware=[
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
exceptiongroup = e.value.args[1] exceptiongroup = e.value.args[1]
assert ( assert (
exceptiongroup[0].args[0] exceptiongroup[0].args[0]
@ -133,13 +125,11 @@ def test_csrf_middleware_above_session_middleware():
def test_csrf_middleware(): def test_csrf_middleware():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", form_view_with_csrf, ["GET", "POST"]) app.add_route("/", form_view_with_csrf, ["GET", "POST"])
@ -198,14 +188,13 @@ def test_csrf_middleware():
def test_csrf_expired_token(): def test_csrf_expired_token():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.middleware[1].CSRF_EXPIRY = -1 app.middleware[1].CSRF_EXPIRY = -1
app.add_route("/", form_view_with_csrf, ["GET", "POST"]) app.add_route("/", form_view_with_csrf, ["GET", "POST"])
@ -235,13 +224,11 @@ def test_csrf_expired_token():
def test_csrf_exempt(): def test_csrf_exempt():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", form_csrf_exempt, ["GET", "POST"]) app.add_route("/", form_csrf_exempt, ["GET", "POST"])
@ -268,8 +255,7 @@ def test_csrf_exempt():
def test_csrf_trusted_origins(): def test_csrf_trusted_origins():
_, environ, start_response = setup() app, environ, start_response = setup(
app = SpiderwebRouter(
middleware=[ middleware=[
"spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
@ -277,9 +263,7 @@ def test_csrf_trusted_origins():
csrf_trusted_origins=[ csrf_trusted_origins=[
"example.com", "example.com",
], ],
db=SqliteDatabase("spiderweb-tests.db"),
) )
app.add_route("/", form_view_without_csrf, ["GET", "POST"]) app.add_route("/", form_view_without_csrf, ["GET", "POST"])
environ["HTTP_USER_AGENT"] = "hi" environ["HTTP_USER_AGENT"] = "hi"

View File

@ -75,15 +75,13 @@ def test_redirect_response():
def test_add_route_at_server_start(): def test_add_route_at_server_start():
app, environ, start_response = setup()
def index(request): def index(request):
return RedirectResponse(location="/redirected") return RedirectResponse(location="/redirected")
def view2(request): def view2(request):
return HttpResponse("View 2") return HttpResponse("View 2")
app = SpiderwebRouter( app, environ, start_response = setup(
routes=[ routes=[
("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}), ("/", index, {"allowed_methods": ["GET", "POST"], "csrf_exempt": True}),
("/view2", view2), ("/view2", view2),
@ -95,8 +93,7 @@ def test_add_route_at_server_start():
def test_redirect_on_append_slash(): def test_redirect_on_append_slash():
_, environ, start_response = setup() app, environ, start_response = setup(append_slash=True)
app = SpiderwebRouter(append_slash=True)
@app.route("/hello") @app.route("/hello")
def index(request): def index(request):
@ -109,9 +106,7 @@ def test_redirect_on_append_slash():
@given(st.text()) @given(st.text())
def test_template_response_with_template(text): def test_template_response_with_template(text):
_, environ, start_response = setup() app, environ, start_response = setup(templates_dirs=["spiderweb/tests"])
app = SpiderwebRouter(templates_dirs=["spiderweb/tests"])
@app.route("/") @app.route("/")
def index(request): def index(request):
@ -174,11 +169,10 @@ def test_duplicate_error_view():
def test_missing_view_with_custom_404_alt(): def test_missing_view_with_custom_404_alt():
_, environ, start_response = setup()
def custom_404(request): def custom_404(request):
return HttpResponse("Custom 404 2") return HttpResponse("Custom 404 2")
app = SpiderwebRouter(error_routes={404: custom_404}) app, environ, start_response = setup(error_routes={404: custom_404})
assert app(environ, start_response) == [b"Custom 404 2"] assert app(environ, start_response) == [b"Custom 404 2"]