Update VectorStore interface to contain from_texts, enforce common in… (#97)

…terface
pull/98/head
Samantha Whitmore 2 years ago committed by GitHub
parent 61f12229df
commit 2ddab88c06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "965eecee",
"metadata": {},
"outputs": [],
@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "68481687",
"metadata": {},
"outputs": [],
@ -30,7 +30,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "015f4ff5",
"metadata": {},
"outputs": [],
@ -43,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "67baf32e",
"metadata": {},
"outputs": [
@ -69,12 +69,12 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "4906b8a3",
"metadata": {},
"outputs": [],
"source": [
"docsearch = ElasticVectorSearch.from_texts(\"http://localhost:9200\", texts, embeddings)\n",
"docsearch = ElasticVectorSearch.from_texts(texts, embeddings, elasticsearch_url=\"http://localhost:9200\")\n",
"\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = docsearch.similarity_search(query)"
@ -82,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "95f9eee9",
"metadata": {},
"outputs": [

@ -1,8 +1,9 @@
"""Interface for vector stores."""
from abc import ABC, abstractmethod
from typing import List
from typing import Any, List
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
class VectorStore(ABC):
@ -11,3 +12,10 @@ class VectorStore(ABC):
@abstractmethod
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
"""Return docs most similar to query."""
@classmethod
@abstractmethod
def from_texts(
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
) -> "VectorStore":
"""Return VectorStore initialized from texts and embeddings."""

@ -1,6 +1,7 @@
"""Wrapper around Elasticsearch vector database."""
import os
import uuid
from typing import Callable, Dict, List
from typing import Any, Callable, Dict, List
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
@ -46,7 +47,7 @@ class ElasticVectorSearch(VectorStore):
def __init__(
self,
elastic_url: str,
elasticsearch_url: str,
index_name: str,
mapping: Dict,
embedding_function: Callable,
@ -62,7 +63,7 @@ class ElasticVectorSearch(VectorStore):
self.embedding_function = embedding_function
self.index_name = index_name
try:
es_client = elasticsearch.Elasticsearch(elastic_url) # noqa
es_client = elasticsearch.Elasticsearch(elasticsearch_url) # noqa
except ValueError as e:
raise ValueError(
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
@ -89,7 +90,7 @@ class ElasticVectorSearch(VectorStore):
@classmethod
def from_texts(
cls, elastic_url: str, texts: List[str], embedding: Embeddings
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
) -> "ElasticVectorSearch":
"""Construct ElasticVectorSearch wrapper from raw documents.
@ -107,11 +108,21 @@ class ElasticVectorSearch(VectorStore):
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
elastic_vector_search = ElasticVectorSearch.from_texts(
"http://localhost:9200",
texts,
embeddings
embeddings,
elasticsearch_url="http://localhost:9200"
)
"""
elasticsearch_url = kwargs.get("elasticsearch_url")
if not elasticsearch_url:
elasticsearch_url = os.environ.get("ELASTICSEARCH_URL")
if elasticsearch_url is None or elasticsearch_url == "":
raise ValueError(
"Did not find Elasticsearch URL, please add an environment variable"
" `ELASTICSEARCH_URL` which contains it, or pass"
" `elasticsearch_url` as a named parameter."
)
try:
import elasticsearch
from elasticsearch.helpers import bulk
@ -121,7 +132,7 @@ class ElasticVectorSearch(VectorStore):
"Please install it with `pip install elasticearch`."
)
try:
client = elasticsearch.Elasticsearch(elastic_url)
client = elasticsearch.Elasticsearch(elasticsearch_url)
except ValueError as e:
raise ValueError(
"Your elasticsearch client string is misformatted. " f"Got error: {e} "
@ -144,4 +155,4 @@ class ElasticVectorSearch(VectorStore):
requests.append(request)
bulk(client, requests)
client.indices.refresh(index=index_name)
return cls(elastic_url, index_name, mapping, embedding.embed_query)
return cls(elasticsearch_url, index_name, mapping, embedding.embed_query)

@ -53,7 +53,9 @@ class FAISS(VectorStore):
return docs
@classmethod
def from_texts(cls, texts: List[str], embedding: Embeddings) -> "FAISS":
def from_texts(
cls, texts: List[str], embedding: Embeddings, **kwargs: Any
) -> "FAISS":
"""Construct FAISS wrapper from raw documents.
This is a user friendly interface that:

Loading…
Cancel
Save