mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
423 lines
13 KiB
Python
423 lines
13 KiB
Python
from typing import List, Union
|
|
|
|
import pytest
|
|
from test_utils import MockEncoder
|
|
|
|
import langchain_experimental.rl_chain.base as base
|
|
|
|
encoded_keyword = "[encoded]"
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_simple_context_str_no_emb() -> None:
|
|
expected = [{"a_namespace": "test"}]
|
|
assert base.embed("test", MockEncoder(), "a_namespace") == expected
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_simple_context_str_w_emb() -> None:
|
|
str1 = "test"
|
|
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
|
expected = [{"a_namespace": encoded_str1}]
|
|
assert base.embed(base.Embed(str1), MockEncoder(), "a_namespace") == expected
|
|
expected_embed_and_keep = [{"a_namespace": str1 + " " + encoded_str1}]
|
|
assert (
|
|
base.embed(base.EmbedAndKeep(str1), MockEncoder(), "a_namespace")
|
|
== expected_embed_and_keep
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_simple_context_str_w_nested_emb() -> None:
|
|
# nested embeddings, innermost wins
|
|
str1 = "test"
|
|
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
|
expected = [{"a_namespace": encoded_str1}]
|
|
assert (
|
|
base.embed(base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace")
|
|
== expected
|
|
)
|
|
|
|
expected2 = [{"a_namespace": str1 + " " + encoded_str1}]
|
|
assert (
|
|
base.embed(base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace")
|
|
== expected2
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_context_w_namespace_no_emb() -> None:
|
|
expected = [{"test_namespace": "test"}]
|
|
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_context_w_namespace_w_emb() -> None:
|
|
str1 = "test"
|
|
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
|
expected = [{"test_namespace": encoded_str1}]
|
|
assert base.embed({"test_namespace": base.Embed(str1)}, MockEncoder()) == expected
|
|
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
|
assert (
|
|
base.embed({"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder())
|
|
== expected_embed_and_keep
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_context_w_namespace_w_emb2() -> None:
|
|
str1 = "test"
|
|
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
|
expected = [{"test_namespace": encoded_str1}]
|
|
assert base.embed(base.Embed({"test_namespace": str1}), MockEncoder()) == expected
|
|
expected_embed_and_keep = [{"test_namespace": str1 + " " + encoded_str1}]
|
|
assert (
|
|
base.embed(base.EmbedAndKeep({"test_namespace": str1}), MockEncoder())
|
|
== expected_embed_and_keep
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_context_w_namespace_w_some_emb() -> None:
|
|
str1 = "test1"
|
|
str2 = "test2"
|
|
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
|
expected = [{"test_namespace": str1, "test_namespace2": encoded_str2}]
|
|
assert (
|
|
base.embed(
|
|
{"test_namespace": str1, "test_namespace2": base.Embed(str2)}, MockEncoder()
|
|
)
|
|
== expected
|
|
)
|
|
expected_embed_and_keep = [
|
|
{
|
|
"test_namespace": str1,
|
|
"test_namespace2": str2 + " " + encoded_str2,
|
|
}
|
|
]
|
|
assert (
|
|
base.embed(
|
|
{"test_namespace": str1, "test_namespace2": base.EmbedAndKeep(str2)},
|
|
MockEncoder(),
|
|
)
|
|
== expected_embed_and_keep
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_simple_action_strlist_no_emb() -> None:
|
|
str1 = "test1"
|
|
str2 = "test2"
|
|
str3 = "test3"
|
|
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
|
|
to_embed: List[Union[str, base._Embed]] = [str1, str2, str3]
|
|
assert base.embed(to_embed, MockEncoder(), "a_namespace") == expected
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_simple_action_strlist_w_emb() -> None:
|
|
str1 = "test1"
|
|
str2 = "test2"
|
|
str3 = "test3"
|
|
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
|
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
|
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
|
expected = [
|
|
{"a_namespace": encoded_str1},
|
|
{"a_namespace": encoded_str2},
|
|
{"a_namespace": encoded_str3},
|
|
]
|
|
assert (
|
|
base.embed(base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace")
|
|
== expected
|
|
)
|
|
expected_embed_and_keep = [
|
|
{"a_namespace": str1 + " " + encoded_str1},
|
|
{"a_namespace": str2 + " " + encoded_str2},
|
|
{"a_namespace": str3 + " " + encoded_str3},
|
|
]
|
|
assert (
|
|
base.embed(base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace")
|
|
== expected_embed_and_keep
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_simple_action_strlist_w_some_emb() -> None:
|
|
str1 = "test1"
|
|
str2 = "test2"
|
|
str3 = "test3"
|
|
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
|
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
|
expected = [
|
|
{"a_namespace": str1},
|
|
{"a_namespace": encoded_str2},
|
|
{"a_namespace": encoded_str3},
|
|
]
|
|
assert (
|
|
base.embed(
|
|
[str1, base.Embed(str2), base.Embed(str3)], MockEncoder(), "a_namespace"
|
|
)
|
|
== expected
|
|
)
|
|
expected_embed_and_keep = [
|
|
{"a_namespace": str1},
|
|
{"a_namespace": str2 + " " + encoded_str2},
|
|
{"a_namespace": str3 + " " + encoded_str3},
|
|
]
|
|
assert (
|
|
base.embed(
|
|
[str1, base.EmbedAndKeep(str2), base.EmbedAndKeep(str3)],
|
|
MockEncoder(),
|
|
"a_namespace",
|
|
)
|
|
== expected_embed_and_keep
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_action_w_namespace_no_emb() -> None:
|
|
str1 = "test1"
|
|
str2 = "test2"
|
|
str3 = "test3"
|
|
expected = [
|
|
{"test_namespace": str1},
|
|
{"test_namespace": str2},
|
|
{"test_namespace": str3},
|
|
]
|
|
assert (
|
|
base.embed(
|
|
[
|
|
{"test_namespace": str1},
|
|
{"test_namespace": str2},
|
|
{"test_namespace": str3},
|
|
],
|
|
MockEncoder(),
|
|
)
|
|
== expected
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_action_w_namespace_w_emb() -> None:
|
|
str1 = "test1"
|
|
str2 = "test2"
|
|
str3 = "test3"
|
|
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
|
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
|
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
|
expected = [
|
|
{"test_namespace": encoded_str1},
|
|
{"test_namespace": encoded_str2},
|
|
{"test_namespace": encoded_str3},
|
|
]
|
|
assert (
|
|
base.embed(
|
|
[
|
|
{"test_namespace": base.Embed(str1)},
|
|
{"test_namespace": base.Embed(str2)},
|
|
{"test_namespace": base.Embed(str3)},
|
|
],
|
|
MockEncoder(),
|
|
)
|
|
== expected
|
|
)
|
|
expected_embed_and_keep = [
|
|
{"test_namespace": str1 + " " + encoded_str1},
|
|
{"test_namespace": str2 + " " + encoded_str2},
|
|
{"test_namespace": str3 + " " + encoded_str3},
|
|
]
|
|
assert (
|
|
base.embed(
|
|
[
|
|
{"test_namespace": base.EmbedAndKeep(str1)},
|
|
{"test_namespace": base.EmbedAndKeep(str2)},
|
|
{"test_namespace": base.EmbedAndKeep(str3)},
|
|
],
|
|
MockEncoder(),
|
|
)
|
|
== expected_embed_and_keep
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_action_w_namespace_w_emb2() -> None:
|
|
str1 = "test1"
|
|
str2 = "test2"
|
|
str3 = "test3"
|
|
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
|
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
|
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
|
expected = [
|
|
{"test_namespace1": encoded_str1},
|
|
{"test_namespace2": encoded_str2},
|
|
{"test_namespace3": encoded_str3},
|
|
]
|
|
assert (
|
|
base.embed(
|
|
base.Embed(
|
|
[
|
|
{"test_namespace1": str1},
|
|
{"test_namespace2": str2},
|
|
{"test_namespace3": str3},
|
|
]
|
|
),
|
|
MockEncoder(),
|
|
)
|
|
== expected
|
|
)
|
|
expected_embed_and_keep = [
|
|
{"test_namespace1": str1 + " " + encoded_str1},
|
|
{"test_namespace2": str2 + " " + encoded_str2},
|
|
{"test_namespace3": str3 + " " + encoded_str3},
|
|
]
|
|
assert (
|
|
base.embed(
|
|
base.EmbedAndKeep(
|
|
[
|
|
{"test_namespace1": str1},
|
|
{"test_namespace2": str2},
|
|
{"test_namespace3": str3},
|
|
]
|
|
),
|
|
MockEncoder(),
|
|
)
|
|
== expected_embed_and_keep
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_action_w_namespace_w_some_emb() -> None:
|
|
str1 = "test1"
|
|
str2 = "test2"
|
|
str3 = "test3"
|
|
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
|
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
|
expected = [
|
|
{"test_namespace": str1},
|
|
{"test_namespace": encoded_str2},
|
|
{"test_namespace": encoded_str3},
|
|
]
|
|
assert (
|
|
base.embed(
|
|
[
|
|
{"test_namespace": str1},
|
|
{"test_namespace": base.Embed(str2)},
|
|
{"test_namespace": base.Embed(str3)},
|
|
],
|
|
MockEncoder(),
|
|
)
|
|
== expected
|
|
)
|
|
expected_embed_and_keep = [
|
|
{"test_namespace": str1},
|
|
{"test_namespace": str2 + " " + encoded_str2},
|
|
{"test_namespace": str3 + " " + encoded_str3},
|
|
]
|
|
assert (
|
|
base.embed(
|
|
[
|
|
{"test_namespace": str1},
|
|
{"test_namespace": base.EmbedAndKeep(str2)},
|
|
{"test_namespace": base.EmbedAndKeep(str3)},
|
|
],
|
|
MockEncoder(),
|
|
)
|
|
== expected_embed_and_keep
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
|
|
str1 = "test1"
|
|
str2 = "test2"
|
|
str3 = "test3"
|
|
encoded_str1 = base.stringify_embedding(list(encoded_keyword + str1))
|
|
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
|
encoded_str3 = base.stringify_embedding(list(encoded_keyword + str3))
|
|
expected = [
|
|
{"test_namespace": encoded_str1, "test_namespace2": str1},
|
|
{"test_namespace": encoded_str2, "test_namespace2": str2},
|
|
{"test_namespace": encoded_str3, "test_namespace2": str3},
|
|
]
|
|
assert (
|
|
base.embed(
|
|
[
|
|
{"test_namespace": base.Embed(str1), "test_namespace2": str1},
|
|
{"test_namespace": base.Embed(str2), "test_namespace2": str2},
|
|
{"test_namespace": base.Embed(str3), "test_namespace2": str3},
|
|
],
|
|
MockEncoder(),
|
|
)
|
|
== expected
|
|
)
|
|
expected_embed_and_keep = [
|
|
{
|
|
"test_namespace": str1 + " " + encoded_str1,
|
|
"test_namespace2": str1,
|
|
},
|
|
{
|
|
"test_namespace": str2 + " " + encoded_str2,
|
|
"test_namespace2": str2,
|
|
},
|
|
{
|
|
"test_namespace": str3 + " " + encoded_str3,
|
|
"test_namespace2": str3,
|
|
},
|
|
]
|
|
assert (
|
|
base.embed(
|
|
[
|
|
{"test_namespace": base.EmbedAndKeep(str1), "test_namespace2": str1},
|
|
{"test_namespace": base.EmbedAndKeep(str2), "test_namespace2": str2},
|
|
{"test_namespace": base.EmbedAndKeep(str3), "test_namespace2": str3},
|
|
],
|
|
MockEncoder(),
|
|
)
|
|
== expected_embed_and_keep
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_one_namespace_w_list_of_features_no_emb() -> None:
|
|
str1 = "test1"
|
|
str2 = "test2"
|
|
expected = [{"test_namespace": [str1, str2]}]
|
|
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_one_namespace_w_list_of_features_w_some_emb() -> None:
|
|
str1 = "test1"
|
|
str2 = "test2"
|
|
encoded_str2 = base.stringify_embedding(list(encoded_keyword + str2))
|
|
expected = [{"test_namespace": [str1, encoded_str2]}]
|
|
assert (
|
|
base.embed({"test_namespace": [str1, base.Embed(str2)]}, MockEncoder())
|
|
== expected
|
|
)
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_nested_list_features_throws() -> None:
|
|
with pytest.raises(ValueError):
|
|
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_dict_in_list_throws() -> None:
|
|
with pytest.raises(ValueError):
|
|
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_nested_dict_throws() -> None:
|
|
with pytest.raises(ValueError):
|
|
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
|
|
|
|
|
|
@pytest.mark.requires("vowpal_wabbit_next")
|
|
def test_list_of_tuples_throws() -> None:
|
|
with pytest.raises(ValueError):
|
|
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())
|