mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
Merge pull request #7 from VowpalWabbit/scorer_activate_deactivate
activate and deactivate scorer
This commit is contained in:
commit
d46ad01ee0
@ -343,6 +343,7 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
selection_scorer: Union[SelectionScorer, None]
|
selection_scorer: Union[SelectionScorer, None]
|
||||||
active_policy: Policy = _NoOpPolicy()
|
active_policy: Policy = _NoOpPolicy()
|
||||||
auto_embed: bool = False
|
auto_embed: bool = False
|
||||||
|
selection_scorer_activated: bool = True
|
||||||
selected_input_key = "rl_chain_selected"
|
selected_input_key = "rl_chain_selected"
|
||||||
selected_based_on_input_key = "rl_chain_selected_based_on"
|
selected_based_on_input_key = "rl_chain_selected_based_on"
|
||||||
metrics: Optional[MetricsTracker] = None
|
metrics: Optional[MetricsTracker] = None
|
||||||
@ -400,6 +401,42 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
"""
|
"""
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
|
|
||||||
|
def update_with_delayed_score(
|
||||||
|
self, score: float, event: TEvent, force_score: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Updates the learned policy with the score provided.
|
||||||
|
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
|
||||||
|
""" # noqa: E501
|
||||||
|
if self._can_use_selection_scorer() and not force_score:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The selection scorer is set, and force_score was not set to True. \
|
||||||
|
Please set force_score=True to use this function."
|
||||||
|
)
|
||||||
|
if self.metrics:
|
||||||
|
self.metrics.on_feedback(score)
|
||||||
|
self._call_after_scoring_before_learning(event=event, score=score)
|
||||||
|
self.active_policy.learn(event=event)
|
||||||
|
self.active_policy.log(event=event)
|
||||||
|
|
||||||
|
def deactivate_selection_scorer(self) -> None:
|
||||||
|
"""
|
||||||
|
Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses.
|
||||||
|
""" # noqa: E501
|
||||||
|
self.selection_scorer_activated = False
|
||||||
|
|
||||||
|
def activate_selection_scorer(self) -> None:
|
||||||
|
"""
|
||||||
|
Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses.
|
||||||
|
""" # noqa: E501
|
||||||
|
self.selection_scorer_activated = True
|
||||||
|
|
||||||
|
def save_progress(self) -> None:
|
||||||
|
"""
|
||||||
|
This function should be called to save the state of the learned policy model.
|
||||||
|
""" # noqa: E501
|
||||||
|
self.active_policy.save()
|
||||||
|
|
||||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||||
super()._validate_inputs(inputs)
|
super()._validate_inputs(inputs)
|
||||||
if (
|
if (
|
||||||
@ -412,6 +449,12 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
they are reserved for internal use during auto reward."
|
they are reserved for internal use during auto reward."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _can_use_selection_scorer(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns whether the chain can use the selection scorer to score responses or not.
|
||||||
|
""" # noqa: E501
|
||||||
|
return self.selection_scorer is not None and self.selection_scorer_activated
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
|
def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:
|
||||||
...
|
...
|
||||||
@ -434,30 +477,6 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
) -> TEvent:
|
) -> TEvent:
|
||||||
...
|
...
|
||||||
|
|
||||||
def update_with_delayed_score(
|
|
||||||
self, score: float, event: TEvent, force_score: bool = False
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Updates the learned policy with the score provided.
|
|
||||||
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
|
|
||||||
""" # noqa: E501
|
|
||||||
if self.selection_scorer and not force_score:
|
|
||||||
raise RuntimeError(
|
|
||||||
"The selection scorer is set, and force_score was not set to True. \
|
|
||||||
Please set force_score=True to use this function."
|
|
||||||
)
|
|
||||||
if self.metrics:
|
|
||||||
self.metrics.on_feedback(score)
|
|
||||||
self._call_after_scoring_before_learning(event=event, score=score)
|
|
||||||
self.active_policy.learn(event=event)
|
|
||||||
self.active_policy.log(event=event)
|
|
||||||
|
|
||||||
def set_auto_embed(self, auto_embed: bool) -> None:
|
|
||||||
"""
|
|
||||||
Sets whether the chain should auto embed the inputs or not.
|
|
||||||
"""
|
|
||||||
self.auto_embed = auto_embed
|
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
@ -494,8 +513,8 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
|
|
||||||
score = None
|
score = None
|
||||||
try:
|
try:
|
||||||
if self.selection_scorer:
|
if self._can_use_selection_scorer():
|
||||||
score = self.selection_scorer.score_response(
|
score = self.selection_scorer.score_response( # type: ignore
|
||||||
inputs=next_chain_inputs, llm_response=output, event=event
|
inputs=next_chain_inputs, llm_response=output, event=event
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -511,12 +530,6 @@ class RLChain(Chain, Generic[TEvent]):
|
|||||||
|
|
||||||
return {self.output_key: {"response": output, "selection_metadata": event}}
|
return {self.output_key: {"response": output, "selection_metadata": event}}
|
||||||
|
|
||||||
def save_progress(self) -> None:
|
|
||||||
"""
|
|
||||||
This function should be called to save the state of the learned policy model.
|
|
||||||
"""
|
|
||||||
self.active_policy.save()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "llm_personalizer_chain"
|
return "llm_personalizer_chain"
|
||||||
|
@ -363,3 +363,41 @@ def test_calling_chain_w_reserved_inputs_throws() -> None:
|
|||||||
User=rl_chain.BasedOn("Context"),
|
User=rl_chain.BasedOn("Context"),
|
||||||
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
|
def test_activate_and_deactivate_scorer() -> None:
|
||||||
|
llm, PROMPT = setup()
|
||||||
|
scorer_llm = FakeListChatModel(responses=[300])
|
||||||
|
chain = pick_best_chain.PickBest.from_llm(
|
||||||
|
llm=llm,
|
||||||
|
prompt=PROMPT,
|
||||||
|
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
|
||||||
|
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
|
||||||
|
)
|
||||||
|
response = chain.run(
|
||||||
|
User=pick_best_chain.base.BasedOn("Context"),
|
||||||
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||||
|
)
|
||||||
|
# chain llm used for both basic prompt and for scoring
|
||||||
|
assert response["response"] == "hey"
|
||||||
|
selection_metadata = response["selection_metadata"]
|
||||||
|
assert selection_metadata.selected.score == 300.0
|
||||||
|
|
||||||
|
chain.deactivate_selection_scorer()
|
||||||
|
response = chain.run(
|
||||||
|
User=pick_best_chain.base.BasedOn("Context"),
|
||||||
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||||
|
)
|
||||||
|
assert response["response"] == "hey"
|
||||||
|
selection_metadata = response["selection_metadata"]
|
||||||
|
assert selection_metadata.selected.score is None
|
||||||
|
|
||||||
|
chain.activate_selection_scorer()
|
||||||
|
response = chain.run(
|
||||||
|
User=pick_best_chain.base.BasedOn("Context"),
|
||||||
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||||
|
)
|
||||||
|
assert response["response"] == "hey"
|
||||||
|
selection_metadata = response["selection_metadata"]
|
||||||
|
assert selection_metadata.selected.score == 300.0
|
||||||
|
Loading…
Reference in New Issue
Block a user