mirror of https://github.com/hwchase17/langchain
Harrison/weaviate retriever (#2524)
Co-authored-by: Erika Cardenas <110841617+erika-cardenas@users.noreply.github.com>pull/2525/head
parent
c2f21a519f
commit
5c64b86ba3
@ -0,0 +1,132 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ce0f17b9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Weaviate Hybrid Search\n",
|
||||
"\n",
|
||||
"This notebook shows how to use [Weaviate hybrid search](https://weaviate.io/blog/hybrid-search-explained) as a LangChain retriever."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c10dd962",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import weaviate\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"WEAVIATE_URL = \"...\"\n",
|
||||
"client = weaviate.Client(\n",
|
||||
" url=WEAVIATE_URL,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "f47a2bfe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever\n",
|
||||
"from langchain.schema import Document"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "f2eff08e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = WeaviateHybridSearchRetriever(client, index_name=\"LangChain\", text_key=\"text\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "cd8a7b17",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = [Document(page_content=\"foo\")]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "3c5970db",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['3f79d151-fb84-44cf-85e0-8682bfe145e0']"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retriever.add_documents(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "bf7dbb98",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='foo', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"retriever.get_relevant_documents(\"foo\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b2bc87c1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,79 @@
|
||||
"""Wrapper around weaviate vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
|
||||
class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
index_name: str,
|
||||
text_key: str,
|
||||
alpha: float = 0.5,
|
||||
k: int = 4,
|
||||
):
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`."
|
||||
)
|
||||
if not isinstance(client, weaviate.Client):
|
||||
raise ValueError(
|
||||
f"client should be an instance of weaviate.Client, got {type(client)}"
|
||||
)
|
||||
self._client = client
|
||||
self.k = k
|
||||
self.alpha = alpha
|
||||
self._index_name = index_name
|
||||
self._text_key = text_key
|
||||
self._query_attrs = [self._text_key]
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# added text_key
|
||||
def add_documents(self, docs: List[Document]) -> List[str]:
|
||||
"""Upload documents to Weaviate."""
|
||||
from weaviate.util import get_valid_uuid
|
||||
|
||||
with self._client.batch as batch:
|
||||
ids = []
|
||||
for i, doc in enumerate(docs):
|
||||
data_properties = {
|
||||
self._text_key: doc.page_content,
|
||||
}
|
||||
_id = get_valid_uuid(uuid4())
|
||||
batch.add_data_object(data_properties, self._index_name, _id)
|
||||
ids.append(_id)
|
||||
return ids
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Look up similar documents in Weaviate."""
|
||||
content: Dict[str, Any] = {"concepts": [query]}
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
|
||||
result = (
|
||||
query_obj.with_hybrid(content, alpha=self.alpha).with_limit(self.k).do()
|
||||
)
|
||||
|
||||
docs = []
|
||||
|
||||
for res in result["data"]["Get"][self._index_name]:
|
||||
text = res.pop(self._text_key)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError
|
Loading…
Reference in New Issue