|
|
|
@ -33,7 +33,7 @@ from langchain_core.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManagerForRetrieverRun,
|
|
|
|
|
CallbackManagerForRetrieverRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
from langchain_core.documents import Document, DocumentSearchHit
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
|
from langchain_core.pydantic_v1 import Field, root_validator
|
|
|
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
@ -192,19 +192,43 @@ class SelfQueryRetriever(BaseRetriever):
|
|
|
|
|
return new_query, search_kwargs
|
|
|
|
|
|
|
|
|
|
def _get_docs_with_query(
|
|
|
|
|
self, query: str, search_kwargs: Dict[str, Any]
|
|
|
|
|
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
|
|
|
|
|
if include_score:
|
|
|
|
|
docs_and_scores = self.vectorstore.similarity_search_with_score(
|
|
|
|
|
query, **search_kwargs
|
|
|
|
|
)
|
|
|
|
|
return [
|
|
|
|
|
DocumentSearchHit(page_content=doc.page_content, score=score)
|
|
|
|
|
for doc, score in docs_and_scores
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
async def _aget_docs_with_query(
|
|
|
|
|
self, query: str, search_kwargs: Dict[str, Any]
|
|
|
|
|
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
|
|
|
|
|
if include_score:
|
|
|
|
|
docs_and_scores = await self.vectorstore.asimilarity_search_with_score(
|
|
|
|
|
query, **search_kwargs
|
|
|
|
|
)
|
|
|
|
|
return [
|
|
|
|
|
DocumentSearchHit(page_content=doc.page_content, score=score)
|
|
|
|
|
for doc, score in docs_and_scores
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
docs = await self.vectorstore.asearch(
|
|
|
|
|
query, self.search_type, **search_kwargs
|
|
|
|
|
)
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
def _get_relevant_documents(
|
|
|
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
|
|
|
self,
|
|
|
|
|
query: str,
|
|
|
|
|
*,
|
|
|
|
|
run_manager: CallbackManagerForRetrieverRun,
|
|
|
|
|
include_score: bool = False,
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
"""Get documents relevant for a query.
|
|
|
|
|
|
|
|
|
@ -220,11 +244,17 @@ class SelfQueryRetriever(BaseRetriever):
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logger.info(f"Generated Query: {structured_query}")
|
|
|
|
|
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
|
|
|
|
docs = self._get_docs_with_query(new_query, search_kwargs)
|
|
|
|
|
docs = self._get_docs_with_query(
|
|
|
|
|
new_query, search_kwargs, include_score=include_score
|
|
|
|
|
)
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
async def _aget_relevant_documents(
|
|
|
|
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
|
|
|
self,
|
|
|
|
|
query: str,
|
|
|
|
|
*,
|
|
|
|
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
|
|
|
include_score: bool = False,
|
|
|
|
|
) -> List[Document]:
|
|
|
|
|
"""Get documents relevant for a query.
|
|
|
|
|
|
|
|
|
@ -240,7 +270,9 @@ class SelfQueryRetriever(BaseRetriever):
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logger.info(f"Generated Query: {structured_query}")
|
|
|
|
|
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
|
|
|
|
docs = await self._aget_docs_with_query(new_query, search_kwargs)
|
|
|
|
|
docs = await self._aget_docs_with_query(
|
|
|
|
|
new_query, search_kwargs, include_score=include_score
|
|
|
|
|
)
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|