Akashmj22122002's picture
Upload folder using huggingface_hub
14edff4 verified
"""Encrypted Session wrapper for secure conversation storage.
This module provides transparent encryption for session storage with automatic
expiration of old data. When TTL expires, expired items are silently skipped.
Usage::
from agents.extensions.memory import EncryptedSession, SQLAlchemySession
# Create underlying session (e.g. SQLAlchemySession)
underlying_session = SQLAlchemySession.from_url(
session_id="user-123",
url="postgresql+asyncpg://app:[email protected]/agents",
create_tables=True,
)
# Wrap with encryption and TTL-based expiration
session = EncryptedSession(
session_id="user-123",
underlying_session=underlying_session,
encryption_key="your-encryption-key",
ttl=600, # 10 minutes
)
await Runner.run(agent, "Hello", session=session)
"""
from __future__ import annotations
import base64
import json
from typing import Any, cast
from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from typing_extensions import Literal, TypedDict, TypeGuard
from ...items import TResponseInputItem
from ...memory.session import SessionABC
class EncryptedEnvelope(TypedDict):
"""TypedDict for encrypted message envelopes stored in the underlying session."""
__enc__: Literal[1]
v: int
kid: str
payload: str
def _ensure_fernet_key_bytes(master_key: str) -> bytes:
"""
Accept either a Fernet key (urlsafe-b64, 32 bytes after decode) or a raw string.
Returns raw bytes suitable for HKDF input.
"""
if not master_key:
raise ValueError("encryption_key not set; required for EncryptedSession.")
try:
key_bytes = base64.urlsafe_b64decode(master_key)
if len(key_bytes) == 32:
return key_bytes
except Exception:
pass
return master_key.encode("utf-8")
def _derive_session_fernet_key(master_key_bytes: bytes, session_id: str) -> Fernet:
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=session_id.encode("utf-8"),
info=b"agents.session-store.hkdf.v1",
)
derived = hkdf.derive(master_key_bytes)
return Fernet(base64.urlsafe_b64encode(derived))
def _to_json_bytes(obj: Any) -> bytes:
return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), default=str).encode("utf-8")
def _from_json_bytes(data: bytes) -> Any:
return json.loads(data.decode("utf-8"))
def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]:
"""Type guard to check if an item is an encrypted envelope."""
return (
isinstance(item, dict)
and item.get("__enc__") == 1
and "payload" in item
and "kid" in item
and "v" in item
)
class EncryptedSession(SessionABC):
"""Encrypted wrapper for Session implementations with TTL-based expiration.
This class wraps any SessionABC implementation to provide transparent
encryption/decryption of stored items using Fernet encryption with
per-session key derivation and automatic expiration of old data.
When items expire (exceed TTL), they are silently skipped during retrieval.
Note: Expired tokens are rejected based on the system clock of the application server.
To avoid valid tokens being rejected due to clock drift, ensure all servers in
your environment are synchronized using NTP.
"""
def __init__(
self,
session_id: str,
underlying_session: SessionABC,
encryption_key: str,
ttl: int = 600,
):
"""
Args:
session_id: ID for this session
underlying_session: The real session store (e.g. SQLiteSession, SQLAlchemySession)
encryption_key: Master key (Fernet key or raw secret)
ttl: Token time-to-live in seconds (default 10 min)
"""
self.session_id = session_id
self.underlying_session = underlying_session
self.ttl = ttl
master = _ensure_fernet_key_bytes(encryption_key)
self.cipher = _derive_session_fernet_key(master, session_id)
self._kid = "hkdf-v1"
self._ver = 1
def __getattr__(self, name):
return getattr(self.underlying_session, name)
def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope:
if isinstance(item, dict):
payload = item
elif hasattr(item, "model_dump"):
payload = item.model_dump()
elif hasattr(item, "__dict__"):
payload = item.__dict__
else:
payload = dict(item)
token = self.cipher.encrypt(_to_json_bytes(payload)).decode("utf-8")
return {"__enc__": 1, "v": self._ver, "kid": self._kid, "payload": token}
def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInputItem | None:
if not _is_encrypted_envelope(item):
return cast(TResponseInputItem, item)
try:
token = item["payload"].encode("utf-8")
plaintext = self.cipher.decrypt(token, ttl=self.ttl)
return cast(TResponseInputItem, _from_json_bytes(plaintext))
except (InvalidToken, KeyError):
return None
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
encrypted_items = await self.underlying_session.get_items(limit)
valid_items: list[TResponseInputItem] = []
for enc in encrypted_items:
item = self._unwrap(enc)
if item is not None:
valid_items.append(item)
return valid_items
async def add_items(self, items: list[TResponseInputItem]) -> None:
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped))
async def pop_item(self) -> TResponseInputItem | None:
while True:
enc = await self.underlying_session.pop_item()
if not enc:
return None
item = self._unwrap(enc)
if item is not None:
return item
async def clear_session(self) -> None:
await self.underlying_session.clear_session()