mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
e7e5cb9d08
# [WIP] Tree of Thought introducing a new ToTChain. This PR adds a new chain called ToTChain that implements the ["Large Language Model Guided Tree-of-Though"](https://arxiv.org/pdf/2305.08291.pdf) paper. There's a notebook example `docs/modules/chains/examples/tot.ipynb` that shows how to use it. Implements #4975 ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: - @hwchase17 - @vowelparrot --------- Co-authored-by: Vadim Gubergrits <vgubergrits@outbox.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
55 lines
1.6 KiB
Python
55 lines
1.6 KiB
Python
from typing import Tuple
|
|
|
|
from langchain_experimental.tot.memory import ToTDFSMemory
|
|
from langchain_experimental.tot.thought import ThoughtValidity
|
|
|
|
|
|
class ToTController:
|
|
"""
|
|
Tree of Thought (ToT) controller.
|
|
|
|
This is a version of a ToT controller, dubbed in the paper as a "Simple
|
|
Controller".
|
|
|
|
It has one parameter `c` which is the number of children to explore for each
|
|
thought.
|
|
"""
|
|
|
|
def __init__(self, c: int = 3):
|
|
"""
|
|
Initialize the controller.
|
|
|
|
Args:
|
|
c: The number of children to explore at each node.
|
|
"""
|
|
self.c = c
|
|
|
|
def __call__(self, memory: ToTDFSMemory) -> Tuple[str, ...]:
|
|
next_thought = memory.top()
|
|
parent_thought = memory.top_parent()
|
|
validity = (
|
|
ThoughtValidity.VALID_INTERMEDIATE
|
|
if next_thought is None
|
|
else next_thought.validity
|
|
)
|
|
|
|
# 1 if the current partial solution is invalid, backtrack to the parent
|
|
# thought.
|
|
if validity == ThoughtValidity.INVALID:
|
|
memory.pop()
|
|
next_thought = memory.top()
|
|
if next_thought and len(next_thought.children) >= self.c:
|
|
memory.pop()
|
|
|
|
# 2 if the current partial solution is valid but C children were
|
|
# explored and yet failed to find a final solution, backtrack to the
|
|
# parent thought.
|
|
elif (
|
|
validity == ThoughtValidity.VALID_INTERMEDIATE
|
|
and parent_thought
|
|
and len(parent_thought.children) >= self.c
|
|
):
|
|
memory.pop(2)
|
|
|
|
return tuple(thought.text for thought in memory.current_path())
|