finish_middleware

This commit is contained in:
Joe Kaufeld 2024-08-09 12:02:46 -04:00
parent 9caca72d24
commit 4a292a282f
6 changed files with 224 additions and 85 deletions

View File

@ -1,7 +1,15 @@
from spiderweb import WebServer from spiderweb import WebServer
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
app = WebServer(templates_dirs=["templates"])
app = WebServer(
templates_dirs=["templates"],
middleware=[
"example_middleware.TestMiddleware",
"example_middleware.RedirectMiddleware"
],
append_slash=False # default
)
@app.route("/") @app.route("/")
@ -14,10 +22,24 @@ def redirect(request):
return RedirectResponse("/") return RedirectResponse("/")
@app.route("/json")
def json(request):
return JsonResponse(data={"key": "value"})
@app.route("/error")
def error(request):
return HttpResponse(status_code=500, body="Internal Server Error")
@app.route("/middleware")
def middleware(request):
return HttpResponse(
body="We'll never hit this because it's redirected in middleware"
)
if __name__ == "__main__": if __name__ == "__main__":
# can also add routes like this:
# app.add_route("/", index) # app.add_route("/", index)
try: app.start()
app.start()
print("Currently serving on", app.uri())
except KeyboardInterrupt:
app.stop()

22
example_middleware.py Normal file
View File

@ -0,0 +1,22 @@
from spiderweb.middleware import SpiderwebMiddleware
from spiderweb.request import Request
from spiderweb.response import HttpResponse, RedirectResponse
class TestMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse | None:
# example of a middleware that sets a flag on the request
request.spiderweb = True
def process_response(
self, request: Request, response: HttpResponse
) -> HttpResponse | None:
# example of a middleware that sets a header on the resp
if hasattr(request, "spiderweb"):
response.headers["X-Spiderweb"] = "true"
class RedirectMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request) -> HttpResponse | None:
if request.path == "/middleware":
return RedirectResponse("/")

View File

