diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index c250815943..d08200c709 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -402,7 +402,7 @@ class RLChain(Chain, Generic[TEvent]): return [self.output_key] def update_with_delayed_score( - self, score: float, event: TEvent, force_score: bool = False + self, score: float, chain_response: Dict[str, Any], force_score: bool = False ) -> None: """ Updates the learned policy with the score provided. @@ -415,6 +415,7 @@ class RLChain(Chain, Generic[TEvent]): ) if self.metrics: self.metrics.on_feedback(score) + event: TEvent = chain_response["selection_metadata"] self._call_after_scoring_before_learning(event=event, score=score) self.active_policy.learn(event=event) self.active_policy.log(event=event) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py index d7dee7fdf6..d4576ce254 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py @@ -86,7 +86,7 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None: selection_metadata = response["selection_metadata"] assert selection_metadata.selected.score == 3.0 with pytest.raises(RuntimeError): - chain.update_with_delayed_score(event=selection_metadata, score=100) + chain.update_with_delayed_score(chain_response=response, score=100) @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @@ -109,7 +109,7 @@ def test_update_with_delayed_score_force() -> None: selection_metadata = response["selection_metadata"] assert selection_metadata.selected.score == 3.0 chain.update_with_delayed_score( - event=selection_metadata, score=100, force_score=True + chain_response=response, score=100, force_score=True ) assert selection_metadata.selected.score == 100.0 @@ -131,7 +131,7 @@ def test_update_with_delayed_score() -> None: assert response["response"] == "hey" selection_metadata = response["selection_metadata"] assert selection_metadata.selected.score is None - chain.update_with_delayed_score(event=selection_metadata, score=100) + chain.update_with_delayed_score(chain_response=response, score=100) assert selection_metadata.selected.score == 100.0