mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
|
"""Vector SQL Database Chain Retriever"""
|
||
|
from typing import Any, Dict, List
|
||
|
|
||
|
from langchain.callbacks.manager import (
|
||
|
AsyncCallbackManagerForRetrieverRun,
|
||
|
CallbackManagerForRetrieverRun,
|
||
|
)
|
||
|
from langchain.schema import BaseRetriever, Document
|
||
|
|
||
|
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
||
|
|
||
|
|
||
|
class VectorSQLDatabaseChainRetriever(BaseRetriever):
|
||
|
"""Retriever that uses SQLDatabase as Retriever"""
|
||
|
|
||
|
sql_db_chain: VectorSQLDatabaseChain
|
||
|
"""SQL Database Chain"""
|
||
|
page_content_key: str = "content"
|
||
|
"""column name for page content of documents"""
|
||
|
|
||
|
def _get_relevant_documents(
|
||
|
self,
|
||
|
query: str,
|
||
|
*,
|
||
|
run_manager: CallbackManagerForRetrieverRun,
|
||
|
**kwargs: Any,
|
||
|
) -> List[Document]:
|
||
|
ret: List[Dict[str, Any]] = self.sql_db_chain(
|
||
|
query, callbacks=run_manager.get_child(), **kwargs
|
||
|
)["result"]
|
||
|
return [
|
||
|
Document(page_content=r[self.page_content_key], metadata=r) for r in ret
|
||
|
]
|
||
|
|
||
|
async def _aget_relevant_documents(
|
||
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||
|
) -> List[Document]:
|
||
|
raise NotImplementedError
|