add dependency requirements to test file

pull/10242/head
olgavrou 1 year ago
parent e276ae2616
commit 44badd0707

@ -471,7 +471,7 @@ class RLChain(Chain):
def save_progress(self) -> None:
"""
This function should be called to save the state of the Vowpal Wabbit model.
This function should be called to save the state of the learned policy model.
"""
self.policy.save()

@ -9,7 +9,7 @@ from langchain.prompts.prompt import PromptTemplate
encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def setup():
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
@ -18,7 +18,7 @@ def setup():
return llm, PROMPT
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_multiple_ToSelectFrom_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -31,7 +31,7 @@ def test_multiple_ToSelectFrom_throws():
)
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_missing_basedOn_from_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -40,7 +40,7 @@ def test_missing_basedOn_from_throws():
chain.run(action=rl_chain.ToSelectFrom(actions))
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_ToSelectFrom_not_a_list_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -52,7 +52,7 @@ def test_ToSelectFrom_not_a_list_throws():
)
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_with_auto_validator_throws():
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
@ -74,7 +74,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
chain.update_with_delayed_score(event=selection_metadata, score=100)
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_force():
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
@ -98,7 +98,7 @@ def test_update_with_delayed_score_force():
assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
@ -116,7 +116,7 @@ def test_update_with_delayed_score():
assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_user_defined_scorer():
llm, PROMPT = setup()
@ -138,7 +138,7 @@ def test_user_defined_scorer():
assert selection_metadata.selected.score == 200.0
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -172,7 +172,7 @@ def test_default_embeddings():
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_off():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -198,7 +198,7 @@ def test_default_embeddings_off():
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_mixed_w_explicit_user_embeddings():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -233,7 +233,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_no_scorer_specified():
_, PROMPT = setup()
chain_llm = FakeListChatModel(responses=[100])
@ -248,7 +248,7 @@ def test_default_no_scorer_specified():
assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_explicitly_no_scorer():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
@ -264,7 +264,7 @@ def test_explicitly_no_scorer():
assert selection_metadata.selected.score is None
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_auto_scorer_with_user_defined_llm():
llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=[300])
@ -283,7 +283,7 @@ def test_auto_scorer_with_user_defined_llm():
assert selection_metadata.selected.score == 300.0
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_calling_chain_w_reserved_inputs_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)

Loading…
Cancel
Save