Merge branch 'master' into mitchell/fix-sqarql-bug

This commit is contained in:
Mingchen Li 2024-10-19 12:53:25 +08:00 committed by GitHub
commit f17bf1f3bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 769 additions and 10161 deletions

View File

@ -77,7 +77,7 @@
"id": "08f8b820-5912-49c1-9d76-40be0571dffb", "id": "08f8b820-5912-49c1-9d76-40be0571dffb",
"metadata": {}, "metadata": {},
"source": [ "source": [
"This creates a retriever (specifically a [VectorStoreRetriever](https://python.langchain.com/api_reference/core/vectorstores/langchain_core.vectorstores.VectorStoreRetriever.html)), which we can use in the usual way:" "This creates a retriever (specifically a [VectorStoreRetriever](https://python.langchain.com/api_reference/core/vectorstores/langchain_core.vectorstores.base.VectorStoreRetriever.html)), which we can use in the usual way:"
] ]
}, },
{ {

View File

@ -434,6 +434,160 @@
"fine_tuned_model.invoke(messages)" "fine_tuned_model.invoke(messages)"
] ]
}, },
{
"cell_type": "markdown",
"id": "5d5d9793",
"metadata": {},
"source": [
"## Multimodal Inputs\n",
"\n",
"OpenAI has models that support multimodal inputs. You can pass in images or audio to these models. For more information on how to do this in LangChain, head to the [multimodal inputs](/docs/how_to/multimodal_inputs) docs.\n",
"\n",
"You can see the list of models that support different modalities in [OpenAI's documentation](https://platform.openai.com/docs/models).\n",
"\n",
"At the time of this doc's writing, the main OpenAI models you would use would be:\n",
"\n",
"- Image inputs: `gpt-4o`, `gpt-4o-mini`\n",
"- Audio inputs: `gpt-4o-audio-preview`\n",
"\n",
"For an example of passing in image inputs, see the [multimodal inputs how-to guide](/docs/how_to/multimodal_inputs).\n",
"\n",
"Below is an example of passing audio inputs to `gpt-4o-audio-preview`:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "39d08780",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"I'm sorry, but I can't create audio content that involves yelling. Is there anything else I can help you with?\""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import base64\n",
"\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(\n",
" model=\"gpt-4o-audio-preview\",\n",
" temperature=0,\n",
")\n",
"\n",
"with open(\n",
" \"../../../../libs/partners/openai/tests/integration_tests/chat_models/audio_input.wav\",\n",
" \"rb\",\n",
") as f:\n",
" # b64 encode it\n",
" audio = f.read()\n",
" audio_b64 = base64.b64encode(audio).decode()\n",
"\n",
"\n",
"output_message = llm.invoke(\n",
" [\n",
" (\n",
" \"human\",\n",
" [\n",
" {\"type\": \"text\", \"text\": \"Transcribe the following:\"},\n",
" # the audio clip says \"I'm sorry, but I can't create...\"\n",
" {\n",
" \"type\": \"input_audio\",\n",
" \"input_audio\": {\"data\": audio_b64, \"format\": \"wav\"},\n",
" },\n",
" ],\n",
" ),\n",
" ]\n",
")\n",
"output_message.content"
]
},
{
"cell_type": "markdown",
"id": "feb4a499",
"metadata": {},
"source": [
"## Audio Generation (Preview)\n",
"\n",
":::info\n",
"Requires `langchain-openai>=0.2.3`\n",
":::\n",
"\n",
"OpenAI has a new [audio generation feature](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-out) that allows you to use audio inputs and outputs with the `gpt-4o-audio-preview` model."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f67a2cac",
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(\n",
" model=\"gpt-4o-audio-preview\",\n",
" temperature=0,\n",
" model_kwargs={\n",
" \"modalities\": [\"text\", \"audio\"],\n",
" \"audio\": {\"voice\": \"alloy\", \"format\": \"wav\"},\n",
" },\n",
")\n",
"\n",
"output_message = llm.invoke(\n",
" [\n",
" (\"human\", \"Are you made by OpenAI? Just answer yes or no\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b7dd4e8b",
"metadata": {},
"source": [
"`output_message.additional_kwargs['audio']` will contain a dictionary like\n",
"```python\n",
"{\n",
" 'data': '<audio data b64-encoded',\n",
" 'expires_at': 1729268602,\n",
" 'id': 'audio_67127d6a44348190af62c1530ef0955a',\n",
" 'transcript': 'Yes.'\n",
"}\n",
"```\n",
"and the format will be what was passed in `model_kwargs['audio']['format']`.\n",
"\n",
"We can also pass this message with audio data back to the model as part of a message history before openai `expires_at` is reached.\n",
"\n",
":::note\n",
"Output audio is stored under the `audio` key in `AIMessage.additional_kwargs`, but input content blocks are typed with an `input_audio` type and key in `HumanMessage.content` lists. \n",
"\n",
"For more information, see OpenAI's [audio docs](https://platform.openai.com/docs/guides/audio).\n",
":::"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f5ae473d",
"metadata": {},
"outputs": [],
"source": [
"history = [\n",
" (\"human\", \"Are you made by OpenAI? Just answer yes or no\"),\n",
" output_message,\n",
" (\"human\", \"And what is your name? Just give your name.\"),\n",
"]\n",
"second_output_message = llm.invoke(history)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "a796d728-971b-408b-88d5-440015bbb941", "id": "a796d728-971b-408b-88d5-440015bbb941",
@ -447,7 +601,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": ".venv",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -85,6 +85,10 @@
{ {
"source": "/docs/integrations/platforms/:path((?:anthropic|aws|google|huggingface|microsoft|openai)?/?)*", "source": "/docs/integrations/platforms/:path((?:anthropic|aws|google|huggingface|microsoft|openai)?/?)*",
"destination": "/docs/integrations/providers/:path*" "destination": "/docs/integrations/providers/:path*"
},
{
"source": "/docs/troubleshooting/errors/:path((?:GRAPH_RECURSION_LIMIT|INVALID_CONCURRENT_GRAPH_UPDATE|INVALID_GRAPH_NODE_RETURN_VALUE|MULTIPLE_SUBGRAPHS)/?)*",
"destination": "https://langchain-ai.github.io/langgraph/troubleshooting/errors/:path*"
} }
] ]
} }

View File

@ -173,21 +173,23 @@ class GradientLLM(BaseLLM):
"content-type": "application/json", "content-type": "application/json",
}, },
json=dict( json=dict(
samples=tuple( samples=(
{ tuple(
"inputs": input, {
} "inputs": input,
for input in inputs }
) for input in inputs
if multipliers is None )
else tuple( if multipliers is None
{ else tuple(
"inputs": input, {
"fineTuningParameters": { "inputs": input,
"multiplier": multiplier, "fineTuningParameters": {
}, "multiplier": multiplier,
} },
for input, multiplier in zip(inputs, multipliers) }
for input, multiplier in zip(inputs, multipliers)
)
), ),
), ),
) )
@ -338,9 +340,11 @@ class GradientLLM(BaseLLM):
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
generations = [] generations = []
for generation in asyncio.gather( for generation in await asyncio.gather(
[self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)] *[
for prompt in prompts self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
for prompt in prompts
]
): ):
generations.append([Generation(text=generation)]) generations.append([Generation(text=generation)])
return LLMResult(generations=generations) return LLMResult(generations=generations)

View File

