diff --git a/langchain/chains/constitutional_ai/base.py b/langchain/chains/constitutional_ai/base.py index b3ff12f5..7845da22 100644 --- a/langchain/chains/constitutional_ai/base.py +++ b/langchain/chains/constitutional_ai/base.py @@ -18,14 +18,19 @@ class ConstitutionalChain(Chain): from langchain.llms import OpenAI from langchain.chains import LLMChain, ConstitutionalChain + from langchain.chains.constitutional_ai.models \ + import ConstitutionalPrinciple + + llm = OpenAI() qa_prompt = PromptTemplate( template="Q: {question} A:", input_variables=["question"], ) - qa_chain = LLMChain(llm=OpenAI(), prompt=qa_prompt) + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) constitutional_chain = ConstitutionalChain.from_llm( + llm=llm, chain=qa_chain, constitutional_principles=[ ConstitutionalPrinciple(