langchain/libs/community/tests/integration_tests/storage/test_sql.py
Philippe PRADOS 9aabb446c5
community[minor]: Add SQL storage implementation (#22207)
Hello @eyurtsev

- package: langchain-comminity
- **Description**: Add SQL implementation for docstore. A new
implementation, in line with my other PR ([async
PGVector](https://github.com/langchain-ai/langchain-postgres/pull/32),
[SQLChatMessageMemory](https://github.com/langchain-ai/langchain/pull/22065))
- Twitter handler: pprados

---------

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Piotr Mardziel <piotrm@gmail.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-06-07 21:17:02 +00:00

187 lines
6.5 KiB
Python

"""Implement integration tests for Redis storage."""
import pytest
from sqlalchemy import Engine, create_engine, text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from langchain_community.storage import SQLStore
pytest.importorskip("sqlalchemy")
@pytest.fixture
def sql_engine() -> Engine:
"""Yield redis client."""
return create_engine(url="sqlite://", echo=True)
@pytest.fixture
def sql_aengine() -> AsyncEngine:
"""Yield redis client."""
return create_async_engine(url="sqlite+aiosqlite:///:memory:", echo=True)
def test_mget(sql_engine: Engine) -> None:
"""Test mget method."""
store = SQLStore(engine=sql_engine, namespace="test")
store.create_schema()
keys = ["key1", "key2"]
with sql_engine.connect() as session:
session.execute(
text(
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
"values('test','key1',:value)"
).bindparams(value=b"value1"),
)
session.execute(
text(
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
"values('test','key2',:value)"
).bindparams(value=b"value2"),
)
session.commit()
result = store.mget(keys)
assert result == [b"value1", b"value2"]
@pytest.mark.asyncio
async def test_amget(sql_aengine: AsyncEngine) -> None:
"""Test mget method."""
store = SQLStore(engine=sql_aengine, namespace="test")
await store.acreate_schema()
keys = ["key1", "key2"]
async with sql_aengine.connect() as session:
await session.execute(
text(
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
"values('test','key1',:value)"
).bindparams(value=b"value1"),
)
await session.execute(
text(
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
"values('test','key2',:value)"
).bindparams(value=b"value2"),
)
await session.commit()
result = await store.amget(keys)
assert result == [b"value1", b"value2"]
def test_mset(sql_engine: Engine) -> None:
"""Test that multiple keys can be set."""
store = SQLStore(engine=sql_engine, namespace="test")
store.create_schema()
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
store.mset(key_value_pairs)
with sql_engine.connect() as session:
result = session.exec_driver_sql("select * from langchain_key_value_stores")
assert result.keys() == ["namespace", "key", "value"]
data = [(row[0], row[1]) for row in result]
assert data == [("test", "key1"), ("test", "key2")]
session.commit()
@pytest.mark.asyncio
async def test_amset(sql_aengine: AsyncEngine) -> None:
"""Test that multiple keys can be set."""
store = SQLStore(engine=sql_aengine, namespace="test")
await store.acreate_schema()
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
await store.amset(key_value_pairs)
async with sql_aengine.connect() as session:
result = await session.exec_driver_sql(
"select * from langchain_key_value_stores"
)
assert result.keys() == ["namespace", "key", "value"]
data = [(row[0], row[1]) for row in result]
assert data == [("test", "key1"), ("test", "key2")]
await session.commit()
def test_mdelete(sql_engine: Engine) -> None:
"""Test that deletion works as expected."""
store = SQLStore(engine=sql_engine, namespace="test")
store.create_schema()
keys = ["key1", "key2"]
with sql_engine.connect() as session:
session.execute(
text(
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
"values('test','key1',:value)"
).bindparams(value=b"value1"),
)
session.execute(
text(
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
"values('test','key2',:value)"
).bindparams(value=b"value2"),
)
session.commit()
store.mdelete(keys)
with sql_engine.connect() as session:
result = session.exec_driver_sql("select * from langchain_key_value_stores")
assert result.keys() == ["namespace", "key", "value"]
data = [row for row in result]
assert data == []
session.commit()
@pytest.mark.asyncio
async def test_amdelete(sql_aengine: AsyncEngine) -> None:
"""Test that deletion works as expected."""
store = SQLStore(engine=sql_aengine, namespace="test")
await store.acreate_schema()
keys = ["key1", "key2"]
async with sql_aengine.connect() as session:
await session.execute(
text(
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
"values('test','key1',:value)"
).bindparams(value=b"value1"),
)
await session.execute(
text(
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
"values('test','key2',:value)"
).bindparams(value=b"value2"),
)
await session.commit()
await store.amdelete(keys)
async with sql_aengine.connect() as session:
result = await session.exec_driver_sql(
"select * from langchain_key_value_stores"
)
assert result.keys() == ["namespace", "key", "value"]
data = [row for row in result]
assert data == []
await session.commit()
def test_yield_keys(sql_engine: Engine) -> None:
store = SQLStore(engine=sql_engine, namespace="test")
store.create_schema()
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
store.mset(key_value_pairs)
assert sorted(store.yield_keys()) == ["key1", "key2"]
assert sorted(store.yield_keys(prefix="key")) == ["key1", "key2"]
assert sorted(store.yield_keys(prefix="lang")) == []
@pytest.mark.asyncio
async def test_ayield_keys(sql_aengine: AsyncEngine) -> None:
store = SQLStore(engine=sql_aengine, namespace="test")
await store.acreate_schema()
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
await store.amset(key_value_pairs)
assert sorted([k async for k in store.ayield_keys()]) == ["key1", "key2"]
assert sorted([k async for k in store.ayield_keys(prefix="key")]) == [
"key1",
"key2",
]
assert sorted([k async for k in store.ayield_keys(prefix="lang")]) == []