mirror of
https://github.com/hwchase17/langchain
synced 2024-11-11 19:11:02 +00:00
core[patch]: Add B(bugbear) ruff rules (#25520)
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d5ddaac1fc
commit
ff0df5ea15
@ -270,7 +270,7 @@ def warn_beta(
|
||||
message += f" {addendum}"
|
||||
|
||||
warning = LangChainBetaWarning(message)
|
||||
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=2)
|
||||
warnings.warn(warning, category=LangChainBetaWarning, stacklevel=4)
|
||||
|
||||
|
||||
def surface_langchain_beta_warnings() -> None:
|
||||
|
@ -444,7 +444,7 @@ def warn_deprecated(
|
||||
LangChainPendingDeprecationWarning if pending else LangChainDeprecationWarning
|
||||
)
|
||||
warning = warning_cls(message)
|
||||
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=2)
|
||||
warnings.warn(warning, category=LangChainDeprecationWarning, stacklevel=4)
|
||||
|
||||
|
||||
def surface_langchain_deprecation_warnings() -> None:
|
||||
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.documents.base import Blob
|
||||
|
||||
|
||||
class BaseLoader(ABC):
|
||||
class BaseLoader(ABC): # noqa: B024
|
||||
"""Interface for Document Loader.
|
||||
|
||||
Implementations should implement the lazy-loading method using generators
|
||||
|
@ -80,12 +80,12 @@ def _texts_to_nodes(
|
||||
for text in texts:
|
||||
try:
|
||||
_metadata = next(metadatas_it).copy() if metadatas_it else {}
|
||||
except StopIteration:
|
||||
raise ValueError("texts iterable longer than metadatas")
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than metadatas") from e
|
||||
try:
|
||||
_id = next(ids_it) if ids_it else None
|
||||
except StopIteration:
|
||||
raise ValueError("texts iterable longer than ids")
|
||||
except StopIteration as e:
|
||||
raise ValueError("texts iterable longer than ids") from e
|
||||
|
||||
links = _metadata.pop(METADATA_LINKS_KEY, [])
|
||||
if not isinstance(links, list):
|
||||
|
@ -91,7 +91,7 @@ class _HashedDocument(Document):
|
||||
raise ValueError(
|
||||
f"Failed to hash metadata: {e}. "
|
||||
f"Please use a dict that can be serialized using json."
|
||||
)
|
||||
) from e
|
||||
|
||||
values["content_hash"] = content_hash
|
||||
values["metadata_hash"] = metadata_hash
|
||||
|
@ -64,12 +64,12 @@ def get_tokenizer() -> Any:
|
||||
"""
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast # type: ignore[import]
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_token_ids. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
) from e
|
||||
# create a GPT-2 tokenizer instance
|
||||
return GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
|
@ -235,6 +235,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
@ -69,6 +69,10 @@ class FakeListLLM(LLM):
|
||||
return {"responses": self.responses}
|
||||
|
||||
|
||||
class FakeListLLMError(Exception):
|
||||
"""Fake error for testing purposes."""
|
||||
|
||||
|
||||
class FakeStreamingListLLM(FakeListLLM):
|
||||
"""Fake streaming list LLM for testing purposes.
|
||||
|
||||
@ -98,7 +102,7 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
raise FakeListLLMError
|
||||
yield c
|
||||
|
||||
async def astream(
|
||||
@ -118,5 +122,5 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
raise FakeListLLMError
|
||||
yield c
|
||||
|
@ -44,6 +44,10 @@ class FakeMessagesListChatModel(BaseChatModel):
|
||||
return "fake-messages-list-chat-model"
|
||||
|
||||
|
||||
class FakeListChatModelError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FakeListChatModel(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
@ -93,7 +97,7 @@ class FakeListChatModel(SimpleChatModel):
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
raise FakeListChatModelError
|
||||
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
@ -116,7 +120,7 @@ class FakeListChatModel(SimpleChatModel):
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
raise FakeListChatModelError
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
@property
|
||||
|
@ -114,7 +114,6 @@ def create_base_retry_decorator(
|
||||
_log_error_once(f"Error in on_retry: {e}")
|
||||
else:
|
||||
run_manager.on_retry(retry_state)
|
||||
return None
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
@ -311,6 +310,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
@ -290,10 +290,10 @@ def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
|
||||
msg_type = msg_kwargs.pop("type")
|
||||
# None msg content is not allowed
|
||||
msg_content = msg_kwargs.pop("content") or ""
|
||||
except KeyError:
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Message dict must contain 'role' and 'content' keys, got {message}"
|
||||
)
|
||||
) from e
|
||||
_message = _create_message_from_message_type(
|
||||
msg_type, msg_content, **msg_kwargs
|
||||
)
|
||||
@ -344,9 +344,7 @@ def _runnable_support(func: Callable) -> Callable:
|
||||
if messages is not None:
|
||||
return func(messages, **kwargs)
|
||||
else:
|
||||
return RunnableLambda(
|
||||
partial(func, **kwargs), name=getattr(func, "__name__")
|
||||
)
|
||||
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
|
||||
|
||||
wrapped.__doc__ = func.__doc__
|
||||
return wrapped
|
||||
@ -791,7 +789,7 @@ def trim_messages(
|
||||
raise ValueError
|
||||
messages = convert_to_messages(messages)
|
||||
if hasattr(token_counter, "get_num_tokens_from_messages"):
|
||||
list_token_counter = getattr(token_counter, "get_num_tokens_from_messages")
|
||||
list_token_counter = token_counter.get_num_tokens_from_messages
|
||||
elif callable(token_counter):
|
||||
if (
|
||||
list(inspect.signature(token_counter).parameters.values())[0].annotation
|
||||
|
@ -42,7 +42,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
||||
try:
|
||||
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
|
||||
except KeyError as exc:
|
||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call: {exc}"
|
||||
) from exc
|
||||
|
||||
if self.args_only:
|
||||
return func_call["arguments"]
|
||||
@ -100,7 +102,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
if partial:
|
||||
return None
|
||||
else:
|
||||
raise OutputParserException(f"Could not parse function call: {exc}")
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call: {exc}"
|
||||
) from exc
|
||||
try:
|
||||
if partial:
|
||||
try:
|
||||
@ -126,7 +130,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
) from exc
|
||||
else:
|
||||
try:
|
||||
return {
|
||||
@ -138,7 +142,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise OutputParserException(
|
||||
f"Could not parse function call data: {exc}"
|
||||
)
|
||||
) from exc
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
@ -55,7 +55,7 @@ def parse_tool_call(
|
||||
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
|
||||
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
|
||||
f"Received JSONDecodeError {e}"
|
||||
)
|
||||
) from e
|
||||
parsed = {
|
||||
"name": raw_tool_call["function"]["name"] or "",
|
||||
"args": function_args or {},
|
||||
|
@ -32,12 +32,12 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
||||
{self.pydantic_object.__class__}"
|
||||
)
|
||||
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
|
||||
raise self._parser_exception(e, obj)
|
||||
raise self._parser_exception(e, obj) from e
|
||||
else: # pydantic v1
|
||||
try:
|
||||
return self.pydantic_object.parse_obj(obj)
|
||||
except pydantic.ValidationError as e:
|
||||
raise self._parser_exception(e, obj)
|
||||
raise self._parser_exception(e, obj) from e
|
||||
|
||||
def _parser_exception(
|
||||
self, e: Exception, json_object: dict
|
||||
|
@ -46,12 +46,12 @@ class _StreamingParser:
|
||||
if parser == "defusedxml":
|
||||
try:
|
||||
from defusedxml import ElementTree as DET # type: ignore
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"defusedxml is not installed. "
|
||||
"Please install it to use the defusedxml parser."
|
||||
"You can install it with `pip install defusedxml` "
|
||||
)
|
||||
) from e
|
||||
_parser = DET.DefusedXMLParser(target=TreeBuilder())
|
||||
else:
|
||||
_parser = None
|
||||
@ -189,13 +189,13 @@ class XMLOutputParser(BaseTransformOutputParser):
|
||||
if self.parser == "defusedxml":
|
||||
try:
|
||||
from defusedxml import ElementTree as DET # type: ignore
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"defusedxml is not installed. "
|
||||
"Please install it to use the defusedxml parser."
|
||||
"You can install it with `pip install defusedxml`"
|
||||
"See https://github.com/tiran/defusedxml for more details"
|
||||
)
|
||||
) from e
|
||||
_ET = DET # Use the defusedxml parser
|
||||
else:
|
||||
_ET = ET # Use the standard library parser
|
||||
|
@ -235,7 +235,9 @@ class PromptTemplate(StringPromptTemplate):
|
||||
template = f.read()
|
||||
if input_variables:
|
||||
warnings.warn(
|
||||
"`input_variables' is deprecated and ignored.", DeprecationWarning
|
||||
"`input_variables' is deprecated and ignored.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return cls.from_template(template=template, **kwargs)
|
||||
|
||||
|
@ -40,14 +40,14 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
try:
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||
"Please install it with `pip install jinja2`."
|
||||
"Please be cautious when using jinja2 templates. "
|
||||
"Do not expand jinja2 templates using unverified or user-controlled "
|
||||
"inputs as that can result in arbitrary Python code execution."
|
||||
)
|
||||
) from e
|
||||
|
||||
# This uses a sandboxed environment to prevent arbitrary code execution.
|
||||
# Jinja2 uses an opt-out rather than opt-in approach for sand-boxing.
|
||||
@ -81,17 +81,17 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||
warning_message += f"Extra variables: {extra_variables}"
|
||||
|
||||
if warning_message:
|
||||
warnings.warn(warning_message.strip())
|
||||
warnings.warn(warning_message.strip(), stacklevel=7)
|
||||
|
||||
|
||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||
try:
|
||||
from jinja2 import Environment, meta
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||
"Please install it with `pip install jinja2`."
|
||||
)
|
||||
) from e
|
||||
env = Environment()
|
||||
ast = env.parse(template)
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
|
@ -155,6 +155,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
"Retrievers must implement abstract `_get_relevant_documents` method"
|
||||
" instead of `get_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
swap = cls.get_relevant_documents
|
||||
cls.get_relevant_documents = ( # type: ignore[assignment]
|
||||
@ -169,6 +170,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
||||
" instead of `aget_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
aswap = cls.aget_relevant_documents
|
||||
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
||||
|
@ -3915,7 +3915,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
|
||||
func = getattr(self, "_transform", None) or self._atransform
|
||||
try:
|
||||
params = inspect.signature(func).parameters
|
||||
first_param = next(iter(params.values()), None)
|
||||
@ -3928,7 +3928,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
|
||||
func = getattr(self, "_transform", None) or self._atransform
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
return (
|
||||
@ -4152,7 +4152,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
"""The type of the input to this Runnable."""
|
||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||
func = getattr(self, "func", None) or self.afunc
|
||||
try:
|
||||
params = inspect.signature(func).parameters
|
||||
first_param = next(iter(params.values()), None)
|
||||
@ -4174,7 +4174,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
Returns:
|
||||
The input schema for this Runnable.
|
||||
"""
|
||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||
func = getattr(self, "func", None) or self.afunc
|
||||
|
||||
if isinstance(func, itemgetter):
|
||||
# This is terrible, but afaict it's not possible to access _items
|
||||
@ -4212,7 +4212,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
Returns:
|
||||
The type of the output of this Runnable.
|
||||
"""
|
||||
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||
func = getattr(self, "func", None) or self.afunc
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
if sig.return_annotation != inspect.Signature.empty:
|
||||
|
@ -236,6 +236,7 @@ def get_config_list(
|
||||
warnings.warn(
|
||||
"Provided run_id be used only for the first element of the batch.",
|
||||
category=RuntimeWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
subsequent = cast(
|
||||
RunnableConfig, {k: v for k, v in config.items() if k != "run_id"}
|
||||
|
@ -537,7 +537,7 @@ class Graph:
|
||||
*,
|
||||
with_styles: bool = True,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: NodeStyles = NodeStyles(),
|
||||
node_colors: Optional[NodeStyles] = None,
|
||||
wrap_label_n_words: int = 9,
|
||||
) -> str:
|
||||
"""Draw the graph as a Mermaid syntax string.
|
||||
@ -573,7 +573,7 @@ class Graph:
|
||||
self,
|
||||
*,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_colors: NodeStyles = NodeStyles(),
|
||||
node_colors: Optional[NodeStyles] = None,
|
||||
wrap_label_n_words: int = 9,
|
||||
output_file_path: Optional[str] = None,
|
||||
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
|
||||
|
@ -20,7 +20,7 @@ def draw_mermaid(
|
||||
last_node: Optional[str] = None,
|
||||
with_styles: bool = True,
|
||||
curve_style: CurveStyle = CurveStyle.LINEAR,
|
||||
node_styles: NodeStyles = NodeStyles(),
|
||||
node_styles: Optional[NodeStyles] = None,
|
||||
wrap_label_n_words: int = 9,
|
||||
) -> str:
|
||||
"""Draws a Mermaid graph using the provided graph data.
|
||||
@ -153,7 +153,7 @@ def draw_mermaid(
|
||||
|
||||
# Add custom styles for nodes
|
||||
if with_styles:
|
||||
mermaid_graph += _generate_mermaid_graph_styles(node_styles)
|
||||
mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles())
|
||||
return mermaid_graph
|
||||
|
||||
|
||||
|
@ -218,6 +218,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
not_set: List[Output] = []
|
||||
result = not_set
|
||||
try:
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
@ -247,9 +249,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
):
|
||||
attempt.retry_state.set_result(result)
|
||||
except RetryError as e:
|
||||
try:
|
||||
result
|
||||
except UnboundLocalError:
|
||||
if result is not_set:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
@ -284,6 +284,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
not_set: List[Output] = []
|
||||
result = not_set
|
||||
try:
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
@ -313,9 +315,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
):
|
||||
attempt.retry_state.set_result(result)
|
||||
except RetryError as e:
|
||||
try:
|
||||
result
|
||||
except UnboundLocalError:
|
||||
if result is not_set:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
|
@ -748,7 +748,7 @@ def is_async_generator(
|
||||
"""
|
||||
return (
|
||||
inspect.isasyncgenfunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
or hasattr(func, "__call__") # noqa: B004
|
||||
and inspect.isasyncgenfunction(func.__call__)
|
||||
)
|
||||
|
||||
@ -767,6 +767,6 @@ def is_async_callable(
|
||||
"""
|
||||
return (
|
||||
asyncio.iscoroutinefunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
or hasattr(func, "__call__") # noqa: B004
|
||||
and asyncio.iscoroutinefunction(func.__call__)
|
||||
)
|
||||
|
@ -443,6 +443,7 @@ class ChildTool(BaseTool):
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=6,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
@ -244,7 +244,7 @@ def _get_schema_from_runnable_and_arg_types(
|
||||
"Tool input must be str or dict. If dict, dict arguments must be "
|
||||
"typed. Either annotate types (e.g., with TypedDict) or pass "
|
||||
f"arg_types into `.as_tool` to specify. {str(e)}"
|
||||
)
|
||||
) from e
|
||||
fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()}
|
||||
return create_model(name, **fields) # type: ignore
|
||||
|
||||
|
@ -516,15 +516,19 @@ class _TracerCore(ABC):
|
||||
|
||||
def _end_trace(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""End a trace for a run."""
|
||||
return None
|
||||
|
||||
def _on_run_create(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process a run upon creation."""
|
||||
return None
|
||||
|
||||
def _on_run_update(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process a run upon update."""
|
||||
return None
|
||||
|
||||
def _on_llm_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the LLM Run upon start."""
|
||||
return None
|
||||
|
||||
def _on_llm_new_token(
|
||||
self,
|
||||
@ -533,39 +537,52 @@ class _TracerCore(ABC):
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process new LLM token."""
|
||||
return None
|
||||
|
||||
def _on_llm_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the LLM Run."""
|
||||
return None
|
||||
|
||||
def _on_llm_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the LLM Run upon error."""
|
||||
return None
|
||||
|
||||
def _on_chain_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chain Run upon start."""
|
||||
return None
|
||||
|
||||
def _on_chain_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chain Run."""
|
||||
return None
|
||||
|
||||
def _on_chain_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chain Run upon error."""
|
||||
return None
|
||||
|
||||
def _on_tool_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Tool Run upon start."""
|
||||
return None
|
||||
|
||||
def _on_tool_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Tool Run."""
|
||||
return None
|
||||
|
||||
def _on_tool_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Tool Run upon error."""
|
||||
return None
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Chat Model Run upon start."""
|
||||
return None
|
||||
|
||||
def _on_retriever_start(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Retriever Run upon start."""
|
||||
return None
|
||||
|
||||
def _on_retriever_end(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Retriever Run."""
|
||||
return None
|
||||
|
||||
def _on_retriever_error(self, run: Run) -> Union[None, Coroutine[Any, Any, None]]:
|
||||
"""Process the Retriever Run upon error."""
|
||||
return None
|
||||
|
@ -144,11 +144,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
example_id = str(run.reference_example_id)
|
||||
with self.lock:
|
||||
for res in eval_results:
|
||||
run_id = (
|
||||
str(getattr(res, "target_run_id"))
|
||||
if hasattr(res, "target_run_id")
|
||||
else str(run.id)
|
||||
)
|
||||
run_id = str(getattr(res, "target_run_id", run.id))
|
||||
self.logged_eval_results.setdefault((run_id, example_id), []).append(
|
||||
res
|
||||
)
|
||||
@ -179,11 +175,9 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
source_info_: Dict[str, Any] = {}
|
||||
if res.evaluator_info:
|
||||
source_info_ = {**res.evaluator_info, **source_info_}
|
||||
run_id_ = (
|
||||
getattr(res, "target_run_id")
|
||||
if hasattr(res, "target_run_id") and res.target_run_id is not None
|
||||
else run.id
|
||||
)
|
||||
run_id_ = getattr(res, "target_run_id", None)
|
||||
if run_id_ is None:
|
||||
run_id_ = run.id
|
||||
self.client.create_feedback(
|
||||
run_id_,
|
||||
res.key,
|
||||
|
@ -22,6 +22,7 @@ def RunTypeEnum() -> Type[RunTypeEnumDep]:
|
||||
"RunTypeEnum is deprecated. Please directly use a string instead"
|
||||
" (e.g. 'llm', 'chain', 'tool').",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return RunTypeEnumDep
|
||||
|
||||
|
@ -62,8 +62,8 @@ def py_anext(
|
||||
__anext__ = cast(
|
||||
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
|
||||
)
|
||||
except AttributeError:
|
||||
raise TypeError(f"{iterator!r} is not an async iterator")
|
||||
except AttributeError as e:
|
||||
raise TypeError(f"{iterator!r} is not an async iterator") from e
|
||||
|
||||
if default is _no_default:
|
||||
return __anext__(iterator)
|
||||
|
@ -182,7 +182,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}") from e
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserException(
|
||||
|
@ -21,7 +21,9 @@ def try_load_from_hub(
|
||||
) -> Any:
|
||||
warnings.warn(
|
||||
"Loading from the deprecated github-based Hub is no longer supported. "
|
||||
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead."
|
||||
"Please use the new LangChain Hub at https://smith.langchain.com/hub instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# return None, which indicates that we shouldn't load from old hub
|
||||
# and might just be a filepath for e.g. load_chain
|
||||
|
@ -3,13 +3,17 @@ Adapted from https://github.com/noahmorrison/chevron
|
||||
MIT License
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@ -22,7 +26,7 @@ from typing_extensions import TypeAlias
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Scopes: TypeAlias = List[Union[Literal[False, 0], Dict[str, Any]]]
|
||||
Scopes: TypeAlias = List[Union[Literal[False, 0], Mapping[str, Any]]]
|
||||
|
||||
|
||||
# Globals
|
||||
@ -152,8 +156,8 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
|
||||
# Get the tag
|
||||
try:
|
||||
tag, template = template.split(r_del, 1)
|
||||
except ValueError:
|
||||
raise ChevronError("unclosed tag " f"at line {_CURRENT_LINE}")
|
||||
except ValueError as e:
|
||||
raise ChevronError("unclosed tag " f"at line {_CURRENT_LINE}") from e
|
||||
|
||||
# Find the type meaning of the first character
|
||||
tag_type = tag_types.get(tag[0], "variable")
|
||||
@ -279,12 +283,12 @@ def tokenize(
|
||||
# is the same as us
|
||||
try:
|
||||
last_section = open_sections.pop()
|
||||
except IndexError:
|
||||
except IndexError as e:
|
||||
raise ChevronError(
|
||||
f'Trying to close tag "{tag_key}"\n'
|
||||
"Looks like it was not opened.\n"
|
||||
f"line {_CURRENT_LINE + 1}"
|
||||
)
|
||||
) from e
|
||||
if tag_key != last_section:
|
||||
# Otherwise we need to complain
|
||||
raise ChevronError(
|
||||
@ -411,7 +415,7 @@ def _get_key(
|
||||
return ""
|
||||
|
||||
|
||||
def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
|
||||
def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str:
|
||||
"""Load a partial"""
|
||||
try:
|
||||
# Maybe the partial is in the dictionary
|
||||
@ -425,11 +429,13 @@ def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
|
||||
#
|
||||
g_token_cache: Dict[str, List[Tuple[str, str]]] = {}
|
||||
|
||||
EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({})
|
||||
|
||||
|
||||
def render(
|
||||
template: Union[str, List[Tuple[str, str]]] = "",
|
||||
data: Dict[str, Any] = {},
|
||||
partials_dict: Dict[str, str] = {},
|
||||
data: Mapping[str, Any] = EMPTY_DICT,
|
||||
partials_dict: Mapping[str, str] = EMPTY_DICT,
|
||||
padding: str = "",
|
||||
def_ldel: str = "{{",
|
||||
def_rdel: str = "}}",
|
||||
|
@ -131,12 +131,12 @@ def guard_import(
|
||||
"""
|
||||
try:
|
||||
module = importlib.import_module(module_name, package)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
|
||||
raise ImportError(
|
||||
f"Could not import {module_name} python package. "
|
||||
f"Please install it with `pip install {pip_name}`."
|
||||
)
|
||||
) from e
|
||||
return module
|
||||
|
||||
|
||||
@ -235,7 +235,8 @@ def build_extra_kwargs(
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
Please confirm that {field_name} is what you intended.""",
|
||||
stacklevel=7,
|
||||
)
|
||||
extra_kwargs[field_name] = values.pop(field_name)
|
||||
|
||||
|
@ -558,7 +558,8 @@ class VectorStore(ABC):
|
||||
):
|
||||
warnings.warn(
|
||||
"Relevance scores must be between"
|
||||
f" 0 and 1, got {docs_and_similarities}"
|
||||
f" 0 and 1, got {docs_and_similarities}",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if score_threshold is not None:
|
||||
@ -568,7 +569,7 @@ class VectorStore(ABC):
|
||||
if similarity >= score_threshold
|
||||
]
|
||||
if len(docs_and_similarities) == 0:
|
||||
warnings.warn(
|
||||
logger.warning(
|
||||
"No relevant docs were retrieved using the relevance score"
|
||||
f" threshold {score_threshold}"
|
||||
)
|
||||
@ -605,7 +606,8 @@ class VectorStore(ABC):
|
||||
):
|
||||
warnings.warn(
|
||||
"Relevance scores must be between"
|
||||
f" 0 and 1, got {docs_and_similarities}"
|
||||
f" 0 and 1, got {docs_and_similarities}",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if score_threshold is not None:
|
||||
@ -615,7 +617,7 @@ class VectorStore(ABC):
|
||||
if similarity >= score_threshold
|
||||
]
|
||||
if len(docs_and_similarities) == 0:
|
||||
warnings.warn(
|
||||
logger.warning(
|
||||
"No relevant docs were retrieved using the relevance score"
|
||||
f" threshold {score_threshold}"
|
||||
)
|
||||
|
@ -435,11 +435,11 @@ class InMemoryVectorStore(VectorStore):
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"numpy must be installed to use max_marginal_relevance_search "
|
||||
"pip install numpy"
|
||||
)
|
||||
) from e
|
||||
|
||||
mmr_chosen_indices = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
|
@ -34,11 +34,11 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"cosine_similarity requires numpy to be installed. "
|
||||
"Please install numpy with `pip install numpy`."
|
||||
)
|
||||
) from e
|
||||
|
||||
if len(X) == 0 or len(Y) == 0:
|
||||
return np.array([])
|
||||
@ -93,11 +93,11 @@ def maximal_marginal_relevance(
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"maximal_marginal_relevance requires numpy to be installed. "
|
||||
"Please install numpy with `pip install numpy`."
|
||||
)
|
||||
) from e
|
||||
|
||||
if min(k, len(embedding_list)) <= 0:
|
||||
return []
|
||||
|
@ -41,7 +41,7 @@ python = ">=3.12.4"
|
||||
[tool.poetry.extras]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [ "E", "F", "I", "T201", "UP",]
|
||||
select = [ "B", "E", "F", "I", "T201", "UP",]
|
||||
ignore = [ "UP006", "UP007",]
|
||||
|
||||
[tool.coverage.run]
|
||||
|
@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel, FakeListChatModel
|
||||
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -110,7 +111,7 @@ async def test_stream_error_callback() -> None:
|
||||
responses=[message],
|
||||
error_on_chunk_number=i,
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(FakeListChatModelError):
|
||||
cb_async = FakeAsyncCallbackHandler()
|
||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||
pass
|
||||
|
@ -7,6 +7,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM
|
||||
from langchain_core.language_models.fake import FakeListLLMError
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
@ -108,7 +109,7 @@ async def test_stream_error_callback() -> None:
|
||||
responses=[message],
|
||||
error_on_chunk_number=i,
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(FakeListLLMError):
|
||||
cb_async = FakeAsyncCallbackHandler()
|
||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||
pass
|
||||
|
@ -3,6 +3,7 @@ from typing import Any, AsyncIterator, Iterator, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.json import (
|
||||
SimpleJsonOutputParser,
|
||||
)
|
||||
@ -531,7 +532,7 @@ async def test_partial_text_json_output_parser_diff_async() -> None:
|
||||
|
||||
def test_raises_error() -> None:
|
||||
parser = SimpleJsonOutputParser()
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(OutputParserException):
|
||||
parser.invoke("hi")
|
||||
|
||||
|
||||
|
@ -164,13 +164,9 @@ def test_pydantic_output_parser_fail() -> None:
|
||||
pydantic_object=TestModel
|
||||
)
|
||||
|
||||
try:
|
||||
with pytest.raises(OutputParserException) as e:
|
||||
pydantic_parser.parse(DEF_RESULT_FAIL)
|
||||
except OutputParserException as e:
|
||||
print("parse_result:", e) # noqa: T201
|
||||
assert "Failed to parse TestModel from completion" in str(e)
|
||||
else:
|
||||
assert False, "Expected OutputParserException"
|
||||
|
||||
|
||||
def test_pydantic_output_parser_type_inference() -> None:
|
||||
|
@ -28,7 +28,7 @@ def _replace_all_of_with_ref(schema: Any) -> None:
|
||||
del schema["default"]
|
||||
else:
|
||||
# Recursively process nested schemas
|
||||
for key, value in schema.items():
|
||||
for value in schema.values():
|
||||
if isinstance(value, (dict, list)):
|
||||
_replace_all_of_with_ref(value)
|
||||
elif isinstance(schema, list):
|
||||
@ -47,7 +47,7 @@ def _remove_bad_none_defaults(schema: Any) -> None:
|
||||
See difference between Optional and NotRequired types in python.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
for key, value in schema.items():
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
if "default" in value and value["default"] is None:
|
||||
any_of = value.get("anyOf", [])
|
||||
|
@ -307,7 +307,7 @@ async def test_fallbacks_astream() -> None:
|
||||
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
|
||||
[RunnableGenerator(_agenerate)]
|
||||
)
|
||||
async for c in runnable.astream({}):
|
||||
async for _ in runnable.astream({}):
|
||||
pass
|
||||
|
||||
|
||||
@ -373,7 +373,7 @@ def test_fallbacks_getattr() -> None:
|
||||
assert llm_with_fallbacks.foo == 3
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
llm_with_fallbacks.bar
|
||||
assert llm_with_fallbacks.bar == 4
|
||||
|
||||
|
||||
def test_fallbacks_getattr_runnable_output() -> None:
|
||||
|
@ -1516,7 +1516,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
) == [5, 7]
|
||||
|
||||
assert len(spy.call_args_list) == 2
|
||||
for i, call in enumerate(spy.call_args_list):
|
||||
for call in spy.call_args_list:
|
||||
call_arg = call.args[0]
|
||||
|
||||
if call_arg == "hello":
|
||||
@ -1533,7 +1533,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
||||
assert len(spy.call_args_list) == 2
|
||||
assert set(call.args[0] for call in spy.call_args_list) == {"hello", "wooorld"}
|
||||
for i, call in enumerate(spy.call_args_list):
|
||||
for call in spy.call_args_list:
|
||||
assert call.args[1].get("tags") == ["a-tag"]
|
||||
assert call.args[1].get("metadata") == {}
|
||||
spy.reset_mock()
|
||||
@ -5205,7 +5205,7 @@ def test_invoke_stream_passthrough_assign_trace() -> None:
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||
|
||||
tracer = FakeTracer()
|
||||
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||
for _ in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||
pass
|
||||
|
||||
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||
@ -5225,7 +5225,7 @@ async def test_ainvoke_astream_passthrough_assign_trace() -> None:
|
||||
assert tracer.runs[0].child_runs[0].name == "RunnableParallel<urls>"
|
||||
|
||||
tracer = FakeTracer()
|
||||
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||
async for _ in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
|
||||
pass
|
||||
|
||||
assert tracer.runs[0].name == "RunnableAssign<urls>"
|
||||
|
@ -140,7 +140,7 @@ class LangChainProjectNameTest(unittest.TestCase):
|
||||
projects = []
|
||||
|
||||
def mock_create_run(**kwargs: Any) -> Any:
|
||||
projects.append(kwargs.get("project_name"))
|
||||
projects.append(kwargs.get("project_name")) # noqa: B023
|
||||
return unittest.mock.MagicMock()
|
||||
|
||||
client.create_run = mock_create_run
|
||||
@ -151,6 +151,4 @@ class LangChainProjectNameTest(unittest.TestCase):
|
||||
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
|
||||
)
|
||||
tracer.wait_for_futures()
|
||||
assert (
|
||||
len(projects) == 1 and projects[0] == case.expected_project_name
|
||||
)
|
||||
assert projects == [case.expected_project_name]
|
||||
|
Loading…
Reference in New Issue
Block a user