|
|
|
@ -8,7 +8,7 @@ from langchain.chains.base import Chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ForkChain(Chain, BaseModel):
|
|
|
|
|
"""Conditionally executes a follow up chain based on the output of a decision chain."""
|
|
|
|
|
"""Conditionally executes follow up chain based on output of a decision chain."""
|
|
|
|
|
|
|
|
|
|
decision_chain: LLMChain
|
|
|
|
|
follow_up_chains: Dict[str, Chain]
|
|
|
|
@ -21,6 +21,7 @@ class ForkChain(Chain, BaseModel):
|
|
|
|
|
|
|
|
|
|
@validator("follow_up_chains")
|
|
|
|
|
def default_in_follow_up_chains(cls, v: Dict[str, Chain]) -> Dict[str, Chain]:
|
|
|
|
|
"""Make sure that `default` key exists in follow_up_chains."""
|
|
|
|
|
if "default" not in v:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"`follow_up_chains` must contain a 'default' option. "
|
|
|
|
|