From 44badd07077879d092fdd53ae29fc8fc8f418756 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 18 Aug 2023 07:19:56 -0400 Subject: [PATCH] add dependency requirements to test file --- .../langchain/chains/rl_chain/base.py | 2 +- .../rl_chain/test_pick_best_chain_call.py | 30 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/libs/langchain/langchain/chains/rl_chain/base.py b/libs/langchain/langchain/chains/rl_chain/base.py index 437053f2dc..28baf898d2 100644 --- a/libs/langchain/langchain/chains/rl_chain/base.py +++ b/libs/langchain/langchain/chains/rl_chain/base.py @@ -471,7 +471,7 @@ class RLChain(Chain): def save_progress(self) -> None: """ - This function should be called to save the state of the Vowpal Wabbit model. + This function should be called to save the state of the learned policy model. """ self.policy.save() diff --git a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py index e42818ea8c..3fad1667d9 100644 --- a/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py +++ b/libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py @@ -9,7 +9,7 @@ from langchain.prompts.prompt import PromptTemplate encoded_text = "[ e n c o d e d ] " -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def setup(): _PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm""" PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE) @@ -18,7 +18,7 @@ def setup(): return llm, PROMPT -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_multiple_ToSelectFrom_throws(): llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) @@ -31,7 +31,7 @@ def test_multiple_ToSelectFrom_throws(): ) -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_missing_basedOn_from_throws(): llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) @@ -40,7 +40,7 @@ def test_missing_basedOn_from_throws(): chain.run(action=rl_chain.ToSelectFrom(actions)) -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_ToSelectFrom_not_a_list_throws(): llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT) @@ -52,7 +52,7 @@ def test_ToSelectFrom_not_a_list_throws(): ) -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_update_with_delayed_score_with_auto_validator_throws(): llm, PROMPT = setup() # this LLM returns a number so that the auto validator will return that @@ -74,7 +74,7 @@ def test_update_with_delayed_score_with_auto_validator_throws(): chain.update_with_delayed_score(event=selection_metadata, score=100) -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_update_with_delayed_score_force(): llm, PROMPT = setup() # this LLM returns a number so that the auto validator will return that @@ -98,7 +98,7 @@ def test_update_with_delayed_score_force(): assert selection_metadata.selected.score == 100.0 -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_update_with_delayed_score(): llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm( @@ -116,7 +116,7 @@ def test_update_with_delayed_score(): assert selection_metadata.selected.score == 100.0 -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_user_defined_scorer(): llm, PROMPT = setup() @@ -138,7 +138,7 @@ def test_user_defined_scorer(): assert selection_metadata.selected.score == 200.0 -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_default_embeddings(): llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) @@ -172,7 +172,7 @@ def test_default_embeddings(): assert vw_str == expected -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_default_embeddings_off(): llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) @@ -198,7 +198,7 @@ def test_default_embeddings_off(): assert vw_str == expected -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_default_embeddings_mixed_w_explicit_user_embeddings(): llm, PROMPT = setup() feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()) @@ -233,7 +233,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings(): assert vw_str == expected -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_default_no_scorer_specified(): _, PROMPT = setup() chain_llm = FakeListChatModel(responses=[100]) @@ -248,7 +248,7 @@ def test_default_no_scorer_specified(): assert selection_metadata.selected.score == 100.0 -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_explicitly_no_scorer(): llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm( @@ -264,7 +264,7 @@ def test_explicitly_no_scorer(): assert selection_metadata.selected.score is None -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_auto_scorer_with_user_defined_llm(): llm, PROMPT = setup() scorer_llm = FakeListChatModel(responses=[300]) @@ -283,7 +283,7 @@ def test_auto_scorer_with_user_defined_llm(): assert selection_metadata.selected.score == 300.0 -@pytest.mark.requires("vowpal_wabbit_next") +@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") def test_calling_chain_w_reserved_inputs_throws(): llm, PROMPT = setup() chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)