mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
163 lines
6.0 KiB
Python
163 lines
6.0 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
from typing import Any, Dict, List, Optional, cast
|
||
|
from uuid import uuid4
|
||
|
|
||
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||
|
from langchain_core.documents import Document
|
||
|
from langchain_core.pydantic_v1 import root_validator
|
||
|
from langchain_core.retrievers import BaseRetriever
|
||
|
|
||
|
|
||
|
class WeaviateHybridSearchRetriever(BaseRetriever):
|
||
|
"""`Weaviate hybrid search` retriever.
|
||
|
|
||
|
See the documentation:
|
||
|
https://weaviate.io/blog/hybrid-search-explained
|
||
|
"""
|
||
|
|
||
|
client: Any
|
||
|
"""keyword arguments to pass to the Weaviate client."""
|
||
|
index_name: str
|
||
|
"""The name of the index to use."""
|
||
|
text_key: str
|
||
|
"""The name of the text key to use."""
|
||
|
alpha: float = 0.5
|
||
|
"""The weight of the text key in the hybrid search."""
|
||
|
k: int = 4
|
||
|
"""The number of results to return."""
|
||
|
attributes: List[str]
|
||
|
"""The attributes to return in the results."""
|
||
|
create_schema_if_missing: bool = True
|
||
|
"""Whether to create the schema if it doesn't exist."""
|
||
|
|
||
|
@root_validator(pre=True)
|
||
|
def validate_client(
|
||
|
cls,
|
||
|
values: Dict[str, Any],
|
||
|
) -> Dict[str, Any]:
|
||
|
try:
|
||
|
import weaviate
|
||
|
except ImportError:
|
||
|
raise ImportError(
|
||
|
"Could not import weaviate python package. "
|
||
|
"Please install it with `pip install weaviate-client`."
|
||
|
)
|
||
|
if not isinstance(values["client"], weaviate.Client):
|
||
|
client = values["client"]
|
||
|
raise ValueError(
|
||
|
f"client should be an instance of weaviate.Client, got {type(client)}"
|
||
|
)
|
||
|
if values.get("attributes") is None:
|
||
|
values["attributes"] = []
|
||
|
|
||
|
cast(List, values["attributes"]).append(values["text_key"])
|
||
|
|
||
|
if values.get("create_schema_if_missing", True):
|
||
|
class_obj = {
|
||
|
"class": values["index_name"],
|
||
|
"properties": [{"name": values["text_key"], "dataType": ["text"]}],
|
||
|
"vectorizer": "text2vec-openai",
|
||
|
}
|
||
|
|
||
|
if not values["client"].schema.exists(values["index_name"]):
|
||
|
values["client"].schema.create_class(class_obj)
|
||
|
|
||
|
return values
|
||
|
|
||
|
class Config:
|
||
|
"""Configuration for this pydantic object."""
|
||
|
|
||
|
arbitrary_types_allowed = True
|
||
|
|
||
|
# added text_key
|
||
|
def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]:
|
||
|
"""Upload documents to Weaviate."""
|
||
|
from weaviate.util import get_valid_uuid
|
||
|
|
||
|
with self.client.batch as batch:
|
||
|
ids = []
|
||
|
for i, doc in enumerate(docs):
|
||
|
metadata = doc.metadata or {}
|
||
|
data_properties = {self.text_key: doc.page_content, **metadata}
|
||
|
|
||
|
# If the UUID of one of the objects already exists
|
||
|
# then the existing objectwill be replaced by the new object.
|
||
|
if "uuids" in kwargs:
|
||
|
_id = kwargs["uuids"][i]
|
||
|
else:
|
||
|
_id = get_valid_uuid(uuid4())
|
||
|
|
||
|
batch.add_data_object(data_properties, self.index_name, _id)
|
||
|
ids.append(_id)
|
||
|
return ids
|
||
|
|
||
|
def _get_relevant_documents(
|
||
|
self,
|
||
|
query: str,
|
||
|
*,
|
||
|
run_manager: CallbackManagerForRetrieverRun,
|
||
|
where_filter: Optional[Dict[str, object]] = None,
|
||
|
score: bool = False,
|
||
|
hybrid_search_kwargs: Optional[Dict[str, object]] = None,
|
||
|
) -> List[Document]:
|
||
|
"""Look up similar documents in Weaviate.
|
||
|
|
||
|
query: The query to search for relevant documents
|
||
|
of using weviate hybrid search.
|
||
|
|
||
|
where_filter: A filter to apply to the query.
|
||
|
https://weaviate.io/developers/weaviate/guides/querying/#filtering
|
||
|
|
||
|
score: Whether to include the score, and score explanation
|
||
|
in the returned Documents meta_data.
|
||
|
|
||
|
hybrid_search_kwargs: Used to pass additional arguments
|
||
|
to the .with_hybrid() method.
|
||
|
The primary uses cases for this are:
|
||
|
1) Search specific properties only -
|
||
|
specify which properties to be used during hybrid search portion.
|
||
|
Note: this is not the same as the (self.attributes) to be returned.
|
||
|
Example - hybrid_search_kwargs={"properties": ["question", "answer"]}
|
||
|
https://weaviate.io/developers/weaviate/search/hybrid#selected-properties-only
|
||
|
|
||
|
2) Weight boosted searched properties -
|
||
|
Boost the weight of certain properties during the hybrid search portion.
|
||
|
Example - hybrid_search_kwargs={"properties": ["question^2", "answer"]}
|
||
|
https://weaviate.io/developers/weaviate/search/hybrid#weight-boost-searched-properties
|
||
|
|
||
|
3) Search with a custom vector - Define a different vector
|
||
|
to be used during the hybrid search portion.
|
||
|
Example - hybrid_search_kwargs={"vector": [0.1, 0.2, 0.3, ...]}
|
||
|
https://weaviate.io/developers/weaviate/search/hybrid#with-a-custom-vector
|
||
|
|
||
|
4) Use Fusion ranking method
|
||
|
Example - from weaviate.gql.get import HybridFusion
|
||
|
hybrid_search_kwargs={"fusion": fusion_type=HybridFusion.RELATIVE_SCORE}
|
||
|
https://weaviate.io/developers/weaviate/search/hybrid#fusion-ranking-method
|
||
|
"""
|
||
|
query_obj = self.client.query.get(self.index_name, self.attributes)
|
||
|
if where_filter:
|
||
|
query_obj = query_obj.with_where(where_filter)
|
||
|
|
||
|
if score:
|
||
|
query_obj = query_obj.with_additional(["score", "explainScore"])
|
||
|
|
||
|
if hybrid_search_kwargs is None:
|
||
|
hybrid_search_kwargs = {}
|
||
|
|
||
|
result = (
|
||
|
query_obj.with_hybrid(query, alpha=self.alpha, **hybrid_search_kwargs)
|
||
|
.with_limit(self.k)
|
||
|
.do()
|
||
|
)
|
||
|
if "errors" in result:
|
||
|
raise ValueError(f"Error during query: {result['errors']}")
|
||
|
|
||
|
docs = []
|
||
|
|
||
|
for res in result["data"]["Get"][self.index_name]:
|
||
|
text = res.pop(self.text_key)
|
||
|
docs.append(Document(page_content=text, metadata=res))
|
||
|
return docs
|