diff --git a/langchain/utilities/google_serper.py b/langchain/utilities/google_serper.py index 0830a631..9db376c4 100644 --- a/langchain/utilities/google_serper.py +++ b/langchain/utilities/google_serper.py @@ -5,6 +5,7 @@ import aiohttp import requests from pydantic.class_validators import root_validator from pydantic.main import BaseModel +from typing_extensions import Literal from langchain.utils import get_from_dict_or_env @@ -28,7 +29,16 @@ class GoogleSerperAPIWrapper(BaseModel): k: int = 10 gl: str = "us" hl: str = "en" - type: str = "search" # search, images, places, news + # "places" and "images" is available from Serper but not implemented in the + # parser of run(). They can be used in results() + type: Literal["news", "search", "places", "images"] = "search" + result_key_for_type = { + "news": "news", + "places": "places", + "images": "images", + "search": "organic", + } + tbs: Optional[str] = None serper_api_key: Optional[str] = None aiosession: Optional[aiohttp.ClientSession] = None @@ -50,7 +60,7 @@ class GoogleSerperAPIWrapper(BaseModel): def results(self, query: str, **kwargs: Any) -> Dict: """Run query through GoogleSearch.""" - return self._google_serper_search_results( + return self._google_serper_api_results( query, gl=self.gl, hl=self.hl, @@ -62,7 +72,7 @@ class GoogleSerperAPIWrapper(BaseModel): def run(self, query: str, **kwargs: Any) -> str: """Run query through GoogleSearch and parse result.""" - results = self._google_serper_search_results( + results = self._google_serper_api_results( query, gl=self.gl, hl=self.hl, @@ -125,7 +135,7 @@ class GoogleSerperAPIWrapper(BaseModel): for attribute, value in kg.get("attributes", {}).items(): snippets.append(f"{title} {attribute}: {value}.") - for result in results["organic"][: self.k]: + for result in results[self.result_key_for_type[self.type]][: self.k]: if "snippet" in result: snippets.append(result["snippet"]) for attribute, value in result.get("attributes", {}).items(): @@ -138,7 +148,7 @@ class GoogleSerperAPIWrapper(BaseModel): def _parse_results(self, results: dict) -> str: return " ".join(self._parse_snippets(results)) - def _google_serper_search_results( + def _google_serper_api_results( self, search_term: str, search_type: str = "search", **kwargs: Any ) -> dict: headers = { diff --git a/tests/integration_tests/utilities/test_googleserper_api.py b/tests/integration_tests/utilities/test_googleserper_api.py index 78f3275b..2fe4a5b7 100644 --- a/tests/integration_tests/utilities/test_googleserper_api.py +++ b/tests/integration_tests/utilities/test_googleserper_api.py @@ -4,13 +4,20 @@ import pytest from langchain.utilities.google_serper import GoogleSerperAPIWrapper -def test_call() -> None: - """Test that call gives the correct answer.""" +def test_search_call() -> None: + """Test that call gives the correct answer from search.""" search = GoogleSerperAPIWrapper() output = search.run("What was Obama's first name?") assert "Barack Hussein Obama II" in output +def test_news_call() -> None: + """Test that call gives the correct answer from news search.""" + search = GoogleSerperAPIWrapper(type="news") + output = search.run("What's new with stock market?").lower() + assert "stock" in output or "market" in output + + async def test_results() -> None: """Test that call gives the correct answer.""" search = GoogleSerperAPIWrapper()