From 7ac76883fcc743710e93b8613c2d20960da769bb Mon Sep 17 00:00:00 2001 From: Joe Kaufeld Date: Wed, 16 Oct 2024 17:26:22 -0400 Subject: [PATCH] :sparkles: add ability to adjust headers in post_process --- docs/middleware/custom_middleware.md | 16 ++++++++++--- example_middleware.py | 18 ++++++++++---- spiderweb/main.py | 24 +++++++++---------- spiderweb/middleware/__init__.py | 10 +++++--- spiderweb/middleware/base.py | 8 +++---- spiderweb/tests/middleware.py | 21 +++++++++++++++-- spiderweb/tests/test_middleware.py | 35 ++++++++++++++++++++++++++++ 7 files changed, 104 insertions(+), 28 deletions(-) diff --git a/docs/middleware/custom_middleware.md b/docs/middleware/custom_middleware.md index 26af02b..7cfa8a2 100644 --- a/docs/middleware/custom_middleware.md +++ b/docs/middleware/custom_middleware.md @@ -1,3 +1,5 @@ +from spiderweb import HttpResponse + # writing your own middleware Sometimes you want to run the same code on every request or every response (or both!). Lots of processing happens in the middleware layer, and if you want to write your own, all you have to do is write a quick class and put it in a place that Spiderweb can find it. A piece of middleware only needs two things to be successful: @@ -55,13 +57,19 @@ Unlike `process_request`, returning a value here doesn't change anything. We're This is a helper function that is available for you to override; it's not often used by middleware, but there are some ([like the pydantic middleware](middleware/pydantic.md)) that call `on_error` when there is a validation failure. -## post_process(self, request: Request, rendered_response: str) -> str: +## post_process(self, request: Request, response: HttpResponse, rendered_response: str) -> str: > New in 1.3.0! After `process_request` and `process_response` run, the response is rendered out into the raw text that is going to be sent to the client. Right before that happens, `post_process` is called on each middleware in the same order as `process_response` (so the closer something is to the beginning of the middleware list, the more important it is). -Note that this function *must* return something. Each invocation of `post_process` overwrites the entire output of the response, so make sure to return everything that you want to send. For example, here's a middleware that ~~breaks~~ adjusts the capitalization of the response and also demonstrates passing variables into the middleware: +There are three things passed to `post_process`: + +- `request`: the request object. It's provided here purely for reference purposes; while you can technically change it here, it won't have any effect on the response. +- `response`: the response object. The full HTML of the response has already been rendered, but the headers can still be modified here. This object can be modified in place, like in `process_response`. +- `rendered_response`: the full HTML of the response as a string. This is the final output that will be sent to the client. Every instance of `post_process` must return the full HTML of the response, so if you want to make changes, you'll need to return the modified string. + +Note that this function *must* return the full HTML of the response (provided at the start as `rendered_response`. Each invocation of `post_process` overwrites the entire output of the response, so make sure to return everything that you want to send. For example, here's a middleware that ~~breaks~~ adjusts the capitalization of the response and also demonstrates passing variables into the middleware and modifies the headers with the type of transformation: ```python import random @@ -74,7 +82,7 @@ from spiderweb.exceptions import ConfigError class CaseTransformMiddleware(SpiderwebMiddleware): # this breaks everything, but it's hilarious so it's worth it. # Blame Sam. - def post_process(self, request: Request, rendered_response: str) -> str: + def post_process(self, request: Request, response: HttpResponse, rendered_response: str) -> str: valid_options = ["spongebob", "random"] # grab the value from the extra data passed into the server object # during instantiation @@ -86,12 +94,14 @@ class CaseTransformMiddleware(SpiderwebMiddleware): ) if method == "spongebob": + response.headers["X-Case-Transform"] = "spongebob" return "".join( char.upper() if i % 2 == 0 else char.lower() for i, char in enumerate(rendered_response) ) else: + response.headers["X-Case-Transform"] = "random" return "".join( char.upper() if random.random() > 0.5 diff --git a/example_middleware.py b/example_middleware.py index b69ed2b..21d0a21 100644 --- a/example_middleware.py +++ b/example_middleware.py @@ -32,9 +32,15 @@ class ExplodingMiddleware(SpiderwebMiddleware): class CaseTransformMiddleware(SpiderwebMiddleware): # this breaks everything, but it's hilarious so it's worth it. # Blame Sam. - def post_process(self, request: Request, rendered_response: str) -> str: + def post_process( + self, request: Request, response: HttpResponse, rendered_response: str + ) -> str: valid_options = ["spongebob", "random"] - method = self.server.extra_data.get("case_transform_middleware_type", "spongebob") + # grab the value from the extra data passed into the server object + # during instantiation + method = self.server.extra_data.get( + "case_transform_middleware_type", "spongebob" + ) if method not in valid_options: raise ConfigError( f"Invalid method '{method}' for CaseTransformMiddleware." @@ -42,10 +48,14 @@ class CaseTransformMiddleware(SpiderwebMiddleware): ) if method == "spongebob": + response.headers["X-Case-Transform"] = "spongebob" return "".join( - char.upper() if i % 2 == 0 else char.lower() for i, char in enumerate(rendered_response) + char.upper() if i % 2 == 0 else char.lower() + for i, char in enumerate(rendered_response) ) else: + response.headers["X-Case-Transform"] = "random" return "".join( - char.upper() if random.random() > 0.5 else char for char in rendered_response + char.upper() if random.random() > 0.5 else char + for char in rendered_response ) diff --git a/spiderweb/main.py b/spiderweb/main.py index e30bb1d..82ab30a 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -201,6 +201,17 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi def fire_response(self, start_response, request: Request, resp: HttpResponse): try: + try: + rendered_output: str = resp.render() + final_output: str | list[str] = self.post_process_middleware( + request, resp, rendered_output + ) + except Exception as e: + self.log.error("Fatal error!") + self.log.error(e) + self.log.error(traceback.format_exc()) + return [f"Internal Server Error: {e}".encode(DEFAULT_ENCODING)] + status = get_http_status_by_code(resp.status_code) cookies = [] varies = [] @@ -218,24 +229,13 @@ class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixi for v in varies: headers.append(("vary", str(v))) - start_response(status, headers) - - try: - rendered_output: str = resp.render() - final_output: str | list[str] = self.post_process_middleware(request, rendered_output) - except Exception as e: - self.log.error("Fatal error!") - self.log.error(e) - self.log.error(traceback.format_exc()) - return [f"Internal Server Error: {e}".encode(DEFAULT_ENCODING)] - if not isinstance(final_output, list): final_output = [final_output] encoded_resp = [ chunk.encode(DEFAULT_ENCODING) if isinstance(chunk, str) else chunk for chunk in final_output ] - + start_response(status, headers) return encoded_resp except APIError: raise diff --git a/spiderweb/middleware/__init__.py b/spiderweb/middleware/__init__.py index ad4757b..9118034 100644 --- a/spiderweb/middleware/__init__.py +++ b/spiderweb/middleware/__init__.py @@ -61,13 +61,17 @@ class MiddlewareMixin: self.middleware.remove(middleware) continue - def post_process_middleware(self, request: Request, response: str) -> str: + def post_process_middleware( + self, request: Request, response: HttpResponse, rendered_response: str + ) -> str: # run them in reverse order, same as process_response. The top of the middleware # stack should be the first and last middleware to run. for middleware in reversed(self.middleware): try: - response = middleware.post_process(request, response) + rendered_response = middleware.post_process( + request, response, rendered_response + ) except UnusedMiddleware: self.middleware.remove(middleware) continue - return response + return rendered_response diff --git a/spiderweb/middleware/base.py b/spiderweb/middleware/base.py index 90a52a1..e9a22ff 100644 --- a/spiderweb/middleware/base.py +++ b/spiderweb/middleware/base.py @@ -29,9 +29,7 @@ class SpiderwebMiddleware: # the request and return a response immediately. pass - def process_response( - self, request: Request, response: HttpResponse - ) -> None: + def process_response(self, request: Request, response: HttpResponse) -> None: # This method is called after the view has returned a response. You can modify # the response in this method. The response will be returned to the client after # all middleware has been processed. @@ -43,7 +41,9 @@ class SpiderwebMiddleware: # will be re-raised. pass - def post_process(self, request: Request, rendered_response: str) -> str: + def post_process( + self, request: Request, response: HttpResponse, rendered_response: str + ) -> str: # This method is called after all the middleware has been processed and receives # the final rendered response in str form. You can modify the response here. return rendered_response diff --git a/spiderweb/tests/middleware.py b/spiderweb/tests/middleware.py index 9517b89..ba17113 100644 --- a/spiderweb/tests/middleware.py +++ b/spiderweb/tests/middleware.py @@ -19,5 +19,22 @@ class InterruptingMiddleware(SpiderwebMiddleware): class PostProcessingMiddleware(SpiderwebMiddleware): - def post_process(self, request: Request, response: str) -> str: - return response + " Moo!" + def post_process( + self, request: Request, response: HttpResponse, rendered_response: str + ) -> str: + return rendered_response + " Moo!" + + +class PostProcessingWithHeaderManipulation(SpiderwebMiddleware): + def post_process( + self, request: Request, response: HttpResponse, rendered_response: str + ) -> str: + response.headers["X-Moo"] = "true" + return rendered_response + + +class ExplodingPostProcessingMiddleware(SpiderwebMiddleware): + def post_process( + self, request: Request, response: HttpResponse, rendered_response: str + ) -> str: + raise UnusedMiddleware("Unfinished!") diff --git a/spiderweb/tests/test_middleware.py b/spiderweb/tests/test_middleware.py index 75947c7..e2c4199 100644 --- a/spiderweb/tests/test_middleware.py +++ b/spiderweb/tests/test_middleware.py @@ -314,6 +314,41 @@ def test_post_process_middleware(): assert app(environ, start_response) == [bytes("Hi! Moo!", DEFAULT_ENCODING)] +def test_post_process_header_manip(): + app, environ, start_response = setup( + middleware=[ + "spiderweb.tests.middleware.PostProcessingWithHeaderManipulation", + ], + ) + + app.add_route("/", text_view) + + environ["HTTP_USER_AGENT"] = "hi" + environ["REMOTE_ADDR"] = "/" + environ["REQUEST_METHOD"] = "GET" + + assert app(environ, start_response) == [bytes("Hi!", DEFAULT_ENCODING)] + assert start_response.get_headers()["x-moo"] == "true" + + +def test_unused_post_process_middleware(): + app, environ, start_response = setup( + middleware=[ + "spiderweb.tests.middleware.ExplodingPostProcessingMiddleware", + ], + ) + + app.add_route("/", text_view) + + environ["HTTP_USER_AGENT"] = "hi" + environ["REMOTE_ADDR"] = "/" + environ["REQUEST_METHOD"] = "GET" + + assert app(environ, start_response) == [bytes("Hi!", DEFAULT_ENCODING)] + # make sure it kicked out the middleware and isn't just ignoring it + assert len(app.middleware) == 0 + + class TestCorsMiddleware: # adapted from: # https://github.com/adamchainz/django-cors-headers/blob/main/tests/test_middleware.py