@ -64,7 +64,7 @@ class GoogleTrendsAPIWrapper(BaseModel):
"q": query, "q": query,
} }
total_results = [] total_results: Any = []
client = self.serp_search_engine(params) client = self.serp_search_engine(params)
client_dict = client.get_dict() client_dict = client.get_dict()
total_results = ( total_results = (

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "aiohappyeyeballs" name = "aiohappyeyeballs"
@ -1801,7 +1801,7 @@ files = [
[[package]] [[package]]
name = "langchain" name = "langchain"
version = "0.3.3" version = "0.3.4"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.9,<4.0" python-versions = ">=3.9,<4.0"
@ -1811,7 +1811,7 @@ develop = true
[package.dependencies] [package.dependencies]
aiohttp = "^3.8.3" aiohttp = "^3.8.3"
async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""} async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""}
langchain-core = "^0.3.10" langchain-core = "^0.3.12"
langchain-text-splitters = "^0.3.0" langchain-text-splitters = "^0.3.0"
langsmith = "^0.1.17" langsmith = "^0.1.17"
numpy = [ numpy = [
@ -1830,7 +1830,7 @@ url = "../langchain"
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.3.10" version = "0.3.12"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.9,<4.0" python-versions = ">=3.9,<4.0"
@ -2146,38 +2146,43 @@ typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""}
[[package]] [[package]]
name = "mypy" name = "mypy"
version = "1.11.2" version = "1.12.0"
description = "Optional static typing for Python" description = "Optional static typing for Python"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "mypy-1.11.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d42a6dd818ffce7be66cce644f1dff482f1d97c53ca70908dff0b9ddc120b77a"}, {file = "mypy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4397081e620dc4dc18e2f124d5e1d2c288194c2c08df6bdb1db31c38cd1fe1ed"},
{file = "mypy-1.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:801780c56d1cdb896eacd5619a83e427ce436d86a3bdf9112527f24a66618fef"}, {file = "mypy-1.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:684a9c508a283f324804fea3f0effeb7858eb03f85c4402a967d187f64562469"},
{file = "mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41ea707d036a5307ac674ea172875f40c9d55c5394f888b168033177fce47383"}, {file = "mypy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cabe4cda2fa5eca7ac94854c6c37039324baaa428ecbf4de4567279e9810f9e"},
{file = "mypy-1.11.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e658bd2d20565ea86da7d91331b0eed6d2eee22dc031579e6297f3e12c758c8"}, {file = "mypy-1.12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:060a07b10e999ac9e7fa249ce2bdcfa9183ca2b70756f3bce9df7a92f78a3c0a"},
{file = "mypy-1.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:478db5f5036817fe45adb7332d927daa62417159d49783041338921dcf646fc7"}, {file = "mypy-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:0eff042d7257f39ba4ca06641d110ca7d2ad98c9c1fb52200fe6b1c865d360ff"},
{file = "mypy-1.11.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75746e06d5fa1e91bfd5432448d00d34593b52e7e91a187d981d08d1f33d4385"}, {file = "mypy-1.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4b86de37a0da945f6d48cf110d5206c5ed514b1ca2614d7ad652d4bf099c7de7"},
{file = "mypy-1.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a976775ab2256aadc6add633d44f100a2517d2388906ec4f13231fafbb0eccca"}, {file = "mypy-1.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20c7c5ce0c1be0b0aea628374e6cf68b420bcc772d85c3c974f675b88e3e6e57"},
{file = "mypy-1.11.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cd953f221ac1379050a8a646585a29574488974f79d8082cedef62744f0a0104"}, {file = "mypy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a64ee25f05fc2d3d8474985c58042b6759100a475f8237da1f4faf7fcd7e6309"},
{file = "mypy-1.11.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:57555a7715c0a34421013144a33d280e73c08df70f3a18a552938587ce9274f4"}, {file = "mypy-1.12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:faca7ab947c9f457a08dcb8d9a8664fd438080e002b0fa3e41b0535335edcf7f"},
{file = "mypy-1.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:36383a4fcbad95f2657642a07ba22ff797de26277158f1cc7bd234821468b1b6"}, {file = "mypy-1.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:5bc81701d52cc8767005fdd2a08c19980de9ec61a25dbd2a937dfb1338a826f9"},
{file = "mypy-1.11.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e8960dbbbf36906c5c0b7f4fbf2f0c7ffb20f4898e6a879fcf56a41a08b0d318"}, {file = "mypy-1.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8462655b6694feb1c99e433ea905d46c478041a8b8f0c33f1dab00ae881b2164"},
{file = "mypy-1.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06d26c277962f3fb50e13044674aa10553981ae514288cb7d0a738f495550b36"}, {file = "mypy-1.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:923ea66d282d8af9e0f9c21ffc6653643abb95b658c3a8a32dca1eff09c06475"},
{file = "mypy-1.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e7184632d89d677973a14d00ae4d03214c8bc301ceefcdaf5c474866814c987"}, {file = "mypy-1.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1ebf9e796521f99d61864ed89d1fb2926d9ab6a5fab421e457cd9c7e4dd65aa9"},
{file = "mypy-1.11.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a66169b92452f72117e2da3a576087025449018afc2d8e9bfe5ffab865709ca"}, {file = "mypy-1.12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e478601cc3e3fa9d6734d255a59c7a2e5c2934da4378f3dd1e3411ea8a248642"},
{file = "mypy-1.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:969ea3ef09617aff826885a22ece0ddef69d95852cdad2f60c8bb06bf1f71f70"}, {file = "mypy-1.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:c72861b7139a4f738344faa0e150834467521a3fba42dc98264e5aa9507dd601"},
{file = "mypy-1.11.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:37c7fa6121c1cdfcaac97ce3d3b5588e847aa79b580c1e922bb5d5d2902df19b"}, {file = "mypy-1.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52b9e1492e47e1790360a43755fa04101a7ac72287b1a53ce817f35899ba0521"},
{file = "mypy-1.11.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a8a53bc3ffbd161b5b2a4fff2f0f1e23a33b0168f1c0778ec70e1a3d66deb86"}, {file = "mypy-1.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:48d3e37dd7d9403e38fa86c46191de72705166d40b8c9f91a3de77350daa0893"},
{file = "mypy-1.11.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ff93107f01968ed834f4256bc1fc4475e2fecf6c661260066a985b52741ddce"}, {file = "mypy-1.12.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f106db5ccb60681b622ac768455743ee0e6a857724d648c9629a9bd2ac3f721"},
{file = "mypy-1.11.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:edb91dded4df17eae4537668b23f0ff6baf3707683734b6a818d5b9d0c0c31a1"}, {file = "mypy-1.12.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:233e11b3f73ee1f10efada2e6da0f555b2f3a5316e9d8a4a1224acc10e7181d3"},
{file = "mypy-1.11.2-cp38-cp38-win_amd64.whl", hash = "sha256:ee23de8530d99b6db0573c4ef4bd8f39a2a6f9b60655bf7a1357e585a3486f2b"}, {file = "mypy-1.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:4ae8959c21abcf9d73aa6c74a313c45c0b5a188752bf37dace564e29f06e9c1b"},
{file = "mypy-1.11.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:801ca29f43d5acce85f8e999b1e431fb479cb02d0e11deb7d2abb56bdaf24fd6"}, {file = "mypy-1.12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:eafc1b7319b40ddabdc3db8d7d48e76cfc65bbeeafaa525a4e0fa6b76175467f"},
{file = "mypy-1.11.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:af8d155170fcf87a2afb55b35dc1a0ac21df4431e7d96717621962e4b9192e70"}, {file = "mypy-1.12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9b9ce1ad8daeb049c0b55fdb753d7414260bad8952645367e70ac91aec90e07e"},
{file = "mypy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7821776e5c4286b6a13138cc935e2e9b6fde05e081bdebf5cdb2bb97c9df81d"}, {file = "mypy-1.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfe012b50e1491d439172c43ccb50db66d23fab714d500b57ed52526a1020bb7"},
{file = "mypy-1.11.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:539c570477a96a4e6fb718b8d5c3e0c0eba1f485df13f86d2970c91f0673148d"}, {file = "mypy-1.12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2c40658d4fa1ab27cb53d9e2f1066345596af2f8fe4827defc398a09c7c9519b"},
{file = "mypy-1.11.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f14cd3d386ac4d05c5a39a51b84387403dadbd936e17cb35882134d4f8f0d24"}, {file = "mypy-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:dee78a8b9746c30c1e617ccb1307b351ded57f0de0d287ca6276378d770006c0"},
{file = "mypy-1.11.2-py3-none-any.whl", hash = "sha256:b499bc07dbdcd3de92b0a8b29fdf592c111276f6a12fe29c30f6c417dd546d12"}, {file = "mypy-1.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b5df6c8a8224f6b86746bda716bbe4dbe0ce89fd67b1fa4661e11bfe38e8ec8"},
{file = "mypy-1.11.2.tar.gz", hash = "sha256:7f9993ad3e0ffdc95c2a14b66dee63729f021968bff8ad911867579c65d13a79"}, {file = "mypy-1.12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5feee5c74eb9749e91b77f60b30771563327329e29218d95bedbe1257e2fe4b0"},
{file = "mypy-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:77278e8c6ffe2abfba6db4125de55f1024de9a323be13d20e4f73b8ed3402bd1"},
{file = "mypy-1.12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:dcfb754dea911039ac12434d1950d69a2f05acd4d56f7935ed402be09fad145e"},
{file = "mypy-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:06de0498798527451ffb60f68db0d368bd2bae2bbfb5237eae616d4330cc87aa"},
{file = "mypy-1.12.0-py3-none-any.whl", hash = "sha256:fd313226af375d52e1e36c383f39bf3836e1f192801116b31b090dfcd3ec5266"},
{file = "mypy-1.12.0.tar.gz", hash = "sha256:65a22d87e757ccd95cbbf6f7e181e6caa87128255eb2b6be901bb71b26d8a99d"},
] ]
[package.dependencies] [package.dependencies]
@ -4552,4 +4557,4 @@ type = ["pytest-mypy"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<4.0" python-versions = ">=3.9,<4.0"
content-hash = "ceb97d5b463eacb11b6c7d9c572101e7f25c3f6b008178c6e9a273ed7e02920b" content-hash = "5c436a9ba9a1695c5c456c1ad8a81c9772a2ba0248624278c6cb606dd019b338"

View File

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "langchain-community" name = "langchain-community"
version = "0.3.2" version = "0.3.3"
description = "Community contributed LangChain integrations." description = "Community contributed LangChain integrations."
authors = [] authors = []
license = "MIT" license = "MIT"
@ -33,8 +33,8 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<4.0" python = ">=3.9,<4.0"
langchain-core = "^0.3.10" langchain-core = "^0.3.12"
langchain = "^0.3.3" langchain = "^0.3.4"
SQLAlchemy = ">=1.4,<3" SQLAlchemy = ">=1.4,<3"
requests = "^2" requests = "^2"
PyYAML = ">=5.3" PyYAML = ">=5.3"
@ -130,7 +130,7 @@ jupyter = "^1.0.0"
setuptools = "^67.6.1" setuptools = "^67.6.1"
[tool.poetry.group.typing.dependencies] [tool.poetry.group.typing.dependencies]
mypy = "^1.10" mypy = "^1.12"
types-pyyaml = "^6.0.12.2" types-pyyaml = "^6.0.12.2"
types-requests = "^2.28.11.5" types-requests = "^2.28.11.5"
types-toml = "^0.10.8.1" types-toml = "^0.10.8.1"

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "langchain" name = "langchain"
version = "0.3.3" version = "0.3.4"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
authors = [] authors = []
license = "MIT" license = "MIT"
@ -33,7 +33,7 @@ langchain-server = "langchain.server:main"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<4.0" python = ">=3.9,<4.0"
langchain-core = "^0.3.10" langchain-core = "^0.3.12"
langchain-text-splitters = "^0.3.0" langchain-text-splitters = "^0.3.0"
langsmith = "^0.1.17" langsmith = "^0.1.17"
pydantic = "^2.7.4" pydantic = "^2.7.4"

View File

@ -1 +0,0 @@
__pycache__

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,61 +0,0 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/azure --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_azure_dynamic_sessions
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_azure_dynamic_sessions -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -1,36 +1,3 @@
# langchain-azure-dynamic-sessions This package has moved!
This package contains the LangChain integration for Azure Container Apps dynamic sessions. You can use it to add a secure and scalable code interpreter to your agents.
## Installation
```bash
pip install -U langchain-azure-dynamic-sessions
```
## Usage
You first need to create an Azure Container Apps session pool and obtain its management endpoint. Then you can use the `SessionsPythonREPLTool` tool to give your agent the ability to execute Python code.
```python
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
# get the management endpoint from the session pool in the Azure portal
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
prompt = hub.pull("hwchase17/react")
tools=[tool]
react_agent = create_react_agent(
llm=llm,
tools=tools,
prompt=prompt,
)
react_agent_executor = AgentExecutor(agent=react_agent, tools=tools, verbose=True, handle_parsing_errors=True)
react_agent_executor.invoke({"input": "What is the current time in Vancouver, Canada?"})
```
By default, the tool uses `DefaultAzureCredential` to authenticate with Azure. If you're using a user-assigned managed identity, you must set the `AZURE_CLIENT_ID` environment variable to the ID of the managed identity.
https://github.com/langchain-ai/langchain-azure/tree/main/libs/azure-dynamic-sessions

View File

@ -1,7 +0,0 @@
"""This package provides tools for managing dynamic sessions in Azure."""
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
__all__ = [
"SessionsPythonREPLTool",
]

View File

@ -1,7 +0,0 @@
"""This package provides tools for managing dynamic sessions in Azure."""
from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool
__all__ = [
"SessionsPythonREPLTool",
]

View File

@ -1,314 +0,0 @@
"""This is the Azure Dynamic Sessions module.
This module provides the SessionsPythonREPLTool class for
managing dynamic sessions in Azure.
"""
import importlib.metadata
import json
import os
import re
import urllib
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from io import BytesIO
from typing import Any, BinaryIO, Callable, List, Literal, Optional, Tuple
from uuid import uuid4
import requests
from azure.core.credentials import AccessToken
from azure.identity import DefaultAzureCredential
from langchain_core.tools import BaseTool
try:
_package_version = importlib.metadata.version("langchain-azure-dynamic-sessions")
except importlib.metadata.PackageNotFoundError:
_package_version = "0.0.0"
USER_AGENT = f"langchain-azure-dynamic-sessions/{_package_version} (Language=Python)"
def _access_token_provider_factory() -> Callable[[], Optional[str]]:
"""Factory function for creating an access token provider function.
Returns:
Callable[[], Optional[str]]: The access token provider function
"""
access_token: Optional[AccessToken] = None
def access_token_provider() -> Optional[str]:
nonlocal access_token
if access_token is None or datetime.fromtimestamp(
access_token.expires_on, timezone.utc
) < datetime.now(timezone.utc) + timedelta(minutes=5):
credential = DefaultAzureCredential()
access_token = credential.get_token("https://dynamicsessions.io/.default")
return access_token.token
return access_token_provider
def _sanitize_input(query: str) -> str:
"""Sanitize input to the python REPL.
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
Args:
query: The query to sanitize
Returns:
str: The sanitized query
"""
# Removes `, whitespace & python from start
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
# Removes whitespace & ` from end
query = re.sub(r"(\s|`)*$", "", query)
return query
@dataclass
class RemoteFileMetadata:
"""Metadata for a file in the session."""
filename: str
"""The filename relative to `/mnt/data`."""
size_in_bytes: int
"""The size of the file in bytes."""
@property
def full_path(self) -> str:
"""Get the full path of the file."""
return f"/mnt/data/{self.filename}"
@staticmethod
def from_dict(data: dict) -> "RemoteFileMetadata":
"""Create a RemoteFileMetadata object from a dictionary."""
properties = data.get("properties", {})
return RemoteFileMetadata(
filename=properties.get("filename"),
size_in_bytes=properties.get("size"),
)
class SessionsPythonREPLTool(BaseTool):
r"""Azure Dynamic Sessions tool.
Setup:
Install ``langchain-azure-dynamic-sessions`` and create a session pool, which you can do by following the instructions [here](https://learn.microsoft.com/en-us/azure/container-apps/sessions-code-interpreter?tabs=azure-cli#create-a-session-pool-with-azure-cli).
.. code-block:: bash
pip install -U langchain-azure-dynamic-sessions
.. code-block:: python
import getpass
POOL_MANAGEMENT_ENDPOINT = getpass.getpass("Enter the management endpoint of the session pool: ")
Instantiation:
.. code-block:: python
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
tool = SessionsPythonREPLTool(
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT
)
Invocation with args:
.. code-block:: python
tool.invoke("6 * 7")
.. code-block:: python
'{\\n "result": 42,\\n "stdout": "",\\n "stderr": ""\\n}'
Invocation with ToolCall:
.. code-block:: python
tool.invoke({"args": {"input":"6 * 7"}, "id": "1", "name": tool.name, "type": "tool_call"})
.. code-block:: python
'{\\n "result": 42,\\n "stdout": "",\\n "stderr": ""\\n}'
""" # noqa: E501
name: str = "Python_REPL"
description: str = (
"A Python shell. Use this to execute python commands "
"when you need to perform calculations or computations. "
"Input should be a valid python command. "
"Returns a JSON object with the result, stdout, and stderr. "
)
sanitize_input: bool = True
"""Whether to sanitize input to the python REPL."""
pool_management_endpoint: str
"""The management endpoint of the session pool. Should end with a '/'."""
access_token_provider: Callable[[], Optional[str]] = (
_access_token_provider_factory()
)
"""A function that returns the access token to use for the session pool."""
session_id: str = str(uuid4())
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
response_format: Literal["content_and_artifact"] = "content_and_artifact"
def _build_url(self, path: str) -> str:
pool_management_endpoint = self.pool_management_endpoint
if not pool_management_endpoint:
raise ValueError("pool_management_endpoint is not set")
if not pool_management_endpoint.endswith("/"):
pool_management_endpoint += "/"
encoded_session_id = urllib.parse.quote(self.session_id)
query = f"identifier={encoded_session_id}&api-version=2024-02-02-preview"
query_separator = "&" if "?" in pool_management_endpoint else "?"
full_url = pool_management_endpoint + path + query_separator + query
return full_url
def execute(self, python_code: str) -> Any:
"""Execute Python code in the session."""
if self.sanitize_input:
python_code = _sanitize_input(python_code)
access_token = self.access_token_provider()
api_url = self._build_url("code/execute")
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"User-Agent": USER_AGENT,
}
body = {
"properties": {
"codeInputType": "inline",
"executionType": "synchronous",
"code": python_code,
}
}
response = requests.post(api_url, headers=headers, json=body)
response.raise_for_status()
response_json = response.json()
properties = response_json.get("properties", {})
return properties
def _run(self, python_code: str, **kwargs: Any) -> Tuple[str, dict]:
response = self.execute(python_code)
# if the result is an image, remove the base64 data
result = deepcopy(response.get("result"))
if isinstance(result, dict):
if result.get("type") == "image" and "base64_data" in result:
result.pop("base64_data")
content = json.dumps(
{
"result": result,
"stdout": response.get("stdout"),
"stderr": response.get("stderr"),
},
indent=2,
)
return content, response
def upload_file(
self,
*,
data: Optional[BinaryIO] = None,
remote_file_path: Optional[str] = None,
local_file_path: Optional[str] = None,
) -> RemoteFileMetadata:
"""Upload a file to the session.
Args:
data: The data to upload.
remote_file_path: The path to upload the file to, relative to
`/mnt/data`. If local_file_path is provided, this is defaulted
to its filename.
local_file_path: The path to the local file to upload.
Returns:
RemoteFileMetadata: The metadata for the uploaded file
"""
if data and local_file_path:
raise ValueError("data and local_file_path cannot be provided together")
if data:
file_data = data
elif local_file_path:
if not remote_file_path:
remote_file_path = os.path.basename(local_file_path)
file_data = open(local_file_path, "rb")
access_token = self.access_token_provider()
api_url = self._build_url("files/upload")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
files = [("file", (remote_file_path, file_data, "application/octet-stream"))]
response = requests.request(
"POST", api_url, headers=headers, data={}, files=files
)
response.raise_for_status()
response_json = response.json()
return RemoteFileMetadata.from_dict(response_json["value"][0])
def download_file(
self, *, remote_file_path: str, local_file_path: Optional[str] = None
) -> BinaryIO:
"""Download a file from the session.
Args:
remote_file_path: The path to download the file from,
relative to `/mnt/data`.
local_file_path: The path to save the downloaded file to.
If not provided, the file is returned as a BufferedReader.
Returns:
BinaryIO: The data of the downloaded file.
"""
access_token = self.access_token_provider()
encoded_remote_file_path = urllib.parse.quote(remote_file_path)
api_url = self._build_url(f"files/content/{encoded_remote_file_path}")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
response = requests.get(api_url, headers=headers)
response.raise_for_status()
if local_file_path:
with open(local_file_path, "wb") as f:
f.write(response.content)
return BytesIO(response.content)
def list_files(self) -> List[RemoteFileMetadata]:
"""List the files in the session.
Returns:
list[RemoteFileMetadata]: The metadata for the files in the session
"""
access_token = self.access_token_provider()
api_url = self._build_url("files")
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": USER_AGENT,
}
response = requests.get(api_url, headers=headers)
response.raise_for_status()
response_json = response.json()
return [RemoteFileMetadata.from_dict(entry) for entry in response_json["value"]]

File diff suppressed because it is too large Load Diff

View File

@ -1,115 +0,0 @@
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-azure-dynamic-sessions"
version = "0.2.0"
description = "An integration package connecting Azure Container Apps dynamic sessions and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.mypy]
disallow_untyped_defs = "True"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/azure-dynamic-sessions"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-azure-dynamic-sessions%3D%3D0%22&expanded=true"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
langchain-core = "^0.3.0"
azure-identity = "^1.16.0"
requests = "^2.31.0"
[tool.ruff.lint]
select = ["E", "F", "I", "D"]
[tool.coverage.run]
omit = ["tests/*"]
[tool.pytest.ini_options]
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
markers = [
"requires: mark tests as requiring a specific library",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.dev]
optional = true
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.ruff.lint.per-file-ignores]
"tests/**" = ["D"]
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
python-dotenv = "^1.0.1"
# TODO: hack to fix 3.9 builds
cffi = [
{ version = "<1.17.1", python = "<3.10" },
{ version = "*", python = ">=3.10" },
]
[tool.poetry.group.test_integration.dependencies]
pytest = "^7.3.0"
python-dotenv = "^1.0.1"
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.lint.dependencies]
ruff = "^0.5"
python-dotenv = "^1.0.1"
pytest = "^7.3.0"
# TODO: hack to fix 3.9 builds
cffi = [
{ version = "<1.17.1", python = "<3.10" },
{ version = "*", python = ">=3.10" },
]
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.29.4"
langchainhub = "^0.1.15"
[tool.poetry.group.typing.dependencies]
mypy = "^1.10"
types-requests = "^2.31.0.20240406"
[tool.poetry.group.test.dependencies.langchain-core]
path = "../../core"
develop = true
[tool.poetry.group.dev.dependencies.langchain-core]
path = "../../core"
develop = true
[tool.poetry.group.dev.dependencies.langchain-openai]
path = "../openai"
develop = true
[tool.poetry.group.typing.dependencies.langchain-core]
path = "../../core"
develop = true

View File

@ -1,19 +0,0 @@
"""This module checks for specific import statements in the codebase."""
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_failure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

View File

@ -1,17 +0,0 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@ -1,7 +0,0 @@
import pytest # type: ignore[import-not-found]
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -1,68 +0,0 @@
import json
import os
from io import BytesIO
import dotenv # type: ignore[import-not-found]
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
dotenv.load_dotenv()
POOL_MANAGEMENT_ENDPOINT = os.getenv("AZURE_DYNAMIC_SESSIONS_POOL_MANAGEMENT_ENDPOINT")
TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "testdata.txt")
TEST_DATA_CONTENT = open(TEST_DATA_PATH, "rb").read()
def test_end_to_end() -> None:
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT) # type: ignore[arg-type]
result = tool.run("print('hello world')\n1 + 1")
assert json.loads(result) == {
"result": 2,
"stdout": "hello world\n",
"stderr": "",
}
# upload file content
uploaded_file1_metadata = tool.upload_file(
remote_file_path="test1.txt", data=BytesIO(b"hello world!!!!!")
)
assert uploaded_file1_metadata.filename == "test1.txt"
assert uploaded_file1_metadata.size_in_bytes == 16
assert uploaded_file1_metadata.full_path == "/mnt/data/test1.txt"
downloaded_file1 = tool.download_file(remote_file_path="test1.txt")
assert downloaded_file1.read() == b"hello world!!!!!"
# upload file from buffer
with open(TEST_DATA_PATH, "rb") as f:
uploaded_file2_metadata = tool.upload_file(remote_file_path="test2.txt", data=f)
assert uploaded_file2_metadata.filename == "test2.txt"
downloaded_file2 = tool.download_file(remote_file_path="test2.txt")
assert downloaded_file2.read() == TEST_DATA_CONTENT
# upload file from disk, specifying remote file path
uploaded_file3_metadata = tool.upload_file(
remote_file_path="test3.txt", local_file_path=TEST_DATA_PATH
)
assert uploaded_file3_metadata.filename == "test3.txt"
downloaded_file3 = tool.download_file(remote_file_path="test3.txt")
assert downloaded_file3.read() == TEST_DATA_CONTENT
# upload file from disk, without specifying remote file path
uploaded_file4_metadata = tool.upload_file(local_file_path=TEST_DATA_PATH)
assert uploaded_file4_metadata.filename == os.path.basename(TEST_DATA_PATH)
downloaded_file4 = tool.download_file(
remote_file_path=uploaded_file4_metadata.filename
)
assert downloaded_file4.read() == TEST_DATA_CONTENT
# list files
remote_files_metadata = tool.list_files()
assert len(remote_files_metadata) == 4
remote_file_paths = [metadata.filename for metadata in remote_files_metadata]
expected_filenames = [
"test1.txt",
"test2.txt",
"test3.txt",
os.path.basename(TEST_DATA_PATH),
]
assert set(remote_file_paths) == set(expected_filenames)

View File

