|
|
|
@ -1,9 +1,8 @@
|
|
|
|
|
"""Experiment with different models."""
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Sequence, Union
|
|
|
|
|
from typing import List, Optional, Sequence
|
|
|
|
|
|
|
|
|
|
from langchain.agents.agent import Agent
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.input import get_color_mapping, print_text
|
|
|
|
@ -14,22 +13,19 @@ from langchain.prompts.prompt import PromptTemplate
|
|
|
|
|
class ModelLaboratory:
|
|
|
|
|
"""Experiment with different models."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, chains: Sequence[Union[Chain, Agent]], names: Optional[List[str]] = None
|
|
|
|
|
):
|
|
|
|
|
def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None):
|
|
|
|
|
"""Initialize with chains to experiment with.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
chains: list of chains to experiment with.
|
|
|
|
|
"""
|
|
|
|
|
for chain in chains:
|
|
|
|
|
if not isinstance(chain, (Chain, Agent)):
|
|
|
|
|
if not isinstance(chain, Chain):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"ModelLaboratory should now be initialized with Chains or Agents. "
|
|
|
|
|
"ModelLaboratory should now be initialized with Chains. "
|
|
|
|
|
"If you want to initialize with LLMs, use the `from_llms` method "
|
|
|
|
|
"instead (`ModelLaboratory.from_llms(...)`)"
|
|
|
|
|
)
|
|
|
|
|
if isinstance(chain, Chain):
|
|
|
|
|
if len(chain.input_keys) != 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Currently only support chains with one input variable, "
|
|
|
|
|