forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
"""Chain for interacting with SQL Database."""
|
|
from typing import Dict, List
|
|
|
|
from pydantic import BaseModel, Extra
|
|
|
|
from langchain.chains.base import Chain
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.chains.sql_database.prompt import PROMPT
|
|
from langchain.input import ChainedInput
|
|
from langchain.llms.base import LLM
|
|
from langchain.sql_database import SQLDatabase
|
|
|
|
|
|
class SQLDatabaseChain(Chain, BaseModel):
|
|
"""Chain for interacting with SQL Database.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain import SQLDatabaseChain, OpenAI, SQLDatabase
|
|
db = SQLDatabase(...)
|
|
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
|
|
"""
|
|
|
|
llm: LLM
|
|
"""LLM wrapper to use."""
|
|
database: SQLDatabase
|
|
"""SQL Database to connect to."""
|
|
input_key: str = "query" #: :meta private:
|
|
output_key: str = "result" #: :meta private:
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Return the singular input key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.input_key]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Return the singular output key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.output_key]
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
|
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
|
chained_input = ChainedInput(
|
|
inputs[self.input_key] + "\nSQLQuery:", verbose=self.verbose
|
|
)
|
|
llm_inputs = {
|
|
"input": chained_input.input,
|
|
"dialect": self.database.dialect,
|
|
"table_info": self.database.table_info,
|
|
"stop": ["\nSQLResult:"],
|
|
}
|
|
sql_cmd = llm_chain.predict(**llm_inputs)
|
|
chained_input.add(sql_cmd, color="green")
|
|
try:
|
|
result = self.database.run(sql_cmd)
|
|
except Exception as e:
|
|
result = str(e)
|
|
chained_input.add("\nSQLResult: ")
|
|
chained_input.add(result, color="yellow")
|
|
chained_input.add("\nAnswer:")
|
|
llm_inputs["input"] = chained_input.input
|
|
final_result = llm_chain.predict(**llm_inputs)
|
|
chained_input.add(final_result, color="green")
|
|
return {self.output_key: final_result}
|