forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
28 lines
1021 B
Python
28 lines
1021 B
Python
from typing import Any, List, Optional
|
|
|
|
from langchain.schema import BaseRetriever, Document
|
|
|
|
|
|
class MetalRetriever(BaseRetriever):
|
|
def __init__(self, client: Any, params: Optional[dict] = None):
|
|
from metal_sdk.metal import Metal
|
|
|
|
if not isinstance(client, Metal):
|
|
raise ValueError(
|
|
"Got unexpected client, should be of type metal_sdk.metal.Metal. "
|
|
f"Instead, got {type(client)}"
|
|
)
|
|
self.client: Metal = client
|
|
self.params = params or {}
|
|
|
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
|
results = self.client.search({"text": query}, **self.params)
|
|
final_results = []
|
|
for r in results["data"]:
|
|
metadata = {k: v for k, v in r.items() if k != "text"}
|
|
final_results.append(Document(page_content=r["text"], metadata=metadata))
|
|
return final_results
|
|
|
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
|
raise NotImplementedError
|