Added chat history to codey models (#8831)

#7469

since 1.29.0, Vertex SDK supports a chat history provided to a codey
chat model.

Co-authored-by: Leonid Kuligin <kuligin@google.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/8870/head
Leonid Kuligin 1 year ago committed by GitHub
parent a616e19975
commit 6e3fa59073
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -111,7 +111,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
values["client"] = ChatModel.from_pretrained(values["model_name"])
except ImportError:
raise_vertex_import_error(minimum_expected_version="1.28.0")
raise_vertex_import_error(minimum_expected_version="1.29.0")
return values
def _generate(
@ -155,7 +155,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
context=context, message_history=history.history, **params
)
else:
chat = self.client.start_chat(**params)
chat = self.client.start_chat(message_history=history.history, **params)
response = chat.send_message(question.content)
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])

@ -16,8 +16,12 @@ from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
def test_vertexai_single_call() -> None:
model = ChatVertexAI()
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
def test_vertexai_single_call(model_name: str) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
model = ChatVertexAI()
message = HumanMessage(content="Hello")
response = model([message])
assert isinstance(response, AIMessage)
@ -56,6 +60,22 @@ def test_vertexai_single_call_with_examples() -> None:
assert isinstance(response.content, str)
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
def test_vertexai_single_call_with_history(model_name: str) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name)
else:
model = ChatVertexAI()
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_parse_chat_history_correct() -> None:
from vertexai.language_models import ChatMessage
@ -145,7 +165,7 @@ def test_parse_examples_correct() -> None:
]
def test_parse_exmaples_failes_wrong_sequence() -> None:
def test_parse_examples_failes_wrong_sequence() -> None:
with pytest.raises(ValueError) as exc_info:
_ = _parse_examples([AIMessage(content="a")])
print(str(exc_info.value))

Loading…
Cancel
Save