✨ finish_middleware
This commit is contained in:
parent
9caca72d24
commit
4a292a282f
34
example.py
34
example.py
@ -1,7 +1,15 @@
|
||||
from spiderweb import WebServer
|
||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
||||
|
||||
app = WebServer(templates_dirs=["templates"])
|
||||
|
||||
app = WebServer(
|
||||
templates_dirs=["templates"],
|
||||
middleware=[
|
||||
"example_middleware.TestMiddleware",
|
||||
"example_middleware.RedirectMiddleware"
|
||||
],
|
||||
append_slash=False # default
|
||||
)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
@ -14,10 +22,24 @@ def redirect(request):
|
||||
return RedirectResponse("/")
|
||||
|
||||
|
||||
@app.route("/json")
|
||||
def json(request):
|
||||
return JsonResponse(data={"key": "value"})
|
||||
|
||||
|
||||
@app.route("/error")
|
||||
def error(request):
|
||||
return HttpResponse(status_code=500, body="Internal Server Error")
|
||||
|
||||
|
||||
@app.route("/middleware")
|
||||
def middleware(request):
|
||||
return HttpResponse(
|
||||
body="We'll never hit this because it's redirected in middleware"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# can also add routes like this:
|
||||
# app.add_route("/", index)
|
||||
try:
|
||||
app.start()
|
||||
print("Currently serving on", app.uri())
|
||||
except KeyboardInterrupt:
|
||||
app.stop()
|
||||
app.start()
|
||||
|
22
example_middleware.py
Normal file
22
example_middleware.py
Normal file
@ -0,0 +1,22 @@
|
||||
from spiderweb.middleware import SpiderwebMiddleware
|
||||
from spiderweb.request import Request
|
||||
from spiderweb.response import HttpResponse, RedirectResponse
|
||||
|
||||
|
||||
class TestMiddleware(SpiderwebMiddleware):
|
||||
def process_request(self, request: Request) -> HttpResponse | None:
|
||||
# example of a middleware that sets a flag on the request
|
||||
request.spiderweb = True
|
||||
|
||||
def process_response(
|
||||
self, request: Request, response: HttpResponse
|
||||
) -> HttpResponse | None:
|
||||
# example of a middleware that sets a header on the resp
|
||||
if hasattr(request, "spiderweb"):
|
||||
response.headers["X-Spiderweb"] = "true"
|
||||
|
||||
|
||||
class RedirectMiddleware(SpiderwebMiddleware):
|
||||
def process_request(self, request: Request) -> HttpResponse | None:
|
||||
if request.path == "/middleware":
|
||||
return RedirectResponse("/")
|
@ -2,15 +2,17 @@
|
||||
# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55
|
||||
#
|
||||
# Extensively modified by @itsthejoker
|
||||
|
||||
import json
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
import urllib.parse as urlparse
|
||||
import threading
|
||||
import logging
|
||||
from typing import Callable, Any
|
||||
from typing import Callable, Any, NoReturn
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
@ -24,9 +26,12 @@ from spiderweb.exceptions import (
|
||||
NoResponseError,
|
||||
)
|
||||
from spiderweb.request import Request
|
||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse
|
||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
||||
from spiderweb.utils import import_by_string
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def route(path):
|
||||
@ -39,30 +44,12 @@ def route(path):
|
||||
return outer
|
||||
|
||||
|
||||
def convert_path(path):
|
||||
"""Convert a path to a regex."""
|
||||
parts = path.split("/")
|
||||
for i, part in enumerate(parts):
|
||||
if part.startswith("<") and part.endswith(">"):
|
||||
name = part[1:-1]
|
||||
if "__" in name:
|
||||
raise ConfigError(
|
||||
f"Cannot use `__` (double underscore) in path variable."
|
||||
f" Please fix '{name}'."
|
||||
)
|
||||
if ":" in name:
|
||||
converter, name = name.split(":")
|
||||
try:
|
||||
converter = globals()[converter.title() + "Converter"]
|
||||
except KeyError:
|
||||
raise ParseError(f"Unknown converter {converter}")
|
||||
else:
|
||||
converter = StrConverter # noqa: F405
|
||||
parts[i] = r"(?P<%s>%s)" % (
|
||||
f"{name}__{str(converter.__name__)}",
|
||||
converter.regex,
|
||||
)
|
||||
return re.compile(r"^%s$" % "/".join(parts))
|
||||
class DummyRedirectRoute:
|
||||
def __init__(self, location):
|
||||
self.location = location
|
||||
|
||||
def __call__(self, request):
|
||||
return RedirectResponse(self.location)
|
||||
|
||||
|
||||
def convert_match_to_dict(match: dict):
|
||||
@ -81,6 +68,7 @@ class WebServer(HTTPServer):
|
||||
custom_handler: Callable = None,
|
||||
templates_dirs: list[str] = None,
|
||||
middleware: list[str] = None,
|
||||
append_slash: bool = False,
|
||||
):
|
||||
"""
|
||||
Create a new server on address, port. Port can be zero.
|
||||
@ -92,9 +80,21 @@ class WebServer(HTTPServer):
|
||||
by calling `add_handler("path", function)`.
|
||||
"""
|
||||
addr = addr if addr else "localhost"
|
||||
port = port if port else 7777
|
||||
port = port if port else 8000
|
||||
self.append_slash = append_slash
|
||||
self.templates_dirs = templates_dirs
|
||||
self.middleware = middleware if middleware else []
|
||||
self._thread = None
|
||||
|
||||
if self.middleware:
|
||||
middleware_by_reference = []
|
||||
for m in self.middleware:
|
||||
try:
|
||||
middleware_by_reference.append(import_by_string(m)())
|
||||
except ImportError:
|
||||
raise ConfigError(f"Middleware '{m}' not found.")
|
||||
self.middleware = middleware_by_reference
|
||||
|
||||
if self.templates_dirs:
|
||||
self.env = Environment(loader=FileSystemLoader(self.templates_dirs))
|
||||
else:
|
||||
@ -106,29 +106,66 @@ class WebServer(HTTPServer):
|
||||
class HandlerClass(RequestHandler):
|
||||
pass
|
||||
|
||||
# inject template loader, middleware, and other important things into handler
|
||||
self.handler_class = custom_handler if custom_handler else HandlerClass
|
||||
self.handler_class.env = self.env
|
||||
self.handler_class.middleware = self.middleware
|
||||
self.handler_class.append_slash = self.append_slash
|
||||
|
||||
# routed methods map into handler
|
||||
for method in type(self).__dict__.values():
|
||||
if hasattr(method, "_routes"):
|
||||
for route in method._routes:
|
||||
self.add_route(route, method)
|
||||
|
||||
try:
|
||||
super().__init__(server_address, self.handler_class)
|
||||
except OSError:
|
||||
raise GeneralException("Port already in use.")
|
||||
|
||||
def convert_path(self, path: str):
|
||||
"""Convert a path to a regex."""
|
||||
parts = path.split("/")
|
||||
for i, part in enumerate(parts):
|
||||
if part.startswith("<") and part.endswith(">"):
|
||||
name = part[1:-1]
|
||||
if "__" in name:
|
||||
raise ConfigError(
|
||||
f"Cannot use `__` (double underscore) in path variable."
|
||||
f" Please fix '{name}'."
|
||||
)
|
||||
if ":" in name:
|
||||
converter, name = name.split(":")
|
||||
try:
|
||||
converter = globals()[converter.title() + "Converter"]
|
||||
except KeyError:
|
||||
raise ParseError(f"Unknown converter {converter}")
|
||||
else:
|
||||
converter = StrConverter # noqa: F405
|
||||
parts[i] = r"(?P<%s>%s)" % (
|
||||
f"{name}__{str(converter.__name__)}",
|
||||
converter.regex,
|
||||
)
|
||||
return re.compile(r"^%s$" % "/".join(parts))
|
||||
|
||||
def check_for_route_duplicates(self, path: str):
|
||||
if convert_path(path) in self.handler_class._routes:
|
||||
if self.convert_path(path) in self.handler_class._routes:
|
||||
raise ConfigError(f"Route '{path}' already exists.")
|
||||
|
||||
def add_route(self, path: str, method: Callable):
|
||||
"""Add a route to the server."""
|
||||
if not hasattr(self.handler_class, "_routes"):
|
||||
setattr(self.handler_class, "_routes", [])
|
||||
self.check_for_route_duplicates(path)
|
||||
self.handler_class._routes[convert_path(path)] = method
|
||||
|
||||
if self.append_slash and not path.endswith("/"):
|
||||
updated_path = path + "/"
|
||||
self.check_for_route_duplicates(updated_path)
|
||||
self.check_for_route_duplicates(path)
|
||||
self.handler_class._routes[self.convert_path(path)] = DummyRedirectRoute(updated_path)
|
||||
self.handler_class._routes[self.convert_path(updated_path)] = method
|
||||
else:
|
||||
self.check_for_route_duplicates(path)
|
||||
self.handler_class._routes[self.convert_path(path)] = method
|
||||
|
||||
def route(self, path) -> Callable:
|
||||
"""
|
||||
@ -147,18 +184,17 @@ class WebServer(HTTPServer):
|
||||
"""
|
||||
|
||||
def outer(func):
|
||||
if not hasattr(self.handler_class, "_routes"):
|
||||
setattr(self.handler_class, "_routes", [])
|
||||
self.check_for_route_duplicates(path)
|
||||
self.handler_class._routes[convert_path(path)] = func
|
||||
self.add_route(path, func)
|
||||
return func
|
||||
|
||||
return outer
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
"""Return current port."""
|
||||
return self.socket.getsockname()[1]
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
"""Return current IP address."""
|
||||
return self.socket.getsockname()[0]
|
||||
@ -168,18 +204,26 @@ class WebServer(HTTPServer):
|
||||
path = path if path else ""
|
||||
if path.startswith("/"):
|
||||
path = path[1:]
|
||||
return "http://" + self.__addr + ":" + str(self.port()) + "/" + path
|
||||
return self.__addr + ":" + str(self.port()) + "/" + path
|
||||
|
||||
def signal_handler(self, sig, frame) -> NoReturn:
|
||||
log.warning('Shutting down!')
|
||||
self.stop()
|
||||
|
||||
def start(self, blocking=False):
|
||||
signal.signal(signal.SIGINT, self.signal_handler)
|
||||
log.info(f"Starting server on {self.address}:{self.port}")
|
||||
log.info("Press CTRL+C to stop the server.")
|
||||
self._thread = threading.Thread(target=self.serve_forever)
|
||||
self._thread.start()
|
||||
if not blocking:
|
||||
threading.Thread(target=self.serve_forever).start()
|
||||
return self._thread
|
||||
else:
|
||||
try:
|
||||
self.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print() # empty line after ^C
|
||||
print("Stopping server!")
|
||||
return
|
||||
while self._thread.is_alive():
|
||||
try:
|
||||
time.sleep(0.2)
|
||||
except KeyboardInterrupt:
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
super().shutdown()
|
||||
@ -190,6 +234,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||
# I can't help the naming convention of these because that's what
|
||||
# BaseHTTPRequestHandler uses for some weird reason
|
||||
_routes = {}
|
||||
middleware = []
|
||||
|
||||
def get_request(self):
|
||||
return Request(
|
||||
@ -233,7 +278,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||
except KeyError:
|
||||
return http500
|
||||
|
||||
def fire_response(self, resp: HttpResponse):
|
||||
def _fire_response(self, resp: HttpResponse):
|
||||
self.send_response(resp.status_code)
|
||||
content = resp.render()
|
||||
self.send_header("Content-Length", str(len(content)))
|
||||
@ -243,38 +288,81 @@ class RequestHandler(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
self.wfile.write(bytes(content, "utf-8"))
|
||||
|
||||
def handle_request(self, request):
|
||||
def fire_response(self, request: Request, resp: HttpResponse):
|
||||
try:
|
||||
request.url = urlparse.urlparse(request.path)
|
||||
self._fire_response(resp)
|
||||
except APIError:
|
||||
raise
|
||||
except ConnectionAbortedError as e:
|
||||
log.error(f"GET {self.path} : {e}")
|
||||
except Exception:
|
||||
log.error(traceback.format_exc())
|
||||
self.fire_response(request, self.get_error_route(500)(request))
|
||||
|
||||
handler, additional_args = self.get_route(request.url.path)
|
||||
def process_request_middleware(self, request: Request) -> None | bool:
|
||||
for middleware in self.middleware:
|
||||
resp = middleware.process_request(request)
|
||||
if resp:
|
||||
self.process_response_middleware(request, resp)
|
||||
self.fire_response(request, resp)
|
||||
return True # abort further processing
|
||||
|
||||
if request.url.query:
|
||||
params = urlparse.parse_qs(request.url.query)
|
||||
else:
|
||||
params = {}
|
||||
def process_response_middleware(self, request: Request, response: HttpResponse) -> None:
|
||||
for middleware in self.middleware:
|
||||
middleware.process_response(request, response)
|
||||
|
||||
request.query_params = params
|
||||
def prepare_response(self, request, resp) -> HttpResponse:
|
||||
try:
|
||||
if isinstance(resp, dict):
|
||||
self.fire_response(JsonResponse(data=resp))
|
||||
if isinstance(resp, TemplateResponse):
|
||||
if hasattr(self, "env"): # injected from above
|
||||
resp.set_template_loader(self.env)
|
||||
|
||||
for middleware in self.middleware:
|
||||
middleware.process_response(request, resp)
|
||||
|
||||
self.fire_response(resp)
|
||||
|
||||
except APIError:
|
||||
raise
|
||||
|
||||
except ConnectionAbortedError as e:
|
||||
log.error(f"GET {self.path} : {e}")
|
||||
except Exception:
|
||||
log.error(traceback.format_exc())
|
||||
self.fire_response(request, self.get_error_route(500)(request))
|
||||
|
||||
def handle_request(self, request):
|
||||
|
||||
request.url = urlparse.urlparse(request.path)
|
||||
|
||||
handler, additional_args = self.get_route(request.url.path)
|
||||
|
||||
if request.url.query:
|
||||
params = urlparse.parse_qs(request.url.query)
|
||||
else:
|
||||
params = {}
|
||||
|
||||
request.query_params = params
|
||||
try:
|
||||
if handler:
|
||||
try:
|
||||
resp = handler(request, **additional_args)
|
||||
if resp is None:
|
||||
raise NoResponseError(f"View {handler} returned None.")
|
||||
if isinstance(resp, dict):
|
||||
self.fire_response(JsonResponse(data=resp))
|
||||
if isinstance(resp, TemplateResponse):
|
||||
if hasattr(self, "env"): # injected from above
|
||||
resp.set_template_loader(self.env)
|
||||
self.fire_response(resp)
|
||||
except APIError:
|
||||
raise
|
||||
except ConnectionAbortedError as e:
|
||||
log.error(f"GET {self.path} : {e}")
|
||||
except Exception:
|
||||
log.error(traceback.format_exc())
|
||||
self.fire_response(self.get_error_route(500)(request))
|
||||
# middleware is injected from WebServer
|
||||
abort = self.process_request_middleware(request)
|
||||
if abort:
|
||||
return
|
||||
|
||||
resp = handler(request, **additional_args)
|
||||
if resp is None:
|
||||
raise NoResponseError(f"View {handler} returned None.")
|
||||
if isinstance(resp, dict):
|
||||
self.fire_response(request, JsonResponse(data=resp))
|
||||
if isinstance(resp, TemplateResponse):
|
||||
if hasattr(self, "env"): # injected from above
|
||||
resp.set_template_loader(self.env)
|
||||
|
||||
self.process_response_middleware(request, resp)
|
||||
self.fire_response(request, resp)
|
||||
else:
|
||||
raise APIError(404)
|
||||
except APIError as e:
|
||||
|
@ -16,17 +16,12 @@ class SpiderwebMiddleware:
|
||||
|
||||
If `process_request` returns a HttpResponse, the request will be short-circuited
|
||||
and the response will be returned immediately. `process_response` will not be called.
|
||||
|
||||
"""
|
||||
|
||||
def process_request(self, request: Request) -> HttpResponse | None:
|
||||
# example of a middleware that sets a flag on the request
|
||||
request.spiderweb = True
|
||||
pass
|
||||
|
||||
def process_response(
|
||||
self, request: Request, response: HttpResponse
|
||||
) -> HttpResponse | None:
|
||||
# example of a middleware that sets a header on the resp
|
||||
if hasattr(request, "spiderweb"):
|
||||
response.headers["X-Spiderweb"] = "true"
|
||||
return response
|
||||
pass
|
||||
|
7
spiderweb/utils.py
Normal file
7
spiderweb/utils.py
Normal file
@ -0,0 +1,7 @@
|
||||
def import_by_string(name):
|
||||
# https://stackoverflow.com/a/547867
|
||||
components = name.split('.')
|
||||
mod = __import__(components[0])
|
||||
for comp in components[1:]:
|
||||
mod = getattr(mod, comp)
|
||||
return mod
|
@ -1,5 +1,10 @@
|
||||
<h1>FART</h1>
|
||||
<h1>HI, THIS IS A PAGE</h1>
|
||||
|
||||
<p>
|
||||
This is a test of the {{ value }} template.
|
||||
This is a test of the template rendering system. If rendering is working, this value
|
||||
should be <code>TEST</code>: {{ value }}.
|
||||
</p>
|
||||
<p>
|
||||
The value of <code>request.spiderweb</code> is {{ request.spiderweb }}. If this is True,
|
||||
middleware is working.
|
||||
</p>
|
Loading…
Reference in New Issue
Block a user