Add ChatModel, LLM, and Embeddings for Google's PaLM APIs (#3575)

- Add langchain.llms.GooglePalm for text completion,
 - Add langchain.chat_models.ChatGooglePalm for chat completion,
- Add langchain.embeddings.GooglePalmEmbeddings for sentence embeddings,
- Add example field to HumanMessage and AIMessage so that users can feed
in examples into the PaLM Chat API,
 - Add system and unit tests.

Note async completion for the Text API is not yet supported and will be
included in a future PR.

Happy for feedback on any aspect of this PR, especially our choice of
adding an example field to Human and AI Message objects to enable
passing example messages to the API.
fix_agent_callbacks
James Brotchie 1 year ago committed by GitHub
parent d15f481352
commit 921894960b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,13 @@
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.google_palm import ChatGooglePalm
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
__all__ = ["ChatOpenAI", "AzureChatOpenAI", "PromptLayerChatOpenAI", "ChatAnthropic"]
__all__ = [
"ChatOpenAI",
"AzureChatOpenAI",
"PromptLayerChatOpenAI",
"ChatAnthropic",
"ChatGooglePalm",
]

@ -0,0 +1,248 @@
"""Wrapper around Google's PaLM Chat API."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import BaseModel, root_validator
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
HumanMessage,
SystemMessage,
)
from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
import google.generativeai as genai
class ChatGooglePalmError(Exception):
pass
def _truncate_at_stop_tokens(
text: str,
stop: Optional[List[str]],
) -> str:
"""Truncates text at the earliest stop token found."""
if stop is None:
return text
for stop_token in stop:
stop_token_idx = text.find(stop_token)
if stop_token_idx != -1:
text = text[:stop_token_idx]
return text
def _response_to_result(
response: genai.types.ChatResponse,
stop: Optional[List[str]],
) -> ChatResult:
"""Converts a PaLM API response into a LangChain ChatResult."""
if not response.candidates:
raise ChatGooglePalmError("ChatResponse must have at least one candidate.")
generations: List[ChatGeneration] = []
for candidate in response.candidates:
author = candidate.get("author")
if author is None:
raise ChatGooglePalmError(f"ChatResponse must have an author: {candidate}")
content = _truncate_at_stop_tokens(candidate.get("content", ""), stop)
if content is None:
raise ChatGooglePalmError(f"ChatResponse must have a content: {candidate}")
if author == "ai":
generations.append(
ChatGeneration(text=content, message=AIMessage(content=content))
)
elif author == "human":
generations.append(
ChatGeneration(
text=content,
message=HumanMessage(content=content),
)
)
else:
generations.append(
ChatGeneration(
text=content,
message=ChatMessage(role=author, content=content),
)
)
return ChatResult(generations=generations)
def _messages_to_prompt_dict(
input_messages: List[BaseMessage],
) -> genai.types.MessagePromptDict:
"""Converts a list of LangChain messages into a PaLM API MessagePrompt structure."""
import google.generativeai as genai
context: str = ""
examples: List[genai.types.MessageDict] = []
messages: List[genai.types.MessageDict] = []
remaining = list(enumerate(input_messages))
while remaining:
index, input_message = remaining.pop(0)
if isinstance(input_message, SystemMessage):
if index != 0:
raise ChatGooglePalmError("System message must be first input message.")
context = input_message.content
elif isinstance(input_message, HumanMessage) and input_message.example:
if messages:
raise ChatGooglePalmError(
"Message examples must come before other messages."
)
_, next_input_message = remaining.pop(0)
if isinstance(next_input_message, AIMessage) and next_input_message.example:
examples.extend(
[
genai.types.MessageDict(
author="human", content=input_message.content
),
genai.types.MessageDict(
author="ai", content=next_input_message.content
),
]
)
else:
raise ChatGooglePalmError(
"Human example message must be immediately followed by an "
" AI example response."
)
elif isinstance(input_message, AIMessage) and input_message.example:
raise ChatGooglePalmError(
"AI example message must be immediately preceded by a Human "
"example message."
)
elif isinstance(input_message, AIMessage):
messages.append(
genai.types.MessageDict(author="ai", content=input_message.content)
)
elif isinstance(input_message, HumanMessage):
messages.append(
genai.types.MessageDict(author="human", content=input_message.content)
)
elif isinstance(input_message, ChatMessage):
messages.append(
genai.types.MessageDict(
author=input_message.role, content=input_message.content
)
)
else:
raise ChatGooglePalmError(
"Messages without an explicit role not supported by PaLM API."
)
return genai.types.MessagePromptDict(
context=context,
examples=examples,
messages=messages,
)
class ChatGooglePalm(BaseChatModel, BaseModel):
"""Wrapper around Google's PaLM Chat API.
To use you must have the google.generativeai Python package installed and
either:
1. The ``GOOGLE_API_KEY``` environment varaible set with your API key, or
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
constructor.
Example:
.. code-block:: python
from langchain.chat_models import ChatGooglePalm
chat = ChatGooglePalm()
"""
client: Any #: :meta private:
model_name: str = "models/chat-bison-001"
"""Model name to use."""
google_api_key: Optional[str] = None
temperature: Optional[float] = None
"""Run inference with this temperature. Must by in the closed
interval [0.0, 1.0]."""
top_p: Optional[float] = None
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
n: int = 1
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
try:
import google.generativeai as genai
genai.configure(api_key=google_api_key)
except ImportError:
raise ChatGooglePalmError(
"Could not import google.generativeai python package."
)
values["client"] = genai
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if values["top_k"] is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
return values
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = self.client.chat(
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
)
return _response_to_result(response, stop)
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = await self.client.chat_async(
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
)
return _response_to_result(response, stop)

@ -8,6 +8,7 @@ from langchain.embeddings.aleph_alpha import (
)
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.fake import FakeEmbeddings
from langchain.embeddings.google_palm import GooglePalmEmbeddings
from langchain.embeddings.huggingface import (
HuggingFaceEmbeddings,
HuggingFaceInstructEmbeddings,
@ -44,6 +45,7 @@ __all__ = [
"AlephAlphaAsymmetricSemanticEmbedding",
"AlephAlphaSymmetricSemanticEmbedding",
"SentenceTransformerEmbeddings",
"GooglePalmEmbeddings",
]

@ -0,0 +1,38 @@
"""Wrapper arround Google's PaLM Embeddings APIs."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
class GooglePalmEmbeddings(BaseModel, Embeddings):
client: Any
google_api_key: Optional[str]
model_name: str = "models/embedding-gecko-001"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists."""
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
try:
import google.generativeai as genai
genai.configure(api_key=google_api_key)
except ImportError:
raise ImportError("Could not import google.generativeai python package.")
values["client"] = genai
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
embedding = self.client.generate_embeddings(self.model_name, text)
return embedding["embedding"]

@ -10,6 +10,7 @@ from langchain.llms.cerebriumai import CerebriumAI
from langchain.llms.cohere import Cohere
from langchain.llms.deepinfra import DeepInfra
from langchain.llms.forefrontai import ForefrontAI
from langchain.llms.google_palm import GooglePalm
from langchain.llms.gooseai import GooseAI
from langchain.llms.gpt4all import GPT4All
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
@ -39,6 +40,7 @@ __all__ = [
"Cohere",
"DeepInfra",
"ForefrontAI",
"GooglePalm",
"GooseAI",
"GPT4All",
"LlamaCpp",
@ -74,6 +76,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"cohere": Cohere,
"deepinfra": DeepInfra,
"forefrontai": ForefrontAI,
"google_palm": GooglePalm,
"gooseai": GooseAI,
"gpt4all": GPT4All,
"huggingface_hub": HuggingFaceHub,

@ -0,0 +1,109 @@
"""Wrapper arround Google's PaLM Text APIs."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, root_validator
from langchain.llms import BaseLLM
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
def _strip_erroneous_leading_spaces(text: str) -> str:
"""Strip erroneous leading spaces from text.
The PaLM API will sometimes erroneously return a single leading space in all
lines > 1. This function strips that space.
"""
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
if has_leading_space:
return text.replace("\n ", "\n")
else:
return text
class GooglePalm(BaseLLM, BaseModel):
client: Any #: :meta private:
google_api_key: Optional[str]
model_name: str = "models/text-bison-001"
"""Model name to use."""
temperature: float = 0.7
"""Run inference with this temperature. Must by in the closed interval
[0.0, 1.0]."""
top_p: Optional[float] = None
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
max_output_tokens: Optional[int] = None
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
If unset, will default to 64."""
n: int = 1
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists."""
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
try:
import google.generativeai as genai
genai.configure(api_key=google_api_key)
except ImportError:
raise ImportError("Could not import google.generativeai python package.")
values["client"] = genai
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if values["top_k"] is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
raise ValueError("max_output_tokens must be greater than zero")
return values
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
generations = []
for prompt in prompts:
completion = self.client.generate_text(
model=self.model_name,
prompt=prompt,
stop_sequences=stop,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
max_output_tokens=self.max_output_tokens,
candidate_count=self.n,
)
prompt_generations = []
for candidate in completion.candidates:
raw_text = candidate["output"]
stripped_text = _strip_erroneous_leading_spaces(raw_text)
prompt_generations.append(Generation(text=stripped_text))
generations.append(prompt_generations)
return LLMResult(generations=generations)
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
raise NotImplementedError()
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "google_palm"

@ -79,6 +79,8 @@ class BaseMessage(BaseModel):
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
example: bool = False
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
@ -88,6 +90,8 @@ class HumanMessage(BaseMessage):
class AIMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
example: bool = False
@property
def type(self) -> str:
"""Type of the message, used for serialization."""

@ -0,0 +1,81 @@
"""Test Google PaLM Chat API wrapper.
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
valid API key.
"""
import pytest
from langchain.chat_models import ChatGooglePalm
from langchain.schema import (
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
LLMResult,
SystemMessage,
)
def test_chat_google_palm() -> None:
"""Test Google PaLM Chat API wrapper."""
chat = ChatGooglePalm()
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_google_palm_system_message() -> None:
"""Test Google PaLM Chat API wrapper with system message."""
chat = ChatGooglePalm()
system_message = SystemMessage(content="You are to chat with the user.")
human_message = HumanMessage(content="Hello")
response = chat([system_message, human_message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_google_palm_generate() -> None:
"""Test Google PaLM Chat API wrapper with generate."""
chat = ChatGooglePalm(n=2, temperature=1.0)
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 2
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
def test_chat_google_palm_multiple_completions() -> None:
"""Test Google PaLM Chat API wrapper with multiple completions."""
# The API de-dupes duplicate responses, so set temperature higher. This
# could be a flakey test though...
chat = ChatGooglePalm(n=5, temperature=1.0)
message = HumanMessage(content="Hello")
response = chat._generate([message])
assert isinstance(response, ChatResult)
assert len(response.generations) == 5
for generation in response.generations:
assert isinstance(generation.message, BaseMessage)
assert isinstance(generation.message.content, str)
@pytest.mark.asyncio
async def test_async_chat_google_palm() -> None:
"""Test async generation."""
chat = ChatGooglePalm(n=2, temperature=1.0)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 2
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

@ -0,0 +1,34 @@
"""Test Google PaLM embeddings.
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
valid API key.
"""
from langchain.embeddings.google_palm import GooglePalmEmbeddings
def test_google_palm_embedding_documents() -> None:
"""Test Google PaLM embeddings."""
documents = ["foo bar"]
embedding = GooglePalmEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768
def test_google_palm_embedding_documents_multiple() -> None:
"""Test Google PaLM embeddings."""
documents = ["foo bar", "bar foo", "foo"]
embedding = GooglePalmEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 3
assert len(output[0]) == 768
assert len(output[1]) == 768
assert len(output[2]) == 768
def test_google_palm_embedding_query() -> None:
"""Test Google PaLM embeddings."""
document = "foo bar"
embedding = GooglePalmEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 768

@ -0,0 +1,25 @@
"""Test Google PaLM Text API wrapper.
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
valid API key.
"""
from pathlib import Path
from langchain.llms.google_palm import GooglePalm
from langchain.llms.loading import load_llm
def test_google_palm_call() -> None:
"""Test valid call to Google PaLM text API."""
llm = GooglePalm(max_output_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading a Google PaLM LLM."""
llm = GooglePalm(max_output_tokens=10)
llm.save(file_path=tmp_path / "google_palm.yaml")
loaded_llm = load_llm(tmp_path / "google_palm.yaml")
assert loaded_llm == llm

@ -0,0 +1,114 @@
"""Test Google PaLM Chat API wrapper."""
import pytest
from langchain.chat_models.google_palm import (
ChatGooglePalm,
ChatGooglePalmError,
_messages_to_prompt_dict,
)
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
)
def test_messages_to_prompt_dict_with_valid_messages() -> None:
pytest.importorskip("google.generativeai")
result = _messages_to_prompt_dict(
[
SystemMessage(content="Prompt"),
HumanMessage(example=True, content="Human example #1"),
AIMessage(example=True, content="AI example #1"),
HumanMessage(example=True, content="Human example #2"),
AIMessage(example=True, content="AI example #2"),
HumanMessage(content="Real human message"),
AIMessage(content="Real AI message"),
]
)
expected = {
"context": "Prompt",
"examples": [
{"author": "human", "content": "Human example #1"},
{"author": "ai", "content": "AI example #1"},
{"author": "human", "content": "Human example #2"},
{"author": "ai", "content": "AI example #2"},
],
"messages": [
{"author": "human", "content": "Real human message"},
{"author": "ai", "content": "Real AI message"},
],
}
assert result == expected
def test_messages_to_prompt_dict_raises_with_misplaced_system_message() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
HumanMessage(content="Real human message"),
SystemMessage(content="Prompt"),
]
)
assert "System message must be first" in str(e)
def test_messages_to_prompt_dict_raises_with_misordered_examples() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
AIMessage(example=True, content="AI example #1"),
HumanMessage(example=True, content="Human example #1"),
]
)
assert "AI example message must be immediately preceded" in str(e)
def test_messages_to_prompt_dict_raises_with_mismatched_examples() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
HumanMessage(example=True, content="Human example #1"),
AIMessage(example=False, content="AI example #1"),
]
)
assert "Human example message must be immediately followed" in str(e)
def test_messages_to_prompt_dict_raises_with_example_after_real() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
HumanMessage(example=False, content="Real message"),
HumanMessage(example=True, content="Human example #1"),
AIMessage(example=True, content="AI example #1"),
]
)
assert "Message examples must come before other" in str(e)
def test_chat_google_raises_with_invalid_temperature() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ValueError) as e:
ChatGooglePalm(google_api_key="fake", temperature=2.0)
assert "must be in the range" in str(e)
def test_chat_google_raises_with_invalid_top_p() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ValueError) as e:
ChatGooglePalm(google_api_key="fake", top_p=2.0)
assert "must be in the range" in str(e)
def test_chat_google_raises_with_invalid_top_k() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ValueError) as e:
ChatGooglePalm(google_api_key="fake", top_k=-5)
assert "must be positive" in str(e)
Loading…
Cancel
Save