🚧 progress

This commit is contained in:
Joe Kaufeld 2024-08-05 20:24:30 -04:00
parent 738adee6c2
commit fe359538e1
5 changed files with 90 additions and 33 deletions

View File

@ -0,0 +1,13 @@
from spiderweb.response import JsonResponse
def http403(request):
return JsonResponse(data={"error": "Forbidden"}, status_code=403)
def http404(request):
return JsonResponse(data={"error": f"Route {request.url} not found"}, status_code=404)
def http500(request):
return JsonResponse(data={"error": "Internal server error"}, status_code=500)

View File

@ -4,6 +4,7 @@ class SpiderwebException(Exception):
class SpiderwebNetworkException(SpiderwebException): class SpiderwebNetworkException(SpiderwebException):
"""Something has gone wrong with the network stack."""
def __init__(self, code, msg=None, desc=None): def __init__(self, code, msg=None, desc=None):
self.code = code self.code = code
self.msg = msg self.msg = msg
@ -31,3 +32,7 @@ class GeneralException(SpiderwebException):
class UnusedMiddleware(SpiderwebException): class UnusedMiddleware(SpiderwebException):
pass pass
class NoResponseError(SpiderwebException):
pass

View File

@ -1,5 +1,6 @@
# very simple RPC server in python # Started life from
# Originally from https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55 # https://gist.github.com/earonesty/ab07b4c0fea2c226e75b3d538cc0dc55
#
# Extensively modified by @itsthejoker # Extensively modified by @itsthejoker
import json import json
@ -11,8 +12,10 @@ import logging
from typing import Callable, Any from typing import Callable, Any
from spiderweb.converters import * # noqa: F403 from spiderweb.converters import * # noqa: F403
from spiderweb.exceptions import APIError, ConfigError, ParseError, GeneralException from spiderweb.default_responses import http403, http404, http500
from spiderweb.exceptions import APIError, ConfigError, ParseError, GeneralException, NoResponseError
from spiderweb.request import Request from spiderweb.request import Request
from spiderweb.response import HttpResponse, JsonResponse
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -23,7 +26,6 @@ def api_route(path):
setattr(func, "_routes", []) setattr(func, "_routes", [])
func._routes += [path] func._routes += [path]
return func return func
return outer return outer
@ -61,15 +63,15 @@ def convert_match_to_dict(match: dict):
} }
class APIServer(HTTPServer): class WebServer(HTTPServer):
def __init__(self, addr: str, port: int, custom_handler: Callable = None): def __init__(self, addr: str, port: int, custom_handler: Callable = None):
""" """
Create a new server on address, port. Port can be zero. Create a new server on address, port. Port can be zero.
> from simple_rpc_server import APIServer, APIError, api_route > from simple_rpc_server import WebServer, APIError, api_route
Create your handlers by inheriting from APIServer and tagging them with Create your handlers by inheriting from WebServer and tagging them with
@api_route("/path"). Alternately, you can use the APIServer() directly @api_route("/path"). Alternately, you can use the WebServer() directly
by calling `add_handler("path", function)`. by calling `add_handler("path", function)`.
Raise network errors by raising `APIError(code, message, description=None)`. Raise network errors by raising `APIError(code, message, description=None)`.
@ -83,8 +85,8 @@ class APIServer(HTTPServer):
server_address = (addr, port) server_address = (addr, port)
self.__addr = addr self.__addr = addr
# shim class that is an APIHandler # shim class that is an RequestHandler
class HandlerClass(APIHandler): class HandlerClass(RequestHandler):
pass pass
self.handler_class = custom_handler if custom_handler else HandlerClass self.handler_class = custom_handler if custom_handler else HandlerClass
@ -133,7 +135,7 @@ class APIServer(HTTPServer):
self.socket.close() self.socket.close()
class APIHandler(BaseHTTPRequestHandler): class RequestHandler(BaseHTTPRequestHandler):
# I can't help the naming convention of these because that's what # I can't help the naming convention of these because that's what
# BaseHTTPRequestHandler uses for some weird reason # BaseHTTPRequestHandler uses for some weird reason
_routes = {} _routes = {}
@ -173,6 +175,23 @@ class APIHandler(BaseHTTPRequestHandler):
) )
raise APIError(404, "No route found") raise APIError(404, "No route found")
def get_error_route(self, code: int) -> Callable:
try:
view = globals()[f"http{code}"]
return view
except KeyError:
return http500
def fire_response(self, resp: HttpResponse):
self.send_response(resp.status_code)
content = resp.render()
self.send_header("Content-Length", str(len(content)))
if resp.headers:
for key, value in resp.headers.items():
self.send_header(key, value)
self.end_headers()
self.wfile.write(bytes(content, "utf-8"))
def handle_request(self, request): def handle_request(self, request):
try: try:
request.url = urlparse.urlparse(request.path) request.url = urlparse.urlparse(request.path)
@ -188,22 +207,19 @@ class APIHandler(BaseHTTPRequestHandler):
if handler: if handler:
try: try:
response = handler(request, **additional_args) resp = handler(request, **additional_args)
self.send_response(200) if resp is None:
if response is None: raise NoResponseError(f"View {handler} returned None.")
response = "" if isinstance(resp, dict):
if isinstance(response, dict): self.fire_response(JsonResponse(data=resp))
response = json.dumps(response)
response = bytes(str(response), "utf-8")
self.send_header("Content-Length", str(len(response)))
self.end_headers()
self.wfile.write(response)
except APIError: except APIError:
raise raise
except ConnectionAbortedError as e: except ConnectionAbortedError as e:
log.error(f"GET {self.path} : {e}") log.error(f"GET {self.path} : {e}")
except Exception as e: except Exception as e:
raise APIError(500, str(e)) log.error(e.__traceback__)
self.fire_response(self.get_error_route(500)(self, request))
else: else:
raise APIError(404) raise APIError(404)
except APIError as e: except APIError as e:

