|
|
|
@ -24,7 +24,17 @@ class TavilySearchAPIWrapper(BaseModel):
|
|
|
|
|
|
|
|
|
|
extra = Extra.forbid
|
|
|
|
|
|
|
|
|
|
def _tavily_search_results(
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Validate that api key and endpoint exists in environment."""
|
|
|
|
|
tavily_api_key = get_from_dict_or_env(
|
|
|
|
|
values, "tavily_api_key", "TAVILY_API_KEY"
|
|
|
|
|
)
|
|
|
|
|
values["tavily_api_key"] = tavily_api_key
|
|
|
|
|
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
def raw_results(
|
|
|
|
|
self,
|
|
|
|
|
query: str,
|
|
|
|
|
max_results: Optional[int] = 5,
|
|
|
|
@ -34,7 +44,7 @@ class TavilySearchAPIWrapper(BaseModel):
|
|
|
|
|
include_answer: Optional[bool] = False,
|
|
|
|
|
include_raw_content: Optional[bool] = False,
|
|
|
|
|
include_images: Optional[bool] = False,
|
|
|
|
|
) -> List[dict]:
|
|
|
|
|
) -> Dict:
|
|
|
|
|
params = {
|
|
|
|
|
"api_key": self.tavily_api_key,
|
|
|
|
|
"query": query,
|
|
|
|
@ -51,20 +61,8 @@ class TavilySearchAPIWrapper(BaseModel):
|
|
|
|
|
f"{TAVILY_API_URL}/search",
|
|
|
|
|
json=params,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
search_results = response.json()
|
|
|
|
|
return self.clean_results(search_results["results"])
|
|
|
|
|
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Validate that api key and endpoint exists in environment."""
|
|
|
|
|
tavily_api_key = get_from_dict_or_env(
|
|
|
|
|
values, "tavily_api_key", "TAVILY_API_KEY"
|
|
|
|
|
)
|
|
|
|
|
values["tavily_api_key"] = tavily_api_key
|
|
|
|
|
|
|
|
|
|
return values
|
|
|
|
|
return response.json()
|
|
|
|
|
|
|
|
|
|
def results(
|
|
|
|
|
self,
|
|
|
|
@ -88,7 +86,6 @@ class TavilySearchAPIWrapper(BaseModel):
|
|
|
|
|
include_answer: Whether to include the answer in the results.
|
|
|
|
|
include_raw_content: Whether to include the raw content in the results.
|
|
|
|
|
include_images: Whether to include images in the results.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
query: The query that was searched for.
|
|
|
|
|
follow_up_questions: A list of follow up questions.
|
|
|
|
@ -101,22 +98,20 @@ class TavilySearchAPIWrapper(BaseModel):
|
|
|
|
|
content: The content of the result.
|
|
|
|
|
score: The score of the result.
|
|
|
|
|
raw_content: The raw content of the result.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" # noqa: E501
|
|
|
|
|
raw_search_results = self._tavily_search_results(
|
|
|
|
|
raw_search_results = self.raw_results(
|
|
|
|
|
query,
|
|
|
|
|
max_results,
|
|
|
|
|
search_depth,
|
|
|
|
|
include_domains,
|
|
|
|
|
exclude_domains,
|
|
|
|
|
include_answer,
|
|
|
|
|
include_raw_content,
|
|
|
|
|
include_images,
|
|
|
|
|
max_results=max_results,
|
|
|
|
|
search_depth=search_depth,
|
|
|
|
|
include_domains=include_domains,
|
|
|
|
|
exclude_domains=exclude_domains,
|
|
|
|
|
include_answer=include_answer,
|
|
|
|
|
include_raw_content=include_raw_content,
|
|
|
|
|
include_images=include_images,
|
|
|
|
|
)
|
|
|
|
|
return raw_search_results
|
|
|
|
|
return self.clean_results(raw_search_results["results"])
|
|
|
|
|
|
|
|
|
|
async def results_async(
|
|
|
|
|
async def raw_results_async(
|
|
|
|
|
self,
|
|
|
|
|
query: str,
|
|
|
|
|
max_results: Optional[int] = 5,
|
|
|
|
@ -126,7 +121,7 @@ class TavilySearchAPIWrapper(BaseModel):
|
|
|
|
|
include_answer: Optional[bool] = False,
|
|
|
|
|
include_raw_content: Optional[bool] = False,
|
|
|
|
|
include_images: Optional[bool] = False,
|
|
|
|
|
) -> List[Dict]:
|
|
|
|
|
) -> Dict:
|
|
|
|
|
"""Get results from the Tavily Search API asynchronously."""
|
|
|
|
|
|
|
|
|
|
# Function to perform the API call
|
|
|
|
@ -151,7 +146,29 @@ class TavilySearchAPIWrapper(BaseModel):
|
|
|
|
|
raise Exception(f"Error {res.status}: {res.reason}")
|
|
|
|
|
|
|
|
|
|
results_json_str = await fetch()
|
|
|
|
|
results_json = json.loads(results_json_str)
|
|
|
|
|
return json.loads(results_json_str)
|
|
|
|
|
|
|
|
|
|
async def results_async(
|
|
|
|
|
self,
|
|
|
|
|
query: str,
|
|
|
|
|
max_results: Optional[int] = 5,
|
|
|
|
|
search_depth: Optional[str] = "advanced",
|
|
|
|
|
include_domains: Optional[List[str]] = [],
|
|
|
|
|
exclude_domains: Optional[List[str]] = [],
|
|
|
|
|
include_answer: Optional[bool] = False,
|
|
|
|
|
include_raw_content: Optional[bool] = False,
|
|
|
|
|
include_images: Optional[bool] = False,
|
|
|
|
|
) -> List[Dict]:
|
|
|
|
|
results_json = await self.raw_results_async(
|
|
|
|
|
query=query,
|
|
|
|
|
max_results=max_results,
|
|
|
|
|
search_depth=search_depth,
|
|
|
|
|
include_domains=include_domains,
|
|
|
|
|
exclude_domains=exclude_domains,
|
|
|
|
|
include_answer=include_answer,
|
|
|
|
|
include_raw_content=include_raw_content,
|
|
|
|
|
include_images=include_images,
|
|
|
|
|
)
|
|
|
|
|
return self.clean_results(results_json["results"])
|
|
|
|
|
|
|
|
|
|
def clean_results(self, results: List[Dict]) -> List[Dict]:
|
|
|
|
|