mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Check for Tiktoken (#7705)
This commit is contained in:
parent
bae93682f6
commit
fcf98dc4c1
@ -3,7 +3,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -48,6 +48,29 @@ class _EmbeddingDistanceChainMixin(Chain):
|
||||
embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings)
|
||||
distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE)
|
||||
|
||||
@root_validator(pre=False)
|
||||
def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that the TikTok library is installed.
|
||||
|
||||
Args:
|
||||
values (Dict[str, Any]): The values to validate.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The validated values.
|
||||
"""
|
||||
embeddings = values.get("embeddings")
|
||||
if isinstance(embeddings, OpenAIEmbeddings):
|
||||
try:
|
||||
import tiktoken # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The tiktoken library is required to use the default "
|
||||
"OpenAI embeddings with embedding distance evaluators."
|
||||
" Please either manually select a different Embeddings object"
|
||||
" or install tiktoken using `pip install tiktoken`."
|
||||
)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Permit embeddings to go unvalidated."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user