mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
9b2f9ee952
**Description**: this PR enable VectorStore autoconfiguration for Infinispan: if metadatas are only of basic types, protobuf config will be automatically generated for the user.
615 lines
21 KiB
Python
615 lines
21 KiB
Python
"""Module providing Infinispan as a VectorStore"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from typing import Any, Iterable, List, Optional, Tuple, Type, cast
|
|
|
|
import requests
|
|
from langchain_core.documents import Document
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.vectorstores import VectorStore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class InfinispanVS(VectorStore):
|
|
"""`Infinispan` VectorStore interface.
|
|
|
|
This class exposes the method to present Infinispan as a
|
|
VectorStore. It relies on the Infinispan class (below) which takes care
|
|
of the REST interface with the server.
|
|
|
|
Example:
|
|
... code-block:: python
|
|
from langchain_community.vectorstores import InfinispanVS
|
|
from mymodels import RGBEmbeddings
|
|
...
|
|
vectorDb = InfinispanVS.from_documents(docs,
|
|
embedding=RGBEmbeddings(),
|
|
output_fields=["texture", "color"],
|
|
lambda_key=lambda text,meta: str(meta["_key"]),
|
|
lambda_content=lambda item: item["color"])
|
|
|
|
or an empty InfinispanVS instance can be created if preliminary setup
|
|
is required before populating the store
|
|
|
|
... code-block:: python
|
|
from langchain_community.vectorstores import InfinispanVS
|
|
from mymodels import RGBEmbeddings
|
|
...
|
|
ispnVS = InfinispanVS()
|
|
# configure Infinispan here
|
|
# i.e. create cache and schema
|
|
|
|
# then populate the store
|
|
vectorDb = InfinispanVS.from_documents(docs,
|
|
embedding=RGBEmbeddings(),
|
|
output_fields: ["texture", "color"],
|
|
lambda_key: lambda text,meta: str(meta["_key"]),
|
|
lambda_content: lambda item: item["color"]})
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding: Optional[Embeddings] = None,
|
|
ids: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
):
|
|
self.ispn = Infinispan(**kwargs)
|
|
self._configuration = kwargs
|
|
self._cache_name = str(self._configuration.get("cache_name", "vector"))
|
|
self._entity_name = str(self._configuration.get("entity_name", "vector"))
|
|
self._embedding = embedding
|
|
self._textfield = self._configuration.get("textfield", "text")
|
|
self._vectorfield = self._configuration.get("vectorfield", "vector")
|
|
self._to_content = self._configuration.get(
|
|
"lambda_content", lambda item: self._default_content(item)
|
|
)
|
|
self._to_metadata = self._configuration.get(
|
|
"lambda_metadata", lambda item: self._default_metadata(item)
|
|
)
|
|
self._output_fields = self._configuration.get("output_fields")
|
|
self._ids = ids
|
|
|
|
def _default_metadata(self, item: dict) -> dict:
|
|
meta = dict(item)
|
|
meta.pop(self._vectorfield, None)
|
|
meta.pop(self._textfield, None)
|
|
meta.pop("_type", None)
|
|
return meta
|
|
|
|
def _default_content(self, item: dict[str, Any]) -> Any:
|
|
return item.get(self._textfield)
|
|
|
|
def schema_builder(self, templ: dict, dimension: int) -> str:
|
|
metadata_proto_tpl = """
|
|
/**
|
|
* @Indexed
|
|
*/
|
|
message %s {
|
|
/**
|
|
* @Vector(dimension=%d)
|
|
*/
|
|
repeated float %s = 1;
|
|
"""
|
|
metadata_proto = metadata_proto_tpl % (
|
|
self._entity_name,
|
|
dimension,
|
|
self._vectorfield,
|
|
)
|
|
idx = 2
|
|
for f, v in templ.items():
|
|
if isinstance(v, str):
|
|
metadata_proto += "optional string " + f + " = " + str(idx) + ";\n"
|
|
elif isinstance(v, int):
|
|
metadata_proto += "optional int64 " + f + " = " + str(idx) + ";\n"
|
|
elif isinstance(v, float):
|
|
metadata_proto += "optional double " + f + " = " + str(idx) + ";\n"
|
|
elif isinstance(v, bytes):
|
|
metadata_proto += "optional bytes " + f + " = " + str(idx) + ";\n"
|
|
elif isinstance(v, bool):
|
|
metadata_proto += "optional bool " + f + " = " + str(idx) + ";\n"
|
|
else:
|
|
raise Exception(
|
|
"Unable to build proto schema for metadata. "
|
|
"Unhandled type for field: " + f
|
|
)
|
|
idx += 1
|
|
metadata_proto += "}\n"
|
|
return metadata_proto
|
|
|
|
def schema_create(self, proto: str) -> requests.Response:
|
|
"""Deploy the schema for the vector db
|
|
Args:
|
|
proto(str): protobuf schema
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
return self.ispn.schema_post(self._entity_name + ".proto", proto)
|
|
|
|
def schema_delete(self) -> requests.Response:
|
|
"""Delete the schema for the vector db
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
return self.ispn.schema_delete(self._entity_name + ".proto")
|
|
|
|
def cache_create(self, config: str = "") -> requests.Response:
|
|
"""Create the cache for the vector db
|
|
Args:
|
|
config(str): configuration of the cache.
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
if config == "":
|
|
config = (
|
|
'''
|
|
{
|
|
"distributed-cache": {
|
|
"owners": "2",
|
|
"mode": "SYNC",
|
|
"statistics": true,
|
|
"encoding": {
|
|
"media-type": "application/x-protostream"
|
|
},
|
|
"indexing": {
|
|
"enabled": true,
|
|
"storage": "filesystem",
|
|
"startup-mode": "AUTO",
|
|
"indexing-mode": "AUTO",
|
|
"indexed-entities": [
|
|
"'''
|
|
+ self._entity_name
|
|
+ """"
|
|
]
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
return self.ispn.cache_post(self._cache_name, config)
|
|
|
|
def cache_delete(self) -> requests.Response:
|
|
"""Delete the cache for the vector db
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
return self.ispn.cache_delete(self._cache_name)
|
|
|
|
def cache_clear(self) -> requests.Response:
|
|
"""Clear the cache for the vector db
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
return self.ispn.cache_clear(self._cache_name)
|
|
|
|
def cache_exists(self) -> bool:
|
|
"""Checks if the cache exists
|
|
Returns:
|
|
true if exists
|
|
"""
|
|
return self.ispn.cache_exists(self._cache_name)
|
|
|
|
def cache_index_clear(self) -> requests.Response:
|
|
"""Clear the index for the vector db
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
return self.ispn.index_clear(self._cache_name)
|
|
|
|
def cache_index_reindex(self) -> requests.Response:
|
|
"""Rebuild the for the vector db
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
return self.ispn.index_reindex(self._cache_name)
|
|
|
|
def add_texts(
|
|
self,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]] = None,
|
|
last_vector: Optional[List[float]] = None,
|
|
**kwargs: Any,
|
|
) -> List[str]:
|
|
result = []
|
|
texts_l = list(texts)
|
|
if last_vector:
|
|
texts_l.pop()
|
|
embeds = self._embedding.embed_documents(texts_l) # type: ignore
|
|
if last_vector:
|
|
embeds.append(last_vector)
|
|
if not metadatas:
|
|
metadatas = [{} for _ in texts]
|
|
ids = self._ids or [str(uuid.uuid4()) for _ in texts]
|
|
data_input = list(zip(metadatas, embeds, ids))
|
|
for metadata, embed, key in data_input:
|
|
data = {"_type": self._entity_name, self._vectorfield: embed}
|
|
data.update(metadata)
|
|
data_str = json.dumps(data)
|
|
self.ispn.put(key, data_str, self._cache_name)
|
|
result.append(key)
|
|
return result
|
|
|
|
def similarity_search(
|
|
self, query: str, k: int = 4, **kwargs: Any
|
|
) -> List[Document]:
|
|
"""Return docs most similar to query."""
|
|
documents = self.similarity_search_with_score(query=query, k=k)
|
|
return [doc for doc, _ in documents]
|
|
|
|
def similarity_search_with_score(
|
|
self, query: str, k: int = 4, **kwargs: Any
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Perform a search on a query string and return results with score.
|
|
|
|
Args:
|
|
query (str): The text being searched.
|
|
k (int, optional): The amount of results to return. Defaults to 4.
|
|
|
|
Returns:
|
|
List[Tuple[Document, float]]
|
|
"""
|
|
embed = self._embedding.embed_query(query) # type: ignore
|
|
documents = self.similarity_search_with_score_by_vector(embedding=embed, k=k)
|
|
return documents
|
|
|
|
def similarity_search_by_vector(
|
|
self, embedding: List[float], k: int = 4, **kwargs: Any
|
|
) -> List[Document]:
|
|
res = self.similarity_search_with_score_by_vector(embedding, k)
|
|
return [doc for doc, _ in res]
|
|
|
|
def similarity_search_with_score_by_vector(
|
|
self, embedding: List[float], k: int = 4
|
|
) -> List[Tuple[Document, float]]:
|
|
"""Return docs most similar to embedding vector.
|
|
|
|
Args:
|
|
embedding: Embedding to look up documents similar to.
|
|
k: Number of Documents to return. Defaults to 4.
|
|
|
|
Returns:
|
|
List of pair (Documents, score) most similar to the query vector.
|
|
"""
|
|
if self._output_fields is None:
|
|
query_str = (
|
|
"select v, score(v) from "
|
|
+ self._entity_name
|
|
+ " v where v."
|
|
+ self._vectorfield
|
|
+ " <-> "
|
|
+ json.dumps(embedding)
|
|
+ "~"
|
|
+ str(k)
|
|
)
|
|
else:
|
|
query_proj = "select "
|
|
for field in self._output_fields[:-1]:
|
|
query_proj = query_proj + "v." + field + ","
|
|
query_proj = query_proj + "v." + self._output_fields[-1]
|
|
query_str = (
|
|
query_proj
|
|
+ ", score(v) from "
|
|
+ self._entity_name
|
|
+ " v where v."
|
|
+ self._vectorfield
|
|
+ " <-> "
|
|
+ json.dumps(embedding)
|
|
+ "~"
|
|
+ str(k)
|
|
)
|
|
query_res = self.ispn.req_query(query_str, self._cache_name)
|
|
result = json.loads(query_res.text)
|
|
return self._query_result_to_docs(result)
|
|
|
|
def _query_result_to_docs(
|
|
self, result: dict[str, Any]
|
|
) -> List[Tuple[Document, float]]:
|
|
documents = []
|
|
for row in result["hits"]:
|
|
hit = row["hit"] or {}
|
|
if self._output_fields is None:
|
|
entity = hit["*"]
|
|
else:
|
|
entity = {key: hit.get(key) for key in self._output_fields}
|
|
doc = Document(
|
|
page_content=self._to_content(entity),
|
|
metadata=self._to_metadata(entity),
|
|
)
|
|
documents.append((doc, hit["score()"]))
|
|
return documents
|
|
|
|
def configure(self, metadata: dict, dimension: int) -> None:
|
|
schema = self.schema_builder(metadata, dimension)
|
|
output = self.schema_create(schema)
|
|
assert output.ok, "Unable to create schema. Already exists? "
|
|
"Consider using clear_old=True"
|
|
assert json.loads(output.text)["error"] is None
|
|
if not self.cache_exists():
|
|
output = self.cache_create()
|
|
assert output.ok, "Unable to create cache. Already exists? "
|
|
"Consider using clear_old=True"
|
|
# Ensure index is clean
|
|
self.cache_index_clear()
|
|
|
|
def config_clear(self) -> None:
|
|
self.schema_delete()
|
|
self.cache_delete()
|
|
|
|
@classmethod
|
|
def from_texts(
|
|
cls: Type[InfinispanVS],
|
|
texts: List[str],
|
|
embedding: Embeddings,
|
|
metadatas: Optional[List[dict]] = None,
|
|
ids: Optional[List[str]] = None,
|
|
clear_old: Optional[bool] = True,
|
|
auto_config: Optional[bool] = True,
|
|
**kwargs: Any,
|
|
) -> InfinispanVS:
|
|
"""Return VectorStore initialized from texts and embeddings."""
|
|
infinispanvs = cls(embedding=embedding, ids=ids, **kwargs)
|
|
if auto_config and len(metadatas or []) > 0:
|
|
if clear_old:
|
|
infinispanvs.config_clear()
|
|
vec = embedding.embed_query(texts[len(texts) - 1])
|
|
metadatas = cast(List[dict], metadatas)
|
|
infinispanvs.configure(metadatas[0], len(vec))
|
|
else:
|
|
if clear_old:
|
|
infinispanvs.cache_clear()
|
|
vec = embedding.embed_query(texts[len(texts) - 1])
|
|
if texts:
|
|
infinispanvs.add_texts(texts, metadatas, vector=vec)
|
|
return infinispanvs
|
|
|
|
|
|
REST_TIMEOUT = 10
|
|
|
|
|
|
class Infinispan:
|
|
"""Helper class for `Infinispan` REST interface.
|
|
|
|
This class exposes the Infinispan operations needed to
|
|
create and set up a vector db.
|
|
|
|
You need a running Infinispan (15+) server without authentication.
|
|
You can easily start one, see:
|
|
https://github.com/rigazilla/infinispan-vector#run-infinispan
|
|
"""
|
|
|
|
def __init__(self, **kwargs: Any):
|
|
self._configuration = kwargs
|
|
self._schema = str(self._configuration.get("schema", "http"))
|
|
self._host = str(self._configuration.get("hosts", ["127.0.0.1:11222"])[0])
|
|
self._default_node = self._schema + "://" + self._host
|
|
self._cache_url = str(self._configuration.get("cache_url", "/rest/v2/caches"))
|
|
self._schema_url = str(self._configuration.get("cache_url", "/rest/v2/schemas"))
|
|
self._use_post_for_query = str(
|
|
self._configuration.get("use_post_for_query", True)
|
|
)
|
|
|
|
def req_query(
|
|
self, query: str, cache_name: str, local: bool = False
|
|
) -> requests.Response:
|
|
"""Request a query
|
|
Args:
|
|
query(str): query requested
|
|
cache_name(str): name of the target cache
|
|
local(boolean): whether the query is local to clustered
|
|
Returns:
|
|
An http Response containing the result set or errors
|
|
"""
|
|
if self._use_post_for_query:
|
|
return self._query_post(query, cache_name, local)
|
|
return self._query_get(query, cache_name, local)
|
|
|
|
def _query_post(
|
|
self, query_str: str, cache_name: str, local: bool = False
|
|
) -> requests.Response:
|
|
api_url = (
|
|
self._default_node
|
|
+ self._cache_url
|
|
+ "/"
|
|
+ cache_name
|
|
+ "?action=search&local="
|
|
+ str(local)
|
|
)
|
|
data = {"query": query_str}
|
|
data_json = json.dumps(data)
|
|
response = requests.post(
|
|
api_url,
|
|
data_json,
|
|
headers={"Content-Type": "application/json"},
|
|
timeout=REST_TIMEOUT,
|
|
)
|
|
return response
|
|
|
|
def _query_get(
|
|
self, query_str: str, cache_name: str, local: bool = False
|
|
) -> requests.Response:
|
|
api_url = (
|
|
self._default_node
|
|
+ self._cache_url
|
|
+ "/"
|
|
+ cache_name
|
|
+ "?action=search&query="
|
|
+ query_str
|
|
+ "&local="
|
|
+ str(local)
|
|
)
|
|
response = requests.get(api_url, timeout=REST_TIMEOUT)
|
|
return response
|
|
|
|
def post(self, key: str, data: str, cache_name: str) -> requests.Response:
|
|
"""Post an entry
|
|
Args:
|
|
key(str): key of the entry
|
|
data(str): content of the entry in json format
|
|
cache_name(str): target cache
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
api_url = self._default_node + self._cache_url + "/" + cache_name + "/" + key
|
|
response = requests.post(
|
|
api_url,
|
|
data,
|
|
headers={"Content-Type": "application/json"},
|
|
timeout=REST_TIMEOUT,
|
|
)
|
|
return response
|
|
|
|
def put(self, key: str, data: str, cache_name: str) -> requests.Response:
|
|
"""Put an entry
|
|
Args:
|
|
key(str): key of the entry
|
|
data(str): content of the entry in json format
|
|
cache_name(str): target cache
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
api_url = self._default_node + self._cache_url + "/" + cache_name + "/" + key
|
|
response = requests.put(
|
|
api_url,
|
|
data,
|
|
headers={"Content-Type": "application/json"},
|
|
timeout=REST_TIMEOUT,
|
|
)
|
|
return response
|
|
|
|
def get(self, key: str, cache_name: str) -> requests.Response:
|
|
"""Get an entry
|
|
Args:
|
|
key(str): key of the entry
|
|
cache_name(str): target cache
|
|
Returns:
|
|
An http Response containing the entry or errors
|
|
"""
|
|
api_url = self._default_node + self._cache_url + "/" + cache_name + "/" + key
|
|
response = requests.get(
|
|
api_url, headers={"Content-Type": "application/json"}, timeout=REST_TIMEOUT
|
|
)
|
|
return response
|
|
|
|
def schema_post(self, name: str, proto: str) -> requests.Response:
|
|
"""Deploy a schema
|
|
Args:
|
|
name(str): name of the schema. Will be used as a key
|
|
proto(str): protobuf schema
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
api_url = self._default_node + self._schema_url + "/" + name
|
|
response = requests.post(api_url, proto, timeout=REST_TIMEOUT)
|
|
return response
|
|
|
|
def cache_post(self, name: str, config: str) -> requests.Response:
|
|
"""Create a cache
|
|
Args:
|
|
name(str): name of the cache.
|
|
config(str): configuration of the cache.
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
api_url = self._default_node + self._cache_url + "/" + name
|
|
response = requests.post(
|
|
api_url,
|
|
config,
|
|
headers={"Content-Type": "application/json"},
|
|
timeout=REST_TIMEOUT,
|
|
)
|
|
return response
|
|
|
|
def schema_delete(self, name: str) -> requests.Response:
|
|
"""Delete a schema
|
|
Args:
|
|
name(str): name of the schema.
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
api_url = self._default_node + self._schema_url + "/" + name
|
|
response = requests.delete(api_url, timeout=REST_TIMEOUT)
|
|
return response
|
|
|
|
def cache_delete(self, name: str) -> requests.Response:
|
|
"""Delete a cache
|
|
Args:
|
|
name(str): name of the cache.
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
api_url = self._default_node + self._cache_url + "/" + name
|
|
response = requests.delete(api_url, timeout=REST_TIMEOUT)
|
|
return response
|
|
|
|
def cache_clear(self, cache_name: str) -> requests.Response:
|
|
"""Clear a cache
|
|
Args:
|
|
cache_name(str): name of the cache.
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
api_url = (
|
|
self._default_node + self._cache_url + "/" + cache_name + "?action=clear"
|
|
)
|
|
response = requests.post(api_url, timeout=REST_TIMEOUT)
|
|
return response
|
|
|
|
def cache_exists(self, cache_name: str) -> bool:
|
|
"""Check if a cache exists
|
|
Args:
|
|
cache_name(str): name of the cache.
|
|
Returns:
|
|
True if cache exists
|
|
"""
|
|
api_url = (
|
|
self._default_node + self._cache_url + "/" + cache_name + "?action=clear"
|
|
)
|
|
return self.resource_exists(api_url)
|
|
|
|
@staticmethod
|
|
def resource_exists(api_url: str) -> bool:
|
|
"""Check if a resource exists
|
|
Args:
|
|
api_url(str): url of the resource.
|
|
Returns:
|
|
true if resource exists
|
|
"""
|
|
response = requests.head(api_url, timeout=REST_TIMEOUT)
|
|
return response.ok
|
|
|
|
def index_clear(self, cache_name: str) -> requests.Response:
|
|
"""Clear an index on a cache
|
|
Args:
|
|
cache_name(str): name of the cache.
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
api_url = (
|
|
self._default_node
|
|
+ self._cache_url
|
|
+ "/"
|
|
+ cache_name
|
|
+ "/search/indexes?action=clear"
|
|
)
|
|
return requests.post(api_url, timeout=REST_TIMEOUT)
|
|
|
|
def index_reindex(self, cache_name: str) -> requests.Response:
|
|
"""Rebuild index on a cache
|
|
Args:
|
|
cache_name(str): name of the cache.
|
|
Returns:
|
|
An http Response containing the result of the operation
|
|
"""
|
|
api_url = (
|
|
self._default_node
|
|
+ self._cache_url
|
|
+ "/"
|
|
+ cache_name
|
|
+ "/search/indexes?action=reindex"
|
|
)
|
|
return requests.post(api_url, timeout=REST_TIMEOUT)
|