langchain/libs/experimental/langchain_experimental/retrievers/vector_sql_database.py
Leonid Ganeline e512d3c6a6
langchain: callbacks imports fix (#20348)
Replaced all `from langchain.callbacks` into `from
langchain_core.callbacks` .
Changes in the `langchain` and `langchain_experimental`

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-04-12 20:13:14 +00:00

41 lines
1.2 KiB
Python

"""Vector SQL Database Chain Retriever"""
from typing import Any, Dict, List
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
class VectorSQLDatabaseChainRetriever(BaseRetriever):
"""Retriever that uses Vector SQL Database."""
sql_db_chain: VectorSQLDatabaseChain
"""SQL Database Chain"""
page_content_key: str = "content"
"""column name for page content of documents"""
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
ret: List[Dict[str, Any]] = self.sql_db_chain(
query, callbacks=run_manager.get_child(), **kwargs
)["result"]
return [
Document(page_content=r[self.page_content_key], metadata=r) for r in ret
]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError