From 44c8d8a9acf580d8768f8b529de4d21095800cec Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 20 Feb 2023 21:15:45 -0800 Subject: [PATCH] move serpapi wrapper (#1199) Co-authored-by: Tim Asp <707699+timothyasp@users.noreply.github.com> --- langchain/__init__.py | 5 +- langchain/agents/load_tools.py | 2 +- langchain/agents/self_ask_with_search/base.py | 2 +- langchain/serpapi.py | 157 +----------------- langchain/utilities/__init__.py | 2 +- langchain/utilities/serpapi.py | 152 +++++++++++++++++ tests/integration_tests/test_serpapi.py | 2 +- 7 files changed, 163 insertions(+), 159 deletions(-) create mode 100644 langchain/utilities/serpapi.py diff --git a/langchain/__init__.py b/langchain/__init__.py index 59a6cc54..3b7f3ebf 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -40,11 +40,11 @@ from langchain.prompts import ( Prompt, PromptTemplate, ) -from langchain.serpapi import SerpAPIChain, SerpAPIWrapper from langchain.sql_database import SQLDatabase from langchain.utilities.google_search import GoogleSearchAPIWrapper from langchain.utilities.google_serper import GoogleSerperAPIWrapper from langchain.utilities.searx_search import SearxSearchWrapper +from langchain.utilities.serpapi import SerpAPIWrapper from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper from langchain.vectorstores import FAISS, ElasticVectorSearch @@ -52,6 +52,9 @@ verbose: bool = False llm_cache: Optional[BaseCache] = None set_default_callback_manager() +# For backwards compatibility +SerpAPIChain = SerpAPIWrapper + __all__ = [ "LLMChain", "LLMBashChain", diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index d2388454..606bc127 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -11,7 +11,6 @@ from langchain.chains.pal.base import PALChain from langchain.llms.base import BaseLLM from langchain.python import PythonREPL from langchain.requests import RequestsWrapper -from langchain.serpapi import SerpAPIWrapper from langchain.tools.base import BaseTool from langchain.tools.bing_search.tool import BingSearchRun from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun @@ -21,6 +20,7 @@ from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.google_search import GoogleSearchAPIWrapper from langchain.utilities.google_serper import GoogleSerperAPIWrapper from langchain.utilities.searx_search import SearxSearchWrapper +from langchain.utilities.serpapi import SerpAPIWrapper from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper diff --git a/langchain/agents/self_ask_with_search/base.py b/langchain/agents/self_ask_with_search/base.py index d9d23b05..694273e3 100644 --- a/langchain/agents/self_ask_with_search/base.py +++ b/langchain/agents/self_ask_with_search/base.py @@ -6,9 +6,9 @@ from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.tools import Tool from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate -from langchain.serpapi import SerpAPIWrapper from langchain.tools.base import BaseTool from langchain.utilities.google_serper import GoogleSerperAPIWrapper +from langchain.utilities.serpapi import SerpAPIWrapper class SelfAskWithSearchAgent(Agent): diff --git a/langchain/serpapi.py b/langchain/serpapi.py index aea65d83..dd8569b6 100644 --- a/langchain/serpapi.py +++ b/langchain/serpapi.py @@ -1,155 +1,4 @@ -"""Chain that calls SerpAPI. +"""For backwards compatiblity.""" +from langchain.utilities.serpapi import SerpAPIWrapper -Heavily borrowed from https://github.com/ofirpress/self-ask -""" -import os -import sys -from typing import Any, Dict, Optional, Tuple - -import aiohttp -from pydantic import BaseModel, Extra, Field, root_validator - -from langchain.utils import get_from_dict_or_env - - -class HiddenPrints: - """Context manager to hide prints.""" - - def __enter__(self) -> None: - """Open file to pipe stdout to.""" - self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, "w") - - def __exit__(self, *_: Any) -> None: - """Close file that stdout was piped to.""" - sys.stdout.close() - sys.stdout = self._original_stdout - - -def _get_default_params() -> dict: - return { - "engine": "google", - "google_domain": "google.com", - "gl": "us", - "hl": "en", - } - - -def process_response(res: dict) -> str: - """Process response from SerpAPI.""" - if "error" in res.keys(): - raise ValueError(f"Got error from SerpAPI: {res['error']}") - if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): - toret = res["answer_box"]["answer"] - elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): - toret = res["answer_box"]["snippet"] - elif ( - "answer_box" in res.keys() - and "snippet_highlighted_words" in res["answer_box"].keys() - ): - toret = res["answer_box"]["snippet_highlighted_words"][0] - elif ( - "sports_results" in res.keys() - and "game_spotlight" in res["sports_results"].keys() - ): - toret = res["sports_results"]["game_spotlight"] - elif ( - "knowledge_graph" in res.keys() - and "description" in res["knowledge_graph"].keys() - ): - toret = res["knowledge_graph"]["description"] - elif "snippet" in res["organic_results"][0].keys(): - toret = res["organic_results"][0]["snippet"] - - else: - toret = "No good search result found" - return toret - - -class SerpAPIWrapper(BaseModel): - """Wrapper around SerpAPI. - - To use, you should have the ``google-search-results`` python package installed, - and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass - `serpapi_api_key` as a named parameter to the constructor. - - Example: - .. code-block:: python - - from langchain import SerpAPIWrapper - serpapi = SerpAPIWrapper() - """ - - search_engine: Any #: :meta private: - params: dict = Field(default_factory=_get_default_params) - serpapi_api_key: Optional[str] = None - aiosession: Optional[aiohttp.ClientSession] = None - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - serpapi_api_key = get_from_dict_or_env( - values, "serpapi_api_key", "SERPAPI_API_KEY" - ) - values["serpapi_api_key"] = serpapi_api_key - try: - from serpapi import GoogleSearch - - values["search_engine"] = GoogleSearch - except ImportError: - raise ValueError( - "Could not import serpapi python package. " - "Please it install it with `pip install google-search-results`." - ) - return values - - async def arun(self, query: str) -> str: - """Use aiohttp to run query through SerpAPI and parse result.""" - - def construct_url_and_params() -> Tuple[str, Dict[str, str]]: - params = self.get_params(query) - params["source"] = "python" - if self.serpapi_api_key: - params["serp_api_key"] = self.serpapi_api_key - params["output"] = "json" - url = "https://serpapi.com/search" - return url, params - - url, params = construct_url_and_params() - if not self.aiosession: - async with aiohttp.ClientSession() as session: - async with session.get(url, params=params) as response: - res = await response.json() - else: - async with self.aiosession.get(url, params=params) as response: - res = await response.json() - - return process_response(res) - - def run(self, query: str) -> str: - """Run query through SerpAPI and parse result.""" - params = self.get_params(query) - with HiddenPrints(): - search = self.search_engine(params) - res = search.get_dict() - return process_response(res) - - def get_params(self, query: str) -> Dict[str, str]: - """Get parameters for SerpAPI.""" - _params = { - "api_key": self.serpapi_api_key, - "q": query, - } - params = {**self.params, **_params} - return params - - -# For backwards compatibility - -SerpAPIChain = SerpAPIWrapper +__all__ = ["SerpAPIWrapper"] diff --git a/langchain/utilities/__init__.py b/langchain/utilities/__init__.py index 36a5092c..b61fa9fd 100644 --- a/langchain/utilities/__init__.py +++ b/langchain/utilities/__init__.py @@ -1,12 +1,12 @@ """General utilities.""" from langchain.python import PythonREPL from langchain.requests import RequestsWrapper -from langchain.serpapi import SerpAPIWrapper from langchain.utilities.bash import BashProcess from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.google_search import GoogleSearchAPIWrapper from langchain.utilities.google_serper import GoogleSerperAPIWrapper from langchain.utilities.searx_search import SearxSearchWrapper +from langchain.utilities.serpapi import SerpAPIWrapper from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper __all__ = [ diff --git a/langchain/utilities/serpapi.py b/langchain/utilities/serpapi.py new file mode 100644 index 00000000..3a35a711 --- /dev/null +++ b/langchain/utilities/serpapi.py @@ -0,0 +1,152 @@ +"""Chain that calls SerpAPI. + +Heavily borrowed from https://github.com/ofirpress/self-ask +""" +import os +import sys +from typing import Any, Dict, Optional, Tuple + +import aiohttp +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.utils import get_from_dict_or_env + + +class HiddenPrints: + """Context manager to hide prints.""" + + def __enter__(self) -> None: + """Open file to pipe stdout to.""" + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + + def __exit__(self, *_: Any) -> None: + """Close file that stdout was piped to.""" + sys.stdout.close() + sys.stdout = self._original_stdout + + +class SerpAPIWrapper(BaseModel): + """Wrapper around SerpAPI. + + To use, you should have the ``google-search-results`` python package installed, + and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass + `serpapi_api_key` as a named parameter to the constructor. + + Example: + .. code-block:: python + + from langchain import SerpAPIWrapper + serpapi = SerpAPIWrapper() + """ + + search_engine: Any #: :meta private: + params: dict = Field( + default={ + "engine": "google", + "google_domain": "google.com", + "gl": "us", + "hl": "en", + } + ) + serpapi_api_key: Optional[str] = None + aiosession: Optional[aiohttp.ClientSession] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + serpapi_api_key = get_from_dict_or_env( + values, "serpapi_api_key", "SERPAPI_API_KEY" + ) + values["serpapi_api_key"] = serpapi_api_key + try: + from serpapi import GoogleSearch + + values["search_engine"] = GoogleSearch + except ImportError: + raise ValueError( + "Could not import serpapi python package. " + "Please it install it with `pip install google-search-results`." + ) + return values + + async def arun(self, query: str) -> str: + """Use aiohttp to run query through SerpAPI and parse result.""" + + def construct_url_and_params() -> Tuple[str, Dict[str, str]]: + params = self.get_params(query) + params["source"] = "python" + if self.serpapi_api_key: + params["serp_api_key"] = self.serpapi_api_key + params["output"] = "json" + url = "https://serpapi.com/search" + return url, params + + url, params = construct_url_and_params() + if not self.aiosession: + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params) as response: + res = await response.json() + else: + async with self.aiosession.get(url, params=params) as response: + res = await response.json() + + return self._process_response(res) + + def run(self, query: str) -> str: + """Run query through SerpAPI and parse result.""" + return self._process_response(self.results(query)) + + def results(self, query: str) -> dict: + """Run query through SerpAPI and return the raw result.""" + params = self.get_params(query) + with HiddenPrints(): + search = self.search_engine(params) + res = search.get_dict() + return res + + def get_params(self, query: str) -> Dict[str, str]: + """Get parameters for SerpAPI.""" + _params = { + "api_key": self.serpapi_api_key, + "q": query, + } + params = {**self.params, **_params} + return params + + @staticmethod + def _process_response(res: dict) -> str: + """Process response from SerpAPI.""" + if "error" in res.keys(): + raise ValueError(f"Got error from SerpAPI: {res['error']}") + if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): + toret = res["answer_box"]["answer"] + elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): + toret = res["answer_box"]["snippet"] + elif ( + "answer_box" in res.keys() + and "snippet_highlighted_words" in res["answer_box"].keys() + ): + toret = res["answer_box"]["snippet_highlighted_words"][0] + elif ( + "sports_results" in res.keys() + and "game_spotlight" in res["sports_results"].keys() + ): + toret = res["sports_results"]["game_spotlight"] + elif ( + "knowledge_graph" in res.keys() + and "description" in res["knowledge_graph"].keys() + ): + toret = res["knowledge_graph"]["description"] + elif "snippet" in res["organic_results"][0].keys(): + toret = res["organic_results"][0]["snippet"] + + else: + toret = "No good search result found" + return toret diff --git a/tests/integration_tests/test_serpapi.py b/tests/integration_tests/test_serpapi.py index cd2b6343..2e3d3427 100644 --- a/tests/integration_tests/test_serpapi.py +++ b/tests/integration_tests/test_serpapi.py @@ -1,5 +1,5 @@ """Integration test for SerpAPI.""" -from langchain.serpapi import SerpAPIWrapper +from langchain.utilities import SerpAPIWrapper def test_call() -> None: