add dependency requirements to test file

This commit is contained in:
olgavrou 2023-08-18 07:19:56 -04:00
parent e276ae2616
commit 44badd0707
2 changed files with 16 additions and 16 deletions

View File

@ -471,7 +471,7 @@ class RLChain(Chain):
def save_progress(self) -> None:
"""
This function should be called to save the state of the Vowpal Wabbit model.
This function should be called to save the state of the learned policy model.
"""
self.policy.save()

View File

@ -9,7 +9,7 @@ from langchain.prompts.prompt import PromptTemplate
encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def setup():
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
@ -18,7 +18,7 @@ def setup():
return llm, PROMPT
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_multiple_ToSelectFrom_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -31,7 +31,7 @@ def test_multiple_ToSelectFrom_throws():
)
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_missing_basedOn_from_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -40,7 +40,7 @@ def test_missing_basedOn_from_throws():
chain.run(action=rl_chain.ToSelectFrom(actions))
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_ToSelectFrom_not_a_list_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -52,7 +52,7 @@ def test_ToSelectFrom_not_a_list_throws():
)
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_with_auto_validator_throws():
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
@ -74,7 +74,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
chain.update_with_delayed_score(event=selection_metadata, score=100)
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score_force():
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
@ -98,7 +98,7 @@ def test_update_with_delayed_score_force():
assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_update_with_delayed_score():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
@ -116,7 +116,7 @@ def test_update_with_delayed_score():
assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_user_defined_scorer():
llm, PROMPT = setup()
@ -138,7 +138,7 @@ def test_user_defined_scorer():
assert selection_metadata.selected.score == 200.0
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -172,7 +172,7 @@ def test_default_embeddings():
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_off():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -198,7 +198,7 @@ def test_default_embeddings_off():
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_embeddings_mixed_w_explicit_user_embeddings():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -233,7 +233,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_default_no_scorer_specified():
_, PROMPT = setup()
chain_llm = FakeListChatModel(responses=[100])
@ -248,7 +248,7 @@ def test_default_no_scorer_specified():
assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_explicitly_no_scorer():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
@ -264,7 +264,7 @@ def test_explicitly_no_scorer():
assert selection_metadata.selected.score is None
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_auto_scorer_with_user_defined_llm():
llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=[300])
@ -283,7 +283,7 @@ def test_auto_scorer_with_user_defined_llm():
assert selection_metadata.selected.score == 300.0
@pytest.mark.requires("vowpal_wabbit_next")
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_calling_chain_w_reserved_inputs_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)