forked from Archives/langchain
Harrison/serper api bug (#4902)
Co-authored-by: Jerry Luan <xmaswillyou@gmail.com>
This commit is contained in:
parent
c998569c8f
commit
9e2227ba11
@ -5,6 +5,7 @@ import aiohttp
|
||||
import requests
|
||||
from pydantic.class_validators import root_validator
|
||||
from pydantic.main import BaseModel
|
||||
from typing_extensions import Literal
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@ -28,7 +29,16 @@ class GoogleSerperAPIWrapper(BaseModel):
|
||||
k: int = 10
|
||||
gl: str = "us"
|
||||
hl: str = "en"
|
||||
type: str = "search" # search, images, places, news
|
||||
# "places" and "images" is available from Serper but not implemented in the
|
||||
# parser of run(). They can be used in results()
|
||||
type: Literal["news", "search", "places", "images"] = "search"
|
||||
result_key_for_type = {
|
||||
"news": "news",
|
||||
"places": "places",
|
||||
"images": "images",
|
||||
"search": "organic",
|
||||
}
|
||||
|
||||
tbs: Optional[str] = None
|
||||
serper_api_key: Optional[str] = None
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
@ -50,7 +60,7 @@ class GoogleSerperAPIWrapper(BaseModel):
|
||||
|
||||
def results(self, query: str, **kwargs: Any) -> Dict:
|
||||
"""Run query through GoogleSearch."""
|
||||
return self._google_serper_search_results(
|
||||
return self._google_serper_api_results(
|
||||
query,
|
||||
gl=self.gl,
|
||||
hl=self.hl,
|
||||
@ -62,7 +72,7 @@ class GoogleSerperAPIWrapper(BaseModel):
|
||||
|
||||
def run(self, query: str, **kwargs: Any) -> str:
|
||||
"""Run query through GoogleSearch and parse result."""
|
||||
results = self._google_serper_search_results(
|
||||
results = self._google_serper_api_results(
|
||||
query,
|
||||
gl=self.gl,
|
||||
hl=self.hl,
|
||||
@ -125,7 +135,7 @@ class GoogleSerperAPIWrapper(BaseModel):
|
||||
for attribute, value in kg.get("attributes", {}).items():
|
||||
snippets.append(f"{title} {attribute}: {value}.")
|
||||
|
||||
for result in results["organic"][: self.k]:
|
||||
for result in results[self.result_key_for_type[self.type]][: self.k]:
|
||||
if "snippet" in result:
|
||||
snippets.append(result["snippet"])
|
||||
for attribute, value in result.get("attributes", {}).items():
|
||||
@ -138,7 +148,7 @@ class GoogleSerperAPIWrapper(BaseModel):
|
||||
def _parse_results(self, results: dict) -> str:
|
||||
return " ".join(self._parse_snippets(results))
|
||||
|
||||
def _google_serper_search_results(
|
||||
def _google_serper_api_results(
|
||||
self, search_term: str, search_type: str = "search", **kwargs: Any
|
||||
) -> dict:
|
||||
headers = {
|
||||
|
@ -4,13 +4,20 @@ import pytest
|
||||
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
||||
|
||||
|
||||
def test_call() -> None:
|
||||
"""Test that call gives the correct answer."""
|
||||
def test_search_call() -> None:
|
||||
"""Test that call gives the correct answer from search."""
|
||||
search = GoogleSerperAPIWrapper()
|
||||
output = search.run("What was Obama's first name?")
|
||||
assert "Barack Hussein Obama II" in output
|
||||
|
||||
|
||||
def test_news_call() -> None:
|
||||
"""Test that call gives the correct answer from news search."""
|
||||
search = GoogleSerperAPIWrapper(type="news")
|
||||
output = search.run("What's new with stock market?").lower()
|
||||
assert "stock" in output or "market" in output
|
||||
|
||||
|
||||
async def test_results() -> None:
|
||||
"""Test that call gives the correct answer."""
|
||||
search = GoogleSerperAPIWrapper()
|
||||
|
Loading…
Reference in New Issue
Block a user