mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
google-genai[patch], community[patch]: Added support for new Google GenerativeAI models (#14530)
Replace this entire comment with: - **Description:** added support for new Google GenerativeAI models - **Twitter handle:** lkuligin --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
6bbf0797f7
commit
7f42811e14
287
docs/docs/integrations/llms/google_ai.ipynb
Normal file
287
docs/docs/integrations/llms/google_ai.ipynb
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "7aZWXpbf0Eph",
|
||||||
|
"metadata": {
|
||||||
|
"id": "7aZWXpbf0Eph"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Google AI\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bead5ede-d9cc-44b9-b062-99c90a10cf40",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"A guide on using [Google Generative AI](https://developers.generativeai.google/) models with Langchain. Note: It's separate from Google Cloud Vertex AI [integration](https://python.langchain.com/docs/integrations/llms/google_vertex_ai_palm)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "H4AjsqTswBCE",
|
||||||
|
"metadata": {
|
||||||
|
"id": "H4AjsqTswBCE"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Setting up\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "EFHNUieMwJrl",
|
||||||
|
"metadata": {
|
||||||
|
"id": "EFHNUieMwJrl"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"To use Google Generative AI you must install the `langchain-google-genai` Python package and generate an API key. [Read more details](https://developers.generativeai.google/)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8Qzm6SqKwgak",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# !pip install langchain-google-genai"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7ONb7ZtOwjbo",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_google_genai import GoogleGenerativeAI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "X3pjCW0i22gm",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from getpass import getpass\n",
|
||||||
|
"\n",
|
||||||
|
"api_key = getpass()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "GT50LgFP0j-w",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"**Pros of Python:**\n",
|
||||||
|
"\n",
|
||||||
|
"* **Easy to learn:** Python is a very easy-to-learn programming language, even for beginners. Its syntax is simple and straightforward, and there are a lot of resources available to help you get started.\n",
|
||||||
|
"* **Versatile:** Python can be used for a wide variety of tasks, including web development, data science, and machine learning. It's also a good choice for beginners because it can be used for a variety of projects, so you can learn the basics and then move on to more complex tasks.\n",
|
||||||
|
"* **High-level:** Python is a high-level programming language, which means that it's closer to human language than other programming languages. This makes it easier to read and understand, which can be a big advantage for beginners.\n",
|
||||||
|
"* **Open-source:** Python is an open-source programming language, which means that it's free to use and there are a lot of resources available to help you learn it.\n",
|
||||||
|
"* **Community:** Python has a large and active community of developers, which means that there are a lot of people who can help you if you get stuck.\n",
|
||||||
|
"\n",
|
||||||
|
"**Cons of Python:**\n",
|
||||||
|
"\n",
|
||||||
|
"* **Slow:** Python is a relatively slow programming language compared to some other languages, such as C++. This can be a disadvantage if you're working on computationally intensive tasks.\n",
|
||||||
|
"* **Not as performant:** Python is not as performant as some other programming languages, such as C++ or Java. This can be a disadvantage if you're working on projects that require high performance.\n",
|
||||||
|
"* **Dynamic typing:** Python is a dynamically typed programming language, which means that the type of a variable can change during runtime. This can be a disadvantage if you need to ensure that your code is type-safe.\n",
|
||||||
|
"* **Unmanaged memory:** Python uses a garbage collection system to manage memory. This can be a disadvantage if you need to have more control over memory management.\n",
|
||||||
|
"\n",
|
||||||
|
"Overall, Python is a very good programming language for beginners. It's easy to learn, versatile, and has a large community of developers. However, it's important to be aware of its limitations, such as its slow performance and lack of performance.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"llm = GoogleGenerativeAI(model=\"models/text-bison-001\", google_api_key=api_key)\n",
|
||||||
|
"print(\n",
|
||||||
|
" llm.invoke(\n",
|
||||||
|
" \"What are some of the pros and cons of Python as a programming language?\"\n",
|
||||||
|
" )\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "TSGdxkJtwl8-",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"**Pros:**\n",
|
||||||
|
"\n",
|
||||||
|
"* **Simplicity and Readability:** Python is known for its simple and easy-to-read syntax, which makes it accessible to beginners and reduces the chance of errors. It uses indentation to define blocks of code, making the code structure clear and visually appealing.\n",
|
||||||
|
"\n",
|
||||||
|
"* **Versatility:** Python is a general-purpose language, meaning it can be used for a wide range of tasks, including web development, data science, machine learning, and desktop applications. This versatility makes it a popular choice for various projects and industries.\n",
|
||||||
|
"\n",
|
||||||
|
"* **Large Community:** Python has a vast and active community of developers, which contributes to its growth and popularity. This community provides extensive documentation, tutorials, and open-source libraries, making it easy for Python developers to find support and resources.\n",
|
||||||
|
"\n",
|
||||||
|
"* **Extensive Libraries:** Python offers a rich collection of libraries and frameworks for various tasks, such as data analysis (NumPy, Pandas), web development (Django, Flask), machine learning (Scikit-learn, TensorFlow), and many more. These libraries provide pre-built functions and modules, allowing developers to quickly and efficiently solve common problems.\n",
|
||||||
|
"\n",
|
||||||
|
"* **Cross-Platform Support:** Python is cross-platform, meaning it can run on various operating systems, including Windows, macOS, and Linux. This allows developers to write code that can be easily shared and used across different platforms.\n",
|
||||||
|
"\n",
|
||||||
|
"**Cons:**\n",
|
||||||
|
"\n",
|
||||||
|
"* **Speed and Performance:** Python is generally slower than compiled languages like C++ or Java due to its interpreted nature. This can be a disadvantage for performance-intensive tasks, such as real-time systems or heavy numerical computations.\n",
|
||||||
|
"\n",
|
||||||
|
"* **Memory Usage:** Python programs tend to consume more memory compared to compiled languages. This is because Python uses a dynamic memory allocation system, which can lead to memory fragmentation and higher memory usage.\n",
|
||||||
|
"\n",
|
||||||
|
"* **Lack of Static Typing:** Python is a dynamically typed language, which means that data types are not explicitly defined for variables. This can make it challenging to detect type errors during development, which can lead to unexpected behavior or errors at runtime.\n",
|
||||||
|
"\n",
|
||||||
|
"* **GIL (Global Interpreter Lock):** Python uses a global interpreter lock (GIL) to ensure that only one thread can execute Python bytecode at a time. This can limit the scalability and parallelism of Python programs, especially in multi-threaded or multiprocessing scenarios.\n",
|
||||||
|
"\n",
|
||||||
|
"* **Package Management:** While Python has a vast ecosystem of libraries and packages, managing dependencies and package versions can be challenging. The Python Package Index (PyPI) is the official repository for Python packages, but it can be difficult to ensure compatibility and avoid conflicts between different versions of packages.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"llm = GoogleGenerativeAI(model=\"gemini-pro\", google_api_key=api_key)\n",
|
||||||
|
"print(\n",
|
||||||
|
" llm.invoke(\n",
|
||||||
|
" \"What are some of the pros and cons of Python as a programming language?\"\n",
|
||||||
|
" )\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "OQ_SlL0K1Cw6",
|
||||||
|
"metadata": {
|
||||||
|
"id": "OQ_SlL0K1Cw6"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Using in a chain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "Nwc9P5_ry79W",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.prompts import PromptTemplate"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "35856bf2-aa5e-436b-977a-9e5725b1a595",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"4\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"template = \"\"\"Question: {question}\n",
|
||||||
|
"\n",
|
||||||
|
"Answer: Let's think step by step.\"\"\"\n",
|
||||||
|
"prompt = PromptTemplate.from_template(template)\n",
|
||||||
|
"\n",
|
||||||
|
"chain = prompt | llm\n",
|
||||||
|
"\n",
|
||||||
|
"question = \"How much is 2+2?\"\n",
|
||||||
|
"print(chain.invoke({\"question\": question}))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ueAin0xQzCqq",
|
||||||
|
"metadata": {
|
||||||
|
"id": "ueAin0xQzCqq"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Streaming calls"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "WftL7x0A0hlF",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"In winter's embrace, a silent ballet,\n",
|
||||||
|
"Snowflakes descend, a celestial display.\n",
|
||||||
|
"Whispering secrets, they softly fall,\n",
|
||||||
|
"A blanket of white, covering all.\n",
|
||||||
|
"\n",
|
||||||
|
"With gentle grace, they paint the land,\n",
|
||||||
|
"Transforming the world into a winter wonderland.\n",
|
||||||
|
"Trees stand adorned in icy splendor,\n",
|
||||||
|
"A glistening spectacle, a sight to render.\n",
|
||||||
|
"\n",
|
||||||
|
"Snowflakes twirl, like dancers on a stage,\n",
|
||||||
|
"Creating a symphony, a winter montage.\n",
|
||||||
|
"Their silent whispers, a sweet serenade,\n",
|
||||||
|
"As they dance and twirl, a snowy cascade.\n",
|
||||||
|
"\n",
|
||||||
|
"In the hush of dawn, a frosty morn,\n",
|
||||||
|
"Snow sparkles bright, like diamonds reborn.\n",
|
||||||
|
"Each flake unique, in its own design,\n",
|
||||||
|
"A masterpiece crafted by the divine.\n",
|
||||||
|
"\n",
|
||||||
|
"So let us revel in this wintry bliss,\n",
|
||||||
|
"As snowflakes fall, with a gentle kiss.\n",
|
||||||
|
"For in their embrace, we find a peace profound,\n",
|
||||||
|
"A frozen world, with magic all around."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import sys\n",
|
||||||
|
"\n",
|
||||||
|
"for chunk in llm.stream(\"Tell me a short poem about snow\"):\n",
|
||||||
|
" sys.stdout.write(chunk)\n",
|
||||||
|
" sys.stdout.flush()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "aefe6df7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -1,10 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
|
from langchain_core._api.deprecation import deprecated
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
from langchain_core.outputs import Generation, LLMResult
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
from langchain_community.llms import BaseLLM
|
from langchain_community.llms import BaseLLM
|
||||||
@ -13,7 +15,9 @@ from langchain_community.utilities.vertexai import create_retry_decorator
|
|||||||
|
|
||||||
def completion_with_retry(
|
def completion_with_retry(
|
||||||
llm: GooglePalm,
|
llm: GooglePalm,
|
||||||
*args: Any,
|
prompt: LanguageModelInput,
|
||||||
|
is_gemini: bool = False,
|
||||||
|
stream: bool = False,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -23,10 +27,23 @@ def completion_with_retry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
def _completion_with_retry(
|
||||||
return llm.client.generate_text(*args, **kwargs)
|
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
generation_config = kwargs.get("generation_config", {})
|
||||||
|
if is_gemini:
|
||||||
|
return llm.client.generate_content(
|
||||||
|
contents=prompt, stream=stream, generation_config=generation_config
|
||||||
|
)
|
||||||
|
return llm.client.generate_text(prompt=prompt, **kwargs)
|
||||||
|
|
||||||
return _completion_with_retry(*args, **kwargs)
|
return _completion_with_retry(
|
||||||
|
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_gemini_model(model_name: str) -> bool:
|
||||||
|
return "gemini" in model_name
|
||||||
|
|
||||||
|
|
||||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||||
@ -42,11 +59,16 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("0.0.351", alternative="langchain_google_genai.GoogleGenerativeAI")
|
||||||
class GooglePalm(BaseLLM, BaseModel):
|
class GooglePalm(BaseLLM, BaseModel):
|
||||||
"""Google PaLM models."""
|
"""
|
||||||
|
DEPRECATED: Use `langchain_google_genai.GoogleGenerativeAI` instead.
|
||||||
|
|
||||||
|
Google PaLM models.
|
||||||
|
"""
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
google_api_key: Optional[str]
|
google_api_key: Optional[SecretStr]
|
||||||
model_name: str = "models/text-bison-001"
|
model_name: str = "models/text-bison-001"
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
@ -67,6 +89,11 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
max_retries: int = 6
|
max_retries: int = 6
|
||||||
"""The maximum number of retries to make when generating."""
|
"""The maximum number of retries to make when generating."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_gemini(self) -> bool:
|
||||||
|
"""Returns whether a model is belongs to a Gemini family or not."""
|
||||||
|
return _is_gemini_model(self.model_name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
return {"google_api_key": "GOOGLE_API_KEY"}
|
return {"google_api_key": "GOOGLE_API_KEY"}
|
||||||
@ -86,18 +113,25 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
google_api_key = get_from_dict_or_env(
|
google_api_key = get_from_dict_or_env(
|
||||||
values, "google_api_key", "GOOGLE_API_KEY"
|
values, "google_api_key", "GOOGLE_API_KEY"
|
||||||
)
|
)
|
||||||
|
model_name = values["model_name"]
|
||||||
try:
|
try:
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
if isinstance(google_api_key, SecretStr):
|
||||||
|
google_api_key = google_api_key.get_secret_value()
|
||||||
|
|
||||||
genai.configure(api_key=google_api_key)
|
genai.configure(api_key=google_api_key)
|
||||||
|
|
||||||
|
if _is_gemini_model(model_name):
|
||||||
|
values["client"] = genai.GenerativeModel(model_name=model_name)
|
||||||
|
else:
|
||||||
|
values["client"] = genai
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import google-generativeai python package. "
|
"Could not import google-generativeai python package. "
|
||||||
"Please install it with `pip install google-generativeai`."
|
"Please install it with `pip install google-generativeai`."
|
||||||
)
|
)
|
||||||
|
|
||||||
values["client"] = genai
|
|
||||||
|
|
||||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||||
|
|
||||||
@ -119,30 +153,76 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
generations = []
|
generations: List[List[Generation]] = []
|
||||||
|
generation_config = {
|
||||||
|
"stop_sequences": stop,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"max_output_tokens": self.max_output_tokens,
|
||||||
|
"candidate_count": self.n,
|
||||||
|
}
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
completion = completion_with_retry(
|
if self.is_gemini:
|
||||||
self,
|
res = completion_with_retry(
|
||||||
model=self.model_name,
|
self,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stop_sequences=stop,
|
stream=False,
|
||||||
temperature=self.temperature,
|
is_gemini=True,
|
||||||
top_p=self.top_p,
|
run_manager=run_manager,
|
||||||
top_k=self.top_k,
|
generation_config=generation_config,
|
||||||
max_output_tokens=self.max_output_tokens,
|
)
|
||||||
candidate_count=self.n,
|
candidates = [
|
||||||
**kwargs,
|
"".join([p.text for p in c.content.parts]) for c in res.candidates
|
||||||
)
|
]
|
||||||
|
generations.append([Generation(text=c) for c in candidates])
|
||||||
prompt_generations = []
|
else:
|
||||||
for candidate in completion.candidates:
|
res = completion_with_retry(
|
||||||
raw_text = candidate["output"]
|
self,
|
||||||
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
model=self.model_name,
|
||||||
prompt_generations.append(Generation(text=stripped_text))
|
prompt=prompt,
|
||||||
generations.append(prompt_generations)
|
stream=False,
|
||||||
|
is_gemini=False,
|
||||||
|
run_manager=run_manager,
|
||||||
|
**generation_config,
|
||||||
|
)
|
||||||
|
prompt_generations = []
|
||||||
|
for candidate in res.candidates:
|
||||||
|
raw_text = candidate["output"]
|
||||||
|
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||||
|
prompt_generations.append(Generation(text=stripped_text))
|
||||||
|
generations.append(prompt_generations)
|
||||||
|
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
generation_config = kwargs.get("generation_config", {})
|
||||||
|
if stop:
|
||||||
|
generation_config["stop_sequences"] = stop
|
||||||
|
for stream_resp in completion_with_retry(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
stream=True,
|
||||||
|
is_gemini=True,
|
||||||
|
run_manager=run_manager,
|
||||||
|
generation_config=generation_config,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
chunk = GenerationChunk(text=stream_resp.text)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(
|
||||||
|
stream_resp.text,
|
||||||
|
chunk=chunk,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
@ -159,5 +239,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
The integer number of tokens in the text.
|
The integer number of tokens in the text.
|
||||||
"""
|
"""
|
||||||
|
if self.is_gemini:
|
||||||
|
raise ValueError("Counting tokens is not yet supported!")
|
||||||
result = self.client.count_text_tokens(model=self.model_name, prompt=text)
|
result = self.client.count_text_tokens(model=self.model_name, prompt=text)
|
||||||
return result["token_count"]
|
return result["token_count"]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Test Google PaLM Text API wrapper.
|
"""Test Google GenerativeAI API wrapper.
|
||||||
|
|
||||||
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
|
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
|
||||||
valid API key.
|
valid API key.
|
||||||
@ -6,35 +6,68 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
from langchain_community.llms.google_palm import GooglePalm
|
from langchain_community.llms.google_palm import GooglePalm
|
||||||
from langchain_community.llms.loading import load_llm
|
from langchain_community.llms.loading import load_llm
|
||||||
|
|
||||||
|
model_names = [None, "models/text-bison-001", "gemini-pro"]
|
||||||
|
|
||||||
def test_google_palm_call() -> None:
|
|
||||||
"""Test valid call to Google PaLM text API."""
|
@pytest.mark.parametrize(
|
||||||
llm = GooglePalm(max_output_tokens=10)
|
"model_name",
|
||||||
|
model_names,
|
||||||
|
)
|
||||||
|
def test_google_generativeai_call(model_name: str) -> None:
|
||||||
|
"""Test valid call to Google GenerativeAI text API."""
|
||||||
|
if model_name:
|
||||||
|
llm = GooglePalm(max_output_tokens=10, model_name=model_name)
|
||||||
|
else:
|
||||||
|
llm = GooglePalm(max_output_tokens=10)
|
||||||
output = llm("Say foo:")
|
output = llm("Say foo:")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert llm._llm_type == "google_palm"
|
assert llm._llm_type == "google_palm"
|
||||||
assert llm.model_name == "models/text-bison-001"
|
if model_name and "gemini" in model_name:
|
||||||
|
assert llm.client.model_name == "models/gemini-pro"
|
||||||
|
else:
|
||||||
|
assert llm.model_name == "models/text-bison-001"
|
||||||
|
|
||||||
|
|
||||||
def test_google_palm_generate() -> None:
|
@pytest.mark.parametrize(
|
||||||
llm = GooglePalm(temperature=0.3, n=2)
|
"model_name",
|
||||||
|
model_names,
|
||||||
|
)
|
||||||
|
def test_google_generativeai_generate(model_name: str) -> None:
|
||||||
|
n = 1 if model_name == "gemini-pro" else 2
|
||||||
|
if model_name:
|
||||||
|
llm = GooglePalm(temperature=0.3, n=n, model_name=model_name)
|
||||||
|
else:
|
||||||
|
llm = GooglePalm(temperature=0.3, n=n)
|
||||||
output = llm.generate(["Say foo:"])
|
output = llm.generate(["Say foo:"])
|
||||||
assert isinstance(output, LLMResult)
|
assert isinstance(output, LLMResult)
|
||||||
assert len(output.generations) == 1
|
assert len(output.generations) == 1
|
||||||
assert len(output.generations[0]) == 2
|
assert len(output.generations[0]) == n
|
||||||
|
|
||||||
|
|
||||||
def test_google_palm_get_num_tokens() -> None:
|
def test_google_generativeai_get_num_tokens() -> None:
|
||||||
llm = GooglePalm()
|
llm = GooglePalm()
|
||||||
output = llm.get_num_tokens("How are you?")
|
output = llm.get_num_tokens("How are you?")
|
||||||
assert output == 4
|
assert output == 4
|
||||||
|
|
||||||
|
|
||||||
|
async def test_google_generativeai_agenerate() -> None:
|
||||||
|
llm = GooglePalm(temperature=0, model_name="gemini-pro")
|
||||||
|
output = await llm.agenerate(["Please say foo:"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generativeai_stream() -> None:
|
||||||
|
llm = GooglePalm(temperature=0, model_name="gemini-pro")
|
||||||
|
outputs = list(llm.stream("Please say foo:"))
|
||||||
|
assert isinstance(outputs[0], str)
|
||||||
|
|
||||||
|
|
||||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||||
"""Test saving/loading a Google PaLM LLM."""
|
"""Test saving/loading a Google PaLM LLM."""
|
||||||
llm = GooglePalm(max_output_tokens=10)
|
llm = GooglePalm(max_output_tokens=10)
|
||||||
|
@ -6,11 +6,14 @@ This module integrates Google's Generative AI models, specifically the Gemini se
|
|||||||
|
|
||||||
The `ChatGoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications.
|
The `ChatGoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications.
|
||||||
|
|
||||||
|
**LLMs**
|
||||||
|
|
||||||
|
The `GoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini LLMs. It allows users to generate text using a specified Gemini model.
|
||||||
|
|
||||||
**Embeddings**
|
**Embeddings**
|
||||||
|
|
||||||
The `GoogleGenerativeAIEmbeddings` class provides functionalities to generate embeddings using Google's models.
|
The `GoogleGenerativeAIEmbeddings` class provides functionalities to generate embeddings using Google's models.
|
||||||
These embeddings can be used for a range of NLP tasks, including semantic analysis, similarity comparisons, and more.
|
These embeddings can be used for a range of NLP tasks, including semantic analysis, similarity comparisons, and more.
|
||||||
|
|
||||||
**Installation**
|
**Installation**
|
||||||
|
|
||||||
To install the package, use pip:
|
To install the package, use pip:
|
||||||
@ -29,6 +32,17 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro")
|
|||||||
llm.invoke("Sing a ballad of LangChain.")
|
llm.invoke("Sing a ballad of LangChain.")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Using LLMs
|
||||||
|
|
||||||
|
The package also supports generating text with Google's models.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain_google_genai import GoogleGenerativeAI
|
||||||
|
|
||||||
|
llm = GoogleGenerativeAI(model="gemini-pro")
|
||||||
|
llm.invoke("Once upon a time, a library called LangChain")
|
||||||
|
```
|
||||||
|
|
||||||
## Embedding Generation
|
## Embedding Generation
|
||||||
|
|
||||||
The package also supports creating embeddings with Google's models, useful for textual similarity and other NLP applications.
|
The package also supports creating embeddings with Google's models, useful for textual similarity and other NLP applications.
|
||||||
@ -42,5 +56,10 @@ embeddings.embed_query("hello, world!")
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||||
|
from langchain_google_genai.llms import GoogleGenerativeAI
|
||||||
|
|
||||||
__all__ = ["ChatGoogleGenerativeAI", "GoogleGenerativeAIEmbeddings"]
|
__all__ = [
|
||||||
|
"ChatGoogleGenerativeAI",
|
||||||
|
"GoogleGenerativeAIEmbeddings",
|
||||||
|
"GoogleGenerativeAI",
|
||||||
|
]
|
||||||
|
262
libs/partners/google-genai/langchain_google_genai/llms.py
Normal file
262
libs/partners/google-genai/langchain_google_genai/llms.py
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
import google.api_core
|
||||||
|
import google.generativeai as genai # type: ignore[import]
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
|
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
|
||||||
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||||
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(
|
||||||
|
llm: BaseLLM,
|
||||||
|
*,
|
||||||
|
max_retries: int = 1,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
|
"""Creates a retry decorator for Vertex / Palm LLMs."""
|
||||||
|
|
||||||
|
errors = [
|
||||||
|
google.api_core.exceptions.ResourceExhausted,
|
||||||
|
google.api_core.exceptions.ServiceUnavailable,
|
||||||
|
google.api_core.exceptions.Aborted,
|
||||||
|
google.api_core.exceptions.DeadlineExceeded,
|
||||||
|
google.api_core.exceptions.GoogleAPIError,
|
||||||
|
]
|
||||||
|
decorator = create_base_retry_decorator(
|
||||||
|
error_types=errors, max_retries=max_retries, run_manager=run_manager
|
||||||
|
)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _completion_with_retry(
|
||||||
|
llm: GoogleGenerativeAI,
|
||||||
|
prompt: LanguageModelInput,
|
||||||
|
is_gemini: bool = False,
|
||||||
|
stream: bool = False,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(
|
||||||
|
llm, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(
|
||||||
|
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
generation_config = kwargs.get("generation_config", {})
|
||||||
|
if is_gemini:
|
||||||
|
return llm.client.generate_content(
|
||||||
|
contents=prompt, stream=stream, generation_config=generation_config
|
||||||
|
)
|
||||||
|
return llm.client.generate_text(prompt=prompt, **kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(
|
||||||
|
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_gemini_model(model_name: str) -> bool:
|
||||||
|
return "gemini" in model_name
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||||
|
"""Strip erroneous leading spaces from text.
|
||||||
|
|
||||||
|
The PaLM API will sometimes erroneously return a single leading space in all
|
||||||
|
lines > 1. This function strips that space.
|
||||||
|
"""
|
||||||
|
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
|
||||||
|
if has_leading_space:
|
||||||
|
return text.replace("\n ", "\n")
|
||||||
|
else:
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleGenerativeAI(BaseLLM, BaseModel):
|
||||||
|
"""Google GenerativeAI models.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_google_genai import GoogleGenerativeAI
|
||||||
|
llm = GoogleGenerativeAI(model="gemini-pro")
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
model: str = Field(
|
||||||
|
...,
|
||||||
|
description="""The name of the model to use.
|
||||||
|
Supported examples:
|
||||||
|
- gemini-pro
|
||||||
|
- models/text-bison-001""",
|
||||||
|
)
|
||||||
|
"""Model name to use."""
|
||||||
|
google_api_key: Optional[SecretStr] = None
|
||||||
|
temperature: float = 0.7
|
||||||
|
"""Run inference with this temperature. Must by in the closed interval
|
||||||
|
[0.0, 1.0]."""
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||||
|
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||||
|
Must be positive."""
|
||||||
|
max_output_tokens: Optional[int] = None
|
||||||
|
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
|
||||||
|
If unset, will default to 64."""
|
||||||
|
n: int = 1
|
||||||
|
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||||
|
not return the full n completions if duplicates are generated."""
|
||||||
|
max_retries: int = 6
|
||||||
|
"""The maximum number of retries to make when generating."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_gemini(self) -> bool:
|
||||||
|
"""Returns whether a model is belongs to a Gemini family or not."""
|
||||||
|
return _is_gemini_model(self.model)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"google_api_key": "GOOGLE_API_KEY"}
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate api key, python package exists."""
|
||||||
|
google_api_key = get_from_dict_or_env(
|
||||||
|
values, "google_api_key", "GOOGLE_API_KEY"
|
||||||
|
)
|
||||||
|
model_name = values["model"]
|
||||||
|
|
||||||
|
if isinstance(google_api_key, SecretStr):
|
||||||
|
google_api_key = google_api_key.get_secret_value()
|
||||||
|
|
||||||
|
genai.configure(api_key=google_api_key)
|
||||||
|
|
||||||
|
if _is_gemini_model(model_name):
|
||||||
|
values["client"] = genai.GenerativeModel(model_name=model_name)
|
||||||
|
else:
|
||||||
|
values["client"] = genai
|
||||||
|
|
||||||
|
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||||
|
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||||
|
|
||||||
|
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||||
|
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||||
|
|
||||||
|
if values["top_k"] is not None and values["top_k"] <= 0:
|
||||||
|
raise ValueError("top_k must be positive")
|
||||||
|
|
||||||
|
if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
|
||||||
|
raise ValueError("max_output_tokens must be greater than zero")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
generations: List[List[Generation]] = []
|
||||||
|
generation_config = {
|
||||||
|
"stop_sequences": stop,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"max_output_tokens": self.max_output_tokens,
|
||||||
|
"candidate_count": self.n,
|
||||||
|
}
|
||||||
|
for prompt in prompts:
|
||||||
|
if self.is_gemini:
|
||||||
|
res = _completion_with_retry(
|
||||||
|
self,
|
||||||
|
prompt=prompt,
|
||||||
|
stream=False,
|
||||||
|
is_gemini=True,
|
||||||
|
run_manager=run_manager,
|
||||||
|
generation_config=generation_config,
|
||||||
|
)
|
||||||
|
candidates = [
|
||||||
|
"".join([p.text for p in c.content.parts]) for c in res.candidates
|
||||||
|
]
|
||||||
|
generations.append([Generation(text=c) for c in candidates])
|
||||||
|
else:
|
||||||
|
res = _completion_with_retry(
|
||||||
|
self,
|
||||||
|
model=self.model,
|
||||||
|
prompt=prompt,
|
||||||
|
stream=False,
|
||||||
|
is_gemini=False,
|
||||||
|
run_manager=run_manager,
|
||||||
|
**generation_config,
|
||||||
|
)
|
||||||
|
prompt_generations = []
|
||||||
|
for candidate in res.candidates:
|
||||||
|
raw_text = candidate["output"]
|
||||||
|
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||||
|
prompt_generations.append(Generation(text=stripped_text))
|
||||||
|
generations.append(prompt_generations)
|
||||||
|
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
generation_config = kwargs.get("generation_config", {})
|
||||||
|
if stop:
|
||||||
|
generation_config["stop_sequences"] = stop
|
||||||
|
for stream_resp in _completion_with_retry(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
stream=True,
|
||||||
|
is_gemini=True,
|
||||||
|
run_manager=run_manager,
|
||||||
|
generation_config=generation_config,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
chunk = GenerationChunk(text=stream_resp.text)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(
|
||||||
|
stream_resp.text,
|
||||||
|
chunk=chunk,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "google_palm"
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
"""Get the number of tokens present in the text.
|
||||||
|
|
||||||
|
Useful for checking if an input will fit in a model's context window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The string input to tokenize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The integer number of tokens in the text.
|
||||||
|
"""
|
||||||
|
if self.is_gemini:
|
||||||
|
raise ValueError("Counting tokens is not yet supported!")
|
||||||
|
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
||||||
|
return result["token_count"]
|
@ -0,0 +1,65 @@
|
|||||||
|
"""Test Google GenerativeAI API wrapper.
|
||||||
|
|
||||||
|
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
|
||||||
|
valid API key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
from langchain_google_genai.llms import GoogleGenerativeAI
|
||||||
|
|
||||||
|
model_names = [None, "models/text-bison-001", "gemini-pro"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
model_names,
|
||||||
|
)
|
||||||
|
def test_google_generativeai_call(model_name: str) -> None:
|
||||||
|
"""Test valid call to Google GenerativeAI text API."""
|
||||||
|
if model_name:
|
||||||
|
llm = GoogleGenerativeAI(max_output_tokens=10, model=model_name)
|
||||||
|
else:
|
||||||
|
llm = GoogleGenerativeAI(max_output_tokens=10)
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
assert llm._llm_type == "google_palm"
|
||||||
|
if model_name and "gemini" in model_name:
|
||||||
|
assert llm.client.model_name == "models/gemini-pro"
|
||||||
|
else:
|
||||||
|
assert llm.model == "models/text-bison-001"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
model_names,
|
||||||
|
)
|
||||||
|
def test_google_generativeai_generate(model_name: str) -> None:
|
||||||
|
n = 1 if model_name == "gemini-pro" else 2
|
||||||
|
if model_name:
|
||||||
|
llm = GoogleGenerativeAI(temperature=0.3, n=n, model=model_name)
|
||||||
|
else:
|
||||||
|
llm = GoogleGenerativeAI(temperature=0.3, n=n)
|
||||||
|
output = llm.generate(["Say foo:"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert len(output.generations) == 1
|
||||||
|
assert len(output.generations[0]) == n
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_generativeai_get_num_tokens() -> None:
|
||||||
|
llm = GoogleGenerativeAI()
|
||||||
|
output = llm.get_num_tokens("How are you?")
|
||||||
|
assert output == 4
|
||||||
|
|
||||||
|
|
||||||
|
async def test_google_generativeai_agenerate() -> None:
|
||||||
|
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||||
|
output = await llm.agenerate(["Please say foo:"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generativeai_stream() -> None:
|
||||||
|
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||||
|
outputs = list(llm.stream("Please say foo:"))
|
||||||
|
assert isinstance(outputs[0], str)
|
@ -3,6 +3,7 @@ from langchain_google_genai import __all__
|
|||||||
EXPECTED_ALL = [
|
EXPECTED_ALL = [
|
||||||
"ChatGoogleGenerativeAI",
|
"ChatGoogleGenerativeAI",
|
||||||
"GoogleGenerativeAIEmbeddings",
|
"GoogleGenerativeAIEmbeddings",
|
||||||
|
"GoogleGenerativeAI",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user