forked from Archives/langchain
Suggestions for better debugging (#765)
Please feel free to disregard any changes you disagree with
This commit is contained in:
parent
5198d6f541
commit
6ad360bdef
@ -19,8 +19,8 @@ class FakeListLLM(LLM, BaseModel):
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Increment counter, and then return response in that index."""
|
||||
self.i += 1
|
||||
print(self.i)
|
||||
print(self.responses)
|
||||
print(f"=== Mock Response #{self.i} ===")
|
||||
print(self.responses[self.i])
|
||||
return self.responses[self.i]
|
||||
|
||||
@property
|
||||
@ -92,7 +92,10 @@ def test_agent_with_callbacks_global() -> None:
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.chain_starts == handler.chain_ends == 3
|
||||
assert handler.llm_starts == handler.llm_ends == 2
|
||||
assert handler.tool_starts == handler.tool_ends == 1
|
||||
assert handler.starts == 6
|
||||
# 1 extra agent end
|
||||
assert handler.ends == 7
|
||||
@ -130,7 +133,10 @@ def test_agent_with_callbacks_local() -> None:
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
|
||||
assert handler.chain_starts == handler.chain_ends == 3
|
||||
assert handler.llm_starts == handler.llm_ends == 2
|
||||
assert handler.tool_starts == handler.tool_ends == 1
|
||||
assert handler.starts == 6
|
||||
# 1 extra agent end
|
||||
assert handler.ends == 7
|
||||
|
@ -39,14 +39,25 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
# add finer-grained counters for easier debugging of failing tests
|
||||
chain_starts: int = 0
|
||||
chain_ends: int = 0
|
||||
llm_starts: int = 0
|
||||
llm_ends: int = 0
|
||||
tool_starts: int = 0
|
||||
tool_ends: int = 0
|
||||
agent_ends: int = 0
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_llm_error(
|
||||
@ -59,10 +70,12 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_chain_error(
|
||||
@ -75,10 +88,12 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_tool_error(
|
||||
@ -93,4 +108,5 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
Loading…
Reference in New Issue
Block a user