add more import checks (#11033)

pull/11140/head
Harrison Chase 12 months ago committed by GitHub
parent efb7c459a2
commit e355606b11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,7 +15,7 @@ from typing import (
from typing_extensions import Literal from typing_extensions import Literal
from langchain.chat_loaders.base import ChatSession from langchain.schema.chat import ChatSession
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,

@ -1,15 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Iterator, List, Sequence, TypedDict from typing import Iterator, List
from langchain.schema.messages import BaseMessage from langchain.schema.chat import ChatSession
class ChatSession(TypedDict):
"""Chat Session represents a single
conversation, channel, or other group of messages."""
messages: Sequence[BaseMessage]
"""The LangChain chat messages loaded from the source."""
class BaseChatLoader(ABC): class BaseChatLoader(ABC):

@ -3,7 +3,8 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Iterator, Union from typing import Iterator, Union
from langchain.chat_loaders.base import BaseChatLoader, ChatSession from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema.chat import ChatSession
from langchain.schema.messages import HumanMessage from langchain.schema.messages import HumanMessage
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)

@ -2,7 +2,8 @@ import base64
import re import re
from typing import Any, Iterator from typing import Any, Iterator
from langchain.chat_loaders.base import BaseChatLoader, ChatSession from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema.chat import ChatSession
from langchain.schema.messages import HumanMessage from langchain.schema.messages import HumanMessage

@ -3,8 +3,9 @@ from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Iterator, List, Optional, Union from typing import TYPE_CHECKING, Iterator, List, Optional, Union
from langchain.chat_loaders.base import BaseChatLoader, ChatSession from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema import HumanMessage from langchain.schema import HumanMessage
from langchain.schema.chat import ChatSession
if TYPE_CHECKING: if TYPE_CHECKING:
import sqlite3 import sqlite3

@ -5,8 +5,9 @@ import zipfile
from pathlib import Path from pathlib import Path
from typing import Dict, Iterator, List, Union from typing import Dict, Iterator, List, Union
from langchain.chat_loaders.base import BaseChatLoader, ChatSession from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema import AIMessage, HumanMessage from langchain.schema import AIMessage, HumanMessage
from langchain.schema.chat import ChatSession
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

@ -6,8 +6,9 @@ import zipfile
from pathlib import Path from pathlib import Path
from typing import Iterator, List, Union from typing import Iterator, List, Union
from langchain.chat_loaders.base import BaseChatLoader, ChatSession from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema import AIMessage, BaseMessage, HumanMessage from langchain.schema import AIMessage, BaseMessage, HumanMessage
from langchain.schema.chat import ChatSession
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

@ -2,7 +2,7 @@
from copy import deepcopy from copy import deepcopy
from typing import Iterable, Iterator, List from typing import Iterable, Iterator, List
from langchain.chat_loaders.base import ChatSession from langchain.schema.chat import ChatSession
from langchain.schema.messages import AIMessage, BaseMessage from langchain.schema.messages import AIMessage, BaseMessage

@ -4,8 +4,9 @@ import re
import zipfile import zipfile
from typing import Iterator, List, Union from typing import Iterator, List, Union
from langchain.chat_loaders.base import BaseChatLoader, ChatSession from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema import AIMessage, HumanMessage from langchain.schema import AIMessage, HumanMessage
from langchain.schema.chat import ChatSession
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

@ -1,8 +1,8 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain.requests import Requests
from langchain.schema.embeddings import Embeddings from langchain.schema.embeddings import Embeddings
from langchain.utilities.requests import Requests
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env

@ -11,7 +11,7 @@ from langchain.callbacks.manager import (
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, Field, root_validator from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.requests import Requests from langchain.utilities.requests import Requests
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

@ -5,10 +5,10 @@ import warnings
from abc import ABC from abc import ABC
from typing import Any, Callable, Dict, List, Set from typing import Any, Callable, Dict, List, Set
from langchain.formatting import formatter
from langchain.schema.messages import BaseMessage, HumanMessage from langchain.schema.messages import BaseMessage, HumanMessage
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.prompt_template import BasePromptTemplate from langchain.schema.prompt_template import BasePromptTemplate
from langchain.utils.formatting import formatter
def jinja2_formatter(template: str, **kwargs: Any) -> str: def jinja2_formatter(template: str, **kwargs: Any) -> str:

@ -6,11 +6,10 @@ from typing import Callable, Dict, Union
import yaml import yaml
from langchain.output_parsers.regex import RegexParser
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, StrOutputParser from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, StrOutputParser
from langchain.utilities.loading import try_load_from_hub from langchain.utils.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -77,6 +76,8 @@ def _load_output_parser(config: dict) -> dict:
_config = config.pop("output_parser") _config = config.pop("output_parser")
output_parser_type = _config.pop("_type") output_parser_type = _config.pop("_type")
if output_parser_type == "regex_parser": if output_parser_type == "regex_parser":
from langchain.output_parsers.regex import RegexParser
output_parser: BaseLLMOutputParser = RegexParser(**_config) output_parser: BaseLLMOutputParser = RegexParser(**_config)
elif output_parser_type == "default": elif output_parser_type == "default":
output_parser = StrOutputParser(**_config) output_parser = StrOutputParser(**_config)

@ -0,0 +1,11 @@
from typing import Sequence, TypedDict
from langchain.schema import BaseMessage
class ChatSession(TypedDict):
"""Chat Session represents a single
conversation, channel, or other group of messages."""
messages: Sequence[BaseMessage]
"""The LangChain chat messages loaded from the source."""

@ -32,10 +32,9 @@ if TYPE_CHECKING:
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
from langchain.callbacks.tracers.log_stream import RunLogPatch
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.tracers.log_stream import LogStreamCallbackHandler, RunLogPatch
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
@ -216,6 +215,12 @@ class Runnable(Generic[Input, Output], ABC):
The jsonpatch ops can be applied in order to construct state. The jsonpatch ops can be applied in order to construct state.
""" """
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.tracers.log_stream import (
LogStreamCallbackHandler,
RunLogPatch,
)
# Create a stream handler that will emit Log objects # Create a stream handler that will emit Log objects
stream = LogStreamCallbackHandler( stream = LogStreamCallbackHandler(
auto_close=False, auto_close=False,

@ -1,4 +1,15 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from tenacity import ( from tenacity import (
AsyncRetrying, AsyncRetrying,
@ -10,14 +21,16 @@ from tenacity import (
wait_exponential_jitter, wait_exponential_jitter,
) )
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.schema.runnable.base import Input, Output, RunnableBinding from langchain.schema.runnable.base import Input, Output, RunnableBinding
from langchain.schema.runnable.config import RunnableConfig, patch_config from langchain.schema.runnable.config import RunnableConfig, patch_config
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun) if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
U = TypeVar("U") U = TypeVar("U")
@ -54,7 +67,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
def _patch_config( def _patch_config(
self, self,
config: RunnableConfig, config: RunnableConfig,
run_manager: T, run_manager: "T",
retry_state: RetryCallState, retry_state: RetryCallState,
) -> RunnableConfig: ) -> RunnableConfig:
attempt = retry_state.attempt_number attempt = retry_state.attempt_number
@ -64,7 +77,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
def _patch_config_list( def _patch_config_list(
self, self,
config: List[RunnableConfig], config: List[RunnableConfig],
run_manager: List[T], run_manager: List["T"],
retry_state: RetryCallState, retry_state: RetryCallState,
) -> List[RunnableConfig]: ) -> List[RunnableConfig]:
return [ return [
@ -74,7 +87,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
def _invoke( def _invoke(
self, self,
input: Input, input: Input,
run_manager: CallbackManagerForChainRun, run_manager: "CallbackManagerForChainRun",
config: RunnableConfig, config: RunnableConfig,
) -> Output: ) -> Output:
for attempt in self._sync_retrying(reraise=True): for attempt in self._sync_retrying(reraise=True):
@ -95,7 +108,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
async def _ainvoke( async def _ainvoke(
self, self,
input: Input, input: Input,
run_manager: AsyncCallbackManagerForChainRun, run_manager: "AsyncCallbackManagerForChainRun",
config: RunnableConfig, config: RunnableConfig,
) -> Output: ) -> Output:
async for attempt in self._async_retrying(reraise=True): async for attempt in self._async_retrying(reraise=True):
@ -116,7 +129,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
def _batch( def _batch(
self, self,
inputs: List[Input], inputs: List[Input],
run_manager: List[CallbackManagerForChainRun], run_manager: List["CallbackManagerForChainRun"],
config: List[RunnableConfig], config: List[RunnableConfig],
) -> List[Union[Output, Exception]]: ) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {} results_map: Dict[int, Output] = {}
@ -180,7 +193,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
async def _abatch( async def _abatch(
self, self,
inputs: List[Input], inputs: List[Input],
run_manager: List[AsyncCallbackManagerForChainRun], run_manager: List["AsyncCallbackManagerForChainRun"],
config: List[RunnableConfig], config: List[RunnableConfig],
) -> List[Union[Output, Exception]]: ) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {} results_map: Dict[int, Output] = {}

@ -1,10 +1,12 @@
from typing import Any, Callable, Dict, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from langchain.document_loaders import ApifyDatasetLoader
from langchain.document_loaders.base import Document
from langchain.pydantic_v1 import BaseModel, root_validator from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema.document import Document
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
from langchain.document_loaders import ApifyDatasetLoader
class ApifyWrapper(BaseModel): class ApifyWrapper(BaseModel):
"""Wrapper around Apify. """Wrapper around Apify.
@ -48,7 +50,7 @@ class ApifyWrapper(BaseModel):
build: Optional[str] = None, build: Optional[str] = None,
memory_mbytes: Optional[int] = None, memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None, timeout_secs: Optional[int] = None,
) -> ApifyDatasetLoader: ) -> "ApifyDatasetLoader":
"""Run an Actor on the Apify platform and wait for results to be ready. """Run an Actor on the Apify platform and wait for results to be ready.
Args: Args:
actor_id (str): The ID or name of the Actor on the Apify platform. actor_id (str): The ID or name of the Actor on the Apify platform.
@ -65,6 +67,8 @@ class ApifyWrapper(BaseModel):
ApifyDatasetLoader: A loader that will fetch the records from the ApifyDatasetLoader: A loader that will fetch the records from the
Actor run's default dataset. Actor run's default dataset.
""" """
from langchain.document_loaders import ApifyDatasetLoader
actor_call = self.apify_client.actor(actor_id).call( actor_call = self.apify_client.actor(actor_id).call(
run_input=run_input, run_input=run_input,
build=build, build=build,
@ -86,7 +90,7 @@ class ApifyWrapper(BaseModel):
build: Optional[str] = None, build: Optional[str] = None,
memory_mbytes: Optional[int] = None, memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None, timeout_secs: Optional[int] = None,
) -> ApifyDatasetLoader: ) -> "ApifyDatasetLoader":
"""Run an Actor on the Apify platform and wait for results to be ready. """Run an Actor on the Apify platform and wait for results to be ready.
Args: Args:
actor_id (str): The ID or name of the Actor on the Apify platform. actor_id (str): The ID or name of the Actor on the Apify platform.
@ -103,6 +107,8 @@ class ApifyWrapper(BaseModel):
ApifyDatasetLoader: A loader that will fetch the records from the ApifyDatasetLoader: A loader that will fetch the records from the
Actor run's default dataset. Actor run's default dataset.
""" """
from langchain.document_loaders import ApifyDatasetLoader
actor_call = await self.apify_client_async.actor(actor_id).call( actor_call = await self.apify_client_async.actor(actor_id).call(
run_input=run_input, run_input=run_input,
build=build, build=build,
@ -124,7 +130,7 @@ class ApifyWrapper(BaseModel):
build: Optional[str] = None, build: Optional[str] = None,
memory_mbytes: Optional[int] = None, memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None, timeout_secs: Optional[int] = None,
) -> ApifyDatasetLoader: ) -> "ApifyDatasetLoader":
"""Run a saved Actor task on Apify and wait for results to be ready. """Run a saved Actor task on Apify and wait for results to be ready.
Args: Args:
task_id (str): The ID or name of the task on the Apify platform. task_id (str): The ID or name of the task on the Apify platform.
@ -142,6 +148,8 @@ class ApifyWrapper(BaseModel):
ApifyDatasetLoader: A loader that will fetch the records from the ApifyDatasetLoader: A loader that will fetch the records from the
task run's default dataset. task run's default dataset.
""" """
from langchain.document_loaders import ApifyDatasetLoader
task_call = self.apify_client.task(task_id).call( task_call = self.apify_client.task(task_id).call(
task_input=task_input, task_input=task_input,
build=build, build=build,
@ -163,7 +171,7 @@ class ApifyWrapper(BaseModel):
build: Optional[str] = None, build: Optional[str] = None,
memory_mbytes: Optional[int] = None, memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None, timeout_secs: Optional[int] = None,
) -> ApifyDatasetLoader: ) -> "ApifyDatasetLoader":
"""Run a saved Actor task on Apify and wait for results to be ready. """Run a saved Actor task on Apify and wait for results to be ready.
Args: Args:
task_id (str): The ID or name of the task on the Apify platform. task_id (str): The ID or name of the task on the Apify platform.
@ -181,6 +189,8 @@ class ApifyWrapper(BaseModel):
ApifyDatasetLoader: A loader that will fetch the records from the ApifyDatasetLoader: A loader that will fetch the records from the
task run's default dataset. task run's default dataset.
""" """
from langchain.document_loaders import ApifyDatasetLoader
task_call = await self.apify_client_async.task(task_id).call( task_call = await self.apify_client_async.task(task_id).call(
task_input=task_input, task_input=task_input,
build=build, build=build,

@ -1,54 +1,4 @@
"""Utilities for loading configurations from langchain-hub.""" from langchain.utils.loading import try_load_from_hub
import os # For backwards compatibility
import re __all__ = ["try_load_from_hub"]
import tempfile
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Optional, Set, TypeVar, Union
from urllib.parse import urljoin
import requests
DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
URL_BASE = os.environ.get(
"LANGCHAIN_HUB_URL_BASE",
"https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/",
)
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)")
T = TypeVar("T")
def try_load_from_hub(
path: Union[str, Path],
loader: Callable[[str], T],
valid_prefix: str,
valid_suffixes: Set[str],
**kwargs: Any,
) -> Optional[T]:
"""Load configuration from hub. Returns None if path is not a hub path."""
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
return None
ref, remote_path_str = match.groups()
ref = ref[1:] if ref else DEFAULT_REF
remote_path = Path(remote_path_str)
if remote_path.parts[0] != valid_prefix:
return None
if remote_path.suffix[1:] not in valid_suffixes:
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
# Using Path with URLs is not recommended, because on Windows
# the backslash is used as the path separator, which can cause issues
# when working with URLs that use forward slashes as the path separator.
# Instead, use PurePosixPath to ensure that forward slashes are used as the
# path separator, regardless of the operating system.
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())
r = requests.get(full_url, timeout=5)
if r.status_code != 200:
raise ValueError(f"Could not find file at {full_url}")
with tempfile.TemporaryDirectory() as tmpdirname:
file = Path(tmpdirname) / remote_path.name
with open(file, "wb") as f:
f.write(r.content)
return loader(str(file), **kwargs)

@ -0,0 +1,54 @@
"""Utilities for loading configurations from langchain-hub."""
import os
import re
import tempfile
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Optional, Set, TypeVar, Union
from urllib.parse import urljoin
import requests
DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
URL_BASE = os.environ.get(
"LANGCHAIN_HUB_URL_BASE",
"https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/",
)
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)")
T = TypeVar("T")
def try_load_from_hub(
path: Union[str, Path],
loader: Callable[[str], T],
valid_prefix: str,
valid_suffixes: Set[str],
**kwargs: Any,
) -> Optional[T]:
"""Load configuration from hub. Returns None if path is not a hub path."""
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
return None
ref, remote_path_str = match.groups()
ref = ref[1:] if ref else DEFAULT_REF
remote_path = Path(remote_path_str)
if remote_path.parts[0] != valid_prefix:
return None
if remote_path.suffix[1:] not in valid_suffixes:
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
# Using Path with URLs is not recommended, because on Windows
# the backslash is used as the path separator, which can cause issues
# when working with URLs that use forward slashes as the path separator.
# Instead, use PurePosixPath to ensure that forward slashes are used as the
# path separator, regardless of the operating system.
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())
r = requests.get(full_url, timeout=5)
if r.status_code != 200:
raise ValueError(f"Could not find file at {full_url}")
with tempfile.TemporaryDirectory() as tmpdirname:
file = Path(tmpdirname) / remote_path.name
with open(file, "wb") as f:
f.write(r.content)
return loader(str(file), **kwargs)

@ -6,8 +6,7 @@ import time
import uuid import uuid
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type
from langchain.docstore.document import Document from langchain.schema.document import Document
from langchain.embeddings import TensorflowHubEmbeddings
from langchain.schema.embeddings import Embeddings from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore from langchain.schema.vectorstore import VectorStore
@ -16,6 +15,8 @@ if TYPE_CHECKING:
from google.cloud.aiplatform import MatchingEngineIndex, MatchingEngineIndexEndpoint from google.cloud.aiplatform import MatchingEngineIndex, MatchingEngineIndexEndpoint
from google.oauth2.service_account import Credentials from google.oauth2.service_account import Credentials
from langchain.embeddings import TensorflowHubEmbeddings
logger = logging.getLogger() logger = logging.getLogger()
@ -443,10 +444,13 @@ class MatchingEngine(VectorStore):
) )
@classmethod @classmethod
def _get_default_embeddings(cls) -> TensorflowHubEmbeddings: def _get_default_embeddings(cls) -> "TensorflowHubEmbeddings":
"""This function returns the default embedding. """This function returns the default embedding.
Returns: Returns:
Default TensorflowHubEmbeddings to use. Default TensorflowHubEmbeddings to use.
""" """
from langchain.embeddings import TensorflowHubEmbeddings
return TensorflowHubEmbeddings() return TensorflowHubEmbeddings()

