diff --git a/libs/partners/astradb/langchain_astradb/utils/astradb.py b/libs/partners/astradb/langchain_astradb/utils/astradb.py index dc2c24c3ff..b1869a8bff 100644 --- a/libs/partners/astradb/langchain_astradb/utils/astradb.py +++ b/libs/partners/astradb/langchain_astradb/utils/astradb.py @@ -6,6 +6,7 @@ from asyncio import InvalidStateError, Task from enum import Enum from typing import Awaitable, Optional, Union +import langchain_core from astrapy.db import AstraDB, AsyncAstraDB @@ -51,13 +52,13 @@ class _AstraDBEnvironment: ) if astra_db: - self.astra_db = astra_db + self.astra_db = astra_db.copy() if async_astra_db: - self.async_astra_db = async_astra_db + self.async_astra_db = async_astra_db.copy() else: self.async_astra_db = self.astra_db.to_async() elif async_astra_db: - self.async_astra_db = async_astra_db + self.async_astra_db = async_astra_db.copy() self.astra_db = self.async_astra_db.to_sync() else: raise ValueError( @@ -65,6 +66,15 @@ class _AstraDBEnvironment: "'token' and 'api_endpoint'" ) + self.astra_db.set_caller( + caller_name="langchain", + caller_version=getattr(langchain_core, "__version__", None), + ) + self.async_astra_db.set_caller( + caller_name="langchain", + caller_version=getattr(langchain_core, "__version__", None), + ) + class _AstraDBCollectionEnvironment(_AstraDBEnvironment): def __init__(