diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py index 7bca6b470d..1b882e932d 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py @@ -23,7 +23,11 @@ def setup() -> tuple: @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_multiple_ToSelectFrom_throws() -> None: llm, PROMPT = setup() - chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) + chain = pick_best_chain.PickBest.from_llm( + llm=llm, + prompt=PROMPT, + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + ) actions = ["0", "1", "2"] with pytest.raises(ValueError): chain.run( @@ -36,7 +40,11 @@ def test_multiple_ToSelectFrom_throws() -> None: @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_missing_basedOn_from_throws() -> None: llm, PROMPT = setup() - chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) + chain = pick_best_chain.PickBest.from_llm( + llm=llm, + prompt=PROMPT, + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + ) actions = ["0", "1", "2"] with pytest.raises(ValueError): chain.run(action=rl_chain.ToSelectFrom(actions)) @@ -45,7 +53,11 @@ def test_missing_basedOn_from_throws() -> None: @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_ToSelectFrom_not_a_list_throws() -> None: llm, PROMPT = setup() - chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) + chain = pick_best_chain.PickBest.from_llm( + llm=llm, + prompt=PROMPT, + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + ) actions = {"actions": ["0", "1", "2"]} with pytest.raises(ValueError): chain.run( @@ -63,6 +75,7 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None: llm=llm, prompt=PROMPT, selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), ) actions = ["0", "1", "2"] response = chain.run( @@ -85,6 +98,7 @@ def test_update_with_delayed_score_force() -> None: llm=llm, prompt=PROMPT, selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), ) actions = ["0", "1", "2"] response = chain.run( @@ -104,7 +118,10 @@ def test_update_with_delayed_score_force() -> None: def test_update_with_delayed_score() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm( - llm=llm, prompt=PROMPT, selection_scorer=None + llm=llm, + prompt=PROMPT, + selection_scorer=None, + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), ) actions = ["0", "1", "2"] response = chain.run( @@ -128,7 +145,10 @@ def test_user_defined_scorer() -> None: return score chain = pick_best_chain.PickBest.from_llm( - llm=llm, prompt=PROMPT, selection_scorer=CustomSelectionScorer() + llm=llm, + prompt=PROMPT, + selection_scorer=CustomSelectionScorer(), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), ) actions = ["0", "1", "2"] response = chain.run( @@ -239,7 +259,11 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None: def test_default_no_scorer_specified() -> None: _, PROMPT = setup() chain_llm = FakeListChatModel(responses=[100]) - chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT) + chain = pick_best_chain.PickBest.from_llm( + llm=chain_llm, + prompt=PROMPT, + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + ) response = chain.run( User=rl_chain.BasedOn("Context"), action=rl_chain.ToSelectFrom(["0", "1", "2"]), @@ -254,7 +278,10 @@ def test_default_no_scorer_specified() -> None: def test_explicitly_no_scorer() -> None: llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm( - llm=llm, prompt=PROMPT, selection_scorer=None + llm=llm, + prompt=PROMPT, + selection_scorer=None, + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), ) response = chain.run( User=rl_chain.BasedOn("Context"), @@ -274,6 +301,7 @@ def test_auto_scorer_with_user_defined_llm() -> None: llm=llm, prompt=PROMPT, selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm), + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), ) response = chain.run( User=rl_chain.BasedOn("Context"), @@ -288,7 +316,11 @@ def test_auto_scorer_with_user_defined_llm() -> None: @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_calling_chain_w_reserved_inputs_throws() -> None: llm, PROMPT = setup() - chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) + chain = pick_best_chain.PickBest.from_llm( + llm=llm, + prompt=PROMPT, + feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()), + ) with pytest.raises(ValueError): chain.run( User=rl_chain.BasedOn("Context"),