|
|
|
@ -12,10 +12,7 @@ from langchain.callbacks.manager import (
|
|
|
|
|
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
|
|
|
|
|
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
|
|
|
|
from langchain.pydantic_v1 import root_validator
|
|
|
|
|
from langchain.schema import (
|
|
|
|
|
ChatGeneration,
|
|
|
|
|
ChatResult,
|
|
|
|
|
)
|
|
|
|
|
from langchain.schema import ChatGeneration, ChatResult
|
|
|
|
|
from langchain.schema.messages import (
|
|
|
|
|
AIMessage,
|
|
|
|
|
AIMessageChunk,
|
|
|
|
@ -177,16 +174,22 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|
|
|
|
|
|
|
|
|
question = _get_question(messages)
|
|
|
|
|
history = _parse_chat_history(messages[:-1])
|
|
|
|
|
params = self._prepare_params(stop=stop, **kwargs)
|
|
|
|
|
params = self._prepare_params(stop=stop, stream=False, **kwargs)
|
|
|
|
|
examples = kwargs.get("examples", None)
|
|
|
|
|
if examples:
|
|
|
|
|
params["examples"] = _parse_examples(examples)
|
|
|
|
|
|
|
|
|
|
chat = self._start_chat(history, params)
|
|
|
|
|
response = chat.send_message(question.content)
|
|
|
|
|
return ChatResult(
|
|
|
|
|
generations=[ChatGeneration(message=AIMessage(content=response.text))]
|
|
|
|
|
)
|
|
|
|
|
msg_params = {}
|
|
|
|
|
if "candidate_count" in params:
|
|
|
|
|
msg_params["candidate_count"] = params.pop("candidate_count")
|
|
|
|
|
|
|
|
|
|
chat = self._start_chat(history, **params)
|
|
|
|
|
response = chat.send_message(question.content, **msg_params)
|
|
|
|
|
generations = [
|
|
|
|
|
ChatGeneration(message=AIMessage(content=r.text))
|
|
|
|
|
for r in response.candidates
|
|
|
|
|
]
|
|
|
|
|
return ChatResult(generations=generations)
|
|
|
|
|
|
|
|
|
|
async def _agenerate(
|
|
|
|
|
self,
|
|
|
|
@ -219,11 +222,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|
|
|
|
if examples:
|
|
|
|
|
params["examples"] = _parse_examples(examples)
|
|
|
|
|
|
|
|
|
|
chat = self._start_chat(history, params)
|
|
|
|
|
response = await chat.send_message_async(question.content)
|
|
|
|
|
return ChatResult(
|
|
|
|
|
generations=[ChatGeneration(message=AIMessage(content=response.text))]
|
|
|
|
|
)
|
|
|
|
|
msg_params = {}
|
|
|
|
|
if "candidate_count" in params:
|
|
|
|
|
msg_params["candidate_count"] = params.pop("candidate_count")
|
|
|
|
|
chat = self._start_chat(history, **params)
|
|
|
|
|
response = await chat.send_message_async(question.content, **msg_params)
|
|
|
|
|
generations = [
|
|
|
|
|
ChatGeneration(message=AIMessage(content=r.text))
|
|
|
|
|
for r in response.candidates
|
|
|
|
|
]
|
|
|
|
|
return ChatResult(generations=generations)
|
|
|
|
|
|
|
|
|
|
def _stream(
|
|
|
|
|
self,
|
|
|
|
@ -239,7 +247,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|
|
|
|
if examples:
|
|
|
|
|
params["examples"] = _parse_examples(examples)
|
|
|
|
|
|
|
|
|
|
chat = self._start_chat(history, params)
|
|
|
|
|
chat = self._start_chat(history, **params)
|
|
|
|
|
responses = chat.send_message_streaming(question.content, **params)
|
|
|
|
|
for response in responses:
|
|
|
|
|
if run_manager:
|
|
|
|
@ -247,11 +255,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|
|
|
|
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
|
|
|
|
|
|
|
|
|
|
def _start_chat(
|
|
|
|
|
self, history: _ChatHistory, params: dict
|
|
|
|
|
self, history: _ChatHistory, **kwargs: Any
|
|
|
|
|
) -> Union[ChatSession, CodeChatSession]:
|
|
|
|
|
if not self.is_codey_model:
|
|
|
|
|
return self.client.start_chat(
|
|
|
|
|
context=history.context, message_history=history.history, **params
|
|
|
|
|
context=history.context, message_history=history.history, **kwargs
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return self.client.start_chat(message_history=history.history, **params)
|
|
|
|
|
return self.client.start_chat(message_history=history.history, **kwargs)
|
|
|
|
|