mirror of https://github.com/hwchase17/langchain
parent
798deaec2b
commit
76aff023d7
@ -0,0 +1,5 @@
|
||||
:mod:`langchain.embeddings`
|
||||
===========================
|
||||
|
||||
.. automodule:: langchain.embeddings
|
||||
:members:
|
@ -0,0 +1,20 @@
|
||||
"""Simple in memory docstore in the form of a dict."""
|
||||
from typing import Dict, Union
|
||||
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class InMemoryDocstore(Docstore):
|
||||
"""Simple in memory docstore in the form of a dict."""
|
||||
|
||||
def __init__(self, _dict: Dict[str, Document]):
|
||||
"""Initialize with dict."""
|
||||
self._dict = _dict
|
||||
|
||||
def search(self, search: str) -> Union[str, Document]:
|
||||
"""Search via direct lookup."""
|
||||
if search not in self._dict:
|
||||
return f"ID {search} not found."
|
||||
else:
|
||||
return self._dict[search]
|
@ -0,0 +1,4 @@
|
||||
"""Wrappers around embedding modules."""
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
__all__ = ["OpenAIEmbeddings"]
|
@ -0,0 +1,15 @@
|
||||
"""Interface for embedding models."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
@ -0,0 +1,84 @@
|
||||
"""Wrapper around OpenAI embedding models."""
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around OpenAI embedding models.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
environment variable ``OPENAI_API_KEY`` set with your API key.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
openai = OpenAIEmbeddings(model_name="davinci")
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "babbage"
|
||||
"""Model name to use."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"Did not find OpenAI API key, please add an environment variable"
|
||||
" `OPENAI_API_KEY` which contains it."
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Embedding
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please it install it with `pip install openai`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint."""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
responses = [
|
||||
self._embedding_func(text, engine=f"text-search-{self.model_name}-doc-001")
|
||||
for text in texts
|
||||
]
|
||||
return responses
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embedding = self._embedding_func(
|
||||
text, engine=f"text-search-{self.model_name}-query-001"
|
||||
)
|
||||
return embedding
|
@ -0,0 +1,86 @@
|
||||
"""Wrapper around FAISS vector database."""
|
||||
from typing import Any, Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class FAISS:
|
||||
"""Wrapper around FAISS vector database.
|
||||
|
||||
To use, you should have the ``faiss`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import FAISS
|
||||
faiss = FAISS(embedding_function, index, docstore)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_function: Callable, index: Any, docstore: Docstore):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding_function = embedding_function
|
||||
self.index = index
|
||||
self.docstore = docstore
|
||||
|
||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
_, indices = self.index.search(np.array([embedding], dtype=np.float32), k)
|
||||
docs = []
|
||||
for i in indices[0]:
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
doc = self.docstore.search(str(i))
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {i}, got {doc}")
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_texts(cls, texts: List[str], embedding: Embeddings) -> "FAISS":
|
||||
"""Construct FAISS wrapper from raw documents.
|
||||
|
||||
This is a user friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates an in memory docstore
|
||||
3. Initializes the FAISS database
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import FAISS
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
faiss = FAISS.from_texts(texts, embeddings)
|
||||
"""
|
||||
try:
|
||||
import faiss
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import faiss python package. "
|
||||
"Please it install it with `pip install faiss` "
|
||||
"or `pip install faiss-cpu` (depending on Python version)."
|
||||
)
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
index = faiss.IndexFlatL2(len(embeddings[0]))
|
||||
index.add(np.array(embeddings, dtype=np.float32))
|
||||
documents = [Document(page_content=text) for text in texts]
|
||||
docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)})
|
||||
return cls(embedding.embed_query, index, docstore)
|
@ -1,3 +1,4 @@
|
||||
-e .
|
||||
# For testing
|
||||
pytest
|
||||
pytest-dotenv
|
||||
|
@ -0,0 +1 @@
|
||||
"""Test embedding integrations."""
|
@ -0,0 +1,19 @@
|
||||
"""Test openai embeddings."""
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
|
||||
def test_openai_embedding_documents() -> None:
|
||||
"""Test openai embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 2048
|
||||
|
||||
|
||||
def test_openai_embedding_query() -> None:
|
||||
"""Test openai embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 2048
|
@ -0,0 +1,47 @@
|
||||
"""Test FAISS functionality."""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.faiss import FAISS
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [[i] * 10 for i in range(len(texts))]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return simple embeddings."""
|
||||
return [0] * 10
|
||||
|
||||
|
||||
def test_faiss() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
expected_docstore = InMemoryDocstore(
|
||||
{
|
||||
"0": Document(page_content="foo"),
|
||||
"1": Document(page_content="bar"),
|
||||
"2": Document(page_content="baz"),
|
||||
}
|
||||
)
|
||||
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_faiss_search_not_found() -> None:
|
||||
"""Test what happens when document is not found."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = FAISS.from_texts(texts, FakeEmbeddings())
|
||||
# Get rid of the docstore to purposefully induce errors.
|
||||
docsearch.docstore = InMemoryDocstore({})
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.similarity_search("foo")
|
@ -0,0 +1,21 @@
|
||||
"""Test in memory docstore."""
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
|
||||
|
||||
def test_document_found() -> None:
|
||||
"""Test document found."""
|
||||
_dict = {"foo": Document(page_content="bar")}
|
||||
docstore = InMemoryDocstore(_dict)
|
||||
output = docstore.search("foo")
|
||||
assert isinstance(output, Document)
|
||||
assert output.page_content == "bar"
|
||||
|
||||
|
||||
def test_document_not_found() -> None:
|
||||
"""Test when document is not found."""
|
||||
_dict = {"foo": Document(page_content="bar")}
|
||||
docstore = InMemoryDocstore(_dict)
|
||||
output = docstore.search("bar")
|
||||
assert output == "ID bar not found."
|
Loading…
Reference in New Issue