diff --git a/docs/examples/agents/custom_agent.ipynb b/docs/examples/agents/custom_agent.ipynb index b97921e439..cc6f135bec 100644 --- a/docs/examples/agents/custom_agent.ipynb +++ b/docs/examples/agents/custom_agent.ipynb @@ -48,7 +48,7 @@ "outputs": [], "source": [ "from langchain.agents import ZeroShotAgent, Tool\n", - "from langchain import OpenAI, SerpAPIChain, LLMChain" + "from langchain import OpenAI, SerpAPIWrapper, LLMChain" ] }, { @@ -58,7 +58,7 @@ "metadata": {}, "outputs": [], "source": [ - "search = SerpAPIChain()\n", + "search = SerpAPIWrapper()\n", "tools = [\n", " Tool(\n", " name = \"Search\",\n", diff --git a/docs/examples/agents/mrkl.ipynb b/docs/examples/agents/mrkl.ipynb index c0dcb817b3..71bc463ddc 100644 --- a/docs/examples/agents/mrkl.ipynb +++ b/docs/examples/agents/mrkl.ipynb @@ -26,7 +26,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain import LLMMathChain, OpenAI, SerpAPIChain, SQLDatabase, SQLDatabaseChain\n", + "from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, SQLDatabase, SQLDatabaseChain\n", "from langchain.agents import initialize_agent, Tool" ] }, @@ -38,7 +38,7 @@ "outputs": [], "source": [ "llm = OpenAI(temperature=0)\n", - "search = SerpAPIChain()\n", + "search = SerpAPIWrapper()\n", "llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n", "db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n", "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n", diff --git a/docs/examples/agents/self_ask_with_search.ipynb b/docs/examples/agents/self_ask_with_search.ipynb index 7be3b59fdc..d4a56ab2a8 100644 --- a/docs/examples/agents/self_ask_with_search.ipynb +++ b/docs/examples/agents/self_ask_with_search.ipynb @@ -45,11 +45,11 @@ } ], "source": [ - "from langchain import OpenAI, SerpAPIChain\n", + "from langchain import OpenAI, SerpAPIWrapper\n", "from langchain.agents import initialize_agent, Tool\n", "\n", "llm = OpenAI(temperature=0)\n", - "search = SerpAPIChain()\n", + "search = SerpAPIWrapper()\n", "tools = [\n", " Tool(\n", " name=\"Intermediate Answer\",\n", diff --git a/docs/examples/memory/agent_with_memory.ipynb b/docs/examples/memory/agent_with_memory.ipynb index fa93a7e0dc..7527907c15 100644 --- a/docs/examples/memory/agent_with_memory.ipynb +++ b/docs/examples/memory/agent_with_memory.ipynb @@ -29,7 +29,7 @@ "source": [ "from langchain.agents import ZeroShotAgent, Tool\n", "from langchain.chains.conversation.memory import ConversationBufferMemory\n", - "from langchain import OpenAI, SerpAPIChain, LLMChain" + "from langchain import OpenAI, SerpAPIWrapper, LLMChain" ] }, { @@ -39,7 +39,7 @@ "metadata": {}, "outputs": [], "source": [ - "search = SerpAPIChain()\n", + "search = SerpAPIWrapper()\n", "tools = [\n", " Tool(\n", " name = \"Search\",\n", diff --git a/docs/examples/model_laboratory.ipynb b/docs/examples/model_laboratory.ipynb index 13892a73d6..5649019788 100644 --- a/docs/examples/model_laboratory.ipynb +++ b/docs/examples/model_laboratory.ipynb @@ -135,14 +135,14 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain import SelfAskWithSearchChain, SerpAPIChain\n", + "from langchain import SelfAskWithSearchChain, SerpAPIWrapper\n", "\n", "open_ai_llm = OpenAI(temperature=0)\n", - "search = SerpAPIChain()\n", + "search = SerpAPIWrapper()\n", "self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n", "\n", "cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n", - "search = SerpAPIChain()\n", + "search = SerpAPIWrapper()\n", "self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)" ] }, diff --git a/docs/getting_started/agents.ipynb b/docs/getting_started/agents.ipynb index 51d0fe78bb..7f7b22aea4 100644 --- a/docs/getting_started/agents.ipynb +++ b/docs/getting_started/agents.ipynb @@ -77,9 +77,9 @@ "outputs": [], "source": [ "# Load the tool configs that are needed.\n", - "from langchain import LLMMathChain, SerpAPIChain\n", + "from langchain import LLMMathChain, SerpAPIWrapper\n", "llm = OpenAI(temperature=0)\n", - "search = SerpAPIChain()\n", + "search = SerpAPIWrapper()\n", "llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n", "tools = [\n", " Tool(\n", diff --git a/langchain/__init__.py b/langchain/__init__.py index dde7aa2259..35b6446ca6 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -12,7 +12,6 @@ from langchain.chains import ( LLMMathChain, PALChain, PythonChain, - SerpAPIChain, SQLDatabaseChain, VectorDBQA, ) @@ -24,6 +23,7 @@ from langchain.prompts import ( Prompt, PromptTemplate, ) +from langchain.serpapi import SerpAPIWrapper from langchain.sql_database import SQLDatabase from langchain.vectorstores import FAISS, ElasticVectorSearch @@ -32,7 +32,8 @@ __all__ = [ "LLMMathChain", "PythonChain", "SelfAskWithSearchChain", - "SerpAPIChain", + "SerpAPIWrapper", + "SerpAPIWrapper", "Cohere", "OpenAI", "BasePromptTemplate", diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 5474b2a907..28eac3db1d 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -131,10 +131,10 @@ class MRKLChain(ZeroShotAgent): Example: .. code-block:: python - from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain + from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, MRKLChain from langchain.chains.mrkl.base import ChainConfig llm = OpenAI(temperature=0) - search = SerpAPIChain() + search = SerpAPIWrapper() llm_math_chain = LLMMathChain(llm=llm) chains = [ ChainConfig( diff --git a/langchain/agents/self_ask_with_search/base.py b/langchain/agents/self_ask_with_search/base.py index ab0b6c62cf..1273308db5 100644 --- a/langchain/agents/self_ask_with_search/base.py +++ b/langchain/agents/self_ask_with_search/base.py @@ -5,9 +5,9 @@ from langchain.agents.agent import Agent from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.tools import Tool from langchain.chains.llm import LLMChain -from langchain.chains.serpapi import SerpAPIChain from langchain.llms.base import LLM from langchain.prompts.base import BasePromptTemplate +from langchain.serpapi import SerpAPIWrapper class SelfAskWithSearchAgent(Agent): @@ -73,12 +73,12 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent): Example: .. code-block:: python - from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain - search_chain = SerpAPIChain() + from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIWrapper + search_chain = SerpAPIWrapper() self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain) """ - def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any): + def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any): """Initialize with just an LLM and a search chain.""" search_tool = Tool(name="Intermediate Answer", func=search_chain.run) llm_chain = LLMChain(llm=llm, prompt=PROMPT) diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index eceb11d40a..62018b9393 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -5,7 +5,6 @@ from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.pal.base import PALChain from langchain.chains.python import PythonChain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain -from langchain.chains.serpapi import SerpAPIChain from langchain.chains.sql_database.base import SQLDatabaseChain from langchain.chains.vector_db_qa.base import VectorDBQA @@ -13,7 +12,6 @@ __all__ = [ "LLMChain", "LLMMathChain", "PythonChain", - "SerpAPIChain", "SQLDatabaseChain", "VectorDBQA", "SequentialChain", diff --git a/langchain/chains/serpapi.py b/langchain/serpapi.py similarity index 77% rename from langchain/chains/serpapi.py rename to langchain/serpapi.py index 30ac632a3c..99affb4246 100644 --- a/langchain/chains/serpapi.py +++ b/langchain/serpapi.py @@ -4,11 +4,10 @@ Heavily borrowed from https://github.com/ofirpress/self-ask """ import os import sys -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, Extra, root_validator -from langchain.chains.base import Chain from langchain.utils import get_from_dict_or_env @@ -26,8 +25,8 @@ class HiddenPrints: sys.stdout = self._original_stdout -class SerpAPIChain(Chain, BaseModel): - """Chain that calls SerpAPI. +class SerpAPIWrapper(BaseModel): + """Wrapper around SerpAPI. To use, you should have the ``google-search-results`` python package installed, and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass @@ -36,13 +35,11 @@ class SerpAPIChain(Chain, BaseModel): Example: .. code-block:: python - from langchain import SerpAPIChain - serpapi = SerpAPIChain() + from langchain import SerpAPIWrapper + serpapi = SerpAPIWrapper() """ search_engine: Any #: :meta private: - input_key: str = "search_query" #: :meta private: - output_key: str = "search_result" #: :meta private: serpapi_api_key: Optional[str] = None @@ -51,22 +48,6 @@ class SerpAPIChain(Chain, BaseModel): extra = Extra.forbid - @property - def input_keys(self) -> List[str]: - """Return the singular input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return the singular output key. - - :meta private: - """ - return [self.output_key] - @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -85,11 +66,12 @@ class SerpAPIChain(Chain, BaseModel): ) return values - def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def run(self, query: str) -> str: + """Run query through SerpAPI and parse result.""" params = { "api_key": self.serpapi_api_key, "engine": "google", - "q": inputs[self.input_key], + "q": query, "google_domain": "google.com", "gl": "us", "hl": "en", @@ -112,4 +94,9 @@ class SerpAPIChain(Chain, BaseModel): toret = res["organic_results"][0]["snippet"] else: toret = "No good search result found" - return {self.output_key: toret} + return toret + + +# For backwards compatability + +SerpAPIWrapper = SerpAPIWrapper diff --git a/tests/integration_tests/chains/test_self_ask_with_search.py b/tests/integration_tests/chains/test_self_ask_with_search.py index 8873be8bf8..e4536f7586 100644 --- a/tests/integration_tests/chains/test_self_ask_with_search.py +++ b/tests/integration_tests/chains/test_self_ask_with_search.py @@ -1,7 +1,7 @@ """Integration test for self ask with search.""" from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain -from langchain.chains.serpapi import SerpAPIChain from langchain.llms.openai import OpenAI +from langchain.serpapi import SerpAPIWrapper def test_self_ask_with_search() -> None: @@ -9,7 +9,7 @@ def test_self_ask_with_search() -> None: question = "What is the hometown of the reigning men's U.S. Open champion?" chain = SelfAskWithSearchChain( llm=OpenAI(temperature=0), - search_chain=SerpAPIChain(), + search_chain=SerpAPIWrapper(), input_key="q", output_key="a", ) diff --git a/tests/integration_tests/chains/test_serpapi.py b/tests/integration_tests/test_serpapi.py similarity index 73% rename from tests/integration_tests/chains/test_serpapi.py rename to tests/integration_tests/test_serpapi.py index 60cda1aa62..cd2b63437a 100644 --- a/tests/integration_tests/chains/test_serpapi.py +++ b/tests/integration_tests/test_serpapi.py @@ -1,9 +1,9 @@ """Integration test for SerpAPI.""" -from langchain.chains.serpapi import SerpAPIChain +from langchain.serpapi import SerpAPIWrapper def test_call() -> None: """Test that call gives the correct answer.""" - chain = SerpAPIChain() + chain = SerpAPIWrapper() output = chain.run("What was Obama's first name?") assert output == "Barack Hussein Obama II"