@ -1,15 +1,46 @@
""" Chain that interprets a prompt and executes bash code to perform bash operations. """
""" Chain that interprets a prompt and executes bash code to perform bash operations. """
from typing import Dict , List
import logging
import re
from typing import Any , Dict , List
from pydantic import Extra
from pydantic import Extra , Field
from langchain . chains . base import Chain
from langchain . chains . base import Chain
from langchain . chains . llm import LLMChain
from langchain . chains . llm import LLMChain
from langchain . chains . llm_bash . prompt import PROMPT
from langchain . chains . llm_bash . prompt import PROMPT
from langchain . prompts . base import BasePromptTemplate
from langchain . prompts . base import BasePromptTemplate
from langchain . schema import BaseLanguageModel
from langchain . schema import BaseLanguageModel , BaseOutputParser , OutputParserException
from langchain . utilities . bash import BashProcess
from langchain . utilities . bash import BashProcess
logger = logging . getLogger ( __name__ )
class BashOutputParser ( BaseOutputParser ) :
""" Parser for bash output. """
def parse ( self , text : str ) - > List [ str ] :
if " ```bash " in text :
return self . get_code_blocks ( text )
else :
raise OutputParserException (
f " Failed to parse bash output. Got: { text } " ,
)
@staticmethod
def get_code_blocks ( t : str ) - > List [ str ] :
""" Get multiple code blocks from the LLM result. """
code_blocks : List [ str ] = [ ]
# Bash markdown code blocks
pattern = re . compile ( r " ```bash(.*?)(?: \ n \ s*)``` " , re . DOTALL )
for match in pattern . finditer ( t ) :
matched = match . group ( 1 ) . strip ( )
if matched :
code_blocks . extend (
[ line for line in matched . split ( " \n " ) if line . strip ( ) ]
)
return code_blocks
class LLMBashChain ( Chain ) :
class LLMBashChain ( Chain ) :
""" Chain that interprets a prompt and executes bash code to perform bash operations.
""" Chain that interprets a prompt and executes bash code to perform bash operations.
@ -26,6 +57,8 @@ class LLMBashChain(Chain):
input_key : str = " question " #: :meta private:
input_key : str = " question " #: :meta private:
output_key : str = " answer " #: :meta private:
output_key : str = " answer " #: :meta private:
prompt : BasePromptTemplate = PROMPT
prompt : BasePromptTemplate = PROMPT
output_parser : BaseOutputParser = Field ( default_factory = BashOutputParser )
bash_process : BashProcess = Field ( default_factory = BashProcess ) #: :meta private:
class Config :
class Config :
""" Configuration for this pydantic object. """
""" Configuration for this pydantic object. """
@ -51,29 +84,40 @@ class LLMBashChain(Chain):
def _call ( self , inputs : Dict [ str , str ] ) - > Dict [ str , str ] :
def _call ( self , inputs : Dict [ str , str ] ) - > Dict [ str , str ] :
llm_executor = LLMChain ( prompt = self . prompt , llm = self . llm )
llm_executor = LLMChain ( prompt = self . prompt , llm = self . llm )
bash_executor = BashProcess ( )
self . callback_manager . on_text ( inputs [ self . input_key ] , verbose = self . verbose )
self . callback_manager . on_text ( inputs [ self . input_key ] , verbose = self . verbose )
t = llm_executor . predict ( question = inputs [ self . input_key ] )
t = llm_executor . predict ( question = inputs [ self . input_key ] )
self . callback_manager . on_text ( t , color = " green " , verbose = self . verbose )
self . callback_manager . on_text ( t , color = " green " , verbose = self . verbose )
t = t . strip ( )
t = t . strip ( )
if t . startswith ( " ```bash " ) :
try :
# Split the string into a list of substrings
command_list = self . output_parser . parse ( t )
command_list = t . split ( " \n " )
except OutputParserException as e :
print ( command_list )
self . callback_manager . on_chain_error ( e , verbose = self . verbose )
raise e
# Remove the first and last substrings
command_list = [ s for s in command_list [ 1 : - 1 ] ]
if self . verbose :
output = bash_executor . run ( command_list )
self . callback_manager . on_text ( " \n Code: " , verbose = self . verbose )
self . callback_manager . on_text (
self . callback_manager . on_text ( " \n Answer: " , verbose = self . verbose )
str ( command_list ) , color = " yellow " , verbose = self . verbose
self . callback_manager . on_text ( output , color = " yellow " , verbose = self . verbose )
)
else :
output = self . bash_process . run ( command_list )
raise ValueError ( f " unknown format from LLM: { t } " )
self . callback_manager . on_text ( " \n Answer: " , verbose = self . verbose )
self . callback_manager . on_text ( output , color = " yellow " , verbose = self . verbose )
return { self . output_key : output }
return { self . output_key : output }
@property
@property
def _chain_type ( self ) - > str :
def _chain_type ( self ) - > str :
return " llm_bash_chain "
return " llm_bash_chain "
@classmethod
def from_bash_process (
cls ,
bash_process : BashProcess ,
llm : BaseLanguageModel ,
* * kwargs : Any ,
) - > " LLMBashChain " :
""" Create a LLMBashChain from a BashProcess. """
return cls ( llm = llm , bash_process = bash_process , * * kwargs )