Harrison/weaviate retriever (#2524)

Co-authored-by: Erika Cardenas <110841617+erika-cardenas@users.noreply.github.com>
doc
Harrison Chase 1 year ago committed by GitHub
parent c2f21a519f
commit 5c64b86ba3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,132 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ce0f17b9",
"metadata": {},
"source": [
"# Weaviate Hybrid Search\n",
"\n",
"This notebook shows how to use [Weaviate hybrid search](https://weaviate.io/blog/hybrid-search-explained) as a LangChain retriever."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c10dd962",
"metadata": {},
"outputs": [],
"source": [
"import weaviate\n",
"import os\n",
"\n",
"WEAVIATE_URL = \"...\"\n",
"client = weaviate.Client(\n",
" url=WEAVIATE_URL,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f47a2bfe",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever\n",
"from langchain.schema import Document"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f2eff08e",
"metadata": {},
"outputs": [],
"source": [
"retriever = WeaviateHybridSearchRetriever(client, index_name=\"LangChain\", text_key=\"text\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cd8a7b17",
"metadata": {},
"outputs": [],
"source": [
"docs = [Document(page_content=\"foo\")]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3c5970db",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['3f79d151-fb84-44cf-85e0-8682bfe145e0']"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.add_documents(docs)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "bf7dbb98",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='foo', metadata={})]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.get_relevant_documents(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2bc87c1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -4,6 +4,7 @@ from langchain.retrievers.metal import MetalRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
from langchain.retrievers.tfidf import TFIDFRetriever
from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever
__all__ = [
"ChatGPTPluginRetriever",
@ -12,4 +13,5 @@ __all__ = [
"MetalRetriever",
"ElasticSearchBM25Retriever",
"TFIDFRetriever",
"WeaviateHybridSearchRetriever",
]

@ -0,0 +1,79 @@
"""Wrapper around weaviate vector database."""
from __future__ import annotations
from typing import Any, Dict, List
from uuid import uuid4
from pydantic import Extra
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
class WeaviateHybridSearchRetriever(BaseRetriever):
def __init__(
self,
client: Any,
index_name: str,
text_key: str,
alpha: float = 0.5,
k: int = 4,
):
try:
import weaviate
except ImportError:
raise ValueError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
if not isinstance(client, weaviate.Client):
raise ValueError(
f"client should be an instance of weaviate.Client, got {type(client)}"
)
self._client = client
self.k = k
self.alpha = alpha
self._index_name = index_name
self._text_key = text_key
self._query_attrs = [self._text_key]
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
# added text_key
def add_documents(self, docs: List[Document]) -> List[str]:
"""Upload documents to Weaviate."""
from weaviate.util import get_valid_uuid
with self._client.batch as batch:
ids = []
for i, doc in enumerate(docs):
data_properties = {
self._text_key: doc.page_content,
}
_id = get_valid_uuid(uuid4())
batch.add_data_object(data_properties, self._index_name, _id)
ids.append(_id)
return ids
def get_relevant_documents(self, query: str) -> List[Document]:
"""Look up similar documents in Weaviate."""
content: Dict[str, Any] = {"concepts": [query]}
query_obj = self._client.query.get(self._index_name, self._query_attrs)
result = (
query_obj.with_hybrid(content, alpha=self.alpha).with_limit(self.k).do()
)
docs = []
for res in result["data"]["Get"][self._index_name]:
text = res.pop(self._text_key)
docs.append(Document(page_content=text, metadata=res))
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
Loading…
Cancel
Save