diff --git a/libs/langchain/langchain/adapters/openai.py b/libs/langchain/langchain/adapters/openai.py index b846d4f465..3c2429a2b1 100644 --- a/libs/langchain/langchain/adapters/openai.py +++ b/libs/langchain/langchain/adapters/openai.py @@ -15,7 +15,7 @@ from typing import ( from typing_extensions import Literal -from langchain.chat_loaders.base import ChatSession +from langchain.schema.chat import ChatSession from langchain.schema.messages import ( AIMessage, AIMessageChunk, diff --git a/libs/langchain/langchain/chat_loaders/base.py b/libs/langchain/langchain/chat_loaders/base.py index 6e1f37ca9a..63203588d5 100644 --- a/libs/langchain/langchain/chat_loaders/base.py +++ b/libs/langchain/langchain/chat_loaders/base.py @@ -1,15 +1,7 @@ from abc import ABC, abstractmethod -from typing import Iterator, List, Sequence, TypedDict +from typing import Iterator, List -from langchain.schema.messages import BaseMessage - - -class ChatSession(TypedDict): - """Chat Session represents a single - conversation, channel, or other group of messages.""" - - messages: Sequence[BaseMessage] - """The LangChain chat messages loaded from the source.""" +from langchain.schema.chat import ChatSession class BaseChatLoader(ABC): diff --git a/libs/langchain/langchain/chat_loaders/facebook_messenger.py b/libs/langchain/langchain/chat_loaders/facebook_messenger.py index bfdc0155c7..644133f1bf 100644 --- a/libs/langchain/langchain/chat_loaders/facebook_messenger.py +++ b/libs/langchain/langchain/chat_loaders/facebook_messenger.py @@ -3,7 +3,8 @@ import logging from pathlib import Path from typing import Iterator, Union -from langchain.chat_loaders.base import BaseChatLoader, ChatSession +from langchain.chat_loaders.base import BaseChatLoader +from langchain.schema.chat import ChatSession from langchain.schema.messages import HumanMessage logger = logging.getLogger(__file__) diff --git a/libs/langchain/langchain/chat_loaders/gmail.py b/libs/langchain/langchain/chat_loaders/gmail.py index 94a3c5617e..f4e57d9241 100644 --- a/libs/langchain/langchain/chat_loaders/gmail.py +++ b/libs/langchain/langchain/chat_loaders/gmail.py @@ -2,7 +2,8 @@ import base64 import re from typing import Any, Iterator -from langchain.chat_loaders.base import BaseChatLoader, ChatSession +from langchain.chat_loaders.base import BaseChatLoader +from langchain.schema.chat import ChatSession from langchain.schema.messages import HumanMessage diff --git a/libs/langchain/langchain/chat_loaders/imessage.py b/libs/langchain/langchain/chat_loaders/imessage.py index a656c60b76..78b3c42974 100644 --- a/libs/langchain/langchain/chat_loaders/imessage.py +++ b/libs/langchain/langchain/chat_loaders/imessage.py @@ -3,8 +3,9 @@ from __future__ import annotations from pathlib import Path from typing import TYPE_CHECKING, Iterator, List, Optional, Union -from langchain.chat_loaders.base import BaseChatLoader, ChatSession +from langchain.chat_loaders.base import BaseChatLoader from langchain.schema import HumanMessage +from langchain.schema.chat import ChatSession if TYPE_CHECKING: import sqlite3 diff --git a/libs/langchain/langchain/chat_loaders/slack.py b/libs/langchain/langchain/chat_loaders/slack.py index 29c2dc794c..65791a13f5 100644 --- a/libs/langchain/langchain/chat_loaders/slack.py +++ b/libs/langchain/langchain/chat_loaders/slack.py @@ -5,8 +5,9 @@ import zipfile from pathlib import Path from typing import Dict, Iterator, List, Union -from langchain.chat_loaders.base import BaseChatLoader, ChatSession +from langchain.chat_loaders.base import BaseChatLoader from langchain.schema import AIMessage, HumanMessage +from langchain.schema.chat import ChatSession logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_loaders/telegram.py b/libs/langchain/langchain/chat_loaders/telegram.py index 2441d49a07..d417ebafe1 100644 --- a/libs/langchain/langchain/chat_loaders/telegram.py +++ b/libs/langchain/langchain/chat_loaders/telegram.py @@ -6,8 +6,9 @@ import zipfile from pathlib import Path from typing import Iterator, List, Union -from langchain.chat_loaders.base import BaseChatLoader, ChatSession +from langchain.chat_loaders.base import BaseChatLoader from langchain.schema import AIMessage, BaseMessage, HumanMessage +from langchain.schema.chat import ChatSession logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/chat_loaders/utils.py b/libs/langchain/langchain/chat_loaders/utils.py index 1c75c852d2..9a351becf2 100644 --- a/libs/langchain/langchain/chat_loaders/utils.py +++ b/libs/langchain/langchain/chat_loaders/utils.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Iterable, Iterator, List -from langchain.chat_loaders.base import ChatSession +from langchain.schema.chat import ChatSession from langchain.schema.messages import AIMessage, BaseMessage diff --git a/libs/langchain/langchain/chat_loaders/whatsapp.py b/libs/langchain/langchain/chat_loaders/whatsapp.py index ad9c1ee9c3..f8de9c0e41 100644 --- a/libs/langchain/langchain/chat_loaders/whatsapp.py +++ b/libs/langchain/langchain/chat_loaders/whatsapp.py @@ -4,8 +4,9 @@ import re import zipfile from typing import Iterator, List, Union -from langchain.chat_loaders.base import BaseChatLoader, ChatSession +from langchain.chat_loaders.base import BaseChatLoader from langchain.schema import AIMessage, HumanMessage +from langchain.schema.chat import ChatSession logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/embeddings/edenai.py b/libs/langchain/langchain/embeddings/edenai.py index 0b3c9749f0..8a0f717dc9 100644 --- a/libs/langchain/langchain/embeddings/edenai.py +++ b/libs/langchain/langchain/embeddings/edenai.py @@ -1,8 +1,8 @@ from typing import Any, Dict, List, Optional from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain.requests import Requests from langchain.schema.embeddings import Embeddings +from langchain.utilities.requests import Requests from langchain.utils import get_from_dict_or_env diff --git a/libs/langchain/langchain/llms/edenai.py b/libs/langchain/langchain/llms/edenai.py index 521653b427..8d333c46c1 100644 --- a/libs/langchain/langchain/llms/edenai.py +++ b/libs/langchain/langchain/llms/edenai.py @@ -11,7 +11,7 @@ from langchain.callbacks.manager import ( from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.pydantic_v1 import Extra, Field, root_validator -from langchain.requests import Requests +from langchain.utilities.requests import Requests from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/prompts/base.py b/libs/langchain/langchain/prompts/base.py index 95d83256bf..74dea35eea 100644 --- a/libs/langchain/langchain/prompts/base.py +++ b/libs/langchain/langchain/prompts/base.py @@ -5,10 +5,10 @@ import warnings from abc import ABC from typing import Any, Callable, Dict, List, Set -from langchain.formatting import formatter from langchain.schema.messages import BaseMessage, HumanMessage from langchain.schema.prompt import PromptValue from langchain.schema.prompt_template import BasePromptTemplate +from langchain.utils.formatting import formatter def jinja2_formatter(template: str, **kwargs: Any) -> str: diff --git a/libs/langchain/langchain/prompts/loading.py b/libs/langchain/langchain/prompts/loading.py index 84f35fa8f5..47512612f5 100644 --- a/libs/langchain/langchain/prompts/loading.py +++ b/libs/langchain/langchain/prompts/loading.py @@ -6,11 +6,10 @@ from typing import Callable, Dict, Union import yaml -from langchain.output_parsers.regex import RegexParser from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, StrOutputParser -from langchain.utilities.loading import try_load_from_hub +from langchain.utils.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" logger = logging.getLogger(__name__) @@ -77,6 +76,8 @@ def _load_output_parser(config: dict) -> dict: _config = config.pop("output_parser") output_parser_type = _config.pop("_type") if output_parser_type == "regex_parser": + from langchain.output_parsers.regex import RegexParser + output_parser: BaseLLMOutputParser = RegexParser(**_config) elif output_parser_type == "default": output_parser = StrOutputParser(**_config) diff --git a/libs/langchain/langchain/schema/chat.py b/libs/langchain/langchain/schema/chat.py new file mode 100644 index 0000000000..659e734853 --- /dev/null +++ b/libs/langchain/langchain/schema/chat.py @@ -0,0 +1,11 @@ +from typing import Sequence, TypedDict + +from langchain.schema import BaseMessage + + +class ChatSession(TypedDict): + """Chat Session represents a single + conversation, channel, or other group of messages.""" + + messages: Sequence[BaseMessage] + """The LangChain chat messages loaded from the source.""" diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index e76e16892a..45872a18c6 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -32,10 +32,9 @@ if TYPE_CHECKING: AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) + from langchain.callbacks.tracers.log_stream import RunLogPatch -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.tracers.log_stream import LogStreamCallbackHandler, RunLogPatch from langchain.load.dump import dumpd from langchain.load.serializable import Serializable from langchain.pydantic_v1 import Field @@ -216,6 +215,12 @@ class Runnable(Generic[Input, Output], ABC): The jsonpatch ops can be applied in order to construct state. """ + from langchain.callbacks.base import BaseCallbackManager + from langchain.callbacks.tracers.log_stream import ( + LogStreamCallbackHandler, + RunLogPatch, + ) + # Create a stream handler that will emit Log objects stream = LogStreamCallbackHandler( auto_close=False, diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index b41f74583b..8a04a4ac18 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -1,4 +1,15 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) from tenacity import ( AsyncRetrying, @@ -10,14 +21,16 @@ from tenacity import ( wait_exponential_jitter, ) -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) from langchain.schema.runnable.base import Input, Output, RunnableBinding from langchain.schema.runnable.config import RunnableConfig, patch_config -T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun) +if TYPE_CHECKING: + from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + ) + + T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun) U = TypeVar("U") @@ -54,7 +67,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): def _patch_config( self, config: RunnableConfig, - run_manager: T, + run_manager: "T", retry_state: RetryCallState, ) -> RunnableConfig: attempt = retry_state.attempt_number @@ -64,7 +77,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): def _patch_config_list( self, config: List[RunnableConfig], - run_manager: List[T], + run_manager: List["T"], retry_state: RetryCallState, ) -> List[RunnableConfig]: return [ @@ -74,7 +87,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): def _invoke( self, input: Input, - run_manager: CallbackManagerForChainRun, + run_manager: "CallbackManagerForChainRun", config: RunnableConfig, ) -> Output: for attempt in self._sync_retrying(reraise=True): @@ -95,7 +108,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): async def _ainvoke( self, input: Input, - run_manager: AsyncCallbackManagerForChainRun, + run_manager: "AsyncCallbackManagerForChainRun", config: RunnableConfig, ) -> Output: async for attempt in self._async_retrying(reraise=True): @@ -116,7 +129,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): def _batch( self, inputs: List[Input], - run_manager: List[CallbackManagerForChainRun], + run_manager: List["CallbackManagerForChainRun"], config: List[RunnableConfig], ) -> List[Union[Output, Exception]]: results_map: Dict[int, Output] = {} @@ -180,7 +193,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): async def _abatch( self, inputs: List[Input], - run_manager: List[AsyncCallbackManagerForChainRun], + run_manager: List["AsyncCallbackManagerForChainRun"], config: List[RunnableConfig], ) -> List[Union[Output, Exception]]: results_map: Dict[int, Output] = {} diff --git a/libs/langchain/langchain/utilities/apify.py b/libs/langchain/langchain/utilities/apify.py index dd7ddcd01d..e10e96a1e5 100644 --- a/libs/langchain/langchain/utilities/apify.py +++ b/libs/langchain/langchain/utilities/apify.py @@ -1,10 +1,12 @@ -from typing import Any, Callable, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional -from langchain.document_loaders import ApifyDatasetLoader -from langchain.document_loaders.base import Document from langchain.pydantic_v1 import BaseModel, root_validator +from langchain.schema.document import Document from langchain.utils import get_from_dict_or_env +if TYPE_CHECKING: + from langchain.document_loaders import ApifyDatasetLoader + class ApifyWrapper(BaseModel): """Wrapper around Apify. @@ -48,7 +50,7 @@ class ApifyWrapper(BaseModel): build: Optional[str] = None, memory_mbytes: Optional[int] = None, timeout_secs: Optional[int] = None, - ) -> ApifyDatasetLoader: + ) -> "ApifyDatasetLoader": """Run an Actor on the Apify platform and wait for results to be ready. Args: actor_id (str): The ID or name of the Actor on the Apify platform. @@ -65,6 +67,8 @@ class ApifyWrapper(BaseModel): ApifyDatasetLoader: A loader that will fetch the records from the Actor run's default dataset. """ + from langchain.document_loaders import ApifyDatasetLoader + actor_call = self.apify_client.actor(actor_id).call( run_input=run_input, build=build, @@ -86,7 +90,7 @@ class ApifyWrapper(BaseModel): build: Optional[str] = None, memory_mbytes: Optional[int] = None, timeout_secs: Optional[int] = None, - ) -> ApifyDatasetLoader: + ) -> "ApifyDatasetLoader": """Run an Actor on the Apify platform and wait for results to be ready. Args: actor_id (str): The ID or name of the Actor on the Apify platform. @@ -103,6 +107,8 @@ class ApifyWrapper(BaseModel): ApifyDatasetLoader: A loader that will fetch the records from the Actor run's default dataset. """ + from langchain.document_loaders import ApifyDatasetLoader + actor_call = await self.apify_client_async.actor(actor_id).call( run_input=run_input, build=build, @@ -124,7 +130,7 @@ class ApifyWrapper(BaseModel): build: Optional[str] = None, memory_mbytes: Optional[int] = None, timeout_secs: Optional[int] = None, - ) -> ApifyDatasetLoader: + ) -> "ApifyDatasetLoader": """Run a saved Actor task on Apify and wait for results to be ready. Args: task_id (str): The ID or name of the task on the Apify platform. @@ -142,6 +148,8 @@ class ApifyWrapper(BaseModel): ApifyDatasetLoader: A loader that will fetch the records from the task run's default dataset. """ + from langchain.document_loaders import ApifyDatasetLoader + task_call = self.apify_client.task(task_id).call( task_input=task_input, build=build, @@ -163,7 +171,7 @@ class ApifyWrapper(BaseModel): build: Optional[str] = None, memory_mbytes: Optional[int] = None, timeout_secs: Optional[int] = None, - ) -> ApifyDatasetLoader: + ) -> "ApifyDatasetLoader": """Run a saved Actor task on Apify and wait for results to be ready. Args: task_id (str): The ID or name of the task on the Apify platform. @@ -181,6 +189,8 @@ class ApifyWrapper(BaseModel): ApifyDatasetLoader: A loader that will fetch the records from the task run's default dataset. """ + from langchain.document_loaders import ApifyDatasetLoader + task_call = await self.apify_client_async.task(task_id).call( task_input=task_input, build=build, diff --git a/libs/langchain/langchain/utilities/loading.py b/libs/langchain/langchain/utilities/loading.py index 60f3e3cf7d..ea46982495 100644 --- a/libs/langchain/langchain/utilities/loading.py +++ b/libs/langchain/langchain/utilities/loading.py @@ -1,54 +1,4 @@ -"""Utilities for loading configurations from langchain-hub.""" +from langchain.utils.loading import try_load_from_hub -import os -import re -import tempfile -from pathlib import Path, PurePosixPath -from typing import Any, Callable, Optional, Set, TypeVar, Union -from urllib.parse import urljoin - -import requests - -DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master") -URL_BASE = os.environ.get( - "LANGCHAIN_HUB_URL_BASE", - "https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/", -) -HUB_PATH_RE = re.compile(r"lc(?P@[^:]+)?://(?P.*)") - -T = TypeVar("T") - - -def try_load_from_hub( - path: Union[str, Path], - loader: Callable[[str], T], - valid_prefix: str, - valid_suffixes: Set[str], - **kwargs: Any, -) -> Optional[T]: - """Load configuration from hub. Returns None if path is not a hub path.""" - if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): - return None - ref, remote_path_str = match.groups() - ref = ref[1:] if ref else DEFAULT_REF - remote_path = Path(remote_path_str) - if remote_path.parts[0] != valid_prefix: - return None - if remote_path.suffix[1:] not in valid_suffixes: - raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") - - # Using Path with URLs is not recommended, because on Windows - # the backslash is used as the path separator, which can cause issues - # when working with URLs that use forward slashes as the path separator. - # Instead, use PurePosixPath to ensure that forward slashes are used as the - # path separator, regardless of the operating system. - full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__()) - - r = requests.get(full_url, timeout=5) - if r.status_code != 200: - raise ValueError(f"Could not find file at {full_url}") - with tempfile.TemporaryDirectory() as tmpdirname: - file = Path(tmpdirname) / remote_path.name - with open(file, "wb") as f: - f.write(r.content) - return loader(str(file), **kwargs) +# For backwards compatibility +__all__ = ["try_load_from_hub"] diff --git a/libs/langchain/langchain/utils/loading.py b/libs/langchain/langchain/utils/loading.py new file mode 100644 index 0000000000..60f3e3cf7d --- /dev/null +++ b/libs/langchain/langchain/utils/loading.py @@ -0,0 +1,54 @@ +"""Utilities for loading configurations from langchain-hub.""" + +import os +import re +import tempfile +from pathlib import Path, PurePosixPath +from typing import Any, Callable, Optional, Set, TypeVar, Union +from urllib.parse import urljoin + +import requests + +DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master") +URL_BASE = os.environ.get( + "LANGCHAIN_HUB_URL_BASE", + "https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/", +) +HUB_PATH_RE = re.compile(r"lc(?P@[^:]+)?://(?P.*)") + +T = TypeVar("T") + + +def try_load_from_hub( + path: Union[str, Path], + loader: Callable[[str], T], + valid_prefix: str, + valid_suffixes: Set[str], + **kwargs: Any, +) -> Optional[T]: + """Load configuration from hub. Returns None if path is not a hub path.""" + if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): + return None + ref, remote_path_str = match.groups() + ref = ref[1:] if ref else DEFAULT_REF + remote_path = Path(remote_path_str) + if remote_path.parts[0] != valid_prefix: + return None + if remote_path.suffix[1:] not in valid_suffixes: + raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") + + # Using Path with URLs is not recommended, because on Windows + # the backslash is used as the path separator, which can cause issues + # when working with URLs that use forward slashes as the path separator. + # Instead, use PurePosixPath to ensure that forward slashes are used as the + # path separator, regardless of the operating system. + full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__()) + + r = requests.get(full_url, timeout=5) + if r.status_code != 200: + raise ValueError(f"Could not find file at {full_url}") + with tempfile.TemporaryDirectory() as tmpdirname: + file = Path(tmpdirname) / remote_path.name + with open(file, "wb") as f: + f.write(r.content) + return loader(str(file), **kwargs) diff --git a/libs/langchain/langchain/vectorstores/matching_engine.py b/libs/langchain/langchain/vectorstores/matching_engine.py index 1e197fe819..2c725969fe 100644 --- a/libs/langchain/langchain/vectorstores/matching_engine.py +++ b/libs/langchain/langchain/vectorstores/matching_engine.py @@ -6,8 +6,7 @@ import time import uuid from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type -from langchain.docstore.document import Document -from langchain.embeddings import TensorflowHubEmbeddings +from langchain.schema.document import Document from langchain.schema.embeddings import Embeddings from langchain.schema.vectorstore import VectorStore @@ -16,6 +15,8 @@ if TYPE_CHECKING: from google.cloud.aiplatform import MatchingEngineIndex, MatchingEngineIndexEndpoint from google.oauth2.service_account import Credentials + from langchain.embeddings import TensorflowHubEmbeddings + logger = logging.getLogger() @@ -443,10 +444,13 @@ class MatchingEngine(VectorStore): ) @classmethod - def _get_default_embeddings(cls) -> TensorflowHubEmbeddings: + def _get_default_embeddings(cls) -> "TensorflowHubEmbeddings": """This function returns the default embedding. Returns: Default TensorflowHubEmbeddings to use. """ + + from langchain.embeddings import TensorflowHubEmbeddings + return TensorflowHubEmbeddings() diff --git a/libs/langchain/langchain/vectorstores/timescalevector.py b/libs/langchain/langchain/vectorstores/timescalevector.py index 4f455e2b9e..50f76ee6b8 100644 --- a/libs/langchain/langchain/vectorstores/timescalevector.py +++ b/libs/langchain/langchain/vectorstores/timescalevector.py @@ -18,8 +18,8 @@ from typing import ( Union, ) -from langchain.docstore.document import Document -from langchain.embeddings.base import Embeddings +from langchain.schema.document import Document +from langchain.schema.embeddings import Embeddings from langchain.schema.vectorstore import VectorStore from langchain.utils import get_from_dict_or_env from langchain.vectorstores.utils import DistanceStrategy diff --git a/libs/langchain/scripts/check_imports.sh b/libs/langchain/scripts/check_imports.sh index 2ac2f48bb5..27e6e988da 100755 --- a/libs/langchain/scripts/check_imports.sh +++ b/libs/langchain/scripts/check_imports.sh @@ -2,13 +2,31 @@ set -eu -git grep 'from langchain import' langchain | grep -vE 'from langchain import (__version__|hub)' && exit 1 || exit 0 +# Initialize a variable to keep track of errors +errors=0 -# Pydantic bridge should not import from any other module -git grep 'from langchain ' langchain/pydantic_v1 && exit 1 || exit 0 +# Check the conditions +git grep '^from langchain import' langchain | grep -vE 'from langchain import (__version__|hub)' && errors=$((errors+1)) +git grep '^from langchain ' langchain/pydantic_v1 | grep -vE 'from langchain.(pydantic_v1)' && errors=$((errors+1)) +git grep '^from langchain' langchain/load | grep -vE 'from langchain.(pydantic_v1|load)' && errors=$((errors+1)) +git grep '^from langchain' langchain/utils | grep -vE 'from langchain.(pydantic_v1|utils)' && errors=$((errors+1)) +git grep '^from langchain' langchain/schema | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1)) +git grep '^from langchain' langchain/adapters | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1)) +git grep '^from langchain' langchain/callbacks | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env)' && errors=$((errors+1)) +# TODO: it's probably not amazing so that so many other modules depend on `langchain.utilities`, because there can be a lot of imports there +git grep '^from langchain' langchain/utilities | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|utilities)' && errors=$((errors+1)) +git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1)) +git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1)) +git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1)) +git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities)' && errors=$((errors+1)) +git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities)' && errors=$((errors+1)) +git grep '^from langchain' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1)) +git grep '^from langchain' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1)) +git grep '^from langchain' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1)) -# load should not import from anything except itself and pydantic_v1 -git grep 'from langchain' langchain/load | grep -vE 'from langchain.(pydantic_v1)' && exit 1 || exit 0 - -# utils should not import from anything except itself and pydantic_v1 -git grep 'from langchain' langchain/utils | grep -vE 'from langchain.(pydantic_v1|utils)' && exit 1 || exit 0 +# Decide on an exit status based on the errors +if [ "$errors" -gt 0 ]; then + exit 1 +else + exit 0 +fi diff --git a/libs/langchain/tests/unit_tests/utilities/test_loading.py b/libs/langchain/tests/unit_tests/utilities/test_loading.py index f9a275fd81..c74df087d1 100644 --- a/libs/langchain/tests/unit_tests/utilities/test_loading.py +++ b/libs/langchain/tests/unit_tests/utilities/test_loading.py @@ -10,7 +10,7 @@ from urllib.parse import urljoin import pytest import responses -from langchain.utilities.loading import DEFAULT_REF, URL_BASE, try_load_from_hub +from langchain.utils.loading import DEFAULT_REF, URL_BASE, try_load_from_hub @pytest.fixture(autouse=True)