{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tree of Thought (ToT) example\n", "\n", "The Tree of Thought (ToT) is a chain that allows you to query a Large Language Model (LLM) using the Tree of Thought technique. This is based on the paper [\"Large Language Model Guided Tree-of-Thought\"](https://arxiv.org/pdf/2305.08291.pdf)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.13) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", " warnings.warn(\n" ] } ], "source": [ "from langchain_openai import OpenAI\n", "\n", "llm = OpenAI(temperature=1, max_tokens=512, model=\"gpt-3.5-turbo-instruct\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\n", "\n", "- This is a 4x4 Sudoku puzzle.\n", "- The * represents a cell to be filled.\n", "- The | character separates rows.\n", "- At each step, replace one or more * with digits 1-4.\n", "- There must be no duplicate digits in any row, column or 2x2 subgrid.\n", "- Keep the known digits from previous valid thoughts in place.\n", "- Each thought can be a partial or the final solution.\n" ] } ], "source": [ "sudoku_puzzle = \"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\"\n", "sudoku_solution = \"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\"\n", "problem_description = f\"\"\"\n", "{sudoku_puzzle}\n", "\n", "- This is a 4x4 Sudoku puzzle.\n", "- The * represents a cell to be filled.\n", "- The | character separates rows.\n", "- At each step, replace one or more * with digits 1-4.\n", "- There must be no duplicate digits in any row, column or 2x2 subgrid.\n", "- Keep the known digits from previous valid thoughts in place.\n", "- Each thought can be a partial or the final solution.\n", "\"\"\".strip()\n", "print(problem_description)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Rules Based Checker\n", "\n", "Each thought is evaluated by the thought checker and is given a validity type: valid, invalid or partial. A simple checker can be rule based. For example, in the case of a sudoku puzzle, the checker can check if the puzzle is valid, invalid or partial.\n", "\n", "In the following code we implement a simple rule based checker for a specific 4x4 sudoku puzzle.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import re\n", "from typing import Tuple\n", "\n", "from langchain_experimental.tot.checker import ToTChecker\n", "from langchain_experimental.tot.thought import ThoughtValidity\n", "\n", "\n", "class MyChecker(ToTChecker):\n", " def evaluate(\n", " self, problem_description: str, thoughts: Tuple[str, ...] = ()\n", " ) -> ThoughtValidity:\n", " last_thought = thoughts[-1]\n", " clean_solution = last_thought.replace(\" \", \"\").replace('\"', \"\")\n", " regex_solution = clean_solution.replace(\"*\", \".\").replace(\"|\", \"\\\\|\")\n", " if sudoku_solution in clean_solution:\n", " return ThoughtValidity.VALID_FINAL\n", " elif re.search(regex_solution, sudoku_solution):\n", " return ThoughtValidity.VALID_INTERMEDIATE\n", " else:\n", " return ThoughtValidity.INVALID" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Just testing the MyChecker class above:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "checker = MyChecker()\n", "assert (\n", " checker.evaluate(\"\", (\"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\",))\n", " == ThoughtValidity.VALID_INTERMEDIATE\n", ")\n", "assert (\n", " checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\",))\n", " == ThoughtValidity.VALID_FINAL\n", ")\n", "assert (\n", " checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1\",))\n", " == ThoughtValidity.VALID_INTERMEDIATE\n", ")\n", "assert (\n", " checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1\",))\n", " == ThoughtValidity.INVALID\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tree of Thought Chain\n", "\n", "Initialize and run the ToT chain, with maximum number of interactions `k` set to `30` and the maximum number child thoughts `c` set to `8`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[1m> Entering new ToTChain chain...\u001b[0m\n", "Starting the ToT solve procedure.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/harrisonchase/workplace/langchain/libs/langchain/langchain/chains/llm.py:275: UserWarning: The predict_and_parse method is deprecated, instead pass an output parser directly to LLMChain.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31;1m\u001b[1;3mThought: 3*,*,2|1*,3,*|*,1,*,3|4,*,*,1\n", "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,*|*,1,*,3|4,*,*,1\n", "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,*,3|4,*,*,1\n", "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,2,3|4,*,*,1\n", "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|2,1,*,3|4,*,*,1\n", "\u001b[0m" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Type not serializable\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n", "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|*,3,2,*|*,1,*,3|4,1,*,*\n", "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n", "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|1,1,*,3|4,1,*,*\n", "\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,1,3,*|1,1,*,3|4,1,*,*\n", "\u001b[0m\u001b[33;1m\u001b[1;3mThought: 3,*,*,2|1,2,3,*|*,1,*,3|4,*,*,1\n", "\u001b[0m\u001b[31;1m\u001b[1;3m Thought: 3,1,4,2|1,2,3,4|2,1,4,3|4,3,2,1\n", "\u001b[0m\u001b[32;1m\u001b[1;3m Thought: 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\n", "\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ "'3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from langchain_experimental.tot.base import ToTChain\n", "\n", "tot_chain = ToTChain(\n", " llm=llm, checker=MyChecker(), k=30, c=5, verbose=True, verbose_llm=False\n", ")\n", "tot_chain.run(problem_description=problem_description)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.1" } }, "nbformat": 4, "nbformat_minor": 2 }