mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
b14d74dd4d
Add an iMessage chat loader
421 lines
12 KiB
Plaintext
421 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "01fcfa2f-33a9-48f3-835a-b1956c394d6b",
|
|
"metadata": {},
|
|
"source": [
|
|
"# iMessage\n",
|
|
"\n",
|
|
"This notebook shows how to use the iMessage chat loader. This class helps convert iMessage conversations to LangChain chat messages.\n",
|
|
"\n",
|
|
"On MacOS, iMessage stores conversations in a sqlite database at `~/Library/Messages/chat.db` (at least for macOS Ventura 13.4). \n",
|
|
"The `IMessageChatLoader` loads from this database file. \n",
|
|
"\n",
|
|
"1. Create the `IMessageChatLoader` with the file path pointed to `chat.db` database you'd like to process.\n",
|
|
"2. Call `loader.load()` (or `loader.lazy_load()`) to perform the conversion. Optionally use `merge_chat_runs` to combine message from the same sender in sequence, and/or `map_ai_messages` to convert messages from the specified sender to the \"AIMessage\" class.\n",
|
|
"\n",
|
|
"## 1. Access Chat DB\n",
|
|
"\n",
|
|
"It's likely that your terminal is denied access to `~/Library/Messages`. To use this class, you can copy the DB to an accessible directory (e.g., Documents) and load from there. Alternatively (and not recommended), you can grant full disk access for your terminal emulator in System Settings > Securityand Privacy > Full Disk Access.\n",
|
|
"\n",
|
|
"We have created an example database you can use at [this linked drive file](https://drive.google.com/file/d/1NebNKqTA2NXApCmeH6mu0unJD2tANZzo/view?usp=sharing)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "036ce7e0-a38f-4cbe-89a6-a205ae7c23be",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"File chat.db downloaded.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# This uses some example data\n",
|
|
"import requests\n",
|
|
"\n",
|
|
"def download_drive_file(url: str, output_path: str = 'chat.db') -> None:\n",
|
|
" file_id = url.split('/')[-2]\n",
|
|
" download_url = f'https://drive.google.com/uc?export=download&id={file_id}'\n",
|
|
"\n",
|
|
" response = requests.get(download_url)\n",
|
|
" if response.status_code != 200:\n",
|
|
" print('Failed to download the file.')\n",
|
|
" return\n",
|
|
"\n",
|
|
" with open(output_path, 'wb') as file:\n",
|
|
" file.write(response.content)\n",
|
|
" print(f'File {output_path} downloaded.')\n",
|
|
"\n",
|
|
"url = 'https://drive.google.com/file/d/1NebNKqTA2NXApCmeH6mu0unJD2tANZzo/view?usp=sharing'\n",
|
|
"\n",
|
|
"# Download file to chat.db\n",
|
|
"download_drive_file(url)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cf60f703-76f1-4602-a723-02c59535c1af",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2. Create the Chat Loader\n",
|
|
"\n",
|
|
"Provide the loader with the file path to the zip directory. You can optionally specify the user id that maps to an ai message as well an configure whether to merge message runs."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "4b8b432a-d2bc-49e1-b35f-761730a8fd6d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.chat_loaders.imessage import IMessageChatLoader"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "8ec6661b-0aca-48ae-9e2b-6412856c287b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"loader = IMessageChatLoader(\n",
|
|
" path=\"./chat.db\",\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8805a7c5-84b4-49f5-8989-0022f2054ace",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 3. Load messages\n",
|
|
"\n",
|
|
"The `load()` (or `lazy_load`) methods return a list of \"ChatSessions\" that currently just contain a list of messages per loaded conversation. All messages are mapped to \"HumanMessage\" objects to start. \n",
|
|
"\n",
|
|
"You can optionally choose to merge message \"runs\" (consecutive messages from the same sender) and select a sender to represent the \"AI\". The fine-tuned LLM will learn to generate these AI messages."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "fcd69b3e-020d-4a15-8a0d-61c2d34e1ee1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import List\n",
|
|
"from langchain.chat_loaders.base import ChatSession\n",
|
|
"from langchain.chat_loaders.utils import (\n",
|
|
" map_ai_messages,\n",
|
|
" merge_chat_runs,\n",
|
|
")\n",
|
|
"\n",
|
|
"raw_messages = loader.lazy_load()\n",
|
|
"# Merge consecutive messages from the same sender into a single message\n",
|
|
"merged_messages = merge_chat_runs(raw_messages)\n",
|
|
"# Convert messages from \"Tortoise\" to AI messages. Do you have a guess who these conversations are between?\n",
|
|
"chat_sessions: List[ChatSession] = list(map_ai_messages(merged_messages, sender=\"Tortoise\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "370b8c26-c7a8-434c-a225-45c20ff14a03",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[AIMessage(content=\"Slow and steady, that's my motto.\", additional_kwargs={'message_time': 1693182723, 'sender': 'Tortoise'}, example=False),\n",
|
|
" HumanMessage(content='Speed is key!', additional_kwargs={'message_time': 1693182753, 'sender': 'Hare'}, example=False),\n",
|
|
" AIMessage(content='A balanced approach is more reliable.', additional_kwargs={'message_time': 1693182783, 'sender': 'Tortoise'}, example=False)]"
|
|
]
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Now all of the Tortoise's messages will take the AI message class\n",
|
|
"# which maps to the 'assistant' role in OpenAI's training format\n",
|
|
"alternating_sessions[0]['messages'][:3]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "05208f9d-3193-4a8d-86a5-13df2c8197e5",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 3. Prepare for fine-tuning\n",
|
|
"\n",
|
|
"Now it's time to convert our chat messages to OpenAI dictionaries. We can use the `convert_messages_for_finetuning` utility to do so."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "8834861f-f37f-4c08-96c6-917269bf09b8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.adapters.openai import convert_messages_for_finetuning"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "ce7ab0f9-6e6a-4a1c-8b86-c635251d437e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Prepared 10 dialogues for training\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"training_data = convert_messages_for_finetuning(alternating_sessions)\n",
|
|
"print(f\"Prepared {len(training_data)} dialogues for training\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b494d64c-8056-42ae-b4c1-a9cfabc002ea",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 4. Fine-tune the model\n",
|
|
"\n",
|
|
"It's time to fine-tune the model. Make sure you have `openai` installed\n",
|
|
"and have set your `OPENAI_API_KEY` appropriately"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "b4b60daa-b899-4291-a09a-412ce9c218fc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# %pip install -U openai --quiet"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "2cca6c95-c0d6-4826-b4fa-1c403f217f93",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"File file-zHIgf4r8LltZG3RFpkGd4Sjf ready after 10.19 seconds.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import json\n",
|
|
"from io import BytesIO\n",
|
|
"import time\n",
|
|
"\n",
|
|
"import openai\n",
|
|
"\n",
|
|
"# We will write the jsonl file in memory\n",
|
|
"my_file = BytesIO()\n",
|
|
"for m in training_data:\n",
|
|
" my_file.write((json.dumps({\"messages\": m}) + \"\\n\").encode('utf-8'))\n",
|
|
"\n",
|
|
"my_file.seek(0)\n",
|
|
"training_file = openai.File.create(\n",
|
|
" file=my_file,\n",
|
|
" purpose='fine-tune'\n",
|
|
")\n",
|
|
"\n",
|
|
"# OpenAI audits each training file for compliance reasons.\n",
|
|
"# This make take a few minutes\n",
|
|
"status = openai.File.retrieve(training_file.id).status\n",
|
|
"start_time = time.time()\n",
|
|
"while status != \"processed\":\n",
|
|
" print(f\"Status=[{status}]... {time.time() - start_time:.2f}s\", end=\"\\r\", flush=True)\n",
|
|
" time.sleep(5)\n",
|
|
" status = openai.File.retrieve(training_file.id).status\n",
|
|
"print(f\"File {training_file.id} ready after {time.time() - start_time:.2f} seconds.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "60ee0476-3113-4dc8-a886-bce878c60b07",
|
|
"metadata": {},
|
|
"source": [
|
|
"With the file ready, it's time to kick off a training job."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "c376ddca-5b4f-4e5a-bf4e-6beeb467eacc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"job = openai.FineTuningJob.create(\n",
|
|
" training_file=training_file.id,\n",
|
|
" model=\"gpt-3.5-turbo\",\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "09344c60-0bee-4989-b8d1-4a8821553cc3",
|
|
"metadata": {},
|
|
"source": [
|
|
"Grab a cup of tea while your model is being prepared. This may take some time!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "22eae900-04ca-456b-ba51-1dfff1f8e0e1",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Status=[running]... 524.95s\r"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"status = openai.FineTuningJob.retrieve(job.id).status\n",
|
|
"start_time = time.time()\n",
|
|
"while status != \"succeeded\":\n",
|
|
" print(f\"Status=[{status}]... {time.time() - start_time:.2f}s\", end=\"\\r\", flush=True)\n",
|
|
" time.sleep(5)\n",
|
|
" job = openai.FineTuningJob.retrieve(job.id)\n",
|
|
" status = job.status"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "39e72616-a7d9-44b8-a4eb-506611d119f4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"ft:gpt-3.5-turbo-0613:personal::7sKoRdlz\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(job.fine_tuned_model)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0d717749-b1b6-451f-b3c5-3286b82d45b9",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 5. Use in LangChain\n",
|
|
"\n",
|
|
"You can use the resulting model ID directly the `ChatOpenAI` model class."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "1579dfca-95c6-47b7-8549-1195b9dce5b0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.chat_models import ChatOpenAI\n",
|
|
"\n",
|
|
"model = ChatOpenAI(\n",
|
|
" model=job.fine_tuned_model,\n",
|
|
" temperature=1,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"id": "6f53d1b1-dcbf-4976-a61a-17f74c6f1b0a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.prompts import ChatPromptTemplate\n",
|
|
"from langchain.schema.output_parser import StrOutputParser\n",
|
|
"\n",
|
|
"prompt = ChatPromptTemplate.from_messages(\n",
|
|
" [\n",
|
|
" (\"system\", \"You are speaking to hare.\"),\n",
|
|
" (\"human\", \"{input}\"),\n",
|
|
" ]\n",
|
|
")\n",
|
|
"\n",
|
|
"chain = prompt | model | StrOutputParser()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"id": "6619c9bc-54ea-4136-bd9a-44557f7da724",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"A symbol of interconnectedness."
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for tok in chain.stream({\"input\": \"What's the golden thread?\"}):\n",
|
|
" print(tok, end=\"\", flush=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "88e0d1a1-48a9-4d9d-9f4e-010cbbb65af8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.2"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|