mirror of https://github.com/hwchase17/langchain
Add Tavily Search API as a Tool (#12103)
Adding Tavily Search API as a tool. I will be the maintainer and assaf_elovic is the twitter handler. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>pull/12108/head
parent
85302a9ec1
commit
78d186fb44
@ -0,0 +1,5 @@
|
||||
"""Tavily Search API toolkit."""
|
||||
|
||||
from langchain.tools.tavily_search.tool import TavilySearchResults
|
||||
|
||||
__all__ = ["TavilySearchResults"]
|
@ -0,0 +1,51 @@
|
||||
"""Tool for the Tavily search API."""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
|
||||
|
||||
|
||||
class TavilySearchResults(BaseTool):
|
||||
"""Tool that queries the Tavily Search API and gets back json."""
|
||||
|
||||
name: str = "tavily_search_results_json"
|
||||
description: str = """"
|
||||
"A search engine optimized for comprehensive, accurate, and trusted results. "
|
||||
"Useful for when you need to answer questions about current events. "
|
||||
"Input should be a search query."
|
||||
"""
|
||||
api_wrapper: TavilySearchAPIWrapper
|
||||
max_results: int = 5
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> Union[List[Dict], str]:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
return self.api_wrapper.results(
|
||||
query,
|
||||
self.max_results,
|
||||
)
|
||||
except Exception as e:
|
||||
return repr(e)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> Union[List[Dict], str]:
|
||||
"""Use the tool asynchronously."""
|
||||
try:
|
||||
return await self.api_wrapper.results_async(
|
||||
query,
|
||||
self.max_results,
|
||||
)
|
||||
except Exception as e:
|
||||
return repr(e)
|
@ -0,0 +1,167 @@
|
||||
"""Util that calls Tavily Search API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
"""
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
TAVILY_API_URL = "https://api.tavily.com"
|
||||
|
||||
|
||||
class TavilySearchAPIWrapper(BaseModel):
|
||||
"""Wrapper for Tavily Search API."""
|
||||
|
||||
tavily_api_key: str
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def _tavily_search_results(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = 5,
|
||||
search_depth: Optional[str] = "advanced",
|
||||
include_domains: Optional[List[str]] = [],
|
||||
exclude_domains: Optional[List[str]] = [],
|
||||
include_answer: Optional[bool] = False,
|
||||
include_raw_content: Optional[bool] = False,
|
||||
include_images: Optional[bool] = False,
|
||||
) -> List[dict]:
|
||||
params = {
|
||||
"api_key": self.tavily_api_key,
|
||||
"query": query,
|
||||
"max_results": max_results,
|
||||
"search_depth": search_depth,
|
||||
"include_domains": include_domains,
|
||||
"exclude_domains": exclude_domains,
|
||||
"include_answer": include_answer,
|
||||
"include_raw_content": include_raw_content,
|
||||
"include_images": include_images,
|
||||
}
|
||||
response = requests.post(
|
||||
# type: ignore
|
||||
f"{TAVILY_API_URL}/search",
|
||||
json=params,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
search_results = response.json()
|
||||
return self.clean_results(search_results["results"])
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
tavily_api_key = get_from_dict_or_env(
|
||||
values, "tavily_api_key", "TAVILY_API_KEY"
|
||||
)
|
||||
values["tavily_api_key"] = tavily_api_key
|
||||
|
||||
return values
|
||||
|
||||
def results(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = 5,
|
||||
search_depth: Optional[str] = "advanced",
|
||||
include_domains: Optional[List[str]] = [],
|
||||
exclude_domains: Optional[List[str]] = [],
|
||||
include_answer: Optional[bool] = False,
|
||||
include_raw_content: Optional[bool] = False,
|
||||
include_images: Optional[bool] = False,
|
||||
) -> List[Dict]:
|
||||
"""Run query through Tavily Search and return metadata.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
max_results: The maximum number of results to return.
|
||||
search_depth: The depth of the search. Can be "basic" or "advanced".
|
||||
include_domains: A list of domains to include in the search.
|
||||
exclude_domains: A list of domains to exclude from the search.
|
||||
include_answer: Whether to include the answer in the results.
|
||||
include_raw_content: Whether to include the raw content in the results.
|
||||
include_images: Whether to include images in the results.
|
||||
|
||||
Returns:
|
||||
query: The query that was searched for.
|
||||
follow_up_questions: A list of follow up questions.
|
||||
response_time: The response time of the query.
|
||||
answer: The answer to the query.
|
||||
images: A list of images.
|
||||
results: A list of dictionaries containing the results:
|
||||
title: The title of the result.
|
||||
url: The url of the result.
|
||||
content: The content of the result.
|
||||
score: The score of the result.
|
||||
raw_content: The raw content of the result.
|
||||
|
||||
|
||||
""" # noqa: E501
|
||||
raw_search_results = self._tavily_search_results(
|
||||
query,
|
||||
max_results,
|
||||
search_depth,
|
||||
include_domains,
|
||||
exclude_domains,
|
||||
include_answer,
|
||||
include_raw_content,
|
||||
include_images,
|
||||
)
|
||||
return raw_search_results
|
||||
|
||||
async def results_async(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = 5,
|
||||
search_depth: Optional[str] = "advanced",
|
||||
include_domains: Optional[List[str]] = [],
|
||||
exclude_domains: Optional[List[str]] = [],
|
||||
include_answer: Optional[bool] = False,
|
||||
include_raw_content: Optional[bool] = False,
|
||||
include_images: Optional[bool] = False,
|
||||
) -> List[Dict]:
|
||||
"""Get results from the Tavily Search API asynchronously."""
|
||||
|
||||
# Function to perform the API call
|
||||
async def fetch() -> str:
|
||||
params = {
|
||||
"api_key": self.tavily_api_key,
|
||||
"query": query,
|
||||
"max_results": max_results,
|
||||
"search_depth": search_depth,
|
||||
"include_domains": include_domains,
|
||||
"exclude_domains": exclude_domains,
|
||||
"include_answer": include_answer,
|
||||
"include_raw_content": include_raw_content,
|
||||
"include_images": include_images,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{TAVILY_API_URL}/search", json=params) as res:
|
||||
if res.status == 200:
|
||||
data = await res.text()
|
||||
return data
|
||||
else:
|
||||
raise Exception(f"Error {res.status}: {res.reason}")
|
||||
|
||||
results_json_str = await fetch()
|
||||
results_json = json.loads(results_json_str)
|
||||
return self.clean_results(results_json["results"])
|
||||
|
||||
def clean_results(self, results: List[Dict]) -> List[Dict]:
|
||||
"""Clean results from Tavily Search API."""
|
||||
clean_results = []
|
||||
for result in results:
|
||||
clean_results.append(
|
||||
{
|
||||
"url": result["url"],
|
||||
"content": result["content"],
|
||||
}
|
||||
)
|
||||
return clean_results
|
Loading…
Reference in New Issue