"""Tool for the Tavily search API.""" from typing import Dict, List, Optional, Type, Union from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper class TavilyInput(BaseModel): """Input for the Tavily tool.""" query: str = Field(description="search query to look up") class TavilySearchResults(BaseTool): """Tool that queries the Tavily Search API and gets back json.""" name: str = "tavily_search_results_json" description: str = ( "A search engine optimized for comprehensive, accurate, and trusted results. " "Useful for when you need to answer questions about current events. " "Input should be a search query." ) api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper) # type: ignore[arg-type] max_results: int = 5 """Max search results to return, default is 5""" search_depth: str = "advanced" '''The depth of the search. It can be "basic" or "advanced"''' include_domains: List[str] = [] """A list of domains to specifically include in the search results. Default is None, which includes all domains.""" # noqa: E501 exclude_domains: List[str] = [] """A list of domains to specifically exclude from the search results. Default is None, which doesn't exclude any domains.""" # noqa: E501 include_answer: bool = False """Include a short answer to original query in the search results. Default is False.""" # noqa: E501 include_raw_content: bool = False """Include cleaned and parsed HTML of each site search results. Default is False.""" include_images: bool = False """Include a list of query related images in the response. Default is False.""" args_schema: Type[BaseModel] = TavilyInput def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Union[List[Dict], str]: """Use the tool.""" try: return self.api_wrapper.results( query, self.max_results, self.search_depth, self.include_domains, self.exclude_domains, self.include_answer, self.include_raw_content, self.include_images, ) except Exception as e: return repr(e) async def _arun( self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Union[List[Dict], str]: """Use the tool asynchronously.""" try: return await self.api_wrapper.results_async( query, self.max_results, self.search_depth, self.include_domains, self.exclude_domains, self.include_answer, self.include_raw_content, self.include_images, ) except Exception as e: return repr(e) class TavilyAnswer(BaseTool): """Tool that queries the Tavily Search API and gets back an answer.""" name: str = "tavily_answer" description: str = ( "A search engine optimized for comprehensive, accurate, and trusted results. " "Useful for when you need to answer questions about current events. " "Input should be a search query. " "This returns only the answer - not the original source data." ) api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper) # type: ignore[arg-type] args_schema: Type[BaseModel] = TavilyInput def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Union[List[Dict], str]: """Use the tool.""" try: return self.api_wrapper.raw_results( query, max_results=5, include_answer=True, search_depth="basic", )["answer"] except Exception as e: return repr(e) async def _arun( self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Union[List[Dict], str]: """Use the tool asynchronously.""" try: result = await self.api_wrapper.raw_results_async( query, max_results=5, include_answer=True, search_depth="basic", ) return result["answer"] except Exception as e: return repr(e)