You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/smart_chains/router.py

91 lines
3.0 KiB
Python

"""Chain that takes in an input and produces an action and action input."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
class RouterChain(Chain, BaseModel, ABC):
"""Chain responsible for deciding the action to take."""
input_key: str = "input_text" #: :meta private:
action_key: str = "action" #: :meta private:
action_input_key: str = "action_input" #: :meta private:
log_key: str = "log" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Will be the input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return three keys: the action, the action input, and the log.
:meta private:
"""
return [self.action_key, self.action_input_key, self.log_key]
@abstractmethod
def get_action_and_input(self, text: str) -> Tuple[str, str, str]:
"""Return action, action input, and log (in that order)."""
@property
@abstractmethod
def observation_prefix(self) -> str:
"""Prefix to append the observation with."""
@property
@abstractmethod
def router_prefix(self) -> str:
"""Prefix to append the router call with."""
@property
def finish_action_name(self) -> str:
"""Name of the action of when to finish the chain."""
return "Final Answer"
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
action, action_input, log = self.get_action_and_input(inputs[self.input_key])
return {
self.action_key: action,
self.action_input_key: action_input,
self.log_key: log,
}
class LLMRouterChain(RouterChain, BaseModel, ABC):
"""RouterChain that uses an LLM."""
llm_chain: LLMChain
stops: Optional[List[str]]
@abstractmethod
def _extract_action_and_input(self, text: str) -> Optional[Tuple[str, str]]:
"""Extract action and action input from llm output."""
def _fix_text(self, text: str) -> str:
"""Fix the text."""
raise ValueError("fix_text not implemented for this router.")
def get_action_and_input(self, text: str) -> Tuple[str, str, str]:
"""Return action, action input, and log (in that order)."""
input_key = self.llm_chain.input_keys[0]
inputs = {input_key: text, "stop": self.stops}
full_output = self.llm_chain.predict(**inputs)
parsed_output = self._extract_action_and_input(full_output)
while parsed_output is None:
full_output = self._fix_text(full_output)
inputs = {input_key: text + full_output, "stop": self.stops}
output = self.llm_chain.predict(**inputs)
full_output += output
parsed_output = self._extract_action_and_input(full_output)
action, action_input = parsed_output
return action, action_input, full_output