✨ finish_middleware
This commit is contained in:
parent
9caca72d24
commit
4a292a282f
32
example.py
32
example.py
@ -1,7 +1,15 @@
|
|||||||
from spiderweb import WebServer
|
from spiderweb import WebServer
|
||||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
||||||
|
|
||||||
app = WebServer(templates_dirs=["templates"])
|
|
||||||
|
app = WebServer(
|
||||||
|
templates_dirs=["templates"],
|
||||||
|
middleware=[
|
||||||
|
"example_middleware.TestMiddleware",
|
||||||
|
"example_middleware.RedirectMiddleware"
|
||||||
|
],
|
||||||
|
append_slash=False # default
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
@ -14,10 +22,24 @@ def redirect(request):
|
|||||||
return RedirectResponse("/")
|
return RedirectResponse("/")
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/json")
|
||||||
|
def json(request):
|
||||||
|
return JsonResponse(data={"key": "value"})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/error")
|
||||||
|
def error(request):
|
||||||
|
return HttpResponse(status_code=500, body="Internal Server Error")
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/middleware")
|
||||||
|
def middleware(request):
|
||||||
|
return HttpResponse(
|
||||||
|
body="We'll never hit this because it's redirected in middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# can also add routes like this:
|
||||||
# app.add_route("/", index)
|
# app.add_route("/", index)
|
||||||
try:
|
|
||||||
app.start()
|
app.start()
|
||||||
print("Currently serving on", app.uri())
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
app.stop()
|
|
||||||
|
22
example_middleware.py
Normal file
22
example_middleware.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from spiderweb.middleware import SpiderwebMiddleware
|
||||||
|
from spiderweb.request import Request
|
||||||
|
from spiderweb.response import HttpResponse, RedirectResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiddleware(SpiderwebMiddleware):
|
||||||
|
def process_request(self, request: Request) -> HttpResponse | None:
|
||||||
|
# example of a middleware that sets a flag on the request
|
||||||
|
request.spiderweb = True
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self, request: Request, response: HttpResponse
|
||||||
|
) -> HttpResponse | None:
|
||||||
|
# example of a middleware that sets a header on the resp
|
||||||
|
if hasattr(request, "spiderweb"):
|
||||||
|
response.headers["X-Spiderweb"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class RedirectMiddleware(SpiderwebMiddleware):
|
||||||
|
def process_request(self, request: Request) -> HttpResponse | None:
|
||||||
|
if request.path == "/middleware":
|
||||||
|
return RedirectResponse("/")
|
@ -2,15 +2,17 @@
|
|||||||
# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55
|
# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55
|
||||||
#
|
#
|
||||||
# Extensively modified by @itsthejoker
|
# Extensively modified by @itsthejoker
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
import urllib.parse as urlparse
|
import urllib.parse as urlparse
|
||||||
import threading
|
import threading
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, Any
|
from typing import Callable, Any, NoReturn
|
||||||
|
|
||||||
from jinja2 import Environment, FileSystemLoader
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
|
||||||
@ -24,9 +26,12 @@ from spiderweb.exceptions import (
|
|||||||
NoResponseError,
|
NoResponseError,
|
||||||
)
|
)
|
||||||
from spiderweb.request import Request
|
from spiderweb.request import Request
|
||||||
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse
|
from spiderweb.response import HttpResponse, JsonResponse, TemplateResponse, RedirectResponse
|
||||||
|
from spiderweb.utils import import_by_string
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def route(path):
|
def route(path):
|
||||||
@ -39,7 +44,86 @@ def route(path):
|
|||||||
return outer
|
return outer
|
||||||
|
|
||||||
|
|
||||||
def convert_path(path):
|
class DummyRedirectRoute:
|
||||||
|
def __init__(self, location):
|
||||||
|
self.location = location
|
||||||
|
|
||||||
|
def __call__(self, request):
|
||||||
|
return RedirectResponse(self.location)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_match_to_dict(match: dict):
|
||||||
|
"""Convert a match object to a dict with the proper converted types for each match."""
|
||||||
|
return {
|
||||||
|
k.split("__")[0]: globals()[k.split("__")[1]]().to_python(v)
|
||||||
|
for k, v in match.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class WebServer(HTTPServer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
addr: str = None,
|
||||||
|
port: int = None,
|
||||||
|
custom_handler: Callable = None,
|
||||||
|
templates_dirs: list[str] = None,
|
||||||
|
middleware: list[str] = None,
|
||||||
|
append_slash: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new server on address, port. Port can be zero.
|
||||||
|
|
||||||
|
> from simple_rpc_server import WebServer, APIError, route
|
||||||
|
|
||||||
|
Create your handlers by inheriting from WebServer and tagging them with
|
||||||
|
@route("/path"). Alternately, you can use the WebServer() directly
|
||||||
|
by calling `add_handler("path", function)`.
|
||||||
|
"""
|
||||||
|
addr = addr if addr else "localhost"
|
||||||
|
port = port if port else 8000
|
||||||
|
self.append_slash = append_slash
|
||||||
|
self.templates_dirs = templates_dirs
|
||||||
|
self.middleware = middleware if middleware else []
|
||||||
|
self._thread = None
|
||||||
|
|
||||||
|
if self.middleware:
|
||||||
|
middleware_by_reference = []
|
||||||
|
for m in self.middleware:
|
||||||
|
try:
|
||||||
|
middleware_by_reference.append(import_by_string(m)())
|
||||||
|
except ImportError:
|
||||||
|
raise ConfigError(f"Middleware '{m}' not found.")
|
||||||
|
self.middleware = middleware_by_reference
|
||||||
|
|
||||||
|
if self.templates_dirs:
|
||||||
|
self.env = Environment(loader=FileSystemLoader(self.templates_dirs))
|
||||||
|
else:
|
||||||
|
self.env = None
|
||||||
|
server_address = (addr, port)
|
||||||
|
self.__addr = addr
|
||||||
|
|
||||||
|
# shim class that is an RequestHandler
|
||||||
|
class HandlerClass(RequestHandler):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# inject template loader, middleware, and other important things into handler
|
||||||
|
self.handler_class = custom_handler if custom_handler else HandlerClass
|
||||||
|
self.handler_class.env = self.env
|
||||||
|
self.handler_class.middleware = self.middleware
|
||||||
|
self.handler_class.append_slash = self.append_slash
|
||||||
|
|
||||||
|
# routed methods map into handler
|
||||||
|
for method in type(self).__dict__.values():
|
||||||
|
if hasattr(method, "_routes"):
|
||||||
|
for route in method._routes:
|
||||||
|
self.add_route(route, method)
|
||||||
|
|
||||||
|
try:
|
||||||
|
super().__init__(server_address, self.handler_class)
|
||||||
|
except OSError:
|
||||||
|
raise GeneralException("Port already in use.")
|
||||||
|
|
||||||
|
def convert_path(self, path: str):
|
||||||
"""Convert a path to a regex."""
|
"""Convert a path to a regex."""
|
||||||
parts = path.split("/")
|
parts = path.split("/")
|
||||||
for i, part in enumerate(parts):
|
for i, part in enumerate(parts):
|
||||||
@ -64,71 +148,24 @@ def convert_path(path):
|
|||||||
)
|
)
|
||||||
return re.compile(r"^%s$" % "/".join(parts))
|
return re.compile(r"^%s$" % "/".join(parts))
|
||||||
|
|
||||||
|
|
||||||
def convert_match_to_dict(match: dict):
|
|
||||||
"""Convert a match object to a dict with the proper converted types for each match."""
|
|
||||||
return {
|
|
||||||
k.split("__")[0]: globals()[k.split("__")[1]]().to_python(v)
|
|
||||||
for k, v in match.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class WebServer(HTTPServer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
addr: str = None,
|
|
||||||
port: int = None,
|
|
||||||
custom_handler: Callable = None,
|
|
||||||
templates_dirs: list[str] = None,
|
|
||||||
middleware: list[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new server on address, port. Port can be zero.
|
|
||||||
|
|
||||||
> from simple_rpc_server import WebServer, APIError, route
|
|
||||||
|
|
||||||
Create your handlers by inheriting from WebServer and tagging them with
|
|
||||||
@route("/path"). Alternately, you can use the WebServer() directly
|
|
||||||
by calling `add_handler("path", function)`.
|
|
||||||
"""
|
|
||||||
addr = addr if addr else "localhost"
|
|
||||||
port = port if port else 7777
|
|
||||||
self.templates_dirs = templates_dirs
|
|
||||||
self.middleware = middleware if middleware else []
|
|
||||||
if self.templates_dirs:
|
|
||||||
self.env = Environment(loader=FileSystemLoader(self.templates_dirs))
|
|
||||||
else:
|
|
||||||
self.env = None
|
|
||||||
server_address = (addr, port)
|
|
||||||
self.__addr = addr
|
|
||||||
|
|
||||||
# shim class that is an RequestHandler
|
|
||||||
class HandlerClass(RequestHandler):
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.handler_class = custom_handler if custom_handler else HandlerClass
|
|
||||||
self.handler_class.env = self.env
|
|
||||||
|
|
||||||
# routed methods map into handler
|
|
||||||
for method in type(self).__dict__.values():
|
|
||||||
if hasattr(method, "_routes"):
|
|
||||||
for route in method._routes:
|
|
||||||
self.add_route(route, method)
|
|
||||||
try:
|
|
||||||
super().__init__(server_address, self.handler_class)
|
|
||||||
except OSError:
|
|
||||||
raise GeneralException("Port already in use.")
|
|
||||||
|
|
||||||
def check_for_route_duplicates(self, path: str):
|
def check_for_route_duplicates(self, path: str):
|
||||||
if convert_path(path) in self.handler_class._routes:
|
if self.convert_path(path) in self.handler_class._routes:
|
||||||
raise ConfigError(f"Route '{path}' already exists.")
|
raise ConfigError(f"Route '{path}' already exists.")
|
||||||
|
|
||||||
def add_route(self, path: str, method: Callable):
|
def add_route(self, path: str, method: Callable):
|
||||||
"""Add a route to the server."""
|
"""Add a route to the server."""
|
||||||
if not hasattr(self.handler_class, "_routes"):
|
if not hasattr(self.handler_class, "_routes"):
|
||||||
setattr(self.handler_class, "_routes", [])
|
setattr(self.handler_class, "_routes", [])
|
||||||
|
|
||||||
|
if self.append_slash and not path.endswith("/"):
|
||||||
|
updated_path = path + "/"
|
||||||
|
self.check_for_route_duplicates(updated_path)
|
||||||
self.check_for_route_duplicates(path)
|
self.check_for_route_duplicates(path)
|
||||||
self.handler_class._routes[convert_path(path)] = method
|
self.handler_class._routes[self.convert_path(path)] = DummyRedirectRoute(updated_path)
|
||||||
|
self.handler_class._routes[self.convert_path(updated_path)] = method
|
||||||
|
else:
|
||||||
|
self.check_for_route_duplicates(path)
|
||||||
|
self.handler_class._routes[self.convert_path(path)] = method
|
||||||
|
|
||||||
def route(self, path) -> Callable:
|
def route(self, path) -> Callable:
|
||||||
"""
|
"""
|
||||||
@ -147,18 +184,17 @@ class WebServer(HTTPServer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def outer(func):
|
def outer(func):
|
||||||
if not hasattr(self.handler_class, "_routes"):
|
self.add_route(path, func)
|
||||||
setattr(self.handler_class, "_routes", [])
|
|
||||||
self.check_for_route_duplicates(path)
|
|
||||||
self.handler_class._routes[convert_path(path)] = func
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return outer
|
return outer
|
||||||
|
|
||||||
|
@property
|
||||||
def port(self):
|
def port(self):
|
||||||
"""Return current port."""
|
"""Return current port."""
|
||||||
return self.socket.getsockname()[1]
|
return self.socket.getsockname()[1]
|
||||||
|
|
||||||
|
@property
|
||||||
def address(self):
|
def address(self):
|
||||||
"""Return current IP address."""
|
"""Return current IP address."""
|
||||||
return self.socket.getsockname()[0]
|
return self.socket.getsockname()[0]
|
||||||
@ -168,18 +204,26 @@ class WebServer(HTTPServer):
|
|||||||
path = path if path else ""
|
path = path if path else ""
|
||||||
if path.startswith("/"):
|
if path.startswith("/"):
|
||||||
path = path[1:]
|
path = path[1:]
|
||||||
return "http://" + self.__addr + ":" + str(self.port()) + "/" + path
|
return self.__addr + ":" + str(self.port()) + "/" + path
|
||||||
|
|
||||||
|
def signal_handler(self, sig, frame) -> NoReturn:
|
||||||
|
log.warning('Shutting down!')
|
||||||
|
self.stop()
|
||||||
|
|
||||||
def start(self, blocking=False):
|
def start(self, blocking=False):
|
||||||
|
signal.signal(signal.SIGINT, self.signal_handler)
|
||||||
|
log.info(f"Starting server on {self.address}:{self.port}")
|
||||||
|
log.info("Press CTRL+C to stop the server.")
|
||||||
|
self._thread = threading.Thread(target=self.serve_forever)
|
||||||
|
self._thread.start()
|
||||||
if not blocking:
|
if not blocking:
|
||||||
threading.Thread(target=self.serve_forever).start()
|
return self._thread
|
||||||
else:
|
else:
|
||||||
|
while self._thread.is_alive():
|
||||||
try:
|
try:
|
||||||
self.serve_forever()
|
time.sleep(0.2)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print() # empty line after ^C
|
self.stop()
|
||||||
print("Stopping server!")
|
|
||||||
return
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
super().shutdown()
|
super().shutdown()
|
||||||
@ -190,6 +234,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||||||
# I can't help the naming convention of these because that's what
|
# I can't help the naming convention of these because that's what
|
||||||
# BaseHTTPRequestHandler uses for some weird reason
|
# BaseHTTPRequestHandler uses for some weird reason
|
||||||
_routes = {}
|
_routes = {}
|
||||||
|
middleware = []
|
||||||
|
|
||||||
def get_request(self):
|
def get_request(self):
|
||||||
return Request(
|
return Request(
|
||||||
@ -233,7 +278,7 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
return http500
|
return http500
|
||||||
|
|
||||||
def fire_response(self, resp: HttpResponse):
|
def _fire_response(self, resp: HttpResponse):
|
||||||
self.send_response(resp.status_code)
|
self.send_response(resp.status_code)
|
||||||
content = resp.render()
|
content = resp.render()
|
||||||
self.send_header("Content-Length", str(len(content)))
|
self.send_header("Content-Length", str(len(content)))
|
||||||
@ -243,8 +288,53 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(bytes(content, "utf-8"))
|
self.wfile.write(bytes(content, "utf-8"))
|
||||||
|
|
||||||
def handle_request(self, request):
|
def fire_response(self, request: Request, resp: HttpResponse):
|
||||||
try:
|
try:
|
||||||
|
self._fire_response(resp)
|
||||||
|
except APIError:
|
||||||
|
raise
|
||||||
|
except ConnectionAbortedError as e:
|
||||||
|
log.error(f"GET {self.path} : {e}")
|
||||||
|
except Exception:
|
||||||
|
log.error(traceback.format_exc())
|
||||||
|
self.fire_response(request, self.get_error_route(500)(request))
|
||||||
|
|
||||||
|
def process_request_middleware(self, request: Request) -> None | bool:
|
||||||
|
for middleware in self.middleware:
|
||||||
|
resp = middleware.process_request(request)
|
||||||
|
if resp:
|
||||||
|
self.process_response_middleware(request, resp)
|
||||||
|
self.fire_response(request, resp)
|
||||||
|
return True # abort further processing
|
||||||
|
|
||||||
|
def process_response_middleware(self, request: Request, response: HttpResponse) -> None:
|
||||||
|
for middleware in self.middleware:
|
||||||
|
middleware.process_response(request, response)
|
||||||
|
|
||||||
|
def prepare_response(self, request, resp) -> HttpResponse:
|
||||||
|
try:
|
||||||
|
if isinstance(resp, dict):
|
||||||
|
self.fire_response(JsonResponse(data=resp))
|
||||||
|
if isinstance(resp, TemplateResponse):
|
||||||
|
if hasattr(self, "env"): # injected from above
|
||||||
|
resp.set_template_loader(self.env)
|
||||||
|
|
||||||
|
for middleware in self.middleware:
|
||||||
|
middleware.process_response(request, resp)
|
||||||
|
|
||||||
|
self.fire_response(resp)
|
||||||
|
|
||||||
|
except APIError:
|
||||||
|
raise
|
||||||
|
|
||||||
|
except ConnectionAbortedError as e:
|
||||||
|
log.error(f"GET {self.path} : {e}")
|
||||||
|
except Exception:
|
||||||
|
log.error(traceback.format_exc())
|
||||||
|
self.fire_response(request, self.get_error_route(500)(request))
|
||||||
|
|
||||||
|
def handle_request(self, request):
|
||||||
|
|
||||||
request.url = urlparse.urlparse(request.path)
|
request.url = urlparse.urlparse(request.path)
|
||||||
|
|
||||||
handler, additional_args = self.get_route(request.url.path)
|
handler, additional_args = self.get_route(request.url.path)
|
||||||
@ -255,26 +345,24 @@ class RequestHandler(BaseHTTPRequestHandler):
|
|||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
request.query_params = params
|
request.query_params = params
|
||||||
|
|
||||||
if handler:
|
|
||||||
try:
|
try:
|
||||||
|
if handler:
|
||||||
|
# middleware is injected from WebServer
|
||||||
|
abort = self.process_request_middleware(request)
|
||||||
|
if abort:
|
||||||
|
return
|
||||||
|
|
||||||
resp = handler(request, **additional_args)
|
resp = handler(request, **additional_args)
|
||||||
if resp is None:
|
if resp is None:
|
||||||
raise NoResponseError(f"View {handler} returned None.")
|
raise NoResponseError(f"View {handler} returned None.")
|
||||||
if isinstance(resp, dict):
|
if isinstance(resp, dict):
|
||||||
self.fire_response(JsonResponse(data=resp))
|
self.fire_response(request, JsonResponse(data=resp))
|
||||||
if isinstance(resp, TemplateResponse):
|
if isinstance(resp, TemplateResponse):
|
||||||
if hasattr(self, "env"): # injected from above
|
if hasattr(self, "env"): # injected from above
|
||||||
resp.set_template_loader(self.env)
|
resp.set_template_loader(self.env)
|
||||||
self.fire_response(resp)
|
|
||||||
except APIError:
|
|
||||||
raise
|
|
||||||
except ConnectionAbortedError as e:
|
|
||||||
log.error(f"GET {self.path} : {e}")
|
|
||||||
except Exception:
|
|
||||||
log.error(traceback.format_exc())
|
|
||||||
self.fire_response(self.get_error_route(500)(request))
|
|
||||||
|
|
||||||
|
self.process_response_middleware(request, resp)
|
||||||
|
self.fire_response(request, resp)
|
||||||
else:
|
else:
|
||||||
raise APIError(404)
|
raise APIError(404)
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
|
@ -16,17 +16,12 @@ class SpiderwebMiddleware:
|
|||||||
|
|
||||||
If `process_request` returns a HttpResponse, the request will be short-circuited
|
If `process_request` returns a HttpResponse, the request will be short-circuited
|
||||||
and the response will be returned immediately. `process_response` will not be called.
|
and the response will be returned immediately. `process_response` will not be called.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def process_request(self, request: Request) -> HttpResponse | None:
|
def process_request(self, request: Request) -> HttpResponse | None:
|
||||||
# example of a middleware that sets a flag on the request
|
pass
|
||||||
request.spiderweb = True
|
|
||||||
|
|
||||||
def process_response(
|
def process_response(
|
||||||
self, request: Request, response: HttpResponse
|
self, request: Request, response: HttpResponse
|
||||||
) -> HttpResponse | None:
|
) -> HttpResponse | None:
|
||||||
# example of a middleware that sets a header on the resp
|
pass
|
||||||
if hasattr(request, "spiderweb"):
|
|
||||||
response.headers["X-Spiderweb"] = "true"
|
|
||||||
return response
|
|
||||||
|
7
spiderweb/utils.py
Normal file
7
spiderweb/utils.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
def import_by_string(name):
|
||||||
|
# https://stackoverflow.com/a/547867
|
||||||
|
components = name.split('.')
|
||||||
|
mod = __import__(components[0])
|
||||||
|
for comp in components[1:]:
|
||||||
|
mod = getattr(mod, comp)
|
||||||
|
return mod
|
@ -1,5 +1,10 @@
|
|||||||
<h1>FART</h1>
|
<h1>HI, THIS IS A PAGE</h1>
|
||||||
|
|
||||||
<p>
|
<p>
|
||||||
This is a test of the {{ value }} template.
|
This is a test of the template rendering system. If rendering is working, this value
|
||||||
|
should be <code>TEST</code>: {{ value }}.
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
The value of <code>request.spiderweb</code> is {{ request.spiderweb }}. If this is True,
|
||||||
|
middleware is working.
|
||||||
</p>
|
</p>
|
Loading…
Reference in New Issue
Block a user