langchain/libs/community/langchain_community/tools/tavily_search/tool.py
Erick Friis c2a3021bb0
multiple: pydantic 2 compatibility, v0.3 (#26443)
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com>
Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com>
Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com>
Co-authored-by: ZhangShenao <15201440436@163.com>
Co-authored-by: Friso H. Kingma <fhkingma@gmail.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
Co-authored-by: Morgante Pell <morgantep@google.com>
2024-09-13 14:38:45 -07:00

238 lines
7.6 KiB
Python

"""Tool for the Tavily search API."""
from typing import Dict, List, Literal, Optional, Tuple, Type, Union
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
class TavilyInput(BaseModel):
"""Input for the Tavily tool."""
query: str = Field(description="search query to look up")
class TavilySearchResults(BaseTool):
"""Tool that queries the Tavily Search API and gets back json.
Setup:
Install ``langchain-openai`` and ``tavily-python``, and set environment variable ``TAVILY_API_KEY``.
.. code-block:: bash
pip install -U langchain-community tavily-python
export TAVILY_API_KEY="your-api-key"
Instantiate:
.. code-block:: python
from langchain_community.tools import TavilySearchResults
tool = TavilySearchResults(
max_results=5,
include_answer=True,
include_raw_content=True,
include_images=True,
# search_depth="advanced",
# include_domains = []
# exclude_domains = []
)
Invoke directly with args:
.. code-block:: python
tool.invoke({'query': 'who won the last french open'})
.. code-block:: python
'{\n "url": "https://www.nytimes.com...", "content": "Novak Djokovic won the last French Open by beating Casper Ruud ...'
Invoke with tool call:
.. code-block:: python
tool.invoke({"args": {'query': 'who won the last french open'}, "type": "tool_call", "id": "foo", "name": "tavily"})
.. code-block:: python
ToolMessage(
content='{\n "url": "https://www.nytimes.com...", "content": "Novak Djokovic won the last French Open by beating Casper Ruud ...',
artifact={
'query': 'who won the last french open',
'follow_up_questions': None,
'answer': 'Novak ...',
'images': [
'https://www.amny.com/wp-content/uploads/2023/06/AP23162622181176-1200x800.jpg',
...
],
'results': [
{
'title': 'Djokovic ...',
'url': 'https://www.nytimes.com...',
'content': "Novak...",
'score': 0.99505633,
'raw_content': 'Tennis\nNovak ...'
},
...
],
'response_time': 2.92
},
tool_call_id='1',
name='tavily_search_results_json',
)
""" # noqa: E501
name: str = "tavily_search_results_json"
description: str = (
"A search engine optimized for comprehensive, accurate, and trusted results. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query."
)
args_schema: Type[BaseModel] = TavilyInput
"""The tool response format."""
max_results: int = 5
"""Max search results to return, default is 5"""
search_depth: str = "advanced"
"""The depth of the search. It can be "basic" or "advanced"
.. versionadded:: 0.2.5
"""
include_domains: List[str] = []
"""A list of domains to specifically include in the search results.
Default is None, which includes all domains.
.. versionadded:: 0.2.5
"""
exclude_domains: List[str] = []
"""A list of domains to specifically exclude from the search results.
Default is None, which doesn't exclude any domains.
.. versionadded:: 0.2.5
"""
include_answer: bool = False
"""Include a short answer to original query in the search results.
Default is False.
.. versionadded:: 0.2.5
"""
include_raw_content: bool = False
"""Include cleaned and parsed HTML of each site search results.
Default is False.
.. versionadded:: 0.2.5
"""
include_images: bool = False
"""Include a list of query related images in the response.
Default is False.
.. versionadded:: 0.2.5
"""
api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper) # type: ignore[arg-type]
response_format: Literal["content_and_artifact"] = "content_and_artifact"
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Union[List[Dict[str, str]], str], Dict]:
"""Use the tool."""
# TODO: remove try/except, should be handled by BaseTool
try:
raw_results = self.api_wrapper.raw_results(
query,
self.max_results,
self.search_depth,
self.include_domains,
self.exclude_domains,
self.include_answer,
self.include_raw_content,
self.include_images,
)
except Exception as e:
return repr(e), {}
return self.api_wrapper.clean_results(raw_results["results"]), raw_results
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[Union[List[Dict[str, str]], str], Dict]:
"""Use the tool asynchronously."""
try:
raw_results = await self.api_wrapper.raw_results_async(
query,
self.max_results,
self.search_depth,
self.include_domains,
self.exclude_domains,
self.include_answer,
self.include_raw_content,
self.include_images,
)
except Exception as e:
return repr(e), {}
return self.api_wrapper.clean_results(raw_results["results"]), raw_results
class TavilyAnswer(BaseTool):
"""Tool that queries the Tavily Search API and gets back an answer."""
name: str = "tavily_answer"
description: str = (
"A search engine optimized for comprehensive, accurate, and trusted results. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query. "
"This returns only the answer - not the original source data."
)
api_wrapper: TavilySearchAPIWrapper = Field(default_factory=TavilySearchAPIWrapper) # type: ignore[arg-type]
args_schema: Type[BaseModel] = TavilyInput
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool."""
try:
return self.api_wrapper.raw_results(
query,
max_results=5,
include_answer=True,
search_depth="basic",
)["answer"]
except Exception as e:
return repr(e)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool asynchronously."""
try:
result = await self.api_wrapper.raw_results_async(
query,
max_results=5,
include_answer=True,
search_depth="basic",
)
return result["answer"]
except Exception as e:
return repr(e)