|
|
@ -1,8 +1,9 @@
|
|
|
|
"""Chain for applying constitutional principles to the outputs of another chain."""
|
|
|
|
"""Chain for applying constitutional principles to the outputs of another chain."""
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
|
|
|
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
|
|
|
|
|
|
|
from langchain.chains.constitutional_ai.principles import PRINCIPLES
|
|
|
|
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
|
|
|
|
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
@ -42,6 +43,15 @@ class ConstitutionalChain(Chain):
|
|
|
|
critique_chain: LLMChain
|
|
|
|
critique_chain: LLMChain
|
|
|
|
revision_chain: LLMChain
|
|
|
|
revision_chain: LLMChain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def get_principles(
|
|
|
|
|
|
|
|
cls, names: Optional[List[str]] = None
|
|
|
|
|
|
|
|
) -> List[ConstitutionalPrinciple]:
|
|
|
|
|
|
|
|
if names is None:
|
|
|
|
|
|
|
|
return list(PRINCIPLES.values())
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return [PRINCIPLES[name] for name in names]
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def from_llm(
|
|
|
|
def from_llm(
|
|
|
|
cls,
|
|
|
|
cls,
|
|
|
|