View File

@ -1,5 +1,3 @@
from typing import Optional, NoReturn
from spiderweb.request import Request from spiderweb.request import Request
from spiderweb.response import HttpResponse from spiderweb.response import HttpResponse
@ -10,22 +8,22 @@ class SpiderwebMiddleware:
(optional!) methods: (optional!) methods:
process_request(self, request) -> None or Response process_request(self, request) -> None or Response
process_response(self, request, response) -> None process_response(self, request, resp) -> None
Middleware can be used to modify requests and responses in a variety of ways. Middleware can be used to modify requests and responses in a variety of ways.
If one of the two methods is not defined, the request or response will be passed If one of the two methods is not defined, the request or resp will be passed
through unmodified. through unmodified.
If `process_request` returns If `process_request` returns a HttpResponse, the request will be short-circuited
and the response will be returned immediately. `process_response` will not be called.
""" """
def process_request(self, request: Request) -> HttpResponse | None: def process_request(self, request: Request) -> HttpResponse | None:
# example of a middleware that sets a flag on the request # example of a middleware that sets a flag on the request
request.spiderweb = True request.spiderweb = True
def process_response(self, request: Request, response: HttpResponse) -> HttpResponse | None:
def process_response(self, request: Request, response: HttpResponse) -> NoReturn: # example of a middleware that sets a header on the resp
# example of a middleware that sets a header on the response
if hasattr(request, 'spiderweb'): if hasattr(request, 'spiderweb'):
response['X-Spiderweb'] = 'true' response.headers['X-Spiderweb'] = 'true'
return response return response

View File

@ -1,9 +1,34 @@
import json
from typing import Any
class HttpResponse: class HttpResponse:
... def __init__(
self,
content: str = None,
data: dict[str, Any] = None,
status_code: int = 200,
headers=None,
):
self.content = content
self.data = data
self.status_code = status_code
self.headers = headers if headers else {}
def __str__(self):
return self.content
def render(self) -> str:
raise NotImplemented
class JsonResponse(HttpResponse): class JsonResponse(HttpResponse):
... def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.headers["Content-Type"] = "application/json"
def render(self) -> str:
return json.dumps(self.data)
class RedirectResponse(HttpResponse): class RedirectResponse(HttpResponse):