diff --git a/example.py b/example.py index 88052d6..c2543fe 100644 --- a/example.py +++ b/example.py @@ -15,6 +15,7 @@ from spiderweb.response import ( app = SpiderwebRouter( templates_dirs=["templates"], middleware=[ + "spiderweb.middleware.sessions.SessionMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware", "example_middleware.TestMiddleware", "example_middleware.RedirectMiddleware", @@ -72,6 +73,15 @@ def form(request: CommentForm): return TemplateResponse(request, "form.html") +@app.route("/session") +def session(request): + if "test" not in request.SESSION: + request.SESSION["test"] = 0 + else: + request.SESSION["test"] += 1 + return HttpResponse(body=f"Session test: {request.SESSION['test']}") + + @app.route("/cookies") def cookies(request): print("request.COOKIES: ", request.COOKIES) diff --git a/poetry.lock b/poetry.lock index b5ea28c..6a2e005 100644 --- a/poetry.lock +++ b/poetry.lock @@ -641,4 +641,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "4fa8ab616be6891780300d4d66e9fb936aeac75bd3448e04073b2835acf9aadd" +content-hash = "84633fc94c48c2a05b5ec77367ad29f327be1dc249a6e4cb76b50ebbe14739b5" diff --git a/spiderweb/constants.py b/spiderweb/constants.py index ecbd6bb..4495ae5 100644 --- a/spiderweb/constants.py +++ b/spiderweb/constants.py @@ -1,6 +1,10 @@ +from peewee import DatabaseProxy + DEFAULT_ALLOWED_METHODS = ["GET"] DEFAULT_ENCODING = "ISO-8859-1" __version__ = "0.10.0" # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" + +DATABASE_PROXY = DatabaseProxy() diff --git a/spiderweb/db.py b/spiderweb/db.py index e69de29..9eedf31 100644 --- a/spiderweb/db.py +++ b/spiderweb/db.py @@ -0,0 +1,98 @@ +from peewee import Model, Field, SchemaManager, DatabaseProxy + +from spiderweb.constants import DATABASE_PROXY + + +class MigrationsNeeded(ExceptionGroup): ... + + +class MigrationRequired(Exception): ... + + +class SpiderwebModel(Model): + + @classmethod + def check_for_needed_migration(cls): + current_model_fields: dict[str, Field] = cls._meta.fields + current_db_fields = { + c.name: { + "data_type": c.data_type, + "null": c.null, + "primary_key": c.primary_key, + "default": c.default, + } + for c in cls._meta.database.get_columns(cls._meta.table_name) + } + problems = [] + s = SchemaManager(cls, cls._meta.database) + ctx = s._create_context() + for field_name, field_obj in current_model_fields.items(): + db_version = current_db_fields.get(field_obj.column_name) + if not db_version: + problems.append( + MigrationRequired(f"Field {field_name} not found in DB.") + ) + continue + + if field_obj.field_type == "VARCHAR": + field_obj.max_length = field_obj.max_length or 255 + if ( + cls._meta.fields[field_name].ddl_datatype(ctx).sql + != db_version["data_type"] + ): + problems.append( + MigrationRequired( + f"CharField `{field_name}` has changed the field type." + ) + ) + else: + if ( + cls._meta.database.get_context_options()["field_types"][ + field_obj.field_type + ] + != db_version["data_type"] + ): + problems.append( + MigrationRequired( + f"Field `{field_name}` has changed the field type." + ) + ) + if field_obj.null != db_version["null"]: + problems.append( + MigrationRequired( + f"Field `{field_name}` has changed the nullability." + ) + ) + if field_obj.__class__.__name__ == "BooleanField": + if field_obj.default == False and db_version["default"] not in ( + False, + None, + 0, + ): + problems.append( + MigrationRequired( + f"BooleanField `{field_name}` has changed the default value." + ) + ) + elif field_obj.default == True and db_version["default"] not in ( + True, + 1, + ): + problems.append( + MigrationRequired( + f"BooleanField `{field_name}` has changed the default value." + ) + ) + else: + if field_obj.default != db_version["default"]: + problems.append( + MigrationRequired( + f"Field `{field_name}` has changed the default value." + ) + ) + + if problems: + raise MigrationsNeeded(f"The model {cls} requires migrations.", problems) + + class Meta: + database = DATABASE_PROXY diff --git a/spiderweb/example_validator.py b/spiderweb/example_validator.py index c791cd7..babb562 100644 --- a/spiderweb/example_validator.py +++ b/spiderweb/example_validator.py @@ -1,8 +1,8 @@ from pydantic import EmailStr -from spiderweb.middleware.pydantic import SpiderwebModel +from spiderweb.middleware.pydantic import RequestModel -class CommentForm(SpiderwebModel): +class CommentForm(RequestModel): email: EmailStr comment: str diff --git a/spiderweb/exceptions.py b/spiderweb/exceptions.py index 4969560..c826d40 100644 --- a/spiderweb/exceptions.py +++ b/spiderweb/exceptions.py @@ -1,6 +1,10 @@ class SpiderwebException(Exception): # parent error class; all child exceptions should inherit from this def __str__(self): + name = self.__class__.__name__ + msg = self.args[0] if len(self.args) > 0 else "" + if msg: + return f"{name}() - {msg}" return f"{self.__class__.__name__}()" diff --git a/spiderweb/local_server.py b/spiderweb/local_server.py index b1ad396..f869e6a 100644 --- a/spiderweb/local_server.py +++ b/spiderweb/local_server.py @@ -16,7 +16,7 @@ class SpiderwebRequestHandler(WSGIRequestHandler): super().__init__(*args, **kwargs) -class LocalServerMiddleware: +class LocalServerMixin: """Cannot be called on its own. Requires context of SpiderwebRouter.""" addr: str diff --git a/spiderweb/main.py b/spiderweb/main.py index daca021..b6d1efb 100644 --- a/spiderweb/main.py +++ b/spiderweb/main.py @@ -8,9 +8,15 @@ from typing import Optional, Callable from wsgiref.simple_server import WSGIServer from jinja2 import Environment, FileSystemLoader +from peewee import Database, SqliteDatabase -from spiderweb.middleware import MiddlewareMiddleware -from spiderweb.constants import DEFAULT_ENCODING, DEFAULT_ALLOWED_METHODS +from spiderweb.middleware import MiddlewareMixin +from spiderweb.constants import ( + DATABASE_PROXY, + DEFAULT_ENCODING, + DEFAULT_ALLOWED_METHODS, +) +from spiderweb.db import SpiderwebModel from spiderweb.default_views import * # noqa: F403 from spiderweb.exceptions import ( ConfigError, @@ -19,24 +25,23 @@ from spiderweb.exceptions import ( NoResponseError, SpiderwebNetworkException, ) -from spiderweb.local_server import LocalServerMiddleware +from spiderweb.local_server import LocalServerMixin from spiderweb.request import Request from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse -from spiderweb.routes import RoutesMiddleware -from spiderweb.secrets import FernetMiddleware +from spiderweb.routes import RoutesMixin +from spiderweb.secrets import FernetMixin from spiderweb.utils import get_http_status_by_code file_logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) -class SpiderwebRouter( - LocalServerMiddleware, MiddlewareMiddleware, RoutesMiddleware, FernetMiddleware -): +class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixin): def __init__( self, addr: str = None, port: int = None, + db: Optional[Database] = None, templates_dirs: list[str] = None, middleware: list[str] = None, append_slash: bool = False, @@ -44,6 +49,12 @@ class SpiderwebRouter( routes: list[list[str | Callable | dict]] = None, error_routes: dict[str, Callable] = None, secret_key: str = None, + session_max_age=60 * 60 * 24 * 14, # 2 weeks + session_cookie_name="swsession", + session_cookie_secure=False, # should be true if serving over HTTPS + session_cookie_http_only=True, + session_cookie_same_site="lax", + session_cookie_path="/", log=None, ): self._routes = {} @@ -59,9 +70,17 @@ class SpiderwebRouter( self.middleware = middleware if middleware else [] self.secret_key = secret_key if secret_key else self.generate_key() + # session middleware + self.session_max_age = session_max_age + self.session_cookie_name = session_cookie_name + self.session_cookie_secure = session_cookie_secure + self.session_cookie_http_only = session_cookie_http_only + self.session_cookie_same_site = session_cookie_same_site + self.session_cookie_path = session_cookie_path + self.DEFAULT_ENCODING = DEFAULT_ENCODING self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS - self.log = log if log else file_logger + self.log: logging.Logger = log if log else file_logger # for using .start() and .stop() self._thread: Optional[Thread] = None @@ -71,6 +90,13 @@ class SpiderwebRouter( self.init_fernet() self.init_middleware() + self.db = db or SqliteDatabase(self.BASE_DIR / "spiderweb.db") + # give the models the db connection + DATABASE_PROXY.initialize(self.db) + self.db.create_tables(SpiderwebModel.__subclasses__()) + for model in SpiderwebModel.__subclasses__(): + model.check_for_needed_migration() + if self.routes: self.add_routes() diff --git a/spiderweb/middleware/__init__.py b/spiderweb/middleware/__init__.py index f3112e0..32c9419 100644 --- a/spiderweb/middleware/__init__.py +++ b/spiderweb/middleware/__init__.py @@ -2,13 +2,14 @@ from typing import Callable, ClassVar from .base import SpiderwebMiddleware as SpiderwebMiddleware from .csrf import CSRFMiddleware as CSRFMiddleware +from .sessions import SessionMiddleware as SessionMiddleware from ..exceptions import ConfigError, UnusedMiddleware from ..request import Request from ..response import HttpResponse from ..utils import import_by_string -class MiddlewareMiddleware: +class MiddlewareMixin: """Cannot be called on its own. Requires context of SpiderwebRouter.""" middleware: list[ClassVar] diff --git a/spiderweb/middleware/pydantic.py b/spiderweb/middleware/pydantic.py index c7be9f2..5258680 100644 --- a/spiderweb/middleware/pydantic.py +++ b/spiderweb/middleware/pydantic.py @@ -7,7 +7,7 @@ from spiderweb.request import Request from spiderweb.response import JsonResponse -class SpiderwebModel(BaseModel, Request): +class RequestModel(BaseModel, Request): # type hinting shenanigans that allow us to annotate Request objects # with the pydantic models we want to validate them with, but doesn't # break the Request object's ability to be used as a Request object diff --git a/spiderweb/middleware/sessions.py b/spiderweb/middleware/sessions.py new file mode 100644 index 0000000..a835a41 --- /dev/null +++ b/spiderweb/middleware/sessions.py @@ -0,0 +1,117 @@ +from datetime import datetime, timedelta +import json + +from peewee import CharField, TextField, DateTimeField, BooleanField + +from spiderweb.middleware import SpiderwebMiddleware +from spiderweb.request import Request +from spiderweb.response import HttpResponse +from spiderweb.db import SpiderwebModel +from spiderweb.utils import generate_key, is_jsonable + + +class Session(SpiderwebModel): + session_key = CharField(max_length=64) + user_id = CharField(max_length=64, null=True) + is_authenticated = BooleanField(default=False) + session_data = TextField() + created_at = DateTimeField() + last_active = DateTimeField() + ip_address = CharField(max_length=30) + user_agent = TextField() + + +class SessionMiddleware(SpiderwebMiddleware): + def process_request(self, request: Request): + existing_session = ( + Session.select() + .where( + Session.session_key + == request.COOKIES.get(self.server.session_cookie_name), + Session.ip_address == request.META.get("client_address"), + Session.user_agent == request.headers.get("HTTP_USER_AGENT"), + ) + .first() + ) + new_session = False + if not existing_session: + new_session = True + elif datetime.now() - existing_session.created_at > timedelta( + seconds=self.server.session_max_age + ): + existing_session.delete_instance() + new_session = True + + if new_session: + request.SESSION = {} + request._session["id"] = generate_key() + request._session["new_session"] = True + return + + request.SESSION = json.loads(existing_session.session_data) + request._session["id"] = existing_session.session_key + existing_session.save() + + def process_response(self, request: Request, response: HttpResponse): + cookie_settings = { + "max_age": self.server.session_max_age, + "same_site": self.server.session_cookie_same_site, + "http_only": self.server.session_cookie_http_only, + "secure": self.server.session_cookie_secure + or request.META.get("HTTPS", False), + "path": self.server.session_cookie_path, + } + + # if a new session has been requested, ignore everything else and make that happen + if request._session["new_session"]: + # we generated a new one earlier, so we can use it now + session_key = request._session["id"] + response.set_cookie( + self.server.session_cookie_name, + session_key, + **cookie_settings, + ) + session = Session( + session_key=session_key, + session_data=json.dumps(request.SESSION), + created_at=datetime.now(), + last_active=datetime.now(), + ip_address=request.META.get("client_address"), + user_agent=request.headers.get("HTTP_USER_AGENT"), + ) + session.save() + return + + # Otherwise, we can save the one we already have. + session_key = request._session["id"] + # update the session expiration time + response.set_cookie( + self.server.session_cookie_name, + session_key, + **cookie_settings, + ) + + session = ( + Session.select() + .where( + Session.session_key == session_key, + Session.ip_address == request.META.get("client_address"), + Session.user_agent == request.headers.get("HTTP_USER_AGENT"), + ) + .first() + ) + if not session: + if not is_jsonable(request.SESSION): + raise ValueError("Session data is not JSON serializable.") + session = Session( + session_key=session_key, + session_data=json.dumps(request.SESSION), + created_at=datetime.now(), + last_active=datetime.now(), + ip_address=request.META.get("client_address"), + user_agent=request.META.get("HTTP_USER_AGENT"), + ) + else: + session.session_data = json.dumps(request.SESSION) + session.last_active = datetime.now() + session.save() diff --git a/spiderweb/request.py b/spiderweb/request.py index 0ba7f95..6f95cde 100644 --- a/spiderweb/request.py +++ b/spiderweb/request.py @@ -2,6 +2,7 @@ import json from urllib.parse import urlparse from spiderweb.constants import DEFAULT_ENCODING +from spiderweb.utils import get_client_address class Request: @@ -27,6 +28,9 @@ class Request: self.POST = {} self.META = {} self.COOKIES = {} + # only used for the session middleware + self.SESSION = {} + self._session: dict = {"new_session": False, "id": None} # only used for the pydantic middleware and only on POST requests self.validated_data = {} @@ -50,6 +54,8 @@ class Request: self.headers[k] = v def populate_meta(self) -> None: + # all caps fields are from WSGI, lowercase names + # are custom fields = [ "SERVER_PROTOCOL", "SERVER_SOFTWARE", @@ -66,6 +72,7 @@ class Request: ] for f in fields: self.META[f] = self.environ.get(f) + self.META["client_address"] = get_client_address(self.environ) def populate_cookies(self) -> None: if cookies := self.environ.get("HTTP_COOKIE"): diff --git a/spiderweb/routes.py b/spiderweb/routes.py index 0a608b7..49fbcde 100644 --- a/spiderweb/routes.py +++ b/spiderweb/routes.py @@ -24,7 +24,7 @@ class DummyRedirectRoute: return RedirectResponse(self.location) -class RoutesMiddleware: +class RoutesMixin: """Cannot be called on its own. Requires context of SpiderwebRouter.""" # ones that start with underscores are the compiled versions, non-underscores diff --git a/spiderweb/secrets.py b/spiderweb/secrets.py index b0e26ae..31f86cd 100644 --- a/spiderweb/secrets.py +++ b/spiderweb/secrets.py @@ -3,7 +3,7 @@ from cryptography.fernet import Fernet from spiderweb.constants import DEFAULT_ENCODING -class FernetMiddleware: +class FernetMixin: """Cannot be called on its own. Requires context of SpiderwebRouter.""" fernet: Fernet diff --git a/spiderweb/utils.py b/spiderweb/utils.py index 4ae033f..42baf35 100644 --- a/spiderweb/utils.py +++ b/spiderweb/utils.py @@ -1,7 +1,14 @@ +import json +import secrets +import string from http import HTTPStatus -from typing import Optional +from typing import Optional, TYPE_CHECKING -from spiderweb.request import Request +if TYPE_CHECKING: + from spiderweb.request import Request + + +VALID_CHARS = string.ascii_letters + string.digits def import_by_string(name): @@ -31,8 +38,28 @@ def get_http_status_by_code(code: int) -> Optional[str]: return f"{resp.value} {resp.phrase}" -def is_form_request(request: Request) -> bool: +def is_form_request(request: "Request") -> bool: return ( "Content-Type" in request.headers and request.headers["Content-Type"] == "application/x-www-form-urlencoded" ) + + +# https://stackoverflow.com/a/7839576 +def get_client_address(environ: dict) -> str: + try: + return environ["HTTP_X_FORWARDED_FOR"].split(",")[-1].strip() + except KeyError: + return environ.get("REMOTE_ADDR", "unknown") + + +def generate_key(length=64): + return "".join(secrets.choice(VALID_CHARS) for _ in range(length)) + + +def is_jsonable(data: str) -> bool: + try: + json.dumps(data) + return True + except (TypeError, OverflowError): + return False