From 5c64b86ba3d0909c4fc70909a3b8c711b008794d Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 6 Apr 2023 22:27:37 -0700 Subject: [PATCH] Harrison/weaviate retriever (#2524) Co-authored-by: Erika Cardenas <110841617+erika-cardenas@users.noreply.github.com> --- .../retrievers/examples/weaviate-hybrid.ipynb | 132 ++++++++++++++++++ langchain/retrievers/__init__.py | 2 + .../retrievers/weaviate_hybrid_search.py | 79 +++++++++++ 3 files changed, 213 insertions(+) create mode 100644 docs/modules/indexes/retrievers/examples/weaviate-hybrid.ipynb create mode 100644 langchain/retrievers/weaviate_hybrid_search.py diff --git a/docs/modules/indexes/retrievers/examples/weaviate-hybrid.ipynb b/docs/modules/indexes/retrievers/examples/weaviate-hybrid.ipynb new file mode 100644 index 00000000..32888d65 --- /dev/null +++ b/docs/modules/indexes/retrievers/examples/weaviate-hybrid.ipynb @@ -0,0 +1,132 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ce0f17b9", + "metadata": {}, + "source": [ + "# Weaviate Hybrid Search\n", + "\n", + "This notebook shows how to use [Weaviate hybrid search](https://weaviate.io/blog/hybrid-search-explained) as a LangChain retriever." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c10dd962", + "metadata": {}, + "outputs": [], + "source": [ + "import weaviate\n", + "import os\n", + "\n", + "WEAVIATE_URL = \"...\"\n", + "client = weaviate.Client(\n", + " url=WEAVIATE_URL,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f47a2bfe", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever\n", + "from langchain.schema import Document" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f2eff08e", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = WeaviateHybridSearchRetriever(client, index_name=\"LangChain\", text_key=\"text\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cd8a7b17", + "metadata": {}, + "outputs": [], + "source": [ + "docs = [Document(page_content=\"foo\")]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3c5970db", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['3f79d151-fb84-44cf-85e0-8682bfe145e0']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "retriever.add_documents(docs)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "bf7dbb98", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='foo', metadata={})]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "retriever.get_relevant_documents(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2bc87c1", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index efeff643..33ea97aa 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -4,6 +4,7 @@ from langchain.retrievers.metal import MetalRetriever from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever from langchain.retrievers.remote_retriever import RemoteLangChainRetriever from langchain.retrievers.tfidf import TFIDFRetriever +from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever __all__ = [ "ChatGPTPluginRetriever", @@ -12,4 +13,5 @@ __all__ = [ "MetalRetriever", "ElasticSearchBM25Retriever", "TFIDFRetriever", + "WeaviateHybridSearchRetriever", ] diff --git a/langchain/retrievers/weaviate_hybrid_search.py b/langchain/retrievers/weaviate_hybrid_search.py new file mode 100644 index 00000000..aeaba149 --- /dev/null +++ b/langchain/retrievers/weaviate_hybrid_search.py @@ -0,0 +1,79 @@ +"""Wrapper around weaviate vector database.""" +from __future__ import annotations + +from typing import Any, Dict, List +from uuid import uuid4 + +from pydantic import Extra + +from langchain.docstore.document import Document +from langchain.schema import BaseRetriever + + +class WeaviateHybridSearchRetriever(BaseRetriever): + def __init__( + self, + client: Any, + index_name: str, + text_key: str, + alpha: float = 0.5, + k: int = 4, + ): + try: + import weaviate + except ImportError: + raise ValueError( + "Could not import weaviate python package. " + "Please install it with `pip install weaviate-client`." + ) + if not isinstance(client, weaviate.Client): + raise ValueError( + f"client should be an instance of weaviate.Client, got {type(client)}" + ) + self._client = client + self.k = k + self.alpha = alpha + self._index_name = index_name + self._text_key = text_key + self._query_attrs = [self._text_key] + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + # added text_key + def add_documents(self, docs: List[Document]) -> List[str]: + """Upload documents to Weaviate.""" + from weaviate.util import get_valid_uuid + + with self._client.batch as batch: + ids = [] + for i, doc in enumerate(docs): + data_properties = { + self._text_key: doc.page_content, + } + _id = get_valid_uuid(uuid4()) + batch.add_data_object(data_properties, self._index_name, _id) + ids.append(_id) + return ids + + def get_relevant_documents(self, query: str) -> List[Document]: + """Look up similar documents in Weaviate.""" + content: Dict[str, Any] = {"concepts": [query]} + query_obj = self._client.query.get(self._index_name, self._query_attrs) + + result = ( + query_obj.with_hybrid(content, alpha=self.alpha).with_limit(self.k).do() + ) + + docs = [] + + for res in result["data"]["Get"][self._index_name]: + text = res.pop(self._text_key) + docs.append(Document(page_content=text, metadata=res)) + return docs + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError