|
1 |
| -"""File-based implementation of session manager.""" |
| 1 | +"""Agent session manager implementation.""" |
2 | 2 |
|
3 | 3 | import logging
|
4 |
| -from typing import TYPE_CHECKING, Any, Optional |
| 4 | +from typing import TYPE_CHECKING |
5 | 5 |
|
6 | 6 | from ..agent.state import AgentState
|
7 |
| -from ..handlers.callback_handler import CompositeCallbackHandler |
8 |
| -from ..types.content import Message |
9 |
| -from .exceptions import SessionException |
10 |
| -from .file_session_dao import FileSessionDAO |
11 |
| -from .session_dao import SessionDAO |
| 7 | +from ..experimental.hooks.events import AgentInitializedEvent, MessageAddedEvent |
| 8 | +from ..telemetry.metrics import EventLoopMetrics |
| 9 | +from ..types.session import ( |
| 10 | + SessionType, |
| 11 | + create_session, |
| 12 | + session_agent_from_agent, |
| 13 | + session_message_from_message, |
| 14 | + session_message_to_message, |
| 15 | +) |
12 | 16 | from .session_manager import SessionManager
|
13 |
| -from .session_models import Session, SessionAgent, SessionMessage, SessionType |
| 17 | +from .session_repository import SessionRepository |
14 | 18 |
|
15 | 19 | logger = logging.getLogger(__name__)
|
16 | 20 |
|
17 | 21 | if TYPE_CHECKING:
|
18 |
| - from ..agent.agent import Agent |
| 22 | + pass |
19 | 23 |
|
| 24 | +DEFAULT_SESSION_AGENT_ID = "default" |
20 | 25 |
|
21 |
| -class AgentSessionManager(SessionManager): |
22 |
| - """Session manager for a single Agent. |
23 | 26 |
|
24 |
| - This implementation stores sessions as JSON files in a specified directory. |
25 |
| - Each session is stored in a separate file named by its session_id. |
26 |
| - """ |
| 27 | +class AgentSessionManager(SessionManager): |
| 28 | + """Session manager for persisting agent's in a Session.""" |
27 | 29 |
|
28 | 30 | def __init__(
|
29 | 31 | self,
|
30 | 32 | session_id: str,
|
31 |
| - session_dao: Optional[SessionDAO] = None, |
| 33 | + session_repository: SessionRepository, |
32 | 34 | ):
|
33 |
| - """Initialize the FileSessionManager.""" |
34 |
| - self.session_dao = session_dao or FileSessionDAO() |
| 35 | + """Initialize the AgentSessionManager.""" |
| 36 | + self.session_repository = session_repository |
35 | 37 | self.session_id = session_id
|
| 38 | + session = session_repository.read_session(session_id) |
| 39 | + # Create a session if it does not exist yet |
| 40 | + if session is None: |
| 41 | + logger.debug("session_id=<%s> | Session not found, creating new session.", self.session_id) |
| 42 | + session = create_session(session_id=session_id, session_type=SessionType.AGENT) |
| 43 | + session_repository.create_session(session) |
| 44 | + else: |
| 45 | + if session["session_type"] != SessionType.AGENT: |
| 46 | + raise ValueError(f"Invalid session type: {session.session_type}") |
| 47 | + |
| 48 | + self.session = session |
| 49 | + self._default_agent_initialized = False |
36 | 50 |
|
37 |
| - def append_message_to_agent_session(self, agent: "Agent", message: Message) -> None: |
| 51 | + def append_message(self, event: MessageAddedEvent) -> None: |
38 | 52 | """Append a message to the agent's session.
|
39 | 53 |
|
40 | 54 | Args:
|
41 |
| - agent: The agent whose session to update |
42 |
| - message: The message to append |
| 55 | + event: Event for a newly added Message |
43 | 56 | """
|
44 |
| - if agent.id is None: |
45 |
| - raise ValueError("`agent.id` must be set before appending message to session.") |
| 57 | + agent = event.agent |
| 58 | + message = event.message |
46 | 59 |
|
47 |
| - session_message = SessionMessage.from_dict(dict(message)) |
48 |
| - self.session_dao.create_message(self.session_id, agent.id, session_message) |
49 |
| - self.session_dao.update_agent( |
| 60 | + if agent.agent_id is None: |
| 61 | + raise ValueError("`agent.agent_id` must be set before appending message to session.") |
| 62 | + |
| 63 | + session_message = session_message_from_message(message) |
| 64 | + self.session_repository.create_message(self.session_id, agent.agent_id, session_message) |
| 65 | + self.session_repository.update_agent( |
50 | 66 | self.session_id,
|
51 |
| - SessionAgent( |
52 |
| - agent_id=agent.id, |
53 |
| - session_id=self.session_id, |
54 |
| - event_loop_metrics=agent.event_loop_metrics.to_dict(), |
55 |
| - state=agent.state.get(), |
56 |
| - ), |
| 67 | + session_agent_from_agent(agent=agent), |
57 | 68 | )
|
58 | 69 |
|
59 |
| - def initialize_agent(self, agent: "Agent") -> None: |
60 |
| - """Restore agent data from the current session. |
| 70 | + def initialize(self, event: AgentInitializedEvent) -> None: |
| 71 | + """Initialize an agent with a session. |
61 | 72 |
|
62 | 73 | Args:
|
63 |
| - agent: Agent instance to restore session data to |
64 |
| -
|
65 |
| - Raises: |
66 |
| - SessionException: If restore operation fails |
| 74 | + event: Event when an agent is initialized |
67 | 75 | """
|
68 |
| - if agent.id is None: |
69 |
| - raise ValueError("`agent.id` must be set before initializing session.") |
70 |
| - |
71 |
| - try: |
72 |
| - # Try to read existing session |
73 |
| - session = self.session_dao.read_session(self.session_id) |
| 76 | + agent = event.agent |
| 77 | + |
| 78 | + if agent.agent_id is None: |
| 79 | + if self._default_agent_initialized: |
| 80 | + raise ValueError( |
| 81 | + "By default, only one agent with no `agent_id` can be initialized within session_manager." |
| 82 | + "Set `agent_id` to support more than one agent in a session." |
| 83 | + ) |
| 84 | + logger.debug( |
| 85 | + "agent_id=<%s> | session_id=<%s> | Using default agent_id.", |
| 86 | + agent.agent_id, |
| 87 | + self.session_id, |
| 88 | + ) |
| 89 | + agent.agent_id = DEFAULT_SESSION_AGENT_ID |
| 90 | + self._default_agent_initialized = True |
74 | 91 |
|
75 |
| - if session.session_type != SessionType.AGENT: |
76 |
| - raise ValueError(f"Invalid session type: {session.session_type}") |
| 92 | + session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) |
77 | 93 |
|
78 |
| - if agent.id not in [agent.agent_id for agent in self.session_dao.list_agents(self.session_id)]: |
79 |
| - raise ValueError(f"Agent {agent.id} not found in session {self.session_id}") |
| 94 | + if session_agent is None: |
| 95 | + logger.debug( |
| 96 | + "agent_id=<%s> | session_id=<%s> | Creating agent.", |
| 97 | + agent.agent_id, |
| 98 | + self.session_id, |
| 99 | + ) |
80 | 100 |
|
81 |
| - # Initialize agent |
| 101 | + session_agent = session_agent_from_agent(agent) |
| 102 | + self.session_repository.create_agent(self.session_id, session_agent) |
| 103 | + for message in agent.messages: |
| 104 | + session_message = session_message_from_message(message) |
| 105 | + self.session_repository.create_message(self.session_id, agent.agent_id, session_message) |
| 106 | + else: |
| 107 | + logger.debug( |
| 108 | + "agent_id=<%s> | session_id=<%s> | Restoring agent.", |
| 109 | + agent.agent_id, |
| 110 | + self.session_id, |
| 111 | + ) |
82 | 112 | agent.messages = [
|
83 |
| - session_message.to_message() |
84 |
| - for session_message in self.session_dao.list_messages(self.session_id, agent.id) |
| 113 | + session_message_to_message(session_message) |
| 114 | + for session_message in self.session_repository.list_messages(self.session_id, agent.agent_id) |
85 | 115 | ]
|
86 |
| - agent.state = AgentState(self.session_dao.read_agent(self.session_id, agent.id).state) |
87 |
| - |
88 |
| - except SessionException: |
89 |
| - # Session doesn't exist, create new one |
90 |
| - logger.debug("Session not found, creating new session") |
91 |
| - # Session doesn't exist, create new one |
92 |
| - session = Session(session_id=self.session_id, session_type=SessionType.AGENT) |
93 |
| - session_agent = SessionAgent( |
94 |
| - agent_id=agent.id, |
95 |
| - session_id=self.session_id, |
96 |
| - event_loop_metrics=agent.event_loop_metrics.to_dict(), |
97 |
| - state=agent.state.get(), |
98 |
| - ) |
99 |
| - self.session_dao.create_session(session) |
100 |
| - self.session_dao.create_agent(self.session_id, session_agent) |
101 |
| - for message in agent.messages: |
102 |
| - session_message = SessionMessage.from_dict(dict(message)) |
103 |
| - self.session_dao.create_message(self.session_id, agent.id, session_message) |
104 |
| - |
105 |
| - self.session = session |
106 |
| - |
107 |
| - # Attach a callback handler for persisting messages |
108 |
| - def session_callback(**kwargs: Any) -> None: |
109 |
| - try: |
110 |
| - # Handle message persistence |
111 |
| - if "message" in kwargs: |
112 |
| - message = kwargs["message"] |
113 |
| - self.append_message_to_agent_session(kwargs["agent"], message) |
114 |
| - except Exception as e: |
115 |
| - logger.error("Persistence operation failed", e) |
116 |
| - |
117 |
| - agent.callback_handler = CompositeCallbackHandler(agent.callback_handler, session_callback) |
| 116 | + agent.state = AgentState(session_agent["state"]) |
| 117 | + agent.event_loop_metrics = EventLoopMetrics.from_dict(session_agent["event_loop_metrics"]) |
0 commit comments