mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
c2a3021bb0
Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com> Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com> Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com> Co-authored-by: ZhangShenao <15201440436@163.com> Co-authored-by: Friso H. Kingma <fhkingma@gmail.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Nuno Campos <nuno@langchain.dev> Co-authored-by: Morgante Pell <morgantep@google.com>
186 lines
6.7 KiB
Python
186 lines
6.7 KiB
Python
"""Util that calls Tavily Search API.
|
|
|
|
In order to set this up, follow instructions at:
|
|
https://docs.tavily.com/docs/tavily-api/introduction
|
|
"""
|
|
|
|
import json
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import aiohttp
|
|
import requests
|
|
from langchain_core.utils import get_from_dict_or_env
|
|
from pydantic import BaseModel, ConfigDict, SecretStr, model_validator
|
|
|
|
TAVILY_API_URL = "https://api.tavily.com"
|
|
|
|
|
|
class TavilySearchAPIWrapper(BaseModel):
|
|
"""Wrapper for Tavily Search API."""
|
|
|
|
tavily_api_key: SecretStr
|
|
|
|
model_config = ConfigDict(
|
|
extra="forbid",
|
|
)
|
|
|
|
@model_validator(mode="before")
|
|
@classmethod
|
|
def validate_environment(cls, values: Dict) -> Any:
|
|
"""Validate that api key and endpoint exists in environment."""
|
|
tavily_api_key = get_from_dict_or_env(
|
|
values, "tavily_api_key", "TAVILY_API_KEY"
|
|
)
|
|
values["tavily_api_key"] = tavily_api_key
|
|
|
|
return values
|
|
|
|
def raw_results(
|
|
self,
|
|
query: str,
|
|
max_results: Optional[int] = 5,
|
|
search_depth: Optional[str] = "advanced",
|
|
include_domains: Optional[List[str]] = [],
|
|
exclude_domains: Optional[List[str]] = [],
|
|
include_answer: Optional[bool] = False,
|
|
include_raw_content: Optional[bool] = False,
|
|
include_images: Optional[bool] = False,
|
|
) -> Dict:
|
|
params = {
|
|
"api_key": self.tavily_api_key.get_secret_value(),
|
|
"query": query,
|
|
"max_results": max_results,
|
|
"search_depth": search_depth,
|
|
"include_domains": include_domains,
|
|
"exclude_domains": exclude_domains,
|
|
"include_answer": include_answer,
|
|
"include_raw_content": include_raw_content,
|
|
"include_images": include_images,
|
|
}
|
|
response = requests.post(
|
|
# type: ignore
|
|
f"{TAVILY_API_URL}/search",
|
|
json=params,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
def results(
|
|
self,
|
|
query: str,
|
|
max_results: Optional[int] = 5,
|
|
search_depth: Optional[str] = "advanced",
|
|
include_domains: Optional[List[str]] = [],
|
|
exclude_domains: Optional[List[str]] = [],
|
|
include_answer: Optional[bool] = False,
|
|
include_raw_content: Optional[bool] = False,
|
|
include_images: Optional[bool] = False,
|
|
) -> List[Dict]:
|
|
"""Run query through Tavily Search and return metadata.
|
|
|
|
Args:
|
|
query: The query to search for.
|
|
max_results: The maximum number of results to return.
|
|
search_depth: The depth of the search. Can be "basic" or "advanced".
|
|
include_domains: A list of domains to include in the search.
|
|
exclude_domains: A list of domains to exclude from the search.
|
|
include_answer: Whether to include the answer in the results.
|
|
include_raw_content: Whether to include the raw content in the results.
|
|
include_images: Whether to include images in the results.
|
|
Returns:
|
|
query: The query that was searched for.
|
|
follow_up_questions: A list of follow up questions.
|
|
response_time: The response time of the query.
|
|
answer: The answer to the query.
|
|
images: A list of images.
|
|
results: A list of dictionaries containing the results:
|
|
title: The title of the result.
|
|
url: The url of the result.
|
|
content: The content of the result.
|
|
score: The score of the result.
|
|
raw_content: The raw content of the result.
|
|
"""
|
|
raw_search_results = self.raw_results(
|
|
query,
|
|
max_results=max_results,
|
|
search_depth=search_depth,
|
|
include_domains=include_domains,
|
|
exclude_domains=exclude_domains,
|
|
include_answer=include_answer,
|
|
include_raw_content=include_raw_content,
|
|
include_images=include_images,
|
|
)
|
|
return self.clean_results(raw_search_results["results"])
|
|
|
|
async def raw_results_async(
|
|
self,
|
|
query: str,
|
|
max_results: Optional[int] = 5,
|
|
search_depth: Optional[str] = "advanced",
|
|
include_domains: Optional[List[str]] = [],
|
|
exclude_domains: Optional[List[str]] = [],
|
|
include_answer: Optional[bool] = False,
|
|
include_raw_content: Optional[bool] = False,
|
|
include_images: Optional[bool] = False,
|
|
) -> Dict:
|
|
"""Get results from the Tavily Search API asynchronously."""
|
|
|
|
# Function to perform the API call
|
|
async def fetch() -> str:
|
|
params = {
|
|
"api_key": self.tavily_api_key.get_secret_value(),
|
|
"query": query,
|
|
"max_results": max_results,
|
|
"search_depth": search_depth,
|
|
"include_domains": include_domains,
|
|
"exclude_domains": exclude_domains,
|
|
"include_answer": include_answer,
|
|
"include_raw_content": include_raw_content,
|
|
"include_images": include_images,
|
|
}
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(f"{TAVILY_API_URL}/search", json=params) as res:
|
|
if res.status == 200:
|
|
data = await res.text()
|
|
return data
|
|
else:
|
|
raise Exception(f"Error {res.status}: {res.reason}")
|
|
|
|
results_json_str = await fetch()
|
|
return json.loads(results_json_str)
|
|
|
|
async def results_async(
|
|
self,
|
|
query: str,
|
|
max_results: Optional[int] = 5,
|
|
search_depth: Optional[str] = "advanced",
|
|
include_domains: Optional[List[str]] = [],
|
|
exclude_domains: Optional[List[str]] = [],
|
|
include_answer: Optional[bool] = False,
|
|
include_raw_content: Optional[bool] = False,
|
|
include_images: Optional[bool] = False,
|
|
) -> List[Dict]:
|
|
results_json = await self.raw_results_async(
|
|
query=query,
|
|
max_results=max_results,
|
|
search_depth=search_depth,
|
|
include_domains=include_domains,
|
|
exclude_domains=exclude_domains,
|
|
include_answer=include_answer,
|
|
include_raw_content=include_raw_content,
|
|
include_images=include_images,
|
|
)
|
|
return self.clean_results(results_json["results"])
|
|
|
|
def clean_results(self, results: List[Dict]) -> List[Dict]:
|
|
"""Clean results from Tavily Search API."""
|
|
clean_results = []
|
|
for result in results:
|
|
clean_results.append(
|
|
{
|
|
"url": result["url"],
|
|
"content": result["content"],
|
|
}
|
|
)
|
|
return clean_results
|