langchain/libs/community/langchain_community/tools/tavily_search/tool.py

133 lines
4.7 KiB
Python

"""Tool for the Tavily search API."""
from typing import Dict, List, Optional, Type, Union
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
class TavilyInput(BaseModel):
"""Input for the Tavily tool."""
query: str = Field(description="search query to look up")
class TavilySearchResults(BaseTool):
"""Tool that queries the Tavily Search API and gets back json."""
name: str = "tavily_search_results_json"
description: str = (
"A search engine optimized for comprehensive, accurate, and trusted results. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query."
)
api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper) # type: ignore[arg-type]
max_results: int = 5
"""Max search results to return, default is 5"""
search_depth: str = "advanced"
'''The depth of the search. It can be "basic" or "advanced"'''
include_domains: List[str] = []
"""A list of domains to specifically include in the search results. Default is None, which includes all domains.""" # noqa: E501
exclude_domains: List[str] = []
"""A list of domains to specifically exclude from the search results. Default is None, which doesn't exclude any domains.""" # noqa: E501
include_answer: bool = False
"""Include a short answer to original query in the search results. Default is False.""" # noqa: E501
include_raw_content: bool = False
"""Include cleaned and parsed HTML of each site search results. Default is False."""
include_images: bool = False
"""Include a list of query related images in the response. Default is False."""
args_schema: Type[BaseModel] = TavilyInput
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool."""
try:
return self.api_wrapper.results(
query,
self.max_results,
self.search_depth,
self.include_domains,
self.exclude_domains,
self.include_answer,
self.include_raw_content,
self.include_images,
)
except Exception as e:
return repr(e)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool asynchronously."""
try:
return await self.api_wrapper.results_async(
query,
self.max_results,
self.search_depth,
self.include_domains,
self.exclude_domains,
self.include_answer,
self.include_raw_content,
self.include_images,
)
except Exception as e:
return repr(e)
class TavilyAnswer(BaseTool):
"""Tool that queries the Tavily Search API and gets back an answer."""
name: str = "tavily_answer"
description: str = (
"A search engine optimized for comprehensive, accurate, and trusted results. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query. "
"This returns only the answer - not the original source data."
)
api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper) # type: ignore[arg-type]
args_schema: Type[BaseModel] = TavilyInput
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool."""
try:
return self.api_wrapper.raw_results(
query,
max_results=5,
include_answer=True,
search_depth="basic",
)["answer"]
except Exception as e:
return repr(e)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool asynchronously."""
try:
result = await self.api_wrapper.raw_results_async(
query,
max_results=5,
include_answer=True,
search_depth="basic",
)
return result["answer"]
except Exception as e:
return repr(e)