From 2b975de94dc88e63056a8c89775de503c6e1ebd8 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 4 Apr 2023 12:17:13 -0700 Subject: [PATCH] add metal retriever (#2244) --- .../indexes/retrievers/examples/metal.ipynb | 156 ++++++++++++++++++ langchain/retrievers/__init__.py | 3 +- langchain/retrievers/metal.py | 27 +++ 3 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 docs/modules/indexes/retrievers/examples/metal.ipynb create mode 100644 langchain/retrievers/metal.py diff --git a/docs/modules/indexes/retrievers/examples/metal.ipynb b/docs/modules/indexes/retrievers/examples/metal.ipynb new file mode 100644 index 0000000000..8e6908d708 --- /dev/null +++ b/docs/modules/indexes/retrievers/examples/metal.ipynb @@ -0,0 +1,156 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9fc6205b", + "metadata": {}, + "source": [ + "# Metal\n", + "\n", + "This notebook shows how to use [Metal's](https://docs.getmetal.io/introduction) retriever.\n", + "\n", + "First, you will need to sign up for Metal and get an API key. You can do so [here](https://docs.getmetal.io/misc-create-app)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1a737220", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install metal_sdk" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b1bb478f", + "metadata": {}, + "outputs": [], + "source": [ + "from metal_sdk.metal import Metal\n", + "API_KEY = \"\"\n", + "CLIENT_ID = \"\"\n", + "APP_ID = \"\"\n", + "\n", + "metal = Metal(API_KEY, CLIENT_ID, APP_ID);\n" + ] + }, + { + "cell_type": "markdown", + "id": "ae3c3d16", + "metadata": {}, + "source": [ + "## Ingest Documents\n", + "\n", + "You only need to do this if you haven't already set up an index" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f0425fa0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'data': {'id': '642739aa7559b026b4430e42',\n", + " 'text': 'foo',\n", + " 'createdAt': '2023-03-31T19:51:06.748Z'}}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metal.index( {\"text\": \"foo1\"})\n", + "metal.index( {\"text\": \"foo\"})" + ] + }, + { + "cell_type": "markdown", + "id": "944e172b", + "metadata": {}, + "source": [ + "## Query\n", + "\n", + "Now that our index is set up, we can set up a retriever and start querying it." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d0e6f506", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers import MetalRetriever" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f381f642", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = MetalRetriever(metal, params={\"limit\": 2})" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "20ae1a74", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='foo1', metadata={'dist': '1.19209289551e-07', 'id': '642739a17559b026b4430e40', 'createdAt': '2023-03-31T19:50:57.853Z'}),\n", + " Document(page_content='foo1', metadata={'dist': '4.05311584473e-06', 'id': '642738f67559b026b4430e3c', 'createdAt': '2023-03-31T19:48:06.769Z'})]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "retriever.get_relevant_documents(\"foo1\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d5a5088", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index 6742f5bcc8..e8ec900d41 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -1,4 +1,5 @@ from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever +from langchain.retrievers.metal import MetalRetriever from langchain.retrievers.remote_retriever import RemoteLangChainRetriever -__all__ = ["ChatGPTPluginRetriever", "RemoteLangChainRetriever"] +__all__ = ["ChatGPTPluginRetriever", "RemoteLangChainRetriever", "MetalRetriever"] diff --git a/langchain/retrievers/metal.py b/langchain/retrievers/metal.py new file mode 100644 index 0000000000..dcaad005f0 --- /dev/null +++ b/langchain/retrievers/metal.py @@ -0,0 +1,27 @@ +from typing import Any, List, Optional + +from langchain.schema import BaseRetriever, Document + + +class MetalRetriever(BaseRetriever): + def __init__(self, client: Any, params: Optional[dict] = None): + from metal_sdk.metal import Metal + + if not isinstance(client, Metal): + raise ValueError( + "Got unexpected client, should be of type metal_sdk.metal.Metal. " + f"Instead, got {type(client)}" + ) + self.client: Metal = client + self.params = params or {} + + def get_relevant_documents(self, query: str) -> List[Document]: + results = self.client.search({"text": query}, **self.params) + final_results = [] + for r in results["data"]: + metadata = {k: v for k, v in r.items() if k != "text"} + final_results.append(Document(page_content=r["text"], metadata=metadata)) + return final_results + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError