mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
Merge pull request #7 from VowpalWabbit/scorer_activate_deactivate
activate and deactivate scorer
This commit is contained in:
commit
d46ad01ee0
@ -343,6 +343,7 @@ class RLChain(Chain, Generic[TEvent]):
|
||||
selection_scorer: Union[SelectionScorer, None]
|
||||
active_policy: Policy = _NoOpPolicy()
|
||||
auto_embed: bool = False
|
||||
selection_scorer_activated: bool = True
|
||||
selected_input_key = "rl_chain_selected"
|
||||
selected_based_on_input_key = "rl_chain_selected_based_on"
|
||||
metrics: Optional[MetricsTracker] = None
|
||||
@ -400,6 +401,42 @@ class RLChain(Chain, Generic[TEvent]):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def update_with_delayed_score(
|
||||
self, score: float, event: TEvent, force_score: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Updates the learned policy with the score provided.
|
||||
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
|
||||
""" # noqa: E501
|
||||
if self._can_use_selection_scorer() and not force_score:
|
||||
raise RuntimeError(
|
||||
"The selection scorer is set, and force_score was not set to True. \
|
||||
Please set force_score=True to use this function."
|
||||
)
|
||||
if self.metrics:
|
||||
self.metrics.on_feedback(score)
|
||||
self._call_after_scoring_before_learning(event=event, score=score)
|
||||
self.active_policy.learn(event=event)
|
||||
self.active_policy.log(event=event)
|
||||
|
||||
def deactivate_selection_scorer(self) -> None:
|
||||
"""
|
||||
Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses.
|
||||
""" # noqa: E501
|
||||
self.selection_scorer_activated = False
|
||||
|
||||
def activate_selection_scorer(self) -> None:
|
||||
"""
|
||||
Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses.
|
||||
""" # noqa: E501
|
||||
self.selection_scorer_activated = True
|
||||
|
||||
def save_progress(self) -> None:
|
||||
"""
|
||||
This function should be called to save the state of the learned policy model.
|
||||
""" # noqa: E501
|
||||
self.active_policy.save()
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
super()._validate_inputs(inputs)
|
||||
if (
|
||||
@ -412,6 +449,12 @@ class RLChain(Chain, Generic[TEvent]):
|
||||
they are reserved for internal use during auto reward."
|
||||
)
|
||||
|
||||
def _can_use_selection_scorer(self) -> bool:
|
||||
"""
|
||||
Returns whether the chain can use the selection scorer to score responses or not.
|
||||
""" # noqa: E501
|
||||
return self.selection_scorer is not None and self.selection_scorer_activated
|
||||
|
||||
@abstractmethod
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
|
||||
...
|
||||
@ -434,30 +477,6 @@ class RLChain(Chain, Generic[TEvent]):
|
||||
) -> TEvent:
|
||||
...
|
||||
|
||||
def update_with_delayed_score(
|
||||
self, score: float, event: TEvent, force_score: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Updates the learned policy with the score provided.
|
||||
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
|
||||
""" # noqa: E501
|
||||
if self.selection_scorer and not force_score:
|
||||
raise RuntimeError(
|
||||
"The selection scorer is set, and force_score was not set to True. \
|
||||
Please set force_score=True to use this function."
|
||||
)
|
||||
if self.metrics:
|
||||
self.metrics.on_feedback(score)
|
||||
self._call_after_scoring_before_learning(event=event, score=score)
|
||||
self.active_policy.learn(event=event)
|
||||
self.active_policy.log(event=event)
|
||||
|
||||
def set_auto_embed(self, auto_embed: bool) -> None:
|
||||
"""
|
||||
Sets whether the chain should auto embed the inputs or not.
|
||||
"""
|
||||
self.auto_embed = auto_embed
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
@ -494,8 +513,8 @@ class RLChain(Chain, Generic[TEvent]):
|
||||
|
||||
score = None
|
||||
try:
|
||||
if self.selection_scorer:
|
||||
score = self.selection_scorer.score_response(
|
||||
if self._can_use_selection_scorer():
|
||||
score = self.selection_scorer.score_response( # type: ignore
|
||||
inputs=next_chain_inputs, llm_response=output, event=event
|
||||
)
|
||||
except Exception as e:
|
||||
@ -511,12 +530,6 @@ class RLChain(Chain, Generic[TEvent]):
|
||||
|
||||
return {self.output_key: {"response": output, "selection_metadata": event}}
|
||||
|
||||
def save_progress(self) -> None:
|
||||
"""
|
||||
This function should be called to save the state of the learned policy model.
|
||||
"""
|
||||
self.active_policy.save()
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_personalizer_chain"
|
||||
|
@ -363,3 +363,41 @@ def test_calling_chain_w_reserved_inputs_throws() -> None:
|
||||
User=rl_chain.BasedOn("Context"),
|
||||
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_activate_and_deactivate_scorer() -> None:
|
||||
llm, PROMPT = setup()
|
||||
scorer_llm = FakeListChatModel(responses=[300])
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
llm=llm,
|
||||
prompt=PROMPT,
|
||||
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
|
||||
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
||||
)
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 300.0
|
||||
|
||||
chain.deactivate_selection_scorer()
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score is None
|
||||
|
||||
chain.activate_selection_scorer()
|
||||
response = chain.run(
|
||||
User=pick_best_chain.base.BasedOn("Context"),
|
||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||
)
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == 300.0
|
||||
|
Loading…
Reference in New Issue
Block a user