@ -1,9 +0,0 @@
from langchain_azure_dynamic_sessions import __all__
EXPECTED_ALL = [
"SessionsPythonREPLTool",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

View File

@ -1,208 +0,0 @@
import json
import re
import time
from unittest import mock
from urllib.parse import parse_qs, urlparse
from azure.core.credentials import AccessToken
from langchain_azure_dynamic_sessions import SessionsPythonREPLTool
from langchain_azure_dynamic_sessions.tools.sessions import (
_access_token_provider_factory,
)
POOL_MANAGEMENT_ENDPOINT = "https://westus2.dynamicsessions.io/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/sessions-rg/sessionPools/my-pool"
def test_default_access_token_provider_returns_token() -> None:
access_token_provider = _access_token_provider_factory()
with mock.patch(
"azure.identity.DefaultAzureCredential.get_token"
) as mock_get_token:
mock_get_token.return_value = AccessToken("token_value", 0)
access_token = access_token_provider()
assert access_token == "token_value"
def test_default_access_token_provider_returns_cached_token() -> None:
access_token_provider = _access_token_provider_factory()
with mock.patch(
"azure.identity.DefaultAzureCredential.get_token"
) as mock_get_token:
mock_get_token.return_value = AccessToken(
"token_value", int(time.time() + 1000)
)
access_token = access_token_provider()
assert access_token == "token_value"
assert mock_get_token.call_count == 1
mock_get_token.return_value = AccessToken(
"new_token_value", int(time.time() + 1000)
)
access_token = access_token_provider()
assert access_token == "token_value"
assert mock_get_token.call_count == 1
def test_default_access_token_provider_refreshes_expiring_token() -> None:
access_token_provider = _access_token_provider_factory()
with mock.patch(
"azure.identity.DefaultAzureCredential.get_token"
) as mock_get_token:
mock_get_token.return_value = AccessToken("token_value", int(time.time() - 1))
access_token = access_token_provider()
assert access_token == "token_value"
assert mock_get_token.call_count == 1
mock_get_token.return_value = AccessToken(
"new_token_value", int(time.time() + 1000)
)
access_token = access_token_provider()
assert access_token == "new_token_value"
assert mock_get_token.call_count == 2
@mock.patch("requests.post")
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
def test_code_execution_calls_api(
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
) -> None:
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "hello world\n",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
result = tool.run("print('hello world')")
assert json.loads(result) == {
"result": "",
"stdout": "hello world\n",
"stderr": "",
}
api_url = f"{POOL_MANAGEMENT_ENDPOINT}/code/execute"
headers = {
"Authorization": "Bearer token_value",
"Content-Type": "application/json",
"User-Agent": mock.ANY,
}
body = {
"properties": {
"codeInputType": "inline",
"executionType": "synchronous",
"code": "print('hello world')",
}
}
mock_post.assert_called_once_with(mock.ANY, headers=headers, json=body)
called_headers = mock_post.call_args.kwargs["headers"]
assert re.match(
r"^langchain-azure-dynamic-sessions/\d+\.\d+\.\d+.* \(Language=Python\)",
called_headers["User-Agent"],
)
called_api_url = mock_post.call_args.args[0]
assert called_api_url.startswith(api_url)
@mock.patch("requests.post")
@mock.patch("azure.identity.DefaultAzureCredential.get_token")
def test_uses_specified_session_id(
mock_get_token: mock.MagicMock, mock_post: mock.MagicMock
) -> None:
tool = SessionsPythonREPLTool(
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
session_id="00000000-0000-0000-0000-000000000003",
)
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "2",
"executionTimeInMilliseconds": 33,
},
}
mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000))
tool.run("1 + 1")
call_url = mock_post.call_args.args[0]
parsed_url = urlparse(call_url)
call_identifier = parse_qs(parsed_url.query)["identifier"][0]
assert call_identifier == "00000000-0000-0000-0000-000000000003"
def test_sanitizes_input() -> None:
tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT)
with mock.patch("requests.post") as mock_post:
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
tool.run("```python\nprint('hello world')\n```")
body = mock_post.call_args.kwargs["json"]
assert body["properties"]["code"] == "print('hello world')"
def test_does_not_sanitize_input() -> None:
tool = SessionsPythonREPLTool(
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT, sanitize_input=False
)
with mock.patch("requests.post") as mock_post:
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
tool.run("```python\nprint('hello world')\n```")
body = mock_post.call_args.kwargs["json"]
assert body["properties"]["code"] == "```python\nprint('hello world')\n```"
def test_uses_custom_access_token_provider() -> None:
def custom_access_token_provider() -> str:
return "custom_token"
tool = SessionsPythonREPLTool(
pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT,
access_token_provider=custom_access_token_provider,
)
with mock.patch("requests.post") as mock_post:
mock_post.return_value.json.return_value = {
"$id": "1",
"properties": {
"$id": "2",
"status": "Success",
"stdout": "",
"stderr": "",
"result": "",
"executionTimeInMilliseconds": 33,
},
}
tool.run("print('hello world')")
headers = mock_post.call_args.kwargs["headers"]
assert headers["Authorization"] == "Bearer custom_token"

View File

