google-genai[patch], community[patch]: Added support for new Google GenerativeAI models (#14530)

Replace this entire comment with:
  - **Description:** added support for new Google GenerativeAI models
  - **Twitter handle:** lkuligin

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/14131/head^2
Leonid Kuligin 6 months ago committed by GitHub
parent 6bbf0797f7
commit 7f42811e14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,287 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7aZWXpbf0Eph",
"metadata": {
"id": "7aZWXpbf0Eph"
},
"source": [
"# Google AI\n"
]
},
{
"cell_type": "markdown",
"id": "bead5ede-d9cc-44b9-b062-99c90a10cf40",
"metadata": {},
"source": [
"A guide on using [Google Generative AI](https://developers.generativeai.google/) models with Langchain. Note: It's separate from Google Cloud Vertex AI [integration](https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm)."
]
},
{
"cell_type": "markdown",
"id": "H4AjsqTswBCE",
"metadata": {
"id": "H4AjsqTswBCE"
},
"source": [
"## Setting up\n"
]
},
{
"cell_type": "markdown",
"id": "EFHNUieMwJrl",
"metadata": {
"id": "EFHNUieMwJrl"
},
"source": [
"To use Google Generative AI you must install the `langchain-google-genai` Python package and generate an API key. [Read more details](https://developers.generativeai.google/)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8Qzm6SqKwgak",
"metadata": {},
"outputs": [],
"source": [
"# !pip install langchain-google-genai"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ONb7ZtOwjbo",
"metadata": {},
"outputs": [],
"source": [
"from langchain_google_genai import GoogleGenerativeAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "X3pjCW0i22gm",
"metadata": {},
"outputs": [],
"source": [
"from getpass import getpass\n",
"\n",
"api_key = getpass()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "GT50LgFP0j-w",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"**Pros of Python:**\n",
"\n",
"* **Easy to learn:** Python is a very easy-to-learn programming language, even for beginners. Its syntax is simple and straightforward, and there are a lot of resources available to help you get started.\n",
"* **Versatile:** Python can be used for a wide variety of tasks, including web development, data science, and machine learning. It's also a good choice for beginners because it can be used for a variety of projects, so you can learn the basics and then move on to more complex tasks.\n",
"* **High-level:** Python is a high-level programming language, which means that it's closer to human language than other programming languages. This makes it easier to read and understand, which can be a big advantage for beginners.\n",
"* **Open-source:** Python is an open-source programming language, which means that it's free to use and there are a lot of resources available to help you learn it.\n",
"* **Community:** Python has a large and active community of developers, which means that there are a lot of people who can help you if you get stuck.\n",
"\n",
"**Cons of Python:**\n",
"\n",
"* **Slow:** Python is a relatively slow programming language compared to some other languages, such as C++. This can be a disadvantage if you're working on computationally intensive tasks.\n",
"* **Not as performant:** Python is not as performant as some other programming languages, such as C++ or Java. This can be a disadvantage if you're working on projects that require high performance.\n",
"* **Dynamic typing:** Python is a dynamically typed programming language, which means that the type of a variable can change during runtime. This can be a disadvantage if you need to ensure that your code is type-safe.\n",
"* **Unmanaged memory:** Python uses a garbage collection system to manage memory. This can be a disadvantage if you need to have more control over memory management.\n",
"\n",
"Overall, Python is a very good programming language for beginners. It's easy to learn, versatile, and has a large community of developers. However, it's important to be aware of its limitations, such as its slow performance and lack of performance.\n"
]
}
],
"source": [
"llm = GoogleGenerativeAI(model=\"models/text-bison-001\", google_api_key=api_key)\n",
"print(\n",
" llm.invoke(\n",
" \"What are some of the pros and cons of Python as a programming language?\"\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "TSGdxkJtwl8-",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"**Pros:**\n",
"\n",
"* **Simplicity and Readability:** Python is known for its simple and easy-to-read syntax, which makes it accessible to beginners and reduces the chance of errors. It uses indentation to define blocks of code, making the code structure clear and visually appealing.\n",
"\n",
"* **Versatility:** Python is a general-purpose language, meaning it can be used for a wide range of tasks, including web development, data science, machine learning, and desktop applications. This versatility makes it a popular choice for various projects and industries.\n",
"\n",
"* **Large Community:** Python has a vast and active community of developers, which contributes to its growth and popularity. This community provides extensive documentation, tutorials, and open-source libraries, making it easy for Python developers to find support and resources.\n",
"\n",
"* **Extensive Libraries:** Python offers a rich collection of libraries and frameworks for various tasks, such as data analysis (NumPy, Pandas), web development (Django, Flask), machine learning (Scikit-learn, TensorFlow), and many more. These libraries provide pre-built functions and modules, allowing developers to quickly and efficiently solve common problems.\n",
"\n",
"* **Cross-Platform Support:** Python is cross-platform, meaning it can run on various operating systems, including Windows, macOS, and Linux. This allows developers to write code that can be easily shared and used across different platforms.\n",
"\n",
"**Cons:**\n",
"\n",
"* **Speed and Performance:** Python is generally slower than compiled languages like C++ or Java due to its interpreted nature. This can be a disadvantage for performance-intensive tasks, such as real-time systems or heavy numerical computations.\n",
"\n",
"* **Memory Usage:** Python programs tend to consume more memory compared to compiled languages. This is because Python uses a dynamic memory allocation system, which can lead to memory fragmentation and higher memory usage.\n",
"\n",
"* **Lack of Static Typing:** Python is a dynamically typed language, which means that data types are not explicitly defined for variables. This can make it challenging to detect type errors during development, which can lead to unexpected behavior or errors at runtime.\n",
"\n",
"* **GIL (Global Interpreter Lock):** Python uses a global interpreter lock (GIL) to ensure that only one thread can execute Python bytecode at a time. This can limit the scalability and parallelism of Python programs, especially in multi-threaded or multiprocessing scenarios.\n",
"\n",
"* **Package Management:** While Python has a vast ecosystem of libraries and packages, managing dependencies and package versions can be challenging. The Python Package Index (PyPI) is the official repository for Python packages, but it can be difficult to ensure compatibility and avoid conflicts between different versions of packages.\n"
]
}
],
"source": [
"llm = GoogleGenerativeAI(model=\"gemini-pro\", google_api_key=api_key)\n",
"print(\n",
" llm.invoke(\n",
" \"What are some of the pros and cons of Python as a programming language?\"\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "OQ_SlL0K1Cw6",
"metadata": {
"id": "OQ_SlL0K1Cw6"
},
"source": [
"## Using in a chain"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "Nwc9P5_ry79W",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "35856bf2-aa5e-436b-977a-9e5725b1a595",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4\n"
]
}
],
"source": [
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"prompt = PromptTemplate.from_template(template)\n",
"\n",
"chain = prompt | llm\n",
"\n",
"question = \"How much is 2+2?\"\n",
"print(chain.invoke({\"question\": question}))"
]
},
{
"cell_type": "markdown",
"id": "ueAin0xQzCqq",
"metadata": {
"id": "ueAin0xQzCqq"
},
"source": [
"## Streaming calls"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "WftL7x0A0hlF",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In winter's embrace, a silent ballet,\n",
"Snowflakes descend, a celestial display.\n",
"Whispering secrets, they softly fall,\n",
"A blanket of white, covering all.\n",
"\n",
"With gentle grace, they paint the land,\n",
"Transforming the world into a winter wonderland.\n",
"Trees stand adorned in icy splendor,\n",
"A glistening spectacle, a sight to render.\n",
"\n",
"Snowflakes twirl, like dancers on a stage,\n",
"Creating a symphony, a winter montage.\n",
"Their silent whispers, a sweet serenade,\n",
"As they dance and twirl, a snowy cascade.\n",
"\n",
"In the hush of dawn, a frosty morn,\n",
"Snow sparkles bright, like diamonds reborn.\n",
"Each flake unique, in its own design,\n",
"A masterpiece crafted by the divine.\n",
"\n",
"So let us revel in this wintry bliss,\n",
"As snowflakes fall, with a gentle kiss.\n",
"For in their embrace, we find a peace profound,\n",
"A frozen world, with magic all around."
]
}
],
"source": [
"import sys\n",
"\n",
"for chunk in llm.stream(\"Tell me a short poem about snow\"):\n",
" sys.stdout.write(chunk)\n",
" sys.stdout.flush()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aefe6df7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -1,10 +1,12 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterator, List, Optional
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.language_models import LanguageModelInput
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms import BaseLLM
@ -13,7 +15,9 @@ from langchain_community.utilities.vertexai import create_retry_decorator
def completion_with_retry(
llm: GooglePalm,
*args: Any,
prompt: LanguageModelInput,
is_gemini: bool = False,
stream: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
@ -23,10 +27,23 @@ def completion_with_retry(
)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.generate_text(*args, **kwargs)
def _completion_with_retry(
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
) -> Any:
generation_config = kwargs.get("generation_config", {})
if is_gemini:
return llm.client.generate_content(
contents=prompt, stream=stream, generation_config=generation_config
)
return llm.client.generate_text(prompt=prompt, **kwargs)
return _completion_with_retry(
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
)
return _completion_with_retry(*args, **kwargs)
def _is_gemini_model(model_name: str) -> bool:
return "gemini" in model_name
def _strip_erroneous_leading_spaces(text: str) -> str:
@ -42,11 +59,16 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
return text
@deprecated("0.0.351", alternative="langchain_google_genai.GoogleGenerativeAI")
class GooglePalm(BaseLLM, BaseModel):
"""Google PaLM models."""
"""
DEPRECATED: Use `langchain_google_genai.GoogleGenerativeAI` instead.
Google PaLM models.
"""
client: Any #: :meta private:
google_api_key: Optional[str]
google_api_key: Optional[SecretStr]
model_name: str = "models/text-bison-001"
"""Model name to use."""
temperature: float = 0.7
@ -67,6 +89,11 @@ class GooglePalm(BaseLLM, BaseModel):
max_retries: int = 6
"""The maximum number of retries to make when generating."""
@property
def is_gemini(self) -> bool:
"""Returns whether a model is belongs to a Gemini family or not."""
return _is_gemini_model(self.model_name)
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@ -86,18 +113,25 @@ class GooglePalm(BaseLLM, BaseModel):
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
model_name = values["model_name"]
try:
import google.generativeai as genai
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
genai.configure(api_key=google_api_key)
if _is_gemini_model(model_name):
values["client"] = genai.GenerativeModel(model_name=model_name)
else:
values["client"] = genai
except ImportError:
raise ImportError(
"Could not import google-generativeai python package. "
"Please install it with `pip install google-generativeai`."
)
values["client"] = genai
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
@ -119,30 +153,76 @@ class GooglePalm(BaseLLM, BaseModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations = []
generations: List[List[Generation]] = []
generation_config = {
"stop_sequences": stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
for prompt in prompts:
completion = completion_with_retry(
self,
model=self.model_name,
prompt=prompt,
stop_sequences=stop,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
max_output_tokens=self.max_output_tokens,
candidate_count=self.n,
**kwargs,
)
prompt_generations = []
for candidate in completion.candidates:
raw_text = candidate["output"]
stripped_text = _strip_erroneous_leading_spaces(raw_text)
prompt_generations.append(Generation(text=stripped_text))
generations.append(prompt_generations)
if self.is_gemini:
res = completion_with_retry(
self,
prompt=prompt,
stream=False,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
)
candidates = [
"".join([p.text for p in c.content.parts]) for c in res.candidates
]
generations.append([Generation(text=c) for c in candidates])
else:
res = completion_with_retry(
self,
model=self.model_name,
prompt=prompt,
stream=False,
is_gemini=False,
run_manager=run_manager,
**generation_config,
)
prompt_generations = []
for candidate in res.candidates:
raw_text = candidate["output"]
stripped_text = _strip_erroneous_leading_spaces(raw_text)
prompt_generations.append(Generation(text=stripped_text))
generations.append(prompt_generations)
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
generation_config = kwargs.get("generation_config", {})
if stop:
generation_config["stop_sequences"] = stop
for stream_resp in completion_with_retry(
self,
prompt,
stream=True,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
**kwargs,
):
chunk = GenerationChunk(text=stream_resp.text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
stream_resp.text,
chunk=chunk,
verbose=self.verbose,
)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
@ -159,5 +239,7 @@ class GooglePalm(BaseLLM, BaseModel):
Returns:
The integer number of tokens in the text.
"""
if self.is_gemini:
raise ValueError("Counting tokens is not yet supported!")
result = self.client.count_text_tokens(model=self.model_name, prompt=text)
return result["token_count"]

@ -1,4 +1,4 @@
"""Test Google PaLM Text API wrapper.
"""Test Google GenerativeAI API wrapper.
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
valid API key.
@ -6,35 +6,68 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
from pathlib import Path
import pytest
from langchain_core.outputs import LLMResult
from langchain_community.llms.google_palm import GooglePalm
from langchain_community.llms.loading import load_llm
model_names = [None, "models/text-bison-001", "gemini-pro"]
def test_google_palm_call() -> None:
"""Test valid call to Google PaLM text API."""
llm = GooglePalm(max_output_tokens=10)
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_call(model_name: str) -> None:
"""Test valid call to Google GenerativeAI text API."""
if model_name:
llm = GooglePalm(max_output_tokens=10, model_name=model_name)
else:
llm = GooglePalm(max_output_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "google_palm"
assert llm.model_name == "models/text-bison-001"
if model_name and "gemini" in model_name:
assert llm.client.model_name == "models/gemini-pro"
else:
assert llm.model_name == "models/text-bison-001"
def test_google_palm_generate() -> None:
llm = GooglePalm(temperature=0.3, n=2)
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_generate(model_name: str) -> None:
n = 1 if model_name == "gemini-pro" else 2
if model_name:
llm = GooglePalm(temperature=0.3, n=n, model_name=model_name)
else:
llm = GooglePalm(temperature=0.3, n=n)
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == 2
assert len(output.generations[0]) == n
def test_google_palm_get_num_tokens() -> None:
def test_google_generativeai_get_num_tokens() -> None:
llm = GooglePalm()
output = llm.get_num_tokens("How are you?")
assert output == 4
async def test_google_generativeai_agenerate() -> None:
llm = GooglePalm(temperature=0, model_name="gemini-pro")
output = await llm.agenerate(["Please say foo:"])
assert isinstance(output, LLMResult)
def test_generativeai_stream() -> None:
llm = GooglePalm(temperature=0, model_name="gemini-pro")
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading a Google PaLM LLM."""
llm = GooglePalm(max_output_tokens=10)

@ -6,11 +6,14 @@ This module integrates Google's Generative AI models, specifically the Gemini se
The `ChatGoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications.
**LLMs**
The `GoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini LLMs. It allows users to generate text using a specified Gemini model.
**Embeddings**
The `GoogleGenerativeAIEmbeddings` class provides functionalities to generate embeddings using Google's models.
These embeddings can be used for a range of NLP tasks, including semantic analysis, similarity comparisons, and more.
**Installation**
To install the package, use pip:
@ -29,6 +32,17 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro")
llm.invoke("Sing a ballad of LangChain.")
```
## Using LLMs
The package also supports generating text with Google's models.
```python
from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-pro")
llm.invoke("Once upon a time, a library called LangChain")
```
## Embedding Generation
The package also supports creating embeddings with Google's models, useful for textual similarity and other NLP applications.
@ -42,5 +56,10 @@ embeddings.embed_query("hello, world!")
""" # noqa: E501
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_google_genai.llms import GoogleGenerativeAI
__all__ = ["ChatGoogleGenerativeAI", "GoogleGenerativeAIEmbeddings"]
__all__ = [
"ChatGoogleGenerativeAI",
"GoogleGenerativeAIEmbeddings",
"GoogleGenerativeAI",
]

@ -0,0 +1,262 @@
from __future__ import annotations
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import google.api_core
import google.generativeai as genai # type: ignore[import]
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
def _create_retry_decorator(
llm: BaseLLM,
*,
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Creates a retry decorator for Vertex / Palm LLMs."""
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
google.api_core.exceptions.GoogleAPIError,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=max_retries, run_manager=run_manager
)
return decorator
def _completion_with_retry(
llm: GoogleGenerativeAI,
prompt: LanguageModelInput,
is_gemini: bool = False,
stream: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _completion_with_retry(
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
) -> Any:
generation_config = kwargs.get("generation_config", {})
if is_gemini:
return llm.client.generate_content(
contents=prompt, stream=stream, generation_config=generation_config
)
return llm.client.generate_text(prompt=prompt, **kwargs)
return _completion_with_retry(
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
)
def _is_gemini_model(model_name: str) -> bool:
return "gemini" in model_name
def _strip_erroneous_leading_spaces(text: str) -> str:
"""Strip erroneous leading spaces from text.
The PaLM API will sometimes erroneously return a single leading space in all
lines > 1. This function strips that space.
"""
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
if has_leading_space:
return text.replace("\n ", "\n")
else:
return text
class GoogleGenerativeAI(BaseLLM, BaseModel):
"""Google GenerativeAI models.
Example:
.. code-block:: python
from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-pro")
"""
client: Any #: :meta private:
model: str = Field(
...,
description="""The name of the model to use.
Supported examples:
- gemini-pro
- models/text-bison-001""",
)
"""Model name to use."""
google_api_key: Optional[SecretStr] = None
temperature: float = 0.7
"""Run inference with this temperature. Must by in the closed interval
[0.0, 1.0]."""
top_p: Optional[float] = None
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
max_output_tokens: Optional[int] = None
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
If unset, will default to 64."""
n: int = 1
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
max_retries: int = 6
"""The maximum number of retries to make when generating."""
@property
def is_gemini(self) -> bool:
"""Returns whether a model is belongs to a Gemini family or not."""
return _is_gemini_model(self.model)
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists."""
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
model_name = values["model"]
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
genai.configure(api_key=google_api_key)
if _is_gemini_model(model_name):
values["client"] = genai.GenerativeModel(model_name=model_name)
else:
values["client"] = genai
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if values["top_k"] is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
raise ValueError("max_output_tokens must be greater than zero")
return values
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations: List[List[Generation]] = []
generation_config = {
"stop_sequences": stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
for prompt in prompts:
if self.is_gemini:
res = _completion_with_retry(
self,
prompt=prompt,
stream=False,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
)
candidates = [
"".join([p.text for p in c.content.parts]) for c in res.candidates
]
generations.append([Generation(text=c) for c in candidates])
else:
res = _completion_with_retry(
self,
model=self.model,
prompt=prompt,
stream=False,
is_gemini=False,
run_manager=run_manager,
**generation_config,
)
prompt_generations = []
for candidate in res.candidates:
raw_text = candidate["output"]
stripped_text = _strip_erroneous_leading_spaces(raw_text)
prompt_generations.append(Generation(text=stripped_text))
generations.append(prompt_generations)
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
generation_config = kwargs.get("generation_config", {})
if stop:
generation_config["stop_sequences"] = stop
for stream_resp in _completion_with_retry(
self,
prompt,
stream=True,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
**kwargs,
):
chunk = GenerationChunk(text=stream_resp.text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
stream_resp.text,
chunk=chunk,
verbose=self.verbose,
)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "google_palm"
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
if self.is_gemini:
raise ValueError("Counting tokens is not yet supported!")
result = self.client.count_text_tokens(model=self.model, prompt=text)
return result["token_count"]

@ -0,0 +1,65 @@
"""Test Google GenerativeAI API wrapper.
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
valid API key.
"""
import pytest
from langchain_core.outputs import LLMResult
from langchain_google_genai.llms import GoogleGenerativeAI
model_names = [None, "models/text-bison-001", "gemini-pro"]
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_call(model_name: str) -> None:
"""Test valid call to Google GenerativeAI text API."""
if model_name:
llm = GoogleGenerativeAI(max_output_tokens=10, model=model_name)
else:
llm = GoogleGenerativeAI(max_output_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "google_palm"
if model_name and "gemini" in model_name:
assert llm.client.model_name == "models/gemini-pro"
else:
assert llm.model == "models/text-bison-001"
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_generate(model_name: str) -> None:
n = 1 if model_name == "gemini-pro" else 2
if model_name:
llm = GoogleGenerativeAI(temperature=0.3, n=n, model=model_name)
else:
llm = GoogleGenerativeAI(temperature=0.3, n=n)
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == n
def test_google_generativeai_get_num_tokens() -> None:
llm = GoogleGenerativeAI()
output = llm.get_num_tokens("How are you?")
assert output == 4
async def test_google_generativeai_agenerate() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
output = await llm.agenerate(["Please say foo:"])
assert isinstance(output, LLMResult)
def test_generativeai_stream() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)

@ -3,6 +3,7 @@ from langchain_google_genai import __all__
EXPECTED_ALL = [
"ChatGoogleGenerativeAI",
"GoogleGenerativeAIEmbeddings",
"GoogleGenerativeAI",
]

Loading…
Cancel
Save