From e4da7918f37d0966b04d10649baeb54721013877 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Thu, 8 Feb 2024 17:29:53 -0800 Subject: [PATCH] google-genai[patch]: fix streaming, function calling (#17268) --- .../langchain_google_genai/chat_models.py | 41 +++++++++++++++---- libs/partners/google-genai/poetry.lock | 10 ++--- .../tests/unit_tests/test_chat_models.py | 17 +++++++- 3 files changed, 53 insertions(+), 15 deletions(-) diff --git a/libs/partners/google-genai/langchain_google_genai/chat_models.py b/libs/partners/google-genai/langchain_google_genai/chat_models.py index 8d56e983c8..420dbcfd13 100644 --- a/libs/partners/google-genai/langchain_google_genai/chat_models.py +++ b/libs/partners/google-genai/langchain_google_genai/chat_models.py @@ -35,6 +35,7 @@ from langchain_core.callbacks.manager import ( from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, FunctionMessage, HumanMessage, @@ -339,11 +340,23 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human parts = _convert_to_parts(message.content) elif isinstance(message, FunctionMessage): role = "user" + response: Any + if not isinstance(message.content, str): + response = message.content + else: + try: + response = json.loads(message.content) + except json.JSONDecodeError: + response = message.content # leave as str representation parts = [ glm.Part( function_response=glm.FunctionResponse( name=message.name, - response=message.content, + response=( + {"output": response} + if not isinstance(response, dict) + else response + ), ) ) ] @@ -364,12 +377,16 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human return messages -def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage: +def _parse_response_candidate( + response_candidate: glm.Candidate, stream: bool +) -> AIMessage: first_part = response_candidate.content.parts[0] if first_part.function_call: function_call = proto.Message.to_dict(first_part.function_call) function_call["arguments"] = json.dumps(function_call.pop("args", {})) - return AIMessage(content="", additional_kwargs={"function_call": function_call}) + return (AIMessageChunk if stream else AIMessage)( + content="", additional_kwargs={"function_call": function_call} + ) else: parts = response_candidate.content.parts @@ -377,11 +394,14 @@ def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage: content: Union[str, List[Union[str, Dict]]] = parts[0].text else: content = [proto.Message.to_dict(part) for part in parts] - return AIMessage(content=content, additional_kwargs={}) + return (AIMessageChunk if stream else AIMessage)( + content=content, additional_kwargs={} + ) def _response_to_result( response: glm.GenerateContentResponse, + stream: bool = False, ) -> ChatResult: """Converts a PaLM API response into a LangChain ChatResult.""" llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)} @@ -397,8 +417,8 @@ def _response_to_result( for safety_rating in candidate.safety_ratings ] generations.append( - ChatGeneration( - message=_parse_response_candidate(candidate), + (ChatGenerationChunk if stream else ChatGeneration)( + message=_parse_response_candidate(candidate, stream=stream), generation_info=generation_info, ) ) @@ -411,7 +431,10 @@ def _response_to_result( f"Feedback: {response.prompt_feedback}" ) generations = [ - ChatGeneration(message=AIMessage(content=""), generation_info={}) + (ChatGenerationChunk if stream else ChatGeneration)( + message=(AIMessageChunk if stream else AIMessage)(content=""), + generation_info={}, + ) ] return ChatResult(generations=generations, llm_output=llm_output) @@ -573,7 +596,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel): stream=True, ) for chunk in response: - _chat_result = _response_to_result(chunk) + _chat_result = _response_to_result(chunk, stream=True) gen = cast(ChatGenerationChunk, _chat_result.generations[0]) if run_manager: run_manager.on_llm_new_token(gen.text) @@ -597,7 +620,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel): generation_method=chat.send_message_async, stream=True, ): - _chat_result = _response_to_result(chunk) + _chat_result = _response_to_result(chunk, stream=True) gen = cast(ChatGenerationChunk, _chat_result.generations[0]) if run_manager: await run_manager.on_llm_new_token(gen.text) diff --git a/libs/partners/google-genai/poetry.lock b/libs/partners/google-genai/poetry.lock index dc37a3fc06..ad06fd99ce 100644 --- a/libs/partners/google-genai/poetry.lock +++ b/libs/partners/google-genai/poetry.lock @@ -228,13 +228,13 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4 [[package]] name = "google-api-core" -version = "2.16.2" +version = "2.17.0" description = "Google API client core library" optional = false python-versions = ">=3.7" files = [ - {file = "google-api-core-2.16.2.tar.gz", hash = "sha256:032d37b45d1d6bdaf68fb11ff621e2593263a239fa9246e2e94325f9c47876d2"}, - {file = "google_api_core-2.16.2-py3-none-any.whl", hash = "sha256:449ca0e3f14c179b4165b664256066c7861610f70b6ffe54bb01a04e9b466929"}, + {file = "google-api-core-2.17.0.tar.gz", hash = "sha256:de7ef0450faec7c75e0aea313f29ac870fdc44cfaec9d6499a9a17305980ef66"}, + {file = "google_api_core-2.17.0-py3-none-any.whl", hash = "sha256:08ed79ed8e93e329de5e3e7452746b734e6bf8438d8d64dd3319d21d3164890c"}, ] [package.dependencies] @@ -448,7 +448,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.19" +version = "0.1.21" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -458,7 +458,7 @@ develop = true [package.dependencies] anyio = ">=3,<5" jsonpatch = "^1.33" -langsmith = ">=0.0.83,<0.1" +langsmith = "^0.0.87" packaging = "^23.2" pydantic = ">=1,<3" PyYAML = ">=5.3" diff --git a/libs/partners/google-genai/tests/unit_tests/test_chat_models.py b/libs/partners/google-genai/tests/unit_tests/test_chat_models.py index 93d5bb3231..e13dcccafc 100644 --- a/libs/partners/google-genai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/google-genai/tests/unit_tests/test_chat_models.py @@ -1,5 +1,14 @@ """Test chat model integration.""" -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from typing import Dict, List, Union + +import pytest +from langchain_core.messages import ( + AIMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture @@ -58,3 +67,9 @@ def test_parse_history() -> None: "parts": [{"text": system_input}, {"text": text_question1}], } assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]} + + +@pytest.mark.parametrize("content", ['["a"]', '{"a":"b"}', "function output"]) +def test_parse_function_history(content: Union[str, List[Union[str, Dict]]]) -> None: + function_message = FunctionMessage(name="search_tool", content=content) + _parse_chat_history([function_message], convert_system_message_to_human=True)