mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
resolving linting and formatting errors
This commit is contained in:
parent
a2f807e055
commit
5aafb3bc46
@ -1,16 +1,16 @@
|
||||
from langchain.chains.rl_chain.pick_best_chain import PickBest
|
||||
import logging
|
||||
|
||||
from langchain.chains.rl_chain.base import (
|
||||
Embed,
|
||||
BasedOn,
|
||||
ToSelectFrom,
|
||||
SelectionScorer,
|
||||
AutoSelectionScorer,
|
||||
BasedOn,
|
||||
Embed,
|
||||
Embedder,
|
||||
Policy,
|
||||
SelectionScorer,
|
||||
ToSelectFrom,
|
||||
VwPolicy,
|
||||
)
|
||||
|
||||
import logging
|
||||
from langchain.chains.rl_chain.pick_best_chain import PickBest
|
||||
|
||||
|
||||
def configure_logger():
|
||||
@ -26,3 +26,15 @@ def configure_logger():
|
||||
|
||||
|
||||
configure_logger()
|
||||
|
||||
__all__ = [
|
||||
"PickBest",
|
||||
"Embed",
|
||||
"BasedOn",
|
||||
"ToSelectFrom",
|
||||
"SelectionScorer",
|
||||
"AutoSelectionScorer",
|
||||
"Embedder",
|
||||
"Policy",
|
||||
"VwPolicy",
|
||||
]
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
@ -19,6 +19,9 @@ from langchain.prompts import (
|
||||
)
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -85,8 +88,6 @@ def EmbedAndKeep(anything):
|
||||
|
||||
|
||||
def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
return [parser.parse_line(line) for line in input_str.split("\n")]
|
||||
|
||||
|
||||
@ -113,8 +114,11 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
|
||||
|
||||
|
||||
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]):
|
||||
# go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if
|
||||
# their inner values are not already _Embed, then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
|
||||
"""
|
||||
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
|
||||
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
|
||||
""" # noqa: E501
|
||||
|
||||
next_inputs = inputs.copy()
|
||||
for k, v in next_inputs.items():
|
||||
if isinstance(v, _ToSelectFrom) or isinstance(v, _BasedOn):
|
||||
@ -219,13 +223,18 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||
@staticmethod
|
||||
def get_default_system_prompt() -> SystemMessagePromptTemplate:
|
||||
return SystemMessagePromptTemplate.from_template(
|
||||
"PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
|
||||
You must respond with your ranking by providing a single float within the range [0, 1], 0 being very bad response and 1 being very good response."
|
||||
"PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n \
|
||||
You are a strict judge that is called on to rank a response based on \
|
||||
given criteria. You must respond with your ranking by providing a \
|
||||
single float within the range [0, 1], 0 being very bad \
|
||||
response and 1 being very good response."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_default_prompt() -> ChatPromptTemplate:
|
||||
human_template = 'Given this based_on "{rl_chain_selected_based_on}" as the most important attribute, rank how good or bad this text is: "{llm_response}".'
|
||||
human_template = 'Given this based_on "{rl_chain_selected_based_on}" \
|
||||
as the most important attribute, rank how good or bad this text is: \
|
||||
"{llm_response}".'
|
||||
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
|
||||
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
@ -260,25 +269,36 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
|
||||
return resp
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"The llm did not manage to rank the response as expected, there is always the option to try again or tweak the reward prompt. Error: {e}"
|
||||
f"The auto selection scorer did not manage to score the response, \
|
||||
there is always the option to try again or tweak the reward prompt.\
|
||||
Error: {e}"
|
||||
)
|
||||
|
||||
|
||||
class RLChain(Chain):
|
||||
"""
|
||||
RLChain class that utilizes the Vowpal Wabbit (VW) model for personalization.
|
||||
The `RLChain` class leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning.
|
||||
|
||||
Attributes:
|
||||
model_loading (bool, optional): If set to True, the chain will attempt to load an existing VW model from the latest checkpoint file in the {model_save_dir} directory (current directory if none specified). If set to False, it will start training from scratch, potentially overwriting existing files. Defaults to True.
|
||||
large_action_spaces (bool, optional): If set to True and vw_cmd has not been specified in the constructor, it will enable large action spaces
|
||||
vw_cmd (List[str], optional): Advanced users can set the VW command line to whatever they want, as long as it is compatible with the Type that is specified (Type Enum)
|
||||
model_save_dir (str, optional): The directory to save the VW model to. Defaults to the current directory.
|
||||
selection_scorer (SelectionScorer): If set, the chain will check the response using the provided selection_scorer and the VW model will be updated with the result. Defaults to None.
|
||||
- llm_chain (Chain): Represents the underlying Language Model chain.
|
||||
- prompt (BasePromptTemplate): The template for the base prompt.
|
||||
- selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
|
||||
- policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
|
||||
- auto_embed (bool): Determines if embedding should be automatic. Default is True.
|
||||
- metrics (Optional[MetricsTracker]): Tracker for metrics, can be set to None.
|
||||
|
||||
Initialization Attributes:
|
||||
- feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs.
|
||||
- model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory.
|
||||
- reset_model (bool): If set to True, the model starts training from scratch. Default is False.
|
||||
- vw_cmd (List[str], optional): Command line arguments for the VW model.
|
||||
- policy (VwPolicy): Policy used by the chain.
|
||||
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
|
||||
- metrics_step (int): Step for the metrics tracker. Default is -1.
|
||||
|
||||
Notes:
|
||||
The class creates a VW model instance using the provided arguments. Before the chain object is destroyed the save_progress() function can be called. If it is called, the learned VW model is saved to a file in the current directory named `model-<checkpoint>.vw`. Checkpoints start at 1 and increment monotonically.
|
||||
When making predictions, VW is first called to choose action(s) which are then passed into the prompt with the key `{actions}`. After action selection, the LLM (Language Model) is called with the prompt populated by the chosen action(s), and the response is returned.
|
||||
"""
|
||||
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
|
||||
""" # noqa: E501
|
||||
|
||||
llm_chain: Chain
|
||||
|
||||
@ -306,7 +326,9 @@ class RLChain(Chain):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.selection_scorer is None:
|
||||
logger.warning(
|
||||
"No response validator provided, which means that no reinforcement learning will be done in the RL chain unless update_with_delayed_score is called."
|
||||
"No selection scorer provided, which means that no \
|
||||
reinforcement learning will be done in the RL chain \
|
||||
unless update_with_delayed_score is called."
|
||||
)
|
||||
self.policy = policy(
|
||||
model_repo=ModelRepository(
|
||||
@ -346,7 +368,9 @@ class RLChain(Chain):
|
||||
or self.selected_based_on_input_key in inputs.keys()
|
||||
):
|
||||
raise ValueError(
|
||||
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward."
|
||||
f"The rl chain does not accept '{self.selected_input_key}' \
|
||||
or '{self.selected_based_on_input_key}' as input keys, \
|
||||
they are reserved for internal use during auto reward."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@ -375,13 +399,13 @@ class RLChain(Chain):
|
||||
self, score: float, event: Event, force_score=False
|
||||
) -> None:
|
||||
"""
|
||||
Learn will be called with the score specified and the actions/embeddings/etc stored in event
|
||||
|
||||
Updates the learned policy with the score provided.
|
||||
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
|
||||
"""
|
||||
""" # noqa: E501
|
||||
if self.selection_scorer and not force_score:
|
||||
raise RuntimeError(
|
||||
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function."
|
||||
"The selection scorer is set, and force_score was not set to True. \
|
||||
Please set force_score=True to use this function."
|
||||
)
|
||||
self.metrics.on_feedback(score)
|
||||
self._call_after_scoring_before_learning(event=event, score=score)
|
||||
@ -390,10 +414,7 @@ class RLChain(Chain):
|
||||
|
||||
def set_auto_embed(self, auto_embed: bool) -> None:
|
||||
"""
|
||||
Set whether the chain should auto embed the inputs or not. If set to False, the inputs will not be embedded and the user will need to embed the inputs themselves before calling run.
|
||||
|
||||
Args:
|
||||
auto_embed (bool): Whether the chain should auto embed the inputs or not.
|
||||
Sets whether the chain should auto embed the inputs or not.
|
||||
"""
|
||||
self.auto_embed = auto_embed
|
||||
|
||||
@ -438,7 +459,8 @@ class RLChain(Chain):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"The LLM was not able to rank and the chain was not able to adjust to this response, error: {e}"
|
||||
f"The selection scorer was not able to score, \
|
||||
and the chain was not able to adjust to this response, error: {e}"
|
||||
)
|
||||
self.metrics.on_feedback(score)
|
||||
event = self._call_after_scoring_before_learning(score=score, event=event)
|
||||
@ -449,16 +471,7 @@ class RLChain(Chain):
|
||||
|
||||
def save_progress(self) -> None:
|
||||
"""
|
||||
This function should be called whenever there is a need to save the progress of the VW (Vowpal Wabbit) model within the chain. It saves the current state of the VW model to a file.
|
||||
|
||||
File Naming Convention:
|
||||
The file will be named using the pattern `model-<checkpoint>.vw`, where `<checkpoint>` is a monotonically increasing number. The numbering starts from 1, and increments by 1 for each subsequent save. If there are already saved checkpoints, the number used for `<checkpoint>` will be the next in the sequence.
|
||||
|
||||
Example:
|
||||
If there are already two saved checkpoints, `model-1.vw` and `model-2.vw`, the next time this function is called, it will save the model as `model-3.vw`.
|
||||
|
||||
Note:
|
||||
Be cautious when deleting or renaming checkpoint files manually, as this could cause the function to reuse checkpoint numbers.
|
||||
This function should be called to save the state of the Vowpal Wabbit model.
|
||||
"""
|
||||
self.policy.save()
|
||||
|
||||
@ -493,7 +506,8 @@ def embed_string_type(
|
||||
|
||||
if namespace is None:
|
||||
raise ValueError(
|
||||
"The default namespace must be provided when embedding a string or _Embed object."
|
||||
"The default namespace must be \
|
||||
provided when embedding a string or _Embed object."
|
||||
)
|
||||
|
||||
return {namespace: keep_str + join_char.join(map(str, encoded))}
|
||||
@ -533,7 +547,7 @@ def embed(
|
||||
namespace: Optional[str] = None,
|
||||
) -> List[Dict[str, Union[str, List[str]]]]:
|
||||
"""
|
||||
Embeds the actions or context using the SentenceTransformer model
|
||||
Embeds the actions or context using the SentenceTransformer model (or a model that has an `encode` function)
|
||||
|
||||
Attributes:
|
||||
to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries.
|
||||
@ -541,7 +555,7 @@ def embed(
|
||||
model: (Any, required) The model to use for embedding
|
||||
Returns:
|
||||
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
|
||||
"""
|
||||
""" # noqa: E501
|
||||
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
|
||||
to_embed, str
|
||||
):
|
||||
|
@ -1,4 +1,7 @@
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class MetricsTracker:
|
||||
|
@ -4,7 +4,10 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
from typing import TYPE_CHECKING, Sequence, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -35,8 +38,6 @@ class ModelRepository:
|
||||
return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0
|
||||
|
||||
def save(self, workspace: "vw.Workspace") -> None:
|
||||
import vowpal_wabbit_next as vw
|
||||
|
||||
with open(self.model_path, "wb") as f:
|
||||
logger.info(f"storing rl_chain model in: {self.model_path}")
|
||||
f.write(workspace.serialize())
|
||||
|
@ -12,17 +12,18 @@ from langchain.prompts import BasePromptTemplate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# sentinel object used to distinguish between user didn't supply anything or user explicitly supplied None
|
||||
# sentinel object used to distinguish between
|
||||
# user didn't supply anything or user explicitly supplied None
|
||||
SENTINEL = object()
|
||||
|
||||
|
||||
class PickBestFeatureEmbedder(base.Embedder):
|
||||
"""
|
||||
Contextual Bandit Text Embedder class that embeds the based_on and to_select_from into a format that can be used by VW
|
||||
Text Embedder class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
|
||||
|
||||
Attributes:
|
||||
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
|
||||
"""
|
||||
""" # noqa E501
|
||||
|
||||
def __init__(self, model: Optional[Any] = None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -36,7 +37,7 @@ class PickBestFeatureEmbedder(base.Embedder):
|
||||
|
||||
def format(self, event: PickBest.Event) -> str:
|
||||
"""
|
||||
Converts the based_on and to_select_from into a format that can be used by VW
|
||||
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
|
||||
"""
|
||||
|
||||
cost = None
|
||||
@ -68,14 +69,20 @@ class PickBestFeatureEmbedder(base.Embedder):
|
||||
example_string += "shared "
|
||||
for context_item in context_emb:
|
||||
for ns, based_on in context_item.items():
|
||||
example_string += f"|{ns} {' '.join(based_on) if isinstance(based_on, list) else based_on} "
|
||||
e = " ".join(based_on) if isinstance(based_on, list) else based_on
|
||||
example_string += f"|{ns} {e} "
|
||||
example_string += "\n"
|
||||
|
||||
for i, action in enumerate(action_embs):
|
||||
if cost is not None and chosen_action == i:
|
||||
example_string += f"{chosen_action}:{cost}:{prob} "
|
||||
for ns, action_embedding in action.items():
|
||||
example_string += f"|{ns} {' '.join(action_embedding) if isinstance(action_embedding, list) else action_embedding} "
|
||||
e = (
|
||||
" ".join(action_embedding)
|
||||
if isinstance(action_embedding, list)
|
||||
else action_embedding
|
||||
)
|
||||
example_string += f"|{ns} {e} "
|
||||
example_string += "\n"
|
||||
# Strip the last newline
|
||||
return example_string[:-1]
|
||||
@ -83,33 +90,31 @@ class PickBestFeatureEmbedder(base.Embedder):
|
||||
|
||||
class PickBest(base.RLChain):
|
||||
"""
|
||||
PickBest class that utilizes the Vowpal Wabbit (VW) model for personalization.
|
||||
`PickBest` is a class designed to leverage the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call.
|
||||
|
||||
The Chain is initialized with a set of potential to_select_from. For each call to the Chain, a specific action will be chosen based on an input based_on.
|
||||
This chosen action is then passed to the prompt that will be utilized in the subsequent call to the LLM (Language Model).
|
||||
Each invocation of the chain's `run()` method should be equipped with a set of potential actions (`ToSelectFrom`) and will result in the selection of a specific action based on the `BasedOn` input. This chosen action then informs the LLM (Language Model) prompt for the subsequent response generation.
|
||||
|
||||
The flow of this chain is:
|
||||
- Chain is initialized
|
||||
- Chain is called input containing the based_on and the List of potential to_select_from
|
||||
- Chain chooses an action based on the based_on
|
||||
- Chain calls the LLM with the chosen action
|
||||
- LLM returns a response
|
||||
- If the selection_scorer is specified, the response is checked against the selection_scorer
|
||||
- The internal model will be updated with the based_on, action, and reward of the response (how good or bad the response was)
|
||||
- The response is returned
|
||||
The standard operation flow of this Chain includes:
|
||||
1. The Chain is invoked with inputs containing the `BasedOn` criteria and a list of potential actions (`ToSelectFrom`).
|
||||
2. An action is selected based on the `BasedOn` input.
|
||||
3. The LLM is called with the dynamic prompt, producing a response.
|
||||
4. If a `selection_scorer` is provided, it is used to score the selection.
|
||||
5. The internal Vowpal Wabbit model is updated with the `BasedOn` input, the chosen `ToSelectFrom` action, and the resulting score from the scorer.
|
||||
6. The final response is returned.
|
||||
|
||||
Expected input dictionary format:
|
||||
- At least one variable encapsulated within `BasedOn` to serve as the selection criteria.
|
||||
- A single list variable within `ToSelectFrom`, representing potential actions for the VW model. This list can take the form of:
|
||||
- A list of strings, e.g., `action = ToSelectFrom(["action1", "action2", "action3"])`
|
||||
- A list of list of strings e.g. `action = ToSelectFrom([["action1", "another identifier of action1"], ["action2", "another identifier of action2"]])`
|
||||
- A list of dictionaries, where each dictionary represents an action with namespace names as keys and corresponding action strings as values. For instance, `action = ToSelectFrom([{"namespace1": ["action1", "another identifier of action1"], "namespace2": "action2"}, {"namespace1": "action3", "namespace2": "action4"}])`.
|
||||
|
||||
input dictionary expects:
|
||||
- at least one variable wrapped in BasedOn which will be the based_on to use for personalization
|
||||
- one variable of a list wrapped in ToSelectFrom which will be the list of to_select_from for the Vowpal Wabbit model to choose from.
|
||||
This list can either be a List of str's or a List of Dict's.
|
||||
- Actions provided as a list of strings e.g. to_select_from = ["action1", "action2", "action3"]
|
||||
- If to_select_from are provided as a list of dictionaries, each action should be a dictionary where the keys are namespace names and the values are the corresponding action strings e.g. to_select_from = [{"namespace1": "action1", "namespace2": "action2"}, {"namespace1": "action3", "namespace2": "action4"}]
|
||||
Extends:
|
||||
RLChain
|
||||
|
||||
Attributes:
|
||||
feature_embedder: (PickBestFeatureEmbedder, optional) The text embedder to use for embedding the based_on and the to_select_from. If not provided, a default embedder is used.
|
||||
"""
|
||||
feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized.
|
||||
""" # noqa E501
|
||||
|
||||
class Selected(base.Selected):
|
||||
index: Optional[int]
|
||||
@ -169,17 +174,23 @@ class PickBest(base.RLChain):
|
||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||
if not actions:
|
||||
raise ValueError(
|
||||
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
|
||||
"No variables using 'ToSelectFrom' found in the inputs. \
|
||||
Please include at least one variable containing \
|
||||
a list to select from."
|
||||
)
|
||||
|
||||
if len(list(actions.values())) > 1:
|
||||
raise ValueError(
|
||||
"Only one variable using 'ToSelectFrom' can be provided in the inputs for the PickBest chain. Please provide only one variable containing a list to select from."
|
||||
"Only one variable using 'ToSelectFrom' can be provided in the inputs \
|
||||
for the PickBest chain. Please provide only one variable \
|
||||
containing a list to select from."
|
||||
)
|
||||
|
||||
if not context:
|
||||
raise ValueError(
|
||||
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on."
|
||||
"No variables using 'BasedOn' found in the inputs. \
|
||||
Please include at least one variable containing information \
|
||||
to base the selected of ToSelectFrom on."
|
||||
)
|
||||
|
||||
event = PickBest.Event(inputs=inputs, to_select_from=actions, based_on=context)
|
||||
@ -231,19 +242,6 @@ class PickBest(base.RLChain):
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
When chain.run() is called with the given inputs, this function is called. It is responsible for calling the VW model to choose an action (ToSelectFrom) based on the (BasedOn) based_on, and then calling the LLM (Language Model) with the chosen action to generate a response.
|
||||
|
||||
Attributes:
|
||||
inputs: (Dict, required) The inputs to the chain. The inputs must contain a input variables that are wrapped in BasedOn and ToSelectFrom. BasedOn is the based_on that will be used for selecting an ToSelectFrom action that will be passed to the LLM prompt.
|
||||
run_manager: (CallbackManagerForChainRun, optional) The callback manager to use for this run. If not provided, a default callback manager is used.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- `response`: The response generated by the LLM (Language Model).
|
||||
- `selection_metadata`: A Event object containing all the information needed to learn the reward for the chosen action at a later point. If an automatic selection_scorer is not provided, then this object can be used at a later point with the `update_with_delayed_score()` function to learn the delayed reward and update the Vowpal Wabbit model.
|
||||
- the `score` in the `selection_metadata` object is set to None if an automatic selection_scorer is not provided or if the selection_scorer failed (e.g. LLM timeout or LLM failed to rank correctly).
|
||||
"""
|
||||
return super()._call(run_manager=run_manager, inputs=inputs)
|
||||
|
||||
@property
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Union, Optional
|
||||
from pathlib import Path
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class VwLogger:
|
||||
|
@ -111,7 +111,7 @@ def test_update_with_delayed_score():
|
||||
)
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == None
|
||||
assert selection_metadata.selected.score is None
|
||||
chain.update_with_delayed_score(event=selection_metadata, score=100)
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
@ -157,7 +157,7 @@ def test_default_embeddings():
|
||||
ctx_str_2 = "context2"
|
||||
|
||||
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
|
||||
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
|
||||
encoded_text + " ".join(char for char in ctx_str_2)
|
||||
|
||||
expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """
|
||||
|
||||
@ -261,7 +261,7 @@ def test_explicitly_no_scorer():
|
||||
# chain llm used for both basic prompt and for scoring
|
||||
assert response["response"] == "hey"
|
||||
selection_metadata = response["selection_metadata"]
|
||||
assert selection_metadata.selected.score == None
|
||||
assert selection_metadata.selected.score is None
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
|
@ -235,12 +235,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = encoded_text + " ".join(char for char in str1)
|
||||
encoded_str2 = encoded_text + " ".join(char for char in str2)
|
||||
encoded_text + " ".join(char for char in str2)
|
||||
encoded_str3 = encoded_text + " ".join(char for char in str3)
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
|
||||
encoded_text + " ".join(char for char in ctx_str_1)
|
||||
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
|
||||
|
||||
named_actions = {
|
||||
@ -269,12 +269,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
|
||||
str2 = "1"
|
||||
str3 = "2"
|
||||
encoded_str1 = encoded_text + " ".join(char for char in str1)
|
||||
encoded_str2 = encoded_text + " ".join(char for char in str2)
|
||||
encoded_text + " ".join(char for char in str2)
|
||||
encoded_str3 = encoded_text + " ".join(char for char in str3)
|
||||
|
||||
ctx_str_1 = "context1"
|
||||
ctx_str_2 = "context2"
|
||||
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
|
||||
encoded_text + " ".join(char for char in ctx_str_1)
|
||||
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
|
||||
|
||||
named_actions = {
|
||||
|
Loading…
Reference in New Issue
Block a user