forked from Archives/langchain
establish router
parent
8869b0ab0e
commit
6f55fa8ba7
@ -0,0 +1,75 @@
|
||||
"""Chain that takes in an input and produces an action and action input."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
|
||||
class RouterChain(Chain, BaseModel, ABC):
|
||||
"""Chain responsible for deciding the action to take."""
|
||||
|
||||
input_key: str = "input_text" #: :meta private:
|
||||
action_key: str = "action" #: :meta private:
|
||||
action_input_key: str = "action_input" #: :meta private:
|
||||
log_key: str = "log" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be the input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return three keys: the action, the action input, and the log.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.action_key, self.action_input_key, self.log_key]
|
||||
|
||||
@abstractmethod
|
||||
def get_action_and_input(self, text: str) -> Tuple[str, str, str]:
|
||||
"""Return action, action input, and log (in that order)."""
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
action, action_input, log = self.get_action_and_input(inputs[self.input_key])
|
||||
return {
|
||||
self.action_key: action,
|
||||
self.action_input_key: action_input,
|
||||
self.log_key: log,
|
||||
}
|
||||
|
||||
|
||||
class LLMRouterChain(RouterChain, BaseModel, ABC):
|
||||
"""RouterChain that uses an LLM."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
stops: Optional[List[str]]
|
||||
|
||||
@abstractmethod
|
||||
def _extract_action_and_input(self, text: str) -> Optional[Tuple[str, str]]:
|
||||
"""Extract action and action input from llm output."""
|
||||
|
||||
def _fix_text(self, text: str) -> str:
|
||||
"""Fix the text."""
|
||||
raise ValueError("fix_text not implemented for this router.")
|
||||
|
||||
def get_action_and_input(self, text: str) -> Tuple[str, str, str]:
|
||||
"""Return action, action input, and log (in that order)."""
|
||||
input_key = self.llm_chain.input_keys[0]
|
||||
inputs = {input_key: text, "stop": self.stops}
|
||||
full_output = self.llm_chain.predict(**inputs)
|
||||
parsed_output = self._extract_action_and_input(full_output)
|
||||
while parsed_output is None:
|
||||
full_output = self._fix_text(full_output)
|
||||
inputs = {input_key: text + full_output, "stop": self.stops}
|
||||
output = self.llm_chain.predict(**inputs)
|
||||
full_output += output
|
||||
parsed_output = self._extract_action_and_input(full_output)
|
||||
action, action_input = parsed_output
|
||||
return action, action_input, full_output
|
Loading…
Reference in New Issue