finish_middleware

This commit is contained in:
Joe Kaufeld 2024-08-09 12:02:46 -04:00
parent 9caca72d24
commit 4a292a282f
6 changed files with 224 additions and 85 deletions

View File

@ -1,7 +1,15 @@
from spiderweb import WebServer
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
app = WebServer(templates_dirs=["templates"])
app = WebServer(
templates_dirs=["templates"],
middleware=[
"example_middleware.TestMiddleware",
"example_middleware.RedirectMiddleware"
],
append_slash=False # default
)
@app.route("/")
@ -14,10 +22,24 @@ def redirect(request):
return RedirectResponse("/")
@app.route("/json")
def json(request):
return JsonResponse(data={"key": "value"})
@app.route("/error")
def error(request):
return HttpResponse(status_code=500, body="Internal Server Error")
@app.route("/middleware")
def middleware(request):
return HttpResponse(
body="We'll never hit this because it's redirected in middleware"
)
if __name__ == "__main__":
# can also add routes like this:
# app.add_route("/", index)
try:
app.start()
print("Currently serving on", app.uri())
except KeyboardInterrupt:
app.stop()
app.start()

22
example_middleware.py Normal file
View File

@ -0,0 +1,22 @@
from spiderweb.middleware import SpiderwebMiddleware
from spiderweb.request import Request
from spiderweb.response import HttpResponse, RedirectResponse
class TestMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse | None:
# example of a middleware that sets a flag on the request
request.spiderweb = True
def process_response(
self, request: Request, response: HttpResponse
) -> HttpResponse | None:
# example of a middleware that sets a header on the resp
if hasattr(request, "spiderweb"):
response.headers["X-Spiderweb"] = "true"
class RedirectMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse | None:
if request.path == "/middleware":
return RedirectResponse("/")

View File

