add ability to adjust headers in post_process

This commit is contained in:
Joe Kaufeld 2024-10-16 17:26:22 -04:00
parent 12f6c726c9
commit 7ac76883fc
7 changed files with 104 additions and 28 deletions

View File

@ -1,3 +1,5 @@
from spiderweb import HttpResponse
# writing your own middleware
Sometimes you want to run the same code on every request or every response (or both!). Lots of processing happens in the middleware layer, and if you want to write your own, all you have to do is write a quick class and put it in a place that Spiderweb can find it. A piece of middleware only needs two things to be successful:
@ -55,13 +57,19 @@ Unlike `process_request`, returning a value here doesn't change anything. We're
This is a helper function that is available for you to override; it's not often used by middleware, but there are some ([like the pydantic middleware](middleware/pydantic.md)) that call `on_error` when there is a validation failure.
## post_process(self, request: Request, rendered_response: str) -> str:
## post_process(self, request: Request, response: HttpResponse, rendered_response: str) -> str:
> New in 1.3.0!
After `process_request` and `process_response` run, the response is rendered out into the raw text that is going to be sent to the client. Right before that happens, `post_process` is called on each middleware in the same order as `process_response` (so the closer something is to the beginning of the middleware list, the more important it is).
Note that this function *must* return something. Each invocation of `post_process` overwrites the entire output of the response, so make sure to return everything that you want to send. For example, here's a middleware that ~~breaks~~ adjusts the capitalization of the response and also demonstrates passing variables into the middleware:
There are three things passed to `post_process`:
- `request`: the request object. It's provided here purely for reference purposes; while you can technically change it here, it won't have any effect on the response.
- `response`: the response object. The full HTML of the response has already been rendered, but the headers can still be modified here. This object can be modified in place, like in `process_response`.
- `rendered_response`: the full HTML of the response as a string. This is the final output that will be sent to the client. Every instance of `post_process` must return the full HTML of the response, so if you want to make changes, you'll need to return the modified string.
Note that this function *must* return the full HTML of the response (provided at the start as `rendered_response`. Each invocation of `post_process` overwrites the entire output of the response, so make sure to return everything that you want to send. For example, here's a middleware that ~~breaks~~ adjusts the capitalization of the response and also demonstrates passing variables into the middleware and modifies the headers with the type of transformation:
```python
import random
@ -74,7 +82,7 @@ from spiderweb.exceptions import ConfigError
class CaseTransformMiddleware(SpiderwebMiddleware):
# this breaks everything, but it's hilarious so it's worth it.
# Blame Sam.
def post_process(self, request: Request, rendered_response: str) -> str:
def post_process(self, request: Request, response: HttpResponse, rendered_response: str) -> str:
valid_options = ["spongebob", "random"]
# grab the value from the extra data passed into the server object
# during instantiation
@ -86,12 +94,14 @@ class CaseTransformMiddleware(SpiderwebMiddleware):
)
if method == "spongebob":
response.headers["X-Case-Transform"] = "spongebob"
return "".join(
char.upper()
if i % 2 == 0
else char.lower() for i, char in enumerate(rendered_response)
)
else:
response.headers["X-Case-Transform"] = "random"
return "".join(
char.upper()
if random.random() > 0.5

View File

@ -32,9 +32,15 @@ class ExplodingMiddleware(SpiderwebMiddleware):
class CaseTransformMiddleware(SpiderwebMiddleware):
# this breaks everything, but it's hilarious so it's worth it.
# Blame Sam.
def post_process(self, request: Request, rendered_response: str) -> str:
def post_process(
self, request: Request, response: HttpResponse, rendered_response: str
) -> str:
valid_options = ["spongebob", "random"]
method = self.server.extra_data.get("case_transform_middleware_type", "spongebob")
# grab the value from the extra data passed into the server object
# during instantiation
method = self.server.extra_data.get(
"case_transform_middleware_type", "spongebob"
)
if method not in valid_options:
raise ConfigError(
f"Invalid method '{method}' for CaseTransformMiddleware."
@ -42,10 +48,14 @@ class CaseTransformMiddleware(SpiderwebMiddleware):
)
if method == "spongebob":
response.headers["X-Case-Transform"] = "spongebob"
return "".join(
char.upper() if i % 2 == 0 else char.lower() for i, char in enumerate(rendered_response)
char.upper() if i % 2 == 0 else char.lower()
for i, char in enumerate(rendered_response)
)
else:
response.headers["X-Case-Transform"] = "random"
return "".join(
char.upper() if random.random() > 0.5 else char for char in rendered_response
char.upper() if random.random() > 0.5 else char
for char in rendered_response
)

View File

@ -201,6 +201,17 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
def fire_response(self, start_response, request: Request, resp: HttpResponse):
try:
try:
rendered_output: str = resp.render()
final_output: str | list[str] = self.post_process_middleware(
request, resp, rendered_output
)
except Exception as e:
self.log.error("Fatal error!")
self.log.error(e)
self.log.error(traceback.format_exc())
return [f"Internal Server Error: {e}".encode(DEFAULT_ENCODING)]
status = get_http_status_by_code(resp.status_code)
cookies = []
varies = []
@ -218,24 +229,13 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
for v in varies:
headers.append(("vary", str(v)))
start_response(status, headers)
try:
rendered_output: str = resp.render()
final_output: str | list[str] = self.post_process_middleware(request, rendered_output)
except Exception as e:
self.log.error("Fatal error!")
self.log.error(e)
self.log.error(traceback.format_exc())
return [f"Internal Server Error: {e}".encode(DEFAULT_ENCODING)]
if not isinstance(final_output, list):
final_output = [final_output]
encoded_resp = [
chunk.encode(DEFAULT_ENCODING) if isinstance(chunk, str) else chunk
for chunk in final_output
]
start_response(status, headers)
return encoded_resp
except APIError:
raise

View File

@ -61,13 +61,17 @@ class MiddlewareMixin:
self.middleware.remove(middleware)
continue
def post_process_middleware(self, request: Request, response: str) -> str:
def post_process_middleware(
self, request: Request, response: HttpResponse, rendered_response: str
) -> str:
# run them in reverse order, same as process_response. The top of the middleware
# stack should be the first and last middleware to run.
for middleware in reversed(self.middleware):
try:
response = middleware.post_process(request, response)
rendered_response = middleware.post_process(
request, response, rendered_response
)
except UnusedMiddleware:
self.middleware.remove(middleware)
continue
return response
return rendered_response

View File

@ -29,9 +29,7 @@ class SpiderwebMiddleware:
# the request and return a response immediately.
pass
def process_response(
self, request: Request, response: HttpResponse
) -> None:
def process_response(self, request: Request, response: HttpResponse) -> None:
# This method is called after the view has returned a response. You can modify
# the response in this method. The response will be returned to the client after
# all middleware has been processed.
@ -43,7 +41,9 @@ class SpiderwebMiddleware:
# will be re-raised.
pass
def post_process(self, request: Request, rendered_response: str) -> str:
def post_process(
self, request: Request, response: HttpResponse, rendered_response: str
) -> str:
# This method is called after all the middleware has been processed and receives
# the final rendered response in str form. You can modify the response here.
return rendered_response

View File

@ -19,5 +19,22 @@ class InterruptingMiddleware(SpiderwebMiddleware):
class PostProcessingMiddleware(SpiderwebMiddleware):
def post_process(self, request: Request, response: str) -> str:
return response + " Moo!"
def post_process(
self, request: Request, response: HttpResponse, rendered_response: str
) -> str:
return rendered_response + " Moo!"
class PostProcessingWithHeaderManipulation(SpiderwebMiddleware):
def post_process(
self, request: Request, response: HttpResponse, rendered_response: str
) -> str:
response.headers["X-Moo"] = "true"
return rendered_response
class ExplodingPostProcessingMiddleware(SpiderwebMiddleware):
def post_process(
self, request: Request, response: HttpResponse, rendered_response: str
) -> str:
raise UnusedMiddleware("Unfinished!")

View File

@ -314,6 +314,41 @@ def test_post_process_middleware():
assert app(environ, start_response) == [bytes("Hi! Moo!", DEFAULT_ENCODING)]
def test_post_process_header_manip():
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.PostProcessingWithHeaderManipulation",
],
)
app.add_route("/", text_view)
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "/"
environ["REQUEST_METHOD"] = "GET"
assert app(environ, start_response) == [bytes("Hi!", DEFAULT_ENCODING)]
assert start_response.get_headers()["x-moo"] == "true"
def test_unused_post_process_middleware():
app, environ, start_response = setup(
middleware=[
"spiderweb.tests.middleware.ExplodingPostProcessingMiddleware",
],
)
app.add_route("/", text_view)
environ["HTTP_USER_AGENT"] = "hi"
environ["REMOTE_ADDR"] = "/"
environ["REQUEST_METHOD"] = "GET"
assert app(environ, start_response) == [bytes("Hi!", DEFAULT_ENCODING)]
# make sure it kicked out the middleware and isn't just ignoring it
assert len(app.middleware) == 0
class TestCorsMiddleware:
# adapted from:
# https://github.com/adamchainz/django-cors-headers/blob/main/tests/test_middleware.py