mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Added chat history to codey models (#8831)
#7469 since 1.29.0, Vertex SDK supports a chat history provided to a codey chat model. Co-authored-by: Leonid Kuligin <kuligin@google.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
a616e19975
commit
6e3fa59073
@ -111,7 +111,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
|
|
||||||
values["client"] = ChatModel.from_pretrained(values["model_name"])
|
values["client"] = ChatModel.from_pretrained(values["model_name"])
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise_vertex_import_error(minimum_expected_version="1.28.0")
|
raise_vertex_import_error(minimum_expected_version="1.29.0")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -155,7 +155,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
context=context, message_history=history.history, **params
|
context=context, message_history=history.history, **params
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chat = self.client.start_chat(**params)
|
chat = self.client.start_chat(message_history=history.history, **params)
|
||||||
response = chat.send_message(question.content)
|
response = chat.send_message(question.content)
|
||||||
text = self._enforce_stop_words(response.text, stop)
|
text = self._enforce_stop_words(response.text, stop)
|
||||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
||||||
|
@ -16,8 +16,12 @@ from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples
|
|||||||
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
def test_vertexai_single_call() -> None:
|
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
|
||||||
model = ChatVertexAI()
|
def test_vertexai_single_call(model_name: str) -> None:
|
||||||
|
if model_name:
|
||||||
|
model = ChatVertexAI(model_name=model_name)
|
||||||
|
else:
|
||||||
|
model = ChatVertexAI()
|
||||||
message = HumanMessage(content="Hello")
|
message = HumanMessage(content="Hello")
|
||||||
response = model([message])
|
response = model([message])
|
||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
@ -56,6 +60,22 @@ def test_vertexai_single_call_with_examples() -> None:
|
|||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
|
||||||
|
def test_vertexai_single_call_with_history(model_name: str) -> None:
|
||||||
|
if model_name:
|
||||||
|
model = ChatVertexAI(model_name=model_name)
|
||||||
|
else:
|
||||||
|
model = ChatVertexAI()
|
||||||
|
text_question1, text_answer1 = "How much is 2+2?", "4"
|
||||||
|
text_question2 = "How much is 3+3?"
|
||||||
|
message1 = HumanMessage(content=text_question1)
|
||||||
|
message2 = AIMessage(content=text_answer1)
|
||||||
|
message3 = HumanMessage(content=text_question2)
|
||||||
|
response = model([message1, message2, message3])
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_history_correct() -> None:
|
def test_parse_chat_history_correct() -> None:
|
||||||
from vertexai.language_models import ChatMessage
|
from vertexai.language_models import ChatMessage
|
||||||
|
|
||||||
@ -145,7 +165,7 @@ def test_parse_examples_correct() -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_parse_exmaples_failes_wrong_sequence() -> None:
|
def test_parse_examples_failes_wrong_sequence() -> None:
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
_ = _parse_examples([AIMessage(content="a")])
|
_ = _parse_examples([AIMessage(content="a")])
|
||||||
print(str(exc_info.value))
|
print(str(exc_info.value))
|
||||||
|
Loading…
Reference in New Issue
Block a user