forked from Archives/langchain
Harrison/chain lab (#156)
This commit is contained in:
parent
0ac08bbca6
commit
b15c84e19d
@ -42,7 +42,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model_lab = ModelLaboratory(llms)"
|
"model_lab = ModelLaboratory.from_llms(llms)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -60,19 +60,19 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mOpenAI\u001b[0m\n",
|
"\u001b[1mOpenAI\u001b[0m\n",
|
||||||
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
|
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
|
||||||
"\u001b[104m\n",
|
"\u001b[36;1m\u001b[1;3m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Flamingos are pink.\u001b[0m\n",
|
"Flamingos are pink.\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mCohere\u001b[0m\n",
|
"\u001b[1mCohere\u001b[0m\n",
|
||||||
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
|
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
|
||||||
"\u001b[103m\n",
|
"\u001b[33;1m\u001b[1;3m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Pink\u001b[0m\n",
|
"Pink\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mHuggingFaceHub\u001b[0m\n",
|
"\u001b[1mHuggingFaceHub\u001b[0m\n",
|
||||||
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
|
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
|
||||||
"\u001b[101mpink\u001b[0m\n",
|
"\u001b[38;5;200m\u001b[1;3mpink\u001b[0m\n",
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -89,7 +89,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"prompt = Prompt(template=\"What is the capital of {state}?\", input_variables=[\"state\"])\n",
|
"prompt = Prompt(template=\"What is the capital of {state}?\", input_variables=[\"state\"])\n",
|
||||||
"model_lab_with_prompt = ModelLaboratory(llms, prompt=prompt)"
|
"model_lab_with_prompt = ModelLaboratory.from_llms(llms, prompt=prompt)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -107,19 +107,19 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mOpenAI\u001b[0m\n",
|
"\u001b[1mOpenAI\u001b[0m\n",
|
||||||
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
|
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
|
||||||
"\u001b[104m\n",
|
"\u001b[36;1m\u001b[1;3m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The capital of New York is Albany.\u001b[0m\n",
|
"The capital of New York is Albany.\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mCohere\u001b[0m\n",
|
"\u001b[1mCohere\u001b[0m\n",
|
||||||
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
|
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
|
||||||
"\u001b[103m\n",
|
"\u001b[33;1m\u001b[1;3m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The capital of New York is Albany.\u001b[0m\n",
|
"The capital of New York is Albany.\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1mHuggingFaceHub\u001b[0m\n",
|
"\u001b[1mHuggingFaceHub\u001b[0m\n",
|
||||||
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
|
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
|
||||||
"\u001b[101mst john s\u001b[0m\n",
|
"\u001b[38;5;200m\u001b[1;3mst john s\u001b[0m\n",
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -130,10 +130,103 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 7,
|
||||||
"id": "54336dbf",
|
"id": "54336dbf",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain import SelfAskWithSearchChain, SerpAPIChain\n",
|
||||||
|
"\n",
|
||||||
|
"open_ai_llm = OpenAI(temperature=0)\n",
|
||||||
|
"search = SerpAPIChain()\n",
|
||||||
|
"self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n",
|
||||||
|
"\n",
|
||||||
|
"cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n",
|
||||||
|
"search = SerpAPIChain()\n",
|
||||||
|
"self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "6a50a9f1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chains = [self_ask_with_search_openai, self_ask_with_search_cohere]\n",
|
||||||
|
"names = [str(open_ai_llm), str(cohere_llm)]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "d3549e99",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model_lab = ModelLaboratory(chains, names=names)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "362f7f57",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\u001b[1mInput:\u001b[0m\n",
|
||||||
|
"What is the hometown of the reigning men's U.S. Open champion?\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1mOpenAI\u001b[0m\n",
|
||||||
|
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"What is the hometown of the reigning men's U.S. Open champion?\n",
|
||||||
|
"Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n",
|
||||||
|
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
|
||||||
|
"Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"Follow up: Where is Carlos Alcaraz from?\u001b[0m\n",
|
||||||
|
"Intermediate answer: \u001b[33;1m\u001b[1;3mEl Palmar, Spain.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"So the final answer is: El Palmar, Spain\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"\u001b[36;1m\u001b[1;3m\n",
|
||||||
|
"So the final answer is: El Palmar, Spain\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1mCohere\u001b[0m\n",
|
||||||
|
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 256, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"What is the hometown of the reigning men's U.S. Open champion?\n",
|
||||||
|
"Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n",
|
||||||
|
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
|
||||||
|
"Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"So the final answer is:\n",
|
||||||
|
"\n",
|
||||||
|
"Carlos Alcaraz\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||||
|
"\u001b[33;1m\u001b[1;3m\n",
|
||||||
|
"So the final answer is:\n",
|
||||||
|
"\n",
|
||||||
|
"Carlos Alcaraz\u001b[0m\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model_lab.compare(\"What is the hometown of the reigning men's U.S. Open champion?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "94159131",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -27,6 +27,8 @@ class VectorDBQA(Chain, BaseModel):
|
|||||||
"""LLM wrapper to use."""
|
"""LLM wrapper to use."""
|
||||||
vectorstore: VectorStore
|
vectorstore: VectorStore
|
||||||
"""Vector Database to connect to."""
|
"""Vector Database to connect to."""
|
||||||
|
k: int = 4
|
||||||
|
"""Number of documents to query for."""
|
||||||
input_key: str = "query" #: :meta private:
|
input_key: str = "query" #: :meta private:
|
||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
|
|
||||||
@ -55,7 +57,7 @@ class VectorDBQA(Chain, BaseModel):
|
|||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
question = inputs[self.input_key]
|
question = inputs[self.input_key]
|
||||||
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
|
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
|
||||||
docs = self.vectorstore.similarity_search(question)
|
docs = self.vectorstore.similarity_search(question, k=self.k)
|
||||||
contexts = []
|
contexts = []
|
||||||
for j, doc in enumerate(docs):
|
for j, doc in enumerate(docs):
|
||||||
contexts.append(f"Context {j}:\n{doc.page_content}")
|
contexts.append(f"Context {j}:\n{doc.page_content}")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Experiment with different models."""
|
"""Experiment with different models."""
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import get_color_mapping, print_text
|
from langchain.input import get_color_mapping, print_text
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
@ -10,7 +11,41 @@ from langchain.prompts.prompt import Prompt
|
|||||||
class ModelLaboratory:
|
class ModelLaboratory:
|
||||||
"""Experiment with different models."""
|
"""Experiment with different models."""
|
||||||
|
|
||||||
def __init__(self, llms: List[LLM], prompt: Optional[Prompt] = None):
|
def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None):
|
||||||
|
"""Initialize with chains to experiment with.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chains: list of chains to experiment with.
|
||||||
|
"""
|
||||||
|
if not isinstance(chains[0], Chain):
|
||||||
|
raise ValueError(
|
||||||
|
"ModelLaboratory should now be initialized with Chains. "
|
||||||
|
"If you want to initialize with LLMs, use the `from_llms` method "
|
||||||
|
"instead (`ModelLaboratory.from_llms(...)`)"
|
||||||
|
)
|
||||||
|
for chain in chains:
|
||||||
|
if len(chain.input_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently only support chains with one input variable, "
|
||||||
|
f"got {chain.input_keys}"
|
||||||
|
)
|
||||||
|
if len(chain.output_keys) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently only support chains with one output variable, "
|
||||||
|
f"got {chain.output_keys}"
|
||||||
|
)
|
||||||
|
if names is not None:
|
||||||
|
if len(names) != len(chains):
|
||||||
|
raise ValueError("Length of chains does not match length of names.")
|
||||||
|
self.chains = chains
|
||||||
|
chain_range = [str(i) for i in range(len(self.chains))]
|
||||||
|
self.chain_colors = get_color_mapping(chain_range)
|
||||||
|
self.names = names
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llms(
|
||||||
|
cls, llms: List[LLM], prompt: Optional[Prompt] = None
|
||||||
|
) -> "ModelLaboratory":
|
||||||
"""Initialize with LLMs to experiment with and optional prompt.
|
"""Initialize with LLMs to experiment with and optional prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -18,18 +53,11 @@ class ModelLaboratory:
|
|||||||
prompt: Optional prompt to use to prompt the LLMs. Defaults to None.
|
prompt: Optional prompt to use to prompt the LLMs. Defaults to None.
|
||||||
If a prompt was provided, it should only have one input variable.
|
If a prompt was provided, it should only have one input variable.
|
||||||
"""
|
"""
|
||||||
self.llms = llms
|
|
||||||
llm_range = [str(i) for i in range(len(self.llms))]
|
|
||||||
self.llm_colors = get_color_mapping(llm_range)
|
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
self.prompt = Prompt(input_variables=["_input"], template="{_input}")
|
prompt = Prompt(input_variables=["_input"], template="{_input}")
|
||||||
else:
|
chains = [LLMChain(llm=llm, prompt=prompt) for llm in llms]
|
||||||
if len(prompt.input_variables) != 1:
|
names = [str(llm) for llm in llms]
|
||||||
raise ValueError(
|
return cls(chains, names=names)
|
||||||
"Currently only support prompts with one input variable, "
|
|
||||||
f"got {prompt}"
|
|
||||||
)
|
|
||||||
self.prompt = prompt
|
|
||||||
|
|
||||||
def compare(self, text: str) -> None:
|
def compare(self, text: str) -> None:
|
||||||
"""Compare model outputs on an input text.
|
"""Compare model outputs on an input text.
|
||||||
@ -42,9 +70,11 @@ class ModelLaboratory:
|
|||||||
text: input text to run all models on.
|
text: input text to run all models on.
|
||||||
"""
|
"""
|
||||||
print(f"\033[1mInput:\033[0m\n{text}\n")
|
print(f"\033[1mInput:\033[0m\n{text}\n")
|
||||||
for i, llm in enumerate(self.llms):
|
for i, chain in enumerate(self.chains):
|
||||||
print_text(str(llm), end="\n")
|
if self.names is not None:
|
||||||
chain = LLMChain(llm=llm, prompt=self.prompt)
|
name = self.names[i]
|
||||||
llm_inputs = {self.prompt.input_variables[0]: text}
|
else:
|
||||||
output = chain.predict(**llm_inputs)
|
name = str(chain)
|
||||||
print_text(output, color=self.llm_colors[str(i)], end="\n\n")
|
print_text(name, end="\n")
|
||||||
|
output = chain.run(text)
|
||||||
|
print_text(output, color=self.chain_colors[str(i)], end="\n\n")
|
||||||
|
Loading…
Reference in New Issue
Block a user