@ -18,8 +18,8 @@ from typing import (
Union, Union,
) )
from langchain.docstore.document import Document from langchain.schema.document import Document
from langchain.embeddings.base import Embeddings from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore from langchain.schema.vectorstore import VectorStore
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.utils import DistanceStrategy from langchain.vectorstores.utils import DistanceStrategy

@ -2,13 +2,31 @@
set -eu set -eu
git grep 'from langchain import' langchain | grep -vE 'from langchain import (__version__|hub)' && exit 1 || exit 0 # Initialize a variable to keep track of errors
errors=0
# Pydantic bridge should not import from any other module # Check the conditions
git grep 'from langchain ' langchain/pydantic_v1 && exit 1 || exit 0 git grep '^from langchain import' langchain | grep -vE 'from langchain import (__version__|hub)' && errors=$((errors+1))
git grep '^from langchain ' langchain/pydantic_v1 | grep -vE 'from langchain.(pydantic_v1)' && errors=$((errors+1))
git grep '^from langchain' langchain/load | grep -vE 'from langchain.(pydantic_v1|load)' && errors=$((errors+1))
git grep '^from langchain' langchain/utils | grep -vE 'from langchain.(pydantic_v1|utils)' && errors=$((errors+1))
git grep '^from langchain' langchain/schema | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1))
git grep '^from langchain' langchain/adapters | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1))
git grep '^from langchain' langchain/callbacks | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env)' && errors=$((errors+1))
# TODO: it's probably not amazing so that so many other modules depend on `langchain.utilities`, because there can be a lot of imports there
git grep '^from langchain' langchain/utilities | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1))
git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1))
git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1))
git grep '^from langchain' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1))
# load should not import from anything except itself and pydantic_v1 # Decide on an exit status based on the errors
git grep 'from langchain' langchain/load | grep -vE 'from langchain.(pydantic_v1)' && exit 1 || exit 0 if [ "$errors" -gt 0 ]; then
exit 1
# utils should not import from anything except itself and pydantic_v1 else
git grep 'from langchain' langchain/utils | grep -vE 'from langchain.(pydantic_v1|utils)' && exit 1 || exit 0 exit 0
fi

@ -10,7 +10,7 @@ from urllib.parse import urljoin
import pytest import pytest
import responses import responses
from langchain.utilities.loading import DEFAULT_REF, URL_BASE, try_load_from_hub from langchain.utils.loading import DEFAULT_REF, URL_BASE, try_load_from_hub
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)

Loading…
Cancel
Save