add more import checks (#11033)

pull/11140/head
Harrison Chase 11 months ago committed by GitHub
parent efb7c459a2
commit e355606b11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,7 +15,7 @@ from typing import (
from typing_extensions import Literal
from langchain.chat_loaders.base import ChatSession
from langchain.schema.chat import ChatSession
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,

@ -1,15 +1,7 @@
from abc import ABC, abstractmethod
from typing import Iterator, List, Sequence, TypedDict
from typing import Iterator, List
from langchain.schema.messages import BaseMessage
class ChatSession(TypedDict):
"""Chat Session represents a single
conversation, channel, or other group of messages."""
messages: Sequence[BaseMessage]
"""The LangChain chat messages loaded from the source."""
from langchain.schema.chat import ChatSession
class BaseChatLoader(ABC):

@ -3,7 +3,8 @@ import logging
from pathlib import Path
from typing import Iterator, Union
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema.chat import ChatSession
from langchain.schema.messages import HumanMessage
logger = logging.getLogger(__file__)

@ -2,7 +2,8 @@ import base64
import re
from typing import Any, Iterator
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema.chat import ChatSession
from langchain.schema.messages import HumanMessage

@ -3,8 +3,9 @@ from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema import HumanMessage
from langchain.schema.chat import ChatSession
if TYPE_CHECKING:
import sqlite3

@ -5,8 +5,9 @@ import zipfile
from pathlib import Path
from typing import Dict, Iterator, List, Union
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema import AIMessage, HumanMessage
from langchain.schema.chat import ChatSession
logger = logging.getLogger(__name__)

@ -6,8 +6,9 @@ import zipfile
from pathlib import Path
from typing import Iterator, List, Union
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema import AIMessage, BaseMessage, HumanMessage
from langchain.schema.chat import ChatSession
logger = logging.getLogger(__name__)

@ -2,7 +2,7 @@
from copy import deepcopy
from typing import Iterable, Iterator, List
from langchain.chat_loaders.base import ChatSession
from langchain.schema.chat import ChatSession
from langchain.schema.messages import AIMessage, BaseMessage

@ -4,8 +4,9 @@ import re
import zipfile
from typing import Iterator, List, Union
from langchain.chat_loaders.base import BaseChatLoader, ChatSession
from langchain.chat_loaders.base import BaseChatLoader
from langchain.schema import AIMessage, HumanMessage
from langchain.schema.chat import ChatSession
logger = logging.getLogger(__name__)

@ -1,8 +1,8 @@
from typing import Any, Dict, List, Optional
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain.requests import Requests
from langchain.schema.embeddings import Embeddings
from langchain.utilities.requests import Requests
from langchain.utils import get_from_dict_or_env

@ -11,7 +11,7 @@ from langchain.callbacks.manager import (
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.requests import Requests
from langchain.utilities.requests import Requests
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)

@ -5,10 +5,10 @@ import warnings
from abc import ABC
from typing import Any, Callable, Dict, List, Set
from langchain.formatting import formatter
from langchain.schema.messages import BaseMessage, HumanMessage
from langchain.schema.prompt import PromptValue
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.utils.formatting import formatter
def jinja2_formatter(template: str, **kwargs: Any) -> str:

@ -6,11 +6,10 @@ from typing import Callable, Dict, Union
import yaml
from langchain.output_parsers.regex import RegexParser
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, StrOutputParser
from langchain.utilities.loading import try_load_from_hub
from langchain.utils.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
logger = logging.getLogger(__name__)
@ -77,6 +76,8 @@ def _load_output_parser(config: dict) -> dict:
_config = config.pop("output_parser")
output_parser_type = _config.pop("_type")
if output_parser_type == "regex_parser":
from langchain.output_parsers.regex import RegexParser
output_parser: BaseLLMOutputParser = RegexParser(**_config)
elif output_parser_type == "default":
output_parser = StrOutputParser(**_config)

@ -0,0 +1,11 @@
from typing import Sequence, TypedDict
from langchain.schema import BaseMessage
class ChatSession(TypedDict):
"""Chat Session represents a single
conversation, channel, or other group of messages."""
messages: Sequence[BaseMessage]
"""The LangChain chat messages loaded from the source."""

@ -32,10 +32,9 @@ if TYPE_CHECKING:
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.callbacks.tracers.log_stream import RunLogPatch
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.tracers.log_stream import LogStreamCallbackHandler, RunLogPatch
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field
@ -216,6 +215,12 @@ class Runnable(Generic[Input, Output], ABC):
The jsonpatch ops can be applied in order to construct state.
"""
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.tracers.log_stream import (
LogStreamCallbackHandler,
RunLogPatch,
)
# Create a stream handler that will emit Log objects
stream = LogStreamCallbackHandler(
auto_close=False,

@ -1,4 +1,15 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from tenacity import (
AsyncRetrying,
@ -10,14 +21,16 @@ from tenacity import (
wait_exponential_jitter,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.schema.runnable.base import Input, Output, RunnableBinding
from langchain.schema.runnable.config import RunnableConfig, patch_config
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
U = TypeVar("U")
@ -54,7 +67,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
def _patch_config(
self,
config: RunnableConfig,
run_manager: T,
run_manager: "T",
retry_state: RetryCallState,
) -> RunnableConfig:
attempt = retry_state.attempt_number
@ -64,7 +77,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
def _patch_config_list(
self,
config: List[RunnableConfig],
run_manager: List[T],
run_manager: List["T"],
retry_state: RetryCallState,
) -> List[RunnableConfig]:
return [
@ -74,7 +87,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
def _invoke(
self,
input: Input,
run_manager: CallbackManagerForChainRun,
run_manager: "CallbackManagerForChainRun",
config: RunnableConfig,
) -> Output:
for attempt in self._sync_retrying(reraise=True):
@ -95,7 +108,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
async def _ainvoke(
self,
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
run_manager: "AsyncCallbackManagerForChainRun",
config: RunnableConfig,
) -> Output:
async for attempt in self._async_retrying(reraise=True):
@ -116,7 +129,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
def _batch(
self,
inputs: List[Input],
run_manager: List[CallbackManagerForChainRun],
run_manager: List["CallbackManagerForChainRun"],
config: List[RunnableConfig],
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
@ -180,7 +193,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
async def _abatch(
self,
inputs: List[Input],
run_manager: List[AsyncCallbackManagerForChainRun],
run_manager: List["AsyncCallbackManagerForChainRun"],
config: List[RunnableConfig],
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}

@ -1,10 +1,12 @@
from typing import Any, Callable, Dict, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from langchain.document_loaders import ApifyDatasetLoader
from langchain.document_loaders.base import Document
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema.document import Document
from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
from langchain.document_loaders import ApifyDatasetLoader
class ApifyWrapper(BaseModel):
"""Wrapper around Apify.
@ -48,7 +50,7 @@ class ApifyWrapper(BaseModel):
build: Optional[str] = None,
memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None,
) -> ApifyDatasetLoader:
) -> "ApifyDatasetLoader":
"""Run an Actor on the Apify platform and wait for results to be ready.
Args:
actor_id (str): The ID or name of the Actor on the Apify platform.
@ -65,6 +67,8 @@ class ApifyWrapper(BaseModel):
ApifyDatasetLoader: A loader that will fetch the records from the
Actor run's default dataset.
"""
from langchain.document_loaders import ApifyDatasetLoader
actor_call = self.apify_client.actor(actor_id).call(
run_input=run_input,
build=build,
@ -86,7 +90,7 @@ class ApifyWrapper(BaseModel):
build: Optional[str] = None,
memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None,
) -> ApifyDatasetLoader:
) -> "ApifyDatasetLoader":
"""Run an Actor on the Apify platform and wait for results to be ready.
Args:
actor_id (str): The ID or name of the Actor on the Apify platform.
@ -103,6 +107,8 @@ class ApifyWrapper(BaseModel):
ApifyDatasetLoader: A loader that will fetch the records from the
Actor run's default dataset.
"""
from langchain.document_loaders import ApifyDatasetLoader
actor_call = await self.apify_client_async.actor(actor_id).call(
run_input=run_input,
build=build,
@ -124,7 +130,7 @@ class ApifyWrapper(BaseModel):
build: Optional[str] = None,
memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None,
) -> ApifyDatasetLoader:
) -> "ApifyDatasetLoader":
"""Run a saved Actor task on Apify and wait for results to be ready.
Args:
task_id (str): The ID or name of the task on the Apify platform.
@ -142,6 +148,8 @@ class ApifyWrapper(BaseModel):
ApifyDatasetLoader: A loader that will fetch the records from the
task run's default dataset.
"""
from langchain.document_loaders import ApifyDatasetLoader
task_call = self.apify_client.task(task_id).call(
task_input=task_input,
build=build,
@ -163,7 +171,7 @@ class ApifyWrapper(BaseModel):
build: Optional[str] = None,
memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None,
) -> ApifyDatasetLoader:
) -> "ApifyDatasetLoader":
"""Run a saved Actor task on Apify and wait for results to be ready.
Args:
task_id (str): The ID or name of the task on the Apify platform.
@ -181,6 +189,8 @@ class ApifyWrapper(BaseModel):
ApifyDatasetLoader: A loader that will fetch the records from the
task run's default dataset.
"""
from langchain.document_loaders import ApifyDatasetLoader
task_call = await self.apify_client_async.task(task_id).call(
task_input=task_input,
build=build,

@ -1,54 +1,4 @@
"""Utilities for loading configurations from langchain-hub."""
from langchain.utils.loading import try_load_from_hub
import os
import re
import tempfile
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Optional, Set, TypeVar, Union
from urllib.parse import urljoin
import requests
DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
URL_BASE = os.environ.get(
"LANGCHAIN_HUB_URL_BASE",
"https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/",
)
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)")
T = TypeVar("T")
def try_load_from_hub(
path: Union[str, Path],
loader: Callable[[str], T],
valid_prefix: str,
valid_suffixes: Set[str],
**kwargs: Any,
) -> Optional[T]:
"""Load configuration from hub. Returns None if path is not a hub path."""
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
return None
ref, remote_path_str = match.groups()
ref = ref[1:] if ref else DEFAULT_REF
remote_path = Path(remote_path_str)
if remote_path.parts[0] != valid_prefix:
return None
if remote_path.suffix[1:] not in valid_suffixes:
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
# Using Path with URLs is not recommended, because on Windows
# the backslash is used as the path separator, which can cause issues
# when working with URLs that use forward slashes as the path separator.
# Instead, use PurePosixPath to ensure that forward slashes are used as the
# path separator, regardless of the operating system.
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())
r = requests.get(full_url, timeout=5)
if r.status_code != 200:
raise ValueError(f"Could not find file at {full_url}")
with tempfile.TemporaryDirectory() as tmpdirname:
file = Path(tmpdirname) / remote_path.name
with open(file, "wb") as f:
f.write(r.content)
return loader(str(file), **kwargs)
# For backwards compatibility
__all__ = ["try_load_from_hub"]

