Merge pull request #5 from VowpalWabbit/nosockettests

unit tests to use mock encoder
This commit is contained in:
olgavrou 2023-08-29 07:28:03 -04:00 committed by GitHub
commit 42bdb003ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,7 +23,11 @@ def setup() -> tuple:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_multiple_ToSelectFrom_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
chain.run(
@ -36,7 +40,11 @@ def test_multiple_ToSelectFrom_throws() -> None:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_missing_basedOn_from_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
chain.run(action=rl_chain.ToSelectFrom(actions))
@ -45,7 +53,11 @@ def test_missing_basedOn_from_throws() -> None:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_ToSelectFrom_not_a_list_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = {"actions": ["0", "1", "2"]}
with pytest.raises(ValueError):
chain.run(
@ -63,6 +75,7 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
response = chain.run(
@ -85,6 +98,7 @@ def test_update_with_delayed_score_force() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
response = chain.run(
@ -104,7 +118,10 @@ def test_update_with_delayed_score_force() -> None:
def test_update_with_delayed_score() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
llm=llm,
prompt=PROMPT,
selection_scorer=None,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
response = chain.run(
@ -128,7 +145,10 @@ def test_user_defined_scorer() -> None:
return score
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=CustomSelectionScorer()
llm=llm,
prompt=PROMPT,
selection_scorer=CustomSelectionScorer(),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
actions = ["0", "1", "2"]
response = chain.run(
@ -239,7 +259,11 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
def test_default_no_scorer_specified() -> None:
_, PROMPT = setup()
chain_llm = FakeListChatModel(responses=[100])
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
chain = pick_best_chain.PickBest.from_llm(
llm=chain_llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
@ -254,7 +278,10 @@ def test_default_no_scorer_specified() -> None:
def test_explicitly_no_scorer() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
llm=llm,
prompt=PROMPT,
selection_scorer=None,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
@ -274,6 +301,7 @@ def test_auto_scorer_with_user_defined_llm() -> None:
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
@ -288,7 +316,11 @@ def test_auto_scorer_with_user_defined_llm() -> None:
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
def test_calling_chain_w_reserved_inputs_throws() -> None:
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
feature_embedder=pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder()),
)
with pytest.raises(ValueError):
chain.run(
User=rl_chain.BasedOn("Context"),