[Experimental]: Async agenerate method ollama functions (#21682)

- **Description:** :
Added Async method for Generate for OllamaFunctions which was missing
and was raising errors for the users.
   
- **Issue:** 
#21422
This commit is contained in:
Mohammad Mohtashim 2024-06-05 20:50:36 +05:00 committed by GitHub
parent 328d0c99f2
commit 7fcef2556c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,7 +17,10 @@ from typing import (
) )
from langchain_community.chat_models.ollama import ChatOllama from langchain_community.chat_models.ollama import ChatOllama
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, BaseMessage, ToolCall from langchain_core.messages import AIMessage, BaseMessage, ToolCall
from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.base import OutputParserLike
@ -369,6 +372,86 @@ class OllamaFunctions(ChatOllama):
generations=[ChatGeneration(message=response_message_with_functions)] generations=[ChatGeneration(message=response_message_with_functions)]
) )
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
functions = kwargs.get("functions", [])
if "functions" in kwargs:
del kwargs["functions"]
if "function_call" in kwargs:
functions = [
fn for fn in functions if fn["name"] == kwargs["function_call"]["name"]
]
if not functions:
raise ValueError(
"If `function_call` is specified, you must also pass a "
"matching function in `functions`."
)
del kwargs["function_call"]
elif not functions:
functions.append(DEFAULT_RESPONSE_FUNCTION)
if _is_pydantic_class(functions[0]):
functions = [convert_to_ollama_tool(fn) for fn in functions]
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
self.tool_system_prompt_template
)
system_message = system_message_prompt_template.format(
tools=json.dumps(functions, indent=2)
)
response_message = await super()._agenerate(
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
)
chat_generation_content = response_message.generations[0].text
if not isinstance(chat_generation_content, str):
raise ValueError("OllamaFunctions does not support non-string output.")
try:
parsed_chat_result = json.loads(chat_generation_content)
except json.JSONDecodeError:
raise ValueError(
f"""'{self.model}' did not respond with valid JSON.
Please try again.
Response: {chat_generation_content}"""
)
called_tool_name = parsed_chat_result["tool"]
called_tool_arguments = parsed_chat_result["tool_input"]
called_tool = next(
(fn for fn in functions if fn["name"] == called_tool_name), None
)
if called_tool is None:
raise ValueError(
f"Failed to parse a function call from {self.model} output: "
f"{chat_generation_content}"
)
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=called_tool_arguments["response"],
)
)
]
)
response_message_with_functions = AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": called_tool_name,
"arguments": json.dumps(called_tool_arguments)
if called_tool_arguments
else "",
},
},
)
return ChatResult(
generations=[ChatGeneration(message=response_message_with_functions)]
)
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "ollama_functions" return "ollama_functions"