You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/chains/sql_database/base.py

76 lines
2.3 KiB
Python

"""Chain for interacting with SQL Database."""
from typing import Dict, List, Any
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import PROMPT
from langchain.input import ChainedInput
from langchain.llms.base import LLM
from langchain.sql_database import SQLDatabase
from langchain.logger import CONTEXT_KEY
class SQLDatabaseChain(Chain, BaseModel):
"""Chain for interacting with SQL Database.
Example:
.. code-block:: python
from langchain import SQLDatabaseChain, OpenAI, SQLDatabase
db = SQLDatabase(...)
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
"""
llm: LLM
"""LLM wrapper to use."""
database: SQLDatabase
"""SQL Database to connect to."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
chained_input = ChainedInput(
inputs[self.input_key] + "\nSQLQuery:", inputs[CONTEXT_KEY], logger=self.logger
)
llm_inputs = {
"input": chained_input.input,
"dialect": self.database.dialect,
"table_info": self.database.table_info,
"stop": ["\nSQLResult:"],
}
sql_cmd = llm_chain.predict(**llm_inputs)
chained_input.add(sql_cmd, color="green")
result = self.database.run(sql_cmd)
chained_input.add("\nSQLResult: ")
chained_input.add(result, color="yellow")
chained_input.add("\nAnswer:")
llm_inputs["input"] = chained_input.input
final_result = llm_chain.predict(**llm_inputs)
chained_input.add(final_result, color="green")
return {self.output_key: final_result}