update SelfQueryRetriever

pull/20800/head
Chester Curme 1 month ago
parent 26455d156d
commit c262cef1fb

@ -157,6 +157,12 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"base",
"Document",
),
("langchain", "schema", "document_search_hit", "DocumentSearchHit"): (
"langchain_core",
"documents",
"base",
"DocumentSearchHit",
),
("langchain", "output_parsers", "fix", "OutputFixingParser"): (
"langchain",
"output_parsers",
@ -666,6 +672,12 @@ OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"base",
"Document",
),
("langchain_core", "documents", "base", "DocumentSearchHit"): (
"langchain_core",
"documents",
"base",
"DocumentSearchHit",
),
("langchain_core", "prompts", "chat", "AIMessagePromptTemplate"): (
"langchain_core",
"prompts",

@ -33,7 +33,7 @@ from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.documents import Document, DocumentSearchHit
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
@ -192,19 +192,43 @@ class SelfQueryRetriever(BaseRetriever):
return new_query, search_kwargs
def _get_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
) -> List[Document]:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
if include_score:
docs_and_scores = self.vectorstore.similarity_search_with_score(
query, **search_kwargs
)
return [
DocumentSearchHit(page_content=doc.page_content, score=score)
for doc, score in docs_and_scores
]
else:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs
async def _aget_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
) -> List[Document]:
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
if include_score:
docs_and_scores = await self.vectorstore.asimilarity_search_with_score(
query, **search_kwargs
)
return [
DocumentSearchHit(page_content=doc.page_content, score=score)
for doc, score in docs_and_scores
]
else:
docs = await self.vectorstore.asearch(
query, self.search_type, **search_kwargs
)
return docs
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
"""Get documents relevant for a query.
@ -220,11 +244,17 @@ class SelfQueryRetriever(BaseRetriever):
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = self._get_docs_with_query(new_query, search_kwargs)
docs = self._get_docs_with_query(
new_query, search_kwargs, include_score=include_score
)
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
"""Get documents relevant for a query.
@ -240,7 +270,9 @@ class SelfQueryRetriever(BaseRetriever):
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = await self._aget_docs_with_query(new_query, search_kwargs)
docs = await self._aget_docs_with_query(
new_query, search_kwargs, include_score=include_score
)
return docs
@classmethod

Loading…
Cancel
Save