Check for Tiktoken (#7705)

pull/7719/head
William FH 1 year ago committed by GitHub
parent bae93682f6
commit fcf98dc4c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy as np import numpy as np
from pydantic import Field from pydantic import Field, root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -48,6 +48,29 @@ class _EmbeddingDistanceChainMixin(Chain):
embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings) embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings)
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE) distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
@root_validator(pre=False)
def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that the TikTok library is installed.
Args:
values (Dict[str, Any]): The values to validate.
Returns:
Dict[str, Any]: The validated values.
"""
embeddings = values.get("embeddings")
if isinstance(embeddings, OpenAIEmbeddings):
try:
import tiktoken # noqa: F401
except ImportError:
raise ImportError(
"The tiktoken library is required to use the default "
"OpenAI embeddings with embedding distance evaluators."
" Please either manually select a different Embeddings object"
" or install tiktoken using `pip install tiktoken`."
)
return values
class Config: class Config:
"""Permit embeddings to go unvalidated.""" """Permit embeddings to go unvalidated."""

Loading…
Cancel
Save