Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 97 additions & 37 deletions src/bub/builtin/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,27 @@
from datetime import UTC, datetime
from functools import cached_property
from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, cast

from any_llm import AnyLLM
from any_llm.types.completion import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
ChatCompletionMessageFunctionToolCall,
ChatCompletionMessageToolCall,
ChoiceDeltaToolCall,
Function,
ParsedChatCompletion,
)
from loguru import logger
from openai.types.chat.chat_completion_message_custom_tool_call import ChatCompletionMessageCustomToolCall
from openai.types.chat.chat_completion_message_function_tool_call import ChatCompletionMessageFunctionToolCall, Function
from pydantic import TypeAdapter, ValidationError

from bub.builtin.settings import ModelCandidate, load_settings
from bub.builtin.store import ForkTapeStore
from bub.builtin.tape import TapeService
from bub.framework import BubFramework
from bub.runtime import AsyncStreamEvents, BubError, StreamEvent, StreamState
from bub.runtime import AsyncStreamEvents, BubError, ErrorKind, StreamEvent, StreamState
from bub.skills import discover_skills, render_skills_prompt
from bub.tape import InMemoryTapeStore, Tape
from bub.tools import (
Expand All @@ -49,6 +52,7 @@
re.IGNORECASE,
)
MAX_AUTO_HANDOFF_RETRIES = 1
TOOL_ARGUMENTS_ADAPTER = TypeAdapter(dict[str, Any])


class Agent:
Expand Down Expand Up @@ -418,7 +422,7 @@ async def iterator() -> AsyncGenerator[StreamEvent, None]:
model_tools_for_call = model_tools(tools)
text_parts: list[str] = []
tool_calls = _ToolCallAccumulator()
response: ChatCompletion | None = None
response: ChatCompletion | ParsedChatCompletion[Any] | None = None
async with asyncio.timeout(self.settings.model_timeout_seconds):
completion = await self._completion_response(
model=model or self.settings.model,
Expand All @@ -430,13 +434,23 @@ async def iterator() -> AsyncGenerator[StreamEvent, None]:
async for event in _completion_events(completion, state, text_parts, tool_calls):
yield event

text = "".join(text_parts)
resolved_tool_calls = tool_calls.as_list()
if resolved_tool_calls:
assistant_message = response.choices[0].message if response is not None else None

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does it handle toolcalls for stream output?

@PsiACE PsiACE Jun 15, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't given much thought to how to handle this gracefully in the streaming process. since there was no native state management, I considered abandoning it. but.......

so its a bug....

fixed with 82010bb , need carefully check, i tested it with qwen-max and tool call in stream worked.

text = (
assistant_message.content
if assistant_message and assistant_message.content is not None
else "".join(text_parts)
)
native_tool_calls = tool_calls.as_native()
if native_tool_calls:
tool_map = {tool_item.name: tool_item for tool_item in model_tools_for_call}
serialized_tool_calls = [tool_call.model_dump(exclude_none=True) for tool_call in native_tool_calls]
tool_invocations = [
_tool_invocation_from_native(tool_call, tool_map) for tool_call in native_tool_calls
]
yield StreamEvent("tool_call", {"tool_calls": serialized_tool_calls})
context = ToolContext(tape=tape.name, run_id=run_id, state=tape.context.state)
execution = await ToolExecutor().execute_async(
resolved_tool_calls,
tools=model_tools_for_call,
tool_invocations,
context=context,
)
await self.tapes.record_chat(
Expand All @@ -445,16 +459,15 @@ async def iterator() -> AsyncGenerator[StreamEvent, None]:
system_prompt=system_prompt,
new_messages=[prompt_message],
response_text=None,
tool_calls=execution.tool_calls,
tool_calls=serialized_tool_calls,
tool_results=execution.tool_results,
response=response,
model=model or self.settings.model,
usage=state.usage,
)
yield StreamEvent("tool_call", {"tool_calls": execution.tool_calls})
yield StreamEvent("tool_result", {"tool_results": execution.tool_results})
yield StreamEvent(
"final", {"ok": True, "tool_calls": execution.tool_calls, "tool_results": execution.tool_results}
"final", {"ok": True, "tool_calls": serialized_tool_calls, "tool_results": execution.tool_results}
)
return

Expand All @@ -480,7 +493,7 @@ def _build_llm(self, candidate: ModelCandidate) -> AnyLLM:

async def _completion_response(
self, *, model: str, messages: list[dict[str, Any]], tools: list[Tool]
) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]:
) -> ChatCompletion | ParsedChatCompletion[Any] | AsyncIterator[ChatCompletionChunk]:
from bub.builtin.tools import completion_tools

tool_payloads = completion_tools(tools) or None
Expand Down Expand Up @@ -557,25 +570,45 @@ def as_tool_call(self, index: int) -> ChatCompletionMessageFunctionToolCall:

