mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Added support for examples for VertexAI chat models. (#7636)
#5278 Co-authored-by: Leonid Kuligin <kuligin@google.com>
This commit is contained in:
parent
45bb414be2
commit
85e1c9b348
@ -23,7 +23,7 @@ from langchain.schema.messages import (
|
|||||||
from langchain.utilities.vertexai import raise_vertex_import_error
|
from langchain.utilities.vertexai import raise_vertex_import_error
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vertexai.language_models import ChatMessage
|
from vertexai.language_models import ChatMessage, InputOutputTextPair
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -65,6 +65,36 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
|
|||||||
return chat_history
|
return chat_history
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]:
|
||||||
|
from vertexai.language_models import InputOutputTextPair
|
||||||
|
|
||||||
|
if len(examples) % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expect examples to have an even amount of messages, got {len(examples)}."
|
||||||
|
)
|
||||||
|
example_pairs = []
|
||||||
|
input_text = None
|
||||||
|
for i, example in enumerate(examples):
|
||||||
|
if i % 2 == 0:
|
||||||
|
if not isinstance(example, HumanMessage):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected the first message in a part to be from human, got "
|
||||||
|
f"{type(example)} for the {i}th message."
|
||||||
|
)
|
||||||
|
input_text = example.content
|
||||||
|
if i % 2 == 1:
|
||||||
|
if not isinstance(example, AIMessage):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected the second message in a part to be from AI, got "
|
||||||
|
f"{type(example)} for the {i}th message."
|
||||||
|
)
|
||||||
|
pair = InputOutputTextPair(
|
||||||
|
input_text=input_text, output_text=example.content
|
||||||
|
)
|
||||||
|
example_pairs.append(pair)
|
||||||
|
return example_pairs
|
||||||
|
|
||||||
|
|
||||||
class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||||
"""Wrapper around Vertex AI large language models."""
|
"""Wrapper around Vertex AI large language models."""
|
||||||
|
|
||||||
@ -120,13 +150,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
history = _parse_chat_history(messages[:-1])
|
history = _parse_chat_history(messages[:-1])
|
||||||
context = history.context if history.context else None
|
context = history.context if history.context else None
|
||||||
params = {**self._default_params, **kwargs}
|
params = {**self._default_params, **kwargs}
|
||||||
|
examples = kwargs.get("examples", None)
|
||||||
|
if examples:
|
||||||
|
params["examples"] = _parse_examples(examples)
|
||||||
if not self.is_codey_model:
|
if not self.is_codey_model:
|
||||||
chat = self.client.start_chat(
|
chat = self.client.start_chat(
|
||||||
context=context, message_history=history.history, **params
|
context=context, message_history=history.history, **params
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chat = self.client.start_chat(**params)
|
chat = self.client.start_chat(**params)
|
||||||
response = chat.send_message(question.content, **params)
|
response = chat.send_message(question.content)
|
||||||
text = self._enforce_stop_words(response.text, stop)
|
text = self._enforce_stop_words(response.text, stop)
|
||||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||||
|
|
||||||
|
@ -7,12 +7,12 @@ pip install google-cloud-aiplatform>=1.25.0
|
|||||||
Your end-user credentials would be used to make the calls (make sure you've run
|
Your end-user credentials would be used to make the calls (make sure you've run
|
||||||
`gcloud auth login` first).
|
`gcloud auth login` first).
|
||||||
"""
|
"""
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.chat_models import ChatVertexAI
|
from langchain.chat_models import ChatVertexAI
|
||||||
from langchain.chat_models.vertexai import _parse_chat_history
|
from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples
|
||||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
@ -42,6 +42,20 @@ def test_vertexai_single_call_with_context() -> None:
|
|||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertexai_single_call_with_examples() -> None:
|
||||||
|
model = ChatVertexAI()
|
||||||
|
raw_context = "My name is Ned. You are my personal assistant."
|
||||||
|
question = "2+2"
|
||||||
|
text_question, text_answer = "4+4", "8"
|
||||||
|
inp = HumanMessage(content=text_question)
|
||||||
|
output = AIMessage(content=text_answer)
|
||||||
|
context = SystemMessage(content=raw_context)
|
||||||
|
message = HumanMessage(content=question)
|
||||||
|
response = model([context, message], examples=[inp, output])
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_history_correct() -> None:
|
def test_parse_chat_history_correct() -> None:
|
||||||
from vertexai.language_models import ChatMessage
|
from vertexai.language_models import ChatMessage
|
||||||
|
|
||||||
@ -92,17 +106,50 @@ def test_vertexai_args_passed() -> None:
|
|||||||
|
|
||||||
# Mock the library to ensure the args are passed correctly
|
# Mock the library to ensure the args are passed correctly
|
||||||
with patch(
|
with patch(
|
||||||
"vertexai.language_models._language_models.ChatSession.send_message"
|
"vertexai.language_models._language_models.ChatModel.start_chat"
|
||||||
) as send_message:
|
) as start_chat:
|
||||||
mock_response = Mock(text=response_text)
|
mock_response = Mock(text=response_text)
|
||||||
send_message.return_value = mock_response
|
mock_chat = MagicMock()
|
||||||
|
start_chat.return_value = mock_chat
|
||||||
|
mock_send_message = MagicMock(return_value=mock_response)
|
||||||
|
mock_chat.send_message = mock_send_message
|
||||||
|
|
||||||
model = ChatVertexAI(**prompt_params)
|
model = ChatVertexAI(**prompt_params)
|
||||||
message = HumanMessage(content=user_prompt)
|
message = HumanMessage(content=user_prompt)
|
||||||
response = model([message])
|
response = model([message])
|
||||||
|
|
||||||
assert response.content == response_text
|
assert response.content == response_text
|
||||||
send_message.assert_called_once_with(
|
mock_send_message.assert_called_once_with(user_prompt)
|
||||||
user_prompt,
|
start_chat.assert_called_once_with(
|
||||||
**prompt_params,
|
context=None, message_history=[], **prompt_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_examples_correct() -> None:
|
||||||
|
from vertexai.language_models import InputOutputTextPair
|
||||||
|
|
||||||
|
text_question = (
|
||||||
|
"Hello, could you recommend a good movie for me to watch this evening, please?"
|
||||||
|
)
|
||||||
|
question = HumanMessage(content=text_question)
|
||||||
|
text_answer = (
|
||||||
|
"Sure, You might enjoy The Lord of the Rings: The Fellowship of the Ring "
|
||||||
|
"(2001): This is the first movie in the Lord of the Rings trilogy."
|
||||||
|
)
|
||||||
|
answer = AIMessage(content=text_answer)
|
||||||
|
examples = _parse_examples([question, answer, question, answer])
|
||||||
|
assert len(examples) == 2
|
||||||
|
assert examples == [
|
||||||
|
InputOutputTextPair(input_text=text_question, output_text=text_answer),
|
||||||
|
InputOutputTextPair(input_text=text_question, output_text=text_answer),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_exmaples_failes_wrong_sequence() -> None:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
_ = _parse_examples([AIMessage(content="a")])
|
||||||
|
print(str(exc_info.value))
|
||||||
|
assert (
|
||||||
|
str(exc_info.value)
|
||||||
|
== "Expect examples to have an even amount of messages, got 1."
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user