From db37bd089fc18c8215da42202dfadc397b20d26c Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 8 Nov 2022 22:17:10 -0800 Subject: [PATCH] model laboratory (#95) --- examples/model_laboratory.ipynb | 147 +++++++++++++++++++++++++ langchain/input.py | 6 +- langchain/llms/base.py | 12 +- langchain/llms/cohere.py | 31 ++++-- langchain/llms/huggingface_hub.py | 13 ++- langchain/llms/nlpcloud.py | 7 +- langchain/llms/openai.py | 7 +- langchain/model_laboratory.py | 50 +++++++++ tests/unit_tests/chains/test_natbot.py | 6 +- tests/unit_tests/chains/test_react.py | 6 +- tests/unit_tests/llms/fake_llm.py | 6 +- 11 files changed, 268 insertions(+), 23 deletions(-) create mode 100644 examples/model_laboratory.ipynb create mode 100644 langchain/model_laboratory.py diff --git a/examples/model_laboratory.ipynb b/examples/model_laboratory.ipynb new file mode 100644 index 0000000000..b648ca49dc --- /dev/null +++ b/examples/model_laboratory.ipynb @@ -0,0 +1,147 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ab9e95ad", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain import LLMChain, OpenAI, Cohere, HuggingFaceHub, Prompt\n", + "from langchain.model_laboratory import ModelLaboratory" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "32cb94e6", + "metadata": {}, + "outputs": [], + "source": [ + "llms = [OpenAI(temperature=0), Cohere(model=\"command-xlarge-20221108\", max_tokens=20, temperature=0), HuggingFaceHub(repo_id=\"google/flan-t5-xl\", model_kwargs={\"temperature\":1})]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "14cde09d", + "metadata": {}, + "outputs": [], + "source": [ + "model_lab = ModelLaboratory(llms)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f186c741", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mInput:\u001b[0m\n", + "What color is a flamingo?\n", + "\n", + "\u001b[1mOpenAI\u001b[0m\n", + "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", + "\u001b[104m\n", + "\n", + "Flamingos are pink.\u001b[0m\n", + "\n", + "\u001b[1mCohere\u001b[0m\n", + "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", + "\u001b[103m\n", + "\n", + "Pink\u001b[0m\n", + "\n", + "\u001b[1mHuggingFaceHub\u001b[0m\n", + "Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n", + "\u001b[101mpink\u001b[0m\n", + "\n" + ] + } + ], + "source": [ + "model_lab.compare(\"What color is a flamingo?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "248b652a", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = Prompt(template=\"What is the capital of {state}?\", input_variables=[\"state\"])\n", + "model_lab_with_prompt = ModelLaboratory(llms, prompt=prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f64377ac", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mInput:\u001b[0m\n", + "New York\n", + "\n", + "\u001b[1mOpenAI\u001b[0m\n", + "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", + "\u001b[104m\n", + "\n", + "The capital of New York is Albany.\u001b[0m\n", + "\n", + "\u001b[1mCohere\u001b[0m\n", + "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", + "\u001b[103m\n", + "\n", + "The capital of New York is Albany.\u001b[0m\n", + "\n", + "\u001b[1mHuggingFaceHub\u001b[0m\n", + "Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n", + "\u001b[101mst john s\u001b[0m\n", + "\n" + ] + } + ], + "source": [ + "model_lab_with_prompt.compare(\"New York\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54336dbf", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/input.py b/langchain/input.py index 0d1b3b6daf..94fad90828 100644 --- a/langchain/input.py +++ b/langchain/input.py @@ -15,13 +15,13 @@ def get_color_mapping( return color_mapping -def print_text(text: str, color: Optional[str] = None) -> None: +def print_text(text: str, color: Optional[str] = None, end: str = "") -> None: """Print text with highlighting and no end characters.""" if color is None: - print(text, end="") + print(text, end=end) else: color_str = _COLOR_MAPPING[color] - print(f"\x1b[{color_str}m{text}\x1b[0m", end="") + print(f"\x1b[{color_str}m{text}\x1b[0m", end=end) class ChainedInput: diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 56382efde3..a45ae6069d 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -1,6 +1,6 @@ """Base interface for large language models to expose.""" from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Any, List, Mapping, Optional class LLM(ABC): @@ -9,3 +9,13 @@ class LLM(ABC): @abstractmethod def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Run the LLM on the given prompt and input.""" + + @property + @abstractmethod + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + + def __str__(self) -> str: + """Get a string representation of the object for printing.""" + cls_name = f"\033[1m{self.__class__.__name__}\033[0m" + return f"{cls_name}\nParams: {self._identifying_params}" diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index 7d456536e7..2a41b807ea 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -1,6 +1,6 @@ """Wrapper around Cohere APIs.""" import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator @@ -8,7 +8,7 @@ from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -class Cohere(BaseModel, LLM): +class Cohere(LLM, BaseModel): """Wrapper around Cohere large language models. To use, you should have the ``cohere`` python package installed, and the @@ -73,6 +73,23 @@ class Cohere(BaseModel, LLM): ) return values + @property + def _default_params(self) -> Mapping[str, Any]: + """Get the default parameters for calling Cohere API.""" + return { + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "k": self.k, + "p": self.p, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + } + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model": self.model}, **self._default_params} + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to Cohere's generate endpoint. @@ -89,15 +106,7 @@ class Cohere(BaseModel, LLM): response = cohere("Tell me a joke.") """ response = self.client.generate( - model=self.model, - prompt=prompt, - max_tokens=self.max_tokens, - temperature=self.temperature, - k=self.k, - p=self.p, - frequency_penalty=self.frequency_penalty, - presence_penalty=self.presence_penalty, - stop_sequences=stop, + model=self.model, prompt=prompt, stop_sequences=stop, **self._default_params ) text = response.generations[0].text # If stop tokens are provided, Cohere's endpoint returns them. diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index fc83586eef..8d584558ae 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -1,6 +1,6 @@ """Wrapper around HuggingFace APIs.""" import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator @@ -11,7 +11,7 @@ DEFAULT_REPO_ID = "gpt2" VALID_TASKS = ("text2text-generation", "text-generation") -class HuggingFaceHub(BaseModel, LLM): +class HuggingFaceHub(LLM, BaseModel): """Wrapper around HuggingFaceHub models. To use, you should have the ``huggingface_hub`` python package installed, and the @@ -74,6 +74,12 @@ class HuggingFaceHub(BaseModel, LLM): ) return values + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} + return {**{"repo_id": self.repo_id}, **_model_kwargs} + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to HuggingFace Hub's inference endpoint. @@ -89,7 +95,8 @@ class HuggingFaceHub(BaseModel, LLM): response = hf("Tell me a joke.") """ - response = self.client(inputs=prompt, params=self.model_kwargs) + _model_kwargs = self.model_kwargs or {} + response = self.client(inputs=prompt, params=_model_kwargs) if "error" in response: raise ValueError(f"Error raised by inference API: {response['error']}") if self.client.task == "text-generation": diff --git a/langchain/llms/nlpcloud.py b/langchain/llms/nlpcloud.py index 9d28bb2407..6a8bced270 100644 --- a/langchain/llms/nlpcloud.py +++ b/langchain/llms/nlpcloud.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -class NLPCloud(BaseModel, LLM): +class NLPCloud(LLM, BaseModel): """Wrapper around NLPCloud large language models. To use, you should have the ``nlpcloud`` python package installed, and the @@ -106,6 +106,11 @@ class NLPCloud(BaseModel, LLM): "num_return_sequences": self.num_return_sequences, } + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model_name": self.model_name}, **self._default_params} + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to NLPCloud's create endpoint. diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index ad86c22fff..2355015b64 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -class OpenAI(BaseModel, LLM): +class OpenAI(LLM, BaseModel): """Wrapper around OpenAI large language models. To use, you should have the ``openai`` python package installed, and the @@ -81,6 +81,11 @@ class OpenAI(BaseModel, LLM): "best_of": self.best_of, } + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model": self.model_name}, **self._default_params} + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Call out to OpenAI's create endpoint. diff --git a/langchain/model_laboratory.py b/langchain/model_laboratory.py new file mode 100644 index 0000000000..0243f70e88 --- /dev/null +++ b/langchain/model_laboratory.py @@ -0,0 +1,50 @@ +"""Experiment with different models.""" +from typing import List, Optional + +from langchain.chains.llm import LLMChain +from langchain.input import get_color_mapping, print_text +from langchain.llms.base import LLM +from langchain.prompts.prompt import Prompt + + +class ModelLaboratory: + """Experiment with different models.""" + + def __init__(self, llms: List[LLM], prompt: Optional[Prompt] = None): + """Initialize with LLMs to experiment with and optional prompt. + + Args: + llms: list of LLMs to experiment with + prompt: Optional prompt to use to prompt the LLMs. Defaults to None. + If a prompt was provided, it should only have one input variable. + """ + self.llms = llms + llm_range = [str(i) for i in range(len(self.llms))] + self.llm_colors = get_color_mapping(llm_range) + if prompt is None: + self.prompt = Prompt(input_variables=["_input"], template="{_input}") + else: + if len(prompt.input_variables) != 1: + raise ValueError( + "Currently only support prompts with one input variable, " + f"got {prompt}" + ) + self.prompt = prompt + + def compare(self, text: str) -> None: + """Compare model outputs on an input text. + + If a prompt was provided with starting the laboratory, then this text will be + fed into the prompt. If no prompt was provided, then the input text is the + entire prompt. + + Args: + text: input text to run all models on. + """ + print(f"\033[1mInput:\033[0m\n{text}\n") + for i, llm in enumerate(self.llms): + print_text(str(llm), end="\n") + chain = LLMChain(llm=llm, prompt=self.prompt) + llm_inputs = {self.prompt.input_variables[0]: text} + output = chain.predict(**llm_inputs) + print_text(output, color=self.llm_colors[str(i)], end="\n\n") diff --git a/tests/unit_tests/chains/test_natbot.py b/tests/unit_tests/chains/test_natbot.py index d2701f8b54..622237ae31 100644 --- a/tests/unit_tests/chains/test_natbot.py +++ b/tests/unit_tests/chains/test_natbot.py @@ -1,6 +1,6 @@ """Test functionality related to natbot.""" -from typing import List, Optional +from typing import Any, List, Mapping, Optional from langchain.chains.natbot.base import NatBotChain from langchain.llms.base import LLM @@ -16,6 +16,10 @@ class FakeLLM(LLM): else: return "bar" + @property + def _identifying_params(self) -> Mapping[str, Any]: + return {} + def test_proper_inputs() -> None: """Test that natbot shortens inputs correctly.""" diff --git a/tests/unit_tests/chains/test_react.py b/tests/unit_tests/chains/test_react.py index e5c22dd4be..4490be08b0 100644 --- a/tests/unit_tests/chains/test_react.py +++ b/tests/unit_tests/chains/test_react.py @@ -1,6 +1,6 @@ """Unit tests for ReAct.""" -from typing import List, Optional, Union +from typing import Any, List, Mapping, Optional, Union import pytest @@ -35,6 +35,10 @@ class FakeListLLM(LLM): self.i += 1 return self.responses[self.i] + @property + def _identifying_params(self) -> Mapping[str, Any]: + return {} + class FakeDocstore(Docstore): """Fake docstore for testing purposes.""" diff --git a/tests/unit_tests/llms/fake_llm.py b/tests/unit_tests/llms/fake_llm.py index f9d6387b04..65fd8c7936 100644 --- a/tests/unit_tests/llms/fake_llm.py +++ b/tests/unit_tests/llms/fake_llm.py @@ -1,5 +1,5 @@ """Fake LLM wrapper for testing purposes.""" -from typing import List, Mapping, Optional +from typing import Any, List, Mapping, Optional from langchain.llms.base import LLM @@ -19,3 +19,7 @@ class FakeLLM(LLM): return "foo" else: return "bar" + + @property + def _identifying_params(self) -> Mapping[str, Any]: + return {}