move search to not be a chain (#226)

pull/227/head
Harrison Chase 2 years ago committed by GitHub
parent b19a73be26
commit ca2394028f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -48,7 +48,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.agents import ZeroShotAgent, Tool\n", "from langchain.agents import ZeroShotAgent, Tool\n",
"from langchain import OpenAI, SerpAPIChain, LLMChain" "from langchain import OpenAI, SerpAPIWrapper, LLMChain"
] ]
}, },
{ {
@ -58,7 +58,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"search = SerpAPIChain()\n", "search = SerpAPIWrapper()\n",
"tools = [\n", "tools = [\n",
" Tool(\n", " Tool(\n",
" name = \"Search\",\n", " name = \"Search\",\n",

@ -26,7 +26,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain import LLMMathChain, OpenAI, SerpAPIChain, SQLDatabase, SQLDatabaseChain\n", "from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, SQLDatabase, SQLDatabaseChain\n",
"from langchain.agents import initialize_agent, Tool" "from langchain.agents import initialize_agent, Tool"
] ]
}, },
@ -38,7 +38,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"llm = OpenAI(temperature=0)\n", "llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n", "search = SerpAPIWrapper()\n",
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n", "llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
"db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n", "db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n",
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n", "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n",

@ -45,11 +45,11 @@
} }
], ],
"source": [ "source": [
"from langchain import OpenAI, SerpAPIChain\n", "from langchain import OpenAI, SerpAPIWrapper\n",
"from langchain.agents import initialize_agent, Tool\n", "from langchain.agents import initialize_agent, Tool\n",
"\n", "\n",
"llm = OpenAI(temperature=0)\n", "llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n", "search = SerpAPIWrapper()\n",
"tools = [\n", "tools = [\n",
" Tool(\n", " Tool(\n",
" name=\"Intermediate Answer\",\n", " name=\"Intermediate Answer\",\n",

@ -29,7 +29,7 @@
"source": [ "source": [
"from langchain.agents import ZeroShotAgent, Tool\n", "from langchain.agents import ZeroShotAgent, Tool\n",
"from langchain.chains.conversation.memory import ConversationBufferMemory\n", "from langchain.chains.conversation.memory import ConversationBufferMemory\n",
"from langchain import OpenAI, SerpAPIChain, LLMChain" "from langchain import OpenAI, SerpAPIWrapper, LLMChain"
] ]
}, },
{ {
@ -39,7 +39,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"search = SerpAPIChain()\n", "search = SerpAPIWrapper()\n",
"tools = [\n", "tools = [\n",
" Tool(\n", " Tool(\n",
" name = \"Search\",\n", " name = \"Search\",\n",

@ -135,14 +135,14 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain import SelfAskWithSearchChain, SerpAPIChain\n", "from langchain import SelfAskWithSearchChain, SerpAPIWrapper\n",
"\n", "\n",
"open_ai_llm = OpenAI(temperature=0)\n", "open_ai_llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n", "search = SerpAPIWrapper()\n",
"self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n", "self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n",
"\n", "\n",
"cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n", "cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n",
"search = SerpAPIChain()\n", "search = SerpAPIWrapper()\n",
"self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)" "self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)"
] ]
}, },

@ -77,9 +77,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Load the tool configs that are needed.\n", "# Load the tool configs that are needed.\n",
"from langchain import LLMMathChain, SerpAPIChain\n", "from langchain import LLMMathChain, SerpAPIWrapper\n",
"llm = OpenAI(temperature=0)\n", "llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n", "search = SerpAPIWrapper()\n",
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n", "llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
"tools = [\n", "tools = [\n",
" Tool(\n", " Tool(\n",

