Merge pull request #7 from VowpalWabbit/scorer_activate_deactivate

activate and deactivate scorer
This commit is contained in:
olgavrou 2023-08-29 09:12:11 -04:00 committed by GitHub
commit d46ad01ee0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 32 deletions

View File

@ -343,6 +343,7 @@ class RLChain(Chain, Generic[TEvent]):
selection_scorer: Union[SelectionScorer, None] selection_scorer: Union[SelectionScorer, None]
active_policy: Policy = _NoOpPolicy() active_policy: Policy = _NoOpPolicy()
auto_embed: bool = False auto_embed: bool = False
selection_scorer_activated: bool = True
selected_input_key = "rl_chain_selected" selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on" selected_based_on_input_key = "rl_chain_selected_based_on"
metrics: Optional[MetricsTracker] = None metrics: Optional[MetricsTracker] = None
@ -400,6 +401,42 @@ class RLChain(Chain, Generic[TEvent]):
""" """
return [self.output_key] return [self.output_key]
def update_with_delayed_score(
self, score: float, event: TEvent, force_score: bool = False
) -> None:
"""
Updates the learned policy with the score provided.
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
""" # noqa: E501
if self._can_use_selection_scorer() and not force_score:
raise RuntimeError(
"The selection scorer is set, and force_score was not set to True. \
Please set force_score=True to use this function."
)
if self.metrics:
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)
def deactivate_selection_scorer(self) -> None:
"""
Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses.
""" # noqa: E501
self.selection_scorer_activated = False
def activate_selection_scorer(self) -> None:
"""
Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses.
""" # noqa: E501
self.selection_scorer_activated = True
def save_progress(self) -> None:
"""
This function should be called to save the state of the learned policy model.
""" # noqa: E501
self.active_policy.save()
def _validate_inputs(self, inputs: Dict[str, Any]) -> None: def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
super()._validate_inputs(inputs) super()._validate_inputs(inputs)
if ( if (
@ -412,6 +449,12 @@ class RLChain(Chain, Generic[TEvent]):
they are reserved for internal use during auto reward." they are reserved for internal use during auto reward."
) )
def _can_use_selection_scorer(self) -> bool:
"""
Returns whether the chain can use the selection scorer to score responses or not.
""" # noqa: E501
return self.selection_scorer is not None and self.selection_scorer_activated
@abstractmethod @abstractmethod
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent: def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
... ...
@ -434,30 +477,6 @@ class RLChain(Chain, Generic[TEvent]):
) -> TEvent: ) -> TEvent:
... ...
def update_with_delayed_score(
self, score: float, event: TEvent, force_score: bool = False
) -> None:
"""
Updates the learned policy with the score provided.
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
""" # noqa: E501
if self.selection_scorer and not force_score:
raise RuntimeError(
"The selection scorer is set, and force_score was not set to True. \
Please set force_score=True to use this function."
)
if self.metrics:
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score)
self.active_policy.learn(event=event)
self.active_policy.log(event=event)
def set_auto_embed(self, auto_embed: bool) -> None:
"""
Sets whether the chain should auto embed the inputs or not.
"""
self.auto_embed = auto_embed
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: Dict[str, Any],
@ -494,8 +513,8 @@ class RLChain(Chain, Generic[TEvent]):
score = None score = None
try: try:
if self.selection_scorer: if self._can_use_selection_scorer():
score = self.selection_scorer.score_response( score = self.selection_scorer.score_response( # type: ignore
inputs=next_chain_inputs, llm_response=output, event=event inputs=next_chain_inputs, llm_response=output, event=event
) )
except Exception as e: except Exception as e:
@ -511,12 +530,6 @@ class RLChain(Chain, Generic[TEvent]):
return {self.output_key: {"response": output, "selection_metadata": event}} return {self.output_key: {"response": output, "selection_metadata": event}}
def save_progress(self) -> None:
"""
This function should be called to save the state of the learned policy model.
"""
self.active_policy.save()
@property @property
def _chain_type(self) -> str: def _chain_type(self) -> str:
return "llm_personalizer_chain" return "llm_personalizer_chain"

View File

@ -363,3 +363,41 @@ def test_calling_chain_w_reserved_inputs_throws() -> None:
User=rl_chain.BasedOn("Context"), User=rl_chain.BasedOn("Context"),
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]), rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
) )
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_activate_and_deactivate_scorer() -> None:
llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=[300])
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 300.0
chain.deactivate_selection_scorer()
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score is None
chain.activate_selection_scorer()
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 300.0