Harrison/chain lab (#156)

This commit is contained in:
Harrison Chase 2022-11-18 05:50:02 -08:00 committed by GitHub
parent 0ac08bbca6
commit b15c84e19d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 154 additions and 29 deletions

View File

@ -42,7 +42,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model_lab = ModelLaboratory(llms)" "model_lab = ModelLaboratory.from_llms(llms)"
] ]
}, },
{ {
@ -60,19 +60,19 @@
"\n", "\n",
"\u001b[1mOpenAI\u001b[0m\n", "\u001b[1mOpenAI\u001b[0m\n",
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
"\u001b[104m\n", "\u001b[36;1m\u001b[1;3m\n",
"\n", "\n",
"Flamingos are pink.\u001b[0m\n", "Flamingos are pink.\u001b[0m\n",
"\n", "\n",
"\u001b[1mCohere\u001b[0m\n", "\u001b[1mCohere\u001b[0m\n",
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
"\u001b[103m\n", "\u001b[33;1m\u001b[1;3m\n",
"\n", "\n",
"Pink\u001b[0m\n", "Pink\u001b[0m\n",
"\n", "\n",
"\u001b[1mHuggingFaceHub\u001b[0m\n", "\u001b[1mHuggingFaceHub\u001b[0m\n",
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n", "Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
"\u001b[101mpink\u001b[0m\n", "\u001b[38;5;200m\u001b[1;3mpink\u001b[0m\n",
"\n" "\n"
] ]
} }
@ -89,7 +89,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"prompt = Prompt(template=\"What is the capital of {state}?\", input_variables=[\"state\"])\n", "prompt = Prompt(template=\"What is the capital of {state}?\", input_variables=[\"state\"])\n",
"model_lab_with_prompt = ModelLaboratory(llms, prompt=prompt)" "model_lab_with_prompt = ModelLaboratory.from_llms(llms, prompt=prompt)"
] ]
}, },
{ {
@ -107,19 +107,19 @@
"\n", "\n",
"\u001b[1mOpenAI\u001b[0m\n", "\u001b[1mOpenAI\u001b[0m\n",
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
"\u001b[104m\n", "\u001b[36;1m\u001b[1;3m\n",
"\n", "\n",
"The capital of New York is Albany.\u001b[0m\n", "The capital of New York is Albany.\u001b[0m\n",
"\n", "\n",
"\u001b[1mCohere\u001b[0m\n", "\u001b[1mCohere\u001b[0m\n",
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
"\u001b[103m\n", "\u001b[33;1m\u001b[1;3m\n",
"\n", "\n",
"The capital of New York is Albany.\u001b[0m\n", "The capital of New York is Albany.\u001b[0m\n",
"\n", "\n",
"\u001b[1mHuggingFaceHub\u001b[0m\n", "\u001b[1mHuggingFaceHub\u001b[0m\n",
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n", "Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
"\u001b[101mst john s\u001b[0m\n", "\u001b[38;5;200m\u001b[1;3mst john s\u001b[0m\n",
"\n" "\n"
] ]
} }
@ -130,10 +130,103 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 7,
"id": "54336dbf", "id": "54336dbf",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"from langchain import SelfAskWithSearchChain, SerpAPIChain\n",
"\n",
"open_ai_llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n",
"self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n",
"\n",
"cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n",
"search = SerpAPIChain()\n",
"self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6a50a9f1",
"metadata": {},
"outputs": [],
"source": [
"chains = [self_ask_with_search_openai, self_ask_with_search_cohere]\n",
"names = [str(open_ai_llm), str(cohere_llm)]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d3549e99",
"metadata": {},
"outputs": [],
"source": [
"model_lab = ModelLaboratory(chains, names=names)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "362f7f57",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1mInput:\u001b[0m\n",
"What is the hometown of the reigning men's U.S. Open champion?\n",
"\n",
"\u001b[1mOpenAI\u001b[0m\n",
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"What is the hometown of the reigning men's U.S. Open champion?\n",
"Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n",
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
"Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Follow up: Where is Carlos Alcaraz from?\u001b[0m\n",
"Intermediate answer: \u001b[33;1m\u001b[1;3mEl Palmar, Spain.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"So the final answer is: El Palmar, Spain\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[36;1m\u001b[1;3m\n",
"So the final answer is: El Palmar, Spain\u001b[0m\n",
"\n",
"\u001b[1mCohere\u001b[0m\n",
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 256, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"What is the hometown of the reigning men's U.S. Open champion?\n",
"Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n",
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
"Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"So the final answer is:\n",
"\n",
"Carlos Alcaraz\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[33;1m\u001b[1;3m\n",
"So the final answer is:\n",
"\n",
"Carlos Alcaraz\u001b[0m\n",
"\n"
]
}
],
"source": [
"model_lab.compare(\"What is the hometown of the reigning men's U.S. Open champion?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94159131",
"metadata": {},
"outputs": [],
"source": [] "source": []
} }
], ],