class _ToolCallAccumulator:
def __init__(self) -> None:
self._calls: list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall] = []
self._message_calls: list[ChatCompletionMessageToolCall] = []
self._stream_calls: dict[int, _StreamToolCall] = {}

def add_message_call(
self, call: ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall
) -> None:
self._calls.append(call)
def add_message_calls(self, calls: Iterable[ChatCompletionMessageToolCall]) -> None:
self._message_calls.extend(calls)

def merge_delta_calls(self, deltas: Iterable[ChoiceDeltaToolCall]) -> None:
for delta in deltas:
self._stream_calls.setdefault(delta.index, _StreamToolCall()).merge(delta)

def as_list(self) -> list[dict[str, Any]]:
calls = self._calls or [self._stream_calls[index].as_tool_call(index) for index in sorted(self._stream_calls)]
return [call.model_dump(exclude_none=True) for call in calls]
def as_native(self) -> list[ChatCompletionMessageToolCall]:
if self._message_calls:
return list(self._message_calls)
return [self._stream_calls[index].as_tool_call(index) for index in sorted(self._stream_calls)]


def _tool_invocation_from_native(
tool_call: ChatCompletionMessageToolCall,
tool_map: dict[str, Tool],
) -> tuple[Tool, dict[str, Any]]:
tool_name, arguments = _parse_native_function_call(tool_call)
tool_obj = tool_map.get(tool_name)
if tool_obj is None:
raise BubError(ErrorKind.TOOL, f"Unknown tool name: {tool_name}.")
return tool_obj, arguments


def _parse_native_function_call(tool_call: ChatCompletionMessageToolCall) -> tuple[str, dict[str, Any]]:
if not isinstance(tool_call, ChatCompletionMessageFunctionToolCall):
raise BubError(ErrorKind.INVALID_INPUT, "Expected a function tool call with JSON object arguments.")
try:
arguments = TOOL_ARGUMENTS_ADAPTER.validate_json(tool_call.function.arguments or "{}")
except ValidationError as exc:
raise BubError(ErrorKind.INVALID_INPUT, "Expected a function tool call with JSON object arguments.") from exc
return tool_call.function.name, arguments


async def _completion_events(
completion: ChatCompletion | AsyncIterator[ChatCompletionChunk],
completion: ChatCompletion | ParsedChatCompletion[Any] | AsyncIterator[ChatCompletionChunk],
state: StreamState,
text_parts: list[str],
tool_calls: _ToolCallAccumulator,
Expand All @@ -584,23 +617,50 @@ async def _completion_events(
if usage := TapeService._extract_usage(completion):
state.usage = usage
message = completion.choices[0].message
if message.content:
text_parts.append(message.content)
yield StreamEvent("text", {"delta": message.content})
for tool_call in message.tool_calls or []:
tool_calls.add_message_call(tool_call)
for event in _completion_message_events(message, text_parts, tool_calls):
yield event
return

async for chunk in completion:
if usage := TapeService._extract_usage(chunk):
state.usage = usage
for choice in chunk.choices:
delta = choice.delta
if delta.content:
text_parts.append(delta.content)
yield StreamEvent("text", {"delta": delta.content})
if delta.tool_calls:
tool_calls.merge_delta_calls(delta.tool_calls)
async for event in _completion_chunk_events(chunk, state, text_parts, tool_calls):
yield event


def _completion_message_events(
message: ChatCompletionMessage,
text_parts: list[str],
tool_calls: _ToolCallAccumulator,
) -> Iterable[StreamEvent]:
if message.reasoning:
yield StreamEvent("reasoning", {"delta": _reasoning_text(message.reasoning)})
if message.content:
text_parts.append(message.content)
yield StreamEvent("text", {"delta": message.content})
tool_calls.add_message_calls(cast("Iterable[ChatCompletionMessageToolCall]", message.tool_calls or []))


async def _completion_chunk_events(
chunk: ChatCompletionChunk,
state: StreamState,
text_parts: list[str],
tool_calls: _ToolCallAccumulator,
) -> AsyncGenerator[StreamEvent, None]:
if usage := TapeService._extract_usage(chunk):
state.usage = usage
for choice in chunk.choices:
delta = choice.delta
if delta.reasoning:
yield StreamEvent("reasoning", {"delta": _reasoning_text(delta.reasoning)})
if delta.content:
text_parts.append(delta.content)
yield StreamEvent("text", {"delta": delta.content})
if delta.tool_calls:
tool_calls.merge_delta_calls(delta.tool_calls)


def _reasoning_text(reasoning: object) -> str:
content = getattr(reasoning, "content", reasoning)
return "" if content is None else str(content)


@dataclass(frozen=True)
Expand Down
Loading
Loading