@ -0,0 +1,54 @@
"""Utilities for loading configurations from langchain-hub."""
import os
import re
import tempfile
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Optional, Set, TypeVar, Union
from urllib.parse import urljoin
import requests
DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
URL_BASE = os.environ.get(
"LANGCHAIN_HUB_URL_BASE",
"https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/",
)
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)")
T = TypeVar("T")
def try_load_from_hub(
path: Union[str, Path],
loader: Callable[[str], T],
valid_prefix: str,
valid_suffixes: Set[str],
**kwargs: Any,
) -> Optional[T]:
"""Load configuration from hub. Returns None if path is not a hub path."""
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
return None
ref, remote_path_str = match.groups()
ref = ref[1:] if ref else DEFAULT_REF
remote_path = Path(remote_path_str)
if remote_path.parts[0] != valid_prefix:
return None
if remote_path.suffix[1:] not in valid_suffixes:
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
# Using Path with URLs is not recommended, because on Windows
# the backslash is used as the path separator, which can cause issues
# when working with URLs that use forward slashes as the path separator.
# Instead, use PurePosixPath to ensure that forward slashes are used as the
# path separator, regardless of the operating system.
full_url = urljoin(URL_BASE.format(ref=ref), PurePosixPath(remote_path).__str__())
r = requests.get(full_url, timeout=5)
if r.status_code != 200:
raise ValueError(f"Could not find file at {full_url}")
with tempfile.TemporaryDirectory() as tmpdirname:
file = Path(tmpdirname) / remote_path.name
with open(file, "wb") as f:
f.write(r.content)
return loader(str(file), **kwargs)

