From acfce300178b9dc91533bb7a2626b7606982363a Mon Sep 17 00:00:00 2001 From: Joel Akeret Date: Wed, 24 Jul 2024 23:57:05 +0200 Subject: [PATCH] Adding compatibility for OllamaFunctions with ImagePromptTemplate (#24499) - [ ] **PR title**: "experimental: Adding compatibility for OllamaFunctions with ImagePromptTemplate" - [ ] **PR message**: - **Description:** Removes the outdated `_convert_messages_to_ollama_messages` method override in the `OllamaFunctions` class to ensure that ollama multimodal models can be invoked with an image. - **Issue:** #24174 --------- Co-authored-by: Joel Akeret Co-authored-by: Isaac Francisco <78627776+isahers1@users.noreply.github.com> Co-authored-by: isaac hershenson --- .../llms/ollama_functions.py | 57 ------------------- .../tests/unit_tests/test_ollama_functions.py | 30 ++++++++++ 2 files changed, 30 insertions(+), 57 deletions(-) create mode 100644 libs/experimental/tests/unit_tests/test_ollama_functions.py diff --git a/libs/experimental/langchain_experimental/llms/ollama_functions.py b/libs/experimental/langchain_experimental/llms/ollama_functions.py index 630e514fd1..ba68a7f3ea 100644 --- a/libs/experimental/langchain_experimental/llms/ollama_functions.py +++ b/libs/experimental/langchain_experimental/llms/ollama_functions.py @@ -12,7 +12,6 @@ from typing import ( TypedDict, TypeVar, Union, - cast, ) from langchain_community.chat_models.ollama import ChatOllama @@ -24,10 +23,7 @@ from langchain_core.language_models import LanguageModelInput from langchain_core.messages import ( AIMessage, BaseMessage, - HumanMessage, - SystemMessage, ToolCall, - ToolMessage, ) from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.json import JsonOutputParser @@ -282,59 +278,6 @@ class OllamaFunctions(ChatOllama): else: return llm | parser_chain - def _convert_messages_to_ollama_messages( - self, messages: List[BaseMessage] - ) -> List[Dict[str, Union[str, List[str]]]]: - ollama_messages: List = [] - for message in messages: - role = "" - if isinstance(message, HumanMessage): - role = "user" - elif isinstance(message, AIMessage) or isinstance(message, ToolMessage): - role = "assistant" - elif isinstance(message, SystemMessage): - role = "system" - else: - raise ValueError("Received unsupported message type for Ollama.") - - content = "" - images = [] - if isinstance(message.content, str): - content = message.content - else: - for content_part in cast(List[Dict], message.content): - if content_part.get("type") == "text": - content += f"\n{content_part['text']}" - elif content_part.get("type") == "image_url": - if isinstance(content_part.get("image_url"), str): - image_url_components = content_part["image_url"].split(",") - # Support data:image/jpeg;base64, format - # and base64 strings - if len(image_url_components) > 1: - images.append(image_url_components[1]) - else: - images.append(image_url_components[0]) - else: - raise ValueError( - "Only string image_url content parts are supported." - ) - else: - raise ValueError( - "Unsupported message content type. " - "Must either have type 'text' or type 'image_url' " - "with a string 'image_url' field." - ) - - ollama_messages.append( - { - "role": role, - "content": content, - "images": images, - } - ) - - return ollama_messages - def _generate( self, messages: List[BaseMessage], diff --git a/libs/experimental/tests/unit_tests/test_ollama_functions.py b/libs/experimental/tests/unit_tests/test_ollama_functions.py new file mode 100644 index 0000000000..1404473d50 --- /dev/null +++ b/libs/experimental/tests/unit_tests/test_ollama_functions.py @@ -0,0 +1,30 @@ +import json +from typing import Any +from unittest.mock import patch + +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel + +from langchain_experimental.llms.ollama_functions import OllamaFunctions + + +class Schema(BaseModel): + pass + + +@patch.object(OllamaFunctions, "_create_stream") +def test_convert_image_prompt( + _create_stream_mock: Any, +) -> None: + response = {"message": {"content": '{"tool": "Schema", "tool_input": {}}'}} + _create_stream_mock.return_value = [json.dumps(response)] + + prompt = ChatPromptTemplate.from_messages( + [("human", [{"image_url": "data:image/jpeg;base64,{image_url}"}])] + ) + + lmm = prompt | OllamaFunctions().with_structured_output(schema=Schema) + + schema_instance = lmm.invoke(dict(image_url="")) + + assert schema_instance is not None