From 76f30f5297dd6519883d46a93ad6994e18bb1ff2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Thu, 7 Dec 2023 20:13:19 +0100 Subject: [PATCH] langchain[patch]: Rollback multiple keys in Qdrant (#14390) This reverts commit 38813d7090294c0c96d4963a2a230db4fef5e37e. This is a temporary fix, as I don't see a clear way on how to use multiple keys with `Qdrant.from_texts`. Context: #14378 --- .../langchain/vectorstores/qdrant.py | 106 +++++------------- 1 file changed, 28 insertions(+), 78 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/qdrant.py b/libs/langchain/langchain/vectorstores/qdrant.py index 4d6f3170c8..09cba48911 100644 --- a/libs/langchain/langchain/vectorstores/qdrant.py +++ b/libs/langchain/langchain/vectorstores/qdrant.py @@ -82,8 +82,8 @@ class Qdrant(VectorStore): qdrant = Qdrant(client, collection_name, embedding_function) """ - CONTENT_KEY = ["page_content"] - METADATA_KEY = ["metadata"] + CONTENT_KEY = "page_content" + METADATA_KEY = "metadata" VECTOR_NAME = None def __init__( @@ -91,8 +91,8 @@ class Qdrant(VectorStore): client: Any, collection_name: str, embeddings: Optional[Embeddings] = None, - content_payload_key: Union[list, str] = CONTENT_KEY, - metadata_payload_key: Union[list, str] = METADATA_KEY, + content_payload_key: str = CONTENT_KEY, + metadata_payload_key: str = METADATA_KEY, distance_strategy: str = "COSINE", vector_name: Optional[str] = VECTOR_NAME, embedding_function: Optional[Callable] = None, # deprecated @@ -112,12 +112,6 @@ class Qdrant(VectorStore): f"got {type(client)}" ) - if isinstance(content_payload_key, str): # Ensuring Backward compatibility - content_payload_key = [content_payload_key] - - if isinstance(metadata_payload_key, str): # Ensuring Backward compatibility - metadata_payload_key = [metadata_payload_key] - if embeddings is None and embedding_function is None: raise ValueError( "`embeddings` value can't be None. Pass `Embeddings` instance." @@ -133,14 +127,8 @@ class Qdrant(VectorStore): self._embeddings_function = embedding_function self.client: qdrant_client.QdrantClient = client self.collection_name = collection_name - self.content_payload_key = ( - content_payload_key if content_payload_key is not None else self.CONTENT_KEY - ) - self.metadata_payload_key = ( - metadata_payload_key - if metadata_payload_key is not None - else self.METADATA_KEY - ) + self.content_payload_key = content_payload_key or self.CONTENT_KEY + self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY self.vector_name = vector_name or self.VECTOR_NAME if embedding_function is not None: @@ -1190,8 +1178,8 @@ class Qdrant(VectorStore): path: Optional[str] = None, collection_name: Optional[str] = None, distance_func: str = "Cosine", - content_payload_key: List[str] = CONTENT_KEY, - metadata_payload_key: List[str] = METADATA_KEY, + content_payload_key: str = CONTENT_KEY, + metadata_payload_key: str = METADATA_KEY, vector_name: Optional[str] = VECTOR_NAME, batch_size: int = 64, shard_number: Optional[int] = None, @@ -1366,8 +1354,8 @@ class Qdrant(VectorStore): path: Optional[str] = None, collection_name: Optional[str] = None, distance_func: str = "Cosine", - content_payload_key: List[str] = CONTENT_KEY, - metadata_payload_key: List[str] = METADATA_KEY, + content_payload_key: str = CONTENT_KEY, + metadata_payload_key: str = METADATA_KEY, vector_name: Optional[str] = VECTOR_NAME, batch_size: int = 64, shard_number: Optional[int] = None, @@ -1539,8 +1527,8 @@ class Qdrant(VectorStore): path: Optional[str] = None, collection_name: Optional[str] = None, distance_func: str = "Cosine", - content_payload_key: List[str] = CONTENT_KEY, - metadata_payload_key: List[str] = METADATA_KEY, + content_payload_key: str = CONTENT_KEY, + metadata_payload_key: str = METADATA_KEY, vector_name: Optional[str] = VECTOR_NAME, shard_number: Optional[int] = None, replication_factor: Optional[int] = None, @@ -1703,8 +1691,8 @@ class Qdrant(VectorStore): path: Optional[str] = None, collection_name: Optional[str] = None, distance_func: str = "Cosine", - content_payload_key: List[str] = CONTENT_KEY, - metadata_payload_key: List[str] = METADATA_KEY, + content_payload_key: str = CONTENT_KEY, + metadata_payload_key: str = METADATA_KEY, vector_name: Optional[str] = VECTOR_NAME, shard_number: Optional[int] = None, replication_factor: Optional[int] = None, @@ -1900,11 +1888,11 @@ class Qdrant(VectorStore): @classmethod def _build_payloads( - cls: Type[Qdrant], + cls, texts: Iterable[str], metadatas: Optional[List[dict]], - content_payload_key: list[str], - metadata_payload_key: list[str], + content_payload_key: str, + metadata_payload_key: str, ) -> List[dict]: payloads = [] for i, text in enumerate(texts): @@ -1925,67 +1913,29 @@ class Qdrant(VectorStore): @classmethod def _document_from_scored_point( - cls: Type[Qdrant], + cls, scored_point: Any, - content_payload_key: list[str], - metadata_payload_key: list[str], + content_payload_key: str, + metadata_payload_key: str, ) -> Document: - payload = scored_point.payload - return Qdrant._document_from_payload( - payload=payload, - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, + return Document( + page_content=scored_point.payload.get(content_payload_key), + metadata=scored_point.payload.get(metadata_payload_key) or {}, ) @classmethod def _document_from_scored_point_grpc( - cls: Type[Qdrant], + cls, scored_point: Any, - content_payload_key: list[str], - metadata_payload_key: list[str], + content_payload_key: str, + metadata_payload_key: str, ) -> Document: from qdrant_client.conversions.conversion import grpc_to_payload payload = grpc_to_payload(scored_point.payload) - return Qdrant._document_from_payload( - payload=payload, - content_payload_key=content_payload_key, - metadata_payload_key=metadata_payload_key, - ) - - @classmethod - def _document_from_payload( - cls: Type[Qdrant], - payload: Any, - content_payload_key: list[str], - metadata_payload_key: list[str], - ) -> Document: - if len(content_payload_key) == 1: - content = payload.get( - content_payload_key - ) # Ensuring backward compatibility - elif len(content_payload_key) > 1: - content = { - content_key: payload.get(content_key) - for content_key in content_payload_key - } - content = str(content) # Ensuring str type output - else: - content = "" - if len(metadata_payload_key) == 1: - metadata = payload.get( - metadata_payload_key - ) # Ensuring backward compatibility - elif len(metadata_payload_key) > 1: - metadata = { - metadata_key: payload.get(metadata_key) - for metadata_key in metadata_payload_key - } - else: - metadata = {} return Document( - page_content=content, - metadata=metadata, + page_content=payload[content_payload_key], + metadata=payload.get(metadata_payload_key) or {}, ) def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]: