Added chat history to codey models (#8831)

#7469

since 1.29.0, Vertex SDK supports a chat history provided to a codey
chat model.

Co-authored-by: Leonid Kuligin <kuligin@google.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Leonid Kuligin 2023-08-07 16:34:35 +02:00 committed by GitHub
parent a616e19975
commit 6e3fa59073
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 5 deletions

View File

@ -111,7 +111,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
values["client"] = ChatModel.from_pretrained(values["model_name"]) values["client"] = ChatModel.from_pretrained(values["model_name"])
except ImportError: except ImportError:
raise_vertex_import_error(minimum_expected_version="1.28.0") raise_vertex_import_error(minimum_expected_version="1.29.0")
return values return values
def _generate( def _generate(
@ -155,7 +155,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
context=context, message_history=history.history, **params context=context, message_history=history.history, **params
) )
else: else:
chat = self.client.start_chat(**params) chat = self.client.start_chat(message_history=history.history, **params)
response = chat.send_message(question.content) response = chat.send_message(question.content)
text = self._enforce_stop_words(response.text, stop) text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])

View File

@ -16,8 +16,12 @@ from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
def test_vertexai_single_call() -> None: @pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
model = ChatVertexAI() def test_vertexai_single_call(model_name: str) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
model = ChatVertexAI()
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
response = model([message]) response = model([message])
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
@ -56,6 +60,22 @@ def test_vertexai_single_call_with_examples() -> None:
assert isinstance(response.content, str) assert isinstance(response.content, str)
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
def test_vertexai_single_call_with_history(model_name: str) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
model = ChatVertexAI()
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_parse_chat_history_correct() -> None: def test_parse_chat_history_correct() -> None:
from vertexai.language_models import ChatMessage from vertexai.language_models import ChatMessage
@ -145,7 +165,7 @@ def test_parse_examples_correct() -> None:
] ]
def test_parse_exmaples_failes_wrong_sequence() -> None: def test_parse_examples_failes_wrong_sequence() -> None:
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
_ = _parse_examples([AIMessage(content="a")]) _ = _parse_examples([AIMessage(content="a")])
print(str(exc_info.value)) print(str(exc_info.value))