google-genai[patch]: fix streaming, function calling (#17268)

pull/17077/head^2
Erick Friis 8 months ago committed by GitHub
parent 96b5711a0c
commit e4da7918f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -35,6 +35,7 @@ from langchain_core.callbacks.manager import (
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
FunctionMessage,
HumanMessage,
@ -339,11 +340,23 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
parts = _convert_to_parts(message.content)
elif isinstance(message, FunctionMessage):
role = "user"
response: Any
if not isinstance(message.content, str):
response = message.content
else:
try:
response = json.loads(message.content)
except json.JSONDecodeError:
response = message.content # leave as str representation
parts = [
glm.Part(
function_response=glm.FunctionResponse(
name=message.name,
response=message.content,
response=(
{"output": response}
if not isinstance(response, dict)
else response
),
)
)
]
@ -364,12 +377,16 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
return messages
def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
def _parse_response_candidate(
response_candidate: glm.Candidate, stream: bool
) -> AIMessage:
first_part = response_candidate.content.parts[0]
if first_part.function_call:
function_call = proto.Message.to_dict(first_part.function_call)
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
return AIMessage(content="", additional_kwargs={"function_call": function_call})
return (AIMessageChunk if stream else AIMessage)(
content="", additional_kwargs={"function_call": function_call}
)
else:
parts = response_candidate.content.parts
@ -377,11 +394,14 @@ def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
content: Union[str, List[Union[str, Dict]]] = parts[0].text
else:
content = [proto.Message.to_dict(part) for part in parts]
return AIMessage(content=content, additional_kwargs={})
return (AIMessageChunk if stream else AIMessage)(
content=content, additional_kwargs={}
)
def _response_to_result(
response: glm.GenerateContentResponse,
stream: bool = False,
) -> ChatResult:
"""Converts a PaLM API response into a LangChain ChatResult."""
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
@ -397,8 +417,8 @@ def _response_to_result(
for safety_rating in candidate.safety_ratings
]
generations.append(
ChatGeneration(
message=_parse_response_candidate(candidate),
(ChatGenerationChunk if stream else ChatGeneration)(
message=_parse_response_candidate(candidate, stream=stream),
generation_info=generation_info,
)
)
@ -411,7 +431,10 @@ def _response_to_result(
f"Feedback: {response.prompt_feedback}"
)
generations = [
ChatGeneration(message=AIMessage(content=""), generation_info={})
(ChatGenerationChunk if stream else ChatGeneration)(
message=(AIMessageChunk if stream else AIMessage)(content=""),
generation_info={},
)
]
return ChatResult(generations=generations, llm_output=llm_output)
@ -573,7 +596,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
stream=True,
)
for chunk in response:
_chat_result = _response_to_result(chunk)
_chat_result = _response_to_result(chunk, stream=True)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if run_manager:
run_manager.on_llm_new_token(gen.text)
@ -597,7 +620,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
generation_method=chat.send_message_async,
stream=True,
):
_chat_result = _response_to_result(chunk)
_chat_result = _response_to_result(chunk, stream=True)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if run_manager:
await run_manager.on_llm_new_token(gen.text)

@ -228,13 +228,13 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4
[[package]]
name = "google-api-core"
version = "2.16.2"
version = "2.17.0"
description = "Google API client core library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google-api-core-2.16.2.tar.gz", hash = "sha256:032d37b45d1d6bdaf68fb11ff621e2593263a239fa9246e2e94325f9c47876d2"},
{file = "google_api_core-2.16.2-py3-none-any.whl", hash = "sha256:449ca0e3f14c179b4165b664256066c7861610f70b6ffe54bb01a04e9b466929"},
{file = "google-api-core-2.17.0.tar.gz", hash = "sha256:de7ef0450faec7c75e0aea313f29ac870fdc44cfaec9d6499a9a17305980ef66"},
{file = "google_api_core-2.17.0-py3-none-any.whl", hash = "sha256:08ed79ed8e93e329de5e3e7452746b734e6bf8438d8d64dd3319d21d3164890c"},
]
[package.dependencies]
@ -448,7 +448,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.19"
version = "0.1.21"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -458,7 +458,7 @@ develop = true
[package.dependencies]
anyio = ">=3,<5"
jsonpatch = "^1.33"
langsmith = ">=0.0.83,<0.1"
langsmith = "^0.0.87"
packaging = "^23.2"
pydantic = ">=1,<3"
PyYAML = ">=5.3"

@ -1,5 +1,14 @@
"""Test chat model integration."""
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from typing import Dict, List, Union
import pytest
from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture
@ -58,3 +67,9 @@ def test_parse_history() -> None:
"parts": [{"text": system_input}, {"text": text_question1}],
}
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}
@pytest.mark.parametrize("content", ['["a"]', '{"a":"b"}', "function output"])
def test_parse_function_history(content: Union[str, List[Union[str, Dict]]]) -> None:
function_message = FunctionMessage(name="search_tool", content=content)
_parse_chat_history([function_message], convert_system_message_to_human=True)

Loading…
Cancel
Save