mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
no errors in pick best chain
This commit is contained in:
parent
6a1102d4c0
commit
dd6fff1c62
@ -13,7 +13,7 @@ from langchain.chains.rl_chain.base import (
|
||||
from langchain.chains.rl_chain.pick_best_chain import PickBest
|
||||
|
||||
|
||||
def configure_logger():
|
||||
def configure_logger() -> None:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
ch = logging.StreamHandler()
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import langchain.chains.rl_chain.base as base
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
@ -145,7 +145,6 @@ class PickBest(base.RLChain[PickBest.Event]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
@ -163,12 +162,14 @@ class PickBest(base.RLChain[PickBest.Event]):
|
||||
raise ValueError(
|
||||
"If vw_cmd is specified, it must include --cb_explore_adf"
|
||||
)
|
||||
|
||||
kwargs["vw_cmd"] = vw_cmd
|
||||
|
||||
feature_embedder = kwargs.get("feature_embedder", None)
|
||||
if not feature_embedder:
|
||||
feature_embedder = PickBestFeatureEmbedder()
|
||||
kwargs["feature_embedder"] = feature_embedder
|
||||
|
||||
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||
@ -223,10 +224,15 @@ class PickBest(base.RLChain[PickBest.Event]):
|
||||
next_chain_inputs = event.inputs.copy()
|
||||
# only one key, value pair in event.to_select_from
|
||||
value = next(iter(event.to_select_from.values()))
|
||||
v = (
|
||||
value[event.selected.index]
|
||||
if event.selected
|
||||
else event.to_select_from.values()
|
||||
)
|
||||
next_chain_inputs.update(
|
||||
{
|
||||
self.selected_based_on_input_key: str(event.based_on),
|
||||
self.selected_input_key: value[event.selected.index],
|
||||
self.selected_input_key: v,
|
||||
}
|
||||
)
|
||||
return next_chain_inputs, event
|
||||
@ -234,6 +240,7 @@ class PickBest(base.RLChain[PickBest.Event]):
|
||||
def _call_after_scoring_before_learning(
|
||||
self, event: Event, score: Optional[float]
|
||||
) -> Event:
|
||||
if event.selected:
|
||||
event.selected.score = score
|
||||
return event
|
||||
|
||||
@ -249,34 +256,20 @@ class PickBest(base.RLChain[PickBest.Event]):
|
||||
return "rl_chain_pick_best"
|
||||
|
||||
@classmethod
|
||||
def from_chain(
|
||||
cls,
|
||||
llm_chain: Chain,
|
||||
def from_llm(
|
||||
cls: Type[PickBest],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
selection_scorer=SENTINEL,
|
||||
selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL,
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> PickBest:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
if selection_scorer is SENTINEL:
|
||||
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
|
||||
|
||||
return PickBest(
|
||||
llm_chain=llm_chain,
|
||||
prompt=prompt,
|
||||
selection_scorer=selection_scorer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
selection_scorer=SENTINEL,
|
||||
**kwargs: Any,
|
||||
):
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return PickBest.from_chain(
|
||||
llm_chain=llm_chain,
|
||||
prompt=prompt,
|
||||
selection_scorer=selection_scorer,
|
||||
**kwargs,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user