core: Add various ruff rules (#26836)

Adds
- ASYNC
- COM
- DJ
- EXE
- FLY
- FURB
- ICN
- INT
- LOG
- NPY
- PD
- Q
- RSE
- SLOT
- T10
- TID
- YTT

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Christophe Bornet 2024-10-08 00:30:27 +02:00 committed by GitHub
parent 5c826faece
commit 16f5fdb38b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 108 additions and 53 deletions

View File

@ -53,7 +53,7 @@ class FakeEmbeddings(Embeddings, BaseModel):
def _get_embedding(self) -> list[float]:
import numpy as np # type: ignore[import-not-found, import-untyped]
return list(np.random.normal(size=self.size))
return list(np.random.default_rng().normal(size=self.size))
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self._get_embedding() for _ in texts]
@ -109,8 +109,8 @@ class DeterministicFakeEmbedding(Embeddings, BaseModel):
import numpy as np # type: ignore[import-not-found, import-untyped]
# set the seed for the random generator
np.random.seed(seed)
return list(np.random.normal(size=self.size))
rng = np.random.default_rng(seed)
return list(rng.normal(size=self.size))
def _get_seed(self, text: str) -> int:
"""Get a seed for the random generator, using the hash of the text."""

View File

@ -237,7 +237,7 @@ class BaseLanguageModel(
"""Not implemented on this class."""
# Implement this on child class if there is a way of steering the model to
# generate responses that match a given schema.
raise NotImplementedError()
raise NotImplementedError
@deprecated("0.1.7", alternative="invoke", removal="1.0")
@abstractmethod

View File

@ -977,7 +977,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError()
raise NotImplementedError
async def _astream(
self,
@ -1112,7 +1112,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError()
raise NotImplementedError
def with_structured_output(
self,

View File

@ -698,7 +698,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
Returns:
An iterator of GenerationChunks.
"""
raise NotImplementedError()
raise NotImplementedError
async def _astream(
self,

View File

@ -151,7 +151,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
Returns:
The parsed JSON object.
"""
raise NotImplementedError()
raise NotImplementedError
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):

View File

@ -207,7 +207,7 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
Returns:
The parsed tool calls.
"""
raise NotImplementedError()
raise NotImplementedError
class JsonOutputKeyToolsParser(JsonOutputToolsParser):

View File

@ -106,7 +106,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
Returns:
The diff between the previous and current parsed output.
"""
raise NotImplementedError()
raise NotImplementedError
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None

View File

@ -1336,7 +1336,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
Args:
file_path: path to file.
"""
raise NotImplementedError()
raise NotImplementedError
def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation.

View File

@ -464,4 +464,4 @@ class FewShotChatMessagePromptTemplate(
Returns:
A pretty representation of the prompt template.
"""
raise NotImplementedError()
raise NotImplementedError

View File

@ -132,4 +132,4 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
Returns:
A pretty representation of the prompt.
"""
raise NotImplementedError()
raise NotImplementedError

View File

@ -1,3 +1,4 @@
import asyncio
import base64
import re
from dataclasses import asdict
@ -289,9 +290,14 @@ async def _render_mermaid_using_pyppeteer(
img_bytes = await page.screenshot({"fullPage": False})
await browser.close()
def write_to_file(path: str, bytes: bytes) -> None:
with open(path, "wb") as file:
file.write(bytes)
if output_file_path is not None:
with open(output_file_path, "wb") as file:
file.write(img_bytes)
await asyncio.get_event_loop().run_in_executor(
None, write_to_file, output_file_path, img_bytes
)
return img_bytes

View File

@ -453,7 +453,7 @@ def secret_from_env(
return SecretStr(os.environ[key])
if isinstance(default, str):
return SecretStr(default)
elif isinstance(default, type(None)):
elif default is None:
return None
else:
if error_message:

View File

@ -44,8 +44,42 @@ python = ">=3.12.4"
[tool.poetry.extras]
[tool.ruff.lint]
select = [ "B", "C4", "E", "EM", "F", "I", "N", "PIE", "SIM", "T201", "UP", "W",]
ignore = [ "UP007", "W293",]
select = [
"ASYNC",
"B",
"C4",
"COM",
"DJ",
"E",
"EM",
"EXE",
"F",
"FLY",
"FURB",
"I",
"ICN",
"INT",
"LOG",
"N",
"NPY",
"PD",
"PIE",
"Q",
"RSE",
"SIM",
"SLOT",
"T10",
"T201",
"TID",
"UP",
"W",
"YTT"
]
ignore = [
"COM812", # Messes with the formatter
"UP007", # Incompatible with pydantic + Python 3.9
"W293", #
]
[tool.coverage.run]
omit = [ "tests/*",]

View File

@ -17,7 +17,7 @@ def test_add_message_implementation_only() -> None:
def clear(self) -> None:
"""Clear the store."""
raise NotImplementedError()
raise NotImplementedError
store: list[BaseMessage] = []
chat_history = SampleChatHistory(store=store)
@ -50,7 +50,7 @@ def test_bulk_message_implementation_only() -> None:
def clear(self) -> None:
"""Clear the store."""
raise NotImplementedError()
raise NotImplementedError
chat_history = BulkAddHistory(store=store)
chat_history.add_message(HumanMessage(content="Hello"))

View File

@ -165,7 +165,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
raise NotImplementedError()
raise NotImplementedError
def _stream(
self,
@ -210,7 +210,7 @@ async def test_astream_implementation_uses_astream() -> None:
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
raise NotImplementedError()
raise NotImplementedError
async def _astream( # type: ignore
self,

View File

@ -161,7 +161,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()
raise NotImplementedError
def _stream(
self,
@ -198,7 +198,7 @@ async def test_astream_implementation_uses_astream() -> None:
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()
raise NotImplementedError
async def _astream(
self,

View File

@ -59,7 +59,7 @@ def test_base_transform_output_parser() -> None:
def parse(self, text: str) -> str:
"""Parse a single string into a specific format."""
raise NotImplementedError()
raise NotImplementedError
def parse_result(
self, result: list[Generation], *, partial: bool = False

View File

@ -61,13 +61,13 @@ def chain() -> Runnable:
def _raise_error(inputs: dict) -> str:
raise ValueError()
raise ValueError
def _dont_raise_error(inputs: dict) -> str:
if "exception" in inputs:
return "bar"
raise ValueError()
raise ValueError
@pytest.fixture()
@ -99,11 +99,11 @@ def _runnable(inputs: dict) -> str:
if inputs["text"] == "foo":
return "first"
if "exception" not in inputs:
raise ValueError()
raise ValueError
if inputs["text"] == "bar":
return "second"
if isinstance(inputs["exception"], ValueError):
raise RuntimeError()
raise RuntimeError
return "third"
@ -251,13 +251,13 @@ def _generate(input: Iterator) -> Iterator[str]:
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
raise ValueError()
raise ValueError
yield ""
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
yield ""
raise ValueError()
raise ValueError
def test_fallbacks_stream() -> None:
@ -279,13 +279,13 @@ async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
raise ValueError()
raise ValueError
yield ""
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
yield ""
raise ValueError()
raise ValueError
async def test_fallbacks_astream() -> None:

View File

@ -356,7 +356,7 @@ def test_runnable_get_graph_with_invalid_input_type() -> None:
@property
@override
def InputType(self) -> type:
raise TypeError()
raise TypeError
@override
def invoke(
@ -381,7 +381,7 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
@property
@override
def OutputType(self) -> type:
raise TypeError()
raise TypeError
@override
def invoke(

View File

@ -653,7 +653,7 @@ def test_with_types_with_type_generics() -> None:
def foo(x: int) -> None:
"""Add one to the input."""
raise NotImplementedError()
raise NotImplementedError
# Try specifying some
RunnableLambda(foo).with_types(
@ -3980,7 +3980,7 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
raise NotImplementedError()
raise NotImplementedError
def _batch(
self,
@ -4101,7 +4101,7 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
raise NotImplementedError()
raise NotImplementedError
async def _abatch(
self,
@ -5352,7 +5352,7 @@ async def test_listeners_async() -> None:
assert value2 in shared_state.values(), "Value not found in the dictionary."
async def test_closing_iterator_doesnt_raise_error() -> None:
def test_closing_iterator_doesnt_raise_error() -> None:
"""Test that closing an iterator calls on_chain_end rather than on_chain_error."""
import time
@ -5361,9 +5361,10 @@ async def test_closing_iterator_doesnt_raise_error() -> None:
from langchain_core.output_parsers import StrOutputParser
on_chain_error_triggered = False
on_chain_end_triggered = False
class MyHandler(BaseCallbackHandler):
async def on_chain_error(
def on_chain_error(
self,
error: BaseException,
*,
@ -5376,6 +5377,17 @@ async def test_closing_iterator_doesnt_raise_error() -> None:
nonlocal on_chain_error_triggered
on_chain_error_triggered = True
def on_chain_end(
self,
outputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
nonlocal on_chain_end_triggered
on_chain_end_triggered = True
llm = GenericFakeChatModel(messages=iter(["hi there"]))
chain = llm | StrOutputParser()
chain_ = chain.with_config({"callbacks": [MyHandler()]})
@ -5386,6 +5398,7 @@ async def test_closing_iterator_doesnt_raise_error() -> None:
# Wait for a bit to make sure that the callback is called.
time.sleep(0.05)
assert on_chain_error_triggered is False
assert on_chain_end_triggered is True
def test_pydantic_protected_namespaces() -> None:

View File

@ -2067,7 +2067,7 @@ class StreamingRunnable(Runnable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
raise NotImplementedError()
raise NotImplementedError
async def astream(
self,

View File

@ -19,7 +19,7 @@ from langchain_core.runnables.utils import (
[
(lambda x: x * 2, "lambda x: x * 2"),
(lambda a, b: a + b, "lambda a, b: a + b"),
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"),
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"), # noqa: FURB136
],
)
def test_get_lambda_source(func: Callable, expected_source: str) -> None:

View File

@ -5,6 +5,8 @@ from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
class AnyStr(str):
__slots__ = ()
def __eq__(self, other: Any) -> bool:
return isinstance(other, str)

View File

@ -346,7 +346,7 @@ class TestGetBufferString(unittest.TestCase):
self.chat_msg,
self.tool_calls_msg,
]
expected_output = "\n".join(
expected_output = "\n".join( # noqa: FLY002
[
"Human: human",
"AI: ai",

View File

@ -401,7 +401,7 @@ def test_structured_tool_from_function_docstring() -> None:
bar: the bar value
baz: the baz value
"""
raise NotImplementedError()
raise NotImplementedError
structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo"
@ -435,7 +435,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
bar: int
baz: List[str]
"""
raise NotImplementedError()
raise NotImplementedError
structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo"
@ -781,7 +781,7 @@ def test_structured_tool_from_function() -> None:
bar: the bar value
baz: the baz value
"""
raise NotImplementedError()
raise NotImplementedError
structured_tool = StructuredTool.from_function(foo)
assert structured_tool.name == "foo"
@ -854,7 +854,7 @@ def test_validation_error_handling_non_validation_error(
self,
tool_input: Union[str, dict],
) -> Union[str, dict[str, Any]]:
raise NotImplementedError()
raise NotImplementedError
def _run(self) -> str:
return "dummy"
@ -916,7 +916,7 @@ async def test_async_validation_error_handling_non_validation_error(
self,
tool_input: Union[str, dict],
) -> Union[str, dict[str, Any]]:
raise NotImplementedError()
raise NotImplementedError
def _run(self) -> str:
return "dummy"

View File

@ -39,7 +39,7 @@ async def test_inmemory_similarity_search() -> None:
output = await store.asimilarity_search("bar", k=2)
assert output == [
_any_id_document(page_content="bar"),
_any_id_document(page_content="baz"),
_any_id_document(page_content="foo"),
]
@ -81,7 +81,7 @@ async def test_inmemory_mmr() -> None:
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
assert len(output) == len(texts)
assert output[0] == _any_id_document(page_content="foo")
assert output[1] == _any_id_document(page_content="foy")
assert output[1] == _any_id_document(page_content="fou")
# Check async version
output = await docsearch.amax_marginal_relevance_search(
@ -89,7 +89,7 @@ async def test_inmemory_mmr() -> None:
)
assert len(output) == len(texts)
assert output[0] == _any_id_document(page_content="foo")
assert output[1] == _any_id_document(page_content="foy")
assert output[1] == _any_id_document(page_content="fou")
async def test_inmemory_dump_load(tmp_path: Path) -> None:

View File

@ -63,7 +63,7 @@ class CustomAddTextsVectorstore(VectorStore):
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
raise NotImplementedError()
raise NotImplementedError
class CustomAddDocumentsVectorstore(VectorStore):
@ -107,7 +107,7 @@ class CustomAddDocumentsVectorstore(VectorStore):
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
raise NotImplementedError()
raise NotImplementedError
@pytest.mark.parametrize(