forked from Archives/langchain
Suggestions for better debugging (#765)
Please feel free to disregard any changes you disagree with
This commit is contained in:
parent
5198d6f541
commit
6ad360bdef
@ -19,8 +19,8 @@ class FakeListLLM(LLM, BaseModel):
|
|||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
"""Increment counter, and then return response in that index."""
|
"""Increment counter, and then return response in that index."""
|
||||||
self.i += 1
|
self.i += 1
|
||||||
print(self.i)
|
print(f"=== Mock Response #{self.i} ===")
|
||||||
print(self.responses)
|
print(self.responses[self.i])
|
||||||
return self.responses[self.i]
|
return self.responses[self.i]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -92,7 +92,10 @@ def test_agent_with_callbacks_global() -> None:
|
|||||||
output = agent.run("when was langchain made")
|
output = agent.run("when was langchain made")
|
||||||
assert output == "curses foiled again"
|
assert output == "curses foiled again"
|
||||||
|
|
||||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||||
|
assert handler.chain_starts == handler.chain_ends == 3
|
||||||
|
assert handler.llm_starts == handler.llm_ends == 2
|
||||||
|
assert handler.tool_starts == handler.tool_ends == 1
|
||||||
assert handler.starts == 6
|
assert handler.starts == 6
|
||||||
# 1 extra agent end
|
# 1 extra agent end
|
||||||
assert handler.ends == 7
|
assert handler.ends == 7
|
||||||
@ -130,7 +133,10 @@ def test_agent_with_callbacks_local() -> None:
|
|||||||
output = agent.run("when was langchain made")
|
output = agent.run("when was langchain made")
|
||||||
assert output == "curses foiled again"
|
assert output == "curses foiled again"
|
||||||
|
|
||||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
|
||||||
|
assert handler.chain_starts == handler.chain_ends == 3
|
||||||
|
assert handler.llm_starts == handler.llm_ends == 2
|
||||||
|
assert handler.tool_starts == handler.tool_ends == 1
|
||||||
assert handler.starts == 6
|
assert handler.starts == 6
|
||||||
# 1 extra agent end
|
# 1 extra agent end
|
||||||
assert handler.ends == 7
|
assert handler.ends == 7
|
||||||
|
@ -39,14 +39,25 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
|||||||
"""Whether to ignore agent callbacks."""
|
"""Whether to ignore agent callbacks."""
|
||||||
return self.ignore_agent_
|
return self.ignore_agent_
|
||||||
|
|
||||||
|
# add finer-grained counters for easier debugging of failing tests
|
||||||
|
chain_starts: int = 0
|
||||||
|
chain_ends: int = 0
|
||||||
|
llm_starts: int = 0
|
||||||
|
llm_ends: int = 0
|
||||||
|
tool_starts: int = 0
|
||||||
|
tool_ends: int = 0
|
||||||
|
agent_ends: int = 0
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
|
self.llm_starts += 1
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Run when LLM ends running."""
|
"""Run when LLM ends running."""
|
||||||
|
self.llm_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
@ -59,10 +70,12 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
|||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running."""
|
||||||
|
self.chain_starts += 1
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""Run when chain ends running."""
|
"""Run when chain ends running."""
|
||||||
|
self.chain_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
@ -75,10 +88,12 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
|||||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
|
self.tool_starts += 1
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
self.tool_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
@ -93,4 +108,5 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
"""Run when agent ends running."""
|
"""Run when agent ends running."""
|
||||||
|
self.agent_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user