mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
e7e5cb9d08
# [WIP] Tree of Thought introducing a new ToTChain. This PR adds a new chain called ToTChain that implements the ["Large Language Model Guided Tree-of-Though"](https://arxiv.org/pdf/2305.08291.pdf) paper. There's a notebook example `docs/modules/chains/examples/tot.ipynb` that shows how to use it. Implements #4975 ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: - @hwchase17 - @vowelparrot --------- Co-authored-by: Vadim Gubergrits <vgubergrits@outbox.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
152 lines
5.5 KiB
Python
152 lines
5.5 KiB
Python
import re
|
|
import unittest
|
|
from typing import Tuple
|
|
|
|
import pytest
|
|
|
|
from langchain_experimental.tot.base import ToTChain
|
|
from langchain_experimental.tot.checker import ToTChecker
|
|
from langchain_experimental.tot.controller import ToTController
|
|
from langchain_experimental.tot.memory import ToTDFSMemory
|
|
from langchain_experimental.tot.thought import Thought, ThoughtValidity
|
|
from langchain_experimental.tot.thought_generation import SampleCoTStrategy
|
|
from tests.unit_tests.fake_llm import FakeLLM
|
|
|
|
sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1"
|
|
solutions = [
|
|
"3,*,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
|
|
" 3,4,1,2|1,6,3,*|*,1,*,3|4,*,*,1", # INVALID c=1
|
|
" 3,4,1,2|1,7,3,*|*,1,*,3|4,*,*,1", # INVALID c=2
|
|
" 3,4,1,2|1,8,3,*|*,1,*,3|4,*,*,1", # INVALID c=3
|
|
" 3,4,1,2|1,2,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE c=4 (rollback)
|
|
"3,1,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # INVALID (rollback)
|
|
"3,4,1,2|1,2,3,4|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
|
|
" 3,4,1,2|1,2,3,4|4,1,*,3|4,*,*,1", # INVALID (rollback)
|
|
" 3,4,1,2|1,2,3,4|2,1,4,3|4,*,*,1", # VALID_INTERMEDIATE
|
|
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1", # VALID_INTERMEDIATE
|
|
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1", # VALID_FINAL
|
|
]
|
|
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1"
|
|
|
|
|
|
@pytest.fixture
|
|
def fake_llm_sudoku() -> FakeLLM:
|
|
"""This is a fake LLM that responds to the sudoku problem."""
|
|
queries = {i: next_step.strip() for i, next_step in enumerate(solutions)}
|
|
return FakeLLM(queries=queries, sequential_responses=True)
|
|
|
|
|
|
class SudokuChecker(ToTChecker):
|
|
def evaluate(
|
|
self, problem_description: str, thoughts: Tuple[str, ...] = ()
|
|
) -> ThoughtValidity:
|
|
last_thought = thoughts[-1]
|
|
clean_solution = last_thought.replace(" ", "").replace('"', "")
|
|
regex_solution = clean_solution.replace("*", ".").replace("|", "\\|")
|
|
if sudoku_solution in clean_solution:
|
|
return ThoughtValidity.VALID_FINAL
|
|
elif re.search(regex_solution, sudoku_solution):
|
|
return ThoughtValidity.VALID_INTERMEDIATE
|
|
else:
|
|
return ThoughtValidity.INVALID
|
|
|
|
|
|
def test_solve_sudoku(fake_llm_sudoku: FakeLLM) -> None:
|
|
"""Test simple question that should not need python."""
|
|
tot_chain = ToTChain(
|
|
llm=fake_llm_sudoku,
|
|
checker=SudokuChecker(),
|
|
k=len(solutions),
|
|
c=4,
|
|
tot_strategy_class=SampleCoTStrategy,
|
|
)
|
|
output = tot_chain.run({"problem_description": ""})
|
|
assert output == sudoku_solution
|
|
|
|
|
|
def test_solve_sudoku_k_too_small(fake_llm_sudoku: FakeLLM) -> None:
|
|
"""Test simple question that should not need python."""
|
|
tot_chain = ToTChain(
|
|
llm=fake_llm_sudoku,
|
|
checker=SudokuChecker(),
|
|
k=len(solutions) - 1,
|
|
c=4,
|
|
tot_strategy_class=SampleCoTStrategy,
|
|
)
|
|
output = tot_chain.run({"problem_description": ""})
|
|
assert output != sudoku_solution
|
|
|
|
|
|
@pytest.fixture
|
|
def fake_llm_checker() -> FakeLLM:
|
|
"""This is a fake LLM that responds with a thought validity."""
|
|
responses = [
|
|
"VALID",
|
|
"valid",
|
|
"INVALID",
|
|
"invalid",
|
|
"INTERMEDIATE",
|
|
"intermediate",
|
|
"SOMETHING ELSE",
|
|
]
|
|
queries = dict(enumerate(responses))
|
|
return FakeLLM(queries=queries, sequential_responses=True)
|
|
|
|
|
|
class ControllerTestCase(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.controller = ToTController(c=3)
|
|
|
|
def test_empty(self) -> None:
|
|
memory = ToTDFSMemory([])
|
|
self.assertEqual(self.controller(memory), ())
|
|
|
|
def test_one_thoghts(self) -> None:
|
|
thoughts = [Thought(text="a", validity=ThoughtValidity.VALID_FINAL)]
|
|
memory = ToTDFSMemory(thoughts)
|
|
self.assertEqual(self.controller(memory), ("a",))
|
|
|
|
def test_two_thoghts(self) -> None:
|
|
memory = ToTDFSMemory(
|
|
[
|
|
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
|
|
Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE),
|
|
]
|
|
)
|
|
self.assertEqual(self.controller(memory), ("a", "b"))
|
|
|
|
def test_two_thoughts_invalid(self) -> None:
|
|
memory = ToTDFSMemory(
|
|
[
|
|
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
|
|
Thought(text="b", validity=ThoughtValidity.INVALID),
|
|
]
|
|
)
|
|
self.assertEqual(self.controller(memory), ("a",))
|
|
|
|
def test_thoughts_rollback(self) -> None:
|
|
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
|
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
|
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
|
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
|
c_3 = Thought(text="c_3", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
|
|
|
a.children = {b}
|
|
b.children = {c_1, c_2, c_3}
|
|
|
|
memory = ToTDFSMemory([a, b, c_3])
|
|
self.assertEqual(self.controller(memory), ("a",))
|
|
|
|
def test_thoughts_rollback_invalid(self) -> None:
|
|
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
|
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
|
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
|
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
|
|
c_3 = Thought(text="c_3", validity=ThoughtValidity.INVALID)
|
|
|
|
a.children = {b}
|
|
b.children = {c_1, c_2, c_3}
|
|
|
|
memory = ToTDFSMemory([a, b, c_3])
|
|
self.assertEqual(self.controller(memory), ("a",))
|