🔥 remove old main.py
This commit is contained in:
parent
675743bf8d
commit
6c636ffd2e
@ -1,497 +0,0 @@
|
|||||||
# Started life from
|
|
||||||
# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55
|
|
||||||
#
|
|
||||||
# Extensively modified by @itsthejoker
|
|
||||||
import inspect
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import pathlib
|
|
||||||
import signal
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
||||||
import urllib.parse as urlparse
|
|
||||||
import threading
|
|
||||||
import logging
|
|
||||||
from typing import Callable, Any, NoReturn
|
|
||||||
from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
|
|
||||||
|
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
from jinja2 import Environment, FileSystemLoader
|
|
||||||
|
|
||||||
from spiderweb.constants import DEFAULT_ENCODING, DEFAULT_ALLOWED_METHODS
|
|
||||||
from spiderweb.converters import * # noqa: F403
|
|
||||||
from spiderweb.default_responses import * # noqa: F403
|
|
||||||
from spiderweb.exceptions import (
|
|
||||||
APIError,
|
|
||||||
ConfigError,
|
|
||||||
ParseError,
|
|
||||||
GeneralException,
|
|
||||||
NoResponseError,
|
|
||||||
UnusedMiddleware,
|
|
||||||
SpiderwebNetworkException,
|
|
||||||
NotFound,
|
|
||||||
)
|
|
||||||
from spiderweb.request import Request
|
|
||||||
from spiderweb.response import (
|
|
||||||
HttpResponse,
|
|
||||||
JsonResponse,
|
|
||||||
TemplateResponse,
|
|
||||||
RedirectResponse,
|
|
||||||
FileResponse,
|
|
||||||
)
|
|
||||||
from spiderweb.utils import import_by_string, is_safe_path
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
def route(path):
|
|
||||||
def outer(func):
|
|
||||||
if not hasattr(func, "_routes"):
|
|
||||||
setattr(func, "_routes", [])
|
|
||||||
func._routes += [path]
|
|
||||||
return func
|
|
||||||
|
|
||||||
return outer
|
|
||||||
|
|
||||||
|
|
||||||
def send_file(request, filename: str) -> HttpResponse:
|
|
||||||
for folder in request.server.staticfiles_dirs:
|
|
||||||
requested_path = request.server.BASE_DIR / folder / filename
|
|
||||||
if os.path.exists(requested_path):
|
|
||||||
if not is_safe_path(requested_path):
|
|
||||||
raise NotFound
|
|
||||||
return FileResponse(filename=requested_path)
|
|
||||||
raise NotFound
|
|
||||||
|
|
||||||
|
|
||||||
class DummyRedirectRoute:
|
|
||||||
def __init__(self, location):
|
|
||||||
self.location = location
|
|
||||||
|
|
||||||
def __call__(self, request):
|
|
||||||
return RedirectResponse(self.location)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_match_to_dict(match: dict):
|
|
||||||
"""Convert a match object to a dict with the proper converted types for each match."""
|
|
||||||
return {
|
|
||||||
k.split("__")[0]: globals()[k.split("__")[1]]().to_python(v)
|
|
||||||
for k, v in match.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class WebServer(HTTPServer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
addr: str = None,
|
|
||||||
port: int = None,
|
|
||||||
custom_handler: Callable = None,
|
|
||||||
templates_dirs: list[str] = None,
|
|
||||||
middleware: list[str] = None,
|
|
||||||
append_slash: bool = False,
|
|
||||||
staticfiles_dirs: list[str] = None,
|
|
||||||
secret_key: str = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new server on address, port. Port can be zero.
|
|
||||||
|
|
||||||
> from simple_rpc_server import WebServer, APIError, route
|
|
||||||
|
|
||||||
Create your handlers by inheriting from WebServer and tagging them with
|
|
||||||
@route("/path"). Alternately, you can use the WebServer() directly
|
|
||||||
by calling `add_handler("path", function)`.
|
|
||||||
"""
|
|
||||||
addr = addr if addr else "localhost"
|
|
||||||
port = port if port else 8000
|
|
||||||
self.append_slash = append_slash
|
|
||||||
self.templates_dirs = templates_dirs
|
|
||||||
self.staticfiles_dirs = staticfiles_dirs
|
|
||||||
self.middleware = middleware if middleware else []
|
|
||||||
self.secret_key = secret_key if secret_key else self._create_secret_key()
|
|
||||||
self.fernet = Fernet(self.key)
|
|
||||||
self.DEFAULT_ENCODING = DEFAULT_ENCODING
|
|
||||||
self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS
|
|
||||||
self._thread = None
|
|
||||||
|
|
||||||
self.BASE_DIR = self.get_caller_filepath()
|
|
||||||
|
|
||||||
if self.middleware:
|
|
||||||
middleware_by_reference = []
|
|
||||||
for m in self.middleware:
|
|
||||||
try:
|
|
||||||
middleware_by_reference.append(import_by_string(m)(server=self))
|
|
||||||
except ImportError:
|
|
||||||
raise ConfigError(f"Middleware '{m}' not found.")
|
|
||||||
self.middleware = middleware_by_reference
|
|
||||||
|
|
||||||
if self.templates_dirs:
|
|
||||||
self.env = Environment(loader=FileSystemLoader(self.templates_dirs))
|
|
||||||
else:
|
|
||||||
self.env = None
|
|
||||||
|
|
||||||
server_address = (addr, port)
|
|
||||||
self.__addr = addr
|
|
||||||
|
|
||||||
# shim class that is an RequestHandler
|
|
||||||
class HandlerClass(RequestHandler):
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.handler_class = custom_handler if custom_handler else HandlerClass
|
|
||||||
self.handler_class.server = self
|
|
||||||
|
|
||||||
# routed methods map into handler
|
|
||||||
for method in type(self).__dict__.values():
|
|
||||||
if hasattr(method, "_routes"):
|
|
||||||
for route in method._routes:
|
|
||||||
self.add_route(route, method)
|
|
||||||
|
|
||||||
if self.staticfiles_dirs:
|
|
||||||
for static_dir in self.staticfiles_dirs:
|
|
||||||
static_dir = pathlib.Path(static_dir)
|
|
||||||
if not pathlib.Path(self.BASE_DIR / static_dir).exists():
|
|
||||||
log.error(
|
|
||||||
f"Static files directory '{str(static_dir)}' does not exist."
|
|
||||||
)
|
|
||||||
raise ConfigError
|
|
||||||
self.add_route(r"/static/<str:filename>", send_file)
|
|
||||||
|
|
||||||
try:
|
|
||||||
super().__init__(server_address, self.handler_class)
|
|
||||||
except OSError:
|
|
||||||
raise GeneralException("Port already in use.")
|
|
||||||
|
|
||||||
def get_caller_filepath(self):
|
|
||||||
"""Figure out who called us and return their path."""
|
|
||||||
stack = inspect.stack()
|
|
||||||
caller_frame = stack[1]
|
|
||||||
return pathlib.Path(caller_frame.filename).parent.parent
|
|
||||||
|
|
||||||
def convert_path(self, path: str):
|
|
||||||
"""Convert a path to a regex."""
|
|
||||||
parts = path.split("/")
|
|
||||||
for i, part in enumerate(parts):
|
|
||||||
if part.startswith("<") and part.endswith(">"):
|
|
||||||
name = part[1:-1]
|
|
||||||
if "__" in name:
|
|
||||||
raise ConfigError(
|
|
||||||
f"Cannot use `__` (double underscore) in path variable."
|
|
||||||
f" Please fix '{name}'."
|
|
||||||
)
|
|
||||||
if ":" in name:
|
|
||||||
converter, name = name.split(":")
|
|
||||||
try:
|
|
||||||
converter = globals()[converter.title() + "Converter"]
|
|
||||||
except KeyError:
|
|
||||||
raise ParseError(f"Unknown converter {converter}")
|
|
||||||
else:
|
|
||||||
converter = StrConverter # noqa: F405
|
|
||||||
parts[i] = rf"(?P<{name}__{str(converter.__name__)}>{converter.regex})"
|
|
||||||
return re.compile(rf"^{'/'.join(parts)}$")
|
|
||||||
|
|
||||||
def check_for_route_duplicates(self, path: str):
|
|
||||||
if self.convert_path(path) in self.handler_class._routes:
|
|
||||||
raise ConfigError(f"Route '{path}' already exists.")
|
|
||||||
|
|
||||||
def add_route(
|
|
||||||
self, path: str, method: Callable, allowed_methods: None | list[str] = None
|
|
||||||
):
|
|
||||||
"""Add a route to the server."""
|
|
||||||
if not hasattr(self.handler_class, "_routes"):
|
|
||||||
setattr(self.handler_class, "_routes", {})
|
|
||||||
|
|
||||||
allowed_methods = (
|
|
||||||
allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.append_slash and not path.endswith("/"):
|
|
||||||
updated_path = path + "/"
|
|
||||||
self.check_for_route_duplicates(updated_path)
|
|
||||||
self.check_for_route_duplicates(path)
|
|
||||||
self.handler_class._routes[self.convert_path(path)] = {
|
|
||||||
"func": DummyRedirectRoute(updated_path),
|
|
||||||
"allowed_methods": allowed_methods,
|
|
||||||
}
|
|
||||||
self.handler_class._routes[self.convert_path(updated_path)] = {
|
|
||||||
"func": method,
|
|
||||||
"allowed_methods": allowed_methods,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
self.check_for_route_duplicates(path)
|
|
||||||
self.handler_class._routes[self.convert_path(path)] = {
|
|
||||||
"func": method,
|
|
||||||
"allowed_methods": allowed_methods,
|
|
||||||
}
|
|
||||||
|
|
||||||
def add_error_route(self, code: int, method: Callable):
|
|
||||||
"""Add an error route to the server."""
|
|
||||||
if not hasattr(self.handler_class, "_error_routes"):
|
|
||||||
setattr(self.handler_class, "_error_routes", {})
|
|
||||||
|
|
||||||
if code not in self.handler_class._error_routes:
|
|
||||||
self.handler_class._error_routes[code] = method
|
|
||||||
else:
|
|
||||||
raise ConfigError(f"Error route for code {code} already exists.")
|
|
||||||
|
|
||||||
def route(self, path, allowed_methods=None) -> Callable:
|
|
||||||
"""
|
|
||||||
Decorator for adding a route to a view.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
app = WebServer()
|
|
||||||
|
|
||||||
@app.route("/hello")
|
|
||||||
def index(request):
|
|
||||||
return HttpResponse(content="Hello, world!")
|
|
||||||
|
|
||||||
:param path: str
|
|
||||||
:param allowed_methods: list[str]
|
|
||||||
:return: Callable
|
|
||||||
"""
|
|
||||||
|
|
||||||
def outer(func):
|
|
||||||
self.add_route(
|
|
||||||
path,
|
|
||||||
func,
|
|
||||||
allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS,
|
|
||||||
)
|
|
||||||
return func
|
|
||||||
|
|
||||||
return outer
|
|
||||||
|
|
||||||
def error(self, code: int) -> Callable:
|
|
||||||
def outer(func):
|
|
||||||
self.add_error_route(code, func)
|
|
||||||
return func
|
|
||||||
|
|
||||||
return outer
|
|
||||||
|
|
||||||
@property
|
|
||||||
def port(self):
|
|
||||||
"""Return current port."""
|
|
||||||
return self.socket.getsockname()[1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def address(self):
|
|
||||||
"""Return current IP address."""
|
|
||||||
return self.socket.getsockname()[0]
|
|
||||||
|
|
||||||
def uri(self, path=None):
|
|
||||||
"""Make a URI pointing at myself."""
|
|
||||||
path = path if path else ""
|
|
||||||
if path.startswith("/"):
|
|
||||||
path = path[1:]
|
|
||||||
return self.__addr + ":" + str(self.port()) + "/" + path
|
|
||||||
|
|
||||||
def signal_handler(self, sig, frame) -> NoReturn:
|
|
||||||
log.warning("Shutting down!")
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
def start(self, blocking=False):
|
|
||||||
signal.signal(signal.SIGINT, self.signal_handler)
|
|
||||||
log.info(f"Starting server on {self.address}:{self.port}")
|
|
||||||
log.info("Press CTRL+C to stop the server.")
|
|
||||||
self._thread = threading.Thread(target=self.serve_forever)
|
|
||||||
self._thread.start()
|
|
||||||
if not blocking:
|
|
||||||
return self._thread
|
|
||||||
else:
|
|
||||||
while self._thread.is_alive():
|
|
||||||
try:
|
|
||||||
time.sleep(0.2)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
super().shutdown()
|
|
||||||
self.socket.close()
|
|
||||||
|
|
||||||
def _create_secret_key(self):
|
|
||||||
self.key = Fernet.generate_key()
|
|
||||||
|
|
||||||
def encrypt(self, data: str):
|
|
||||||
return self.fernet.encrypt(bytes(data, DEFAULT_ENCODING))
|
|
||||||
|
|
||||||
def decrypt(self, data: str):
|
|
||||||
if isinstance(data, bytes):
|
|
||||||
return self.fernet.decrypt(data).decode(DEFAULT_ENCODING)
|
|
||||||
return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode(
|
|
||||||
DEFAULT_ENCODING
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler(BaseHTTPRequestHandler):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
# These stop pycharm from complaining about these not existing. They're
|
|
||||||
# injected by the WebServer class at runtime
|
|
||||||
self._routes = {}
|
|
||||||
self._error_routes = {}
|
|
||||||
self.server = None
|
|
||||||
|
|
||||||
def get_request(self):
|
|
||||||
return Request(
|
|
||||||
content="",
|
|
||||||
body="",
|
|
||||||
method=self.command,
|
|
||||||
headers=self.headers,
|
|
||||||
path=self.path,
|
|
||||||
server=self.server,
|
|
||||||
)
|
|
||||||
|
|
||||||
# I can't help the naming convention of these because that's what
|
|
||||||
# BaseHTTPRequestHandler uses for some weird reason
|
|
||||||
def do_GET(self):
|
|
||||||
request = self.get_request()
|
|
||||||
request.method = "GET"
|
|
||||||
self.handle_request(request)
|
|
||||||
|
|
||||||
def do_POST(self):
|
|
||||||
content = "{}"
|
|
||||||
if self.headers["Content-Length"]:
|
|
||||||
length = int(self.headers["Content-Length"])
|
|
||||||
content = self.rfile.read(length)
|
|
||||||
request = self.get_request()
|
|
||||||
request.method = "POST"
|
|
||||||
request.content = content
|
|
||||||
self.handle_request(request)
|
|
||||||
|
|
||||||
def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]:
|
|
||||||
for option in self._routes.keys():
|
|
||||||
if match_data := option.match(path):
|
|
||||||
return (
|
|
||||||
self._routes[option]["func"],
|
|
||||||
convert_match_to_dict(match_data.groupdict()),
|
|
||||||
self._routes[option]["allowed_methods"],
|
|
||||||
)
|
|
||||||
raise NotFound()
|
|
||||||
|
|
||||||
def get_error_route(self, code: int) -> Callable:
|
|
||||||
view = self._error_routes.get(code) or globals().get(f"http{code}")
|
|
||||||
if not view:
|
|
||||||
return http500
|
|
||||||
return view
|
|
||||||
|
|
||||||
def _fire_response(
|
|
||||||
self,
|
|
||||||
status: int = 200,
|
|
||||||
content: str = None,
|
|
||||||
headers: dict[str, str | int] = None,
|
|
||||||
):
|
|
||||||
self.send_response(status)
|
|
||||||
self.send_header("Content-Length", str(len(content)))
|
|
||||||
if headers:
|
|
||||||
for key, value in headers.items():
|
|
||||||
self.send_header(key, value)
|
|
||||||
self.end_headers()
|
|
||||||
self.wfile.write(bytes(content, DEFAULT_ENCODING))
|
|
||||||
|
|
||||||
def fire_response(self, request: Request, resp: HttpResponse):
|
|
||||||
try:
|
|
||||||
self._fire_response(
|
|
||||||
status=resp.status_code, content=resp.render(), headers=resp.headers
|
|
||||||
)
|
|
||||||
except APIError:
|
|
||||||
raise
|
|
||||||
except ConnectionAbortedError as e:
|
|
||||||
log.error(f"GET {self.path} : {e}")
|
|
||||||
except Exception:
|
|
||||||
log.error(traceback.format_exc())
|
|
||||||
self.fire_response(request, self.get_error_route(500)(request))
|
|
||||||
|
|
||||||
def process_request_middleware(self, request: Request) -> None | bool:
|
|
||||||
for middleware in self.server.middleware:
|
|
||||||
try:
|
|
||||||
resp = middleware.process_request(request)
|
|
||||||
except UnusedMiddleware:
|
|
||||||
self.server.middleware.remove(middleware)
|
|
||||||
continue
|
|
||||||
if resp:
|
|
||||||
self.process_response_middleware(request, resp)
|
|
||||||
self.fire_response(request, resp)
|
|
||||||
return True # abort further processing
|
|
||||||
|
|
||||||
def process_response_middleware(
|
|
||||||
self, request: Request, response: HttpResponse
|
|
||||||
) -> None:
|
|
||||||
for middleware in self.server.middleware:
|
|
||||||
try:
|
|
||||||
middleware.process_response(request, response)
|
|
||||||
except UnusedMiddleware:
|
|
||||||
self.server.middleware.remove(middleware)
|
|
||||||
continue
|
|
||||||
|
|
||||||
def prepare_and_fire_response(self, request, resp) -> None:
|
|
||||||
try:
|
|
||||||
if isinstance(resp, dict):
|
|
||||||
self.fire_response(request, JsonResponse(data=resp))
|
|
||||||
if isinstance(resp, TemplateResponse):
|
|
||||||
if hasattr(self.server, "env"):
|
|
||||||
resp.set_template_loader(self.server.env)
|
|
||||||
|
|
||||||
for middleware in self.server.middleware:
|
|
||||||
middleware.process_response(request, resp)
|
|
||||||
|
|
||||||
self.fire_response(request, resp)
|
|
||||||
|
|
||||||
except APIError:
|
|
||||||
raise
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
log.error(traceback.format_exc())
|
|
||||||
self.fire_response(request, self.get_error_route(500)(request))
|
|
||||||
|
|
||||||
def is_form_request(self, request: Request) -> bool:
|
|
||||||
return (
|
|
||||||
"Content-Type" in request.headers
|
|
||||||
and request.headers["Content-Type"] == "application/x-www-form-urlencoded"
|
|
||||||
)
|
|
||||||
|
|
||||||
def send_error_response(self, request: Request, e: SpiderwebNetworkException):
|
|
||||||
try:
|
|
||||||
self.send_error(e.code, e.msg, e.desc)
|
|
||||||
except ConnectionAbortedError as e:
|
|
||||||
log.error(f"{request.method} {self.path} : {e}")
|
|
||||||
|
|
||||||
def handle_request(self, request):
|
|
||||||
|
|
||||||
try:
|
|
||||||
handler, additional_args, allowed_methods = self.get_route(request.url.path)
|
|
||||||
except NotFound:
|
|
||||||
handler = self.get_error_route(404)
|
|
||||||
additional_args = {}
|
|
||||||
allowed_methods = DEFAULT_ALLOWED_METHODS
|
|
||||||
|
|
||||||
if request.method not in allowed_methods:
|
|
||||||
# replace the potentially valid handler with the error route
|
|
||||||
handler = self.get_error_route(405)
|
|
||||||
|
|
||||||
request.query_params = (
|
|
||||||
urlparse.parse_qs(request.url.query) if request.url.query else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.is_form_request(request):
|
|
||||||
formdata = urlparse.parse_qs(request.content.decode("utf-8"))
|
|
||||||
for key, value in formdata.items():
|
|
||||||
if len(value) == 1:
|
|
||||||
formdata[key] = value[0]
|
|
||||||
setattr(request, request.method, formdata)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if handler:
|
|
||||||
# middleware is injected from WebServer
|
|
||||||
abort = self.process_request_middleware(request)
|
|
||||||
if abort:
|
|
||||||
return
|
|
||||||
|
|
||||||
resp = handler(request, **additional_args)
|
|
||||||
if resp is None:
|
|
||||||
raise NoResponseError(f"View {handler} returned None.")
|
|
||||||
# run the response through the middleware and send it
|
|
||||||
self.prepare_and_fire_response(request, resp)
|
|
||||||
else:
|
|
||||||
raise SpiderwebNetworkException(404)
|
|
||||||
except SpiderwebNetworkException as e:
|
|
||||||
self.send_error_response(request, e)
|
|
Loading…
Reference in New Issue
Block a user