From 7f1d26160deb2a8c0cdb9b5829f0d27e1e6c9cc7 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 12 Nov 2023 10:22:54 -0800 Subject: [PATCH] update tools (#13243) --- libs/langchain/langchain/retrievers/you.py | 8 +- libs/langchain/langchain/tools/arxiv/tool.py | 9 ++- libs/langchain/langchain/tools/retriever.py | 6 ++ .../langchain/tools/tavily_search/__init__.py | 4 +- .../langchain/tools/tavily_search/tool.py | 59 +++++++++++++- .../langchain/utilities/tavily_search.py | 77 +++++++++++-------- 6 files changed, 124 insertions(+), 39 deletions(-) diff --git a/libs/langchain/langchain/retrievers/you.py b/libs/langchain/langchain/retrievers/you.py index c1fe694625..2cc0476858 100644 --- a/libs/langchain/langchain/retrievers/you.py +++ b/libs/langchain/langchain/retrievers/you.py @@ -18,6 +18,8 @@ class YouRetriever(BaseRetriever): ydc_api_key: str k: Optional[int] = None + n_hits: Optional[int] = None + n_snippets_per_hit: Optional[int] = None endpoint_type: str = "web" @root_validator(pre=True) @@ -43,8 +45,10 @@ class YouRetriever(BaseRetriever): ).json() docs = [] - for hit in results["hits"]: - for snippet in hit["snippets"]: + n_hits = self.n_hits or len(results["hits"]) + for hit in results["hits"][:n_hits]: + n_snippets_per_hit = self.n_snippets_per_hit or len(hit["snippets"]) + for snippet in hit["snippets"][:n_snippets_per_hit]: docs.append(Document(page_content=snippet)) if self.k is not None and len(docs) >= self.k: return docs diff --git a/libs/langchain/langchain/tools/arxiv/tool.py b/libs/langchain/langchain/tools/arxiv/tool.py index c9b2f5caa9..c6b8e98e9d 100644 --- a/libs/langchain/langchain/tools/arxiv/tool.py +++ b/libs/langchain/langchain/tools/arxiv/tool.py @@ -1,13 +1,17 @@ """Tool for the Arxiv API.""" -from typing import Optional +from typing import Optional, Type from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.pydantic_v1 import Field +from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.utilities.arxiv import ArxivAPIWrapper +class ArxivInput(BaseModel): + query: str = Field(description="search query to look up") + + class ArxivQueryRun(BaseTool): """Tool that searches the Arxiv API.""" @@ -21,6 +25,7 @@ class ArxivQueryRun(BaseTool): "Input should be a search query." ) api_wrapper: ArxivAPIWrapper = Field(default_factory=ArxivAPIWrapper) + args_schema: Type[BaseModel] = ArxivInput def _run( self, diff --git a/libs/langchain/langchain/tools/retriever.py b/libs/langchain/langchain/tools/retriever.py index 43f0a2105d..d11c0c1d5c 100644 --- a/libs/langchain/langchain/tools/retriever.py +++ b/libs/langchain/langchain/tools/retriever.py @@ -1,7 +1,12 @@ +from langchain.pydantic_v1 import BaseModel, Field from langchain.schema import BaseRetriever from langchain.tools import Tool +class RetrieverInput(BaseModel): + query: str = Field(description="query to look up in retriever") + + def create_retriever_tool( retriever: BaseRetriever, name: str, description: str ) -> Tool: @@ -22,4 +27,5 @@ def create_retriever_tool( description=description, func=retriever.get_relevant_documents, coroutine=retriever.aget_relevant_documents, + args_schema=RetrieverInput, ) diff --git a/libs/langchain/langchain/tools/tavily_search/__init__.py b/libs/langchain/langchain/tools/tavily_search/__init__.py index ccedc5d14e..4315ad999a 100644 --- a/libs/langchain/langchain/tools/tavily_search/__init__.py +++ b/libs/langchain/langchain/tools/tavily_search/__init__.py @@ -1,5 +1,5 @@ """Tavily Search API toolkit.""" -from langchain.tools.tavily_search.tool import TavilySearchResults +from langchain.tools.tavily_search.tool import TavilyAnswer, TavilySearchResults -__all__ = ["TavilySearchResults"] +__all__ = ["TavilySearchResults", "TavilyAnswer"] diff --git a/libs/langchain/langchain/tools/tavily_search/tool.py b/libs/langchain/langchain/tools/tavily_search/tool.py index d70e295459..a054cf4276 100644 --- a/libs/langchain/langchain/tools/tavily_search/tool.py +++ b/libs/langchain/langchain/tools/tavily_search/tool.py @@ -1,26 +1,32 @@ """Tool for the Tavily search API.""" -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Type, Union from langchain.callbacks.manager import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) +from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.base import BaseTool from langchain.utilities.tavily_search import TavilySearchAPIWrapper +class TavilyInput(BaseModel): + query: str = Field(description="search query to look up") + + class TavilySearchResults(BaseTool): """Tool that queries the Tavily Search API and gets back json.""" name: str = "tavily_search_results_json" - description: str = """" + description: str = ( "A search engine optimized for comprehensive, accurate, and trusted results. " "Useful for when you need to answer questions about current events. " "Input should be a search query." - """ + ) api_wrapper: TavilySearchAPIWrapper max_results: int = 5 + args_schema: Type[BaseModel] = TavilyInput def _run( self, @@ -49,3 +55,50 @@ class TavilySearchResults(BaseTool): ) except Exception as e: return repr(e) + + +class TavilyAnswer(BaseTool): + """Tool that queries the Tavily Search API and gets back an answer.""" + + name: str = "tavily_answer" + description: str = ( + "A search engine optimized for comprehensive, accurate, and trusted results. " + "Useful for when you need to answer questions about current events. " + "Input should be a search query. " + "This returns only the answer - not the original source data." + ) + api_wrapper: TavilySearchAPIWrapper + args_schema: Type[BaseModel] = TavilyInput + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> Union[List[Dict], str]: + """Use the tool.""" + try: + return self.api_wrapper.raw_results( + query, + max_results=5, + include_answer=True, + search_depth="basic", + )["answer"] + except Exception as e: + return repr(e) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> Union[List[Dict], str]: + """Use the tool asynchronously.""" + try: + result = await self.api_wrapper.raw_results_async( + query, + max_results=5, + include_answer=True, + search_depth="basic", + ) + return result["answer"] + except Exception as e: + return repr(e) diff --git a/libs/langchain/langchain/utilities/tavily_search.py b/libs/langchain/langchain/utilities/tavily_search.py index 5cbb5a9f3c..57267f2284 100644 --- a/libs/langchain/langchain/utilities/tavily_search.py +++ b/libs/langchain/langchain/utilities/tavily_search.py @@ -24,7 +24,17 @@ class TavilySearchAPIWrapper(BaseModel): extra = Extra.forbid - def _tavily_search_results( + @root_validator(pre=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and endpoint exists in environment.""" + tavily_api_key = get_from_dict_or_env( + values, "tavily_api_key", "TAVILY_API_KEY" + ) + values["tavily_api_key"] = tavily_api_key + + return values + + def raw_results( self, query: str, max_results: Optional[int] = 5, @@ -34,7 +44,7 @@ class TavilySearchAPIWrapper(BaseModel): include_answer: Optional[bool] = False, include_raw_content: Optional[bool] = False, include_images: Optional[bool] = False, - ) -> List[dict]: + ) -> Dict: params = { "api_key": self.tavily_api_key, "query": query, @@ -51,20 +61,8 @@ class TavilySearchAPIWrapper(BaseModel): f"{TAVILY_API_URL}/search", json=params, ) - response.raise_for_status() - search_results = response.json() - return self.clean_results(search_results["results"]) - - @root_validator(pre=True) - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and endpoint exists in environment.""" - tavily_api_key = get_from_dict_or_env( - values, "tavily_api_key", "TAVILY_API_KEY" - ) - values["tavily_api_key"] = tavily_api_key - - return values + return response.json() def results( self, @@ -88,7 +86,6 @@ class TavilySearchAPIWrapper(BaseModel): include_answer: Whether to include the answer in the results. include_raw_content: Whether to include the raw content in the results. include_images: Whether to include images in the results. - Returns: query: The query that was searched for. follow_up_questions: A list of follow up questions. @@ -101,22 +98,20 @@ class TavilySearchAPIWrapper(BaseModel): content: The content of the result. score: The score of the result. raw_content: The raw content of the result. - - """ # noqa: E501 - raw_search_results = self._tavily_search_results( + raw_search_results = self.raw_results( query, - max_results, - search_depth, - include_domains, - exclude_domains, - include_answer, - include_raw_content, - include_images, + max_results=max_results, + search_depth=search_depth, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_answer=include_answer, + include_raw_content=include_raw_content, + include_images=include_images, ) - return raw_search_results + return self.clean_results(raw_search_results["results"]) - async def results_async( + async def raw_results_async( self, query: str, max_results: Optional[int] = 5, @@ -126,7 +121,7 @@ class TavilySearchAPIWrapper(BaseModel): include_answer: Optional[bool] = False, include_raw_content: Optional[bool] = False, include_images: Optional[bool] = False, - ) -> List[Dict]: + ) -> Dict: """Get results from the Tavily Search API asynchronously.""" # Function to perform the API call @@ -151,7 +146,29 @@ class TavilySearchAPIWrapper(BaseModel): raise Exception(f"Error {res.status}: {res.reason}") results_json_str = await fetch() - results_json = json.loads(results_json_str) + return json.loads(results_json_str) + + async def results_async( + self, + query: str, + max_results: Optional[int] = 5, + search_depth: Optional[str] = "advanced", + include_domains: Optional[List[str]] = [], + exclude_domains: Optional[List[str]] = [], + include_answer: Optional[bool] = False, + include_raw_content: Optional[bool] = False, + include_images: Optional[bool] = False, + ) -> List[Dict]: + results_json = await self.raw_results_async( + query=query, + max_results=max_results, + search_depth=search_depth, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_answer=include_answer, + include_raw_content=include_raw_content, + include_images=include_images, + ) return self.clean_results(results_json["results"]) def clean_results(self, results: List[Dict]) -> List[Dict]: