2023-12-13 19:57:59 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import base64
|
2024-02-08 01:07:31 +00:00
|
|
|
import json
|
2023-12-13 19:57:59 +00:00
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
from io import BytesIO
|
|
|
|
from typing import (
|
|
|
|
Any,
|
|
|
|
AsyncIterator,
|
|
|
|
Callable,
|
|
|
|
Dict,
|
|
|
|
Iterator,
|
|
|
|
List,
|
|
|
|
Mapping,
|
|
|
|
Optional,
|
|
|
|
Sequence,
|
|
|
|
Tuple,
|
|
|
|
Union,
|
|
|
|
cast,
|
|
|
|
)
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
2024-02-08 21:34:46 +00:00
|
|
|
import google.ai.generativelanguage as glm
|
2024-01-24 21:58:46 +00:00
|
|
|
import google.api_core
|
|
|
|
|
2023-12-14 01:05:31 +00:00
|
|
|
# TODO: remove ignore once the google package is published with types
|
|
|
|
import google.generativeai as genai # type: ignore[import]
|
2024-02-08 21:34:46 +00:00
|
|
|
import proto # type: ignore[import]
|
2023-12-13 19:57:59 +00:00
|
|
|
import requests
|
|
|
|
from langchain_core.callbacks.manager import (
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
)
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
|
|
from langchain_core.messages import (
|
|
|
|
AIMessage,
|
2024-02-09 01:29:53 +00:00
|
|
|
AIMessageChunk,
|
2023-12-13 19:57:59 +00:00
|
|
|
BaseMessage,
|
2024-02-08 01:14:50 +00:00
|
|
|
FunctionMessage,
|
2023-12-13 19:57:59 +00:00
|
|
|
HumanMessage,
|
2023-12-19 02:23:14 +00:00
|
|
|
SystemMessage,
|
2023-12-13 19:57:59 +00:00
|
|
|
)
|
|
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
2024-01-25 04:43:16 +00:00
|
|
|
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
2023-12-13 19:57:59 +00:00
|
|
|
from langchain_core.utils import get_from_dict_or_env
|
|
|
|
from tenacity import (
|
|
|
|
before_sleep_log,
|
|
|
|
retry,
|
|
|
|
retry_if_exception_type,
|
|
|
|
stop_after_attempt,
|
|
|
|
wait_exponential,
|
|
|
|
)
|
|
|
|
|
2023-12-14 01:05:31 +00:00
|
|
|
from langchain_google_genai._common import GoogleGenerativeAIError
|
2024-02-08 01:07:31 +00:00
|
|
|
from langchain_google_genai._function_utils import (
|
|
|
|
convert_to_genai_function_declarations,
|
|
|
|
)
|
2024-01-25 04:43:16 +00:00
|
|
|
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
|
2023-12-13 19:57:59 +00:00
|
|
|
|
|
|
|
IMAGE_TYPES: Tuple = ()
|
|
|
|
try:
|
|
|
|
import PIL
|
|
|
|
from PIL.Image import Image
|
|
|
|
|
|
|
|
IMAGE_TYPES = IMAGE_TYPES + (Image,)
|
|
|
|
except ImportError:
|
|
|
|
PIL = None # type: ignore
|
|
|
|
Image = None # type: ignore
|
|
|
|
|
2023-12-14 01:05:31 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2023-12-13 19:57:59 +00:00
|
|
|
|
2023-12-14 01:05:31 +00:00
|
|
|
class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
|
2023-12-13 19:57:59 +00:00
|
|
|
"""
|
|
|
|
Custom exception class for errors associated with the `Google GenAI` API.
|
|
|
|
|
|
|
|
This exception is raised when there are specific issues related to the
|
|
|
|
Google genai API usage in the ChatGoogleGenerativeAI class, such as unsupported
|
|
|
|
message types or roles.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def _create_retry_decorator() -> Callable[[Any], Any]:
|
|
|
|
"""
|
|
|
|
Creates and returns a preconfigured tenacity retry decorator.
|
|
|
|
|
|
|
|
The retry decorator is configured to handle specific Google API exceptions
|
|
|
|
such as ResourceExhausted and ServiceUnavailable. It uses an exponential
|
|
|
|
backoff strategy for retries.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Callable[[Any], Any]: A retry decorator configured for handling specific
|
|
|
|
Google API exceptions.
|
|
|
|
"""
|
|
|
|
multiplier = 2
|
|
|
|
min_seconds = 1
|
|
|
|
max_seconds = 60
|
|
|
|
max_retries = 10
|
|
|
|
|
|
|
|
return retry(
|
|
|
|
reraise=True,
|
|
|
|
stop=stop_after_attempt(max_retries),
|
|
|
|
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
|
|
|
retry=(
|
|
|
|
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
|
|
|
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
|
|
|
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
|
|
|
),
|
|
|
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-12-19 02:23:14 +00:00
|
|
|
def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
|
2023-12-13 19:57:59 +00:00
|
|
|
"""
|
|
|
|
Executes a chat generation method with retry logic using tenacity.
|
|
|
|
|
|
|
|
This function is a wrapper that applies a retry mechanism to a provided
|
|
|
|
chat generation function. It is useful for handling intermittent issues
|
|
|
|
like network errors or temporary service unavailability.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
generation_method (Callable): The chat generation method to be executed.
|
|
|
|
**kwargs (Any): Additional keyword arguments to pass to the generation method.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: The result from the chat generation method.
|
|
|
|
"""
|
|
|
|
retry_decorator = _create_retry_decorator()
|
|
|
|
|
|
|
|
@retry_decorator
|
|
|
|
def _chat_with_retry(**kwargs: Any) -> Any:
|
|
|
|
try:
|
|
|
|
return generation_method(**kwargs)
|
2024-01-24 21:58:46 +00:00
|
|
|
# Do not retry for these errors.
|
|
|
|
except google.api_core.exceptions.FailedPrecondition as exc:
|
|
|
|
if "location is not supported" in exc.message:
|
|
|
|
error_msg = (
|
|
|
|
"Your location is not supported by google-generativeai "
|
|
|
|
"at the moment. Try to use ChatVertexAI LLM from "
|
|
|
|
"langchain_google_vertexai."
|
|
|
|
)
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
|
|
except google.api_core.exceptions.InvalidArgument as e:
|
2023-12-13 19:57:59 +00:00
|
|
|
raise ChatGoogleGenerativeAIError(
|
|
|
|
f"Invalid argument provided to Gemini: {e}"
|
|
|
|
) from e
|
|
|
|
except Exception as e:
|
|
|
|
raise e
|
|
|
|
|
|
|
|
return _chat_with_retry(**kwargs)
|
|
|
|
|
|
|
|
|
2023-12-19 02:23:14 +00:00
|
|
|
async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
|
2023-12-13 19:57:59 +00:00
|
|
|
"""
|
|
|
|
Executes a chat generation method with retry logic using tenacity.
|
|
|
|
|
|
|
|
This function is a wrapper that applies a retry mechanism to a provided
|
|
|
|
chat generation function. It is useful for handling intermittent issues
|
|
|
|
like network errors or temporary service unavailability.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
generation_method (Callable): The chat generation method to be executed.
|
|
|
|
**kwargs (Any): Additional keyword arguments to pass to the generation method.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: The result from the chat generation method.
|
|
|
|
"""
|
|
|
|
retry_decorator = _create_retry_decorator()
|
|
|
|
from google.api_core.exceptions import InvalidArgument # type: ignore
|
|
|
|
|
|
|
|
@retry_decorator
|
|
|
|
async def _achat_with_retry(**kwargs: Any) -> Any:
|
|
|
|
try:
|
|
|
|
return await generation_method(**kwargs)
|
|
|
|
except InvalidArgument as e:
|
|
|
|
# Do not retry for these errors.
|
|
|
|
raise ChatGoogleGenerativeAIError(
|
|
|
|
f"Invalid argument provided to Gemini: {e}"
|
|
|
|
) from e
|
|
|
|
except Exception as e:
|
|
|
|
raise e
|
|
|
|
|
|
|
|
return await _achat_with_retry(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def _is_openai_parts_format(part: dict) -> bool:
|
|
|
|
return "type" in part
|
|
|
|
|
|
|
|
|
|
|
|
def _is_vision_model(model: str) -> bool:
|
|
|
|
return "vision" in model
|
|
|
|
|
|
|
|
|
|
|
|
def _is_url(s: str) -> bool:
|
|
|
|
try:
|
|
|
|
result = urlparse(s)
|
|
|
|
return all([result.scheme, result.netloc])
|
|
|
|
except Exception as e:
|
|
|
|
logger.debug(f"Unable to parse URL: {e}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def _is_b64(s: str) -> bool:
|
|
|
|
return s.startswith("data:image")
|
|
|
|
|
|
|
|
|
|
|
|
def _load_image_from_gcs(path: str, project: Optional[str] = None) -> Image:
|
|
|
|
try:
|
|
|
|
from google.cloud import storage # type: ignore[attr-defined]
|
|
|
|
except ImportError:
|
|
|
|
raise ImportError(
|
|
|
|
"google-cloud-storage is required to load images from GCS."
|
|
|
|
" Install it with `pip install google-cloud-storage`"
|
|
|
|
)
|
|
|
|
if PIL is None:
|
|
|
|
raise ImportError(
|
|
|
|
"PIL is required to load images. Please install it "
|
|
|
|
"with `pip install pillow`"
|
|
|
|
)
|
|
|
|
|
|
|
|
gcs_client = storage.Client(project=project)
|
|
|
|
pieces = path.split("/")
|
|
|
|
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
|
|
|
|
if len(blobs) > 1:
|
|
|
|
raise ValueError(f"Found more than one candidate for {path}!")
|
|
|
|
img_bytes = blobs[0].download_as_bytes()
|
|
|
|
return PIL.Image.open(BytesIO(img_bytes))
|
|
|
|
|
|
|
|
|
|
|
|
def _url_to_pil(image_source: str) -> Image:
|
|
|
|
if PIL is None:
|
|
|
|
raise ImportError(
|
|
|
|
"PIL is required to load images. Please install it "
|
|
|
|
"with `pip install pillow`"
|
|
|
|
)
|
|
|
|
try:
|
|
|
|
if isinstance(image_source, IMAGE_TYPES):
|
|
|
|
return image_source # type: ignore[return-value]
|
|
|
|
elif _is_url(image_source):
|
|
|
|
if image_source.startswith("gs://"):
|
|
|
|
return _load_image_from_gcs(image_source)
|
|
|
|
response = requests.get(image_source)
|
|
|
|
response.raise_for_status()
|
|
|
|
return PIL.Image.open(BytesIO(response.content))
|
|
|
|
elif _is_b64(image_source):
|
|
|
|
_, encoded = image_source.split(",", 1)
|
|
|
|
data = base64.b64decode(encoded)
|
|
|
|
return PIL.Image.open(BytesIO(data))
|
|
|
|
elif os.path.exists(image_source):
|
|
|
|
return PIL.Image.open(image_source)
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
"The provided string is not a valid URL, base64, or file path."
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
raise ValueError(f"Unable to process the provided image source: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_to_parts(
|
2023-12-19 02:23:14 +00:00
|
|
|
raw_content: Union[str, Sequence[Union[str, dict]]],
|
2023-12-13 19:57:59 +00:00
|
|
|
) -> List[genai.types.PartType]:
|
|
|
|
"""Converts a list of LangChain messages into a google parts."""
|
|
|
|
parts = []
|
2023-12-19 02:23:14 +00:00
|
|
|
content = [raw_content] if isinstance(raw_content, str) else raw_content
|
2023-12-13 19:57:59 +00:00
|
|
|
for part in content:
|
|
|
|
if isinstance(part, str):
|
2023-12-19 02:23:14 +00:00
|
|
|
parts.append(genai.types.PartDict(text=part))
|
2023-12-13 19:57:59 +00:00
|
|
|
elif isinstance(part, Mapping):
|
|
|
|
# OpenAI Format
|
|
|
|
if _is_openai_parts_format(part):
|
|
|
|
if part["type"] == "text":
|
|
|
|
parts.append({"text": part["text"]})
|
|
|
|
elif part["type"] == "image_url":
|
|
|
|
img_url = part["image_url"]
|
|
|
|
if isinstance(img_url, dict):
|
|
|
|
if "url" not in img_url:
|
|
|
|
raise ValueError(
|
|
|
|
f"Unrecognized message image format: {img_url}"
|
|
|
|
)
|
|
|
|
img_url = img_url["url"]
|
|
|
|
parts.append({"inline_data": _url_to_pil(img_url)})
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unrecognized message part type: {part['type']}")
|
|
|
|
else:
|
|
|
|
# Yolo
|
|
|
|
logger.warning(
|
|
|
|
"Unrecognized message part format. Assuming it's a text part."
|
|
|
|
)
|
|
|
|
parts.append(part)
|
|
|
|
else:
|
|
|
|
# TODO: Maybe some of Google's native stuff
|
|
|
|
# would hit this branch.
|
|
|
|
raise ChatGoogleGenerativeAIError(
|
|
|
|
"Gemini only supports text and inline_data parts."
|
|
|
|
)
|
|
|
|
return parts
|
|
|
|
|
|
|
|
|
2023-12-19 02:23:14 +00:00
|
|
|
def _parse_chat_history(
|
|
|
|
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
|
2023-12-13 19:57:59 +00:00
|
|
|
) -> List[genai.types.ContentDict]:
|
|
|
|
messages: List[genai.types.MessageDict] = []
|
2023-12-19 02:23:14 +00:00
|
|
|
|
|
|
|
raw_system_message: Optional[SystemMessage] = None
|
2023-12-13 19:57:59 +00:00
|
|
|
for i, message in enumerate(input_messages):
|
2023-12-19 02:23:14 +00:00
|
|
|
if (
|
|
|
|
i == 0
|
|
|
|
and isinstance(message, SystemMessage)
|
|
|
|
and not convert_system_message_to_human
|
|
|
|
):
|
|
|
|
raise ValueError(
|
|
|
|
"""SystemMessages are not yet supported!
|
|
|
|
|
|
|
|
To automatically convert the leading SystemMessage to a HumanMessage,
|
|
|
|
set `convert_system_message_to_human` to True. Example:
|
|
|
|
|
|
|
|
llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True)
|
|
|
|
"""
|
|
|
|
)
|
|
|
|
elif i == 0 and isinstance(message, SystemMessage):
|
|
|
|
raw_system_message = message
|
|
|
|
continue
|
|
|
|
elif isinstance(message, AIMessage):
|
|
|
|
role = "model"
|
2024-02-08 21:34:46 +00:00
|
|
|
raw_function_call = message.additional_kwargs.get("function_call")
|
|
|
|
if raw_function_call:
|
|
|
|
function_call = glm.FunctionCall(
|
|
|
|
{
|
|
|
|
"name": raw_function_call["name"],
|
|
|
|
"args": json.loads(raw_function_call["arguments"]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
parts = [glm.Part(function_call=function_call)]
|
|
|
|
else:
|
|
|
|
parts = _convert_to_parts(message.content)
|
2023-12-19 02:23:14 +00:00
|
|
|
elif isinstance(message, HumanMessage):
|
|
|
|
role = "user"
|
2024-02-08 01:14:50 +00:00
|
|
|
parts = _convert_to_parts(message.content)
|
|
|
|
elif isinstance(message, FunctionMessage):
|
|
|
|
role = "user"
|
2024-02-09 01:29:53 +00:00
|
|
|
response: Any
|
|
|
|
if not isinstance(message.content, str):
|
|
|
|
response = message.content
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
response = json.loads(message.content)
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
response = message.content # leave as str representation
|
2024-02-08 21:34:46 +00:00
|
|
|
parts = [
|
|
|
|
glm.Part(
|
|
|
|
function_response=glm.FunctionResponse(
|
|
|
|
name=message.name,
|
2024-02-09 01:29:53 +00:00
|
|
|
response=(
|
|
|
|
{"output": response}
|
|
|
|
if not isinstance(response, dict)
|
|
|
|
else response
|
|
|
|
),
|
2024-02-08 21:34:46 +00:00
|
|
|
)
|
|
|
|
)
|
|
|
|
]
|
2023-12-13 19:57:59 +00:00
|
|
|
else:
|
2023-12-19 02:23:14 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Unexpected message with type {type(message)} at the position {i}."
|
|
|
|
)
|
|
|
|
|
|
|
|
if raw_system_message:
|
|
|
|
if role == "model":
|
|
|
|
raise ValueError(
|
|
|
|
"SystemMessage should be followed by a HumanMessage and "
|
|
|
|
"not by AIMessage."
|
2023-12-13 19:57:59 +00:00
|
|
|
)
|
2023-12-19 02:23:14 +00:00
|
|
|
parts = _convert_to_parts(raw_system_message.content) + parts
|
|
|
|
raw_system_message = None
|
|
|
|
messages.append({"role": role, "parts": parts})
|
2023-12-13 19:57:59 +00:00
|
|
|
return messages
|
|
|
|
|
|
|
|
|
2024-02-09 01:29:53 +00:00
|
|
|
def _parse_response_candidate(
|
|
|
|
response_candidate: glm.Candidate, stream: bool
|
|
|
|
) -> AIMessage:
|
2024-02-08 21:34:46 +00:00
|
|
|
first_part = response_candidate.content.parts[0]
|
|
|
|
if first_part.function_call:
|
|
|
|
function_call = proto.Message.to_dict(first_part.function_call)
|
|
|
|
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
|
2024-02-09 01:29:53 +00:00
|
|
|
return (AIMessageChunk if stream else AIMessage)(
|
|
|
|
content="", additional_kwargs={"function_call": function_call}
|
|
|
|
)
|
2024-02-08 21:34:46 +00:00
|
|
|
else:
|
|
|
|
parts = response_candidate.content.parts
|
|
|
|
|
|
|
|
if len(parts) == 1 and parts[0].text:
|
|
|
|
content: Union[str, List[Union[str, Dict]]] = parts[0].text
|
|
|
|
else:
|
|
|
|
content = [proto.Message.to_dict(part) for part in parts]
|
2024-02-09 01:29:53 +00:00
|
|
|
return (AIMessageChunk if stream else AIMessage)(
|
|
|
|
content=content, additional_kwargs={}
|
|
|
|
)
|
2023-12-13 19:57:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _response_to_result(
|
2024-02-08 21:34:46 +00:00
|
|
|
response: glm.GenerateContentResponse,
|
2024-02-09 01:29:53 +00:00
|
|
|
stream: bool = False,
|
2023-12-13 19:57:59 +00:00
|
|
|
) -> ChatResult:
|
|
|
|
"""Converts a PaLM API response into a LangChain ChatResult."""
|
2024-02-08 21:34:46 +00:00
|
|
|
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
|
2023-12-13 19:57:59 +00:00
|
|
|
|
|
|
|
generations: List[ChatGeneration] = []
|
|
|
|
|
|
|
|
for candidate in response.candidates:
|
|
|
|
generation_info = {}
|
|
|
|
if candidate.finish_reason:
|
|
|
|
generation_info["finish_reason"] = candidate.finish_reason.name
|
2024-02-08 21:34:46 +00:00
|
|
|
generation_info["safety_ratings"] = [
|
|
|
|
proto.Message.to_dict(safety_rating, use_integers_for_enums=False)
|
|
|
|
for safety_rating in candidate.safety_ratings
|
|
|
|
]
|
|
|
|
generations.append(
|
2024-02-09 01:29:53 +00:00
|
|
|
(ChatGenerationChunk if stream else ChatGeneration)(
|
|
|
|
message=_parse_response_candidate(candidate, stream=stream),
|
2024-02-08 21:34:46 +00:00
|
|
|
generation_info=generation_info,
|
|
|
|
)
|
|
|
|
)
|
2023-12-13 19:57:59 +00:00
|
|
|
if not response.candidates:
|
|
|
|
# Likely a "prompt feedback" violation (e.g., toxic input)
|
|
|
|
# Raising an error would be different than how OpenAI handles it,
|
|
|
|
# so we'll just log a warning and continue with an empty message.
|
|
|
|
logger.warning(
|
|
|
|
"Gemini produced an empty response. Continuing with empty message\n"
|
|
|
|
f"Feedback: {response.prompt_feedback}"
|
|
|
|
)
|
2024-02-08 21:34:46 +00:00
|
|
|
generations = [
|
2024-02-09 01:29:53 +00:00
|
|
|
(ChatGenerationChunk if stream else ChatGeneration)(
|
|
|
|
message=(AIMessageChunk if stream else AIMessage)(content=""),
|
|
|
|
generation_info={},
|
|
|
|
)
|
2024-02-08 21:34:46 +00:00
|
|
|
]
|
2023-12-13 19:57:59 +00:00
|
|
|
return ChatResult(generations=generations, llm_output=llm_output)
|
|
|
|
|
|
|
|
|
2024-01-25 04:43:16 +00:00
|
|
|
class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
2023-12-13 19:57:59 +00:00
|
|
|
"""`Google Generative AI` Chat models API.
|
|
|
|
|
2023-12-14 01:05:31 +00:00
|
|
|
To use, you must have either:
|
2023-12-13 19:57:59 +00:00
|
|
|
|
|
|
|
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
|
|
|
|
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
|
|
|
|
constructor.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
|
chat = ChatGoogleGenerativeAI(model="gemini-pro")
|
|
|
|
chat.invoke("Write me a ballad about LangChain")
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
client: Any #: :meta private:
|
2024-01-25 04:43:16 +00:00
|
|
|
|
2023-12-19 02:23:14 +00:00
|
|
|
convert_system_message_to_human: bool = False
|
|
|
|
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
|
|
|
|
|
|
|
Gemini does not support system messages; any unsupported messages will
|
|
|
|
raise an error."""
|
2023-12-13 19:57:59 +00:00
|
|
|
|
|
|
|
class Config:
|
|
|
|
allow_population_by_field_name = True
|
|
|
|
|
|
|
|
@property
|
|
|
|
def lc_secrets(self) -> Dict[str, str]:
|
|
|
|
return {"google_api_key": "GOOGLE_API_KEY"}
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
return "chat-google-generative-ai"
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def is_lc_serializable(self) -> bool:
|
|
|
|
return True
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
2024-01-12 16:52:00 +00:00
|
|
|
"""Validates params and passes them to google-generativeai package."""
|
2023-12-13 19:57:59 +00:00
|
|
|
google_api_key = get_from_dict_or_env(
|
|
|
|
values, "google_api_key", "GOOGLE_API_KEY"
|
|
|
|
)
|
2023-12-14 01:05:31 +00:00
|
|
|
if isinstance(google_api_key, SecretStr):
|
|
|
|
google_api_key = google_api_key.get_secret_value()
|
2024-01-12 16:52:00 +00:00
|
|
|
|
|
|
|
genai.configure(
|
|
|
|
api_key=google_api_key,
|
|
|
|
transport=values.get("transport"),
|
|
|
|
client_options=values.get("client_options"),
|
|
|
|
)
|
2023-12-13 19:57:59 +00:00
|
|
|
if (
|
|
|
|
values.get("temperature") is not None
|
|
|
|
and not 0 <= values["temperature"] <= 1
|
|
|
|
):
|
|
|
|
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
|
|
|
|
|
|
|
if values.get("top_p") is not None and not 0 <= values["top_p"] <= 1:
|
|
|
|
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
|
|
|
|
|
|
|
if values.get("top_k") is not None and values["top_k"] <= 0:
|
|
|
|
raise ValueError("top_k must be positive")
|
|
|
|
model = values["model"]
|
2023-12-19 02:23:14 +00:00
|
|
|
values["client"] = genai.GenerativeModel(model_name=model)
|
2023-12-13 19:57:59 +00:00
|
|
|
return values
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
|
|
"""Get the identifying parameters."""
|
|
|
|
return {
|
|
|
|
"model": self.model,
|
|
|
|
"temperature": self.temperature,
|
|
|
|
"top_k": self.top_k,
|
|
|
|
"n": self.n,
|
|
|
|
}
|
|
|
|
|
|
|
|
def _prepare_params(
|
2023-12-19 02:23:14 +00:00
|
|
|
self, stop: Optional[List[str]], **kwargs: Any
|
2023-12-13 19:57:59 +00:00
|
|
|
) -> Dict[str, Any]:
|
|
|
|
gen_config = {
|
|
|
|
k: v
|
|
|
|
for k, v in {
|
|
|
|
"candidate_count": self.n,
|
|
|
|
"temperature": self.temperature,
|
|
|
|
"stop_sequences": stop,
|
|
|
|
"max_output_tokens": self.max_output_tokens,
|
|
|
|
"top_k": self.top_k,
|
|
|
|
"top_p": self.top_p,
|
|
|
|
}.items()
|
|
|
|
if v is not None
|
|
|
|
}
|
|
|
|
if "generation_config" in kwargs:
|
|
|
|
gen_config = {**gen_config, **kwargs.pop("generation_config")}
|
2023-12-19 02:23:14 +00:00
|
|
|
params = {"generation_config": gen_config, **kwargs}
|
2023-12-13 19:57:59 +00:00
|
|
|
return params
|
|
|
|
|
|
|
|
def _generate(
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> ChatResult:
|
2024-02-07 20:09:30 +00:00
|
|
|
params, chat, message = self._prepare_chat(
|
|
|
|
messages,
|
|
|
|
stop=stop,
|
|
|
|
functions=kwargs.get("functions"),
|
|
|
|
)
|
2023-12-14 01:05:31 +00:00
|
|
|
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
2023-12-19 02:23:14 +00:00
|
|
|
content=message,
|
2023-12-13 19:57:59 +00:00
|
|
|
**params,
|
2023-12-19 02:23:14 +00:00
|
|
|
generation_method=chat.send_message,
|
2023-12-13 19:57:59 +00:00
|
|
|
)
|
|
|
|
return _response_to_result(response)
|
|
|
|
|
|
|
|
async def _agenerate(
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> ChatResult:
|
2024-02-07 20:09:30 +00:00
|
|
|
params, chat, message = self._prepare_chat(
|
|
|
|
messages,
|
|
|
|
stop=stop,
|
|
|
|
functions=kwargs.get("functions"),
|
|
|
|
)
|
2023-12-14 01:05:31 +00:00
|
|
|
response: genai.types.GenerateContentResponse = await _achat_with_retry(
|
2023-12-19 02:23:14 +00:00
|
|
|
content=message,
|
2023-12-13 19:57:59 +00:00
|
|
|
**params,
|
2023-12-19 02:23:14 +00:00
|
|
|
generation_method=chat.send_message_async,
|
2023-12-13 19:57:59 +00:00
|
|
|
)
|
|
|
|
return _response_to_result(response)
|
|
|
|
|
|
|
|
def _stream(
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Iterator[ChatGenerationChunk]:
|
2024-02-07 20:09:30 +00:00
|
|
|
params, chat, message = self._prepare_chat(
|
|
|
|
messages,
|
|
|
|
stop=stop,
|
|
|
|
functions=kwargs.get("functions"),
|
|
|
|
)
|
2023-12-14 01:05:31 +00:00
|
|
|
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
2023-12-19 02:23:14 +00:00
|
|
|
content=message,
|
2023-12-13 19:57:59 +00:00
|
|
|
**params,
|
2023-12-19 02:23:14 +00:00
|
|
|
generation_method=chat.send_message,
|
2023-12-13 19:57:59 +00:00
|
|
|
stream=True,
|
|
|
|
)
|
|
|
|
for chunk in response:
|
2024-02-09 01:29:53 +00:00
|
|
|
_chat_result = _response_to_result(chunk, stream=True)
|
2023-12-13 19:57:59 +00:00
|
|
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
google-genai[patch]: on_llm_new_token fix (#16924)
### This pull request makes the following changes:
* Fixed issue #16913
Fixed the google gen ai chat_models.py code to make sure that the
callback is called before the token is yielded
<!-- Thank you for contributing to LangChain!
Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.
Replace this entire comment with:
- **Description:** a description of the change,
- **Issue:** the issue # it fixes if applicable,
- **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!
Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.
See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
-->
---------
Co-authored-by: Erick Friis <erick@langchain.dev>
2024-02-10 02:00:24 +00:00
|
|
|
|
2023-12-13 19:57:59 +00:00
|
|
|
if run_manager:
|
|
|
|
run_manager.on_llm_new_token(gen.text)
|
2024-02-08 21:13:46 +00:00
|
|
|
yield gen
|
2023-12-13 19:57:59 +00:00
|
|
|
|
|
|
|
async def _astream(
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> AsyncIterator[ChatGenerationChunk]:
|
2024-02-07 20:09:30 +00:00
|
|
|
params, chat, message = self._prepare_chat(
|
|
|
|
messages,
|
|
|
|
stop=stop,
|
|
|
|
functions=kwargs.get("functions"),
|
|
|
|
)
|
2023-12-14 01:05:31 +00:00
|
|
|
async for chunk in await _achat_with_retry(
|
2023-12-19 02:23:14 +00:00
|
|
|
content=message,
|
2023-12-13 19:57:59 +00:00
|
|
|
**params,
|
2023-12-19 02:23:14 +00:00
|
|
|
generation_method=chat.send_message_async,
|
2023-12-13 19:57:59 +00:00
|
|
|
stream=True,
|
|
|
|
):
|
2024-02-09 01:29:53 +00:00
|
|
|
_chat_result = _response_to_result(chunk, stream=True)
|
2023-12-13 19:57:59 +00:00
|
|
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
google-genai[patch]: on_llm_new_token fix (#16924)
### This pull request makes the following changes:
* Fixed issue #16913
Fixed the google gen ai chat_models.py code to make sure that the
callback is called before the token is yielded
<!-- Thank you for contributing to LangChain!
Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.
Replace this entire comment with:
- **Description:** a description of the change,
- **Issue:** the issue # it fixes if applicable,
- **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!
Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.
See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/
If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.
If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
-->
---------
Co-authored-by: Erick Friis <erick@langchain.dev>
2024-02-10 02:00:24 +00:00
|
|
|
|
2023-12-13 19:57:59 +00:00
|
|
|
if run_manager:
|
|
|
|
await run_manager.on_llm_new_token(gen.text)
|
2024-02-08 21:13:46 +00:00
|
|
|
yield gen
|
2023-12-19 02:23:14 +00:00
|
|
|
|
|
|
|
def _prepare_chat(
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
|
2024-02-08 01:07:31 +00:00
|
|
|
client = self.client
|
2024-02-07 20:09:30 +00:00
|
|
|
functions = kwargs.pop("functions", None)
|
|
|
|
if functions:
|
2024-02-08 01:07:31 +00:00
|
|
|
tools = convert_to_genai_function_declarations(functions)
|
|
|
|
client = genai.GenerativeModel(model_name=self.model, tools=tools)
|
2024-02-07 20:09:30 +00:00
|
|
|
|
2023-12-19 02:23:14 +00:00
|
|
|
params = self._prepare_params(stop, **kwargs)
|
|
|
|
history = _parse_chat_history(
|
|
|
|
messages,
|
|
|
|
convert_system_message_to_human=self.convert_system_message_to_human,
|
|
|
|
)
|
|
|
|
message = history.pop()
|
2024-02-08 01:07:31 +00:00
|
|
|
chat = client.start_chat(history=history)
|
2023-12-19 02:23:14 +00:00
|
|
|
return params, chat, message
|
2024-01-25 04:43:16 +00:00
|
|
|
|
|
|
|
def get_num_tokens(self, text: str) -> int:
|
|
|
|
"""Get the number of tokens present in the text.
|
|
|
|
|
|
|
|
Useful for checking if an input will fit in a model's context window.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
text: The string input to tokenize.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The integer number of tokens in the text.
|
|
|
|
"""
|
|
|
|
if self._model_family == GoogleModelFamily.GEMINI:
|
|
|
|
result = self.client.count_tokens(text)
|
|
|
|
token_count = result.total_tokens
|
|
|
|
else:
|
|
|
|
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
|
|
|
token_count = result["token_count"]
|
|
|
|
|
|
|
|
return token_count
|