diff --git a/spiderweb/main.py b/spiderweb/main.py deleted file mode 100644 index 97d3c84..0000000 --- a/spiderweb/main.py +++ /dev/null @@ -1,497 +0,0 @@ -# Started life from -# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55 -# -# Extensively modified by @itsthejoker -import inspect -import os -import re -import pathlib -import signal -import time -import traceback -from http.server import BaseHTTPRequestHandler, HTTPServer -import urllib.parse as urlparse -import threading -import logging -from typing import Callable, Any, NoReturn -from wsgiref.simple_server import WSGIRequestHandler, WSGIServer - -from cryptography.fernet import Fernet -from jinja2 import Environment, FileSystemLoader - -from spiderweb.constants import DEFAULT_ENCODING, DEFAULT_ALLOWED_METHODS -from spiderweb.converters import * # noqa: F403 -from spiderweb.default_responses import * # noqa: F403 -from spiderweb.exceptions import ( - APIError, - ConfigError, - ParseError, - GeneralException, - NoResponseError, - UnusedMiddleware, - SpiderwebNetworkException, - NotFound, -) -from spiderweb.request import Request -from spiderweb.response import ( - HttpResponse, - JsonResponse, - TemplateResponse, - RedirectResponse, - FileResponse, -) -from spiderweb.utils import import_by_string, is_safe_path - -log = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - - -def route(path): - def outer(func): - if not hasattr(func, "_routes"): - setattr(func, "_routes", []) - func._routes += [path] - return func - - return outer - - -def send_file(request, filename: str) -> HttpResponse: - for folder in request.server.staticfiles_dirs: - requested_path = request.server.BASE_DIR / folder / filename - if os.path.exists(requested_path): - if not is_safe_path(requested_path): - raise NotFound - return FileResponse(filename=requested_path) - raise NotFound - - -class DummyRedirectRoute: - def __init__(self, location): - self.location = location - - def __call__(self, request): - return RedirectResponse(self.location) - - -def convert_match_to_dict(match: dict): - """Convert a match object to a dict with the proper converted types for each match.""" - return { - k.split("__")[0]: globals()[k.split("__")[1]]().to_python(v) - for k, v in match.items() - } - - -class WebServer(HTTPServer): - def __init__( - self, - addr: str = None, - port: int = None, - custom_handler: Callable = None, - templates_dirs: list[str] = None, - middleware: list[str] = None, - append_slash: bool = False, - staticfiles_dirs: list[str] = None, - secret_key: str = None, - ): - """ - Create a new server on address, port. Port can be zero. - - > from simple_rpc_server import WebServer, APIError, route - - Create your handlers by inheriting from WebServer and tagging them with - @route("/path"). Alternately, you can use the WebServer() directly - by calling `add_handler("path", function)`. - """ - addr = addr if addr else "localhost" - port = port if port else 8000 - self.append_slash = append_slash - self.templates_dirs = templates_dirs - self.staticfiles_dirs = staticfiles_dirs - self.middleware = middleware if middleware else [] - self.secret_key = secret_key if secret_key else self._create_secret_key() - self.fernet = Fernet(self.key) - self.DEFAULT_ENCODING = DEFAULT_ENCODING - self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS - self._thread = None - - self.BASE_DIR = self.get_caller_filepath() - - if self.middleware: - middleware_by_reference = [] - for m in self.middleware: - try: - middleware_by_reference.append(import_by_string(m)(server=self)) - except ImportError: - raise ConfigError(f"Middleware '{m}' not found.") - self.middleware = middleware_by_reference - - if self.templates_dirs: - self.env = Environment(loader=FileSystemLoader(self.templates_dirs)) - else: - self.env = None - - server_address = (addr, port) - self.__addr = addr - - # shim class that is an RequestHandler - class HandlerClass(RequestHandler): - pass - - self.handler_class = custom_handler if custom_handler else HandlerClass - self.handler_class.server = self - - # routed methods map into handler - for method in type(self).__dict__.values(): - if hasattr(method, "_routes"): - for route in method._routes: - self.add_route(route, method) - - if self.staticfiles_dirs: - for static_dir in self.staticfiles_dirs: - static_dir = pathlib.Path(static_dir) - if not pathlib.Path(self.BASE_DIR / static_dir).exists(): - log.error( - f"Static files directory '{str(static_dir)}' does not exist." - ) - raise ConfigError - self.add_route(r"/static/", send_file) - - try: - super().__init__(server_address, self.handler_class) - except OSError: - raise GeneralException("Port already in use.") - - def get_caller_filepath(self): - """Figure out who called us and return their path.""" - stack = inspect.stack() - caller_frame = stack[1] - return pathlib.Path(caller_frame.filename).parent.parent - - def convert_path(self, path: str): - """Convert a path to a regex.""" - parts = path.split("/") - for i, part in enumerate(parts): - if part.startswith("<") and part.endswith(">"): - name = part[1:-1] - if "__" in name: - raise ConfigError( - f"Cannot use `__` (double underscore) in path variable." - f" Please fix '{name}'." - ) - if ":" in name: - converter, name = name.split(":") - try: - converter = globals()[converter.title() + "Converter"] - except KeyError: - raise ParseError(f"Unknown converter {converter}") - else: - converter = StrConverter # noqa: F405 - parts[i] = rf"(?P<{name}__{str(converter.__name__)}>{converter.regex})" - return re.compile(rf"^{'/'.join(parts)}$") - - def check_for_route_duplicates(self, path: str): - if self.convert_path(path) in self.handler_class._routes: - raise ConfigError(f"Route '{path}' already exists.") - - def add_route( - self, path: str, method: Callable, allowed_methods: None | list[str] = None - ): - """Add a route to the server.""" - if not hasattr(self.handler_class, "_routes"): - setattr(self.handler_class, "_routes", {}) - - allowed_methods = ( - allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS - ) - - if self.append_slash and not path.endswith("/"): - updated_path = path + "/" - self.check_for_route_duplicates(updated_path) - self.check_for_route_duplicates(path) - self.handler_class._routes[self.convert_path(path)] = { - "func": DummyRedirectRoute(updated_path), - "allowed_methods": allowed_methods, - } - self.handler_class._routes[self.convert_path(updated_path)] = { - "func": method, - "allowed_methods": allowed_methods, - } - else: - self.check_for_route_duplicates(path) - self.handler_class._routes[self.convert_path(path)] = { - "func": method, - "allowed_methods": allowed_methods, - } - - def add_error_route(self, code: int, method: Callable): - """Add an error route to the server.""" - if not hasattr(self.handler_class, "_error_routes"): - setattr(self.handler_class, "_error_routes", {}) - - if code not in self.handler_class._error_routes: - self.handler_class._error_routes[code] = method - else: - raise ConfigError(f"Error route for code {code} already exists.") - - def route(self, path, allowed_methods=None) -> Callable: - """ - Decorator for adding a route to a view. - - Usage: - - app = WebServer() - - @app.route("/hello") - def index(request): - return HttpResponse(content="Hello, world!") - - :param path: str - :param allowed_methods: list[str] - :return: Callable - """ - - def outer(func): - self.add_route( - path, - func, - allowed_methods if allowed_methods else DEFAULT_ALLOWED_METHODS, - ) - return func - - return outer - - def error(self, code: int) -> Callable: - def outer(func): - self.add_error_route(code, func) - return func - - return outer - - @property - def port(self): - """Return current port.""" - return self.socket.getsockname()[1] - - @property - def address(self): - """Return current IP address.""" - return self.socket.getsockname()[0] - - def uri(self, path=None): - """Make a URI pointing at myself.""" - path = path if path else "" - if path.startswith("/"): - path = path[1:] - return self.__addr + ":" + str(self.port()) + "/" + path - - def signal_handler(self, sig, frame) -> NoReturn: - log.warning("Shutting down!") - self.stop() - - def start(self, blocking=False): - signal.signal(signal.SIGINT, self.signal_handler) - log.info(f"Starting server on {self.address}:{self.port}") - log.info("Press CTRL+C to stop the server.") - self._thread = threading.Thread(target=self.serve_forever) - self._thread.start() - if not blocking: - return self._thread - else: - while self._thread.is_alive(): - try: - time.sleep(0.2) - except KeyboardInterrupt: - self.stop() - - def stop(self): - super().shutdown() - self.socket.close() - - def _create_secret_key(self): - self.key = Fernet.generate_key() - - def encrypt(self, data: str): - return self.fernet.encrypt(bytes(data, DEFAULT_ENCODING)) - - def decrypt(self, data: str): - if isinstance(data, bytes): - return self.fernet.decrypt(data).decode(DEFAULT_ENCODING) - return self.fernet.decrypt(bytes(data, DEFAULT_ENCODING)).decode( - DEFAULT_ENCODING - ) - - -class RequestHandler(BaseHTTPRequestHandler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # These stop pycharm from complaining about these not existing. They're - # injected by the WebServer class at runtime - self._routes = {} - self._error_routes = {} - self.server = None - - def get_request(self): - return Request( - content="", - body="", - method=self.command, - headers=self.headers, - path=self.path, - server=self.server, - ) - - # I can't help the naming convention of these because that's what - # BaseHTTPRequestHandler uses for some weird reason - def do_GET(self): - request = self.get_request() - request.method = "GET" - self.handle_request(request) - - def do_POST(self): - content = "{}" - if self.headers["Content-Length"]: - length = int(self.headers["Content-Length"]) - content = self.rfile.read(length) - request = self.get_request() - request.method = "POST" - request.content = content - self.handle_request(request) - - def get_route(self, path) -> tuple[Callable, dict[str, Any], list[str]]: - for option in self._routes.keys(): - if match_data := option.match(path): - return ( - self._routes[option]["func"], - convert_match_to_dict(match_data.groupdict()), - self._routes[option]["allowed_methods"], - ) - raise NotFound() - - def get_error_route(self, code: int) -> Callable: - view = self._error_routes.get(code) or globals().get(f"http{code}") - if not view: - return http500 - return view - - def _fire_response( - self, - status: int = 200, - content: str = None, - headers: dict[str, str | int] = None, - ): - self.send_response(status) - self.send_header("Content-Length", str(len(content))) - if headers: - for key, value in headers.items(): - self.send_header(key, value) - self.end_headers() - self.wfile.write(bytes(content, DEFAULT_ENCODING)) - - def fire_response(self, request: Request, resp: HttpResponse): - try: - self._fire_response( - status=resp.status_code, content=resp.render(), headers=resp.headers - ) - except APIError: - raise - except ConnectionAbortedError as e: - log.error(f"GET {self.path} : {e}") - except Exception: - log.error(traceback.format_exc()) - self.fire_response(request, self.get_error_route(500)(request)) - - def process_request_middleware(self, request: Request) -> None | bool: - for middleware in self.server.middleware: - try: - resp = middleware.process_request(request) - except UnusedMiddleware: - self.server.middleware.remove(middleware) - continue - if resp: - self.process_response_middleware(request, resp) - self.fire_response(request, resp) - return True # abort further processing - - def process_response_middleware( - self, request: Request, response: HttpResponse - ) -> None: - for middleware in self.server.middleware: - try: - middleware.process_response(request, response) - except UnusedMiddleware: - self.server.middleware.remove(middleware) - continue - - def prepare_and_fire_response(self, request, resp) -> None: - try: - if isinstance(resp, dict): - self.fire_response(request, JsonResponse(data=resp)) - if isinstance(resp, TemplateResponse): - if hasattr(self.server, "env"): - resp.set_template_loader(self.server.env) - - for middleware in self.server.middleware: - middleware.process_response(request, resp) - - self.fire_response(request, resp) - - except APIError: - raise - - except Exception: - log.error(traceback.format_exc()) - self.fire_response(request, self.get_error_route(500)(request)) - - def is_form_request(self, request: Request) -> bool: - return ( - "Content-Type" in request.headers - and request.headers["Content-Type"] == "application/x-www-form-urlencoded" - ) - - def send_error_response(self, request: Request, e: SpiderwebNetworkException): - try: - self.send_error(e.code, e.msg, e.desc) - except ConnectionAbortedError as e: - log.error(f"{request.method} {self.path} : {e}") - - def handle_request(self, request): - - try: - handler, additional_args, allowed_methods = self.get_route(request.url.path) - except NotFound: - handler = self.get_error_route(404) - additional_args = {} - allowed_methods = DEFAULT_ALLOWED_METHODS - - if request.method not in allowed_methods: - # replace the potentially valid handler with the error route - handler = self.get_error_route(405) - - request.query_params = ( - urlparse.parse_qs(request.url.query) if request.url.query else {} - ) - - if self.is_form_request(request): - formdata = urlparse.parse_qs(request.content.decode("utf-8")) - for key, value in formdata.items(): - if len(value) == 1: - formdata[key] = value[0] - setattr(request, request.method, formdata) - - try: - if handler: - # middleware is injected from WebServer - abort = self.process_request_middleware(request) - if abort: - return - - resp = handler(request, **additional_args) - if resp is None: - raise NoResponseError(f"View {handler} returned None.") - # run the response through the middleware and send it - self.prepare_and_fire_response(request, resp) - else: - raise SpiderwebNetworkException(404) - except SpiderwebNetworkException as e: - self.send_error_response(request, e)