diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index 6e01bb5063..66ead42e71 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -229,6 +229,9 @@ class VwPolicy(Policy): class Embedder(Generic[TEvent], ABC): + def __init__(self, *args: Any, **kwargs: Any): + pass + @abstractmethod def format(self, event: TEvent) -> str: ... @@ -498,8 +501,8 @@ class RLChain(Chain, Generic[TEvent]): ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - if self.auto_embed: - inputs = prepare_inputs_for_autoembed(inputs=inputs) + # if self.auto_embed: + # inputs = prepare_inputs_for_autoembed(inputs=inputs) event: TEvent = self._call_before_predict(inputs=inputs) prediction = self.active_policy.predict(event=event) diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index 04218d2934..5ed32c4cad 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -53,21 +53,25 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer. """ # noqa E501 - def __init__(self, model: Optional[Any] = None, *args: Any, **kwargs: Any): + def __init__( + self, auto_embed: bool, model: Optional[Any] = None, *args: Any, **kwargs: Any + ): super().__init__(*args, **kwargs) if model is None: from sentence_transformers import SentenceTransformer - model = SentenceTransformer("bert-base-nli-mean-tokens") + model = SentenceTransformer("all-mpnet-base-v2") + # model = SentenceTransformer("all-MiniLM-L6-v2") self.model = model + self.auto_embed = auto_embed - def format(self, event: PickBestEvent) -> str: - """ - Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW - """ + @staticmethod + def _str(embedding): + return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)]) + def get_label(self, event: PickBestEvent) -> tuple: cost = None if event.selected: chosen_action = event.selected.index @@ -77,7 +81,11 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): else None ) prob = event.selected.probability + return chosen_action, cost, prob + else: + return None, None, None + def get_context_and_action_embeddings(self, event: PickBestEvent) -> tuple: context_emb = base.embed(event.based_on, self.model) if event.based_on else None to_select_from_var_name, to_select_from = next( iter(event.to_select_from.items()), (None, None) @@ -97,6 +105,97 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): raise ValueError( "Context and to_select_from must be provided in the inputs dictionary" ) + return context_emb, action_embs + + def get_indexed_dot_product(self, context_emb: List, action_embs: List) -> Dict: + import numpy as np + + unique_contexts = set() + for context_item in context_emb: + for ns, ee in context_item.items(): + if isinstance(ee, list): + for ea in ee: + unique_contexts.add(f"{ns}={ea}") + else: + unique_contexts.add(f"{ns}={ee}") + + encoded_contexts = self.model.encode(list(unique_contexts)) + context_embeddings = dict(zip(unique_contexts, encoded_contexts)) + + unique_actions = set() + for action in action_embs: + for ns, e in action.items(): + if isinstance(e, list): + for ea in e: + unique_actions.add(f"{ns}={ea}") + else: + unique_actions.add(f"{ns}={e}") + + encoded_actions = self.model.encode(list(unique_actions)) + action_embeddings = dict(zip(unique_actions, encoded_actions)) + + action_matrix = np.stack([v for k, v in action_embeddings.items()]) + context_matrix = np.stack([v for k, v in context_embeddings.items()]) + dot_product_matrix = np.dot(context_matrix, action_matrix.T) + + indexed_dot_product = {} + + for i, context_key in enumerate(context_embeddings.keys()): + indexed_dot_product[context_key] = {} + for j, action_key in enumerate(action_embeddings.keys()): + indexed_dot_product[context_key][action_key] = dot_product_matrix[i, j] + + return indexed_dot_product + + def format_auto_embed_on(self, event: PickBestEvent) -> str: + chosen_action, cost, prob = self.get_label(event) + context_emb, action_embs = self.get_context_and_action_embeddings(event) + indexed_dot_product = self.get_indexed_dot_product(context_emb, action_embs) + + action_lines = [] + for i, action in enumerate(action_embs): + line_parts = [] + dot_prods = [] + if cost is not None and chosen_action == i: + line_parts.append(f"{chosen_action}:{cost}:{prob}") + for ns, action in action.items(): + line_parts.append(f"|{ns}") + elements = action if isinstance(action, list) else [action] + nsa = [] + for elem in elements: + line_parts.append(f"{elem}") + ns_a = f"{ns}={elem}" + nsa.append(ns_a) + for k,v in indexed_dot_product.items(): + dot_prods.append(v[ns_a]) + nsa = " ".join(nsa) + line_parts.append(f"|# {nsa}") + + line_parts.append(f"|embedding {self._str(dot_prods)}") + action_lines.append(" ".join(line_parts)) + + shared = [] + for item in context_emb: + for ns, context in item.items(): + shared.append(f"|{ns}") + elements = context if isinstance(context, list) else [context] + nsc = [] + for elem in elements: + shared.append(f"{elem}") + nsc.append(f"{ns}={elem}") + nsc = " ".join(nsc) + shared.append(f"|@ {nsc}") + + r = "shared " + " ".join(shared) + "\n" + "\n".join(action_lines) + print(r) + return r + + def format_auto_embed_off(self, event: PickBestEvent) -> str: + """ + Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW + """ + chosen_action, cost, prob = self.get_label(event) + context_emb, action_embs = self.get_context_and_action_embeddings(event) example_string = "" example_string += "shared " @@ -120,6 +219,12 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): # Strip the last newline return example_string[:-1] + def format(self, event: PickBestEvent) -> str: + if self.auto_embed: + return self.format_auto_embed_on(event) + else: + return self.format_auto_embed_off(event) + class PickBest(base.RLChain[PickBestEvent]): """ @@ -154,12 +259,20 @@ class PickBest(base.RLChain[PickBestEvent]): *args: Any, **kwargs: Any, ): + auto_embed = kwargs.get("auto_embed", False) + vw_cmd = kwargs.get("vw_cmd", []) if not vw_cmd: - vw_cmd = [ + interactions = ["--interactions=::"] + if auto_embed: + interactions = [ + "--interactions=@#", + "--ignore_linear=@", + "--ignore_linear=#", + "--noconstant", + ] + vw_cmd = interactions + [ "--cb_explore_adf", - "--quiet", - "--interactions=::", "--coin", "--squarecb", ] @@ -172,7 +285,7 @@ class PickBest(base.RLChain[PickBestEvent]): feature_embedder = kwargs.get("feature_embedder", None) if not feature_embedder: - feature_embedder = PickBestFeatureEmbedder() + feature_embedder = PickBestFeatureEmbedder(auto_embed=auto_embed) kwargs["feature_embedder"] = feature_embedder super().__init__(*args, **kwargs) diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py index 7bfa5ad550..3678523a04 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py @@ -26,7 +26,7 @@ def test_multiple_ToSelectFrom_throws() -> None: chain = pick_best_chain.PickBest.from_llm( llm=llm, prompt=PROMPT, - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) actions = ["0", "1", "2"] with pytest.raises(ValueError): @@ -43,7 +43,7 @@ def test_missing_basedOn_from_throws() -> None: chain = pick_best_chain.PickBest.from_llm( llm=llm, prompt=PROMPT, - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) actions = ["0", "1", "2"] with pytest.raises(ValueError): @@ -56,7 +56,7 @@ def test_ToSelectFrom_not_a_list_throws() -> None: chain = pick_best_chain.PickBest.from_llm( llm=llm, prompt=PROMPT, - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) actions = {"actions": ["0", "1", "2"]} with pytest.raises(ValueError): @@ -75,7 +75,7 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None: llm=llm, prompt=PROMPT, selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm), - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) actions = ["0", "1", "2"] response = chain.run( @@ -98,7 +98,7 @@ def test_update_with_delayed_score_force() -> None: llm=llm, prompt=PROMPT, selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm), - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) actions = ["0", "1", "2"] response = chain.run( @@ -121,7 +121,7 @@ def test_update_with_delayed_score() -> None: llm=llm, prompt=PROMPT, selection_scorer=None, - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) actions = ["0", "1", "2"] response = chain.run( @@ -153,7 +153,7 @@ def test_user_defined_scorer() -> None: llm=llm, prompt=PROMPT, selection_scorer=CustomSelectionScorer(), - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) actions = ["0", "1", "2"] response = chain.run( @@ -166,11 +166,11 @@ def test_user_defined_scorer() -> None: @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") -def test_auto_embeddings_on() -> None: +def test_everything_embedded() -> None: llm, PROMPT = setup() - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( - llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True + llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False ) str1 = "0" @@ -189,8 +189,8 @@ def test_auto_embeddings_on() -> None: actions = [str1, str2, str3] response = chain.run( - User=rl_chain.BasedOn(ctx_str_1), - action=rl_chain.ToSelectFrom(actions), + User=rl_chain.EmbedAndKeep(rl_chain.BasedOn(ctx_str_1)), + action=rl_chain.EmbedAndKeep(rl_chain.ToSelectFrom(actions)), ) selection_metadata = response["selection_metadata"] vw_str = feature_embedder.format(selection_metadata) @@ -200,7 +200,7 @@ def test_auto_embeddings_on() -> None: @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_default_auto_embedder_is_off() -> None: llm, PROMPT = setup() - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( llm=llm, prompt=PROMPT, feature_embedder=feature_embedder ) @@ -226,7 +226,7 @@ def test_default_auto_embedder_is_off() -> None: @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_default_embeddings_off() -> None: llm, PROMPT = setup() - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False ) @@ -252,7 +252,7 @@ def test_default_embeddings_off() -> None: @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None: llm, PROMPT = setup() - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=True, model=MockEncoder()) chain = pick_best_chain.PickBest.from_llm( llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=True ) @@ -291,7 +291,7 @@ def test_default_no_scorer_specified() -> None: chain = pick_best_chain.PickBest.from_llm( llm=chain_llm, prompt=PROMPT, - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) response = chain.run( User=rl_chain.BasedOn("Context"), @@ -310,7 +310,7 @@ def test_explicitly_no_scorer() -> None: llm=llm, prompt=PROMPT, selection_scorer=None, - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) response = chain.run( User=rl_chain.BasedOn("Context"), @@ -330,7 +330,7 @@ def test_auto_scorer_with_user_defined_llm() -> None: llm=llm, prompt=PROMPT, selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm), - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) response = chain.run( User=rl_chain.BasedOn("Context"), @@ -348,7 +348,7 @@ def test_calling_chain_w_reserved_inputs_throws() -> None: chain = pick_best_chain.PickBest.from_llm( llm=llm, prompt=PROMPT, - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) with pytest.raises(ValueError): chain.run( @@ -371,7 +371,7 @@ def test_activate_and_deactivate_scorer() -> None: llm=llm, prompt=PROMPT, selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm), - feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()), ) response = chain.run( User=pick_best_chain.base.BasedOn("Context"), diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py index 8683e3b0e5..734dae8d25 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_text_embedder.py @@ -9,7 +9,7 @@ encoded_keyword = "[encoded]" @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_missing_context_throws() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) named_action = {"action": ["0", "1", "2"]} event = pick_best_chain.PickBestEvent( inputs={}, to_select_from=named_action, based_on={} @@ -20,7 +20,7 @@ def test_pickbest_textembedder_missing_context_throws() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_missing_actions_throws() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) event = pick_best_chain.PickBestEvent( inputs={}, to_select_from={}, based_on={"context": "context"} ) @@ -30,7 +30,7 @@ def test_pickbest_textembedder_missing_actions_throws() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_no_label_no_emb() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) named_actions = {"action1": ["0", "1", "2"]} expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ event = pick_best_chain.PickBestEvent( @@ -42,7 +42,7 @@ def test_pickbest_textembedder_no_label_no_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) named_actions = {"action1": ["0", "1", "2"]} expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """ selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) @@ -58,7 +58,7 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_w_full_label_no_emb() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) named_actions = {"action1": ["0", "1", "2"]} expected = ( """shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """ @@ -76,7 +76,7 @@ def test_pickbest_textembedder_w_full_label_no_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_w_full_label_w_emb() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) str1 = "0" str2 = "1" str3 = "2" @@ -100,7 +100,7 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) str1 = "0" str2 = "1" str3 = "2" @@ -124,7 +124,7 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 @@ -137,7 +137,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 @@ -151,7 +151,7 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} context = {"context1": "context1", "context2": "context2"} expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """ # noqa: E501 @@ -165,7 +165,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) str1 = "0" str2 = "1" @@ -198,7 +198,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep() -> ( None ): - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) str1 = "0" str2 = "1" @@ -231,7 +231,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) str1 = "0" str2 = "1" @@ -263,7 +263,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N @pytest.mark.requires("vowpal_wabbit_next") def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) str1 = "0" str2 = "1" @@ -298,7 +298,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() @pytest.mark.requires("vowpal_wabbit_next") def test_raw_features_underscored() -> None: - feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) + feature_embedder = pick_best_chain.PickBestFeatureEmbedder(auto_embed=False, model=MockEncoder()) str1 = "this is a long string" str1_underscored = str1.replace(" ", "_") encoded_str1 = rl_chain.stringify_embedding(list(encoded_keyword + str1))