@ -1 +0,0 @@
__pycache__

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2024 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,59 +0,0 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
integration_test integration_tests: TEST_FILE=tests/integration_tests/
test tests integration_test integration_tests:
poetry run pytest $(TEST_FILE)
test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/mongodb --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_mongodb
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_mongodb -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -1,39 +1,3 @@
# langchain-mongodb This package has moved!
# Installation https://github.com/langchain-ai/langchain-mongodb/tree/main/libs/mongodb
```
pip install -U langchain-mongodb
```
# Usage
- See [Getting Started with the LangChain Integration](https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/langchain/#get-started-with-the-langchain-integration) for a walkthrough on using your first LangChain implementation with MongoDB Atlas.
## Using MongoDBAtlasVectorSearch
```python
from langchain_mongodb import MongoDBAtlasVectorSearch
# Pull MongoDB Atlas URI from environment variables
MONGODB_ATLAS_CLUSTER_URI = os.environ.get("MONGODB_ATLAS_CLUSTER_URI")
DB_NAME = "langchain_db"
COLLECTION_NAME = "test"
ATLAS_VECTOR_SEARCH_INDEX_NAME = "index_name"
MONGODB_COLLECTION = client[DB_NAME][COLLECTION_NAME]
# Create the vector search via `from_connection_string`
vector_search = MongoDBAtlasVectorSearch.from_connection_string(
MONGODB_ATLAS_CLUSTER_URI,
DB_NAME + "." + COLLECTION_NAME,
OpenAIEmbeddings(disallowed_special=()),
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
)
# Initialize MongoDB python client
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
# Create the vector search via instantiation
vector_search_2 = MongoDBAtlasVectorSearch(
collection=MONGODB_COLLECTION,
embeddings=OpenAIEmbeddings(disallowed_special=()),
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
)
```

View File

@ -1,20 +0,0 @@
"""
Integrate your operational database and vector search in a single, unified,
fully managed platform with full vector database capabilities on MongoDB Atlas.
Store your operational data, metadata, and vector embeddings in oue VectorStore,
MongoDBAtlasVectorSearch.
Insert into a Chain via a Vector, FullText, or Hybrid Retriever.
"""
from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
__all__ = [
"MongoDBAtlasVectorSearch",
"MongoDBChatMessageHistory",
"MongoDBCache",
"MongoDBAtlasSemanticCache",
]

View File

@ -1,308 +0,0 @@
"""LangChain MongoDB Caches."""
import json
import logging
import time
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Union
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.outputs import Generation
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.driver_info import DriverInfo
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
logger = logging.getLogger(__file__)
class MongoDBCache(BaseCache):
"""MongoDB Atlas cache
A cache that uses MongoDB Atlas as a backend
"""
PROMPT = "prompt"
LLM = "llm"
RETURN_VAL = "return_val"
def __init__(
self,
connection_string: str,
collection_name: str = "default",
database_name: str = "default",
**kwargs: Dict[str, Any],
) -> None:
"""
Initialize Atlas Cache. Creates collection on instantiation
Args:
collection_name (str): Name of collection for cache to live.
Defaults to "default".
connection_string (str): Connection URI to MongoDB Atlas.
Defaults to "default".
database_name (str): Name of database for cache to live.
Defaults to "default".
"""
self.client = _generate_mongo_client(connection_string)
self.__database_name = database_name
self.__collection_name = collection_name
if self.__collection_name not in self.database.list_collection_names():
self.database.create_collection(self.__collection_name)
# Create an index on key and llm_string
self.collection.create_index([self.PROMPT, self.LLM])
@property
def database(self) -> Database:
"""Returns the database used to store cache values."""
return self.client[self.__database_name]
@property
def collection(self) -> Collection:
"""Returns the collection used to store cache values."""
return self.database[self.__collection_name]
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
return_doc = (
self.collection.find_one(self._generate_keys(prompt, llm_string)) or {}
)
return_val = return_doc.get(self.RETURN_VAL)
return _loads_generations(return_val) if return_val else None # type: ignore
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
self.collection.update_one(
{**self._generate_keys(prompt, llm_string)},
{"$set": {self.RETURN_VAL: _dumps_generations(return_val)}},
upsert=True,
)
def _generate_keys(self, prompt: str, llm_string: str) -> Dict[str, str]:
"""Create keyed fields for caching layer"""
return {self.PROMPT: prompt, self.LLM: llm_string}
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments.
Any additional arguments will propagate as filtration criteria for
what gets deleted.
E.g.
# Delete only entries that have llm_string as "fake-model"
self.clear(llm_string="fake-model")
"""
self.collection.delete_many({**kwargs})
class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
"""MongoDB Atlas Semantic cache.
A Cache backed by a MongoDB Atlas server with vector-store support
"""
LLM = "llm_string"
RETURN_VAL = "return_val"
def __init__(
self,
connection_string: str,
embedding: Embeddings,
collection_name: str = "default",
database_name: str = "default",
index_name: str = "default",
wait_until_ready: Optional[float] = None,
score_threshold: Optional[float] = None,
**kwargs: Dict[str, Any],
):
"""
Initialize Atlas VectorSearch Cache.
Assumes collection exists before instantiation
Args:
connection_string (str): MongoDB URI to connect to MongoDB Atlas cluster.
embedding (Embeddings): Text embedding model to use.
collection_name (str): MongoDB Collection to add the texts to.
Defaults to "default".
database_name (str): MongoDB Database where to store texts.
Defaults to "default".
index_name: Name of the Atlas Search index.
defaults to 'default'
wait_until_ready (float): Wait this time for Atlas to finish indexing
the stored text. Defaults to None.
"""
client = _generate_mongo_client(connection_string)
self.collection = client[database_name][collection_name]
self.score_threshold = score_threshold
self._wait_until_ready = wait_until_ready
super().__init__(
collection=self.collection,
embedding=embedding,
index_name=index_name,
**kwargs, # type: ignore
)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
post_filter_pipeline = (
[{"$match": {"score": {"$gte": self.score_threshold}}}]
if self.score_threshold
else None
)
search_response = self.similarity_search_with_score(
prompt,
1,
pre_filter={self.LLM: {"$eq": llm_string}},
post_filter_pipeline=post_filter_pipeline,
)
if search_response:
return_val = search_response[0][0].metadata.get(self.RETURN_VAL)
response = _loads_generations(return_val) or return_val # type: ignore
return response
return None
def update(
self,
prompt: str,
llm_string: str,
return_val: RETURN_VAL_TYPE,
wait_until_ready: Optional[float] = None,
) -> None:
"""Update cache based on prompt and llm_string."""
self.add_texts(
[prompt],
[
{
self.LLM: llm_string,
self.RETURN_VAL: _dumps_generations(return_val),
}
],
)
wait = self._wait_until_ready if wait_until_ready is None else wait_until_ready
def is_indexed() -> bool:
return self.lookup(prompt, llm_string) == return_val
if wait:
_wait_until(is_indexed, return_val, timeout=wait)
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments.
Any additional arguments will propagate as filtration criteria for
what gets deleted. It will delete any locally cached content regardless
E.g.
# Delete only entries that have llm_string as "fake-model"
self.clear(llm_string="fake-model")
"""
self.collection.delete_many({**kwargs})
def _generate_mongo_client(connection_string: str) -> MongoClient:
return MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
def _dumps_generations(generations: RETURN_VAL_TYPE) -> str:
"""
Serialization for generic RETURN_VAL_TYPE, i.e. sequence of `Generation`
Args:
generations (RETURN_VAL_TYPE): A list of language model generations.
Returns:
str: a single string representing a list of generations.
This, and "_dumps_generations" are duplicated in this utility
from modules: "libs/community/langchain_community/cache.py"
This function and its counterpart rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.
Each item in the list can be `dumps`ed to a string,
then we make the whole list of strings into a json-dumped.
"""
return json.dumps([dumps(_item) for _item in generations])
def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
"""
Deserialization of a string into a generic RETURN_VAL_TYPE
(i.e. a sequence of `Generation`).
Args:
generations_str (str): A string representing a list of generations.
Returns:
RETURN_VAL_TYPE: A list of generations.
This function and its counterpart rely on
the dumps/loads pair with Reviver, so are able to deal
with all subclasses of Generation.
See `_dumps_generations`, the inverse of this function.
Compatible with the legacy cache-blob format
Does not raise exceptions for malformed entries, just logs a warning
and returns none: the caller should be prepared for such a cache miss.
"""
try:
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
return generations
except (json.JSONDecodeError, TypeError):
# deferring the (soft) handling to after the legacy-format attempt
pass
try:
gen_dicts = json.loads(generations_str)
# not relying on `_load_generations_from_json` (which could disappear):
generations = [Generation(**generation_dict) for generation_dict in gen_dicts]
logger.warning(
f"Legacy 'Generation' cached blob encountered: '{generations_str}'"
)
return generations
except (json.JSONDecodeError, TypeError):
logger.warning(
f"Malformed/unparsable cached blob encountered: '{generations_str}'"
)
return None
def _wait_until(
predicate: Callable, success_description: Any, timeout: float = 10.0
) -> None:
"""Wait up to 10 seconds (by default) for predicate to be true.
E.g.:
wait_until(lambda: client.primary == ('a', 1),
'connect to the primary')
If the lambda-expression isn't true after 10 seconds, we raise
AssertionError("Didn't ever connect to the primary").
Returns the predicate's first true value.
"""
start = time.time()
interval = min(float(timeout) / 100, 0.1)
while True:
retval = predicate()
if retval:
return retval
if time.time() - start > timeout:
raise TimeoutError("Didn't ever %s" % success_description)
time.sleep(interval)

View File

@ -1,162 +0,0 @@
import json
import logging
from typing import Dict, List, Optional
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
from pymongo import MongoClient, errors
logger = logging.getLogger(__name__)
DEFAULT_DBNAME = "chat_history"
DEFAULT_COLLECTION_NAME = "message_store"
DEFAULT_SESSION_ID_KEY = "SessionId"
DEFAULT_HISTORY_KEY = "History"
class MongoDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in MongoDB.
Setup:
Install ``langchain-mongodb`` python package.
.. code-block:: bash
pip install langchain-mongodb
Instantiate:
.. code-block:: python
from langchain_mongodb import MongoDBChatMessageHistory
history = MongoDBChatMessageHistory(
connection_string="mongodb://your-host:your-port/", # mongodb://localhost:27017/
session_id = "your-session-id",
)
Add and retrieve messages:
.. code-block:: python
# Add single message
history.add_message(message)
# Add batch messages
history.add_messages([message1, message2, message3, ...])
# Add human message
history.add_user_message(human_message)
# Add ai message
history.add_ai_message(ai_message)
# Retrieve messages
messages = history.messages
""" # noqa: E501
def __init__(
self,
connection_string: str,
session_id: str,
database_name: str = DEFAULT_DBNAME,
collection_name: str = DEFAULT_COLLECTION_NAME,
*,
session_id_key: str = DEFAULT_SESSION_ID_KEY,
history_key: str = DEFAULT_HISTORY_KEY,
create_index: bool = True,
history_size: Optional[int] = None,
index_kwargs: Optional[Dict] = None,
):
"""Initialize with a MongoDBChatMessageHistory instance.
Args:
connection_string: str
connection string to connect to MongoDB.
session_id: str
arbitrary key that is used to store the messages of
a single chat session.
database_name: Optional[str]
name of the database to use.
collection_name: Optional[str]
name of the collection to use.
session_id_key: Optional[str]
name of the field that stores the session id.
history_key: Optional[str]
name of the field that stores the chat history.
create_index: Optional[bool]
whether to create an index on the session id field.
history_size: Optional[int]
count of (most recent) messages to fetch from MongoDB.
index_kwargs: Optional[Dict]
additional keyword arguments to pass to the index creation.
"""
self.connection_string = connection_string
self.session_id = session_id
self.database_name = database_name
self.collection_name = collection_name
self.session_id_key = session_id_key
self.history_key = history_key
self.history_size = history_size
try:
self.client: MongoClient = MongoClient(connection_string)
except errors.ConnectionFailure as error:
logger.error(error)
self.db = self.client[database_name]
self.collection = self.db[collection_name]
if create_index:
index_kwargs = index_kwargs or {}
self.collection.create_index(self.session_id_key, **index_kwargs)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from MongoDB"""
try:
if self.history_size is None:
cursor = self.collection.find({self.session_id_key: self.session_id})
else:
skip_count = max(
0,
self.collection.count_documents(
{self.session_id_key: self.session_id}
)
- self.history_size,
)
cursor = self.collection.find(
{self.session_id_key: self.session_id}, skip=skip_count
)
except errors.OperationFailure as error:
logger.error(error)
if cursor:
items = [json.loads(document[self.history_key]) for document in cursor]
else:
items = []
messages = messages_from_dict(items)
return messages
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in MongoDB"""
try:
self.collection.insert_one(
{
self.session_id_key: self.session_id,
self.history_key: json.dumps(message_to_dict(message)),
}
)
except errors.WriteError as err:
logger.error(err)
def clear(self) -> None:
"""Clear session memory from MongoDB"""
try:
self.collection.delete_many({self.session_id_key: self.session_id})
except errors.WriteError as err:
logger.error(err)

View File

@ -1,270 +0,0 @@
"""Search Index Commands"""
import logging
from time import monotonic, sleep
from typing import Any, Callable, Dict, List, Optional
from pymongo.collection import Collection
from pymongo.errors import OperationFailure
from pymongo.operations import SearchIndexModel
logger = logging.getLogger(__file__)
def _search_index_error_message() -> str:
return (
"Search index operations are not currently available on shared clusters, "
"such as MO. They require dedicated clusters >= M10. "
"You may still perform vector search. "
"You simply must set up indexes manually. Follow the instructions here: "
"https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/"
)
def _vector_search_index_definition(
dimensions: int,
path: str,
similarity: str,
filters: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/
fields = [
{
"numDimensions": dimensions,
"path": path,
"similarity": similarity,
"type": "vector",
},
]
if filters:
for field in filters:
fields.append({"type": "filter", "path": field})
definition = {"fields": fields}
definition.update(kwargs)
return definition
def create_vector_search_index(
collection: Collection,
index_name: str,
dimensions: int,
path: str,
similarity: str,
filters: Optional[List[str]] = None,
*,
wait_until_complete: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Experimental Utility function to create a vector search index
Args:
collection (Collection): MongoDB Collection
index_name (str): Name of Index
dimensions (int): Number of dimensions in embedding
path (str): field with vector embedding
similarity (str): The similarity score used for the index
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
"""
logger.info("Creating Search Index %s on %s", index_name, collection.name)
try:
result = collection.create_search_index(
SearchIndexModel(
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
**kwargs,
),
name=index_name,
type="vectorSearch",
)
)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
err=f"{index_name=} did not complete in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info(result)
def drop_vector_search_index(
collection: Collection,
index_name: str,
*,
wait_until_complete: Optional[float] = None,
) -> None:
"""Drop a created vector search index
Args:
collection (Collection): MongoDB Collection with index to be dropped
index_name (str): Name of the MongoDB index
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
"""
logger.info(
"Dropping Search Index %s from Collection: %s", index_name, collection.name
)
try:
collection.drop_search_index(index_name)
except OperationFailure as e:
if "CommandNotSupported" in str(e):
raise OperationFailure(_search_index_error_message()) from e
# else this most likely means an ongoing drop request was made so skip
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: len(list(collection.list_search_indexes())) == 0,
err=f"Index {index_name} did not drop in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info("Vector Search index %s.%s dropped", collection.name, index_name)
def update_vector_search_index(
collection: Collection,
index_name: str,
dimensions: int,
path: str,
similarity: str,
filters: Optional[List[str]] = None,
*,
wait_until_complete: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Update a search index.
Replace the existing index definition with the provided definition.
Args:
collection (Collection): MongoDB Collection
index_name (str): Name of Index
dimensions (int): Number of dimensions in embedding
path (str): field with vector embedding
similarity (str): The similarity score used for the index.
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
"""
logger.info(
"Updating Search Index %s from Collection: %s", index_name, collection.name
)
try:
collection.update_search_index(
name=index_name,
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
**kwargs,
),
)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
err=f"Index {index_name} update did not complete in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info("Update succeeded")
def _is_index_ready(collection: Collection, index_name: str) -> bool:
"""Check for the index name in the list of available search indexes to see if the
specified index is of status READY
Args:
collection (Collection): MongoDB Collection to for the search indexes
index_name (str): Vector Search Index name
Returns:
bool : True if the index is present and READY false otherwise
"""
try:
search_indexes = collection.list_search_indexes(index_name)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e
for index in search_indexes:
if index["type"] == "vectorSearch" and index["status"] == "READY":
return True
return False
def _wait_for_predicate(
predicate: Callable, err: str, timeout: float = 120, interval: float = 0.5
) -> None:
"""Generic to block until the predicate returns true
Args:
predicate (Callable[, bool]): A function that returns a boolean value
err (str): Error message to raise if nothing occurs
timeout (float, optional): Wait time for predicate. Defaults to TIMEOUT.
interval (float, optional): Interval to check predicate. Defaults to DELAY.
Raises:
TimeoutError: _description_
"""
start = monotonic()
while not predicate():
if monotonic() - start > timeout:
raise TimeoutError(err)
sleep(interval)
def create_fulltext_search_index(
collection: Collection,
index_name: str,
field: str,
*,
wait_until_complete: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Experimental Utility function to create an Atlas Search index
Args:
collection (Collection): MongoDB Collection
index_name (str): Name of Index
field (str): Field to index
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
"""
logger.info("Creating Search Index %s on %s", index_name, collection.name)
definition = {
"mappings": {"dynamic": False, "fields": {field: [{"type": "string"}]}}
}
try:
result = collection.create_search_index(
SearchIndexModel(
definition=definition,
name=index_name,
type="search",
**kwargs,
)
)
except OperationFailure as e:
raise OperationFailure(_search_index_error_message()) from e
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
err=f"{index_name=} did not complete in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info(result)

View File

@ -1,160 +0,0 @@
"""Aggregation pipeline components used in Atlas Full-Text, Vector, and Hybrid Search
See the following for more:
- `Full-Text Search <https://www.mongodb.com/docs/atlas/atlas-search/aggregation-stages/search/#mongodb-pipeline-pipe.-search>`_
- `MongoDB Operators <https://www.mongodb.com/docs/atlas/atlas-search/operators-and-collectors/#std-label-operators-ref>`_
- `Vector Search <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/>`_
- `Filter Example <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
"""
from typing import Any, Dict, List, Optional
def text_search_stage(
query: str,
search_field: str,
index_name: str,
limit: Optional[int] = None,
filter: Optional[Dict[str, Any]] = None,
include_scores: Optional[bool] = True,
**kwargs: Any,
) -> List[Dict[str, Any]]: # noqa: E501
"""Full-Text search using Lucene's standard (BM25) analyzer
Args:
query: Input text to search for
search_field: Field in Collection that will be searched
index_name: Atlas Search Index name
limit: Maximum number of documents to return. Default of no limit
filter: Any MQL match expression comparing an indexed field
include_scores: Scores provide measure of relative relevance
Returns:
Dictionary defining the $search stage
"""
pipeline = [
{
"$search": {
"index": index_name,
"text": {"query": query, "path": search_field},
}
}
]
if filter:
pipeline.append({"$match": filter}) # type: ignore
if include_scores:
pipeline.append({"$set": {"score": {"$meta": "searchScore"}}})
if limit:
pipeline.append({"$limit": limit}) # type: ignore
return pipeline # type: ignore
def vector_search_stage(
query_vector: List[float],
search_field: str,
index_name: str,
top_k: int = 4,
filter: Optional[Dict[str, Any]] = None,
oversampling_factor: int = 10,
**kwargs: Any,
) -> Dict[str, Any]: # noqa: E501
"""Vector Search Stage without Scores.
Scoring is applied later depending on strategy.
vector search includes a vectorSearchScore that is typically used.
hybrid uses Reciprocal Rank Fusion.
Args:
query_vector: List of embedding vector
search_field: Field in Collection containing embedding vectors
index_name: Name of Atlas Vector Search Index tied to Collection
top_k: Number of documents to return
oversampling_factor: this times limit is the number of candidates
filter: MQL match expression comparing an indexed field.
Some operators are not supported.
See `vectorSearch filter docs <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_
Returns:
Dictionary defining the $vectorSearch
"""
stage = {
"index": index_name,
"path": search_field,
"queryVector": query_vector,
"numCandidates": top_k * oversampling_factor,
"limit": top_k,
}
if filter:
stage["filter"] = filter
return {"$vectorSearch": stage}
def combine_pipelines(
pipeline: List[Any], stage: List[Dict[str, Any]], collection_name: str
) -> None:
"""Combines two aggregations into a single result set in-place."""
if pipeline:
pipeline.append({"$unionWith": {"coll": collection_name, "pipeline": stage}})
else:
pipeline.extend(stage)
def reciprocal_rank_stage(
score_field: str, penalty: float = 0, **kwargs: Any
) -> List[Dict[str, Any]]:
"""Stage adds Reciprocal Rank Fusion weighting.
First, it pushes documents retrieved from previous stage
into a temporary sub-document. It then unwinds to establish
the rank to each and applies the penalty.
Args:
score_field: A unique string to identify the search being ranked
penalty: A non-negative float.
extra_fields: Any fields other than text_field that one wishes to keep.
Returns:
RRF score := \frac{1}{rank + penalty} with rank in [1,2,..,n]
"""
rrf_pipeline = [
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
{
"$addFields": {
f"docs.{score_field}": {
"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]
},
"docs.rank": "$rank",
"_id": "$docs._id",
}
},
{"$replaceRoot": {"newRoot": "$docs"}},
]
return rrf_pipeline # type: ignore
def final_hybrid_stage(
scores_fields: List[str], limit: int, **kwargs: Any
) -> List[Dict[str, Any]]:
"""Sum weighted scores, sort, and apply limit.
Args:
scores_fields: List of fields given to scores of vector and text searches
limit: Number of documents to return
Returns:
Final aggregation stages
"""
return [
{"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}},
{"$replaceRoot": {"newRoot": "$docs"}},
{"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}},
{"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}},
{"$sort": {"score": -1}},
{"$limit": limit},
]

View File

@ -1,15 +0,0 @@
"""Search Retrievers of various types.
Use ``MongoDBAtlasVectorSearch.as_retriever(**)``
to create MongoDB's core Vector Search Retriever.
"""
from langchain_mongodb.retrievers.full_text_search import (
MongoDBAtlasFullTextSearchRetriever,
)
from langchain_mongodb.retrievers.hybrid_search import MongoDBAtlasHybridSearchRetriever
__all__ = [
"MongoDBAtlasHybridSearchRetriever",
"MongoDBAtlasFullTextSearchRetriever",
]

View File

@ -1,59 +0,0 @@
from typing import Any, Dict, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pymongo.collection import Collection
from langchain_mongodb.pipelines import text_search_stage
from langchain_mongodb.utils import make_serializable
class MongoDBAtlasFullTextSearchRetriever(BaseRetriever):
"""Hybrid Search Retriever performs full-text searches
using Lucene's standard (BM25) analyzer.
"""
collection: Collection
"""MongoDB Collection on an Atlas cluster"""
search_index_name: str
"""Atlas Search Index name"""
search_field: str
"""Collection field that contains the text to be searched. It must be indexed"""
top_k: Optional[int] = None
"""Number of documents to return. Default is no limit"""
filter: Optional[Dict[str, Any]] = None
"""(Optional) List of MQL match expression comparing an indexed field"""
show_embeddings: float = False
"""If true, returned Document metadata will include vectors"""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve documents that are highest scoring / most similar to query.
Args:
query: String to find relevant documents for
run_manager: The callback handler to use
Returns:
List of relevant documents
"""
pipeline = text_search_stage( # type: ignore
query=query,
search_field=self.search_field,
index_name=self.search_index_name,
limit=self.top_k,
filter=self.filter,
)
# Execution
cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type]
# Formatting
docs = []
for res in cursor:
text = res.pop(self.search_field)
make_serializable(res)
docs.append(Document(page_content=text, metadata=res))
return docs

View File

@ -1,126 +0,0 @@
from typing import Any, Dict, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pymongo.collection import Collection
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.pipelines import (
combine_pipelines,
final_hybrid_stage,
reciprocal_rank_stage,
text_search_stage,
vector_search_stage,
)
from langchain_mongodb.utils import make_serializable
class MongoDBAtlasHybridSearchRetriever(BaseRetriever):
"""Hybrid Search Retriever combines vector and full-text searches
weighting them the via Reciprocal Rank Fusion (RRF) algorithm.
Increasing the vector_penalty will reduce the importance on the vector search.
Increasing the fulltext_penalty will correspondingly reduce the fulltext score.
For more on the algorithm,see
https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking
"""
vectorstore: MongoDBAtlasVectorSearch
"""MongoDBAtlas VectorStore"""
search_index_name: str
"""Atlas Search Index (full-text) name"""
top_k: int = 4
"""Number of documents to return."""
oversampling_factor: int = 10
"""This times top_k is the number of candidates chosen at each step"""
pre_filter: Optional[Dict[str, Any]] = None
"""(Optional) Any MQL match expression comparing an indexed field"""
post_filter: Optional[List[Dict[str, Any]]] = None
"""(Optional) Pipeline of MongoDB aggregation stages for postprocessing."""
vector_penalty: float = 60.0
"""Penalty applied to vector search results in RRF: scores=1/(rank + penalty)"""
fulltext_penalty: float = 60.0
"""Penalty applied to full-text search results in RRF: scores=1/(rank + penalty)"""
show_embeddings: float = False
"""If true, returned Document metadata will include vectors."""
@property
def collection(self) -> Collection:
return self.vectorstore._collection
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve documents that are highest scoring / most similar to query.
Note that the same query is used in both searches,
embedded for vector search, and as-is for full-text search.
Args:
query: String to find relevant documents for
run_manager: The callback handler to use
Returns:
List of relevant documents
"""
query_vector = self.vectorstore._embedding.embed_query(query)
scores_fields = ["vector_score", "fulltext_score"]
pipeline: List[Any] = []
# First we build up the aggregation pipeline,
# then it is passed to the server to execute
# Vector Search stage
vector_pipeline = [
vector_search_stage(
query_vector=query_vector,
search_field=self.vectorstore._embedding_key,
index_name=self.vectorstore._index_name,
top_k=self.top_k,
filter=self.pre_filter,
oversampling_factor=self.oversampling_factor,
)
]
vector_pipeline += reciprocal_rank_stage("vector_score", self.vector_penalty)
combine_pipelines(pipeline, vector_pipeline, self.collection.name)
# Full-Text Search stage
text_pipeline = text_search_stage(
query=query,
search_field=self.vectorstore._text_key,
index_name=self.search_index_name,
limit=self.top_k,
filter=self.pre_filter,
)
text_pipeline.extend(
reciprocal_rank_stage("fulltext_score", self.fulltext_penalty)
)
combine_pipelines(pipeline, text_pipeline, self.collection.name)
# Sum and sort stage
pipeline.extend(
final_hybrid_stage(scores_fields=scores_fields, limit=self.top_k)
)
# Removal of embeddings unless requested.
if not self.show_embeddings:
pipeline.append({"$project": {self.vectorstore._embedding_key: 0}})
# Post filtering
if self.post_filter is not None:
pipeline.extend(self.post_filter)
# Execution
cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type]
# Formatting
docs = []
for res in cursor:
text = res.pop(self.vectorstore._text_key)
# score = res.pop("score") # The score remains buried!
make_serializable(res)
docs.append(Document(page_content=text, metadata=res))
return docs

View File

@ -1,183 +0,0 @@
"""Various Utility Functions
- Tools for handling bson.ObjectId
The help IDs live as ObjectId in MongoDB and str in Langchain and JSON.
- Tools for the Maximal Marginal Relevance (MMR) reranking
These are duplicated from langchain_community to avoid cross-dependencies.
Functions "maximal_marginal_relevance" and "cosine_similarity"
are duplicated in this utility respectively from modules:
- "libs/community/langchain_community/vectorstores/utils.py"
- "libs/community/langchain_community/utils/math.py"
"""
from __future__ import annotations
import logging
from datetime import date, datetime
from typing import Any, Dict, List, Union
import numpy as np
logger = logging.getLogger(__name__)
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
"""Row-wise cosine similarity between two equal-width matrices."""
if len(X) == 0 or len(Y) == 0:
return np.array([])
X = np.array(X)
Y = np.array(Y)
if X.shape[1] != Y.shape[1]:
raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
f"and Y has shape {Y.shape}."
)
try:
import simsimd as simd
X = np.array(X, dtype=np.float32)
Y = np.array(Y, dtype=np.float32)
Z = 1 - np.array(simd.cdist(X, Y, metric="cosine"))
return Z
except ImportError:
logger.debug(
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
"to use simsimd please install with `pip install simsimd`."
)
X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1)
# Ignore divide by zero errors run time warnings as those are handled below.
with np.errstate(divide="ignore", invalid="ignore"):
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity
def maximal_marginal_relevance(
query_embedding: np.ndarray,
embedding_list: list,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[int]:
"""Compute Maximal Marginal Relevance (MMR).
MMR is a technique used to select documents that are both relevant to the query
and diverse among themselves. This function returns the indices
of the top-k embeddings that maximize the marginal relevance.
Args:
query_embedding (np.ndarray): The embedding vector of the query.
embedding_list (list of np.ndarray): A list containing the embedding vectors
of the candidate documents.
lambda_mult (float, optional): The trade-off parameter between
relevance and diversity. Defaults to 0.5.
k (int, optional): The number of embeddings to select. Defaults to 4.
Returns:
list of int: The indices of the embeddings that maximize the marginal relevance.
Notes:
The Maximal Marginal Relevance (MMR) is computed using the following formula:
MMR = argmax_{D_i R \ S} [λ * Sim(D_i, Q) - (1 - λ) * max_{D_j S} Sim(D_i, D_j)]
where:
- R is the set of candidate documents,
- S is the set of selected documents,
- Q is the query embedding,
- Sim(D_i, Q) is the similarity between document D_i and the query,
- Sim(D_i, D_j) is the similarity between documents D_i and D_j,
- λ is the trade-off parameter.
"""
if min(k, len(embedding_list)) <= 0:
return []
if query_embedding.ndim == 1:
query_embedding = np.expand_dims(query_embedding, axis=0)
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
most_similar = int(np.argmax(similarity_to_query))
idxs = [most_similar]
selected = np.array([embedding_list[most_similar]])
while len(idxs) < min(k, len(embedding_list)):
best_score = -np.inf
idx_to_add = -1
similarity_to_selected = cosine_similarity(embedding_list, selected)
for i, query_score in enumerate(similarity_to_query):
if i in idxs:
continue
redundant_score = max(similarity_to_selected[i])
equation_score = (
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
)
if equation_score > best_score:
best_score = equation_score
idx_to_add = i
idxs.append(idx_to_add)
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
return idxs
def str_to_oid(str_repr: str) -> Any | str:
"""Attempt to cast string representation of id to MongoDB's internal BSON ObjectId.
To be consistent with ObjectId, input must be a 24 character hex string.
If it is not, MongoDB will happily use the string in the main _id index.
Importantly, the str representation that comes out of MongoDB will have this form.
Args:
str_repr: id as string.
Returns:
ObjectID
"""
from bson import ObjectId
from bson.errors import InvalidId
try:
return ObjectId(str_repr)
except InvalidId:
logger.debug(
"ObjectIds must be 12-character byte or 24-character hex strings. "
"Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'"
)
return str_repr
def oid_to_str(oid: Any) -> str:
"""Convert MongoDB's internal BSON ObjectId into a simple str for compatibility.
Instructive helper to show where data is coming out of MongoDB.
Args:
oid: bson.ObjectId
Returns:
24 character hex string.
"""
return str(oid)
def make_serializable(
obj: Dict[str, Any],
) -> None:
"""Recursively cast values in a dict to a form able to json.dump"""
from bson import ObjectId
for k, v in obj.items():
if isinstance(v, dict):
make_serializable(v)
elif isinstance(v, list) and v and isinstance(v[0], (ObjectId, date, datetime)):
obj[k] = [oid_to_str(item) for item in v]
elif isinstance(v, ObjectId):
obj[k] = oid_to_str(v)
elif isinstance(v, (datetime, date)):
obj[k] = v.isoformat()

View File

@ -1,796 +0,0 @@
from __future__ import annotations
import logging
from importlib.metadata import version
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Union,
)
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo
from pymongo.errors import CollectionInvalid
from langchain_mongodb.index import (
create_vector_search_index,
update_vector_search_index,
)
from langchain_mongodb.pipelines import vector_search_stage
from langchain_mongodb.utils import (
make_serializable,
maximal_marginal_relevance,
oid_to_str,
str_to_oid,
)
VST = TypeVar("VST", bound=VectorStore)
logger = logging.getLogger(__name__)
DEFAULT_INSERT_BATCH_SIZE = 100_000
class MongoDBAtlasVectorSearch(VectorStore):
"""MongoDB Atlas vector store integration.
MongoDBAtlasVectorSearch performs data operations on
text, embeddings and arbitrary data. In addition to CRUD operations,
the VectorStore provides Vector Search
based on similarity of embedding vectors following the
Hierarchical Navigable Small Worlds (HNSW) algorithm.
This supports a number of models to ascertain scores,
"similarity" (default), "MMR", and "similarity_score_threshold".
These are described in the search_type argument to as_retriever,
which provides the Runnable.invoke(query) API, allowing
MongoDBAtlasVectorSearch to be used within a chain.
Setup:
* Set up a MongoDB Atlas cluster. The free tier M0 will allow you to start.
Search Indexes are only available on Atlas, the fully managed cloud service,
not the self-managed MongoDB.
Follow [this guide](https://www.mongodb.com/basics/mongodb-atlas-tutorial)
* Create a Collection and a Vector Search Index.The procedure is described
[here](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure).
* Install ``langchain-mongodb``
.. code-block:: bash
pip install -qU langchain-mongodb pymongo
.. code-block:: python
import getpass
MONGODB_ATLAS_CLUSTER_URI = getpass.getpass("MongoDB Atlas Cluster URI:")
Key init args indexing params:
embedding: Embeddings
Embedding function to use.
Key init args client params:
collection: Collection
MongoDB collection to use.
index_name: str
Name of the Atlas Search index.
Instantiate:
.. code-block:: python
from pymongo import MongoClient
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
from pymongo import MongoClient
from langchain_openai import OpenAIEmbeddings
# initialize MongoDB python client
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_vectorstores"
ATLAS_VECTOR_SEARCH_INDEX_NAME = "langchain-test-index-vectorstores"
MONGODB_COLLECTION = client[DB_NAME][COLLECTION_NAME]
vector_store = MongoDBAtlasVectorSearch(
collection=MONGODB_COLLECTION,
embedding=OpenAIEmbeddings(),
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
relevance_score_fn="cosine",
)
Add Documents:
.. code-block:: python
from langchain_core.documents import Document
document_1 = Document(page_content="foo", metadata={"baz": "bar"})
document_2 = Document(page_content="thud", metadata={"bar": "baz"})
document_3 = Document(page_content="i will be deleted :(")
documents = [document_1, document_2, document_3]
ids = ["1", "2", "3"]
vector_store.add_documents(documents=documents, ids=ids)
Delete Documents:
.. code-block:: python
vector_store.delete(ids=["3"])
Search:
.. code-block:: python
results = vector_store.similarity_search(query="thud",k=1)
for doc in results:
print(f"* {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* thud [{'_id': '2', 'baz': 'baz'}]
Search with filter:
.. code-block:: python
results = vector_store.similarity_search(query="thud",k=1,post_filter=[{"bar": "baz"]})
for doc in results:
print(f"* {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* thud [{'_id': '2', 'baz': 'baz'}]
Search with score:
.. code-block:: python
results = vector_store.similarity_search_with_score(query="qux",k=1)
for doc, score in results:
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* [SIM=0.916096] foo [{'_id': '1', 'baz': 'bar'}]
Async:
.. code-block:: python
# add documents
# await vector_store.aadd_documents(documents=documents, ids=ids)
# delete documents
# await vector_store.adelete(ids=["3"])
# search
# results = vector_store.asimilarity_search(query="thud",k=1)
# search with score
results = await vector_store.asimilarity_search_with_score(query="qux",k=1)
for doc,score in results:
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* [SIM=0.916096] foo [{'_id': '1', 'baz': 'bar'}]
Use as Retriever:
.. code-block:: python
retriever = vector_store.as_retriever(
search_type="mmr",
search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5},
)
retriever.invoke("thud")
.. code-block:: python
[Document(metadata={'_id': '2', 'embedding': [-0.01850726455450058, -0.0014740974875167012, -0.009762819856405258, ...], 'baz': 'baz'}, page_content='thud')]
""" # noqa: E501
def __init__(
self,
collection: Collection[Dict[str, Any]],
embedding: Embeddings,
index_name: str = "vector_index",
text_key: str = "text",
embedding_key: str = "embedding",
relevance_score_fn: str = "cosine",
**kwargs: Any,
):
"""
Args:
collection: MongoDB collection to add the texts to
embedding: Text embedding model to use
text_key: MongoDB field that will contain the text for each document
index_name: Existing Atlas Vector Search Index
embedding_key: Field that will contain the embedding for each document
vector_index_name: Name of the Atlas Vector Search index
relevance_score_fn: The similarity score used for the index
Currently supported: 'euclidean', 'cosine', and 'dotProduct'
"""
self._collection = collection
self._embedding = embedding
self._index_name = index_name
self._text_key = text_key
self._embedding_key = embedding_key
self._relevance_score_fn = relevance_score_fn
@property
def embeddings(self) -> Embeddings:
return self._embedding
def _select_relevance_score_fn(self) -> Callable[[float], float]:
scoring: dict[str, Callable] = {
"euclidean": self._euclidean_relevance_score_fn,
"dotProduct": self._max_inner_product_relevance_score_fn,
"cosine": self._cosine_relevance_score_fn,
}
if self._relevance_score_fn in scoring:
return scoring[self._relevance_score_fn]
else:
raise NotImplementedError(
f"No relevance score function for ${self._relevance_score_fn}"
)
@classmethod
def from_connection_string(
cls,
connection_string: str,
namespace: str,
embedding: Embeddings,
**kwargs: Any,
) -> MongoDBAtlasVectorSearch:
"""Construct a `MongoDB Atlas Vector Search` vector store
from a MongoDB connection URI.
Args:
connection_string: A valid MongoDB connection URI.
namespace: A valid MongoDB namespace (database and collection).
embedding: The text embedding model to use for the vector store.
Returns:
A new MongoDBAtlasVectorSearch instance.
"""
client: MongoClient = MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain")),
)
db_name, collection_name = namespace.split(".")
collection = client[db_name][collection_name]
return cls(collection, embedding, **kwargs)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Add texts, create embeddings, and add to the Collection and index.
Important notes on ids:
- If _id or id is a key in the metadatas dicts, one must
pop them and provide as separate list.
- They must be unique.
- If they are not provided, the VectorStore will create unique ones,
stored as bson.ObjectIds internally, and strings in Langchain.
These will appear in Document.metadata with key, '_id'.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of unique ids that will be used as index in VectorStore.
See note on ids.
Returns:
List of ids added to the vectorstore.
"""
# Check to see if metadata includes ids
if metadatas is not None and (
metadatas[0].get("_id") or metadatas[0].get("id")
):
logger.warning(
"_id or id key found in metadata. "
"Please pop from each dict and input as separate list."
"Retrieving methods will include the same id as '_id' in metadata."
)
texts_batch = texts
_metadatas: Union[List, Generator] = metadatas or ({} for _ in texts)
metadatas_batch = _metadatas
result_ids = []
batch_size = kwargs.get("batch_size", DEFAULT_INSERT_BATCH_SIZE)
if batch_size:
texts_batch = []
metadatas_batch = []
size = 0
i = 0
for j, (text, metadata) in enumerate(zip(texts, _metadatas)):
size += len(text) + len(metadata)
texts_batch.append(text)
metadatas_batch.append(metadata)
if (j + 1) % batch_size == 0 or size >= 47_000_000:
if ids:
batch_res = self.bulk_embed_and_insert_texts(
texts_batch, metadatas_batch, ids[i : j + 1]
)
else:
batch_res = self.bulk_embed_and_insert_texts(
texts_batch, metadatas_batch
)
result_ids.extend(batch_res)
texts_batch = []
metadatas_batch = []
size = 0
i = j + 1
if texts_batch:
if ids:
batch_res = self.bulk_embed_and_insert_texts(
texts_batch, metadatas_batch, ids[i : j + 1]
)
else:
batch_res = self.bulk_embed_and_insert_texts(
texts_batch, metadatas_batch
)
result_ids.extend(batch_res)
return result_ids
def bulk_embed_and_insert_texts(
self,
texts: Union[List[str], Iterable[str]],
metadatas: Union[List[dict], Generator[dict, Any, Any]],
ids: Optional[List[str]] = None,
) -> List[str]:
"""Bulk insert single batch of texts, embeddings, and optionally ids.
See add_texts for additional details.
"""
if not texts:
return []
# Compute embedding vectors
embeddings = self._embedding.embed_documents(texts) # type: ignore
if ids:
to_insert = [
{
"_id": str_to_oid(i),
self._text_key: t,
self._embedding_key: embedding,
**m,
}
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings)
]
else:
to_insert = [
{self._text_key: t, self._embedding_key: embedding, **m}
for t, m, embedding in zip(texts, metadatas, embeddings)
]
# insert the documents in MongoDB Atlas
insert_result = self._collection.insert_many(to_insert) # type: ignore
return [oid_to_str(_id) for _id in insert_result.inserted_ids]
def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
batch_size: int = DEFAULT_INSERT_BATCH_SIZE,
**kwargs: Any,
) -> List[str]:
"""Add documents to the vectorstore.
Args:
documents: Documents to add to the vectorstore.
ids: Optional list of unique ids that will be used as index in VectorStore.
See note on ids in add_texts.
batch_size: Number of documents to insert at a time.
Tuning this may help with performance and sidestep MongoDB limits.
Returns:
List of IDs of the added texts.
"""
n_docs = len(documents)
if ids:
assert len(ids) == n_docs, "Number of ids must equal number of documents."
result_ids = []
start = 0
for end in range(batch_size, n_docs + batch_size, batch_size):
texts, metadatas = zip(
*[(doc.page_content, doc.metadata) for doc in documents[start:end]]
)
if ids:
result_ids.extend(
self.bulk_embed_and_insert_texts(
texts=texts, metadatas=metadatas, ids=ids[start:end]
)
)
else:
result_ids.extend(
self.bulk_embed_and_insert_texts(texts=texts, metadatas=metadatas)
)
start = end
return result_ids
def similarity_search_with_score(
self,
query: str,
k: int = 4,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
oversampling_factor: int = 10,
include_embeddings: bool = False,
**kwargs: Any,
) -> List[Tuple[Document, float]]: # noqa: E501
"""Return MongoDB documents most similar to the given query and their scores.
Atlas Vector Search eliminates the need to run a separate
search system alongside your database.
Args:
query: Input text of semantic query
k: Number of documents to return. Also known as top_k.
pre_filter: List of MQL match expressions comparing an indexed field
post_filter_pipeline: (Optional) Arbitrary pipeline of MongoDB
aggregation stages applied after the search is complete.
oversampling_factor: This times k is the number of candidates chosen
at each step in the in HNSW Vector Search
include_embeddings: If True, the embedding vector of each result
will be included in metadata.
kwargs: Additional arguments are specific to the search_type
Returns:
List of documents most similar to the query and their scores.
"""
embedding = self._embedding.embed_query(query)
docs = self._similarity_search_with_score(
embedding,
k=k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
oversampling_factor=oversampling_factor,
include_embeddings=include_embeddings,
**kwargs,
)
return docs
def similarity_search(
self,
query: str,
k: int = 4,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
oversampling_factor: int = 10,
include_scores: bool = False,
include_embeddings: bool = False,
**kwargs: Any,
) -> List[Document]: # noqa: E501
"""Return MongoDB documents most similar to the given query.
Atlas Vector Search eliminates the need to run a separate
search system alongside your database.
Args:
query: Input text of semantic query
k: (Optional) number of documents to return. Defaults to 4.
pre_filter: List of MQL match expressions comparing an indexed field
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
to filter/process results after $vectorSearch.
oversampling_factor: Multiple of k used when generating number of candidates
at each step in the HNSW Vector Search,
include_scores: If True, the query score of each result
will be included in metadata.
include_embeddings: If True, the embedding vector of each result
will be included in metadata.
kwargs: Additional arguments are specific to the search_type
Returns:
List of documents most similar to the query and their scores.
"""
docs_and_scores = self.similarity_search_with_score(
query,
k=k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
oversampling_factor=oversampling_factor,
include_embeddings=include_embeddings,
**kwargs,
)
if include_scores:
for doc, score in docs_and_scores:
doc.metadata["score"] = score
return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return documents selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: (Optional) number of documents to return. Defaults to 4.
fetch_k: (Optional) number of documents to fetch before passing to MMR
algorithm. Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
pre_filter: List of MQL match expressions comparing an indexed field
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
following the $vectorSearch stage.
Returns:
List of documents selected by maximal marginal relevance.
"""
return self.max_marginal_relevance_search_by_vector(
embedding=self._embedding.embed_query(query),
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
**kwargs,
)
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[Dict]] = None,
collection: Optional[Collection] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> MongoDBAtlasVectorSearch:
"""Construct a `MongoDB Atlas Vector Search` vector store from raw documents.
This is a user-friendly interface that:
1. Embeds documents.
2. Adds the documents to a provided MongoDB Atlas Vector Search index
(Lucene)
This is intended to be a quick way to get started.
See `MongoDBAtlasVectorSearch` for kwargs and further description.
Example:
.. code-block:: python
from pymongo import MongoClient
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_openai import OpenAIEmbeddings
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
collection = mongo_client["<db_name>"]["<collection_name>"]
embeddings = OpenAIEmbeddings()
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
embeddings,
metadatas=metadatas,
collection=collection
)
"""
if collection is None:
raise ValueError("Must provide 'collection' named parameter.")
vectorstore = cls(collection, embedding, **kwargs)
vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs)
return vectorstore
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete documents from VectorStore by ids.
Args:
ids: List of ids to delete.
**kwargs: Other keyword arguments passed to Collection.delete_many()
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
filter = {}
if ids:
oids = [str_to_oid(i) for i in ids]
filter = {"_id": {"$in": oids}}
return self._collection.delete_many(filter=filter, **kwargs).acknowledged
async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any
) -> Optional[bool]:
"""Delete by vector ID or other criteria.
Args:
ids: List of ids to delete.
**kwargs: Other keyword arguments that subclasses might use.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
return await run_in_executor(None, self.delete, ids=ids, **kwargs)
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
oversampling_factor: int = 10,
**kwargs: Any,
) -> List[Document]: # type: ignore
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
pre_filter: (Optional) dictionary of arguments to filter document fields on.
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
following the vectorSearch stage.
oversampling_factor: Multiple of k used when generating number
of candidates in HNSW Vector Search,
kwargs: Additional arguments are specific to the search_type
Returns:
List of Documents selected by maximal marginal relevance.
"""
docs = self._similarity_search_with_score(
embedding,
k=fetch_k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
include_embeddings=True,
oversampling_factor=oversampling_factor,
**kwargs,
)
mmr_doc_indexes = maximal_marginal_relevance(
np.array(embedding),
[doc.metadata[self._embedding_key] for doc, _ in docs],
k=k,
lambda_mult=lambda_mult,
)
mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
return mmr_docs
async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
oversampling_factor: int = 10,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance."""
return await run_in_executor(
None,
self.max_marginal_relevance_search_by_vector, # type: ignore[arg-type]
embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
oversampling_factor=oversampling_factor,
**kwargs,
)
def _similarity_search_with_score(
self,
query_vector: List[float],
k: int = 4,
pre_filter: Optional[Dict[str, Any]] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
oversampling_factor: int = 10,
include_embeddings: bool = False,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Core search routine. See external methods for details."""
# Atlas Vector Search, potentially with filter
pipeline = [
vector_search_stage(
query_vector,
self._embedding_key,
self._index_name,
k,
pre_filter,
oversampling_factor,
**kwargs,
),
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]
# Remove embeddings unless requested.
if not include_embeddings:
pipeline.append({"$project": {self._embedding_key: 0}})
# Post-processing
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
# Execution
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
docs = []
# Format
for res in cursor:
text = res.pop(self._text_key)
score = res.pop("score")
make_serializable(res)
docs.append((Document(page_content=text, metadata=res), score))
return docs
def create_vector_search_index(
self,
dimensions: int,
filters: Optional[List[str]] = None,
update: bool = False,
) -> None:
"""Creates a MongoDB Atlas vectorSearch index for the VectorStore
Note**: This method may fail as it requires a MongoDB Atlas with these
`pre-requisites <https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#prerequisites>`.
Currently, vector and full-text search index operations need to be
performed manually on the Atlas UI for shared M0 clusters.
Args:
dimensions (int): Number of dimensions in embedding
filters (Optional[List[Dict[str, str]]], optional): additional filters
for index definition.
Defaults to None.
update (bool, optional): Updates existing vectorSearch index.
Defaults to False.
"""
try:
self._collection.database.create_collection(self._collection.name)
except CollectionInvalid:
pass
index_operation = (
update_vector_search_index if update else create_vector_search_index
)
index_operation(
collection=self._collection,
index_name=self._index_name,
dimensions=dimensions,
path=self._embedding_key,
similarity=self._relevance_score_fn,
filters=filters or [],
) # type: ignore [operator]

