{ "cells": [ { "cell_type": "markdown", "id": "e93283d1", "metadata": {}, "source": [ "# Selecting LLMs based on Context Length\n", "\n", "Different LLMs have different context lengths. As a very immediate an practical example, OpenAI has two versions of GPT-3.5-Turbo: one with 4k context, another with 16k context. This notebook shows how to route between them based on input." ] }, { "cell_type": "code", "execution_count": 24, "id": "cc453450", "metadata": {}, "outputs": [], "source": [ "from langchain.prompts import PromptTemplate\n", "from langchain_core.output_parsers import StrOutputParser\n", "from langchain_core.prompt_values import PromptValue\n", "from langchain_openai import ChatOpenAI" ] }, { "cell_type": "code", "execution_count": 3, "id": "1cec6a10", "metadata": {}, "outputs": [], "source": [ "short_context_model = ChatOpenAI(model=\"gpt-3.5-turbo\")\n", "long_context_model = ChatOpenAI(model=\"gpt-3.5-turbo-16k\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "772da153", "metadata": {}, "outputs": [], "source": [ "def get_context_length(prompt: PromptValue):\n", " messages = prompt.to_messages()\n", " tokens = short_context_model.get_num_tokens_from_messages(messages)\n", " return tokens" ] }, { "cell_type": "code", "execution_count": 5, "id": "db771e20", "metadata": {}, "outputs": [], "source": [ "prompt = PromptTemplate.from_template(\"Summarize this passage: {context}\")" ] }, { "cell_type": "code", "execution_count": 20, "id": "af057e2f", "metadata": {}, "outputs": [], "source": [ "def choose_model(prompt: PromptValue):\n", " context_len = get_context_length(prompt)\n", " if context_len < 30:\n", " print(\"short model\")\n", " return short_context_model\n", " else:\n", " print(\"long model\")\n", " return long_context_model" ] }, { "cell_type": "code", "execution_count": 25, "id": "84f3e07d", "metadata": {}, "outputs": [], "source": [ "chain = prompt | choose_model | StrOutputParser()" ] }, { "cell_type": "code", "execution_count": 26, "id": "d8b14f8f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "short model\n" ] }, { "data": { "text/plain": [ "'The passage mentions that a frog visited a pond.'" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke({\"context\": \"a frog went to a pond\"})" ] }, { "cell_type": "code", "execution_count": 27, "id": "70ebd3dd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "long model\n" ] }, { "data": { "text/plain": [ "'The passage describes a frog that moved from one pond to another and perched on a log.'" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chain.invoke(\n", " {\"context\": \"a frog went to a pond and sat on a log and went to a different pond\"}\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "a7e29fef", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.1" } }, "nbformat": 4, "nbformat_minor": 5 }