mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
e5cf1e2414
So the api keys don't show up in repr's Still need to do tests
184 lines
6.7 KiB
Python
184 lines
6.7 KiB
Python
"""Util that calls Tavily Search API.
|
|
|
|
In order to set this up, follow instructions at:
|
|
"""
|
|
import json
|
|
from typing import Dict, List, Optional
|
|
|
|
import aiohttp
|
|
import requests
|
|
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
|
|
from langchain_core.utils import get_from_dict_or_env
|
|
|
|
TAVILY_API_URL = "https://api.tavily.com"
|
|
|
|
|
|
class TavilySearchAPIWrapper(BaseModel):
|
|
"""Wrapper for Tavily Search API."""
|
|
|
|
tavily_api_key: SecretStr
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
@root_validator(pre=True)
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key and endpoint exists in environment."""
|
|
tavily_api_key = get_from_dict_or_env(
|
|
values, "tavily_api_key", "TAVILY_API_KEY"
|
|
)
|
|
values["tavily_api_key"] = tavily_api_key
|
|
|
|
return values
|
|
|
|
def raw_results(
|
|
self,
|
|
query: str,
|
|
max_results: Optional[int] = 5,
|
|
search_depth: Optional[str] = "advanced",
|
|
include_domains: Optional[List[str]] = [],
|
|
exclude_domains: Optional[List[str]] = [],
|
|
include_answer: Optional[bool] = False,
|
|
include_raw_content: Optional[bool] = False,
|
|
include_images: Optional[bool] = False,
|
|
) -> Dict:
|
|
params = {
|
|
"api_key": self.tavily_api_key.get_secret_value(),
|
|
"query": query,
|
|
"max_results": max_results,
|
|
"search_depth": search_depth,
|
|
"include_domains": include_domains,
|
|
"exclude_domains": exclude_domains,
|
|
"include_answer": include_answer,
|
|
"include_raw_content": include_raw_content,
|
|
"include_images": include_images,
|
|
}
|
|
response = requests.post(
|
|
# type: ignore
|
|
f"{TAVILY_API_URL}/search",
|
|
json=params,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
def results(
|
|
self,
|
|
query: str,
|
|
max_results: Optional[int] = 5,
|
|
search_depth: Optional[str] = "advanced",
|
|
include_domains: Optional[List[str]] = [],
|
|
exclude_domains: Optional[List[str]] = [],
|
|
include_answer: Optional[bool] = False,
|
|
include_raw_content: Optional[bool] = False,
|
|
include_images: Optional[bool] = False,
|
|
) -> List[Dict]:
|
|
"""Run query through Tavily Search and return metadata.
|
|
|
|
Args:
|
|
query: The query to search for.
|
|
max_results: The maximum number of results to return.
|
|
search_depth: The depth of the search. Can be "basic" or "advanced".
|
|
include_domains: A list of domains to include in the search.
|
|
exclude_domains: A list of domains to exclude from the search.
|
|
include_answer: Whether to include the answer in the results.
|
|
include_raw_content: Whether to include the raw content in the results.
|
|
include_images: Whether to include images in the results.
|
|
Returns:
|
|
query: The query that was searched for.
|
|
follow_up_questions: A list of follow up questions.
|
|
response_time: The response time of the query.
|
|
answer: The answer to the query.
|
|
images: A list of images.
|
|
results: A list of dictionaries containing the results:
|
|
title: The title of the result.
|
|
url: The url of the result.
|
|
content: The content of the result.
|
|
score: The score of the result.
|
|
raw_content: The raw content of the result.
|
|
""" # noqa: E501
|
|
raw_search_results = self.raw_results(
|
|
query,
|
|
max_results=max_results,
|
|
search_depth=search_depth,
|
|
include_domains=include_domains,
|
|
exclude_domains=exclude_domains,
|
|
include_answer=include_answer,
|
|
include_raw_content=include_raw_content,
|
|
include_images=include_images,
|
|
)
|
|
return self.clean_results(raw_search_results["results"])
|
|
|
|
async def raw_results_async(
|
|
self,
|
|
query: str,
|
|
max_results: Optional[int] = 5,
|
|
search_depth: Optional[str] = "advanced",
|
|
include_domains: Optional[List[str]] = [],
|
|
exclude_domains: Optional[List[str]] = [],
|
|
include_answer: Optional[bool] = False,
|
|
include_raw_content: Optional[bool] = False,
|
|
include_images: Optional[bool] = False,
|
|
) -> Dict:
|
|
"""Get results from the Tavily Search API asynchronously."""
|
|
|
|
# Function to perform the API call
|
|
async def fetch() -> str:
|
|
params = {
|
|
"api_key": self.tavily_api_key.get_secret_value(),
|
|
"query": query,
|
|
"max_results": max_results,
|
|
"search_depth": search_depth,
|
|
"include_domains": include_domains,
|
|
"exclude_domains": exclude_domains,
|
|
"include_answer": include_answer,
|
|
"include_raw_content": include_raw_content,
|
|
"include_images": include_images,
|
|
}
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(f"{TAVILY_API_URL}/search", json=params) as res:
|
|
if res.status == 200:
|
|
data = await res.text()
|
|
return data
|
|
else:
|
|
raise Exception(f"Error {res.status}: {res.reason}")
|
|
|
|
results_json_str = await fetch()
|
|
return json.loads(results_json_str)
|
|
|
|
async def results_async(
|
|
self,
|
|
query: str,
|
|
max_results: Optional[int] = 5,
|
|
search_depth: Optional[str] = "advanced",
|
|
include_domains: Optional[List[str]] = [],
|
|
exclude_domains: Optional[List[str]] = [],
|
|
include_answer: Optional[bool] = False,
|
|
include_raw_content: Optional[bool] = False,
|
|
include_images: Optional[bool] = False,
|
|
) -> List[Dict]:
|
|
results_json = await self.raw_results_async(
|
|
query=query,
|
|
max_results=max_results,
|
|
search_depth=search_depth,
|
|
include_domains=include_domains,
|
|
exclude_domains=exclude_domains,
|
|
include_answer=include_answer,
|
|
include_raw_content=include_raw_content,
|
|
include_images=include_images,
|
|
)
|
|
return self.clean_results(results_json["results"])
|
|
|
|
def clean_results(self, results: List[Dict]) -> List[Dict]:
|
|
"""Clean results from Tavily Search API."""
|
|
clean_results = []
|
|
for result in results:
|
|
clean_results.append(
|
|
{
|
|
"url": result["url"],
|
|
"content": result["content"],
|
|
}
|
|
)
|
|
return clean_results
|