move search to not be a chain (#226)

pull/227/head
Harrison Chase 2 years ago committed by GitHub
parent b19a73be26
commit ca2394028f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -48,7 +48,7 @@
"outputs": [],
"source": [
"from langchain.agents import ZeroShotAgent, Tool\n",
"from langchain import OpenAI, SerpAPIChain, LLMChain"
"from langchain import OpenAI, SerpAPIWrapper, LLMChain"
]
},
{
@ -58,7 +58,7 @@
"metadata": {},
"outputs": [],
"source": [
"search = SerpAPIChain()\n",
"search = SerpAPIWrapper()\n",
"tools = [\n",
" Tool(\n",
" name = \"Search\",\n",

@ -26,7 +26,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain import LLMMathChain, OpenAI, SerpAPIChain, SQLDatabase, SQLDatabaseChain\n",
"from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, SQLDatabase, SQLDatabaseChain\n",
"from langchain.agents import initialize_agent, Tool"
]
},
@ -38,7 +38,7 @@
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n",
"search = SerpAPIWrapper()\n",
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
"db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n",
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n",

@ -45,11 +45,11 @@
}
],
"source": [
"from langchain import OpenAI, SerpAPIChain\n",
"from langchain import OpenAI, SerpAPIWrapper\n",
"from langchain.agents import initialize_agent, Tool\n",
"\n",
"llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n",
"search = SerpAPIWrapper()\n",
"tools = [\n",
" Tool(\n",
" name=\"Intermediate Answer\",\n",

@ -29,7 +29,7 @@
"source": [
"from langchain.agents import ZeroShotAgent, Tool\n",
"from langchain.chains.conversation.memory import ConversationBufferMemory\n",
"from langchain import OpenAI, SerpAPIChain, LLMChain"
"from langchain import OpenAI, SerpAPIWrapper, LLMChain"
]
},
{
@ -39,7 +39,7 @@
"metadata": {},
"outputs": [],
"source": [
"search = SerpAPIChain()\n",
"search = SerpAPIWrapper()\n",
"tools = [\n",
" Tool(\n",
" name = \"Search\",\n",

@ -135,14 +135,14 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain import SelfAskWithSearchChain, SerpAPIChain\n",
"from langchain import SelfAskWithSearchChain, SerpAPIWrapper\n",
"\n",
"open_ai_llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n",
"search = SerpAPIWrapper()\n",
"self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n",
"\n",
"cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n",
"search = SerpAPIChain()\n",
"search = SerpAPIWrapper()\n",
"self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)"
]
},

