From b15c84e19d31722693f6388b1b59418e88d32d3d Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 18 Nov 2022 05:50:02 -0800 Subject: [PATCH] Harrison/chain lab (#156) --- docs/examples/model_laboratory.ipynb | 111 +++++++++++++++++++++++--- langchain/chains/vector_db_qa/base.py | 4 +- langchain/model_laboratory.py | 68 +++++++++++----- 3 files changed, 154 insertions(+), 29 deletions(-) diff --git a/docs/examples/model_laboratory.ipynb b/docs/examples/model_laboratory.ipynb index 0646386e..8c5af92f 100644 --- a/docs/examples/model_laboratory.ipynb +++ b/docs/examples/model_laboratory.ipynb @@ -42,7 +42,7 @@ "metadata": {}, "outputs": [], "source": [ - "model_lab = ModelLaboratory(llms)" + "model_lab = ModelLaboratory.from_llms(llms)" ] }, { @@ -60,19 +60,19 @@ "\n", "\u001b[1mOpenAI\u001b[0m\n", "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", - "\u001b[104m\n", + "\u001b[36;1m\u001b[1;3m\n", "\n", "Flamingos are pink.\u001b[0m\n", "\n", "\u001b[1mCohere\u001b[0m\n", "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", - "\u001b[103m\n", + "\u001b[33;1m\u001b[1;3m\n", "\n", "Pink\u001b[0m\n", "\n", "\u001b[1mHuggingFaceHub\u001b[0m\n", "Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n", - "\u001b[101mpink\u001b[0m\n", + "\u001b[38;5;200m\u001b[1;3mpink\u001b[0m\n", "\n" ] } @@ -89,7 +89,7 @@ "outputs": [], "source": [ "prompt = Prompt(template=\"What is the capital of {state}?\", input_variables=[\"state\"])\n", - "model_lab_with_prompt = ModelLaboratory(llms, prompt=prompt)" + "model_lab_with_prompt = ModelLaboratory.from_llms(llms, prompt=prompt)" ] }, { @@ -107,19 +107,19 @@ "\n", "\u001b[1mOpenAI\u001b[0m\n", "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", - "\u001b[104m\n", + "\u001b[36;1m\u001b[1;3m\n", "\n", "The capital of New York is Albany.\u001b[0m\n", "\n", "\u001b[1mCohere\u001b[0m\n", "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", - "\u001b[103m\n", + "\u001b[33;1m\u001b[1;3m\n", "\n", "The capital of New York is Albany.\u001b[0m\n", "\n", "\u001b[1mHuggingFaceHub\u001b[0m\n", "Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n", - "\u001b[101mst john s\u001b[0m\n", + "\u001b[38;5;200m\u001b[1;3mst john s\u001b[0m\n", "\n" ] } @@ -130,10 +130,103 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "54336dbf", "metadata": {}, "outputs": [], + "source": [ + "from langchain import SelfAskWithSearchChain, SerpAPIChain\n", + "\n", + "open_ai_llm = OpenAI(temperature=0)\n", + "search = SerpAPIChain()\n", + "self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n", + "\n", + "cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n", + "search = SerpAPIChain()\n", + "self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6a50a9f1", + "metadata": {}, + "outputs": [], + "source": [ + "chains = [self_ask_with_search_openai, self_ask_with_search_cohere]\n", + "names = [str(open_ai_llm), str(cohere_llm)]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d3549e99", + "metadata": {}, + "outputs": [], + "source": [ + "model_lab = ModelLaboratory(chains, names=names)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "362f7f57", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mInput:\u001b[0m\n", + "What is the hometown of the reigning men's U.S. Open champion?\n", + "\n", + "\u001b[1mOpenAI\u001b[0m\n", + "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "What is the hometown of the reigning men's U.S. Open champion?\n", + "Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n", + "Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n", + "Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Follow up: Where is Carlos Alcaraz from?\u001b[0m\n", + "Intermediate answer: \u001b[33;1m\u001b[1;3mEl Palmar, Spain.\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "So the final answer is: El Palmar, Spain\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\u001b[36;1m\u001b[1;3m\n", + "So the final answer is: El Palmar, Spain\u001b[0m\n", + "\n", + "\u001b[1mCohere\u001b[0m\n", + "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 256, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "What is the hometown of the reigning men's U.S. Open champion?\n", + "Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n", + "Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n", + "Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "So the final answer is:\n", + "\n", + "Carlos Alcaraz\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\u001b[33;1m\u001b[1;3m\n", + "So the final answer is:\n", + "\n", + "Carlos Alcaraz\u001b[0m\n", + "\n" + ] + } + ], + "source": [ + "model_lab.compare(\"What is the hometown of the reigning men's U.S. Open champion?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94159131", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py index 3e010710..d54de11c 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/vector_db_qa/base.py @@ -27,6 +27,8 @@ class VectorDBQA(Chain, BaseModel): """LLM wrapper to use.""" vectorstore: VectorStore """Vector Database to connect to.""" + k: int = 4 + """Number of documents to query for.""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @@ -55,7 +57,7 @@ class VectorDBQA(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: question = inputs[self.input_key] llm_chain = LLMChain(llm=self.llm, prompt=prompt) - docs = self.vectorstore.similarity_search(question) + docs = self.vectorstore.similarity_search(question, k=self.k) contexts = [] for j, doc in enumerate(docs): contexts.append(f"Context {j}:\n{doc.page_content}") diff --git a/langchain/model_laboratory.py b/langchain/model_laboratory.py index 0243f70e..d61265c0 100644 --- a/langchain/model_laboratory.py +++ b/langchain/model_laboratory.py @@ -1,6 +1,7 @@ """Experiment with different models.""" -from typing import List, Optional +from typing import List, Optional, Sequence +from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping, print_text from langchain.llms.base import LLM @@ -10,7 +11,41 @@ from langchain.prompts.prompt import Prompt class ModelLaboratory: """Experiment with different models.""" - def __init__(self, llms: List[LLM], prompt: Optional[Prompt] = None): + def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None): + """Initialize with chains to experiment with. + + Args: + chains: list of chains to experiment with. + """ + if not isinstance(chains[0], Chain): + raise ValueError( + "ModelLaboratory should now be initialized with Chains. " + "If you want to initialize with LLMs, use the `from_llms` method " + "instead (`ModelLaboratory.from_llms(...)`)" + ) + for chain in chains: + if len(chain.input_keys) != 1: + raise ValueError( + "Currently only support chains with one input variable, " + f"got {chain.input_keys}" + ) + if len(chain.output_keys) != 1: + raise ValueError( + "Currently only support chains with one output variable, " + f"got {chain.output_keys}" + ) + if names is not None: + if len(names) != len(chains): + raise ValueError("Length of chains does not match length of names.") + self.chains = chains + chain_range = [str(i) for i in range(len(self.chains))] + self.chain_colors = get_color_mapping(chain_range) + self.names = names + + @classmethod + def from_llms( + cls, llms: List[LLM], prompt: Optional[Prompt] = None + ) -> "ModelLaboratory": """Initialize with LLMs to experiment with and optional prompt. Args: @@ -18,18 +53,11 @@ class ModelLaboratory: prompt: Optional prompt to use to prompt the LLMs. Defaults to None. If a prompt was provided, it should only have one input variable. """ - self.llms = llms - llm_range = [str(i) for i in range(len(self.llms))] - self.llm_colors = get_color_mapping(llm_range) if prompt is None: - self.prompt = Prompt(input_variables=["_input"], template="{_input}") - else: - if len(prompt.input_variables) != 1: - raise ValueError( - "Currently only support prompts with one input variable, " - f"got {prompt}" - ) - self.prompt = prompt + prompt = Prompt(input_variables=["_input"], template="{_input}") + chains = [LLMChain(llm=llm, prompt=prompt) for llm in llms] + names = [str(llm) for llm in llms] + return cls(chains, names=names) def compare(self, text: str) -> None: """Compare model outputs on an input text. @@ -42,9 +70,11 @@ class ModelLaboratory: text: input text to run all models on. """ print(f"\033[1mInput:\033[0m\n{text}\n") - for i, llm in enumerate(self.llms): - print_text(str(llm), end="\n") - chain = LLMChain(llm=llm, prompt=self.prompt) - llm_inputs = {self.prompt.input_variables[0]: text} - output = chain.predict(**llm_inputs) - print_text(output, color=self.llm_colors[str(i)], end="\n\n") + for i, chain in enumerate(self.chains): + if self.names is not None: + name = self.names[i] + else: + name = str(chain) + print_text(name, end="\n") + output = chain.run(text) + print_text(output, color=self.chain_colors[str(i)], end="\n\n")