fix lock, imports, deps, test w deps, typo, formatting

pull/10242/head
olgavrou 1 year ago
parent e9423300d9
commit 1ae5a9c7a3

@ -2,25 +2,22 @@ from __future__ import annotations
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence
from abc import ABC, abstractmethod
import vowpal_wabbit_next as vw
from langchain.chains.rl_chain.vw_logger import VwLogger
from langchain.chains.rl_chain.model_repository import ModelRepository
from langchain.chains.rl_chain.metrics import MetricsTracker
from langchain.prompts import BasePromptTemplate
from langchain.pydantic_v1 import Extra, BaseModel, root_validator
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.rl_chain.metrics import MetricsTracker
from langchain.chains.rl_chain.model_repository import ModelRepository
from langchain.chains.rl_chain.vw_logger import VwLogger
from langchain.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
logger = logging.getLogger(__name__)
@ -87,7 +84,9 @@ def EmbedAndKeep(anything):
# helper functions
def parse_lines(parser: vw.TextFormatParser, input_str: str) -> List[vw.Example]:
def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
import vowpal_wabbit_next as vw
return [parser.parse_line(line) for line in input_str.split("\n")]
@ -100,7 +99,8 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
"No variables using 'ToSelectFrom' found in the inputs. \
Please include at least one variable containing a list to select from."
)
based_on = {
@ -173,14 +173,17 @@ class VwPolicy(Policy):
self.vw_logger = vw_logger
def predict(self, event: Event) -> Any:
import vowpal_wabbit_next as vw
text_parser = vw.TextFormatParser(self.workspace)
return self.workspace.predict_one(
parse_lines(text_parser, self.feature_embedder.format(event))
)
def learn(self, event: Event):
vw_ex = self.feature_embedder.format(event)
import vowpal_wabbit_next as vw
vw_ex = self.feature_embedder.format(event)
text_parser = vw.TextFormatParser(self.workspace)
multi_ex = parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)
@ -216,7 +219,7 @@ class AutoSelectionScorer(SelectionScorer, BaseModel):
@staticmethod
def get_default_system_prompt() -> SystemMessagePromptTemplate:
return SystemMessagePromptTemplate.from_template(
"PLEASE RESPOND ONLY WITH A SIGNLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
"PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
You must respond with your ranking by providing a single float within the range [0, 1], 0 being very bad response and 1 being very good response."
)

@ -1,4 +1,3 @@
import pandas as pd
from typing import Optional
@ -23,5 +22,7 @@ class MetricsTracker:
if self._step > 0 and self._i % self._step == 0:
self._history.append({"step": self._i, "score": self.score})
def to_pandas(self) -> pd.DataFrame:
def to_pandas(self) -> "pd.DataFrame":
import pandas as pd
return pd.DataFrame(self._history)

@ -1,11 +1,10 @@
from pathlib import Path
import shutil
import datetime
import vowpal_wabbit_next as vw
from typing import Union, Sequence
import os
import glob
import logging
import os
import shutil
from pathlib import Path
from typing import Sequence, Union
logger = logging.getLogger(__name__)
@ -35,14 +34,18 @@ class ModelRepository:
def has_history(self) -> bool:
return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0
def save(self, workspace: vw.Workspace) -> None:
def save(self, workspace: "vw.Workspace") -> None:
import vowpal_wabbit_next as vw
with open(self.model_path, "wb") as f:
logger.info(f"storing rl_chain model in: {self.model_path}")
f.write(workspace.serialize())
if self.with_history: # write history
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
def load(self, commandline: Sequence[str]) -> vw.Workspace:
def load(self, commandline: Sequence[str]) -> "vw.Workspace":
import vowpal_wabbit_next as vw
model_data = None
if self.model_path.exists():
with open(self.model_path, "rb") as f:

