update tools (#13243)

pull/13208/head^2
Harrison Chase 9 months ago committed by GitHub
parent 8d6faf5665
commit 7f1d26160d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,6 +18,8 @@ class YouRetriever(BaseRetriever):
ydc_api_key: str ydc_api_key: str
k: Optional[int] = None k: Optional[int] = None
n_hits: Optional[int] = None
n_snippets_per_hit: Optional[int] = None
endpoint_type: str = "web" endpoint_type: str = "web"
@root_validator(pre=True) @root_validator(pre=True)
@ -43,8 +45,10 @@ class YouRetriever(BaseRetriever):
).json() ).json()
docs = [] docs = []
for hit in results["hits"]: n_hits = self.n_hits or len(results["hits"])
for snippet in hit["snippets"]: for hit in results["hits"][:n_hits]:
n_snippets_per_hit = self.n_snippets_per_hit or len(hit["snippets"])
for snippet in hit["snippets"][:n_snippets_per_hit]:
docs.append(Document(page_content=snippet)) docs.append(Document(page_content=snippet))
if self.k is not None and len(docs) >= self.k: if self.k is not None and len(docs) >= self.k:
return docs return docs

@ -1,13 +1,17 @@
"""Tool for the Arxiv API.""" """Tool for the Arxiv API."""
from typing import Optional from typing import Optional, Type
from langchain.callbacks.manager import CallbackManagerForToolRun from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.utilities.arxiv import ArxivAPIWrapper from langchain.utilities.arxiv import ArxivAPIWrapper
class ArxivInput(BaseModel):
query: str = Field(description="search query to look up")
class ArxivQueryRun(BaseTool): class ArxivQueryRun(BaseTool):
"""Tool that searches the Arxiv API.""" """Tool that searches the Arxiv API."""
@ -21,6 +25,7 @@ class ArxivQueryRun(BaseTool):
"Input should be a search query." "Input should be a search query."
) )
api_wrapper: ArxivAPIWrapper = Field(default_factory=ArxivAPIWrapper) api_wrapper: ArxivAPIWrapper = Field(default_factory=ArxivAPIWrapper)
args_schema: Type[BaseModel] = ArxivInput
def _run( def _run(
self, self,

@ -1,7 +1,12 @@
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema import BaseRetriever from langchain.schema import BaseRetriever
from langchain.tools import Tool from langchain.tools import Tool
class RetrieverInput(BaseModel):
query: str = Field(description="query to look up in retriever")
def create_retriever_tool( def create_retriever_tool(
retriever: BaseRetriever, name: str, description: str retriever: BaseRetriever, name: str, description: str
) -> Tool: ) -> Tool:
@ -22,4 +27,5 @@ def create_retriever_tool(
description=description, description=description,
func=retriever.get_relevant_documents, func=retriever.get_relevant_documents,
coroutine=retriever.aget_relevant_documents, coroutine=retriever.aget_relevant_documents,
args_schema=RetrieverInput,
) )

@ -1,5 +1,5 @@
"""Tavily Search API toolkit.""" """Tavily Search API toolkit."""
from langchain.tools.tavily_search.tool import TavilySearchResults from langchain.tools.tavily_search.tool import TavilyAnswer, TavilySearchResults
__all__ = ["TavilySearchResults"] __all__ = ["TavilySearchResults", "TavilyAnswer"]

@ -1,26 +1,32 @@
"""Tool for the Tavily search API.""" """Tool for the Tavily search API."""
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Type, Union
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.utilities.tavily_search import TavilySearchAPIWrapper from langchain.utilities.tavily_search import TavilySearchAPIWrapper
class TavilyInput(BaseModel):
query: str = Field(description="search query to look up")
class TavilySearchResults(BaseTool): class TavilySearchResults(BaseTool):
"""Tool that queries the Tavily Search API and gets back json.""" """Tool that queries the Tavily Search API and gets back json."""
name: str = "tavily_search_results_json" name: str = "tavily_search_results_json"
description: str = """" description: str = (
"A search engine optimized for comprehensive, accurate, and trusted results. " "A search engine optimized for comprehensive, accurate, and trusted results. "
"Useful for when you need to answer questions about current events. " "Useful for when you need to answer questions about current events. "
"Input should be a search query." "Input should be a search query."
""" )
api_wrapper: TavilySearchAPIWrapper api_wrapper: TavilySearchAPIWrapper
max_results: int = 5 max_results: int = 5
args_schema: Type[BaseModel] = TavilyInput
def _run( def _run(
self, self,
@ -49,3 +55,50 @@ class TavilySearchResults(BaseTool):
) )
except Exception as e: except Exception as e:
return repr(e) return repr(e)
class TavilyAnswer(BaseTool):
"""Tool that queries the Tavily Search API and gets back an answer."""
name: str = "tavily_answer"
description: str = (
"A search engine optimized for comprehensive, accurate, and trusted results. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query. "
"This returns only the answer - not the original source data."
)
api_wrapper: TavilySearchAPIWrapper
args_schema: Type[BaseModel] = TavilyInput
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool."""
try:
return self.api_wrapper.raw_results(
query,
max_results=5,
include_answer=True,
search_depth="basic",
)["answer"]
except Exception as e:
return repr(e)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool asynchronously."""
try:
result = await self.api_wrapper.raw_results_async(
query,
max_results=5,
include_answer=True,
search_depth="basic",
)
return result["answer"]
except Exception as e:
return repr(e)

