✨ add ability to adjust headers in post_process
This commit is contained in:
parent
12f6c726c9
commit
7ac76883fc
@ -1,3 +1,5 @@
|
||||
from spiderweb import HttpResponse
|
||||
|
||||
# writing your own middleware
|
||||
|
||||
Sometimes you want to run the same code on every request or every response (or both!). Lots of processing happens in the middleware layer, and if you want to write your own, all you have to do is write a quick class and put it in a place that Spiderweb can find it. A piece of middleware only needs two things to be successful:
|
||||
@ -55,13 +57,19 @@ Unlike `process_request`, returning a value here doesn't change anything. We're
|
||||
|
||||
This is a helper function that is available for you to override; it's not often used by middleware, but there are some ([like the pydantic middleware](middleware/pydantic.md)) that call `on_error` when there is a validation failure.
|
||||
|
||||
## post_process(self, request: Request, rendered_response: str) -> str:
|
||||
## post_process(self, request: Request, response: HttpResponse, rendered_response: str) -> str:
|
||||
|
||||
> New in 1.3.0!
|
||||
|
||||
After `process_request` and `process_response` run, the response is rendered out into the raw text that is going to be sent to the client. Right before that happens, `post_process` is called on each middleware in the same order as `process_response` (so the closer something is to the beginning of the middleware list, the more important it is).
|
||||
|
||||
Note that this function *must* return something. Each invocation of `post_process` overwrites the entire output of the response, so make sure to return everything that you want to send. For example, here's a middleware that ~~breaks~~ adjusts the capitalization of the response and also demonstrates passing variables into the middleware:
|
||||
There are three things passed to `post_process`:
|
||||
|
||||
- `request`: the request object. It's provided here purely for reference purposes; while you can technically change it here, it won't have any effect on the response.
|
||||
- `response`: the response object. The full HTML of the response has already been rendered, but the headers can still be modified here. This object can be modified in place, like in `process_response`.
|
||||
- `rendered_response`: the full HTML of the response as a string. This is the final output that will be sent to the client. Every instance of `post_process` must return the full HTML of the response, so if you want to make changes, you'll need to return the modified string.
|
||||
|
||||
Note that this function *must* return the full HTML of the response (provided at the start as `rendered_response`. Each invocation of `post_process` overwrites the entire output of the response, so make sure to return everything that you want to send. For example, here's a middleware that ~~breaks~~ adjusts the capitalization of the response and also demonstrates passing variables into the middleware and modifies the headers with the type of transformation:
|
||||
|
||||
```python
|
||||
import random
|
||||
@ -74,7 +82,7 @@ from spiderweb.exceptions import ConfigError
|
||||
class CaseTransformMiddleware(SpiderwebMiddleware):
|
||||
# this breaks everything, but it's hilarious so it's worth it.
|
||||
# Blame Sam.
|
||||
def post_process(self, request: Request, rendered_response: str) -> str:
|
||||
def post_process(self, request: Request, response: HttpResponse, rendered_response: str) -> str:
|
||||
valid_options = ["spongebob", "random"]
|
||||
# grab the value from the extra data passed into the server object
|
||||
# during instantiation
|
||||
@ -86,12 +94,14 @@ class CaseTransformMiddleware(SpiderwebMiddleware):
|
||||
)
|
||||
|
||||
if method == "spongebob":
|
||||
response.headers["X-Case-Transform"] = "spongebob"
|
||||
return "".join(
|
||||
char.upper()
|
||||
if i % 2 == 0
|
||||
else char.lower() for i, char in enumerate(rendered_response)
|
||||
)
|
||||
else:
|
||||
response.headers["X-Case-Transform"] = "random"
|
||||
return "".join(
|
||||
char.upper()
|
||||
if random.random() > 0.5
|
||||
|
@ -32,9 +32,15 @@ class ExplodingMiddleware(SpiderwebMiddleware):
|
||||
class CaseTransformMiddleware(SpiderwebMiddleware):
|
||||
# this breaks everything, but it's hilarious so it's worth it.
|
||||
# Blame Sam.
|
||||
def post_process(self, request: Request, rendered_response: str) -> str:
|
||||
def post_process(
|
||||
self, request: Request, response: HttpResponse, rendered_response: str
|
||||
) -> str:
|
||||
valid_options = ["spongebob", "random"]
|
||||
method = self.server.extra_data.get("case_transform_middleware_type", "spongebob")
|
||||
# grab the value from the extra data passed into the server object
|
||||
# during instantiation
|
||||
method = self.server.extra_data.get(
|
||||
"case_transform_middleware_type", "spongebob"
|
||||
)
|
||||
if method not in valid_options:
|
||||
raise ConfigError(
|
||||
f"Invalid method '{method}' for CaseTransformMiddleware."
|
||||
@ -42,10 +48,14 @@ class CaseTransformMiddleware(SpiderwebMiddleware):
|
||||
)
|
||||
|
||||
if method == "spongebob":
|
||||
response.headers["X-Case-Transform"] = "spongebob"
|
||||
return "".join(
|
||||
char.upper() if i % 2 == 0 else char.lower() for i, char in enumerate(rendered_response)
|
||||
char.upper() if i % 2 == 0 else char.lower()
|
||||
for i, char in enumerate(rendered_response)
|
||||
)
|
||||
else:
|
||||
response.headers["X-Case-Transform"] = "random"
|
||||
return "".join(
|
||||
char.upper() if random.random() > 0.5 else char for char in rendered_response
|
||||
char.upper() if random.random() > 0.5 else char
|
||||
for char in rendered_response
|
||||
)
|
||||
|
@ -201,6 +201,17 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
|
||||
def fire_response(self, start_response, request: Request, resp: HttpResponse):
|
||||
try:
|
||||
try:
|
||||
rendered_output: str = resp.render()
|
||||
final_output: str | list[str] = self.post_process_middleware(
|
||||
request, resp, rendered_output
|
||||
)
|
||||
except Exception as e:
|
||||
self.log.error("Fatal error!")
|
||||
self.log.error(e)
|
||||
self.log.error(traceback.format_exc())
|
||||
return [f"Internal Server Error: {e}".encode(DEFAULT_ENCODING)]
|
||||
|
||||
status = get_http_status_by_code(resp.status_code)
|
||||
cookies = []
|
||||
varies = []
|
||||
@ -218,24 +229,13 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi
|
||||
for v in varies:
|
||||
headers.append(("vary", str(v)))
|
||||
|
||||
start_response(status, headers)
|
||||
|
||||
try:
|
||||
rendered_output: str = resp.render()
|
||||
final_output: str | list[str] = self.post_process_middleware(request, rendered_output)
|
||||
except Exception as e:
|
||||
self.log.error("Fatal error!")
|
||||
self.log.error(e)
|
||||
self.log.error(traceback.format_exc())
|
||||
return [f"Internal Server Error: {e}".encode(DEFAULT_ENCODING)]
|
||||
|
||||
if not isinstance(final_output, list):
|
||||
final_output = [final_output]
|
||||
encoded_resp = [
|
||||
chunk.encode(DEFAULT_ENCODING) if isinstance(chunk, str) else chunk
|
||||
for chunk in final_output
|
||||
]
|
||||
|
||||
start_response(status, headers)
|
||||
return encoded_resp
|
||||
except APIError:
|
||||
raise
|
||||
|
@ -61,13 +61,17 @@ class MiddlewareMixin:
|
||||
self.middleware.remove(middleware)
|
||||
continue
|
||||
|
||||
def post_process_middleware(self, request: Request, response: str) -> str:
|
||||
def post_process_middleware(
|
||||
self, request: Request, response: HttpResponse, rendered_response: str
|
||||
) -> str:
|
||||
# run them in reverse order, same as process_response. The top of the middleware
|
||||
# stack should be the first and last middleware to run.
|
||||
for middleware in reversed(self.middleware):
|
||||
try:
|
||||
response = middleware.post_process(request, response)
|
||||
rendered_response = middleware.post_process(
|
||||
request, response, rendered_response
|
||||
)
|
||||
except UnusedMiddleware:
|
||||
self.middleware.remove(middleware)
|
||||
continue
|
||||
return response
|
||||
return rendered_response
|
||||
|
@ -29,9 +29,7 @@ class SpiderwebMiddleware:
|
||||
# the request and return a response immediately.
|
||||
pass
|
||||
|
||||
def process_response(
|
||||
self, request: Request, response: HttpResponse
|
||||
) -> None:
|
||||
def process_response(self, request: Request, response: HttpResponse) -> None:
|
||||
# This method is called after the view has returned a response. You can modify
|
||||
# the response in this method. The response will be returned to the client after
|
||||
# all middleware has been processed.
|
||||
@ -43,7 +41,9 @@ class SpiderwebMiddleware:
|
||||
# will be re-raised.
|
||||
pass
|
||||
|
||||
def post_process(self, request: Request, rendered_response: str) -> str:
|
||||
def post_process(
|
||||
self, request: Request, response: HttpResponse, rendered_response: str
|
||||
) -> str:
|
||||
# This method is called after all the middleware has been processed and receives
|
||||
# the final rendered response in str form. You can modify the response here.
|
||||
return rendered_response
|
||||
|
@ -19,5 +19,22 @@ class InterruptingMiddleware(SpiderwebMiddleware):
|
||||
|
||||
|
||||
class PostProcessingMiddleware(SpiderwebMiddleware):
|
||||
def post_process(self, request: Request, response: str) -> str:
|
||||
return response + " Moo!"
|
||||
def post_process(
|
||||
self, request: Request, response: HttpResponse, rendered_response: str
|
||||
) -> str:
|
||||
return rendered_response + " Moo!"
|
||||
|
||||
|
||||
class PostProcessingWithHeaderManipulation(SpiderwebMiddleware):
|
||||
def post_process(
|
||||
self, request: Request, response: HttpResponse, rendered_response: str
|
||||
) -> str:
|
||||
response.headers["X-Moo"] = "true"
|
||||
return rendered_response
|
||||
|
||||
|
||||
class ExplodingPostProcessingMiddleware(SpiderwebMiddleware):
|
||||
def post_process(
|
||||
self, request: Request, response: HttpResponse, rendered_response: str
|
||||
) -> str:
|
||||
raise UnusedMiddleware("Unfinished!")
|
||||
|
@ -314,6 +314,41 @@ def test_post_process_middleware():
|
||||
assert app(environ, start_response) == [bytes("Hi! Moo!", DEFAULT_ENCODING)]
|
||||
|
||||
|
||||
def test_post_process_header_manip():
|
||||
app, environ, start_response = setup(
|
||||
middleware=[
|
||||
"spiderweb.tests.middleware.PostProcessingWithHeaderManipulation",
|
||||
],
|
||||
)
|
||||
|
||||
app.add_route("/", text_view)
|
||||
|
||||
environ["HTTP_USER_AGENT"] = "hi"
|
||||
environ["REMOTE_ADDR"] = "/"
|
||||
environ["REQUEST_METHOD"] = "GET"
|
||||
|
||||
assert app(environ, start_response) == [bytes("Hi!", DEFAULT_ENCODING)]
|
||||
assert start_response.get_headers()["x-moo"] == "true"
|
||||
|
||||
|
||||
def test_unused_post_process_middleware():
|
||||
app, environ, start_response = setup(
|
||||
middleware=[
|
||||
"spiderweb.tests.middleware.ExplodingPostProcessingMiddleware",
|
||||
],
|
||||
)
|
||||
|
||||
app.add_route("/", text_view)
|
||||
|
||||
environ["HTTP_USER_AGENT"] = "hi"
|
||||
environ["REMOTE_ADDR"] = "/"
|
||||
environ["REQUEST_METHOD"] = "GET"
|
||||
|
||||
assert app(environ, start_response) == [bytes("Hi!", DEFAULT_ENCODING)]
|
||||
# make sure it kicked out the middleware and isn't just ignoring it
|
||||
assert len(app.middleware) == 0
|
||||
|
||||
|
||||
class TestCorsMiddleware:
|
||||
# adapted from:
|
||||
# https://github.com/adamchainz/django-cors-headers/blob/main/tests/test_middleware.py
|
||||
|
Loading…
Reference in New Issue
Block a user