From 2ddac9a7c347efd59a2b0e508955a57eca9f8eba Mon Sep 17 00:00:00 2001 From: Karim Lalani Date: Mon, 29 Apr 2024 09:13:33 -0500 Subject: [PATCH] experimental[minor]: Add bind_tools and with_structured_output functions to OllamaFunctions (#20881) Implemented bind_tools for OllamaFunctions. Made OllamaFunctions sub class of ChatOllama. Implemented with_structured_output for OllamaFunctions. integration unit test has been updated. notebook has been updated. --------- Co-authored-by: Bagatur --- .../integrations/chat/ollama_functions.ipynb | 151 +++++++--- .../llms/ollama_functions.py | 274 ++++++++++++++++-- .../llms/test_ollama_functions.py | 46 ++- 3 files changed, 401 insertions(+), 70 deletions(-) diff --git a/docs/docs/integrations/chat/ollama_functions.ipynb b/docs/docs/integrations/chat/ollama_functions.ipynb index 8a2e2826e9..ae4d0ac205 100644 --- a/docs/docs/integrations/chat/ollama_functions.ipynb +++ b/docs/docs/integrations/chat/ollama_functions.ipynb @@ -17,7 +17,7 @@ "\n", "This notebook shows how to use an experimental wrapper around Ollama that gives it the same API as OpenAI Functions.\n", "\n", - "Note that more powerful and capable models will perform better with complex schema and/or multiple functions. The examples below use Mistral.\n", + "Note that more powerful and capable models will perform better with complex schema and/or multiple functions. The examples below use llama3 and phi3 models.\n", "For a complete list of supported models and model variants, see the [Ollama model library](https://ollama.ai/library).\n", "\n", "## Setup\n", @@ -32,12 +32,18 @@ { "cell_type": "code", "execution_count": 1, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-28T00:53:25.276543Z", + "start_time": "2024-04-28T00:53:24.881202Z" + }, + "scrolled": true + }, "outputs": [], "source": [ "from langchain_experimental.llms.ollama_functions import OllamaFunctions\n", "\n", - "model = OllamaFunctions(model=\"mistral\")" + "model = OllamaFunctions(model=\"llama3\", format=\"json\")" ] }, { @@ -50,11 +56,16 @@ { "cell_type": "code", "execution_count": 2, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-26T04:59:17.270931Z", + "start_time": "2024-04-26T04:59:17.263347Z" + } + }, "outputs": [], "source": [ - "model = model.bind(\n", - " functions=[\n", + "model = model.bind_tools(\n", + " tools=[\n", " {\n", " \"name\": \"get_current_weather\",\n", " \"description\": \"Get the current weather in a given location\",\n", @@ -88,12 +99,17 @@ { "cell_type": "code", "execution_count": 3, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-26T04:59:26.092428Z", + "start_time": "2024-04-26T04:59:17.272627Z" + } + }, "outputs": [ { "data": { "text/plain": [ - "AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_current_weather', 'arguments': '{\"location\": \"Boston, MA\", \"unit\": \"celsius\"}'}})" + "AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_current_weather', 'arguments': '{\"location\": \"Boston, MA\"}'}}, id='run-1791f9fe-95ad-4ca4-bdf7-9f73eab31e6f-0')" ] }, "execution_count": 3, @@ -111,54 +127,119 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Using for extraction\n", + "## Structured Output\n", "\n", - "One useful thing you can do with function calling here is extracting properties from a given input in a structured format:" + "One useful thing you can do with function calling using `with_structured_output()` function is extracting properties from a given input in a structured format:" ] }, { "cell_type": "code", "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-26T04:59:26.098828Z", + "start_time": "2024-04-26T04:59:26.094021Z" + } + }, + "outputs": [], + "source": [ + "from langchain_core.prompts import PromptTemplate\n", + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "\n", + "\n", + "# Schema for structured response\n", + "class Person(BaseModel):\n", + " name: str = Field(description=\"The person's name\", required=True)\n", + " height: float = Field(description=\"The person's height\", required=True)\n", + " hair_color: str = Field(description=\"The person's hair color\")\n", + "\n", + "\n", + "# Prompt template\n", + "prompt = PromptTemplate.from_template(\n", + " \"\"\"Alex is 5 feet tall. \n", + "Claudia is 1 feet taller than Alex and jumps higher than him. \n", + "Claudia is a brunette and Alex is blonde.\n", + "\n", + "Human: {question}\n", + "AI: \"\"\"\n", + ")\n", + "\n", + "# Chain\n", + "llm = OllamaFunctions(model=\"phi3\", format=\"json\", temperature=0)\n", + "structured_llm = llm.with_structured_output(Person)\n", + "chain = prompt | structured_llm" + ] + }, + { + "cell_type": "markdown", "metadata": {}, + "source": [ + "### Extracting data about Alex" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-26T04:59:30.164955Z", + "start_time": "2024-04-26T04:59:26.099790Z" + } + }, "outputs": [ { "data": { "text/plain": [ - "[{'name': 'Alex', 'height': 5, 'hair_color': 'blonde'},\n", - " {'name': 'Claudia', 'height': 6, 'hair_color': 'brunette'}]" + "Person(name='Alex', height=5.0, hair_color='blonde')" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from langchain.chains import create_extraction_chain\n", - "\n", - "# Schema\n", - "schema = {\n", - " \"properties\": {\n", - " \"name\": {\"type\": \"string\"},\n", - " \"height\": {\"type\": \"integer\"},\n", - " \"hair_color\": {\"type\": \"string\"},\n", - " },\n", - " \"required\": [\"name\", \"height\"],\n", - "}\n", - "\n", - "# Input\n", - "input = \"\"\"Alex is 5 feet tall. Claudia is 1 feet taller than Alex and jumps higher than him. Claudia is a brunette and Alex is blonde.\"\"\"\n", - "\n", - "# Run chain\n", - "llm = OllamaFunctions(model=\"mistral\", temperature=0)\n", - "chain = create_extraction_chain(schema, llm)\n", - "chain.run(input)" + "alex = chain.invoke(\"Describe Alex\")\n", + "alex" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Extracting data about Claudia" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-26T04:59:31.509846Z", + "start_time": "2024-04-26T04:59:30.165662Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Person(name='Claudia', height=6.0, hair_color='brunette')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "claudia = chain.invoke(\"Describe Claudia\")\n", + "claudia" ] } ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -172,9 +253,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.5" + "version": "3.9.1" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/libs/experimental/langchain_experimental/llms/ollama_functions.py b/libs/experimental/langchain_experimental/llms/ollama_functions.py index af5d7a478b..7bd04f918f 100644 --- a/libs/experimental/langchain_experimental/llms/ollama_functions.py +++ b/libs/experimental/langchain_experimental/llms/ollama_functions.py @@ -1,14 +1,34 @@ import json -from typing import Any, Dict, List, Optional +from operator import itemgetter +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Type, + TypedDict, + TypeVar, + Union, + overload, +) from langchain_community.chat_models.ollama import ChatOllama from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.language_models import BaseChatModel +from langchain_core.language_models import LanguageModelInput from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.output_parsers.base import OutputParserLike +from langchain_core.output_parsers.json import JsonOutputParser +from langchain_core.output_parsers.pydantic import PydanticOutputParser from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.prompts import SystemMessagePromptTemplate - -from langchain_experimental.pydantic_v1 import root_validator +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import Runnable, RunnableLambda +from langchain_core.runnables.base import RunnableMap +from langchain_core.runnables.passthrough import RunnablePassthrough +from langchain_core.tools import BaseTool DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools: @@ -22,7 +42,6 @@ You must always select one of the above tools and respond with only a JSON objec }} """ # noqa: E501 - DEFAULT_RESPONSE_FUNCTION = { "name": "__conversational_response", "description": ( @@ -40,26 +59,219 @@ DEFAULT_RESPONSE_FUNCTION = { }, } +_BM = TypeVar("_BM", bound=BaseModel) +_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM]] +_DictOrPydantic = Union[Dict, _BM] + + +def _is_pydantic_class(obj: Any) -> bool: + return isinstance(obj, type) and ( + issubclass(obj, BaseModel) or BaseModel in obj.__bases__ + ) + + +def convert_to_ollama_tool(tool: Any) -> Dict: + """Convert a tool to an Ollama tool.""" + if _is_pydantic_class(tool): + schema = tool.construct().schema() + definition = {"name": schema["title"], "properties": schema["properties"]} + if "required" in schema: + definition["required"] = schema["required"] + + return definition + raise ValueError( + f"Cannot convert {tool} to an Ollama tool. {tool} needs to be a Pydantic model." + ) -class OllamaFunctions(BaseChatModel): - """Function chat model that uses Ollama API.""" - llm: ChatOllama +class _AllReturnType(TypedDict): + raw: BaseMessage + parsed: Optional[_DictOrPydantic] + parsing_error: Optional[BaseException] - tool_system_prompt_template: str - @root_validator(pre=True) - def validate_environment(cls, values: Dict) -> Dict: - values["llm"] = values.get("llm") or ChatOllama(**values, format="json") - values["tool_system_prompt_template"] = ( - values.get("tool_system_prompt_template") or DEFAULT_SYSTEM_TEMPLATE +def parse_response(message: BaseMessage) -> str: + """Extract `function_call` from `AIMessage`.""" + if isinstance(message, AIMessage): + kwargs = message.additional_kwargs + if "function_call" in kwargs: + if "arguments" in kwargs["function_call"]: + return kwargs["function_call"]["arguments"] + raise ValueError( + f"`arguments` missing from `function_call` within AIMessage: {message}" + ) + raise ValueError( + "`function_call` missing from `additional_kwargs` " + f"within AIMessage: {message}" ) - return values + raise ValueError(f"`message` is not an instance of `AIMessage`: {message}") - @property - def model(self) -> BaseChatModel: - """For backwards compatibility.""" - return self.llm + +class OllamaFunctions(ChatOllama): + """Function chat model that uses Ollama API.""" + + tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + return self.bind(functions=tools, **kwargs) + + @overload + def with_structured_output( + self, + schema: Optional[_DictOrPydanticClass] = None, + *, + include_raw: Literal[True] = True, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, _AllReturnType]: + ... + + @overload + def with_structured_output( + self, + schema: Optional[_DictOrPydanticClass] = None, + *, + include_raw: Literal[False] = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, _DictOrPydantic]: + ... + + def with_structured_output( + self, + schema: Optional[_DictOrPydanticClass] = None, + *, + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, _DictOrPydantic]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: The output schema as a dict or a Pydantic class. If a Pydantic class + then the model output will be an object of that class. If a dict then + the model output will be a dict. With a Pydantic class the returned + attributes will be validated, whereas with a dict they will not be. + include_raw: If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + Returns: + A Runnable that takes any ChatModel input and returns as output: + + If include_raw is True then a dict with keys: + raw: BaseMessage + parsed: Optional[_DictOrPydantic] + parsing_error: Optional[BaseException] + + If include_raw is False then just _DictOrPydantic is returned, + where _DictOrPydantic depends on the schema: + + If schema is a Pydantic class then _DictOrPydantic is the Pydantic + class. + + If schema is a dict then _DictOrPydantic is a dict. + + Example: Pydantic schema (include_raw=False): + .. code-block:: python + + from langchain_experimental.llms import OllamaFunctions + from langchain_core.pydantic_v1 import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm = OllamaFunctions(model="phi3", format="json", temperature=0) + structured_llm = llm.with_structured_output(AnswerWithJustification) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + + # -> AnswerWithJustification( + # answer='They weigh the same', + # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.' + # ) + + Example: Pydantic schema (include_raw=True): + .. code-block:: python + + from langchain_experimental.llms import OllamaFunctions + from langchain_core.pydantic_v1 import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm = OllamaFunctions(model="phi3", format="json", temperature=0) + structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), + # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), + # 'parsing_error': None + # } + + Example: dict schema (method="include_raw=False): + .. code-block:: python + + from langchain_experimental.llms import OllamaFunctions, convert_to_ollama_tool + from langchain_core.pydantic_v1 import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + dict_schema = convert_to_ollama_tool(AnswerWithJustification) + llm = OllamaFunctions(model="phi3", format="json", temperature=0) + structured_llm = llm.with_structured_output(dict_schema) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'answer': 'They weigh the same', + # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' + # } + + + """ # noqa: E501 + if kwargs: + raise ValueError(f"Received unsupported arguments {kwargs}") + is_pydantic_schema = _is_pydantic_class(schema) + if schema is None: + raise ValueError( + "schema must be specified when method is 'function_calling'. " + "Received None." + ) + llm = self.bind_tools(tools=[schema], format="json") + if is_pydantic_schema: + output_parser: OutputParserLike = PydanticOutputParser( + pydantic_object=schema + ) + else: + output_parser = JsonOutputParser() + + parser_chain = RunnableLambda(parse_response) | output_parser + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | parser_chain, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | parser_chain def _generate( self, @@ -69,37 +281,41 @@ class OllamaFunctions(BaseChatModel): **kwargs: Any, ) -> ChatResult: functions = kwargs.get("functions", []) + if "functions" in kwargs: + del kwargs["functions"] if "function_call" in kwargs: functions = [ fn for fn in functions if fn["name"] == kwargs["function_call"]["name"] ] if not functions: raise ValueError( - 'If "function_call" is specified, you must also pass a matching \ -function in "functions".' + "If `function_call` is specified, you must also pass a " + "matching function in `functions`." ) del kwargs["function_call"] elif not functions: functions.append(DEFAULT_RESPONSE_FUNCTION) + if _is_pydantic_class(functions[0]): + functions = [convert_to_ollama_tool(fn) for fn in functions] system_message_prompt_template = SystemMessagePromptTemplate.from_template( self.tool_system_prompt_template ) system_message = system_message_prompt_template.format( tools=json.dumps(functions, indent=2) ) - if "functions" in kwargs: - del kwargs["functions"] - response_message = self.llm.invoke( - [system_message] + messages, stop=stop, callbacks=run_manager, **kwargs + response_message = super()._generate( + [system_message] + messages, stop=stop, run_manager=run_manager, **kwargs ) - chat_generation_content = response_message.content + chat_generation_content = response_message.generations[0].text if not isinstance(chat_generation_content, str): raise ValueError("OllamaFunctions does not support non-string output.") try: parsed_chat_result = json.loads(chat_generation_content) except json.JSONDecodeError: raise ValueError( - f'"{self.llm.model}" did not respond with valid JSON. Please try again.' + f"""'{self.model}' did not respond with valid JSON. + Please try again. + Response: {chat_generation_content}""" ) called_tool_name = parsed_chat_result["tool"] called_tool_arguments = parsed_chat_result["tool_input"] @@ -108,8 +324,8 @@ function in "functions".' ) if called_tool is None: raise ValueError( - f"Failed to parse a function call from {self.llm.model} \ -output: {chat_generation_content}" + f"Failed to parse a function call from {self.model} output: " + f"{chat_generation_content}" ) if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]: return ChatResult( diff --git a/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py b/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py index c1b845bdbd..fb63ee5d38 100644 --- a/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py +++ b/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py @@ -2,9 +2,18 @@ import unittest -from langchain_community.chat_models.ollama import ChatOllama +from langchain_core.messages import AIMessage +from langchain_core.pydantic_v1 import BaseModel, Field -from langchain_experimental.llms.ollama_functions import OllamaFunctions +from langchain_experimental.llms.ollama_functions import ( + OllamaFunctions, + convert_to_ollama_tool, +) + + +class Joke(BaseModel): + setup: str = Field(description="The setup of the joke") + punchline: str = Field(description="The punchline to the joke") class TestOllamaFunctions(unittest.TestCase): @@ -13,12 +22,11 @@ class TestOllamaFunctions(unittest.TestCase): """ def test_default_ollama_functions(self) -> None: - base_model = OllamaFunctions(model="mistral") - self.assertIsInstance(base_model.model, ChatOllama) + base_model = OllamaFunctions(model="llama3", format="json") # bind functions - model = base_model.bind( - functions=[ + model = base_model.bind_tools( + tools=[ { "name": "get_current_weather", "description": "Get the current weather in a given location", @@ -47,3 +55,29 @@ class TestOllamaFunctions(unittest.TestCase): function_call = res.additional_kwargs.get("function_call") assert function_call self.assertEqual(function_call.get("name"), "get_current_weather") + + def test_ollama_structured_output(self) -> None: + model = OllamaFunctions(model="phi3") + structured_llm = model.with_structured_output(Joke, include_raw=False) + + res = structured_llm.invoke("Tell me a joke about cats") + assert isinstance(res, Joke) + + def test_ollama_structured_output_with_json(self) -> None: + model = OllamaFunctions(model="phi3") + joke_schema = convert_to_ollama_tool(Joke) + structured_llm = model.with_structured_output(joke_schema, include_raw=False) + + res = structured_llm.invoke("Tell me a joke about cats") + assert "setup" in res + assert "punchline" in res + + def test_ollama_structured_output_raw(self) -> None: + model = OllamaFunctions(model="phi3") + structured_llm = model.with_structured_output(Joke, include_raw=True) + + res = structured_llm.invoke("Tell me a joke about cars") + assert "raw" in res + assert "parsed" in res + assert isinstance(res["raw"], AIMessage) + assert isinstance(res["parsed"], Joke)