mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
Merge branch 'master' into mitchell/fix-sqarql-bug
This commit is contained in:
commit
f17bf1f3bf
@ -77,7 +77,7 @@
|
|||||||
"id": "08f8b820-5912-49c1-9d76-40be0571dffb",
|
"id": "08f8b820-5912-49c1-9d76-40be0571dffb",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"This creates a retriever (specifically a [VectorStoreRetriever](https://python.langchain.com/api_reference/core/vectorstores/langchain_core.vectorstores.VectorStoreRetriever.html)), which we can use in the usual way:"
|
"This creates a retriever (specifically a [VectorStoreRetriever](https://python.langchain.com/api_reference/core/vectorstores/langchain_core.vectorstores.base.VectorStoreRetriever.html)), which we can use in the usual way:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -434,6 +434,160 @@
|
|||||||
"fine_tuned_model.invoke(messages)"
|
"fine_tuned_model.invoke(messages)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5d5d9793",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Multimodal Inputs\n",
|
||||||
|
"\n",
|
||||||
|
"OpenAI has models that support multimodal inputs. You can pass in images or audio to these models. For more information on how to do this in LangChain, head to the [multimodal inputs](/docs/how_to/multimodal_inputs) docs.\n",
|
||||||
|
"\n",
|
||||||
|
"You can see the list of models that support different modalities in [OpenAI's documentation](https://platform.openai.com/docs/models).\n",
|
||||||
|
"\n",
|
||||||
|
"At the time of this doc's writing, the main OpenAI models you would use would be:\n",
|
||||||
|
"\n",
|
||||||
|
"- Image inputs: `gpt-4o`, `gpt-4o-mini`\n",
|
||||||
|
"- Audio inputs: `gpt-4o-audio-preview`\n",
|
||||||
|
"\n",
|
||||||
|
"For an example of passing in image inputs, see the [multimodal inputs how-to guide](/docs/how_to/multimodal_inputs).\n",
|
||||||
|
"\n",
|
||||||
|
"Below is an example of passing audio inputs to `gpt-4o-audio-preview`:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "39d08780",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"\"I'm sorry, but I can't create audio content that involves yelling. Is there anything else I can help you with?\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import base64\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain_openai import ChatOpenAI\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatOpenAI(\n",
|
||||||
|
" model=\"gpt-4o-audio-preview\",\n",
|
||||||
|
" temperature=0,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"with open(\n",
|
||||||
|
" \"../../../../libs/partners/openai/tests/integration_tests/chat_models/audio_input.wav\",\n",
|
||||||
|
" \"rb\",\n",
|
||||||
|
") as f:\n",
|
||||||
|
" # b64 encode it\n",
|
||||||
|
" audio = f.read()\n",
|
||||||
|
" audio_b64 = base64.b64encode(audio).decode()\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"output_message = llm.invoke(\n",
|
||||||
|
" [\n",
|
||||||
|
" (\n",
|
||||||
|
" \"human\",\n",
|
||||||
|
" [\n",
|
||||||
|
" {\"type\": \"text\", \"text\": \"Transcribe the following:\"},\n",
|
||||||
|
" # the audio clip says \"I'm sorry, but I can't create...\"\n",
|
||||||
|
" {\n",
|
||||||
|
" \"type\": \"input_audio\",\n",
|
||||||
|
" \"input_audio\": {\"data\": audio_b64, \"format\": \"wav\"},\n",
|
||||||
|
" },\n",
|
||||||
|
" ],\n",
|
||||||
|
" ),\n",
|
||||||
|
" ]\n",
|
||||||
|
")\n",
|
||||||
|
"output_message.content"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "feb4a499",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Audio Generation (Preview)\n",
|
||||||
|
"\n",
|
||||||
|
":::info\n",
|
||||||
|
"Requires `langchain-openai>=0.2.3`\n",
|
||||||
|
":::\n",
|
||||||
|
"\n",
|
||||||
|
"OpenAI has a new [audio generation feature](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-out) that allows you to use audio inputs and outputs with the `gpt-4o-audio-preview` model."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "f67a2cac",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_openai import ChatOpenAI\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatOpenAI(\n",
|
||||||
|
" model=\"gpt-4o-audio-preview\",\n",
|
||||||
|
" temperature=0,\n",
|
||||||
|
" model_kwargs={\n",
|
||||||
|
" \"modalities\": [\"text\", \"audio\"],\n",
|
||||||
|
" \"audio\": {\"voice\": \"alloy\", \"format\": \"wav\"},\n",
|
||||||
|
" },\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"output_message = llm.invoke(\n",
|
||||||
|
" [\n",
|
||||||
|
" (\"human\", \"Are you made by OpenAI? Just answer yes or no\"),\n",
|
||||||
|
" ]\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b7dd4e8b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"`output_message.additional_kwargs['audio']` will contain a dictionary like\n",
|
||||||
|
"```python\n",
|
||||||
|
"{\n",
|
||||||
|
" 'data': '<audio data b64-encoded',\n",
|
||||||
|
" 'expires_at': 1729268602,\n",
|
||||||
|
" 'id': 'audio_67127d6a44348190af62c1530ef0955a',\n",
|
||||||
|
" 'transcript': 'Yes.'\n",
|
||||||
|
"}\n",
|
||||||
|
"```\n",
|
||||||
|
"and the format will be what was passed in `model_kwargs['audio']['format']`.\n",
|
||||||
|
"\n",
|
||||||
|
"We can also pass this message with audio data back to the model as part of a message history before openai `expires_at` is reached.\n",
|
||||||
|
"\n",
|
||||||
|
":::note\n",
|
||||||
|
"Output audio is stored under the `audio` key in `AIMessage.additional_kwargs`, but input content blocks are typed with an `input_audio` type and key in `HumanMessage.content` lists. \n",
|
||||||
|
"\n",
|
||||||
|
"For more information, see OpenAI's [audio docs](https://platform.openai.com/docs/guides/audio).\n",
|
||||||
|
":::"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "f5ae473d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"history = [\n",
|
||||||
|
" (\"human\", \"Are you made by OpenAI? Just answer yes or no\"),\n",
|
||||||
|
" output_message,\n",
|
||||||
|
" (\"human\", \"And what is your name? Just give your name.\"),\n",
|
||||||
|
"]\n",
|
||||||
|
"second_output_message = llm.invoke(history)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "a796d728-971b-408b-88d5-440015bbb941",
|
"id": "a796d728-971b-408b-88d5-440015bbb941",
|
||||||
@ -447,7 +601,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": ".venv",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
@ -85,6 +85,10 @@
|
|||||||
{
|
{
|
||||||
"source": "/docs/integrations/platforms/:path((?:anthropic|aws|google|huggingface|microsoft|openai)?/?)*",
|
"source": "/docs/integrations/platforms/:path((?:anthropic|aws|google|huggingface|microsoft|openai)?/?)*",
|
||||||
"destination": "/docs/integrations/providers/:path*"
|
"destination": "/docs/integrations/providers/:path*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": "/docs/troubleshooting/errors/:path((?:GRAPH_RECURSION_LIMIT|INVALID_CONCURRENT_GRAPH_UPDATE|INVALID_GRAPH_NODE_RETURN_VALUE|MULTIPLE_SUBGRAPHS)/?)*",
|
||||||
|
"destination": "https://langchain-ai.github.io/langgraph/troubleshooting/errors/:path*"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -173,21 +173,23 @@ class GradientLLM(BaseLLM):
|
|||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
},
|
},
|
||||||
json=dict(
|
json=dict(
|
||||||
samples=tuple(
|
samples=(
|
||||||
{
|
tuple(
|
||||||
"inputs": input,
|
{
|
||||||
}
|
"inputs": input,
|
||||||
for input in inputs
|
}
|
||||||
)
|
for input in inputs
|
||||||
if multipliers is None
|
)
|
||||||
else tuple(
|
if multipliers is None
|
||||||
{
|
else tuple(
|
||||||
"inputs": input,
|
{
|
||||||
"fineTuningParameters": {
|
"inputs": input,
|
||||||
"multiplier": multiplier,
|
"fineTuningParameters": {
|
||||||
},
|
"multiplier": multiplier,
|
||||||
}
|
},
|
||||||
for input, multiplier in zip(inputs, multipliers)
|
}
|
||||||
|
for input, multiplier in zip(inputs, multipliers)
|
||||||
|
)
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -338,9 +340,11 @@ class GradientLLM(BaseLLM):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
generations = []
|
generations = []
|
||||||
for generation in asyncio.gather(
|
for generation in await asyncio.gather(
|
||||||
[self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)]
|
*[
|
||||||
for prompt in prompts
|
self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
):
|
):
|
||||||
generations.append([Generation(text=generation)])
|
generations.append([Generation(text=generation)])
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
@ -64,7 +64,7 @@ class GoogleTrendsAPIWrapper(BaseModel):
|
|||||||
"q": query,
|
"q": query,
|
||||||
}
|
}
|
||||||
|
|
||||||
total_results = []
|
total_results: Any = []
|
||||||
client = self.serp_search_engine(params)
|
client = self.serp_search_engine(params)
|
||||||
client_dict = client.get_dict()
|
client_dict = client.get_dict()
|
||||||
total_results = (
|
total_results = (
|
||||||
|
71
libs/community/poetry.lock
generated
71
libs/community/poetry.lock
generated
@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohappyeyeballs"
|
name = "aiohappyeyeballs"
|
||||||
@ -1801,7 +1801,7 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain"
|
name = "langchain"
|
||||||
version = "0.3.3"
|
version = "0.3.4"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9,<4.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
@ -1811,7 +1811,7 @@ develop = true
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
aiohttp = "^3.8.3"
|
aiohttp = "^3.8.3"
|
||||||
async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""}
|
async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""}
|
||||||
langchain-core = "^0.3.10"
|
langchain-core = "^0.3.12"
|
||||||
langchain-text-splitters = "^0.3.0"
|
langchain-text-splitters = "^0.3.0"
|
||||||
langsmith = "^0.1.17"
|
langsmith = "^0.1.17"
|
||||||
numpy = [
|
numpy = [
|
||||||
@ -1830,7 +1830,7 @@ url = "../langchain"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.3.10"
|
version = "0.3.12"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9,<4.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
@ -2146,38 +2146,43 @@ typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""}
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mypy"
|
name = "mypy"
|
||||||
version = "1.11.2"
|
version = "1.12.0"
|
||||||
description = "Optional static typing for Python"
|
description = "Optional static typing for Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "mypy-1.11.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d42a6dd818ffce7be66cce644f1dff482f1d97c53ca70908dff0b9ddc120b77a"},
|
{file = "mypy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4397081e620dc4dc18e2f124d5e1d2c288194c2c08df6bdb1db31c38cd1fe1ed"},
|
||||||
{file = "mypy-1.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:801780c56d1cdb896eacd5619a83e427ce436d86a3bdf9112527f24a66618fef"},
|
{file = "mypy-1.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:684a9c508a283f324804fea3f0effeb7858eb03f85c4402a967d187f64562469"},
|
||||||
{file = "mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41ea707d036a5307ac674ea172875f40c9d55c5394f888b168033177fce47383"},
|
{file = "mypy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cabe4cda2fa5eca7ac94854c6c37039324baaa428ecbf4de4567279e9810f9e"},
|
||||||
{file = "mypy-1.11.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e658bd2d20565ea86da7d91331b0eed6d2eee22dc031579e6297f3e12c758c8"},
|
{file = "mypy-1.12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:060a07b10e999ac9e7fa249ce2bdcfa9183ca2b70756f3bce9df7a92f78a3c0a"},
|
||||||
{file = "mypy-1.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:478db5f5036817fe45adb7332d927daa62417159d49783041338921dcf646fc7"},
|
{file = "mypy-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:0eff042d7257f39ba4ca06641d110ca7d2ad98c9c1fb52200fe6b1c865d360ff"},
|
||||||
{file = "mypy-1.11.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75746e06d5fa1e91bfd5432448d00d34593b52e7e91a187d981d08d1f33d4385"},
|
{file = "mypy-1.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4b86de37a0da945f6d48cf110d5206c5ed514b1ca2614d7ad652d4bf099c7de7"},
|
||||||
{file = "mypy-1.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a976775ab2256aadc6add633d44f100a2517d2388906ec4f13231fafbb0eccca"},
|
{file = "mypy-1.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20c7c5ce0c1be0b0aea628374e6cf68b420bcc772d85c3c974f675b88e3e6e57"},
|
||||||
{file = "mypy-1.11.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cd953f221ac1379050a8a646585a29574488974f79d8082cedef62744f0a0104"},
|
{file = "mypy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a64ee25f05fc2d3d8474985c58042b6759100a475f8237da1f4faf7fcd7e6309"},
|
||||||
{file = "mypy-1.11.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:57555a7715c0a34421013144a33d280e73c08df70f3a18a552938587ce9274f4"},
|
{file = "mypy-1.12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:faca7ab947c9f457a08dcb8d9a8664fd438080e002b0fa3e41b0535335edcf7f"},
|
||||||
{file = "mypy-1.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:36383a4fcbad95f2657642a07ba22ff797de26277158f1cc7bd234821468b1b6"},
|
{file = "mypy-1.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:5bc81701d52cc8767005fdd2a08c19980de9ec61a25dbd2a937dfb1338a826f9"},
|
||||||
{file = "mypy-1.11.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e8960dbbbf36906c5c0b7f4fbf2f0c7ffb20f4898e6a879fcf56a41a08b0d318"},
|
{file = "mypy-1.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8462655b6694feb1c99e433ea905d46c478041a8b8f0c33f1dab00ae881b2164"},
|
||||||
{file = "mypy-1.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06d26c277962f3fb50e13044674aa10553981ae514288cb7d0a738f495550b36"},
|
{file = "mypy-1.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:923ea66d282d8af9e0f9c21ffc6653643abb95b658c3a8a32dca1eff09c06475"},
|
||||||
{file = "mypy-1.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e7184632d89d677973a14d00ae4d03214c8bc301ceefcdaf5c474866814c987"},
|
{file = "mypy-1.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1ebf9e796521f99d61864ed89d1fb2926d9ab6a5fab421e457cd9c7e4dd65aa9"},
|
||||||
{file = "mypy-1.11.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a66169b92452f72117e2da3a576087025449018afc2d8e9bfe5ffab865709ca"},
|
{file = "mypy-1.12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e478601cc3e3fa9d6734d255a59c7a2e5c2934da4378f3dd1e3411ea8a248642"},
|
||||||
{file = "mypy-1.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:969ea3ef09617aff826885a22ece0ddef69d95852cdad2f60c8bb06bf1f71f70"},
|
{file = "mypy-1.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:c72861b7139a4f738344faa0e150834467521a3fba42dc98264e5aa9507dd601"},
|
||||||
{file = "mypy-1.11.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:37c7fa6121c1cdfcaac97ce3d3b5588e847aa79b580c1e922bb5d5d2902df19b"},
|
{file = "mypy-1.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52b9e1492e47e1790360a43755fa04101a7ac72287b1a53ce817f35899ba0521"},
|
||||||
{file = "mypy-1.11.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a8a53bc3ffbd161b5b2a4fff2f0f1e23a33b0168f1c0778ec70e1a3d66deb86"},
|
{file = "mypy-1.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:48d3e37dd7d9403e38fa86c46191de72705166d40b8c9f91a3de77350daa0893"},
|
||||||
{file = "mypy-1.11.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ff93107f01968ed834f4256bc1fc4475e2fecf6c661260066a985b52741ddce"},
|
{file = "mypy-1.12.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f106db5ccb60681b622ac768455743ee0e6a857724d648c9629a9bd2ac3f721"},
|
||||||
{file = "mypy-1.11.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:edb91dded4df17eae4537668b23f0ff6baf3707683734b6a818d5b9d0c0c31a1"},
|
{file = "mypy-1.12.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:233e11b3f73ee1f10efada2e6da0f555b2f3a5316e9d8a4a1224acc10e7181d3"},
|
||||||
{file = "mypy-1.11.2-cp38-cp38-win_amd64.whl", hash = "sha256:ee23de8530d99b6db0573c4ef4bd8f39a2a6f9b60655bf7a1357e585a3486f2b"},
|
{file = "mypy-1.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:4ae8959c21abcf9d73aa6c74a313c45c0b5a188752bf37dace564e29f06e9c1b"},
|
||||||
{file = "mypy-1.11.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:801ca29f43d5acce85f8e999b1e431fb479cb02d0e11deb7d2abb56bdaf24fd6"},
|
{file = "mypy-1.12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:eafc1b7319b40ddabdc3db8d7d48e76cfc65bbeeafaa525a4e0fa6b76175467f"},
|
||||||
{file = "mypy-1.11.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:af8d155170fcf87a2afb55b35dc1a0ac21df4431e7d96717621962e4b9192e70"},
|
{file = "mypy-1.12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9b9ce1ad8daeb049c0b55fdb753d7414260bad8952645367e70ac91aec90e07e"},
|
||||||
{file = "mypy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7821776e5c4286b6a13138cc935e2e9b6fde05e081bdebf5cdb2bb97c9df81d"},
|
{file = "mypy-1.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfe012b50e1491d439172c43ccb50db66d23fab714d500b57ed52526a1020bb7"},
|
||||||
{file = "mypy-1.11.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:539c570477a96a4e6fb718b8d5c3e0c0eba1f485df13f86d2970c91f0673148d"},
|
{file = "mypy-1.12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2c40658d4fa1ab27cb53d9e2f1066345596af2f8fe4827defc398a09c7c9519b"},
|
||||||
{file = "mypy-1.11.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f14cd3d386ac4d05c5a39a51b84387403dadbd936e17cb35882134d4f8f0d24"},
|
{file = "mypy-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:dee78a8b9746c30c1e617ccb1307b351ded57f0de0d287ca6276378d770006c0"},
|
||||||
{file = "mypy-1.11.2-py3-none-any.whl", hash = "sha256:b499bc07dbdcd3de92b0a8b29fdf592c111276f6a12fe29c30f6c417dd546d12"},
|
{file = "mypy-1.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b5df6c8a8224f6b86746bda716bbe4dbe0ce89fd67b1fa4661e11bfe38e8ec8"},
|
||||||
{file = "mypy-1.11.2.tar.gz", hash = "sha256:7f9993ad3e0ffdc95c2a14b66dee63729f021968bff8ad911867579c65d13a79"},
|
{file = "mypy-1.12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5feee5c74eb9749e91b77f60b30771563327329e29218d95bedbe1257e2fe4b0"},
|
||||||
|
{file = "mypy-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:77278e8c6ffe2abfba6db4125de55f1024de9a323be13d20e4f73b8ed3402bd1"},
|
||||||
|
{file = "mypy-1.12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:dcfb754dea911039ac12434d1950d69a2f05acd4d56f7935ed402be09fad145e"},
|
||||||
|
{file = "mypy-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:06de0498798527451ffb60f68db0d368bd2bae2bbfb5237eae616d4330cc87aa"},
|
||||||
|
{file = "mypy-1.12.0-py3-none-any.whl", hash = "sha256:fd313226af375d52e1e36c383f39bf3836e1f192801116b31b090dfcd3ec5266"},
|
||||||
|
{file = "mypy-1.12.0.tar.gz", hash = "sha256:65a22d87e757ccd95cbbf6f7e181e6caa87128255eb2b6be901bb71b26d8a99d"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -4552,4 +4557,4 @@ type = ["pytest-mypy"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<4.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
content-hash = "ceb97d5b463eacb11b6c7d9c572101e7f25c3f6b008178c6e9a273ed7e02920b"
|
content-hash = "5c436a9ba9a1695c5c456c1ad8a81c9772a2ba0248624278c6cb606dd019b338"
|
||||||
|
@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain-community"
|
name = "langchain-community"
|
||||||
version = "0.3.2"
|
version = "0.3.3"
|
||||||
description = "Community contributed LangChain integrations."
|
description = "Community contributed LangChain integrations."
|
||||||
authors = []
|
authors = []
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@ -33,8 +33,8 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.9,<4.0"
|
python = ">=3.9,<4.0"
|
||||||
langchain-core = "^0.3.10"
|
langchain-core = "^0.3.12"
|
||||||
langchain = "^0.3.3"
|
langchain = "^0.3.4"
|
||||||
SQLAlchemy = ">=1.4,<3"
|
SQLAlchemy = ">=1.4,<3"
|
||||||
requests = "^2"
|
requests = "^2"
|
||||||
PyYAML = ">=5.3"
|
PyYAML = ">=5.3"
|
||||||
@ -130,7 +130,7 @@ jupyter = "^1.0.0"
|
|||||||
setuptools = "^67.6.1"
|
setuptools = "^67.6.1"
|
||||||
|
|
||||||
[tool.poetry.group.typing.dependencies]
|
[tool.poetry.group.typing.dependencies]
|
||||||
mypy = "^1.10"
|
mypy = "^1.12"
|
||||||
types-pyyaml = "^6.0.12.2"
|
types-pyyaml = "^6.0.12.2"
|
||||||
types-requests = "^2.28.11.5"
|
types-requests = "^2.28.11.5"
|
||||||
types-toml = "^0.10.8.1"
|
types-toml = "^0.10.8.1"
|
||||||
|
1110
libs/langchain/poetry.lock
generated
1110
libs/langchain/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain"
|
name = "langchain"
|
||||||
version = "0.3.3"
|
version = "0.3.4"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
authors = []
|
authors = []
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@ -33,7 +33,7 @@ langchain-server = "langchain.server:main"
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.9,<4.0"
|
python = ">=3.9,<4.0"
|
||||||
langchain-core = "^0.3.10"
|
langchain-core = "^0.3.12"
|
||||||
langchain-text-splitters = "^0.3.0"
|
langchain-text-splitters = "^0.3.0"
|
||||||
langsmith = "^0.1.17"
|
langsmith = "^0.1.17"
|
||||||
pydantic = "^2.7.4"
|
pydantic = "^2.7.4"
|
||||||
|
@ -1 +0,0 @@
|
|||||||
__pycache__
|
|
@ -1,21 +0,0 @@
|
|||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2023 LangChain, Inc.
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
@ -1,61 +0,0 @@
|
|||||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
|
||||||
|
|
||||||
# Default target executed when no arguments are given to make.
|
|
||||||
all: help
|
|
||||||
|
|
||||||
# Define a variable for the test file path.
|
|
||||||
TEST_FILE ?= tests/unit_tests/
|
|
||||||
|
|
||||||
test:
|
|
||||||
poetry run pytest $(TEST_FILE)
|
|
||||||
|
|
||||||
tests:
|
|
||||||
poetry run pytest $(TEST_FILE)
|
|
||||||
|
|
||||||
test_watch:
|
|
||||||
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
######################
|
|
||||||
# LINTING AND FORMATTING
|
|
||||||
######################
|
|
||||||
|
|
||||||
# Define a variable for Python and notebook files.
|
|
||||||
PYTHON_FILES=.
|
|
||||||
MYPY_CACHE=.mypy_cache
|
|
||||||
lint format: PYTHON_FILES=.
|
|
||||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/azure --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
|
||||||
lint_package: PYTHON_FILES=langchain_azure_dynamic_sessions
|
|
||||||
lint_tests: PYTHON_FILES=tests
|
|
||||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
|
||||||
|
|
||||||
lint lint_diff lint_package lint_tests:
|
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
|
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
|
||||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
|
||||||
|
|
||||||
format format_diff:
|
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
|
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
|
|
||||||
|
|
||||||
spell_check:
|
|
||||||
poetry run codespell --toml pyproject.toml
|
|
||||||
|
|
||||||
spell_fix:
|
|
||||||
poetry run codespell --toml pyproject.toml -w
|
|
||||||
|
|
||||||
check_imports: $(shell find langchain_azure_dynamic_sessions -name '*.py')
|
|
||||||
poetry run python ./scripts/check_imports.py $^
|
|
||||||
|
|
||||||
######################
|
|
||||||
# HELP
|
|
||||||
######################
|
|
||||||
|
|
||||||
help:
|
|
||||||
@echo '----'
|
|
||||||
@echo 'check_imports - check imports'
|
|
||||||
@echo 'format - run code formatters'
|
|
||||||
@echo 'lint - run linters'
|
|
||||||
@echo 'test - run unit tests'
|
|
||||||
@echo 'tests - run unit tests'
|
|
||||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
|
@ -1,36 +1,3 @@
|
|||||||
# langchain-azure-dynamic-sessions
|
This package has moved!
|
||||||
|
|
||||||
This package contains the LangChain integration for Azure Container Apps dynamic sessions. You can use it to add a secure and scalable code interpreter to your agents.
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -U langchain-azure-dynamic-sessions
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
You first need to create an Azure Container Apps session pool and obtain its management endpoint. Then you can use the `SessionsPythonREPLTool` tool to give your agent the ability to execute Python code.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
|
||||||
|
|
||||||
|
|
||||||
# get the management endpoint from the session pool in the Azure portal
|
|
||||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
|
||||||
|
|
||||||
prompt = hub.pull("hwchase17/react")
|
|
||||||
tools=[tool]
|
|
||||||
react_agent = create_react_agent(
|
|
||||||
llm=llm,
|
|
||||||
tools=tools,
|
|
||||||
prompt=prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
react_agent_executor = AgentExecutor(agent=react_agent, tools=tools, verbose=True, handle_parsing_errors=True)
|
|
||||||
|
|
||||||
react_agent_executor.invoke({"input": "What is the current time in Vancouver, Canada?"})
|
|
||||||
```
|
|
||||||
|
|
||||||
By default, the tool uses `DefaultAzureCredential` to authenticate with Azure. If you're using a user-assigned managed identity, you must set the `AZURE_CLIENT_ID` environment variable to the ID of the managed identity.
|
|
||||||
|
|
||||||
|
https://github.com/langchain-ai/langchain-azure/tree/main/libs/azure-dynamic-sessions
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
"""This package provides tools for managing dynamic sessions in Azure."""
|
|
||||||
|
|
||||||
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"SessionsPythonREPLTool",
|
|
||||||
]
|
|
@ -1,7 +0,0 @@
|
|||||||
"""This package provides tools for managing dynamic sessions in Azure."""
|
|
||||||
|
|
||||||
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"SessionsPythonREPLTool",
|
|
||||||
]
|
|
@ -1,314 +0,0 @@
|
|||||||
"""This is the Azure Dynamic Sessions module.
|
|
||||||
|
|
||||||
This module provides the SessionsPythonREPLTool class for
|
|
||||||
managing dynamic sessions in Azure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import importlib.metadata
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import urllib
|
|
||||||
from copy import deepcopy
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
from io import BytesIO
|
|
||||||
from typing import Any, BinaryIO, Callable, List, Literal, Optional, Tuple
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from azure.core.credentials import AccessToken
|
|
||||||
from azure.identity import DefaultAzureCredential
|
|
||||||
from langchain_core.tools import BaseTool
|
|
||||||
|
|
||||||
try:
|
|
||||||
_package_version = importlib.metadata.version("langchain-azure-dynamic-sessions")
|
|
||||||
except importlib.metadata.PackageNotFoundError:
|
|
||||||
_package_version = "0.0.0"
|
|
||||||
USER_AGENT = f"langchain-azure-dynamic-sessions/{_package_version} (Language=Python)"
|
|
||||||
|
|
||||||
|
|
||||||
def _access_token_provider_factory() -> Callable[[], Optional[str]]:
|
|
||||||
"""Factory function for creating an access token provider function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Callable[[], Optional[str]]: The access token provider function
|
|
||||||
"""
|
|
||||||
access_token: Optional[AccessToken] = None
|
|
||||||
|
|
||||||
def access_token_provider() -> Optional[str]:
|
|
||||||
nonlocal access_token
|
|
||||||
if access_token is None or datetime.fromtimestamp(
|
|
||||||
access_token.expires_on, timezone.utc
|
|
||||||
) < datetime.now(timezone.utc) + timedelta(minutes=5):
|
|
||||||
credential = DefaultAzureCredential()
|
|
||||||
access_token = credential.get_token("https://dynamicsessions.io/.default")
|
|
||||||
return access_token.token
|
|
||||||
|
|
||||||
return access_token_provider
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_input(query: str) -> str:
|
|
||||||
"""Sanitize input to the python REPL.
|
|
||||||
|
|
||||||
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The query to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The sanitized query
|
|
||||||
"""
|
|
||||||
# Removes `, whitespace & python from start
|
|
||||||
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
|
|
||||||
# Removes whitespace & ` from end
|
|
||||||
query = re.sub(r"(\s|`)*$", "", query)
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RemoteFileMetadata:
|
|
||||||
"""Metadata for a file in the session."""
|
|
||||||
|
|
||||||
filename: str
|
|
||||||
"""The filename relative to `/mnt/data`."""
|
|
||||||
|
|
||||||
size_in_bytes: int
|
|
||||||
"""The size of the file in bytes."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def full_path(self) -> str:
|
|
||||||
"""Get the full path of the file."""
|
|
||||||
return f"/mnt/data/{self.filename}"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_dict(data: dict) -> "RemoteFileMetadata":
|
|
||||||
"""Create a RemoteFileMetadata object from a dictionary."""
|
|
||||||
properties = data.get("properties", {})
|
|
||||||
return RemoteFileMetadata(
|
|
||||||
filename=properties.get("filename"),
|
|
||||||
size_in_bytes=properties.get("size"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SessionsPythonREPLTool(BaseTool):
|
|
||||||
r"""Azure Dynamic Sessions tool.
|
|
||||||
|
|
||||||
Setup:
|
|
||||||
Install ``langchain-azure-dynamic-sessions`` and create a session pool, which you can do by following the instructions [here](https://learn.microsoft.com/en-us/azure/container-apps/sessions-code-interpreter?tabs=azure-cli#create-a-session-pool-with-azure-cli).
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
pip install -U langchain-azure-dynamic-sessions
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
import getpass
|
|
||||||
|
|
||||||
POOL_MANAGEMENT_ENDPOINT = getpass.getpass("Enter the management endpoint of the session pool: ")
|
|
||||||
|
|
||||||
Instantiation:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
|
||||||
|
|
||||||
tool = SessionsPythonREPLTool(
|
|
||||||
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
Invocation with args:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
tool.invoke("6 * 7")
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
'{\\n "result": 42,\\n "stdout": "",\\n "stderr": ""\\n}'
|
|
||||||
|
|
||||||
Invocation with ToolCall:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
tool.invoke({"args": {"input":"6 * 7"}, "id": "1", "name": tool.name, "type": "tool_call"})
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
'{\\n "result": 42,\\n "stdout": "",\\n "stderr": ""\\n}'
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
name: str = "Python_REPL"
|
|
||||||
description: str = (
|
|
||||||
"A Python shell. Use this to execute python commands "
|
|
||||||
"when you need to perform calculations or computations. "
|
|
||||||
"Input should be a valid python command. "
|
|
||||||
"Returns a JSON object with the result, stdout, and stderr. "
|
|
||||||
)
|
|
||||||
|
|
||||||
sanitize_input: bool = True
|
|
||||||
"""Whether to sanitize input to the python REPL."""
|
|
||||||
|
|
||||||
pool_management_endpoint: str
|
|
||||||
"""The management endpoint of the session pool. Should end with a '/'."""
|
|
||||||
|
|
||||||
access_token_provider: Callable[[], Optional[str]] = (
|
|
||||||
_access_token_provider_factory()
|
|
||||||
)
|
|
||||||
"""A function that returns the access token to use for the session pool."""
|
|
||||||
|
|
||||||
session_id: str = str(uuid4())
|
|
||||||
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
|
|
||||||
|
|
||||||
response_format: Literal["content_and_artifact"] = "content_and_artifact"
|
|
||||||
|
|
||||||
def _build_url(self, path: str) -> str:
|
|
||||||
pool_management_endpoint = self.pool_management_endpoint
|
|
||||||
if not pool_management_endpoint:
|
|
||||||
raise ValueError("pool_management_endpoint is not set")
|
|
||||||
if not pool_management_endpoint.endswith("/"):
|
|
||||||
pool_management_endpoint += "/"
|
|
||||||
encoded_session_id = urllib.parse.quote(self.session_id)
|
|
||||||
query = f"identifier={encoded_session_id}&api-version=2024-02-02-preview"
|
|
||||||
query_separator = "&" if "?" in pool_management_endpoint else "?"
|
|
||||||
full_url = pool_management_endpoint + path + query_separator + query
|
|
||||||
return full_url
|
|
||||||
|
|
||||||
def execute(self, python_code: str) -> Any:
|
|
||||||
"""Execute Python code in the session."""
|
|
||||||
if self.sanitize_input:
|
|
||||||
python_code = _sanitize_input(python_code)
|
|
||||||
|
|
||||||
access_token = self.access_token_provider()
|
|
||||||
api_url = self._build_url("code/execute")
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"User-Agent": USER_AGENT,
|
|
||||||
}
|
|
||||||
body = {
|
|
||||||
"properties": {
|
|
||||||
"codeInputType": "inline",
|
|
||||||
"executionType": "synchronous",
|
|
||||||
"code": python_code,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(api_url, headers=headers, json=body)
|
|
||||||
response.raise_for_status()
|
|
||||||
response_json = response.json()
|
|
||||||
properties = response_json.get("properties", {})
|
|
||||||
return properties
|
|
||||||
|
|
||||||
def _run(self, python_code: str, **kwargs: Any) -> Tuple[str, dict]:
|
|
||||||
response = self.execute(python_code)
|
|
||||||
|
|
||||||
# if the result is an image, remove the base64 data
|
|
||||||
result = deepcopy(response.get("result"))
|
|
||||||
if isinstance(result, dict):
|
|
||||||
if result.get("type") == "image" and "base64_data" in result:
|
|
||||||
result.pop("base64_data")
|
|
||||||
|
|
||||||
content = json.dumps(
|
|
||||||
{
|
|
||||||
"result": result,
|
|
||||||
"stdout": response.get("stdout"),
|
|
||||||
"stderr": response.get("stderr"),
|
|
||||||
},
|
|
||||||
indent=2,
|
|
||||||
)
|
|
||||||
return content, response
|
|
||||||
|
|
||||||
def upload_file(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
data: Optional[BinaryIO] = None,
|
|
||||||
remote_file_path: Optional[str] = None,
|
|
||||||
local_file_path: Optional[str] = None,
|
|
||||||
) -> RemoteFileMetadata:
|
|
||||||
"""Upload a file to the session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: The data to upload.
|
|
||||||
remote_file_path: The path to upload the file to, relative to
|
|
||||||
`/mnt/data`. If local_file_path is provided, this is defaulted
|
|
||||||
to its filename.
|
|
||||||
local_file_path: The path to the local file to upload.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
RemoteFileMetadata: The metadata for the uploaded file
|
|
||||||
"""
|
|
||||||
if data and local_file_path:
|
|
||||||
raise ValueError("data and local_file_path cannot be provided together")
|
|
||||||
|
|
||||||
if data:
|
|
||||||
file_data = data
|
|
||||||
elif local_file_path:
|
|
||||||
if not remote_file_path:
|
|
||||||
remote_file_path = os.path.basename(local_file_path)
|
|
||||||
file_data = open(local_file_path, "rb")
|
|
||||||
|
|
||||||
access_token = self.access_token_provider()
|
|
||||||
api_url = self._build_url("files/upload")
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
"User-Agent": USER_AGENT,
|
|
||||||
}
|
|
||||||
files = [("file", (remote_file_path, file_data, "application/octet-stream"))]
|
|
||||||
|
|
||||||
response = requests.request(
|
|
||||||
"POST", api_url, headers=headers, data={}, files=files
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
response_json = response.json()
|
|
||||||
return RemoteFileMetadata.from_dict(response_json["value"][0])
|
|
||||||
|
|
||||||
def download_file(
|
|
||||||
self, *, remote_file_path: str, local_file_path: Optional[str] = None
|
|
||||||
) -> BinaryIO:
|
|
||||||
"""Download a file from the session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
remote_file_path: The path to download the file from,
|
|
||||||
relative to `/mnt/data`.
|
|
||||||
local_file_path: The path to save the downloaded file to.
|
|
||||||
If not provided, the file is returned as a BufferedReader.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BinaryIO: The data of the downloaded file.
|
|
||||||
"""
|
|
||||||
access_token = self.access_token_provider()
|
|
||||||
encoded_remote_file_path = urllib.parse.quote(remote_file_path)
|
|
||||||
api_url = self._build_url(f"files/content/{encoded_remote_file_path}")
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
"User-Agent": USER_AGENT,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.get(api_url, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
if local_file_path:
|
|
||||||
with open(local_file_path, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
|
|
||||||
return BytesIO(response.content)
|
|
||||||
|
|
||||||
def list_files(self) -> List[RemoteFileMetadata]:
|
|
||||||
"""List the files in the session.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[RemoteFileMetadata]: The metadata for the files in the session
|
|
||||||
"""
|
|
||||||
access_token = self.access_token_provider()
|
|
||||||
api_url = self._build_url("files")
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
"User-Agent": USER_AGENT,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.get(api_url, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
response_json = response.json()
|
|
||||||
return [RemoteFileMetadata.from_dict(entry) for entry in response_json["value"]]
|
|
2232
libs/partners/azure-dynamic-sessions/poetry.lock
generated
2232
libs/partners/azure-dynamic-sessions/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,115 +0,0 @@
|
|||||||
[build-system]
|
|
||||||
requires = ["poetry-core>=1.0.0"]
|
|
||||||
build-backend = "poetry.core.masonry.api"
|
|
||||||
|
|
||||||
[tool.poetry]
|
|
||||||
name = "langchain-azure-dynamic-sessions"
|
|
||||||
version = "0.2.0"
|
|
||||||
description = "An integration package connecting Azure Container Apps dynamic sessions and LangChain"
|
|
||||||
authors = []
|
|
||||||
readme = "README.md"
|
|
||||||
repository = "https://github.com/langchain-ai/langchain"
|
|
||||||
license = "MIT"
|
|
||||||
|
|
||||||
[tool.mypy]
|
|
||||||
disallow_untyped_defs = "True"
|
|
||||||
|
|
||||||
[tool.poetry.urls]
|
|
||||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/azure-dynamic-sessions"
|
|
||||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-azure-dynamic-sessions%3D%3D0%22&expanded=true"
|
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
|
||||||
python = ">=3.9,<4.0"
|
|
||||||
langchain-core = "^0.3.0"
|
|
||||||
azure-identity = "^1.16.0"
|
|
||||||
requests = "^2.31.0"
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = ["E", "F", "I", "D"]
|
|
||||||
|
|
||||||
[tool.coverage.run]
|
|
||||||
omit = ["tests/*"]
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
|
||||||
markers = [
|
|
||||||
"requires: mark tests as requiring a specific library",
|
|
||||||
"compile: mark placeholder test used to compile integration tests without running them",
|
|
||||||
]
|
|
||||||
asyncio_mode = "auto"
|
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.test_integration]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.codespell]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.lint]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.dev]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.ruff.lint.pydocstyle]
|
|
||||||
convention = "google"
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"tests/**" = ["D"]
|
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies]
|
|
||||||
pytest = "^7.3.0"
|
|
||||||
freezegun = "^1.2.2"
|
|
||||||
pytest-mock = "^3.10.0"
|
|
||||||
syrupy = "^4.0.2"
|
|
||||||
pytest-watcher = "^0.3.4"
|
|
||||||
pytest-asyncio = "^0.21.1"
|
|
||||||
python-dotenv = "^1.0.1"
|
|
||||||
# TODO: hack to fix 3.9 builds
|
|
||||||
cffi = [
|
|
||||||
{ version = "<1.17.1", python = "<3.10" },
|
|
||||||
{ version = "*", python = ">=3.10" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.poetry.group.test_integration.dependencies]
|
|
||||||
pytest = "^7.3.0"
|
|
||||||
python-dotenv = "^1.0.1"
|
|
||||||
|
|
||||||
[tool.poetry.group.codespell.dependencies]
|
|
||||||
codespell = "^2.2.0"
|
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies]
|
|
||||||
ruff = "^0.5"
|
|
||||||
python-dotenv = "^1.0.1"
|
|
||||||
pytest = "^7.3.0"
|
|
||||||
# TODO: hack to fix 3.9 builds
|
|
||||||
cffi = [
|
|
||||||
{ version = "<1.17.1", python = "<3.10" },
|
|
||||||
{ version = "*", python = ">=3.10" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
|
||||||
ipykernel = "^6.29.4"
|
|
||||||
langchainhub = "^0.1.15"
|
|
||||||
|
|
||||||
[tool.poetry.group.typing.dependencies]
|
|
||||||
mypy = "^1.10"
|
|
||||||
types-requests = "^2.31.0.20240406"
|
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies.langchain-core]
|
|
||||||
path = "../../core"
|
|
||||||
develop = true
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies.langchain-core]
|
|
||||||
path = "../../core"
|
|
||||||
develop = true
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies.langchain-openai]
|
|
||||||
path = "../openai"
|
|
||||||
develop = true
|
|
||||||
|
|
||||||
[tool.poetry.group.typing.dependencies.langchain-core]
|
|
||||||
path = "../../core"
|
|
||||||
develop = true
|
|
@ -1,19 +0,0 @@
|
|||||||
"""This module checks for specific import statements in the codebase."""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from importlib.machinery import SourceFileLoader
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
files = sys.argv[1:]
|
|
||||||
has_failure = False
|
|
||||||
for file in files:
|
|
||||||
try:
|
|
||||||
SourceFileLoader("x", file).load_module()
|
|
||||||
except Exception:
|
|
||||||
has_failure = True
|
|
||||||
print(file)
|
|
||||||
traceback.print_exc()
|
|
||||||
print()
|
|
||||||
|
|
||||||
sys.exit(1 if has_failure else 0)
|
|
@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -eu
|
|
||||||
|
|
||||||
# Initialize a variable to keep track of errors
|
|
||||||
errors=0
|
|
||||||
|
|
||||||
# make sure not importing from langchain or langchain_experimental
|
|
||||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
|
||||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
|
||||||
|
|
||||||
# Decide on an exit status based on the errors
|
|
||||||
if [ "$errors" -gt 0 ]; then
|
|
||||||
exit 1
|
|
||||||
else
|
|
||||||
exit 0
|
|
||||||
fi
|
|
@ -1 +0,0 @@
|
|||||||
test file content
|
|
@ -1,7 +0,0 @@
|
|||||||
import pytest # type: ignore[import-not-found]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.compile
|
|
||||||
def test_placeholder() -> None:
|
|
||||||
"""Used for compiling integration tests without running any real tests."""
|
|
||||||
pass
|
|
@ -1,68 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
import dotenv # type: ignore[import-not-found]
|
|
||||||
|
|
||||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
POOL_MANAGEMENT_ENDPOINT = os.getenv("AZURE_DYNAMIC_SESSIONS_POOL_MANAGEMENT_ENDPOINT")
|
|
||||||
TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "testdata.txt")
|
|
||||||
TEST_DATA_CONTENT = open(TEST_DATA_PATH, "rb").read()
|
|
||||||
|
|
||||||
|
|
||||||
def test_end_to_end() -> None:
|
|
||||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT) # type: ignore[arg-type]
|
|
||||||
result = tool.run("print('hello world')\n1 + 1")
|
|
||||||
assert json.loads(result) == {
|
|
||||||
"result": 2,
|
|
||||||
"stdout": "hello world\n",
|
|
||||||
"stderr": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
# upload file content
|
|
||||||
uploaded_file1_metadata = tool.upload_file(
|
|
||||||
remote_file_path="test1.txt", data=BytesIO(b"hello world!!!!!")
|
|
||||||
)
|
|
||||||
assert uploaded_file1_metadata.filename == "test1.txt"
|
|
||||||
assert uploaded_file1_metadata.size_in_bytes == 16
|
|
||||||
assert uploaded_file1_metadata.full_path == "/mnt/data/test1.txt"
|
|
||||||
downloaded_file1 = tool.download_file(remote_file_path="test1.txt")
|
|
||||||
assert downloaded_file1.read() == b"hello world!!!!!"
|
|
||||||
|
|
||||||
# upload file from buffer
|
|
||||||
with open(TEST_DATA_PATH, "rb") as f:
|
|
||||||
uploaded_file2_metadata = tool.upload_file(remote_file_path="test2.txt", data=f)
|
|
||||||
assert uploaded_file2_metadata.filename == "test2.txt"
|
|
||||||
downloaded_file2 = tool.download_file(remote_file_path="test2.txt")
|
|
||||||
assert downloaded_file2.read() == TEST_DATA_CONTENT
|
|
||||||
|
|
||||||
# upload file from disk, specifying remote file path
|
|
||||||
uploaded_file3_metadata = tool.upload_file(
|
|
||||||
remote_file_path="test3.txt", local_file_path=TEST_DATA_PATH
|
|
||||||
)
|
|
||||||
assert uploaded_file3_metadata.filename == "test3.txt"
|
|
||||||
downloaded_file3 = tool.download_file(remote_file_path="test3.txt")
|
|
||||||
assert downloaded_file3.read() == TEST_DATA_CONTENT
|
|
||||||
|
|
||||||
# upload file from disk, without specifying remote file path
|
|
||||||
uploaded_file4_metadata = tool.upload_file(local_file_path=TEST_DATA_PATH)
|
|
||||||
assert uploaded_file4_metadata.filename == os.path.basename(TEST_DATA_PATH)
|
|
||||||
downloaded_file4 = tool.download_file(
|
|
||||||
remote_file_path=uploaded_file4_metadata.filename
|
|
||||||
)
|
|
||||||
assert downloaded_file4.read() == TEST_DATA_CONTENT
|
|
||||||
|
|
||||||
# list files
|
|
||||||
remote_files_metadata = tool.list_files()
|
|
||||||
assert len(remote_files_metadata) == 4
|
|
||||||
remote_file_paths = [metadata.filename for metadata in remote_files_metadata]
|
|
||||||
expected_filenames = [
|
|
||||||
"test1.txt",
|
|
||||||
"test2.txt",
|
|
||||||
"test3.txt",
|
|
||||||
os.path.basename(TEST_DATA_PATH),
|
|
||||||
]
|
|
||||||
assert set(remote_file_paths) == set(expected_filenames)
|
|
@ -1,9 +0,0 @@
|
|||||||
from langchain_azure_dynamic_sessions import __all__
|
|
||||||
|
|
||||||
EXPECTED_ALL = [
|
|
||||||
"SessionsPythonREPLTool",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_imports() -> None:
|
|
||||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
|
@ -1,208 +0,0 @@
|
|||||||
import json
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from unittest import mock
|
|
||||||
from urllib.parse import parse_qs, urlparse
|
|
||||||
|
|
||||||
from azure.core.credentials import AccessToken
|
|
||||||
|
|
||||||
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
|
|
||||||
from langchain_azure_dynamic_sessions.tools.sessions import (
|
|
||||||
_access_token_provider_factory,
|
|
||||||
)
|
|
||||||
|
|
||||||
POOL_MANAGEMENT_ENDPOINT = "https://westus2.dynamicsessions.io/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/sessions-rg/sessionPools/my-pool"
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_access_token_provider_returns_token() -> None:
|
|
||||||
access_token_provider = _access_token_provider_factory()
|
|
||||||
with mock.patch(
|
|
||||||
"azure.identity.DefaultAzureCredential.get_token"
|
|
||||||
) as mock_get_token:
|
|
||||||
mock_get_token.return_value = AccessToken("token_value", 0)
|
|
||||||
access_token = access_token_provider()
|
|
||||||
assert access_token == "token_value"
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_access_token_provider_returns_cached_token() -> None:
|
|
||||||
access_token_provider = _access_token_provider_factory()
|
|
||||||
with mock.patch(
|
|
||||||
"azure.identity.DefaultAzureCredential.get_token"
|
|
||||||
) as mock_get_token:
|
|
||||||
mock_get_token.return_value = AccessToken(
|
|
||||||
"token_value", int(time.time() + 1000)
|
|
||||||
)
|
|
||||||
access_token = access_token_provider()
|
|
||||||
assert access_token == "token_value"
|
|
||||||
assert mock_get_token.call_count == 1
|
|
||||||
|
|
||||||
mock_get_token.return_value = AccessToken(
|
|
||||||
"new_token_value", int(time.time() + 1000)
|
|
||||||
)
|
|
||||||
access_token = access_token_provider()
|
|
||||||
assert access_token == "token_value"
|
|
||||||
assert mock_get_token.call_count == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_access_token_provider_refreshes_expiring_token() -> None:
|
|
||||||
access_token_provider = _access_token_provider_factory()
|
|
||||||
with mock.patch(
|
|
||||||
"azure.identity.DefaultAzureCredential.get_token"
|
|
||||||
) as mock_get_token:
|
|
||||||
mock_get_token.return_value = AccessToken("token_value", int(time.time() - 1))
|
|
||||||
access_token = access_token_provider()
|
|
||||||
assert access_token == "token_value"
|
|
||||||
assert mock_get_token.call_count == 1
|
|
||||||
|
|
||||||
mock_get_token.return_value = AccessToken(
|
|
||||||
"new_token_value", int(time.time() + 1000)
|
|
||||||
)
|
|
||||||
access_token = access_token_provider()
|
|
||||||
assert access_token == "new_token_value"
|
|
||||||
assert mock_get_token.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("requests.post")
|
|
||||||
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
|
|
||||||
def test_code_execution_calls_api(
|
|
||||||
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
|
|
||||||
) -> None:
|
|
||||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
|
||||||
mock_post.return_value.json.return_value = {
|
|
||||||
"$id": "1",
|
|
||||||
"properties": {
|
|
||||||
"$id": "2",
|
|
||||||
"status": "Success",
|
|
||||||
"stdout": "hello world\n",
|
|
||||||
"stderr": "",
|
|
||||||
"result": "",
|
|
||||||
"executionTimeInMilliseconds": 33,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
|
|
||||||
|
|
||||||
result = tool.run("print('hello world')")
|
|
||||||
|
|
||||||
assert json.loads(result) == {
|
|
||||||
"result": "",
|
|
||||||
"stdout": "hello world\n",
|
|
||||||
"stderr": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
api_url = f"{POOL_MANAGEMENT_ENDPOINT}/code/execute"
|
|
||||||
headers = {
|
|
||||||
"Authorization": "Bearer token_value",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"User-Agent": mock.ANY,
|
|
||||||
}
|
|
||||||
body = {
|
|
||||||
"properties": {
|
|
||||||
"codeInputType": "inline",
|
|
||||||
"executionType": "synchronous",
|
|
||||||
"code": "print('hello world')",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mock_post.assert_called_once_with(mock.ANY, headers=headers, json=body)
|
|
||||||
|
|
||||||
called_headers = mock_post.call_args.kwargs["headers"]
|
|
||||||
assert re.match(
|
|
||||||
r"^langchain-azure-dynamic-sessions/\d+\.\d+\.\d+.* \(Language=Python\)",
|
|
||||||
called_headers["User-Agent"],
|
|
||||||
)
|
|
||||||
|
|
||||||
called_api_url = mock_post.call_args.args[0]
|
|
||||||
assert called_api_url.startswith(api_url)
|
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("requests.post")
|
|
||||||
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
|
|
||||||
def test_uses_specified_session_id(
|
|
||||||
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
|
|
||||||
) -> None:
|
|
||||||
tool = SessionsPythonREPLTool(
|
|
||||||
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
|
|
||||||
session_id="00000000-0000-0000-0000-000000000003",
|
|
||||||
)
|
|
||||||
mock_post.return_value.json.return_value = {
|
|
||||||
"$id": "1",
|
|
||||||
"properties": {
|
|
||||||
"$id": "2",
|
|
||||||
"status": "Success",
|
|
||||||
"stdout": "",
|
|
||||||
"stderr": "",
|
|
||||||
"result": "2",
|
|
||||||
"executionTimeInMilliseconds": 33,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
|
|
||||||
tool.run("1 + 1")
|
|
||||||
call_url = mock_post.call_args.args[0]
|
|
||||||
parsed_url = urlparse(call_url)
|
|
||||||
call_identifier = parse_qs(parsed_url.query)["identifier"][0]
|
|
||||||
assert call_identifier == "00000000-0000-0000-0000-000000000003"
|
|
||||||
|
|
||||||
|
|
||||||
def test_sanitizes_input() -> None:
|
|
||||||
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
|
|
||||||
with mock.patch("requests.post") as mock_post:
|
|
||||||
mock_post.return_value.json.return_value = {
|
|
||||||
"$id": "1",
|
|
||||||
"properties": {
|
|
||||||
"$id": "2",
|
|
||||||
"status": "Success",
|
|
||||||
"stdout": "",
|
|
||||||
"stderr": "",
|
|
||||||
"result": "",
|
|
||||||
"executionTimeInMilliseconds": 33,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
tool.run("```python\nprint('hello world')\n```")
|
|
||||||
body = mock_post.call_args.kwargs["json"]
|
|
||||||
assert body["properties"]["code"] == "print('hello world')"
|
|
||||||
|
|
||||||
|
|
||||||
def test_does_not_sanitize_input() -> None:
|
|
||||||
tool = SessionsPythonREPLTool(
|
|
||||||
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT, sanitize_input=False
|
|
||||||
)
|
|
||||||
with mock.patch("requests.post") as mock_post:
|
|
||||||
mock_post.return_value.json.return_value = {
|
|
||||||
"$id": "1",
|
|
||||||
"properties": {
|
|
||||||
"$id": "2",
|
|
||||||
"status": "Success",
|
|
||||||
"stdout": "",
|
|
||||||
"stderr": "",
|
|
||||||
"result": "",
|
|
||||||
"executionTimeInMilliseconds": 33,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
tool.run("```python\nprint('hello world')\n```")
|
|
||||||
body = mock_post.call_args.kwargs["json"]
|
|
||||||
assert body["properties"]["code"] == "```python\nprint('hello world')\n```"
|
|
||||||
|
|
||||||
|
|
||||||
def test_uses_custom_access_token_provider() -> None:
|
|
||||||
def custom_access_token_provider() -> str:
|
|
||||||
return "custom_token"
|
|
||||||
|
|
||||||
tool = SessionsPythonREPLTool(
|
|
||||||
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
|
|
||||||
access_token_provider=custom_access_token_provider,
|
|
||||||
)
|
|
||||||
|
|
||||||
with mock.patch("requests.post") as mock_post:
|
|
||||||
mock_post.return_value.json.return_value = {
|
|
||||||
"$id": "1",
|
|
||||||
"properties": {
|
|
||||||
"$id": "2",
|
|
||||||
"status": "Success",
|
|
||||||
"stdout": "",
|
|
||||||
"stderr": "",
|
|
||||||
"result": "",
|
|
||||||
"executionTimeInMilliseconds": 33,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
tool.run("print('hello world')")
|
|
||||||
headers = mock_post.call_args.kwargs["headers"]
|
|
||||||
assert headers["Authorization"] == "Bearer custom_token"
|
|
1
libs/partners/mongodb/.gitignore
vendored
1
libs/partners/mongodb/.gitignore
vendored
@ -1 +0,0 @@
|
|||||||
__pycache__
|
|
@ -1,21 +0,0 @@
|
|||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2024 LangChain, Inc.
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
@ -1,59 +0,0 @@
|
|||||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
|
||||||
|
|
||||||
# Default target executed when no arguments are given to make.
|
|
||||||
all: help
|
|
||||||
|
|
||||||
# Define a variable for the test file path.
|
|
||||||
TEST_FILE ?= tests/unit_tests/
|
|
||||||
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
|
||||||
|
|
||||||
test tests integration_test integration_tests:
|
|
||||||
poetry run pytest $(TEST_FILE)
|
|
||||||
|
|
||||||
test_watch:
|
|
||||||
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
######################
|
|
||||||
# LINTING AND FORMATTING
|
|
||||||
######################
|
|
||||||
|
|
||||||
# Define a variable for Python and notebook files.
|
|
||||||
PYTHON_FILES=.
|
|
||||||
MYPY_CACHE=.mypy_cache
|
|
||||||
lint format: PYTHON_FILES=.
|
|
||||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/mongodb --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
|
||||||
lint_package: PYTHON_FILES=langchain_mongodb
|
|
||||||
lint_tests: PYTHON_FILES=tests
|
|
||||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
|
||||||
|
|
||||||
lint lint_diff lint_package lint_tests:
|
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
|
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
|
||||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
|
||||||
|
|
||||||
format format_diff:
|
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
|
|
||||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
|
|
||||||
|
|
||||||
spell_check:
|
|
||||||
poetry run codespell --toml pyproject.toml
|
|
||||||
|
|
||||||
spell_fix:
|
|
||||||
poetry run codespell --toml pyproject.toml -w
|
|
||||||
|
|
||||||
check_imports: $(shell find langchain_mongodb -name '*.py')
|
|
||||||
poetry run python ./scripts/check_imports.py $^
|
|
||||||
|
|
||||||
######################
|
|
||||||
# HELP
|
|
||||||
######################
|
|
||||||
|
|
||||||
help:
|
|
||||||
@echo '----'
|
|
||||||
@echo 'check_imports - check imports'
|
|
||||||
@echo 'format - run code formatters'
|
|
||||||
@echo 'lint - run linters'
|
|
||||||
@echo 'test - run unit tests'
|
|
||||||
@echo 'tests - run unit tests'
|
|
||||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
|
@ -1,39 +1,3 @@
|
|||||||
# langchain-mongodb
|
This package has moved!
|
||||||
|
|
||||||
# Installation
|
https://github.com/langchain-ai/langchain-mongodb/tree/main/libs/mongodb
|
||||||
```
|
|
||||||
pip install -U langchain-mongodb
|
|
||||||
```
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
- See [Getting Started with the LangChain Integration](https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/langchain/#get-started-with-the-langchain-integration) for a walkthrough on using your first LangChain implementation with MongoDB Atlas.
|
|
||||||
|
|
||||||
## Using MongoDBAtlasVectorSearch
|
|
||||||
```python
|
|
||||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
|
||||||
|
|
||||||
# Pull MongoDB Atlas URI from environment variables
|
|
||||||
MONGODB_ATLAS_CLUSTER_URI = os.environ.get("MONGODB_ATLAS_CLUSTER_URI")
|
|
||||||
|
|
||||||
DB_NAME = "langchain_db"
|
|
||||||
COLLECTION_NAME = "test"
|
|
||||||
ATLAS_VECTOR_SEARCH_INDEX_NAME = "index_name"
|
|
||||||
MONGODB_COLLECTION = client[DB_NAME][COLLECTION_NAME]
|
|
||||||
|
|
||||||
# Create the vector search via `from_connection_string`
|
|
||||||
vector_search = MongoDBAtlasVectorSearch.from_connection_string(
|
|
||||||
MONGODB_ATLAS_CLUSTER_URI,
|
|
||||||
DB_NAME + "." + COLLECTION_NAME,
|
|
||||||
OpenAIEmbeddings(disallowed_special=()),
|
|
||||||
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize MongoDB python client
|
|
||||||
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
|
|
||||||
# Create the vector search via instantiation
|
|
||||||
vector_search_2 = MongoDBAtlasVectorSearch(
|
|
||||||
collection=MONGODB_COLLECTION,
|
|
||||||
embeddings=OpenAIEmbeddings(disallowed_special=()),
|
|
||||||
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
"""
|
|
||||||
Integrate your operational database and vector search in a single, unified,
|
|
||||||
fully managed platform with full vector database capabilities on MongoDB Atlas.
|
|
||||||
|
|
||||||
|
|
||||||
Store your operational data, metadata, and vector embeddings in oue VectorStore,
|
|
||||||
MongoDBAtlasVectorSearch.
|
|
||||||
Insert into a Chain via a Vector, FullText, or Hybrid Retriever.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache
|
|
||||||
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
|
|
||||||
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"MongoDBAtlasVectorSearch",
|
|
||||||
"MongoDBChatMessageHistory",
|
|
||||||
"MongoDBCache",
|
|
||||||
"MongoDBAtlasSemanticCache",
|
|
||||||
]
|
|
@ -1,308 +0,0 @@
|
|||||||
"""LangChain MongoDB Caches."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from importlib.metadata import version
|
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
|
||||||
|
|
||||||
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from langchain_core.load.dump import dumps
|
|
||||||
from langchain_core.load.load import loads
|
|
||||||
from langchain_core.outputs import Generation
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
from pymongo.database import Database
|
|
||||||
from pymongo.driver_info import DriverInfo
|
|
||||||
|
|
||||||
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
|
|
||||||
|
|
||||||
logger = logging.getLogger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
class MongoDBCache(BaseCache):
|
|
||||||
"""MongoDB Atlas cache
|
|
||||||
|
|
||||||
A cache that uses MongoDB Atlas as a backend
|
|
||||||
"""
|
|
||||||
|
|
||||||
PROMPT = "prompt"
|
|
||||||
LLM = "llm"
|
|
||||||
RETURN_VAL = "return_val"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
connection_string: str,
|
|
||||||
collection_name: str = "default",
|
|
||||||
database_name: str = "default",
|
|
||||||
**kwargs: Dict[str, Any],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Initialize Atlas Cache. Creates collection on instantiation
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collection_name (str): Name of collection for cache to live.
|
|
||||||
Defaults to "default".
|
|
||||||
connection_string (str): Connection URI to MongoDB Atlas.
|
|
||||||
Defaults to "default".
|
|
||||||
database_name (str): Name of database for cache to live.
|
|
||||||
Defaults to "default".
|
|
||||||
"""
|
|
||||||
self.client = _generate_mongo_client(connection_string)
|
|
||||||
self.__database_name = database_name
|
|
||||||
self.__collection_name = collection_name
|
|
||||||
|
|
||||||
if self.__collection_name not in self.database.list_collection_names():
|
|
||||||
self.database.create_collection(self.__collection_name)
|
|
||||||
# Create an index on key and llm_string
|
|
||||||
self.collection.create_index([self.PROMPT, self.LLM])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def database(self) -> Database:
|
|
||||||
"""Returns the database used to store cache values."""
|
|
||||||
return self.client[self.__database_name]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def collection(self) -> Collection:
|
|
||||||
"""Returns the collection used to store cache values."""
|
|
||||||
return self.database[self.__collection_name]
|
|
||||||
|
|
||||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
||||||
"""Look up based on prompt and llm_string."""
|
|
||||||
return_doc = (
|
|
||||||
self.collection.find_one(self._generate_keys(prompt, llm_string)) or {}
|
|
||||||
)
|
|
||||||
return_val = return_doc.get(self.RETURN_VAL)
|
|
||||||
return _loads_generations(return_val) if return_val else None # type: ignore
|
|
||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
|
||||||
"""Update cache based on prompt and llm_string."""
|
|
||||||
self.collection.update_one(
|
|
||||||
{**self._generate_keys(prompt, llm_string)},
|
|
||||||
{"$set": {self.RETURN_VAL: _dumps_generations(return_val)}},
|
|
||||||
upsert=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate_keys(self, prompt: str, llm_string: str) -> Dict[str, str]:
|
|
||||||
"""Create keyed fields for caching layer"""
|
|
||||||
return {self.PROMPT: prompt, self.LLM: llm_string}
|
|
||||||
|
|
||||||
def clear(self, **kwargs: Any) -> None:
|
|
||||||
"""Clear cache that can take additional keyword arguments.
|
|
||||||
Any additional arguments will propagate as filtration criteria for
|
|
||||||
what gets deleted.
|
|
||||||
|
|
||||||
E.g.
|
|
||||||
# Delete only entries that have llm_string as "fake-model"
|
|
||||||
self.clear(llm_string="fake-model")
|
|
||||||
"""
|
|
||||||
self.collection.delete_many({**kwargs})
|
|
||||||
|
|
||||||
|
|
||||||
class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
|
||||||
"""MongoDB Atlas Semantic cache.
|
|
||||||
|
|
||||||
A Cache backed by a MongoDB Atlas server with vector-store support
|
|
||||||
"""
|
|
||||||
|
|
||||||
LLM = "llm_string"
|
|
||||||
RETURN_VAL = "return_val"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
connection_string: str,
|
|
||||||
embedding: Embeddings,
|
|
||||||
collection_name: str = "default",
|
|
||||||
database_name: str = "default",
|
|
||||||
index_name: str = "default",
|
|
||||||
wait_until_ready: Optional[float] = None,
|
|
||||||
score_threshold: Optional[float] = None,
|
|
||||||
**kwargs: Dict[str, Any],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize Atlas VectorSearch Cache.
|
|
||||||
Assumes collection exists before instantiation
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connection_string (str): MongoDB URI to connect to MongoDB Atlas cluster.
|
|
||||||
embedding (Embeddings): Text embedding model to use.
|
|
||||||
collection_name (str): MongoDB Collection to add the texts to.
|
|
||||||
Defaults to "default".
|
|
||||||
database_name (str): MongoDB Database where to store texts.
|
|
||||||
Defaults to "default".
|
|
||||||
index_name: Name of the Atlas Search index.
|
|
||||||
defaults to 'default'
|
|
||||||
wait_until_ready (float): Wait this time for Atlas to finish indexing
|
|
||||||
the stored text. Defaults to None.
|
|
||||||
"""
|
|
||||||
client = _generate_mongo_client(connection_string)
|
|
||||||
self.collection = client[database_name][collection_name]
|
|
||||||
self.score_threshold = score_threshold
|
|
||||||
self._wait_until_ready = wait_until_ready
|
|
||||||
super().__init__(
|
|
||||||
collection=self.collection,
|
|
||||||
embedding=embedding,
|
|
||||||
index_name=index_name,
|
|
||||||
**kwargs, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
|
||||||
"""Look up based on prompt and llm_string."""
|
|
||||||
post_filter_pipeline = (
|
|
||||||
[{"$match": {"score": {"$gte": self.score_threshold}}}]
|
|
||||||
if self.score_threshold
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
search_response = self.similarity_search_with_score(
|
|
||||||
prompt,
|
|
||||||
1,
|
|
||||||
pre_filter={self.LLM: {"$eq": llm_string}},
|
|
||||||
post_filter_pipeline=post_filter_pipeline,
|
|
||||||
)
|
|
||||||
if search_response:
|
|
||||||
return_val = search_response[0][0].metadata.get(self.RETURN_VAL)
|
|
||||||
response = _loads_generations(return_val) or return_val # type: ignore
|
|
||||||
return response
|
|
||||||
return None
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
llm_string: str,
|
|
||||||
return_val: RETURN_VAL_TYPE,
|
|
||||||
wait_until_ready: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Update cache based on prompt and llm_string."""
|
|
||||||
self.add_texts(
|
|
||||||
[prompt],
|
|
||||||
[
|
|
||||||
{
|
|
||||||
self.LLM: llm_string,
|
|
||||||
self.RETURN_VAL: _dumps_generations(return_val),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
wait = self._wait_until_ready if wait_until_ready is None else wait_until_ready
|
|
||||||
|
|
||||||
def is_indexed() -> bool:
|
|
||||||
return self.lookup(prompt, llm_string) == return_val
|
|
||||||
|
|
||||||
if wait:
|
|
||||||
_wait_until(is_indexed, return_val, timeout=wait)
|
|
||||||
|
|
||||||
def clear(self, **kwargs: Any) -> None:
|
|
||||||
"""Clear cache that can take additional keyword arguments.
|
|
||||||
Any additional arguments will propagate as filtration criteria for
|
|
||||||
what gets deleted. It will delete any locally cached content regardless
|
|
||||||
|
|
||||||
E.g.
|
|
||||||
# Delete only entries that have llm_string as "fake-model"
|
|
||||||
self.clear(llm_string="fake-model")
|
|
||||||
"""
|
|
||||||
self.collection.delete_many({**kwargs})
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_mongo_client(connection_string: str) -> MongoClient:
|
|
||||||
return MongoClient(
|
|
||||||
connection_string,
|
|
||||||
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
|
|
||||||
"""
|
|
||||||
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
generations (RETURN_VAL_TYPE): A list of language model generations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: a single string representing a list of generations.
|
|
||||||
|
|
||||||
This, and "_dumps_generations" are duplicated in this utility
|
|
||||||
from modules: "libs/community/langchain_community/cache.py"
|
|
||||||
|
|
||||||
This function and its counterpart rely on
|
|
||||||
the dumps/loads pair with Reviver, so are able to deal
|
|
||||||
with all subclasses of Generation.
|
|
||||||
|
|
||||||
Each item in the list can be `dumps`ed to a string,
|
|
||||||
then we make the whole list of strings into a json-dumped.
|
|
||||||
"""
|
|
||||||
return json.dumps([dumps(_item) for _item in generations])
|
|
||||||
|
|
||||||
|
|
||||||
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
|
|
||||||
"""
|
|
||||||
Deserialization of a string into a generic RETURN_VAL_TYPE
|
|
||||||
(i.e. a sequence of `Generation`).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
generations_str (str): A string representing a list of generations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
RETURN_VAL_TYPE: A list of generations.
|
|
||||||
|
|
||||||
|
|
||||||
This function and its counterpart rely on
|
|
||||||
the dumps/loads pair with Reviver, so are able to deal
|
|
||||||
with all subclasses of Generation.
|
|
||||||
|
|
||||||
See `_dumps_generations`, the inverse of this function.
|
|
||||||
|
|
||||||
Compatible with the legacy cache-blob format
|
|
||||||
Does not raise exceptions for malformed entries, just logs a warning
|
|
||||||
and returns none: the caller should be prepared for such a cache miss.
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
|
|
||||||
return generations
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
# deferring the (soft) handling to after the legacy-format attempt
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
gen_dicts = json.loads(generations_str)
|
|
||||||
# not relying on `_load_generations_from_json` (which could disappear):
|
|
||||||
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
|
|
||||||
logger.warning(
|
|
||||||
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
|
|
||||||
)
|
|
||||||
return generations
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
logger.warning(
|
|
||||||
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_until(
|
|
||||||
predicate: Callable, success_description: Any, timeout: float = 10.0
|
|
||||||
) -> None:
|
|
||||||
"""Wait up to 10 seconds (by default) for predicate to be true.
|
|
||||||
|
|
||||||
E.g.:
|
|
||||||
|
|
||||||
wait_until(lambda: client.primary == ('a', 1),
|
|
||||||
'connect to the primary')
|
|
||||||
|
|
||||||
If the lambda-expression isn't true after 10 seconds, we raise
|
|
||||||
AssertionError("Didn't ever connect to the primary").
|
|
||||||
|
|
||||||
Returns the predicate's first true value.
|
|
||||||
"""
|
|
||||||
start = time.time()
|
|
||||||
interval = min(float(timeout) / 100, 0.1)
|
|
||||||
while True:
|
|
||||||
retval = predicate()
|
|
||||||
if retval:
|
|
||||||
return retval
|
|
||||||
|
|
||||||
if time.time() - start > timeout:
|
|
||||||
raise TimeoutError("Didn't ever %s" % success_description)
|
|
||||||
|
|
||||||
time.sleep(interval)
|
|
@ -1,162 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
|
||||||
from langchain_core.messages import (
|
|
||||||
BaseMessage,
|
|
||||||
message_to_dict,
|
|
||||||
messages_from_dict,
|
|
||||||
)
|
|
||||||
from pymongo import MongoClient, errors
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
DEFAULT_DBNAME = "chat_history"
|
|
||||||
DEFAULT_COLLECTION_NAME = "message_store"
|
|
||||||
DEFAULT_SESSION_ID_KEY = "SessionId"
|
|
||||||
DEFAULT_HISTORY_KEY = "History"
|
|
||||||
|
|
||||||
|
|
||||||
class MongoDBChatMessageHistory(BaseChatMessageHistory):
|
|
||||||
"""Chat message history that stores history in MongoDB.
|
|
||||||
|
|
||||||
Setup:
|
|
||||||
Install ``langchain-mongodb`` python package.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
pip install langchain-mongodb
|
|
||||||
|
|
||||||
Instantiate:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_mongodb import MongoDBChatMessageHistory
|
|
||||||
|
|
||||||
|
|
||||||
history = MongoDBChatMessageHistory(
|
|
||||||
connection_string="mongodb://your-host:your-port/", # mongodb://localhost:27017/
|
|
||||||
session_id = "your-session-id",
|
|
||||||
)
|
|
||||||
|
|
||||||
Add and retrieve messages:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
# Add single message
|
|
||||||
history.add_message(message)
|
|
||||||
|
|
||||||
# Add batch messages
|
|
||||||
history.add_messages([message1, message2, message3, ...])
|
|
||||||
|
|
||||||
# Add human message
|
|
||||||
history.add_user_message(human_message)
|
|
||||||
|
|
||||||
# Add ai message
|
|
||||||
history.add_ai_message(ai_message)
|
|
||||||
|
|
||||||
# Retrieve messages
|
|
||||||
messages = history.messages
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
connection_string: str,
|
|
||||||
session_id: str,
|
|
||||||
database_name: str = DEFAULT_DBNAME,
|
|
||||||
collection_name: str = DEFAULT_COLLECTION_NAME,
|
|
||||||
*,
|
|
||||||
session_id_key: str = DEFAULT_SESSION_ID_KEY,
|
|
||||||
history_key: str = DEFAULT_HISTORY_KEY,
|
|
||||||
create_index: bool = True,
|
|
||||||
history_size: Optional[int] = None,
|
|
||||||
index_kwargs: Optional[Dict] = None,
|
|
||||||
):
|
|
||||||
"""Initialize with a MongoDBChatMessageHistory instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connection_string: str
|
|
||||||
connection string to connect to MongoDB.
|
|
||||||
session_id: str
|
|
||||||
arbitrary key that is used to store the messages of
|
|
||||||
a single chat session.
|
|
||||||
database_name: Optional[str]
|
|
||||||
name of the database to use.
|
|
||||||
collection_name: Optional[str]
|
|
||||||
name of the collection to use.
|
|
||||||
session_id_key: Optional[str]
|
|
||||||
name of the field that stores the session id.
|
|
||||||
history_key: Optional[str]
|
|
||||||
name of the field that stores the chat history.
|
|
||||||
create_index: Optional[bool]
|
|
||||||
whether to create an index on the session id field.
|
|
||||||
history_size: Optional[int]
|
|
||||||
count of (most recent) messages to fetch from MongoDB.
|
|
||||||
index_kwargs: Optional[Dict]
|
|
||||||
additional keyword arguments to pass to the index creation.
|
|
||||||
"""
|
|
||||||
self.connection_string = connection_string
|
|
||||||
self.session_id = session_id
|
|
||||||
self.database_name = database_name
|
|
||||||
self.collection_name = collection_name
|
|
||||||
self.session_id_key = session_id_key
|
|
||||||
self.history_key = history_key
|
|
||||||
self.history_size = history_size
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.client: MongoClient = MongoClient(connection_string)
|
|
||||||
except errors.ConnectionFailure as error:
|
|
||||||
logger.error(error)
|
|
||||||
|
|
||||||
self.db = self.client[database_name]
|
|
||||||
self.collection = self.db[collection_name]
|
|
||||||
|
|
||||||
if create_index:
|
|
||||||
index_kwargs = index_kwargs or {}
|
|
||||||
self.collection.create_index(self.session_id_key, **index_kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
|
||||||
"""Retrieve the messages from MongoDB"""
|
|
||||||
try:
|
|
||||||
if self.history_size is None:
|
|
||||||
cursor = self.collection.find({self.session_id_key: self.session_id})
|
|
||||||
else:
|
|
||||||
skip_count = max(
|
|
||||||
0,
|
|
||||||
self.collection.count_documents(
|
|
||||||
{self.session_id_key: self.session_id}
|
|
||||||
)
|
|
||||||
- self.history_size,
|
|
||||||
)
|
|
||||||
cursor = self.collection.find(
|
|
||||||
{self.session_id_key: self.session_id}, skip=skip_count
|
|
||||||
)
|
|
||||||
except errors.OperationFailure as error:
|
|
||||||
logger.error(error)
|
|
||||||
|
|
||||||
if cursor:
|
|
||||||
items = [json.loads(document[self.history_key]) for document in cursor]
|
|
||||||
else:
|
|
||||||
items = []
|
|
||||||
|
|
||||||
messages = messages_from_dict(items)
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def add_message(self, message: BaseMessage) -> None:
|
|
||||||
"""Append the message to the record in MongoDB"""
|
|
||||||
try:
|
|
||||||
self.collection.insert_one(
|
|
||||||
{
|
|
||||||
self.session_id_key: self.session_id,
|
|
||||||
self.history_key: json.dumps(message_to_dict(message)),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except errors.WriteError as err:
|
|
||||||
logger.error(err)
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""Clear session memory from MongoDB"""
|
|
||||||
try:
|
|
||||||
self.collection.delete_many({self.session_id_key: self.session_id})
|
|
||||||
except errors.WriteError as err:
|
|
||||||
logger.error(err)
|
|
@ -1,270 +0,0 @@
|
|||||||
"""Search Index Commands"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from time import monotonic, sleep
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
from pymongo.errors import OperationFailure
|
|
||||||
from pymongo.operations import SearchIndexModel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__file__)
|
|
||||||
|
|
||||||
|
|
||||||
def _search_index_error_message() -> str:
|
|
||||||
return (
|
|
||||||
"Search index operations are not currently available on shared clusters, "
|
|
||||||
"such as MO. They require dedicated clusters >= M10. "
|
|
||||||
"You may still perform vector search. "
|
|
||||||
"You simply must set up indexes manually. Follow the instructions here: "
|
|
||||||
"https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _vector_search_index_definition(
|
|
||||||
dimensions: int,
|
|
||||||
path: str,
|
|
||||||
similarity: str,
|
|
||||||
filters: Optional[List[str]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/
|
|
||||||
fields = [
|
|
||||||
{
|
|
||||||
"numDimensions": dimensions,
|
|
||||||
"path": path,
|
|
||||||
"similarity": similarity,
|
|
||||||
"type": "vector",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
if filters:
|
|
||||||
for field in filters:
|
|
||||||
fields.append({"type": "filter", "path": field})
|
|
||||||
definition = {"fields": fields}
|
|
||||||
definition.update(kwargs)
|
|
||||||
return definition
|
|
||||||
|
|
||||||
|
|
||||||
def create_vector_search_index(
|
|
||||||
collection: Collection,
|
|
||||||
index_name: str,
|
|
||||||
dimensions: int,
|
|
||||||
path: str,
|
|
||||||
similarity: str,
|
|
||||||
filters: Optional[List[str]] = None,
|
|
||||||
*,
|
|
||||||
wait_until_complete: Optional[float] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Experimental Utility function to create a vector search index
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collection (Collection): MongoDB Collection
|
|
||||||
index_name (str): Name of Index
|
|
||||||
dimensions (int): Number of dimensions in embedding
|
|
||||||
path (str): field with vector embedding
|
|
||||||
similarity (str): The similarity score used for the index
|
|
||||||
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
|
|
||||||
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
|
||||||
until search index is ready.
|
|
||||||
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
|
|
||||||
"""
|
|
||||||
logger.info("Creating Search Index %s on %s", index_name, collection.name)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = collection.create_search_index(
|
|
||||||
SearchIndexModel(
|
|
||||||
definition=_vector_search_index_definition(
|
|
||||||
dimensions=dimensions,
|
|
||||||
path=path,
|
|
||||||
similarity=similarity,
|
|
||||||
filters=filters,
|
|
||||||
**kwargs,
|
|
||||||
),
|
|
||||||
name=index_name,
|
|
||||||
type="vectorSearch",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except OperationFailure as e:
|
|
||||||
raise OperationFailure(_search_index_error_message()) from e
|
|
||||||
|
|
||||||
if wait_until_complete:
|
|
||||||
_wait_for_predicate(
|
|
||||||
predicate=lambda: _is_index_ready(collection, index_name),
|
|
||||||
err=f"{index_name=} did not complete in {wait_until_complete}!",
|
|
||||||
timeout=wait_until_complete,
|
|
||||||
)
|
|
||||||
logger.info(result)
|
|
||||||
|
|
||||||
|
|
||||||
def drop_vector_search_index(
|
|
||||||
collection: Collection,
|
|
||||||
index_name: str,
|
|
||||||
*,
|
|
||||||
wait_until_complete: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Drop a created vector search index
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collection (Collection): MongoDB Collection with index to be dropped
|
|
||||||
index_name (str): Name of the MongoDB index
|
|
||||||
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
|
||||||
until search index is ready.
|
|
||||||
"""
|
|
||||||
logger.info(
|
|
||||||
"Dropping Search Index %s from Collection: %s", index_name, collection.name
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
collection.drop_search_index(index_name)
|
|
||||||
except OperationFailure as e:
|
|
||||||
if "CommandNotSupported" in str(e):
|
|
||||||
raise OperationFailure(_search_index_error_message()) from e
|
|
||||||
# else this most likely means an ongoing drop request was made so skip
|
|
||||||
if wait_until_complete:
|
|
||||||
_wait_for_predicate(
|
|
||||||
predicate=lambda: len(list(collection.list_search_indexes())) == 0,
|
|
||||||
err=f"Index {index_name} did not drop in {wait_until_complete}!",
|
|
||||||
timeout=wait_until_complete,
|
|
||||||
)
|
|
||||||
logger.info("Vector Search index %s.%s dropped", collection.name, index_name)
|
|
||||||
|
|
||||||
|
|
||||||
def update_vector_search_index(
|
|
||||||
collection: Collection,
|
|
||||||
index_name: str,
|
|
||||||
dimensions: int,
|
|
||||||
path: str,
|
|
||||||
similarity: str,
|
|
||||||
filters: Optional[List[str]] = None,
|
|
||||||
*,
|
|
||||||
wait_until_complete: Optional[float] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Update a search index.
|
|
||||||
|
|
||||||
Replace the existing index definition with the provided definition.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collection (Collection): MongoDB Collection
|
|
||||||
index_name (str): Name of Index
|
|
||||||
dimensions (int): Number of dimensions in embedding
|
|
||||||
path (str): field with vector embedding
|
|
||||||
similarity (str): The similarity score used for the index.
|
|
||||||
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
|
|
||||||
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
|
||||||
until search index is ready.
|
|
||||||
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Updating Search Index %s from Collection: %s", index_name, collection.name
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
collection.update_search_index(
|
|
||||||
name=index_name,
|
|
||||||
definition=_vector_search_index_definition(
|
|
||||||
dimensions=dimensions,
|
|
||||||
path=path,
|
|
||||||
similarity=similarity,
|
|
||||||
filters=filters,
|
|
||||||
**kwargs,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except OperationFailure as e:
|
|
||||||
raise OperationFailure(_search_index_error_message()) from e
|
|
||||||
|
|
||||||
if wait_until_complete:
|
|
||||||
_wait_for_predicate(
|
|
||||||
predicate=lambda: _is_index_ready(collection, index_name),
|
|
||||||
err=f"Index {index_name} update did not complete in {wait_until_complete}!",
|
|
||||||
timeout=wait_until_complete,
|
|
||||||
)
|
|
||||||
logger.info("Update succeeded")
|
|
||||||
|
|
||||||
|
|
||||||
def _is_index_ready(collection: Collection, index_name: str) -> bool:
|
|
||||||
"""Check for the index name in the list of available search indexes to see if the
|
|
||||||
specified index is of status READY
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collection (Collection): MongoDB Collection to for the search indexes
|
|
||||||
index_name (str): Vector Search Index name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool : True if the index is present and READY false otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
search_indexes = collection.list_search_indexes(index_name)
|
|
||||||
except OperationFailure as e:
|
|
||||||
raise OperationFailure(_search_index_error_message()) from e
|
|
||||||
|
|
||||||
for index in search_indexes:
|
|
||||||
if index["type"] == "vectorSearch" and index["status"] == "READY":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_for_predicate(
|
|
||||||
predicate: Callable, err: str, timeout: float = 120, interval: float = 0.5
|
|
||||||
) -> None:
|
|
||||||
"""Generic to block until the predicate returns true
|
|
||||||
|
|
||||||
Args:
|
|
||||||
predicate (Callable[, bool]): A function that returns a boolean value
|
|
||||||
err (str): Error message to raise if nothing occurs
|
|
||||||
timeout (float, optional): Wait time for predicate. Defaults to TIMEOUT.
|
|
||||||
interval (float, optional): Interval to check predicate. Defaults to DELAY.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TimeoutError: _description_
|
|
||||||
"""
|
|
||||||
start = monotonic()
|
|
||||||
while not predicate():
|
|
||||||
if monotonic() - start > timeout:
|
|
||||||
raise TimeoutError(err)
|
|
||||||
sleep(interval)
|
|
||||||
|
|
||||||
|
|
||||||
def create_fulltext_search_index(
|
|
||||||
collection: Collection,
|
|
||||||
index_name: str,
|
|
||||||
field: str,
|
|
||||||
*,
|
|
||||||
wait_until_complete: Optional[float] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Experimental Utility function to create an Atlas Search index
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collection (Collection): MongoDB Collection
|
|
||||||
index_name (str): Name of Index
|
|
||||||
field (str): Field to index
|
|
||||||
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
|
||||||
until search index is ready
|
|
||||||
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
|
|
||||||
"""
|
|
||||||
logger.info("Creating Search Index %s on %s", index_name, collection.name)
|
|
||||||
|
|
||||||
definition = {
|
|
||||||
"mappings": {"dynamic": False, "fields": {field: [{"type": "string"}]}}
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = collection.create_search_index(
|
|
||||||
SearchIndexModel(
|
|
||||||
definition=definition,
|
|
||||||
name=index_name,
|
|
||||||
type="search",
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except OperationFailure as e:
|
|
||||||
raise OperationFailure(_search_index_error_message()) from e
|
|
||||||
|
|
||||||
if wait_until_complete:
|
|
||||||
_wait_for_predicate(
|
|
||||||
predicate=lambda: _is_index_ready(collection, index_name),
|
|
||||||
err=f"{index_name=} did not complete in {wait_until_complete}!",
|
|
||||||
timeout=wait_until_complete,
|
|
||||||
)
|
|
||||||
logger.info(result)
|
|
@ -1,160 +0,0 @@
|
|||||||
"""Aggregation pipeline components used in Atlas Full-Text, Vector, and Hybrid Search
|
|
||||||
|
|
||||||
See the following for more:
|
|
||||||
- `Full-Text Search <https://www.mongodb.com/docs/atlas/atlas-search/aggregation-stages/search/#mongodb-pipeline-pipe.-search>`_
|
|
||||||
- `MongoDB Operators <https://www.mongodb.com/docs/atlas/atlas-search/operators-and-collectors/#std-label-operators-ref>`_
|
|
||||||
- `Vector Search <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/>`_
|
|
||||||
- `Filter Example <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
def text_search_stage(
|
|
||||||
query: str,
|
|
||||||
search_field: str,
|
|
||||||
index_name: str,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
filter: Optional[Dict[str, Any]] = None,
|
|
||||||
include_scores: Optional[bool] = True,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Dict[str, Any]]: # noqa: E501
|
|
||||||
"""Full-Text search using Lucene's standard (BM25) analyzer
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Input text to search for
|
|
||||||
search_field: Field in Collection that will be searched
|
|
||||||
index_name: Atlas Search Index name
|
|
||||||
limit: Maximum number of documents to return. Default of no limit
|
|
||||||
filter: Any MQL match expression comparing an indexed field
|
|
||||||
include_scores: Scores provide measure of relative relevance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary defining the $search stage
|
|
||||||
"""
|
|
||||||
pipeline = [
|
|
||||||
{
|
|
||||||
"$search": {
|
|
||||||
"index": index_name,
|
|
||||||
"text": {"query": query, "path": search_field},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
if filter:
|
|
||||||
pipeline.append({"$match": filter}) # type: ignore
|
|
||||||
if include_scores:
|
|
||||||
pipeline.append({"$set": {"score": {"$meta": "searchScore"}}})
|
|
||||||
if limit:
|
|
||||||
pipeline.append({"$limit": limit}) # type: ignore
|
|
||||||
|
|
||||||
return pipeline # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def vector_search_stage(
|
|
||||||
query_vector: List[float],
|
|
||||||
search_field: str,
|
|
||||||
index_name: str,
|
|
||||||
top_k: int = 4,
|
|
||||||
filter: Optional[Dict[str, Any]] = None,
|
|
||||||
oversampling_factor: int = 10,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Dict[str, Any]: # noqa: E501
|
|
||||||
"""Vector Search Stage without Scores.
|
|
||||||
|
|
||||||
Scoring is applied later depending on strategy.
|
|
||||||
vector search includes a vectorSearchScore that is typically used.
|
|
||||||
hybrid uses Reciprocal Rank Fusion.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_vector: List of embedding vector
|
|
||||||
search_field: Field in Collection containing embedding vectors
|
|
||||||
index_name: Name of Atlas Vector Search Index tied to Collection
|
|
||||||
top_k: Number of documents to return
|
|
||||||
oversampling_factor: this times limit is the number of candidates
|
|
||||||
filter: MQL match expression comparing an indexed field.
|
|
||||||
Some operators are not supported.
|
|
||||||
See `vectorSearch filter docs <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary defining the $vectorSearch
|
|
||||||
"""
|
|
||||||
stage = {
|
|
||||||
"index": index_name,
|
|
||||||
"path": search_field,
|
|
||||||
"queryVector": query_vector,
|
|
||||||
"numCandidates": top_k * oversampling_factor,
|
|
||||||
"limit": top_k,
|
|
||||||
}
|
|
||||||
if filter:
|
|
||||||
stage["filter"] = filter
|
|
||||||
return {"$vectorSearch": stage}
|
|
||||||
|
|
||||||
|
|
||||||
def combine_pipelines(
|
|
||||||
pipeline: List[Any], stage: List[Dict[str, Any]], collection_name: str
|
|
||||||
) -> None:
|
|
||||||
"""Combines two aggregations into a single result set in-place."""
|
|
||||||
if pipeline:
|
|
||||||
pipeline.append({"$unionWith": {"coll": collection_name, "pipeline": stage}})
|
|
||||||
else:
|
|
||||||
pipeline.extend(stage)
|
|
||||||
|
|
||||||
|
|
||||||
def reciprocal_rank_stage(
|
|
||||||
score_field: str, penalty: float = 0, **kwargs: Any
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Stage adds Reciprocal Rank Fusion weighting.
|
|
||||||
|
|
||||||
First, it pushes documents retrieved from previous stage
|
|
||||||
into a temporary sub-document. It then unwinds to establish
|
|
||||||
the rank to each and applies the penalty.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
score_field: A unique string to identify the search being ranked
|
|
||||||
penalty: A non-negative float.
|
|
||||||
extra_fields: Any fields other than text_field that one wishes to keep.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
RRF score := \frac{1}{rank + penalty} with rank in [1,2,..,n]
|
|
||||||
"""
|
|
||||||
|
|
||||||
rrf_pipeline = [
|
|
||||||
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
|
|
||||||
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
|
|
||||||
{
|
|
||||||
"$addFields": {
|
|
||||||
f"docs.{score_field}": {
|
|
||||||
"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]
|
|
||||||
},
|
|
||||||
"docs.rank": "$rank",
|
|
||||||
"_id": "$docs._id",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$replaceRoot": {"newRoot": "$docs"}},
|
|
||||||
]
|
|
||||||
|
|
||||||
return rrf_pipeline # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def final_hybrid_stage(
|
|
||||||
scores_fields: List[str], limit: int, **kwargs: Any
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Sum weighted scores, sort, and apply limit.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scores_fields: List of fields given to scores of vector and text searches
|
|
||||||
limit: Number of documents to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Final aggregation stages
|
|
||||||
"""
|
|
||||||
|
|
||||||
return [
|
|
||||||
{"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}},
|
|
||||||
{"$replaceRoot": {"newRoot": "$docs"}},
|
|
||||||
{"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}},
|
|
||||||
{"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}},
|
|
||||||
{"$sort": {"score": -1}},
|
|
||||||
{"$limit": limit},
|
|
||||||
]
|
|
@ -1,15 +0,0 @@
|
|||||||
"""Search Retrievers of various types.
|
|
||||||
|
|
||||||
Use ``MongoDBAtlasVectorSearch.as_retriever(**)``
|
|
||||||
to create MongoDB's core Vector Search Retriever.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from langchain_mongodb.retrievers.full_text_search import (
|
|
||||||
MongoDBAtlasFullTextSearchRetriever,
|
|
||||||
)
|
|
||||||
from langchain_mongodb.retrievers.hybrid_search import MongoDBAtlasHybridSearchRetriever
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"MongoDBAtlasHybridSearchRetriever",
|
|
||||||
"MongoDBAtlasFullTextSearchRetriever",
|
|
||||||
]
|
|
@ -1,59 +0,0 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.retrievers import BaseRetriever
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
|
|
||||||
from langchain_mongodb.pipelines import text_search_stage
|
|
||||||
from langchain_mongodb.utils import make_serializable
|
|
||||||
|
|
||||||
|
|
||||||
class MongoDBAtlasFullTextSearchRetriever(BaseRetriever):
|
|
||||||
"""Hybrid Search Retriever performs full-text searches
|
|
||||||
using Lucene's standard (BM25) analyzer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
collection: Collection
|
|
||||||
"""MongoDB Collection on an Atlas cluster"""
|
|
||||||
search_index_name: str
|
|
||||||
"""Atlas Search Index name"""
|
|
||||||
search_field: str
|
|
||||||
"""Collection field that contains the text to be searched. It must be indexed"""
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
"""Number of documents to return. Default is no limit"""
|
|
||||||
filter: Optional[Dict[str, Any]] = None
|
|
||||||
"""(Optional) List of MQL match expression comparing an indexed field"""
|
|
||||||
show_embeddings: float = False
|
|
||||||
"""If true, returned Document metadata will include vectors"""
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Retrieve documents that are highest scoring / most similar to query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: String to find relevant documents for
|
|
||||||
run_manager: The callback handler to use
|
|
||||||
Returns:
|
|
||||||
List of relevant documents
|
|
||||||
"""
|
|
||||||
|
|
||||||
pipeline = text_search_stage( # type: ignore
|
|
||||||
query=query,
|
|
||||||
search_field=self.search_field,
|
|
||||||
index_name=self.search_index_name,
|
|
||||||
limit=self.top_k,
|
|
||||||
filter=self.filter,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execution
|
|
||||||
cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
# Formatting
|
|
||||||
docs = []
|
|
||||||
for res in cursor:
|
|
||||||
text = res.pop(self.search_field)
|
|
||||||
make_serializable(res)
|
|
||||||
docs.append(Document(page_content=text, metadata=res))
|
|
||||||
return docs
|
|
@ -1,126 +0,0 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.retrievers import BaseRetriever
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
|
|
||||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
|
||||||
from langchain_mongodb.pipelines import (
|
|
||||||
combine_pipelines,
|
|
||||||
final_hybrid_stage,
|
|
||||||
reciprocal_rank_stage,
|
|
||||||
text_search_stage,
|
|
||||||
vector_search_stage,
|
|
||||||
)
|
|
||||||
from langchain_mongodb.utils import make_serializable
|
|
||||||
|
|
||||||
|
|
||||||
class MongoDBAtlasHybridSearchRetriever(BaseRetriever):
|
|
||||||
"""Hybrid Search Retriever combines vector and full-text searches
|
|
||||||
weighting them the via Reciprocal Rank Fusion (RRF) algorithm.
|
|
||||||
|
|
||||||
Increasing the vector_penalty will reduce the importance on the vector search.
|
|
||||||
Increasing the fulltext_penalty will correspondingly reduce the fulltext score.
|
|
||||||
For more on the algorithm,see
|
|
||||||
https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking
|
|
||||||
"""
|
|
||||||
|
|
||||||
vectorstore: MongoDBAtlasVectorSearch
|
|
||||||
"""MongoDBAtlas VectorStore"""
|
|
||||||
search_index_name: str
|
|
||||||
"""Atlas Search Index (full-text) name"""
|
|
||||||
top_k: int = 4
|
|
||||||
"""Number of documents to return."""
|
|
||||||
oversampling_factor: int = 10
|
|
||||||
"""This times top_k is the number of candidates chosen at each step"""
|
|
||||||
pre_filter: Optional[Dict[str, Any]] = None
|
|
||||||
"""(Optional) Any MQL match expression comparing an indexed field"""
|
|
||||||
post_filter: Optional[List[Dict[str, Any]]] = None
|
|
||||||
"""(Optional) Pipeline of MongoDB aggregation stages for postprocessing."""
|
|
||||||
vector_penalty: float = 60.0
|
|
||||||
"""Penalty applied to vector search results in RRF: scores=1/(rank + penalty)"""
|
|
||||||
fulltext_penalty: float = 60.0
|
|
||||||
"""Penalty applied to full-text search results in RRF: scores=1/(rank + penalty)"""
|
|
||||||
show_embeddings: float = False
|
|
||||||
"""If true, returned Document metadata will include vectors."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def collection(self) -> Collection:
|
|
||||||
return self.vectorstore._collection
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Retrieve documents that are highest scoring / most similar to query.
|
|
||||||
|
|
||||||
Note that the same query is used in both searches,
|
|
||||||
embedded for vector search, and as-is for full-text search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: String to find relevant documents for
|
|
||||||
run_manager: The callback handler to use
|
|
||||||
Returns:
|
|
||||||
List of relevant documents
|
|
||||||
"""
|
|
||||||
|
|
||||||
query_vector = self.vectorstore._embedding.embed_query(query)
|
|
||||||
|
|
||||||
scores_fields = ["vector_score", "fulltext_score"]
|
|
||||||
pipeline: List[Any] = []
|
|
||||||
|
|
||||||
# First we build up the aggregation pipeline,
|
|
||||||
# then it is passed to the server to execute
|
|
||||||
# Vector Search stage
|
|
||||||
vector_pipeline = [
|
|
||||||
vector_search_stage(
|
|
||||||
query_vector=query_vector,
|
|
||||||
search_field=self.vectorstore._embedding_key,
|
|
||||||
index_name=self.vectorstore._index_name,
|
|
||||||
top_k=self.top_k,
|
|
||||||
filter=self.pre_filter,
|
|
||||||
oversampling_factor=self.oversampling_factor,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
vector_pipeline += reciprocal_rank_stage("vector_score", self.vector_penalty)
|
|
||||||
|
|
||||||
combine_pipelines(pipeline, vector_pipeline, self.collection.name)
|
|
||||||
|
|
||||||
# Full-Text Search stage
|
|
||||||
text_pipeline = text_search_stage(
|
|
||||||
query=query,
|
|
||||||
search_field=self.vectorstore._text_key,
|
|
||||||
index_name=self.search_index_name,
|
|
||||||
limit=self.top_k,
|
|
||||||
filter=self.pre_filter,
|
|
||||||
)
|
|
||||||
|
|
||||||
text_pipeline.extend(
|
|
||||||
reciprocal_rank_stage("fulltext_score", self.fulltext_penalty)
|
|
||||||
)
|
|
||||||
|
|
||||||
combine_pipelines(pipeline, text_pipeline, self.collection.name)
|
|
||||||
|
|
||||||
# Sum and sort stage
|
|
||||||
pipeline.extend(
|
|
||||||
final_hybrid_stage(scores_fields=scores_fields, limit=self.top_k)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Removal of embeddings unless requested.
|
|
||||||
if not self.show_embeddings:
|
|
||||||
pipeline.append({"$project": {self.vectorstore._embedding_key: 0}})
|
|
||||||
# Post filtering
|
|
||||||
if self.post_filter is not None:
|
|
||||||
pipeline.extend(self.post_filter)
|
|
||||||
|
|
||||||
# Execution
|
|
||||||
cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
# Formatting
|
|
||||||
docs = []
|
|
||||||
for res in cursor:
|
|
||||||
text = res.pop(self.vectorstore._text_key)
|
|
||||||
# score = res.pop("score") # The score remains buried!
|
|
||||||
make_serializable(res)
|
|
||||||
docs.append(Document(page_content=text, metadata=res))
|
|
||||||
return docs
|
|
@ -1,183 +0,0 @@
|
|||||||
"""Various Utility Functions
|
|
||||||
|
|
||||||
- Tools for handling bson.ObjectId
|
|
||||||
|
|
||||||
The help IDs live as ObjectId in MongoDB and str in Langchain and JSON.
|
|
||||||
|
|
||||||
|
|
||||||
- Tools for the Maximal Marginal Relevance (MMR) reranking
|
|
||||||
|
|
||||||
These are duplicated from langchain_community to avoid cross-dependencies.
|
|
||||||
|
|
||||||
Functions "maximal_marginal_relevance" and "cosine_similarity"
|
|
||||||
are duplicated in this utility respectively from modules:
|
|
||||||
- "libs/community/langchain_community/vectorstores/utils.py"
|
|
||||||
- "libs/community/langchain_community/utils/math.py"
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import date, datetime
|
|
||||||
from typing import Any, Dict, List, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|
||||||
"""Row-wise cosine similarity between two equal-width matrices."""
|
|
||||||
if len(X) == 0 or len(Y) == 0:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
X = np.array(X)
|
|
||||||
Y = np.array(Y)
|
|
||||||
if X.shape[1] != Y.shape[1]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
|
||||||
f"and Y has shape {Y.shape}."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
import simsimd as simd
|
|
||||||
|
|
||||||
X = np.array(X, dtype=np.float32)
|
|
||||||
Y = np.array(Y, dtype=np.float32)
|
|
||||||
Z = 1 - np.array(simd.cdist(X, Y, metric="cosine"))
|
|
||||||
return Z
|
|
||||||
except ImportError:
|
|
||||||
logger.debug(
|
|
||||||
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
|
|
||||||
"to use simsimd please install with `pip install simsimd`."
|
|
||||||
)
|
|
||||||
X_norm = np.linalg.norm(X, axis=1)
|
|
||||||
Y_norm = np.linalg.norm(Y, axis=1)
|
|
||||||
# Ignore divide by zero errors run time warnings as those are handled below.
|
|
||||||
with np.errstate(divide="ignore", invalid="ignore"):
|
|
||||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
|
||||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
|
||||||
return similarity
|
|
||||||
|
|
||||||
|
|
||||||
def maximal_marginal_relevance(
|
|
||||||
query_embedding: np.ndarray,
|
|
||||||
embedding_list: list,
|
|
||||||
lambda_mult: float = 0.5,
|
|
||||||
k: int = 4,
|
|
||||||
) -> List[int]:
|
|
||||||
"""Compute Maximal Marginal Relevance (MMR).
|
|
||||||
|
|
||||||
MMR is a technique used to select documents that are both relevant to the query
|
|
||||||
and diverse among themselves. This function returns the indices
|
|
||||||
of the top-k embeddings that maximize the marginal relevance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_embedding (np.ndarray): The embedding vector of the query.
|
|
||||||
embedding_list (list of np.ndarray): A list containing the embedding vectors
|
|
||||||
of the candidate documents.
|
|
||||||
lambda_mult (float, optional): The trade-off parameter between
|
|
||||||
relevance and diversity. Defaults to 0.5.
|
|
||||||
k (int, optional): The number of embeddings to select. Defaults to 4.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of int: The indices of the embeddings that maximize the marginal relevance.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
The Maximal Marginal Relevance (MMR) is computed using the following formula:
|
|
||||||
|
|
||||||
MMR = argmax_{D_i ∈ R \ S} [λ * Sim(D_i, Q) - (1 - λ) * max_{D_j ∈ S} Sim(D_i, D_j)]
|
|
||||||
|
|
||||||
where:
|
|
||||||
- R is the set of candidate documents,
|
|
||||||
- S is the set of selected documents,
|
|
||||||
- Q is the query embedding,
|
|
||||||
- Sim(D_i, Q) is the similarity between document D_i and the query,
|
|
||||||
- Sim(D_i, D_j) is the similarity between documents D_i and D_j,
|
|
||||||
- λ is the trade-off parameter.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if min(k, len(embedding_list)) <= 0:
|
|
||||||
return []
|
|
||||||
if query_embedding.ndim == 1:
|
|
||||||
query_embedding = np.expand_dims(query_embedding, axis=0)
|
|
||||||
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
|
|
||||||
most_similar = int(np.argmax(similarity_to_query))
|
|
||||||
idxs = [most_similar]
|
|
||||||
selected = np.array([embedding_list[most_similar]])
|
|
||||||
while len(idxs) < min(k, len(embedding_list)):
|
|
||||||
best_score = -np.inf
|
|
||||||
idx_to_add = -1
|
|
||||||
similarity_to_selected = cosine_similarity(embedding_list, selected)
|
|
||||||
for i, query_score in enumerate(similarity_to_query):
|
|
||||||
if i in idxs:
|
|
||||||
continue
|
|
||||||
redundant_score = max(similarity_to_selected[i])
|
|
||||||
equation_score = (
|
|
||||||
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
|
|
||||||
)
|
|
||||||
if equation_score > best_score:
|
|
||||||
best_score = equation_score
|
|
||||||
idx_to_add = i
|
|
||||||
idxs.append(idx_to_add)
|
|
||||||
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
|
||||||
return idxs
|
|
||||||
|
|
||||||
|
|
||||||
def str_to_oid(str_repr: str) -> Any | str:
|
|
||||||
"""Attempt to cast string representation of id to MongoDB's internal BSON ObjectId.
|
|
||||||
|
|
||||||
To be consistent with ObjectId, input must be a 24 character hex string.
|
|
||||||
If it is not, MongoDB will happily use the string in the main _id index.
|
|
||||||
Importantly, the str representation that comes out of MongoDB will have this form.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
str_repr: id as string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ObjectID
|
|
||||||
"""
|
|
||||||
from bson import ObjectId
|
|
||||||
from bson.errors import InvalidId
|
|
||||||
|
|
||||||
try:
|
|
||||||
return ObjectId(str_repr)
|
|
||||||
except InvalidId:
|
|
||||||
logger.debug(
|
|
||||||
"ObjectIds must be 12-character byte or 24-character hex strings. "
|
|
||||||
"Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'"
|
|
||||||
)
|
|
||||||
return str_repr
|
|
||||||
|
|
||||||
|
|
||||||
def oid_to_str(oid: Any) -> str:
|
|
||||||
"""Convert MongoDB's internal BSON ObjectId into a simple str for compatibility.
|
|
||||||
|
|
||||||
Instructive helper to show where data is coming out of MongoDB.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
oid: bson.ObjectId
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
24 character hex string.
|
|
||||||
"""
|
|
||||||
return str(oid)
|
|
||||||
|
|
||||||
|
|
||||||
def make_serializable(
|
|
||||||
obj: Dict[str, Any],
|
|
||||||
) -> None:
|
|
||||||
"""Recursively cast values in a dict to a form able to json.dump"""
|
|
||||||
|
|
||||||
from bson import ObjectId
|
|
||||||
|
|
||||||
for k, v in obj.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
make_serializable(v)
|
|
||||||
elif isinstance(v, list) and v and isinstance(v[0], (ObjectId, date, datetime)):
|
|
||||||
obj[k] = [oid_to_str(item) for item in v]
|
|
||||||
elif isinstance(v, ObjectId):
|
|
||||||
obj[k] = oid_to_str(v)
|
|
||||||
elif isinstance(v, (datetime, date)):
|
|
||||||
obj[k] = v.isoformat()
|
|
@ -1,796 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from importlib.metadata import version
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from langchain_core.runnables.config import run_in_executor
|
|
||||||
from langchain_core.vectorstores import VectorStore
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
from pymongo.driver_info import DriverInfo
|
|
||||||
from pymongo.errors import CollectionInvalid
|
|
||||||
|
|
||||||
from langchain_mongodb.index import (
|
|
||||||
create_vector_search_index,
|
|
||||||
update_vector_search_index,
|
|
||||||
)
|
|
||||||
from langchain_mongodb.pipelines import vector_search_stage
|
|
||||||
from langchain_mongodb.utils import (
|
|
||||||
make_serializable,
|
|
||||||
maximal_marginal_relevance,
|
|
||||||
oid_to_str,
|
|
||||||
str_to_oid,
|
|
||||||
)
|
|
||||||
|
|
||||||
VST = TypeVar("VST", bound=VectorStore)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
DEFAULT_INSERT_BATCH_SIZE = 100_000
|
|
||||||
|
|
||||||
|
|
||||||
class MongoDBAtlasVectorSearch(VectorStore):
|
|
||||||
"""MongoDB Atlas vector store integration.
|
|
||||||
|
|
||||||
MongoDBAtlasVectorSearch performs data operations on
|
|
||||||
text, embeddings and arbitrary data. In addition to CRUD operations,
|
|
||||||
the VectorStore provides Vector Search
|
|
||||||
based on similarity of embedding vectors following the
|
|
||||||
Hierarchical Navigable Small Worlds (HNSW) algorithm.
|
|
||||||
|
|
||||||
This supports a number of models to ascertain scores,
|
|
||||||
"similarity" (default), "MMR", and "similarity_score_threshold".
|
|
||||||
These are described in the search_type argument to as_retriever,
|
|
||||||
which provides the Runnable.invoke(query) API, allowing
|
|
||||||
MongoDBAtlasVectorSearch to be used within a chain.
|
|
||||||
|
|
||||||
Setup:
|
|
||||||
* Set up a MongoDB Atlas cluster. The free tier M0 will allow you to start.
|
|
||||||
Search Indexes are only available on Atlas, the fully managed cloud service,
|
|
||||||
not the self-managed MongoDB.
|
|
||||||
Follow [this guide](https://www.mongodb.com/basics/mongodb-atlas-tutorial)
|
|
||||||
|
|
||||||
* Create a Collection and a Vector Search Index.The procedure is described
|
|
||||||
[here](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure).
|
|
||||||
|
|
||||||
* Install ``langchain-mongodb``
|
|
||||||
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
pip install -qU langchain-mongodb pymongo
|
|
||||||
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
import getpass
|
|
||||||
MONGODB_ATLAS_CLUSTER_URI = getpass.getpass("MongoDB Atlas Cluster URI:")
|
|
||||||
|
|
||||||
Key init args — indexing params:
|
|
||||||
embedding: Embeddings
|
|
||||||
Embedding function to use.
|
|
||||||
|
|
||||||
Key init args — client params:
|
|
||||||
collection: Collection
|
|
||||||
MongoDB collection to use.
|
|
||||||
index_name: str
|
|
||||||
Name of the Atlas Search index.
|
|
||||||
|
|
||||||
Instantiate:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from langchain_openai import OpenAIEmbeddings
|
|
||||||
|
|
||||||
# initialize MongoDB python client
|
|
||||||
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
|
|
||||||
|
|
||||||
DB_NAME = "langchain_test_db"
|
|
||||||
COLLECTION_NAME = "langchain_test_vectorstores"
|
|
||||||
ATLAS_VECTOR_SEARCH_INDEX_NAME = "langchain-test-index-vectorstores"
|
|
||||||
|
|
||||||
MONGODB_COLLECTION = client[DB_NAME][COLLECTION_NAME]
|
|
||||||
|
|
||||||
vector_store = MongoDBAtlasVectorSearch(
|
|
||||||
collection=MONGODB_COLLECTION,
|
|
||||||
embedding=OpenAIEmbeddings(),
|
|
||||||
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
|
|
||||||
relevance_score_fn="cosine",
|
|
||||||
)
|
|
||||||
|
|
||||||
Add Documents:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
document_1 = Document(page_content="foo", metadata={"baz": "bar"})
|
|
||||||
document_2 = Document(page_content="thud", metadata={"bar": "baz"})
|
|
||||||
document_3 = Document(page_content="i will be deleted :(")
|
|
||||||
|
|
||||||
documents = [document_1, document_2, document_3]
|
|
||||||
ids = ["1", "2", "3"]
|
|
||||||
vector_store.add_documents(documents=documents, ids=ids)
|
|
||||||
|
|
||||||
Delete Documents:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
vector_store.delete(ids=["3"])
|
|
||||||
|
|
||||||
Search:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
results = vector_store.similarity_search(query="thud",k=1)
|
|
||||||
for doc in results:
|
|
||||||
print(f"* {doc.page_content} [{doc.metadata}]")
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
* thud [{'_id': '2', 'baz': 'baz'}]
|
|
||||||
|
|
||||||
|
|
||||||
Search with filter:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
results = vector_store.similarity_search(query="thud",k=1,post_filter=[{"bar": "baz"]})
|
|
||||||
for doc in results:
|
|
||||||
print(f"* {doc.page_content} [{doc.metadata}]")
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
* thud [{'_id': '2', 'baz': 'baz'}]
|
|
||||||
|
|
||||||
Search with score:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
results = vector_store.similarity_search_with_score(query="qux",k=1)
|
|
||||||
for doc, score in results:
|
|
||||||
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
* [SIM=0.916096] foo [{'_id': '1', 'baz': 'bar'}]
|
|
||||||
|
|
||||||
Async:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
# add documents
|
|
||||||
# await vector_store.aadd_documents(documents=documents, ids=ids)
|
|
||||||
|
|
||||||
# delete documents
|
|
||||||
# await vector_store.adelete(ids=["3"])
|
|
||||||
|
|
||||||
# search
|
|
||||||
# results = vector_store.asimilarity_search(query="thud",k=1)
|
|
||||||
|
|
||||||
# search with score
|
|
||||||
results = await vector_store.asimilarity_search_with_score(query="qux",k=1)
|
|
||||||
for doc,score in results:
|
|
||||||
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
* [SIM=0.916096] foo [{'_id': '1', 'baz': 'bar'}]
|
|
||||||
|
|
||||||
Use as Retriever:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
retriever = vector_store.as_retriever(
|
|
||||||
search_type="mmr",
|
|
||||||
search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5},
|
|
||||||
)
|
|
||||||
retriever.invoke("thud")
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
[Document(metadata={'_id': '2', 'embedding': [-0.01850726455450058, -0.0014740974875167012, -0.009762819856405258, ...], 'baz': 'baz'}, page_content='thud')]
|
|
||||||
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
collection: Collection[Dict[str, Any]],
|
|
||||||
embedding: Embeddings,
|
|
||||||
index_name: str = "vector_index",
|
|
||||||
text_key: str = "text",
|
|
||||||
embedding_key: str = "embedding",
|
|
||||||
relevance_score_fn: str = "cosine",
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
collection: MongoDB collection to add the texts to
|
|
||||||
embedding: Text embedding model to use
|
|
||||||
text_key: MongoDB field that will contain the text for each document
|
|
||||||
index_name: Existing Atlas Vector Search Index
|
|
||||||
embedding_key: Field that will contain the embedding for each document
|
|
||||||
vector_index_name: Name of the Atlas Vector Search index
|
|
||||||
relevance_score_fn: The similarity score used for the index
|
|
||||||
Currently supported: 'euclidean', 'cosine', and 'dotProduct'
|
|
||||||
"""
|
|
||||||
self._collection = collection
|
|
||||||
self._embedding = embedding
|
|
||||||
self._index_name = index_name
|
|
||||||
self._text_key = text_key
|
|
||||||
self._embedding_key = embedding_key
|
|
||||||
self._relevance_score_fn = relevance_score_fn
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embeddings(self) -> Embeddings:
|
|
||||||
return self._embedding
|
|
||||||
|
|
||||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
|
||||||
scoring: dict[str, Callable] = {
|
|
||||||
"euclidean": self._euclidean_relevance_score_fn,
|
|
||||||
"dotProduct": self._max_inner_product_relevance_score_fn,
|
|
||||||
"cosine": self._cosine_relevance_score_fn,
|
|
||||||
}
|
|
||||||
if self._relevance_score_fn in scoring:
|
|
||||||
return scoring[self._relevance_score_fn]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"No relevance score function for ${self._relevance_score_fn}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_connection_string(
|
|
||||||
cls,
|
|
||||||
connection_string: str,
|
|
||||||
namespace: str,
|
|
||||||
embedding: Embeddings,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> MongoDBAtlasVectorSearch:
|
|
||||||
"""Construct a `MongoDB Atlas Vector Search` vector store
|
|
||||||
from a MongoDB connection URI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connection_string: A valid MongoDB connection URI.
|
|
||||||
namespace: A valid MongoDB namespace (database and collection).
|
|
||||||
embedding: The text embedding model to use for the vector store.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A new MongoDBAtlasVectorSearch instance.
|
|
||||||
|
|
||||||
"""
|
|
||||||
client: MongoClient = MongoClient(
|
|
||||||
connection_string,
|
|
||||||
driver=DriverInfo(name="Langchain", version=version("langchain")),
|
|
||||||
)
|
|
||||||
db_name, collection_name = namespace.split(".")
|
|
||||||
collection = client[db_name][collection_name]
|
|
||||||
return cls(collection, embedding, **kwargs)
|
|
||||||
|
|
||||||
def add_texts(
|
|
||||||
self,
|
|
||||||
texts: Iterable[str],
|
|
||||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
|
||||||
ids: Optional[List[str]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[str]:
|
|
||||||
"""Add texts, create embeddings, and add to the Collection and index.
|
|
||||||
|
|
||||||
Important notes on ids:
|
|
||||||
- If _id or id is a key in the metadatas dicts, one must
|
|
||||||
pop them and provide as separate list.
|
|
||||||
- They must be unique.
|
|
||||||
- If they are not provided, the VectorStore will create unique ones,
|
|
||||||
stored as bson.ObjectIds internally, and strings in Langchain.
|
|
||||||
These will appear in Document.metadata with key, '_id'.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: Iterable of strings to add to the vectorstore.
|
|
||||||
metadatas: Optional list of metadatas associated with the texts.
|
|
||||||
ids: Optional list of unique ids that will be used as index in VectorStore.
|
|
||||||
See note on ids.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ids added to the vectorstore.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Check to see if metadata includes ids
|
|
||||||
if metadatas is not None and (
|
|
||||||
metadatas[0].get("_id") or metadatas[0].get("id")
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"_id or id key found in metadata. "
|
|
||||||
"Please pop from each dict and input as separate list."
|
|
||||||
"Retrieving methods will include the same id as '_id' in metadata."
|
|
||||||
)
|
|
||||||
|
|
||||||
texts_batch = texts
|
|
||||||
_metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
|
|
||||||
metadatas_batch = _metadatas
|
|
||||||
|
|
||||||
result_ids = []
|
|
||||||
batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
|
|
||||||
if batch_size:
|
|
||||||
texts_batch = []
|
|
||||||
metadatas_batch = []
|
|
||||||
size = 0
|
|
||||||
i = 0
|
|
||||||
for j, (text, metadata) in enumerate(zip(texts, _metadatas)):
|
|
||||||
size += len(text) + len(metadata)
|
|
||||||
texts_batch.append(text)
|
|
||||||
metadatas_batch.append(metadata)
|
|
||||||
if (j + 1) % batch_size == 0 or size >= 47_000_000:
|
|
||||||
if ids:
|
|
||||||
batch_res = self.bulk_embed_and_insert_texts(
|
|
||||||
texts_batch, metadatas_batch, ids[i : j + 1]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch_res = self.bulk_embed_and_insert_texts(
|
|
||||||
texts_batch, metadatas_batch
|
|
||||||
)
|
|
||||||
result_ids.extend(batch_res)
|
|
||||||
texts_batch = []
|
|
||||||
metadatas_batch = []
|
|
||||||
size = 0
|
|
||||||
i = j + 1
|
|
||||||
if texts_batch:
|
|
||||||
if ids:
|
|
||||||
batch_res = self.bulk_embed_and_insert_texts(
|
|
||||||
texts_batch, metadatas_batch, ids[i : j + 1]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch_res = self.bulk_embed_and_insert_texts(
|
|
||||||
texts_batch, metadatas_batch
|
|
||||||
)
|
|
||||||
result_ids.extend(batch_res)
|
|
||||||
return result_ids
|
|
||||||
|
|
||||||
def bulk_embed_and_insert_texts(
|
|
||||||
self,
|
|
||||||
texts: Union[List[str], Iterable[str]],
|
|
||||||
metadatas: Union[List[dict], Generator[dict, Any, Any]],
|
|
||||||
ids: Optional[List[str]] = None,
|
|
||||||
) -> List[str]:
|
|
||||||
"""Bulk insert single batch of texts, embeddings, and optionally ids.
|
|
||||||
|
|
||||||
See add_texts for additional details.
|
|
||||||
"""
|
|
||||||
if not texts:
|
|
||||||
return []
|
|
||||||
# Compute embedding vectors
|
|
||||||
embeddings = self._embedding.embed_documents(texts) # type: ignore
|
|
||||||
if ids:
|
|
||||||
to_insert = [
|
|
||||||
{
|
|
||||||
"_id": str_to_oid(i),
|
|
||||||
self._text_key: t,
|
|
||||||
self._embedding_key: embedding,
|
|
||||||
**m,
|
|
||||||
}
|
|
||||||
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
to_insert = [
|
|
||||||
{self._text_key: t, self._embedding_key: embedding, **m}
|
|
||||||
for t, m, embedding in zip(texts, metadatas, embeddings)
|
|
||||||
]
|
|
||||||
# insert the documents in MongoDB Atlas
|
|
||||||
insert_result = self._collection.insert_many(to_insert) # type: ignore
|
|
||||||
return [oid_to_str(_id) for _id in insert_result.inserted_ids]
|
|
||||||
|
|
||||||
def add_documents(
|
|
||||||
self,
|
|
||||||
documents: List[Document],
|
|
||||||
ids: Optional[List[str]] = None,
|
|
||||||
batch_size: int = DEFAULT_INSERT_BATCH_SIZE,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[str]:
|
|
||||||
"""Add documents to the vectorstore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: Documents to add to the vectorstore.
|
|
||||||
ids: Optional list of unique ids that will be used as index in VectorStore.
|
|
||||||
See note on ids in add_texts.
|
|
||||||
batch_size: Number of documents to insert at a time.
|
|
||||||
Tuning this may help with performance and sidestep MongoDB limits.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of IDs of the added texts.
|
|
||||||
"""
|
|
||||||
n_docs = len(documents)
|
|
||||||
if ids:
|
|
||||||
assert len(ids) == n_docs, "Number of ids must equal number of documents."
|
|
||||||
result_ids = []
|
|
||||||
start = 0
|
|
||||||
for end in range(batch_size, n_docs + batch_size, batch_size):
|
|
||||||
texts, metadatas = zip(
|
|
||||||
*[(doc.page_content, doc.metadata) for doc in documents[start:end]]
|
|
||||||
)
|
|
||||||
if ids:
|
|
||||||
result_ids.extend(
|
|
||||||
self.bulk_embed_and_insert_texts(
|
|
||||||
texts=texts, metadatas=metadatas, ids=ids[start:end]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
result_ids.extend(
|
|
||||||
self.bulk_embed_and_insert_texts(texts=texts, metadatas=metadatas)
|
|
||||||
)
|
|
||||||
start = end
|
|
||||||
return result_ids
|
|
||||||
|
|
||||||
def similarity_search_with_score(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 4,
|
|
||||||
pre_filter: Optional[Dict[str, Any]] = None,
|
|
||||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
|
||||||
oversampling_factor: int = 10,
|
|
||||||
include_embeddings: bool = False,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Tuple[Document, float]]: # noqa: E501
|
|
||||||
"""Return MongoDB documents most similar to the given query and their scores.
|
|
||||||
|
|
||||||
Atlas Vector Search eliminates the need to run a separate
|
|
||||||
search system alongside your database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Input text of semantic query
|
|
||||||
k: Number of documents to return. Also known as top_k.
|
|
||||||
pre_filter: List of MQL match expressions comparing an indexed field
|
|
||||||
post_filter_pipeline: (Optional) Arbitrary pipeline of MongoDB
|
|
||||||
aggregation stages applied after the search is complete.
|
|
||||||
oversampling_factor: This times k is the number of candidates chosen
|
|
||||||
at each step in the in HNSW Vector Search
|
|
||||||
include_embeddings: If True, the embedding vector of each result
|
|
||||||
will be included in metadata.
|
|
||||||
kwargs: Additional arguments are specific to the search_type
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of documents most similar to the query and their scores.
|
|
||||||
"""
|
|
||||||
embedding = self._embedding.embed_query(query)
|
|
||||||
docs = self._similarity_search_with_score(
|
|
||||||
embedding,
|
|
||||||
k=k,
|
|
||||||
pre_filter=pre_filter,
|
|
||||||
post_filter_pipeline=post_filter_pipeline,
|
|
||||||
oversampling_factor=oversampling_factor,
|
|
||||||
include_embeddings=include_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def similarity_search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 4,
|
|
||||||
pre_filter: Optional[Dict[str, Any]] = None,
|
|
||||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
|
||||||
oversampling_factor: int = 10,
|
|
||||||
include_scores: bool = False,
|
|
||||||
include_embeddings: bool = False,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]: # noqa: E501
|
|
||||||
"""Return MongoDB documents most similar to the given query.
|
|
||||||
|
|
||||||
Atlas Vector Search eliminates the need to run a separate
|
|
||||||
search system alongside your database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Input text of semantic query
|
|
||||||
k: (Optional) number of documents to return. Defaults to 4.
|
|
||||||
pre_filter: List of MQL match expressions comparing an indexed field
|
|
||||||
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
|
|
||||||
to filter/process results after $vectorSearch.
|
|
||||||
oversampling_factor: Multiple of k used when generating number of candidates
|
|
||||||
at each step in the HNSW Vector Search,
|
|
||||||
include_scores: If True, the query score of each result
|
|
||||||
will be included in metadata.
|
|
||||||
include_embeddings: If True, the embedding vector of each result
|
|
||||||
will be included in metadata.
|
|
||||||
kwargs: Additional arguments are specific to the search_type
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of documents most similar to the query and their scores.
|
|
||||||
"""
|
|
||||||
docs_and_scores = self.similarity_search_with_score(
|
|
||||||
query,
|
|
||||||
k=k,
|
|
||||||
pre_filter=pre_filter,
|
|
||||||
post_filter_pipeline=post_filter_pipeline,
|
|
||||||
oversampling_factor=oversampling_factor,
|
|
||||||
include_embeddings=include_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if include_scores:
|
|
||||||
for doc, score in docs_and_scores:
|
|
||||||
doc.metadata["score"] = score
|
|
||||||
return [doc for doc, _ in docs_and_scores]
|
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 4,
|
|
||||||
fetch_k: int = 20,
|
|
||||||
lambda_mult: float = 0.5,
|
|
||||||
pre_filter: Optional[Dict[str, Any]] = None,
|
|
||||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Return documents selected using the maximal marginal relevance.
|
|
||||||
|
|
||||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
||||||
among selected documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Text to look up documents similar to.
|
|
||||||
k: (Optional) number of documents to return. Defaults to 4.
|
|
||||||
fetch_k: (Optional) number of documents to fetch before passing to MMR
|
|
||||||
algorithm. Defaults to 20.
|
|
||||||
lambda_mult: Number between 0 and 1 that determines the degree
|
|
||||||
of diversity among the results with 0 corresponding
|
|
||||||
to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
|
|
||||||
pre_filter: List of MQL match expressions comparing an indexed field
|
|
||||||
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
|
|
||||||
following the $vectorSearch stage.
|
|
||||||
Returns:
|
|
||||||
List of documents selected by maximal marginal relevance.
|
|
||||||
"""
|
|
||||||
return self.max_marginal_relevance_search_by_vector(
|
|
||||||
embedding=self._embedding.embed_query(query),
|
|
||||||
k=k,
|
|
||||||
fetch_k=fetch_k,
|
|
||||||
lambda_mult=lambda_mult,
|
|
||||||
pre_filter=pre_filter,
|
|
||||||
post_filter_pipeline=post_filter_pipeline,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_texts(
|
|
||||||
cls,
|
|
||||||
texts: List[str],
|
|
||||||
embedding: Embeddings,
|
|
||||||
metadatas: Optional[List[Dict]] = None,
|
|
||||||
collection: Optional[Collection] = None,
|
|
||||||
ids: Optional[List[str]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> MongoDBAtlasVectorSearch:
|
|
||||||
"""Construct a `MongoDB Atlas Vector Search` vector store from raw documents.
|
|
||||||
|
|
||||||
This is a user-friendly interface that:
|
|
||||||
1. Embeds documents.
|
|
||||||
2. Adds the documents to a provided MongoDB Atlas Vector Search index
|
|
||||||
(Lucene)
|
|
||||||
|
|
||||||
This is intended to be a quick way to get started.
|
|
||||||
|
|
||||||
See `MongoDBAtlasVectorSearch` for kwargs and further description.
|
|
||||||
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
from pymongo import MongoClient
|
|
||||||
|
|
||||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
|
||||||
from langchain_openai import OpenAIEmbeddings
|
|
||||||
|
|
||||||
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
|
|
||||||
collection = mongo_client["<db_name>"]["<collection_name>"]
|
|
||||||
embeddings = OpenAIEmbeddings()
|
|
||||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
|
||||||
texts,
|
|
||||||
embeddings,
|
|
||||||
metadatas=metadatas,
|
|
||||||
collection=collection
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
if collection is None:
|
|
||||||
raise ValueError("Must provide 'collection' named parameter.")
|
|
||||||
vectorstore = cls(collection, embedding, **kwargs)
|
|
||||||
vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs)
|
|
||||||
return vectorstore
|
|
||||||
|
|
||||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
|
||||||
"""Delete documents from VectorStore by ids.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ids: List of ids to delete.
|
|
||||||
**kwargs: Other keyword arguments passed to Collection.delete_many()
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[bool]: True if deletion is successful,
|
|
||||||
False otherwise, None if not implemented.
|
|
||||||
"""
|
|
||||||
filter = {}
|
|
||||||
if ids:
|
|
||||||
oids = [str_to_oid(i) for i in ids]
|
|
||||||
filter = {"_id": {"$in": oids}}
|
|
||||||
return self._collection.delete_many(filter=filter, **kwargs).acknowledged
|
|
||||||
|
|
||||||
async def adelete(
|
|
||||||
self, ids: Optional[List[str]] = None, **kwargs: Any
|
|
||||||
) -> Optional[bool]:
|
|
||||||
"""Delete by vector ID or other criteria.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ids: List of ids to delete.
|
|
||||||
**kwargs: Other keyword arguments that subclasses might use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[bool]: True if deletion is successful,
|
|
||||||
False otherwise, None if not implemented.
|
|
||||||
"""
|
|
||||||
return await run_in_executor(None, self.delete, ids=ids, **kwargs)
|
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
|
||||||
self,
|
|
||||||
embedding: List[float],
|
|
||||||
k: int = 4,
|
|
||||||
fetch_k: int = 20,
|
|
||||||
lambda_mult: float = 0.5,
|
|
||||||
pre_filter: Optional[Dict[str, Any]] = None,
|
|
||||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
|
||||||
oversampling_factor: int = 10,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]: # type: ignore
|
|
||||||
"""Return docs selected using the maximal marginal relevance.
|
|
||||||
|
|
||||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
||||||
among selected documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding: Embedding to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
||||||
lambda_mult: Number between 0 and 1 that determines the degree
|
|
||||||
of diversity among the results with 0 corresponding
|
|
||||||
to maximum diversity and 1 to minimum diversity.
|
|
||||||
Defaults to 0.5.
|
|
||||||
pre_filter: (Optional) dictionary of arguments to filter document fields on.
|
|
||||||
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
|
|
||||||
following the vectorSearch stage.
|
|
||||||
oversampling_factor: Multiple of k used when generating number
|
|
||||||
of candidates in HNSW Vector Search,
|
|
||||||
kwargs: Additional arguments are specific to the search_type
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Documents selected by maximal marginal relevance.
|
|
||||||
"""
|
|
||||||
docs = self._similarity_search_with_score(
|
|
||||||
embedding,
|
|
||||||
k=fetch_k,
|
|
||||||
pre_filter=pre_filter,
|
|
||||||
post_filter_pipeline=post_filter_pipeline,
|
|
||||||
include_embeddings=True,
|
|
||||||
oversampling_factor=oversampling_factor,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
mmr_doc_indexes = maximal_marginal_relevance(
|
|
||||||
np.array(embedding),
|
|
||||||
[doc.metadata[self._embedding_key] for doc, _ in docs],
|
|
||||||
k=k,
|
|
||||||
lambda_mult=lambda_mult,
|
|
||||||
)
|
|
||||||
mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
|
|
||||||
return mmr_docs
|
|
||||||
|
|
||||||
async def amax_marginal_relevance_search_by_vector(
|
|
||||||
self,
|
|
||||||
embedding: List[float],
|
|
||||||
k: int = 4,
|
|
||||||
fetch_k: int = 20,
|
|
||||||
lambda_mult: float = 0.5,
|
|
||||||
pre_filter: Optional[Dict[str, Any]] = None,
|
|
||||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
|
||||||
oversampling_factor: int = 10,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Return docs selected using the maximal marginal relevance."""
|
|
||||||
return await run_in_executor(
|
|
||||||
None,
|
|
||||||
self.max_marginal_relevance_search_by_vector, # type: ignore[arg-type]
|
|
||||||
embedding,
|
|
||||||
k=k,
|
|
||||||
fetch_k=fetch_k,
|
|
||||||
lambda_mult=lambda_mult,
|
|
||||||
pre_filter=pre_filter,
|
|
||||||
post_filter_pipeline=post_filter_pipeline,
|
|
||||||
oversampling_factor=oversampling_factor,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _similarity_search_with_score(
|
|
||||||
self,
|
|
||||||
query_vector: List[float],
|
|
||||||
k: int = 4,
|
|
||||||
pre_filter: Optional[Dict[str, Any]] = None,
|
|
||||||
post_filter_pipeline: Optional[List[Dict]] = None,
|
|
||||||
oversampling_factor: int = 10,
|
|
||||||
include_embeddings: bool = False,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Tuple[Document, float]]:
|
|
||||||
"""Core search routine. See external methods for details."""
|
|
||||||
|
|
||||||
# Atlas Vector Search, potentially with filter
|
|
||||||
pipeline = [
|
|
||||||
vector_search_stage(
|
|
||||||
query_vector,
|
|
||||||
self._embedding_key,
|
|
||||||
self._index_name,
|
|
||||||
k,
|
|
||||||
pre_filter,
|
|
||||||
oversampling_factor,
|
|
||||||
**kwargs,
|
|
||||||
),
|
|
||||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Remove embeddings unless requested.
|
|
||||||
if not include_embeddings:
|
|
||||||
pipeline.append({"$project": {self._embedding_key: 0}})
|
|
||||||
# Post-processing
|
|
||||||
if post_filter_pipeline is not None:
|
|
||||||
pipeline.extend(post_filter_pipeline)
|
|
||||||
|
|
||||||
# Execution
|
|
||||||
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
|
|
||||||
docs = []
|
|
||||||
|
|
||||||
# Format
|
|
||||||
for res in cursor:
|
|
||||||
text = res.pop(self._text_key)
|
|
||||||
score = res.pop("score")
|
|
||||||
make_serializable(res)
|
|
||||||
docs.append((Document(page_content=text, metadata=res), score))
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def create_vector_search_index(
|
|
||||||
self,
|
|
||||||
dimensions: int,
|
|
||||||
filters: Optional[List[str]] = None,
|
|
||||||
update: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Creates a MongoDB Atlas vectorSearch index for the VectorStore
|
|
||||||
|
|
||||||
Note**: This method may fail as it requires a MongoDB Atlas with these
|
|
||||||
`pre-requisites <https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#prerequisites>`.
|
|
||||||
Currently, vector and full-text search index operations need to be
|
|
||||||
performed manually on the Atlas UI for shared M0 clusters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dimensions (int): Number of dimensions in embedding
|
|
||||||
filters (Optional[List[Dict[str, str]]], optional): additional filters
|
|
||||||
for index definition.
|
|
||||||
Defaults to None.
|
|
||||||
update (bool, optional): Updates existing vectorSearch index.
|
|
||||||
Defaults to False.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self._collection.database.create_collection(self._collection.name)
|
|
||||||
except CollectionInvalid:
|
|
||||||
pass
|
|
||||||
|
|
||||||
index_operation = (
|
|
||||||
update_vector_search_index if update else create_vector_search_index
|
|
||||||
)
|
|
||||||
|
|
||||||
index_operation(
|
|
||||||
collection=self._collection,
|
|
||||||
index_name=self._index_name,
|
|
||||||
dimensions=dimensions,
|
|
||||||
path=self._embedding_key,
|
|
||||||
similarity=self._relevance_score_fn,
|
|
||||||
filters=filters or [],
|
|
||||||
) # type: ignore [operator]
|
|
2158
libs/partners/mongodb/poetry.lock
generated
2158
libs/partners/mongodb/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,103 +0,0 @@
|
|||||||
[build-system]
|
|
||||||
requires = ["poetry-core>=1.0.0"]
|
|
||||||
build-backend = "poetry.core.masonry.api"
|
|
||||||
|
|
||||||
[tool.poetry]
|
|
||||||
name = "langchain-mongodb"
|
|
||||||
version = "0.2.0"
|
|
||||||
description = "An integration package connecting MongoDB and LangChain"
|
|
||||||
authors = []
|
|
||||||
readme = "README.md"
|
|
||||||
repository = "https://github.com/langchain-ai/langchain"
|
|
||||||
license = "MIT"
|
|
||||||
|
|
||||||
[tool.mypy]
|
|
||||||
disallow_untyped_defs = "True"
|
|
||||||
|
|
||||||
[tool.poetry.urls]
|
|
||||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/mongodb"
|
|
||||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-mongodb%3D%3D0%22&expanded=true"
|
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
|
||||||
python = ">=3.9,<4.0"
|
|
||||||
pymongo = ">=4.6.1,<5.0"
|
|
||||||
langchain-core = "^0.3"
|
|
||||||
|
|
||||||
[[tool.poetry.dependencies.numpy]]
|
|
||||||
version = "^1"
|
|
||||||
python = "<3.12"
|
|
||||||
|
|
||||||
[[tool.poetry.dependencies.numpy]]
|
|
||||||
version = "^1.26.0"
|
|
||||||
python = ">=3.12"
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = ["E", "F", "I"]
|
|
||||||
|
|
||||||
[tool.coverage.run]
|
|
||||||
omit = ["tests/*"]
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
|
||||||
markers = [
|
|
||||||
"requires: mark tests as requiring a specific library",
|
|
||||||
"compile: mark placeholder test used to compile integration tests without running them",
|
|
||||||
]
|
|
||||||
asyncio_mode = "auto"
|
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.codespell]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.test_integration]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.lint]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.dev]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies]
|
|
||||||
pytest = "^7.3.0"
|
|
||||||
freezegun = "^1.2.2"
|
|
||||||
pytest-mock = "^3.10.0"
|
|
||||||
syrupy = "^4.0.2"
|
|
||||||
pytest-watcher = "^0.3.4"
|
|
||||||
pytest-asyncio = "^0.21.1"
|
|
||||||
|
|
||||||
[tool.poetry.group.codespell.dependencies]
|
|
||||||
codespell = "^2.2.0"
|
|
||||||
|
|
||||||
[tool.poetry.group.test_integration.dependencies.langchain-openai]
|
|
||||||
path = "../openai"
|
|
||||||
develop = true
|
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies]
|
|
||||||
ruff = "^0.5"
|
|
||||||
|
|
||||||
[tool.poetry.group.typing.dependencies]
|
|
||||||
mypy = "^1.10"
|
|
||||||
simsimd = "^5.0.0"
|
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies.langchain]
|
|
||||||
path = "../../langchain"
|
|
||||||
develop = true
|
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies.langchain-core]
|
|
||||||
path = "../../core"
|
|
||||||
develop = true
|
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies.langchain-text-splitters]
|
|
||||||
path = "../../text-splitters"
|
|
||||||
develop = true
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies.langchain-core]
|
|
||||||
path = "../../core"
|
|
||||||
develop = true
|
|
||||||
|
|
||||||
[tool.poetry.group.typing.dependencies.langchain-core]
|
|
||||||
path = "../../core"
|
|
||||||
develop = true
|
|
@ -1,17 +0,0 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from importlib.machinery import SourceFileLoader
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
files = sys.argv[1:]
|
|
||||||
has_failure = False
|
|
||||||
for file in files:
|
|
||||||
try:
|
|
||||||
SourceFileLoader("x", file).load_module()
|
|
||||||
except Exception:
|
|
||||||
has_failure = True
|
|
||||||
print(file)
|
|
||||||
traceback.print_exc()
|
|
||||||
print()
|
|
||||||
|
|
||||||
sys.exit(1 if has_failure else 0)
|
|
@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -eu
|
|
||||||
|
|
||||||
# Initialize a variable to keep track of errors
|
|
||||||
errors=0
|
|
||||||
|
|
||||||
# make sure not importing from langchain or langchain_experimental
|
|
||||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
|
||||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
|
||||||
|
|
||||||
# Decide on an exit status based on the errors
|
|
||||||
if [ "$errors" -gt 0 ]; then
|
|
||||||
exit 1
|
|
||||||
else
|
|
||||||
exit 0
|
|
||||||
fi
|
|
@ -1,161 +0,0 @@
|
|||||||
import os
|
|
||||||
import uuid
|
|
||||||
from typing import Any, List, Union
|
|
||||||
|
|
||||||
import pytest # type: ignore[import-not-found]
|
|
||||||
from langchain_core.caches import BaseCache
|
|
||||||
from langchain_core.globals import get_llm_cache, set_llm_cache
|
|
||||||
from langchain_core.load.dump import dumps
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
|
||||||
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
|
|
||||||
|
|
||||||
from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache
|
|
||||||
|
|
||||||
from ..utils import ConsistentFakeEmbeddings, FakeChatModel, FakeLLM
|
|
||||||
|
|
||||||
CONN_STRING = os.environ.get("MONGODB_ATLAS_URI")
|
|
||||||
INDEX_NAME = "langchain-test-index-semantic-cache"
|
|
||||||
DATABASE = "langchain_test_db"
|
|
||||||
COLLECTION = "langchain_test_cache"
|
|
||||||
|
|
||||||
|
|
||||||
def random_string() -> str:
|
|
||||||
return str(uuid.uuid4())
|
|
||||||
|
|
||||||
|
|
||||||
def llm_cache(cls: Any) -> BaseCache:
|
|
||||||
set_llm_cache(
|
|
||||||
cls(
|
|
||||||
embedding=ConsistentFakeEmbeddings(dimensionality=1536),
|
|
||||||
connection_string=CONN_STRING,
|
|
||||||
collection_name=COLLECTION,
|
|
||||||
database_name=DATABASE,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
score_threshold=0.5,
|
|
||||||
wait_until_ready=15.0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert get_llm_cache()
|
|
||||||
return get_llm_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def _execute_test(
|
|
||||||
prompt: Union[str, List[BaseMessage]],
|
|
||||||
llm: Union[str, FakeLLM, FakeChatModel],
|
|
||||||
response: List[Generation],
|
|
||||||
) -> None:
|
|
||||||
# Fabricate an LLM String
|
|
||||||
|
|
||||||
if not isinstance(llm, str):
|
|
||||||
params = llm.dict()
|
|
||||||
params["stop"] = None
|
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
||||||
else:
|
|
||||||
llm_string = llm
|
|
||||||
|
|
||||||
# If the prompt is a str then we should pass just the string
|
|
||||||
dumped_prompt: str = prompt if isinstance(prompt, str) else dumps(prompt)
|
|
||||||
|
|
||||||
# Update the cache
|
|
||||||
get_llm_cache().update(dumped_prompt, llm_string, response)
|
|
||||||
|
|
||||||
# Retrieve the cached result through 'generate' call
|
|
||||||
output: Union[List[Generation], LLMResult, None]
|
|
||||||
expected_output: Union[List[Generation], LLMResult]
|
|
||||||
|
|
||||||
if isinstance(llm, str):
|
|
||||||
output = get_llm_cache().lookup(dumped_prompt, llm) # type: ignore
|
|
||||||
expected_output = response
|
|
||||||
else:
|
|
||||||
output = llm.generate([prompt]) # type: ignore
|
|
||||||
expected_output = LLMResult(
|
|
||||||
generations=[response],
|
|
||||||
llm_output={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert output == expected_output # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"prompt, llm, response",
|
|
||||||
[
|
|
||||||
("foo", "bar", [Generation(text="fizz")]),
|
|
||||||
("foo", FakeLLM(), [Generation(text="fizz")]),
|
|
||||||
(
|
|
||||||
[HumanMessage(content="foo")],
|
|
||||||
FakeChatModel(),
|
|
||||||
[ChatGeneration(message=AIMessage(content="foo"))],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"plain_cache",
|
|
||||||
"cache_with_llm",
|
|
||||||
"cache_with_chat",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("cacher", [MongoDBCache, MongoDBAtlasSemanticCache])
|
|
||||||
@pytest.mark.parametrize("remove_score", [True, False])
|
|
||||||
def test_mongodb_cache(
|
|
||||||
remove_score: bool,
|
|
||||||
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
|
|
||||||
prompt: Union[str, List[BaseMessage]],
|
|
||||||
llm: Union[str, FakeLLM, FakeChatModel],
|
|
||||||
response: List[Generation],
|
|
||||||
) -> None:
|
|
||||||
llm_cache(cacher)
|
|
||||||
if remove_score:
|
|
||||||
get_llm_cache().score_threshold = None # type: ignore
|
|
||||||
try:
|
|
||||||
_execute_test(prompt, llm, response)
|
|
||||||
finally:
|
|
||||||
get_llm_cache().clear()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"prompts, generations",
|
|
||||||
[
|
|
||||||
# Single prompt, single generation
|
|
||||||
([random_string()], [[random_string()]]),
|
|
||||||
# Single prompt, multiple generations
|
|
||||||
([random_string()], [[random_string(), random_string()]]),
|
|
||||||
# Single prompt, multiple generations
|
|
||||||
([random_string()], [[random_string(), random_string(), random_string()]]),
|
|
||||||
# Multiple prompts, multiple generations
|
|
||||||
(
|
|
||||||
[random_string(), random_string()],
|
|
||||||
[[random_string()], [random_string(), random_string()]],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"single_prompt_single_generation",
|
|
||||||
"single_prompt_two_generations",
|
|
||||||
"single_prompt_three_generations",
|
|
||||||
"multiple_prompts_multiple_generations",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_mongodb_atlas_cache_matrix(
|
|
||||||
prompts: List[str],
|
|
||||||
generations: List[List[str]],
|
|
||||||
) -> None:
|
|
||||||
llm_cache(MongoDBAtlasSemanticCache)
|
|
||||||
llm = FakeLLM()
|
|
||||||
|
|
||||||
# Fabricate an LLM String
|
|
||||||
params = llm.dict()
|
|
||||||
params["stop"] = None
|
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
||||||
|
|
||||||
llm_generations = [
|
|
||||||
[
|
|
||||||
Generation(text=generation, generation_info=params)
|
|
||||||
for generation in prompt_i_generations
|
|
||||||
]
|
|
||||||
for prompt_i_generations in generations
|
|
||||||
]
|
|
||||||
|
|
||||||
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
|
||||||
_execute_test(prompt_i, llm_string, llm_generations_i)
|
|
||||||
assert llm.generate(prompts) == LLMResult(
|
|
||||||
generations=llm_generations, llm_output={}
|
|
||||||
)
|
|
||||||
get_llm_cache().clear()
|
|
@ -1,144 +0,0 @@
|
|||||||
"Demonstrates MongoDBAtlasVectorSearch.as_retriever() invoked in a chain" ""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from time import sleep
|
|
||||||
|
|
||||||
import pytest # type: ignore[import-not-found]
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.output_parsers.string import StrOutputParser
|
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
|
||||||
from langchain_core.runnables import RunnablePassthrough
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
|
|
||||||
from langchain_mongodb import index
|
|
||||||
|
|
||||||
from ..utils import PatchedMongoDBAtlasVectorSearch
|
|
||||||
|
|
||||||
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
|
|
||||||
DB_NAME = "langchain_test_db"
|
|
||||||
COLLECTION_NAME = "langchain_test_chain_example"
|
|
||||||
INDEX_NAME = "vector_index"
|
|
||||||
DIMENSIONS = 1536
|
|
||||||
TIMEOUT = 60.0
|
|
||||||
INTERVAL = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def collection() -> Collection:
|
|
||||||
"""A Collection with both a Vector and a Full-text Search Index"""
|
|
||||||
client: MongoClient = MongoClient(CONNECTION_STRING)
|
|
||||||
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
|
|
||||||
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
|
|
||||||
else:
|
|
||||||
clxn = client[DB_NAME][COLLECTION_NAME]
|
|
||||||
|
|
||||||
clxn.delete_many({})
|
|
||||||
|
|
||||||
if all([INDEX_NAME != ix["name"] for ix in clxn.list_search_indexes()]):
|
|
||||||
index.create_vector_search_index(
|
|
||||||
collection=clxn,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
dimensions=DIMENSIONS,
|
|
||||||
path="embedding",
|
|
||||||
similarity="cosine",
|
|
||||||
filters=None,
|
|
||||||
wait_until_complete=TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
return clxn
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
"OPENAI_API_KEY" not in os.environ, reason="Requires OpenAI for chat responses."
|
|
||||||
)
|
|
||||||
def test_chain(
|
|
||||||
collection: Collection,
|
|
||||||
) -> None:
|
|
||||||
"""Demonstrate usage of MongoDBAtlasVectorSearch in a realistic chain
|
|
||||||
|
|
||||||
Follows example in the docs: https://python.langchain.com/docs/how_to/hybrid/
|
|
||||||
|
|
||||||
Requires OpenAI_API_KEY for embedding and chat model.
|
|
||||||
Requires INDEX_NAME to have been set up on MONGODB_ATLAS_URI
|
|
||||||
"""
|
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
|
||||||
|
|
||||||
embedding_openai = OpenAIEmbeddings(
|
|
||||||
openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
|
|
||||||
model="text-embedding-3-small",
|
|
||||||
)
|
|
||||||
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
|
||||||
collection=collection,
|
|
||||||
embedding=embedding_openai,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
text_key="page_content",
|
|
||||||
)
|
|
||||||
|
|
||||||
texts = [
|
|
||||||
"In 2023, I visited Paris",
|
|
||||||
"In 2022, I visited New York",
|
|
||||||
"In 2021, I visited New Orleans",
|
|
||||||
"In 2019, I visited San Francisco",
|
|
||||||
"In 2020, I visited Vancouver",
|
|
||||||
]
|
|
||||||
vectorstore.add_texts(texts)
|
|
||||||
|
|
||||||
# Give the index time to build (For CI)
|
|
||||||
sleep(TIMEOUT)
|
|
||||||
|
|
||||||
query = "In the United States, what city did I visit last?"
|
|
||||||
# One can do vector search on the vector store, using its various search types.
|
|
||||||
k = len(texts)
|
|
||||||
|
|
||||||
store_output = list(vectorstore.similarity_search(query=query, k=k))
|
|
||||||
assert len(store_output) == k
|
|
||||||
assert isinstance(store_output[0], Document)
|
|
||||||
|
|
||||||
# Unfortunately, the VectorStore output cannot be given to a Chat Model
|
|
||||||
# If we wish Chat Model to answer based on our own data,
|
|
||||||
# we have to give it the right things to work with.
|
|
||||||
# The way that Langchain does this is by piping results along in
|
|
||||||
# a Chain: https://python.langchain.com/v0.1/docs/modules/chains/
|
|
||||||
|
|
||||||
# Now, we can turn our VectorStore into something Runnable in a Chain
|
|
||||||
# by turning it into a Retriever.
|
|
||||||
# For the simple VectorSearch Retriever, we can do this like so.
|
|
||||||
|
|
||||||
retriever = vectorstore.as_retriever(search_kwargs=dict(k=k))
|
|
||||||
|
|
||||||
# This does not do much other than expose our search function
|
|
||||||
# as an invoke() method with a a certain API, a Runnable.
|
|
||||||
retriever_output = retriever.invoke(query)
|
|
||||||
assert len(retriever_output) == len(texts)
|
|
||||||
assert retriever_output[0].page_content == store_output[0].page_content
|
|
||||||
|
|
||||||
# To get a natural language response to our question,
|
|
||||||
# we need ChatOpenAI, a template to better frame the question as a prompt,
|
|
||||||
# and a parser to send the output to a string.
|
|
||||||
# Together, these become our Chain!
|
|
||||||
# Here goes:
|
|
||||||
|
|
||||||
template = """Answer the question based only on the following context.
|
|
||||||
Answer in as few words as possible.
|
|
||||||
{context}
|
|
||||||
Question: {question}
|
|
||||||
"""
|
|
||||||
prompt = ChatPromptTemplate.from_template(template)
|
|
||||||
|
|
||||||
model = ChatOpenAI()
|
|
||||||
|
|
||||||
chain = (
|
|
||||||
{"context": retriever, "question": RunnablePassthrough()} # type: ignore
|
|
||||||
| prompt
|
|
||||||
| model
|
|
||||||
| StrOutputParser()
|
|
||||||
)
|
|
||||||
|
|
||||||
answer = chain.invoke("What city did I visit last?")
|
|
||||||
|
|
||||||
assert "Paris" in answer
|
|
@ -1,43 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
|
|
||||||
from langchain_core.messages import message_to_dict
|
|
||||||
|
|
||||||
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
|
|
||||||
|
|
||||||
DATABASE = "langchain_test_db"
|
|
||||||
COLLECTION = "langchain_test_chat"
|
|
||||||
|
|
||||||
# Replace these with your mongodb connection string
|
|
||||||
connection_string = os.environ.get("MONGODB_ATLAS_URI", "")
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_with_message_store() -> None:
|
|
||||||
"""Test the memory with a message store."""
|
|
||||||
# setup MongoDB as a message store
|
|
||||||
message_history = MongoDBChatMessageHistory(
|
|
||||||
connection_string=connection_string,
|
|
||||||
session_id="test-session",
|
|
||||||
database_name=DATABASE,
|
|
||||||
collection_name=COLLECTION,
|
|
||||||
)
|
|
||||||
memory = ConversationBufferMemory(
|
|
||||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# add some messages
|
|
||||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
|
||||||
memory.chat_memory.add_user_message("This is me, the human")
|
|
||||||
|
|
||||||
# get the message history from the memory store and turn it into a json
|
|
||||||
messages = memory.chat_memory.messages
|
|
||||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
|
||||||
|
|
||||||
assert "This is me, the AI" in messages_json
|
|
||||||
assert "This is me, the human" in messages_json
|
|
||||||
|
|
||||||
# remove the record from MongoDB, so the next test run won't pick it up
|
|
||||||
memory.chat_memory.clear()
|
|
||||||
|
|
||||||
assert memory.chat_memory.messages == []
|
|
@ -1,7 +0,0 @@
|
|||||||
import pytest # type: ignore[import-not-found]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.compile
|
|
||||||
def test_placeholder() -> None:
|
|
||||||
"""Used for compiling integration tests without running any real tests."""
|
|
||||||
pass
|
|
@ -1,83 +0,0 @@
|
|||||||
"""Search index commands are only supported on Atlas Clusters >=M10"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Generator, List, Optional
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
|
|
||||||
from langchain_mongodb import index
|
|
||||||
|
|
||||||
DB_NAME = "langchain_test_index_db"
|
|
||||||
COLLECTION_NAME = "test_index"
|
|
||||||
VECTOR_INDEX_NAME = "vector_index"
|
|
||||||
|
|
||||||
TIMEOUT = 120
|
|
||||||
DIMENSIONS = 10
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def collection() -> Generator:
|
|
||||||
"""Depending on uri, this could point to any type of cluster."""
|
|
||||||
uri = os.environ.get("MONGODB_ATLAS_URI")
|
|
||||||
client: MongoClient = MongoClient(uri)
|
|
||||||
clxn = client[DB_NAME][COLLECTION_NAME]
|
|
||||||
clxn.insert_one({"foo": "bar"})
|
|
||||||
yield clxn
|
|
||||||
clxn.drop()
|
|
||||||
|
|
||||||
|
|
||||||
def test_search_index_commands(collection: Collection) -> None:
|
|
||||||
index_name = VECTOR_INDEX_NAME
|
|
||||||
dimensions = DIMENSIONS
|
|
||||||
path = "embedding"
|
|
||||||
similarity = "cosine"
|
|
||||||
filters: Optional[List[str]] = None
|
|
||||||
wait_until_complete = TIMEOUT
|
|
||||||
|
|
||||||
for index_info in collection.list_search_indexes():
|
|
||||||
index.drop_vector_search_index(
|
|
||||||
collection, index_info["name"], wait_until_complete=wait_until_complete
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(list(collection.list_search_indexes())) == 0
|
|
||||||
|
|
||||||
index.create_vector_search_index(
|
|
||||||
collection=collection,
|
|
||||||
index_name=index_name,
|
|
||||||
dimensions=dimensions,
|
|
||||||
path=path,
|
|
||||||
similarity=similarity,
|
|
||||||
filters=filters,
|
|
||||||
wait_until_complete=wait_until_complete,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert index._is_index_ready(collection, index_name)
|
|
||||||
indexes = list(collection.list_search_indexes())
|
|
||||||
assert len(indexes) == 1
|
|
||||||
assert indexes[0]["name"] == index_name
|
|
||||||
|
|
||||||
new_similarity = "euclidean"
|
|
||||||
index.update_vector_search_index(
|
|
||||||
collection,
|
|
||||||
index_name,
|
|
||||||
DIMENSIONS,
|
|
||||||
"embedding",
|
|
||||||
new_similarity,
|
|
||||||
filters=[],
|
|
||||||
wait_until_complete=wait_until_complete,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert index._is_index_ready(collection, index_name)
|
|
||||||
indexes = list(collection.list_search_indexes())
|
|
||||||
assert len(indexes) == 1
|
|
||||||
assert indexes[0]["name"] == index_name
|
|
||||||
assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == new_similarity
|
|
||||||
|
|
||||||
index.drop_vector_search_index(
|
|
||||||
collection, index_name, wait_until_complete=wait_until_complete
|
|
||||||
)
|
|
||||||
|
|
||||||
indexes = list(collection.list_search_indexes())
|
|
||||||
assert len(indexes) == 0
|
|
@ -1,176 +0,0 @@
|
|||||||
import os
|
|
||||||
from time import sleep
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
|
|
||||||
from langchain_mongodb import index
|
|
||||||
from langchain_mongodb.retrievers import (
|
|
||||||
MongoDBAtlasFullTextSearchRetriever,
|
|
||||||
MongoDBAtlasHybridSearchRetriever,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch
|
|
||||||
|
|
||||||
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
|
|
||||||
DB_NAME = "langchain_test_db"
|
|
||||||
COLLECTION_NAME = "langchain_test_retrievers"
|
|
||||||
VECTOR_INDEX_NAME = "vector_index"
|
|
||||||
EMBEDDING_FIELD = "embedding"
|
|
||||||
PAGE_CONTENT_FIELD = "text"
|
|
||||||
SEARCH_INDEX_NAME = "text_index"
|
|
||||||
|
|
||||||
DIMENSIONS = 1536
|
|
||||||
TIMEOUT = 60.0
|
|
||||||
INTERVAL = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def example_documents() -> List[Document]:
|
|
||||||
return [
|
|
||||||
Document(page_content="In 2023, I visited Paris"),
|
|
||||||
Document(page_content="In 2022, I visited New York"),
|
|
||||||
Document(page_content="In 2021, I visited New Orleans"),
|
|
||||||
Document(page_content="Sandwiches are beautiful. Sandwiches are fine."),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def embedding_openai() -> Embeddings:
|
|
||||||
from langchain_openai import OpenAIEmbeddings
|
|
||||||
|
|
||||||
try:
|
|
||||||
return OpenAIEmbeddings(
|
|
||||||
openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
|
|
||||||
model="text-embedding-3-small",
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
return ConsistentFakeEmbeddings(DIMENSIONS)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def collection() -> Collection:
|
|
||||||
"""A Collection with both a Vector and a Full-text Search Index"""
|
|
||||||
client: MongoClient = MongoClient(CONNECTION_STRING)
|
|
||||||
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
|
|
||||||
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
|
|
||||||
else:
|
|
||||||
clxn = client[DB_NAME][COLLECTION_NAME]
|
|
||||||
|
|
||||||
clxn.delete_many({})
|
|
||||||
|
|
||||||
if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
|
|
||||||
index.create_vector_search_index(
|
|
||||||
collection=clxn,
|
|
||||||
index_name=VECTOR_INDEX_NAME,
|
|
||||||
dimensions=DIMENSIONS,
|
|
||||||
path="embedding",
|
|
||||||
similarity="cosine",
|
|
||||||
wait_until_complete=TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not any([SEARCH_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
|
|
||||||
index.create_fulltext_search_index(
|
|
||||||
collection=clxn,
|
|
||||||
index_name=SEARCH_INDEX_NAME,
|
|
||||||
field=PAGE_CONTENT_FIELD,
|
|
||||||
wait_until_complete=TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
return clxn
|
|
||||||
|
|
||||||
|
|
||||||
def test_hybrid_retriever(
|
|
||||||
embedding_openai: Embeddings,
|
|
||||||
collection: Collection,
|
|
||||||
example_documents: List[Document],
|
|
||||||
) -> None:
|
|
||||||
"""Test basic usage of MongoDBAtlasHybridSearchRetriever"""
|
|
||||||
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
|
||||||
collection=collection,
|
|
||||||
embedding=embedding_openai,
|
|
||||||
index_name=VECTOR_INDEX_NAME,
|
|
||||||
text_key=PAGE_CONTENT_FIELD,
|
|
||||||
)
|
|
||||||
|
|
||||||
vectorstore.add_documents(example_documents)
|
|
||||||
|
|
||||||
sleep(TIMEOUT) # Wait for documents to be sync'd
|
|
||||||
|
|
||||||
retriever = MongoDBAtlasHybridSearchRetriever(
|
|
||||||
vectorstore=vectorstore,
|
|
||||||
search_index_name=SEARCH_INDEX_NAME,
|
|
||||||
top_k=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
query1 = "What was the latest city that I visited?"
|
|
||||||
results = retriever.invoke(query1)
|
|
||||||
assert len(results) == 3
|
|
||||||
assert "Paris" in results[0].page_content
|
|
||||||
|
|
||||||
query2 = "When was the last time I visited new orleans?"
|
|
||||||
results = retriever.invoke(query2)
|
|
||||||
assert "New Orleans" in results[0].page_content
|
|
||||||
|
|
||||||
|
|
||||||
def test_fulltext_retriever(
|
|
||||||
collection: Collection,
|
|
||||||
example_documents: List[Document],
|
|
||||||
) -> None:
|
|
||||||
"""Test result of performing fulltext search
|
|
||||||
|
|
||||||
Independent of the VectorStore, one adds documents
|
|
||||||
via MongoDB's Collection API
|
|
||||||
"""
|
|
||||||
#
|
|
||||||
|
|
||||||
collection.insert_many(
|
|
||||||
[{PAGE_CONTENT_FIELD: doc.page_content} for doc in example_documents]
|
|
||||||
)
|
|
||||||
sleep(TIMEOUT) # Wait for documents to be sync'd
|
|
||||||
|
|
||||||
retriever = MongoDBAtlasFullTextSearchRetriever(
|
|
||||||
collection=collection,
|
|
||||||
search_index_name=SEARCH_INDEX_NAME,
|
|
||||||
search_field=PAGE_CONTENT_FIELD,
|
|
||||||
)
|
|
||||||
|
|
||||||
query = "When was the last time I visited new orleans?"
|
|
||||||
results = retriever.invoke(query)
|
|
||||||
assert "New Orleans" in results[0].page_content
|
|
||||||
assert "score" in results[0].metadata
|
|
||||||
|
|
||||||
|
|
||||||
def test_vector_retriever(
|
|
||||||
embedding_openai: Embeddings,
|
|
||||||
collection: Collection,
|
|
||||||
example_documents: List[Document],
|
|
||||||
) -> None:
|
|
||||||
"""Test VectorStoreRetriever"""
|
|
||||||
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
|
||||||
collection=collection,
|
|
||||||
embedding=embedding_openai,
|
|
||||||
index_name=VECTOR_INDEX_NAME,
|
|
||||||
text_key=PAGE_CONTENT_FIELD,
|
|
||||||
)
|
|
||||||
|
|
||||||
vectorstore.add_documents(example_documents)
|
|
||||||
|
|
||||||
sleep(TIMEOUT) # Wait for documents to be sync'd
|
|
||||||
|
|
||||||
retriever = vectorstore.as_retriever()
|
|
||||||
|
|
||||||
query1 = "What was the latest city that I visited?"
|
|
||||||
results = retriever.invoke(query1)
|
|
||||||
assert len(results) == 4
|
|
||||||
assert "Paris" in results[0].page_content
|
|
||||||
|
|
||||||
query2 = "When was the last time I visited new orleans?"
|
|
||||||
results = retriever.invoke(query2)
|
|
||||||
assert "New Orleans" in results[0].page_content
|
|
@ -1,473 +0,0 @@
|
|||||||
"""Test MongoDB Atlas Vector Search functionality."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from time import monotonic, sleep
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import pytest # type: ignore[import-not-found]
|
|
||||||
from bson import ObjectId
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
from pymongo.errors import OperationFailure
|
|
||||||
|
|
||||||
from langchain_mongodb.index import drop_vector_search_index
|
|
||||||
from langchain_mongodb.utils import oid_to_str
|
|
||||||
|
|
||||||
from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch
|
|
||||||
|
|
||||||
INDEX_NAME = "langchain-test-index-vectorstores"
|
|
||||||
INDEX_CREATION_NAME = "langchain-test-index-vectorstores-create-test"
|
|
||||||
NAMESPACE = "langchain_test_db.langchain_test_vectorstores"
|
|
||||||
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
|
|
||||||
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
|
||||||
INDEX_COLLECTION_NAME = "langchain_test_vectorstores_index"
|
|
||||||
INDEX_DB_NAME = "langchain_test_index_db"
|
|
||||||
DIMENSIONS = 1536
|
|
||||||
TIMEOUT = 120.0
|
|
||||||
INTERVAL = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def example_documents() -> List[Document]:
|
|
||||||
return [
|
|
||||||
Document(page_content="Dogs are tough.", metadata={"a": 1}),
|
|
||||||
Document(page_content="Cats have fluff.", metadata={"b": 1}),
|
|
||||||
Document(page_content="What is a sandwich?", metadata={"c": 1}),
|
|
||||||
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _await_index_deletion(coll: Collection, index_name: str) -> None:
|
|
||||||
start = monotonic()
|
|
||||||
try:
|
|
||||||
drop_vector_search_index(coll, index_name)
|
|
||||||
except OperationFailure:
|
|
||||||
# This most likely means an ongoing drop request was made so skip
|
|
||||||
pass
|
|
||||||
|
|
||||||
while list(coll.list_search_indexes(name=index_name)):
|
|
||||||
if monotonic() - start > TIMEOUT:
|
|
||||||
raise TimeoutError(f"Index Name: {index_name} never dropped")
|
|
||||||
sleep(INTERVAL)
|
|
||||||
|
|
||||||
|
|
||||||
def get_collection(
|
|
||||||
database_name: str = DB_NAME, collection_name: str = COLLECTION_NAME
|
|
||||||
) -> Collection:
|
|
||||||
test_client: MongoClient = MongoClient(CONNECTION_STRING)
|
|
||||||
return test_client[database_name][collection_name]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def collection() -> Collection:
|
|
||||||
return get_collection()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def texts() -> List[str]:
|
|
||||||
return [
|
|
||||||
"Dogs are tough.",
|
|
||||||
"Cats have fluff.",
|
|
||||||
"What is a sandwich?",
|
|
||||||
"That fence is purple.",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def index_collection() -> Collection:
|
|
||||||
return get_collection(INDEX_DB_NAME, INDEX_COLLECTION_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
class TestMongoDBAtlasVectorSearch:
|
|
||||||
@classmethod
|
|
||||||
def setup_class(cls) -> None:
|
|
||||||
# insure the test collection is empty
|
|
||||||
collection = get_collection()
|
|
||||||
if collection.count_documents({}):
|
|
||||||
collection.delete_many({}) # type: ignore[index]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def teardown_class(cls) -> None:
|
|
||||||
collection = get_collection()
|
|
||||||
# delete all the documents in the collection
|
|
||||||
collection.delete_many({}) # type: ignore[index]
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup(self) -> None:
|
|
||||||
collection = get_collection()
|
|
||||||
# delete all the documents in the collection
|
|
||||||
collection.delete_many({}) # type: ignore[index]
|
|
||||||
|
|
||||||
# delete all indexes on index collection name
|
|
||||||
_await_index_deletion(
|
|
||||||
get_collection(INDEX_DB_NAME, INDEX_COLLECTION_NAME), INDEX_CREATION_NAME
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def embeddings(self) -> Embeddings:
|
|
||||||
try:
|
|
||||||
from langchain_openai import OpenAIEmbeddings
|
|
||||||
|
|
||||||
return OpenAIEmbeddings(
|
|
||||||
openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
|
|
||||||
model="text-embedding-3-small",
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
return ConsistentFakeEmbeddings(DIMENSIONS)
|
|
||||||
|
|
||||||
def test_from_documents(
|
|
||||||
self,
|
|
||||||
embeddings: Embeddings,
|
|
||||||
collection: Any,
|
|
||||||
example_documents: List[Document],
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search."""
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
|
||||||
example_documents,
|
|
||||||
embedding=embeddings,
|
|
||||||
collection=collection,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
output = vectorstore.similarity_search("Sandwich", k=1)
|
|
||||||
assert len(output) == 1
|
|
||||||
# Check for the presence of the metadata key
|
|
||||||
assert any(
|
|
||||||
[key.page_content == output[0].page_content for key in example_documents]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_from_documents_no_embedding_return(
|
|
||||||
self,
|
|
||||||
embeddings: Embeddings,
|
|
||||||
collection: Any,
|
|
||||||
example_documents: List[Document],
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search."""
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
|
||||||
example_documents,
|
|
||||||
embedding=embeddings,
|
|
||||||
collection=collection,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
output = vectorstore.similarity_search("Sandwich", k=1)
|
|
||||||
assert len(output) == 1
|
|
||||||
# Check for presence of embedding in each document
|
|
||||||
assert all(["embedding" not in key.metadata for key in output])
|
|
||||||
# Check for the presence of the metadata key
|
|
||||||
assert any(
|
|
||||||
[key.page_content == output[0].page_content for key in example_documents]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_from_documents_embedding_return(
|
|
||||||
self,
|
|
||||||
embeddings: Embeddings,
|
|
||||||
collection: Any,
|
|
||||||
example_documents: List[Document],
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search."""
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
|
||||||
example_documents,
|
|
||||||
embedding=embeddings,
|
|
||||||
collection=collection,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
output = vectorstore.similarity_search("Sandwich", k=1, include_embeddings=True)
|
|
||||||
assert len(output) == 1
|
|
||||||
# Check for presence of embedding in each document
|
|
||||||
assert all([key.metadata.get("embedding") for key in output])
|
|
||||||
# Check for the presence of the metadata key
|
|
||||||
assert any(
|
|
||||||
[key.page_content == output[0].page_content for key in example_documents]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_from_texts(
|
|
||||||
self, embeddings: Embeddings, collection: Collection, texts: List[str]
|
|
||||||
) -> None:
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
|
||||||
texts,
|
|
||||||
embedding=embeddings,
|
|
||||||
collection=collection,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
output = vectorstore.similarity_search("Sandwich", k=1)
|
|
||||||
assert len(output) == 1
|
|
||||||
|
|
||||||
def test_from_texts_with_metadatas(
|
|
||||||
self,
|
|
||||||
embeddings: Embeddings,
|
|
||||||
collection: Collection,
|
|
||||||
texts: List[str],
|
|
||||||
) -> None:
|
|
||||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
|
||||||
metakeys = ["a", "b", "c", "d", "e"]
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
|
||||||
texts,
|
|
||||||
embedding=embeddings,
|
|
||||||
metadatas=metadatas,
|
|
||||||
collection=collection,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
output = vectorstore.similarity_search("Sandwich", k=1)
|
|
||||||
assert len(output) == 1
|
|
||||||
# Check for the presence of the metadata key
|
|
||||||
assert any([key in output[0].metadata for key in metakeys])
|
|
||||||
|
|
||||||
def test_from_texts_with_metadatas_and_pre_filter(
|
|
||||||
self, embeddings: Embeddings, collection: Any, texts: List[str]
|
|
||||||
) -> None:
|
|
||||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
|
||||||
texts,
|
|
||||||
embedding=embeddings,
|
|
||||||
metadatas=metadatas,
|
|
||||||
collection=collection,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
does_not_match_filter = vectorstore.similarity_search(
|
|
||||||
"Sandwich", k=1, pre_filter={"c": {"$lte": 0}}
|
|
||||||
)
|
|
||||||
assert does_not_match_filter == []
|
|
||||||
|
|
||||||
matches_filter = vectorstore.similarity_search(
|
|
||||||
"Sandwich", k=3, pre_filter={"c": {"$gt": 0}}
|
|
||||||
)
|
|
||||||
assert len(matches_filter) == 1
|
|
||||||
|
|
||||||
def test_mmr(self, embeddings: Embeddings, collection: Any) -> None:
|
|
||||||
texts = ["foo", "foo", "fou", "foy"]
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
|
|
||||||
texts,
|
|
||||||
embedding=embeddings,
|
|
||||||
collection=collection,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
query = "foo"
|
|
||||||
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
|
|
||||||
assert len(output) == len(texts)
|
|
||||||
assert output[0].page_content == "foo"
|
|
||||||
assert output[1].page_content != "foo"
|
|
||||||
|
|
||||||
def test_retriever(
|
|
||||||
self,
|
|
||||||
embeddings: Embeddings,
|
|
||||||
collection: Any,
|
|
||||||
example_documents: List[Document],
|
|
||||||
) -> None:
|
|
||||||
"""Demonstrate usage and parity of VectorStore similarity_search
|
|
||||||
with Retriever.invoke."""
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
|
||||||
example_documents,
|
|
||||||
embedding=embeddings,
|
|
||||||
collection=collection,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
query = "sandwich"
|
|
||||||
|
|
||||||
retriever_default_kwargs = vectorstore.as_retriever()
|
|
||||||
result_retriever = retriever_default_kwargs.invoke(query)
|
|
||||||
result_vectorstore = vectorstore.similarity_search(query)
|
|
||||||
assert all(
|
|
||||||
[
|
|
||||||
result_retriever[i].page_content == result_vectorstore[i].page_content
|
|
||||||
for i in range(len(result_retriever))
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_include_embeddings(
|
|
||||||
self,
|
|
||||||
embeddings: Embeddings,
|
|
||||||
collection: Any,
|
|
||||||
example_documents: List[Document],
|
|
||||||
) -> None:
|
|
||||||
"""Test explicitly passing vector kwarg matches default."""
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
|
|
||||||
documents=example_documents,
|
|
||||||
embedding=embeddings,
|
|
||||||
collection=collection,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_with = vectorstore.similarity_search(
|
|
||||||
"Sandwich", include_embeddings=True, k=1
|
|
||||||
)
|
|
||||||
assert vectorstore._embedding_key in output_with[0].metadata
|
|
||||||
output_without = vectorstore.similarity_search("Sandwich", k=1)
|
|
||||||
assert vectorstore._embedding_key not in output_without[0].metadata
|
|
||||||
|
|
||||||
def test_delete(
|
|
||||||
self, embeddings: Embeddings, collection: Any, texts: List[str]
|
|
||||||
) -> None:
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
|
||||||
collection=collection,
|
|
||||||
embedding=embeddings,
|
|
||||||
index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
clxn: Collection = vectorstore._collection
|
|
||||||
assert clxn.count_documents({}) == 0
|
|
||||||
ids = vectorstore.add_texts(texts)
|
|
||||||
assert clxn.count_documents({}) == len(texts)
|
|
||||||
|
|
||||||
deleted = vectorstore.delete(ids[-2:])
|
|
||||||
assert deleted
|
|
||||||
assert clxn.count_documents({}) == len(texts) - 2
|
|
||||||
|
|
||||||
new_ids = vectorstore.add_texts(["Pigs eat stuff", "Pigs eat sandwiches"])
|
|
||||||
assert set(new_ids).intersection(set(ids)) == set() # new ids will be unique.
|
|
||||||
assert isinstance(new_ids, list)
|
|
||||||
assert all(isinstance(i, str) for i in new_ids)
|
|
||||||
assert len(new_ids) == 2
|
|
||||||
assert clxn.count_documents({}) == 4
|
|
||||||
|
|
||||||
def test_add_texts(
|
|
||||||
self,
|
|
||||||
embeddings: Embeddings,
|
|
||||||
collection: Collection,
|
|
||||||
texts: List[str],
|
|
||||||
) -> None:
|
|
||||||
"""Tests API of add_texts, focussing on id treatment
|
|
||||||
|
|
||||||
Warning: This is slow because of the number of cases
|
|
||||||
"""
|
|
||||||
metadatas: List[Dict[str, Any]] = [
|
|
||||||
{"a": 1},
|
|
||||||
{"b": 1},
|
|
||||||
{"c": 1},
|
|
||||||
{"d": 1, "e": 2},
|
|
||||||
]
|
|
||||||
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
|
||||||
collection=collection, embedding=embeddings, index_name=INDEX_NAME
|
|
||||||
)
|
|
||||||
|
|
||||||
# Case 1. Add texts without ids
|
|
||||||
provided_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
|
|
||||||
all_docs = list(vectorstore._collection.find({}))
|
|
||||||
assert all("_id" in doc for doc in all_docs)
|
|
||||||
docids = set(doc["_id"] for doc in all_docs)
|
|
||||||
assert all(isinstance(_id, ObjectId) for _id in docids) #
|
|
||||||
assert set(provided_ids) == set(oid_to_str(oid) for oid in docids)
|
|
||||||
|
|
||||||
# Case 2: Test Document.metadata looks right. i.e. contains _id
|
|
||||||
search_res = vectorstore.similarity_search_with_score("sandwich", k=1)
|
|
||||||
doc, score = search_res[0]
|
|
||||||
assert "_id" in doc.metadata
|
|
||||||
|
|
||||||
# Case 3: Add new ids that are 24-char hex strings
|
|
||||||
hex_ids = [oid_to_str(ObjectId()) for _ in range(2)]
|
|
||||||
hex_texts = ["Text for hex_id"] * len(hex_ids)
|
|
||||||
out_ids = vectorstore.add_texts(texts=hex_texts, ids=hex_ids)
|
|
||||||
assert set(out_ids) == set(hex_ids)
|
|
||||||
assert collection.count_documents({}) == len(texts) + len(hex_texts)
|
|
||||||
assert all(
|
|
||||||
isinstance(doc["_id"], ObjectId) for doc in vectorstore._collection.find({})
|
|
||||||
)
|
|
||||||
|
|
||||||
# Case 4: Add new ids that cannot be cast to ObjectId
|
|
||||||
# - We can still index and search on them
|
|
||||||
str_ids = ["Sandwiches are beautiful,", "..sandwiches are fine."]
|
|
||||||
str_texts = str_ids # No reason for them to differ
|
|
||||||
out_ids = vectorstore.add_texts(texts=str_texts, ids=str_ids)
|
|
||||||
assert set(out_ids) == set(str_ids)
|
|
||||||
assert collection.count_documents({}) == 8
|
|
||||||
res = vectorstore.similarity_search("sandwich", k=8)
|
|
||||||
assert any(str_ids[0] in doc.metadata["_id"] for doc in res)
|
|
||||||
|
|
||||||
# Case 5: Test adding in multiple batches
|
|
||||||
batch_size = 2
|
|
||||||
batch_ids = [oid_to_str(ObjectId()) for _ in range(2 * batch_size)]
|
|
||||||
batch_texts = [f"Text for batch text {i}" for i in range(2 * batch_size)]
|
|
||||||
out_ids = vectorstore.add_texts(
|
|
||||||
texts=batch_texts, ids=batch_ids, batch_size=batch_size
|
|
||||||
)
|
|
||||||
assert set(out_ids) == set(batch_ids)
|
|
||||||
assert collection.count_documents({}) == 12
|
|
||||||
|
|
||||||
# Case 6: _ids in metadata
|
|
||||||
collection.delete_many({})
|
|
||||||
# 6a. Unique _id in metadata, but ids=None
|
|
||||||
# Will be added as if ids kwarg provided
|
|
||||||
i = 0
|
|
||||||
n = len(texts)
|
|
||||||
assert len(metadatas) == n
|
|
||||||
_ids = [str(i) for i in range(n)]
|
|
||||||
for md in metadatas:
|
|
||||||
md["_id"] = _ids[i]
|
|
||||||
i += 1
|
|
||||||
returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
|
|
||||||
assert returned_ids == ["0", "1", "2", "3"]
|
|
||||||
assert set(d["_id"] for d in vectorstore._collection.find({})) == set(_ids)
|
|
||||||
|
|
||||||
# 6b. Unique "id", not "_id", but ids=None
|
|
||||||
# New ids will be assigned
|
|
||||||
i = 1
|
|
||||||
for md in metadatas:
|
|
||||||
md.pop("_id")
|
|
||||||
md["id"] = f"{1}"
|
|
||||||
i += 1
|
|
||||||
returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
|
|
||||||
assert len(set(returned_ids).intersection(set(_ids))) == 0
|
|
||||||
|
|
||||||
def test_add_documents(
|
|
||||||
self,
|
|
||||||
embeddings: Embeddings,
|
|
||||||
collection: Collection,
|
|
||||||
) -> None:
|
|
||||||
"""Tests add_documents."""
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
|
||||||
collection=collection, embedding=embeddings, index_name=INDEX_NAME
|
|
||||||
)
|
|
||||||
|
|
||||||
# Case 1: No ids
|
|
||||||
n_docs = 10
|
|
||||||
batch_size = 3
|
|
||||||
docs = [
|
|
||||||
Document(page_content=f"document {i}", metadata={"i": i})
|
|
||||||
for i in range(n_docs)
|
|
||||||
]
|
|
||||||
result_ids = vectorstore.add_documents(docs, batch_size=batch_size)
|
|
||||||
assert len(result_ids) == n_docs
|
|
||||||
assert collection.count_documents({}) == n_docs
|
|
||||||
|
|
||||||
# Case 2: ids
|
|
||||||
collection.delete_many({})
|
|
||||||
n_docs = 10
|
|
||||||
batch_size = 3
|
|
||||||
docs = [
|
|
||||||
Document(page_content=f"document {i}", metadata={"i": i})
|
|
||||||
for i in range(n_docs)
|
|
||||||
]
|
|
||||||
ids = [str(i) for i in range(n_docs)]
|
|
||||||
result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size)
|
|
||||||
assert len(result_ids) == n_docs
|
|
||||||
assert set(ids) == set(collection.distinct("_id"))
|
|
||||||
|
|
||||||
# Case 3: Single batch
|
|
||||||
collection.delete_many({})
|
|
||||||
n_docs = 3
|
|
||||||
batch_size = 10
|
|
||||||
docs = [
|
|
||||||
Document(page_content=f"document {i}", metadata={"i": i})
|
|
||||||
for i in range(n_docs)
|
|
||||||
]
|
|
||||||
ids = [str(i) for i in range(n_docs)]
|
|
||||||
result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size)
|
|
||||||
assert len(result_ids) == n_docs
|
|
||||||
assert set(ids) == set(collection.distinct("_id"))
|
|
||||||
|
|
||||||
def test_index_creation(
|
|
||||||
self, embeddings: Embeddings, index_collection: Any
|
|
||||||
) -> None:
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
|
||||||
index_collection, embedding=embeddings, index_name=INDEX_CREATION_NAME
|
|
||||||
)
|
|
||||||
vectorstore.create_vector_search_index(dimensions=1536)
|
|
||||||
|
|
||||||
def test_index_update(self, embeddings: Embeddings, index_collection: Any) -> None:
|
|
||||||
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
|
||||||
index_collection, embedding=embeddings, index_name=INDEX_CREATION_NAME
|
|
||||||
)
|
|
||||||
vectorstore.create_vector_search_index(dimensions=1536)
|
|
||||||
vectorstore.create_vector_search_index(dimensions=1536, update=True)
|
|
@ -1,215 +0,0 @@
|
|||||||
import uuid
|
|
||||||
from typing import Any, Dict, List, Union
|
|
||||||
|
|
||||||
import pytest # type: ignore[import-not-found]
|
|
||||||
from langchain_core.caches import BaseCache
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from langchain_core.globals import get_llm_cache, set_llm_cache
|
|
||||||
from langchain_core.load.dump import dumps
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
|
||||||
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
|
|
||||||
from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache
|
|
||||||
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
|
|
||||||
|
|
||||||
from ..utils import ConsistentFakeEmbeddings, FakeChatModel, FakeLLM, MockCollection
|
|
||||||
|
|
||||||
CONN_STRING = "MockString"
|
|
||||||
COLLECTION = "default"
|
|
||||||
DATABASE = "default"
|
|
||||||
|
|
||||||
|
|
||||||
class PatchedMongoDBCache(MongoDBCache):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
connection_string: str,
|
|
||||||
collection_name: str = "default",
|
|
||||||
database_name: str = "default",
|
|
||||||
**kwargs: Dict[str, Any],
|
|
||||||
) -> None:
|
|
||||||
self.__database_name = database_name
|
|
||||||
self.__collection_name = collection_name
|
|
||||||
self.client = {self.__database_name: {self.__collection_name: MockCollection()}} # type: ignore
|
|
||||||
|
|
||||||
@property
|
|
||||||
def database(self) -> Any: # type: ignore
|
|
||||||
"""Returns the database used to store cache values."""
|
|
||||||
return self.client[self.__database_name]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def collection(self) -> Collection:
|
|
||||||
"""Returns the collection used to store cache values."""
|
|
||||||
return self.database[self.__collection_name]
|
|
||||||
|
|
||||||
|
|
||||||
class PatchedMongoDBAtlasSemanticCache(MongoDBAtlasSemanticCache):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
connection_string: str,
|
|
||||||
embedding: Embeddings,
|
|
||||||
collection_name: str = "default",
|
|
||||||
database_name: str = "default",
|
|
||||||
wait_until_ready: bool = False,
|
|
||||||
**kwargs: Dict[str, Any],
|
|
||||||
):
|
|
||||||
self.collection = MockCollection()
|
|
||||||
self._wait_until_ready = False
|
|
||||||
self.score_threshold = None
|
|
||||||
MongoDBAtlasVectorSearch.__init__(
|
|
||||||
self,
|
|
||||||
self.collection,
|
|
||||||
embedding=embedding,
|
|
||||||
**kwargs, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def random_string() -> str:
|
|
||||||
return str(uuid.uuid4())
|
|
||||||
|
|
||||||
|
|
||||||
def llm_cache(cls: Any) -> BaseCache:
|
|
||||||
set_llm_cache(
|
|
||||||
cls(
|
|
||||||
embedding=ConsistentFakeEmbeddings(dimensionality=1536),
|
|
||||||
connection_string=CONN_STRING,
|
|
||||||
collection_name=COLLECTION,
|
|
||||||
database_name=DATABASE,
|
|
||||||
wait_until_ready=15.0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert get_llm_cache()
|
|
||||||
return get_llm_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def _execute_test(
|
|
||||||
prompt: Union[str, List[BaseMessage]],
|
|
||||||
llm: Union[str, FakeLLM, FakeChatModel],
|
|
||||||
response: List[Generation],
|
|
||||||
) -> None:
|
|
||||||
# Fabricate an LLM String
|
|
||||||
|
|
||||||
if not isinstance(llm, str):
|
|
||||||
params = llm.dict()
|
|
||||||
params["stop"] = None
|
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
||||||
else:
|
|
||||||
llm_string = llm
|
|
||||||
|
|
||||||
# If the prompt is a str then we should pass just the string
|
|
||||||
dumped_prompt: str = prompt if isinstance(prompt, str) else dumps(prompt)
|
|
||||||
|
|
||||||
# Update the cache
|
|
||||||
llm_cache = get_llm_cache()
|
|
||||||
llm_cache.update(dumped_prompt, llm_string, response)
|
|
||||||
|
|
||||||
# Retrieve the cached result through 'generate' call
|
|
||||||
output: Union[List[Generation], LLMResult, None]
|
|
||||||
expected_output: Union[List[Generation], LLMResult]
|
|
||||||
if isinstance(llm_cache, PatchedMongoDBAtlasSemanticCache):
|
|
||||||
llm_cache._collection._aggregate_result = [ # type: ignore
|
|
||||||
data
|
|
||||||
for data in llm_cache._collection._data # type: ignore
|
|
||||||
if data.get("text") == dumped_prompt
|
|
||||||
and data.get("llm_string") == llm_string
|
|
||||||
] # type: ignore
|
|
||||||
if isinstance(llm, str):
|
|
||||||
output = get_llm_cache().lookup(dumped_prompt, llm) # type: ignore
|
|
||||||
expected_output = response
|
|
||||||
else:
|
|
||||||
output = llm.generate([prompt]) # type: ignore
|
|
||||||
expected_output = LLMResult(
|
|
||||||
generations=[response],
|
|
||||||
llm_output={},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert output == expected_output # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"prompt, llm, response",
|
|
||||||
[
|
|
||||||
("foo", "bar", [Generation(text="fizz")]),
|
|
||||||
("foo", FakeLLM(), [Generation(text="fizz")]),
|
|
||||||
(
|
|
||||||
[HumanMessage(content="foo")],
|
|
||||||
FakeChatModel(),
|
|
||||||
[ChatGeneration(message=AIMessage(content="foo"))],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"plain_cache",
|
|
||||||
"cache_with_llm",
|
|
||||||
"cache_with_chat",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"cacher", [PatchedMongoDBCache, PatchedMongoDBAtlasSemanticCache]
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("remove_score", [True, False])
|
|
||||||
def test_mongodb_cache(
|
|
||||||
remove_score: bool,
|
|
||||||
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
|
|
||||||
prompt: Union[str, List[BaseMessage]],
|
|
||||||
llm: Union[str, FakeLLM, FakeChatModel],
|
|
||||||
response: List[Generation],
|
|
||||||
) -> None:
|
|
||||||
llm_cache(cacher)
|
|
||||||
if remove_score:
|
|
||||||
get_llm_cache().score_threshold = None # type: ignore
|
|
||||||
try:
|
|
||||||
_execute_test(prompt, llm, response)
|
|
||||||
finally:
|
|
||||||
get_llm_cache().clear()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"prompts, generations",
|
|
||||||
[
|
|
||||||
# Single prompt, single generation
|
|
||||||
([random_string()], [[random_string()]]),
|
|
||||||
# Single prompt, multiple generations
|
|
||||||
([random_string()], [[random_string(), random_string()]]),
|
|
||||||
# Single prompt, multiple generations
|
|
||||||
([random_string()], [[random_string(), random_string(), random_string()]]),
|
|
||||||
# Multiple prompts, multiple generations
|
|
||||||
(
|
|
||||||
[random_string(), random_string()],
|
|
||||||
[[random_string()], [random_string(), random_string()]],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"single_prompt_single_generation",
|
|
||||||
"single_prompt_two_generations",
|
|
||||||
"single_prompt_three_generations",
|
|
||||||
"multiple_prompts_multiple_generations",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_mongodb_atlas_cache_matrix(
|
|
||||||
prompts: List[str],
|
|
||||||
generations: List[List[str]],
|
|
||||||
) -> None:
|
|
||||||
llm_cache(PatchedMongoDBAtlasSemanticCache)
|
|
||||||
llm = FakeLLM()
|
|
||||||
|
|
||||||
# Fabricate an LLM String
|
|
||||||
params = llm.dict()
|
|
||||||
params["stop"] = None
|
|
||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
||||||
|
|
||||||
llm_generations = [
|
|
||||||
[
|
|
||||||
Generation(text=generation, generation_info=params)
|
|
||||||
for generation in prompt_i_generations
|
|
||||||
]
|
|
||||||
for prompt_i_generations in generations
|
|
||||||
]
|
|
||||||
|
|
||||||
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
|
||||||
_execute_test(prompt_i, llm_string, llm_generations_i)
|
|
||||||
|
|
||||||
get_llm_cache()._collection._simulate_cache_aggregation_query = True # type: ignore
|
|
||||||
assert llm.generate(prompts) == LLMResult(
|
|
||||||
generations=llm_generations, llm_output={}
|
|
||||||
)
|
|
||||||
get_llm_cache().clear()
|
|
@ -1,44 +0,0 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
|
|
||||||
from langchain_core.messages import message_to_dict
|
|
||||||
|
|
||||||
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
|
|
||||||
|
|
||||||
from ..utils import MockCollection
|
|
||||||
|
|
||||||
|
|
||||||
class PatchedMongoDBChatMessageHistory(MongoDBChatMessageHistory):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.session_id = "test-session"
|
|
||||||
self.database_name = "test-database"
|
|
||||||
self.collection_name = "test-collection"
|
|
||||||
self.collection = MockCollection()
|
|
||||||
self.session_id_key = "SessionId"
|
|
||||||
self.history_key = "History"
|
|
||||||
self.history_size = None
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_with_message_store() -> None:
|
|
||||||
"""Test the memory with a message store."""
|
|
||||||
# setup MongoDB as a message store
|
|
||||||
message_history = PatchedMongoDBChatMessageHistory()
|
|
||||||
memory = ConversationBufferMemory(
|
|
||||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# add some messages
|
|
||||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
|
||||||
memory.chat_memory.add_user_message("This is me, the human")
|
|
||||||
|
|
||||||
# get the message history from the memory store and turn it into a json
|
|
||||||
messages = memory.chat_memory.messages
|
|
||||||
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
|
|
||||||
|
|
||||||
assert "This is me, the AI" in messages_json
|
|
||||||
assert "This is me, the human" in messages_json
|
|
||||||
|
|
||||||
# remove the record from MongoDB, so the next test run won't pick it up
|
|
||||||
memory.chat_memory.clear()
|
|
||||||
|
|
||||||
assert memory.chat_memory.messages == []
|
|
@ -1,12 +0,0 @@
|
|||||||
from langchain_mongodb import __all__
|
|
||||||
|
|
||||||
EXPECTED_ALL = [
|
|
||||||
"MongoDBAtlasVectorSearch",
|
|
||||||
"MongoDBChatMessageHistory",
|
|
||||||
"MongoDBCache",
|
|
||||||
"MongoDBAtlasSemanticCache",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_imports() -> None:
|
|
||||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
|
@ -1,72 +0,0 @@
|
|||||||
"""Search index commands are only supported on Atlas Clusters >=M10"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from time import sleep
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
from pymongo.errors import OperationFailure, ServerSelectionTimeoutError
|
|
||||||
|
|
||||||
from langchain_mongodb import index
|
|
||||||
|
|
||||||
DIMENSION = 10
|
|
||||||
TIMEOUT = 10
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def collection() -> Collection:
|
|
||||||
"""Depending on uri, this could point to any type of cluster.
|
|
||||||
|
|
||||||
For unit tests, MONGODB_URI should be localhost, None, or Atlas cluster <M10.
|
|
||||||
"""
|
|
||||||
uri = os.environ.get("MONGODB_URI")
|
|
||||||
client: MongoClient = MongoClient(uri)
|
|
||||||
return client["db"]["collection"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_vector_search_index(collection: Collection) -> None:
|
|
||||||
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
|
|
||||||
index.create_vector_search_index(
|
|
||||||
collection,
|
|
||||||
"index_name",
|
|
||||||
DIMENSION,
|
|
||||||
"embedding",
|
|
||||||
"cosine",
|
|
||||||
[],
|
|
||||||
wait_until_complete=TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_drop_vector_search_index(collection: Collection) -> None:
|
|
||||||
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
|
|
||||||
index.drop_vector_search_index(
|
|
||||||
collection, "index_name", wait_until_complete=TIMEOUT
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_vector_search_index(collection: Collection) -> None:
|
|
||||||
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
|
|
||||||
index.update_vector_search_index(
|
|
||||||
collection,
|
|
||||||
"index_name",
|
|
||||||
DIMENSION,
|
|
||||||
"embedding",
|
|
||||||
"cosine",
|
|
||||||
[],
|
|
||||||
wait_until_complete=TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test___is_index_ready(collection: Collection) -> None:
|
|
||||||
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
|
|
||||||
index._is_index_ready(collection, "index_name")
|
|
||||||
|
|
||||||
|
|
||||||
def test__wait_for_predicate() -> None:
|
|
||||||
err = "error string"
|
|
||||||
with pytest.raises(TimeoutError) as e:
|
|
||||||
index._wait_for_predicate(lambda: sleep(5), err=err, timeout=0.5, interval=0.1)
|
|
||||||
assert err in str(e)
|
|
||||||
|
|
||||||
index._wait_for_predicate(lambda: True, err=err, timeout=1.0, interval=0.5)
|
|
@ -1,191 +0,0 @@
|
|||||||
from json import dumps, loads
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import pytest # type: ignore[import-not-found]
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
|
|
||||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
|
||||||
|
|
||||||
from ..utils import ConsistentFakeEmbeddings, MockCollection
|
|
||||||
|
|
||||||
INDEX_NAME = "langchain-test-index"
|
|
||||||
NAMESPACE = "langchain_test_db.langchain_test_collection"
|
|
||||||
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
|
|
||||||
|
|
||||||
|
|
||||||
def get_collection() -> MockCollection:
|
|
||||||
return MockCollection()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def collection() -> MockCollection:
|
|
||||||
return get_collection()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def embedding_openai() -> Embeddings:
|
|
||||||
return ConsistentFakeEmbeddings()
|
|
||||||
|
|
||||||
|
|
||||||
def test_initialization(collection: Collection, embedding_openai: Embeddings) -> None:
|
|
||||||
"""Test initialization of vector store class"""
|
|
||||||
assert MongoDBAtlasVectorSearch(collection, embedding_openai)
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_from_texts(collection: Collection, embedding_openai: Embeddings) -> None:
|
|
||||||
"""Test from_texts operation on an empty list"""
|
|
||||||
assert MongoDBAtlasVectorSearch.from_texts(
|
|
||||||
[], embedding_openai, collection=collection
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestMongoDBAtlasVectorSearch:
|
|
||||||
@classmethod
|
|
||||||
def setup_class(cls) -> None:
|
|
||||||
# ensure the test collection is empty
|
|
||||||
collection = get_collection()
|
|
||||||
assert collection.count_documents({}) == 0 # type: ignore[index]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def teardown_class(cls) -> None:
|
|
||||||
collection = get_collection()
|
|
||||||
# delete all the documents in the collection
|
|
||||||
collection.delete_many({}) # type: ignore[index]
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup(self) -> None:
|
|
||||||
collection = get_collection()
|
|
||||||
# delete all the documents in the collection
|
|
||||||
collection.delete_many({}) # type: ignore[index]
|
|
||||||
|
|
||||||
def _validate_search(
|
|
||||||
self,
|
|
||||||
vectorstore: MongoDBAtlasVectorSearch,
|
|
||||||
collection: MockCollection,
|
|
||||||
search_term: str = "sandwich",
|
|
||||||
page_content: str = "What is a sandwich?",
|
|
||||||
metadata: Optional[Any] = 1,
|
|
||||||
) -> None:
|
|
||||||
collection._aggregate_result = list(
|
|
||||||
filter(
|
|
||||||
lambda x: search_term.lower() in x[vectorstore._text_key].lower(),
|
|
||||||
collection._data,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
output = vectorstore.similarity_search("", k=1)
|
|
||||||
assert output[0].page_content == page_content
|
|
||||||
assert output[0].metadata.get("c") == metadata
|
|
||||||
# Validate the ObjectId provided is json serializable
|
|
||||||
assert loads(dumps(output[0].page_content)) == output[0].page_content
|
|
||||||
assert loads(dumps(output[0].metadata)) == output[0].metadata
|
|
||||||
assert isinstance(output[0].metadata["_id"], str)
|
|
||||||
|
|
||||||
def test_from_documents(
|
|
||||||
self, embedding_openai: Embeddings, collection: MockCollection
|
|
||||||
) -> None:
|
|
||||||
"""Test end to end construction and search."""
|
|
||||||
documents = [
|
|
||||||
Document(page_content="Dogs are tough.", metadata={"a": 1}),
|
|
||||||
Document(page_content="Cats have fluff.", metadata={"b": 1}),
|
|
||||||
Document(page_content="What is a sandwich?", metadata={"c": 1}),
|
|
||||||
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
|
|
||||||
]
|
|
||||||
vectorstore = MongoDBAtlasVectorSearch.from_documents(
|
|
||||||
documents,
|
|
||||||
embedding_openai,
|
|
||||||
collection=collection,
|
|
||||||
vector_index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
self._validate_search(
|
|
||||||
vectorstore, collection, metadata=documents[2].metadata["c"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_from_texts(
|
|
||||||
self, embedding_openai: Embeddings, collection: MockCollection
|
|
||||||
) -> None:
|
|
||||||
texts = [
|
|
||||||
"Dogs are tough.",
|
|
||||||
"Cats have fluff.",
|
|
||||||
"What is a sandwich?",
|
|
||||||
"That fence is purple.",
|
|
||||||
]
|
|
||||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
|
||||||
texts,
|
|
||||||
embedding_openai,
|
|
||||||
collection=collection,
|
|
||||||
vector_index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
self._validate_search(vectorstore, collection, metadata=None)
|
|
||||||
|
|
||||||
def test_from_texts_with_metadatas(
|
|
||||||
self, embedding_openai: Embeddings, collection: MockCollection
|
|
||||||
) -> None:
|
|
||||||
texts = [
|
|
||||||
"Dogs are tough.",
|
|
||||||
"Cats have fluff.",
|
|
||||||
"What is a sandwich?",
|
|
||||||
"The fence is purple.",
|
|
||||||
]
|
|
||||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
|
||||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
|
||||||
texts,
|
|
||||||
embedding_openai,
|
|
||||||
metadatas=metadatas,
|
|
||||||
collection=collection,
|
|
||||||
vector_index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
self._validate_search(vectorstore, collection, metadata=metadatas[2]["c"])
|
|
||||||
|
|
||||||
def test_from_texts_with_metadatas_and_pre_filter(
|
|
||||||
self, embedding_openai: Embeddings, collection: MockCollection
|
|
||||||
) -> None:
|
|
||||||
texts = [
|
|
||||||
"Dogs are tough.",
|
|
||||||
"Cats have fluff.",
|
|
||||||
"What is a sandwich?",
|
|
||||||
"The fence is purple.",
|
|
||||||
]
|
|
||||||
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
|
|
||||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
|
||||||
texts,
|
|
||||||
embedding_openai,
|
|
||||||
metadatas=metadatas,
|
|
||||||
collection=collection,
|
|
||||||
vector_index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
collection._aggregate_result = list(
|
|
||||||
filter(
|
|
||||||
lambda x: "sandwich" in x[vectorstore._text_key].lower()
|
|
||||||
and x.get("c") < 0,
|
|
||||||
collection._data,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
output = vectorstore.similarity_search(
|
|
||||||
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}}
|
|
||||||
)
|
|
||||||
assert output == []
|
|
||||||
|
|
||||||
def test_mmr(
|
|
||||||
self, embedding_openai: Embeddings, collection: MockCollection
|
|
||||||
) -> None:
|
|
||||||
texts = ["foo", "foo", "fou", "foy"]
|
|
||||||
vectorstore = MongoDBAtlasVectorSearch.from_texts(
|
|
||||||
texts,
|
|
||||||
embedding=embedding_openai,
|
|
||||||
collection=collection,
|
|
||||||
vector_index_name=INDEX_NAME,
|
|
||||||
)
|
|
||||||
query = "foo"
|
|
||||||
self._validate_search(
|
|
||||||
vectorstore,
|
|
||||||
collection,
|
|
||||||
search_term=query[0:2],
|
|
||||||
page_content=query,
|
|
||||||
metadata=None,
|
|
||||||
)
|
|
||||||
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
|
|
||||||
assert len(output) == len(texts)
|
|
||||||
assert output[0].page_content == "foo"
|
|
||||||
assert output[1].page_content != "foo"
|
|
@ -1,273 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from copy import deepcopy
|
|
||||||
from time import monotonic, sleep
|
|
||||||
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Union, cast
|
|
||||||
|
|
||||||
from bson import ObjectId
|
|
||||||
from langchain_core.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForLLMRun,
|
|
||||||
CallbackManagerForLLMRun,
|
|
||||||
)
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from langchain_core.language_models.chat_models import SimpleChatModel
|
|
||||||
from langchain_core.language_models.llms import LLM
|
|
||||||
from langchain_core.messages import (
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
)
|
|
||||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
|
||||||
from pydantic import model_validator
|
|
||||||
from pymongo.collection import Collection
|
|
||||||
from pymongo.results import DeleteResult, InsertManyResult
|
|
||||||
|
|
||||||
from langchain_mongodb import MongoDBAtlasVectorSearch
|
|
||||||
from langchain_mongodb.cache import MongoDBAtlasSemanticCache
|
|
||||||
|
|
||||||
TIMEOUT = 120
|
|
||||||
INTERVAL = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class PatchedMongoDBAtlasVectorSearch(MongoDBAtlasVectorSearch):
|
|
||||||
def bulk_embed_and_insert_texts(
|
|
||||||
self,
|
|
||||||
texts: Union[List[str], Iterable[str]],
|
|
||||||
metadatas: Union[List[dict], Generator[dict, Any, Any]],
|
|
||||||
ids: Optional[List[str]] = None,
|
|
||||||
) -> List:
|
|
||||||
"""Patched insert_texts that waits for data to be indexed before returning"""
|
|
||||||
ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids)
|
|
||||||
start = monotonic()
|
|
||||||
while len(ids_inserted) != len(self.similarity_search("sandwich")) and (
|
|
||||||
monotonic() - start <= TIMEOUT
|
|
||||||
):
|
|
||||||
sleep(INTERVAL)
|
|
||||||
return ids_inserted
|
|
||||||
|
|
||||||
def create_vector_search_index(
|
|
||||||
self,
|
|
||||||
dimensions: int,
|
|
||||||
filters: Optional[List[str]] = None,
|
|
||||||
update: bool = False,
|
|
||||||
) -> None:
|
|
||||||
result = super().create_vector_search_index(
|
|
||||||
dimensions=dimensions, filters=filters, update=update
|
|
||||||
)
|
|
||||||
start = monotonic()
|
|
||||||
while monotonic() - start <= TIMEOUT:
|
|
||||||
if indexes := list(
|
|
||||||
self._collection.list_search_indexes(name=self._index_name)
|
|
||||||
):
|
|
||||||
if indexes[0].get("status") == "READY":
|
|
||||||
return result
|
|
||||||
sleep(INTERVAL)
|
|
||||||
|
|
||||||
|
|
||||||
class ConsistentFakeEmbeddings(Embeddings):
|
|
||||||
"""Fake embeddings functionality for testing."""
|
|
||||||
|
|
||||||
def __init__(self, dimensionality: int = 10) -> None:
|
|
||||||
self.known_texts: List[str] = []
|
|
||||||
self.dimensionality = dimensionality
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""Return consistent embeddings for each text seen so far."""
|
|
||||||
out_vectors = []
|
|
||||||
for text in texts:
|
|
||||||
if text not in self.known_texts:
|
|
||||||
self.known_texts.append(text)
|
|
||||||
vector = [float(1.0)] * (self.dimensionality - 1) + [
|
|
||||||
float(self.known_texts.index(text))
|
|
||||||
]
|
|
||||||
out_vectors.append(vector)
|
|
||||||
return out_vectors
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
|
||||||
one if the text is unknown."""
|
|
||||||
return self.embed_documents([text])[0]
|
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
return self.embed_documents(texts)
|
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
|
||||||
return self.embed_query(text)
|
|
||||||
|
|
||||||
|
|
||||||
class FakeChatModel(SimpleChatModel):
|
|
||||||
"""Fake Chat Model wrapper for testing purposes."""
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str:
|
|
||||||
return "fake response"
|
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
output_str = "fake response"
|
|
||||||
message = AIMessage(content=output_str)
|
|
||||||
generation = ChatGeneration(message=message)
|
|
||||||
return ChatResult(generations=[generation])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _llm_type(self) -> str:
|
|
||||||
return "fake-chat-model"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
|
||||||
return {"key": "fake"}
|
|
||||||
|
|
||||||
|
|
||||||
class FakeLLM(LLM):
|
|
||||||
"""Fake LLM wrapper for testing purposes."""
|
|
||||||
|
|
||||||
queries: Optional[Mapping] = None
|
|
||||||
sequential_responses: Optional[bool] = False
|
|
||||||
response_index: int = 0
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_queries_required(cls, values: dict) -> dict:
|
|
||||||
if values.get("sequential_response") and not values.get("queries"):
|
|
||||||
raise ValueError(
|
|
||||||
"queries is required when sequential_response is set to True"
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
|
||||||
"""Return number of tokens."""
|
|
||||||
return len(text.split())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _llm_type(self) -> str:
|
|
||||||
"""Return type of llm."""
|
|
||||||
return "fake"
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str:
|
|
||||||
if self.sequential_responses:
|
|
||||||
return self._get_next_response_in_sequence
|
|
||||||
if self.queries is not None:
|
|
||||||
return self.queries[prompt]
|
|
||||||
if stop is None:
|
|
||||||
return "foo"
|
|
||||||
else:
|
|
||||||
return "bar"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _get_next_response_in_sequence(self) -> str:
|
|
||||||
queries = cast(Mapping, self.queries)
|
|
||||||
response = queries[list(queries.keys())[self.response_index]]
|
|
||||||
self.response_index = self.response_index + 1
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class MockCollection(Collection):
|
|
||||||
"""Mocked Mongo Collection"""
|
|
||||||
|
|
||||||
_aggregate_result: List[Any]
|
|
||||||
_insert_result: Optional[InsertManyResult]
|
|
||||||
_data: List[Any]
|
|
||||||
_simulate_cache_aggregation_query: bool
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._data = []
|
|
||||||
self._aggregate_result = []
|
|
||||||
self._insert_result = None
|
|
||||||
self._simulate_cache_aggregation_query = False
|
|
||||||
|
|
||||||
def delete_many(self, *args, **kwargs) -> DeleteResult: # type: ignore
|
|
||||||
old_len = len(self._data)
|
|
||||||
self._data = []
|
|
||||||
return DeleteResult({"n": old_len}, acknowledged=True)
|
|
||||||
|
|
||||||
def insert_many(self, to_insert: List[Any], *args, **kwargs) -> InsertManyResult: # type: ignore
|
|
||||||
mongodb_inserts = [
|
|
||||||
{"_id": ObjectId(), "score": 1, **insert} for insert in to_insert
|
|
||||||
]
|
|
||||||
self._data.extend(mongodb_inserts)
|
|
||||||
return self._insert_result or InsertManyResult(
|
|
||||||
[k["_id"] for k in mongodb_inserts], acknowledged=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def insert_one(self, to_insert: Any, *args, **kwargs) -> Any: # type: ignore
|
|
||||||
return self.insert_many([to_insert])
|
|
||||||
|
|
||||||
def find_one(self, find_query: Dict[str, Any]) -> Optional[Dict[str, Any]]: # type: ignore
|
|
||||||
find = self.find(find_query) or [None] # type: ignore
|
|
||||||
return find[0]
|
|
||||||
|
|
||||||
def find(self, find_query: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]: # type: ignore
|
|
||||||
def _is_match(item: Dict[str, Any]) -> bool:
|
|
||||||
for key, match_val in find_query.items():
|
|
||||||
if item.get(key) != match_val:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
return [document for document in self._data if _is_match(document)]
|
|
||||||
|
|
||||||
def update_one( # type: ignore
|
|
||||||
self,
|
|
||||||
find_query: Dict[str, Any],
|
|
||||||
options: Dict[str, Any],
|
|
||||||
*args: Any,
|
|
||||||
upsert=True,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None: # type: ignore
|
|
||||||
result = self.find_one(find_query)
|
|
||||||
set_options = options.get("$set", {})
|
|
||||||
|
|
||||||
if result:
|
|
||||||
result.update(set_options)
|
|
||||||
elif upsert:
|
|
||||||
self._data.append({**find_query, **set_options})
|
|
||||||
|
|
||||||
def _execute_cache_aggregation_query(self, *args, **kwargs) -> List[Dict[str, Any]]: # type: ignore
|
|
||||||
"""Helper function only to be used for MongoDBAtlasSemanticCache Testing
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict[str, Any]]: Aggregation query result
|
|
||||||
"""
|
|
||||||
pipeline: List[Dict[str, Any]] = args[0]
|
|
||||||
params = pipeline[0]["$vectorSearch"]
|
|
||||||
embedding = params["queryVector"]
|
|
||||||
# Assumes MongoDBAtlasSemanticCache.LLM == "llm_string"
|
|
||||||
llm_string = params["filter"][MongoDBAtlasSemanticCache.LLM]["$eq"]
|
|
||||||
|
|
||||||
acc = []
|
|
||||||
for document in self._data:
|
|
||||||
if (
|
|
||||||
document.get("embedding") == embedding
|
|
||||||
and document.get(MongoDBAtlasSemanticCache.LLM) == llm_string
|
|
||||||
):
|
|
||||||
acc.append(document)
|
|
||||||
return acc
|
|
||||||
|
|
||||||
def aggregate(self, *args, **kwargs) -> List[Any]: # type: ignore
|
|
||||||
if self._simulate_cache_aggregation_query:
|
|
||||||
return deepcopy(self._execute_cache_aggregation_query(*args, **kwargs))
|
|
||||||
return deepcopy(self._aggregate_result)
|
|
||||||
|
|
||||||
def count_documents(self, *args, **kwargs) -> int: # type: ignore
|
|
||||||
return len(self._data)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return "MockCollection"
|
|
Loading…
Reference in New Issue
Block a user