@ -3,7 +3,7 @@ import inspect
import os
import pathlib
from pathlib import Path
from typing import Any , List , Tuple, Type
from typing import Any , List , Optional, Tuple, Type
HERE = Path ( __file__ ) . parent
# Should bring us to [root]/src
@ -15,13 +15,15 @@ PARTNER_PKGS = PKGS_ROOT / "partners"
class ImportExtractor ( ast . NodeVisitor ) :
def __init__ ( self , * , from_package : str ) - > None :
""" Extract all imports from the given package."""
def __init__ ( self , * , from_package : Optional [ str ] = None ) - > None :
""" Extract all imports from the given code, optionally filtering by package."""
self . imports = [ ]
self . package = from_package
def visit_ImportFrom ( self , node ) :
if node . module and str ( node . module ) . startswith ( self . package ) :
if node . module and (
self . package is None or str ( node . module ) . startswith ( self . package )
) :
for alias in node . names :
self . imports . append ( ( node . module , alias . name ) )
self . generic_visit ( node )
@ -72,13 +74,40 @@ def _get_all_classnames_from_file(file: str, pkg: str) -> List[Tuple[str, str]]:
code = f . read ( )
module_name = _get_current_module ( file , pkg )
class_names = _get_class_names ( code )
return [ ( module_name , class_name ) for class_name in class_names ]
def identify_all_imports_in_file (
file : str , * , from_package : Optional [ str ] = None
) - > List [ Tuple [ str , str ] ] :
""" Let ' s also identify all the imports in the given file. """
with open ( file , encoding = " utf-8 " ) as f :
code = f . read ( )
return find_imports_from_package ( code , from_package = from_package )
def identify_pkg_source ( pkg_root : str ) - > pathlib . Path :
""" Identify the source of the package.
Args :
pkg_root : the root of the package . This contains source + tests , and other
things like pyproject . toml , lock files etc
Returns :
Returns the path to the source code for the package .
"""
dirs = [ d for d in Path ( pkg_root ) . iterdir ( ) if d . is_dir ( ) ]
matching_dirs = [ d for d in dirs if d . name . startswith ( " langchain_ " ) ]
assert len ( matching_dirs ) == 1 , " There should be only one langchain package. "
return matching_dirs [ 0 ]
def list_classes_by_package ( pkg_root : str ) - > List [ Tuple [ str , str ] ] :
""" List all classes in a package. """
module_classes = [ ]
files = list ( Path ( pkg_root ) . rglob ( " *.py " ) )
pkg_source = identify_pkg_source ( pkg_root )
files = list ( pkg_source . rglob ( " *.py " ) )
for file in files :
rel_path = os . path . relpath ( file , pkg_root )
@ -88,11 +117,29 @@ def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
return module_classes
def find_imports_from_package ( code : str , * , from_package : str ) - > List [ Tuple [ str , str ] ] :
def list_init_imports_by_package ( pkg_root : str ) - > List [ Tuple [ str , str ] ] :
""" List all the things that are being imported in a package by module. """
imports = [ ]
pkg_source = identify_pkg_source ( pkg_root )
# Scan all the files in the package
files = list ( Path ( pkg_source ) . rglob ( " *.py " ) )
for file in files :
if not file . name == " __init__.py " :
continue
import_in_file = identify_all_imports_in_file ( str ( file ) )
module_name = _get_current_module ( file , pkg_root )
imports . extend ( [ ( module_name , item ) for _ , item in import_in_file ] )
return imports
def find_imports_from_package (
code : str , * , from_package : Optional [ str ] = None
) - > List [ Tuple [ str , str ] ] :
# Parse the code into an AST
tree = ast . parse ( code )
# Create an instance of the visitor
extractor = ImportExtractor ( from_package = " langchain_community " )
extractor = ImportExtractor ( from_package = from_package )
# Use the visitor to update the imports list
extractor . visit ( tree )
return extractor . imports