You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

150 lines
5.3 KiB

"""Wrapper around Google VertexAI chat-based models."""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import (
from langchain.chat_models.base import BaseChatModel
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
from langchain.schema import (
from langchain.utilities.vertexai import raise_vertex_import_error
class _MessagePair:
"""InputOutputTextPair represents a pair of input and output texts."""
question: HumanMessage
answer: AIMessage
class _ChatHistory:
"""InputOutputTextPair represents a pair of input and output texts."""
history: List[_MessagePair] = field(default_factory=list)
system_message: Optional[SystemMessage] = None
def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
"""Parse a sequence of messages into history.
A sequence should be either (SystemMessage, HumanMessage, AIMessage,
HumanMessage, AIMessage, ...) or (HumanMessage, AIMessage, HumanMessage,
AIMessage, ...). CodeChat does not support SystemMessage.
history: The list of messages to re-create the history of the chat.
A parsed chat history.
ValueError: If a sequence of message is odd, or a human message is not followed
by a message from AI (e.g., Human, Human, AI or AI, AI, Human).
if not history:
return _ChatHistory()
first_message = history[0]
system_message = first_message if isinstance(first_message, SystemMessage) else None
chat_history = _ChatHistory(system_message=system_message)
messages_left = history[1:] if system_message else history
if len(messages_left) % 2 != 0:
raise ValueError(
f"Amount of messages in history should be even, got {len(messages_left)}!"
for question, answer in zip(messages_left[::2], messages_left[1::2]):
if not isinstance(question, HumanMessage) or not isinstance(answer, AIMessage):
raise ValueError(
"A human message should follow a bot one, "
f"got {question.type}, {answer.type}."
chat_history.history.append(_MessagePair(question=question, answer=answer))
return chat_history
class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""Wrapper around Vertex AI large language models."""
model_name: str = "chat-bison"
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
if is_codey_model(values["model_name"]):
from vertexai.preview.language_models import CodeChatModel
values["client"] = CodeChatModel.from_pretrained(values["model_name"])
from vertexai.preview.language_models import ChatModel
values["client"] = ChatModel.from_pretrained(values["model_name"])
except ImportError:
return values
def _generate(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate next turn in the conversation.
messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment.
The ChatResult that contains outputs generated by the model.
ValueError: if the last message in the list is not from human.
if not messages:
raise ValueError(
"You should provide at least one message to start the chat!"
question = messages[-1]
if not isinstance(question, HumanMessage):
raise ValueError(
f"Last message in the list should be from human, got {question.type}."
history = _parse_chat_history(messages[:-1])
context = history.system_message.content if history.system_message else None
params = {**self._default_params, **kwargs}
if not self.is_codey_model:
params["context"] = context
chat = self.client.start_chat(**params)
for pair in history.history:
chat._history.append((pair.question.content, pair.answer.content))
response = chat.send_message(question.content, **params)
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
async def _agenerate(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError(
"""Vertex AI doesn't support async requests at the moment."""