From 3a2782053b41e3efa954a241a246622150ff89f3 Mon Sep 17 00:00:00 2001 From: Alexander Weichart <55558407+AlexW00@users.noreply.github.com> Date: Sun, 2 Apr 2023 23:05:21 +0200 Subject: [PATCH] feat: category support for SearxSearchWrapper (#2271) Added an optional parameter "categories" to specify the active search categories. API: https://docs.searxng.org/dev/search_api.html --- langchain/utilities/searx_search.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/langchain/utilities/searx_search.py b/langchain/utilities/searx_search.py index da0bbd83..b2ed0805 100644 --- a/langchain/utilities/searx_search.py +++ b/langchain/utilities/searx_search.py @@ -203,16 +203,11 @@ class SearxSearchWrapper(BaseModel): params: dict = Field(default_factory=_get_default_params) headers: Optional[dict] = None engines: Optional[List[str]] = [] + categories: Optional[List[str]] = [] query_suffix: Optional[str] = "" k: int = 10 aiosession: Optional[Any] = None - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - @validator("unsecure") def disable_ssl_warnings(cls, v: bool) -> bool: """Disable SSL warnings.""" @@ -238,6 +233,10 @@ class SearxSearchWrapper(BaseModel): if engines: values["params"]["engines"] = ",".join(engines) + categories = values.get("categories") + if categories: + values["params"]["categories"] = ",".join(categories) + searx_host = get_from_dict_or_env(values, "searx_host", "SEARX_HOST") if not searx_host.startswith("http"): print( @@ -252,6 +251,11 @@ class SearxSearchWrapper(BaseModel): return values + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + def _searx_api_query(self, params: dict) -> SearxResults: """Actual request to searx API.""" raw_result = requests.get( @@ -298,6 +302,7 @@ class SearxSearchWrapper(BaseModel): self, query: str, engines: Optional[List[str]] = None, + categories: Optional[List[str]] = None, query_suffix: Optional[str] = "", **kwargs: Any, ) -> str: @@ -309,6 +314,7 @@ class SearxSearchWrapper(BaseModel): query: The query to search for. query_suffix: Extra suffix appended to the query. engines: List of engines to use for the query. + categories: List of categories to use for the query. **kwargs: extra parameters to pass to the searx API. Returns: @@ -345,6 +351,9 @@ class SearxSearchWrapper(BaseModel): if isinstance(engines, list) and len(engines) > 0: params["engines"] = ",".join(engines) + if isinstance(categories, list) and len(categories) > 0: + params["categories"] = ",".join(categories) + res = self._searx_api_query(params) if len(res.answers) > 0: @@ -398,6 +407,7 @@ class SearxSearchWrapper(BaseModel): query: str, num_results: int, engines: Optional[List[str]] = None, + categories: Optional[List[str]] = None, query_suffix: Optional[str] = "", **kwargs: Any, ) -> List[Dict]: @@ -412,6 +422,8 @@ class SearxSearchWrapper(BaseModel): engines: List of engines to use for the query. + categories: List of categories to use for the query. + **kwargs: extra parameters to pass to the searx API. Returns: @@ -441,6 +453,8 @@ class SearxSearchWrapper(BaseModel): params["q"] += " " + query_suffix if isinstance(engines, list) and len(engines) > 0: params["engines"] = ",".join(engines) + if isinstance(categories, list) and len(categories) > 0: + params["categories"] = ",".join(categories) results = self._searx_api_query(params).results[:num_results] if len(results) == 0: return [{"Result": "No good Search Result was found"}]