mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
add dependency requirements to test file
This commit is contained in:
parent
e276ae2616
commit
44badd0707
@ -471,7 +471,7 @@ class RLChain(Chain):
|
||||
|
||||
def save_progress(self) -> None:
|
||||
"""
|
||||
This function should be called to save the state of the Vowpal Wabbit model.
|
||||
This function should be called to save the state of the learned policy model.
|
||||
"""
|
||||
self.policy.save()
|
||||
|
||||
|
@ -9,7 +9,7 @@ from langchain.prompts.prompt import PromptTemplate
|
||||
encoded_text = "[ e n c o d e d ] "
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def setup():
|
||||
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
|
||||
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
|
||||
@ -18,7 +18,7 @@ def setup():
|
||||
return llm, PROMPT
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_multiple_ToSelectFrom_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
@ -31,7 +31,7 @@ def test_multiple_ToSelectFrom_throws():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_missing_basedOn_from_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
@ -40,7 +40,7 @@ def test_missing_basedOn_from_throws():
|
||||
chain.run(action=rl_chain.ToSelectFrom(actions))
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_ToSelectFrom_not_a_list_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
@ -52,7 +52,7 @@ def test_ToSelectFrom_not_a_list_throws():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score_with_auto_validator_throws():
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
@ -74,7 +74,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
|
||||
chain.update_with_delayed_score(event=selection_metadata, score=100)
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score_force():
|
||||
llm, PROMPT = setup()
|
||||
# this LLM returns a number so that the auto validator will return that
|
||||
@ -98,7 +98,7 @@ def test_update_with_delayed_score_force():
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_update_with_delayed_score():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
@ -116,7 +116,7 @@ def test_update_with_delayed_score():
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_user_defined_scorer():
|
||||
llm, PROMPT = setup()
|
||||
|
||||
@ -138,7 +138,7 @@ def test_user_defined_scorer():
|
||||
assert selection_metadata.selected.score == 200.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_embeddings():
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
@ -172,7 +172,7 @@ def test_default_embeddings():
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_embeddings_off():
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
@ -198,7 +198,7 @@ def test_default_embeddings_off():
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
||||
llm, PROMPT = setup()
|
||||
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
|
||||
@ -233,7 +233,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
|
||||
assert vw_str == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_default_no_scorer_specified():
|
||||
_, PROMPT = setup()
|
||||
chain_llm = FakeListChatModel(responses=[100])
|
||||
@ -248,7 +248,7 @@ def test_default_no_scorer_specified():
|
||||
assert selection_metadata.selected.score == 100.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_explicitly_no_scorer():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(
|
||||
@ -264,7 +264,7 @@ def test_explicitly_no_scorer():
|
||||
assert selection_metadata.selected.score is None
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_auto_scorer_with_user_defined_llm():
|
||||
llm, PROMPT = setup()
|
||||
scorer_llm = FakeListChatModel(responses=[300])
|
||||
@ -283,7 +283,7 @@ def test_auto_scorer_with_user_defined_llm():
|
||||
assert selection_metadata.selected.score == 300.0
|
||||
|
||||
|
||||
@pytest.mark.requires("vowpal_wabbit_next")
|
||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||
def test_calling_chain_w_reserved_inputs_throws():
|
||||
llm, PROMPT = setup()
|
||||
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
|
||||
|
Loading…
Reference in New Issue
Block a user