Source code for litestar_permissions.guards

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from litestar.exceptions import NotAuthorizedException, PermissionDeniedException
from sqlalchemy import and_, or_, select

if TYPE_CHECKING:
    from litestar.connection import ASGIConnection
    from litestar.handlers import BaseRouteHandler


def _is_superuser(user: object) -> bool:
    return getattr(user, "is_superuser", False) or getattr(user, "admin", False)


[docs] def require_permission( *permissions: str, resource_type_param: str | None = None, resource_id_param: str | None = None, ) -> Any: """Guard factory that checks if the current user has ALL of the specified permissions. Args: permissions: Permission codenames the user must have (e.g. "application:deploy"). resource_type_param: Path/query param name that holds the resource type. If None, uses resource_id_param with a fixed resource_type from the guard config. resource_id_param: Path/query param name that holds the resource ID. Usage: @get("/apps/{app_id}/deploy", guards=[require_permission("application:deploy", resource_id_param="app_id")]) Note: Consumers must populate ``connection.scope["db_session"]`` with a per-request ``AsyncSession`` (e.g. via middleware or a dependency that writes to scope). """ async def guard(connection: ASGIConnection, _: BaseRouteHandler) -> None: permissions_config = connection.app.state.get("permissions_config") user_key = permissions_config.user_key if permissions_config else "user" user = connection.scope.get(user_key) if user is None: raise NotAuthorizedException("Authentication required") if permissions_config and permissions_config.superuser_bypass and _is_superuser(user): return resolver = connection.app.state.get("permissions_resolver") if resolver is None: raise PermissionDeniedException("Permissions system not configured") db = connection.scope.get("db_session") if db is None: raise PermissionDeniedException("db_session not found in request scope") resource_type, resource_id = _resolve_resource_scope( connection, permissions, resource_type_param, resource_id_param, permissions_config ) for perm in permissions: if not await resolver.can(user.id, perm, resource_type, resource_id, db=db): raise PermissionDeniedException(f"Missing permission: {perm}") return guard
def _resolve_resource_scope( connection: ASGIConnection, permissions: tuple[str, ...], resource_type_param: str | None, resource_id_param: str | None, permissions_config: object | None, ) -> tuple[str | None, str | None]: """Extract resource type and ID from path params.""" resource_type = None resource_id = None if resource_id_param: resource_id = connection.path_params.get(resource_id_param) if resource_type_param: resource_type = connection.path_params.get(resource_type_param) elif resource_id and permissions_config: for perm in permissions: if ":" in perm: resource_type = perm.split(":")[0] break return resource_type, resource_id
[docs] def require_role( *role_names: str, resource_type_param: str | None = None, resource_id_param: str | None = None, ) -> Any: """Guard factory that checks if the current user has ANY of the specified roles. Args: role_names: Role names (user must have at least one). resource_type_param: Path param for resource type. resource_id_param: Path param for resource ID. Note: Consumers must populate ``connection.scope["db_session"]`` with a per-request ``AsyncSession`` (e.g. via middleware or a dependency that writes to scope). """ async def guard(connection: ASGIConnection, _: BaseRouteHandler) -> None: permissions_config = connection.app.state.get("permissions_config") user_key = permissions_config.user_key if permissions_config else "user" user = connection.scope.get(user_key) if user is None: raise NotAuthorizedException("Authentication required") if permissions_config and permissions_config.superuser_bypass and _is_superuser(user): return db = connection.scope.get("db_session") if db is None: raise PermissionDeniedException("db_session not found in request scope") models = connection.app.state.get("permissions_models") if not models: raise PermissionDeniedException("Permissions system not configured") user_role_assignment = models["UserRoleAssignment"] role_model = models["Role"] resource_type = None resource_id = None if resource_id_param: resource_id = connection.path_params.get(resource_id_param) if resource_type_param: resource_type = connection.path_params.get(resource_type_param) stmt = ( select(role_model.name) .join(user_role_assignment, user_role_assignment.role_id == role_model.id) .where(user_role_assignment.user_id == user.id) .where(role_model.name.in_(role_names)) ) scope_filters = [ and_( user_role_assignment.resource_type.is_(None), user_role_assignment.resource_id.is_(None), ) ] if resource_type and resource_id: scope_filters.append( and_( user_role_assignment.resource_type == resource_type, user_role_assignment.resource_id == resource_id, ) ) stmt = stmt.where(or_(*scope_filters)) result = await db.execute(stmt) if result.first() is None: raise PermissionDeniedException(f"Required role: {' or '.join(role_names)}") return guard