Merge pull request #8 from VowpalWabbit/update_w_score

update score to take entire response object to make it easier for user
This commit is contained in:
olgavrou 2023-08-29 09:18:52 -04:00 committed by GitHub
commit 256849e02a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 4 deletions

View File

@ -402,7 +402,7 @@ class RLChain(Chain, Generic[TEvent]):
return [self.output_key]
def update_with_delayed_score(
self, score: float, event: TEvent, force_score: bool = False
self, score: float, chain_response: Dict[str, Any], force_score: bool = False
) -> None:
"""
Updates the learned policy with the score provided.
@ -415,6 +415,7 @@ class RLChain(Chain, Generic[TEvent]):
)
if self.metrics:
self.metrics.on_feedback(score)
event: TEvent = chain_response["selection_metadata"]
self._call_after_scoring_before_learning(event=event, score=score)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)

View File

@ -86,7 +86,7 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 3.0
with pytest.raises(RuntimeError):
chain.update_with_delayed_score(event=selection_metadata, score=100)
chain.update_with_delayed_score(chain_response=response, score=100)
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
@ -109,7 +109,7 @@ def test_update_with_delayed_score_force() -> None:
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 3.0
chain.update_with_delayed_score(
event=selection_metadata, score=100, force_score=True
chain_response=response, score=100, force_score=True
)
assert selection_metadata.selected.score == 100.0
@ -131,7 +131,7 @@ def test_update_with_delayed_score() -> None:
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score is None
chain.update_with_delayed_score(event=selection_metadata, score=100)
chain.update_with_delayed_score(chain_response=response, score=100)
assert selection_metadata.selected.score == 100.0