langchain/libs/experimental/tests/unit_tests/test_tot.py
Vadim Gubergrits e7e5cb9d08
Tree of Thought introducing a new ToTChain. (#5167)
# [WIP] Tree of Thought introducing a new ToTChain.

This PR adds a new chain called ToTChain that implements the ["Large
Language Model Guided
Tree-of-Though"](https://arxiv.org/pdf/2305.08291.pdf) paper.

There's a notebook example `docs/modules/chains/examples/tot.ipynb` that
shows how to use it.


Implements #4975


## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:

- @hwchase17
- @vowelparrot

---------

Co-authored-by: Vadim Gubergrits <vgubergrits@outbox.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
2023-07-26 21:29:39 -07:00

152 lines
5.5 KiB
Python

import re
import unittest
from typing import Tuple
import pytest
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.controller import ToTController
from langchain_experimental.tot.memory import ToTDFSMemory
from langchain_experimental.tot.thought import Thought, ThoughtValidity
from langchain_experimental.tot.thought_generation import SampleCoTStrategy
from tests.unit_tests.fake_llm import FakeLLM
sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1"
solutions = [
"3,*,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,6,3,*|*,1,*,3|4,*,*,1", # INVALID c=1
" 3,4,1,2|1,7,3,*|*,1,*,3|4,*,*,1", # INVALID c=2
" 3,4,1,2|1,8,3,*|*,1,*,3|4,*,*,1", # INVALID c=3
" 3,4,1,2|1,2,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE c=4 (rollback)
"3,1,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # INVALID (rollback)
"3,4,1,2|1,2,3,4|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,2,3,4|4,1,*,3|4,*,*,1", # INVALID (rollback)
" 3,4,1,2|1,2,3,4|2,1,4,3|4,*,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1", # VALID_FINAL
]
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1"
@pytest.fixture
def fake_llm_sudoku() -> FakeLLM:
"""This is a fake LLM that responds to the sudoku problem."""
queries = {i: next_step.strip() for i, next_step in enumerate(solutions)}
return FakeLLM(queries=queries, sequential_responses=True)
class SudokuChecker(ToTChecker):
def evaluate(
self, problem_description: str, thoughts: Tuple[str, ...] = ()
) -> ThoughtValidity:
last_thought = thoughts[-1]
clean_solution = last_thought.replace(" ", "").replace('"', "")
regex_solution = clean_solution.replace("*", ".").replace("|", "\\|")
if sudoku_solution in clean_solution:
return ThoughtValidity.VALID_FINAL
elif re.search(regex_solution, sudoku_solution):
return ThoughtValidity.VALID_INTERMEDIATE
else:
return ThoughtValidity.INVALID
def test_solve_sudoku(fake_llm_sudoku: FakeLLM) -> None:
"""Test simple question that should not need python."""
tot_chain = ToTChain(
llm=fake_llm_sudoku,
checker=SudokuChecker(),
k=len(solutions),
c=4,
tot_strategy_class=SampleCoTStrategy,
)
output = tot_chain.run({"problem_description": ""})
assert output == sudoku_solution
def test_solve_sudoku_k_too_small(fake_llm_sudoku: FakeLLM) -> None:
"""Test simple question that should not need python."""
tot_chain = ToTChain(
llm=fake_llm_sudoku,
checker=SudokuChecker(),
k=len(solutions) - 1,
c=4,
tot_strategy_class=SampleCoTStrategy,
)
output = tot_chain.run({"problem_description": ""})
assert output != sudoku_solution
@pytest.fixture
def fake_llm_checker() -> FakeLLM:
"""This is a fake LLM that responds with a thought validity."""
responses = [
"VALID",
"valid",
"INVALID",
"invalid",
"INTERMEDIATE",
"intermediate",
"SOMETHING ELSE",
]
queries = dict(enumerate(responses))
return FakeLLM(queries=queries, sequential_responses=True)
class ControllerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.controller = ToTController(c=3)
def test_empty(self) -> None:
memory = ToTDFSMemory([])
self.assertEqual(self.controller(memory), ())
def test_one_thoghts(self) -> None:
thoughts = [Thought(text="a", validity=ThoughtValidity.VALID_FINAL)]
memory = ToTDFSMemory(thoughts)
self.assertEqual(self.controller(memory), ("a",))
def test_two_thoghts(self) -> None:
memory = ToTDFSMemory(
[
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE),
]
)
self.assertEqual(self.controller(memory), ("a", "b"))
def test_two_thoughts_invalid(self) -> None:
memory = ToTDFSMemory(
[
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
Thought(text="b", validity=ThoughtValidity.INVALID),
]
)
self.assertEqual(self.controller(memory), ("a",))
def test_thoughts_rollback(self) -> None:
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_3 = Thought(text="c_3", validity=ThoughtValidity.VALID_INTERMEDIATE)
a.children = {b}
b.children = {c_1, c_2, c_3}
memory = ToTDFSMemory([a, b, c_3])
self.assertEqual(self.controller(memory), ("a",))
def test_thoughts_rollback_invalid(self) -> None:
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_3 = Thought(text="c_3", validity=ThoughtValidity.INVALID)
a.children = {b}
b.children = {c_1, c_2, c_3}
memory = ToTDFSMemory([a, b, c_3])
self.assertEqual(self.controller(memory), ("a",))