langchain/libs/experimental/langchain_experimental/tot/controller.py
Vadim Gubergrits e7e5cb9d08
Tree of Thought introducing a new ToTChain. (#5167)
# [WIP] Tree of Thought introducing a new ToTChain.

This PR adds a new chain called ToTChain that implements the ["Large
Language Model Guided
Tree-of-Though"](https://arxiv.org/pdf/2305.08291.pdf) paper.

There's a notebook example `docs/modules/chains/examples/tot.ipynb` that
shows how to use it.


Implements #4975


## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:

- @hwchase17
- @vowelparrot

---------

Co-authored-by: Vadim Gubergrits <vgubergrits@outbox.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
2023-07-26 21:29:39 -07:00

55 lines
1.6 KiB
Python

from typing import Tuple
from langchain_experimental.tot.memory import ToTDFSMemory
from langchain_experimental.tot.thought import ThoughtValidity
class ToTController:
"""
Tree of Thought (ToT) controller.
This is a version of a ToT controller, dubbed in the paper as a "Simple
Controller".
It has one parameter `c` which is the number of children to explore for each
thought.
"""
def __init__(self, c: int = 3):
"""
Initialize the controller.
Args:
c: The number of children to explore at each node.
"""
self.c = c
def __call__(self, memory: ToTDFSMemory) -> Tuple[str, ...]:
next_thought = memory.top()
parent_thought = memory.top_parent()
validity = (
ThoughtValidity.VALID_INTERMEDIATE
if next_thought is None
else next_thought.validity
)
# 1 if the current partial solution is invalid, backtrack to the parent
# thought.
if validity == ThoughtValidity.INVALID:
memory.pop()
next_thought = memory.top()
if next_thought and len(next_thought.children) >= self.c:
memory.pop()
# 2 if the current partial solution is valid but C children were
# explored and yet failed to find a final solution, backtrack to the
# parent thought.
elif (
validity == ThoughtValidity.VALID_INTERMEDIATE
and parent_thought
and len(parent_thought.children) >= self.c
):
memory.pop(2)
return tuple(thought.text for thought in memory.current_path())