|
|
@ -3,7 +3,7 @@ from enum import Enum
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
from pydantic import Field
|
|
|
|
from pydantic import Field, root_validator
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
AsyncCallbackManagerForChainRun,
|
|
|
|
AsyncCallbackManagerForChainRun,
|
|
|
@ -48,6 +48,29 @@ class _EmbeddingDistanceChainMixin(Chain):
|
|
|
|
embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings)
|
|
|
|
embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings)
|
|
|
|
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
|
|
|
|
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@root_validator(pre=False)
|
|
|
|
|
|
|
|
def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
"""Validate that the TikTok library is installed.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
values (Dict[str, Any]): The values to validate.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
Dict[str, Any]: The validated values.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
embeddings = values.get("embeddings")
|
|
|
|
|
|
|
|
if isinstance(embeddings, OpenAIEmbeddings):
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
import tiktoken # noqa: F401
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
|
|
raise ImportError(
|
|
|
|
|
|
|
|
"The tiktoken library is required to use the default "
|
|
|
|
|
|
|
|
"OpenAI embeddings with embedding distance evaluators."
|
|
|
|
|
|
|
|
" Please either manually select a different Embeddings object"
|
|
|
|
|
|
|
|
" or install tiktoken using `pip install tiktoken`."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
class Config:
|
|
|
|
"""Permit embeddings to go unvalidated."""
|
|
|
|
"""Permit embeddings to go unvalidated."""
|
|
|
|
|
|
|
|
|
|
|
|