mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Refactored math_utils
(#7961)
`math_utils.py` is in the root code folder. This creates the `langchain.math_utils: Math Utils` group on the API Reference navigation ToC, on the same level with `Chains` and `Agents` which is not correct. Refactoring: - created the `utils/` folder - moved `math_utils.py` to `utils/math.py` - moved `utils.py` to `utils/utils.py` - split `utils.py` into `utils.py, env.py, strings.py` - added module description @baskaryan
This commit is contained in:
parent
5137f40dd6
commit
995220b797
@ -5,8 +5,8 @@ import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.math_utils import cosine_similarity
|
||||
from langchain.schema import BaseDocumentTransformer, Document
|
||||
from langchain.utils.math import cosine_similarity
|
||||
|
||||
|
||||
class _DocumentWithState(Document):
|
||||
|
@ -14,8 +14,8 @@ from langchain.chains.base import Chain
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
||||
from langchain.math_utils import cosine_similarity
|
||||
from langchain.schema import RUN_KEY
|
||||
from langchain.utils.math import cosine_similarity
|
||||
|
||||
|
||||
class EmbeddingDistance(str, Enum):
|
||||
|
@ -9,11 +9,11 @@ from langchain.document_transformers.embeddings_redundant_filter import (
|
||||
get_stateful_documents,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.math_utils import cosine_similarity
|
||||
from langchain.retrievers.document_compressors.base import (
|
||||
BaseDocumentCompressor,
|
||||
)
|
||||
from langchain.schema import Document
|
||||
from langchain.utils.math import cosine_similarity
|
||||
|
||||
|
||||
class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
|
@ -10,7 +10,7 @@ from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||
from sqlalchemy.schema import CreateTable
|
||||
|
||||
from langchain import utils
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
|
||||
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||
@ -192,13 +192,11 @@ class SQLDatabase:
|
||||
|
||||
default_host = context.browserHostName if context else None
|
||||
if host is None:
|
||||
host = utils.get_from_env("host", "DATABRICKS_HOST", default_host)
|
||||
host = get_from_env("host", "DATABRICKS_HOST", default_host)
|
||||
|
||||
default_api_token = context.apiToken if context else None
|
||||
if api_token is None:
|
||||
api_token = utils.get_from_env(
|
||||
"api_token", "DATABRICKS_TOKEN", default_api_token
|
||||
)
|
||||
api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token)
|
||||
|
||||
if warehouse_id is None and cluster_id is None:
|
||||
if context:
|
||||
|
33
langchain/utils/__init__.py
Normal file
33
langchain/utils/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""
|
||||
Utility functions for langchain.
|
||||
|
||||
These functions do not depend on any other langchain modules.
|
||||
"""
|
||||
|
||||
from langchain.utils.env import get_from_dict_or_env, get_from_env
|
||||
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
|
||||
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
|
||||
from langchain.utils.utils import (
|
||||
check_package_version,
|
||||
get_pydantic_field_names,
|
||||
guard_import,
|
||||
mock_now,
|
||||
raise_for_status_with_text,
|
||||
xor_args,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"check_package_version",
|
||||
"comma_list",
|
||||
"cosine_similarity",
|
||||
"cosine_similarity_top_k",
|
||||
"get_from_dict_or_env",
|
||||
"get_from_env",
|
||||
"get_pydantic_field_names",
|
||||
"guard_import",
|
||||
"mock_now",
|
||||
"raise_for_status_with_text",
|
||||
"stringify_dict",
|
||||
"stringify_value",
|
||||
"xor_args",
|
||||
]
|
26
langchain/utils/env.py
Normal file
26
langchain/utils/env.py
Normal file
@ -0,0 +1,26 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
def get_from_dict_or_env(
|
||||
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get a value from a dictionary or an environment variable."""
|
||||
if key in data and data[key]:
|
||||
return data[key]
|
||||
else:
|
||||
return get_from_env(key, env_key, default=default)
|
||||
|
||||
|
||||
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
||||
"""Get a value from a dictionary or an environment variable."""
|
||||
if env_key in os.environ and os.environ[env_key]:
|
||||
return os.environ[env_key]
|
||||
elif default is not None:
|
||||
return default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Did not find {key}, please add an environment variable"
|
||||
f" `{env_key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
39
langchain/utils/strings.py
Normal file
39
langchain/utils/strings.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Any, List
|
||||
|
||||
|
||||
def stringify_value(val: Any) -> str:
|
||||
"""Stringify a value.
|
||||
|
||||
Args:
|
||||
val: The value to stringify.
|
||||
|
||||
Returns:
|
||||
str: The stringified value.
|
||||
"""
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
elif isinstance(val, dict):
|
||||
return "\n" + stringify_dict(val)
|
||||
elif isinstance(val, list):
|
||||
return "\n".join(stringify_value(v) for v in val)
|
||||
else:
|
||||
return str(val)
|
||||
|
||||
|
||||
def stringify_dict(data: dict) -> str:
|
||||
"""Stringify a dictionary.
|
||||
|
||||
Args:
|
||||
data: The dictionary to stringify.
|
||||
|
||||
Returns:
|
||||
str: The stringified dictionary.
|
||||
"""
|
||||
text = ""
|
||||
for key, value in data.items():
|
||||
text += key + ": " + stringify_value(value) + "\n"
|
||||
return text
|
||||
|
||||
|
||||
def comma_list(items: List[Any]) -> str:
|
||||
return ", ".join(str(item) for item in items)
|
@ -2,38 +2,13 @@
|
||||
import contextlib
|
||||
import datetime
|
||||
import importlib
|
||||
import os
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Callable, Optional, Set, Tuple
|
||||
|
||||
from packaging.version import parse
|
||||
from requests import HTTPError, Response
|
||||
|
||||
|
||||
def get_from_dict_or_env(
|
||||
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get a value from a dictionary or an environment variable."""
|
||||
if key in data and data[key]:
|
||||
return data[key]
|
||||
else:
|
||||
return get_from_env(key, env_key, default=default)
|
||||
|
||||
|
||||
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
||||
"""Get a value from a dictionary or an environment variable."""
|
||||
if env_key in os.environ and os.environ[env_key]:
|
||||
return os.environ[env_key]
|
||||
elif default is not None:
|
||||
return default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Did not find {key}, please add an environment variable"
|
||||
f" `{env_key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
||||
|
||||
|
||||
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||
"""Validate specified keyword args are mutually exclusive."""
|
||||
|
||||
@ -67,44 +42,6 @@ def raise_for_status_with_text(response: Response) -> None:
|
||||
raise ValueError(response.text) from e
|
||||
|
||||
|
||||
def stringify_value(val: Any) -> str:
|
||||
"""Stringify a value.
|
||||
|
||||
Args:
|
||||
val: The value to stringify.
|
||||
|
||||
Returns:
|
||||
str: The stringified value.
|
||||
"""
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
elif isinstance(val, dict):
|
||||
return "\n" + stringify_dict(val)
|
||||
elif isinstance(val, list):
|
||||
return "\n".join(stringify_value(v) for v in val)
|
||||
else:
|
||||
return str(val)
|
||||
|
||||
|
||||
def stringify_dict(data: dict) -> str:
|
||||
"""Stringify a dictionary.
|
||||
|
||||
Args:
|
||||
data: The dictionary to stringify.
|
||||
|
||||
Returns:
|
||||
str: The stringified dictionary.
|
||||
"""
|
||||
text = ""
|
||||
for key, value in data.items():
|
||||
text += key + ": " + stringify_value(value) + "\n"
|
||||
return text
|
||||
|
||||
|
||||
def comma_list(items: List[Any]) -> str:
|
||||
return ", ".join(str(item) for item in items)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_now(dt_value): # type: ignore
|
||||
"""Context manager for mocking out datetime.now() in unit tests.
|
@ -5,7 +5,7 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.math_utils import cosine_similarity
|
||||
from langchain.utils.math import cosine_similarity
|
||||
|
||||
|
||||
class DistanceStrategy(str, Enum):
|
||||
|
@ -2,7 +2,7 @@
|
||||
from langchain.document_transformers.embeddings_redundant_filter import (
|
||||
_filter_similar_embeddings,
|
||||
)
|
||||
from langchain.math_utils import cosine_similarity
|
||||
from langchain.utils.math import cosine_similarity
|
||||
|
||||
|
||||
def test__filter_similar_embeddings() -> None:
|
||||
|
@ -4,7 +4,7 @@ from typing import List
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain.math_utils import cosine_similarity, cosine_similarity_top_k
|
||||
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
Loading…
Reference in New Issue
Block a user