mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
127 lines
4.4 KiB
Python
127 lines
4.4 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import json
|
||
|
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
|
||
|
|
||
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||
|
from langchain_core.documents import Document
|
||
|
from langchain_core.retrievers import BaseRetriever
|
||
|
|
||
|
|
||
|
class VespaRetriever(BaseRetriever):
|
||
|
"""`Vespa` retriever."""
|
||
|
|
||
|
app: Any
|
||
|
"""Vespa application to query."""
|
||
|
body: Dict
|
||
|
"""Body of the query."""
|
||
|
content_field: str
|
||
|
"""Name of the content field."""
|
||
|
metadata_fields: Sequence[str]
|
||
|
"""Names of the metadata fields."""
|
||
|
|
||
|
def _query(self, body: Dict) -> List[Document]:
|
||
|
response = self.app.query(body)
|
||
|
|
||
|
if not str(response.status_code).startswith("2"):
|
||
|
raise RuntimeError(
|
||
|
"Could not retrieve data from Vespa. Error code: {}".format(
|
||
|
response.status_code
|
||
|
)
|
||
|
)
|
||
|
|
||
|
root = response.json["root"]
|
||
|
if "errors" in root:
|
||
|
raise RuntimeError(json.dumps(root["errors"]))
|
||
|
|
||
|
docs = []
|
||
|
for child in response.hits:
|
||
|
page_content = child["fields"].pop(self.content_field, "")
|
||
|
if self.metadata_fields == "*":
|
||
|
metadata = child["fields"]
|
||
|
else:
|
||
|
metadata = {mf: child["fields"].get(mf) for mf in self.metadata_fields}
|
||
|
metadata["id"] = child["id"]
|
||
|
docs.append(Document(page_content=page_content, metadata=metadata))
|
||
|
return docs
|
||
|
|
||
|
def _get_relevant_documents(
|
||
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||
|
) -> List[Document]:
|
||
|
body = self.body.copy()
|
||
|
body["query"] = query
|
||
|
return self._query(body)
|
||
|
|
||
|
def get_relevant_documents_with_filter(
|
||
|
self, query: str, *, _filter: Optional[str] = None
|
||
|
) -> List[Document]:
|
||
|
body = self.body.copy()
|
||
|
_filter = f" and {_filter}" if _filter else ""
|
||
|
body["yql"] = body["yql"] + _filter
|
||
|
body["query"] = query
|
||
|
return self._query(body)
|
||
|
|
||
|
@classmethod
|
||
|
def from_params(
|
||
|
cls,
|
||
|
url: str,
|
||
|
content_field: str,
|
||
|
*,
|
||
|
k: Optional[int] = None,
|
||
|
metadata_fields: Union[Sequence[str], Literal["*"]] = (),
|
||
|
sources: Union[Sequence[str], Literal["*"], None] = None,
|
||
|
_filter: Optional[str] = None,
|
||
|
yql: Optional[str] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> VespaRetriever:
|
||
|
"""Instantiate retriever from params.
|
||
|
|
||
|
Args:
|
||
|
url (str): Vespa app URL.
|
||
|
content_field (str): Field in results to return as Document page_content.
|
||
|
k (Optional[int]): Number of Documents to return. Defaults to None.
|
||
|
metadata_fields(Sequence[str] or "*"): Fields in results to include in
|
||
|
document metadata. Defaults to empty tuple ().
|
||
|
sources (Sequence[str] or "*" or None): Sources to retrieve
|
||
|
from. Defaults to None.
|
||
|
_filter (Optional[str]): Document filter condition expressed in YQL.
|
||
|
Defaults to None.
|
||
|
yql (Optional[str]): Full YQL query to be used. Should not be specified
|
||
|
if _filter or sources are specified. Defaults to None.
|
||
|
kwargs (Any): Keyword arguments added to query body.
|
||
|
|
||
|
Returns:
|
||
|
VespaRetriever: Instantiated VespaRetriever.
|
||
|
"""
|
||
|
try:
|
||
|
from vespa.application import Vespa
|
||
|
except ImportError:
|
||
|
raise ImportError(
|
||
|
"pyvespa is not installed, please install with `pip install pyvespa`"
|
||
|
)
|
||
|
app = Vespa(url)
|
||
|
body = kwargs.copy()
|
||
|
if yql and (sources or _filter):
|
||
|
raise ValueError(
|
||
|
"yql should only be specified if both sources and _filter are not "
|
||
|
"specified."
|
||
|
)
|
||
|
else:
|
||
|
if metadata_fields == "*":
|
||
|
_fields = "*"
|
||
|
body["summary"] = "short"
|
||
|
else:
|
||
|
_fields = ", ".join([content_field] + list(metadata_fields or []))
|
||
|
_sources = ", ".join(sources) if isinstance(sources, Sequence) else "*"
|
||
|
_filter = f" and {_filter}" if _filter else ""
|
||
|
yql = f"select {_fields} from sources {_sources} where userQuery(){_filter}"
|
||
|
body["yql"] = yql
|
||
|
if k:
|
||
|
body["hits"] = k
|
||
|
return cls(
|
||
|
app=app,
|
||
|
body=body,
|
||
|
content_field=content_field,
|
||
|
metadata_fields=metadata_fields,
|
||
|
)
|