@ -1,19 +1,15 @@
from __future__ import annotations
import langchain.chains.rl_chain.base as base
import logging
from typing import Any, Dict, List, Optional, Tuple
import langchain.chains.rl_chain.base as base
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from sentence_transformers import SentenceTransformer
from langchain.prompts import BasePromptTemplate
import logging
logger = logging.getLogger(__name__)
# sentinel object used to distinguish between user didn't supply anything or user explicitly supplied None
@ -23,7 +19,7 @@ SENTINEL = object()
class PickBestFeatureEmbedder(base.Embedder):
"""
Contextual Bandit Text Embedder class that embeds the based_on and to_select_from into a format that can be used by VW
Attributes:
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
"""
@ -32,6 +28,8 @@ class PickBestFeatureEmbedder(base.Embedder):
super().__init__(*args, **kwargs)
if model is None:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("bert-base-nli-mean-tokens")
self.model = model
@ -67,7 +65,7 @@ class PickBestFeatureEmbedder(base.Embedder):
)
example_string = ""
example_string += f"shared "
example_string += "shared "
for context_item in context_emb:
for ns, based_on in context_item.items():
example_string += f"|{ns} {' '.join(based_on) if isinstance(based_on, list) else based_on} "
@ -190,6 +188,8 @@ class PickBest(base.RLChain):
def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
) -> Tuple[Dict[str, Any], PickBest.Event]:
import numpy as np
prob_sum = sum(prob for _, prob in prediction)
probabilities = [prob / prob_sum for _, prob in prediction]
## sample from the pmf
@ -237,7 +237,7 @@ class PickBest(base.RLChain):
Attributes:
inputs: (Dict, required) The inputs to the chain. The inputs must contain a input variables that are wrapped in BasedOn and ToSelectFrom. BasedOn is the based_on that will be used for selecting an ToSelectFrom action that will be passed to the LLM prompt.
run_manager: (CallbackManagerForChainRun, optional) The callback manager to use for this run. If not provided, a default callback manager is used.
Returns:
A dictionary containing:
- `response`: The response generated by the LLM (Language Model).

File diff suppressed because it is too large Load Diff

@ -338,6 +338,7 @@ extended_testing = [
"xmltodict",
"faiss-cpu",
"openapi-schema-pydantic",
"vowpal-wabbit-next"
]
[tool.ruff]

