|
|
|
@ -15,6 +15,11 @@ from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
|
|
|
|
from langchain_core.utils import get_from_dict_or_env
|
|
|
|
|
|
|
|
|
|
from langchain_google_genai._enums import (
|
|
|
|
|
HarmBlockThreshold,
|
|
|
|
|
HarmCategory,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GoogleModelFamily(str, Enum):
|
|
|
|
|
GEMINI = auto()
|
|
|
|
@ -77,7 +82,10 @@ def _completion_with_retry(
|
|
|
|
|
try:
|
|
|
|
|
if is_gemini:
|
|
|
|
|
return llm.client.generate_content(
|
|
|
|
|
contents=prompt, stream=stream, generation_config=generation_config
|
|
|
|
|
contents=prompt,
|
|
|
|
|
stream=stream,
|
|
|
|
|
generation_config=generation_config,
|
|
|
|
|
safety_settings=kwargs.pop("safety_settings", None),
|
|
|
|
|
)
|
|
|
|
|
return llm.client.generate_text(prompt=prompt, **kwargs)
|
|
|
|
|
except google.api_core.exceptions.FailedPrecondition as exc:
|
|
|
|
@ -143,6 +151,22 @@ Supported examples:
|
|
|
|
|
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
|
|
|
|
|
"""The default safety settings to use for all generations.
|
|
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
|
|
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
|
|
|
|
|
|
|
|
|
|
safety_settings = {
|
|
|
|
|
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
|
|
|
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
|
|
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
|
|
|
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
|
|
|
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
|
|
|
}
|
|
|
|
|
""" # noqa: E501
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def lc_secrets(self) -> Dict[str, str]:
|
|
|
|
|
return {"google_api_key": "GOOGLE_API_KEY"}
|
|
|
|
@ -184,6 +208,8 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
|
|
|
|
)
|
|
|
|
|
model_name = values["model"]
|
|
|
|
|
|
|
|
|
|
safety_settings = values["safety_settings"]
|
|
|
|
|
|
|
|
|
|
if isinstance(google_api_key, SecretStr):
|
|
|
|
|
google_api_key = google_api_key.get_secret_value()
|
|
|
|
|
|
|
|
|
@ -193,8 +219,15 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
|
|
|
|
client_options=values.get("client_options"),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if safety_settings and (
|
|
|
|
|
not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI
|
|
|
|
|
):
|
|
|
|
|
raise ValueError("Safety settings are only supported for Gemini models")
|
|
|
|
|
|
|
|
|
|
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
|
|
|
|
|
values["client"] = genai.GenerativeModel(model_name=model_name)
|
|
|
|
|
values["client"] = genai.GenerativeModel(
|
|
|
|
|
model_name=model_name, safety_settings=safety_settings
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
values["client"] = genai
|
|
|
|
|
|
|
|
|
@ -237,6 +270,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
|
|
|
|
is_gemini=True,
|
|
|
|
|
run_manager=run_manager,
|
|
|
|
|
generation_config=generation_config,
|
|
|
|
|
safety_settings=kwargs.pop("safety_settings", None),
|
|
|
|
|
)
|
|
|
|
|
candidates = [
|
|
|
|
|
"".join([p.text for p in c.content.parts]) for c in res.candidates
|
|
|
|
@ -278,6 +312,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
|
|
|
|
is_gemini=True,
|
|
|
|
|
run_manager=run_manager,
|
|
|
|
|
generation_config=generation_config,
|
|
|
|
|
safety_settings=kwargs.pop("safety_settings", None),
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
chunk = GenerationChunk(text=stream_resp.text)
|
|
|
|
|