mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
e512d3c6a6
Replaced all `from langchain.callbacks` into `from langchain_core.callbacks` . Changes in the `langchain` and `langchain_experimental` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
"""Vector SQL Database Chain Retriever"""
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
from langchain_core.callbacks.manager import (
|
|
AsyncCallbackManagerForRetrieverRun,
|
|
CallbackManagerForRetrieverRun,
|
|
)
|
|
from langchain_core.documents import Document
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
|
|
|
|
|
class VectorSQLDatabaseChainRetriever(BaseRetriever):
|
|
"""Retriever that uses Vector SQL Database."""
|
|
|
|
sql_db_chain: VectorSQLDatabaseChain
|
|
"""SQL Database Chain"""
|
|
page_content_key: str = "content"
|
|
"""column name for page content of documents"""
|
|
|
|
def _get_relevant_documents(
|
|
self,
|
|
query: str,
|
|
*,
|
|
run_manager: CallbackManagerForRetrieverRun,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
ret: List[Dict[str, Any]] = self.sql_db_chain(
|
|
query, callbacks=run_manager.get_child(), **kwargs
|
|
)["result"]
|
|
return [
|
|
Document(page_content=r[self.page_content_key], metadata=r) for r in ret
|
|
]
|
|
|
|
async def _aget_relevant_documents(
|
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
) -> List[Document]:
|
|
raise NotImplementedError
|