Tree of Thought introducing a new ToTChain. (#5167)

# [WIP] Tree of Thought introducing a new ToTChain.

This PR adds a new chain called ToTChain that implements the ["Large
Language Model Guided
Tree-of-Though"](https://arxiv.org/pdf/2305.08291.pdf) paper.

There's a notebook example `docs/modules/chains/examples/tot.ipynb` that
shows how to use it.


Implements #4975


## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:

- @hwchase17
- @vowelparrot

---------

Co-authored-by: Vadim Gubergrits <vgubergrits@outbox.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/8339/head
Vadim Gubergrits 12 months ago committed by GitHub
parent 412e29d436
commit e7e5cb9d08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,239 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tree of Thouht (ToT) example\n",
"\n",
"The Tree of Thought (ToT) is a chain that allows you to query a Large Language Model (LLM) using the Tree of Thought technique. This is based on the papaer [\"Large Language Model Guided Tree-of-Thought\"](https://arxiv.org/pdf/2305.08291.pdf)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.13) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from langchain.llms import OpenAI\n",
"\n",
"llm = OpenAI(temperature=1, max_tokens=512, model=\"text-davinci-003\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\n",
"\n",
"- This is a 4x4 Sudoku puzzle.\n",
"- The * represents a cell to be filled.\n",
"- The | character separates rows.\n",
"- At each step, replace one or more * with digits 1-4.\n",
"- There must be no duplicate digits in any row, column or 2x2 subgrid.\n",
"- Keep the known digits from previous valid thoughts in place.\n",
"- Each thought can be a partial or the final solution.\n"
]
}
],
"source": [
"sudoku_puzzle = \"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\"\n",
"sudoku_solution = \"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\"\n",
"problem_description = f\"\"\"\n",
"{sudoku_puzzle}\n",
"\n",
"- This is a 4x4 Sudoku puzzle.\n",
"- The * represents a cell to be filled.\n",
"- The | character separates rows.\n",
"- At each step, replace one or more * with digits 1-4.\n",
"- There must be no duplicate digits in any row, column or 2x2 subgrid.\n",
"- Keep the known digits from previous valid thoughts in place.\n",
"- Each thought can be a partial or the final solution.\n",
"\"\"\".strip()\n",
"print(problem_description)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Rules Based Checker\n",
"\n",
"Each thought is evaluated by the thought checker and is given a validity type: valid, invalid or partial. A simple checker can be rule based. For example, in the case of a sudoku puzzle, the checker can check if the puzzle is valid, invalid or partial.\n",
"\n",
"In the following code we implement a simple rule based checker for a specific 4x4 sudoku puzzle.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from typing import Tuple\n",
"from langchain_experimental.tot.checker import ToTChecker\n",
"from langchain_experimental.tot.thought import ThoughtValidity\n",
"import re\n",
"\n",
"class MyChecker(ToTChecker):\n",
" def evaluate(self, problem_description: str, thoughts: Tuple[str, ...] = ()) -> ThoughtValidity:\n",
" last_thought = thoughts[-1]\n",
" clean_solution = last_thought.replace(\" \", \"\").replace('\"', \"\")\n",
" regex_solution = clean_solution.replace(\"*\", \".\").replace(\"|\", \"\\\\|\")\n",
" if sudoku_solution in clean_solution:\n",
" return ThoughtValidity.VALID_FINAL\n",
" elif re.search(regex_solution, sudoku_solution):\n",
" return ThoughtValidity.VALID_INTERMEDIATE\n",
" else:\n",
" return ThoughtValidity.INVALID"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just testing the MyChecker class above:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"checker = MyChecker()\n",
"assert checker.evaluate(\"\", (\"3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1\",)) == ThoughtValidity.VALID_INTERMEDIATE\n",
"assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\",)) == ThoughtValidity.VALID_FINAL\n",
"assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1\",)) == ThoughtValidity.VALID_INTERMEDIATE\n",
"assert checker.evaluate(\"\", (\"3,4,1,2|1,2,3,4|2,1,4,3|4,*,3,1\",)) == ThoughtValidity.INVALID"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tree of Thought Chain\n",
"\n",
"Initialize and run the ToT chain, with maximum number of interactions `k` set to `30` and the maximum number child thoughts `c` set to `8`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new ToTChain chain...\u001b[0m\n",
"Starting the ToT solve procedure.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/workplace/langchain/libs/langchain/langchain/chains/llm.py:275: UserWarning: The predict_and_parse method is deprecated, instead pass an output parser directly to LLMChain.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31;1m\u001b[1;3mThought: 3*,*,2|1*,3,*|*,1,*,3|4,*,*,1\n",
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,*|*,1,*,3|4,*,*,1\n",
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,*,3|4,*,*,1\n",
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|*,1,2,3|4,*,*,1\n",
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3*,1,2|1*,3,4|2,1,*,3|4,*,*,1\n",
"\u001b[0m"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Type <enum 'ThoughtValidity'> not serializable\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n",
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,*,*,2|*,3,2,*|*,1,*,3|4,1,*,*\n",
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|*,1,*,3|4,1,*,*\n",
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,*,3,*|1,1,*,3|4,1,*,*\n",
"\u001b[0m\u001b[31;1m\u001b[1;3mThought: 3,2,*,2|1,1,3,*|1,1,*,3|4,1,*,*\n",
"\u001b[0m\u001b[33;1m\u001b[1;3mThought: 3,*,*,2|1,2,3,*|*,1,*,3|4,*,*,1\n",
"\u001b[0m\u001b[31;1m\u001b[1;3m Thought: 3,1,4,2|1,2,3,4|2,1,4,3|4,3,2,1\n",
"\u001b[0m\u001b[32;1m\u001b[1;3m Thought: 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1\n",
"\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_experimental.tot.base import ToTChain\n",
"\n",
"tot_chain = ToTChain(llm=llm, checker=MyChecker(), k=30, c=5, verbose=True, verbose_llm=False)\n",
"tot_chain.run(problem_description=problem_description)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -0,0 +1,4 @@
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
__all__ = ["ToTChain", "ToTChecker"]

@ -0,0 +1,155 @@
"""
This a Tree of Thought (ToT) chain based on the paper "Large Language Model
Guided Tree-of-Thought"
https://arxiv.org/pdf/2305.08291.pdf
The Tree of Thought (ToT) chain uses a tree structure to explore the space of
possible solutions to a problem.
"""
from __future__ import annotations
from textwrap import indent
from typing import Any, Dict, List, Optional, Type
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from pydantic import Extra
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.controller import ToTController
from langchain_experimental.tot.memory import ToTDFSMemory
from langchain_experimental.tot.thought import Thought, ThoughtValidity
from langchain_experimental.tot.thought_generation import (
BaseThoughtGenerationStrategy,
ProposePromptStrategy,
)
class ToTChain(Chain):
"""
A Chain implementing the Tree of Thought (ToT).
"""
llm: BaseLanguageModel
"""
Language model to use. It must be set to produce different variations for
the same prompt.
"""
checker: ToTChecker
"""ToT Checker to use."""
output_key: str = "response" #: :meta private:
k: int = 10
"""The maximmum number of conversation rounds"""
c: int = 3
"""The number of children to explore at each node"""
tot_memory: ToTDFSMemory = ToTDFSMemory()
tot_controller: ToTController = ToTController()
tot_strategy_class: Type[BaseThoughtGenerationStrategy] = ProposePromptStrategy
verbose_llm: bool = False
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@classmethod
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> ToTChain:
"""
Create a ToTChain from a language model.
:param llm: The language model to use.
:param kwargs: Additional arguments to pass to the ToTChain constructor.
"""
return cls(llm=llm, **kwargs)
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.tot_controller.c = self.c
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return ["problem_description"]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
def log_thought(
self,
thought: Thought,
level: int,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> None:
if run_manager:
colors = {
ThoughtValidity.VALID_FINAL: "green",
ThoughtValidity.VALID_INTERMEDIATE: "yellow",
ThoughtValidity.INVALID: "red",
}
text = indent(f"Thought: {thought.text}\n", prefix=" " * level)
run_manager.on_text(
text=text, color=colors[thought.validity], verbose=self.verbose
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
if run_manager:
run_manager.on_text(text="Starting the ToT solve procedure.\n")
problem_description = inputs["problem_description"]
checker_inputs = {"problem_description": problem_description}
thoughts_path: tuple[str, ...] = ()
thought_generator = self.tot_strategy_class(
llm=self.llm, c=self.c, verbose=self.verbose_llm
)
level = 0
for _ in range(self.k):
level = self.tot_memory.level
thought_text = thought_generator.next_thought(
problem_description, thoughts_path, callbacks=_run_manager.get_child()
)
checker_inputs["thoughts"] = thoughts_path + (thought_text,)
thought_validity = self.checker(
checker_inputs, callbacks=_run_manager.get_child()
)["validity"]
thought = Thought(text=thought_text, validity=thought_validity)
if thought.validity == ThoughtValidity.VALID_FINAL:
self.log_thought(thought, level, run_manager)
return {self.output_key: thought.text}
self.tot_memory.store(thought)
self.log_thought(thought, level, run_manager)
thoughts_path = self.tot_controller(self.tot_memory)
return {self.output_key: "No solution found"}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
raise NotImplementedError("Async not implemented yet")
@property
def _chain_type(self) -> str:
return "tot"

@ -0,0 +1,52 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain_experimental.tot.thought import ThoughtValidity
class ToTChecker(Chain, ABC):
"""
Tree of Thought (ToT) checker.
This is an abstract ToT checker that must be implemented by the user. You
can implement a simple rule-based checker or a more sophisticated
neural network based classifier.
"""
output_key: str = "validity" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""The checker input keys.
:meta private:
"""
return ["problem_description", "thoughts"]
@property
def output_keys(self) -> List[str]:
"""The checker output keys.
:meta private:
"""
return [self.output_key]
@abstractmethod
def evaluate(
self,
problem_description: str,
thoughts: Tuple[str, ...] = (),
) -> ThoughtValidity:
"""
Evaluate the response to the problem description and return the solution type.
"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, ThoughtValidity]:
return {self.output_key: self.evaluate(**inputs)}

@ -0,0 +1,54 @@
from typing import Tuple
from langchain_experimental.tot.memory import ToTDFSMemory
from langchain_experimental.tot.thought import ThoughtValidity
class ToTController:
"""
Tree of Thought (ToT) controller.
This is a version of a ToT controller, dubbed in the paper as a "Simple
Controller".
It has one parameter `c` which is the number of children to explore for each
thought.
"""
def __init__(self, c: int = 3):
"""
Initialize the controller.
Args:
c: The number of children to explore at each node.
"""
self.c = c
def __call__(self, memory: ToTDFSMemory) -> Tuple[str, ...]:
next_thought = memory.top()
parent_thought = memory.top_parent()
validity = (
ThoughtValidity.VALID_INTERMEDIATE
if next_thought is None
else next_thought.validity
)
# 1 if the current partial solution is invalid, backtrack to the parent
# thought.
if validity == ThoughtValidity.INVALID:
memory.pop()
next_thought = memory.top()
if next_thought and len(next_thought.children) >= self.c:
memory.pop()
# 2 if the current partial solution is valid but C children were
# explored and yet failed to find a final solution, backtrack to the
# parent thought.
elif (
validity == ThoughtValidity.VALID_INTERMEDIATE
and parent_thought
and len(parent_thought.children) >= self.c
):
memory.pop(2)
return tuple(thought.text for thought in memory.current_path())

@ -0,0 +1,46 @@
from __future__ import annotations
from typing import List, Optional
from langchain_experimental.tot.thought import Thought
class ToTDFSMemory:
"""
Memory for the Tree of Thought (ToT) chain. Implemented as a stack of
thoughts. This allows for a depth first search (DFS) of the ToT.
"""
def __init__(self, stack: Optional[List[Thought]] = None):
self.stack: list[Thought] = stack or []
def top(self) -> Optional[Thought]:
"Get the top of the stack without popping it."
return self.stack[-1] if len(self.stack) > 0 else None
def pop(self, n: int = 1) -> Optional[Thought]:
"Pop the top n elements of the stack and return the last one."
if len(self.stack) < n:
return None
for _ in range(n):
node = self.stack.pop()
return node
def top_parent(self) -> Optional[Thought]:
"Get the parent of the top of the stack without popping it."
return self.stack[-2] if len(self.stack) > 1 else None
def store(self, node: Thought) -> None:
"Add a node on the top of the stack."
if len(self.stack) > 0:
self.stack[-1].children.add(node)
self.stack.append(node)
@property
def level(self) -> int:
"Return the current level of the stack."
return len(self.stack)
def current_path(self) -> List[Thought]:
"Return the thoughts path."
return self.stack[:]

@ -0,0 +1,137 @@
import json
from textwrap import dedent
from typing import List
from langchain.prompts import PromptTemplate
from langchain.schema import BaseOutputParser
from langchain_experimental.tot.thought import ThoughtValidity
COT_PROMPT = PromptTemplate(
template_format="jinja2",
input_variables=["problem_description", "thoughts"],
template=dedent(
"""
You are an intelligent agent that is generating one thought at a time in
a tree of thoughts setting.
PROBLEM
{{problem_description}}
{% if thoughts %}
THOUGHTS
{% for thought in thoughts %}
{{ thought }}
{% endfor %}
{% endif %}
Let's think step by step.
"""
).strip(),
)
class JSONListOutputParser(BaseOutputParser):
"""Class to parse the output of a PROPOSE_PROMPT response."""
@property
def _type(self) -> str:
return "json_list"
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
json_string = text.split("```json")[1].strip().strip("```").strip()
try:
return json.loads(json_string)
except json.JSONDecodeError:
return []
PROPOSE_PROMPT = PromptTemplate(
template_format="jinja2",
input_variables=["problem_description", "thoughts", "n"],
output_parser=JSONListOutputParser(),
template=dedent(
"""
You are an intelligent agent that is generating thoughts in a tree of
thoughts setting.
The output should be a markdown code snippet formatted as a JSON list of
strings, including the leading and trailing "```json" and "```":
```json
[
"<thought-1>",
"<thought-2>",
"<thought-3>"
]
```
PROBLEM
{{ problem_description }}
{% if thoughts %}
VALID THOUGHTS
{% for thought in thoughts %}
{{ thought }}
{% endfor %}
Possible next {{ n }} valid thoughts based on the last valid thought:
{% else %}
Possible next {{ n }} valid thoughts based on the PROBLEM:
{%- endif -%}
"""
).strip(),
)
class CheckerOutputParser(BaseOutputParser):
def parse(self, text: str) -> ThoughtValidity:
"""Parse the output of the language model."""
text = text.upper()
if "INVALID" in text:
return ThoughtValidity.INVALID
elif "INTERMEDIATE" in text:
return ThoughtValidity.VALID_INTERMEDIATE
elif "VALID" in text:
return ThoughtValidity.VALID_FINAL
else:
return ThoughtValidity.INVALID
@property
def _type(self) -> str:
return "tot_llm_checker_output"
CHECKER_PROMPT = PromptTemplate(
input_variables=["problem_description", "thoughts"],
template=dedent(
"""
You are an intelligent agent, validating thoughts of another intelligent agent.
PROBLEM
{problem_description}
THOUGHTS
{thoughts}
Evaluate the thoughts and respond with one word.
- Respond VALID if the last thought is a valid final solution to the
poblem.
- Respond INVALID if the last thought is invalid.
- Respond INTERMEDIATE if the last thought is valid but not the final
solution to the problem.
This chain of thoughts is"""
).strip(),
output_parser=CheckerOutputParser(),
)

@ -0,0 +1,21 @@
from __future__ import annotations
from enum import Enum
from typing import Set
from pydantic import BaseModel, Field
class ThoughtValidity(Enum):
VALID_INTERMEDIATE = 0
VALID_FINAL = 1
INVALID = 2
class Thought(BaseModel):
text: str
validity: ThoughtValidity
children: Set[Thought] = Field(default_factory=set)
def __hash__(self) -> int:
return id(self)

@ -0,0 +1,94 @@
"""
We provide two strategies for generating thoughts in the Tree of Thoughts (ToT)
framework to avoid repetition:
These strategies ensure that the language model generates diverse and
non-repeating thoughts, which are crucial for problem-solving tasks that require
exploration.
"""
from abc import abstractmethod
from typing import Any, Dict, List, Tuple
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from pydantic import Field
from langchain_experimental.tot.prompts import COT_PROMPT, PROPOSE_PROMPT
class BaseThoughtGenerationStrategy(LLMChain):
"""
Base class for a thought generation strategy.
"""
c: int = 3
"""The number of children thoughts to propose at each step."""
@abstractmethod
def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any
) -> str:
"""
Generate the next thought given the problem description and the thoughts
generated so far.
"""
class SampleCoTStrategy(BaseThoughtGenerationStrategy):
"""
Sample thoughts from a Chain-of-Thought (CoT) prompt.
This strategy works better when the thought space is rich, such as when each
thought is a paragraph. Independent and identically distributed samples
lead to diversity, which helps to avoid repetition.
"""
prompt: BasePromptTemplate = COT_PROMPT
def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any
) -> str:
response_text = self.predict_and_parse(
problem_description=problem_description, thoughts=thoughts_path, **kwargs
)
return response_text if isinstance(response_text, str) else ""
class ProposePromptStrategy(BaseThoughtGenerationStrategy):
"""
Propose thoughts sequentially using a "propose prompt".
This strategy works better when the thought space is more constrained, such
as when each thought is just a word or a line. Proposing different thoughts
in the same prompt completion helps to avoid duplication.
"""
prompt: BasePromptTemplate = PROPOSE_PROMPT
tot_memory: Dict[Tuple[str, ...], List[str]] = Field(default_factory=dict)
def next_thought(
self,
problem_description: str,
thoughts_path: Tuple[str, ...] = (),
**kwargs: Any
) -> str:
if thoughts_path not in self.tot_memory or not self.tot_memory[thoughts_path]:
new_thoughts = self.predict_and_parse(
problem_description=problem_description,
thoughts=thoughts_path,
n=self.c,
**kwargs
)
if not new_thoughts:
return ""
if isinstance(new_thoughts, list):
self.tot_memory[thoughts_path] = new_thoughts[::-1]
else:
return ""
return self.tot_memory[thoughts_path].pop()

@ -0,0 +1,61 @@
"""Fake LLM wrapper for testing purposes."""
from typing import Any, Dict, List, Mapping, Optional, cast
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from pydantic import validator
class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes."""
queries: Optional[Mapping] = None
sequential_responses: Optional[bool] = False
response_index: int = 0
@validator("queries", always=True)
def check_queries_required(
cls, queries: Optional[Mapping], values: Mapping[str, Any]
) -> Optional[Mapping]:
if values.get("sequential_response") and not queries:
raise ValueError(
"queries is required when sequential_response is set to True"
)
return queries
def get_num_tokens(self, text: str) -> int:
"""Return number of tokens."""
return len(text.split())
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.sequential_responses:
return self._get_next_response_in_sequence
if self.queries is not None:
return self.queries[prompt]
if stop is None:
return "foo"
else:
return "bar"
@property
def _identifying_params(self) -> Dict[str, Any]:
return {}
@property
def _get_next_response_in_sequence(self) -> str:
queries = cast(Mapping, self.queries)
response = queries[list(queries.keys())[self.response_index]]
self.response_index = self.response_index + 1
return response

@ -0,0 +1,151 @@
import re
import unittest
from typing import Tuple
import pytest
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.controller import ToTController
from langchain_experimental.tot.memory import ToTDFSMemory
from langchain_experimental.tot.thought import Thought, ThoughtValidity
from langchain_experimental.tot.thought_generation import SampleCoTStrategy
from tests.unit_tests.fake_llm import FakeLLM
sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1"
solutions = [
"3,*,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,6,3,*|*,1,*,3|4,*,*,1", # INVALID c=1
" 3,4,1,2|1,7,3,*|*,1,*,3|4,*,*,1", # INVALID c=2
" 3,4,1,2|1,8,3,*|*,1,*,3|4,*,*,1", # INVALID c=3
" 3,4,1,2|1,2,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE c=4 (rollback)
"3,1,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # INVALID (rollback)
"3,4,1,2|1,2,3,4|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,2,3,4|4,1,*,3|4,*,*,1", # INVALID (rollback)
" 3,4,1,2|1,2,3,4|2,1,4,3|4,*,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1", # VALID_FINAL
]
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1"
@pytest.fixture
def fake_llm_sudoku() -> FakeLLM:
"""This is a fake LLM that responds to the sudoku problem."""
queries = {i: next_step.strip() for i, next_step in enumerate(solutions)}
return FakeLLM(queries=queries, sequential_responses=True)
class SudokuChecker(ToTChecker):
def evaluate(
self, problem_description: str, thoughts: Tuple[str, ...] = ()
) -> ThoughtValidity:
last_thought = thoughts[-1]
clean_solution = last_thought.replace(" ", "").replace('"', "")
regex_solution = clean_solution.replace("*", ".").replace("|", "\\|")
if sudoku_solution in clean_solution:
return ThoughtValidity.VALID_FINAL
elif re.search(regex_solution, sudoku_solution):
return ThoughtValidity.VALID_INTERMEDIATE
else:
return ThoughtValidity.INVALID
def test_solve_sudoku(fake_llm_sudoku: FakeLLM) -> None:
"""Test simple question that should not need python."""
tot_chain = ToTChain(
llm=fake_llm_sudoku,
checker=SudokuChecker(),
k=len(solutions),
c=4,
tot_strategy_class=SampleCoTStrategy,
)
output = tot_chain.run({"problem_description": ""})
assert output == sudoku_solution
def test_solve_sudoku_k_too_small(fake_llm_sudoku: FakeLLM) -> None:
"""Test simple question that should not need python."""
tot_chain = ToTChain(
llm=fake_llm_sudoku,
checker=SudokuChecker(),
k=len(solutions) - 1,
c=4,
tot_strategy_class=SampleCoTStrategy,
)
output = tot_chain.run({"problem_description": ""})
assert output != sudoku_solution
@pytest.fixture
def fake_llm_checker() -> FakeLLM:
"""This is a fake LLM that responds with a thought validity."""
responses = [
"VALID",
"valid",
"INVALID",
"invalid",
"INTERMEDIATE",
"intermediate",
"SOMETHING ELSE",
]
queries = dict(enumerate(responses))
return FakeLLM(queries=queries, sequential_responses=True)
class ControllerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.controller = ToTController(c=3)
def test_empty(self) -> None:
memory = ToTDFSMemory([])
self.assertEqual(self.controller(memory), ())
def test_one_thoghts(self) -> None:
thoughts = [Thought(text="a", validity=ThoughtValidity.VALID_FINAL)]
memory = ToTDFSMemory(thoughts)
self.assertEqual(self.controller(memory), ("a",))
def test_two_thoghts(self) -> None:
memory = ToTDFSMemory(
[
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE),
]
)
self.assertEqual(self.controller(memory), ("a", "b"))
def test_two_thoughts_invalid(self) -> None:
memory = ToTDFSMemory(
[
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
Thought(text="b", validity=ThoughtValidity.INVALID),
]
)
self.assertEqual(self.controller(memory), ("a",))
def test_thoughts_rollback(self) -> None:
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_3 = Thought(text="c_3", validity=ThoughtValidity.VALID_INTERMEDIATE)
a.children = {b}
b.children = {c_1, c_2, c_3}
memory = ToTDFSMemory([a, b, c_3])
self.assertEqual(self.controller(memory), ("a",))
def test_thoughts_rollback_invalid(self) -> None:
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_3 = Thought(text="c_3", validity=ThoughtValidity.INVALID)
a.children = {b}
b.children = {c_1, c_2, c_3}
memory = ToTDFSMemory([a, b, c_3])
self.assertEqual(self.controller(memory), ("a",))
Loading…
Cancel
Save