mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
Merge pull request #8 from VowpalWabbit/update_w_score
update score to take entire response object to make it easier for user
This commit is contained in:
commit
256849e02a
@ -402,7 +402,7 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
|
|
||||||
def update_with_delayed_score(
|
def update_with_delayed_score(
|
||||||
self, score: float, event: TEvent, force_score: bool = False
|
self, score: float, chain_response: Dict[str, Any], force_score: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Updates the learned policy with the score provided.
|
Updates the learned policy with the score provided.
|
||||||
@ -415,6 +415,7 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
)
|
)
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.on_feedback(score)
|
self.metrics.on_feedback(score)
|
||||||
|
event: TEvent = chain_response["selection_metadata"]
|
||||||
self._call_after_scoring_before_learning(event=event, score=score)
|
self._call_after_scoring_before_learning(event=event, score=score)
|
||||||
self.active_policy.learn(event=event)
|
self.active_policy.learn(event=event)
|
||||||
self.active_policy.log(event=event)
|
self.active_policy.log(event=event)
|
||||||
|
@ -86,7 +86,7 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
|||||||
selection_metadata = response["selection_metadata"]
|
selection_metadata = response["selection_metadata"]
|
||||||
assert selection_metadata.selected.score == 3.0
|
assert selection_metadata.selected.score == 3.0
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
chain.update_with_delayed_score(event=selection_metadata, score=100)
|
chain.update_with_delayed_score(chain_response=response, score=100)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
@ -109,7 +109,7 @@ def test_update_with_delayed_score_force() -> None:
|
|||||||
selection_metadata = response["selection_metadata"]
|
selection_metadata = response["selection_metadata"]
|
||||||
assert selection_metadata.selected.score == 3.0
|
assert selection_metadata.selected.score == 3.0
|
||||||
chain.update_with_delayed_score(
|
chain.update_with_delayed_score(
|
||||||
event=selection_metadata, score=100, force_score=True
|
chain_response=response, score=100, force_score=True
|
||||||
)
|
)
|
||||||
assert selection_metadata.selected.score == 100.0
|
assert selection_metadata.selected.score == 100.0
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ def test_update_with_delayed_score() -> None:
|
|||||||
assert response["response"] == "hey"
|
assert response["response"] == "hey"
|
||||||
selection_metadata = response["selection_metadata"]
|
selection_metadata = response["selection_metadata"]
|
||||||
assert selection_metadata.selected.score is None
|
assert selection_metadata.selected.score is None
|
||||||
chain.update_with_delayed_score(event=selection_metadata, score=100)
|
chain.update_with_delayed_score(chain_response=response, score=100)
|
||||||
assert selection_metadata.selected.score == 100.0
|
assert selection_metadata.selected.score == 100.0
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user