From 4a292a282f4e18872fdd479fb0e0321d170afe50 Mon Sep 17 00:00:00 2001
From: Joe Kaufeld
Date: Fri, 9 Aug 2024 12:02:46 -0400
Subject: [PATCH] :sparkles: finish_middleware
---
example.py | 34 ++++--
example_middleware.py | 22 ++++
spiderweb/main.py | 226 ++++++++++++++++++++++++++++------------
spiderweb/middleware.py | 9 +-
spiderweb/utils.py | 7 ++
templates/test.html | 11 +-
6 files changed, 224 insertions(+), 85 deletions(-)
create mode 100644 example_middleware.py
create mode 100644 spiderweb/utils.py
diff --git a/example.py b/example.py
index 760a433..ad1f75f 100644
--- a/example.py
+++ b/example.py
@@ -1,7 +1,15 @@
from spiderweb import WebServer
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
-app = WebServer(templates_dirs=["templates"])
+
+app = WebServer(
+ templates_dirs=["templates"],
+ middleware=[
+ "example_middleware.TestMiddleware",
+ "example_middleware.RedirectMiddleware"
+ ],
+ append_slash=False # default
+)
@app.route("/")
@@ -14,10 +22,24 @@ def redirect(request):
return RedirectResponse("/")
+@app.route("/json")
+def json(request):
+ return JsonResponse(data={"key": "value"})
+
+
+@app.route("/error")
+def error(request):
+ return HttpResponse(status_code=500, body="Internal Server Error")
+
+
+@app.route("/middleware")
+def middleware(request):
+ return HttpResponse(
+ body="We'll never hit this because it's redirected in middleware"
+ )
+
+
if __name__ == "__main__":
+ # can also add routes like this:
# app.add_route("/", index)
- try:
- app.start()
- print("Currently serving on", app.uri())
- except KeyboardInterrupt:
- app.stop()
+ app.start()
diff --git a/example_middleware.py b/example_middleware.py
new file mode 100644
index 0000000..09f38d7
--- /dev/null
+++ b/example_middleware.py
@@ -0,0 +1,22 @@
+from spiderweb.middleware import SpiderwebMiddleware
+from spiderweb.request import Request
+from spiderweb.response import HttpResponse, RedirectResponse
+
+
+class TestMiddleware(SpiderwebMiddleware):
+ def process_request(self, request: Request) -> HttpResponse | None:
+ # example of a middleware that sets a flag on the request
+ request.spiderweb = True
+
+ def process_response(
+ self, request: Request, response: HttpResponse
+ ) -> HttpResponse | None:
+ # example of a middleware that sets a header on the resp
+ if hasattr(request, "spiderweb"):
+ response.headers["X-Spiderweb"] = "true"
+
+
+class RedirectMiddleware(SpiderwebMiddleware):
+ def process_request(self, request: Request) -> HttpResponse | None:
+ if request.path == "/middleware":
+ return RedirectResponse("/")
diff --git a/spiderweb/main.py b/spiderweb/main.py
index c8b0993..91f748e 100644
--- a/spiderweb/main.py
+++ b/spiderweb/main.py
@@ -2,15 +2,17 @@
# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55
#
# Extensively modified by @itsthejoker
-
import json
import re
+import signal
+import sys
+import time
import traceback
from http.server import BaseHTTPRequestHandler, HTTPServer
import urllib.parse as urlparse
import threading
import logging
-from typing import Callable, Any
+from typing import Callable, Any, NoReturn
from jinja2 import Environment, FileSystemLoader
@@ -24,9 +26,12 @@ from spiderweb.exceptions import (
NoResponseError,
)
from spiderweb.request import Request
-from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse
+from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
+from spiderweb.utils import import_by_string
+
log = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
def route(path):
@@ -39,30 +44,12 @@ def route(path):
return outer
-def convert_path(path):
- """Convert a path to a regex."""
- parts = path.split("/")
- for i, part in enumerate(parts):
- if part.startswith("<") and part.endswith(">"):
- name = part[1:-1]
- if "__" in name:
- raise ConfigError(
- f"Cannot use `__` (double underscore) in path variable."
- f" Please fix '{name}'."
- )
- if ":" in name:
- converter, name = name.split(":")
- try:
- converter = globals()[converter.title() + "Converter"]
- except KeyError:
- raise ParseError(f"Unknown converter {converter}")
- else:
- converter = StrConverter # noqa: F405
- parts[i] = r"(?P<%s>%s)" % (
- f"{name}__{str(converter.__name__)}",
- converter.regex,
- )
- return re.compile(r"^%s$" % "/".join(parts))
+class DummyRedirectRoute:
+ def __init__(self, location):
+ self.location = location
+
+ def __call__(self, request):
+ return RedirectResponse(self.location)
def convert_match_to_dict(match: dict):
@@ -81,6 +68,7 @@ class WebServer(HTTPServer):
custom_handler: Callable = None,
templates_dirs: list[str] = None,
middleware: list[str] = None,
+ append_slash: bool = False,
):
"""
Create a new server on address, port. Port can be zero.
@@ -92,9 +80,21 @@ class WebServer(HTTPServer):
by calling `add_handler("path", function)`.
"""
addr = addr if addr else "localhost"
- port = port if port else 7777
+ port = port if port else 8000
+ self.append_slash = append_slash
self.templates_dirs = templates_dirs
self.middleware = middleware if middleware else []
+ self._thread = None
+
+ if self.middleware:
+ middleware_by_reference = []
+ for m in self.middleware:
+ try:
+ middleware_by_reference.append(import_by_string(m)())
+ except ImportError:
+ raise ConfigError(f"Middleware '{m}' not found.")
+ self.middleware = middleware_by_reference
+
if self.templates_dirs:
self.env = Environment(loader=FileSystemLoader(self.templates_dirs))
else:
@@ -106,29 +106,66 @@ class WebServer(HTTPServer):
class HandlerClass(RequestHandler):
pass
+ # inject template loader, middleware, and other important things into handler
self.handler_class = custom_handler if custom_handler else HandlerClass
self.handler_class.env = self.env
+ self.handler_class.middleware = self.middleware
+ self.handler_class.append_slash = self.append_slash
# routed methods map into handler
for method in type(self).__dict__.values():
if hasattr(method, "_routes"):
for route in method._routes:
self.add_route(route, method)
+
try:
super().__init__(server_address, self.handler_class)
except OSError:
raise GeneralException("Port already in use.")
+ def convert_path(self, path: str):
+ """Convert a path to a regex."""
+ parts = path.split("/")
+ for i, part in enumerate(parts):
+ if part.startswith("<") and part.endswith(">"):
+ name = part[1:-1]
+ if "__" in name:
+ raise ConfigError(
+ f"Cannot use `__` (double underscore) in path variable."
+ f" Please fix '{name}'."
+ )
+ if ":" in name:
+ converter, name = name.split(":")
+ try:
+ converter = globals()[converter.title() + "Converter"]
+ except KeyError:
+ raise ParseError(f"Unknown converter {converter}")
+ else:
+ converter = StrConverter # noqa: F405
+ parts[i] = r"(?P<%s>%s)" % (
+ f"{name}__{str(converter.__name__)}",
+ converter.regex,
+ )
+ return re.compile(r"^%s$" % "/".join(parts))
+
def check_for_route_duplicates(self, path: str):
- if convert_path(path) in self.handler_class._routes:
+ if self.convert_path(path) in self.handler_class._routes:
raise ConfigError(f"Route '{path}' already exists.")
def add_route(self, path: str, method: Callable):
"""Add a route to the server."""
if not hasattr(self.handler_class, "_routes"):
setattr(self.handler_class, "_routes", [])
- self.check_for_route_duplicates(path)
- self.handler_class._routes[convert_path(path)] = method
+
+ if self.append_slash and not path.endswith("/"):
+ updated_path = path + "/"
+ self.check_for_route_duplicates(updated_path)
+ self.check_for_route_duplicates(path)
+ self.handler_class._routes[self.convert_path(path)] = DummyRedirectRoute(updated_path)
+ self.handler_class._routes[self.convert_path(updated_path)] = method
+ else:
+ self.check_for_route_duplicates(path)
+ self.handler_class._routes[self.convert_path(path)] = method
def route(self, path) -> Callable:
"""
@@ -147,18 +184,17 @@ class WebServer(HTTPServer):
"""
def outer(func):
- if not hasattr(self.handler_class, "_routes"):
- setattr(self.handler_class, "_routes", [])
- self.check_for_route_duplicates(path)
- self.handler_class._routes[convert_path(path)] = func
+ self.add_route(path, func)
return func
return outer
+ @property
def port(self):
"""Return current port."""
return self.socket.getsockname()[1]
+ @property
def address(self):
"""Return current IP address."""
return self.socket.getsockname()[0]
@@ -168,18 +204,26 @@ class WebServer(HTTPServer):
path = path if path else ""
if path.startswith("/"):
path = path[1:]
- return "http://" + self.__addr + ":" + str(self.port()) + "/" + path
+ return self.__addr + ":" + str(self.port()) + "/" + path
+
+ def signal_handler(self, sig, frame) -> NoReturn:
+ log.warning('Shutting down!')
+ self.stop()
def start(self, blocking=False):
+ signal.signal(signal.SIGINT, self.signal_handler)
+ log.info(f"Starting server on {self.address}:{self.port}")
+ log.info("Press CTRL+C to stop the server.")
+ self._thread = threading.Thread(target=self.serve_forever)
+ self._thread.start()
if not blocking:
- threading.Thread(target=self.serve_forever).start()
+ return self._thread
else:
- try:
- self.serve_forever()
- except KeyboardInterrupt:
- print() # empty line after ^C
- print("Stopping server!")
- return
+ while self._thread.is_alive():
+ try:
+ time.sleep(0.2)
+ except KeyboardInterrupt:
+ self.stop()
def stop(self):
super().shutdown()
@@ -190,6 +234,7 @@ class RequestHandler(BaseHTTPRequestHandler):
# I can't help the naming convention of these because that's what
# BaseHTTPRequestHandler uses for some weird reason
_routes = {}
+ middleware = []
def get_request(self):
return Request(
@@ -233,7 +278,7 @@ class RequestHandler(BaseHTTPRequestHandler):
except KeyError:
return http500
- def fire_response(self, resp: HttpResponse):
+ def _fire_response(self, resp: HttpResponse):
self.send_response(resp.status_code)
content = resp.render()
self.send_header("Content-Length", str(len(content)))
@@ -243,38 +288,81 @@ class RequestHandler(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(bytes(content, "utf-8"))
- def handle_request(self, request):
+ def fire_response(self, request: Request, resp: HttpResponse):
try:
- request.url = urlparse.urlparse(request.path)
+ self._fire_response(resp)
+ except APIError:
+ raise
+ except ConnectionAbortedError as e:
+ log.error(f"GET {self.path} : {e}")
+ except Exception:
+ log.error(traceback.format_exc())
+ self.fire_response(request, self.get_error_route(500)(request))
- handler, additional_args = self.get_route(request.url.path)
+ def process_request_middleware(self, request: Request) -> None | bool:
+ for middleware in self.middleware:
+ resp = middleware.process_request(request)
+ if resp:
+ self.process_response_middleware(request, resp)
+ self.fire_response(request, resp)
+ return True # abort further processing
- if request.url.query:
- params = urlparse.parse_qs(request.url.query)
- else:
- params = {}
+ def process_response_middleware(self, request: Request, response: HttpResponse) -> None:
+ for middleware in self.middleware:
+ middleware.process_response(request, response)
- request.query_params = params
+ def prepare_response(self, request, resp) -> HttpResponse:
+ try:
+ if isinstance(resp, dict):
+ self.fire_response(JsonResponse(data=resp))
+ if isinstance(resp, TemplateResponse):
+ if hasattr(self, "env"): # injected from above
+ resp.set_template_loader(self.env)
+ for middleware in self.middleware:
+ middleware.process_response(request, resp)
+
+ self.fire_response(resp)
+
+ except APIError:
+ raise
+
+ except ConnectionAbortedError as e:
+ log.error(f"GET {self.path} : {e}")
+ except Exception:
+ log.error(traceback.format_exc())
+ self.fire_response(request, self.get_error_route(500)(request))
+
+ def handle_request(self, request):
+
+ request.url = urlparse.urlparse(request.path)
+
+ handler, additional_args = self.get_route(request.url.path)
+
+ if request.url.query:
+ params = urlparse.parse_qs(request.url.query)
+ else:
+ params = {}
+
+ request.query_params = params
+ try:
if handler:
- try:
- resp = handler(request, **additional_args)
- if resp is None:
- raise NoResponseError(f"View {handler} returned None.")
- if isinstance(resp, dict):
- self.fire_response(JsonResponse(data=resp))
- if isinstance(resp, TemplateResponse):
- if hasattr(self, "env"): # injected from above
- resp.set_template_loader(self.env)
- self.fire_response(resp)
- except APIError:
- raise
- except ConnectionAbortedError as e:
- log.error(f"GET {self.path} : {e}")
- except Exception:
- log.error(traceback.format_exc())
- self.fire_response(self.get_error_route(500)(request))
+ # middleware is injected from WebServer
+ abort = self.process_request_middleware(request)
+ if abort:
+ return
+ resp = handler(request, **additional_args)
+ if resp is None:
+ raise NoResponseError(f"View {handler} returned None.")
+ if isinstance(resp, dict):
+ self.fire_response(request, JsonResponse(data=resp))
+ if isinstance(resp, TemplateResponse):
+ if hasattr(self, "env"): # injected from above
+ resp.set_template_loader(self.env)
+
+ self.process_response_middleware(request, resp)
+ self.fire_response(request, resp)
else:
raise APIError(404)
except APIError as e:
diff --git a/spiderweb/middleware.py b/spiderweb/middleware.py
index 97b4bbc..9a1eebd 100644
--- a/spiderweb/middleware.py
+++ b/spiderweb/middleware.py
@@ -16,17 +16,12 @@ class SpiderwebMiddleware:
If `process_request` returns a HttpResponse, the request will be short-circuited
and the response will be returned immediately. `process_response` will not be called.
-
"""
def process_request(self, request: Request) -> HttpResponse | None:
- # example of a middleware that sets a flag on the request
- request.spiderweb = True
+ pass
def process_response(
self, request: Request, response: HttpResponse
) -> HttpResponse | None:
- # example of a middleware that sets a header on the resp
- if hasattr(request, "spiderweb"):
- response.headers["X-Spiderweb"] = "true"
- return response
+ pass
diff --git a/spiderweb/utils.py b/spiderweb/utils.py
new file mode 100644
index 0000000..360e5a3
--- /dev/null
+++ b/spiderweb/utils.py
@@ -0,0 +1,7 @@
+def import_by_string(name):
+ # https://stackoverflow.com/a/547867
+ components = name.split('.')
+ mod = __import__(components[0])
+ for comp in components[1:]:
+ mod = getattr(mod, comp)
+ return mod
diff --git a/templates/test.html b/templates/test.html
index a27453b..101cdaf 100644
--- a/templates/test.html
+++ b/templates/test.html
@@ -1,5 +1,10 @@
-FART
+HI, THIS IS A PAGE
- This is a test of the {{ value }} template.
-
\ No newline at end of file
+ This is a test of the template rendering system. If rendering is working, this value
+ should be TEST
: {{ value }}.
+
+
+ The value of request.spiderweb
is {{ request.spiderweb }}. If this is True,
+ middleware is working.
+