forked from Archives/langchain
Merge branch 'master' into harrison/agent_multi_inputs
commit
c7c38dd3df
@ -0,0 +1,2 @@
|
|||||||
|
[run]
|
||||||
|
omit = tests/*
|
@ -1,10 +1,35 @@
|
|||||||
Prompts
|
LLMs & Prompts
|
||||||
=======
|
==============
|
||||||
|
|
||||||
|
The examples here all highlight how to work with LLMs and prompts.
|
||||||
|
|
||||||
|
**LLMs**
|
||||||
|
|
||||||
|
`LLM Functionality <prompts/llm_functionality.ipynb>`_: A walkthrough of all the functionality the standard LLM interface exposes.
|
||||||
|
|
||||||
|
`LLM Serialization <prompts/llm_serialization.ipynb>`_: A walkthrough of how to serialize LLMs to and from disk.
|
||||||
|
|
||||||
|
`Custom LLM <prompts/custom_llm.ipynb>`_: How to create and use a custom LLM class, in case you have an LLM not from one of the standard providers (including one that you host yourself).
|
||||||
|
|
||||||
|
|
||||||
|
**Prompts**
|
||||||
|
|
||||||
|
`Prompt Management <prompts/prompt_management.ipynb>`_: A walkthrough of all the functionality LangChain supports for working with prompts.
|
||||||
|
|
||||||
|
`Prompt Serialization <prompts/prompt_serialization.ipynb>`_: A walkthrough of how to serialize prompts to and from disk.
|
||||||
|
|
||||||
|
`Few Shot Examples <prompts/few_shot_examples.ipynb>`_: How to include examples in the prompt.
|
||||||
|
|
||||||
|
`Generate Examples <prompts/generate_examples.ipynb>`_: How to use existing examples to generate more examples.
|
||||||
|
|
||||||
|
`Custom Example Selector <prompts/custom_example_selector.ipynb>`_: How to create and use a custom ExampleSelector (the class responsible for choosing which examples to use in a prompt).
|
||||||
|
|
||||||
|
`Custom Prompt Template <prompts/custom_prompt_template.ipynb>`_: How to create and use a custom PromptTemplate, the logic that decides how input variables get formatted into a prompt.
|
||||||
|
|
||||||
The examples here all highlight how to work with prompts.
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
:glob:
|
:glob:
|
||||||
|
:hidden:
|
||||||
|
|
||||||
prompts/*
|
prompts/*
|
||||||
|
@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"model_name": "text-davinci-003",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 256,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"n": 1,
|
||||||
|
"best_of": 1,
|
||||||
|
"_type": "openai"
|
||||||
|
}
|
@ -0,0 +1,9 @@
|
|||||||
|
_type: openai
|
||||||
|
best_of: 1
|
||||||
|
frequency_penalty: 0.0
|
||||||
|
max_tokens: 256
|
||||||
|
model_name: text-davinci-003
|
||||||
|
n: 1
|
||||||
|
presence_penalty: 0.0
|
||||||
|
temperature: 0.7
|
||||||
|
top_p: 1.0
|
@ -0,0 +1,412 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "20ac6b98",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# LLM Functionality\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over all the different features of the LLM class in LangChain.\n",
|
||||||
|
"\n",
|
||||||
|
"We will work with an OpenAI LLM wrapper, although these functionalities should exist for all LLM types."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "df924055",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import OpenAI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "182b484c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = OpenAI(model_name=\"text-ada-001\", n=2, best_of=2)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9695ccfc",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"**Generate Text:** The most basic functionality an LLM has is just the ability to call it, passing in a string and getting back a string."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "9d12ac26",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"llm(\"Tell me a joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e7d4d42d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"**Generate:** More broadly, you can call it with a list of inputs, getting back a more complete response than just the text. This complete response includes things like multiple top responses, as well as LLM provider specific information"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "f4dc241a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_result = llm.generate([\"Tell me a joke\", \"Tell me a poem\"]*15)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "740392f6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"30"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"len(llm_result.generations)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "ab6cdcf1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'),\n",
|
||||||
|
" Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!')]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"llm_result.generations[0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "4946a778",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[Generation(text=\"\\n\\nA rose by the side of the road\\n\\nIs all I need to find my way\\n\\nTo the place I've been searching for\\n\\nAnd my heart is singing with joy\\n\\nWhen I look at this rose\\n\\nIt reminds me of the love I've found\\n\\nAnd I know that wherever I go\\n\\nI'll always find my rose by the side of the road.\"),\n",
|
||||||
|
" Generation(text=\"\\n\\nWhen I was younger\\nI thought that love\\nI was something like a fairytale\\nI would find my prince and they would be my people\\nI was naïve\\nI thought that\\n\\nLove was a something that happened\\nWhen I was younger\\nI was it for my fairytale prince\\nNow I realize\\nThat love is something that waits\\nFor when my prince comes\\nAnd when I am ready to be his wife\\nI'll tell you a poem\\n\\nWhen I was younger\\nI thought that love\\nI was something like a fairytale\\nI would find my prince and they would be my people\\nI was naïve\\nI thought that\\n\\nLove was a something that happened\\nAnd I would be happy\\nWhen my prince came\\nAnd I was ready to be his wife\")]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"llm_result.generations[-1]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "242e4527",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'token_usage': {'completion_tokens': 3722,\n",
|
||||||
|
" 'prompt_tokens': 120,\n",
|
||||||
|
" 'total_tokens': 3842}}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Provider specific info\n",
|
||||||
|
"llm_result.llm_output"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bde8e04f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"**Number of Tokens:** You can also estimate how many tokens a piece of text will be in that model. This is useful because models have a context length (and cost more for more tokens), which means you need to be aware of how long the text you are passing in is.\n",
|
||||||
|
"\n",
|
||||||
|
"Notice that by default the tokens are estimated using a HuggingFace tokenizer."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "b623c774",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"3"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"llm.get_num_tokens(\"what a joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ee6fcf8d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Caching\n",
|
||||||
|
"With LangChain, you can also enable caching of LLM calls. Note that currently this only applies for individual LLM calls."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "2626ca48",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import langchain\n",
|
||||||
|
"from langchain.cache import InMemoryCache\n",
|
||||||
|
"langchain.llm_cache = InMemoryCache()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "97762272",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# To make the caching really obvious, lets use a slower model.\n",
|
||||||
|
"llm = OpenAI(model_name=\"text-davinci-002\", n=2, best_of=2)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "e80c65e4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CPU times: user 31.2 ms, sys: 11.8 ms, total: 43.1 ms\n",
|
||||||
|
"Wall time: 1.75 s\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%%time\n",
|
||||||
|
"# The first time, it is not yet in cache, so it should take longer\n",
|
||||||
|
"llm(\"Tell me a joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "678408ec",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CPU times: user 51 µs, sys: 1 µs, total: 52 µs\n",
|
||||||
|
"Wall time: 67.2 µs\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%%time\n",
|
||||||
|
"# The second time it is, so it goes faster\n",
|
||||||
|
"llm(\"Tell me a joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "3f0ac8d2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# We can do the same thing with a SQLite cache\n",
|
||||||
|
"from langchain.cache import SQLiteCache\n",
|
||||||
|
"langchain.llm_cache = SQLiteCache(database_path=\".langchain.db\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "0e1dcce3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CPU times: user 26.6 ms, sys: 11.2 ms, total: 37.7 ms\n",
|
||||||
|
"Wall time: 1.89 s\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%%time\n",
|
||||||
|
"# The first time, it is not yet in cache, so it should take longer\n",
|
||||||
|
"llm(\"Tell me a joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "efadd750",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CPU times: user 2.69 ms, sys: 1.57 ms, total: 4.27 ms\n",
|
||||||
|
"Wall time: 2.73 ms\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%%time\n",
|
||||||
|
"# The second time it is, so it goes faster\n",
|
||||||
|
"llm(\"Tell me a joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6053408b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# You can use SQLAlchemyCache to cache with any SQL database supported by SQLAlchemy.\n",
|
||||||
|
"from langchain.cache import SQLAlchemyCache\n",
|
||||||
|
"from sqlalchemy import create_engine\n",
|
||||||
|
"\n",
|
||||||
|
"engine = create_engine(\"postgresql://postgres:postgres@localhost:5432/postgres\")\n",
|
||||||
|
"langchain.llm_cache = SQLAlchemyCache(engine)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "base",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.12 (main, Jun 1 2022, 06:34:44) \n[Clang 12.0.0 ]"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "1235b9b19e8e9828b5c1fdb2cd89fe8d3de0fcde5ef5f3db36e4b671adb8660f"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,166 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "73f9bf40",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# LLM Serialization\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook walks how to write and read an LLM Configuration to and from disk. This is useful if you want to save the configuration for a given LLM (eg the provider, the temperature, etc)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "9c9fb6ff",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import OpenAI\n",
|
||||||
|
"from langchain.llms.loading import load_llm"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "88ce018b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Loading\n",
|
||||||
|
"First, lets go over loading a LLM from disk. LLMs can be saved on disk in two formats: json or yaml. No matter the extension, they are loaded in the same way."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "f12b28f3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{\r\n",
|
||||||
|
" \"model_name\": \"text-davinci-003\",\r\n",
|
||||||
|
" \"temperature\": 0.7,\r\n",
|
||||||
|
" \"max_tokens\": 256,\r\n",
|
||||||
|
" \"top_p\": 1,\r\n",
|
||||||
|
" \"frequency_penalty\": 0,\r\n",
|
||||||
|
" \"presence_penalty\": 0,\r\n",
|
||||||
|
" \"n\": 1,\r\n",
|
||||||
|
" \"best_of\": 1,\r\n",
|
||||||
|
" \"_type\": \"openai\"\r\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!cat llm.json"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "9ab709fc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = load_llm(\"llm.json\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "095b1d56",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"_type: openai\r\n",
|
||||||
|
"best_of: 1\r\n",
|
||||||
|
"frequency_penalty: 0\r\n",
|
||||||
|
"max_tokens: 256\r\n",
|
||||||
|
"model_name: text-davinci-003\r\n",
|
||||||
|
"n: 1\r\n",
|
||||||
|
"presence_penalty: 0\r\n",
|
||||||
|
"temperature: 0.7\r\n",
|
||||||
|
"top_p: 1\r\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!cat llm.yaml"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "8cafaafe",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = load_llm(\"llm.yaml\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ab3e4223",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Saving\n",
|
||||||
|
"If you want to go from a LLM in memory to a serialized version of it, you can do so easily by calling the `.save` method. Again, this supports both json and yaml."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "b38f685d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm.save(\"llm.json\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "b7365503",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm.save(\"llm.yaml\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0e494851",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.8"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,118 @@
|
|||||||
|
"""Beta Feature: base interface for cache."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, String, create_engine, select
|
||||||
|
from sqlalchemy.engine.base import Engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from langchain.schema import Generation
|
||||||
|
|
||||||
|
RETURN_VAL_TYPE = Union[List[Generation], str]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCache(ABC):
|
||||||
|
"""Base interface for cache."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
|
"""Look up based on prompt and llm_string."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
|
"""Update cache based on prompt and llm_string."""
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryCache(BaseCache):
|
||||||
|
"""Cache that stores things in memory."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize with empty cache."""
|
||||||
|
self._cache: Dict[Tuple[str, str], RETURN_VAL_TYPE] = {}
|
||||||
|
|
||||||
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
|
"""Look up based on prompt and llm_string."""
|
||||||
|
return self._cache.get((prompt, llm_string), None)
|
||||||
|
|
||||||
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
|
"""Update cache based on prompt and llm_string."""
|
||||||
|
self._cache[(prompt, llm_string)] = return_val
|
||||||
|
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCache(Base): # type: ignore
|
||||||
|
"""SQLite table for simple LLM cache (string only)."""
|
||||||
|
|
||||||
|
__tablename__ = "llm_cache"
|
||||||
|
prompt = Column(String, primary_key=True)
|
||||||
|
llm = Column(String, primary_key=True)
|
||||||
|
response = Column(String)
|
||||||
|
|
||||||
|
|
||||||
|
class FullLLMCache(Base): # type: ignore
|
||||||
|
"""SQLite table for full LLM Cache (all generations)."""
|
||||||
|
|
||||||
|
__tablename__ = "full_llm_cache"
|
||||||
|
prompt = Column(String, primary_key=True)
|
||||||
|
llm = Column(String, primary_key=True)
|
||||||
|
idx = Column(Integer, primary_key=True)
|
||||||
|
response = Column(String)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLAlchemyCache(BaseCache):
|
||||||
|
"""Cache that uses SQAlchemy as a backend."""
|
||||||
|
|
||||||
|
def __init__(self, engine: Engine):
|
||||||
|
"""Initialize by creating all tables."""
|
||||||
|
self.engine = engine
|
||||||
|
Base.metadata.create_all(self.engine)
|
||||||
|
|
||||||
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
|
"""Look up based on prompt and llm_string."""
|
||||||
|
stmt = (
|
||||||
|
select(FullLLMCache.response)
|
||||||
|
.where(FullLLMCache.prompt == prompt)
|
||||||
|
.where(FullLLMCache.llm == llm_string)
|
||||||
|
.order_by(FullLLMCache.idx)
|
||||||
|
)
|
||||||
|
with Session(self.engine) as session:
|
||||||
|
generations = []
|
||||||
|
for row in session.execute(stmt):
|
||||||
|
generations.append(Generation(text=row[0]))
|
||||||
|
if len(generations) > 0:
|
||||||
|
return generations
|
||||||
|
stmt = (
|
||||||
|
select(LLMCache.response)
|
||||||
|
.where(LLMCache.prompt == prompt)
|
||||||
|
.where(LLMCache.llm == llm_string)
|
||||||
|
)
|
||||||
|
with Session(self.engine) as session:
|
||||||
|
for row in session.execute(stmt):
|
||||||
|
return row[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
|
"""Look up based on prompt and llm_string."""
|
||||||
|
if isinstance(return_val, str):
|
||||||
|
item = LLMCache(prompt=prompt, llm=llm_string, response=return_val)
|
||||||
|
with Session(self.engine) as session, session.begin():
|
||||||
|
session.add(item)
|
||||||
|
else:
|
||||||
|
for i, generation in enumerate(return_val):
|
||||||
|
item = FullLLMCache(
|
||||||
|
prompt=prompt, llm=llm_string, response=generation.text, idx=i
|
||||||
|
)
|
||||||
|
with Session(self.engine) as session, session.begin():
|
||||||
|
session.add(item)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteCache(SQLAlchemyCache):
|
||||||
|
"""Cache that uses SQLite as a backend."""
|
||||||
|
|
||||||
|
def __init__(self, database_path: str = ".langchain.db"):
|
||||||
|
"""Initialize by creating the engine and all tables."""
|
||||||
|
engine = create_engine(f"sqlite:///{database_path}")
|
||||||
|
super().__init__(engine)
|
@ -1,7 +1,28 @@
|
|||||||
"""Wrappers on top of large language models APIs."""
|
"""Wrappers on top of large language models APIs."""
|
||||||
|
from typing import Dict, Type
|
||||||
|
|
||||||
|
from langchain.llms.ai21 import AI21
|
||||||
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.cohere import Cohere
|
from langchain.llms.cohere import Cohere
|
||||||
from langchain.llms.huggingface_hub import HuggingFaceHub
|
from langchain.llms.huggingface_hub import HuggingFaceHub
|
||||||
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
from langchain.llms.nlpcloud import NLPCloud
|
from langchain.llms.nlpcloud import NLPCloud
|
||||||
from langchain.llms.openai import OpenAI
|
from langchain.llms.openai import OpenAI
|
||||||
|
|
||||||
__all__ = ["Cohere", "NLPCloud", "OpenAI", "HuggingFaceHub"]
|
__all__ = [
|
||||||
|
"Cohere",
|
||||||
|
"NLPCloud",
|
||||||
|
"OpenAI",
|
||||||
|
"HuggingFaceHub",
|
||||||
|
"HuggingFacePipeline",
|
||||||
|
"AI21",
|
||||||
|
]
|
||||||
|
|
||||||
|
type_to_cls_dict: Dict[str, Type[LLM]] = {
|
||||||
|
"ai21": AI21,
|
||||||
|
"cohere": Cohere,
|
||||||
|
"huggingface_hub": HuggingFaceHub,
|
||||||
|
"nlpcloud": NLPCloud,
|
||||||
|
"openai": OpenAI,
|
||||||
|
"huggingface_pipeline": HuggingFacePipeline,
|
||||||
|
}
|
||||||
|
@ -0,0 +1,118 @@
|
|||||||
|
"""Wrapper around HuggingFace Pipeline APIs."""
|
||||||
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
|
DEFAULT_MODEL_ID = "gpt2"
|
||||||
|
DEFAULT_TASK = "text-generation"
|
||||||
|
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFacePipeline(LLM, BaseModel):
|
||||||
|
"""Wrapper around HuggingFace Pipeline API.
|
||||||
|
|
||||||
|
To use, you should have the ``transformers`` python package installed.
|
||||||
|
|
||||||
|
Only supports `text-generation` and `text2text-generation` for now.
|
||||||
|
|
||||||
|
Example using from_model_id:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
|
hf = HuggingFacePipeline.from_model_id(
|
||||||
|
model_id="gpt2", task="text-generation"
|
||||||
|
)
|
||||||
|
Example passing pipeline in directly:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||||
|
|
||||||
|
model_id = "gpt2"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
|
pipe = pipeline(
|
||||||
|
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10
|
||||||
|
)
|
||||||
|
hf = HuggingFacePipeline(pipeline=pipe
|
||||||
|
"""
|
||||||
|
|
||||||
|
pipeline: Any #: :meta private:
|
||||||
|
model_id: str = DEFAULT_MODEL_ID
|
||||||
|
"""Model name to use."""
|
||||||
|
model_kwargs: Optional[dict] = None
|
||||||
|
"""Key word arguments to pass to the model."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model_id(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
task: str,
|
||||||
|
model_kwargs: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLM:
|
||||||
|
"""Construct the pipeline object from model_id and task."""
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from transformers import pipeline as hf_pipeline
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
|
pipeline = hf_pipeline(
|
||||||
|
task=task, model=model, tokenizer=tokenizer, **model_kwargs
|
||||||
|
)
|
||||||
|
if pipeline.task not in VALID_TASKS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got invalid task {pipeline.task}, "
|
||||||
|
f"currently only {VALID_TASKS} are supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
pipeline=pipeline,
|
||||||
|
model_id=model_id,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import transformers python package. "
|
||||||
|
"Please it install it with `pip install transformers`."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
**{"model_id": self.model_id},
|
||||||
|
**{"model_kwargs": self.model_kwargs},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "huggingface_pipeline"
|
||||||
|
|
||||||
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
response = self.pipeline(text_inputs=prompt)
|
||||||
|
if self.pipeline.task == "text-generation":
|
||||||
|
# Text generation return includes the starter text.
|
||||||
|
text = response[0]["generated_text"][len(prompt) :]
|
||||||
|
elif self.pipeline.task == "text2text-generation":
|
||||||
|
text = response[0]["generated_text"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got invalid task {self.pipeline.task}, "
|
||||||
|
f"currently only {VALID_TASKS} are supported"
|
||||||
|
)
|
||||||
|
if stop is not None:
|
||||||
|
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||||
|
# stop tokens when making calls to huggingface_hub.
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
@ -0,0 +1,42 @@
|
|||||||
|
"""Base interface for loading large language models apis."""
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from langchain.llms import type_to_cls_dict
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def load_llm_from_config(config: dict) -> LLM:
|
||||||
|
"""Load LLM from Config Dict."""
|
||||||
|
if "_type" not in config:
|
||||||
|
raise ValueError("Must specify an LLM Type in config")
|
||||||
|
config_type = config.pop("_type")
|
||||||
|
|
||||||
|
if config_type not in type_to_cls_dict:
|
||||||
|
raise ValueError(f"Loading {config_type} LLM not supported")
|
||||||
|
|
||||||
|
llm_cls = type_to_cls_dict[config_type]
|
||||||
|
return llm_cls(**config)
|
||||||
|
|
||||||
|
|
||||||
|
def load_llm(file: Union[str, Path]) -> LLM:
|
||||||
|
"""Load LLM from file."""
|
||||||
|
# Convert file to Path object.
|
||||||
|
if isinstance(file, str):
|
||||||
|
file_path = Path(file)
|
||||||
|
else:
|
||||||
|
file_path = file
|
||||||
|
# Load from either json or yaml.
|
||||||
|
if file_path.suffix == ".json":
|
||||||
|
with open(file_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
elif file_path.suffix == ".yaml":
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
else:
|
||||||
|
raise ValueError("File type must be json or yaml")
|
||||||
|
# Load the LLM from the config now.
|
||||||
|
return load_llm_from_config(config)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,41 @@
|
|||||||
|
"""Test HuggingFace Pipeline wrapper."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||||
|
|
||||||
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
|
from langchain.llms.loading import load_llm
|
||||||
|
from tests.integration_tests.llms.utils import assert_llm_equality
|
||||||
|
|
||||||
|
|
||||||
|
def test_huggingface_pipeline_text_generation() -> None:
|
||||||
|
"""Test valid call to HuggingFace text generation model."""
|
||||||
|
llm = HuggingFacePipeline.from_model_id(
|
||||||
|
model_id="gpt2", task="text-generation", model_kwargs={"max_new_tokens": 10}
|
||||||
|
)
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||||
|
"""Test saving/loading an HuggingFaceHub LLM."""
|
||||||
|
llm = HuggingFacePipeline.from_model_id(
|
||||||
|
model_id="gpt2", task="text-generation", model_kwargs={"max_new_tokens": 10}
|
||||||
|
)
|
||||||
|
llm.save(file_path=tmp_path / "hf.yaml")
|
||||||
|
loaded_llm = load_llm(tmp_path / "hf.yaml")
|
||||||
|
assert_llm_equality(llm, loaded_llm)
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_with_pipeline() -> None:
|
||||||
|
"""Test initialization with a HF pipeline."""
|
||||||
|
model_id = "gpt2"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
|
pipe = pipeline(
|
||||||
|
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10
|
||||||
|
)
|
||||||
|
llm = HuggingFacePipeline(pipeline=pipe)
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
@ -0,0 +1,16 @@
|
|||||||
|
"""Utils for LLM Tests."""
|
||||||
|
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None:
|
||||||
|
"""Assert LLM Equality for tests."""
|
||||||
|
# Check that they are the same type.
|
||||||
|
assert type(llm) == type(loaded_llm)
|
||||||
|
# Client field can be session based, so hash is different despite
|
||||||
|
# all other values being the same, so just assess all other fields
|
||||||
|
for field in llm.__fields__.keys():
|
||||||
|
if field != "client" and field != "pipeline":
|
||||||
|
val = getattr(llm, field)
|
||||||
|
new_val = getattr(loaded_llm, field)
|
||||||
|
assert new_val == val
|
@ -0,0 +1,118 @@
|
|||||||
|
"""Test functionality related to combining documents."""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.chains.combine_documents.map_reduce import (
|
||||||
|
_collapse_docs,
|
||||||
|
_split_list_of_docs,
|
||||||
|
)
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_docs_len_func(docs: List[Document]) -> int:
|
||||||
|
return len(_fake_combine_docs_func(docs))
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_combine_docs_func(docs: List[Document]) -> str:
|
||||||
|
return "".join([d.page_content for d in docs])
|
||||||
|
|
||||||
|
|
||||||
|
def test__split_list_long_single_doc() -> None:
|
||||||
|
"""Test splitting of a long single doc."""
|
||||||
|
docs = [Document(page_content="foo" * 100)]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_split_list_of_docs(docs, _fake_docs_len_func, 100)
|
||||||
|
|
||||||
|
|
||||||
|
def test__split_list_long_pair_doc() -> None:
|
||||||
|
"""Test splitting of a list with two medium docs."""
|
||||||
|
docs = [Document(page_content="foo" * 30)] * 2
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_split_list_of_docs(docs, _fake_docs_len_func, 100)
|
||||||
|
|
||||||
|
|
||||||
|
def test__split_list_single_doc() -> None:
|
||||||
|
"""Test splitting works with just a single doc."""
|
||||||
|
docs = [Document(page_content="foo")]
|
||||||
|
doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 100)
|
||||||
|
assert doc_list == [docs]
|
||||||
|
|
||||||
|
|
||||||
|
def test__split_list_double_doc() -> None:
|
||||||
|
"""Test splitting works with just two docs."""
|
||||||
|
docs = [Document(page_content="foo"), Document(page_content="bar")]
|
||||||
|
doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 100)
|
||||||
|
assert doc_list == [docs]
|
||||||
|
|
||||||
|
|
||||||
|
def test__split_list_works_correctly() -> None:
|
||||||
|
"""Test splitting works correctly."""
|
||||||
|
docs = [
|
||||||
|
Document(page_content="foo"),
|
||||||
|
Document(page_content="bar"),
|
||||||
|
Document(page_content="baz"),
|
||||||
|
Document(page_content="foo" * 2),
|
||||||
|
Document(page_content="bar"),
|
||||||
|
Document(page_content="baz"),
|
||||||
|
]
|
||||||
|
doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 10)
|
||||||
|
expected_result = [
|
||||||
|
# Test a group of three.
|
||||||
|
[
|
||||||
|
Document(page_content="foo"),
|
||||||
|
Document(page_content="bar"),
|
||||||
|
Document(page_content="baz"),
|
||||||
|
],
|
||||||
|
# Test a group of two, where one is bigger.
|
||||||
|
[Document(page_content="foo" * 2), Document(page_content="bar")],
|
||||||
|
# Test no errors on last
|
||||||
|
[Document(page_content="baz")],
|
||||||
|
]
|
||||||
|
assert doc_list == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test__collapse_docs_no_metadata() -> None:
|
||||||
|
"""Test collapse documents functionality when no metadata."""
|
||||||
|
docs = [
|
||||||
|
Document(page_content="foo"),
|
||||||
|
Document(page_content="bar"),
|
||||||
|
Document(page_content="baz"),
|
||||||
|
]
|
||||||
|
output = _collapse_docs(docs, _fake_combine_docs_func)
|
||||||
|
expected_output = Document(page_content="foobarbaz")
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test__collapse_docs_one_doc() -> None:
|
||||||
|
"""Test collapse documents functionality when only one document present."""
|
||||||
|
# Test with no metadata.
|
||||||
|
docs = [Document(page_content="foo")]
|
||||||
|
output = _collapse_docs(docs, _fake_combine_docs_func)
|
||||||
|
assert output == docs[0]
|
||||||
|
|
||||||
|
# Test with metadata.
|
||||||
|
docs = [Document(page_content="foo", metadata={"source": "a"})]
|
||||||
|
output = _collapse_docs(docs, _fake_combine_docs_func)
|
||||||
|
assert output == docs[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test__collapse_docs_metadata() -> None:
|
||||||
|
"""Test collapse documents functionality when metadata exists."""
|
||||||
|
metadata1 = {"source": "a", "foo": 2, "bar": "1", "extra1": "foo"}
|
||||||
|
metadata2 = {"source": "b", "foo": "3", "bar": 2, "extra2": "bar"}
|
||||||
|
docs = [
|
||||||
|
Document(page_content="foo", metadata=metadata1),
|
||||||
|
Document(page_content="bar", metadata=metadata2),
|
||||||
|
]
|
||||||
|
output = _collapse_docs(docs, _fake_combine_docs_func)
|
||||||
|
expected_metadata = {
|
||||||
|
"source": "a, b",
|
||||||
|
"foo": "2, 3",
|
||||||
|
"bar": "1, 2",
|
||||||
|
"extra1": "foo",
|
||||||
|
"extra2": "bar",
|
||||||
|
}
|
||||||
|
expected_output = Document(page_content="foobar", metadata=expected_metadata)
|
||||||
|
assert output == expected_output
|
@ -0,0 +1,15 @@
|
|||||||
|
"""Test LLM saving and loading functions."""
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from langchain.llms.loading import load_llm
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
@patch("langchain.llms.loading.type_to_cls_dict", {"fake": FakeLLM})
|
||||||
|
def test_saving_loading_round_trip(tmp_path: Path) -> None:
|
||||||
|
"""Test saving/loading a Fake LLM."""
|
||||||
|
fake_llm = FakeLLM()
|
||||||
|
fake_llm.save(file_path=tmp_path / "fake_llm.yaml")
|
||||||
|
loaded_llm = load_llm(tmp_path / "fake_llm.yaml")
|
||||||
|
assert loaded_llm == fake_llm
|
Loading…
Reference in New Issue