diff --git a/spiderweb/default_responses.py b/spiderweb/default_responses.py new file mode 100644 index 0000000..b0ee819 --- /dev/null +++ b/spiderweb/default_responses.py @@ -0,0 +1,13 @@ +from spiderweb.response import JsonResponse + + +def http403(request): + return JsonResponse(data={"error": "Forbidden"}, status_code=403) + + +def http404(request): + return JsonResponse(data={"error": f"Route {request.url} not found"}, status_code=404) + + +def http500(request): + return JsonResponse(data={"error": "Internal server error"}, status_code=500) diff --git a/spiderweb/exceptions.py b/spiderweb/exceptions.py index 09ccbe6..90150d1 100644 --- a/spiderweb/exceptions.py +++ b/spiderweb/exceptions.py @@ -4,6 +4,7 @@ class SpiderwebException(Exception): class SpiderwebNetworkException(SpiderwebException): + """Something has gone wrong with the network stack.""" def __init__(self, code, msg=None, desc=None): self.code = code self.msg = msg @@ -31,3 +32,7 @@ class GeneralException(SpiderwebException): class UnusedMiddleware(SpiderwebException): pass + + +class NoResponseError(SpiderwebException): + pass diff --git a/spiderweb/main.py b/spiderweb/main.py index fa2f913..296049d 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -1,5 +1,6 @@ -# very simple RPC server in python -# Originally from https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55 +# Started life from +# https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55 +# # Extensively modified by @itsthejoker import json @@ -11,8 +12,10 @@ import logging from typing import Callable, Any from spiderweb.converters import * # noqa: F403 -from spiderweb.exceptions import APIError, ConfigError, ParseError, GeneralException +from spiderweb.default_responses import http403, http404, http500 +from spiderweb.exceptions import APIError, ConfigError, ParseError, GeneralException, NoResponseError from spiderweb.request import Request +from spiderweb.response import HttpResponse, JsonResponse log = logging.getLogger(__name__) @@ -23,7 +26,6 @@ def api_route(path): setattr(func, "_routes", []) func._routes += [path] return func - return outer @@ -61,15 +63,15 @@ def convert_match_to_dict(match: dict): } -class APIServer(HTTPServer): +class WebServer(HTTPServer): def __init__(self, addr: str, port: int, custom_handler: Callable = None): """ Create a new server on address, port. Port can be zero. - > from simple_rpc_server import APIServer, APIError, api_route + > from simple_rpc_server import WebServer, APIError, api_route - Create your handlers by inheriting from APIServer and tagging them with - @api_route("/path"). Alternately, you can use the APIServer() directly + Create your handlers by inheriting from WebServer and tagging them with + @api_route("/path"). Alternately, you can use the WebServer() directly by calling `add_handler("path", function)`. Raise network errors by raising `APIError(code, message, description=None)`. @@ -83,8 +85,8 @@ class APIServer(HTTPServer): server_address = (addr, port) self.__addr = addr - # shim class that is an APIHandler - class HandlerClass(APIHandler): + # shim class that is an RequestHandler + class HandlerClass(RequestHandler): pass self.handler_class = custom_handler if custom_handler else HandlerClass @@ -133,7 +135,7 @@ class APIServer(HTTPServer): self.socket.close() -class APIHandler(BaseHTTPRequestHandler): +class RequestHandler(BaseHTTPRequestHandler): # I can't help the naming convention of these because that's what # BaseHTTPRequestHandler uses for some weird reason _routes = {} @@ -173,6 +175,23 @@ class APIHandler(BaseHTTPRequestHandler): ) raise APIError(404, "No route found") + def get_error_route(self, code: int) -> Callable: + try: + view = globals()[f"http{code}"] + return view + except KeyError: + return http500 + + def fire_response(self, resp: HttpResponse): + self.send_response(resp.status_code) + content = resp.render() + self.send_header("Content-Length", str(len(content))) + if resp.headers: + for key, value in resp.headers.items(): + self.send_header(key, value) + self.end_headers() + self.wfile.write(bytes(content, "utf-8")) + def handle_request(self, request): try: request.url = urlparse.urlparse(request.path) @@ -188,22 +207,19 @@ class APIHandler(BaseHTTPRequestHandler): if handler: try: - response = handler(request, **additional_args) - self.send_response(200) - if response is None: - response = "" - if isinstance(response, dict): - response = json.dumps(response) - response = bytes(str(response), "utf-8") - self.send_header("Content-Length", str(len(response))) - self.end_headers() - self.wfile.write(response) + resp = handler(request, **additional_args) + if resp is None: + raise NoResponseError(f"View {handler} returned None.") + if isinstance(resp, dict): + self.fire_response(JsonResponse(data=resp)) except APIError: raise except ConnectionAbortedError as e: log.error(f"GET {self.path} : {e}") except Exception as e: - raise APIError(500, str(e)) + log.error(e.__traceback__) + self.fire_response(self.get_error_route(500)(self, request)) + else: raise APIError(404) except APIError as e: diff --git a/spiderweb/middleware.py b/spiderweb/middleware.py index 7639187..68a4586 100644 --- a/spiderweb/middleware.py +++ b/spiderweb/middleware.py @@ -1,5 +1,3 @@ -from typing import Optional, NoReturn - from spiderweb.request import Request from spiderweb.response import HttpResponse @@ -10,22 +8,22 @@ class SpiderwebMiddleware: (optional!) methods: process_request(self, request) -> None or Response - process_response(self, request, response) -> None + process_response(self, request, resp) -> None Middleware can be used to modify requests and responses in a variety of ways. - If one of the two methods is not defined, the request or response will be passed + If one of the two methods is not defined, the request or resp will be passed through unmodified. - If `process_request` returns + If `process_request` returns a HttpResponse, the request will be short-circuited + and the response will be returned immediately. `process_response` will not be called. """ def process_request(self, request: Request) -> HttpResponse | None: # example of a middleware that sets a flag on the request request.spiderweb = True - - def process_response(self, request: Request, response: HttpResponse) -> NoReturn: - # example of a middleware that sets a header on the response + def process_response(self, request: Request, response: HttpResponse) -> HttpResponse | None: + # example of a middleware that sets a header on the resp if hasattr(request, 'spiderweb'): - response['X-Spiderweb'] = 'true' + response.headers['X-Spiderweb'] = 'true' return response diff --git a/spiderweb/response.py b/spiderweb/response.py index 5a94653..1b36b6d 100644 --- a/spiderweb/response.py +++ b/spiderweb/response.py @@ -1,9 +1,34 @@ +import json +from typing import Any + + class HttpResponse: - ... + def __init__( + self, + content: str = None, + data: dict[str, Any] = None, + status_code: int = 200, + headers=None, + ): + self.content = content + self.data = data + self.status_code = status_code + self.headers = headers if headers else {} + + def __str__(self): + return self.content + + def render(self) -> str: + raise NotImplemented class JsonResponse(HttpResponse): - ... + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.headers["Content-Type"] = "application/json" + + def render(self) -> str: + return json.dumps(self.data) class RedirectResponse(HttpResponse):