mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
141 lines
3.9 KiB
Python
141 lines
3.9 KiB
Python
import json
|
|
from textwrap import dedent
|
|
from typing import List
|
|
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain.schema import BaseOutputParser
|
|
|
|
from langchain_experimental.tot.thought import ThoughtValidity
|
|
|
|
|
|
def get_cot_prompt() -> PromptTemplate:
|
|
return PromptTemplate(
|
|
template_format="jinja2",
|
|
input_variables=["problem_description", "thoughts"],
|
|
template=dedent(
|
|
"""
|
|
You are an intelligent agent that is generating one thought at a time in
|
|
a tree of thoughts setting.
|
|
|
|
PROBLEM
|
|
|
|
{{problem_description}}
|
|
|
|
{% if thoughts %}
|
|
THOUGHTS
|
|
|
|
{% for thought in thoughts %}
|
|
{{ thought }}
|
|
{% endfor %}
|
|
{% endif %}
|
|
|
|
Let's think step by step.
|
|
"""
|
|
).strip(),
|
|
)
|
|
|
|
|
|
class JSONListOutputParser(BaseOutputParser):
|
|
"""Class to parse the output of a PROPOSE_PROMPT response."""
|
|
|
|
@property
|
|
def _type(self) -> str:
|
|
return "json_list"
|
|
|
|
def parse(self, text: str) -> List[str]:
|
|
"""Parse the output of an LLM call."""
|
|
|
|
json_string = text.split("```json")[1].strip().strip("```").strip()
|
|
try:
|
|
return json.loads(json_string)
|
|
except json.JSONDecodeError:
|
|
return []
|
|
|
|
|
|
def get_propose_prompt() -> PromptTemplate:
|
|
return PromptTemplate(
|
|
template_format="jinja2",
|
|
input_variables=["problem_description", "thoughts", "n"],
|
|
output_parser=JSONListOutputParser(),
|
|
template=dedent(
|
|
"""
|
|
You are an intelligent agent that is generating thoughts in a tree of
|
|
thoughts setting.
|
|
|
|
The output should be a markdown code snippet formatted as a JSON list of
|
|
strings, including the leading and trailing "```json" and "```":
|
|
|
|
```json
|
|
[
|
|
"<thought-1>",
|
|
"<thought-2>",
|
|
"<thought-3>"
|
|
]
|
|
```
|
|
|
|
PROBLEM
|
|
|
|
{{ problem_description }}
|
|
|
|
{% if thoughts %}
|
|
VALID THOUGHTS
|
|
|
|
{% for thought in thoughts %}
|
|
{{ thought }}
|
|
{% endfor %}
|
|
|
|
Possible next {{ n }} valid thoughts based on the last valid thought:
|
|
{% else %}
|
|
|
|
Possible next {{ n }} valid thoughts based on the PROBLEM:
|
|
{%- endif -%}
|
|
"""
|
|
).strip(),
|
|
)
|
|
|
|
|
|
class CheckerOutputParser(BaseOutputParser):
|
|
def parse(self, text: str) -> ThoughtValidity:
|
|
"""Parse the output of the language model."""
|
|
text = text.upper()
|
|
if "INVALID" in text:
|
|
return ThoughtValidity.INVALID
|
|
elif "INTERMEDIATE" in text:
|
|
return ThoughtValidity.VALID_INTERMEDIATE
|
|
elif "VALID" in text:
|
|
return ThoughtValidity.VALID_FINAL
|
|
else:
|
|
return ThoughtValidity.INVALID
|
|
|
|
@property
|
|
def _type(self) -> str:
|
|
return "tot_llm_checker_output"
|
|
|
|
|
|
CHECKER_PROMPT = PromptTemplate(
|
|
input_variables=["problem_description", "thoughts"],
|
|
template=dedent(
|
|
"""
|
|
You are an intelligent agent, validating thoughts of another intelligent agent.
|
|
|
|
PROBLEM
|
|
|
|
{problem_description}
|
|
|
|
THOUGHTS
|
|
|
|
{thoughts}
|
|
|
|
Evaluate the thoughts and respond with one word.
|
|
|
|
- Respond VALID if the last thought is a valid final solution to the
|
|
problem.
|
|
- Respond INVALID if the last thought is invalid.
|
|
- Respond INTERMEDIATE if the last thought is valid but not the final
|
|
solution to the problem.
|
|
|
|
This chain of thoughts is"""
|
|
).strip(),
|
|
output_parser=CheckerOutputParser(),
|
|
)
|