|
|
|
@ -1,9 +1,11 @@
|
|
|
|
|
"""Wrapper around Google VertexAI chat-based models."""
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
|
|
|
from langchain.chat_models.base import BaseChatModel
|
|
|
|
|
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
|
|
|
|
|
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
|
|
|
|
from langchain.pydantic_v1 import root_validator
|
|
|
|
|
from langchain.schema import (
|
|
|
|
@ -12,14 +14,21 @@ from langchain.schema import (
|
|
|
|
|
)
|
|
|
|
|
from langchain.schema.messages import (
|
|
|
|
|
AIMessage,
|
|
|
|
|
AIMessageChunk,
|
|
|
|
|
BaseMessage,
|
|
|
|
|
HumanMessage,
|
|
|
|
|
SystemMessage,
|
|
|
|
|
)
|
|
|
|
|
from langchain.schema.output import ChatGenerationChunk
|
|
|
|
|
from langchain.utilities.vertexai import raise_vertex_import_error
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from vertexai.language_models import ChatMessage, InputOutputTextPair
|
|
|
|
|
from vertexai.language_models import (
|
|
|
|
|
ChatMessage,
|
|
|
|
|
ChatSession,
|
|
|
|
|
CodeChatSession,
|
|
|
|
|
InputOutputTextPair,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
@ -91,10 +100,23 @@ def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]:
|
|
|
|
|
return example_pairs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_question(messages: List[BaseMessage]) -> HumanMessage:
|
|
|
|
|
"""Get the human message at the end of a list of input messages to a chat model."""
|
|
|
|
|
if not messages:
|
|
|
|
|
raise ValueError("You should provide at least one message to start the chat!")
|
|
|
|
|
question = messages[-1]
|
|
|
|
|
if not isinstance(question, HumanMessage):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Last message in the list should be from human, got {question.type}."
|
|
|
|
|
)
|
|
|
|
|
return question
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|
|
|
|
"""`Vertex AI` Chat large language models API."""
|
|
|
|
|
|
|
|
|
|
model_name: str = "chat-bison"
|
|
|
|
|
streaming: bool = False
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
@ -118,6 +140,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
|
stream: Optional[bool] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
"""Generate next turn in the conversation.
|
|
|
|
@ -127,6 +150,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|
|
|
|
does not support context.
|
|
|
|
|
stop: The list of stop words (optional).
|
|
|
|
|
run_manager: The CallbackManager for LLM run, it's not used at the moment.
|
|
|
|
|
stream: Whether to use the streaming endpoint.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The ChatResult that contains outputs generated by the model.
|
|
|
|
@ -134,27 +158,53 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: if the last message in the list is not from human.
|
|
|
|
|
"""
|
|
|
|
|
if not messages:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"You should provide at least one message to start the chat!"
|
|
|
|
|
)
|
|
|
|
|
question = messages[-1]
|
|
|
|
|
if not isinstance(question, HumanMessage):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Last message in the list should be from human, got {question.type}."
|
|
|
|
|
should_stream = stream if stream is not None else self.streaming
|
|
|
|
|
if should_stream:
|
|
|
|
|
stream_iter = self._stream(
|
|
|
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
|
|
|
)
|
|
|
|
|
return _generate_from_stream(stream_iter)
|
|
|
|
|
|
|
|
|
|
question = _get_question(messages)
|
|
|
|
|
history = _parse_chat_history(messages[:-1])
|
|
|
|
|
context = history.context if history.context else None
|
|
|
|
|
params = {**self._default_params, **kwargs}
|
|
|
|
|
examples = kwargs.get("examples", None)
|
|
|
|
|
if examples:
|
|
|
|
|
params["examples"] = _parse_examples(examples)
|
|
|
|
|
if not self.is_codey_model:
|
|
|
|
|
chat = self.client.start_chat(
|
|
|
|
|
context=context, message_history=history.history, **params
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
chat = self.client.start_chat(message_history=history.history, **params)
|
|
|
|
|
|
|
|
|
|
chat = self._start_chat(history, params)
|
|
|
|
|
response = chat.send_message(question.content)
|
|
|
|
|
text = self._enforce_stop_words(response.text, stop)
|
|
|
|
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
|
|
|
|
|
|
|
|
|
|
def _stream(
|
|
|
|
|
self,
|
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
|
|
|
question = _get_question(messages)
|
|
|
|
|
history = _parse_chat_history(messages[:-1])
|
|
|
|
|
params = {**self._default_params, **kwargs}
|
|
|
|
|
examples = kwargs.get("examples", None)
|
|
|
|
|
if examples:
|
|
|
|
|
params["examples"] = _parse_examples(examples)
|
|
|
|
|
|
|
|
|
|
chat = self._start_chat(history, params)
|
|
|
|
|
responses = chat.send_message_streaming(question.content, **params)
|
|
|
|
|
for response in responses:
|
|
|
|
|
text = self._enforce_stop_words(response.text, stop)
|
|
|
|
|
if run_manager:
|
|
|
|
|
run_manager.on_llm_new_token(text)
|
|
|
|
|
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
|
|
|
|
|
|
|
|
|
|
def _start_chat(
|
|
|
|
|
self, history: _ChatHistory, params: dict
|
|
|
|
|
) -> Union[ChatSession, CodeChatSession]:
|
|
|
|
|
if not self.is_codey_model:
|
|
|
|
|
return self.client.start_chat(
|
|
|
|
|
context=history.context, message_history=history.history, **params
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return self.client.start_chat(message_history=history.history, **params)
|
|
|
|
|