🔥 remove old main.py

This commit is contained in:
Joe Kaufeld 2024-08-18 17:48:10 -04:00
parent 675743bf8d
commit 6c636ffd2e

View File

@ -1,497 +0,0 @@
# Started life from
# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55
#
# Extensively modified by @itsthejoker
import inspect
import os
import re
import pathlib
import signal
import time
import traceback
from http.server import BaseHTTPRequestHandler, HTTPServer
import urllib.parse as urlparse
import threading
import logging
from typing import Callable, Any, NoReturn
from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
from cryptography.fernet import Fernet
from jinja2 import Environment, FileSystemLoader
from spiderweb.constants import DEFAULT_ENCODING, DEFAULT_ALLOWED_METHODS
from spiderweb.converters import * # noqa: F403
from spiderweb.default_responses import * # noqa: F403
from spiderweb.exceptions import (
APIError,
ConfigError,
ParseError,
GeneralException,
NoResponseError,
UnusedMiddleware,
SpiderwebNetworkException,
NotFound,
)
from spiderweb.request import Request
from spiderweb.response import (
HttpResponse,
JsonResponse,
TemplateResponse,
RedirectResponse,
FileResponse,
)
from spiderweb.utils import import_by_string, is_safe_path
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def route(path):
def outer(func):
if not hasattr(func, "_routes"):
setattr(func, "_routes", [])
func._routes += [path]
return func
return outer
def send_file(request, filename: str) -> HttpResponse:
for folder in request.server.staticfiles_dirs:
requested_path = request.server.BASE_DIR / folder / filename
if os.path.exists(requested_path):
if not is_safe_path(requested_path):
raise NotFound
return FileResponse(filename=requested_path)
raise NotFound
class DummyRedirectRoute:
def __init__(self, location):
self.location = location
def __call__(self, request):
return RedirectResponse(self.location)
def convert_match_to_dict(match: dict):
"""Convert a match object to a dict with the proper converted types for each match."""
return {
k.split("__")[0]: globals()[k.split("__")[1]]().to_python(v)
for k, v in match.items()
}
class WebServer(HTTPServer):
def __init__(
self,
addr: str = None,
port: int = None,
custom_handler: Callable = None,
templates_dirs: list[str] = None,
middleware: list[str] = None,
append_slash: bool = False,
staticfiles_dirs: list[str] = None,
secret_key: str = None,
):
"""
Create a new server on address, port. Port can be zero.
> from simple_rpc_server import WebServer, APIError, route
Create your handlers by inheriting from WebServer and tagging them with
@route("/path"). Alternately, you can use the WebServer() directly
by calling `add_handler("path", function)`.
"""
addr = addr if addr else "localhost"
port = port if port else 8000
self.append_slash = append_slash
self.templates_dirs = templates_dirs
self.staticfiles_dirs = staticfiles_dirs
self.middleware = middleware if middleware else []
self.secret_key = secret_key if secret_key else self._create_secret_key()
self.fernet = Fernet(self.key)
self.DEFAULT_ENCODING = DEFAULT_ENCODING
self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS
self._thread = None
self.BASE_DIR = self.get_caller_filepath()
if self.middleware:
middleware_by_reference = []
for m in self.middleware:
try:
middleware_by_reference.append(import_by_string(m)(server=self))
except ImportError:
raise ConfigError(f"Middleware '{m}' not found.")
self.middleware = middleware_by_reference
if self.templates_dirs:
self.env = Environment(loader=FileSystemLoader(self.templates_dirs))
else:
self.env = None
server_address = (addr, port)
self.__addr = addr
# shim class that is an RequestHandler
class HandlerClass(RequestHandler):
pass
self.handler_class = custom_handler if custom_handler else HandlerClass
self.handler_class.server = self
# routed methods map into handler
for method in type(self).__dict__.values():
if hasattr(method, "_routes"):
for route in method._routes:
self.add_route(route, method)
if self.staticfiles_dirs:
for static_dir in self.staticfiles_dirs:
static_dir = pathlib.Path(static_dir)
if not pathlib.Path(self.BASE_DIR / static_dir).exists():
log.error(
f"Static files directory '{str(static_dir)}' does not exist."
)
raise ConfigError
self.add_route(r"/static/<str:filename>", send_file)
try:
super().__init__(server_address, self.handler_class)
except OSError:
raise GeneralException("Port already in use.")
def get_caller_filepath(self):
"""Figure out who called us and return their path."""
stack = inspect.stack()
caller_frame = stack[1]
return pathlib.Path(caller_frame.filename).parent.parent
def convert_path(self, path: str):
"""Convert a path to a regex."""
parts = path.split("/")
for i, part in enumerate(parts):
if part.startswith("<") and part.endswith(">"):
name = part[1:-1]
if "__" in name:
raise ConfigError(
f"Cannot use `__` (double underscore) in path variable."
f" Please fix '{name}'."
)
if ":" in name:
converter, name = name.split(":")
try:
converter = globals()[converter.title() + "Converter"]
except KeyError:
raise ParseError(f"Unknown converter {converter}")
else:
converter = StrConverter # noqa: F405
parts[i] = rf"(?P<{name}__{str(converter.__name__)}>{converter.regex})"
return re.compile(rf"^{'/'.join(parts)}$")
def check_for_route_duplicates(self, path: str):
if self.convert_path(path) in self.handler_class._routes:
raise ConfigError(f"Route '{path}' already exists.")
def add_route(
self, path: str, method: Callable, allowed_methods: None | list[str] = None
):
"""Add a route to the server."""
if not hasattr(self.handler_class, "_routes"):
setattr(self.handler_class, "_routes", {})
allowed_methods = (
allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS
)
if self.append_slash and not path.endswith("/"):
updated_path = path + "/"
self.check_for_route_duplicates(updated_path)
self.check_for_route_duplicates(path)
self.handler_class._routes[self.convert_path(path)] = {
"func": DummyRedirectRoute(updated_path),
"allowed_methods": allowed_methods,
}
self.handler_class._routes[self.convert_path(updated_path)] = {
"func": method,
"allowed_methods": allowed_methods,
}
else:
self.check_for_route_duplicates(path)
self.handler_class._routes[self.convert_path(path)] = {
"func": method,
"allowed_methods": allowed_methods,
}
def add_error_route(self, code: int, method: Callable):
"""Add an error route to the server."""
if not hasattr(self.handler_class, "_error_routes"):
setattr(self.handler_class, "_error_routes", {})
if code not in self.handler_class._error_routes:
self.handler_class._error_routes[code] = method
else:
raise ConfigError(f"Error route for code {code} already exists.")
def route(self, path, allowed_methods=None) -> Callable:
"""
Decorator for adding a route to a view.
Usage:
app = WebServer()
@app.route("/hello")
def index(request):
return HttpResponse(content="Hello, world!")
:param path: str
:param allowed_methods: list[str]
:return: Callable
"""
def outer(func):
self.add_route(
path,
func,
allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS,
)
return func
return outer
def error(self, code: int) -> Callable:
def outer(func):
self.add_error_route(code, func)
return func
return outer
@property
def port(self):
"""Return current port."""
return self.socket.getsockname()[1]
@property
def address(self):
"""Return current IP address."""
return self.socket.getsockname()[0]
def uri(self, path=None):
"""Make a URI pointing at myself."""
path = path if path else ""
if path.startswith("/"):
path = path[1:]
return self.__addr + ":" + str(self.port()) + "/" + path
def signal_handler(self, sig, frame) -> NoReturn:
log.warning("Shutting down!")
self.stop()
def start(self, blocking=False):
signal.signal(signal.SIGINT, self.signal_handler)
log.info(f"Starting server on {self.address}:{self.port}")
log.info("Press CTRL+C to stop the server.")
self._thread = threading.Thread(target=self.serve_forever)
self._thread.start()
if not blocking:
return self._thread
else:
while self._thread.is_alive():
try:
time.sleep(0.2)
except KeyboardInterrupt:
self.stop()
def stop(self):
super().shutdown()
self.socket.close()
def _create_secret_key(self):
self.key = Fernet.generate_key()
def encrypt(self, data: str):
return self.fernet.encrypt(bytes(data, DEFAULT_ENCODING))
def decrypt(self, data: str):
if isinstance(data, bytes):
return self.fernet.decrypt(data).decode(DEFAULT_ENCODING)
return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode(
DEFAULT_ENCODING
)
class RequestHandler(BaseHTTPRequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# These stop pycharm from complaining about these not existing. They're
# injected by the WebServer class at runtime
self._routes = {}
self._error_routes = {}
self.server = None
def get_request(self):
return Request(
content="",
body="",
method=self.command,
headers=self.headers,
path=self.path,
server=self.server,
)
# I can't help the naming convention of these because that's what
# BaseHTTPRequestHandler uses for some weird reason
def do_GET(self):
request = self.get_request()
request.method = "GET"
self.handle_request(request)
def do_POST(self):
content = "{}"
if self.headers["Content-Length"]:
length = int(self.headers["Content-Length"])
content = self.rfile.read(length)
request = self.get_request()
request.method = "POST"
request.content = content
self.handle_request(request)
def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]:
for option in self._routes.keys():
if match_data := option.match(path):
return (
self._routes[option]["func"],
convert_match_to_dict(match_data.groupdict()),
self._routes[option]["allowed_methods"],
)
raise NotFound()
def get_error_route(self, code: int) -> Callable:
view = self._error_routes.get(code) or globals().get(f"http{code}")
if not view:
return http500
return view
def _fire_response(
self,
status: int = 200,
content: str = None,
headers: dict[str, str | int] = None,
):
self.send_response(status)
self.send_header("Content-Length", str(len(content)))
if headers:
for key, value in headers.items():
self.send_header(key, value)
self.end_headers()
self.wfile.write(bytes(content, DEFAULT_ENCODING))
def fire_response(self, request: Request, resp: HttpResponse):
try:
self._fire_response(
status=resp.status_code, content=resp.render(), headers=resp.headers
)
except APIError:
raise
except ConnectionAbortedError as e:
log.error(f"GET {self.path} : {e}")
except Exception:
log.error(traceback.format_exc())
self.fire_response(request, self.get_error_route(500)(request))
def process_request_middleware(self, request: Request) -> None | bool:
for middleware in self.server.middleware:
try:
resp = middleware.process_request(request)
except UnusedMiddleware:
self.server.middleware.remove(middleware)
continue
if resp:
self.process_response_middleware(request, resp)
self.fire_response(request, resp)
return True # abort further processing
def process_response_middleware(
self, request: Request, response: HttpResponse
) -> None:
for middleware in self.server.middleware:
try:
middleware.process_response(request, response)
except UnusedMiddleware:
self.server.middleware.remove(middleware)
continue
def prepare_and_fire_response(self, request, resp) -> None:
try:
if isinstance(resp, dict):
self.fire_response(request, JsonResponse(data=resp))
if isinstance(resp, TemplateResponse):
if hasattr(self.server, "env"):
resp.set_template_loader(self.server.env)
for middleware in self.server.middleware:
middleware.process_response(request, resp)
self.fire_response(request, resp)
except APIError:
raise
except Exception:
log.error(traceback.format_exc())
self.fire_response(request, self.get_error_route(500)(request))
def is_form_request(self, request: Request) -> bool:
return (
"Content-Type" in request.headers
and request.headers["Content-Type"] == "application/x-www-form-urlencoded"
)
def send_error_response(self, request: Request, e: SpiderwebNetworkException):
try:
self.send_error(e.code, e.msg, e.desc)
except ConnectionAbortedError as e:
log.error(f"{request.method} {self.path} : {e}")
def handle_request(self, request):
try:
handler, additional_args, allowed_methods = self.get_route(request.url.path)
except NotFound:
handler = self.get_error_route(404)
additional_args = {}
allowed_methods = DEFAULT_ALLOWED_METHODS
if request.method not in allowed_methods:
# replace the potentially valid handler with the error route
handler = self.get_error_route(405)
request.query_params = (
urlparse.parse_qs(request.url.query) if request.url.query else {}
)
if self.is_form_request(request):
formdata = urlparse.parse_qs(request.content.decode("utf-8"))
for key, value in formdata.items():
if len(value) == 1:
formdata[key] = value[0]
setattr(request, request.method, formdata)
try:
if handler:
# middleware is injected from WebServer
abort = self.process_request_middleware(request)
if abort:
return
resp = handler(request, **additional_args)
if resp is None:
raise NoResponseError(f"View {handler} returned None.")
# run the response through the middleware and send it
self.prepare_and_fire_response(request, resp)
else:
raise SpiderwebNetworkException(404)
except SpiderwebNetworkException as e:
self.send_error_response(request, e)