Suggestions for better debugging (#765)

Please feel free to disregard any changes you disagree with
This commit is contained in:
Amos Ng 2023-01-28 23:05:20 +07:00 committed by GitHub
parent 5198d6f541
commit 6ad360bdef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 4 deletions

View File

@ -19,8 +19,8 @@ class FakeListLLM(LLM, BaseModel):
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Increment counter, and then return response in that index."""
self.i += 1
print(self.i)
print(self.responses)
print(f"=== Mock Response #{self.i} ===")
print(self.responses[self.i])
return self.responses[self.i]
@property
@ -92,7 +92,10 @@ def test_agent_with_callbacks_global() -> None:
output = agent.run("when was langchain made")
assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.chain_starts == handler.chain_ends == 3
assert handler.llm_starts == handler.llm_ends == 2
assert handler.tool_starts == handler.tool_ends == 1
assert handler.starts == 6
# 1 extra agent end
assert handler.ends == 7
@ -130,7 +133,10 @@ def test_agent_with_callbacks_local() -> None:
output = agent.run("when was langchain made")
assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
assert handler.chain_starts == handler.chain_ends == 3
assert handler.llm_starts == handler.llm_ends == 2
assert handler.tool_starts == handler.tool_ends == 1
assert handler.starts == 6
# 1 extra agent end
assert handler.ends == 7

View File

@ -39,14 +39,25 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
# add finer-grained counters for easier debugging of failing tests
chain_starts: int = 0
chain_ends: int = 0
llm_starts: int = 0
llm_ends: int = 0
tool_starts: int = 0
tool_ends: int = 0
agent_ends: int = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
self.llm_starts += 1
self.starts += 1
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.llm_ends += 1
self.ends += 1
def on_llm_error(
@ -59,10 +70,12 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.chain_starts += 1
self.starts += 1
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.chain_ends += 1
self.ends += 1
def on_chain_error(
@ -75,10 +88,12 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.tool_starts += 1
self.starts += 1
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.tool_ends += 1
self.ends += 1
def on_tool_error(
@ -93,4 +108,5 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.agent_ends += 1
self.ends += 1