@ -12,7 +12,6 @@ from langchain.chains import (
LLMMathChain, LLMMathChain,
PALChain, PALChain,
PythonChain, PythonChain,
SerpAPIChain,
SQLDatabaseChain, SQLDatabaseChain,
VectorDBQA, VectorDBQA,
) )
@ -24,6 +23,7 @@ from langchain.prompts import (
Prompt, Prompt,
PromptTemplate, PromptTemplate,
) )
from langchain.serpapi import SerpAPIWrapper
from langchain.sql_database import SQLDatabase from langchain.sql_database import SQLDatabase
from langchain.vectorstores import FAISS, ElasticVectorSearch from langchain.vectorstores import FAISS, ElasticVectorSearch
@ -32,7 +32,8 @@ __all__ = [
"LLMMathChain", "LLMMathChain",
"PythonChain", "PythonChain",
"SelfAskWithSearchChain", "SelfAskWithSearchChain",
"SerpAPIChain", "SerpAPIWrapper",
"SerpAPIWrapper",
"Cohere", "Cohere",
"OpenAI", "OpenAI",
"BasePromptTemplate", "BasePromptTemplate",

@ -131,10 +131,10 @@ class MRKLChain(ZeroShotAgent):
Example: Example:
.. code-block:: python .. code-block:: python
from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, MRKLChain
from langchain.chains.mrkl.base import ChainConfig from langchain.chains.mrkl.base import ChainConfig
llm = OpenAI(temperature=0) llm = OpenAI(temperature=0)
search = SerpAPIChain() search = SerpAPIWrapper()
llm_math_chain = LLMMathChain(llm=llm) llm_math_chain = LLMMathChain(llm=llm)
chains = [ chains = [
ChainConfig( ChainConfig(

@ -5,9 +5,9 @@ from langchain.agents.agent import Agent
from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.serpapi import SerpAPIChain
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.serpapi import SerpAPIWrapper
class SelfAskWithSearchAgent(Agent): class SelfAskWithSearchAgent(Agent):
@ -73,12 +73,12 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent):
Example: Example:
.. code-block:: python .. code-block:: python
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIWrapper
search_chain = SerpAPIChain() search_chain = SerpAPIWrapper()
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain) self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
""" """
def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any): def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
"""Initialize with just an LLM and a search chain.""" """Initialize with just an LLM and a search chain."""
search_tool = Tool(name="Intermediate Answer", func=search_chain.run) search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
llm_chain = LLMChain(llm=llm, prompt=PROMPT) llm_chain = LLMChain(llm=llm, prompt=PROMPT)

@ -5,7 +5,6 @@ from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.pal.base import PALChain from langchain.chains.pal.base import PALChain
from langchain.chains.python import PythonChain from langchain.chains.python import PythonChain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.chains.serpapi import SerpAPIChain
from langchain.chains.sql_database.base import SQLDatabaseChain from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.chains.vector_db_qa.base import VectorDBQA from langchain.chains.vector_db_qa.base import VectorDBQA
@ -13,7 +12,6 @@ __all__ = [
"LLMChain", "LLMChain",
"LLMMathChain", "LLMMathChain",
"PythonChain", "PythonChain",
"SerpAPIChain",
"SQLDatabaseChain", "SQLDatabaseChain",
"VectorDBQA", "VectorDBQA",
"SequentialChain", "SequentialChain",

@ -4,11 +4,10 @@ Heavily borrowed from https://github.com/ofirpress/self-ask
""" """
import os import os
import sys import sys
from typing import Any, Dict, List, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
@ -26,8 +25,8 @@ class HiddenPrints:
sys.stdout = self._original_stdout sys.stdout = self._original_stdout
class SerpAPIChain(Chain, BaseModel): class SerpAPIWrapper(BaseModel):
"""Chain that calls SerpAPI. """Wrapper around SerpAPI.
To use, you should have the ``google-search-results`` python package installed, To use, you should have the ``google-search-results`` python package installed,
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
@ -36,13 +35,11 @@ class SerpAPIChain(Chain, BaseModel):
Example: Example:
.. code-block:: python .. code-block:: python
from langchain import SerpAPIChain from langchain import SerpAPIWrapper
serpapi = SerpAPIChain() serpapi = SerpAPIWrapper()
""" """
search_engine: Any #: :meta private: search_engine: Any #: :meta private:
input_key: str = "search_query" #: :meta private:
output_key: str = "search_result" #: :meta private:
serpapi_api_key: Optional[str] = None serpapi_api_key: Optional[str] = None
@ -51,22 +48,6 @@ class SerpAPIChain(Chain, BaseModel):
extra = Extra.forbid extra = Extra.forbid
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
return [self.output_key]
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
@ -85,11 +66,12 @@ class SerpAPIChain(Chain, BaseModel):
) )
return values return values
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: def run(self, query: str) -> str:
"""Run query through SerpAPI and parse result."""
params = { params = {
"api_key": self.serpapi_api_key, "api_key": self.serpapi_api_key,
"engine": "google", "engine": "google",
"q": inputs[self.input_key], "q": query,
"google_domain": "google.com", "google_domain": "google.com",
"gl": "us", "gl": "us",
"hl": "en", "hl": "en",
@ -112,4 +94,9 @@ class SerpAPIChain(Chain, BaseModel):
toret = res["organic_results"][0]["snippet"] toret = res["organic_results"][0]["snippet"]
else: else:
toret = "No good search result found" toret = "No good search result found"
return {self.output_key: toret} return toret
# For backwards compatability
SerpAPIWrapper = SerpAPIWrapper

@ -1,7 +1,7 @@
"""Integration test for self ask with search.""" """Integration test for self ask with search."""
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
from langchain.chains.serpapi import SerpAPIChain
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
from langchain.serpapi import SerpAPIWrapper
def test_self_ask_with_search() -> None: def test_self_ask_with_search() -> None:
@ -9,7 +9,7 @@ def test_self_ask_with_search() -> None:
question = "What is the hometown of the reigning men's U.S. Open champion?" question = "What is the hometown of the reigning men's U.S. Open champion?"
chain = SelfAskWithSearchChain( chain = SelfAskWithSearchChain(
llm=OpenAI(temperature=0), llm=OpenAI(temperature=0),
search_chain=SerpAPIChain(), search_chain=SerpAPIWrapper(),
input_key="q", input_key="q",
output_key="a", output_key="a",
) )

@ -1,9 +1,9 @@
"""Integration test for SerpAPI.""" """Integration test for SerpAPI."""
from langchain.chains.serpapi import SerpAPIChain from langchain.serpapi import SerpAPIWrapper
def test_call() -> None: def test_call() -> None:
"""Test that call gives the correct answer.""" """Test that call gives the correct answer."""
chain = SerpAPIChain() chain = SerpAPIWrapper()
output = chain.run("What was Obama's first name?") output = chain.run("What was Obama's first name?")
assert output == "Barack Hussein Obama II" assert output == "Barack Hussein Obama II"
Loading…
Cancel
Save