Add embedding and vectorstore provider info as tags (#8027)

Example:
https://smith.langchain.com/public/bcd3714d-abba-4790-81c8-9b5718535867/r


The vectorstore implementations aren't super standardized yet, so just
adding an optional embeddings property to pass in.
pull/8071/head
William FH 1 year ago committed by GitHub
parent 355b7d8b86
commit c38965fcba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -326,6 +326,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
@ -346,6 +347,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
start_time=start_time,
execution_order=execution_order,
child_execution_order=execution_order,
tags=tags,
child_runs=[],
run_type=RunTypeEnum.retriever,
)

@ -79,6 +79,10 @@ class AnalyticDB(VectorStore):
self.engine = create_engine(self.connection_string, **_engine_args)
self.create_collection()
@property
def embeddings(self) -> Embeddings:
return self.embedding_function
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._euclidean_relevance_score_fn

@ -61,6 +61,11 @@ class Annoy(VectorStore):
self.docstore = docstore
self.index_to_docstore_id = index_to_docstore_id
@property
def embeddings(self) -> Optional[Embeddings]:
# TODO: Accept embedding object directly
return None
def add_texts(
self,
texts: Iterable[str],

@ -46,7 +46,7 @@ class AtlasDB(VectorStore):
Args:
name (str): The name of your project. If the project already exists,
it will be loaded.
embedding_function (Optional[Callable]): An optional function used for
embedding_function (Optional[Embeddings]): An optional function used for
embedding your data. If None, data will be embedded with
Nomic's embed model.
api_key (str): Your nomic API key
@ -86,6 +86,10 @@ class AtlasDB(VectorStore):
)
self.project._latest_project_state()
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding_function
def add_texts(
self,
texts: Iterable[str],

@ -73,6 +73,12 @@ class AwaDB(VectorStore):
self.table2embeddings[table_name] = embedding
self.using_table_name = table_name
@property
def embeddings(self) -> Optional[Embeddings]:
if self.using_table_name in self.table2embeddings:
return self.table2embeddings[self.using_table_name]
return None
def add_texts(
self,
texts: Iterable[str],

@ -191,6 +191,11 @@ class AzureSearch(VectorStore):
self.semantic_configuration_name = semantic_configuration_name
self.semantic_query_language = semantic_query_language
@property
def embeddings(self) -> Optional[Embeddings]:
# TODO: Support embedding object directly
return None
def add_texts(
self,
texts: Iterable[str],

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import logging
import math
import warnings
from abc import ABC, abstractmethod
@ -31,6 +32,8 @@ from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever
logger = logging.getLogger(__name__)
VST = TypeVar("VST", bound="VectorStore")
@ -55,6 +58,14 @@ class VectorStore(ABC):
List of ids from adding the texts into the vectorstore.
"""
@property
def embeddings(self) -> Optional[Embeddings]:
"""Access the query embedding object if available."""
logger.debug(
f"{Embeddings.__name__} is not implemented for {self.__class__.__name__}"
)
return None
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by vector ID or other criteria.
@ -435,8 +446,17 @@ class VectorStore(ABC):
"""Return VectorStore initialized from texts and embeddings."""
raise NotImplementedError
def __get_retriever_tags(self) -> List[str]:
"""Get tags for retriever."""
tags = [self.__class__.__name__]
if self.embeddings:
tags.append(self.embeddings.__class__.__name__)
return tags
def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
return VectorStoreRetriever(vectorstore=self, **kwargs)
tags = kwargs.pop("tags", None) or []
tags.extend(self.__get_retriever_tags())
return VectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
class VectorStoreRetriever(BaseRetriever):

@ -77,6 +77,10 @@ class Cassandra(VectorStore):
primary_key_type="TEXT",
)
@property
def embeddings(self) -> Embeddings:
return self.embedding
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._cosine_relevance_score_fn

@ -121,6 +121,10 @@ class Chroma(VectorStore):
)
self.override_relevance_score_fn = relevance_score_fn
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding_function
@xor_args(("query_texts", "query_embeddings"))
def __query_collection(
self,

@ -212,6 +212,10 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
self.client.command("SET allow_experimental_annoy_index=1")
self.client.command(self.schema)
@property
def embeddings(self) -> Embeddings:
return self.embedding_function
def escape_str(self, value: str) -> str:
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)

@ -151,6 +151,10 @@ class DeepLake(VectorStore):
self._embedding_function = embedding_function
self._id_tensor_name = "ids" if "ids" in self.vectorstore.tensors() else "id"
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding_function
def add_texts(
self,
texts: Iterable[str],

@ -153,6 +153,10 @@ class ElasticVectorSearch(VectorStore, ABC):
f"Your elasticsearch client string is mis-formatted. Got error: {e} "
)
@property
def embeddings(self) -> Embeddings:
return self.embeddings
def add_texts(
self,
texts: Iterable[str],

@ -87,6 +87,11 @@ class FAISS(VectorStore):
)
)
@property
def embeddings(self) -> Optional[Embeddings]:
# TODO: Accept embeddings object directly
return None
def __add(
self,
texts: Iterable[str],

@ -148,6 +148,10 @@ class Hologres(VectorStore):
self.create_vector_extension()
self.create_table()
@property
def embeddings(self) -> Embeddings:
return self.embedding_function
def create_vector_extension(self) -> None:
try:
self.storage.create_vector_extension()

@ -51,6 +51,10 @@ class LanceDB(VectorStore):
self._id_key = id_key
self._text_key = text_key
@property
def embeddings(self) -> Embeddings:
return self._embedding
def add_texts(
self,
texts: Iterable[str],

@ -82,6 +82,10 @@ class Marqo(VectorStore):
self._document_batch_size = 1024
@property
def embeddings(self) -> Optional[Embeddings]:
return None
def add_texts(
self,
texts: Iterable[str],

@ -84,6 +84,10 @@ class MatchingEngine(VectorStore):
self.credentials = credentials
self.gcs_bucket_name = gcs_bucket_name
@property
def embeddings(self) -> Embeddings:
return self.embedding
def _validate_google_libraries_installation(self) -> None:
"""Validates that Google libraries that are needed are installed."""
try:

@ -164,6 +164,10 @@ class Milvus(VectorStore):
# Initialize the vector store
self._init()
@property
def embeddings(self) -> Embeddings:
return self.embedding_func
def _create_connection_alias(self, connection_args: dict) -> str:
"""Create the connection to the Milvus server."""
from pymilvus import MilvusException, connections

@ -77,6 +77,10 @@ class MongoDBAtlasVectorSearch(VectorStore):
self._text_key = text_key
self._embedding_key = embedding_key
@property
def embeddings(self) -> Embeddings:
return self._embedding
@classmethod
def from_connection_string(
cls,

@ -115,7 +115,7 @@ class MyScale(VectorStore):
) -> None:
"""MyScale Wrapper to LangChain
embedding_function (Embeddings):
embedding (Embeddings):
config (MyScaleSettings): Configuration to MyScale Client
Other keyword arguments will pass into
[clickhouse-connect](https://docs.myscale.com/)
@ -175,7 +175,7 @@ class MyScale(VectorStore):
self.dim = dim
self.BS = "\\"
self.must_escape = ("\\", "'")
self.embedding_function = embedding.embed_query
self._embeddings = embedding
self.dist_order = "ASC" if self.config.metric in ["cosine", "l2"] else "DESC"
# Create a connection to myscale
@ -189,6 +189,10 @@ class MyScale(VectorStore):
self.client.command("SET allow_experimental_object_type=1")
self.client.command(schema_)
@property
def embeddings(self) -> Embeddings:
return self._embeddings
def escape_str(self, value: str) -> str:
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
@ -238,7 +242,7 @@ class MyScale(VectorStore):
column_names = {
colmap_["id"]: ids,
colmap_["text"]: texts,
colmap_["vector"]: map(self.embedding_function, texts),
colmap_["vector"]: map(self._embeddings.embed_query, texts),
}
metadatas = metadatas or [{} for _ in texts]
column_names[colmap_["metadata"]] = map(json.dumps, metadatas)
@ -269,7 +273,7 @@ class MyScale(VectorStore):
@classmethod
def from_texts(
cls,
texts: List[str],
texts: Iterable[str],
embedding: Embeddings,
metadatas: Optional[List[Dict[Any, Any]]] = None,
config: Optional[MyScaleSettings] = None,
@ -280,8 +284,8 @@ class MyScale(VectorStore):
"""Create Myscale wrapper with existing texts
Args:
embedding_function (Embeddings): Function to extract text embedding
texts (Iterable[str]): List or tuple of strings to be added
embedding (Embeddings): Function to extract text embedding
config (MyScaleSettings, Optional): Myscale configuration
text_ids (Optional[Iterable], optional): IDs for the texts.
Defaults to None.
@ -357,7 +361,7 @@ class MyScale(VectorStore):
List[Document]: List of Documents
"""
return self.similarity_search_by_vector(
self.embedding_function(query), k, where_str, **kwargs
self._embeddings.embed_query(query), k, where_str, **kwargs
)
def similarity_search_by_vector(
@ -417,7 +421,7 @@ class MyScale(VectorStore):
and cosine distance in float for each.
Lower score represents more similarity.
"""
q_str = self._build_qstr(self.embedding_function(query), k, where_str)
q_str = self._build_qstr(self._embeddings.embed_query(query), k, where_str)
try:
return [
(

@ -316,6 +316,10 @@ class OpenSearchVectorSearch(VectorStore):
self.index_name = index_name
self.client = _get_opensearch_client(opensearch_url, **kwargs)
@property
def embeddings(self) -> Embeddings:
return self.embedding_function
def add_texts(
self,
texts: Iterable[str],

@ -135,6 +135,10 @@ class PGEmbedding(VectorStore):
self.create_tables_if_not_exists()
self.create_collection()
@property
def embeddings(self) -> Embeddings:
return self.embedding_function
def connect(self) -> sqlalchemy.engine.Connection:
engine = sqlalchemy.create_engine(self.connection_string)
conn = engine.connect()

@ -125,6 +125,10 @@ class PGVector(VectorStore):
self.create_tables_if_not_exists()
self.create_collection()
@property
def embeddings(self) -> Embeddings:
return self.embedding_function
def connect(self) -> sqlalchemy.engine.Connection:
engine = sqlalchemy.create_engine(self.connection_string)
conn = engine.connect()

@ -62,6 +62,11 @@ class Pinecone(VectorStore):
self._namespace = namespace
self.distance_strategy = distance_strategy
@property
def embeddings(self) -> Optional[Embeddings]:
# TODO: Accept this object directly
return None
def add_texts(
self,
texts: Iterable[str],

@ -123,7 +123,7 @@ class Qdrant(VectorStore):
"Use `embeddings` only."
)
self.embeddings = embeddings
self._embeddings = embeddings
self._embeddings_function = embedding_function
self.client: qdrant_client.QdrantClient = client
self.collection_name = collection_name
@ -143,10 +143,14 @@ class Qdrant(VectorStore):
"Using `embeddings` as `embedding_function` which is deprecated"
)
self._embeddings_function = embeddings
self.embeddings = None
self._embeddings = None
self.distance_strategy = distance_strategy.upper()
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embeddings
def add_texts(
self,
texts: Iterable[str],

@ -161,6 +161,11 @@ class Redis(VectorStore):
self.distance_metric = distance_metric
self.relevance_score_fn = relevance_score_fn
@property
def embeddings(self) -> Optional[Embeddings]:
# TODO: Accept embedding object directly
return None
def _select_relevance_score_fn(self) -> Callable[[float], float]:
if self.relevance_score_fn:
return self.relevance_score_fn
@ -601,7 +606,9 @@ class Redis(VectorStore):
)
def as_retriever(self, **kwargs: Any) -> RedisVectorStoreRetriever:
return RedisVectorStoreRetriever(vectorstore=self, **kwargs)
tags = kwargs.pop("tags", None) or []
tags.extend(self.__get_retriever_tags())
return RedisVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
class RedisVectorStoreRetriever(VectorStoreRetriever):

@ -83,6 +83,10 @@ class Rockset(VectorStore):
self._text_key = text_key
self._embedding_key = embedding_key
@property
def embeddings(self) -> Embeddings:
return self._embeddings
def add_texts(
self,
texts: Iterable[str],

@ -213,6 +213,10 @@ class SingleStoreDB(VectorStore):
)
self._create_table()
@property
def embeddings(self) -> Embeddings:
return self.embedding
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._max_inner_product_relevance_score_fn
@ -441,7 +445,9 @@ class SingleStoreDB(VectorStore):
return instance
def as_retriever(self, **kwargs: Any) -> SingleStoreDBRetriever:
return SingleStoreDBRetriever(vectorstore=self, **kwargs)
tags = kwargs.pop("tags", None) or []
tags.extend(self.__get_retriever_tags())
return SingleStoreDBRetriever(vectorstore=self, **kwargs, tags=tags)
class SingleStoreDBRetriever(VectorStoreRetriever):

@ -163,6 +163,10 @@ class SKLearnVectorStore(VectorStore):
if self._persist_path is not None and os.path.isfile(self._persist_path):
self._load()
@property
def embeddings(self) -> Embeddings:
return self._embedding_function
def persist(self) -> None:
if self._serializer is None:
raise SKLearnVectorStoreException(

@ -207,6 +207,10 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
def escape_str(self, value: str) -> str:
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
@property
def embeddings(self) -> Embeddings:
return self.embedding_function
def _build_insert_sql(self, transac: Iterable, column_names: Iterable[str]) -> str:
ks = ",".join(column_names)
embed_tuple_index = tuple(column_names).index(

@ -67,6 +67,10 @@ class SupabaseVectorStore(VectorStore):
self.table_name = table_name or "documents"
self.query_name = query_name or "match_documents"
@property
def embeddings(self) -> Embeddings:
return self._embedding
def add_texts(
self,
texts: Iterable[str],

@ -51,6 +51,10 @@ class Tair(VectorStore):
self.metadata_key = metadata_key
self.search_params = search_params
@property
def embeddings(self) -> Embeddings:
return self.embedding_function
def create_index_if_not_exist(
self,
dim: int,

@ -28,6 +28,10 @@ class Tigris(VectorStore):
self._embed_fn = embeddings
self._vector_store = TigrisVectorStore(client.get_search(), index_name)
@property
def embeddings(self) -> Embeddings:
return self._embed_fn
@property
def search_index(self) -> TigrisVectorStore:
return self._vector_store

@ -81,6 +81,10 @@ class Typesense(VectorStore):
def _collection(self) -> Collection:
return self._typesense_client.collections[self._typesense_collection_name]
@property
def embeddings(self) -> Embeddings:
return self._embedding
def _prep_texts(
self,
texts: Iterable[str],

@ -61,6 +61,10 @@ class Vectara(VectorStore):
adapter = requests.adapters.HTTPAdapter(max_retries=3)
self._session.mount("http://", adapter)
@property
def embeddings(self) -> Optional[Embeddings]:
return None
def _get_post_headers(self) -> dict:
"""Returns headers that should be attached to each post request."""
return {
@ -402,7 +406,9 @@ class Vectara(VectorStore):
return vectara
def as_retriever(self, **kwargs: Any) -> VectaraRetriever:
return VectaraRetriever(vectorstore=self, **kwargs)
tags = kwargs.pop("tags", None) or []
tags.extend(self.__get_retriever_tags())
return VectaraRetriever(vectorstore=self, **kwargs, tags=tags)
class VectaraRetriever(VectorStoreRetriever):

@ -118,6 +118,10 @@ class Weaviate(VectorStore):
if attributes is not None:
self._query_attrs.extend(attributes)
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return (
self.relevance_score_fn

Loading…
Cancel
Save