Check for Tiktoken (#7705)

This commit is contained in:
William FH 2023-07-14 09:49:01 -07:00 committed by GitHub
parent bae93682f6
commit fcf98dc4c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,7 +3,7 @@ from enum import Enum
from typing import Any, Dict, List, Optional
import numpy as np
from pydantic import Field
from pydantic import Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
@ -48,6 +48,29 @@ class _EmbeddingDistanceChainMixin(Chain):
embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings)
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
@root_validator(pre=False)
def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that the TikTok library is installed.
Args:
values (Dict[str, Any]): The values to validate.
Returns:
Dict[str, Any]: The validated values.
"""
embeddings = values.get("embeddings")
if isinstance(embeddings, OpenAIEmbeddings):
try:
import tiktoken # noqa: F401
except ImportError:
raise ImportError(
"The tiktoken library is required to use the default "
"OpenAI embeddings with embedding distance evaluators."
" Please either manually select a different Embeddings object"
" or install tiktoken using `pip install tiktoken`."
)
return values
class Config:
"""Permit embeddings to go unvalidated."""