Harrison/serper api bug (#4902)

Co-authored-by: Jerry Luan <xmaswillyou@gmail.com>
This commit is contained in:
Harrison Chase 2023-05-17 21:40:39 -07:00 committed by GitHub
parent c998569c8f
commit 9e2227ba11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 7 deletions

View File

@ -5,6 +5,7 @@ import aiohttp
import requests
from pydantic.class_validators import root_validator
from pydantic.main import BaseModel
from typing_extensions import Literal
from langchain.utils import get_from_dict_or_env
@ -28,7 +29,16 @@ class GoogleSerperAPIWrapper(BaseModel):
k: int = 10
gl: str = "us"
hl: str = "en"
type: str = "search" # search, images, places, news
# "places" and "images" is available from Serper but not implemented in the
# parser of run(). They can be used in results()
type: Literal["news", "search", "places", "images"] = "search"
result_key_for_type = {
"news": "news",
"places": "places",
"images": "images",
"search": "organic",
}
tbs: Optional[str] = None
serper_api_key: Optional[str] = None
aiosession: Optional[aiohttp.ClientSession] = None
@ -50,7 +60,7 @@ class GoogleSerperAPIWrapper(BaseModel):
def results(self, query: str, **kwargs: Any) -> Dict:
"""Run query through GoogleSearch."""
return self._google_serper_search_results(
return self._google_serper_api_results(
query,
gl=self.gl,
hl=self.hl,
@ -62,7 +72,7 @@ class GoogleSerperAPIWrapper(BaseModel):
def run(self, query: str, **kwargs: Any) -> str:
"""Run query through GoogleSearch and parse result."""
results = self._google_serper_search_results(
results = self._google_serper_api_results(
query,
gl=self.gl,
hl=self.hl,
@ -125,7 +135,7 @@ class GoogleSerperAPIWrapper(BaseModel):
for attribute, value in kg.get("attributes", {}).items():
snippets.append(f"{title} {attribute}: {value}.")
for result in results["organic"][: self.k]:
for result in results[self.result_key_for_type[self.type]][: self.k]:
if "snippet" in result:
snippets.append(result["snippet"])
for attribute, value in result.get("attributes", {}).items():
@ -138,7 +148,7 @@ class GoogleSerperAPIWrapper(BaseModel):
def _parse_results(self, results: dict) -> str:
return " ".join(self._parse_snippets(results))
def _google_serper_search_results(
def _google_serper_api_results(
self, search_term: str, search_type: str = "search", **kwargs: Any
) -> dict:
headers = {

View File

@ -4,13 +4,20 @@ import pytest
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
def test_call() -> None:
"""Test that call gives the correct answer."""
def test_search_call() -> None:
"""Test that call gives the correct answer from search."""
search = GoogleSerperAPIWrapper()
output = search.run("What was Obama's first name?")
assert "Barack Hussein Obama II" in output
def test_news_call() -> None:
"""Test that call gives the correct answer from news search."""
search = GoogleSerperAPIWrapper(type="news")
output = search.run("What's new with stock market?").lower()
assert "stock" in output or "market" in output
async def test_results() -> None:
"""Test that call gives the correct answer."""
search = GoogleSerperAPIWrapper()