fixes and tests

pull/10242/head
olgavrou 1 year ago
parent b162f1c8e1
commit ca163f0ee6

@ -118,8 +118,7 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]
if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. \
Please include at least one variable containing a list to select from."
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." # noqa: E501
)
based_on = {
@ -303,9 +302,7 @@ class AutoSelectionScorer(SelectionScorer[Event], BaseModel):
return resp
except Exception as e:
raise RuntimeError(
f"The auto selection scorer did not manage to score the response, \
there is always the option to try again or tweak the reward prompt.\
Error: {e}"
f"The auto selection scorer did not manage to score the response, there is always the option to try again or tweak the reward prompt. Error: {e}" # noqa: E501
)
@ -426,8 +423,7 @@ class RLChain(Chain, Generic[TEvent]):
""" # noqa: E501
if self._can_use_selection_scorer() and not force_score:
raise RuntimeError(
"The selection scorer is set, and force_score was not set to True. \
Please set force_score=True to use this function."
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function." # noqa: E501
)
if self.metrics:
self.metrics.on_feedback(score)
@ -461,9 +457,7 @@ class RLChain(Chain, Generic[TEvent]):
or self.selected_based_on_input_key in inputs.keys()
):
raise ValueError(
f"The rl chain does not accept '{self.selected_input_key}' \
or '{self.selected_based_on_input_key}' as input keys, \
they are reserved for internal use during auto reward."
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward." # noqa: E501
)
def _can_use_selection_scorer(self) -> bool:
@ -501,9 +495,6 @@ class RLChain(Chain, Generic[TEvent]):
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
# if self.auto_embed:
# inputs = prepare_inputs_for_autoembed(inputs=inputs)
event: TEvent = self._call_before_predict(inputs=inputs)
prediction = self.active_policy.predict(event=event)
if self.metrics:
@ -576,8 +567,7 @@ def embed_string_type(
if namespace is None:
raise ValueError(
"The default namespace must be \
provided when embedding a string or _Embed object."
"The default namespace must be provided when embedding a string or _Embed object." # noqa: E501
)
return {namespace: keep_str + encoded}

@ -118,7 +118,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
unique_contexts.add(f"{ns}={ea}")
else:
unique_contexts.add(f"{ns}={ee}")
encoded_contexts = self.model.encode(list(unique_contexts))
context_embeddings = dict(zip(unique_contexts, encoded_contexts))
@ -144,9 +144,9 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
indexed_dot_product[context_key] = {}
for j, action_key in enumerate(action_embeddings.keys()):
indexed_dot_product[context_key][action_key] = dot_product_matrix[i, j]
return indexed_dot_product
def format_auto_embed_on(self, event: PickBestEvent) -> str:
chosen_action, cost, prob = self.get_label(event)
context_emb, action_embs = self.get_context_and_action_embeddings(event)
@ -166,12 +166,12 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
line_parts.append(f"{elem}")
ns_a = f"{ns}={elem}"
nsa.append(ns_a)
for k,v in indexed_dot_product.items():
for k, v in indexed_dot_product.items():
dot_prods.append(v[ns_a])
nsa = " ".join(nsa)
line_parts.append(f"|# {nsa}")
line_parts.append(f"|embedding {self._str(dot_prods)}")
line_parts.append(f"|dotprod {self._str(dot_prods)}")
action_lines.append(" ".join(line_parts))
shared = []
@ -186,9 +186,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
nsc = " ".join(nsc)
shared.append(f"|@ {nsc}")
r = "shared " + " ".join(shared) + "\n" + "\n".join(action_lines)
print(r)
return r
return "shared " + " ".join(shared) + "\n" + "\n".join(action_lines)
def format_auto_embed_off(self, event: PickBestEvent) -> str:
"""
@ -262,29 +260,35 @@ class PickBest(base.RLChain[PickBestEvent]):
auto_embed = kwargs.get("auto_embed", False)
vw_cmd = kwargs.get("vw_cmd", [])
if not vw_cmd:
if vw_cmd:
if "--cb_explore_adf" not in vw_cmd:
raise ValueError(
"If vw_cmd is specified, it must include --cb_explore_adf"
)
else:
interactions = ["--interactions=::"]
if auto_embed:
interactions = [
"--interactions=@#",
"--ignore_linear=@",
"--ignore_linear=#",
"--noconstant",
]
vw_cmd = interactions + [
"--cb_explore_adf",
"--coin",
"--squarecb",
"--quiet",
]
else:
if "--cb_explore_adf" not in vw_cmd:
raise ValueError(
"If vw_cmd is specified, it must include --cb_explore_adf"
)
kwargs["vw_cmd"] = vw_cmd
feature_embedder = kwargs.get("feature_embedder", None)
if not feature_embedder:
if feature_embedder:
if "auto_embed" in kwargs:
logger.warning(
"auto_embed will take no effect when explicit feature_embedder is provided" # noqa E501
)
else:
feature_embedder = PickBestFeatureEmbedder(auto_embed=auto_embed)
kwargs["feature_embedder"] = feature_embedder
@ -294,23 +298,17 @@ class PickBest(base.RLChain[PickBestEvent]):
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
if not actions:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. \
Please include at least one variable containing \
a list to select from."
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." # noqa E501
)
if len(list(actions.values())) > 1:
raise ValueError(
"Only one variable using 'ToSelectFrom' can be provided in the inputs \
for the PickBest chain. Please provide only one variable \
containing a list to select from."
"Only one variable using 'ToSelectFrom' can be provided in the inputs for the PickBest chain. Please provide only one variable containing a list to select from." # noqa E501
)
if not context:
raise ValueError(
"No variables using 'BasedOn' found in the inputs. \
Please include at least one variable containing information \
to base the selected of ToSelectFrom on."
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on." # noqa E501
)
event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)

