mirror of https://github.com/hwchase17/langchain
Merge branch 'master' into fork/feature_audio_loader_auzre_speech
commit
a59739e1d2
@ -0,0 +1,201 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/langchain-ai/langchain/docs/docs/integrations/chat/maritalk.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
||||
"\n",
|
||||
"# Maritalk\n",
|
||||
"\n",
|
||||
"## Introduction\n",
|
||||
"\n",
|
||||
"MariTalk is an assistant developed by the Brazilian company [Maritaca AI](www.maritaca.ai).\n",
|
||||
"MariTalk is based on language models that have been specially trained to understand Portuguese well.\n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to use MariTalk with LangChain through two examples:\n",
|
||||
"\n",
|
||||
"1. A simple example of how to use MariTalk to perform a task.\n",
|
||||
"2. LLM + RAG: The second example shows how to answer a question whose answer is found in a long document that does not fit within the token limit of MariTalk. For this, we will use a simple searcher (BM25) to first search the document for the most relevant sections and then feed them to MariTalk for answering."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Installation\n",
|
||||
"First, install the LangChain library (and all its dependencies) using the following command:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install langchain-core langchain-community"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## API Key\n",
|
||||
"You will need an API key that can be obtained from chat.maritaca.ai (\"Chaves da API\" section)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"### Example 1 - Pet Name Suggestions\n",
|
||||
"\n",
|
||||
"Let's define our language model, ChatMaritalk, and configure it with your API key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts.chat import ChatPromptTemplate\n",
|
||||
"from langchain_community.chat_models import ChatMaritalk\n",
|
||||
"from langchain_core.output_parsers import StrOutputParser\n",
|
||||
"\n",
|
||||
"llm = ChatMaritalk(\n",
|
||||
" api_key=\"\", # Insert your API key here\n",
|
||||
" temperature=0.7,\n",
|
||||
" max_tokens=100,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"output_parser = StrOutputParser()\n",
|
||||
"\n",
|
||||
"chat_prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"You are an assistant specialized in suggesting pet names. Given the animal, you must suggest 4 names.\",\n",
|
||||
" ),\n",
|
||||
" (\"human\", \"I have a {animal}\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = chat_prompt | llm | output_parser\n",
|
||||
"\n",
|
||||
"response = chain.invoke({\"animal\": \"dog\"})\n",
|
||||
"print(response) # should answer something like \"1. Max\\n2. Bella\\n3. Charlie\\n4. Rocky\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example 2 - RAG + LLM: UNICAMP 2024 Entrance Exam Question Answering System\n",
|
||||
"For this example, we need to install some extra libraries:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install unstructured rank_bm25 pdf2image pdfminer-six pikepdf pypdf unstructured_inference fastapi kaleido uvicorn \"pillow<10.1.0\" pillow_heif -q"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Loading the database\n",
|
||||
"\n",
|
||||
"The first step is to create a database with the information from the notice. For this, we will download the notice from the COMVEST website and segment the extracted text into 500-character windows."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import OnlinePDFLoader\n",
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||
"\n",
|
||||
"# Loading the COMVEST 2024 notice\n",
|
||||
"loader = OnlinePDFLoader(\n",
|
||||
" \"https://www.comvest.unicamp.br/wp-content/uploads/2023/10/31-2023-Dispoe-sobre-o-Vestibular-Unicamp-2024_com-retificacao.pdf\"\n",
|
||||
")\n",
|
||||
"data = loader.load()\n",
|
||||
"\n",
|
||||
"text_splitter = RecursiveCharacterTextSplitter(\n",
|
||||
" chunk_size=500, chunk_overlap=100, separators=[\"\\n\", \" \", \"\"]\n",
|
||||
")\n",
|
||||
"texts = text_splitter.split_documents(data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Creating a Searcher\n",
|
||||
"Now that we have our database, we need a searcher. For this example, we will use a simple BM25 as a search system, but this could be replaced by any other searcher (such as search via embeddings)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.retrievers import BM25Retriever\n",
|
||||
"\n",
|
||||
"retriever = BM25Retriever.from_documents(texts)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Combining Search System + LLM\n",
|
||||
"Now that we have our searcher, we just need to implement a prompt specifying the task and invoke the chain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||
"\n",
|
||||
"prompt = \"\"\"Baseado nos seguintes documentos, responda a pergunta abaixo.\n",
|
||||
"\n",
|
||||
"{context}\n",
|
||||
"\n",
|
||||
"Pergunta: {query}\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"qa_prompt = ChatPromptTemplate.from_messages([(\"human\", prompt)])\n",
|
||||
"\n",
|
||||
"chain = load_qa_chain(llm, chain_type=\"stuff\", verbose=True, prompt=qa_prompt)\n",
|
||||
"\n",
|
||||
"query = \"Qual o tempo máximo para realização da prova?\"\n",
|
||||
"\n",
|
||||
"docs = retriever.get_relevant_documents(query)\n",
|
||||
"\n",
|
||||
"chain.invoke(\n",
|
||||
" {\"input_documents\": docs, \"query\": query}\n",
|
||||
") # Should output something like: \"O tempo máximo para realização da prova é de 5 horas.\""
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,229 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "a016701c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"sidebar_label: Perplexity\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatPerplexity\n",
|
||||
"\n",
|
||||
"This notebook covers how to get started with Perplexity chat models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-01-19T11:25:00.590587Z",
|
||||
"start_time": "2024-01-19T11:25:00.127293Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.chat_models import ChatPerplexity\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "97a8ce3a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The code provided assumes that your PPLX_API_KEY is set in your environment variables. If you would like to manually specify your API key and also choose a different model, you can use the following code:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"chat = ChatPerplexity(temperature=0, pplx_api_key=\"YOUR_API_KEY\", model=\"pplx-70b-online\")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"You can check a list of available models [here](https://docs.perplexity.ai/docs/model-cards). For reproducibility, we can set the API key dynamically by taking it as an input in this notebook."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "d3e49d78",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"PPLX_API_KEY = getpass()\n",
|
||||
"os.environ[\"PPLX_API_KEY\"] = PPLX_API_KEY"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-01-19T11:25:04.349676Z",
|
||||
"start_time": "2024-01-19T11:25:03.964930Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = ChatPerplexity(temperature=0, model=\"pplx-70b-online\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-01-19T11:25:07.274418Z",
|
||||
"start_time": "2024-01-19T11:25:05.898031Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The Higgs Boson is an elementary subatomic particle that plays a crucial role in the Standard Model of particle physics, which accounts for three of the four fundamental forces governing the behavior of our universe: the strong and weak nuclear forces, electromagnetism, and gravity. The Higgs Boson is important for several reasons:\\n\\n1. **Final Elementary Particle**: The Higgs Boson is the last elementary particle waiting to be discovered under the Standard Model. Its detection helps complete the Standard Model and further our understanding of the fundamental forces in the universe.\\n\\n2. **Mass Generation**: The Higgs Boson is responsible for giving mass to other particles, a process that occurs through its interaction with the Higgs field. This mass generation is essential for the formation of atoms, molecules, and the visible matter we observe in the universe.\\n\\n3. **Implications for New Physics**: While the detection of the Higgs Boson has confirmed many aspects of the Standard Model, it also opens up new possibilities for discoveries beyond the Standard Model. Further research on the Higgs Boson could reveal insights into the nature of dark matter, supersymmetry, and other exotic phenomena.\\n\\n4. **Advancements in Technology**: The search for the Higgs Boson has led to significant advancements in technology, such as the development of artificial intelligence and machine learning algorithms used in particle accelerators like the Large Hadron Collider (LHC). These advancements have not only contributed to the discovery of the Higgs Boson but also have potential applications in various other fields.\\n\\nIn summary, the Higgs Boson is important because it completes the Standard Model, plays a crucial role in mass generation, hints at new physics phenomena beyond the Standard Model, and drives advancements in technology.\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"system = \"You are a helpful assistant.\"\n",
|
||||
"human = \"{input}\"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
|
||||
"\n",
|
||||
"chain = prompt | chat\n",
|
||||
"response = chain.invoke({\"input\": \"Why is the Higgs Boson important?\"})\n",
|
||||
"response.content"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "de6d8d5a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can format and structure the prompts like you would typically. In the following example, we ask the model to tell us a joke about cats."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-01-19T11:25:10.448733Z",
|
||||
"start_time": "2024-01-19T11:25:08.866277Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Here\\'s a joke about cats:\\n\\nWhy did the cat want math lessons from a mermaid?\\n\\nBecause it couldn\\'t find its \"core purpose\" in life!\\n\\nRemember, cats are unique and fascinating creatures, and each one has its own special traits and abilities. While some may see them as mysterious or even a bit aloof, they are still beloved pets that bring joy and companionship to their owners. So, if your cat ever seeks guidance from a mermaid, just remember that they are on their own journey to self-discovery!\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat = ChatPerplexity(temperature=0, model=\"pplx-70b-online\")\n",
|
||||
"prompt = ChatPromptTemplate.from_messages([(\"human\", \"Tell me a joke about {topic}\")])\n",
|
||||
"chain = prompt | chat\n",
|
||||
"response = chain.invoke({\"topic\": \"cats\"})\n",
|
||||
"response.content"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "13d93dc4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## `ChatPerplexity` also supports streaming functionality:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-01-19T11:25:24.438696Z",
|
||||
"start_time": "2024-01-19T11:25:14.687480Z"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Here is a list of some famous tourist attractions in Pakistan:\n",
|
||||
"\n",
|
||||
"1. **Minar-e-Pakistan**: A 62-meter high minaret in Lahore that represents the history of Pakistan.\n",
|
||||
"2. **Badshahi Mosque**: A historic mosque in Lahore with a capacity of 10,000 worshippers.\n",
|
||||
"3. **Shalimar Gardens**: A beautiful garden in Lahore with landscaped grounds and a series of cascading pools.\n",
|
||||
"4. **Pakistan Monument**: A national monument in Islamabad representing the four provinces and three districts of Pakistan.\n",
|
||||
"5. **National Museum of Pakistan**: A museum in Karachi showcasing the country's cultural history.\n",
|
||||
"6. **Faisal Mosque**: A large mosque in Islamabad that can accommodate up to 300,000 worshippers.\n",
|
||||
"7. **Clifton Beach**: A popular beach in Karachi offering water activities and recreational facilities.\n",
|
||||
"8. **Kartarpur Corridor**: A visa-free border crossing and religious corridor connecting Gurdwara Darbar Sahib in Pakistan to Gurudwara Sri Kartarpur Sahib in India.\n",
|
||||
"9. **Mohenjo-daro**: An ancient Indus Valley civilization site in Sindh, Pakistan, dating back to around 2500 BCE.\n",
|
||||
"10. **Hunza Valley**: A picturesque valley in Gilgit-Baltistan known for its stunning mountain scenery and unique culture.\n",
|
||||
"\n",
|
||||
"These attractions showcase the rich history, diverse culture, and natural beauty of Pakistan, making them popular destinations for both local and international tourists.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat = ChatPerplexity(temperature=0.7, model=\"pplx-70b-online\")\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [(\"human\", \"Give me a list of famous tourist attractions in Pakistan\")]\n",
|
||||
")\n",
|
||||
"chain = prompt | chat\n",
|
||||
"for chunk in chain.stream({}):\n",
|
||||
" print(chunk.content, end=\"\", flush=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,157 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "278b6c63",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# llamafile\n",
|
||||
"\n",
|
||||
"Let's load the [llamafile](https://github.com/Mozilla-Ocho/llamafile) Embeddings class.\n",
|
||||
"\n",
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"First, the are 3 setup steps:\n",
|
||||
"\n",
|
||||
"1. Download a llamafile. In this notebook, we use `TinyLlama-1.1B-Chat-v1.0.Q5_K_M` but there are many others available on [HuggingFace](https://huggingface.co/models?other=llamafile).\n",
|
||||
"2. Make the llamafile executable.\n",
|
||||
"3. Start the llamafile in server mode.\n",
|
||||
"\n",
|
||||
"You can run the following bash script to do all this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "43ef6dfa-9cc4-4552-8a53-5df523afae7c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%bash\n",
|
||||
"# llamafile setup\n",
|
||||
"\n",
|
||||
"# Step 1: Download a llamafile. The download may take several minutes.\n",
|
||||
"wget -nv -nc https://huggingface.co/jartine/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile\n",
|
||||
"\n",
|
||||
"# Step 2: Make the llamafile executable. Note: if you're on Windows, just append '.exe' to the filename.\n",
|
||||
"chmod +x TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile\n",
|
||||
"\n",
|
||||
"# Step 3: Start llamafile server in background. All the server logs will be written to 'tinyllama.log'.\n",
|
||||
"# Alternatively, you can just open a separate terminal outside this notebook and run: \n",
|
||||
"# ./TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile --server --nobrowser --embedding\n",
|
||||
"./TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile --server --nobrowser --embedding > tinyllama.log 2>&1 &\n",
|
||||
"pid=$!\n",
|
||||
"echo \"${pid}\" > .llamafile_pid # write the process pid to a file so we can terminate the server later"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3188b22f-879f-47b3-9a27-24412f6fad5f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Embedding texts using LlamafileEmbeddings\n",
|
||||
"\n",
|
||||
"Now, we can use the `LlamafileEmbeddings` class to interact with the llamafile server that's currently serving our TinyLlama model at http://localhost:8080."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0be1af71",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.embeddings import LlamafileEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2c66e5da",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embedder = LlamafileEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "01370375",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"This is a test document.\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a42e4035",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To generate embeddings, you can either query an invidivual text, or you can query a list of texts."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "91bc875d-829b-4c3d-8e6f-fc2dda30a3bd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_result = embedder.embed_query(text)\n",
|
||||
"query_result[:5]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a4b0d49e-0c73-44b6-aed5-5b426564e085",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"doc_result = embedder.embed_documents([text])\n",
|
||||
"doc_result[0][:5]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1ccc78fc-03ae-411d-ae73-74a4ee91c725",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%bash\n",
|
||||
"# cleanup: kill the llamafile server process\n",
|
||||
"kill $(cat .llamafile_pid)\n",
|
||||
"rm .llamafile_pid"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e971737741ff4ec9aff7dc6155a1060a59a8a6d52c757dbbe66bf8ee389494b1"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,239 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c94240f5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Gremlin (with CosmosDB) QA chain\n",
|
||||
"\n",
|
||||
"This notebook shows how to use LLMs to provide a natural language interface to a graph database you can query with the Gremlin query language."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dbc0ee68",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You will need to have a Azure CosmosDB Graph database instance. One option is to create a [free CosmosDB Graph database instance in Azure](https://learn.microsoft.com/en-us/azure/cosmos-db/free-tier). \n",
|
||||
"\n",
|
||||
"When you create your Cosmos DB account and Graph, use /type as partition key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "62812aad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import nest_asyncio\n",
|
||||
"from langchain.chains.graph_qa import GremlinQAChain\n",
|
||||
"from langchain.schema import Document\n",
|
||||
"from langchain_community.graphs import GremlinGraph\n",
|
||||
"from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship\n",
|
||||
"from langchain_openai import AzureChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0928915d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cosmosdb_name = \"mycosmosdb\"\n",
|
||||
"cosmosdb_db_id = \"graphtesting\"\n",
|
||||
"cosmosdb_db_graph_id = \"mygraph\"\n",
|
||||
"cosmosdb_access_Key = \"longstring==\"\n",
|
||||
"\n",
|
||||
"graph = GremlinGraph(\n",
|
||||
" url=f\"=wss://{cosmosdb_name}.gremlin.cosmos.azure.com:443/\",\n",
|
||||
" username=f\"/dbs/{cosmosdb_db_id}/colls/{cosmosdb_db_graph_id}\",\n",
|
||||
" password=cosmosdb_access_Key,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "995ea9b9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Seeding the database\n",
|
||||
"\n",
|
||||
"Assuming your database is empty, you can populate it using the GraphDocuments\n",
|
||||
"\n",
|
||||
"For Gremlin, always add property called 'label' for each Node.\n",
|
||||
"If no label is set, Node.type is used as a label.\n",
|
||||
"For cosmos using natural id's make sense, as they are visible in the graph explorer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fedd26b9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"source_doc = Document(\n",
|
||||
" page_content=\"Matrix is a movie where Keanu Reeves, Laurence Fishburne and Carrie-Anne Moss acted.\"\n",
|
||||
")\n",
|
||||
"movie = Node(id=\"The Matrix\", properties={\"label\": \"movie\", \"title\": \"The Matrix\"})\n",
|
||||
"actor1 = Node(id=\"Keanu Reeves\", properties={\"label\": \"actor\", \"name\": \"Keanu Reeves\"})\n",
|
||||
"actor2 = Node(\n",
|
||||
" id=\"Laurence Fishburne\", properties={\"label\": \"actor\", \"name\": \"Laurence Fishburne\"}\n",
|
||||
")\n",
|
||||
"actor3 = Node(\n",
|
||||
" id=\"Carrie-Anne Moss\", properties={\"label\": \"actor\", \"name\": \"Carrie-Anne Moss\"}\n",
|
||||
")\n",
|
||||
"rel1 = Relationship(\n",
|
||||
" id=5, type=\"ActedIn\", source=actor1, target=movie, properties={\"label\": \"ActedIn\"}\n",
|
||||
")\n",
|
||||
"rel2 = Relationship(\n",
|
||||
" id=6, type=\"ActedIn\", source=actor2, target=movie, properties={\"label\": \"ActedIn\"}\n",
|
||||
")\n",
|
||||
"rel3 = Relationship(\n",
|
||||
" id=7, type=\"ActedIn\", source=actor3, target=movie, properties={\"label\": \"ActedIn\"}\n",
|
||||
")\n",
|
||||
"rel4 = Relationship(\n",
|
||||
" id=8,\n",
|
||||
" type=\"Starring\",\n",
|
||||
" source=movie,\n",
|
||||
" target=actor1,\n",
|
||||
" properties={\"label\": \"Strarring\"},\n",
|
||||
")\n",
|
||||
"rel5 = Relationship(\n",
|
||||
" id=9,\n",
|
||||
" type=\"Starring\",\n",
|
||||
" source=movie,\n",
|
||||
" target=actor2,\n",
|
||||
" properties={\"label\": \"Strarring\"},\n",
|
||||
")\n",
|
||||
"rel6 = Relationship(\n",
|
||||
" id=10,\n",
|
||||
" type=\"Straring\",\n",
|
||||
" source=movie,\n",
|
||||
" target=actor3,\n",
|
||||
" properties={\"label\": \"Strarring\"},\n",
|
||||
")\n",
|
||||
"graph_doc = GraphDocument(\n",
|
||||
" nodes=[movie, actor1, actor2, actor3],\n",
|
||||
" relationships=[rel1, rel2, rel3, rel4, rel5, rel6],\n",
|
||||
" source=source_doc,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d18f77a3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The underlying python-gremlin has a problem when running in notebook\n",
|
||||
"# The following line is a workaround to fix the problem\n",
|
||||
"nest_asyncio.apply()\n",
|
||||
"\n",
|
||||
"# Add the document to the CosmosDB graph.\n",
|
||||
"graph.add_graph_documents([graph_doc])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "58c1a8ea",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Refresh graph schema information\n",
|
||||
"If the schema of database changes (after updates), you can refresh the schema information.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4e3de44f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"graph.refresh_schema()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1fe76ccd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(graph.schema)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "68a3c677",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Querying the graph\n",
|
||||
"\n",
|
||||
"We can now use the gremlin QA chain to ask question of the graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7476ce98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = GremlinQAChain.from_llm(\n",
|
||||
" AzureChatOpenAI(\n",
|
||||
" temperature=0,\n",
|
||||
" azure_deployment=\"gpt-4-turbo\",\n",
|
||||
" ),\n",
|
||||
" graph=graph,\n",
|
||||
" verbose=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ef8ee27b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain.invoke(\"Who played in The Matrix?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47c64027-cf42-493a-9c76-2d10ba753728",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain.run(\"How many people played in The Matrix?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,151 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import SimpleChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
|
||||
class ChatMaritalk(SimpleChatModel):
|
||||
"""`MariTalk` Chat models API.
|
||||
|
||||
This class allows interacting with the MariTalk chatbot API.
|
||||
To use it, you must provide an API key either through the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatMaritalk
|
||||
chat = ChatMaritalk(api_key="your_api_key_here")
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
"""Your MariTalk API key."""
|
||||
|
||||
temperature: float = Field(default=0.7, gt=0.0, lt=1.0)
|
||||
"""Run inference with this temperature.
|
||||
Must be in the closed interval [0.0, 1.0]."""
|
||||
|
||||
max_tokens: int = Field(default=512, gt=0)
|
||||
"""The maximum number of tokens to generate in the reply."""
|
||||
|
||||
do_sample: bool = Field(default=True)
|
||||
"""Whether or not to use sampling; use `True` to enable."""
|
||||
|
||||
top_p: float = Field(default=0.95, gt=0.0, lt=1.0)
|
||||
"""Nucleus sampling parameter controlling the size of
|
||||
the probability mass considered for sampling."""
|
||||
|
||||
system_message_workaround: bool = Field(default=True)
|
||||
"""Whether to include a workaround for system messages
|
||||
by adding them as a user message."""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Identifies the LLM type as 'maritalk'."""
|
||||
return "maritalk"
|
||||
|
||||
def parse_messages_for_model(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Union[str, List[Union[str, Dict[Any, Any]]]]]]:
|
||||
"""
|
||||
Parses messages from LangChain's format to the format expected by
|
||||
the MariTalk API.
|
||||
|
||||
Parameters:
|
||||
messages (List[BaseMessage]): A list of messages in LangChain
|
||||
format to be parsed.
|
||||
|
||||
Returns:
|
||||
A list of messages formatted for the MariTalk API.
|
||||
"""
|
||||
parsed_messages = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
parsed_messages.append({"role": "user", "content": message.content})
|
||||
elif isinstance(message, AIMessage):
|
||||
parsed_messages.append(
|
||||
{"role": "assistant", "content": message.content}
|
||||
)
|
||||
elif isinstance(message, SystemMessage) and self.system_message_workaround:
|
||||
# Maritalk models do not understand system message.
|
||||
# #Instead we add these messages as user messages.
|
||||
parsed_messages.append({"role": "user", "content": message.content})
|
||||
parsed_messages.append({"role": "assistant", "content": "ok"})
|
||||
|
||||
return parsed_messages
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Sends the parsed messages to the MariTalk API and returns the generated
|
||||
response or an error message.
|
||||
|
||||
This method makes an HTTP POST request to the MariTalk API with the
|
||||
provided messages and other parameters.
|
||||
If the request is successful and the API returns a response,
|
||||
this method returns a string containing the answer.
|
||||
If the request is rate-limited or encounters another error,
|
||||
it returns a string with the error message.
|
||||
|
||||
Parameters:
|
||||
messages (List[BaseMessage]): Messages to send to the model.
|
||||
stop (Optional[List[str]]): Tokens that will signal the model
|
||||
to stop generating further tokens.
|
||||
|
||||
Returns:
|
||||
str: If the API call is successful, returns the answer.
|
||||
If an error occurs (e.g., rate limiting), returns a string
|
||||
describing the error.
|
||||
"""
|
||||
try:
|
||||
url = "https://chat.maritaca.ai/api/chat/inference"
|
||||
headers = {"authorization": f"Key {self.api_key}"}
|
||||
stopping_tokens = stop if stop is not None else []
|
||||
|
||||
parsed_messages = self.parse_messages_for_model(messages)
|
||||
|
||||
data = {
|
||||
"messages": parsed_messages,
|
||||
"do_sample": self.do_sample,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"stopping_tokens": stopping_tokens,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
if response.status_code == 429:
|
||||
return "Rate limited, please try again soon"
|
||||
elif response.ok:
|
||||
return response.json().get("answer", "No answer found")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"An error occurred: {str(e)}"
|
||||
|
||||
# Fallback return statement, in case of unexpected code paths
|
||||
return "An unexpected error occurred"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Identifies the key parameters of the chat model for logging
|
||||
or tracking purposes.
|
||||
|
||||
Returns:
|
||||
A dictionary of the key configuration parameters.
|
||||
"""
|
||||
return {
|
||||
"system_message_workaround": self.system_message_workaround,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
@ -0,0 +1,271 @@
|
||||
"""Wrapper around Perplexity APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatPerplexity(BaseChatModel):
|
||||
"""`Perplexity AI` Chat models API.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
environment variable ``PPLX_API_KEY`` set to your API key.
|
||||
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatPerplexity
|
||||
|
||||
chat = ChatPerplexity(model="pplx-70b-online", temperature=0.7)
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model: str = "pplx-70b-online"
|
||||
"""Model name."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
pplx_api_key: Optional[str] = None
|
||||
"""Base URL path for API requests,
|
||||
leave blank if not using a proxy or service emulator."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout for requests to PerplexityChat completion API. Default is 600 seconds."""
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"pplx_api_key": "PPLX_API_KEY"}
|
||||
|
||||
@root_validator(pre=True, allow_reuse=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not a default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["pplx_api_key"] = get_from_dict_or_env(
|
||||
values, "pplx_api_key", "PPLX_API_KEY"
|
||||
)
|
||||
try:
|
||||
import openai # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.OpenAI(
|
||||
api_key=values["pplx_api_key"], base_url="https://api.perplexity.ai"
|
||||
)
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling PerplexityChat API."""
|
||||
return {
|
||||
"request_timeout": self.request_timeout,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.streaming,
|
||||
"temperature": self.temperature,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = dict(self._invocation_params)
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
self, _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
function_call["name"] = ""
|
||||
additional_kwargs["function_call"] = function_call
|
||||
if _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
default_chunk_class = AIMessageChunk
|
||||
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
stream_resp = self.client.chat.completions.create(
|
||||
model=params["model"], messages=message_dicts, stream=True
|
||||
)
|
||||
for chunk in stream_resp:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = self._convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
finish_reason = choice.get("finish_reason")
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if stream_iter:
|
||||
return generate_from_stream(stream_iter)
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.client.chat.completions.create(
|
||||
model=params["model"], messages=message_dicts
|
||||
)
|
||||
message = AIMessage(content=response.choices[0].message.content)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Mapping[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
pplx_creds: Dict[str, Any] = {
|
||||
"api_key": self.pplx_api_key,
|
||||
"api_base": "https://api.perplexity.ai",
|
||||
"model": self.model,
|
||||
}
|
||||
return {**pplx_creds, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "perplexitychat"
|
@ -0,0 +1,119 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamafileEmbeddings(BaseModel, Embeddings):
|
||||
"""Llamafile lets you distribute and run large language models with a
|
||||
single file.
|
||||
|
||||
To get started, see: https://github.com/Mozilla-Ocho/llamafile
|
||||
|
||||
To use this class, you will need to first:
|
||||
|
||||
1. Download a llamafile.
|
||||
2. Make the downloaded file executable: `chmod +x path/to/model.llamafile`
|
||||
3. Start the llamafile in server mode with embeddings enabled:
|
||||
|
||||
`./path/to/model.llamafile --server --nobrowser --embedding`
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import LlamafileEmbeddings
|
||||
embedder = LlamafileEmbeddings()
|
||||
doc_embeddings = embedder.embed_documents(
|
||||
[
|
||||
"Alpha is the first letter of the Greek alphabet",
|
||||
"Beta is the second letter of the Greek alphabet",
|
||||
]
|
||||
)
|
||||
query_embedding = embedder.embed_query(
|
||||
"What is the second letter of the Greek alphabet"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
base_url: str = "http://localhost:8080"
|
||||
"""Base url where the llamafile server is listening."""
|
||||
|
||||
request_timeout: Optional[int] = None
|
||||
"""Timeout for server requests"""
|
||||
|
||||
def _embed(self, text: str) -> List[float]:
|
||||
try:
|
||||
response = requests.post(
|
||||
url=f"{self.base_url}/embedding",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"content": text,
|
||||
},
|
||||
timeout=self.request_timeout,
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise requests.exceptions.ConnectionError(
|
||||
f"Could not connect to Llamafile server. Please make sure "
|
||||
f"that a server is running at {self.base_url}."
|
||||
)
|
||||
|
||||
# Raise exception if we got a bad (non-200) response status code
|
||||
response.raise_for_status()
|
||||
|
||||
contents = response.json()
|
||||
if "embedding" not in contents:
|
||||
raise KeyError(
|
||||
"Unexpected output from /embedding endpoint, output dict "
|
||||
"missing 'embedding' key."
|
||||
)
|
||||
|
||||
embedding = contents["embedding"]
|
||||
|
||||
# Sanity check the embedding vector:
|
||||
# Prior to llamafile v0.6.2, if the server was not started with the
|
||||
# `--embedding` option, the embedding endpoint would always return a
|
||||
# 0-vector. See issue:
|
||||
# https://github.com/Mozilla-Ocho/llamafile/issues/243
|
||||
# So here we raise an exception if the vector sums to exactly 0.
|
||||
if sum(embedding) == 0.0:
|
||||
raise ValueError(
|
||||
"Embedding sums to 0, did you start the llamafile server with "
|
||||
"the `--embedding` option enabled?"
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed documents using a llamafile server running at `self.base_url`.
|
||||
llamafile server should be started in a separate process before invoking
|
||||
this method.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
doc_embeddings = []
|
||||
for text in texts:
|
||||
doc_embeddings.append(self._embed(text))
|
||||
return doc_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a query using a llamafile server running at `self.base_url`.
|
||||
llamafile server should be started in a separate process before invoking
|
||||
this method.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self._embed(text)
|
@ -0,0 +1,207 @@
|
||||
import hashlib
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
|
||||
class GremlinGraph(GraphStore):
|
||||
"""Gremlin wrapper for graph operations.
|
||||
Parameters:
|
||||
url (Optional[str]): The URL of the Gremlin database server or env GREMLIN_URI
|
||||
username (Optional[str]): The collection-identifier like '/dbs/database/colls/graph'
|
||||
or env GREMLIN_USERNAME if none provided
|
||||
password (Optional[str]): The connection-key for database authentication
|
||||
or env GREMLIN_PASSWORD if none provided
|
||||
traversal_source (str): The traversal source to use for queries. Defaults to 'g'.
|
||||
message_serializer (Optional[Any]): The message serializer to use for requests.
|
||||
Defaults to serializer.GraphSONSerializersV2d0()
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
|
||||
*Implementation details*:
|
||||
The Gremlin queries are designed to work with Azure CosmosDB limitations
|
||||
"""
|
||||
|
||||
@property
|
||||
def get_structured_schema(self) -> Dict[str, Any]:
|
||||
return self.structured_schema
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
traversal_source: str = "g",
|
||||
message_serializer: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Create a new Gremlin graph wrapper instance."""
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from gremlin_python.driver import client, serializer
|
||||
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Please install gremlin-python first: " "`pip3 install gremlinpython"
|
||||
)
|
||||
|
||||
self.client = client.Client(
|
||||
url=get_from_env("url", "GREMLIN_URI", url),
|
||||
traversal_source=traversal_source,
|
||||
username=get_from_env("username", "GREMLIN_USERNAME", username),
|
||||
password=get_from_env("password", "GREMLIN_PASSWORD", password),
|
||||
message_serializer=message_serializer
|
||||
if message_serializer
|
||||
else serializer.GraphSONSerializersV2d0(),
|
||||
)
|
||||
self.schema: str = ""
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the Gremlin database"""
|
||||
if len(self.schema) == 0:
|
||||
self.refresh_schema()
|
||||
return self.schema
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
Refreshes the Gremlin graph schema information.
|
||||
"""
|
||||
vertex_schema = self.client.submit("g.V().label().dedup()").all().result()
|
||||
edge_schema = self.client.submit("g.E().label().dedup()").all().result()
|
||||
vertex_properties = (
|
||||
self.client.submit(
|
||||
"g.V().group().by(label).by(properties().label().dedup().fold())"
|
||||
)
|
||||
.all()
|
||||
.result()[0]
|
||||
)
|
||||
|
||||
self.structured_schema = {
|
||||
"vertex_labels": vertex_schema,
|
||||
"edge_labels": edge_schema,
|
||||
"vertice_props": vertex_properties,
|
||||
}
|
||||
|
||||
self.schema = "\n".join(
|
||||
[
|
||||
"Vertex labels are the following:",
|
||||
",".join(vertex_schema),
|
||||
"Edge labes are the following:",
|
||||
",".join(edge_schema),
|
||||
f"Vertices have following properties:\n{vertex_properties}",
|
||||
]
|
||||
)
|
||||
|
||||
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||
q = self.client.submit(query)
|
||||
return q.all().result()
|
||||
|
||||
def add_graph_documents(
|
||||
self, graph_documents: List[GraphDocument], include_source: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Take GraphDocument as input as uses it to construct a graph.
|
||||
"""
|
||||
node_cache: Dict[Union[str, int], Node] = {}
|
||||
for document in graph_documents:
|
||||
if include_source:
|
||||
# Create document vertex
|
||||
doc_props = {
|
||||
"page_content": document.source.page_content,
|
||||
"metadata": document.source.metadata,
|
||||
}
|
||||
doc_id = hashlib.md5(document.source.page_content.encode()).hexdigest()
|
||||
doc_node = self.add_node(
|
||||
Node(id=doc_id, type="Document", properties=doc_props), node_cache
|
||||
)
|
||||
|
||||
# Import nodes to vertices
|
||||
for n in document.nodes:
|
||||
node = self.add_node(n)
|
||||
if include_source:
|
||||
# Add Edge to document for each node
|
||||
self.add_edge(
|
||||
Relationship(
|
||||
type="contains information about",
|
||||
source=doc_node,
|
||||
target=node,
|
||||
properties={},
|
||||
)
|
||||
)
|
||||
self.add_edge(
|
||||
Relationship(
|
||||
type="is extracted from",
|
||||
source=node,
|
||||
target=doc_node,
|
||||
properties={},
|
||||
)
|
||||
)
|
||||
|
||||
# Edges
|
||||
for el in document.relationships:
|
||||
# Find or create the source vertex
|
||||
self.add_node(el.source, node_cache)
|
||||
# Find or create the target vertex
|
||||
self.add_node(el.target, node_cache)
|
||||
# Find or create the edge
|
||||
self.add_edge(el)
|
||||
|
||||
def build_vertex_query(self, node: Node) -> str:
|
||||
base_query = (
|
||||
f"g.V().has('id','{node.id}').fold()"
|
||||
+ f".coalesce(unfold(),addV('{node.type}')"
|
||||
+ f".property('id','{node.id}')"
|
||||
+ f".property('type','{node.type}')"
|
||||
)
|
||||
for key, value in node.properties.items():
|
||||
base_query += f".property('{key}', '{value}')"
|
||||
|
||||
return base_query + ")"
|
||||
|
||||
def build_edge_query(self, relationship: Relationship) -> str:
|
||||
source_query = f".has('id','{relationship.source.id}')"
|
||||
target_query = f".has('id','{relationship.target.id}')"
|
||||
|
||||
base_query = f""""g.V(){source_query}.as('a')
|
||||
.V(){target_query}.as('b')
|
||||
.choose(
|
||||
__.inE('{relationship.type}').where(outV().as('a')),
|
||||
__.identity(),
|
||||
__.addE('{relationship.type}').from('a').to('b')
|
||||
)
|
||||
""".replace("\n", "").replace("\t", "")
|
||||
for key, value in relationship.properties.items():
|
||||
base_query += f".property('{key}', '{value}')"
|
||||
|
||||
return base_query
|
||||
|
||||
def add_node(self, node: Node, node_cache: dict = {}) -> Node:
|
||||
# if properties does not have label, add type as label
|
||||
if "label" not in node.properties:
|
||||
node.properties["label"] = node.type
|
||||
if node.id in node_cache:
|
||||
return node_cache[node.id]
|
||||
else:
|
||||
query = self.build_vertex_query(node)
|
||||
_ = self.client.submit(query).all().result()[0]
|
||||
node_cache[node.id] = node
|
||||
return node
|
||||
|
||||
def add_edge(self, relationship: Relationship) -> Any:
|
||||
query = self.build_edge_query(relationship)
|
||||
return self.client.submit(query).all().result()
|
@ -0,0 +1,30 @@
|
||||
"""Test Perplexity Chat API wrapper."""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.chat_models import ChatPerplexity
|
||||
|
||||
os.environ["PPLX_API_KEY"] = "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_perplexity_model_name_param() -> None:
|
||||
llm = ChatPerplexity(model="foo")
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_perplexity_model_kwargs() -> None:
|
||||
llm = ChatPerplexity(model="test", model_kwargs={"foo": "bar"})
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_perplexity_initialization() -> None:
|
||||
"""Test perplexity initialization."""
|
||||
# Verify that chat perplexity can be initialized using a secret key provided
|
||||
# as a parameter rather than an environment variable.
|
||||
ChatPerplexity(
|
||||
model="test", perplexity_api_key="test", temperature=0.7, verbose=True
|
||||
)
|
@ -0,0 +1,67 @@
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from langchain_community.embeddings import LlamafileEmbeddings
|
||||
|
||||
|
||||
def mock_response() -> requests.Response:
|
||||
contents = json.dumps({"embedding": np.random.randn(512).tolist()})
|
||||
response = requests.Response()
|
||||
response.status_code = 200
|
||||
response._content = str.encode(contents)
|
||||
return response
|
||||
|
||||
|
||||
def test_embed_documents(monkeypatch: MonkeyPatch) -> None:
|
||||
"""
|
||||
Test basic functionality of the `embed_documents` method
|
||||
"""
|
||||
embedder = LlamafileEmbeddings(
|
||||
base_url="http://llamafile-host:8080",
|
||||
)
|
||||
|
||||
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
|
||||
assert url == "http://llamafile-host:8080/embedding"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# 'unknown' kwarg should be ignored
|
||||
assert json == {"content": "Test text"}
|
||||
# assert stream is False
|
||||
assert timeout is None
|
||||
return mock_response()
|
||||
|
||||
monkeypatch.setattr(requests, "post", mock_post)
|
||||
out = embedder.embed_documents(["Test text", "Test text"])
|
||||
assert isinstance(out, list)
|
||||
assert len(out) == 2
|
||||
for vec in out:
|
||||
assert len(vec) == 512
|
||||
|
||||
|
||||
def test_embed_query(monkeypatch: MonkeyPatch) -> None:
|
||||
"""
|
||||
Test basic functionality of the `embed_query` method
|
||||
"""
|
||||
embedder = LlamafileEmbeddings(
|
||||
base_url="http://llamafile-host:8080",
|
||||
)
|
||||
|
||||
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
|
||||
assert url == "http://llamafile-host:8080/embedding"
|
||||
assert headers == {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# 'unknown' kwarg should be ignored
|
||||
assert json == {"content": "Test text"}
|
||||
# assert stream is False
|
||||
assert timeout is None
|
||||
return mock_response()
|
||||
|
||||
monkeypatch.setattr(requests, "post", mock_post)
|
||||
out = embedder.embed_query("Test text")
|
||||
assert isinstance(out, list)
|
||||
assert len(out) == 512
|
@ -0,0 +1,221 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_community.graphs import GremlinGraph
|
||||
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
GRAPHDB_SPARQL_FIX_TEMPLATE,
|
||||
GREMLIN_GENERATION_PROMPT,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def extract_gremlin(text: str) -> str:
|
||||
"""Extract Gremlin code from a text.
|
||||
|
||||
Args:
|
||||
text: Text to extract Gremlin code from.
|
||||
|
||||
Returns:
|
||||
Gremlin code extracted from the text.
|
||||
"""
|
||||
text = text.replace("`", "")
|
||||
if text.startswith("gremlin"):
|
||||
text = text[len("gremlin") :]
|
||||
return text.replace("\n", "")
|
||||
|
||||
|
||||
class GremlinQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating gremlin statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: GremlinGraph = Field(exclude=True)
|
||||
gremlin_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
gremlin_fix_chain: LLMChain
|
||||
max_fix_retries: int = 3
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 100
|
||||
return_direct: bool = False
|
||||
return_intermediate_steps: bool = False
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
gremlin_fix_prompt: BasePromptTemplate = PromptTemplate(
|
||||
input_variables=["error_message", "generated_sparql", "schema"],
|
||||
template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace(
|
||||
"in Turtle format", ""
|
||||
),
|
||||
),
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> GremlinQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
|
||||
gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt)
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
gremlin_generation_chain=gremlin_generation_chain,
|
||||
gremlin_fix_chain=gremlinl_fix_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Generate gremlin statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
chain_response = self.gremlin_generation_chain.invoke(
|
||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
generated_gremlin = extract_gremlin(
|
||||
chain_response[self.gremlin_generation_chain.output_key]
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_gremlin, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_gremlin})
|
||||
|
||||
if generated_gremlin:
|
||||
context = self.execute_with_retry(
|
||||
_run_manager, callbacks, generated_gremlin
|
||||
)[: self.top_k]
|
||||
else:
|
||||
context = []
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain.invoke(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
|
||||
def execute_query(self, query: str) -> List[Any]:
|
||||
try:
|
||||
return self.graph.query(query)
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_message"):
|
||||
raise ValueError(e.status_message)
|
||||
else:
|
||||
raise ValueError(str(e))
|
||||
|
||||
def execute_with_retry(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
callbacks: CallbackManager,
|
||||
generated_gremlin: str,
|
||||
) -> List[Any]:
|
||||
try:
|
||||
return self.execute_query(generated_gremlin)
|
||||
except Exception as e:
|
||||
retries = 0
|
||||
error_message = str(e)
|
||||
self.log_invalid_query(_run_manager, generated_gremlin, error_message)
|
||||
|
||||
while retries < self.max_fix_retries:
|
||||
try:
|
||||
fix_chain_result = self.gremlin_fix_chain.invoke(
|
||||
{
|
||||
"error_message": error_message,
|
||||
# we are borrowing template from sparql
|
||||
"generated_sparql": generated_gremlin,
|
||||
"schema": self.schema,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key]
|
||||
return self.execute_query(fixed_gremlin)
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
parse_exception = str(e)
|
||||
self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception)
|
||||
|
||||
raise ValueError("The generated Gremlin query is invalid.")
|
||||
|
||||
def log_invalid_query(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
generated_query: str,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
_run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_query, color="red", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
"Gremlin Query Parse Error: ", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
error_message, color="red", end="\n\n", verbose=self.verbose
|
||||
)
|
@ -1,5 +0,0 @@
|
||||
from openai_functions_agent.agent import agent_executor
|
||||
|
||||
if __name__ == "__main__":
|
||||
question = "who won the womens world cup in 2023?"
|
||||
print(agent_executor.invoke({"input": question, "chat_history": []})) # noqa: T201
|
Loading…
Reference in New Issue