@ -1,13 +1,15 @@
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
import langchain.chains.rl_chain.base as rl_chain
from test_utils import MockEncoder
import pytest
from langchain.prompts.prompt import PromptTemplate
from test_utils import MockEncoder
import langchain.chains.rl_chain.base as rl_chain
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
from langchain.chat_models import FakeListChatModel
from langchain.prompts.prompt import PromptTemplate
encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
def setup():
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
@ -16,6 +18,7 @@ def setup():
return llm, PROMPT
@pytest.mark.requires("vowpal_wabbit_next")
def test_multiple_ToSelectFrom_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -28,6 +31,7 @@ def test_multiple_ToSelectFrom_throws():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_missing_basedOn_from_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -36,6 +40,7 @@ def test_missing_basedOn_from_throws():
chain.run(action=rl_chain.ToSelectFrom(actions))
@pytest.mark.requires("vowpal_wabbit_next")
def test_ToSelectFrom_not_a_list_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
@ -47,6 +52,7 @@ def test_ToSelectFrom_not_a_list_throws():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_update_with_delayed_score_with_auto_validator_throws():
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
@ -68,6 +74,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
chain.update_with_delayed_score(event=selection_metadata, score=100)
@pytest.mark.requires("vowpal_wabbit_next")
def test_update_with_delayed_score_force():
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
@ -91,6 +98,7 @@ def test_update_with_delayed_score_force():
assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
def test_update_with_delayed_score():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
@ -108,6 +116,7 @@ def test_update_with_delayed_score():
assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
def test_user_defined_scorer():
llm, PROMPT = setup()
@ -129,6 +138,7 @@ def test_user_defined_scorer():
assert selection_metadata.selected.score == 200.0
@pytest.mark.requires("vowpal_wabbit_next")
def test_default_embeddings():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -162,6 +172,7 @@ def test_default_embeddings():
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_default_embeddings_off():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -187,6 +198,7 @@ def test_default_embeddings_off():
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_default_embeddings_mixed_w_explicit_user_embeddings():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -221,6 +233,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
assert vw_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_default_no_scorer_specified():
_, PROMPT = setup()
chain_llm = FakeListChatModel(responses=[100])
@ -235,6 +248,7 @@ def test_default_no_scorer_specified():
assert selection_metadata.selected.score == 100.0
@pytest.mark.requires("vowpal_wabbit_next")
def test_explicitly_no_scorer():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
@ -250,6 +264,7 @@ def test_explicitly_no_scorer():
assert selection_metadata.selected.score == None
@pytest.mark.requires("vowpal_wabbit_next")
def test_auto_scorer_with_user_defined_llm():
llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=[300])
@ -268,15 +283,14 @@ def test_auto_scorer_with_user_defined_llm():
assert selection_metadata.selected.score == 300.0
@pytest.mark.requires("vowpal_wabbit_next")
def test_calling_chain_w_reserved_inputs_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
with pytest.raises(ValueError):
chain.run(
User=rl_chain.BasedOn("Context"),
rl_chain_selected_based_on=rl_chain.ToSelectFrom(
["0", "1", "2"]
),
rl_chain_selected_based_on=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
with pytest.raises(ValueError):

@ -1,12 +1,13 @@
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
import langchain.chains.rl_chain.base as rl_chain
import pytest
from test_utils import MockEncoder
import pytest
import langchain.chains.rl_chain.base as rl_chain
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_context_throws():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_action = {"action": ["0", "1", "2"]}
@ -17,6 +18,7 @@ def test_pickbest_textembedder_missing_context_throws():
feature_embedder.format(event)
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_missing_actions_throws():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
event = pick_best_chain.PickBest.Event(
@ -26,6 +28,7 @@ def test_pickbest_textembedder_missing_actions_throws():
feature_embedder.format(event)
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_no_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
@ -37,6 +40,7 @@ def test_pickbest_textembedder_no_label_no_emb():
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_label_no_score_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
@ -52,6 +56,7 @@ def test_pickbest_textembedder_w_label_no_score_no_emb():
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
@ -69,6 +74,7 @@ def test_pickbest_textembedder_w_full_label_no_emb():
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
@ -92,6 +98,7 @@ def test_pickbest_textembedder_w_full_label_w_emb():
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
@ -115,6 +122,7 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
@ -127,6 +135,7 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
@ -140,6 +149,7 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
@ -153,6 +163,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -168,9 +179,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
named_actions = {
"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])
}
named_actions = {"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])}
context = {
"context1": rl_chain.Embed(ctx_str_1),
"context2": rl_chain.Embed(ctx_str_2),
@ -185,6 +194,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -201,9 +211,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
named_actions = {
"action1": rl_chain.EmbedAndKeep(
[{"a": str1, "b": str1}, str2, str3]
)
"action1": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3])
}
context = {
"context1": rl_chain.EmbedAndKeep(ctx_str_1),
@ -219,6 +227,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -252,6 +261,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
@ -288,6 +298,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_
assert vw_ex_str == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_raw_features_underscored():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "this is a long string"

@ -1,16 +1,18 @@
import langchain.chains.rl_chain.base as base
import pytest
from test_utils import MockEncoder
import pytest
import langchain.chains.rl_chain.base as base
encoded_text = "[ e n c o d e d ] "
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_no_emb():
expected = [{"a_namespace": "test"}]
assert base.embed("test", MockEncoder(), "a_namespace") == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_emb():
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
@ -25,6 +27,7 @@ def test_simple_context_str_w_emb():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_context_str_w_nested_emb():
# nested embeddings, innermost wins
str1 = "test"
@ -42,11 +45,13 @@ def test_simple_context_str_w_nested_emb():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_no_emb():
expected = [{"test_namespace": "test"}]
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb():
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
@ -61,6 +66,7 @@ def test_context_w_namespace_w_emb():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_emb2():
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
@ -75,6 +81,7 @@ def test_context_w_namespace_w_emb2():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_context_w_namespace_w_some_emb():
str1 = "test1"
str2 = "test2"
@ -103,6 +110,7 @@ def test_context_w_namespace_w_some_emb():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_no_emb():
str1 = "test1"
str2 = "test2"
@ -111,6 +119,7 @@ def test_simple_action_strlist_no_emb():
assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_emb():
str1 = "test1"
str2 = "test2"
@ -138,6 +147,7 @@ def test_simple_action_strlist_w_emb():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_simple_action_strlist_w_some_emb():
str1 = "test1"
str2 = "test2"
@ -170,6 +180,7 @@ def test_simple_action_strlist_w_some_emb():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_no_emb():
str1 = "test1"
str2 = "test2"
@ -192,6 +203,7 @@ def test_action_w_namespace_no_emb():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb():
str1 = "test1"
str2 = "test2"
@ -233,6 +245,7 @@ def test_action_w_namespace_w_emb():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb2():
str1 = "test1"
str2 = "test2"
@ -278,6 +291,7 @@ def test_action_w_namespace_w_emb2():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_some_emb():
str1 = "test1"
str2 = "test2"
@ -318,6 +332,7 @@ def test_action_w_namespace_w_some_emb():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
str1 = "test1"
str2 = "test2"
@ -368,6 +383,7 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_no_emb():
str1 = "test1"
str2 = "test2"
@ -375,6 +391,7 @@ def test_one_namespace_w_list_of_features_no_emb():
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
@pytest.mark.requires("vowpal_wabbit_next")
def test_one_namespace_w_list_of_features_w_some_emb():
str1 = "test1"
str2 = "test2"
@ -386,21 +403,25 @@ def test_one_namespace_w_list_of_features_w_some_emb():
)
@pytest.mark.requires("vowpal_wabbit_next")
def test_nested_list_features_throws():
with pytest.raises(ValueError):
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_dict_in_list_throws():
with pytest.raises(ValueError):
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_nested_dict_throws():
with pytest.raises(ValueError):
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
@pytest.mark.requires("vowpal_wabbit_next")
def test_list_of_tuples_throws():
with pytest.raises(ValueError):
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())

Loading…
Cancel
Save