forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
import re
|
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
from langchain.chains import LLMChain
|
|
from langchain.experimental.plan_and_execute.planners.base import LLMPlanner
|
|
from langchain.experimental.plan_and_execute.schema import (
|
|
Plan,
|
|
PlanOutputParser,
|
|
Step,
|
|
)
|
|
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
|
from langchain.schema import SystemMessage
|
|
|
|
SYSTEM_PROMPT = (
|
|
"Let's first understand the problem and devise a plan to solve the problem."
|
|
" Please output the plan starting with the header 'Plan:' "
|
|
"and then followed by a numbered list of steps. "
|
|
"Please make the plan the minimum number of steps required "
|
|
"to accurately complete the task. If the task is a question, "
|
|
"the final step should almost always be 'Given the above steps taken, "
|
|
"please respond to the users original question'. "
|
|
"At the end of your plan, say '<END_OF_PLAN>'"
|
|
)
|
|
|
|
|
|
class PlanningOutputParser(PlanOutputParser):
|
|
def parse(self, text: str) -> Plan:
|
|
steps = [Step(value=v) for v in re.split("\n\d+\. ", text)[1:]]
|
|
return Plan(steps=steps)
|
|
|
|
|
|
def load_chat_planner(
|
|
llm: BaseLanguageModel, system_prompt: str = SYSTEM_PROMPT
|
|
) -> LLMPlanner:
|
|
prompt_template = ChatPromptTemplate.from_messages(
|
|
[
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessagePromptTemplate.from_template("{input}"),
|
|
]
|
|
)
|
|
llm_chain = LLMChain(llm=llm, prompt=prompt_template)
|
|
return LLMPlanner(
|
|
llm_chain=llm_chain,
|
|
output_parser=PlanningOutputParser(),
|
|
stop=["<END_OF_PLAN>"],
|
|
)
|