langchain/libs/community/langchain_community/vectorstores/semadb.py
Erick Friis 3a2eb6e12b
infra: add print rule to ruff (#16221)
Added noqa for existing prints. Can slowly remove / will prevent more
being intro'd
2024-02-09 16:13:30 -08:00

273 lines
9.5 KiB
Python

from typing import Any, Iterable, List, Optional, Tuple
from uuid import uuid4
import numpy as np
import requests
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_env
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import DistanceStrategy
class SemaDB(VectorStore):
"""`SemaDB` vector store.
This vector store is a wrapper around the SemaDB database.
Example:
.. code-block:: python
from langchain_community.vectorstores import SemaDB
db = SemaDB('mycollection', 768, embeddings, DistanceStrategy.COSINE)
"""
HOST = "semadb.p.rapidapi.com"
BASE_URL = "https://" + HOST
def __init__(
self,
collection_name: str,
vector_size: int,
embedding: Embeddings,
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
api_key: str = "",
):
"""initialize the SemaDB vector store."""
self.collection_name = collection_name
self.vector_size = vector_size
self.api_key = api_key or get_from_env("api_key", "SEMADB_API_KEY")
self._embedding = embedding
self.distance_strategy = distance_strategy
@property
def headers(self) -> dict:
"""Return the common headers."""
return {
"content-type": "application/json",
"X-RapidAPI-Key": self.api_key,
"X-RapidAPI-Host": SemaDB.HOST,
}
def _get_internal_distance_strategy(self) -> str:
"""Return the internal distance strategy."""
if self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
return "euclidean"
elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
raise ValueError("Max inner product is not supported by SemaDB")
elif self.distance_strategy == DistanceStrategy.DOT_PRODUCT:
return "dot"
elif self.distance_strategy == DistanceStrategy.JACCARD:
raise ValueError("Max inner product is not supported by SemaDB")
elif self.distance_strategy == DistanceStrategy.COSINE:
return "cosine"
else:
raise ValueError(f"Unknown distance strategy {self.distance_strategy}")
def create_collection(self) -> bool:
"""Creates the corresponding collection in SemaDB."""
payload = {
"id": self.collection_name,
"vectorSize": self.vector_size,
"distanceMetric": self._get_internal_distance_strategy(),
}
response = requests.post(
SemaDB.BASE_URL + "/collections",
json=payload,
headers=self.headers,
)
return response.status_code == 200
def delete_collection(self) -> bool:
"""Deletes the corresponding collection in SemaDB."""
response = requests.delete(
SemaDB.BASE_URL + f"/collections/{self.collection_name}",
headers=self.headers,
)
return response.status_code == 200
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> List[str]:
"""Add texts to the vector store."""
if not isinstance(texts, list):
texts = list(texts)
embeddings = self._embedding.embed_documents(texts)
# Check dimensions
if len(embeddings[0]) != self.vector_size:
raise ValueError(
f"Embedding size mismatch {len(embeddings[0])} != {self.vector_size}"
)
# Normalise if needed
if self.distance_strategy == DistanceStrategy.COSINE:
embed_matrix = np.array(embeddings)
embed_matrix = embed_matrix / np.linalg.norm(
embed_matrix, axis=1, keepdims=True
)
embeddings = embed_matrix.tolist()
# Create points
ids: List[str] = []
points = []
if metadatas is not None:
for text, embedding, metadata in zip(texts, embeddings, metadatas):
new_id = str(uuid4())
ids.append(new_id)
points.append(
{
"id": new_id,
"vector": embedding,
"metadata": {**metadata, **{"text": text}},
}
)
else:
for text, embedding in zip(texts, embeddings):
new_id = str(uuid4())
ids.append(new_id)
points.append(
{
"id": new_id,
"vector": embedding,
"metadata": {"text": text},
}
)
# Insert points in batches
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
response = requests.post(
SemaDB.BASE_URL + f"/collections/{self.collection_name}/points",
json={"points": batch},
headers=self.headers,
)
if response.status_code != 200:
print("HERE--", batch) # noqa: T201
raise ValueError(f"Error adding points: {response.text}")
failed_ranges = response.json()["failedRanges"]
if len(failed_ranges) > 0:
raise ValueError(f"Error adding points: {failed_ranges}")
# Return ids
return ids
@property
def embeddings(self) -> Embeddings:
"""Return the embeddings."""
return self._embedding
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by vector ID or other criteria.
Args:
ids: List of ids to delete.
**kwargs: Other keyword arguments that subclasses might use.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
payload = {
"ids": ids,
}
response = requests.delete(
SemaDB.BASE_URL + f"/collections/{self.collection_name}/points",
json=payload,
headers=self.headers,
)
return response.status_code == 200 and len(response.json()["failedPoints"]) == 0
def _search_points(self, embedding: List[float], k: int = 4) -> List[dict]:
"""Search points."""
# Normalise if needed
if self.distance_strategy == DistanceStrategy.COSINE:
vec = np.array(embedding)
vec = vec / np.linalg.norm(vec)
embedding = vec.tolist()
# Perform search request
payload = {
"vector": embedding,
"limit": k,
}
response = requests.post(
SemaDB.BASE_URL + f"/collections/{self.collection_name}/points/search",
json=payload,
headers=self.headers,
)
if response.status_code != 200:
raise ValueError(f"Error searching: {response.text}")
return response.json()["points"]
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
query_embedding = self._embedding.embed_query(query)
return self.similarity_search_by_vector(query_embedding, k=k)
def similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Run similarity search with distance."""
query_embedding = self._embedding.embed_query(query)
points = self._search_points(query_embedding, k=k)
return [
(
Document(page_content=p["metadata"]["text"], metadata=p["metadata"]),
p["distance"],
)
for p in points
]
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query vector.
"""
points = self._search_points(embedding, k=k)
return [
Document(page_content=p["metadata"]["text"], metadata=p["metadata"])
for p in points
]
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
collection_name: str = "",
vector_size: int = 0,
api_key: str = "",
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
**kwargs: Any,
) -> "SemaDB":
"""Return VectorStore initialized from texts and embeddings."""
if not collection_name:
raise ValueError("Collection name must be provided")
if not vector_size:
raise ValueError("Vector size must be provided")
if not api_key:
raise ValueError("API key must be provided")
semadb = cls(
collection_name,
vector_size,
embedding,
distance_strategy=distance_strategy,
api_key=api_key,
)
if not semadb.create_collection():
raise ValueError("Error creating collection")
semadb.add_texts(texts, metadatas=metadatas)
return semadb