add metal retriever (#2244)

This commit is contained in:
Harrison Chase 2023-04-04 12:17:13 -07:00 committed by GitHub
parent 1f88b11c99
commit 2b975de94d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 185 additions and 1 deletions

View File

@ -0,0 +1,156 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9fc6205b",
"metadata": {},
"source": [
"# Metal\n",
"\n",
"This notebook shows how to use [Metal's](https://docs.getmetal.io/introduction) retriever.\n",
"\n",
"First, you will need to sign up for Metal and get an API key. You can do so [here](https://docs.getmetal.io/misc-create-app)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1a737220",
"metadata": {},
"outputs": [],
"source": [
"# !pip install metal_sdk"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b1bb478f",
"metadata": {},
"outputs": [],
"source": [
"from metal_sdk.metal import Metal\n",
"API_KEY = \"\"\n",
"CLIENT_ID = \"\"\n",
"APP_ID = \"\"\n",
"\n",
"metal = Metal(API_KEY, CLIENT_ID, APP_ID);\n"
]
},
{
"cell_type": "markdown",
"id": "ae3c3d16",
"metadata": {},
"source": [
"## Ingest Documents\n",
"\n",
"You only need to do this if you haven't already set up an index"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f0425fa0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'data': {'id': '642739aa7559b026b4430e42',\n",
" 'text': 'foo',\n",
" 'createdAt': '2023-03-31T19:51:06.748Z'}}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"metal.index( {\"text\": \"foo1\"})\n",
"metal.index( {\"text\": \"foo\"})"
]
},
{
"cell_type": "markdown",
"id": "944e172b",
"metadata": {},
"source": [
"## Query\n",
"\n",
"Now that our index is set up, we can set up a retriever and start querying it."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d0e6f506",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import MetalRetriever"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f381f642",
"metadata": {},
"outputs": [],
"source": [
"retriever = MetalRetriever(metal, params={\"limit\": 2})"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "20ae1a74",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='foo1', metadata={'dist': '1.19209289551e-07', 'id': '642739a17559b026b4430e40', 'createdAt': '2023-03-31T19:50:57.853Z'}),\n",
" Document(page_content='foo1', metadata={'dist': '4.05311584473e-06', 'id': '642738f67559b026b4430e3c', 'createdAt': '2023-03-31T19:48:06.769Z'})]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"retriever.get_relevant_documents(\"foo1\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d5a5088",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -1,4 +1,5 @@
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
from langchain.retrievers.metal import MetalRetriever
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
__all__ = ["ChatGPTPluginRetriever", "RemoteLangChainRetriever"]
__all__ = ["ChatGPTPluginRetriever", "RemoteLangChainRetriever", "MetalRetriever"]

View File

@ -0,0 +1,27 @@
from typing import Any, List, Optional
from langchain.schema import BaseRetriever, Document
class MetalRetriever(BaseRetriever):
def __init__(self, client: Any, params: Optional[dict] = None):
from metal_sdk.metal import Metal
if not isinstance(client, Metal):
raise ValueError(
"Got unexpected client, should be of type metal_sdk.metal.Metal. "
f"Instead, got {type(client)}"
)
self.client: Metal = client
self.params = params or {}
def get_relevant_documents(self, query: str) -> List[Document]:
results = self.client.search({"text": query}, **self.params)
final_results = []
for r in results["data"]:
metadata = {k: v for k, v in r.items() if k != "text"}
final_results.append(Document(page_content=r["text"], metadata=metadata))
return final_results
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError