Skip to content

Commit 358c529

Browse files
committed
feat: Update with PR feedback
1 parent 443cb8f commit 358c529

22 files changed

+425
-2121
lines changed

src/strands/agent/agent.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,10 @@ def __init__(
201201
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
202202
*,
203203
name: Optional[str] = None,
204-
id: Optional[str] = None,
204+
agent_id: Optional[str] = None,
205205
description: Optional[str] = None,
206206
state: Optional[Union[AgentState, dict]] = None,
207-
session_manager: Optional["SessionManager"] = None,
207+
session_manager: Optional[SessionManager] = None,
208208
):
209209
"""Initialize the Agent with the specified configuration.
210210
@@ -237,7 +237,7 @@ def __init__(
237237
trace_attributes: Custom trace attributes to apply to the agent's trace span.
238238
name: name of the Agent
239239
Defaults to None.
240-
id: identifier for the agent, used by session manager.
240+
agent_id: identifier for the agent, used by session manager.
241241
Defaults to uuid4().
242242
description: description of what the Agent does
243243
Defaults to None.
@@ -308,22 +308,22 @@ def __init__(
308308
else:
309309
self.state = AgentState()
310310

311-
# Initialize session management functionality
312-
self.session_manager = session_manager
313-
314311
self.tool_caller = Agent.ToolCaller(self)
315312
self.name = name or _DEFAULT_AGENT_NAME
316-
self.id = id
313+
self.agent_id = agent_id
317314
self.description = description
318315

319316
self._hooks = HookRegistry()
320-
# Register built-in hook providers (like ConversationManager) here
321-
self._hooks.invoke_callbacks(AgentInitializedEvent(agent=self))
322317

318+
# Register built-in hook providers (like SessionManager) here
319+
320+
# Initialize session management functionality
321+
self.session_manager = session_manager
323322

324-
# Setup session callback handler if session is enabled
325323
if self.session_manager:
326-
self.session_manager.initialize_agent(self)
324+
self._hooks.add_hook(self.session_manager)
325+
326+
self._hooks.invoke_callbacks(AgentInitializedEvent(agent=self))
327327

328328
@property
329329
def tool(self) -> ToolCaller:
@@ -536,10 +536,6 @@ async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGene
536536

537537
self._append_message(message)
538538

539-
# Save message if session manager is available
540-
if self.session_manager:
541-
self.session_manager.append_message_to_agent_session(self, new_message)
542-
543539
# Execute the event loop cycle with retry logic for context limits
544540
events = self._execute_event_loop_cycle(kwargs)
545541
async for event in events:
@@ -633,13 +629,6 @@ def _record_tool_execution(
633629
self._append_message(tool_result_msg)
634630
self._append_message(assistant_msg)
635631

636-
if self.session_manager:
637-
self.session_manager.append_message_to_agent_session(self, user_msg)
638-
self.session_manager.append_message_to_agent_session(self, tool_use_msg)
639-
self.session_manager.append_message_to_agent_session(self, tool_result_msg)
640-
self.session_manager.append_message_to_agent_session(self, assistant_msg)
641-
642-
643632
def _start_agent_trace_span(self, message: Message) -> None:
644633
"""Starts a trace span for the agent.
645634

src/strands/event_loop/event_loop.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,6 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
189189
# Add the response message to the conversation
190190
agent.messages.append(message)
191191
get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
192-
if "agent" in kwargs:
193-
callback_data["agent"] = kwargs["agent"]
194192
yield {"callback": {"message": message}}
195193

196194
# Update metrics
@@ -455,11 +453,8 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator:
455453
"content": [{"toolResult": result} for result in tool_results],
456454
}
457455

458-
agent.messages.append(tool_result_message)
459-
callback_data = {"message": tool_result_message}
460-
if "agent" in kwargs:
461-
callback_data["agent"] = kwargs["agent"]
462-
yield {"callback": callback_data}
456+
agent._append_message(tool_result_message)
457+
yield {"callback": {"message": tool_result_message}}
463458

464459
if cycle_span:
465460
tracer = get_tracer()
Lines changed: 82 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,117 @@
1-
"""File-based implementation of session manager."""
1+
"""Agent session manager implementation."""
22

33
import logging
4-
from typing import TYPE_CHECKING, Any, Optional
4+
from typing import TYPE_CHECKING
55

66
from ..agent.state import AgentState
7-
from ..handlers.callback_handler import CompositeCallbackHandler
8-
from ..types.content import Message
9-
from .exceptions import SessionException
10-
from .file_session_dao import FileSessionDAO
11-
from .session_dao import SessionDAO
7+
from ..experimental.hooks.events import AgentInitializedEvent, MessageAddedEvent
8+
from ..telemetry.metrics import EventLoopMetrics
9+
from ..types.session import (
10+
SessionType,
11+
create_session,
12+
session_agent_from_agent,
13+
session_message_from_message,
14+
session_message_to_message,
15+
)
1216
from .session_manager import SessionManager
13-
from .session_models import Session, SessionAgent, SessionMessage, SessionType
17+
from .session_repository import SessionRepository
1418

1519
logger = logging.getLogger(__name__)
1620

1721
if TYPE_CHECKING:
18-
from ..agent.agent import Agent
22+
pass
1923

24+
DEFAULT_SESSION_AGENT_ID = "default"
2025

21-
class AgentSessionManager(SessionManager):
22-
"""Session manager for a single Agent.
2326

24-
This implementation stores sessions as JSON files in a specified directory.
25-
Each session is stored in a separate file named by its session_id.
26-
"""
27+
class AgentSessionManager(SessionManager):
28+
"""Session manager for persisting agent's in a Session."""
2729

2830
def __init__(
2931
self,
3032
session_id: str,
31-
session_dao: Optional[SessionDAO] = None,
33+
session_repository: SessionRepository,
3234
):
33-
"""Initialize the FileSessionManager."""
34-
self.session_dao = session_dao or FileSessionDAO()
35+
"""Initialize the AgentSessionManager."""
36+
self.session_repository = session_repository
3537
self.session_id = session_id
38+
session = session_repository.read_session(session_id)
39+
# Create a session if it does not exist yet
40+
if session is None:
41+
logger.debug("session_id=<%s> | Session not found, creating new session.", self.session_id)
42+
session = create_session(session_id=session_id, session_type=SessionType.AGENT)
43+
session_repository.create_session(session)
44+
else:
45+
if session["session_type"] != SessionType.AGENT:
46+
raise ValueError(f"Invalid session type: {session.session_type}")
47+
48+
self.session = session
49+
self._default_agent_initialized = False
3650

37-
def append_message_to_agent_session(self, agent: "Agent", message: Message) -> None:
51+
def append_message(self, event: MessageAddedEvent) -> None:
3852
"""Append a message to the agent's session.
3953
4054
Args:
41-
agent: The agent whose session to update
42-
message: The message to append
55+
event: Event for a newly added Message
4356
"""
44-
if agent.id is None:
45-
raise ValueError("`agent.id` must be set before appending message to session.")
57+
agent = event.agent
58+
message = event.message
4659

47-
session_message = SessionMessage.from_dict(dict(message))
48-
self.session_dao.create_message(self.session_id, agent.id, session_message)
49-
self.session_dao.update_agent(
60+
if agent.agent_id is None:
61+
raise ValueError("`agent.agent_id` must be set before appending message to session.")
62+
63+
session_message = session_message_from_message(message)
64+
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
65+
self.session_repository.update_agent(
5066
self.session_id,
51-
SessionAgent(
52-
agent_id=agent.id,
53-
session_id=self.session_id,
54-
event_loop_metrics=agent.event_loop_metrics.to_dict(),
55-
state=agent.state.get(),
56-
),
67+
session_agent_from_agent(agent=agent),
5768
)
5869

59-
def initialize_agent(self, agent: "Agent") -> None:
60-
"""Restore agent data from the current session.
70+
def initialize(self, event: AgentInitializedEvent) -> None:
71+
"""Initialize an agent with a session.
6172
6273
Args:
63-
agent: Agent instance to restore session data to
64-
65-
Raises:
66-
SessionException: If restore operation fails
74+
event: Event when an agent is initialized
6775
"""
68-
if agent.id is None:
69-
raise ValueError("`agent.id` must be set before initializing session.")
70-
71-
try:
72-
# Try to read existing session
73-
session = self.session_dao.read_session(self.session_id)
76+
agent = event.agent
77+
78+
if agent.agent_id is None:
79+
if self._default_agent_initialized:
80+
raise ValueError(
81+
"By default, only one agent with no `agent_id` can be initialized within session_manager."
82+
"Set `agent_id` to support more than one agent in a session."
83+
)
84+
logger.debug(
85+
"agent_id=<%s> | session_id=<%s> | Using default agent_id.",
86+
agent.agent_id,
87+
self.session_id,
88+
)
89+
agent.agent_id = DEFAULT_SESSION_AGENT_ID
90+
self._default_agent_initialized = True
7491

75-
if session.session_type != SessionType.AGENT:
76-
raise ValueError(f"Invalid session type: {session.session_type}")
92+
session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id)
7793

78-
if agent.id not in [agent.agent_id for agent in self.session_dao.list_agents(self.session_id)]:
79-
raise ValueError(f"Agent {agent.id} not found in session {self.session_id}")
94+
if session_agent is None:
95+
logger.debug(
96+
"agent_id=<%s> | session_id=<%s> | Creating agent.",
97+
agent.agent_id,
98+
self.session_id,
99+
)
80100

81-
# Initialize agent
101+
session_agent = session_agent_from_agent(agent)
102+
self.session_repository.create_agent(self.session_id, session_agent)
103+
for message in agent.messages:
104+
session_message = session_message_from_message(message)
105+
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
106+
else:
107+
logger.debug(
108+
"agent_id=<%s> | session_id=<%s> | Restoring agent.",
109+
agent.agent_id,
110+
self.session_id,
111+
)
82112
agent.messages = [
83-
session_message.to_message()
84-
for session_message in self.session_dao.list_messages(self.session_id, agent.id)
113+
session_message_to_message(session_message)
114+
for session_message in self.session_repository.list_messages(self.session_id, agent.agent_id)
85115
]
86-
agent.state = AgentState(self.session_dao.read_agent(self.session_id, agent.id).state)
87-
88-
except SessionException:
89-
# Session doesn't exist, create new one
90-
logger.debug("Session not found, creating new session")
91-
# Session doesn't exist, create new one
92-
session = Session(session_id=self.session_id, session_type=SessionType.AGENT)
93-
session_agent = SessionAgent(
94-
agent_id=agent.id,
95-
session_id=self.session_id,
96-
event_loop_metrics=agent.event_loop_metrics.to_dict(),
97-
state=agent.state.get(),
98-
)
99-
self.session_dao.create_session(session)
100-
self.session_dao.create_agent(self.session_id, session_agent)
101-
for message in agent.messages:
102-
session_message = SessionMessage.from_dict(dict(message))
103-
self.session_dao.create_message(self.session_id, agent.id, session_message)
104-
105-
self.session = session
106-
107-
# Attach a callback handler for persisting messages
108-
def session_callback(**kwargs: Any) -> None:
109-
try:
110-
# Handle message persistence
111-
if "message" in kwargs:
112-
message = kwargs["message"]
113-
self.append_message_to_agent_session(kwargs["agent"], message)
114-
except Exception as e:
115-
logger.error("Persistence operation failed", e)
116-
117-
agent.callback_handler = CompositeCallbackHandler(agent.callback_handler, session_callback)
116+
agent.state = AgentState(session_agent["state"])
117+
agent.event_loop_metrics = EventLoopMetrics.from_dict(session_agent["event_loop_metrics"])

src/strands/session/exceptions.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy