fix imports

pull/10242/head
olgavrou 1 year ago
parent c37fd29fd8
commit b422dc035f

@ -1,5 +1,5 @@
from .pick_best_chain import PickBest
from .rl_chain_base import (
from langchain.chains.rl_chain.pick_best_chain import PickBest
from langchain.chains.rl_chain.rl_chain_base import (
Embed,
BasedOn,
ToSelectFrom,

@ -1,6 +1,6 @@
from __future__ import annotations
from . import rl_chain_base as base
import langchain.chains.rl_chain.rl_chain_base as base
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain

@ -6,9 +6,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union, Sequence
from abc import ABC, abstractmethod
import vowpal_wabbit_next as vw
from .vw_logger import VwLogger
from .model_repository import ModelRepository
from .metrics import MetricsTracker
from langchain.chains.rl_chain.vw_logger import VwLogger
from langchain.chains.rl_chain.model_repository import ModelRepository
from langchain.chains.rl_chain.metrics import MetricsTracker
from langchain.prompts import BasePromptTemplate
from langchain.pydantic_v1 import Extra, BaseModel, root_validator

@ -1,4 +1,4 @@
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
import langchain.chains.rl_chain as rl_chain
from test_utils import MockEncoder
import pytest
from langchain.prompts.prompt import PromptTemplate
@ -17,32 +17,32 @@ def setup():
def test_multiple_ToSelectFrom_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(actions),
another_action=pick_best_chain.base.ToSelectFrom(actions),
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
another_action=rl_chain.ToSelectFrom(actions),
)
def test_missing_basedOn_from_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
chain.run(action=pick_best_chain.base.ToSelectFrom(actions))
chain.run(action=rl_chain.ToSelectFrom(actions))
def test_ToSelectFrom_not_a_list_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = {"actions": ["0", "1", "2"]}
with pytest.raises(ValueError):
chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(actions),
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
@ -50,15 +50,15 @@ def test_update_with_delayed_score_with_auto_validator_throws():
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"])
chain = pick_best_chain.PickBest.from_llm(
chain = rl_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=auto_val_llm),
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
)
actions = ["0", "1", "2"]
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(actions),
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
@ -71,15 +71,15 @@ def test_update_with_delayed_score_force():
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"])
chain = pick_best_chain.PickBest.from_llm(
chain = rl_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=auto_val_llm),
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
)
actions = ["0", "1", "2"]
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(actions),
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
@ -92,13 +92,13 @@ def test_update_with_delayed_score_force():
def test_update_with_delayed_score():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
chain = rl_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
)
actions = ["0", "1", "2"]
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(actions),
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
@ -110,18 +110,18 @@ def test_update_with_delayed_score():
def test_user_defined_scorer():
llm, PROMPT = setup()
class CustomSelectionScorer(pick_best_chain.base.SelectionScorer):
class CustomSelectionScorer(rl_chain.SelectionScorer):
def score_response(self, inputs, llm_response: str) -> float:
score = 200
return score
chain = pick_best_chain.PickBest.from_llm(
chain = rl_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=CustomSelectionScorer()
)
actions = ["0", "1", "2"]
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(actions),
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
@ -130,8 +130,8 @@ def test_user_defined_scorer():
def test_default_embeddings():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = rl_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
)
@ -153,8 +153,8 @@ def test_default_embeddings():
actions = [str1, str2, str3]
response = chain.run(
User=pick_best_chain.base.BasedOn(ctx_str_1),
action=pick_best_chain.base.ToSelectFrom(actions),
User=rl_chain.BasedOn(ctx_str_1),
action=rl_chain.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"]
vw_str = feature_embedder.format(selection_metadata)
@ -163,8 +163,8 @@ def test_default_embeddings():
def test_default_embeddings_off():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = rl_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
)
@ -178,8 +178,8 @@ def test_default_embeddings_off():
actions = [str1, str2, str3]
response = chain.run(
User=pick_best_chain.base.BasedOn(ctx_str_1),
action=pick_best_chain.base.ToSelectFrom(actions),
User=rl_chain.BasedOn(ctx_str_1),
action=rl_chain.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"]
vw_str = feature_embedder.format(selection_metadata)
@ -188,8 +188,8 @@ def test_default_embeddings_off():
def test_default_embeddings_mixed_w_explicit_user_embeddings():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = rl_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
)
@ -208,12 +208,12 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
expected = f"""shared |User {encoded_ctx_str_1} |User2 {ctx_str_2 + " " + encoded_ctx_str_2} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {encoded_str3} """
actions = [str1, str2, pick_best_chain.base.Embed(str3)]
actions = [str1, str2, rl_chain.Embed(str3)]
response = chain.run(
User=pick_best_chain.base.BasedOn(pick_best_chain.base.Embed(ctx_str_1)),
User2=pick_best_chain.base.BasedOn(ctx_str_2),
action=pick_best_chain.base.ToSelectFrom(actions),
User=rl_chain.BasedOn(rl_chain.Embed(ctx_str_1)),
User2=rl_chain.BasedOn(ctx_str_2),
action=rl_chain.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"]
vw_str = feature_embedder.format(selection_metadata)
@ -223,10 +223,10 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
def test_default_no_scorer_specified():
_, PROMPT = setup()
chain_llm = FakeListChatModel(responses=[100])
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
chain = rl_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "100"
@ -236,12 +236,12 @@ def test_default_no_scorer_specified():
def test_explicitly_no_scorer():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
chain = rl_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
)
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "hey"
@ -252,14 +252,14 @@ def test_explicitly_no_scorer():
def test_auto_scorer_with_user_defined_llm():
llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=[300])
chain = pick_best_chain.PickBest.from_llm(
chain = rl_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
)
response = chain.run(
User=pick_best_chain.base.BasedOn("Context"),
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "hey"
@ -269,17 +269,17 @@ def test_auto_scorer_with_user_defined_llm():
def test_calling_chain_w_reserved_inputs_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
with pytest.raises(ValueError):
chain.run(
User=pick_best_chain.base.BasedOn("Context"),
rl_chain_selected_based_on=pick_best_chain.base.ToSelectFrom(
User=rl_chain.BasedOn("Context"),
rl_chain_selected_based_on=rl_chain.ToSelectFrom(
["0", "1", "2"]
),
)
with pytest.raises(ValueError):
chain.run(
User=pick_best_chain.base.BasedOn("Context"),
rl_chain_selected=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
User=rl_chain.BasedOn("Context"),
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
)

@ -1,4 +1,4 @@
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
import langchain.chains.rl_chain as rl_chain
from test_utils import MockEncoder
import pytest
@ -7,9 +7,9 @@ encoded_text = "[ e n c o d e d ] "
def test_pickbest_textembedder_missing_context_throws():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_action = {"action": ["0", "1", "2"]}
event = pick_best_chain.PickBest.Event(
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_action, based_on={}
)
with pytest.raises(ValueError):
@ -17,8 +17,8 @@ def test_pickbest_textembedder_missing_context_throws():
def test_pickbest_textembedder_missing_actions_throws():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
event = pick_best_chain.PickBest.Event(
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
event = rl_chain.PickBest.Event(
inputs={}, to_select_from={}, based_on={"context": "context"}
)
with pytest.raises(ValueError):
@ -26,10 +26,10 @@ def test_pickbest_textembedder_missing_actions_throws():
def test_pickbest_textembedder_no_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
event = pick_best_chain.PickBest.Event(
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
)
vw_ex_str = feature_embedder.format(event)
@ -37,11 +37,11 @@ def test_pickbest_textembedder_no_label_no_emb():
def test_pickbest_textembedder_w_label_no_score_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
event = pick_best_chain.PickBest.Event(
selected = rl_chain.PickBest.Selected(index=0, probability=1.0)
event = rl_chain.PickBest.Event(
inputs={},
to_select_from=named_actions,
based_on={"context": "context"},
@ -52,13 +52,13 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
def test_pickbest_textembedder_w_full_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = (
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
)
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = rl_chain.PickBest.Event(
inputs={},
to_select_from=named_actions,
based_on={"context": "context"},
@ -69,7 +69,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
def test_pickbest_textembedder_w_full_label_w_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
str3 = "2"
@ -80,11 +80,11 @@ def test_pickbest_textembedder_w_full_label_w_emb():
ctx_str_1 = "context1"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
named_actions = {"action1": pick_best_chain.base.Embed([str1, str2, str3])}
context = {"context": pick_best_chain.base.Embed(ctx_str_1)}
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
context = {"context": rl_chain.Embed(ctx_str_1)}
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -92,7 +92,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
str3 = "2"
@ -103,11 +103,11 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
ctx_str_1 = "context1"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
named_actions = {"action1": pick_best_chain.base.EmbedAndKeep([str1, str2, str3])}
context = {"context": pick_best_chain.base.EmbedAndKeep(ctx_str_1)}
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -115,11 +115,11 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """
event = pick_best_chain.PickBest.Event(
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
@ -127,12 +127,12 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
event = pick_best_chain.PickBest.Event(
selected = rl_chain.PickBest.Selected(index=0, probability=1.0)
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -140,12 +140,12 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -153,7 +153,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
@ -168,16 +168,16 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
named_actions = {
"action1": pick_best_chain.base.Embed([{"a": str1, "b": str1}, str2, str3])
"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])
}
context = {
"context1": pick_best_chain.base.Embed(ctx_str_1),
"context2": pick_best_chain.base.Embed(ctx_str_2),
"context1": rl_chain.Embed(ctx_str_1),
"context2": rl_chain.Embed(ctx_str_2),
}
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -185,7 +185,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
@ -200,18 +200,18 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
named_actions = {
"action1": pick_best_chain.base.EmbedAndKeep(
"action1": rl_chain.EmbedAndKeep(
[{"a": str1, "b": str1}, str2, str3]
)
}
context = {
"context1": pick_best_chain.base.EmbedAndKeep(ctx_str_1),
"context2": pick_best_chain.base.EmbedAndKeep(ctx_str_2),
"context1": rl_chain.EmbedAndKeep(ctx_str_1),
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
}
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -219,7 +219,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
@ -235,16 +235,16 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
named_actions = {
"action1": [
{"a": str1, "b": pick_best_chain.base.Embed(str1)},
{"a": str1, "b": rl_chain.Embed(str1)},
str2,
pick_best_chain.base.Embed(str3),
rl_chain.Embed(str3),
]
}
context = {"context1": ctx_str_1, "context2": pick_best_chain.base.Embed(ctx_str_2)}
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -252,7 +252,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
@ -268,19 +268,19 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
named_actions = {
"action1": [
{"a": str1, "b": pick_best_chain.base.EmbedAndKeep(str1)},
{"a": str1, "b": rl_chain.EmbedAndKeep(str1)},
str2,
pick_best_chain.base.EmbedAndKeep(str3),
rl_chain.EmbedAndKeep(str3),
]
}
context = {
"context1": ctx_str_1,
"context2": pick_best_chain.base.EmbedAndKeep(ctx_str_2),
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
}
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
selected = rl_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
@ -288,7 +288,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
def test_raw_features_underscored():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "this is a long string"
str1_underscored = str1.replace(" ", "_")
encoded_str1 = encoded_text + " ".join(char for char in str1)
@ -303,27 +303,27 @@ def test_raw_features_underscored():
expected_no_embed = (
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
)
event = pick_best_chain.PickBest.Event(
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected_no_embed
# Just embeddings
named_actions = {"action": pick_best_chain.base.Embed([str1])}
context = {"context": pick_best_chain.base.Embed(ctx_str)}
named_actions = {"action": rl_chain.Embed([str1])}
context = {"context": rl_chain.Embed(ctx_str)}
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
event = pick_best_chain.PickBest.Event(
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected_embed
# Embeddings and raw features
named_actions = {"action": pick_best_chain.base.EmbedAndKeep([str1])}
context = {"context": pick_best_chain.base.EmbedAndKeep(ctx_str)}
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """
event = pick_best_chain.PickBest.Event(
event = rl_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)

@ -1,4 +1,4 @@
import langchain.chains.rl_chain.rl_chain_base as base
import langchain.chains.rl_chain as base
from test_utils import MockEncoder
import pytest

Loading…
Cancel
Save