@ -6,8 +6,7 @@ import time
import uuid
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type
from langchain.docstore.document import Document
from langchain.embeddings import TensorflowHubEmbeddings
from langchain.schema.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
@ -16,6 +15,8 @@ if TYPE_CHECKING:
from google.cloud.aiplatform import MatchingEngineIndex, MatchingEngineIndexEndpoint
from google.oauth2.service_account import Credentials
from langchain.embeddings import TensorflowHubEmbeddings
logger = logging.getLogger()
@ -443,10 +444,13 @@ class MatchingEngine(VectorStore):
)
@classmethod
def _get_default_embeddings(cls) -> TensorflowHubEmbeddings:
def _get_default_embeddings(cls) -> "TensorflowHubEmbeddings":
"""This function returns the default embedding.
Returns:
Default TensorflowHubEmbeddings to use.
"""
from langchain.embeddings import TensorflowHubEmbeddings
return TensorflowHubEmbeddings()

@ -18,8 +18,8 @@ from typing import (
Union,
)
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.utils import DistanceStrategy

@ -2,13 +2,31 @@
set -eu
git grep 'from langchain import' langchain | grep -vE 'from langchain import (__version__|hub)' && exit 1 || exit 0
# Initialize a variable to keep track of errors
errors=0
# Pydantic bridge should not import from any other module
git grep 'from langchain ' langchain/pydantic_v1 && exit 1 || exit 0
# Check the conditions
git grep '^from langchain import' langchain | grep -vE 'from langchain import (__version__|hub)' && errors=$((errors+1))
git grep '^from langchain ' langchain/pydantic_v1 | grep -vE 'from langchain.(pydantic_v1)' && errors=$((errors+1))
git grep '^from langchain' langchain/load | grep -vE 'from langchain.(pydantic_v1|load)' && errors=$((errors+1))
git grep '^from langchain' langchain/utils | grep -vE 'from langchain.(pydantic_v1|utils)' && errors=$((errors+1))
git grep '^from langchain' langchain/schema | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1))
git grep '^from langchain' langchain/adapters | grep -vE 'from langchain.(pydantic_v1|utils|schema|load)' && errors=$((errors+1))
git grep '^from langchain' langchain/callbacks | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env)' && errors=$((errors+1))
# TODO: it's probably not amazing so that so many other modules depend on `langchain.utilities`, because there can be a lot of imports there
git grep '^from langchain' langchain/utilities | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1))
git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1))
git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/embeddings | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|llms|embeddings|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/docstore | grep -vE 'from langchain.(pydantic_v1|utils|schema|docstore)' && errors=$((errors+1))
git grep '^from langchain' langchain/vectorstores | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|_api|storage|llms|docstore|vectorstores|utilities)' && errors=$((errors+1))
# load should not import from anything except itself and pydantic_v1
git grep 'from langchain' langchain/load | grep -vE 'from langchain.(pydantic_v1)' && exit 1 || exit 0
# utils should not import from anything except itself and pydantic_v1
git grep 'from langchain' langchain/utils | grep -vE 'from langchain.(pydantic_v1|utils)' && exit 1 || exit 0
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

@ -10,7 +10,7 @@ from urllib.parse import urljoin
import pytest
import responses
from langchain.utilities.loading import DEFAULT_REF, URL_BASE, try_load_from_hub
from langchain.utils.loading import DEFAULT_REF, URL_BASE, try_load_from_hub
@pytest.fixture(autouse=True)

Loading…
Cancel
Save