diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index c9461e5d..b1a70ffb 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -19,8 +19,8 @@ class FakeListLLM(LLM, BaseModel): def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Increment counter, and then return response in that index.""" self.i += 1 - print(self.i) - print(self.responses) + print(f"=== Mock Response #{self.i} ===") + print(self.responses[self.i]) return self.responses[self.i] @property @@ -92,7 +92,10 @@ def test_agent_with_callbacks_global() -> None: output = agent.run("when was langchain made") assert output == "curses foiled again" - # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run + # 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run + assert handler.chain_starts == handler.chain_ends == 3 + assert handler.llm_starts == handler.llm_ends == 2 + assert handler.tool_starts == handler.tool_ends == 1 assert handler.starts == 6 # 1 extra agent end assert handler.ends == 7 @@ -130,7 +133,10 @@ def test_agent_with_callbacks_local() -> None: output = agent.run("when was langchain made") assert output == "curses foiled again" - # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run + # 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run + assert handler.chain_starts == handler.chain_ends == 3 + assert handler.llm_starts == handler.llm_ends == 2 + assert handler.tool_starts == handler.tool_ends == 1 assert handler.starts == 6 # 1 extra agent end assert handler.ends == 7 diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 36f32b85..7dd0fc01 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -39,14 +39,25 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler): """Whether to ignore agent callbacks.""" return self.ignore_agent_ + # add finer-grained counters for easier debugging of failing tests + chain_starts: int = 0 + chain_ends: int = 0 + llm_starts: int = 0 + llm_ends: int = 0 + tool_starts: int = 0 + tool_ends: int = 0 + agent_ends: int = 0 + def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" + self.llm_starts += 1 self.starts += 1 def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" + self.llm_ends += 1 self.ends += 1 def on_llm_error( @@ -59,10 +70,12 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler): self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" + self.chain_starts += 1 self.starts += 1 def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running.""" + self.chain_ends += 1 self.ends += 1 def on_chain_error( @@ -75,10 +88,12 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler): self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any ) -> None: """Run when tool starts running.""" + self.tool_starts += 1 self.starts += 1 def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" + self.tool_ends += 1 self.ends += 1 def on_tool_error( @@ -93,4 +108,5 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler): def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """Run when agent ends running.""" + self.agent_ends += 1 self.ends += 1