2023-10-27 02:44:30 +00:00
|
|
|
from elasticsearch import Elasticsearch
|
2023-10-26 01:47:42 +00:00
|
|
|
from langchain.chat_models import ChatOpenAI
|
|
|
|
from langchain.output_parsers.json import SimpleJsonOutputParser
|
2023-10-29 05:13:22 +00:00
|
|
|
from langchain.pydantic_v1 import BaseModel
|
2023-10-26 01:47:42 +00:00
|
|
|
|
|
|
|
from .elastic_index_info import get_indices_infos
|
2023-10-27 02:44:30 +00:00
|
|
|
from .prompts import DSL_PROMPT
|
2023-10-26 01:47:42 +00:00
|
|
|
|
2023-10-29 05:13:22 +00:00
|
|
|
# Setup Elasticsearch
|
|
|
|
# This shows how to set it up for a cloud hosted version
|
|
|
|
|
|
|
|
# Password for the 'elastic' user generated by Elasticsearch
|
|
|
|
ELASTIC_PASSWORD = "..."
|
|
|
|
|
|
|
|
# Found in the 'Manage Deployment' page
|
|
|
|
CLOUD_ID = "..."
|
2023-10-26 01:47:42 +00:00
|
|
|
|
2023-10-29 05:13:22 +00:00
|
|
|
# Create the client instance
|
2023-10-26 01:47:42 +00:00
|
|
|
db = Elasticsearch(
|
2023-10-29 05:13:22 +00:00
|
|
|
cloud_id=CLOUD_ID,
|
|
|
|
basic_auth=("elastic", ELASTIC_PASSWORD)
|
2023-10-26 01:47:42 +00:00
|
|
|
)
|
|
|
|
|
2023-10-29 05:13:22 +00:00
|
|
|
# Specify indices to include
|
|
|
|
# If you want to use on your own indices, you will need to change this.
|
|
|
|
INCLUDE_INDICES = ["customers"]
|
|
|
|
|
|
|
|
# With the Elasticsearch connection created, we can now move on to the chain
|
|
|
|
|
2023-10-26 01:47:42 +00:00
|
|
|
_model = ChatOpenAI(temperature=0, model="gpt-4")
|
|
|
|
|
|
|
|
chain = {
|
|
|
|
"input": lambda x: x["input"],
|
2023-10-29 05:13:22 +00:00
|
|
|
# This line only get index info for "customers" index.
|
|
|
|
# If you are running this on your own data, you will want to change.
|
|
|
|
"indices_info": lambda _: get_indices_infos(db, include_indices=INCLUDE_INDICES),
|
2023-10-26 01:47:42 +00:00
|
|
|
"top_k": lambda x: x.get("top_k", 5),
|
|
|
|
} | DSL_PROMPT | _model | SimpleJsonOutputParser()
|
2023-10-29 05:13:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
# Nicely typed inputs for playground
|
|
|
|
class ChainInputs(BaseModel):
|
|
|
|
input: str
|
|
|
|
top_k: int = 5
|
|
|
|
|
|
|
|
|
|
|
|
chain = chain.with_types(input_type=ChainInputs)
|