langchain/tests/integration_tests/vectorstores/fake_embeddings.py
Kacper Łukawski f93d256190
Feat: Add batching to Qdrant (#5443)
# Add batching to Qdrant

Several people requested a batching mechanism while uploading data to
Qdrant. It is important, as there are some limits for the maximum size
of the request payload, and without batching implemented in Langchain,
users need to implement it on their own. This PR exposes a new optional
`batch_size` parameter, so all the documents/texts are loaded in batches
of the expected size (64, by default).

The integration tests of Qdrant are extended to cover two cases:
1. Documents are sent in separate batches.
2. All the documents are sent in a single request.
2023-05-30 15:33:54 -07:00

48 lines
1.8 KiB
Python

"""Fake Embedding class for testing purposes."""
from typing import List
from langchain.embeddings.base import Embeddings
fake_texts = ["foo", "bar", "baz"]
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
def embed_query(self, text: str) -> List[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)]
class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts."""
def __init__(self) -> None:
self.known_texts: List[str] = []
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * 9 + [float(self.known_texts.index(text))]
out_vectors.append(vector)
return out_vectors
def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
if text not in self.known_texts:
return [float(1.0)] * 9 + [float(0.0)]
return [float(1.0)] * 9 + [float(self.known_texts.index(text))]