core[patch]: Add B(bugbear) ruff rules (#25520)

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Christophe Bornet 2024-08-28 09:09:29 +02:00 committed by GitHub
parent d5ddaac1fc
commit ff0df5ea15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 149 additions and 112 deletions

View File

@ -270,7 +270,7 @@ def warn_beta(
message += f" {addendum}"
warning = LangChainBetaWarning(message)
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=2)
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=4)
def surface_langchain_beta_warnings() -> None:

View File

@ -444,7 +444,7 @@ def warn_deprecated(
LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning
)
warning = warning_cls(message)
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2)
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=4)
def surface_langchain_deprecation_warnings() -> None:

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
from langchain_core.documents.base import Blob
class BaseLoader(ABC):
class BaseLoader(ABC): # noqa: B024
"""Interface for Document Loader.
Implementations should implement the lazy-loading method using generators

View File

@ -80,12 +80,12 @@ def _texts_to_nodes(
for text in texts:
try:
_metadata = next(metadatas_it).copy() if metadatas_it else {}
except StopIteration:
raise ValueError("texts iterable longer than metadatas")
except StopIteration as e:
raise ValueError("texts iterable longer than metadatas") from e
try:
_id = next(ids_it) if ids_it else None
except StopIteration:
raise ValueError("texts iterable longer than ids")
except StopIteration as e:
raise ValueError("texts iterable longer than ids") from e
links = _metadata.pop(METADATA_LINKS_KEY, [])
if not isinstance(links, list):

View File

@ -91,7 +91,7 @@ class _HashedDocument(Document):
raise ValueError(
f"Failed to hash metadata: {e}. "
f"Please use a dict that can be serialized using json."
)
) from e
values["content_hash"] = content_hash
values["metadata_hash"] = metadata_hash

View File

@ -64,12 +64,12 @@ def get_tokenizer() -> Any:
"""
try:
from transformers import GPT2TokenizerFast # type: ignore[import]
except ImportError:
except ImportError as e:
raise ImportError(
"Could not import transformers python package. "
"This is needed in order to calculate get_token_ids. "
"Please install it with `pip install transformers`."
)
) from e
# create a GPT-2 tokenizer instance
return GPT2TokenizerFast.from_pretrained("gpt2")

View File

@ -235,6 +235,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
stacklevel=5,
)
values["callbacks"] = values.pop("callback_manager", None)
return values

View File

