feat(chains): adding ElasticsearchDatabaseChain for interacting with analytics database (#7686)

This pull request adds a ElasticsearchDatabaseChain chain for
interacting with analytics database, in the manner of the
SQLDatabaseChain.

Maintainer: @samber
Twitter handler: samuelberthe

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/7708/head
Samuel Berthe 1 year ago committed by GitHub
parent 6d88b23ef7
commit 7d4843fe84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,218 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "dd7ec7af",
"metadata": {},
"source": [
"# Elasticsearch database\n",
"\n",
"Interact with Elasticsearch analytics database via Langchain. This chain builds search queries via the Elasticsearch DSL API (filters and aggregations).\n",
"\n",
"The Elasticsearch client must have permissions for index listing, mapping description and search queries.\n",
"\n",
"See [here](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html) for instructions on how to run Elasticsearch locally.\n",
"\n",
"Make sure to install the Elasticsearch Python client before:\n",
"\n",
"```sh\n",
"pip install elasticsearch\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "dd8eae75",
"metadata": {},
"outputs": [],
"source": [
"from elasticsearch import Elasticsearch\n",
"\n",
"from langchain.chains.elasticsearch_database import ElasticsearchDatabaseChain\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "659b5ed0",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5cde03bc",
"metadata": {},
"outputs": [],
"source": [
"# Initialize Elasticsearch python client.\n",
"# See https://elasticsearch-py.readthedocs.io/en/v8.8.2/api.html#elasticsearch.Elasticsearch\n",
"ELASTIC_SEARCH_SERVER = \"https://elastic:gvODoJ_nRYQIJZfG7=ec@localhost:9200\"\n",
"db = Elasticsearch(ELASTIC_SEARCH_SERVER, ca_certs=False, verify_certs=False)"
]
},
{
"cell_type": "markdown",
"id": "74a41374",
"metadata": {},
"source": [
"Uncomment the next cell to initially populate your db."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "430ada0f",
"metadata": {},
"outputs": [],
"source": [
"# customers = [\n",
"# {\"firstname\": \"Jennifer\", \"lastname\": \"Walters\"},\n",
"# {\"firstname\": \"Monica\",\"lastname\":\"Rambeau\"},\n",
"# {\"firstname\": \"Carol\",\"lastname\":\"Danvers\"},\n",
"# {\"firstname\": \"Wanda\",\"lastname\":\"Maximoff\"},\n",
"# {\"firstname\": \"Jennifer\",\"lastname\":\"Takeda\"},\n",
"# ]\n",
"# for i, customer in enumerate(customers):\n",
"# db.create(index=\"customers\", document=customer, id=i)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f36ae0d8",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(model_name=\"gpt-4\", temperature=0)\n",
"chain = ElasticsearchDatabaseChain.from_llm(llm=llm, database=db, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "b5d22d9d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new ElasticsearchDatabaseChain chain...\u001b[0m\n",
"What are the first names of all the customers?\n",
"ESQuery:\u001b[32;1m\u001b[1;3m{'size': 10, 'query': {'match_all': {}}, '_source': ['firstname']}\u001b[0m\n",
"ESResult: \u001b[33;1m\u001b[1;3m{'took': 5, 'timed_out': False, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}, 'hits': {'total': {'value': 6, 'relation': 'eq'}, 'max_score': 1.0, 'hits': [{'_index': 'customers', '_id': '0', '_score': 1.0, '_source': {'firstname': 'Jennifer'}}, {'_index': 'customers', '_id': '1', '_score': 1.0, '_source': {'firstname': 'Monica'}}, {'_index': 'customers', '_id': '2', '_score': 1.0, '_source': {'firstname': 'Carol'}}, {'_index': 'customers', '_id': '3', '_score': 1.0, '_source': {'firstname': 'Wanda'}}, {'_index': 'customers', '_id': '4', '_score': 1.0, '_source': {'firstname': 'Jennifer'}}, {'_index': 'customers', '_id': 'firstname', '_score': 1.0, '_source': {'firstname': 'Jennifer'}}]}}\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3mThe first names of all the customers are Jennifer, Monica, Carol, Wanda, and Jennifer.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'The first names of all the customers are Jennifer, Monica, Carol, Wanda, and Jennifer.'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question = \"What are the first names of all the customers?\"\n",
"chain.run(question)"
]
},
{
"cell_type": "markdown",
"id": "9b4bfada",
"metadata": {},
"source": [
"## Custom prompt\n",
"\n",
"For best results you'll likely need to customize the prompt."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0a494f5b",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.elasticsearch_database.prompts import DEFAULT_DSL_TEMPLATE\n",
"from langchain.prompts.prompt import PromptTemplate\n",
"\n",
"PROMPT_TEMPLATE = \"\"\"Given an input question, create a syntactically correct Elasticsearch query to run. Unless the user specifies in their question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n",
"\n",
"Unless told to do not query for all the columns from a specific index, only ask for a the few relevant columns given the question.\n",
"\n",
"Pay attention to use only the column names that you can see in the mapping description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which index. Return the query as valid json.\n",
"\n",
"Use the following format:\n",
"\n",
"Question: Question here\n",
"ESQuery: Elasticsearch Query formatted as json\n",
"\"\"\"\n",
"\n",
"PROMPT = PromptTemplate.from_template(\n",
" PROMPT_TEMPLATE,\n",
")\n",
"chain = ElasticsearchDatabaseChain.from_llm(llm=llm, database=db, query_prompt=PROMPT)"
]
},
{
"cell_type": "markdown",
"id": "372b8f93",
"metadata": {},
"source": [
"## Adding example rows from each index\n",
"\n",
"Sometimes, the format of the data is not obvious and it is optimal to include a sample of rows from the indices in the prompt to allow the LLM to understand the data before providing a final query. Here we will use this feature to let the LLM know that artists are saved with their full names by providing ten rows from the index."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eef818de",
"metadata": {},
"outputs": [],
"source": [
"chain = ElasticsearchDatabaseChain.from_llm(\n",
" llm=ChatOpenAI(temperature=0),\n",
" database=db,\n",
" sample_documents_in_index_info=2, # 2 rows from each index will be included in the prompt as sample data\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -0,0 +1,3 @@
from langchain.chains.elasticsearch_database.base import ElasticsearchDatabaseChain
__all__ = ["ElasticsearchDatabaseChain"]

@ -0,0 +1,211 @@
"""Chain for interacting with Elasticsearch Database."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT
from langchain.chains.llm import LLMChain
from langchain.output_parsers.json import SimpleJsonOutputParser
from langchain.schema import BaseLLMOutputParser, BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
class ElasticsearchDatabaseChain(Chain):
"""Chain for interacting with Elasticsearch Database.
Example:
.. code-block:: python
from langchain import ElasticsearchDatabaseChain, OpenAI
from elasticsearch import Elasticsearch
database = Elasticsearch("http://localhost:9200")
db_chain = ElasticsearchDatabaseChain.from_llm(OpenAI(), database)
"""
query_chain: LLMChain
"""Chain for creating the ES query."""
answer_chain: LLMChain
"""Chain for answering the user question."""
database: Any
"""Elasticsearch database to connect to of type elasticsearch.Elasticsearch."""
top_k: int = 10
"""Number of results to return from the query"""
ignore_indices: Optional[List[str]] = None
include_indices: Optional[List[str]] = None
input_key: str = "question" #: :meta private:
output_key: str = "result" #: :meta private:
sample_documents_in_index_info: int = 3
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator()
def validate_indices(cls, values: dict) -> dict:
if values["include_indices"] and values["ignore_indices"]:
raise ValueError(
"Cannot specify both 'include_indices' and 'ignore_indices'."
)
return values
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _list_indices(self) -> List[str]:
all_indices = [
index["index"] for index in self.database.cat.indices(format="json")
]
if self.include_indices:
all_indices = [i for i in all_indices if i in self.include_indices]
if self.ignore_indices:
all_indices = [i for i in all_indices if i not in self.ignore_indices]
return all_indices
def _get_indices_infos(self, indices: List[str]) -> str:
mappings = self.database.indices.get_mapping(index=",".join(indices))
if self.sample_documents_in_index_info > 0:
for k, v in mappings.items():
hits = self.database.search(
index=k,
query={"match_all": {}},
size=self.sample_documents_in_index_info,
)["hits"]["hits"]
hits = [str(hit["_source"]) for hit in hits]
mappings[k]["mappings"] = str(v) + "\n\n/*\n" + "\n".join(hits) + "\n*/"
return "\n\n".join(
[
"Mapping for index {}:\n{}".format(index, mappings[index]["mappings"])
for index in mappings
]
)
def _search(self, indices: List[str], query: str) -> str:
result = self.database.search(index=",".join(indices), body=query)
return str(result)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
input_text = f"{inputs[self.input_key]}\nESQuery:"
_run_manager.on_text(input_text, verbose=self.verbose)
indices = self._list_indices()
indices_info = self._get_indices_infos(indices)
query_inputs: dict = {
"input": input_text,
"top_k": str(self.top_k),
"indices_info": indices_info,
"stop": ["\nESResult:"],
}
intermediate_steps: List = []
try:
intermediate_steps.append(query_inputs) # input: es generation
es_cmd = self.query_chain.run(
callbacks=_run_manager.get_child(),
**query_inputs,
)
_run_manager.on_text(es_cmd, color="green", verbose=self.verbose)
intermediate_steps.append(
es_cmd
) # output: elasticsearch dsl generation (no checker)
intermediate_steps.append({"es_cmd": es_cmd}) # input: ES search
result = self._search(indices=indices, query=es_cmd)
intermediate_steps.append(str(result)) # output: ES search
_run_manager.on_text("\nESResult: ", verbose=self.verbose)
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
answer_inputs: dict = {"data": result, "input": input_text}
intermediate_steps.append(answer_inputs) # input: final answer
final_result = self.answer_chain.run(
callbacks=_run_manager.get_child(),
**answer_inputs,
)
intermediate_steps.append(final_result) # output: final answer
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result
except Exception as exc:
# Append intermediate steps to exception, to aid in logging and later
# improvement of few shot prompt seeds
exc.intermediate_steps = intermediate_steps # type: ignore
raise exc
@property
def _chain_type(self) -> str:
return "elasticsearch_database_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
database: Elasticsearch,
*,
query_prompt: Optional[BasePromptTemplate] = None,
answer_prompt: Optional[BasePromptTemplate] = None,
query_output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any,
) -> ElasticsearchDatabaseChain:
"""Convenience method to construct ElasticsearchDatabaseChain from an LLM.
Args:
llm: The language model to use.
database: The Elasticsearch db.
query_prompt: The prompt to use for query construction.
answer_prompt: The prompt to use for answering user question given data.
query_output_parser: The output parser to use for parsing model-generated
ES query. Defaults to SimpleJsonOutputParser.
**kwargs: Additional arguments to pass to the constructor.
"""
query_prompt = query_prompt or DSL_PROMPT
query_output_parser = query_output_parser or SimpleJsonOutputParser()
query_chain = LLMChain(
llm=llm, prompt=query_prompt, output_parser=query_output_parser
)
answer_prompt = answer_prompt or ANSWER_PROMPT
answer_chain = LLMChain(llm=llm, prompt=answer_prompt)
return cls(
query_chain=query_chain,
answer_chain=answer_chain,
database=database,
**kwargs,
)

@ -0,0 +1,36 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate
PROMPT_SUFFIX = """Only use the following Elasticsearch indices:
{indices_info}
Question: {input}
ESQuery:"""
DEFAULT_DSL_TEMPLATE = """Given an input question, create a syntactically correct Elasticsearch query to run. Unless the user specifies in their question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.
Unless told to do not query for all the columns from a specific index, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the mapping description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which index. Return the query as valid json.
Use the following format:
Question: Question here
ESQuery: Elasticsearch Query formatted as json
"""
DSL_PROMPT = PromptTemplate.from_template(DEFAULT_DSL_TEMPLATE + PROMPT_SUFFIX)
DEFAULT_ANSWER_TEMPLATE = """Given an input question and relevant data from a database, answer the user question.
Use the following format:
Question: Question here
Data: Relevant data here
Answer: Final answer here
Question: {input}
Data: {data}
Answer:"""
ANSWER_PROMPT = PromptTemplate.from_template(DEFAULT_ANSWER_TEMPLATE)

@ -2,9 +2,10 @@ from __future__ import annotations
import json
import re
from typing import List
from json import JSONDecodeError
from typing import Any, List
from langchain.schema import OutputParserException
from langchain.schema import BaseOutputParser, OutputParserException
def parse_json_markdown(json_string: str) -> dict:
@ -59,3 +60,16 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
f"to be present, but got {json_obj}"
)
return json_obj
class SimpleJsonOutputParser(BaseOutputParser[Any]):
def parse(self, text: str) -> Any:
text = text.strip()
try:
return json.loads(text)
except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e
@property
def _type(self) -> str:
return "simple_json_output_parser"

Loading…
Cancel
Save