update tools (#13243)

pull/13208/head^2
Harrison Chase 8 months ago committed by GitHub
parent 8d6faf5665
commit 7f1d26160d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,6 +18,8 @@ class YouRetriever(BaseRetriever):
ydc_api_key: str
k: Optional[int] = None
n_hits: Optional[int] = None
n_snippets_per_hit: Optional[int] = None
endpoint_type: str = "web"
@root_validator(pre=True)
@ -43,8 +45,10 @@ class YouRetriever(BaseRetriever):
).json()
docs = []
for hit in results["hits"]:
for snippet in hit["snippets"]:
n_hits = self.n_hits or len(results["hits"])
for hit in results["hits"][:n_hits]:
n_snippets_per_hit = self.n_snippets_per_hit or len(hit["snippets"])
for snippet in hit["snippets"][:n_snippets_per_hit]:
docs.append(Document(page_content=snippet))
if self.k is not None and len(docs) >= self.k:
return docs

@ -1,13 +1,17 @@
"""Tool for the Arxiv API."""
from typing import Optional
from typing import Optional, Type
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.pydantic_v1 import Field
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools.base import BaseTool
from langchain.utilities.arxiv import ArxivAPIWrapper
class ArxivInput(BaseModel):
query: str = Field(description="search query to look up")
class ArxivQueryRun(BaseTool):
"""Tool that searches the Arxiv API."""
@ -21,6 +25,7 @@ class ArxivQueryRun(BaseTool):
"Input should be a search query."
)
api_wrapper: ArxivAPIWrapper = Field(default_factory=ArxivAPIWrapper)
args_schema: Type[BaseModel] = ArxivInput
def _run(
self,

@ -1,7 +1,12 @@
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema import BaseRetriever
from langchain.tools import Tool
class RetrieverInput(BaseModel):
query: str = Field(description="query to look up in retriever")
def create_retriever_tool(
retriever: BaseRetriever, name: str, description: str
) -> Tool:
@ -22,4 +27,5 @@ def create_retriever_tool(
description=description,
func=retriever.get_relevant_documents,
coroutine=retriever.aget_relevant_documents,
args_schema=RetrieverInput,
)

@ -1,5 +1,5 @@
"""Tavily Search API toolkit."""
from langchain.tools.tavily_search.tool import TavilySearchResults
from langchain.tools.tavily_search.tool import TavilyAnswer, TavilySearchResults
__all__ = ["TavilySearchResults"]
__all__ = ["TavilySearchResults", "TavilyAnswer"]

@ -1,26 +1,32 @@
"""Tool for the Tavily search API."""
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools.base import BaseTool
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
class TavilyInput(BaseModel):
query: str = Field(description="search query to look up")
class TavilySearchResults(BaseTool):
"""Tool that queries the Tavily Search API and gets back json."""
name: str = "tavily_search_results_json"
description: str = """"
description: str = (
"A search engine optimized for comprehensive, accurate, and trusted results. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query."
"""
)
api_wrapper: TavilySearchAPIWrapper
max_results: int = 5
args_schema: Type[BaseModel] = TavilyInput
def _run(
self,
@ -49,3 +55,50 @@ class TavilySearchResults(BaseTool):
)
except Exception as e:
return repr(e)
class TavilyAnswer(BaseTool):
"""Tool that queries the Tavily Search API and gets back an answer."""
name: str = "tavily_answer"
description: str = (
"A search engine optimized for comprehensive, accurate, and trusted results. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query. "
"This returns only the answer - not the original source data."
)
api_wrapper: TavilySearchAPIWrapper
args_schema: Type[BaseModel] = TavilyInput
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool."""
try:
return self.api_wrapper.raw_results(
query,
max_results=5,
include_answer=True,
search_depth="basic",
)["answer"]
except Exception as e:
return repr(e)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool asynchronously."""
try:
result = await self.api_wrapper.raw_results_async(
query,
max_results=5,
include_answer=True,
search_depth="basic",
)
return result["answer"]
except Exception as e:
return repr(e)

@ -24,7 +24,17 @@ class TavilySearchAPIWrapper(BaseModel):
extra = Extra.forbid
def _tavily_search_results(
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and endpoint exists in environment."""
tavily_api_key = get_from_dict_or_env(
values, "tavily_api_key", "TAVILY_API_KEY"
)
values["tavily_api_key"] = tavily_api_key
return values
def raw_results(
self,
query: str,
max_results: Optional[int] = 5,
@ -34,7 +44,7 @@ class TavilySearchAPIWrapper(BaseModel):
include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False,
) -> List[dict]:
) -> Dict:
params = {
"api_key": self.tavily_api_key,
"query": query,
@ -51,20 +61,8 @@ class TavilySearchAPIWrapper(BaseModel):
f"{TAVILY_API_URL}/search",
json=params,
)
response.raise_for_status()
search_results = response.json()
return self.clean_results(search_results["results"])
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and endpoint exists in environment."""
tavily_api_key = get_from_dict_or_env(
values, "tavily_api_key", "TAVILY_API_KEY"
)
values["tavily_api_key"] = tavily_api_key
return values
return response.json()
def results(
self,
@ -88,7 +86,6 @@ class TavilySearchAPIWrapper(BaseModel):
include_answer: Whether to include the answer in the results.
include_raw_content: Whether to include the raw content in the results.
include_images: Whether to include images in the results.
Returns:
query: The query that was searched for.
follow_up_questions: A list of follow up questions.
@ -101,22 +98,20 @@ class TavilySearchAPIWrapper(BaseModel):
content: The content of the result.
score: The score of the result.
raw_content: The raw content of the result.
""" # noqa: E501
raw_search_results = self._tavily_search_results(
raw_search_results = self.raw_results(
query,
max_results,
search_depth,
include_domains,
exclude_domains,
include_answer,
include_raw_content,
include_images,
max_results=max_results,
search_depth=search_depth,
include_domains=include_domains,
exclude_domains=exclude_domains,
include_answer=include_answer,
include_raw_content=include_raw_content,
include_images=include_images,
)
return raw_search_results
return self.clean_results(raw_search_results["results"])
async def results_async(
async def raw_results_async(
self,
query: str,
max_results: Optional[int] = 5,
@ -126,7 +121,7 @@ class TavilySearchAPIWrapper(BaseModel):
include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False,
) -> List[Dict]:
) -> Dict:
"""Get results from the Tavily Search API asynchronously."""
# Function to perform the API call
@ -151,7 +146,29 @@ class TavilySearchAPIWrapper(BaseModel):
raise Exception(f"Error {res.status}: {res.reason}")
results_json_str = await fetch()
results_json = json.loads(results_json_str)
return json.loads(results_json_str)
async def results_async(
self,
query: str,
max_results: Optional[int] = 5,
search_depth: Optional[str] = "advanced",
include_domains: Optional[List[str]] = [],
exclude_domains: Optional[List[str]] = [],
include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False,
) -> List[Dict]:
results_json = await self.raw_results_async(
query=query,
max_results=max_results,
search_depth=search_depth,
include_domains=include_domains,
exclude_domains=exclude_domains,
include_answer=include_answer,
include_raw_content=include_raw_content,
include_images=include_images,
)
return self.clean_results(results_json["results"])
def clean_results(self, results: List[Dict]) -> List[Dict]:

Loading…
Cancel
Save