@ -69,6 +69,10 @@ class FakeListLLM(LLM):
return {"responses": self.responses}
class FakeListLLMError(Exception):
"""Fake error for testing purposes."""
class FakeStreamingListLLM(FakeListLLM):
"""Fake streaming list LLM for testing purposes.
@ -98,7 +102,7 @@ class FakeStreamingListLLM(FakeListLLM):
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
raise FakeListLLMError
yield c
async def astream(
@ -118,5 +122,5 @@ class FakeStreamingListLLM(FakeListLLM):
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
raise FakeListLLMError
yield c

View File

@ -44,6 +44,10 @@ class FakeMessagesListChatModel(BaseChatModel):
return "fake-messages-list-chat-model"
class FakeListChatModelError(Exception):
pass
class FakeListChatModel(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
@ -93,7 +97,7 @@ class FakeListChatModel(SimpleChatModel):
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
raise FakeListChatModelError
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@ -116,7 +120,7 @@ class FakeListChatModel(SimpleChatModel):
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
raise FakeListChatModelError
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@property

View File

@ -114,7 +114,6 @@ def create_base_retry_decorator(
_log_error_once(f"Error in on_retry: {e}")
else:
run_manager.on_retry(retry_state)
return None
min_seconds = 4
max_seconds = 10
@ -311,6 +310,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
stacklevel=5,
)
values["callbacks"] = values.pop("callback_manager", None)
return values

View File

@ -290,10 +290,10 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
msg_type = msg_kwargs.pop("type")
# None msg content is not allowed
msg_content = msg_kwargs.pop("content") or ""
except KeyError:
except KeyError as e:
raise ValueError(
f"Message dict must contain 'role' and 'content' keys, got {message}"
)
) from e
_message = _create_message_from_message_type(
msg_type, msg_content, **msg_kwargs
)
@ -344,9 +344,7 @@ def _runnable_support(func: Callable) -> Callable:
if messages is not None:
return func(messages, **kwargs)
else:
return RunnableLambda(
partial(func, **kwargs), name=getattr(func, "__name__")
)
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
wrapped.__doc__ = func.__doc__
return wrapped
@ -791,7 +789,7 @@ def trim_messages(
raise ValueError
messages = convert_to_messages(messages)
if hasattr(token_counter, "get_num_tokens_from_messages"):
list_token_counter = getattr(token_counter, "get_num_tokens_from_messages")
list_token_counter = token_counter.get_num_tokens_from_messages
elif callable(token_counter):
if (
list(inspect.signature(token_counter).parameters.values())[0].annotation

View File

@ -42,7 +42,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
try:
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
except KeyError as exc:
raise OutputParserException(f"Could not parse function call: {exc}")
raise OutputParserException(
f"Could not parse function call: {exc}"
) from exc
if self.args_only:
return func_call["arguments"]
@ -100,7 +102,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
if partial:
return None
else:
raise OutputParserException(f"Could not parse function call: {exc}")
raise OutputParserException(
f"Could not parse function call: {exc}"
) from exc
try:
if partial:
try:
@ -126,7 +130,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
) from exc
else:
try:
return {
@ -138,7 +142,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
) from exc
except KeyError:
return None

View File

@ -55,7 +55,7 @@ def parse_tool_call(
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
f"Received JSONDecodeError {e}"
)
) from e
parsed = {
"name": raw_tool_call["function"]["name"] or "",
"args": function_args or {},

View File

@ -32,12 +32,12 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
{self.pydantic_object.__class__}"
)
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj)
raise self._parser_exception(e, obj) from e
else: # pydantic v1
try:
return self.pydantic_object.parse_obj(obj)
except pydantic.ValidationError as e:
raise self._parser_exception(e, obj)
raise self._parser_exception(e, obj) from e
def _parser_exception(
self, e: Exception, json_object: dict

View File

@ -46,12 +46,12 @@ class _StreamingParser:
if parser == "defusedxml":
try:
from defusedxml import ElementTree as DET # type: ignore
except ImportError:
except ImportError as e:
raise ImportError(
"defusedxml is not installed. "
"Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml` "
)
) from e
_parser = DET.DefusedXMLParser(target=TreeBuilder())
else:
_parser = None
@ -189,13 +189,13 @@ class XMLOutputParser(BaseTransformOutputParser):
if self.parser == "defusedxml":
try:
from defusedxml import ElementTree as DET # type: ignore
except ImportError:
except ImportError as e:
raise ImportError(
"defusedxml is not installed. "
"Please install it to use the defusedxml parser."
"You can install it with `pip install defusedxml`"
"See https://github.com/tiran/defusedxml for more details"
)
) from e
_ET = DET # Use the defusedxml parser
else:
_ET = ET # Use the standard library parser

View File

@ -235,7 +235,9 @@ class PromptTemplate(StringPromptTemplate):
template = f.read()
if input_variables:
warnings.warn(
"`input_variables' is deprecated and ignored.", DeprecationWarning
"`input_variables' is deprecated and ignored.",
DeprecationWarning,
stacklevel=2,
)
return cls.from_template(template=template, **kwargs)

View File

@ -40,14 +40,14 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""
try:
from jinja2.sandbox import SandboxedEnvironment
except ImportError:
except ImportError as e:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
"Please be cautious when using jinja2 templates. "
"Do not expand jinja2 templates using unverified or user-controlled "
"inputs as that can result in arbitrary Python code execution."
)
) from e
# This uses a sandboxed environment to prevent arbitrary code execution.
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
@ -81,17 +81,17 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
warning_message += f"Extra variables: {extra_variables}"
if warning_message:
warnings.warn(warning_message.strip())
warnings.warn(warning_message.strip(), stacklevel=7)
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
except ImportError as e:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
) from e
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)

View File

@ -155,6 +155,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"Retrievers must implement abstract `_get_relevant_documents` method"
" instead of `get_relevant_documents`",
DeprecationWarning,
stacklevel=4,
)
swap = cls.get_relevant_documents
cls.get_relevant_documents = ( # type: ignore[assignment]
@ -169,6 +170,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"Retrievers must implement abstract `_aget_relevant_documents` method"
" instead of `aget_relevant_documents`",
DeprecationWarning,
stacklevel=4,
)
aswap = cls.aget_relevant_documents
cls.aget_relevant_documents = ( # type: ignore[assignment]

View File

@ -3915,7 +3915,7 @@ class RunnableGenerator(Runnable[Input, Output]):
@property
def InputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
func = getattr(self, "_transform", None) or self._atransform
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
@ -3928,7 +3928,7 @@ class RunnableGenerator(Runnable[Input, Output]):
@property
def OutputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
func = getattr(self, "_transform", None) or self._atransform
try:
sig = inspect.signature(func)
return (
@ -4152,7 +4152,7 @@ class RunnableLambda(Runnable[Input, Output]):
@property
def InputType(self) -> Any:
"""The type of the input to this Runnable."""
func = getattr(self, "func", None) or getattr(self, "afunc")
func = getattr(self, "func", None) or self.afunc
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
@ -4174,7 +4174,7 @@ class RunnableLambda(Runnable[Input, Output]):
Returns:
The input schema for this Runnable.
"""
func = getattr(self, "func", None) or getattr(self, "afunc")
func = getattr(self, "func", None) or self.afunc
if isinstance(func, itemgetter):
# This is terrible, but afaict it's not possible to access _items
@ -4212,7 +4212,7 @@ class RunnableLambda(Runnable[Input, Output]):
Returns:
The type of the output of this Runnable.
"""
func = getattr(self, "func", None) or getattr(self, "afunc")
func = getattr(self, "func", None) or self.afunc
try:
sig = inspect.signature(func)
if sig.return_annotation != inspect.Signature.empty:

View File

@ -236,6 +236,7 @@ def get_config_list(
warnings.warn(
"Provided run_id be used only for the first element of the batch.",
category=RuntimeWarning,
stacklevel=3,
)
subsequent = cast(
RunnableConfig, {k: v for k, v in config.items() if k != "run_id"}

View File

@ -537,7 +537,7 @@ class Graph:
*,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeStyles = NodeStyles(),
node_colors: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9,
) -> str:
"""Draw the graph as a Mermaid syntax string.
@ -573,7 +573,7 @@ class Graph:
self,
*,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeStyles = NodeStyles(),
node_colors: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9,
output_file_path: Optional[str] = None,
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,

