From 923a7dde5aa58664b0abd2483e6df03a1aaab7fb Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 28 Mar 2023 08:06:27 -0700 Subject: [PATCH] Harrison/llama index loader (#2097) Co-authored-by: Jerry Liu --- langchain/retrievers/llama_index.py | 71 +++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 langchain/retrievers/llama_index.py diff --git a/langchain/retrievers/llama_index.py b/langchain/retrievers/llama_index.py new file mode 100644 index 0000000000..4b3ee6453f --- /dev/null +++ b/langchain/retrievers/llama_index.py @@ -0,0 +1,71 @@ +from typing import Any, Dict, List, cast + +from pydantic import BaseModel, Field + +from langchain.schema import BaseRetriever, Document + + +class LlamaIndexRetriever(BaseRetriever, BaseModel): + """Question-answering with sources over an LlamaIndex data structure.""" + + index: Any + query_kwargs: Dict = Field(default_factory=dict) + + def get_relevant_documents(self, query: str) -> List[Document]: + """Get documents relevant for a query.""" + try: + from llama_index.indices.base import BaseGPTIndex + from llama_index.response.schema import Response + except ImportError: + raise ImportError( + "You need to install `pip install llama-index` to use this retriever." + ) + index = cast(BaseGPTIndex, self.index) + + response = index.query(query, response_mode="no_text", **self.query_kwargs) + response = cast(Response, response) + # parse source nodes + docs = [] + for source_node in response.source_nodes: + metadata = source_node.extra_info or {} + docs.append( + Document(page_content=source_node.source_text, metadata=metadata) + ) + return docs + + +class LlamaIndexGraphRetriever(BaseRetriever, BaseModel): + """Question-answering with sources over an LlamaIndex graph data structure.""" + + graph: Any + query_configs: List[Dict] = Field(default_factory=list) + + def get_relevant_documents(self, query: str) -> List[Document]: + """Get documents relevant for a query.""" + try: + from llama_index.composability.graph import ( + QUERY_CONFIG_TYPE, + ComposableGraph, + ) + from llama_index.response.schema import Response + except ImportError: + raise ImportError( + "You need to install `pip install llama-index` to use this retriever." + ) + graph = cast(ComposableGraph, self.graph) + + # for now, inject response_mode="no_text" into query configs + for query_config in self.query_configs: + query_config["response_mode"] = "no_text" + query_configs = cast(List[QUERY_CONFIG_TYPE], self.query_configs) + response = graph.query(query, query_configs=query_configs) + response = cast(Response, response) + + # parse source nodes + docs = [] + for source_node in response.source_nodes: + metadata = source_node.extra_info or {} + docs.append( + Document(page_content=source_node.source_text, metadata=metadata) + ) + return docs