spiderweb/spiderweb/main.py
2024-08-05 20:24:30 -04:00

230 lines
7.4 KiB
Python

# Started life from
# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55
#
# Extensively modified by @itsthejoker
import json
import re
from http.server import BaseHTTPRequestHandler, HTTPServer
import urllib.parse as urlparse
import threading
import logging
from typing import Callable, Any
from spiderweb.converters import * # noqa: F403
from spiderweb.default_responses import http403, http404, http500
from spiderweb.exceptions import APIError, ConfigError, ParseError, GeneralException, NoResponseError
from spiderweb.request import Request
from spiderweb.response import HttpResponse, JsonResponse
log = logging.getLogger(__name__)
def api_route(path):
def outer(func):
if not hasattr(func, "_routes"):
setattr(func, "_routes", [])
func._routes += [path]
return func
return outer
def convert_path(path):
"""Convert a path to a regex."""
parts = path.split("/")
for i, part in enumerate(parts):
if part.startswith("<") and part.endswith(">"):
name = part[1:-1]
if "__" in name:
raise ConfigError(
f"Cannot use `__` (double underscore) in path variable."
f" Please fix '{name}'."
)
if ":" in name:
converter, name = name.split(":")
try:
converter = globals()[converter.title() + "Converter"]
except KeyError:
raise ParseError(f"Unknown converter {converter}")
else:
converter = StrConverter # noqa: F405
parts[i] = r"(?P<%s>%s)" % (
f"{name}__{str(converter.__name__)}",
converter.regex,
)
return re.compile(r"^%s$" % "/".join(parts))
def convert_match_to_dict(match: dict):
"""Convert a match object to a dict with the proper converted types for each match."""
return {
k.split("__")[0]: globals()[k.split("__")[1]]().to_python(v)
for k, v in match.items()
}
class WebServer(HTTPServer):
def __init__(self, addr: str, port: int, custom_handler: Callable = None):
"""
Create a new server on address, port. Port can be zero.
> from simple_rpc_server import WebServer, APIError, api_route
Create your handlers by inheriting from WebServer and tagging them with
@api_route("/path"). Alternately, you can use the WebServer() directly
by calling `add_handler("path", function)`.
Raise network errors by raising `APIError(code, message, description=None)`.
Return responses by simply returning a dict() or str() object.
Parameter to handlers is a dict().
Query arguments are shoved into the dict via urllib.parse_qs.
"""
server_address = (addr, port)
self.__addr = addr
# shim class that is an RequestHandler
class HandlerClass(RequestHandler):
pass
self.handler_class = custom_handler if custom_handler else HandlerClass
# routed methods map into handler
for method in type(self).__dict__.values():
if hasattr(method, "_routes"):
for route in method._routes:
self.add_route(route, method)
try:
super().__init__(server_address, self.handler_class)
except OSError:
raise GeneralException("Port already in use.")
def add_route(self, path: str, method: Callable):
self.handler_class._routes[convert_path(path)] = method
def port(self):
"""Return current port."""
return self.socket.getsockname()[1]
def address(self):
"""Return current IP address."""
return self.socket.getsockname()[0]
def uri(self, path):
"""Make a URI pointing at myself."""
if path[0] == "/":
path = path[1:]
return "http://" + self.__addr + ":" + str(self.port()) + "/" + path
def start(self, blocking=False):
if not blocking:
threading.Thread(target=self.serve_forever).start()
else:
try:
self.serve_forever()
except KeyboardInterrupt:
print() # empty line after ^C
print("Stopping server!")
return
def shutdown(self):
super().shutdown()
self.socket.close()
class RequestHandler(BaseHTTPRequestHandler):
# I can't help the naming convention of these because that's what
# BaseHTTPRequestHandler uses for some weird reason
_routes = {}
def get_request(self):
return Request(
content="",
body="",
method=self.command,
headers=self.headers,
path=self.path
)
def do_GET(self):
request = self.get_request()
self.handle_request(request)
def do_POST(self):
content = "{}"
if self.headers["Content-Length"]:
length = int(self.headers["Content-Length"])
content = self.rfile.read(length)
request = self.get_request()
request.content = content
if content:
try:
request.json()
except json.JSONDecodeError:
raise APIError(400, "Invalid JSON", content)
self.handle_request(request)
def get_route(self, path) -> tuple[Callable, dict[str, Any]]:
for option in self._routes.keys():
if match_data := option.match(path):
return self._routes[option], convert_match_to_dict(
match_data.groupdict()
)
raise APIError(404, "No route found")
def get_error_route(self, code: int) -> Callable:
try:
view = globals()[f"http{code}"]
return view
except KeyError:
return http500
def fire_response(self, resp: HttpResponse):
self.send_response(resp.status_code)
content = resp.render()
self.send_header("Content-Length", str(len(content)))
if resp.headers:
for key, value in resp.headers.items():
self.send_header(key, value)
self.end_headers()
self.wfile.write(bytes(content, "utf-8"))
def handle_request(self, request):
try:
request.url = urlparse.urlparse(request.path)
handler, additional_args = self.get_route(request.url.path)
if request.url.query:
params = urlparse.parse_qs(request.url.query)
else:
params = {}
request.query_params = params
if handler:
try:
resp = handler(request, **additional_args)
if resp is None:
raise NoResponseError(f"View {handler} returned None.")
if isinstance(resp, dict):
self.fire_response(JsonResponse(data=resp))
except APIError:
raise
except ConnectionAbortedError as e:
log.error(f"GET {self.path} : {e}")
except Exception as e:
log.error(e.__traceback__)
self.fire_response(self.get_error_route(500)(self, request))
else:
raise APIError(404)
except APIError as e:
try:
self.send_error(e.code, e.msg, e.desc)
except ConnectionAbortedError as e:
log.error(f"GET {self.path} : {e}")