@ -3,8 +3,8 @@ from typing import Any, Dict
import pytest
from test_utils import MockEncoder , MockEncoderReturnsList
import langchain . chains . rl_chain . base as rl_chain
import langchain . chains . rl_chain . pick_best_chain as pick_best_chain
import langchain _experimental . rl_chain . base as rl_chain
import langchain _experimental . rl_chain . pick_best_chain as pick_best_chain
from langchain . chat_models import FakeListChatModel
from langchain . prompts . prompt import PromptTemplate
@ -332,7 +332,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
@pytest.mark.requires ( " vowpal_wabbit_next " , " sentence_transformers " )
def test_default_no_scorer_specified ( ) - > None :
_ , PROMPT = setup ( )
chain_llm = FakeListChatModel ( responses = [ 100 ] )
chain_llm = FakeListChatModel ( responses = [ " hey " , " 100 " ] )
chain = pick_best_chain . PickBest . from_llm (
llm = chain_llm ,
prompt = PROMPT ,
@ -345,7 +345,7 @@ def test_default_no_scorer_specified() -> None:
action = rl_chain . ToSelectFrom ( [ " 0 " , " 1 " , " 2 " ] ) ,
)
# chain llm used for both basic prompt and for scoring
assert response [ " response " ] == " 100 "
assert response [ " response " ] == " hey "
selection_metadata = response [ " selection_metadata " ]
assert selection_metadata . selected . score == 100.0
@ -374,7 +374,7 @@ def test_explicitly_no_scorer() -> None:
@pytest.mark.requires ( " vowpal_wabbit_next " , " sentence_transformers " )
def test_auto_scorer_with_user_defined_llm ( ) - > None :
llm , PROMPT = setup ( )
scorer_llm = FakeListChatModel ( responses = [ 300 ] )
scorer_llm = FakeListChatModel ( responses = [ " 300 " ] )
chain = pick_best_chain . PickBest . from_llm (
llm = llm ,
prompt = PROMPT ,
@ -418,8 +418,9 @@ def test_calling_chain_w_reserved_inputs_throws() -> None:
@pytest.mark.requires ( " vowpal_wabbit_next " , " sentence_transformers " )
def test_activate_and_deactivate_scorer ( ) - > None :
llm , PROMPT = setup ( )
scorer_llm = FakeListChatModel ( responses = [ 300 ] )
_ , PROMPT = setup ( )
llm = FakeListChatModel ( responses = [ " hey1 " , " hey2 " , " hey3 " ] )
scorer_llm = FakeListChatModel ( responses = [ " 300 " , " 400 " ] )
chain = pick_best_chain . PickBest . from_llm (
llm = llm ,
prompt = PROMPT ,
@ -433,7 +434,7 @@ def test_activate_and_deactivate_scorer() -> None:
action = pick_best_chain . base . ToSelectFrom ( [ " 0 " , " 1 " , " 2 " ] ) ,
)
# chain llm used for both basic prompt and for scoring
assert response [ " response " ] == " hey "
assert response [ " response " ] == " hey 1 "
selection_metadata = response [ " selection_metadata " ]
assert selection_metadata . selected . score == 300.0
@ -442,7 +443,7 @@ def test_activate_and_deactivate_scorer() -> None:
User = pick_best_chain . base . BasedOn ( " Context " ) ,
action = pick_best_chain . base . ToSelectFrom ( [ " 0 " , " 1 " , " 2 " ] ) ,
)
assert response [ " response " ] == " hey "
assert response [ " response " ] == " hey 2 "
selection_metadata = response [ " selection_metadata " ]
assert selection_metadata . selected . score is None
@ -451,6 +452,6 @@ def test_activate_and_deactivate_scorer() -> None:
User = pick_best_chain . base . BasedOn ( " Context " ) ,
action = pick_best_chain . base . ToSelectFrom ( [ " 0 " , " 1 " , " 2 " ] ) ,
)
assert response [ " response " ] == " hey "
assert response [ " response " ] == " hey 3 "
selection_metadata = response [ " selection_metadata " ]
assert selection_metadata . selected . score == 3 00.0
assert selection_metadata . selected . score == 4 00.0