File diff suppressed because it is too large Load Diff

View File

@ -1,103 +0,0 @@
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-mongodb"
version = "0.2.0"
description = "An integration package connecting MongoDB and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.mypy]
disallow_untyped_defs = "True"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/mongodb"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-mongodb%3D%3D0%22&expanded=true"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
pymongo = ">=4.6.1,<5.0"
langchain-core = "^0.3"
[[tool.poetry.dependencies.numpy]]
version = "^1"
python = "<3.12"
[[tool.poetry.dependencies.numpy]]
version = "^1.26.0"
python = ">=3.12"
[tool.ruff.lint]
select = ["E", "F", "I"]
[tool.coverage.run]
omit = ["tests/*"]
[tool.pytest.ini_options]
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
markers = [
"requires: mark tests as requiring a specific library",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration.dependencies.langchain-openai]
path = "../openai"
develop = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.5"
[tool.poetry.group.typing.dependencies]
mypy = "^1.10"
simsimd = "^5.0.0"
[tool.poetry.group.test.dependencies.langchain]
path = "../../langchain"
develop = true
[tool.poetry.group.test.dependencies.langchain-core]
path = "../../core"
develop = true
[tool.poetry.group.test.dependencies.langchain-text-splitters]
path = "../../text-splitters"
develop = true
[tool.poetry.group.dev.dependencies.langchain-core]
path = "../../core"
develop = true
[tool.poetry.group.typing.dependencies.langchain-core]
path = "../../core"
develop = true

