mirror of https://github.com/hwchase17/langchain
Add ChatModel, LLM, and Embeddings for Google's PaLM APIs (#3575)
- Add langchain.llms.GooglePalm for text completion, - Add langchain.chat_models.ChatGooglePalm for chat completion, - Add langchain.embeddings.GooglePalmEmbeddings for sentence embeddings, - Add example field to HumanMessage and AIMessage so that users can feed in examples into the PaLM Chat API, - Add system and unit tests. Note async completion for the Text API is not yet supported and will be included in a future PR. Happy for feedback on any aspect of this PR, especially our choice of adding an example field to Human and AI Message objects to enable passing example messages to the API.pull/3920/head
parent
d15f481352
commit
921894960b
@ -1,6 +1,13 @@
|
|||||||
from langchain.chat_models.anthropic import ChatAnthropic
|
from langchain.chat_models.anthropic import ChatAnthropic
|
||||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||||
|
from langchain.chat_models.google_palm import ChatGooglePalm
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||||
|
|
||||||
__all__ = ["ChatOpenAI", "AzureChatOpenAI", "PromptLayerChatOpenAI", "ChatAnthropic"]
|
__all__ = [
|
||||||
|
"ChatOpenAI",
|
||||||
|
"AzureChatOpenAI",
|
||||||
|
"PromptLayerChatOpenAI",
|
||||||
|
"ChatAnthropic",
|
||||||
|
"ChatGooglePalm",
|
||||||
|
]
|
||||||
|
@ -0,0 +1,248 @@
|
|||||||
|
"""Wrapper around Google's PaLM Chat API."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatGeneration,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResult,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGooglePalmError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_at_stop_tokens(
|
||||||
|
text: str,
|
||||||
|
stop: Optional[List[str]],
|
||||||
|
) -> str:
|
||||||
|
"""Truncates text at the earliest stop token found."""
|
||||||
|
if stop is None:
|
||||||
|
return text
|
||||||
|
|
||||||
|
for stop_token in stop:
|
||||||
|
stop_token_idx = text.find(stop_token)
|
||||||
|
if stop_token_idx != -1:
|
||||||
|
text = text[:stop_token_idx]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _response_to_result(
|
||||||
|
response: genai.types.ChatResponse,
|
||||||
|
stop: Optional[List[str]],
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Converts a PaLM API response into a LangChain ChatResult."""
|
||||||
|
if not response.candidates:
|
||||||
|
raise ChatGooglePalmError("ChatResponse must have at least one candidate.")
|
||||||
|
|
||||||
|
generations: List[ChatGeneration] = []
|
||||||
|
for candidate in response.candidates:
|
||||||
|
author = candidate.get("author")
|
||||||
|
if author is None:
|
||||||
|
raise ChatGooglePalmError(f"ChatResponse must have an author: {candidate}")
|
||||||
|
|
||||||
|
content = _truncate_at_stop_tokens(candidate.get("content", ""), stop)
|
||||||
|
if content is None:
|
||||||
|
raise ChatGooglePalmError(f"ChatResponse must have a content: {candidate}")
|
||||||
|
|
||||||
|
if author == "ai":
|
||||||
|
generations.append(
|
||||||
|
ChatGeneration(text=content, message=AIMessage(content=content))
|
||||||
|
)
|
||||||
|
elif author == "human":
|
||||||
|
generations.append(
|
||||||
|
ChatGeneration(
|
||||||
|
text=content,
|
||||||
|
message=HumanMessage(content=content),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generations.append(
|
||||||
|
ChatGeneration(
|
||||||
|
text=content,
|
||||||
|
message=ChatMessage(role=author, content=content),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResult(generations=generations)
|
||||||
|
|
||||||
|
|
||||||
|
def _messages_to_prompt_dict(
|
||||||
|
input_messages: List[BaseMessage],
|
||||||
|
) -> genai.types.MessagePromptDict:
|
||||||
|
"""Converts a list of LangChain messages into a PaLM API MessagePrompt structure."""
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
context: str = ""
|
||||||
|
examples: List[genai.types.MessageDict] = []
|
||||||
|
messages: List[genai.types.MessageDict] = []
|
||||||
|
|
||||||
|
remaining = list(enumerate(input_messages))
|
||||||
|
|
||||||
|
while remaining:
|
||||||
|
index, input_message = remaining.pop(0)
|
||||||
|
|
||||||
|
if isinstance(input_message, SystemMessage):
|
||||||
|
if index != 0:
|
||||||
|
raise ChatGooglePalmError("System message must be first input message.")
|
||||||
|
context = input_message.content
|
||||||
|
elif isinstance(input_message, HumanMessage) and input_message.example:
|
||||||
|
if messages:
|
||||||
|
raise ChatGooglePalmError(
|
||||||
|
"Message examples must come before other messages."
|
||||||
|
)
|
||||||
|
_, next_input_message = remaining.pop(0)
|
||||||
|
if isinstance(next_input_message, AIMessage) and next_input_message.example:
|
||||||
|
examples.extend(
|
||||||
|
[
|
||||||
|
genai.types.MessageDict(
|
||||||
|
author="human", content=input_message.content
|
||||||
|
),
|
||||||
|
genai.types.MessageDict(
|
||||||
|
author="ai", content=next_input_message.content
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ChatGooglePalmError(
|
||||||
|
"Human example message must be immediately followed by an "
|
||||||
|
" AI example response."
|
||||||
|
)
|
||||||
|
elif isinstance(input_message, AIMessage) and input_message.example:
|
||||||
|
raise ChatGooglePalmError(
|
||||||
|
"AI example message must be immediately preceded by a Human "
|
||||||
|
"example message."
|
||||||
|
)
|
||||||
|
elif isinstance(input_message, AIMessage):
|
||||||
|
messages.append(
|
||||||
|
genai.types.MessageDict(author="ai", content=input_message.content)
|
||||||
|
)
|
||||||
|
elif isinstance(input_message, HumanMessage):
|
||||||
|
messages.append(
|
||||||
|
genai.types.MessageDict(author="human", content=input_message.content)
|
||||||
|
)
|
||||||
|
elif isinstance(input_message, ChatMessage):
|
||||||
|
messages.append(
|
||||||
|
genai.types.MessageDict(
|
||||||
|
author=input_message.role, content=input_message.content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ChatGooglePalmError(
|
||||||
|
"Messages without an explicit role not supported by PaLM API."
|
||||||
|
)
|
||||||
|
|
||||||
|
return genai.types.MessagePromptDict(
|
||||||
|
context=context,
|
||||||
|
examples=examples,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||||
|
"""Wrapper around Google's PaLM Chat API.
|
||||||
|
|
||||||
|
To use you must have the google.generativeai Python package installed and
|
||||||
|
either:
|
||||||
|
|
||||||
|
1. The ``GOOGLE_API_KEY``` environment varaible set with your API key, or
|
||||||
|
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
|
||||||
|
constructor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatGooglePalm
|
||||||
|
chat = ChatGooglePalm()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
model_name: str = "models/chat-bison-001"
|
||||||
|
"""Model name to use."""
|
||||||
|
google_api_key: Optional[str] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
"""Run inference with this temperature. Must by in the closed
|
||||||
|
interval [0.0, 1.0]."""
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||||
|
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||||
|
Must be positive."""
|
||||||
|
n: int = 1
|
||||||
|
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||||
|
not return the full n completions if duplicates are generated."""
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate api key, python package exists, temperature, top_p, and top_k."""
|
||||||
|
google_api_key = get_from_dict_or_env(
|
||||||
|
values, "google_api_key", "GOOGLE_API_KEY"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
genai.configure(api_key=google_api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ChatGooglePalmError(
|
||||||
|
"Could not import google.generativeai python package."
|
||||||
|
)
|
||||||
|
|
||||||
|
values["client"] = genai
|
||||||
|
|
||||||
|
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||||
|
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||||
|
|
||||||
|
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||||
|
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||||
|
|
||||||
|
if values["top_k"] is not None and values["top_k"] <= 0:
|
||||||
|
raise ValueError("top_k must be positive")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||||
|
) -> ChatResult:
|
||||||
|
prompt = _messages_to_prompt_dict(messages)
|
||||||
|
|
||||||
|
response: genai.types.ChatResponse = self.client.chat(
|
||||||
|
model=self.model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
candidate_count=self.n,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _response_to_result(response, stop)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||||
|
) -> ChatResult:
|
||||||
|
prompt = _messages_to_prompt_dict(messages)
|
||||||
|
|
||||||
|
response: genai.types.ChatResponse = await self.client.chat_async(
|
||||||
|
model=self.model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
candidate_count=self.n,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _response_to_result(response, stop)
|
@ -0,0 +1,38 @@
|
|||||||
|
"""Wrapper arround Google's PaLM Embeddings APIs."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||||
|
client: Any
|
||||||
|
google_api_key: Optional[str]
|
||||||
|
model_name: str = "models/embedding-gecko-001"
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate api key, python package exists."""
|
||||||
|
google_api_key = get_from_dict_or_env(
|
||||||
|
values, "google_api_key", "GOOGLE_API_KEY"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
genai.configure(api_key=google_api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Could not import google.generativeai python package.")
|
||||||
|
|
||||||
|
values["client"] = genai
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
return [self.embed_query(text) for text in texts]
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Embed query text."""
|
||||||
|
embedding = self.client.generate_embeddings(self.model_name, text)
|
||||||
|
return embedding["embedding"]
|
@ -0,0 +1,109 @@
|
|||||||
|
"""Wrapper arround Google's PaLM Text APIs."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.llms import BaseLLM
|
||||||
|
from langchain.schema import Generation, LLMResult
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||||
|
"""Strip erroneous leading spaces from text.
|
||||||
|
|
||||||
|
The PaLM API will sometimes erroneously return a single leading space in all
|
||||||
|
lines > 1. This function strips that space.
|
||||||
|
"""
|
||||||
|
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
|
||||||
|
if has_leading_space:
|
||||||
|
return text.replace("\n ", "\n")
|
||||||
|
else:
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class GooglePalm(BaseLLM, BaseModel):
|
||||||
|
client: Any #: :meta private:
|
||||||
|
google_api_key: Optional[str]
|
||||||
|
model_name: str = "models/text-bison-001"
|
||||||
|
"""Model name to use."""
|
||||||
|
temperature: float = 0.7
|
||||||
|
"""Run inference with this temperature. Must by in the closed interval
|
||||||
|
[0.0, 1.0]."""
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||||
|
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||||
|
Must be positive."""
|
||||||
|
max_output_tokens: Optional[int] = None
|
||||||
|
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
|
||||||
|
If unset, will default to 64."""
|
||||||
|
n: int = 1
|
||||||
|
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||||
|
not return the full n completions if duplicates are generated."""
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate api key, python package exists."""
|
||||||
|
google_api_key = get_from_dict_or_env(
|
||||||
|
values, "google_api_key", "GOOGLE_API_KEY"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
genai.configure(api_key=google_api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Could not import google.generativeai python package.")
|
||||||
|
|
||||||
|
values["client"] = genai
|
||||||
|
|
||||||
|
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||||
|
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||||
|
|
||||||
|
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||||
|
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||||
|
|
||||||
|
if values["top_k"] is not None and values["top_k"] <= 0:
|
||||||
|
raise ValueError("top_k must be positive")
|
||||||
|
|
||||||
|
if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
|
||||||
|
raise ValueError("max_output_tokens must be greater than zero")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||||
|
) -> LLMResult:
|
||||||
|
generations = []
|
||||||
|
for prompt in prompts:
|
||||||
|
completion = self.client.generate_text(
|
||||||
|
model=self.model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
stop_sequences=stop,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
max_output_tokens=self.max_output_tokens,
|
||||||
|
candidate_count=self.n,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_generations = []
|
||||||
|
for candidate in completion.candidates:
|
||||||
|
raw_text = candidate["output"]
|
||||||
|
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||||
|
prompt_generations.append(Generation(text=stripped_text))
|
||||||
|
generations.append(prompt_generations)
|
||||||
|
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||||
|
) -> LLMResult:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "google_palm"
|
@ -0,0 +1,81 @@
|
|||||||
|
"""Test Google PaLM Chat API wrapper.
|
||||||
|
|
||||||
|
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
|
||||||
|
valid API key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatGooglePalm
|
||||||
|
from langchain.schema import (
|
||||||
|
BaseMessage,
|
||||||
|
ChatGeneration,
|
||||||
|
ChatResult,
|
||||||
|
HumanMessage,
|
||||||
|
LLMResult,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_google_palm() -> None:
|
||||||
|
"""Test Google PaLM Chat API wrapper."""
|
||||||
|
chat = ChatGooglePalm()
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat([message])
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_google_palm_system_message() -> None:
|
||||||
|
"""Test Google PaLM Chat API wrapper with system message."""
|
||||||
|
chat = ChatGooglePalm()
|
||||||
|
system_message = SystemMessage(content="You are to chat with the user.")
|
||||||
|
human_message = HumanMessage(content="Hello")
|
||||||
|
response = chat([system_message, human_message])
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_google_palm_generate() -> None:
|
||||||
|
"""Test Google PaLM Chat API wrapper with generate."""
|
||||||
|
chat = ChatGooglePalm(n=2, temperature=1.0)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat.generate([[message], [message]])
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 2
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_google_palm_multiple_completions() -> None:
|
||||||
|
"""Test Google PaLM Chat API wrapper with multiple completions."""
|
||||||
|
# The API de-dupes duplicate responses, so set temperature higher. This
|
||||||
|
# could be a flakey test though...
|
||||||
|
chat = ChatGooglePalm(n=5, temperature=1.0)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat._generate([message])
|
||||||
|
assert isinstance(response, ChatResult)
|
||||||
|
assert len(response.generations) == 5
|
||||||
|
for generation in response.generations:
|
||||||
|
assert isinstance(generation.message, BaseMessage)
|
||||||
|
assert isinstance(generation.message.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_google_palm() -> None:
|
||||||
|
"""Test async generation."""
|
||||||
|
chat = ChatGooglePalm(n=2, temperature=1.0)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = await chat.agenerate([[message], [message]])
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 2
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
@ -0,0 +1,34 @@
|
|||||||
|
"""Test Google PaLM embeddings.
|
||||||
|
|
||||||
|
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
|
||||||
|
valid API key.
|
||||||
|
"""
|
||||||
|
from langchain.embeddings.google_palm import GooglePalmEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_palm_embedding_documents() -> None:
|
||||||
|
"""Test Google PaLM embeddings."""
|
||||||
|
documents = ["foo bar"]
|
||||||
|
embedding = GooglePalmEmbeddings()
|
||||||
|
output = embedding.embed_documents(documents)
|
||||||
|
assert len(output) == 1
|
||||||
|
assert len(output[0]) == 768
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_palm_embedding_documents_multiple() -> None:
|
||||||
|
"""Test Google PaLM embeddings."""
|
||||||
|
documents = ["foo bar", "bar foo", "foo"]
|
||||||
|
embedding = GooglePalmEmbeddings()
|
||||||
|
output = embedding.embed_documents(documents)
|
||||||
|
assert len(output) == 3
|
||||||
|
assert len(output[0]) == 768
|
||||||
|
assert len(output[1]) == 768
|
||||||
|
assert len(output[2]) == 768
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_palm_embedding_query() -> None:
|
||||||
|
"""Test Google PaLM embeddings."""
|
||||||
|
document = "foo bar"
|
||||||
|
embedding = GooglePalmEmbeddings()
|
||||||
|
output = embedding.embed_query(document)
|
||||||
|
assert len(output) == 768
|
@ -0,0 +1,25 @@
|
|||||||
|
"""Test Google PaLM Text API wrapper.
|
||||||
|
|
||||||
|
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
|
||||||
|
valid API key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from langchain.llms.google_palm import GooglePalm
|
||||||
|
from langchain.llms.loading import load_llm
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_palm_call() -> None:
|
||||||
|
"""Test valid call to Google PaLM text API."""
|
||||||
|
llm = GooglePalm(max_output_tokens=10)
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||||
|
"""Test saving/loading a Google PaLM LLM."""
|
||||||
|
llm = GooglePalm(max_output_tokens=10)
|
||||||
|
llm.save(file_path=tmp_path / "google_palm.yaml")
|
||||||
|
loaded_llm = load_llm(tmp_path / "google_palm.yaml")
|
||||||
|
assert loaded_llm == llm
|
@ -0,0 +1,114 @@
|
|||||||
|
"""Test Google PaLM Chat API wrapper."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.chat_models.google_palm import (
|
||||||
|
ChatGooglePalm,
|
||||||
|
ChatGooglePalmError,
|
||||||
|
_messages_to_prompt_dict,
|
||||||
|
)
|
||||||
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_to_prompt_dict_with_valid_messages() -> None:
|
||||||
|
pytest.importorskip("google.generativeai")
|
||||||
|
result = _messages_to_prompt_dict(
|
||||||
|
[
|
||||||
|
SystemMessage(content="Prompt"),
|
||||||
|
HumanMessage(example=True, content="Human example #1"),
|
||||||
|
AIMessage(example=True, content="AI example #1"),
|
||||||
|
HumanMessage(example=True, content="Human example #2"),
|
||||||
|
AIMessage(example=True, content="AI example #2"),
|
||||||
|
HumanMessage(content="Real human message"),
|
||||||
|
AIMessage(content="Real AI message"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
expected = {
|
||||||
|
"context": "Prompt",
|
||||||
|
"examples": [
|
||||||
|
{"author": "human", "content": "Human example #1"},
|
||||||
|
{"author": "ai", "content": "AI example #1"},
|
||||||
|
{"author": "human", "content": "Human example #2"},
|
||||||
|
{"author": "ai", "content": "AI example #2"},
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{"author": "human", "content": "Real human message"},
|
||||||
|
{"author": "ai", "content": "Real AI message"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_to_prompt_dict_raises_with_misplaced_system_message() -> None:
|
||||||
|
pytest.importorskip("google.generativeai")
|
||||||
|
with pytest.raises(ChatGooglePalmError) as e:
|
||||||
|
_messages_to_prompt_dict(
|
||||||
|
[
|
||||||
|
HumanMessage(content="Real human message"),
|
||||||
|
SystemMessage(content="Prompt"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert "System message must be first" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_to_prompt_dict_raises_with_misordered_examples() -> None:
|
||||||
|
pytest.importorskip("google.generativeai")
|
||||||
|
with pytest.raises(ChatGooglePalmError) as e:
|
||||||
|
_messages_to_prompt_dict(
|
||||||
|
[
|
||||||
|
AIMessage(example=True, content="AI example #1"),
|
||||||
|
HumanMessage(example=True, content="Human example #1"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert "AI example message must be immediately preceded" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_to_prompt_dict_raises_with_mismatched_examples() -> None:
|
||||||
|
pytest.importorskip("google.generativeai")
|
||||||
|
with pytest.raises(ChatGooglePalmError) as e:
|
||||||
|
_messages_to_prompt_dict(
|
||||||
|
[
|
||||||
|
HumanMessage(example=True, content="Human example #1"),
|
||||||
|
AIMessage(example=False, content="AI example #1"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert "Human example message must be immediately followed" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_to_prompt_dict_raises_with_example_after_real() -> None:
|
||||||
|
pytest.importorskip("google.generativeai")
|
||||||
|
with pytest.raises(ChatGooglePalmError) as e:
|
||||||
|
_messages_to_prompt_dict(
|
||||||
|
[
|
||||||
|
HumanMessage(example=False, content="Real message"),
|
||||||
|
HumanMessage(example=True, content="Human example #1"),
|
||||||
|
AIMessage(example=True, content="AI example #1"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert "Message examples must come before other" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_google_raises_with_invalid_temperature() -> None:
|
||||||
|
pytest.importorskip("google.generativeai")
|
||||||
|
with pytest.raises(ValueError) as e:
|
||||||
|
ChatGooglePalm(google_api_key="fake", temperature=2.0)
|
||||||
|
assert "must be in the range" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_google_raises_with_invalid_top_p() -> None:
|
||||||
|
pytest.importorskip("google.generativeai")
|
||||||
|
with pytest.raises(ValueError) as e:
|
||||||
|
ChatGooglePalm(google_api_key="fake", top_p=2.0)
|
||||||
|
assert "must be in the range" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_google_raises_with_invalid_top_k() -> None:
|
||||||
|
pytest.importorskip("google.generativeai")
|
||||||
|
with pytest.raises(ValueError) as e:
|
||||||
|
ChatGooglePalm(google_api_key="fake", top_k=-5)
|
||||||
|
assert "must be positive" in str(e)
|
Loading…
Reference in New Issue