session middleware and databases!

This commit is contained in:
Joe Kaufeld 2024-08-25 00:04:29 -04:00
parent 173325731b
commit 9d4dffb358
15 changed files with 314 additions and 20 deletions

View File

@ -15,6 +15,7 @@ from spiderweb.response import (
app = SpiderwebRouter( app = SpiderwebRouter(
templates_dirs=["templates"], templates_dirs=["templates"],
middleware=[ middleware=[
"spiderweb.middleware.sessions.SessionMiddleware",
"spiderweb.middleware.csrf.CSRFMiddleware", "spiderweb.middleware.csrf.CSRFMiddleware",
"example_middleware.TestMiddleware", "example_middleware.TestMiddleware",
"example_middleware.RedirectMiddleware", "example_middleware.RedirectMiddleware",
@ -72,6 +73,15 @@ def form(request: CommentForm):
return TemplateResponse(request, "form.html") return TemplateResponse(request, "form.html")
@app.route("/session")
def session(request):
if "test" not in request.SESSION:
request.SESSION["test"] = 0
else:
request.SESSION["test"] += 1
return HttpResponse(body=f"Session test: {request.SESSION['test']}")
@app.route("/cookies") @app.route("/cookies")
def cookies(request): def cookies(request):
print("request.COOKIES: ", request.COOKIES) print("request.COOKIES: ", request.COOKIES)

2
poetry.lock generated
View File

@ -641,4 +641,4 @@ files = [
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "4fa8ab616be6891780300d4d66e9fb936aeac75bd3448e04073b2835acf9aadd" content-hash = "84633fc94c48c2a05b5ec77367ad29f327be1dc249a6e4cb76b50ebbe14739b5"

View File

@ -1,6 +1,10 @@
from peewee import DatabaseProxy
DEFAULT_ALLOWED_METHODS = ["GET"] DEFAULT_ALLOWED_METHODS = ["GET"]
DEFAULT_ENCODING = "ISO-8859-1" DEFAULT_ENCODING = "ISO-8859-1"
__version__ = "0.10.0" __version__ = "0.10.0"
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$" REGEX_COOKIE_NAME = r"^[a-zA-Z0-9\s\(\)<>@,;:\/\\\[\]\?=\{\}\"\t]*$"
DATABASE_PROXY = DatabaseProxy()

View File

@ -0,0 +1,98 @@
from peewee import Model, Field, SchemaManager, DatabaseProxy
from spiderweb.constants import DATABASE_PROXY
class MigrationsNeeded(ExceptionGroup): ...
class MigrationRequired(Exception): ...
class SpiderwebModel(Model):
@classmethod
def check_for_needed_migration(cls):
current_model_fields: dict[str, Field] = cls._meta.fields
current_db_fields = {
c.name: {
"data_type": c.data_type,
"null": c.null,
"primary_key": c.primary_key,
"default": c.default,
}
for c in cls._meta.database.get_columns(cls._meta.table_name)
}
problems = []
s = SchemaManager(cls, cls._meta.database)
ctx = s._create_context()
for field_name, field_obj in current_model_fields.items():
db_version = current_db_fields.get(field_obj.column_name)
if not db_version:
problems.append(
MigrationRequired(f"Field {field_name} not found in DB.")
)
continue
if field_obj.field_type == "VARCHAR":
field_obj.max_length = field_obj.max_length or 255
if (
cls._meta.fields[field_name].ddl_datatype(ctx).sql
!= db_version["data_type"]
):
problems.append(
MigrationRequired(
f"CharField `{field_name}` has changed the field type."
)
)
else:
if (
cls._meta.database.get_context_options()["field_types"][
field_obj.field_type
]
!= db_version["data_type"]
):
problems.append(
MigrationRequired(
f"Field `{field_name}` has changed the field type."
)
)
if field_obj.null != db_version["null"]:
problems.append(
MigrationRequired(
f"Field `{field_name}` has changed the nullability."
)
)
if field_obj.__class__.__name__ == "BooleanField":
if field_obj.default == False and db_version["default"] not in (
False,
None,
0,
):
problems.append(
MigrationRequired(
f"BooleanField `{field_name}` has changed the default value."
)
)
elif field_obj.default == True and db_version["default"] not in (
True,
1,
):
problems.append(
MigrationRequired(
f"BooleanField `{field_name}` has changed the default value."
)
)
else:
if field_obj.default != db_version["default"]:
problems.append(
MigrationRequired(
f"Field `{field_name}` has changed the default value."
)
)
if problems:
raise MigrationsNeeded(f"The model {cls} requires migrations.", problems)
class Meta:
database = DATABASE_PROXY

View File

@ -1,8 +1,8 @@
from pydantic import EmailStr from pydantic import EmailStr
from spiderweb.middleware.pydantic import SpiderwebModel from spiderweb.middleware.pydantic import RequestModel
class CommentForm(SpiderwebModel): class CommentForm(RequestModel):
email: EmailStr email: EmailStr
comment: str comment: str

View File

@ -1,6 +1,10 @@
class SpiderwebException(Exception): class SpiderwebException(Exception):
# parent error class; all child exceptions should inherit from this # parent error class; all child exceptions should inherit from this
def __str__(self): def __str__(self):
name = self.__class__.__name__
msg = self.args[0] if len(self.args) > 0 else ""
if msg:
return f"{name}() - {msg}"
return f"{self.__class__.__name__}()" return f"{self.__class__.__name__}()"

View File

@ -16,7 +16,7 @@ class SpiderwebRequestHandler(WSGIRequestHandler):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class LocalServerMiddleware: class LocalServerMixin:
"""Cannot be called on its own. Requires context of SpiderwebRouter.""" """Cannot be called on its own. Requires context of SpiderwebRouter."""
addr: str addr: str

View File

@ -8,9 +8,15 @@ from typing import Optional, Callable
from wsgiref.simple_server import WSGIServer from wsgiref.simple_server import WSGIServer
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from peewee import Database, SqliteDatabase
from spiderweb.middleware import MiddlewareMiddleware from spiderweb.middleware import MiddlewareMixin
from spiderweb.constants import DEFAULT_ENCODING, DEFAULT_ALLOWED_METHODS from spiderweb.constants import (
DATABASE_PROXY,
DEFAULT_ENCODING,
DEFAULT_ALLOWED_METHODS,
)
from spiderweb.db import SpiderwebModel
from spiderweb.default_views import * # noqa: F403 from spiderweb.default_views import * # noqa: F403
from spiderweb.exceptions import ( from spiderweb.exceptions import (
ConfigError, ConfigError,
@ -19,24 +25,23 @@ from spiderweb.exceptions import (
NoResponseError, NoResponseError,
SpiderwebNetworkException, SpiderwebNetworkException,
) )
from spiderweb.local_server import LocalServerMiddleware from spiderweb.local_server import LocalServerMixin
from spiderweb.request import Request from spiderweb.request import Request
from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse from spiderweb.response import HttpResponse, TemplateResponse, JsonResponse
from spiderweb.routes import RoutesMiddleware from spiderweb.routes import RoutesMixin
from spiderweb.secrets import FernetMiddleware from spiderweb.secrets import FernetMixin
from spiderweb.utils import get_http_status_by_code from spiderweb.utils import get_http_status_by_code
file_logger = logging.getLogger(__name__) file_logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
class SpiderwebRouter( class SpiderwebRouter(LocalServerMixin, MiddlewareMixin, RoutesMixin, FernetMixin):
LocalServerMiddleware, MiddlewareMiddleware, RoutesMiddleware, FernetMiddleware
):
def __init__( def __init__(
self, self,
addr: str = None, addr: str = None,
port: int = None, port: int = None,
db: Optional[Database] = None,
templates_dirs: list[str] = None, templates_dirs: list[str] = None,
middleware: list[str] = None, middleware: list[str] = None,
append_slash: bool = False, append_slash: bool = False,
@ -44,6 +49,12 @@ class SpiderwebRouter(
routes: list[list[str | Callable | dict]] = None, routes: list[list[str | Callable | dict]] = None,
error_routes: dict[str, Callable] = None, error_routes: dict[str, Callable] = None,
secret_key: str = None, secret_key: str = None,
session_max_age=60 * 60 * 24 * 14, # 2 weeks
session_cookie_name="swsession",
session_cookie_secure=False, # should be true if serving over HTTPS
session_cookie_http_only=True,
session_cookie_same_site="lax",
session_cookie_path="/",
log=None, log=None,
): ):
self._routes = {} self._routes = {}
@ -59,9 +70,17 @@ class SpiderwebRouter(
self.middleware = middleware if middleware else [] self.middleware = middleware if middleware else []
self.secret_key = secret_key if secret_key else self.generate_key() self.secret_key = secret_key if secret_key else self.generate_key()
# session middleware
self.session_max_age = session_max_age
self.session_cookie_name = session_cookie_name
self.session_cookie_secure = session_cookie_secure
self.session_cookie_http_only = session_cookie_http_only
self.session_cookie_same_site = session_cookie_same_site
self.session_cookie_path = session_cookie_path
self.DEFAULT_ENCODING = DEFAULT_ENCODING self.DEFAULT_ENCODING = DEFAULT_ENCODING
self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS self.DEFAULT_ALLOWED_METHODS = DEFAULT_ALLOWED_METHODS
self.log = log if log else file_logger self.log: logging.Logger = log if log else file_logger
# for using .start() and .stop() # for using .start() and .stop()
self._thread: Optional[Thread] = None self._thread: Optional[Thread] = None
@ -71,6 +90,13 @@ class SpiderwebRouter(
self.init_fernet() self.init_fernet()
self.init_middleware() self.init_middleware()
self.db = db or SqliteDatabase(self.BASE_DIR / "spiderweb.db")
# give the models the db connection
DATABASE_PROXY.initialize(self.db)
self.db.create_tables(SpiderwebModel.__subclasses__())
for model in SpiderwebModel.__subclasses__():
model.check_for_needed_migration()
if self.routes: if self.routes:
self.add_routes() self.add_routes()

View File

@ -2,13 +2,14 @@ from typing import Callable, ClassVar
from .base import SpiderwebMiddleware as SpiderwebMiddleware from .base import SpiderwebMiddleware as SpiderwebMiddleware
from .csrf import CSRFMiddleware as CSRFMiddleware from .csrf import CSRFMiddleware as CSRFMiddleware
from .sessions import SessionMiddleware as SessionMiddleware
from ..exceptions import ConfigError, UnusedMiddleware from ..exceptions import ConfigError, UnusedMiddleware
from ..request import Request from ..request import Request
from ..response import HttpResponse from ..response import HttpResponse
from ..utils import import_by_string from ..utils import import_by_string
class MiddlewareMiddleware: class MiddlewareMixin:
"""Cannot be called on its own. Requires context of SpiderwebRouter.""" """Cannot be called on its own. Requires context of SpiderwebRouter."""
middleware: list[ClassVar] middleware: list[ClassVar]

View File

@ -7,7 +7,7 @@ from spiderweb.request import Request
from spiderweb.response import JsonResponse from spiderweb.response import JsonResponse
class SpiderwebModel(BaseModel, Request): class RequestModel(BaseModel, Request):
# type hinting shenanigans that allow us to annotate Request objects # type hinting shenanigans that allow us to annotate Request objects
# with the pydantic models we want to validate them with, but doesn't # with the pydantic models we want to validate them with, but doesn't
# break the Request object's ability to be used as a Request object # break the Request object's ability to be used as a Request object

View File

@ -0,0 +1,117 @@
from datetime import datetime, timedelta
import json
from peewee import CharField, TextField, DateTimeField, BooleanField
from spiderweb.middleware import SpiderwebMiddleware
from spiderweb.request import Request
from spiderweb.response import HttpResponse
from spiderweb.db import SpiderwebModel
from spiderweb.utils import generate_key, is_jsonable
class Session(SpiderwebModel):
session_key = CharField(max_length=64)
user_id = CharField(max_length=64, null=True)
is_authenticated = BooleanField(default=False)
session_data = TextField()
created_at = DateTimeField()
last_active = DateTimeField()
ip_address = CharField(max_length=30)
user_agent = TextField()
class SessionMiddleware(SpiderwebMiddleware):
def process_request(self, request: Request):
existing_session = (
Session.select()
.where(
Session.session_key
== request.COOKIES.get(self.server.session_cookie_name),
Session.ip_address == request.META.get("client_address"),
Session.user_agent == request.headers.get("HTTP_USER_AGENT"),
)
.first()
)
new_session = False
if not existing_session:
new_session = True
elif datetime.now() - existing_session.created_at > timedelta(
seconds=self.server.session_max_age
):
existing_session.delete_instance()
new_session = True
if new_session:
request.SESSION = {}
request._session["id"] = generate_key()
request._session["new_session"] = True
return
request.SESSION = json.loads(existing_session.session_data)
request._session["id"] = existing_session.session_key
existing_session.save()
def process_response(self, request: Request, response: HttpResponse):
cookie_settings = {
"max_age": self.server.session_max_age,
"same_site": self.server.session_cookie_same_site,
"http_only": self.server.session_cookie_http_only,
"secure": self.server.session_cookie_secure
or request.META.get("HTTPS", False),
"path": self.server.session_cookie_path,
}
# if a new session has been requested, ignore everything else and make that happen
if request._session["new_session"]:
# we generated a new one earlier, so we can use it now
session_key = request._session["id"]
response.set_cookie(
self.server.session_cookie_name,
session_key,
**cookie_settings,
)
session = Session(
session_key=session_key,
session_data=json.dumps(request.SESSION),
created_at=datetime.now(),
last_active=datetime.now(),
ip_address=request.META.get("client_address"),
user_agent=request.headers.get("HTTP_USER_AGENT"),
)
session.save()
return
# Otherwise, we can save the one we already have.
session_key = request._session["id"]
# update the session expiration time
response.set_cookie(
self.server.session_cookie_name,
session_key,
**cookie_settings,
)
session = (
Session.select()
.where(
Session.session_key == session_key,
Session.ip_address == request.META.get("client_address"),
Session.user_agent == request.headers.get("HTTP_USER_AGENT"),
)
.first()
)
if not session:
if not is_jsonable(request.SESSION):
raise ValueError("Session data is not JSON serializable.")
session = Session(
session_key=session_key,
session_data=json.dumps(request.SESSION),
created_at=datetime.now(),
last_active=datetime.now(),
ip_address=request.META.get("client_address"),
user_agent=request.META.get("HTTP_USER_AGENT"),
)
else:
session.session_data = json.dumps(request.SESSION)
session.last_active = datetime.now()
session.save()

View File

@ -2,6 +2,7 @@ import json
from urllib.parse import urlparse from urllib.parse import urlparse
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
from spiderweb.utils import get_client_address
class Request: class Request:
@ -27,6 +28,9 @@ class Request:
self.POST = {} self.POST = {}
self.META = {} self.META = {}
self.COOKIES = {} self.COOKIES = {}
# only used for the session middleware
self.SESSION = {}
self._session: dict = {"new_session": False, "id": None}
# only used for the pydantic middleware and only on POST requests # only used for the pydantic middleware and only on POST requests
self.validated_data = {} self.validated_data = {}
@ -50,6 +54,8 @@ class Request:
self.headers[k] = v self.headers[k] = v
def populate_meta(self) -> None: def populate_meta(self) -> None:
# all caps fields are from WSGI, lowercase names
# are custom
fields = [ fields = [
"SERVER_PROTOCOL", "SERVER_PROTOCOL",
"SERVER_SOFTWARE", "SERVER_SOFTWARE",
@ -66,6 +72,7 @@ class Request:
] ]
for f in fields: for f in fields:
self.META[f] = self.environ.get(f) self.META[f] = self.environ.get(f)
self.META["client_address"] = get_client_address(self.environ)
def populate_cookies(self) -> None: def populate_cookies(self) -> None:
if cookies := self.environ.get("HTTP_COOKIE"): if cookies := self.environ.get("HTTP_COOKIE"):

View File

@ -24,7 +24,7 @@ class DummyRedirectRoute:
return RedirectResponse(self.location) return RedirectResponse(self.location)
class RoutesMiddleware: class RoutesMixin:
"""Cannot be called on its own. Requires context of SpiderwebRouter.""" """Cannot be called on its own. Requires context of SpiderwebRouter."""
# ones that start with underscores are the compiled versions, non-underscores # ones that start with underscores are the compiled versions, non-underscores

View File

@ -3,7 +3,7 @@ from cryptography.fernet import Fernet
from spiderweb.constants import DEFAULT_ENCODING from spiderweb.constants import DEFAULT_ENCODING
class FernetMiddleware: class FernetMixin:
"""Cannot be called on its own. Requires context of SpiderwebRouter.""" """Cannot be called on its own. Requires context of SpiderwebRouter."""
fernet: Fernet fernet: Fernet

View File

@ -1,9 +1,16 @@
import json
import secrets
import string
from http import HTTPStatus from http import HTTPStatus
from typing import Optional from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from spiderweb.request import Request from spiderweb.request import Request
VALID_CHARS = string.ascii_letters + string.digits
def import_by_string(name): def import_by_string(name):
# https://stackoverflow.com/a/547867 # https://stackoverflow.com/a/547867
components = name.split(".") components = name.split(".")
@ -31,8 +38,28 @@ def get_http_status_by_code(code: int) -> Optional[str]:
return f"{resp.value} {resp.phrase}" return f"{resp.value} {resp.phrase}"
def is_form_request(request: Request) -> bool: def is_form_request(request: "Request") -> bool:
return ( return (
"Content-Type" in request.headers "Content-Type" in request.headers
and request.headers["Content-Type"] == "application/x-www-form-urlencoded" and request.headers["Content-Type"] == "application/x-www-form-urlencoded"
) )
# https://stackoverflow.com/a/7839576
def get_client_address(environ: dict) -> str:
try:
return environ["HTTP_X_FORWARDED_FOR"].split(",")[-1].strip()
except KeyError:
return environ.get("REMOTE_ADDR", "unknown")
def generate_key(length=64):
return "".join(secrets.choice(VALID_CHARS) for _ in range(length))
def is_jsonable(data: str) -> bool:
try:
json.dumps(data)
return True
except (TypeError, OverflowError):
return False