You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/model_laboratory.py

85 lines
3.3 KiB
Python

"""Experiment with different models."""
from typing import List, Optional, Sequence, Union
from langchain.agents.agent import Agent
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.printing import print_text, get_color_mapping
from langchain.llms.base import LLM
from langchain.prompts.prompt import PromptTemplate
class ModelLaboratory:
"""Experiment with different models."""
def __init__(
self, chains: Sequence[Union[Chain, Agent]], names: Optional[List[str]] = None
):
"""Initialize with chains to experiment with.
Args:
chains: list of chains to experiment with.
"""
for chain in chains:
if not isinstance(chain, (Chain, Agent)):
raise ValueError(
"ModelLaboratory should now be initialized with Chains or Agents. "
"If you want to initialize with LLMs, use the `from_llms` method "
"instead (`ModelLaboratory.from_llms(...)`)"
)
if isinstance(chain, Chain):
if len(chain.input_keys) != 1:
raise ValueError(
"Currently only support chains with one input variable, "
f"got {chain.input_keys}"
)
if len(chain.output_keys) != 1:
raise ValueError(
"Currently only support chains with one output variable, "
f"got {chain.output_keys}"
)
if names is not None:
if len(names) != len(chains):
raise ValueError("Length of chains does not match length of names.")
self.chains = chains
chain_range = [str(i) for i in range(len(self.chains))]
self.chain_colors = get_color_mapping(chain_range)
self.names = names
@classmethod
def from_llms(
cls, llms: List[LLM], prompt: Optional[PromptTemplate] = None
) -> "ModelLaboratory":
"""Initialize with LLMs to experiment with and optional prompt.
Args:
llms: list of LLMs to experiment with
prompt: Optional prompt to use to prompt the LLMs. Defaults to None.
If a prompt was provided, it should only have one input variable.
"""
if prompt is None:
prompt = PromptTemplate(input_variables=["_input"], template="{_input}")
chains = [LLMChain(llm=llm, prompt=prompt) for llm in llms]
names = [str(llm) for llm in llms]
return cls(chains, names=names)
def compare(self, text: str) -> None:
"""Compare model outputs on an input text.
If a prompt was provided with starting the laboratory, then this text will be
fed into the prompt. If no prompt was provided, then the input text is the
entire prompt.
Args:
text: input text to run all models on.
"""
print(f"\033[1mInput:\033[0m\n{text}\n")
for i, chain in enumerate(self.chains):
if self.names is not None:
name = self.names[i]
else:
name = str(chain)
print_text(name, end="\n")
output = chain.run(text)
print_text(output, color=self.chain_colors[str(i)], end="\n\n")