@ -77,9 +77,9 @@
"outputs": [],
"source": [
"# Load the tool configs that are needed.\n",
"from langchain import LLMMathChain, SerpAPIChain\n",
"from langchain import LLMMathChain, SerpAPIWrapper\n",
"llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n",
"search = SerpAPIWrapper()\n",
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
"tools = [\n",
" Tool(\n",

@ -12,7 +12,6 @@ from langchain.chains import (
LLMMathChain,
PALChain,
PythonChain,
SerpAPIChain,
SQLDatabaseChain,
VectorDBQA,
)
@ -24,6 +23,7 @@ from langchain.prompts import (
Prompt,
PromptTemplate,
)
from langchain.serpapi import SerpAPIWrapper
from langchain.sql_database import SQLDatabase
from langchain.vectorstores import FAISS, ElasticVectorSearch
@ -32,7 +32,8 @@ __all__ = [
"LLMMathChain",
"PythonChain",
"SelfAskWithSearchChain",
"SerpAPIChain",
"SerpAPIWrapper",
"SerpAPIWrapper",
"Cohere",
"OpenAI",
"BasePromptTemplate",

@ -131,10 +131,10 @@ class MRKLChain(ZeroShotAgent):
Example:
.. code-block:: python
from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain
from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, MRKLChain
from langchain.chains.mrkl.base import ChainConfig
llm = OpenAI(temperature=0)
search = SerpAPIChain()
search = SerpAPIWrapper()
llm_math_chain = LLMMathChain(llm=llm)
chains = [
ChainConfig(

@ -5,9 +5,9 @@ from langchain.agents.agent import Agent
from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.chains.llm import LLMChain
from langchain.chains.serpapi import SerpAPIChain
from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
from langchain.serpapi import SerpAPIWrapper
class SelfAskWithSearchAgent(Agent):
@ -73,12 +73,12 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent):
Example:
.. code-block:: python
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain
search_chain = SerpAPIChain()
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIWrapper
search_chain = SerpAPIWrapper()
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
"""
def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any):
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
"""Initialize with just an LLM and a search chain."""
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
llm_chain = LLMChain(llm=llm, prompt=PROMPT)

@ -5,7 +5,6 @@ from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.pal.base import PALChain
from langchain.chains.python import PythonChain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.chains.serpapi import SerpAPIChain
from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.chains.vector_db_qa.base import VectorDBQA
@ -13,7 +12,6 @@ __all__ = [
"LLMChain",
"LLMMathChain",
"PythonChain",
"SerpAPIChain",
"SQLDatabaseChain",
"VectorDBQA",
"SequentialChain",

@ -4,11 +4,10 @@ Heavily borrowed from https://github.com/ofirpress/self-ask
"""
import os
import sys
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.utils import get_from_dict_or_env
@ -26,8 +25,8 @@ class HiddenPrints:
sys.stdout = self._original_stdout
class SerpAPIChain(Chain, BaseModel):
"""Chain that calls SerpAPI.
class SerpAPIWrapper(BaseModel):
"""Wrapper around SerpAPI.
To use, you should have the ``google-search-results`` python package installed,
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
@ -36,13 +35,11 @@ class SerpAPIChain(Chain, BaseModel):
Example:
.. code-block:: python
from langchain import SerpAPIChain
serpapi = SerpAPIChain()
from langchain import SerpAPIWrapper
serpapi = SerpAPIWrapper()
"""
search_engine: Any #: :meta private:
input_key: str = "search_query" #: :meta private:
output_key: str = "search_result" #: :meta private:
serpapi_api_key: Optional[str] = None
@ -51,22 +48,6 @@ class SerpAPIChain(Chain, BaseModel):
extra = Extra.forbid
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
return [self.output_key]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@ -85,11 +66,12 @@ class SerpAPIChain(Chain, BaseModel):
)
return values
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def run(self, query: str) -> str:
"""Run query through SerpAPI and parse result."""
params = {
"api_key": self.serpapi_api_key,
"engine": "google",
"q": inputs[self.input_key],
"q": query,
"google_domain": "google.com",
"gl": "us",
"hl": "en",
@ -112,4 +94,9 @@ class SerpAPIChain(Chain, BaseModel):
toret = res["organic_results"][0]["snippet"]
else:
toret = "No good search result found"
return {self.output_key: toret}
return toret
# For backwards compatability
SerpAPIWrapper = SerpAPIWrapper

@ -1,7 +1,7 @@
"""Integration test for self ask with search."""
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
from langchain.chains.serpapi import SerpAPIChain
from langchain.llms.openai import OpenAI
from langchain.serpapi import SerpAPIWrapper
def test_self_ask_with_search() -> None:
@ -9,7 +9,7 @@ def test_self_ask_with_search() -> None:
question = "What is the hometown of the reigning men's U.S. Open champion?"
chain = SelfAskWithSearchChain(
llm=OpenAI(temperature=0),
search_chain=SerpAPIChain(),
search_chain=SerpAPIWrapper(),
input_key="q",
output_key="a",
)

@ -1,9 +1,9 @@
"""Integration test for SerpAPI."""
from langchain.chains.serpapi import SerpAPIChain
from langchain.serpapi import SerpAPIWrapper
def test_call() -> None:
"""Test that call gives the correct answer."""
chain = SerpAPIChain()
chain = SerpAPIWrapper()
output = chain.run("What was Obama's first name?")
assert output == "Barack Hussein Obama II"
Loading…
Cancel
Save