Merge pull request #8 from VowpalWabbit/update_w_score

update score to take entire response object to make it easier for user
This commit is contained in:
olgavrou 2023-08-29 09:18:52 -04:00 committed by GitHub
commit 256849e02a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 4 deletions

View File

@ -402,7 +402,7 @@ class RLChain(Chain, Generic[TEvent]):
return [self.output_key] return [self.output_key]
def update_with_delayed_score( def update_with_delayed_score(
self, score: float, event: TEvent, force_score: bool = False self, score: float, chain_response: Dict[str, Any], force_score: bool = False
) -> None: ) -> None:
""" """
Updates the learned policy with the score provided. Updates the learned policy with the score provided.
@ -415,6 +415,7 @@ class RLChain(Chain, Generic[TEvent]):
) )
if self.metrics: if self.metrics:
self.metrics.on_feedback(score) self.metrics.on_feedback(score)
event: TEvent = chain_response["selection_metadata"]
self._call_after_scoring_before_learning(event=event, score=score) self._call_after_scoring_before_learning(event=event, score=score)
self.active_policy.learn(event=event) self.active_policy.learn(event=event)
self.active_policy.log(event=event) self.active_policy.log(event=event)

View File

@ -86,7 +86,7 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
selection_metadata = response["selection_metadata"] selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 3.0 assert selection_metadata.selected.score == 3.0
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
chain.update_with_delayed_score(event=selection_metadata, score=100) chain.update_with_delayed_score(chain_response=response, score=100)
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
@ -109,7 +109,7 @@ def test_update_with_delayed_score_force() -> None:
selection_metadata = response["selection_metadata"] selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 3.0 assert selection_metadata.selected.score == 3.0
chain.update_with_delayed_score( chain.update_with_delayed_score(
event=selection_metadata, score=100, force_score=True chain_response=response, score=100, force_score=True
) )
assert selection_metadata.selected.score == 100.0 assert selection_metadata.selected.score == 100.0
@ -131,7 +131,7 @@ def test_update_with_delayed_score() -> None:
assert response["response"] == "hey" assert response["response"] == "hey"
selection_metadata = response["selection_metadata"] selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score is None assert selection_metadata.selected.score is None
chain.update_with_delayed_score(event=selection_metadata, score=100) chain.update_with_delayed_score(chain_response=response, score=100)
assert selection_metadata.selected.score == 100.0 assert selection_metadata.selected.score == 100.0