@ -2,15 +2,17 @@
# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55
#
# Extensively modified by @itsthejoker
import json
import re
import signal
import sys
import time
import traceback
from http.server import BaseHTTPRequestHandler, HTTPServer
import urllib.parse as urlparse
import threading
import logging
from typing import Callable, Any
from typing import Callable, Any, NoReturn
from jinja2 import Environment, FileSystemLoader
@ -24,9 +26,12 @@ from spiderweb.exceptions import (
NoResponseError,
)
from spiderweb.request import Request
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
from spiderweb.utils import import_by_string
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def route(path):
@ -39,30 +44,12 @@ def route(path):
return outer
def convert_path(path):
"""Convert a path to a regex."""
parts = path.split("/")
for i, part in enumerate(parts):
if part.startswith("<") and part.endswith(">"):
name = part[1:-1]
if "__" in name:
raise ConfigError(
f"Cannot use `__` (double underscore) in path variable."
f" Please fix '{name}'."
)
if ":" in name:
converter, name = name.split(":")
try:
converter = globals()[converter.title() + "Converter"]
except KeyError:
raise ParseError(f"Unknown converter {converter}")
else:
converter = StrConverter # noqa: F405
parts[i] = r"(?P<%s>%s)" % (
f"{name}__{str(converter.__name__)}",
converter.regex,
)
return re.compile(r"^%s$" % "/".join(parts))
class DummyRedirectRoute:
def __init__(self, location):
self.location = location
def __call__(self, request):
return RedirectResponse(self.location)
def convert_match_to_dict(match: dict):
@ -81,6 +68,7 @@ class WebServer(HTTPServer):
custom_handler: Callable = None,
templates_dirs: list[str] = None,
middleware: list[str] = None,
append_slash: bool = False,
):
"""
Create a new server on address, port. Port can be zero.
@ -92,9 +80,21 @@ class WebServer(HTTPServer):
by calling `add_handler("path", function)`.
"""
addr = addr if addr else "localhost"
port = port if port else 7777
port = port if port else 8000
self.append_slash = append_slash
self.templates_dirs = templates_dirs
self.middleware = middleware if middleware else []
self._thread = None
if self.middleware:
middleware_by_reference = []
for m in self.middleware:
try:
middleware_by_reference.append(import_by_string(m)())
except ImportError:
raise ConfigError(f"Middleware '{m}' not found.")
self.middleware = middleware_by_reference
if self.templates_dirs:
self.env = Environment(loader=FileSystemLoader(self.templates_dirs))
else:
@ -106,29 +106,66 @@ class WebServer(HTTPServer):
class HandlerClass(RequestHandler):
pass
# inject template loader, middleware, and other important things into handler
self.handler_class = custom_handler if custom_handler else HandlerClass
self.handler_class.env = self.env
self.handler_class.middleware = self.middleware
self.handler_class.append_slash = self.append_slash
# routed methods map into handler
for method in type(self).__dict__.values():
if hasattr(method, "_routes"):
for route in method._routes:
self.add_route(route, method)
try:
super().__init__(server_address, self.handler_class)
except OSError:
raise GeneralException("Port already in use.")
def convert_path(self, path: str):
"""Convert a path to a regex."""
parts = path.split("/")
for i, part in enumerate(parts):
if part.startswith("<") and part.endswith(">"):
name = part[1:-1]
if "__" in name:
raise ConfigError(
f"Cannot use `__` (double underscore) in path variable."
f" Please fix '{name}'."
)
if ":" in name:
converter, name = name.split(":")
try:
converter = globals()[converter.title() + "Converter"]
except KeyError:
raise ParseError(f"Unknown converter {converter}")
else:
converter = StrConverter # noqa: F405
parts[i] = r"(?P<%s>%s)" % (
f"{name}__{str(converter.__name__)}",
converter.regex,
)
return re.compile(r"^%s$" % "/".join(parts))
def check_for_route_duplicates(self, path: str):
if convert_path(path) in self.handler_class._routes:
if self.convert_path(path) in self.handler_class._routes:
raise ConfigError(f"Route '{path}' already exists.")
def add_route(self, path: str, method: Callable):
"""Add a route to the server."""
if not hasattr(self.handler_class, "_routes"):
setattr(self.handler_class, "_routes", [])
self.check_for_route_duplicates(path)
self.handler_class._routes[convert_path(path)] = method
if self.append_slash and not path.endswith("/"):
updated_path = path + "/"
self.check_for_route_duplicates(updated_path)
self.check_for_route_duplicates(path)
self.handler_class._routes[self.convert_path(path)] = DummyRedirectRoute(updated_path)
self.handler_class._routes[self.convert_path(updated_path)] = method
else:
self.check_for_route_duplicates(path)
self.handler_class._routes[self.convert_path(path)] = method
def route(self, path) -> Callable:
"""
@ -147,18 +184,17 @@ class WebServer(HTTPServer):
"""
def outer(func):
if not hasattr(self.handler_class, "_routes"):
setattr(self.handler_class, "_routes", [])
self.check_for_route_duplicates(path)
self.handler_class._routes[convert_path(path)] = func
self.add_route(path, func)
return func
return outer
@property
def port(self):
"""Return current port."""
return self.socket.getsockname()[1]
@property
def address(self):
"""Return current IP address."""
return self.socket.getsockname()[0]
@ -168,18 +204,26 @@ class WebServer(HTTPServer):
path = path if path else ""
if path.startswith("/"):
path = path[1:]
return "http://" + self.__addr + ":" + str(self.port()) + "/" + path
return self.__addr + ":" + str(self.port()) + "/" + path
def signal_handler(self, sig, frame) -> NoReturn:
log.warning('Shutting down!')
self.stop()
def start(self, blocking=False):
signal.signal(signal.SIGINT, self.signal_handler)
log.info(f"Starting server on {self.address}:{self.port}")
log.info("Press CTRL+C to stop the server.")
self._thread = threading.Thread(target=self.serve_forever)
self._thread.start()
if not blocking:
threading.Thread(target=self.serve_forever).start()
return self._thread
else:
try:
self.serve_forever()
except KeyboardInterrupt:
print() # empty line after ^C
print("Stopping server!")
return
while self._thread.is_alive():
try:
time.sleep(0.2)
except KeyboardInterrupt:
self.stop()
def stop(self):
super().shutdown()
@ -190,6 +234,7 @@ class RequestHandler(BaseHTTPRequestHandler):
# I can't help the naming convention of these because that's what
# BaseHTTPRequestHandler uses for some weird reason
_routes = {}
middleware = []
def get_request(self):
return Request(
@ -233,7 +278,7 @@ class RequestHandler(BaseHTTPRequestHandler):
except KeyError:
return http500
def fire_response(self, resp: HttpResponse):
def _fire_response(self, resp: HttpResponse):
self.send_response(resp.status_code)
content = resp.render()
self.send_header("Content-Length", str(len(content)))
@ -243,38 +288,81 @@ class RequestHandler(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(bytes(content, "utf-8"))
def handle_request(self, request):
def fire_response(self, request: Request, resp: HttpResponse):
try:
request.url = urlparse.urlparse(request.path)
self._fire_response(resp)
except APIError:
raise
except ConnectionAbortedError as e:
log.error(f"GET {self.path} : {e}")
except Exception:
log.error(traceback.format_exc())
self.fire_response(request, self.get_error_route(500)(request))
handler, additional_args = self.get_route(request.url.path)
def process_request_middleware(self, request: Request) -> None | bool:
for middleware in self.middleware:
resp = middleware.process_request(request)
if resp:
self.process_response_middleware(request, resp)
self.fire_response(request, resp)
return True # abort further processing
if request.url.query:
params = urlparse.parse_qs(request.url.query)
else:
params = {}
def process_response_middleware(self, request: Request, response: HttpResponse) -> None:
for middleware in self.middleware:
middleware.process_response(request, response)
request.query_params = params
def prepare_response(self, request, resp) -> HttpResponse:
try:
if isinstance(resp, dict):
self.fire_response(JsonResponse(data=resp))
if isinstance(resp, TemplateResponse):
if hasattr(self, "env"): # injected from above
resp.set_template_loader(self.env)
for middleware in self.middleware:
middleware.process_response(request, resp)
self.fire_response(resp)
except APIError:
raise
except ConnectionAbortedError as e:
log.error(f"GET {self.path} : {e}")
except Exception:
log.error(traceback.format_exc())
self.fire_response(request, self.get_error_route(500)(request))
def handle_request(self, request):
request.url = urlparse.urlparse(request.path)
handler, additional_args = self.get_route(request.url.path)
if request.url.query:
params = urlparse.parse_qs(request.url.query)
else:
params = {}
request.query_params = params
try:
if handler:
try:
resp = handler(request, **additional_args)
if resp is None:
raise NoResponseError(f"View {handler} returned None.")
if isinstance(resp, dict):
self.fire_response(JsonResponse(data=resp))
if isinstance(resp, TemplateResponse):
if hasattr(self, "env"): # injected from above
resp.set_template_loader(self.env)
self.fire_response(resp)
except APIError:
raise
except ConnectionAbortedError as e:
log.error(f"GET {self.path} : {e}")
except Exception:
log.error(traceback.format_exc())
self.fire_response(self.get_error_route(500)(request))
# middleware is injected from WebServer
abort = self.process_request_middleware(request)
if abort:
return
resp = handler(request, **additional_args)
if resp is None:
raise NoResponseError(f"View {handler} returned None.")
if isinstance(resp, dict):
self.fire_response(request, JsonResponse(data=resp))
if isinstance(resp, TemplateResponse):
if hasattr(self, "env"): # injected from above
resp.set_template_loader(self.env)
self.process_response_middleware(request, resp)
self.fire_response(request, resp)
else:
raise APIError(404)
except APIError as e:

View File

@ -16,17 +16,12 @@ class SpiderwebMiddleware:
If `process_request` returns a HttpResponse, the request will be short-circuited
and the response will be returned immediately. `process_response` will not be called.
"""
def process_request(self, request: Request) -> HttpResponse | None:
# example of a middleware that sets a flag on the request
request.spiderweb = True
pass
def process_response(
self, request: Request, response: HttpResponse
) -> HttpResponse | None:
# example of a middleware that sets a header on the resp
if hasattr(request, "spiderweb"):
response.headers["X-Spiderweb"] = "true"
return response
pass

7
spiderweb/utils.py Normal file
View File

@ -0,0 +1,7 @@
def import_by_string(name):
# https://stackoverflow.com/a/547867
components = name.split('.')
mod = __import__(components[0])
for comp in components[1:]:
mod = getattr(mod, comp)
return mod

View File

@ -1,5 +1,10 @@
<h1>FART</h1>
<h1>HI, THIS IS A PAGE</h1>
<p>
This is a test of the {{ value }} template.
</p>
This is a test of the template rendering system. If rendering is working, this value
should be <code>TEST</code>: {{ value }}.
</p>
<p>
The value of <code>request.spiderweb</code> is {{ request.spiderweb }}. If this is True,
middleware is working.
</p>