-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Adding per tool usage limit #3691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
017eac1
190e230
c9d7aed
909716b
3880ab6
1d62583
114d2eb
64f359d
903d5d5
8ba5cc8
0a96653
ad40e36
7554e0d
3352ab3
89265a4
776e542
f981f99
34b830a
364295b
283a3f3
3914322
0f63561
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -669,6 +669,51 @@ except UsageLimitExceeded as e: | |
| - Usage limits are especially relevant if you've registered many tools. Use `request_limit` to bound the number of model turns, and `tool_calls_limit` to cap the number of successful tool executions within a run. | ||
| - The `tool_calls_limit` is checked before executing tool calls. If the model returns parallel tool calls that would exceed the limit, no tools will be executed. | ||
|
|
||
| ##### Soft Tool Call Limits with `max_tool_calls` | ||
|
|
||
| If you want to limit tool calls but let the model decide how to proceed instead of raising an error, use the `max_tool_calls` parameter. This is a "soft" limit that returns a message to the model when exceeded, rather than raising a [`UsageLimitExceeded`][pydantic_ai.exceptions.UsageLimitExceeded] exception. | ||
|
|
||
| ```py {test="skip"} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't skip tests unless absolutely necessary |
||
| from pydantic_ai import Agent | ||
|
|
||
| agent = Agent('anthropic:claude-sonnet-4-5', max_tool_calls=2) # (1)! | ||
|
|
||
| @agent.tool_plain | ||
| def do_work() -> str: | ||
| return 'ok' | ||
|
|
||
| # The model can make up to 2 tool calls | ||
| result = agent.run_sync('Please call the tool three times') | ||
| print(result.output) | ||
| #> I was able to call the tool twice, but the third call reached the limit. | ||
| ``` | ||
|
|
||
| 1. Set the maximum number of tool calls allowed during runs. This can also be set per-run. | ||
|
|
||
| When `max_tool_calls` is exceeded, instead of executing the tool, the agent returns a message to the model: `'Tool call limit reached for tool "{tool_name}".'`. The model then decides how to respond based on this information. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like it can be combined with the sentence before the example |
||
|
|
||
| You can also override `max_tool_calls` at run time: | ||
|
|
||
| ```py {test="skip"} | ||
| from pydantic_ai import Agent | ||
|
|
||
| agent = Agent('anthropic:claude-sonnet-4-5', max_tool_calls=5) # Default limit | ||
|
|
||
| @agent.tool_plain | ||
| def calculate(x: int) -> int: | ||
| return x * 2 | ||
|
|
||
| # Override the limit for this specific run | ||
| result = agent.run_sync('Calculate something', max_tool_calls=1) | ||
| ``` | ||
|
|
||
| **When to use `max_tool_calls` vs `tool_calls_limit`:** | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sounds like something for the explanation in the previous section where we link to this feature. And for the primary docs of this feature where we mention how it's related to the UsageLimits hard limit. |
||
|
|
||
| | Parameter | Behavior | Use Case | | ||
| | --------- | -------- | -------- | | ||
| | `tool_calls_limit` | Raises [`UsageLimitExceeded`][pydantic_ai.exceptions.UsageLimitExceeded] | Hard stop when you need to prevent runaway costs | | ||
| | `max_tool_calls` | Returns message to model | Soft limit where you want the model to adapt gracefully | | ||
|
|
||
| #### Model (Run) Settings | ||
|
|
||
| Pydantic AI offers a [`settings.ModelSettings`][pydantic_ai.settings.ModelSettings] structure to help you fine tune your requests. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -411,7 +411,39 @@ If a tool requires sequential/serial execution, you can pass the [`sequential`][ | |
| Async functions are run on the event loop, while sync functions are offloaded to threads. To get the best performance, _always_ use an async function _unless_ you're doing blocking I/O (and there's no way to use a non-blocking library instead) or CPU-bound work (like `numpy` or `scikit-learn` operations), so that simple functions are not offloaded to threads unnecessarily. | ||
|
|
||
| !!! note "Limiting tool executions" | ||
| You can cap tool executions within a run using [`UsageLimits(tool_calls_limit=...)`](agents.md#usage-limits). The counter increments only after a successful tool invocation. Output tools (used for [structured output](output.md)) are not counted in the `tool_calls` metric. | ||
| You can cap the total number of tool executions within a run using [`UsageLimits(tool_calls_limit=...)`](agents.md#usage-limits). For finer control, you can limit how many times a *specific* tool can be called by setting the `max_uses` parameter when registering the tool (e.g., `@agent.tool(max_uses=3)` or `Tool(func, max_uses=3)`). Once a tool reaches its `max_uses` limit, it is automatically removed from the available tools for subsequent steps in the run. The `tool_calls` counter increments only after a successful tool invocation. Output tools (used for [structured output](output.md)) are not counted in the `tool_calls` metric. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The bit about |
||
|
|
||
| For a "soft" limit that lets the model decide how to proceed instead of raising an error, use the [`max_tool_calls`](agents.md#soft-tool-call-limits-with-max_tool_calls) parameter on the agent or run method. | ||
|
|
||
| #### Raising Hard Errors on Tool Usage Limits | ||
|
|
||
| By default, when a tool reaches its `max_uses` limit, it is silently removed from the available tools. If you want to raise an error instead, you can use a [`prepare`](#tool-prepare) function to check the tool usage and raise a [`UsageLimitExceeded`][pydantic_ai.exceptions.UsageLimitExceeded] exception: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| ```python {title="tool_max_uses_hard_error.py"} | ||
| from typing import Any | ||
|
|
||
| from pydantic_ai import Agent, RunContext, ToolDefinition | ||
| from pydantic_ai.exceptions import UsageLimitExceeded | ||
|
|
||
| agent = Agent('test') | ||
|
|
||
|
|
||
| async def raise_on_limit( | ||
| ctx: RunContext[Any], tool_def: ToolDefinition | ||
| ) -> ToolDefinition | None: | ||
| if ctx.max_uses and ctx.tool_usage.get(tool_def.name, 0) >= ctx.max_uses: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another interesting example would be to check
|
||
| raise UsageLimitExceeded( | ||
| f'Tool "{tool_def.name}" has reached its usage limit of {ctx.max_uses}.' | ||
| ) | ||
| return tool_def | ||
|
|
||
|
|
||
| @agent.tool(max_uses=2, prepare=raise_on_limit) | ||
| def limited_tool(ctx: RunContext[None]) -> str: | ||
| return 'Tool executed' | ||
| ``` | ||
|
|
||
| In this example, when `limited_tool` is called more than twice, a `UsageLimitExceeded` error will be raised instead of silently removing the tool. | ||
|
|
||
| #### Output Tool Calls | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,6 +90,7 @@ class GraphAgentState: | |
| retries: int = 0 | ||
| run_step: int = 0 | ||
| run_id: str = dataclasses.field(default_factory=lambda: str(uuid.uuid4())) | ||
| tool_usage: dict[str, int] = dataclasses.field(default_factory=dict) | ||
|
|
||
| def increment_retries( | ||
| self, | ||
|
|
@@ -135,6 +136,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): | |
| model_settings: ModelSettings | None | ||
| usage_limits: _usage.UsageLimits | ||
| max_result_retries: int | ||
| max_tool_calls: int | None | ||
| end_strategy: EndStrategy | ||
| get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]] | ||
|
|
||
|
|
@@ -816,6 +818,8 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT | |
| else DEFAULT_INSTRUMENTATION_VERSION, | ||
| run_step=ctx.state.run_step, | ||
| run_id=ctx.state.run_id, | ||
| tool_usage=ctx.state.tool_usage, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very similar to |
||
| max_tool_calls=ctx.deps.max_tool_calls, | ||
| ) | ||
| validation_context = build_validation_context(ctx.deps.validation_context, run_context) | ||
| run_context = replace(run_context, validation_context=validation_context) | ||
|
|
@@ -1018,6 +1022,17 @@ async def process_tool_calls( # noqa: C901 | |
| output_final_result.append(final_result) | ||
|
|
||
|
|
||
| def _projection_count_of_tool_usage( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is worth a method, it'll be easier to follow if we inline it |
||
| tool_call_counts: defaultdict[str, int], tool_calls: list[_messages.ToolCallPart] | ||
| ) -> None: | ||
| """Populate a count of tool usage based on the provided tool calls for this run step. | ||
|
|
||
| We will use this to make sure the calls do not exceed tool usage limits. | ||
| """ | ||
| for call in tool_calls: | ||
| tool_call_counts[call.tool_name] += 1 | ||
|
|
||
|
|
||
| async def _call_tools( | ||
| tool_manager: ToolManager[DepsT], | ||
| tool_calls: list[_messages.ToolCallPart], | ||
|
|
@@ -1039,14 +1054,38 @@ async def _call_tools( | |
| projected_usage.tool_calls += len(tool_calls) | ||
| usage_limits.check_before_tool_call(projected_usage) | ||
|
|
||
| # Checks for soft limits(if any set on total tools) | ||
| can_make_tool_calls = tool_manager.can_make_tool_calls(len(tool_calls), deepcopy(usage)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to duplicate some of what the |
||
|
|
||
| calls_to_run: list[_messages.ToolCallPart] = [] | ||
|
|
||
| # For each tool, check how many calls are going to be made | ||
| tool_call_counts: defaultdict[str, int] = defaultdict(int) | ||
| _projection_count_of_tool_usage(tool_call_counts, tool_calls) | ||
|
|
||
| for call in tool_calls: | ||
| yield _messages.FunctionToolCallEvent(call) | ||
| current_tool_use = tool_manager.get_current_use_of_tool(call.tool_name) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this be managed entirely in
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be interesting to implement this suggestion from #3352 (comment) at the same time:
|
||
| max_tool_use = tool_manager.get_max_use_of_tool(call.tool_name) | ||
| if ( | ||
| max_tool_use is not None and current_tool_use + tool_call_counts[call.tool_name] > max_tool_use | ||
| ) or not can_make_tool_calls: | ||
| return_part = _messages.ToolReturnPart( | ||
| tool_name=call.tool_name, | ||
| content=f'Tool call limit reached for tool "{call.tool_name}".', | ||
| tool_call_id=call.tool_call_id, | ||
| # TODO: Add return kind and prompt_config here once supported by #3656 | ||
| ) | ||
| output_parts.append(return_part) | ||
| yield _messages.FunctionToolResultEvent(return_part) | ||
| else: | ||
| yield _messages.FunctionToolCallEvent(call) | ||
| calls_to_run.append(call) | ||
|
|
||
| with tracer.start_as_current_span( | ||
| 'running tools', | ||
| attributes={ | ||
| 'tools': [call.tool_name for call in tool_calls], | ||
| 'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}', | ||
| 'tools': [call.tool_name for call in calls_to_run], | ||
| 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}', | ||
| }, | ||
| ): | ||
|
|
||
|
|
@@ -1080,8 +1119,8 @@ async def handle_call_or_result( | |
|
|
||
| return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content) | ||
|
|
||
| if tool_manager.should_call_sequentially(tool_calls): | ||
| for index, call in enumerate(tool_calls): | ||
| if tool_manager.should_call_sequentially(calls_to_run): | ||
| for index, call in enumerate(calls_to_run): | ||
| if event := await handle_call_or_result( | ||
| _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), | ||
| index, | ||
|
|
@@ -1094,7 +1133,7 @@ async def handle_call_or_result( | |
| _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), | ||
| name=call.tool_name, | ||
| ) | ||
| for call in tool_calls | ||
| for call in calls_to_run | ||
| ] | ||
|
|
||
| pending = tasks | ||
|
|
@@ -1111,7 +1150,11 @@ async def handle_call_or_result( | |
| output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)]) | ||
|
|
||
| _populate_deferred_calls( | ||
| tool_calls, deferred_calls_by_index, deferred_metadata_by_index, output_deferred_calls, output_deferred_metadata | ||
| calls_to_run, | ||
| deferred_calls_by_index, | ||
| deferred_metadata_by_index, | ||
| output_deferred_calls, | ||
| output_deferred_metadata, | ||
| ) | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,6 +48,8 @@ class RunContext(Generic[RunContextAgentDepsT]): | |
| """Instrumentation settings version, if instrumentation is enabled.""" | ||
| retries: dict[str, int] = field(default_factory=dict) | ||
| """Number of retries for each tool so far.""" | ||
| tool_usage: dict[str, int] = field(default_factory=dict) | ||
| """Number of calls for each tool so far.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have to be consistent/specific about calls vs uses and how this relates to retries; does this imply successful calls, or any call at all? |
||
| tool_call_id: str | None = None | ||
| """The ID of the tool call.""" | ||
| tool_name: str | None = None | ||
|
|
@@ -56,6 +58,8 @@ class RunContext(Generic[RunContextAgentDepsT]): | |
| """Number of retries of this tool so far.""" | ||
| max_retries: int = 0 | ||
| """The maximum number of retries of this tool.""" | ||
| max_uses: int = 0 | ||
| """The maximum number of times this tool can be used in the run.""" | ||
| run_step: int = 0 | ||
| """The current step in the run.""" | ||
| tool_call_approved: bool = False | ||
|
|
@@ -64,6 +68,8 @@ class RunContext(Generic[RunContextAgentDepsT]): | |
| """Whether the output passed to an output validator is partial.""" | ||
| run_id: str | None = None | ||
| """"Unique identifier for the agent run.""" | ||
| max_tool_calls: int | None = None | ||
| """The maximum number of tool calls allowed during this run, or `None` if unlimited.""" | ||
|
|
||
| @property | ||
| def last_attempt(self) -> bool: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,11 +66,22 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe | |
|
|
||
| @property | ||
| def tool_defs(self) -> list[ToolDefinition]: | ||
| """The tool definitions for the tools in this tool manager.""" | ||
| if self.tools is None: | ||
| """The tool definitions for the tools in this tool manager. | ||
|
|
||
| Tools that have reached their `max_uses` limit are filtered out. | ||
| """ | ||
| if self.tools is None or self.ctx is None: | ||
| raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover | ||
|
|
||
| return [tool.tool_def for tool in self.tools.values()] | ||
| result: list[ToolDefinition] = [] | ||
| for tool in self.tools.values(): | ||
| # Filter out tools that have reached their max_uses limit | ||
| if tool.max_uses is not None: | ||
| current_uses = self.ctx.tool_usage.get(tool.tool_def.name, 0) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| if current_uses >= tool.max_uses: | ||
| continue | ||
| result.append(tool.tool_def) | ||
| return result | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather have this be a list comprehension |
||
|
|
||
| def should_call_sequentially(self, calls: list[ToolCallPart]) -> bool: | ||
| """Whether to require sequential tool calls for a list of tool calls.""" | ||
|
|
@@ -159,8 +170,11 @@ async def _call_tool( | |
| max_retries=tool.max_retries, | ||
| tool_call_approved=approved, | ||
| partial_output=allow_partial, | ||
| max_uses=tool.max_uses, | ||
| ) | ||
|
|
||
| self.ctx.tool_usage[name] = self.ctx.tool_usage.get(name, 0) + 1 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To me, "uses" implied successful uses of a potentially costly thing, so excluding retries. Otherwise I think it should definitely be
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we take it to mean "successful calls", then this part from #3352 (comment) will require special care, especially if we do it in
|
||
|
|
||
| pyd_allow_partial = 'trailing-strings' if allow_partial else 'off' | ||
| validator = tool.args_validator | ||
| if isinstance(call.args, str): | ||
|
|
@@ -272,3 +286,37 @@ async def _call_function_tool( | |
| ) | ||
|
|
||
| return tool_result | ||
|
|
||
| def get_max_use_of_tool(self, tool_name: str) -> int | None: | ||
| """Get the maximum number of uses allowed for a given tool, or `None` if unlimited.""" | ||
| if self.tools is None: | ||
| raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover | ||
|
|
||
| tool = self.tools.get(tool_name, None) | ||
| if tool is None: | ||
| return None | ||
|
|
||
| return tool.max_uses | ||
|
|
||
| def get_current_use_of_tool(self, tool_name: str) -> int: | ||
| """Get the current number of uses of a given tool.""" | ||
| if self.ctx is None: | ||
| raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover | ||
|
|
||
| return self.ctx.tool_usage.get(tool_name, 0) | ||
|
|
||
| def _get_max_tool_calls(self) -> int | None: | ||
| """Get the maximum number of tool calls allowed during this run, or `None` if unlimited.""" | ||
| if self.ctx is None: | ||
| raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move this to a private helper method like |
||
|
|
||
| return self.ctx.max_tool_calls | ||
|
|
||
| def can_make_tool_calls(self, num_tool_calls: int, usage: RunUsage) -> bool: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this require a
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Jeremaiha I don't think so as it's only called from the async event loop, not from threads. |
||
| """Check if the tool calls can be made within max_tool_calls limit if it is set.""" | ||
| max_tool_calls = self._get_max_tool_calls() | ||
| if max_tool_calls is not None: | ||
| usage.tool_calls += num_tool_calls | ||
| if usage.tool_calls > max_tool_calls: | ||
| return False | ||
| return True | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this belongs on
tools-advanced.md, with a link from the usage limits section above.And no need to mention the name of the field in the title, that'll make it too long for the ToC sidebar