Refactored math_utils (#7961)

`math_utils.py` is in the root code folder. This creates the
`langchain.math_utils: Math Utils` group on the API Reference navigation
ToC, on the same level with `Chains` and `Agents` which is not correct.

Refactoring:
- created the `utils/` folder
- moved `math_utils.py` to `utils/math.py`
- moved `utils.py` to `utils/utils.py`
- split `utils.py` into `utils.py, env.py, strings.py`
- added module description

@baskaryan
This commit is contained in:
Leonid Ganeline 2023-07-20 18:55:43 -07:00 committed by GitHub
parent 5137f40dd6
commit 995220b797
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 108 additions and 75 deletions

View File

@ -5,8 +5,8 @@ import numpy as np
from pydantic import BaseModel, Field
from langchain.embeddings.base import Embeddings
from langchain.math_utils import cosine_similarity
from langchain.schema import BaseDocumentTransformer, Document
from langchain.utils.math import cosine_similarity
class _DocumentWithState(Document):

View File

@ -14,8 +14,8 @@ from langchain.chains.base import Chain
from langchain.embeddings.base import Embeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
from langchain.math_utils import cosine_similarity
from langchain.schema import RUN_KEY
from langchain.utils.math import cosine_similarity
class EmbeddingDistance(str, Enum):

View File

@ -9,11 +9,11 @@ from langchain.document_transformers.embeddings_redundant_filter import (
get_stateful_documents,
)
from langchain.embeddings.base import Embeddings
from langchain.math_utils import cosine_similarity
from langchain.retrievers.document_compressors.base import (
BaseDocumentCompressor,
)
from langchain.schema import Document
from langchain.utils.math import cosine_similarity
class EmbeddingsFilter(BaseDocumentCompressor):

View File

@ -10,7 +10,7 @@ from sqlalchemy.engine import Engine
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from langchain import utils
from langchain.utils import get_from_env
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
@ -192,13 +192,11 @@ class SQLDatabase:
default_host = context.browserHostName if context else None
if host is None:
host = utils.get_from_env("host", "DATABRICKS_HOST", default_host)
host = get_from_env("host", "DATABRICKS_HOST", default_host)
default_api_token = context.apiToken if context else None
if api_token is None:
api_token = utils.get_from_env(
"api_token", "DATABRICKS_TOKEN", default_api_token
)
api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token)
if warehouse_id is None and cluster_id is None:
if context:

View File

@ -0,0 +1,33 @@
"""
Utility functions for langchain.
These functions do not depend on any other langchain modules.
"""
from langchain.utils.env import get_from_dict_or_env, get_from_env
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
from langchain.utils.utils import (
check_package_version,
get_pydantic_field_names,
guard_import,
mock_now,
raise_for_status_with_text,
xor_args,
)
__all__ = [
"check_package_version",
"comma_list",
"cosine_similarity",
"cosine_similarity_top_k",
"get_from_dict_or_env",
"get_from_env",
"get_pydantic_field_names",
"guard_import",
"mock_now",
"raise_for_status_with_text",
"stringify_dict",
"stringify_value",
"xor_args",
]

26
langchain/utils/env.py Normal file
View File

@ -0,0 +1,26 @@
import os
from typing import Any, Dict, Optional
def get_from_dict_or_env(
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
) -> str:
"""Get a value from a dictionary or an environment variable."""
if key in data and data[key]:
return data[key]
else:
return get_from_env(key, env_key, default=default)
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
"""Get a value from a dictionary or an environment variable."""
if env_key in os.environ and os.environ[env_key]:
return os.environ[env_key]
elif default is not None:
return default
else:
raise ValueError(
f"Did not find {key}, please add an environment variable"
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)

View File

@ -0,0 +1,39 @@
from typing import Any, List
def stringify_value(val: Any) -> str:
"""Stringify a value.
Args:
val: The value to stringify.
Returns:
str: The stringified value.
"""
if isinstance(val, str):
return val
elif isinstance(val, dict):
return "\n" + stringify_dict(val)
elif isinstance(val, list):
return "\n".join(stringify_value(v) for v in val)
else:
return str(val)
def stringify_dict(data: dict) -> str:
"""Stringify a dictionary.
Args:
data: The dictionary to stringify.
Returns:
str: The stringified dictionary.
"""
text = ""
for key, value in data.items():
text += key + ": " + stringify_value(value) + "\n"
return text
def comma_list(items: List[Any]) -> str:
return ", ".join(str(item) for item in items)

View File

@ -2,38 +2,13 @@
import contextlib
import datetime
import importlib
import os
from importlib.metadata import version
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from typing import Any, Callable, Optional, Set, Tuple
from packaging.version import parse
from requests import HTTPError, Response
def get_from_dict_or_env(
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
) -> str:
"""Get a value from a dictionary or an environment variable."""
if key in data and data[key]:
return data[key]
else:
return get_from_env(key, env_key, default=default)
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
"""Get a value from a dictionary or an environment variable."""
if env_key in os.environ and os.environ[env_key]:
return os.environ[env_key]
elif default is not None:
return default
else:
raise ValueError(
f"Did not find {key}, please add an environment variable"
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive."""
@ -67,44 +42,6 @@ def raise_for_status_with_text(response: Response) -> None:
raise ValueError(response.text) from e
def stringify_value(val: Any) -> str:
"""Stringify a value.
Args:
val: The value to stringify.
Returns:
str: The stringified value.
"""
if isinstance(val, str):
return val
elif isinstance(val, dict):
return "\n" + stringify_dict(val)
elif isinstance(val, list):
return "\n".join(stringify_value(v) for v in val)
else:
return str(val)
def stringify_dict(data: dict) -> str:
"""Stringify a dictionary.
Args:
data: The dictionary to stringify.
Returns:
str: The stringified dictionary.
"""
text = ""
for key, value in data.items():
text += key + ": " + stringify_value(value) + "\n"
return text
def comma_list(items: List[Any]) -> str:
return ", ".join(str(item) for item in items)
@contextlib.contextmanager
def mock_now(dt_value): # type: ignore
"""Context manager for mocking out datetime.now() in unit tests.

View File

@ -5,7 +5,7 @@ from typing import List
import numpy as np
from langchain.math_utils import cosine_similarity
from langchain.utils.math import cosine_similarity
class DistanceStrategy(str, Enum):

View File

@ -2,7 +2,7 @@
from langchain.document_transformers.embeddings_redundant_filter import (
_filter_similar_embeddings,
)
from langchain.math_utils import cosine_similarity
from langchain.utils.math import cosine_similarity
def test__filter_similar_embeddings() -> None:

View File

@ -4,7 +4,7 @@ from typing import List
import numpy as np
import pytest
from langchain.math_utils import cosine_similarity, cosine_similarity_top_k
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
@pytest.fixture