View File

@ -20,7 +20,7 @@ def draw_mermaid(
last_node: Optional[str] = None,
with_styles: bool = True,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_styles: NodeStyles = NodeStyles(),
node_styles: Optional[NodeStyles] = None,
wrap_label_n_words: int = 9,
) -> str:
"""Draws a Mermaid graph using the provided graph data.
@ -153,7 +153,7 @@ def draw_mermaid(
# Add custom styles for nodes
if with_styles:
mermaid_graph += _generate_mermaid_graph_styles(node_styles)
mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles())
return mermaid_graph

View File

@ -218,6 +218,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = []
result = not_set
try:
for attempt in self._sync_retrying():
with attempt:
@ -247,9 +249,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
if result is not_set:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
@ -284,6 +284,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: List[Output] = []
result = not_set
try:
async for attempt in self._async_retrying():
with attempt:
@ -313,9 +315,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
if result is not_set:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []

View File

@ -748,7 +748,7 @@ def is_async_generator(
"""
return (
inspect.isasyncgenfunction(func)
or hasattr(func, "__call__")
or hasattr(func, "__call__") # noqa: B004
and inspect.isasyncgenfunction(func.__call__)
)
@ -767,6 +767,6 @@ def is_async_callable(
"""
return (
asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__")
or hasattr(func, "__call__") # noqa: B004
and asyncio.iscoroutinefunction(func.__call__)
)

View File

@ -443,6 +443,7 @@ class ChildTool(BaseTool):
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
stacklevel=6,
)
values["callbacks"] = values.pop("callback_manager", None)
return values

View File

@ -244,7 +244,7 @@ def _get_schema_from_runnable_and_arg_types(
"Tool input must be str or dict. If dict, dict arguments must be "
"typed. Either annotate types (e.g., with TypedDict) or pass "
f"arg_types into `.as_tool` to specify. {str(e)}"
)
) from e
fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()}
return create_model(name, **fields) # type: ignore

View File

@ -516,15 +516,19 @@ class _TracerCore(ABC):
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""End a trace for a run."""
return None
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process a run upon creation."""
return None
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process a run upon update."""
return None
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run upon start."""
return None
def _on_llm_new_token(
self,
@ -533,39 +537,52 @@ class _TracerCore(ABC):
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
) -> Union[None, Coroutine[Any, Any, None]]:
"""Process new LLM token."""
return None
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run."""
return None
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the LLM Run upon error."""
return None
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run upon start."""
return None
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run."""
return None
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chain Run upon error."""
return None
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run upon start."""
return None
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run."""
return None
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Tool Run upon error."""
return None
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Chat Model Run upon start."""
return None
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run upon start."""
return None
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run."""
return None
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
"""Process the Retriever Run upon error."""
return None

View File

@ -144,11 +144,7 @@ class EvaluatorCallbackHandler(BaseTracer):
example_id = str(run.reference_example_id)
with self.lock:
for res in eval_results:
run_id = (
str(getattr(res, "target_run_id"))
if hasattr(res, "target_run_id")
else str(run.id)
)
run_id = str(getattr(res, "target_run_id", run.id))
self.logged_eval_results.setdefault((run_id, example_id), []).append(
res
)
@ -179,11 +175,9 @@ class EvaluatorCallbackHandler(BaseTracer):
source_info_: Dict[str, Any] = {}
if res.evaluator_info:
source_info_ = {**res.evaluator_info, **source_info_}
run_id_ = (
getattr(res, "target_run_id")
if hasattr(res, "target_run_id") and res.target_run_id is not None
else run.id
)
run_id_ = getattr(res, "target_run_id", None)
if run_id_ is None:
run_id_ = run.id
self.client.create_feedback(
run_id_,
res.key,

View File

@ -22,6 +22,7 @@ def RunTypeEnum() -> Type[RunTypeEnumDep]:
"RunTypeEnum is deprecated. Please directly use a string instead"
" (e.g. 'llm', 'chain', 'tool').",
DeprecationWarning,
stacklevel=2,
)
return RunTypeEnumDep

View File

@ -62,8 +62,8 @@ def py_anext(
__anext__ = cast(
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
)
except AttributeError:
raise TypeError(f"{iterator!r} is not an async iterator")
except AttributeError as e:
raise TypeError(f"{iterator!r} is not an async iterator") from e
if default is _no_default:
return __anext__(iterator)

View File

@ -182,7 +182,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
raise OutputParserException(f"Got invalid JSON object. Error: {e}") from e
for key in expected_keys:
if key not in json_obj:
raise OutputParserException(

View File

@ -21,7 +21,9 @@ def try_load_from_hub(
) -> Any:
warnings.warn(
"Loading from the deprecated github-based Hub is no longer supported. "
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead."
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead.",
DeprecationWarning,
stacklevel=2,
)
# return None, which indicates that we shouldn't load from old hub
# and might just be a filepath for e.g. load_chain

View File

@ -3,13 +3,17 @@ Adapted from https://github.com/noahmorrison/chevron
MIT License
"""
from __future__ import annotations
import logging
from types import MappingProxyType
from typing import (
Any,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
@ -22,7 +26,7 @@ from typing_extensions import TypeAlias
logger = logging.getLogger(__name__)
Scopes: TypeAlias = List[Union[Literal[False, 0], Dict[str, Any]]]
Scopes: TypeAlias = List[Union[Literal[False, 0], Mapping[str, Any]]]
# Globals
@ -152,8 +156,8 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
# Get the tag
try:
tag, template = template.split(r_del, 1)
except ValueError:
raise ChevronError("unclosed tag " f"at line {_CURRENT_LINE}")
except ValueError as e:
raise ChevronError("unclosed tag " f"at line {_CURRENT_LINE}") from e
# Find the type meaning of the first character
tag_type = tag_types.get(tag[0], "variable")
@ -279,12 +283,12 @@ def tokenize(
# is the same as us
try:
last_section = open_sections.pop()
except IndexError:
except IndexError as e:
raise ChevronError(
f'Trying to close tag "{tag_key}"\n'
"Looks like it was not opened.\n"
f"line {_CURRENT_LINE + 1}"
)
) from e
if tag_key != last_section:
# Otherwise we need to complain
raise ChevronError(
@ -411,7 +415,7 @@ def _get_key(
return ""
def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str:
"""Load a partial"""
try:
# Maybe the partial is in the dictionary
@ -425,11 +429,13 @@ def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
#
g_token_cache: Dict[str, List[Tuple[str, str]]] = {}
EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({})
def render(
template: Union[str, List[Tuple[str, str]]] = "",
data: Dict[str, Any] = {},
partials_dict: Dict[str, str] = {},
data: Mapping[str, Any] = EMPTY_DICT,
partials_dict: Mapping[str, str] = EMPTY_DICT,
padding: str = "",
def_ldel: str = "{{",
def_rdel: str = "}}",

View File

@ -131,12 +131,12 @@ def guard_import(
"""
try:
module = importlib.import_module(module_name, package)
except (ImportError, ModuleNotFoundError):
except (ImportError, ModuleNotFoundError) as e:
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
raise ImportError(
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name}`."
)
) from e
return module
@ -235,7 +235,8 @@ def build_extra_kwargs(
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
Please confirm that {field_name} is what you intended.""",
stacklevel=7,
)
extra_kwargs[field_name] = values.pop(field_name)

View File

@ -558,7 +558,8 @@ class VectorStore(ABC):
):
warnings.warn(
"Relevance scores must be between"
f" 0 and 1, got {docs_and_similarities}"
f" 0 and 1, got {docs_and_similarities}",
stacklevel=2,
)
if score_threshold is not None:
@ -568,7 +569,7 @@ class VectorStore(ABC):
if similarity >= score_threshold
]
if len(docs_and_similarities) == 0:
warnings.warn(
logger.warning(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
@ -605,7 +606,8 @@ class VectorStore(ABC):
):
warnings.warn(
"Relevance scores must be between"
f" 0 and 1, got {docs_and_similarities}"
f" 0 and 1, got {docs_and_similarities}",
stacklevel=2,
)
if score_threshold is not None:
@ -615,7 +617,7 @@ class VectorStore(ABC):
if similarity >= score_threshold
]
if len(docs_and_similarities) == 0:
warnings.warn(
logger.warning(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)

View File

@ -435,11 +435,11 @@ class InMemoryVectorStore(VectorStore):
try:
import numpy as np
except ImportError:
except ImportError as e:
raise ImportError(
"numpy must be installed to use max_marginal_relevance_search "
"pip install numpy"
)
) from e
mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),

View File

@ -34,11 +34,11 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
"""
try:
import numpy as np
except ImportError:
except ImportError as e:
raise ImportError(
"cosine_similarity requires numpy to be installed. "
"Please install numpy with `pip install numpy`."
)
) from e
if len(X) == 0 or len(Y) == 0:
return np.array([])
@ -93,11 +93,11 @@ def maximal_marginal_relevance(
"""
try:
import numpy as np
except ImportError:
except ImportError as e:
raise ImportError(
"maximal_marginal_relevance requires numpy to be installed. "
"Please install numpy with `pip install numpy`."
)
) from e
if min(k, len(embedding_list)) <= 0:
return []

View File

@ -41,7 +41,7 @@ python = ">=3.12.4"
[tool.poetry.extras]
[tool.ruff.lint]
select = [ "E", "F", "I", "T201", "UP",]
select = [ "B", "E", "F", "I", "T201", "UP",]
ignore = [ "UP006", "UP007",]
[tool.coverage.run]

View File

@ -7,6 +7,7 @@ import pytest
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, FakeListChatModel
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@ -110,7 +111,7 @@ async def test_stream_error_callback() -> None:
responses=[message],
error_on_chunk_number=i,
)
with pytest.raises(Exception):
with pytest.raises(FakeListChatModelError):
cb_async = FakeAsyncCallbackHandler()
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
pass

View File

@ -7,6 +7,7 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM
from langchain_core.language_models.fake import FakeListLLMError
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import (
@ -108,7 +109,7 @@ async def test_stream_error_callback() -> None:
responses=[message],
error_on_chunk_number=i,
)
with pytest.raises(Exception):
with pytest.raises(FakeListLLMError):
cb_async = FakeAsyncCallbackHandler()
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
pass

View File

@ -3,6 +3,7 @@ from typing import Any, AsyncIterator, Iterator, Tuple
import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.json import (
SimpleJsonOutputParser,
)
@ -531,7 +532,7 @@ async def test_partial_text_json_output_parser_diff_async() -> None:
def test_raises_error() -> None:
parser = SimpleJsonOutputParser()
with pytest.raises(Exception):
with pytest.raises(OutputParserException):
parser.invoke("hi")

View File

@ -164,13 +164,9 @@ def test_pydantic_output_parser_fail() -> None:
pydantic_object=TestModel
)
try:
with pytest.raises(OutputParserException) as e:
pydantic_parser.parse(DEF_RESULT_FAIL)
except OutputParserException as e:
print("parse_result:", e) # noqa: T201
assert "Failed to parse TestModel from completion" in str(e)
else:
assert False, "Expected OutputParserException"
def test_pydantic_output_parser_type_inference() -> None:

View File

@ -28,7 +28,7 @@ def _replace_all_of_with_ref(schema: Any) -> None:
del schema["default"]
else:
# Recursively process nested schemas
for key, value in schema.items():
for value in schema.values():
if isinstance(value, (dict, list)):
_replace_all_of_with_ref(value)
elif isinstance(schema, list):
@ -47,7 +47,7 @@ def _remove_bad_none_defaults(schema: Any) -> None:
See difference between Optional and NotRequired types in python.
"""
if isinstance(schema, dict):
for key, value in schema.items():
for value in schema.values():
if isinstance(value, dict):
if "default" in value and value["default"] is None:
any_of = value.get("anyOf", [])

View File

@ -307,7 +307,7 @@ async def test_fallbacks_astream() -> None:
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
[RunnableGenerator(_agenerate)]
)
async for c in runnable.astream({}):
async for _ in runnable.astream({}):
pass
@ -373,7 +373,7 @@ def test_fallbacks_getattr() -> None:
assert llm_with_fallbacks.foo == 3
with pytest.raises(AttributeError):
llm_with_fallbacks.bar
assert llm_with_fallbacks.bar == 4
def test_fallbacks_getattr_runnable_output() -> None:

View File

@ -1516,7 +1516,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
) == [5, 7]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
for call in spy.call_args_list:
call_arg = call.args[0]
if call_arg == "hello":
@ -1533,7 +1533,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert len(spy.call_args_list) == 2
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
for i, call in enumerate(spy.call_args_list):
for call in spy.call_args_list:
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
spy.reset_mock()
@ -5205,7 +5205,7 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
tracer = FakeTracer()
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
for _ in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
pass
assert tracer.runs[0].name == "RunnableAssign<urls>"
@ -5225,7 +5225,7 @@ async def test_ainvoke_astream_passthrough_assign_trace() -> None:
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
tracer = FakeTracer()
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
async for _ in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
pass
assert tracer.runs[0].name == "RunnableAssign<urls>"

View File

@ -140,7 +140,7 @@ class LangChainProjectNameTest(unittest.TestCase):
projects = []
def mock_create_run(**kwargs: Any) -> Any:
projects.append(kwargs.get("project_name"))
projects.append(kwargs.get("project_name")) # noqa: B023
return unittest.mock.MagicMock()
client.create_run = mock_create_run
@ -151,6 +151,4 @@ class LangChainProjectNameTest(unittest.TestCase):
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
)
tracer.wait_for_futures()
assert (
len(projects) == 1 and projects[0] == case.expected_project_name
)
assert projects == [case.expected_project_name]