View File

@ -1,17 +0,0 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_failure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

View File

@ -1,17 +0,0 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@ -1,161 +0,0 @@
import os
import uuid
from typing import Any, List, Union
import pytest # type: ignore[import-not-found]
from langchain_core.caches import BaseCache
from langchain_core.globals import get_llm_cache, set_llm_cache
from langchain_core.load.dump import dumps
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache
from ..utils import ConsistentFakeEmbeddings, FakeChatModel, FakeLLM
CONN_STRING = os.environ.get("MONGODB_ATLAS_URI")
INDEX_NAME = "langchain-test-index-semantic-cache"
DATABASE = "langchain_test_db"
COLLECTION = "langchain_test_cache"
def random_string() -> str:
return str(uuid.uuid4())
def llm_cache(cls: Any) -> BaseCache:
set_llm_cache(
cls(
embedding=ConsistentFakeEmbeddings(dimensionality=1536),
connection_string=CONN_STRING,
collection_name=COLLECTION,
database_name=DATABASE,
index_name=INDEX_NAME,
score_threshold=0.5,
wait_until_ready=15.0,
)
)
assert get_llm_cache()
return get_llm_cache()
def _execute_test(
prompt: Union[str, List[BaseMessage]],
llm: Union[str, FakeLLM, FakeChatModel],
response: List[Generation],
) -> None:
# Fabricate an LLM String
if not isinstance(llm, str):
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
else:
llm_string = llm
# If the prompt is a str then we should pass just the string
dumped_prompt: str = prompt if isinstance(prompt, str) else dumps(prompt)
# Update the cache
get_llm_cache().update(dumped_prompt, llm_string, response)
# Retrieve the cached result through 'generate' call
output: Union[List[Generation], LLMResult, None]
expected_output: Union[List[Generation], LLMResult]
if isinstance(llm, str):
output = get_llm_cache().lookup(dumped_prompt, llm) # type: ignore
expected_output = response
else:
output = llm.generate([prompt]) # type: ignore
expected_output = LLMResult(
generations=[response],
llm_output={},
)
assert output == expected_output # type: ignore
@pytest.mark.parametrize(
"prompt, llm, response",
[
("foo", "bar", [Generation(text="fizz")]),
("foo", FakeLLM(), [Generation(text="fizz")]),
(
[HumanMessage(content="foo")],
FakeChatModel(),
[ChatGeneration(message=AIMessage(content="foo"))],
),
],
ids=[
"plain_cache",
"cache_with_llm",
"cache_with_chat",
],
)
@pytest.mark.parametrize("cacher", [MongoDBCache, MongoDBAtlasSemanticCache])
@pytest.mark.parametrize("remove_score", [True, False])
def test_mongodb_cache(
remove_score: bool,
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
prompt: Union[str, List[BaseMessage]],
llm: Union[str, FakeLLM, FakeChatModel],
response: List[Generation],
) -> None:
llm_cache(cacher)
if remove_score:
get_llm_cache().score_threshold = None # type: ignore
try:
_execute_test(prompt, llm, response)
finally:
get_llm_cache().clear()
@pytest.mark.parametrize(
"prompts, generations",
[
# Single prompt, single generation
([random_string()], [[random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string(), random_string()]]),
# Multiple prompts, multiple generations
(
[random_string(), random_string()],
[[random_string()], [random_string(), random_string()]],
),
],
ids=[
"single_prompt_single_generation",
"single_prompt_two_generations",
"single_prompt_three_generations",
"multiple_prompts_multiple_generations",
],
)
def test_mongodb_atlas_cache_matrix(
prompts: List[str],
generations: List[List[str]],
) -> None:
llm_cache(MongoDBAtlasSemanticCache)
llm = FakeLLM()
# Fabricate an LLM String
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_generations = [
[
Generation(text=generation, generation_info=params)
for generation in prompt_i_generations
]
for prompt_i_generations in generations
]
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
_execute_test(prompt_i, llm_string, llm_generations_i)
assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={}
)
get_llm_cache().clear()

View File

