diff --git a/docs/extras/modules/chains/additional/elasticsearch_database.ipynb b/docs/extras/modules/chains/additional/elasticsearch_database.ipynb new file mode 100644 index 0000000000..460bcb63c3 --- /dev/null +++ b/docs/extras/modules/chains/additional/elasticsearch_database.ipynb @@ -0,0 +1,218 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "dd7ec7af", + "metadata": {}, + "source": [ + "# Elasticsearch database\n", + "\n", + "Interact with Elasticsearch analytics database via Langchain. This chain builds search queries via the Elasticsearch DSL API (filters and aggregations).\n", + "\n", + "The Elasticsearch client must have permissions for index listing, mapping description and search queries.\n", + "\n", + "See [here](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html) for instructions on how to run Elasticsearch locally.\n", + "\n", + "Make sure to install the Elasticsearch Python client before:\n", + "\n", + "```sh\n", + "pip install elasticsearch\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "dd8eae75", + "metadata": {}, + "outputs": [], + "source": [ + "from elasticsearch import Elasticsearch\n", + "\n", + "from langchain.chains.elasticsearch_database import ElasticsearchDatabaseChain\n", + "from langchain.chat_models import ChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "659b5ed0", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5cde03bc", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize Elasticsearch python client.\n", + "# See https://elasticsearch-py.readthedocs.io/en/v8.8.2/api.html#elasticsearch.Elasticsearch\n", + "ELASTIC_SEARCH_SERVER = \"https://elastic:gvODoJ_nRYQIJZfG7=ec@localhost:9200\"\n", + "db = Elasticsearch(ELASTIC_SEARCH_SERVER, ca_certs=False, verify_certs=False)" + ] + }, + { + "cell_type": "markdown", + "id": "74a41374", + "metadata": {}, + "source": [ + "Uncomment the next cell to initially populate your db." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "430ada0f", + "metadata": {}, + "outputs": [], + "source": [ + "# customers = [\n", + "# {\"firstname\": \"Jennifer\", \"lastname\": \"Walters\"},\n", + "# {\"firstname\": \"Monica\",\"lastname\":\"Rambeau\"},\n", + "# {\"firstname\": \"Carol\",\"lastname\":\"Danvers\"},\n", + "# {\"firstname\": \"Wanda\",\"lastname\":\"Maximoff\"},\n", + "# {\"firstname\": \"Jennifer\",\"lastname\":\"Takeda\"},\n", + "# ]\n", + "# for i, customer in enumerate(customers):\n", + "# db.create(index=\"customers\", document=customer, id=i)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f36ae0d8", + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(model_name=\"gpt-4\", temperature=0)\n", + "chain = ElasticsearchDatabaseChain.from_llm(llm=llm, database=db, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "b5d22d9d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ElasticsearchDatabaseChain chain...\u001b[0m\n", + "What are the first names of all the customers?\n", + "ESQuery:\u001b[32;1m\u001b[1;3m{'size': 10, 'query': {'match_all': {}}, '_source': ['firstname']}\u001b[0m\n", + "ESResult: \u001b[33;1m\u001b[1;3m{'took': 5, 'timed_out': False, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}, 'hits': {'total': {'value': 6, 'relation': 'eq'}, 'max_score': 1.0, 'hits': [{'_index': 'customers', '_id': '0', '_score': 1.0, '_source': {'firstname': 'Jennifer'}}, {'_index': 'customers', '_id': '1', '_score': 1.0, '_source': {'firstname': 'Monica'}}, {'_index': 'customers', '_id': '2', '_score': 1.0, '_source': {'firstname': 'Carol'}}, {'_index': 'customers', '_id': '3', '_score': 1.0, '_source': {'firstname': 'Wanda'}}, {'_index': 'customers', '_id': '4', '_score': 1.0, '_source': {'firstname': 'Jennifer'}}, {'_index': 'customers', '_id': 'firstname', '_score': 1.0, '_source': {'firstname': 'Jennifer'}}]}}\u001b[0m\n", + "Answer:\u001b[32;1m\u001b[1;3mThe first names of all the customers are Jennifer, Monica, Carol, Wanda, and Jennifer.\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The first names of all the customers are Jennifer, Monica, Carol, Wanda, and Jennifer.'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "question = \"What are the first names of all the customers?\"\n", + "chain.run(question)" + ] + }, + { + "cell_type": "markdown", + "id": "9b4bfada", + "metadata": {}, + "source": [ + "## Custom prompt\n", + "\n", + "For best results you'll likely need to customize the prompt." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0a494f5b", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains.elasticsearch_database.prompts import DEFAULT_DSL_TEMPLATE\n", + "from langchain.prompts.prompt import PromptTemplate\n", + "\n", + "PROMPT_TEMPLATE = \"\"\"Given an input question, create a syntactically correct Elasticsearch query to run. Unless the user specifies in their question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n", + "\n", + "Unless told to do not query for all the columns from a specific index, only ask for a the few relevant columns given the question.\n", + "\n", + "Pay attention to use only the column names that you can see in the mapping description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which index. Return the query as valid json.\n", + "\n", + "Use the following format:\n", + "\n", + "Question: Question here\n", + "ESQuery: Elasticsearch Query formatted as json\n", + "\"\"\"\n", + "\n", + "PROMPT = PromptTemplate.from_template(\n", + " PROMPT_TEMPLATE,\n", + ")\n", + "chain = ElasticsearchDatabaseChain.from_llm(llm=llm, database=db, query_prompt=PROMPT)" + ] + }, + { + "cell_type": "markdown", + "id": "372b8f93", + "metadata": {}, + "source": [ + "## Adding example rows from each index\n", + "\n", + "Sometimes, the format of the data is not obvious and it is optimal to include a sample of rows from the indices in the prompt to allow the LLM to understand the data before providing a final query. Here we will use this feature to let the LLM know that artists are saved with their full names by providing ten rows from the index." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eef818de", + "metadata": {}, + "outputs": [], + "source": [ + "chain = ElasticsearchDatabaseChain.from_llm(\n", + " llm=ChatOpenAI(temperature=0),\n", + " database=db,\n", + " sample_documents_in_index_info=2, # 2 rows from each index will be included in the prompt as sample data\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/elasticsearch_database/__init__.py b/langchain/chains/elasticsearch_database/__init__.py new file mode 100644 index 0000000000..9b7bf85406 --- /dev/null +++ b/langchain/chains/elasticsearch_database/__init__.py @@ -0,0 +1,3 @@ +from langchain.chains.elasticsearch_database.base import ElasticsearchDatabaseChain + +__all__ = ["ElasticsearchDatabaseChain"] diff --git a/langchain/chains/elasticsearch_database/base.py b/langchain/chains/elasticsearch_database/base.py new file mode 100644 index 0000000000..17f8ddcdb7 --- /dev/null +++ b/langchain/chains/elasticsearch_database/base.py @@ -0,0 +1,211 @@ +"""Chain for interacting with Elasticsearch Database.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from pydantic import Extra, root_validator + +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.chains.base import Chain +from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT +from langchain.chains.llm import LLMChain +from langchain.output_parsers.json import SimpleJsonOutputParser +from langchain.schema import BaseLLMOutputParser, BasePromptTemplate +from langchain.schema.language_model import BaseLanguageModel + +if TYPE_CHECKING: + from elasticsearch import Elasticsearch + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + + +class ElasticsearchDatabaseChain(Chain): + """Chain for interacting with Elasticsearch Database. + + Example: + .. code-block:: python + + from langchain import ElasticsearchDatabaseChain, OpenAI + from elasticsearch import Elasticsearch + + database = Elasticsearch("http://localhost:9200") + db_chain = ElasticsearchDatabaseChain.from_llm(OpenAI(), database) + """ + + query_chain: LLMChain + """Chain for creating the ES query.""" + answer_chain: LLMChain + """Chain for answering the user question.""" + database: Any + """Elasticsearch database to connect to of type elasticsearch.Elasticsearch.""" + top_k: int = 10 + """Number of results to return from the query""" + ignore_indices: Optional[List[str]] = None + include_indices: Optional[List[str]] = None + input_key: str = "question" #: :meta private: + output_key: str = "result" #: :meta private: + sample_documents_in_index_info: int = 3 + return_intermediate_steps: bool = False + """Whether or not to return the intermediate steps along with the final answer.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator() + def validate_indices(cls, values: dict) -> dict: + if values["include_indices"] and values["ignore_indices"]: + raise ValueError( + "Cannot specify both 'include_indices' and 'ignore_indices'." + ) + return values + + @property + def input_keys(self) -> List[str]: + """Return the singular input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the singular output key. + + :meta private: + """ + if not self.return_intermediate_steps: + return [self.output_key] + else: + return [self.output_key, INTERMEDIATE_STEPS_KEY] + + def _list_indices(self) -> List[str]: + all_indices = [ + index["index"] for index in self.database.cat.indices(format="json") + ] + + if self.include_indices: + all_indices = [i for i in all_indices if i in self.include_indices] + if self.ignore_indices: + all_indices = [i for i in all_indices if i not in self.ignore_indices] + + return all_indices + + def _get_indices_infos(self, indices: List[str]) -> str: + mappings = self.database.indices.get_mapping(index=",".join(indices)) + if self.sample_documents_in_index_info > 0: + for k, v in mappings.items(): + hits = self.database.search( + index=k, + query={"match_all": {}}, + size=self.sample_documents_in_index_info, + )["hits"]["hits"] + hits = [str(hit["_source"]) for hit in hits] + mappings[k]["mappings"] = str(v) + "\n\n/*\n" + "\n".join(hits) + "\n*/" + return "\n\n".join( + [ + "Mapping for index {}:\n{}".format(index, mappings[index]["mappings"]) + for index in mappings + ] + ) + + def _search(self, indices: List[str], query: str) -> str: + result = self.database.search(index=",".join(indices), body=query) + return str(result) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + input_text = f"{inputs[self.input_key]}\nESQuery:" + _run_manager.on_text(input_text, verbose=self.verbose) + indices = self._list_indices() + indices_info = self._get_indices_infos(indices) + query_inputs: dict = { + "input": input_text, + "top_k": str(self.top_k), + "indices_info": indices_info, + "stop": ["\nESResult:"], + } + intermediate_steps: List = [] + try: + intermediate_steps.append(query_inputs) # input: es generation + es_cmd = self.query_chain.run( + callbacks=_run_manager.get_child(), + **query_inputs, + ) + + _run_manager.on_text(es_cmd, color="green", verbose=self.verbose) + intermediate_steps.append( + es_cmd + ) # output: elasticsearch dsl generation (no checker) + intermediate_steps.append({"es_cmd": es_cmd}) # input: ES search + result = self._search(indices=indices, query=es_cmd) + intermediate_steps.append(str(result)) # output: ES search + + _run_manager.on_text("\nESResult: ", verbose=self.verbose) + _run_manager.on_text(result, color="yellow", verbose=self.verbose) + + _run_manager.on_text("\nAnswer:", verbose=self.verbose) + answer_inputs: dict = {"data": result, "input": input_text} + intermediate_steps.append(answer_inputs) # input: final answer + final_result = self.answer_chain.run( + callbacks=_run_manager.get_child(), + **answer_inputs, + ) + + intermediate_steps.append(final_result) # output: final answer + _run_manager.on_text(final_result, color="green", verbose=self.verbose) + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + return chain_result + except Exception as exc: + # Append intermediate steps to exception, to aid in logging and later + # improvement of few shot prompt seeds + exc.intermediate_steps = intermediate_steps # type: ignore + raise exc + + @property + def _chain_type(self) -> str: + return "elasticsearch_database_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + database: Elasticsearch, + *, + query_prompt: Optional[BasePromptTemplate] = None, + answer_prompt: Optional[BasePromptTemplate] = None, + query_output_parser: Optional[BaseLLMOutputParser] = None, + **kwargs: Any, + ) -> ElasticsearchDatabaseChain: + """Convenience method to construct ElasticsearchDatabaseChain from an LLM. + + Args: + llm: The language model to use. + database: The Elasticsearch db. + query_prompt: The prompt to use for query construction. + answer_prompt: The prompt to use for answering user question given data. + query_output_parser: The output parser to use for parsing model-generated + ES query. Defaults to SimpleJsonOutputParser. + **kwargs: Additional arguments to pass to the constructor. + """ + query_prompt = query_prompt or DSL_PROMPT + query_output_parser = query_output_parser or SimpleJsonOutputParser() + query_chain = LLMChain( + llm=llm, prompt=query_prompt, output_parser=query_output_parser + ) + answer_prompt = answer_prompt or ANSWER_PROMPT + answer_chain = LLMChain(llm=llm, prompt=answer_prompt) + return cls( + query_chain=query_chain, + answer_chain=answer_chain, + database=database, + **kwargs, + ) diff --git a/langchain/chains/elasticsearch_database/prompts.py b/langchain/chains/elasticsearch_database/prompts.py new file mode 100644 index 0000000000..9d9b6b00fe --- /dev/null +++ b/langchain/chains/elasticsearch_database/prompts.py @@ -0,0 +1,36 @@ +# flake8: noqa +from langchain.prompts.prompt import PromptTemplate + +PROMPT_SUFFIX = """Only use the following Elasticsearch indices: +{indices_info} + +Question: {input} +ESQuery:""" + +DEFAULT_DSL_TEMPLATE = """Given an input question, create a syntactically correct Elasticsearch query to run. Unless the user specifies in their question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database. + +Unless told to do not query for all the columns from a specific index, only ask for a the few relevant columns given the question. + +Pay attention to use only the column names that you can see in the mapping description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which index. Return the query as valid json. + +Use the following format: + +Question: Question here +ESQuery: Elasticsearch Query formatted as json +""" + +DSL_PROMPT = PromptTemplate.from_template(DEFAULT_DSL_TEMPLATE + PROMPT_SUFFIX) + +DEFAULT_ANSWER_TEMPLATE = """Given an input question and relevant data from a database, answer the user question. + +Use the following format: + +Question: Question here +Data: Relevant data here +Answer: Final answer here + +Question: {input} +Data: {data} +Answer:""" + +ANSWER_PROMPT = PromptTemplate.from_template(DEFAULT_ANSWER_TEMPLATE) diff --git a/langchain/output_parsers/json.py b/langchain/output_parsers/json.py index 7ac22b025d..0a64fa348d 100644 --- a/langchain/output_parsers/json.py +++ b/langchain/output_parsers/json.py @@ -2,9 +2,10 @@ from __future__ import annotations import json import re -from typing import List +from json import JSONDecodeError +from typing import Any, List -from langchain.schema import OutputParserException +from langchain.schema import BaseOutputParser, OutputParserException def parse_json_markdown(json_string: str) -> dict: @@ -59,3 +60,16 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict: f"to be present, but got {json_obj}" ) return json_obj + + +class SimpleJsonOutputParser(BaseOutputParser[Any]): + def parse(self, text: str) -> Any: + text = text.strip() + try: + return json.loads(text) + except JSONDecodeError as e: + raise OutputParserException(f"Invalid json output: {text}") from e + + @property + def _type(self) -> str: + return "simple_json_output_parser"