From 7a5d042bd2fcf7c18fdc8f70c1015ea05f44300b Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 21 May 2024 17:45:26 -0400 Subject: [PATCH] langchain[patch]: Add unit test to detect changes to community imports (#21998) Add unit tests for community imports --- .../tests/unit_tests/test_imports.py | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/libs/langchain/tests/unit_tests/test_imports.py b/libs/langchain/tests/unit_tests/test_imports.py index 422cd046ce..1433638098 100644 --- a/libs/langchain/tests/unit_tests/test_imports.py +++ b/libs/langchain/tests/unit_tests/test_imports.py @@ -1,5 +1,7 @@ +import ast import importlib from pathlib import Path +from typing import Any, Dict, Optional # Attempt to recursively import all modules in langchain PKG_ROOT = Path(__file__).parent.parent.parent @@ -60,3 +62,93 @@ def test_import_all_using_dir() -> None: continue # Attempt to import the name from the module getattr(mod, name) + + +def test_no_more_changes_to_proxy_community() -> None: + """This test is meant to catch any changes to the proxy community module. + + Imports from langchain to community are officially DEPRECATED. Contributors + should not be adding new imports from langchain to community. This test + is meant to catch any new changes to the proxy community module. + """ + library_code = PKG_ROOT / "langchain" + hash_ = 0 + for path in library_code.rglob("*.py"): + # Calculate the relative path to the module + if not str(path).endswith("__init__.py"): + continue + + deprecated_lookup = extract_deprecated_lookup(str(path)) + if deprecated_lookup is None: + continue + + # This uses a very simple hash, so it's not foolproof, but it should catch + # most cases. + hash_ += len(str(sorted(deprecated_lookup.items()))) + + evil_magic_number = 38620 + + assert hash_ == evil_magic_number, ( + "If you're triggering this test, you're likely adding a new import " + "to the langchain package that is importing something from " + "langchain_community. This test is meant to catch such such imports " + "as they are officially DEPRECATED. Please do not add any new imports " + "from langchain_community to the langchain package. " + ) + + +def extract_deprecated_lookup(file_path: str) -> Optional[Dict[str, Any]]: + """Detect and extracts the value of a dictionary named DEPRECATED_LOOKUP + + This variable is located in the global namespace of a Python file. + + Args: + file_path (str): The path to the Python file. + + Returns: + dict or None: The value of DEPRECATED_LOOKUP if it exists, None otherwise. + """ + with open(file_path, "r") as file: + tree = ast.parse(file.read(), filename=file_path) + + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "DEPRECATED_LOOKUP": + if isinstance(node.value, ast.Dict): + return _dict_from_ast(node.value) + return None + + +def _dict_from_ast(node: ast.Dict) -> Dict[str, str]: + """Convert an AST dict node to a Python dictionary, assuming str to str format. + + Args: + node (ast.Dict): The AST node representing a dictionary. + + Returns: + dict: The corresponding Python dictionary. + """ + result: Dict[str, str] = {} + for key, value in zip(node.keys, node.values): + py_key = _literal_eval_str(key) # type: ignore + py_value = _literal_eval_str(value) + result[py_key] = py_value + return result + + +def _literal_eval_str(node: ast.AST) -> str: + """Evaluate an AST literal node to its corresponding string value. + + Args: + node (ast.AST): The AST node representing a literal value. + + Returns: + str: The corresponding string value. + """ + if isinstance(node, ast.Constant): # Python 3.8+ + if isinstance(node.value, str): + return node.value + raise AssertionError( + f"Invalid DEPRECATED_LOOKUP format: expected str, got {type(node).__name__}" + )