diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java index a7dfbc74f40..bbbdd2f46b5 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java @@ -32,6 +32,7 @@ import org.springframework.ai.tool.observation.ToolCallingContentObservationFilter; import org.springframework.ai.tool.observation.ToolCallingObservationConvention; import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; +import org.springframework.ai.tool.resolution.ProviderToolCallbackResolver; import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; import org.springframework.ai.tool.resolution.StaticToolCallbackResolver; import org.springframework.ai.tool.resolution.ToolCallbackResolver; @@ -51,6 +52,7 @@ * @author Thomas Vitale * @author Christian Tzolov * @author Daniel Garnier-Moiroux + * @author Yanming Zhou * @since 1.0.0 */ @AutoConfiguration @@ -66,15 +68,17 @@ ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationC List toolCallbacks, List tcbProviders) { List allFunctionAndToolCallbacks = new ArrayList<>(toolCallbacks); - tcbProviders.stream().map(pr -> List.of(pr.getToolCallbacks())).forEach(allFunctionAndToolCallbacks::addAll); var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks); + var providerToolCallbackResolver = new ProviderToolCallbackResolver(tcbProviders); + var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder() .applicationContext(applicationContext) .build(); - return new DelegatingToolCallbackResolver(List.of(staticToolCallbackResolver, springBeanToolCallbackResolver)); + return new DelegatingToolCallbackResolver( + List.of(staticToolCallbackResolver, providerToolCallbackResolver, springBeanToolCallbackResolver)); } @Bean diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java index af42d744158..ac2d9fa943a 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java @@ -45,12 +45,17 @@ import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * Unit tests for {@link ToolCallingAutoConfiguration}. * * @author Thomas Vitale * @author Christian Tzolov + * @author Yanming Zhou */ class ToolCallingAutoConfigurationTests { @@ -69,6 +74,19 @@ void beansAreCreated() { }); } + @Test + void deferGetToolCallbacks() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(Config.class) + .run(context -> { + var toolCallbackResolver = context.getBean(ToolCallbackResolver.class); + var toolCallbackProvider = context.getBean("toolCallbacks", ToolCallbackProvider.class); + verify(toolCallbackProvider, never()).getToolCallbacks(); + assertThat(toolCallbackResolver.resolve("getForecast")).isNotNull(); + verify(toolCallbackProvider, times(1)).getToolCallbacks(); + }); + } + @Test void resolveMultipleFunctionAndToolCallbacks() { new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) @@ -212,7 +230,10 @@ static class Config { // ToolCallbacks.from(...) utility method. @Bean public ToolCallbackProvider toolCallbacks() { - return MethodToolCallbackProvider.builder().toolObjects(new WeatherService()).build(); + ToolCallbackProvider toolCallbackProvider = MethodToolCallbackProvider.builder() + .toolObjects(new WeatherService()) + .build(); + return spy(toolCallbackProvider); } @Bean diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ProviderToolCallbackResolver.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ProviderToolCallbackResolver.java new file mode 100644 index 00000000000..f5efec01160 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ProviderToolCallbackResolver.java @@ -0,0 +1,63 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.resolution; + +import java.util.List; +import java.util.stream.Stream; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.function.SingletonSupplier; + +/** + * A {@link ToolCallbackResolver} that resolves tool callbacks from + * {@link ToolCallbackProvider} lazily. + * + * @author Yanming Zhou + */ +public class ProviderToolCallbackResolver implements ToolCallbackResolver { + + private static final Logger logger = LoggerFactory.getLogger(ProviderToolCallbackResolver.class); + + private final SingletonSupplier> toolCallbackSupplier; + + public ProviderToolCallbackResolver(List toolCallbackProviders) { + Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null"); + + this.toolCallbackSupplier = SingletonSupplier.of(() -> toolCallbackProviders.stream() + .flatMap(provider -> Stream.of(provider.getToolCallbacks())) + .toList()); + } + + @Override + @Nullable + public ToolCallback resolve(String toolName) { + Assert.hasText(toolName, "toolName cannot be null or empty"); + logger.debug("ToolCallback resolution attempt from tool callback provider"); + return this.toolCallbackSupplier.obtain() + .stream() + .filter(toolCallback -> toolName.equals(toolCallback.getToolDefinition().name())) + .findAny() + .orElse(null); + } + +}