fix linting errors

pull/10242/head
olgavrou 1 year ago
parent 631289a38d
commit 248db75cd6

@ -19,19 +19,20 @@ from typing import (
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain_experimental.rl_chain.metrics import (
MetricsTrackerAverage,
MetricsTrackerRollingWindow,
)
from langchain_experimental.rl_chain.model_repository import ModelRepository
from langchain_experimental.rl_chain.vw_logger import VwLogger
from langchain.prompts import ( from langchain.prompts import (
BasePromptTemplate, BasePromptTemplate,
ChatPromptTemplate, ChatPromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
) )
from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_experimental.rl_chain.metrics import (
MetricsTrackerAverage,
MetricsTrackerRollingWindow,
)
from langchain_experimental.rl_chain.model_repository import ModelRepository
from langchain_experimental.rl_chain.vw_logger import VwLogger
if TYPE_CHECKING: if TYPE_CHECKING:
import vowpal_wabbit_next as vw import vowpal_wabbit_next as vw

@ -3,12 +3,13 @@ from __future__ import annotations
import logging import logging
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import Any, Dict, List, Optional, Tuple, Type, Union
import langchain_experimental.rl_chain.base as base
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
import langchain_experimental.rl_chain.base as base
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# sentinel object used to distinguish between # sentinel object used to distinguish between

@ -1,12 +1,12 @@
from typing import Any, Dict from typing import Any, Dict
import pytest import pytest
from langchain.chat_models import FakeListChatModel
from langchain.prompts.prompt import PromptTemplate
from test_utils import MockEncoder, MockEncoderReturnsList from test_utils import MockEncoder, MockEncoderReturnsList
import langchain_experimental.rl_chain.base as rl_chain import langchain_experimental.rl_chain.base as rl_chain
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
from langchain.chat_models import FakeListChatModel
from langchain.prompts.prompt import PromptTemplate
encoded_keyword = "[encoded]" encoded_keyword = "[encoded]"
@ -94,7 +94,9 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
selection_metadata = response["selection_metadata"] # type: ignore selection_metadata = response["selection_metadata"] # type: ignore
assert selection_metadata.selected.score == 3.0 # type: ignore assert selection_metadata.selected.score == 3.0 # type: ignore
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
chain.update_with_delayed_score(chain_response=response, score=100) # type: ignore chain.update_with_delayed_score(
chain_response=response, score=100 # type: ignore
)
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers") @pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")

Loading…
Cancel
Save