mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
core: Add various ruff rules (#26836)
Adds - ASYNC - COM - DJ - EXE - FLY - FURB - ICN - INT - LOG - NPY - PD - Q - RSE - SLOT - T10 - TID - YTT Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
5c826faece
commit
16f5fdb38b
@ -53,7 +53,7 @@ class FakeEmbeddings(Embeddings, BaseModel):
|
||||
def _get_embedding(self) -> list[float]:
|
||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
return list(np.random.normal(size=self.size))
|
||||
return list(np.random.default_rng().normal(size=self.size))
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return [self._get_embedding() for _ in texts]
|
||||
@ -109,8 +109,8 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
|
||||
import numpy as np # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
# set the seed for the random generator
|
||||
np.random.seed(seed)
|
||||
return list(np.random.normal(size=self.size))
|
||||
rng = np.random.default_rng(seed)
|
||||
return list(rng.normal(size=self.size))
|
||||
|
||||
def _get_seed(self, text: str) -> int:
|
||||
"""Get a seed for the random generator, using the hash of the text."""
|
||||
|
@ -237,7 +237,7 @@ class BaseLanguageModel(
|
||||
"""Not implemented on this class."""
|
||||
# Implement this on child class if there is a way of steering the model to
|
||||
# generate responses that match a given schema.
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
@abstractmethod
|
||||
|
@ -977,7 +977,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@ -1112,7 +1112,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
|
@ -698,7 +698,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
Returns:
|
||||
An iterator of GenerationChunks.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
|
@ -151,7 +151,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
Returns:
|
||||
The parsed JSON object.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
|
@ -207,7 +207,7 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
|
||||
Returns:
|
||||
The parsed tool calls.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
|
@ -106,7 +106,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
||||
Returns:
|
||||
The diff between the previous and current parsed output.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
||||
prev_parsed = None
|
||||
|
@ -1336,7 +1336,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
Args:
|
||||
file_path: path to file.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
"""Human-readable representation.
|
||||
|
@ -464,4 +464,4 @@ class FewShotChatMessagePromptTemplate(
|
||||
Returns:
|
||||
A pretty representation of the prompt template.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
@ -132,4 +132,4 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
Returns:
|
||||
A pretty representation of the prompt.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
@ -289,9 +290,14 @@ async def _render_mermaid_using_pyppeteer(
|
||||
img_bytes = await page.screenshot({"fullPage": False})
|
||||
await browser.close()
|
||||
|
||||
def write_to_file(path: str, bytes: bytes) -> None:
|
||||
with open(path, "wb") as file:
|
||||
file.write(bytes)
|
||||
|
||||
if output_file_path is not None:
|
||||
with open(output_file_path, "wb") as file:
|
||||
file.write(img_bytes)
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, write_to_file, output_file_path, img_bytes
|
||||
)
|
||||
|
||||
return img_bytes
|
||||
|
||||
|
@ -453,7 +453,7 @@ def secret_from_env(
|
||||
return SecretStr(os.environ[key])
|
||||
if isinstance(default, str):
|
||||
return SecretStr(default)
|
||||
elif isinstance(default, type(None)):
|
||||
elif default is None:
|
||||
return None
|
||||
else:
|
||||
if error_message:
|
||||
|
@ -44,8 +44,42 @@ python = ">=3.12.4"
|
||||
[tool.poetry.extras]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [ "B", "C4", "E", "EM", "F", "I", "N", "PIE", "SIM", "T201", "UP", "W",]
|
||||
ignore = [ "UP007", "W293",]
|
||||
select = [
|
||||
"ASYNC",
|
||||
"B",
|
||||
"C4",
|
||||
"COM",
|
||||
"DJ",
|
||||
"E",
|
||||
"EM",
|
||||
"EXE",
|
||||
"F",
|
||||
"FLY",
|
||||
"FURB",
|
||||
"I",
|
||||
"ICN",
|
||||
"INT",
|
||||
"LOG",
|
||||
"N",
|
||||
"NPY",
|
||||
"PD",
|
||||
"PIE",
|
||||
"Q",
|
||||
"RSE",
|
||||
"SIM",
|
||||
"SLOT",
|
||||
"T10",
|
||||
"T201",
|
||||
"TID",
|
||||
"UP",
|
||||
"W",
|
||||
"YTT"
|
||||
]
|
||||
ignore = [
|
||||
"COM812", # Messes with the formatter
|
||||
"UP007", # Incompatible with pydantic + Python 3.9
|
||||
"W293", #
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [ "tests/*",]
|
||||
|
@ -17,7 +17,7 @@ def test_add_message_implementation_only() -> None:
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the store."""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
store: list[BaseMessage] = []
|
||||
chat_history = SampleChatHistory(store=store)
|
||||
@ -50,7 +50,7 @@ def test_bulk_message_implementation_only() -> None:
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the store."""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
chat_history = BulkAddHistory(store=store)
|
||||
chat_history.add_message(HumanMessage(content="Hello"))
|
||||
|
@ -165,7 +165,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -210,7 +210,7 @@ async def test_astream_implementation_uses_astream() -> None:
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
async def _astream( # type: ignore
|
||||
self,
|
||||
|
@ -161,7 +161,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -198,7 +198,7 @@ async def test_astream_implementation_uses_astream() -> None:
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
|
@ -59,7 +59,7 @@ def test_base_transform_output_parser() -> None:
|
||||
|
||||
def parse(self, text: str) -> str:
|
||||
"""Parse a single string into a specific format."""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_result(
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
|
@ -61,13 +61,13 @@ def chain() -> Runnable:
|
||||
|
||||
|
||||
def _raise_error(inputs: dict) -> str:
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
|
||||
|
||||
def _dont_raise_error(inputs: dict) -> str:
|
||||
if "exception" in inputs:
|
||||
return "bar"
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -99,11 +99,11 @@ def _runnable(inputs: dict) -> str:
|
||||
if inputs["text"] == "foo":
|
||||
return "first"
|
||||
if "exception" not in inputs:
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
if inputs["text"] == "bar":
|
||||
return "second"
|
||||
if isinstance(inputs["exception"], ValueError):
|
||||
raise RuntimeError()
|
||||
raise RuntimeError
|
||||
return "third"
|
||||
|
||||
|
||||
@ -251,13 +251,13 @@ def _generate(input: Iterator) -> Iterator[str]:
|
||||
|
||||
|
||||
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
yield ""
|
||||
|
||||
|
||||
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
|
||||
yield ""
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
|
||||
|
||||
def test_fallbacks_stream() -> None:
|
||||
@ -279,13 +279,13 @@ async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
|
||||
|
||||
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
yield ""
|
||||
|
||||
|
||||
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
yield ""
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
|
||||
|
||||
async def test_fallbacks_astream() -> None:
|
||||
|
@ -356,7 +356,7 @@ def test_runnable_get_graph_with_invalid_input_type() -> None:
|
||||
@property
|
||||
@override
|
||||
def InputType(self) -> type:
|
||||
raise TypeError()
|
||||
raise TypeError
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
@ -381,7 +381,7 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type:
|
||||
raise TypeError()
|
||||
raise TypeError
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
|
@ -653,7 +653,7 @@ def test_with_types_with_type_generics() -> None:
|
||||
|
||||
def foo(x: int) -> None:
|
||||
"""Add one to the input."""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
# Try specifying some
|
||||
RunnableLambda(foo).with_types(
|
||||
@ -3980,7 +3980,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
def invoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _batch(
|
||||
self,
|
||||
@ -4101,7 +4101,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
def invoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
async def _abatch(
|
||||
self,
|
||||
@ -5352,7 +5352,7 @@ async def test_listeners_async() -> None:
|
||||
assert value2 in shared_state.values(), "Value not found in the dictionary."
|
||||
|
||||
|
||||
async def test_closing_iterator_doesnt_raise_error() -> None:
|
||||
def test_closing_iterator_doesnt_raise_error() -> None:
|
||||
"""Test that closing an iterator calls on_chain_end rather than on_chain_error."""
|
||||
import time
|
||||
|
||||
@ -5361,9 +5361,10 @@ async def test_closing_iterator_doesnt_raise_error() -> None:
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
on_chain_error_triggered = False
|
||||
on_chain_end_triggered = False
|
||||
|
||||
class MyHandler(BaseCallbackHandler):
|
||||
async def on_chain_error(
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
@ -5376,6 +5377,17 @@ async def test_closing_iterator_doesnt_raise_error() -> None:
|
||||
nonlocal on_chain_error_triggered
|
||||
on_chain_error_triggered = True
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
nonlocal on_chain_end_triggered
|
||||
on_chain_end_triggered = True
|
||||
|
||||
llm = GenericFakeChatModel(messages=iter(["hi there"]))
|
||||
chain = llm | StrOutputParser()
|
||||
chain_ = chain.with_config({"callbacks": [MyHandler()]})
|
||||
@ -5386,6 +5398,7 @@ async def test_closing_iterator_doesnt_raise_error() -> None:
|
||||
# Wait for a bit to make sure that the callback is called.
|
||||
time.sleep(0.05)
|
||||
assert on_chain_error_triggered is False
|
||||
assert on_chain_end_triggered is True
|
||||
|
||||
|
||||
def test_pydantic_protected_namespaces() -> None:
|
||||
|
@ -2067,7 +2067,7 @@ class StreamingRunnable(Runnable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
|
@ -19,7 +19,7 @@ from langchain_core.runnables.utils import (
|
||||
[
|
||||
(lambda x: x * 2, "lambda x: x * 2"),
|
||||
(lambda a, b: a + b, "lambda a, b: a + b"),
|
||||
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"),
|
||||
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"), # noqa: FURB136
|
||||
],
|
||||
)
|
||||
def test_get_lambda_source(func: Callable, expected_source: str) -> None:
|
||||
|
@ -5,6 +5,8 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||
|
||||
|
||||
class AnyStr(str):
|
||||
__slots__ = ()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, str)
|
||||
|
||||
|
@ -346,7 +346,7 @@ class TestGetBufferString(unittest.TestCase):
|
||||
self.chat_msg,
|
||||
self.tool_calls_msg,
|
||||
]
|
||||
expected_output = "\n".join(
|
||||
expected_output = "\n".join( # noqa: FLY002
|
||||
[
|
||||
"Human: human",
|
||||
"AI: ai",
|
||||
|
@ -401,7 +401,7 @@ def test_structured_tool_from_function_docstring() -> None:
|
||||
bar: the bar value
|
||||
baz: the baz value
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
structured_tool = StructuredTool.from_function(foo)
|
||||
assert structured_tool.name == "foo"
|
||||
@ -435,7 +435,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
|
||||
bar: int
|
||||
baz: List[str]
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
structured_tool = StructuredTool.from_function(foo)
|
||||
assert structured_tool.name == "foo"
|
||||
@ -781,7 +781,7 @@ def test_structured_tool_from_function() -> None:
|
||||
bar: the bar value
|
||||
baz: the baz value
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
structured_tool = StructuredTool.from_function(foo)
|
||||
assert structured_tool.name == "foo"
|
||||
@ -854,7 +854,7 @@ def test_validation_error_handling_non_validation_error(
|
||||
self,
|
||||
tool_input: Union[str, dict],
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _run(self) -> str:
|
||||
return "dummy"
|
||||
@ -916,7 +916,7 @@ async def test_async_validation_error_handling_non_validation_error(
|
||||
self,
|
||||
tool_input: Union[str, dict],
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _run(self) -> str:
|
||||
return "dummy"
|
||||
|
@ -39,7 +39,7 @@ async def test_inmemory_similarity_search() -> None:
|
||||
output = await store.asimilarity_search("bar", k=2)
|
||||
assert output == [
|
||||
_any_id_document(page_content="bar"),
|
||||
_any_id_document(page_content="baz"),
|
||||
_any_id_document(page_content="foo"),
|
||||
]
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ async def test_inmemory_mmr() -> None:
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0] == _any_id_document(page_content="foo")
|
||||
assert output[1] == _any_id_document(page_content="foy")
|
||||
assert output[1] == _any_id_document(page_content="fou")
|
||||
|
||||
# Check async version
|
||||
output = await docsearch.amax_marginal_relevance_search(
|
||||
@ -89,7 +89,7 @@ async def test_inmemory_mmr() -> None:
|
||||
)
|
||||
assert len(output) == len(texts)
|
||||
assert output[0] == _any_id_document(page_content="foo")
|
||||
assert output[1] == _any_id_document(page_content="foy")
|
||||
assert output[1] == _any_id_document(page_content="fou")
|
||||
|
||||
|
||||
async def test_inmemory_dump_load(tmp_path: Path) -> None:
|
||||
|
@ -63,7 +63,7 @@ class CustomAddTextsVectorstore(VectorStore):
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CustomAddDocumentsVectorstore(VectorStore):
|
||||
@ -107,7 +107,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
Loading…
Reference in New Issue
Block a user