mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
Merge pull request #8 from VowpalWabbit/update_w_score
update score to take entire response object to make it easier for user
This commit is contained in:
commit
256849e02a
@ -402,7 +402,7 @@ class RLChain(Chain, Generic[TEvent]):
|
||||
return [self.output_key]
|
||||
|
||||
def update_with_delayed_score(
|
||||
self, score: float, event: TEvent, force_score: bool = False
|
||||
self, score: float, chain_response: Dict[str, Any], force_score: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Updates the learned policy with the score provided.
|
||||
@ -415,6 +415,7 @@ class RLChain(Chain, Generic[TEvent]):
|
||||
)
|
||||
if self.metrics:
|
||||
self.metrics.on_feedback(score)
|
||||
event: TEvent = chain_response["selection_metadata"]
|
||||
self._call_after_scoring_before_learning(event=event, score=score)
|
||||
self.active_policy.learn(event=event)
|
||||
self.active_policy.log(event=event)
|
||||
|
@ -86,7 +86,7 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 3.0
|
||||
with pytest.raises(RuntimeError):
|
||||
chain.update_with_delayed_score(event=selection_metadata, score=100)
|
||||
chain.update_with_delayed_score(chain_response=response, score=100)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
@ -109,7 +109,7 @@ def test_update_with_delayed_score_force() -> None:
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 3.0
|
||||
chain.update_with_delayed_score(
|
||||
event=selection_metadata, score=100, force_score=True
|
||||
chain_response=response, score=100, force_score=True
|
||||
)
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
@ -131,7 +131,7 @@ def test_update_with_delayed_score() -> None:
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score is None
|
||||
chain.update_with_delayed_score(event=selection_metadata, score=100)
|
||||
chain.update_with_delayed_score(chain_response=response, score=100)
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user