langchain/libs/experimental/tests/unit_tests/test_tot.py

154 lines
5.6 KiB
Python
Raw Normal View History

import re
import unittest
from typing import Tuple
import pytest
from langchain_experimental.tot.base import ToTChain
from langchain_experimental.tot.checker import ToTChecker
from langchain_experimental.tot.controller import ToTController
from langchain_experimental.tot.memory import ToTDFSMemory
from langchain_experimental.tot.thought import Thought, ThoughtValidity
from langchain_experimental.tot.thought_generation import SampleCoTStrategy
from tests.unit_tests.fake_llm import FakeLLM
sudoku_puzzle = "3,*,*,2|1,*,3,*|*,1,*,3|4,*,*,1"
solutions = [
"3,*,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,6,3,*|*,1,*,3|4,*,*,1", # INVALID c=1
" 3,4,1,2|1,7,3,*|*,1,*,3|4,*,*,1", # INVALID c=2
" 3,4,1,2|1,8,3,*|*,1,*,3|4,*,*,1", # INVALID c=3
" 3,4,1,2|1,2,3,*|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE c=4 (rollback)
"3,1,4,2|1,*,3,*|*,1,*,3|4,*,*,1", # INVALID (rollback)
"3,4,1,2|1,2,3,4|*,1,*,3|4,*,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,2,3,4|4,1,*,3|4,*,*,1", # INVALID (rollback)
" 3,4,1,2|1,2,3,4|2,1,4,3|4,*,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,*,1", # VALID_INTERMEDIATE
" 3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1", # VALID_FINAL
]
sudoku_solution = "3,4,1,2|1,2,3,4|2,1,4,3|4,3,2,1"
@pytest.fixture
def fake_llm_sudoku() -> FakeLLM:
"""This is a fake LLM that responds to the sudoku problem."""
queries = {i: next_step.strip() for i, next_step in enumerate(solutions)}
return FakeLLM(queries=queries, sequential_responses=True)
class SudokuChecker(ToTChecker):
def evaluate(
self, problem_description: str, thoughts: Tuple[str, ...] = ()
) -> ThoughtValidity:
last_thought = thoughts[-1]
clean_solution = last_thought.replace(" ", "").replace('"', "")
regex_solution = clean_solution.replace("*", ".").replace("|", "\\|")
if sudoku_solution in clean_solution:
return ThoughtValidity.VALID_FINAL
elif re.search(regex_solution, sudoku_solution):
return ThoughtValidity.VALID_INTERMEDIATE
else:
return ThoughtValidity.INVALID
@pytest.mark.requires("jinja2")
def test_solve_sudoku(fake_llm_sudoku: FakeLLM) -> None:
"""Test simple question that should not need python."""
tot_chain = ToTChain(
llm=fake_llm_sudoku,
checker=SudokuChecker(),
k=len(solutions),
c=4,
tot_strategy_class=SampleCoTStrategy,
)
output = tot_chain.run({"problem_description": ""})
assert output == sudoku_solution
@pytest.mark.requires("jinja2")
def test_solve_sudoku_k_too_small(fake_llm_sudoku: FakeLLM) -> None:
"""Test simple question that should not need python."""
tot_chain = ToTChain(
llm=fake_llm_sudoku,
checker=SudokuChecker(),
k=len(solutions) - 1,
c=4,
tot_strategy_class=SampleCoTStrategy,
)
output = tot_chain.run({"problem_description": ""})
assert output != sudoku_solution
@pytest.fixture
def fake_llm_checker() -> FakeLLM:
"""This is a fake LLM that responds with a thought validity."""
responses = [
"VALID",
"valid",
"INVALID",
"invalid",
"INTERMEDIATE",
"intermediate",
"SOMETHING ELSE",
]
queries = dict(enumerate(responses))
return FakeLLM(queries=queries, sequential_responses=True)
class ControllerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.controller = ToTController(c=3)
def test_empty(self) -> None:
memory = ToTDFSMemory([])
self.assertEqual(self.controller(memory), ())
def test_one_thoghts(self) -> None:
thoughts = [Thought(text="a", validity=ThoughtValidity.VALID_FINAL)]
memory = ToTDFSMemory(thoughts)
self.assertEqual(self.controller(memory), ("a",))
def test_two_thoghts(self) -> None:
memory = ToTDFSMemory(
[
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE),
]
)
self.assertEqual(self.controller(memory), ("a", "b"))
def test_two_thoughts_invalid(self) -> None:
memory = ToTDFSMemory(
[
Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE),
Thought(text="b", validity=ThoughtValidity.INVALID),
]
)
self.assertEqual(self.controller(memory), ("a",))
def test_thoughts_rollback(self) -> None:
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_3 = Thought(text="c_3", validity=ThoughtValidity.VALID_INTERMEDIATE)
a.children = {b}
b.children = {c_1, c_2, c_3}
memory = ToTDFSMemory([a, b, c_3])
self.assertEqual(self.controller(memory), ("a",))
def test_thoughts_rollback_invalid(self) -> None:
a = Thought(text="a", validity=ThoughtValidity.VALID_INTERMEDIATE)
b = Thought(text="b", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_1 = Thought(text="c_1", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_2 = Thought(text="c_2", validity=ThoughtValidity.VALID_INTERMEDIATE)
c_3 = Thought(text="c_3", validity=ThoughtValidity.INVALID)
a.children = {b}
b.children = {c_1, c_2, c_3}
memory = ToTDFSMemory([a, b, c_3])
self.assertEqual(self.controller(memory), ("a",))