|
|
@ -1,12 +1,12 @@
|
|
|
|
from typing import Any, Dict
|
|
|
|
from typing import Any, Dict
|
|
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from langchain.chat_models import FakeListChatModel
|
|
|
|
|
|
|
|
from langchain.prompts.prompt import PromptTemplate
|
|
|
|
from test_utils import MockEncoder, MockEncoderReturnsList
|
|
|
|
from test_utils import MockEncoder, MockEncoderReturnsList
|
|
|
|
|
|
|
|
|
|
|
|
import langchain_experimental.rl_chain.base as rl_chain
|
|
|
|
import langchain_experimental.rl_chain.base as rl_chain
|
|
|
|
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
|
|
|
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
|
|
|
from langchain.chat_models import FakeListChatModel
|
|
|
|
|
|
|
|
from langchain.prompts.prompt import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoded_keyword = "[encoded]"
|
|
|
|
encoded_keyword = "[encoded]"
|
|
|
|
|
|
|
|
|
|
|
@ -94,7 +94,9 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
selection_metadata = response["selection_metadata"] # type: ignore
|
|
|
|
assert selection_metadata.selected.score == 3.0 # type: ignore
|
|
|
|
assert selection_metadata.selected.score == 3.0 # type: ignore
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
chain.update_with_delayed_score(chain_response=response, score=100) # type: ignore
|
|
|
|
chain.update_with_delayed_score(
|
|
|
|
|
|
|
|
chain_response=response, score=100 # type: ignore
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
|
|
|