You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/retrievers/weaviate_hybrid_search.py

163 lines
6.0 KiB
Python

from __future__ import annotations
from typing import Any, Dict, List, Optional, cast
from uuid import uuid4
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import root_validator
from langchain_core.retrievers import BaseRetriever
class WeaviateHybridSearchRetriever(BaseRetriever):
"""`Weaviate hybrid search` retriever.
See the documentation:
https://weaviate.io/blog/hybrid-search-explained
"""
client: Any
"""keyword arguments to pass to the Weaviate client."""
index_name: str
"""The name of the index to use."""
text_key: str
"""The name of the text key to use."""
alpha: float = 0.5
"""The weight of the text key in the hybrid search."""
k: int = 4
"""The number of results to return."""
attributes: List[str]
"""The attributes to return in the results."""
create_schema_if_missing: bool = True
"""Whether to create the schema if it doesn't exist."""
@root_validator(pre=True)
def validate_client(
cls,
values: Dict[str, Any],
) -> Dict[str, Any]:
try:
import weaviate
except ImportError:
raise ImportError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
if not isinstance(values["client"], weaviate.Client):
client = values["client"]
raise ValueError(
f"client should be an instance of weaviate.Client, got {type(client)}"
)
if values.get("attributes") is None:
values["attributes"] = []
cast(List, values["attributes"]).append(values["text_key"])
if values.get("create_schema_if_missing", True):
class_obj = {
"class": values["index_name"],
"properties": [{"name": values["text_key"], "dataType": ["text"]}],
"vectorizer": "text2vec-openai",
}
if not values["client"].schema.exists(values["index_name"]):
values["client"].schema.create_class(class_obj)
return values
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
# added text_key
def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]:
"""Upload documents to Weaviate."""
from weaviate.util import get_valid_uuid
with self.client.batch as batch:
ids = []
for i, doc in enumerate(docs):
metadata = doc.metadata or {}
data_properties = {self.text_key: doc.page_content, **metadata}
# If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
else:
_id = get_valid_uuid(uuid4())
batch.add_data_object(data_properties, self.index_name, _id)
ids.append(_id)
return ids
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
where_filter: Optional[Dict[str, object]] = None,
score: bool = False,
hybrid_search_kwargs: Optional[Dict[str, object]] = None,
) -> List[Document]:
"""Look up similar documents in Weaviate.
query: The query to search for relevant documents
of using weviate hybrid search.
where_filter: A filter to apply to the query.
https://weaviate.io/developers/weaviate/guides/querying/#filtering
score: Whether to include the score, and score explanation
in the returned Documents meta_data.
hybrid_search_kwargs: Used to pass additional arguments
to the .with_hybrid() method.
The primary uses cases for this are:
1) Search specific properties only -
specify which properties to be used during hybrid search portion.
Note: this is not the same as the (self.attributes) to be returned.
Example - hybrid_search_kwargs={"properties": ["question", "answer"]}
https://weaviate.io/developers/weaviate/search/hybrid#selected-properties-only
2) Weight boosted searched properties -
Boost the weight of certain properties during the hybrid search portion.
Example - hybrid_search_kwargs={"properties": ["question^2", "answer"]}
https://weaviate.io/developers/weaviate/search/hybrid#weight-boost-searched-properties
3) Search with a custom vector - Define a different vector
to be used during the hybrid search portion.
Example - hybrid_search_kwargs={"vector": [0.1, 0.2, 0.3, ...]}
https://weaviate.io/developers/weaviate/search/hybrid#with-a-custom-vector
4) Use Fusion ranking method
Example - from weaviate.gql.get import HybridFusion
hybrid_search_kwargs={"fusion": fusion_type=HybridFusion.RELATIVE_SCORE}
https://weaviate.io/developers/weaviate/search/hybrid#fusion-ranking-method
"""
query_obj = self.client.query.get(self.index_name, self.attributes)
if where_filter:
query_obj = query_obj.with_where(where_filter)
if score:
query_obj = query_obj.with_additional(["score", "explainScore"])
if hybrid_search_kwargs is None:
hybrid_search_kwargs = {}
result = (
query_obj.with_hybrid(query, alpha=self.alpha, **hybrid_search_kwargs)
.with_limit(self.k)
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][self.index_name]:
text = res.pop(self.text_key)
docs.append(Document(page_content=text, metadata=res))
return docs