From e7e5cb9d089d9e74484ad95cbc324b1c5d9b860f Mon Sep 17 00:00:00 2001 From: Vadim Gubergrits Date: Thu, 27 Jul 2023 00:29:39 -0400 Subject: [PATCH] Tree of Thought introducing a new ToTChain. (#5167) # [WIP] Tree of Thought introducing a new ToTChain. This PR adds a new chain called ToTChain that implements the ["Large Language Model Guided Tree-of-Though"](https://arxiv.org/pdf/2305.08291.pdf) paper. There's a notebook example `docs/modules/chains/examples/tot.ipynb` that shows how to use it. Implements #4975 ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: - @hwchase17 - @vowelparrot --------- Co-authored-by: Vadim Gubergrits Co-authored-by: Harrison Chase --- .../modules/chains/additional/tot.ipynb | 239 ++++++++++++++++++ .../langchain_experimental/tot/__init__.py | 4 + .../langchain_experimental/tot/base.py | 155 ++++++++++++ .../langchain_experimental/tot/checker.py | 52 ++++ .../langchain_experimental/tot/controller.py | 54 ++++ .../langchain_experimental/tot/memory.py | 46 ++++ .../langchain_experimental/tot/prompts.py | 137 ++++++++++ .../langchain_experimental/tot/thought.py | 21 ++ .../tot/thought_generation.py | 94 +++++++ .../experimental/tests/unit_tests/fake_llm.py | 61 +++++ .../experimental/tests/unit_tests/test_tot.py | 151 +++++++++++ 11 files changed, 1014 insertions(+) create mode 100644 docs/extras/modules/chains/additional/tot.ipynb create mode 100644 libs/experimental/langchain_experimental/tot/__init__.py create mode 100644 libs/experimental/langchain_experimental/tot/base.py create mode 100644 libs/experimental/langchain_experimental/tot/checker.py create mode 100644 libs/experimental/langchain_experimental/tot/controller.py create mode 100644 libs/experimental/langchain_experimental/tot/memory.py create mode 100644 libs/experimental/langchain_experimental/tot/prompts.py create mode 100644 libs/experimental/langchain_experimental/tot/thought.py create mode 100644 libs/experimental/langchain_experimental/tot/thought_generation.py create mode 100644 libs/experimental/tests/unit_tests/fake_llm.py create mode 100644 libs/experimental/tests/unit_tests/test_tot.py diff --git a/docs/extras/modules/chains/additional/tot.ipynb b/docs/extras/modules/chains/additional/tot.ipynb new file mode 100644 index 0000000000..03f385514c --- /dev/null +++ b/docs/extras/modules/chains/additional/tot.ipynb @@ -0,0 +1,239 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tree of Thouht (ToT) example\n", + "\n", + "The Tree of Thought (ToT) is a chain that allows you to query a Large Language Model (LLM) using the Tree of Thought technique. This is based on the papaer [\"Large Language Model Guided Tree-of-Thought\"](https://arxiv.org/pdf/2305.08291.pdf)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.13) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from langchain.llms import OpenAI\n", + "\n", + "llm = OpenAI(temperature=1, max_tokens=512, model=\"text-davinci-003\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\n", + "\n", + "- This is a 4x4 Sudoku puzzle.\n", + "- The * represents a cell to be filled.\n", + "- The | character separates rows.\n", + "- At each step, replace one or more * with digits 1-4.\n", + "- There must be no duplicate digits in any row, column or 2x2 subgrid.\n", + "- Keep the known digits from previous valid thoughts in place.\n", + "- Each thought can be a partial or the final solution.\n" + ] + } + ], + "source": [ + "sudoku_puzzle = \"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\"\n", + "sudoku_solution = \"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\"\n", + "problem_description = f\"\"\"\n", + "{sudoku_puzzle}\n", + "\n", + "- This is a 4x4 Sudoku puzzle.\n", + "- The * represents a cell to be filled.\n", + "- The | character separates rows.\n", + "- At each step, replace one or more * with digits 1-4.\n", + "- There must be no duplicate digits in any row, column or 2x2 subgrid.\n", + "- Keep the known digits from previous valid thoughts in place.\n", + "- Each thought can be a partial or the final solution.\n", + "\"\"\".strip()\n", + "print(problem_description)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rules Based Checker\n", + "\n", + "Each thought is evaluated by the thought checker and is given a validity type: valid, invalid or partial. A simple checker can be rule based. For example, in the case of a sudoku puzzle, the checker can check if the puzzle is valid, invalid or partial.\n", + "\n", + "In the following code we implement a simple rule based checker for a specific 4x4 sudoku puzzle.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Tuple\n", + "from langchain_experimental.tot.checker import ToTChecker\n", + "from langchain_experimental.tot.thought import ThoughtValidity\n", + "import re\n", + "\n", + "class MyChecker(ToTChecker):\n", + " def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity:\n", + " last_thought = thoughts[-1]\n", + " clean_solution = last_thought.replace(\" \", \"\").replace('\"', \"\")\n", + " regex_solution = clean_solution.replace(\"*\", \".\").replace(\"|\", \"\\\\|\")\n", + " if sudoku_solution in clean_solution:\n", + " return ThoughtValidity.VALID_FINAL\n", + " elif re.search(regex_solution, sudoku_solution):\n", + " return ThoughtValidity.VALID_INTERMEDIATE\n", + " else:\n", + " return ThoughtValidity.INVALID" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Just testing the MyChecker class above:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "checker = MyChecker()\n", + "assert checker.evaluate(\"\", (\"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\",)) == ThoughtValidity.VALID_INTERMEDIATE\n", + "assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\",)) == ThoughtValidity.VALID_FINAL\n", + "assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1\",)) == ThoughtValidity.VALID_INTERMEDIATE\n", + "assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1\",)) == ThoughtValidity.INVALID" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tree of Thought Chain\n", + "\n", + "Initialize and run the ToT chain, with maximum number of interactions `k` set to `30` and the maximum number child thoughts `c` set to `8`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ToTChain chain...\u001b[0m\n", + "Starting the ToT solve procedure.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/harrisonchase/workplace/langchain/libs/langchain/langchain/chains/llm.py:275: UserWarning: The predict_and_parse method is deprecated, instead pass an output parser directly to LLMChain.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[31;1m\u001b[1;3mThought: 3*,*,2|1*,3,*|*,1,*,3|4,*,*,1\n", + "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,*|*,1,*,3|4,*,*,1\n", + "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,*,3|4,*,*,1\n", + "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,2,3|4,*,*,1\n", + "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|2,1,*,3|4,*,*,1\n", + "\u001b[0m" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Type not serializable\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n", + "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|*,3,2,*|*,1,*,3|4,1,*,*\n", + "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n", + "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|1,1,*,3|4,1,*,*\n", + "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,1,3,*|1,1,*,3|4,1,*,*\n", + "\u001b[0m\u001b[33;1m\u001b[1;3mThought: 3,*,*,2|1,2,3,*|*,1,*,3|4,*,*,1\n", + "\u001b[0m\u001b[31;1m\u001b[1;3m Thought: 3,1,4,2|1,2,3,4|2,1,4,3|4,3,2,1\n", + "\u001b[0m\u001b[32;1m\u001b[1;3m Thought: 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\n", + "\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_experimental.tot.base import ToTChain\n", + "\n", + "tot_chain = ToTChain(llm=llm, checker=MyChecker(), k=30, c=5, verbose=True, verbose_llm=False)\n", + "tot_chain.run(problem_description=problem_description)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/experimental/langchain_experimental/tot/__init__.py b/libs/experimental/langchain_experimental/tot/__init__.py new file mode 100644 index 0000000000..ecb8d7f017 --- /dev/null +++ b/libs/experimental/langchain_experimental/tot/__init__.py @@ -0,0 +1,4 @@ +from langchain_experimental.tot.base import ToTChain +from langchain_experimental.tot.checker import ToTChecker + +__all__ = ["ToTChain", "ToTChecker"] diff --git a/libs/experimental/langchain_experimental/tot/base.py b/libs/experimental/langchain_experimental/tot/base.py new file mode 100644 index 0000000000..027d6c8b4e --- /dev/null +++ b/libs/experimental/langchain_experimental/tot/base.py @@ -0,0 +1,155 @@ +""" +This a Tree of Thought (ToT) chain based on the paper "Large Language Model +Guided Tree-of-Thought" + +https://arxiv.org/pdf/2305.08291.pdf + +The Tree of Thought (ToT) chain uses a tree structure to explore the space of +possible solutions to a problem. + +""" + +from __future__ import annotations + +from textwrap import indent +from typing import Any, Dict, List, Optional, Type + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) +from langchain.chains.base import Chain +from pydantic import Extra + +from langchain_experimental.tot.checker import ToTChecker +from langchain_experimental.tot.controller import ToTController +from langchain_experimental.tot.memory import ToTDFSMemory +from langchain_experimental.tot.thought import Thought, ThoughtValidity +from langchain_experimental.tot.thought_generation import ( + BaseThoughtGenerationStrategy, + ProposePromptStrategy, +) + + +class ToTChain(Chain): + """ + A Chain implementing the Tree of Thought (ToT). + """ + + llm: BaseLanguageModel + """ + Language model to use. It must be set to produce different variations for + the same prompt. + """ + checker: ToTChecker + """ToT Checker to use.""" + output_key: str = "response" #: :meta private: + k: int = 10 + """The maximmum number of conversation rounds""" + c: int = 3 + """The number of children to explore at each node""" + tot_memory: ToTDFSMemory = ToTDFSMemory() + tot_controller: ToTController = ToTController() + tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy + verbose_llm: bool = False + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @classmethod + def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain: + """ + Create a ToTChain from a language model. + + :param llm: The language model to use. + :param kwargs: Additional arguments to pass to the ToTChain constructor. + """ + return cls(llm=llm, **kwargs) + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self.tot_controller.c = self.c + + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects. + + :meta private: + """ + return ["problem_description"] + + @property + def output_keys(self) -> List[str]: + """Will always return text key. + + :meta private: + """ + return [self.output_key] + + def log_thought( + self, + thought: Thought, + level: int, + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> None: + if run_manager: + colors = { + ThoughtValidity.VALID_FINAL: "green", + ThoughtValidity.VALID_INTERMEDIATE: "yellow", + ThoughtValidity.INVALID: "red", + } + text = indent(f"Thought: {thought.text}\n", prefix=" " * level) + run_manager.on_text( + text=text, color=colors[thought.validity], verbose=self.verbose + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + if run_manager: + run_manager.on_text(text="Starting the ToT solve procedure.\n") + + problem_description = inputs["problem_description"] + checker_inputs = {"problem_description": problem_description} + thoughts_path: tuple[str, ...] = () + thought_generator = self.tot_strategy_class( + llm=self.llm, c=self.c, verbose=self.verbose_llm + ) + + level = 0 + for _ in range(self.k): + level = self.tot_memory.level + thought_text = thought_generator.next_thought( + problem_description, thoughts_path, callbacks=_run_manager.get_child() + ) + checker_inputs["thoughts"] = thoughts_path + (thought_text,) + thought_validity = self.checker( + checker_inputs, callbacks=_run_manager.get_child() + )["validity"] + thought = Thought(text=thought_text, validity=thought_validity) + if thought.validity == ThoughtValidity.VALID_FINAL: + self.log_thought(thought, level, run_manager) + return {self.output_key: thought.text} + self.tot_memory.store(thought) + self.log_thought(thought, level, run_manager) + thoughts_path = self.tot_controller(self.tot_memory) + + return {self.output_key: "No solution found"} + + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + raise NotImplementedError("Async not implemented yet") + + @property + def _chain_type(self) -> str: + return "tot" diff --git a/libs/experimental/langchain_experimental/tot/checker.py b/libs/experimental/langchain_experimental/tot/checker.py new file mode 100644 index 0000000000..039ec7d5db --- /dev/null +++ b/libs/experimental/langchain_experimental/tot/checker.py @@ -0,0 +1,52 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple + +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.chains.base import Chain + +from langchain_experimental.tot.thought import ThoughtValidity + + +class ToTChecker(Chain, ABC): + """ + Tree of Thought (ToT) checker. + + This is an abstract ToT checker that must be implemented by the user. You + can implement a simple rule-based checker or a more sophisticated + neural network based classifier. + """ + + output_key: str = "validity" #: :meta private: + + @property + def input_keys(self) -> List[str]: + """The checker input keys. + + :meta private: + """ + return ["problem_description", "thoughts"] + + @property + def output_keys(self) -> List[str]: + """The checker output keys. + + :meta private: + """ + return [self.output_key] + + @abstractmethod + def evaluate( + self, + problem_description: str, + thoughts: Tuple[str, ...] = (), + ) -> ThoughtValidity: + """ + Evaluate the response to the problem description and return the solution type. + """ + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, ThoughtValidity]: + return {self.output_key: self.evaluate(**inputs)} diff --git a/libs/experimental/langchain_experimental/tot/controller.py b/libs/experimental/langchain_experimental/tot/controller.py new file mode 100644 index 0000000000..d2a7a6fbd3 --- /dev/null +++ b/libs/experimental/langchain_experimental/tot/controller.py @@ -0,0 +1,54 @@ +from typing import Tuple + +from langchain_experimental.tot.memory import ToTDFSMemory +from langchain_experimental.tot.thought import ThoughtValidity + + +class ToTController: + """ + Tree of Thought (ToT) controller. + + This is a version of a ToT controller, dubbed in the paper as a "Simple + Controller". + + It has one parameter `c` which is the number of children to explore for each + thought. + """ + + def __init__(self, c: int = 3): + """ + Initialize the controller. + + Args: + c: The number of children to explore at each node. + """ + self.c = c + + def __call__(self, memory: ToTDFSMemory) -> Tuple[str, ...]: + next_thought = memory.top() + parent_thought = memory.top_parent() + validity = ( + ThoughtValidity.VALID_INTERMEDIATE + if next_thought is None + else next_thought.validity + ) + + # 1 if the current partial solution is invalid, backtrack to the parent + # thought. + if validity == ThoughtValidity.INVALID: + memory.pop() + next_thought = memory.top() + if next_thought and len(next_thought.children) >= self.c: + memory.pop() + + # 2 if the current partial solution is valid but C children were + # explored and yet failed to find a final solution, backtrack to the + # parent thought. + elif ( + validity == ThoughtValidity.VALID_INTERMEDIATE + and parent_thought + and len(parent_thought.children) >= self.c + ): + memory.pop(2) + + return tuple(thought.text for thought in memory.current_path()) diff --git a/libs/experimental/langchain_experimental/tot/memory.py b/libs/experimental/langchain_experimental/tot/memory.py new file mode 100644 index 0000000000..c63ff227b1 --- /dev/null +++ b/libs/experimental/langchain_experimental/tot/memory.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import List, Optional + +from langchain_experimental.tot.thought import Thought + + +class ToTDFSMemory: + """ + Memory for the Tree of Thought (ToT) chain. Implemented as a stack of + thoughts. This allows for a depth first search (DFS) of the ToT. + """ + + def __init__(self, stack: Optional[List[Thought]] = None): + self.stack: list[Thought] = stack or [] + + def top(self) -> Optional[Thought]: + "Get the top of the stack without popping it." + return self.stack[-1] if len(self.stack) > 0 else None + + def pop(self, n: int = 1) -> Optional[Thought]: + "Pop the top n elements of the stack and return the last one." + if len(self.stack) < n: + return None + for _ in range(n): + node = self.stack.pop() + return node + + def top_parent(self) -> Optional[Thought]: + "Get the parent of the top of the stack without popping it." + return self.stack[-2] if len(self.stack) > 1 else None + + def store(self, node: Thought) -> None: + "Add a node on the top of the stack." + if len(self.stack) > 0: + self.stack[-1].children.add(node) + self.stack.append(node) + + @property + def level(self) -> int: + "Return the current level of the stack." + return len(self.stack) + + def current_path(self) -> List[Thought]: + "Return the thoughts path." + return self.stack[:] diff --git a/libs/experimental/langchain_experimental/tot/prompts.py b/libs/experimental/langchain_experimental/tot/prompts.py new file mode 100644 index 0000000000..cd368933fe --- /dev/null +++ b/libs/experimental/langchain_experimental/tot/prompts.py @@ -0,0 +1,137 @@ +import json +from textwrap import dedent +from typing import List + +from langchain.prompts import PromptTemplate +from langchain.schema import BaseOutputParser + +from langchain_experimental.tot.thought import ThoughtValidity + +COT_PROMPT = PromptTemplate( + template_format="jinja2", + input_variables=["problem_description", "thoughts"], + template=dedent( + """ + You are an intelligent agent that is generating one thought at a time in + a tree of thoughts setting. + + PROBLEM + + {{problem_description}} + + {% if thoughts %} + THOUGHTS + + {% for thought in thoughts %} + {{ thought }} + {% endfor %} + {% endif %} + + Let's think step by step. + """ + ).strip(), +) + + +class JSONListOutputParser(BaseOutputParser): + """Class to parse the output of a PROPOSE_PROMPT response.""" + + @property + def _type(self) -> str: + return "json_list" + + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + + json_string = text.split("```json")[1].strip().strip("```").strip() + try: + return json.loads(json_string) + except json.JSONDecodeError: + return [] + + +PROPOSE_PROMPT = PromptTemplate( + template_format="jinja2", + input_variables=["problem_description", "thoughts", "n"], + output_parser=JSONListOutputParser(), + template=dedent( + """ + You are an intelligent agent that is generating thoughts in a tree of + thoughts setting. + + The output should be a markdown code snippet formatted as a JSON list of + strings, including the leading and trailing "```json" and "```": + + ```json + [ + "", + "", + "" + ] + ``` + + PROBLEM + + {{ problem_description }} + + {% if thoughts %} + VALID THOUGHTS + + {% for thought in thoughts %} + {{ thought }} + {% endfor %} + + Possible next {{ n }} valid thoughts based on the last valid thought: + {% else %} + + Possible next {{ n }} valid thoughts based on the PROBLEM: + {%- endif -%} + """ + ).strip(), +) + + +class CheckerOutputParser(BaseOutputParser): + def parse(self, text: str) -> ThoughtValidity: + """Parse the output of the language model.""" + text = text.upper() + if "INVALID" in text: + return ThoughtValidity.INVALID + elif "INTERMEDIATE" in text: + return ThoughtValidity.VALID_INTERMEDIATE + elif "VALID" in text: + return ThoughtValidity.VALID_FINAL + else: + return ThoughtValidity.INVALID + + @property + def _type(self) -> str: + return "tot_llm_checker_output" + + +CHECKER_PROMPT = PromptTemplate( + input_variables=["problem_description", "thoughts"], + template=dedent( + """ + You are an intelligent agent, validating thoughts of another intelligent agent. + + PROBLEM + + {problem_description} + + THOUGHTS + + {thoughts} + + Evaluate the thoughts and respond with one word. + + - Respond VALID if the last thought is a valid final solution to the + poblem. + - Respond INVALID if the last thought is invalid. + - Respond INTERMEDIATE if the last thought is valid but not the final + solution to the problem. + + This chain of thoughts is""" + ).strip(), + output_parser=CheckerOutputParser(), +) diff --git a/libs/experimental/langchain_experimental/tot/thought.py b/libs/experimental/langchain_experimental/tot/thought.py new file mode 100644 index 0000000000..d0567c8533 --- /dev/null +++ b/libs/experimental/langchain_experimental/tot/thought.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from enum import Enum +from typing import Set + +from pydantic import BaseModel, Field + + +class ThoughtValidity(Enum): + VALID_INTERMEDIATE = 0 + VALID_FINAL = 1 + INVALID = 2 + + +class Thought(BaseModel): + text: str + validity: ThoughtValidity + children: Set[Thought] = Field(default_factory=set) + + def __hash__(self) -> int: + return id(self) diff --git a/libs/experimental/langchain_experimental/tot/thought_generation.py b/libs/experimental/langchain_experimental/tot/thought_generation.py new file mode 100644 index 0000000000..7c41498467 --- /dev/null +++ b/libs/experimental/langchain_experimental/tot/thought_generation.py @@ -0,0 +1,94 @@ +""" +We provide two strategies for generating thoughts in the Tree of Thoughts (ToT) +framework to avoid repetition: + +These strategies ensure that the language model generates diverse and +non-repeating thoughts, which are crucial for problem-solving tasks that require +exploration. +""" +from abc import abstractmethod +from typing import Any, Dict, List, Tuple + +from langchain.chains.llm import LLMChain +from langchain.prompts.base import BasePromptTemplate +from pydantic import Field + +from langchain_experimental.tot.prompts import COT_PROMPT, PROPOSE_PROMPT + + +class BaseThoughtGenerationStrategy(LLMChain): + """ + Base class for a thought generation strategy. + """ + + c: int = 3 + """The number of children thoughts to propose at each step.""" + + @abstractmethod + def next_thought( + self, + problem_description: str, + thoughts_path: Tuple[str, ...] = (), + **kwargs: Any + ) -> str: + """ + Generate the next thought given the problem description and the thoughts + generated so far. + """ + + +class SampleCoTStrategy(BaseThoughtGenerationStrategy): + """ + Sample thoughts from a Chain-of-Thought (CoT) prompt. + + This strategy works better when the thought space is rich, such as when each + thought is a paragraph. Independent and identically distributed samples + lead to diversity, which helps to avoid repetition. + """ + + prompt: BasePromptTemplate = COT_PROMPT + + def next_thought( + self, + problem_description: str, + thoughts_path: Tuple[str, ...] = (), + **kwargs: Any + ) -> str: + response_text = self.predict_and_parse( + problem_description=problem_description, thoughts=thoughts_path, **kwargs + ) + return response_text if isinstance(response_text, str) else "" + + +class ProposePromptStrategy(BaseThoughtGenerationStrategy): + """ + Propose thoughts sequentially using a "propose prompt". + + This strategy works better when the thought space is more constrained, such + as when each thought is just a word or a line. Proposing different thoughts + in the same prompt completion helps to avoid duplication. + """ + + prompt: BasePromptTemplate = PROPOSE_PROMPT + tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict) + + def next_thought( + self, + problem_description: str, + thoughts_path: Tuple[str, ...] = (), + **kwargs: Any + ) -> str: + if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]: + new_thoughts = self.predict_and_parse( + problem_description=problem_description, + thoughts=thoughts_path, + n=self.c, + **kwargs + ) + if not new_thoughts: + return "" + if isinstance(new_thoughts, list): + self.tot_memory[thoughts_path] = new_thoughts[::-1] + else: + return "" + return self.tot_memory[thoughts_path].pop() diff --git a/libs/experimental/tests/unit_tests/fake_llm.py b/libs/experimental/tests/unit_tests/fake_llm.py new file mode 100644 index 0000000000..7da86861ea --- /dev/null +++ b/libs/experimental/tests/unit_tests/fake_llm.py @@ -0,0 +1,61 @@ +"""Fake LLM wrapper for testing purposes.""" +from typing import Any, Dict, List, Mapping, Optional, cast + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +from pydantic import validator + + +class FakeLLM(LLM): + """Fake LLM wrapper for testing purposes.""" + + queries: Optional[Mapping] = None + sequential_responses: Optional[bool] = False + response_index: int = 0 + + @validator("queries", always=True) + def check_queries_required( + cls, queries: Optional[Mapping], values: Mapping[str, Any] + ) -> Optional[Mapping]: + if values.get("sequential_response") and not queries: + raise ValueError( + "queries is required when sequential_response is set to True" + ) + return queries + + def get_num_tokens(self, text: str) -> int: + """Return number of tokens.""" + return len(text.split()) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "fake" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + if self.sequential_responses: + return self._get_next_response_in_sequence + + if self.queries is not None: + return self.queries[prompt] + if stop is None: + return "foo" + else: + return "bar" + + @property + def _identifying_params(self) -> Dict[str, Any]: + return {} + + @property + def _get_next_response_in_sequence(self) -> str: + queries = cast(Mapping, self.queries) + response = queries[list(queries.keys())[self.response_index]] + self.response_index = self.response_index + 1 + return response diff --git a/libs/experimental/tests/unit_tests/test_tot.py b/libs/experimental/tests/unit_tests/test_tot.py new file mode 100644 index 0000000000..04982f04c3 --- /dev/null +++ b/libs/experimental/tests/unit_tests/test_tot.py @@ -0,0 +1,151 @@ +import re +import unittest +from typing import Tuple + +import pytest + +from langchain_experimental.tot.base import ToTChain +from langchain_experimental.tot.checker import ToTChecker +from langchain_experimental.tot.controller import ToTController +from langchain_experimental.tot.memory import ToTDFSMemory +from langchain_experimental.tot.thought import Thought, ThoughtValidity +from langchain_experimental.tot.thought_generation import SampleCoTStrategy +from tests.unit_tests.fake_llm import FakeLLM + +sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1" +solutions = [ + "3,*,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE + " 3,4,1,2|1,6,3,*|*,1,*,3|4,*,*,1", # INVALID c=1 + " 3,4,1,2|1,7,3,*|*,1,*,3|4,*,*,1", # INVALID c=2 + " 3,4,1,2|1,8,3,*|*,1,*,3|4,*,*,1", # INVALID c=3 + " 3,4,1,2|1,2,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE c=4 (rollback) + "3,1,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # INVALID (rollback) + "3,4,1,2|1,2,3,4|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE + " 3,4,1,2|1,2,3,4|4,1,*,3|4,*,*,1", # INVALID (rollback) + " 3,4,1,2|1,2,3,4|2,1,4,3|4,*,*,1", # VALID_INTERMEDIATE + " 3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1", # VALID_INTERMEDIATE + " 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1", # VALID_FINAL +] +sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1" + + +@pytest.fixture +def fake_llm_sudoku() -> FakeLLM: + """This is a fake LLM that responds to the sudoku problem.""" + queries = {i: next_step.strip() for i, next_step in enumerate(solutions)} + return FakeLLM(queries=queries, sequential_responses=True) + + +class SudokuChecker(ToTChecker): + def evaluate( + self, problem_description: str, thoughts: Tuple[str, ...] = () + ) -> ThoughtValidity: + last_thought = thoughts[-1] + clean_solution = last_thought.replace(" ", "").replace('"', "") + regex_solution = clean_solution.replace("*", ".").replace("|", "\\|") + if sudoku_solution in clean_solution: + return ThoughtValidity.VALID_FINAL + elif re.search(regex_solution, sudoku_solution): + return ThoughtValidity.VALID_INTERMEDIATE + else: + return ThoughtValidity.INVALID + + +def test_solve_sudoku(fake_llm_sudoku: FakeLLM) -> None: + """Test simple question that should not need python.""" + tot_chain = ToTChain( + llm=fake_llm_sudoku, + checker=SudokuChecker(), + k=len(solutions), + c=4, + tot_strategy_class=SampleCoTStrategy, + ) + output = tot_chain.run({"problem_description": ""}) + assert output == sudoku_solution + + +def test_solve_sudoku_k_too_small(fake_llm_sudoku: FakeLLM) -> None: + """Test simple question that should not need python.""" + tot_chain = ToTChain( + llm=fake_llm_sudoku, + checker=SudokuChecker(), + k=len(solutions) - 1, + c=4, + tot_strategy_class=SampleCoTStrategy, + ) + output = tot_chain.run({"problem_description": ""}) + assert output != sudoku_solution + + +@pytest.fixture +def fake_llm_checker() -> FakeLLM: + """This is a fake LLM that responds with a thought validity.""" + responses = [ + "VALID", + "valid", + "INVALID", + "invalid", + "INTERMEDIATE", + "intermediate", + "SOMETHING ELSE", + ] + queries = dict(enumerate(responses)) + return FakeLLM(queries=queries, sequential_responses=True) + + +class ControllerTestCase(unittest.TestCase): + def setUp(self) -> None: + self.controller = ToTController(c=3) + + def test_empty(self) -> None: + memory = ToTDFSMemory([]) + self.assertEqual(self.controller(memory), ()) + + def test_one_thoghts(self) -> None: + thoughts = [Thought(text="a", validity=ThoughtValidity.VALID_FINAL)] + memory = ToTDFSMemory(thoughts) + self.assertEqual(self.controller(memory), ("a",)) + + def test_two_thoghts(self) -> None: + memory = ToTDFSMemory( + [ + Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE), + Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE), + ] + ) + self.assertEqual(self.controller(memory), ("a", "b")) + + def test_two_thoughts_invalid(self) -> None: + memory = ToTDFSMemory( + [ + Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE), + Thought(text="b", validity=ThoughtValidity.INVALID), + ] + ) + self.assertEqual(self.controller(memory), ("a",)) + + def test_thoughts_rollback(self) -> None: + a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE) + b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE) + c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE) + c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE) + c_3 = Thought(text="c_3", validity=ThoughtValidity.VALID_INTERMEDIATE) + + a.children = {b} + b.children = {c_1, c_2, c_3} + + memory = ToTDFSMemory([a, b, c_3]) + self.assertEqual(self.controller(memory), ("a",)) + + def test_thoughts_rollback_invalid(self) -> None: + a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE) + b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE) + c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE) + c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE) + c_3 = Thought(text="c_3", validity=ThoughtValidity.INVALID) + + a.children = {b} + b.children = {c_1, c_2, c_3} + + memory = ToTDFSMemory([a, b, c_3]) + self.assertEqual(self.controller(memory), ("a",))