added history and support for system_message as param (#14824)

- **Description:** added support for chat_history for Google
GenerativeAI (to actually use the `chat` API) plus since Gemini
currently doesn't have a support for SystemMessage, added support for it
only if a user provides additional `convert_system_message_to_human`
flag during model initialization (in this case, SystemMessage would be
prepanded to the first HumanMessage)
  - **Issue:** #14710 
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
  - **Twitter handle:** lkuligin

---------

Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com>
pull/14882/head
Leonid Kuligin 10 months ago committed by GitHub
parent 2861766d0d
commit 2d0f1cae8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -136,6 +136,32 @@
"print(result.content)"
]
},
{
"cell_type": "markdown",
"id": "9e55d043-bb2f-44e3-9134-c39a1abe3a9e",
"metadata": {},
"source": [
"Gemini doesn't support `SystemMessage` at the moment, but it can be added to the first human message in the row. If you want such behavior, just set the `convert_system_message_to_human` to True:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a64b523-9710-4d15-9944-1e3cc567a52b",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.messages import HumanMessage, SystemMessage\n",
"\n",
"model = ChatGoogleGenerativeAI(model=\"gemini-pro\", convert_system_message_to_human=True)\n",
"model(\n",
" [\n",
" SystemMessage(content=\"Answer only yes or no.\"),\n",
" HumanMessage(content=\"Is apple a fruit?\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "40773fac-b24d-476d-91c8-2da8fed99b53",

@ -37,6 +37,7 @@ from langchain_core.messages import (
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
@ -106,7 +107,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
)
def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
@ -139,7 +140,7 @@ def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
return _chat_with_retry(**kwargs)
async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
@ -172,26 +173,6 @@ async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> An
return await _achat_with_retry(**kwargs)
def _get_role(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
if message.role not in ("user", "model"):
raise ChatGoogleGenerativeAIError(
"Gemini only supports user and model roles when"
" providing it with Chat messages."
)
return message.role
elif isinstance(message, HumanMessage):
return "user"
elif isinstance(message, AIMessage):
return "model"
else:
# TODO: Gemini doesn't seem to have a concept of system messages yet.
raise ChatGoogleGenerativeAIError(
f"Message of '{message.type}' type not supported by Gemini."
" Please only provide it with Human or AI (user/assistant) messages."
)
def _is_openai_parts_format(part: dict) -> bool:
return "type" in part
@ -266,13 +247,14 @@ def _url_to_pil(image_source: str) -> Image:
def _convert_to_parts(
content: Sequence[Union[str, dict]],
raw_content: Union[str, Sequence[Union[str, dict]]],
) -> List[genai.types.PartType]:
"""Converts a list of LangChain messages into a google parts."""
parts = []
content = [raw_content] if isinstance(raw_content, str) else raw_content
for part in content:
if isinstance(part, str):
parts.append(genai.types.PartDict(text=part, inline_data=None))
parts.append(genai.types.PartDict(text=part))
elif isinstance(part, Mapping):
# OpenAI Format
if _is_openai_parts_format(part):
@ -304,27 +286,49 @@ def _convert_to_parts(
return parts
def _messages_to_genai_contents(
input_messages: Sequence[BaseMessage],
def _parse_chat_history(
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
) -> List[genai.types.ContentDict]:
"""Converts a list of messages into a Gemini API google content dicts."""
messages: List[genai.types.MessageDict] = []
raw_system_message: Optional[SystemMessage] = None
for i, message in enumerate(input_messages):
role = _get_role(message)
if isinstance(message.content, str):
parts = [message.content]
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
raise ValueError(
"""SystemMessages are not yet supported!
To automatically convert the leading SystemMessage to a HumanMessage,
set `convert_system_message_to_human` to True. Example:
llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True)
"""
)
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = message
continue
elif isinstance(message, AIMessage):
role = "model"
elif isinstance(message, HumanMessage):
role = "user"
else:
parts = _convert_to_parts(message.content)
messages.append({"role": role, "parts": parts})
if i > 0:
# Cannot have multiple messages from the same role in a row.
if role == messages[-2]["role"]:
raise ChatGoogleGenerativeAIError(
"Cannot have multiple messages from the same role in a row."
" Consider merging them into a single message with multiple"
f" parts.\nReceived: {messages}"
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)
parts = _convert_to_parts(message.content)
if raw_system_message:
if role == "model":
raise ValueError(
"SystemMessage should be followed by a HumanMessage and "
"not by AIMessage."
)
parts = _convert_to_parts(raw_system_message.content) + parts
raw_system_message = None
messages.append({"role": role, "parts": parts})
return messages
@ -457,8 +461,11 @@ Supported examples:
n: int = Field(default=1, alias="candidate_count")
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
_generative_model: Any #: :meta private:
convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will
raise an error."""
class Config:
allow_population_by_field_name = True
@ -499,7 +506,7 @@ Supported examples:
if values.get("top_k") is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
model = values["model"]
values["_generative_model"] = genai.GenerativeModel(model_name=model)
values["client"] = genai.GenerativeModel(model_name=model)
return values
@property
@ -512,18 +519,9 @@ Supported examples:
"n": self.n,
}
@property
def _generation_method(self) -> Callable:
return self._generative_model.generate_content
@property
def _async_generation_method(self) -> Callable:
return self._generative_model.generate_content_async
def _prepare_params(
self, messages: Sequence[BaseMessage], stop: Optional[List[str]], **kwargs: Any
self, stop: Optional[List[str]], **kwargs: Any
) -> Dict[str, Any]:
contents = _messages_to_genai_contents(messages)
gen_config = {
k: v
for k, v in {
@ -538,7 +536,7 @@ Supported examples:
}
if "generation_config" in kwargs:
gen_config = {**gen_config, **kwargs.pop("generation_config")}
params = {"generation_config": gen_config, "contents": contents, **kwargs}
params = {"generation_config": gen_config, **kwargs}
return params
def _generate(
@ -548,10 +546,11 @@ Supported examples:
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params,
generation_method=self._generation_method,
generation_method=chat.send_message,
)
return _response_to_result(response)
@ -562,10 +561,11 @@ Supported examples:
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = await _achat_with_retry(
content=message,
**params,
generation_method=self._async_generation_method,
generation_method=chat.send_message_async,
)
return _response_to_result(response)
@ -576,10 +576,11 @@ Supported examples:
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params,
generation_method=self._generation_method,
generation_method=chat.send_message,
stream=True,
)
for chunk in response:
@ -602,10 +603,11 @@ Supported examples:
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
async for chunk in await _achat_with_retry(
content=message,
**params,
generation_method=self._async_generation_method,
generation_method=chat.send_message_async,
stream=True,
):
_chat_result = _response_to_result(
@ -619,3 +621,18 @@ Supported examples:
yield gen
if run_manager:
await run_manager.on_llm_new_token(gen.text)
def _prepare_chat(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
params = self._prepare_params(stop, **kwargs)
history = _parse_chat_history(
messages,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history.pop()
chat = self.client.start_chat(history=history)
return params, chat, message

@ -1,6 +1,6 @@
"""Test ChatGoogleGenerativeAI chat model."""
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
@ -147,3 +147,40 @@ def test_chat_google_genai_invoke_multimodal_invalid_model() -> None:
llm = ChatGoogleGenerativeAI(model=_MODEL)
with pytest.raises(ChatGoogleGenerativeAIError):
llm.invoke(messages)
def test_chat_google_genai_single_call_with_history() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_google_genai_system_message_error() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
with pytest.raises(ValueError):
model([system_message, message1, message2, message3])
def test_chat_google_genai_system_message() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL, convert_system_message_to_human=True)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([system_message, message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)

@ -1,8 +1,12 @@
"""Test chat model integration."""
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
_parse_chat_history,
)
def test_integration_initialization() -> None:
@ -36,3 +40,21 @@ def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> N
captured = capsys.readouterr()
assert captured.out == "**********"
def test_parse_history() -> None:
system_input = "You're supposed to answer math questions."
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
history = _parse_chat_history(messages, convert_system_message_to_human=True)
assert len(history) == 3
assert history[0] == {
"role": "user",
"parts": [{"text": system_input}, {"text": text_question1}],
}
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}

Loading…
Cancel
Save