From 4a292a282f4e18872fdd479fb0e0321d170afe50 Mon Sep 17 00:00:00 2001 From: Joe Kaufeld Date: Fri, 9 Aug 2024 12:02:46 -0400 Subject: [PATCH] :sparkles: finish_middleware --- example.py | 34 ++++-- example_middleware.py | 22 ++++ spiderweb/main.py | 226 ++++++++++++++++++++++++++++------------ spiderweb/middleware.py | 9 +- spiderweb/utils.py | 7 ++ templates/test.html | 11 +- 6 files changed, 224 insertions(+), 85 deletions(-) create mode 100644 example_middleware.py create mode 100644 spiderweb/utils.py diff --git a/example.py b/example.py index 760a433..ad1f75f 100644 --- a/example.py +++ b/example.py @@ -1,7 +1,15 @@ from spiderweb import WebServer from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse -app = WebServer(templates_dirs=["templates"]) + +app = WebServer( + templates_dirs=["templates"], + middleware=[ + "example_middleware.TestMiddleware", + "example_middleware.RedirectMiddleware" + ], + append_slash=False # default +) @app.route("/") @@ -14,10 +22,24 @@ def redirect(request): return RedirectResponse("/") +@app.route("/json") +def json(request): + return JsonResponse(data={"key": "value"}) + + +@app.route("/error") +def error(request): + return HttpResponse(status_code=500, body="Internal Server Error") + + +@app.route("/middleware") +def middleware(request): + return HttpResponse( + body="We'll never hit this because it's redirected in middleware" + ) + + if __name__ == "__main__": + # can also add routes like this: # app.add_route("/", index) - try: - app.start() - print("Currently serving on", app.uri()) - except KeyboardInterrupt: - app.stop() + app.start() diff --git a/example_middleware.py b/example_middleware.py new file mode 100644 index 0000000..09f38d7 --- /dev/null +++ b/example_middleware.py @@ -0,0 +1,22 @@ +from spiderweb.middleware import SpiderwebMiddleware +from spiderweb.request import Request +from spiderweb.response import HttpResponse, RedirectResponse + + +class TestMiddleware(SpiderwebMiddleware): + def process_request(self, request: Request) -> HttpResponse | None: + # example of a middleware that sets a flag on the request + request.spiderweb = True + + def process_response( + self, request: Request, response: HttpResponse + ) -> HttpResponse | None: + # example of a middleware that sets a header on the resp + if hasattr(request, "spiderweb"): + response.headers["X-Spiderweb"] = "true" + + +class RedirectMiddleware(SpiderwebMiddleware): + def process_request(self, request: Request) -> HttpResponse | None: + if request.path == "/middleware": + return RedirectResponse("/") diff --git a/spiderweb/main.py b/spiderweb/main.py index c8b0993..91f748e 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -2,15 +2,17 @@ # https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55 # # Extensively modified by @itsthejoker - import json import re +import signal +import sys +import time import traceback from http.server import BaseHTTPRequestHandler, HTTPServer import urllib.parse as urlparse import threading import logging -from typing import Callable, Any +from typing import Callable, Any, NoReturn from jinja2 import Environment, FileSystemLoader @@ -24,9 +26,12 @@ from spiderweb.exceptions import ( NoResponseError, ) from spiderweb.request import Request -from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse +from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse +from spiderweb.utils import import_by_string + log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) def route(path): @@ -39,30 +44,12 @@ def route(path): return outer -def convert_path(path): - """Convert a path to a regex.""" - parts = path.split("/") - for i, part in enumerate(parts): - if part.startswith("<") and part.endswith(">"): - name = part[1:-1] - if "__" in name: - raise ConfigError( - f"Cannot use `__` (double underscore) in path variable." - f" Please fix '{name}'." - ) - if ":" in name: - converter, name = name.split(":") - try: - converter = globals()[converter.title() + "Converter"] - except KeyError: - raise ParseError(f"Unknown converter {converter}") - else: - converter = StrConverter # noqa: F405 - parts[i] = r"(?P<%s>%s)" % ( - f"{name}__{str(converter.__name__)}", - converter.regex, - ) - return re.compile(r"^%s$" % "/".join(parts)) +class DummyRedirectRoute: + def __init__(self, location): + self.location = location + + def __call__(self, request): + return RedirectResponse(self.location) def convert_match_to_dict(match: dict): @@ -81,6 +68,7 @@ class WebServer(HTTPServer): custom_handler: Callable = None, templates_dirs: list[str] = None, middleware: list[str] = None, + append_slash: bool = False, ): """ Create a new server on address, port. Port can be zero. @@ -92,9 +80,21 @@ class WebServer(HTTPServer): by calling `add_handler("path", function)`. """ addr = addr if addr else "localhost" - port = port if port else 7777 + port = port if port else 8000 + self.append_slash = append_slash self.templates_dirs = templates_dirs self.middleware = middleware if middleware else [] + self._thread = None + + if self.middleware: + middleware_by_reference = [] + for m in self.middleware: + try: + middleware_by_reference.append(import_by_string(m)()) + except ImportError: + raise ConfigError(f"Middleware '{m}' not found.") + self.middleware = middleware_by_reference + if self.templates_dirs: self.env = Environment(loader=FileSystemLoader(self.templates_dirs)) else: @@ -106,29 +106,66 @@ class WebServer(HTTPServer): class HandlerClass(RequestHandler): pass + # inject template loader, middleware, and other important things into handler self.handler_class = custom_handler if custom_handler else HandlerClass self.handler_class.env = self.env + self.handler_class.middleware = self.middleware + self.handler_class.append_slash = self.append_slash # routed methods map into handler for method in type(self).__dict__.values(): if hasattr(method, "_routes"): for route in method._routes: self.add_route(route, method) + try: super().__init__(server_address, self.handler_class) except OSError: raise GeneralException("Port already in use.") + def convert_path(self, path: str): + """Convert a path to a regex.""" + parts = path.split("/") + for i, part in enumerate(parts): + if part.startswith("<") and part.endswith(">"): + name = part[1:-1] + if "__" in name: + raise ConfigError( + f"Cannot use `__` (double underscore) in path variable." + f" Please fix '{name}'." + ) + if ":" in name: + converter, name = name.split(":") + try: + converter = globals()[converter.title() + "Converter"] + except KeyError: + raise ParseError(f"Unknown converter {converter}") + else: + converter = StrConverter # noqa: F405 + parts[i] = r"(?P<%s>%s)" % ( + f"{name}__{str(converter.__name__)}", + converter.regex, + ) + return re.compile(r"^%s$" % "/".join(parts)) + def check_for_route_duplicates(self, path: str): - if convert_path(path) in self.handler_class._routes: + if self.convert_path(path) in self.handler_class._routes: raise ConfigError(f"Route '{path}' already exists.") def add_route(self, path: str, method: Callable): """Add a route to the server.""" if not hasattr(self.handler_class, "_routes"): setattr(self.handler_class, "_routes", []) - self.check_for_route_duplicates(path) - self.handler_class._routes[convert_path(path)] = method + + if self.append_slash and not path.endswith("/"): + updated_path = path + "/" + self.check_for_route_duplicates(updated_path) + self.check_for_route_duplicates(path) + self.handler_class._routes[self.convert_path(path)] = DummyRedirectRoute(updated_path) + self.handler_class._routes[self.convert_path(updated_path)] = method + else: + self.check_for_route_duplicates(path) + self.handler_class._routes[self.convert_path(path)] = method def route(self, path) -> Callable: """ @@ -147,18 +184,17 @@ class WebServer(HTTPServer): """ def outer(func): - if not hasattr(self.handler_class, "_routes"): - setattr(self.handler_class, "_routes", []) - self.check_for_route_duplicates(path) - self.handler_class._routes[convert_path(path)] = func + self.add_route(path, func) return func return outer + @property def port(self): """Return current port.""" return self.socket.getsockname()[1] + @property def address(self): """Return current IP address.""" return self.socket.getsockname()[0] @@ -168,18 +204,26 @@ class WebServer(HTTPServer): path = path if path else "" if path.startswith("/"): path = path[1:] - return "http://" + self.__addr + ":" + str(self.port()) + "/" + path + return self.__addr + ":" + str(self.port()) + "/" + path + + def signal_handler(self, sig, frame) -> NoReturn: + log.warning('Shutting down!') + self.stop() def start(self, blocking=False): + signal.signal(signal.SIGINT, self.signal_handler) + log.info(f"Starting server on {self.address}:{self.port}") + log.info("Press CTRL+C to stop the server.") + self._thread = threading.Thread(target=self.serve_forever) + self._thread.start() if not blocking: - threading.Thread(target=self.serve_forever).start() + return self._thread else: - try: - self.serve_forever() - except KeyboardInterrupt: - print() # empty line after ^C - print("Stopping server!") - return + while self._thread.is_alive(): + try: + time.sleep(0.2) + except KeyboardInterrupt: + self.stop() def stop(self): super().shutdown() @@ -190,6 +234,7 @@ class RequestHandler(BaseHTTPRequestHandler): # I can't help the naming convention of these because that's what # BaseHTTPRequestHandler uses for some weird reason _routes = {} + middleware = [] def get_request(self): return Request( @@ -233,7 +278,7 @@ class RequestHandler(BaseHTTPRequestHandler): except KeyError: return http500 - def fire_response(self, resp: HttpResponse): + def _fire_response(self, resp: HttpResponse): self.send_response(resp.status_code) content = resp.render() self.send_header("Content-Length", str(len(content))) @@ -243,38 +288,81 @@ class RequestHandler(BaseHTTPRequestHandler): self.end_headers() self.wfile.write(bytes(content, "utf-8")) - def handle_request(self, request): + def fire_response(self, request: Request, resp: HttpResponse): try: - request.url = urlparse.urlparse(request.path) + self._fire_response(resp) + except APIError: + raise + except ConnectionAbortedError as e: + log.error(f"GET {self.path} : {e}") + except Exception: + log.error(traceback.format_exc()) + self.fire_response(request, self.get_error_route(500)(request)) - handler, additional_args = self.get_route(request.url.path) + def process_request_middleware(self, request: Request) -> None | bool: + for middleware in self.middleware: + resp = middleware.process_request(request) + if resp: + self.process_response_middleware(request, resp) + self.fire_response(request, resp) + return True # abort further processing - if request.url.query: - params = urlparse.parse_qs(request.url.query) - else: - params = {} + def process_response_middleware(self, request: Request, response: HttpResponse) -> None: + for middleware in self.middleware: + middleware.process_response(request, response) - request.query_params = params + def prepare_response(self, request, resp) -> HttpResponse: + try: + if isinstance(resp, dict): + self.fire_response(JsonResponse(data=resp)) + if isinstance(resp, TemplateResponse): + if hasattr(self, "env"): # injected from above + resp.set_template_loader(self.env) + for middleware in self.middleware: + middleware.process_response(request, resp) + + self.fire_response(resp) + + except APIError: + raise + + except ConnectionAbortedError as e: + log.error(f"GET {self.path} : {e}") + except Exception: + log.error(traceback.format_exc()) + self.fire_response(request, self.get_error_route(500)(request)) + + def handle_request(self, request): + + request.url = urlparse.urlparse(request.path) + + handler, additional_args = self.get_route(request.url.path) + + if request.url.query: + params = urlparse.parse_qs(request.url.query) + else: + params = {} + + request.query_params = params + try: if handler: - try: - resp = handler(request, **additional_args) - if resp is None: - raise NoResponseError(f"View {handler} returned None.") - if isinstance(resp, dict): - self.fire_response(JsonResponse(data=resp)) - if isinstance(resp, TemplateResponse): - if hasattr(self, "env"): # injected from above - resp.set_template_loader(self.env) - self.fire_response(resp) - except APIError: - raise - except ConnectionAbortedError as e: - log.error(f"GET {self.path} : {e}") - except Exception: - log.error(traceback.format_exc()) - self.fire_response(self.get_error_route(500)(request)) + # middleware is injected from WebServer + abort = self.process_request_middleware(request) + if abort: + return + resp = handler(request, **additional_args) + if resp is None: + raise NoResponseError(f"View {handler} returned None.") + if isinstance(resp, dict): + self.fire_response(request, JsonResponse(data=resp)) + if isinstance(resp, TemplateResponse): + if hasattr(self, "env"): # injected from above + resp.set_template_loader(self.env) + + self.process_response_middleware(request, resp) + self.fire_response(request, resp) else: raise APIError(404) except APIError as e: diff --git a/spiderweb/middleware.py b/spiderweb/middleware.py index 97b4bbc..9a1eebd 100644 --- a/spiderweb/middleware.py +++ b/spiderweb/middleware.py @@ -16,17 +16,12 @@ class SpiderwebMiddleware: If `process_request` returns a HttpResponse, the request will be short-circuited and the response will be returned immediately. `process_response` will not be called. - """ def process_request(self, request: Request) -> HttpResponse | None: - # example of a middleware that sets a flag on the request - request.spiderweb = True + pass def process_response( self, request: Request, response: HttpResponse ) -> HttpResponse | None: - # example of a middleware that sets a header on the resp - if hasattr(request, "spiderweb"): - response.headers["X-Spiderweb"] = "true" - return response + pass diff --git a/spiderweb/utils.py b/spiderweb/utils.py new file mode 100644 index 0000000..360e5a3 --- /dev/null +++ b/spiderweb/utils.py @@ -0,0 +1,7 @@ +def import_by_string(name): + # https://stackoverflow.com/a/547867 + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod diff --git a/templates/test.html b/templates/test.html index a27453b..101cdaf 100644 --- a/templates/test.html +++ b/templates/test.html @@ -1,5 +1,10 @@ -

FART

+

HI, THIS IS A PAGE

- This is a test of the {{ value }} template. -

\ No newline at end of file + This is a test of the template rendering system. If rendering is working, this value + should be TEST: {{ value }}. +

+

+ The value of request.spiderweb is {{ request.spiderweb }}. If this is True, + middleware is working. +