You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/retrievers/self_query/base.py

155 lines
5.7 KiB
Python

"""Retriever that generates and executes structured queries over its own data source."""
from typing import Any, Dict, List, Optional, Type, cast
from pydantic import BaseModel, Field, root_validator
from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.myscale import MyScaleTranslator
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores import (
Chroma,
MyScale,
Pinecone,
Qdrant,
VectorStore,
Weaviate,
)
def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
"""Get the translator class corresponding to the vector store class."""
vectorstore_cls = vectorstore.__class__
BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
Pinecone: PineconeTranslator,
Chroma: ChromaTranslator,
Weaviate: WeaviateTranslator,
Qdrant: QdrantTranslator,
MyScale: MyScaleTranslator,
}
if vectorstore_cls not in BUILTIN_TRANSLATORS:
raise ValueError(
f"Self query retriever with Vector Store type {vectorstore_cls}"
f" not supported."
)
if isinstance(vectorstore, Qdrant):
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
elif isinstance(vectorstore, MyScale):
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
return BUILTIN_TRANSLATORS[vectorstore_cls]()
class SelfQueryRetriever(BaseRetriever, BaseModel):
"""Retriever that wraps around a vector store and uses an LLM to generate
the vector store queries."""
vectorstore: VectorStore
"""The underlying vector store from which documents will be retrieved."""
llm_chain: LLMChain
"""The LLMChain for generating the vector store queries."""
search_type: str = "similarity"
"""The search type to perform on the vector store."""
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass in to the vector store search."""
structured_query_translator: Visitor
"""Translator for turning internal query language into vectorstore search params."""
verbose: bool = False
"""Use original query instead of the revised new query from LLM"""
use_original_query: bool = False
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator(pre=True)
def validate_translator(cls, values: Dict) -> Dict:
"""Validate translator."""
if "structured_query_translator" not in values:
values["structured_query_translator"] = _get_builtin_translator(
values["vectorstore"]
)
return values
def get_relevant_documents(
self, query: str, callbacks: Callbacks = None
) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = cast(
StructuredQuery,
self.llm_chain.predict_and_parse(callbacks=callbacks, **inputs),
)
if self.verbose:
print(structured_query)
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
structured_query
)
if structured_query.limit is not None:
new_kwargs["k"] = structured_query.limit
if self.use_original_query:
new_query = query
search_kwargs = {**self.search_kwargs, **new_kwargs}
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
vectorstore: VectorStore,
document_contents: str,
metadata_field_info: List[AttributeInfo],
structured_query_translator: Optional[Visitor] = None,
chain_kwargs: Optional[Dict] = None,
enable_limit: bool = False,
use_original_query: bool = False,
**kwargs: Any,
) -> "SelfQueryRetriever":
if structured_query_translator is None:
structured_query_translator = _get_builtin_translator(vectorstore)
chain_kwargs = chain_kwargs or {}
if "allowed_comparators" not in chain_kwargs:
chain_kwargs[
"allowed_comparators"
] = structured_query_translator.allowed_comparators
if "allowed_operators" not in chain_kwargs:
chain_kwargs[
"allowed_operators"
] = structured_query_translator.allowed_operators
llm_chain = load_query_constructor_chain(
llm,
document_contents,
metadata_field_info,
enable_limit=enable_limit,
**chain_kwargs,
)
return cls(
llm_chain=llm_chain,
vectorstore=vectorstore,
use_original_query=use_original_query,
structured_query_translator=structured_query_translator,
**kwargs,
)