diff --git a/langchain/document_transformers/embeddings_redundant_filter.py b/langchain/document_transformers/embeddings_redundant_filter.py index e0af22ab29..ba57f2ef72 100644 --- a/langchain/document_transformers/embeddings_redundant_filter.py +++ b/langchain/document_transformers/embeddings_redundant_filter.py @@ -5,8 +5,8 @@ import numpy as np from pydantic import BaseModel, Field from langchain.embeddings.base import Embeddings -from langchain.math_utils import cosine_similarity from langchain.schema import BaseDocumentTransformer, Document +from langchain.utils.math import cosine_similarity class _DocumentWithState(Document): diff --git a/langchain/evaluation/embedding_distance/base.py b/langchain/evaluation/embedding_distance/base.py index fc7ba51e87..3591f45d8b 100644 --- a/langchain/evaluation/embedding_distance/base.py +++ b/langchain/evaluation/embedding_distance/base.py @@ -14,8 +14,8 @@ from langchain.chains.base import Chain from langchain.embeddings.base import Embeddings from langchain.embeddings.openai import OpenAIEmbeddings from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator -from langchain.math_utils import cosine_similarity from langchain.schema import RUN_KEY +from langchain.utils.math import cosine_similarity class EmbeddingDistance(str, Enum): diff --git a/langchain/retrievers/document_compressors/embeddings_filter.py b/langchain/retrievers/document_compressors/embeddings_filter.py index 589ff46569..fb49ca02ea 100644 --- a/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/langchain/retrievers/document_compressors/embeddings_filter.py @@ -9,11 +9,11 @@ from langchain.document_transformers.embeddings_redundant_filter import ( get_stateful_documents, ) from langchain.embeddings.base import Embeddings -from langchain.math_utils import cosine_similarity from langchain.retrievers.document_compressors.base import ( BaseDocumentCompressor, ) from langchain.schema import Document +from langchain.utils.math import cosine_similarity class EmbeddingsFilter(BaseDocumentCompressor): diff --git a/langchain/sql_database.py b/langchain/sql_database.py index 19d9f81165..43a43e86fd 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -10,7 +10,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.schema import CreateTable -from langchain import utils +from langchain.utils import get_from_env def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: @@ -192,13 +192,11 @@ class SQLDatabase: default_host = context.browserHostName if context else None if host is None: - host = utils.get_from_env("host", "DATABRICKS_HOST", default_host) + host = get_from_env("host", "DATABRICKS_HOST", default_host) default_api_token = context.apiToken if context else None if api_token is None: - api_token = utils.get_from_env( - "api_token", "DATABRICKS_TOKEN", default_api_token - ) + api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token) if warehouse_id is None and cluster_id is None: if context: diff --git a/langchain/utils/__init__.py b/langchain/utils/__init__.py new file mode 100644 index 0000000000..e3db0ddcac --- /dev/null +++ b/langchain/utils/__init__.py @@ -0,0 +1,33 @@ +""" +Utility functions for langchain. + +These functions do not depend on any other langchain modules. +""" + +from langchain.utils.env import get_from_dict_or_env, get_from_env +from langchain.utils.math import cosine_similarity, cosine_similarity_top_k +from langchain.utils.strings import comma_list, stringify_dict, stringify_value +from langchain.utils.utils import ( + check_package_version, + get_pydantic_field_names, + guard_import, + mock_now, + raise_for_status_with_text, + xor_args, +) + +__all__ = [ + "check_package_version", + "comma_list", + "cosine_similarity", + "cosine_similarity_top_k", + "get_from_dict_or_env", + "get_from_env", + "get_pydantic_field_names", + "guard_import", + "mock_now", + "raise_for_status_with_text", + "stringify_dict", + "stringify_value", + "xor_args", +] diff --git a/langchain/utils/env.py b/langchain/utils/env.py new file mode 100644 index 0000000000..f9ac5aba92 --- /dev/null +++ b/langchain/utils/env.py @@ -0,0 +1,26 @@ +import os +from typing import Any, Dict, Optional + + +def get_from_dict_or_env( + data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None +) -> str: + """Get a value from a dictionary or an environment variable.""" + if key in data and data[key]: + return data[key] + else: + return get_from_env(key, env_key, default=default) + + +def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: + """Get a value from a dictionary or an environment variable.""" + if env_key in os.environ and os.environ[env_key]: + return os.environ[env_key] + elif default is not None: + return default + else: + raise ValueError( + f"Did not find {key}, please add an environment variable" + f" `{env_key}` which contains it, or pass" + f" `{key}` as a named parameter." + ) diff --git a/langchain/math_utils.py b/langchain/utils/math.py similarity index 100% rename from langchain/math_utils.py rename to langchain/utils/math.py diff --git a/langchain/utils/strings.py b/langchain/utils/strings.py new file mode 100644 index 0000000000..24741257ac --- /dev/null +++ b/langchain/utils/strings.py @@ -0,0 +1,39 @@ +from typing import Any, List + + +def stringify_value(val: Any) -> str: + """Stringify a value. + + Args: + val: The value to stringify. + + Returns: + str: The stringified value. + """ + if isinstance(val, str): + return val + elif isinstance(val, dict): + return "\n" + stringify_dict(val) + elif isinstance(val, list): + return "\n".join(stringify_value(v) for v in val) + else: + return str(val) + + +def stringify_dict(data: dict) -> str: + """Stringify a dictionary. + + Args: + data: The dictionary to stringify. + + Returns: + str: The stringified dictionary. + """ + text = "" + for key, value in data.items(): + text += key + ": " + stringify_value(value) + "\n" + return text + + +def comma_list(items: List[Any]) -> str: + return ", ".join(str(item) for item in items) diff --git a/langchain/utils.py b/langchain/utils/utils.py similarity index 72% rename from langchain/utils.py rename to langchain/utils/utils.py index f35a6f3aac..a9390d6a66 100644 --- a/langchain/utils.py +++ b/langchain/utils/utils.py @@ -2,38 +2,13 @@ import contextlib import datetime import importlib -import os from importlib.metadata import version -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Optional, Set, Tuple from packaging.version import parse from requests import HTTPError, Response -def get_from_dict_or_env( - data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None -) -> str: - """Get a value from a dictionary or an environment variable.""" - if key in data and data[key]: - return data[key] - else: - return get_from_env(key, env_key, default=default) - - -def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: - """Get a value from a dictionary or an environment variable.""" - if env_key in os.environ and os.environ[env_key]: - return os.environ[env_key] - elif default is not None: - return default - else: - raise ValueError( - f"Did not find {key}, please add an environment variable" - f" `{env_key}` which contains it, or pass" - f" `{key}` as a named parameter." - ) - - def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: """Validate specified keyword args are mutually exclusive.""" @@ -67,44 +42,6 @@ def raise_for_status_with_text(response: Response) -> None: raise ValueError(response.text) from e -def stringify_value(val: Any) -> str: - """Stringify a value. - - Args: - val: The value to stringify. - - Returns: - str: The stringified value. - """ - if isinstance(val, str): - return val - elif isinstance(val, dict): - return "\n" + stringify_dict(val) - elif isinstance(val, list): - return "\n".join(stringify_value(v) for v in val) - else: - return str(val) - - -def stringify_dict(data: dict) -> str: - """Stringify a dictionary. - - Args: - data: The dictionary to stringify. - - Returns: - str: The stringified dictionary. - """ - text = "" - for key, value in data.items(): - text += key + ": " + stringify_value(value) + "\n" - return text - - -def comma_list(items: List[Any]) -> str: - return ", ".join(str(item) for item in items) - - @contextlib.contextmanager def mock_now(dt_value): # type: ignore """Context manager for mocking out datetime.now() in unit tests. diff --git a/langchain/vectorstores/utils.py b/langchain/vectorstores/utils.py index a44560e391..f7a64389e1 100644 --- a/langchain/vectorstores/utils.py +++ b/langchain/vectorstores/utils.py @@ -5,7 +5,7 @@ from typing import List import numpy as np -from langchain.math_utils import cosine_similarity +from langchain.utils.math import cosine_similarity class DistanceStrategy(str, Enum): diff --git a/tests/unit_tests/test_document_transformers.py b/tests/unit_tests/test_document_transformers.py index 0d4d1014fd..f589055999 100644 --- a/tests/unit_tests/test_document_transformers.py +++ b/tests/unit_tests/test_document_transformers.py @@ -2,7 +2,7 @@ from langchain.document_transformers.embeddings_redundant_filter import ( _filter_similar_embeddings, ) -from langchain.math_utils import cosine_similarity +from langchain.utils.math import cosine_similarity def test__filter_similar_embeddings() -> None: diff --git a/tests/unit_tests/test_math_utils.py b/tests/unit_tests/test_math_utils.py index 6b9126fe7b..a64d52a73a 100644 --- a/tests/unit_tests/test_math_utils.py +++ b/tests/unit_tests/test_math_utils.py @@ -4,7 +4,7 @@ from typing import List import numpy as np import pytest -from langchain.math_utils import cosine_similarity, cosine_similarity_top_k +from langchain.utils.math import cosine_similarity, cosine_similarity_top_k @pytest.fixture