@ -1,7 +1,7 @@
from typing import Any, Dict
import pytest
from test_utils import MockEncoder
from test_utils import MockEncoder, MockEncoderReturnsList
import langchain.chains.rl_chain.base as rl_chain
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
@ -26,7 +26,9 @@ def test_multiple_ToSelectFrom_throws() -> None:
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
@ -43,7 +45,9 @@ def test_missing_basedOn_from_throws() -> None:
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
@ -56,7 +60,9 @@ def test_ToSelectFrom_not_a_list_throws() -> None:
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = {"actions": ["0", "1", "2"]}
with pytest.raises(ValueError):
@ -75,7 +81,9 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
response = chain.run(
@ -98,7 +106,9 @@ def test_update_with_delayed_score_force() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
response = chain.run(
@ -121,7 +131,9 @@ def test_update_with_delayed_score() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=None,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
response = chain.run(
@ -153,7 +165,9 @@ def test_user_defined_scorer() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=CustomSelectionScorer(),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
response = chain.run(
@ -168,7 +182,9 @@ def test_user_defined_scorer() -> None:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_everything_embedded() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
)
@ -200,7 +216,9 @@ def test_everything_embedded() -> None:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_auto_embedder_is_off() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
)
@ -224,9 +242,11 @@ def test_default_auto_embedder_is_off() -> None:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_off() -> None:
def test_default_w_embeddings_off() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
)
@ -250,29 +270,54 @@ def test_default_embeddings_off() -> None:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
def test_default_w_embeddings_on() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=True, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=True, model=MockEncoderReturnsList()
)
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
)
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))
encoded_str2 = rl_chain.stringify_embedding(list(encoded_keyword + str2))
encoded_str3 = rl_chain.stringify_embedding(list(encoded_keyword + str3))
ctx_str_1 = "context1"
dot_prod = "dotprod 0:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
expected = f"""shared |User {ctx_str_1} |@ User={ctx_str_1}\n|action {str1} |# action={str1} |{dot_prod}\n|action {str2} |# action={str2} |{dot_prod}""" # noqa
actions = [str1, str2]
response = chain.run(
User=rl_chain.BasedOn(ctx_str_1),
action=rl_chain.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"]
vw_str = feature_embedder.format(selection_metadata)
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=True, model=MockEncoderReturnsList()
)
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True
)
str1 = "0"
str2 = "1"
encoded_str2 = rl_chain.stringify_embedding([1.0, 2.0])
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = rl_chain.stringify_embedding([1.0, 2.0])
dot_prod = "dotprod 0:5.0 1:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]
encoded_ctx_str_1 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_1))
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))
expected = f"""shared |User {encoded_ctx_str_1} |User2 {ctx_str_2 + " " + encoded_ctx_str_2} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {encoded_str3} """ # noqa
expected = f"""shared |User {encoded_ctx_str_1} |@ User={encoded_ctx_str_1} |User2 {ctx_str_2} |@ User2={ctx_str_2}\n|action {str1} |# action={str1} |{dot_prod}\n|action {encoded_str2} |# action={encoded_str2} |{dot_prod}""" # noqa
actions = [str1, str2, rl_chain.Embed(str3)]
actions = [str1, rl_chain.Embed(str2)]
response = chain.run(
User=rl_chain.BasedOn(rl_chain.Embed(ctx_str_1)),
@ -291,7 +336,9 @@ def test_default_no_scorer_specified() -> None:
chain = pick_best_chain.PickBest.from_llm(
llm=chain_llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
@ -310,7 +357,9 @@ def test_explicitly_no_scorer() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=None,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
@ -330,7 +379,9 @@ def test_auto_scorer_with_user_defined_llm() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
@ -348,7 +399,9 @@ def test_calling_chain_w_reserved_inputs_throws() -> None:
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
with pytest.raises(ValueError):
chain.run(
@ -371,7 +424,9 @@ def test_activate_and_deactivate_scorer() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
),
)
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),

