langchain/libs/community/langchain_community/chat_models/vertexai.py
mackong 31891092d8
community[patch]: add missing chunk parameter for _stream/_astream (#17807)
- Description: Add missing chunk parameter for _stream/_astream for some
chat models, make all chat models in a consistent behaviour.
- Issue: N/A
- Dependencies: N/A
2024-02-21 15:32:28 -08:00

393 lines
14 KiB
Python

"""Wrapper around Google VertexAI chat-based models."""
from __future__ import annotations
import base64
import logging
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast
from urllib.parse import urlparse
import requests
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator
from langchain_community.llms.vertexai import (
_VertexAICommon,
is_codey_model,
is_gemini_model,
)
from langchain_community.utilities.vertexai import (
load_image_from_gcs,
raise_vertex_import_error,
)
if TYPE_CHECKING:
from vertexai.language_models import (
ChatMessage,
ChatSession,
CodeChatSession,
InputOutputTextPair,
)
from vertexai.preview.generative_models import Content
logger = logging.getLogger(__name__)
@dataclass
class _ChatHistory:
"""Represents a context and a history of messages."""
history: List["ChatMessage"] = field(default_factory=list)
context: Optional[str] = None
def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
"""Parse a sequence of messages into history.
Args:
history: The list of messages to re-create the history of the chat.
Returns:
A parsed chat history.
Raises:
ValueError: If a sequence of message has a SystemMessage not at the
first place.
"""
from vertexai.language_models import ChatMessage
vertex_messages, context = [], None
for i, message in enumerate(history):
content = cast(str, message.content)
if i == 0 and isinstance(message, SystemMessage):
context = content
elif isinstance(message, AIMessage):
vertex_message = ChatMessage(content=message.content, author="bot")
vertex_messages.append(vertex_message)
elif isinstance(message, HumanMessage):
vertex_message = ChatMessage(content=message.content, author="user")
vertex_messages.append(vertex_message)
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)
chat_history = _ChatHistory(context=context, history=vertex_messages)
return chat_history
def _is_url(s: str) -> bool:
try:
result = urlparse(s)
return all([result.scheme, result.netloc])
except Exception as e:
logger.debug(f"Unable to parse URL: {e}")
return False
def _parse_chat_history_gemini(
history: List[BaseMessage], project: Optional[str]
) -> List["Content"]:
from vertexai.preview.generative_models import Content, Image, Part
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
if isinstance(part, str):
return Part.from_text(part)
if not isinstance(part, Dict):
raise ValueError(
f"Message's content is expected to be a dict, got {type(part)}!"
)
if part["type"] == "text":
return Part.from_text(part["text"])
elif part["type"] == "image_url":
path = part["image_url"]["url"]
if path.startswith("gs://"):
image = load_image_from_gcs(path=path, project=project)
elif path.startswith("data:image/"):
# extract base64 component from image uri
encoded: Any = re.search(r"data:image/\w{2,4};base64,(.*)", path)
if encoded:
encoded = encoded.group(1)
else:
raise ValueError(
"Invalid image uri. It should be in the format "
"data:image/<image_type>;base64,<base64_encoded_image>."
)
image = Image.from_bytes(base64.b64decode(encoded))
elif _is_url(path):
response = requests.get(path)
response.raise_for_status()
image = Image.from_bytes(response.content)
else:
image = Image.load_from_file(path)
else:
raise ValueError("Only text and image_url types are supported!")
return Part.from_image(image)
vertex_messages = []
for i, message in enumerate(history):
if i == 0 and isinstance(message, SystemMessage):
raise ValueError("SystemMessages are not yet supported!")
elif isinstance(message, AIMessage):
role = "model"
elif isinstance(message, HumanMessage):
role = "user"
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)
raw_content = message.content
if isinstance(raw_content, str):
raw_content = [raw_content]
parts = [_convert_to_prompt(part) for part in raw_content]
vertex_message = Content(role=role, parts=parts)
vertex_messages.append(vertex_message)
return vertex_messages
def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]:
from vertexai.language_models import InputOutputTextPair
if len(examples) % 2 != 0:
raise ValueError(
f"Expect examples to have an even amount of messages, got {len(examples)}."
)
example_pairs = []
input_text = None
for i, example in enumerate(examples):
if i % 2 == 0:
if not isinstance(example, HumanMessage):
raise ValueError(
f"Expected the first message in a part to be from human, got "
f"{type(example)} for the {i}th message."
)
input_text = example.content
if i % 2 == 1:
if not isinstance(example, AIMessage):
raise ValueError(
f"Expected the second message in a part to be from AI, got "
f"{type(example)} for the {i}th message."
)
pair = InputOutputTextPair(
input_text=input_text, output_text=example.content
)
example_pairs.append(pair)
return example_pairs
def _get_question(messages: List[BaseMessage]) -> HumanMessage:
"""Get the human message at the end of a list of input messages to a chat model."""
if not messages:
raise ValueError("You should provide at least one message to start the chat!")
question = messages[-1]
if not isinstance(question, HumanMessage):
raise ValueError(
f"Last message in the list should be from human, got {question.type}."
)
return question
@deprecated(
since="0.0.12",
removal="0.2.0",
alternative_import="langchain_google_vertexai.ChatVertexAI",
)
class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""`Vertex AI` Chat large language models API."""
model_name: str = "chat-bison"
"Underlying model name."
examples: Optional[List[BaseMessage]] = None
@classmethod
def is_lc_serializable(self) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "vertexai"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
is_gemini = is_gemini_model(values["model_name"])
cls._try_init_vertexai(values)
try:
from vertexai.language_models import ChatModel, CodeChatModel
if is_gemini:
from vertexai.preview.generative_models import (
GenerativeModel,
)
except ImportError:
raise_vertex_import_error()
if is_gemini:
values["client"] = GenerativeModel(model_name=values["model_name"])
else:
if is_codey_model(values["model_name"]):
model_cls = CodeChatModel
else:
model_cls = ChatModel
values["client"] = model_cls.from_pretrained(values["model_name"])
return values
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment.
stream: Whether to use the streaming endpoint.
Returns:
The ChatResult that contains outputs generated by the model.
Raises:
ValueError: if the last message in the list is not from human.
"""
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
question = _get_question(messages)
params = self._prepare_params(stop=stop, stream=False, **kwargs)
msg_params = {}
if "candidate_count" in params:
msg_params["candidate_count"] = params.pop("candidate_count")
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
response = chat.send_message(message, generation_config=params)
else:
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples") or self.examples
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = chat.send_message(question.content, **msg_params)
generations = [
ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates
]
return ChatResult(generations=generations)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Asynchronously generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment.
Returns:
The ChatResult that contains outputs generated by the model.
Raises:
ValueError: if the last message in the list is not from human.
"""
if "stream" in kwargs:
kwargs.pop("stream")
logger.warning("ChatVertexAI does not currently support async streaming.")
params = self._prepare_params(stop=stop, **kwargs)
msg_params = {}
if "candidate_count" in params:
msg_params["candidate_count"] = params.pop("candidate_count")
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
response = await chat.send_message_async(message, generation_config=params)
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = await chat.send_message_async(question.content, **msg_params)
generations = [
ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates
]
return ChatResult(generations=generations)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
responses = chat.send_message(
message, stream=True, generation_config=params
)
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=response.text))
if run_manager:
run_manager.on_llm_new_token(response.text, chunk=chunk)
yield chunk
def _start_chat(
self, history: _ChatHistory, **kwargs: Any
) -> Union[ChatSession, CodeChatSession]:
if not self.is_codey_model:
return self.client.start_chat(
context=history.context, message_history=history.history, **kwargs
)
else:
return self.client.start_chat(message_history=history.history, **kwargs)