From fcf98dc4c12095fc97e4ad0599b1ae7ea0b1191b Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 14 Jul 2023 09:49:01 -0700 Subject: [PATCH] Check for Tiktoken (#7705) --- .../evaluation/embedding_distance/base.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/langchain/evaluation/embedding_distance/base.py b/langchain/evaluation/embedding_distance/base.py index f25502809a..68c77b3c10 100644 --- a/langchain/evaluation/embedding_distance/base.py +++ b/langchain/evaluation/embedding_distance/base.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Any, Dict, List, Optional import numpy as np -from pydantic import Field +from pydantic import Field, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, @@ -48,6 +48,29 @@ class _EmbeddingDistanceChainMixin(Chain): embeddings: Embeddings = Field(default_factory=OpenAIEmbeddings) distance_metric: EmbeddingDistance = Field(default=EmbeddingDistance.COSINE) + @root_validator(pre=False) + def _validate_tiktoken_installed(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate that the TikTok library is installed. + + Args: + values (Dict[str, Any]): The values to validate. + + Returns: + Dict[str, Any]: The validated values. + """ + embeddings = values.get("embeddings") + if isinstance(embeddings, OpenAIEmbeddings): + try: + import tiktoken # noqa: F401 + except ImportError: + raise ImportError( + "The tiktoken library is required to use the default " + "OpenAI embeddings with embedding distance evaluators." + " Please either manually select a different Embeddings object" + " or install tiktoken using `pip install tiktoken`." + ) + return values + class Config: """Permit embeddings to go unvalidated."""