@ -9,7 +9,9 @@ encoded_keyword = "[encoded]"
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_context_throws() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_action = {"action": ["0", "1", "2"]}
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from=named_action, based_on={}
@ -20,7 +22,9 @@ def test_pickbest_textembedder_missing_context_throws() -> None:
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_actions_throws() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
event = pick_best_chain.PickBestEvent(
inputs={}, to_select_from={}, based_on={"context": "context"}
)
@ -30,7 +34,9 @@ def test_pickbest_textembedder_missing_actions_throws() -> None:
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_no_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
event = pick_best_chain.PickBestEvent(
@ -42,7 +48,9 @@ def test_pickbest_textembedder_no_label_no_emb() -> None:
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
@ -58,7 +66,9 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": ["0", "1", "2"]}
expected = (
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
@ -76,7 +86,9 @@ def test_pickbest_textembedder_w_full_label_no_emb() -> None:
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
str3 = "2"
@ -100,7 +112,9 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None:
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
str3 = "2"
@ -124,7 +138,9 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
@ -137,7 +153,9 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
@ -151,7 +169,9 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501
@ -165,7 +185,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
@ -198,7 +220,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> (
None
):
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
@ -231,7 +255,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
@ -263,7 +289,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "0"
str2 = "1"
@ -298,7 +326,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep()
@pytest.mark.requires("vowpal_wabbit_next")
def test_raw_features_underscored() -> None:
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder())
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(
auto_embed=False, model=MockEncoder()
)
str1 = "this is a long string"
str1_underscored = str1.replace(" ", "_")
encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))

@ -1,3 +1,15 @@
from typing import Any, List
class MockEncoder:
def encode(self, to_encode: str) -> str:
return "[encoded]" + to_encode
class MockEncoderReturnsList:
def encode(self, to_encode: Any) -> List:
if isinstance(to_encode, str):
return [1.0, 2.0]
elif isinstance(to_encode, List):
return [[1.0, 2.0] for _ in range(len(to_encode))]
raise ValueError("Invalid input type for unit test")

Loading…
Cancel
Save