@ -1,144 +0,0 @@
"Demonstrates MongoDBAtlasVectorSearch.as_retriever() invoked in a chain" ""
from __future__ import annotations
import os
from time import sleep
import pytest # type: ignore[import-not-found]
from langchain_core.documents import Document
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from pymongo import MongoClient
from pymongo.collection import Collection
from langchain_mongodb import index
from ..utils import PatchedMongoDBAtlasVectorSearch
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_chain_example"
INDEX_NAME = "vector_index"
DIMENSIONS = 1536
TIMEOUT = 60.0
INTERVAL = 0.5
@pytest.fixture
def collection() -> Collection:
"""A Collection with both a Vector and a Full-text Search Index"""
client: MongoClient = MongoClient(CONNECTION_STRING)
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
clxn = client[DB_NAME][COLLECTION_NAME]
clxn.delete_many({})
if all([INDEX_NAME != ix["name"] for ix in clxn.list_search_indexes()]):
index.create_vector_search_index(
collection=clxn,
index_name=INDEX_NAME,
dimensions=DIMENSIONS,
path="embedding",
similarity="cosine",
filters=None,
wait_until_complete=TIMEOUT,
)
return clxn
@pytest.mark.skipif(
"OPENAI_API_KEY" not in os.environ, reason="Requires OpenAI for chat responses."
)
def test_chain(
collection: Collection,
) -> None:
"""Demonstrate usage of MongoDBAtlasVectorSearch in a realistic chain
Follows example in the docs: https://python.langchain.com/docs/how_to/hybrid/
Requires OpenAI_API_KEY for embedding and chat model.
Requires INDEX_NAME to have been set up on MONGODB_ATLAS_URI
"""
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
embedding_openai = OpenAIEmbeddings(
openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
model="text-embedding-3-small",
)
vectorstore = PatchedMongoDBAtlasVectorSearch(
collection=collection,
embedding=embedding_openai,
index_name=INDEX_NAME,
text_key="page_content",
)
texts = [
"In 2023, I visited Paris",
"In 2022, I visited New York",
"In 2021, I visited New Orleans",
"In 2019, I visited San Francisco",
"In 2020, I visited Vancouver",
]
vectorstore.add_texts(texts)
# Give the index time to build (For CI)
sleep(TIMEOUT)
query = "In the United States, what city did I visit last?"
# One can do vector search on the vector store, using its various search types.
k = len(texts)
store_output = list(vectorstore.similarity_search(query=query, k=k))
assert len(store_output) == k
assert isinstance(store_output[0], Document)
# Unfortunately, the VectorStore output cannot be given to a Chat Model
# If we wish Chat Model to answer based on our own data,
# we have to give it the right things to work with.
# The way that Langchain does this is by piping results along in
# a Chain: https://python.langchain.com/v0.1/docs/modules/chains/
# Now, we can turn our VectorStore into something Runnable in a Chain
# by turning it into a Retriever.
# For the simple VectorSearch Retriever, we can do this like so.
retriever = vectorstore.as_retriever(search_kwargs=dict(k=k))
# This does not do much other than expose our search function
# as an invoke() method with a a certain API, a Runnable.
retriever_output = retriever.invoke(query)
assert len(retriever_output) == len(texts)
assert retriever_output[0].page_content == store_output[0].page_content
# To get a natural language response to our question,
# we need ChatOpenAI, a template to better frame the question as a prompt,
# and a parser to send the output to a string.
# Together, these become our Chain!
# Here goes:
template = """Answer the question based only on the following context.
Answer in as few words as possible.
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
model = ChatOpenAI()
chain = (
{"context": retriever, "question": RunnablePassthrough()} # type: ignore
| prompt
| model
| StrOutputParser()
)
answer = chain.invoke("What city did I visit last?")
assert "Paris" in answer

View File

@ -1,43 +0,0 @@
import json
import os
from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
from langchain_core.messages import message_to_dict
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
DATABASE = "langchain_test_db"
COLLECTION = "langchain_test_chat"
# Replace these with your mongodb connection string
connection_string = os.environ.get("MONGODB_ATLAS_URI", "")
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup MongoDB as a message store
message_history = MongoDBChatMessageHistory(
connection_string=connection_string,
session_id="test-session",
database_name=DATABASE,
collection_name=COLLECTION,
)
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from MongoDB, so the next test run won't pick it up
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@ -1,7 +0,0 @@
import pytest # type: ignore[import-not-found]
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -1,83 +0,0 @@
"""Search index commands are only supported on Atlas Clusters >=M10"""
import os
from typing import Generator, List, Optional
import pytest
from pymongo import MongoClient
from pymongo.collection import Collection
from langchain_mongodb import index
DB_NAME = "langchain_test_index_db"
COLLECTION_NAME = "test_index"
VECTOR_INDEX_NAME = "vector_index"
TIMEOUT = 120
DIMENSIONS = 10
@pytest.fixture
def collection() -> Generator:
"""Depending on uri, this could point to any type of cluster."""
uri = os.environ.get("MONGODB_ATLAS_URI")
client: MongoClient = MongoClient(uri)
clxn = client[DB_NAME][COLLECTION_NAME]
clxn.insert_one({"foo": "bar"})
yield clxn
clxn.drop()
def test_search_index_commands(collection: Collection) -> None:
index_name = VECTOR_INDEX_NAME
dimensions = DIMENSIONS
path = "embedding"
similarity = "cosine"
filters: Optional[List[str]] = None
wait_until_complete = TIMEOUT
for index_info in collection.list_search_indexes():
index.drop_vector_search_index(
collection, index_info["name"], wait_until_complete=wait_until_complete
)
assert len(list(collection.list_search_indexes())) == 0
index.create_vector_search_index(
collection=collection,
index_name=index_name,
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
wait_until_complete=wait_until_complete,
)
assert index._is_index_ready(collection, index_name)
indexes = list(collection.list_search_indexes())
assert len(indexes) == 1
assert indexes[0]["name"] == index_name
new_similarity = "euclidean"
index.update_vector_search_index(
collection,
index_name,
DIMENSIONS,
"embedding",
new_similarity,
filters=[],
wait_until_complete=wait_until_complete,
)
assert index._is_index_ready(collection, index_name)
indexes = list(collection.list_search_indexes())
assert len(indexes) == 1
assert indexes[0]["name"] == index_name
assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == new_similarity
index.drop_vector_search_index(
collection, index_name, wait_until_complete=wait_until_complete
)
indexes = list(collection.list_search_indexes())
assert len(indexes) == 0

View File

@ -1,176 +0,0 @@
import os
from time import sleep
from typing import List
import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from pymongo import MongoClient
from pymongo.collection import Collection
from langchain_mongodb import index
from langchain_mongodb.retrievers import (
MongoDBAtlasFullTextSearchRetriever,
MongoDBAtlasHybridSearchRetriever,
)
from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_retrievers"
VECTOR_INDEX_NAME = "vector_index"
EMBEDDING_FIELD = "embedding"
PAGE_CONTENT_FIELD = "text"
SEARCH_INDEX_NAME = "text_index"
DIMENSIONS = 1536
TIMEOUT = 60.0
INTERVAL = 0.5
@pytest.fixture
def example_documents() -> List[Document]:
return [
Document(page_content="In 2023, I visited Paris"),
Document(page_content="In 2022, I visited New York"),
Document(page_content="In 2021, I visited New Orleans"),
Document(page_content="Sandwiches are beautiful. Sandwiches are fine."),
]
@pytest.fixture
def embedding_openai() -> Embeddings:
from langchain_openai import OpenAIEmbeddings
try:
return OpenAIEmbeddings(
openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
model="text-embedding-3-small",
)
except Exception:
return ConsistentFakeEmbeddings(DIMENSIONS)
@pytest.fixture
def collection() -> Collection:
"""A Collection with both a Vector and a Full-text Search Index"""
client: MongoClient = MongoClient(CONNECTION_STRING)
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
clxn = client[DB_NAME][COLLECTION_NAME]
clxn.delete_many({})
if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
index.create_vector_search_index(
collection=clxn,
index_name=VECTOR_INDEX_NAME,
dimensions=DIMENSIONS,
path="embedding",
similarity="cosine",
wait_until_complete=TIMEOUT,
)
if not any([SEARCH_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
index.create_fulltext_search_index(
collection=clxn,
index_name=SEARCH_INDEX_NAME,
field=PAGE_CONTENT_FIELD,
wait_until_complete=TIMEOUT,
)
return clxn
def test_hybrid_retriever(
embedding_openai: Embeddings,
collection: Collection,
example_documents: List[Document],
) -> None:
"""Test basic usage of MongoDBAtlasHybridSearchRetriever"""
vectorstore = PatchedMongoDBAtlasVectorSearch(
collection=collection,
embedding=embedding_openai,
index_name=VECTOR_INDEX_NAME,
text_key=PAGE_CONTENT_FIELD,
)
vectorstore.add_documents(example_documents)
sleep(TIMEOUT) # Wait for documents to be sync'd
retriever = MongoDBAtlasHybridSearchRetriever(
vectorstore=vectorstore,
search_index_name=SEARCH_INDEX_NAME,
top_k=3,
)
query1 = "What was the latest city that I visited?"
results = retriever.invoke(query1)
assert len(results) == 3
assert "Paris" in results[0].page_content
query2 = "When was the last time I visited new orleans?"
results = retriever.invoke(query2)
assert "New Orleans" in results[0].page_content
def test_fulltext_retriever(
collection: Collection,
example_documents: List[Document],
) -> None:
"""Test result of performing fulltext search
Independent of the VectorStore, one adds documents
via MongoDB's Collection API
"""
#
collection.insert_many(
[{PAGE_CONTENT_FIELD: doc.page_content} for doc in example_documents]
)
sleep(TIMEOUT) # Wait for documents to be sync'd
retriever = MongoDBAtlasFullTextSearchRetriever(
collection=collection,
search_index_name=SEARCH_INDEX_NAME,
search_field=PAGE_CONTENT_FIELD,
)
query = "When was the last time I visited new orleans?"
results = retriever.invoke(query)
assert "New Orleans" in results[0].page_content
assert "score" in results[0].metadata
def test_vector_retriever(
embedding_openai: Embeddings,
collection: Collection,
example_documents: List[Document],
) -> None:
"""Test VectorStoreRetriever"""
vectorstore = PatchedMongoDBAtlasVectorSearch(
collection=collection,
embedding=embedding_openai,
index_name=VECTOR_INDEX_NAME,
text_key=PAGE_CONTENT_FIELD,
)
vectorstore.add_documents(example_documents)
sleep(TIMEOUT) # Wait for documents to be sync'd
retriever = vectorstore.as_retriever()
query1 = "What was the latest city that I visited?"
results = retriever.invoke(query1)
assert len(results) == 4
assert "Paris" in results[0].page_content
query2 = "When was the last time I visited new orleans?"
results = retriever.invoke(query2)
assert "New Orleans" in results[0].page_content

View File

@ -1,473 +0,0 @@
"""Test MongoDB Atlas Vector Search functionality."""
from __future__ import annotations
import os
from time import monotonic, sleep
from typing import Any, Dict, List
import pytest # type: ignore[import-not-found]
from bson import ObjectId
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.errors import OperationFailure
from langchain_mongodb.index import drop_vector_search_index
from langchain_mongodb.utils import oid_to_str
from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch
INDEX_NAME = "langchain-test-index-vectorstores"
INDEX_CREATION_NAME = "langchain-test-index-vectorstores-create-test"
NAMESPACE = "langchain_test_db.langchain_test_vectorstores"
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
INDEX_COLLECTION_NAME = "langchain_test_vectorstores_index"
INDEX_DB_NAME = "langchain_test_index_db"
DIMENSIONS = 1536
TIMEOUT = 120.0
INTERVAL = 0.5
@pytest.fixture
def example_documents() -> List[Document]:
return [
Document(page_content="Dogs are tough.", metadata={"a": 1}),
Document(page_content="Cats have fluff.", metadata={"b": 1}),
Document(page_content="What is a sandwich?", metadata={"c": 1}),
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
]
def _await_index_deletion(coll: Collection, index_name: str) -> None:
start = monotonic()
try:
drop_vector_search_index(coll, index_name)
except OperationFailure:
# This most likely means an ongoing drop request was made so skip
pass
while list(coll.list_search_indexes(name=index_name)):
if monotonic() - start > TIMEOUT:
raise TimeoutError(f"Index Name: {index_name} never dropped")
sleep(INTERVAL)
def get_collection(
database_name: str = DB_NAME, collection_name: str = COLLECTION_NAME
) -> Collection:
test_client: MongoClient = MongoClient(CONNECTION_STRING)
return test_client[database_name][collection_name]
@pytest.fixture()
def collection() -> Collection:
return get_collection()
@pytest.fixture
def texts() -> List[str]:
return [
"Dogs are tough.",
"Cats have fluff.",
"What is a sandwich?",
"That fence is purple.",
]
@pytest.fixture()
def index_collection() -> Collection:
return get_collection(INDEX_DB_NAME, INDEX_COLLECTION_NAME)
class TestMongoDBAtlasVectorSearch:
@classmethod
def setup_class(cls) -> None:
# insure the test collection is empty
collection = get_collection()
if collection.count_documents({}):
collection.delete_many({}) # type: ignore[index]
@classmethod
def teardown_class(cls) -> None:
collection = get_collection()
# delete all the documents in the collection
collection.delete_many({}) # type: ignore[index]
@pytest.fixture(autouse=True)
def setup(self) -> None:
collection = get_collection()
# delete all the documents in the collection
collection.delete_many({}) # type: ignore[index]
# delete all indexes on index collection name
_await_index_deletion(
get_collection(INDEX_DB_NAME, INDEX_COLLECTION_NAME), INDEX_CREATION_NAME
)
@pytest.fixture
def embeddings(self) -> Embeddings:
try:
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings(
openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
model="text-embedding-3-small",
)
except Exception:
return ConsistentFakeEmbeddings(DIMENSIONS)
def test_from_documents(
self,
embeddings: Embeddings,
collection: Any,
example_documents: List[Document],
) -> None:
"""Test end to end construction and search."""
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
example_documents,
embedding=embeddings,
collection=collection,
index_name=INDEX_NAME,
)
output = vectorstore.similarity_search("Sandwich", k=1)
assert len(output) == 1
# Check for the presence of the metadata key
assert any(
[key.page_content == output[0].page_content for key in example_documents]
)
def test_from_documents_no_embedding_return(
self,
embeddings: Embeddings,
collection: Any,
example_documents: List[Document],
) -> None:
"""Test end to end construction and search."""
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
example_documents,
embedding=embeddings,
collection=collection,
index_name=INDEX_NAME,
)
output = vectorstore.similarity_search("Sandwich", k=1)
assert len(output) == 1
# Check for presence of embedding in each document
assert all(["embedding" not in key.metadata for key in output])
# Check for the presence of the metadata key
assert any(
[key.page_content == output[0].page_content for key in example_documents]
)
def test_from_documents_embedding_return(
self,
embeddings: Embeddings,
collection: Any,
example_documents: List[Document],
) -> None:
"""Test end to end construction and search."""
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
example_documents,
embedding=embeddings,
collection=collection,
index_name=INDEX_NAME,
)
output = vectorstore.similarity_search("Sandwich", k=1, include_embeddings=True)
assert len(output) == 1
# Check for presence of embedding in each document
assert all([key.metadata.get("embedding") for key in output])
# Check for the presence of the metadata key
assert any(
[key.page_content == output[0].page_content for key in example_documents]
)
def test_from_texts(
self, embeddings: Embeddings, collection: Collection, texts: List[str]
) -> None:
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
texts,
embedding=embeddings,
collection=collection,
index_name=INDEX_NAME,
)
output = vectorstore.similarity_search("Sandwich", k=1)
assert len(output) == 1
def test_from_texts_with_metadatas(
self,
embeddings: Embeddings,
collection: Collection,
texts: List[str],
) -> None:
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
metakeys = ["a", "b", "c", "d", "e"]
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
texts,
embedding=embeddings,
metadatas=metadatas,
collection=collection,
index_name=INDEX_NAME,
)
output = vectorstore.similarity_search("Sandwich", k=1)
assert len(output) == 1
# Check for the presence of the metadata key
assert any([key in output[0].metadata for key in metakeys])
def test_from_texts_with_metadatas_and_pre_filter(
self, embeddings: Embeddings, collection: Any, texts: List[str]
) -> None:
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
texts,
embedding=embeddings,
metadatas=metadatas,
collection=collection,
index_name=INDEX_NAME,
)
does_not_match_filter = vectorstore.similarity_search(
"Sandwich", k=1, pre_filter={"c": {"$lte": 0}}
)
assert does_not_match_filter == []
matches_filter = vectorstore.similarity_search(
"Sandwich", k=3, pre_filter={"c": {"$gt": 0}}
)
assert len(matches_filter) == 1
def test_mmr(self, embeddings: Embeddings, collection: Any) -> None:
texts = ["foo", "foo", "fou", "foy"]
vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
texts,
embedding=embeddings,
collection=collection,
index_name=INDEX_NAME,
)
query = "foo"
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
assert len(output) == len(texts)
assert output[0].page_content == "foo"
assert output[1].page_content != "foo"
def test_retriever(
self,
embeddings: Embeddings,
collection: Any,
example_documents: List[Document],
) -> None:
"""Demonstrate usage and parity of VectorStore similarity_search
with Retriever.invoke."""
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
example_documents,
embedding=embeddings,
collection=collection,
index_name=INDEX_NAME,
)
query = "sandwich"
retriever_default_kwargs = vectorstore.as_retriever()
result_retriever = retriever_default_kwargs.invoke(query)
result_vectorstore = vectorstore.similarity_search(query)
assert all(
[
result_retriever[i].page_content == result_vectorstore[i].page_content
for i in range(len(result_retriever))
]
)
def test_include_embeddings(
self,
embeddings: Embeddings,
collection: Any,
example_documents: List[Document],
) -> None:
"""Test explicitly passing vector kwarg matches default."""
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
documents=example_documents,
embedding=embeddings,
collection=collection,
index_name=INDEX_NAME,
)
output_with = vectorstore.similarity_search(
"Sandwich", include_embeddings=True, k=1
)
assert vectorstore._embedding_key in output_with[0].metadata
output_without = vectorstore.similarity_search("Sandwich", k=1)
assert vectorstore._embedding_key not in output_without[0].metadata
def test_delete(
self, embeddings: Embeddings, collection: Any, texts: List[str]
) -> None:
vectorstore = PatchedMongoDBAtlasVectorSearch(
collection=collection,
embedding=embeddings,
index_name=INDEX_NAME,
)
clxn: Collection = vectorstore._collection
assert clxn.count_documents({}) == 0
ids = vectorstore.add_texts(texts)
assert clxn.count_documents({}) == len(texts)
deleted = vectorstore.delete(ids[-2:])
assert deleted
assert clxn.count_documents({}) == len(texts) - 2
new_ids = vectorstore.add_texts(["Pigs eat stuff", "Pigs eat sandwiches"])
assert set(new_ids).intersection(set(ids)) == set() # new ids will be unique.
assert isinstance(new_ids, list)
assert all(isinstance(i, str) for i in new_ids)
assert len(new_ids) == 2
assert clxn.count_documents({}) == 4
def test_add_texts(
self,
embeddings: Embeddings,
collection: Collection,
texts: List[str],
) -> None:
"""Tests API of add_texts, focussing on id treatment
Warning: This is slow because of the number of cases
"""
metadatas: List[Dict[str, Any]] = [
{"a": 1},
{"b": 1},
{"c": 1},
{"d": 1, "e": 2},
]
vectorstore = PatchedMongoDBAtlasVectorSearch(
collection=collection, embedding=embeddings, index_name=INDEX_NAME
)
# Case 1. Add texts without ids
provided_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
all_docs = list(vectorstore._collection.find({}))
assert all("_id" in doc for doc in all_docs)
docids = set(doc["_id"] for doc in all_docs)
assert all(isinstance(_id, ObjectId) for _id in docids) #
assert set(provided_ids) == set(oid_to_str(oid) for oid in docids)
# Case 2: Test Document.metadata looks right. i.e. contains _id
search_res = vectorstore.similarity_search_with_score("sandwich", k=1)
doc, score = search_res[0]
assert "_id" in doc.metadata
# Case 3: Add new ids that are 24-char hex strings
hex_ids = [oid_to_str(ObjectId()) for _ in range(2)]
hex_texts = ["Text for hex_id"] * len(hex_ids)
out_ids = vectorstore.add_texts(texts=hex_texts, ids=hex_ids)
assert set(out_ids) == set(hex_ids)
assert collection.count_documents({}) == len(texts) + len(hex_texts)
assert all(
isinstance(doc["_id"], ObjectId) for doc in vectorstore._collection.find({})
)
# Case 4: Add new ids that cannot be cast to ObjectId
# - We can still index and search on them
str_ids = ["Sandwiches are beautiful,", "..sandwiches are fine."]
str_texts = str_ids # No reason for them to differ
out_ids = vectorstore.add_texts(texts=str_texts, ids=str_ids)
assert set(out_ids) == set(str_ids)
assert collection.count_documents({}) == 8
res = vectorstore.similarity_search("sandwich", k=8)
assert any(str_ids[0] in doc.metadata["_id"] for doc in res)
# Case 5: Test adding in multiple batches
batch_size = 2
batch_ids = [oid_to_str(ObjectId()) for _ in range(2 * batch_size)]
batch_texts = [f"Text for batch text {i}" for i in range(2 * batch_size)]
out_ids = vectorstore.add_texts(
texts=batch_texts, ids=batch_ids, batch_size=batch_size
)
assert set(out_ids) == set(batch_ids)
assert collection.count_documents({}) == 12
# Case 6: _ids in metadata
collection.delete_many({})
# 6a. Unique _id in metadata, but ids=None
# Will be added as if ids kwarg provided
i = 0
n = len(texts)
assert len(metadatas) == n
_ids = [str(i) for i in range(n)]
for md in metadatas:
md["_id"] = _ids[i]
i += 1
returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
assert returned_ids == ["0", "1", "2", "3"]
assert set(d["_id"] for d in vectorstore._collection.find({})) == set(_ids)
# 6b. Unique "id", not "_id", but ids=None
# New ids will be assigned
i = 1
for md in metadatas:
md.pop("_id")
md["id"] = f"{1}"
i += 1
returned_ids = vectorstore.add_texts(texts=texts, metadatas=metadatas)
assert len(set(returned_ids).intersection(set(_ids))) == 0
def test_add_documents(
self,
embeddings: Embeddings,
collection: Collection,
) -> None:
"""Tests add_documents."""
vectorstore = PatchedMongoDBAtlasVectorSearch(
collection=collection, embedding=embeddings, index_name=INDEX_NAME
)
# Case 1: No ids
n_docs = 10
batch_size = 3
docs = [
Document(page_content=f"document {i}", metadata={"i": i})
for i in range(n_docs)
]
result_ids = vectorstore.add_documents(docs, batch_size=batch_size)
assert len(result_ids) == n_docs
assert collection.count_documents({}) == n_docs
# Case 2: ids
collection.delete_many({})
n_docs = 10
batch_size = 3
docs = [
Document(page_content=f"document {i}", metadata={"i": i})
for i in range(n_docs)
]
ids = [str(i) for i in range(n_docs)]
result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size)
assert len(result_ids) == n_docs
assert set(ids) == set(collection.distinct("_id"))
# Case 3: Single batch
collection.delete_many({})
n_docs = 3
batch_size = 10
docs = [
Document(page_content=f"document {i}", metadata={"i": i})
for i in range(n_docs)
]
ids = [str(i) for i in range(n_docs)]
result_ids = vectorstore.add_documents(docs, ids, batch_size=batch_size)
assert len(result_ids) == n_docs
assert set(ids) == set(collection.distinct("_id"))
def test_index_creation(
self, embeddings: Embeddings, index_collection: Any
) -> None:
vectorstore = PatchedMongoDBAtlasVectorSearch(
index_collection, embedding=embeddings, index_name=INDEX_CREATION_NAME
)
vectorstore.create_vector_search_index(dimensions=1536)
def test_index_update(self, embeddings: Embeddings, index_collection: Any) -> None:
vectorstore = PatchedMongoDBAtlasVectorSearch(
index_collection, embedding=embeddings, index_name=INDEX_CREATION_NAME
)
vectorstore.create_vector_search_index(dimensions=1536)
vectorstore.create_vector_search_index(dimensions=1536, update=True)

View File

@ -1,215 +0,0 @@
import uuid
from typing import Any, Dict, List, Union
import pytest # type: ignore[import-not-found]
from langchain_core.caches import BaseCache
from langchain_core.embeddings import Embeddings
from langchain_core.globals import get_llm_cache, set_llm_cache
from langchain_core.load.dump import dumps
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
from pymongo.collection import Collection
from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
from ..utils import ConsistentFakeEmbeddings, FakeChatModel, FakeLLM, MockCollection
CONN_STRING = "MockString"
COLLECTION = "default"
DATABASE = "default"
class PatchedMongoDBCache(MongoDBCache):
def __init__(
self,
connection_string: str,
collection_name: str = "default",
database_name: str = "default",
**kwargs: Dict[str, Any],
) -> None:
self.__database_name = database_name
self.__collection_name = collection_name
self.client = {self.__database_name: {self.__collection_name: MockCollection()}} # type: ignore
@property
def database(self) -> Any: # type: ignore
"""Returns the database used to store cache values."""
return self.client[self.__database_name]
@property
def collection(self) -> Collection:
"""Returns the collection used to store cache values."""
return self.database[self.__collection_name]
class PatchedMongoDBAtlasSemanticCache(MongoDBAtlasSemanticCache):
def __init__(
self,
connection_string: str,
embedding: Embeddings,
collection_name: str = "default",
database_name: str = "default",
wait_until_ready: bool = False,
**kwargs: Dict[str, Any],
):
self.collection = MockCollection()
self._wait_until_ready = False
self.score_threshold = None
MongoDBAtlasVectorSearch.__init__(
self,
self.collection,
embedding=embedding,
**kwargs, # type: ignore
)
def random_string() -> str:
return str(uuid.uuid4())
def llm_cache(cls: Any) -> BaseCache:
set_llm_cache(
cls(
embedding=ConsistentFakeEmbeddings(dimensionality=1536),
connection_string=CONN_STRING,
collection_name=COLLECTION,
database_name=DATABASE,
wait_until_ready=15.0,
)
)
assert get_llm_cache()
return get_llm_cache()
def _execute_test(
prompt: Union[str, List[BaseMessage]],
llm: Union[str, FakeLLM, FakeChatModel],
response: List[Generation],
) -> None:
# Fabricate an LLM String
if not isinstance(llm, str):
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
else:
llm_string = llm
# If the prompt is a str then we should pass just the string
dumped_prompt: str = prompt if isinstance(prompt, str) else dumps(prompt)
# Update the cache
llm_cache = get_llm_cache()
llm_cache.update(dumped_prompt, llm_string, response)
# Retrieve the cached result through 'generate' call
output: Union[List[Generation], LLMResult, None]
expected_output: Union[List[Generation], LLMResult]
if isinstance(llm_cache, PatchedMongoDBAtlasSemanticCache):
llm_cache._collection._aggregate_result = [ # type: ignore
data
for data in llm_cache._collection._data # type: ignore
if data.get("text") == dumped_prompt
and data.get("llm_string") == llm_string
] # type: ignore
if isinstance(llm, str):
output = get_llm_cache().lookup(dumped_prompt, llm) # type: ignore
expected_output = response
else:
output = llm.generate([prompt]) # type: ignore
expected_output = LLMResult(
generations=[response],
llm_output={},
)
assert output == expected_output # type: ignore
@pytest.mark.parametrize(
"prompt, llm, response",
[
("foo", "bar", [Generation(text="fizz")]),
("foo", FakeLLM(), [Generation(text="fizz")]),
(
[HumanMessage(content="foo")],
FakeChatModel(),
[ChatGeneration(message=AIMessage(content="foo"))],
),
],
ids=[
"plain_cache",
"cache_with_llm",
"cache_with_chat",
],
)
@pytest.mark.parametrize(
"cacher", [PatchedMongoDBCache, PatchedMongoDBAtlasSemanticCache]
)
@pytest.mark.parametrize("remove_score", [True, False])
def test_mongodb_cache(
remove_score: bool,
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
prompt: Union[str, List[BaseMessage]],
llm: Union[str, FakeLLM, FakeChatModel],
response: List[Generation],
) -> None:
llm_cache(cacher)
if remove_score:
get_llm_cache().score_threshold = None # type: ignore
try:
_execute_test(prompt, llm, response)
finally:
get_llm_cache().clear()
@pytest.mark.parametrize(
"prompts, generations",
[
# Single prompt, single generation
([random_string()], [[random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string(), random_string()]]),
# Multiple prompts, multiple generations
(
[random_string(), random_string()],
[[random_string()], [random_string(), random_string()]],
),
],
ids=[
"single_prompt_single_generation",
"single_prompt_two_generations",
"single_prompt_three_generations",
"multiple_prompts_multiple_generations",
],
)
def test_mongodb_atlas_cache_matrix(
prompts: List[str],
generations: List[List[str]],
) -> None:
llm_cache(PatchedMongoDBAtlasSemanticCache)
llm = FakeLLM()
# Fabricate an LLM String
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_generations = [
[
Generation(text=generation, generation_info=params)
for generation in prompt_i_generations
]
for prompt_i_generations in generations
]
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
_execute_test(prompt_i, llm_string, llm_generations_i)
get_llm_cache()._collection._simulate_cache_aggregation_query = True # type: ignore
assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={}
)
get_llm_cache().clear()

View File

@ -1,44 +0,0 @@
import json
from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
from langchain_core.messages import message_to_dict
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
from ..utils import MockCollection
class PatchedMongoDBChatMessageHistory(MongoDBChatMessageHistory):
def __init__(self) -> None:
self.session_id = "test-session"
self.database_name = "test-database"
self.collection_name = "test-collection"
self.collection = MockCollection()
self.session_id_key = "SessionId"
self.history_key = "History"
self.history_size = None
def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup MongoDB as a message store
message_history = PatchedMongoDBChatMessageHistory()
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)
# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json
# remove the record from MongoDB, so the next test run won't pick it up
memory.chat_memory.clear()
assert memory.chat_memory.messages == []

View File

@ -1,12 +0,0 @@
from langchain_mongodb import __all__
EXPECTED_ALL = [
"MongoDBAtlasVectorSearch",
"MongoDBChatMessageHistory",
"MongoDBCache",
"MongoDBAtlasSemanticCache",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

View File

@ -1,72 +0,0 @@
"""Search index commands are only supported on Atlas Clusters >=M10"""
import os
from time import sleep
import pytest
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.errors import OperationFailure, ServerSelectionTimeoutError
from langchain_mongodb import index
DIMENSION = 10
TIMEOUT = 10
@pytest.fixture
def collection() -> Collection:
"""Depending on uri, this could point to any type of cluster.
For unit tests, MONGODB_URI should be localhost, None, or Atlas cluster <M10.
"""
uri = os.environ.get("MONGODB_URI")
client: MongoClient = MongoClient(uri)
return client["db"]["collection"]
def test_create_vector_search_index(collection: Collection) -> None:
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
index.create_vector_search_index(
collection,
"index_name",
DIMENSION,
"embedding",
"cosine",
[],
wait_until_complete=TIMEOUT,
)
def test_drop_vector_search_index(collection: Collection) -> None:
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
index.drop_vector_search_index(
collection, "index_name", wait_until_complete=TIMEOUT
)
def test_update_vector_search_index(collection: Collection) -> None:
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
index.update_vector_search_index(
collection,
"index_name",
DIMENSION,
"embedding",
"cosine",
[],
wait_until_complete=TIMEOUT,
)
def test___is_index_ready(collection: Collection) -> None:
with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
index._is_index_ready(collection, "index_name")
def test__wait_for_predicate() -> None:
err = "error string"
with pytest.raises(TimeoutError) as e:
index._wait_for_predicate(lambda: sleep(5), err=err, timeout=0.5, interval=0.1)
assert err in str(e)
index._wait_for_predicate(lambda: True, err=err, timeout=1.0, interval=0.5)

View File

@ -1,191 +0,0 @@
from json import dumps, loads
from typing import Any, Optional
import pytest # type: ignore[import-not-found]
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from pymongo.collection import Collection
from langchain_mongodb import MongoDBAtlasVectorSearch
from ..utils import ConsistentFakeEmbeddings, MockCollection
INDEX_NAME = "langchain-test-index"
NAMESPACE = "langchain_test_db.langchain_test_collection"
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
def get_collection() -> MockCollection:
return MockCollection()
@pytest.fixture()
def collection() -> MockCollection:
return get_collection()
@pytest.fixture(scope="module")
def embedding_openai() -> Embeddings:
return ConsistentFakeEmbeddings()
def test_initialization(collection: Collection, embedding_openai: Embeddings) -> None:
"""Test initialization of vector store class"""
assert MongoDBAtlasVectorSearch(collection, embedding_openai)
def test_init_from_texts(collection: Collection, embedding_openai: Embeddings) -> None:
"""Test from_texts operation on an empty list"""
assert MongoDBAtlasVectorSearch.from_texts(
[], embedding_openai, collection=collection
)
class TestMongoDBAtlasVectorSearch:
@classmethod
def setup_class(cls) -> None:
# ensure the test collection is empty
collection = get_collection()
assert collection.count_documents({}) == 0 # type: ignore[index]
@classmethod
def teardown_class(cls) -> None:
collection = get_collection()
# delete all the documents in the collection
collection.delete_many({}) # type: ignore[index]
@pytest.fixture(autouse=True)
def setup(self) -> None:
collection = get_collection()
# delete all the documents in the collection
collection.delete_many({}) # type: ignore[index]
def _validate_search(
self,
vectorstore: MongoDBAtlasVectorSearch,
collection: MockCollection,
search_term: str = "sandwich",
page_content: str = "What is a sandwich?",
metadata: Optional[Any] = 1,
) -> None:
collection._aggregate_result = list(
filter(
lambda x: search_term.lower() in x[vectorstore._text_key].lower(),
collection._data,
)
)
output = vectorstore.similarity_search("", k=1)
assert output[0].page_content == page_content
assert output[0].metadata.get("c") == metadata
# Validate the ObjectId provided is json serializable
assert loads(dumps(output[0].page_content)) == output[0].page_content
assert loads(dumps(output[0].metadata)) == output[0].metadata
assert isinstance(output[0].metadata["_id"], str)
def test_from_documents(
self, embedding_openai: Embeddings, collection: MockCollection
) -> None:
"""Test end to end construction and search."""
documents = [
Document(page_content="Dogs are tough.", metadata={"a": 1}),
Document(page_content="Cats have fluff.", metadata={"b": 1}),
Document(page_content="What is a sandwich?", metadata={"c": 1}),
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
]
vectorstore = MongoDBAtlasVectorSearch.from_documents(
documents,
embedding_openai,
collection=collection,
vector_index_name=INDEX_NAME,
)
self._validate_search(
vectorstore, collection, metadata=documents[2].metadata["c"]
)
def test_from_texts(
self, embedding_openai: Embeddings, collection: MockCollection
) -> None:
texts = [
"Dogs are tough.",
"Cats have fluff.",
"What is a sandwich?",
"That fence is purple.",
]
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
collection=collection,
vector_index_name=INDEX_NAME,
)
self._validate_search(vectorstore, collection, metadata=None)
def test_from_texts_with_metadatas(
self, embedding_openai: Embeddings, collection: MockCollection
) -> None:
texts = [
"Dogs are tough.",
"Cats have fluff.",
"What is a sandwich?",
"The fence is purple.",
]
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
metadatas=metadatas,
collection=collection,
vector_index_name=INDEX_NAME,
)
self._validate_search(vectorstore, collection, metadata=metadatas[2]["c"])
def test_from_texts_with_metadatas_and_pre_filter(
self, embedding_openai: Embeddings, collection: MockCollection
) -> None:
texts = [
"Dogs are tough.",
"Cats have fluff.",
"What is a sandwich?",
"The fence is purple.",
]
metadatas = [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
metadatas=metadatas,
collection=collection,
vector_index_name=INDEX_NAME,
)
collection._aggregate_result = list(
filter(
lambda x: "sandwich" in x[vectorstore._text_key].lower()
and x.get("c") < 0,
collection._data,
)
)
output = vectorstore.similarity_search(
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}}
)
assert output == []
def test_mmr(
self, embedding_openai: Embeddings, collection: MockCollection
) -> None:
texts = ["foo", "foo", "fou", "foy"]
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
embedding=embedding_openai,
collection=collection,
vector_index_name=INDEX_NAME,
)
query = "foo"
self._validate_search(
vectorstore,
collection,
search_term=query[0:2],
page_content=query,
metadata=None,
)
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
assert len(output) == len(texts)
assert output[0].page_content == "foo"
assert output[1].page_content != "foo"

View File

@ -1,273 +0,0 @@
from __future__ import annotations
from copy import deepcopy
from time import monotonic, sleep
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Union, cast
from bson import ObjectId
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import SimpleChatModel
from langchain_core.language_models.llms import LLM
from langchain_core.messages import (
AIMessage,
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from pydantic import model_validator
from pymongo.collection import Collection
from pymongo.results import DeleteResult, InsertManyResult
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.cache import MongoDBAtlasSemanticCache
TIMEOUT = 120
INTERVAL = 0.5
class PatchedMongoDBAtlasVectorSearch(MongoDBAtlasVectorSearch):
def bulk_embed_and_insert_texts(
self,
texts: Union[List[str], Iterable[str]],
metadatas: Union[List[dict], Generator[dict, Any, Any]],
ids: Optional[List[str]] = None,
) -> List:
"""Patched insert_texts that waits for data to be indexed before returning"""
ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids)
start = monotonic()
while len(ids_inserted) != len(self.similarity_search("sandwich")) and (
monotonic() - start <= TIMEOUT
):
sleep(INTERVAL)
return ids_inserted
def create_vector_search_index(
self,
dimensions: int,
filters: Optional[List[str]] = None,
update: bool = False,
) -> None:
result = super().create_vector_search_index(
dimensions=dimensions, filters=filters, update=update
)
start = monotonic()
while monotonic() - start <= TIMEOUT:
if indexes := list(
self._collection.list_search_indexes(name=self._index_name)
):
if indexes[0].get("status") == "READY":
return result
sleep(INTERVAL)
class ConsistentFakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: List[str] = []
self.dimensionality = dimensionality
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * (self.dimensionality - 1) + [
float(self.known_texts.index(text))
]
out_vectors.append(vector)
return out_vectors
def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)
async def aembed_query(self, text: str) -> List[float]:
return self.embed_query(text)
class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return "fake response"
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = "fake response"
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "fake-chat-model"
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"key": "fake"}
class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes."""
queries: Optional[Mapping] = None
sequential_responses: Optional[bool] = False
response_index: int = 0
@model_validator(mode="before")
@classmethod
def check_queries_required(cls, values: dict) -> dict:
if values.get("sequential_response") and not values.get("queries"):
raise ValueError(
"queries is required when sequential_response is set to True"
)
return values
def get_num_tokens(self, text: str) -> int:
"""Return number of tokens."""
return len(text.split())
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fake"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.sequential_responses:
return self._get_next_response_in_sequence
if self.queries is not None:
return self.queries[prompt]
if stop is None:
return "foo"
else:
return "bar"
@property
def _identifying_params(self) -> Dict[str, Any]:
return {}
@property
def _get_next_response_in_sequence(self) -> str:
queries = cast(Mapping, self.queries)
response = queries[list(queries.keys())[self.response_index]]
self.response_index = self.response_index + 1
return response
class MockCollection(Collection):
"""Mocked Mongo Collection"""
_aggregate_result: List[Any]
_insert_result: Optional[InsertManyResult]
_data: List[Any]
_simulate_cache_aggregation_query: bool
def __init__(self) -> None:
self._data = []
self._aggregate_result = []
self._insert_result = None
self._simulate_cache_aggregation_query = False
def delete_many(self, *args, **kwargs) -> DeleteResult: # type: ignore
old_len = len(self._data)
self._data = []
return DeleteResult({"n": old_len}, acknowledged=True)
def insert_many(self, to_insert: List[Any], *args, **kwargs) -> InsertManyResult: # type: ignore
mongodb_inserts = [
{"_id": ObjectId(), "score": 1, **insert} for insert in to_insert
]
self._data.extend(mongodb_inserts)
return self._insert_result or InsertManyResult(
[k["_id"] for k in mongodb_inserts], acknowledged=True
)
def insert_one(self, to_insert: Any, *args, **kwargs) -> Any: # type: ignore
return self.insert_many([to_insert])
def find_one(self, find_query: Dict[str, Any]) -> Optional[Dict[str, Any]]: # type: ignore
find = self.find(find_query) or [None] # type: ignore
return find[0]
def find(self, find_query: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]: # type: ignore
def _is_match(item: Dict[str, Any]) -> bool:
for key, match_val in find_query.items():
if item.get(key) != match_val:
return False
return True
return [document for document in self._data if _is_match(document)]
def update_one( # type: ignore
self,
find_query: Dict[str, Any],
options: Dict[str, Any],
*args: Any,
upsert=True,
**kwargs: Any,
) -> None: # type: ignore
result = self.find_one(find_query)
set_options = options.get("$set", {})
if result:
result.update(set_options)
elif upsert:
self._data.append({**find_query, **set_options})
def _execute_cache_aggregation_query(self, *args, **kwargs) -> List[Dict[str, Any]]: # type: ignore
"""Helper function only to be used for MongoDBAtlasSemanticCache Testing
Returns:
List[Dict[str, Any]]: Aggregation query result
"""
pipeline: List[Dict[str, Any]] = args[0]
params = pipeline[0]["$vectorSearch"]
embedding = params["queryVector"]
# Assumes MongoDBAtlasSemanticCache.LLM == "llm_string"
llm_string = params["filter"][MongoDBAtlasSemanticCache.LLM]["$eq"]
acc = []
for document in self._data:
if (
document.get("embedding") == embedding
and document.get(MongoDBAtlasSemanticCache.LLM) == llm_string
):
acc.append(document)
return acc
def aggregate(self, *args, **kwargs) -> List[Any]: # type: ignore
if self._simulate_cache_aggregation_query:
return deepcopy(self._execute_cache_aggregation_query(*args, **kwargs))
return deepcopy(self._aggregate_result)
def count_documents(self, *args, **kwargs) -> int: # type: ignore
return len(self._data)
def __repr__(self) -> str:
return "MockCollection"