View File

@ -27,6 +27,8 @@ class VectorDBQA(Chain, BaseModel):
"""LLM wrapper to use.""" """LLM wrapper to use."""
vectorstore: VectorStore vectorstore: VectorStore
"""Vector Database to connect to.""" """Vector Database to connect to."""
k: int = 4
"""Number of documents to query for."""
input_key: str = "query" #: :meta private: input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private: output_key: str = "result" #: :meta private:
@ -55,7 +57,7 @@ class VectorDBQA(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
question = inputs[self.input_key] question = inputs[self.input_key]
llm_chain = LLMChain(llm=self.llm, prompt=prompt) llm_chain = LLMChain(llm=self.llm, prompt=prompt)
docs = self.vectorstore.similarity_search(question) docs = self.vectorstore.similarity_search(question, k=self.k)
contexts = [] contexts = []
for j, doc in enumerate(docs): for j, doc in enumerate(docs):
contexts.append(f"Context {j}:\n{doc.page_content}") contexts.append(f"Context {j}:\n{doc.page_content}")

View File

@ -1,6 +1,7 @@
"""Experiment with different models.""" """Experiment with different models."""
from typing import List, Optional from typing import List, Optional, Sequence
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text from langchain.input import get_color_mapping, print_text
from langchain.llms.base import LLM from langchain.llms.base import LLM
@ -10,7 +11,41 @@ from langchain.prompts.prompt import Prompt
class ModelLaboratory: class ModelLaboratory:
"""Experiment with different models.""" """Experiment with different models."""
def __init__(self, llms: List[LLM], prompt: Optional[Prompt] = None): def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None):
"""Initialize with chains to experiment with.
Args:
chains: list of chains to experiment with.
"""
if not isinstance(chains[0], Chain):
raise ValueError(
"ModelLaboratory should now be initialized with Chains. "
"If you want to initialize with LLMs, use the `from_llms` method "
"instead (`ModelLaboratory.from_llms(...)`)"
)
for chain in chains:
if len(chain.input_keys) != 1:
raise ValueError(
"Currently only support chains with one input variable, "
f"got {chain.input_keys}"
)
if len(chain.output_keys) != 1:
raise ValueError(
"Currently only support chains with one output variable, "
f"got {chain.output_keys}"
)
if names is not None:
if len(names) != len(chains):
raise ValueError("Length of chains does not match length of names.")
self.chains = chains
chain_range = [str(i) for i in range(len(self.chains))]
self.chain_colors = get_color_mapping(chain_range)
self.names = names
@classmethod
def from_llms(
cls, llms: List[LLM], prompt: Optional[Prompt] = None
) -> "ModelLaboratory":
"""Initialize with LLMs to experiment with and optional prompt. """Initialize with LLMs to experiment with and optional prompt.
Args: Args:
@ -18,18 +53,11 @@ class ModelLaboratory:
prompt: Optional prompt to use to prompt the LLMs. Defaults to None. prompt: Optional prompt to use to prompt the LLMs. Defaults to None.
If a prompt was provided, it should only have one input variable. If a prompt was provided, it should only have one input variable.
""" """
self.llms = llms
llm_range = [str(i) for i in range(len(self.llms))]
self.llm_colors = get_color_mapping(llm_range)
if prompt is None: if prompt is None:
self.prompt = Prompt(input_variables=["_input"], template="{_input}") prompt = Prompt(input_variables=["_input"], template="{_input}")
else: chains = [LLMChain(llm=llm, prompt=prompt) for llm in llms]
if len(prompt.input_variables) != 1: names = [str(llm) for llm in llms]
raise ValueError( return cls(chains, names=names)
"Currently only support prompts with one input variable, "
f"got {prompt}"
)
self.prompt = prompt
def compare(self, text: str) -> None: def compare(self, text: str) -> None:
"""Compare model outputs on an input text. """Compare model outputs on an input text.
@ -42,9 +70,11 @@ class ModelLaboratory:
text: input text to run all models on. text: input text to run all models on.
""" """
print(f"\033[1mInput:\033[0m\n{text}\n") print(f"\033[1mInput:\033[0m\n{text}\n")
for i, llm in enumerate(self.llms): for i, chain in enumerate(self.chains):
print_text(str(llm), end="\n") if self.names is not None:
chain = LLMChain(llm=llm, prompt=self.prompt) name = self.names[i]
llm_inputs = {self.prompt.input_variables[0]: text} else:
output = chain.predict(**llm_inputs) name = str(chain)
print_text(output, color=self.llm_colors[str(i)], end="\n\n") print_text(name, end="\n")
output = chain.run(text)
print_text(output, color=self.chain_colors[str(i)], end="\n\n")