You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/chains/sql_database/base.py

79 lines
2.4 KiB
Python

"""Chain for interacting with SQL Database."""
from typing import Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import PROMPT
from langchain.input import print_text
from langchain.llms.base import BaseLLM
from langchain.sql_database import SQLDatabase
class SQLDatabaseChain(Chain, BaseModel):
"""Chain for interacting with SQL Database.
Example:
.. code-block:: python
from langchain import SQLDatabaseChain, OpenAI, SQLDatabase
db = SQLDatabase(...)
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
"""
llm: BaseLLM
"""LLM wrapper to use."""
database: SQLDatabase
"""SQL Database to connect to."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
if self.verbose:
print_text(input_text)
llm_inputs = {
"input": input_text,
"dialect": self.database.dialect,
"table_info": self.database.table_info,
"stop": ["\nSQLResult:"],
}
sql_cmd = llm_chain.predict(**llm_inputs)
if self.verbose:
print_text(sql_cmd, color="green")
result = self.database.run(sql_cmd)
if self.verbose:
print_text("\nSQLResult: ")
print_text(result, color="yellow")
print_text("\nAnswer:")
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
llm_inputs["input"] = input_text
final_result = llm_chain.predict(**llm_inputs)
if self.verbose:
print_text(final_result, color="green")
return {self.output_key: final_result}