diff --git a/langchain/chat_models/__init__.py b/langchain/chat_models/__init__.py index fdfe5e7d..11322b04 100644 --- a/langchain/chat_models/__init__.py +++ b/langchain/chat_models/__init__.py @@ -1,6 +1,13 @@ from langchain.chat_models.anthropic import ChatAnthropic from langchain.chat_models.azure_openai import AzureChatOpenAI +from langchain.chat_models.google_palm import ChatGooglePalm from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI -__all__ = ["ChatOpenAI", "AzureChatOpenAI", "PromptLayerChatOpenAI", "ChatAnthropic"] +__all__ = [ + "ChatOpenAI", + "AzureChatOpenAI", + "PromptLayerChatOpenAI", + "ChatAnthropic", + "ChatGooglePalm", +] diff --git a/langchain/chat_models/google_palm.py b/langchain/chat_models/google_palm.py new file mode 100644 index 00000000..ee0acd3c --- /dev/null +++ b/langchain/chat_models/google_palm.py @@ -0,0 +1,248 @@ +"""Wrapper around Google's PaLM Chat API.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from pydantic import BaseModel, root_validator + +from langchain.chat_models.base import BaseChatModel +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + ChatMessage, + ChatResult, + HumanMessage, + SystemMessage, +) +from langchain.utils import get_from_dict_or_env + +if TYPE_CHECKING: + import google.generativeai as genai + + +class ChatGooglePalmError(Exception): + pass + + +def _truncate_at_stop_tokens( + text: str, + stop: Optional[List[str]], +) -> str: + """Truncates text at the earliest stop token found.""" + if stop is None: + return text + + for stop_token in stop: + stop_token_idx = text.find(stop_token) + if stop_token_idx != -1: + text = text[:stop_token_idx] + return text + + +def _response_to_result( + response: genai.types.ChatResponse, + stop: Optional[List[str]], +) -> ChatResult: + """Converts a PaLM API response into a LangChain ChatResult.""" + if not response.candidates: + raise ChatGooglePalmError("ChatResponse must have at least one candidate.") + + generations: List[ChatGeneration] = [] + for candidate in response.candidates: + author = candidate.get("author") + if author is None: + raise ChatGooglePalmError(f"ChatResponse must have an author: {candidate}") + + content = _truncate_at_stop_tokens(candidate.get("content", ""), stop) + if content is None: + raise ChatGooglePalmError(f"ChatResponse must have a content: {candidate}") + + if author == "ai": + generations.append( + ChatGeneration(text=content, message=AIMessage(content=content)) + ) + elif author == "human": + generations.append( + ChatGeneration( + text=content, + message=HumanMessage(content=content), + ) + ) + else: + generations.append( + ChatGeneration( + text=content, + message=ChatMessage(role=author, content=content), + ) + ) + + return ChatResult(generations=generations) + + +def _messages_to_prompt_dict( + input_messages: List[BaseMessage], +) -> genai.types.MessagePromptDict: + """Converts a list of LangChain messages into a PaLM API MessagePrompt structure.""" + import google.generativeai as genai + + context: str = "" + examples: List[genai.types.MessageDict] = [] + messages: List[genai.types.MessageDict] = [] + + remaining = list(enumerate(input_messages)) + + while remaining: + index, input_message = remaining.pop(0) + + if isinstance(input_message, SystemMessage): + if index != 0: + raise ChatGooglePalmError("System message must be first input message.") + context = input_message.content + elif isinstance(input_message, HumanMessage) and input_message.example: + if messages: + raise ChatGooglePalmError( + "Message examples must come before other messages." + ) + _, next_input_message = remaining.pop(0) + if isinstance(next_input_message, AIMessage) and next_input_message.example: + examples.extend( + [ + genai.types.MessageDict( + author="human", content=input_message.content + ), + genai.types.MessageDict( + author="ai", content=next_input_message.content + ), + ] + ) + else: + raise ChatGooglePalmError( + "Human example message must be immediately followed by an " + " AI example response." + ) + elif isinstance(input_message, AIMessage) and input_message.example: + raise ChatGooglePalmError( + "AI example message must be immediately preceded by a Human " + "example message." + ) + elif isinstance(input_message, AIMessage): + messages.append( + genai.types.MessageDict(author="ai", content=input_message.content) + ) + elif isinstance(input_message, HumanMessage): + messages.append( + genai.types.MessageDict(author="human", content=input_message.content) + ) + elif isinstance(input_message, ChatMessage): + messages.append( + genai.types.MessageDict( + author=input_message.role, content=input_message.content + ) + ) + else: + raise ChatGooglePalmError( + "Messages without an explicit role not supported by PaLM API." + ) + + return genai.types.MessagePromptDict( + context=context, + examples=examples, + messages=messages, + ) + + +class ChatGooglePalm(BaseChatModel, BaseModel): + """Wrapper around Google's PaLM Chat API. + + To use you must have the google.generativeai Python package installed and + either: + + 1. The ``GOOGLE_API_KEY``` environment varaible set with your API key, or + 2. Pass your API key using the google_api_key kwarg to the ChatGoogle + constructor. + + Example: + .. code-block:: python + + from langchain.chat_models import ChatGooglePalm + chat = ChatGooglePalm() + + """ + + client: Any #: :meta private: + model_name: str = "models/chat-bison-001" + """Model name to use.""" + google_api_key: Optional[str] = None + temperature: Optional[float] = None + """Run inference with this temperature. Must by in the closed + interval [0.0, 1.0].""" + top_p: Optional[float] = None + """Decode using nucleus sampling: consider the smallest set of tokens whose + probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" + top_k: Optional[int] = None + """Decode using top-k sampling: consider the set of top_k most probable tokens. + Must be positive.""" + n: int = 1 + """Number of chat completions to generate for each prompt. Note that the API may + not return the full n completions if duplicates are generated.""" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate api key, python package exists, temperature, top_p, and top_k.""" + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) + try: + import google.generativeai as genai + + genai.configure(api_key=google_api_key) + except ImportError: + raise ChatGooglePalmError( + "Could not import google.generativeai python package." + ) + + values["client"] = genai + + if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: + raise ValueError("temperature must be in the range [0.0, 1.0]") + + if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: + raise ValueError("top_p must be in the range [0.0, 1.0]") + + if values["top_k"] is not None and values["top_k"] <= 0: + raise ValueError("top_k must be positive") + + return values + + def _generate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None + ) -> ChatResult: + prompt = _messages_to_prompt_dict(messages) + + response: genai.types.ChatResponse = self.client.chat( + model=self.model_name, + prompt=prompt, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + candidate_count=self.n, + ) + + return _response_to_result(response, stop) + + async def _agenerate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None + ) -> ChatResult: + prompt = _messages_to_prompt_dict(messages) + + response: genai.types.ChatResponse = await self.client.chat_async( + model=self.model_name, + prompt=prompt, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + candidate_count=self.n, + ) + + return _response_to_result(response, stop) diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index edcd11ff..1e123f12 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -8,6 +8,7 @@ from langchain.embeddings.aleph_alpha import ( ) from langchain.embeddings.cohere import CohereEmbeddings from langchain.embeddings.fake import FakeEmbeddings +from langchain.embeddings.google_palm import GooglePalmEmbeddings from langchain.embeddings.huggingface import ( HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, @@ -44,6 +45,7 @@ __all__ = [ "AlephAlphaAsymmetricSemanticEmbedding", "AlephAlphaSymmetricSemanticEmbedding", "SentenceTransformerEmbeddings", + "GooglePalmEmbeddings", ] diff --git a/langchain/embeddings/google_palm.py b/langchain/embeddings/google_palm.py new file mode 100644 index 00000000..0d198137 --- /dev/null +++ b/langchain/embeddings/google_palm.py @@ -0,0 +1,38 @@ +"""Wrapper arround Google's PaLM Embeddings APIs.""" +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, root_validator + +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env + + +class GooglePalmEmbeddings(BaseModel, Embeddings): + client: Any + google_api_key: Optional[str] + model_name: str = "models/embedding-gecko-001" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate api key, python package exists.""" + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) + try: + import google.generativeai as genai + + genai.configure(api_key=google_api_key) + except ImportError: + raise ImportError("Could not import google.generativeai python package.") + + values["client"] = genai + + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + """Embed query text.""" + embedding = self.client.generate_embeddings(self.model_name, text) + return embedding["embedding"] diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index ecb7a56e..d928af05 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -10,6 +10,7 @@ from langchain.llms.cerebriumai import CerebriumAI from langchain.llms.cohere import Cohere from langchain.llms.deepinfra import DeepInfra from langchain.llms.forefrontai import ForefrontAI +from langchain.llms.google_palm import GooglePalm from langchain.llms.gooseai import GooseAI from langchain.llms.gpt4all import GPT4All from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint @@ -39,6 +40,7 @@ __all__ = [ "Cohere", "DeepInfra", "ForefrontAI", + "GooglePalm", "GooseAI", "GPT4All", "LlamaCpp", @@ -74,6 +76,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "cohere": Cohere, "deepinfra": DeepInfra, "forefrontai": ForefrontAI, + "google_palm": GooglePalm, "gooseai": GooseAI, "gpt4all": GPT4All, "huggingface_hub": HuggingFaceHub, diff --git a/langchain/llms/google_palm.py b/langchain/llms/google_palm.py new file mode 100644 index 00000000..2b71535b --- /dev/null +++ b/langchain/llms/google_palm.py @@ -0,0 +1,109 @@ +"""Wrapper arround Google's PaLM Text APIs.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, root_validator + +from langchain.llms import BaseLLM +from langchain.schema import Generation, LLMResult +from langchain.utils import get_from_dict_or_env + + +def _strip_erroneous_leading_spaces(text: str) -> str: + """Strip erroneous leading spaces from text. + + The PaLM API will sometimes erroneously return a single leading space in all + lines > 1. This function strips that space. + """ + has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:]) + if has_leading_space: + return text.replace("\n ", "\n") + else: + return text + + +class GooglePalm(BaseLLM, BaseModel): + client: Any #: :meta private: + google_api_key: Optional[str] + model_name: str = "models/text-bison-001" + """Model name to use.""" + temperature: float = 0.7 + """Run inference with this temperature. Must by in the closed interval + [0.0, 1.0].""" + top_p: Optional[float] = None + """Decode using nucleus sampling: consider the smallest set of tokens whose + probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" + top_k: Optional[int] = None + """Decode using top-k sampling: consider the set of top_k most probable tokens. + Must be positive.""" + max_output_tokens: Optional[int] = None + """Maximum number of tokens to include in a candidate. Must be greater than zero. + If unset, will default to 64.""" + n: int = 1 + """Number of chat completions to generate for each prompt. Note that the API may + not return the full n completions if duplicates are generated.""" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate api key, python package exists.""" + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) + try: + import google.generativeai as genai + + genai.configure(api_key=google_api_key) + except ImportError: + raise ImportError("Could not import google.generativeai python package.") + + values["client"] = genai + + if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: + raise ValueError("temperature must be in the range [0.0, 1.0]") + + if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: + raise ValueError("top_p must be in the range [0.0, 1.0]") + + if values["top_k"] is not None and values["top_k"] <= 0: + raise ValueError("top_k must be positive") + + if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0: + raise ValueError("max_output_tokens must be greater than zero") + + return values + + def _generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + generations = [] + for prompt in prompts: + completion = self.client.generate_text( + model=self.model_name, + prompt=prompt, + stop_sequences=stop, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + max_output_tokens=self.max_output_tokens, + candidate_count=self.n, + ) + + prompt_generations = [] + for candidate in completion.candidates: + raw_text = candidate["output"] + stripped_text = _strip_erroneous_leading_spaces(raw_text) + prompt_generations.append(Generation(text=stripped_text)) + generations.append(prompt_generations) + + return LLMResult(generations=generations) + + async def _agenerate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + raise NotImplementedError() + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "google_palm" diff --git a/langchain/schema.py b/langchain/schema.py index ced8b383..ac248c9d 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -79,6 +79,8 @@ class BaseMessage(BaseModel): class HumanMessage(BaseMessage): """Type of message that is spoken by the human.""" + example: bool = False + @property def type(self) -> str: """Type of the message, used for serialization.""" @@ -88,6 +90,8 @@ class HumanMessage(BaseMessage): class AIMessage(BaseMessage): """Type of message that is spoken by the AI.""" + example: bool = False + @property def type(self) -> str: """Type of the message, used for serialization.""" diff --git a/tests/integration_tests/chat_models/test_google_palm.py b/tests/integration_tests/chat_models/test_google_palm.py new file mode 100644 index 00000000..a95419e6 --- /dev/null +++ b/tests/integration_tests/chat_models/test_google_palm.py @@ -0,0 +1,81 @@ +"""Test Google PaLM Chat API wrapper. + +Note: This test must be run with the GOOGLE_API_KEY environment variable set to a + valid API key. +""" + +import pytest + +from langchain.chat_models import ChatGooglePalm +from langchain.schema import ( + BaseMessage, + ChatGeneration, + ChatResult, + HumanMessage, + LLMResult, + SystemMessage, +) + + +def test_chat_google_palm() -> None: + """Test Google PaLM Chat API wrapper.""" + chat = ChatGooglePalm() + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_chat_google_palm_system_message() -> None: + """Test Google PaLM Chat API wrapper with system message.""" + chat = ChatGooglePalm() + system_message = SystemMessage(content="You are to chat with the user.") + human_message = HumanMessage(content="Hello") + response = chat([system_message, human_message]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_chat_google_palm_generate() -> None: + """Test Google PaLM Chat API wrapper with generate.""" + chat = ChatGooglePalm(n=2, temperature=1.0) + message = HumanMessage(content="Hello") + response = chat.generate([[message], [message]]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + assert len(generations) == 2 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + + +def test_chat_google_palm_multiple_completions() -> None: + """Test Google PaLM Chat API wrapper with multiple completions.""" + # The API de-dupes duplicate responses, so set temperature higher. This + # could be a flakey test though... + chat = ChatGooglePalm(n=5, temperature=1.0) + message = HumanMessage(content="Hello") + response = chat._generate([message]) + assert isinstance(response, ChatResult) + assert len(response.generations) == 5 + for generation in response.generations: + assert isinstance(generation.message, BaseMessage) + assert isinstance(generation.message.content, str) + + +@pytest.mark.asyncio +async def test_async_chat_google_palm() -> None: + """Test async generation.""" + chat = ChatGooglePalm(n=2, temperature=1.0) + message = HumanMessage(content="Hello") + response = await chat.agenerate([[message], [message]]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + assert len(generations) == 2 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content diff --git a/tests/integration_tests/embeddings/test_google_palm.py b/tests/integration_tests/embeddings/test_google_palm.py new file mode 100644 index 00000000..251c266c --- /dev/null +++ b/tests/integration_tests/embeddings/test_google_palm.py @@ -0,0 +1,34 @@ +"""Test Google PaLM embeddings. + +Note: This test must be run with the GOOGLE_API_KEY environment variable set to a + valid API key. +""" +from langchain.embeddings.google_palm import GooglePalmEmbeddings + + +def test_google_palm_embedding_documents() -> None: + """Test Google PaLM embeddings.""" + documents = ["foo bar"] + embedding = GooglePalmEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 768 + + +def test_google_palm_embedding_documents_multiple() -> None: + """Test Google PaLM embeddings.""" + documents = ["foo bar", "bar foo", "foo"] + embedding = GooglePalmEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 3 + assert len(output[0]) == 768 + assert len(output[1]) == 768 + assert len(output[2]) == 768 + + +def test_google_palm_embedding_query() -> None: + """Test Google PaLM embeddings.""" + document = "foo bar" + embedding = GooglePalmEmbeddings() + output = embedding.embed_query(document) + assert len(output) == 768 diff --git a/tests/integration_tests/llms/test_google_palm.py b/tests/integration_tests/llms/test_google_palm.py new file mode 100644 index 00000000..ca02b185 --- /dev/null +++ b/tests/integration_tests/llms/test_google_palm.py @@ -0,0 +1,25 @@ +"""Test Google PaLM Text API wrapper. + +Note: This test must be run with the GOOGLE_API_KEY environment variable set to a + valid API key. +""" + +from pathlib import Path + +from langchain.llms.google_palm import GooglePalm +from langchain.llms.loading import load_llm + + +def test_google_palm_call() -> None: + """Test valid call to Google PaLM text API.""" + llm = GooglePalm(max_output_tokens=10) + output = llm("Say foo:") + assert isinstance(output, str) + + +def test_saving_loading_llm(tmp_path: Path) -> None: + """Test saving/loading a Google PaLM LLM.""" + llm = GooglePalm(max_output_tokens=10) + llm.save(file_path=tmp_path / "google_palm.yaml") + loaded_llm = load_llm(tmp_path / "google_palm.yaml") + assert loaded_llm == llm diff --git a/tests/unit_tests/chat_models/__init__.py b/tests/unit_tests/chat_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/chat_models/test_google_palm.py b/tests/unit_tests/chat_models/test_google_palm.py new file mode 100644 index 00000000..0ca7fb4e --- /dev/null +++ b/tests/unit_tests/chat_models/test_google_palm.py @@ -0,0 +1,114 @@ +"""Test Google PaLM Chat API wrapper.""" + +import pytest + +from langchain.chat_models.google_palm import ( + ChatGooglePalm, + ChatGooglePalmError, + _messages_to_prompt_dict, +) +from langchain.schema import ( + AIMessage, + HumanMessage, + SystemMessage, +) + + +def test_messages_to_prompt_dict_with_valid_messages() -> None: + pytest.importorskip("google.generativeai") + result = _messages_to_prompt_dict( + [ + SystemMessage(content="Prompt"), + HumanMessage(example=True, content="Human example #1"), + AIMessage(example=True, content="AI example #1"), + HumanMessage(example=True, content="Human example #2"), + AIMessage(example=True, content="AI example #2"), + HumanMessage(content="Real human message"), + AIMessage(content="Real AI message"), + ] + ) + expected = { + "context": "Prompt", + "examples": [ + {"author": "human", "content": "Human example #1"}, + {"author": "ai", "content": "AI example #1"}, + {"author": "human", "content": "Human example #2"}, + {"author": "ai", "content": "AI example #2"}, + ], + "messages": [ + {"author": "human", "content": "Real human message"}, + {"author": "ai", "content": "Real AI message"}, + ], + } + + assert result == expected + + +def test_messages_to_prompt_dict_raises_with_misplaced_system_message() -> None: + pytest.importorskip("google.generativeai") + with pytest.raises(ChatGooglePalmError) as e: + _messages_to_prompt_dict( + [ + HumanMessage(content="Real human message"), + SystemMessage(content="Prompt"), + ] + ) + assert "System message must be first" in str(e) + + +def test_messages_to_prompt_dict_raises_with_misordered_examples() -> None: + pytest.importorskip("google.generativeai") + with pytest.raises(ChatGooglePalmError) as e: + _messages_to_prompt_dict( + [ + AIMessage(example=True, content="AI example #1"), + HumanMessage(example=True, content="Human example #1"), + ] + ) + assert "AI example message must be immediately preceded" in str(e) + + +def test_messages_to_prompt_dict_raises_with_mismatched_examples() -> None: + pytest.importorskip("google.generativeai") + with pytest.raises(ChatGooglePalmError) as e: + _messages_to_prompt_dict( + [ + HumanMessage(example=True, content="Human example #1"), + AIMessage(example=False, content="AI example #1"), + ] + ) + assert "Human example message must be immediately followed" in str(e) + + +def test_messages_to_prompt_dict_raises_with_example_after_real() -> None: + pytest.importorskip("google.generativeai") + with pytest.raises(ChatGooglePalmError) as e: + _messages_to_prompt_dict( + [ + HumanMessage(example=False, content="Real message"), + HumanMessage(example=True, content="Human example #1"), + AIMessage(example=True, content="AI example #1"), + ] + ) + assert "Message examples must come before other" in str(e) + + +def test_chat_google_raises_with_invalid_temperature() -> None: + pytest.importorskip("google.generativeai") + with pytest.raises(ValueError) as e: + ChatGooglePalm(google_api_key="fake", temperature=2.0) + assert "must be in the range" in str(e) + + +def test_chat_google_raises_with_invalid_top_p() -> None: + pytest.importorskip("google.generativeai") + with pytest.raises(ValueError) as e: + ChatGooglePalm(google_api_key="fake", top_p=2.0) + assert "must be in the range" in str(e) + + +def test_chat_google_raises_with_invalid_top_k() -> None: + pytest.importorskip("google.generativeai") + with pytest.raises(ValueError) as e: + ChatGooglePalm(google_api_key="fake", top_k=-5) + assert "must be positive" in str(e)