@ -2,15 +2,17 @@
# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55 # https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55
# #
# Extensively modified by @itsthejoker # Extensively modified by @itsthejoker
import json import json
import re import re
import signal
import sys
import time
import traceback import traceback
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
import urllib.parse as urlparse import urllib.parse as urlparse
import threading import threading
import logging import logging
from typing import Callable, Any from typing import Callable, Any, NoReturn
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
@ -24,9 +26,12 @@ from spiderweb.exceptions import (
NoResponseError, NoResponseError,
) )
from spiderweb.request import Request from spiderweb.request import Request
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
from spiderweb.utils import import_by_string
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def route(path): def route(path):
@ -39,30 +44,12 @@ def route(path):
return outer return outer
def convert_path(path): class DummyRedirectRoute:
"""Convert a path to a regex.""" def __init__(self, location):
parts = path.split("/") self.location = location
for i, part in enumerate(parts):
if part.startswith("<") and part.endswith(">"): def __call__(self, request):
name = part[1:-1] return RedirectResponse(self.location)
if "__" in name:
raise ConfigError(
f"Cannot use `__` (double underscore) in path variable."
f" Please fix '{name}'."
)
if ":" in name:
converter, name = name.split(":")
try:
converter = globals()[converter.title() + "Converter"]
except KeyError:
raise ParseError(f"Unknown converter {converter}")
else:
converter = StrConverter # noqa: F405
parts[i] = r"(?P<%s>%s)" % (
f"{name}__{str(converter.__name__)}",
converter.regex,
)
return re.compile(r"^%s$" % "/".join(parts))
def convert_match_to_dict(match: dict): def convert_match_to_dict(match: dict):
@ -81,6 +68,7 @@ class WebServer(HTTPServer):
custom_handler: Callable = None, custom_handler: Callable = None,
templates_dirs: list[str] = None, templates_dirs: list[str] = None,
middleware: list[str] = None, middleware: list[str] = None,
append_slash: bool = False,
): ):
""" """
Create a new server on address, port. Port can be zero. Create a new server on address, port. Port can be zero.
@ -92,9 +80,21 @@ class WebServer(HTTPServer):
by calling `add_handler("path", function)`. by calling `add_handler("path", function)`.
""" """
addr = addr if addr else "localhost" addr = addr if addr else "localhost"
port = port if port else 7777 port = port if port else 8000
self.append_slash = append_slash
self.templates_dirs = templates_dirs self.templates_dirs = templates_dirs
self.middleware = middleware if middleware else [] self.middleware = middleware if middleware else []
self._thread = None
if self.middleware:
middleware_by_reference = []
for m in self.middleware:
try:
middleware_by_reference.append(import_by_string(m)())
except ImportError:
raise ConfigError(f"Middleware '{m}' not found.")
self.middleware = middleware_by_reference
if self.templates_dirs: if self.templates_dirs:
self.env = Environment(loader=FileSystemLoader(self.templates_dirs)) self.env = Environment(loader=FileSystemLoader(self.templates_dirs))
else: else:
@ -106,29 +106,66 @@ class WebServer(HTTPServer):
class HandlerClass(RequestHandler): class HandlerClass(RequestHandler):
pass pass
# inject template loader, middleware, and other important things into handler
self.handler_class = custom_handler if custom_handler else HandlerClass self.handler_class = custom_handler if custom_handler else HandlerClass
self.handler_class.env = self.env self.handler_class.env = self.env
self.handler_class.middleware = self.middleware
self.handler_class.append_slash = self.append_slash
# routed methods map into handler # routed methods map into handler
for method in type(self).__dict__.values(): for method in type(self).__dict__.values():
if hasattr(method, "_routes"): if hasattr(method, "_routes"):
for route in method._routes: for route in method._routes:
self.add_route(route, method) self.add_route(route, method)
try: try:
super().__init__(server_address, self.handler_class) super().__init__(server_address, self.handler_class)
except OSError: except OSError:
raise GeneralException("Port already in use.") raise GeneralException("Port already in use.")
def convert_path(self, path: str):
"""Convert a path to a regex."""
parts = path.split("/")
for i, part in enumerate(parts):
if part.startswith("<") and part.endswith(">"):
name = part[1:-1]
if "__" in name:
raise ConfigError(
f"Cannot use `__` (double underscore) in path variable."
f" Please fix '{name}'."
)
if ":" in name:
converter, name = name.split(":")
try:
converter = globals()[converter.title() + "Converter"]
except KeyError:
raise ParseError(f"Unknown converter {converter}")
else:
converter = StrConverter # noqa: F405
parts[i] = r"(?P<%s>%s)" % (
f"{name}__{str(converter.__name__)}",
converter.regex,
)
return re.compile(r"^%s$" % "/".join(parts))
def check_for_route_duplicates(self, path: str): def check_for_route_duplicates(self, path: str):
if convert_path(path) in self.handler_class._routes: if self.convert_path(path) in self.handler_class._routes:
raise ConfigError(f"Route '{path}' already exists.") raise ConfigError(f"Route '{path}' already exists.")
def add_route(self, path: str, method: Callable): def add_route(self, path: str, method: Callable):
"""Add a route to the server.""" """Add a route to the server."""
if not hasattr(self.handler_class, "_routes"): if not hasattr(self.handler_class, "_routes"):
setattr(self.handler_class, "_routes", []) setattr(self.handler_class, "_routes", [])
self.check_for_route_duplicates(path)
self.handler_class._routes[convert_path(path)] = method if self.append_slash and not path.endswith("/"):
updated_path = path + "/"
self.check_for_route_duplicates(updated_path)
self.check_for_route_duplicates(path)
self.handler_class._routes[self.convert_path(path)] = DummyRedirectRoute(updated_path)
self.handler_class._routes[self.convert_path(updated_path)] = method
else:
self.check_for_route_duplicates(path)
self.handler_class._routes[self.convert_path(path)] = method
def route(self, path) -> Callable: def route(self, path) -> Callable:
""" """
@ -147,18 +184,17 @@ class WebServer(HTTPServer):
""" """
def outer(func): def outer(func):
if not hasattr(self.handler_class, "_routes"): self.add_route(path, func)
setattr(self.handler_class, "_routes", [])
self.check_for_route_duplicates(path)
self.handler_class._routes[convert_path(path)] = func
return func return func
return outer return outer
@property
def port(self): def port(self):
"""Return current port.""" """Return current port."""
return self.socket.getsockname()[1] return self.socket.getsockname()[1]
@property
def address(self): def address(self):
"""Return current IP address.""" """Return current IP address."""
return self.socket.getsockname()[0] return self.socket.getsockname()[0]
@ -168,18 +204,26 @@ class WebServer(HTTPServer):
path = path if path else "" path = path if path else ""
if path.startswith("/"): if path.startswith("/"):
path = path[1:] path = path[1:]
return "http://" + self.__addr + ":" + str(self.port()) + "/" + path return self.__addr + ":" + str(self.port()) + "/" + path
def signal_handler(self, sig, frame) -> NoReturn:
log.warning('Shutting down!')
self.stop()
def start(self, blocking=False): def start(self, blocking=False):
signal.signal(signal.SIGINT, self.signal_handler)
log.info(f"Starting server on {self.address}:{self.port}")
log.info("Press CTRL+C to stop the server.")
self._thread = threading.Thread(target=self.serve_forever)
self._thread.start()
if not blocking: if not blocking:
threading.Thread(target=self.serve_forever).start() return self._thread
else: else:
try: while self._thread.is_alive():
self.serve_forever() try:
except KeyboardInterrupt: time.sleep(0.2)
print() # empty line after ^C except KeyboardInterrupt:
print("Stopping server!") self.stop()
return
def stop(self): def stop(self):
super().shutdown() super().shutdown()
@ -190,6 +234,7 @@ class RequestHandler(BaseHTTPRequestHandler):
# I can't help the naming convention of these because that's what # I can't help the naming convention of these because that's what
# BaseHTTPRequestHandler uses for some weird reason # BaseHTTPRequestHandler uses for some weird reason
_routes = {} _routes = {}
middleware = []
def get_request(self): def get_request(self):
return Request( return Request(
@ -233,7 +278,7 @@ class RequestHandler(BaseHTTPRequestHandler):
except KeyError: except KeyError:
return http500 return http500
def fire_response(self, resp: HttpResponse): def _fire_response(self, resp: HttpResponse):
self.send_response(resp.status_code) self.send_response(resp.status_code)
content = resp.render() content = resp.render()
self.send_header("Content-Length", str(len(content))) self.send_header("Content-Length", str(len(content)))
@ -243,38 +288,81 @@ class RequestHandler(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
self.wfile.write(bytes(content, "utf-8")) self.wfile.write(bytes(content, "utf-8"))
def handle_request(self, request): def fire_response(self, request: Request, resp: HttpResponse):
try: try:
request.url = urlparse.urlparse(request.path) self._fire_response(resp)
except APIError:
raise
except ConnectionAbortedError as e:
log.error(f"GET {self.path} : {e}")
except Exception:
log.error(traceback.format_exc())
self.fire_response(request, self.get_error_route(500)(request))
handler, additional_args = self.get_route(request.url.path) def process_request_middleware(self, request: Request) -> None | bool:
for middleware in self.middleware:
resp = middleware.process_request(request)
if resp:
self.process_response_middleware(request, resp)
self.fire_response(request, resp)
return True # abort further processing
if request.url.query: def process_response_middleware(self, request: Request, response: HttpResponse) -> None:
params = urlparse.parse_qs(request.url.query) for middleware in self.middleware:
else: middleware.process_response(request, response)
params = {}
request.query_params = params def prepare_response(self, request, resp) -> HttpResponse:
try:
if isinstance(resp, dict):
self.fire_response(JsonResponse(data=resp))
if isinstance(resp, TemplateResponse):
if hasattr(self, "env"): # injected from above
resp.set_template_loader(self.env)
for middleware in self.middleware:
middleware.process_response(request, resp)
self.fire_response(resp)
except APIError:
raise
except ConnectionAbortedError as e:
log.error(f"GET {self.path} : {e}")
except Exception:
log.error(traceback.format_exc())
self.fire_response(request, self.get_error_route(500)(request))
def handle_request(self, request):
request.url = urlparse.urlparse(request.path)
handler, additional_args = self.get_route(request.url.path)
if request.url.query:
params = urlparse.parse_qs(request.url.query)
else:
params = {}
request.query_params = params
try:
if handler: if handler:
try: # middleware is injected from WebServer
resp = handler(request, **additional_args) abort = self.process_request_middleware(request)
if resp is None: if abort:
raise NoResponseError(f"View {handler} returned None.") return
if isinstance(resp, dict):
self.fire_response(JsonResponse(data=resp))
if isinstance(resp, TemplateResponse):
if hasattr(self, "env"): # injected from above
resp.set_template_loader(self.env)
self.fire_response(resp)
except APIError:
raise
except ConnectionAbortedError as e:
log.error(f"GET {self.path} : {e}")
except Exception:
log.error(traceback.format_exc())
self.fire_response(self.get_error_route(500)(request))
resp = handler(request, **additional_args)
if resp is None:
raise NoResponseError(f"View {handler} returned None.")
if isinstance(resp, dict):
self.fire_response(request, JsonResponse(data=resp))
if isinstance(resp, TemplateResponse):
if hasattr(self, "env"): # injected from above
resp.set_template_loader(self.env)
self.process_response_middleware(request, resp)
self.fire_response(request, resp)
else: else:
raise APIError(404) raise APIError(404)
except APIError as e: except APIError as e:

View File

@ -16,17 +16,12 @@ class SpiderwebMiddleware:
If `process_request` returns a HttpResponse, the request will be short-circuited If `process_request` returns a HttpResponse, the request will be short-circuited
and the response will be returned immediately. `process_response` will not be called. and the response will be returned immediately. `process_response` will not be called.
""" """
def process_request(self, request: Request) -> HttpResponse | None: def process_request(self, request: Request) -> HttpResponse | None:
# example of a middleware that sets a flag on the request pass
request.spiderweb = True
def process_response( def process_response(
self, request: Request, response: HttpResponse self, request: Request, response: HttpResponse
) -> HttpResponse | None: ) -> HttpResponse | None:
# example of a middleware that sets a header on the resp pass
if hasattr(request, "spiderweb"):
response.headers["X-Spiderweb"] = "true"
return response

7
spiderweb/utils.py Normal file
View File

@ -0,0 +1,7 @@
def import_by_string(name):
# https://stackoverflow.com/a/547867
components = name.split('.')
mod = __import__(components[0])
for comp in components[1:]:
mod = getattr(mod, comp)
return mod

View File

@ -1,5 +1,10 @@
<h1>FART</h1> <h1>HI, THIS IS A PAGE</h1>
<p> <p>
This is a test of the {{ value }} template. This is a test of the template rendering system. If rendering is working, this value
should be <code>TEST</code>: {{ value }}.
</p>
<p>
The value of <code>request.spiderweb</code> is {{ request.spiderweb }}. If this is True,
middleware is working.
</p> </p>