|
|
@ -90,11 +90,11 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
|
|
|
User=rl_chain.BasedOn("Context"),
|
|
|
|
User=rl_chain.BasedOn("Context"),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
assert response["response"] == "hey"
|
|
|
|
assert response["response"] == "hey" # type: ignore
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
assert selection_metadata.selected.score == 3.0
|
|
|
|
assert selection_metadata.selected.score == 3.0 # type: ignore
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
chain.update_with_delayed_score(chain_response=response, score=100)
|
|
|
|
chain.update_with_delayed_score(chain_response=response, score=100) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
@ -115,13 +115,13 @@ def test_update_with_delayed_score_force() -> None:
|
|
|
|
User=rl_chain.BasedOn("Context"),
|
|
|
|
User=rl_chain.BasedOn("Context"),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
assert response["response"] == "hey"
|
|
|
|
assert response["response"] == "hey" # type: ignore
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
assert selection_metadata.selected.score == 3.0
|
|
|
|
assert selection_metadata.selected.score == 3.0 # type: ignore
|
|
|
|
chain.update_with_delayed_score(
|
|
|
|
chain.update_with_delayed_score(
|
|
|
|
chain_response=response, score=100, force_score=True
|
|
|
|
chain_response=response, score=100, force_score=True # type: ignore
|
|
|
|
)
|
|
|
|
)
|
|
|
|
assert selection_metadata.selected.score == 100.0
|
|
|
|
assert selection_metadata.selected.score == 100.0 # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
@ -140,11 +140,11 @@ def test_update_with_delayed_score() -> None:
|
|
|
|
User=rl_chain.BasedOn("Context"),
|
|
|
|
User=rl_chain.BasedOn("Context"),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
assert response["response"] == "hey"
|
|
|
|
assert response["response"] == "hey" # type: ignore
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
assert selection_metadata.selected.score is None
|
|
|
|
assert selection_metadata.selected.score is None # type: ignore
|
|
|
|
chain.update_with_delayed_score(chain_response=response, score=100)
|
|
|
|
chain.update_with_delayed_score(chain_response=response, score=100) # type: ignore
|
|
|
|
assert selection_metadata.selected.score == 100.0
|
|
|
|
assert selection_metadata.selected.score == 100.0 # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
@ -174,9 +174,9 @@ def test_user_defined_scorer() -> None:
|
|
|
|
User=rl_chain.BasedOn("Context"),
|
|
|
|
User=rl_chain.BasedOn("Context"),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
assert response["response"] == "hey"
|
|
|
|
assert response["response"] == "hey" # type: ignore
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
assert selection_metadata.selected.score == 200.0
|
|
|
|
assert selection_metadata.selected.score == 200.0 # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
@ -208,8 +208,8 @@ def test_everything_embedded() -> None:
|
|
|
|
User=rl_chain.EmbedAndKeep(rl_chain.BasedOn(ctx_str_1)),
|
|
|
|
User=rl_chain.EmbedAndKeep(rl_chain.BasedOn(ctx_str_1)),
|
|
|
|
action=rl_chain.EmbedAndKeep(rl_chain.ToSelectFrom(actions)),
|
|
|
|
action=rl_chain.EmbedAndKeep(rl_chain.ToSelectFrom(actions)),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
vw_str = feature_embedder.format(selection_metadata)
|
|
|
|
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
|
|
|
assert vw_str == expected
|
|
|
|
assert vw_str == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -236,8 +236,8 @@ def test_default_auto_embedder_is_off() -> None:
|
|
|
|
User=pick_best_chain.base.BasedOn(ctx_str_1),
|
|
|
|
User=pick_best_chain.base.BasedOn(ctx_str_1),
|
|
|
|
action=pick_best_chain.base.ToSelectFrom(actions),
|
|
|
|
action=pick_best_chain.base.ToSelectFrom(actions),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
vw_str = feature_embedder.format(selection_metadata)
|
|
|
|
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
|
|
|
assert vw_str == expected
|
|
|
|
assert vw_str == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -264,8 +264,8 @@ def test_default_w_embeddings_off() -> None:
|
|
|
|
User=rl_chain.BasedOn(ctx_str_1),
|
|
|
|
User=rl_chain.BasedOn(ctx_str_1),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
vw_str = feature_embedder.format(selection_metadata)
|
|
|
|
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
|
|
|
assert vw_str == expected
|
|
|
|
assert vw_str == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -292,8 +292,8 @@ def test_default_w_embeddings_on() -> None:
|
|
|
|
User=rl_chain.BasedOn(ctx_str_1),
|
|
|
|
User=rl_chain.BasedOn(ctx_str_1),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
vw_str = feature_embedder.format(selection_metadata)
|
|
|
|
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
|
|
|
assert vw_str == expected
|
|
|
|
assert vw_str == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -324,8 +324,8 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
|
|
|
|
User2=rl_chain.BasedOn(ctx_str_2),
|
|
|
|
User2=rl_chain.BasedOn(ctx_str_2),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
action=rl_chain.ToSelectFrom(actions),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
vw_str = feature_embedder.format(selection_metadata)
|
|
|
|
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
|
|
|
assert vw_str == expected
|
|
|
|
assert vw_str == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -345,9 +345,9 @@ def test_default_no_scorer_specified() -> None:
|
|
|
|
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# chain llm used for both basic prompt and for scoring
|
|
|
|
# chain llm used for both basic prompt and for scoring
|
|
|
|
assert response["response"] == "hey"
|
|
|
|
assert response["response"] == "hey" # type: ignore
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
assert selection_metadata.selected.score == 100.0
|
|
|
|
assert selection_metadata.selected.score == 100.0 # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
@ -366,9 +366,9 @@ def test_explicitly_no_scorer() -> None:
|
|
|
|
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# chain llm used for both basic prompt and for scoring
|
|
|
|
# chain llm used for both basic prompt and for scoring
|
|
|
|
assert response["response"] == "hey"
|
|
|
|
assert response["response"] == "hey" # type: ignore
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
assert selection_metadata.selected.score is None
|
|
|
|
assert selection_metadata.selected.score is None # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
@ -388,9 +388,9 @@ def test_auto_scorer_with_user_defined_llm() -> None:
|
|
|
|
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# chain llm used for both basic prompt and for scoring
|
|
|
|
# chain llm used for both basic prompt and for scoring
|
|
|
|
assert response["response"] == "hey"
|
|
|
|
assert response["response"] == "hey" # type: ignore
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
assert selection_metadata.selected.score == 300.0
|
|
|
|
assert selection_metadata.selected.score == 300.0 # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
@ -434,24 +434,24 @@ def test_activate_and_deactivate_scorer() -> None:
|
|
|
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# chain llm used for both basic prompt and for scoring
|
|
|
|
# chain llm used for both basic prompt and for scoring
|
|
|
|
assert response["response"] == "hey1"
|
|
|
|
assert response["response"] == "hey1" # type: ignore
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
assert selection_metadata.selected.score == 300.0
|
|
|
|
assert selection_metadata.selected.score == 300.0 # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
chain.deactivate_selection_scorer()
|
|
|
|
chain.deactivate_selection_scorer()
|
|
|
|
response = chain.run(
|
|
|
|
response = chain.run(
|
|
|
|
User=pick_best_chain.base.BasedOn("Context"),
|
|
|
|
User=pick_best_chain.base.BasedOn("Context"),
|
|
|
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
assert response["response"] == "hey2"
|
|
|
|
assert response["response"] == "hey2" # type: ignore
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
assert selection_metadata.selected.score is None
|
|
|
|
assert selection_metadata.selected.score is None # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
chain.activate_selection_scorer()
|
|
|
|
chain.activate_selection_scorer()
|
|
|
|
response = chain.run(
|
|
|
|
response = chain.run(
|
|
|
|
User=pick_best_chain.base.BasedOn("Context"),
|
|
|
|
User=pick_best_chain.base.BasedOn("Context"),
|
|
|
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
assert response["response"] == "hey3"
|
|
|
|
assert response["response"] == "hey3" # type: ignore
|
|
|
|
selection_metadata = response["selection_metadata"]
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
assert selection_metadata.selected.score == 400.0
|
|
|
|
assert selection_metadata.selected.score == 400.0 # type: ignore
|
|
|
|