diff --git a/libs/experimental/langchain_experimental/llms/ollama_functions.py b/libs/experimental/langchain_experimental/llms/ollama_functions.py index f71f91c01a..d9ba49f7b1 100644 --- a/libs/experimental/langchain_experimental/llms/ollama_functions.py +++ b/libs/experimental/langchain_experimental/llms/ollama_functions.py @@ -17,7 +17,10 @@ from typing import ( ) from langchain_community.chat_models.ollama import ChatOllama -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models import LanguageModelInput from langchain_core.messages import AIMessage, BaseMessage, ToolCall from langchain_core.output_parsers.base import OutputParserLike @@ -369,6 +372,86 @@ class OllamaFunctions(ChatOllama): generations=[ChatGeneration(message=response_message_with_functions)] ) + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + functions = kwargs.get("functions", []) + if "functions" in kwargs: + del kwargs["functions"] + if "function_call" in kwargs: + functions = [ + fn for fn in functions if fn["name"] == kwargs["function_call"]["name"] + ] + if not functions: + raise ValueError( + "If `function_call` is specified, you must also pass a " + "matching function in `functions`." + ) + del kwargs["function_call"] + elif not functions: + functions.append(DEFAULT_RESPONSE_FUNCTION) + if _is_pydantic_class(functions[0]): + functions = [convert_to_ollama_tool(fn) for fn in functions] + system_message_prompt_template = SystemMessagePromptTemplate.from_template( + self.tool_system_prompt_template + ) + system_message = system_message_prompt_template.format( + tools=json.dumps(functions, indent=2) + ) + response_message = await super()._agenerate( + [system_message] + messages, stop=stop, run_manager=run_manager, **kwargs + ) + chat_generation_content = response_message.generations[0].text + if not isinstance(chat_generation_content, str): + raise ValueError("OllamaFunctions does not support non-string output.") + try: + parsed_chat_result = json.loads(chat_generation_content) + except json.JSONDecodeError: + raise ValueError( + f"""'{self.model}' did not respond with valid JSON. + Please try again. + Response: {chat_generation_content}""" + ) + called_tool_name = parsed_chat_result["tool"] + called_tool_arguments = parsed_chat_result["tool_input"] + called_tool = next( + (fn for fn in functions if fn["name"] == called_tool_name), None + ) + if called_tool is None: + raise ValueError( + f"Failed to parse a function call from {self.model} output: " + f"{chat_generation_content}" + ) + if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]: + return ChatResult( + generations=[ + ChatGeneration( + message=AIMessage( + content=called_tool_arguments["response"], + ) + ) + ] + ) + + response_message_with_functions = AIMessage( + content="", + additional_kwargs={ + "function_call": { + "name": called_tool_name, + "arguments": json.dumps(called_tool_arguments) + if called_tool_arguments + else "", + }, + }, + ) + return ChatResult( + generations=[ChatGeneration(message=response_message_with_functions)] + ) + @property def _llm_type(self) -> str: return "ollama_functions"