@ -24,7 +24,17 @@ class TavilySearchAPIWrapper(BaseModel):
extra = Extra.forbid extra = Extra.forbid
def _tavily_search_results( @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and endpoint exists in environment."""
tavily_api_key = get_from_dict_or_env(
values, "tavily_api_key", "TAVILY_API_KEY"
)
values["tavily_api_key"] = tavily_api_key
return values
def raw_results(
self, self,
query: str, query: str,
max_results: Optional[int] = 5, max_results: Optional[int] = 5,
@ -34,7 +44,7 @@ class TavilySearchAPIWrapper(BaseModel):
include_answer: Optional[bool] = False, include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False, include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False, include_images: Optional[bool] = False,
) -> List[dict]: ) -> Dict:
params = { params = {
"api_key": self.tavily_api_key, "api_key": self.tavily_api_key,
"query": query, "query": query,
@ -51,20 +61,8 @@ class TavilySearchAPIWrapper(BaseModel):
f"{TAVILY_API_URL}/search", f"{TAVILY_API_URL}/search",
json=params, json=params,
) )
response.raise_for_status() response.raise_for_status()
search_results = response.json() return response.json()
return self.clean_results(search_results["results"])
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and endpoint exists in environment."""
tavily_api_key = get_from_dict_or_env(
values, "tavily_api_key", "TAVILY_API_KEY"
)
values["tavily_api_key"] = tavily_api_key
return values
def results( def results(
self, self,
@ -88,7 +86,6 @@ class TavilySearchAPIWrapper(BaseModel):
include_answer: Whether to include the answer in the results. include_answer: Whether to include the answer in the results.
include_raw_content: Whether to include the raw content in the results. include_raw_content: Whether to include the raw content in the results.
include_images: Whether to include images in the results. include_images: Whether to include images in the results.
Returns: Returns:
query: The query that was searched for. query: The query that was searched for.
follow_up_questions: A list of follow up questions. follow_up_questions: A list of follow up questions.
@ -101,22 +98,20 @@ class TavilySearchAPIWrapper(BaseModel):
content: The content of the result. content: The content of the result.
score: The score of the result. score: The score of the result.
raw_content: The raw content of the result. raw_content: The raw content of the result.
""" # noqa: E501 """ # noqa: E501
raw_search_results = self._tavily_search_results( raw_search_results = self.raw_results(
query, query,
max_results, max_results=max_results,
search_depth, search_depth=search_depth,
include_domains, include_domains=include_domains,
exclude_domains, exclude_domains=exclude_domains,
include_answer, include_answer=include_answer,
include_raw_content, include_raw_content=include_raw_content,
include_images, include_images=include_images,
) )
return raw_search_results return self.clean_results(raw_search_results["results"])
async def results_async( async def raw_results_async(
self, self,
query: str, query: str,
max_results: Optional[int] = 5, max_results: Optional[int] = 5,
@ -126,7 +121,7 @@ class TavilySearchAPIWrapper(BaseModel):
include_answer: Optional[bool] = False, include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False, include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False, include_images: Optional[bool] = False,
) -> List[Dict]: ) -> Dict:
"""Get results from the Tavily Search API asynchronously.""" """Get results from the Tavily Search API asynchronously."""
# Function to perform the API call # Function to perform the API call
@ -151,7 +146,29 @@ class TavilySearchAPIWrapper(BaseModel):
raise Exception(f"Error {res.status}: {res.reason}") raise Exception(f"Error {res.status}: {res.reason}")
results_json_str = await fetch() results_json_str = await fetch()
results_json = json.loads(results_json_str) return json.loads(results_json_str)
async def results_async(
self,
query: str,
max_results: Optional[int] = 5,
search_depth: Optional[str] = "advanced",
include_domains: Optional[List[str]] = [],
exclude_domains: Optional[List[str]] = [],
include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False,
) -> List[Dict]:
results_json = await self.raw_results_async(
query=query,
max_results=max_results,
search_depth=search_depth,
include_domains=include_domains,
exclude_domains=exclude_domains,
include_answer=include_answer,
include_raw_content=include_raw_content,
include_images=include_images,
)
return self.clean_results(results_json["results"]) return self.clean_results(results_json["results"])
def clean_results(self, results: List[Dict]) -> List[Dict]: def clean_results(self, results: List[Dict]) -> List[Dict]:

Loading…
Cancel
Save