__all__ = [
"BaseAgentSpec",
"BasePromptAgentSpec",
"AgentStreamChunk",
]
from ..utils.basic.misc_utils import unique
from ..utils.basic.config_utils import HEAVEN_CM
from ..utils.basic.serialize_utils import dumps_json
from ..utils.basic.log_utils import get_logger
logger = get_logger(__name__)
from ..llm import Messages, LLM, LLMIncludeType, gather_stream, format_messages
from ..ukf.templates.basic.prompt import PromptUKFT
from ..tool.base import ToolSpec
from typing import Dict, Any, List, Optional, Generator, Tuple, Union
from abc import ABC, abstractmethod
from copy import deepcopy
# Type alias for stream chunks
AgentStreamChunk = Dict[str, Any]
[docs]
class BaseAgentSpec(ABC):
[docs]
def __init__(
self,
tools: Optional[List[ToolSpec]] = None,
llm_args: Optional[Dict] = None,
max_steps: Optional[int] = None,
**kwargs,
):
super().__init__()
self.tools = tools or list()
self.llm = LLM(**(llm_args or dict()))
self.max_steps = HEAVEN_CM.get("agent.max_steps", 20) if max_steps is None else max_steps
[docs]
@abstractmethod
def encode(self, **inputs) -> Messages:
"""\
Convert input arguments into Messages for the agent.
Args:
**inputs: Arbitrary input arguments.
Returns:
Messages: The encoded messages for the agent.
"""
pass
[docs]
def step(self, messages: Messages, include: Optional[List[LLMIncludeType]] = None) -> Generator[AgentStreamChunk, None, None]:
"""\
Execute a single LLM call with streaming.
Args:
messages: Current conversation messages.
include: Fields to include in the stream chunks.
Yields:
Stream chunks from the LLM.
"""
for chunk in self.llm.stream(messages, tools=self.tools, include=include, reduce=False):
yield chunk
[docs]
@abstractmethod
def is_done(self, messages: Messages, delta_messages: List[Dict[str, Any]]) -> Tuple[bool, Dict[str, Any]]:
pass
[docs]
def user_proxy(self, messages: Messages, delta_messages: List[Dict[str, Any]], finish_state: Dict[str, Any] = None) -> Messages:
"""\
Add a user proxy message to prompt the agent to continue.
This is called when the agent is not done after a step, to
encourage it to keep going.
Args:
messages: Current conversation messages.
delta_messages: Delta messages from the last step.
finish_state: Optional finish state from the last step.
Returns:
Messages: A list of messages to append.
"""
return [{"role": "user", "content": "The task is not complete. Please continue until the task is complete."}]
[docs]
@abstractmethod
def decode(self, messages: Messages, finish_state: Dict[str, Any] = None) -> Any:
pass
[docs]
def stream(self, messages: Messages, include: Optional[List[LLMIncludeType]] = None) -> Generator[AgentStreamChunk, None, None]:
"""\
Stream the agent execution, yielding chunks as they are generated.
This is the core streaming interface. Each chunk contains:
- Standard LLM fields: text, think, tool_calls, tool_messages, etc.
- Agent control fields: step, done, finish_state, messages
Args:
messages: Initial messages to start the agent.
include: Fields to include in the stream. Defaults to common fields.
Yields:
AgentStreamChunk: Stream chunks with LLM output and agent state.
"""
default_include = ["text", "think", "tool_calls", "tool_messages", "delta_messages"]
include = unique((include or default_include) + ["delta_messages"])
cloned_messages = deepcopy(messages)
for step_num in range(self.max_steps):
# Emit step start
yield {"step": step_num, "step_status": "start"}
delta_messages = list()
# Stream the LLM step
for chunk in self.step(cloned_messages, include=include):
# Pass through LLM chunks
yield chunk
# Collect delta_messages
if chunk.get("delta_messages"):
delta_messages.extend(chunk.get("delta_messages", list()))
# Update messages
cloned_messages.extend(delta_messages)
# Check for stale (no output)
if len(delta_messages) == 0:
finish_state = {"msg": "stale", "steps": step_num + 1, "max_steps": self.max_steps}
yield {"step": step_num, "step_status": "end", "done": True, "finish_state": finish_state, "messages": cloned_messages}
return
# Check if done
done, finish_state = self.is_done(messages=cloned_messages, delta_messages=delta_messages)
if done:
finish_state = (finish_state or {}) | {"steps": step_num + 1, "max_steps": self.max_steps}
yield {"step": step_num, "step_status": "end", "done": True, "finish_state": finish_state, "messages": cloned_messages}
return
# Add user proxy message if needed for continuation
user_proxy_messages = []
if cloned_messages and cloned_messages[-1].get("role") == "assistant":
user_proxy_messages = self.user_proxy(cloned_messages, delta_messages, finish_state=finish_state)
delta_messages.extend(user_proxy_messages)
cloned_messages.extend(user_proxy_messages)
# Emit step end (not done yet), include user_proxy messages for session storage
yield {"step": step_num, "step_status": "end", "done": False, "delta_messages": user_proxy_messages}
# Max steps reached
finish_state = {"msg": "max_steps_reached", "steps": self.max_steps, "max_steps": self.max_steps}
yield {"step": self.max_steps - 1, "step_status": "end", "done": True, "finish_state": finish_state, "messages": cloned_messages}
[docs]
def run(self, messages: Messages, include: Optional[List[LLMIncludeType]] = None) -> Tuple[Messages, Dict[str, Any]]:
"""\
Run the agent to completion, collecting all stream output.
This is a convenience wrapper around stream() that blocks until completion.
Args:
messages: Initial messages to start the agent.
include: Fields to include (passed to stream).
Returns:
Tuple of (final_messages, finish_state).
"""
final_messages = messages
finish_state = {"msg": "unknown"}
for chunk in self.stream(messages, include=include):
if chunk.get("done"):
finish_state = chunk.get("finish_state", finish_state)
final_messages = chunk.get("messages", final_messages)
return final_messages, finish_state
[docs]
def __call__(self, **inputs) -> Any:
"""\
Convenience method to encode, run, and decode in one call.
Args:
**inputs: Input arguments passed to encode().
Returns:
Decoded output from the agent.
"""
encoded_messages = self.encode(**inputs)
final_messages, finish_state = self.run(encoded_messages)
decoded_output = self.decode(final_messages, finish_state=finish_state)
return decoded_output
[docs]
class BasePromptAgentSpec(BaseAgentSpec):
[docs]
def __init__(
self,
prompt: PromptUKFT,
tools: Optional[List[ToolSpec]] = None,
llm_args: Optional[Dict] = None,
max_steps: Optional[int] = None,
**kwargs,
):
super().__init__(tools=tools, llm_args=llm_args, max_steps=max_steps, **kwargs)
self.prompt = prompt.clone()
[docs]
def encode(self, **inputs) -> Messages:
return format_messages(self.prompt.text(instance={"inputs": inputs}))