langchain/templates/elastic-query-generator/elastic_query_generator/chain.py
Bagatur 480626dc99
docs, community[patch], experimental[patch], langchain[patch], cli[pa… (#15412)
…tch]: import models from community

ran
```bash
git grep -l 'from langchain\.chat_models' | xargs -L 1 sed -i '' "s/from\ langchain\.chat_models/from\ langchain_community.chat_models/g"
git grep -l 'from langchain\.llms' | xargs -L 1 sed -i '' "s/from\ langchain\.llms/from\ langchain_community.llms/g"
git grep -l 'from langchain\.embeddings' | xargs -L 1 sed -i '' "s/from\ langchain\.embeddings/from\ langchain_community.embeddings/g"
git checkout master libs/langchain/tests/unit_tests/llms
git checkout master libs/langchain/tests/unit_tests/chat_models
git checkout master libs/langchain/tests/unit_tests/embeddings/test_imports.py
make format
cd libs/langchain; make format
cd ../experimental; make format
cd ../core; make format
```
2024-01-02 15:32:16 -05:00

52 lines
1.4 KiB
Python

from elasticsearch import Elasticsearch
from langchain.output_parsers.json import SimpleJsonOutputParser
from langchain_community.chat_models import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel
from .elastic_index_info import get_indices_infos
from .prompts import DSL_PROMPT
# Setup Elasticsearch
# This shows how to set it up for a cloud hosted version
# Password for the 'elastic' user generated by Elasticsearch
ELASTIC_PASSWORD = "..."
# Found in the 'Manage Deployment' page
CLOUD_ID = "..."
# Create the client instance
db = Elasticsearch(cloud_id=CLOUD_ID, basic_auth=("elastic", ELASTIC_PASSWORD))
# Specify indices to include
# If you want to use on your own indices, you will need to change this.
INCLUDE_INDICES = ["customers"]
# With the Elasticsearch connection created, we can now move on to the chain
_model = ChatOpenAI(temperature=0, model="gpt-4")
chain = (
{
"input": lambda x: x["input"],
# This line only get index info for "customers" index.
# If you are running this on your own data, you will want to change.
"indices_info": lambda _: get_indices_infos(
db, include_indices=INCLUDE_INDICES
),
"top_k": lambda x: x.get("top_k", 5),
}
| DSL_PROMPT
| _model
| SimpleJsonOutputParser()
)
# Nicely typed inputs for playground
class ChainInputs(BaseModel):
input: str
top_k: int = 5
chain = chain.with_types(input_type=ChainInputs)