mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
9aabb446c5
Hello @eyurtsev - package: langchain-comminity - **Description**: Add SQL implementation for docstore. A new implementation, in line with my other PR ([async PGVector](https://github.com/langchain-ai/langchain-postgres/pull/32), [SQLChatMessageMemory](https://github.com/langchain-ai/langchain/pull/22065)) - Twitter handler: pprados --------- Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Piotr Mardziel <piotrm@gmail.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
187 lines
6.5 KiB
Python
187 lines
6.5 KiB
Python
"""Implement integration tests for Redis storage."""
|
|
|
|
import pytest
|
|
from sqlalchemy import Engine, create_engine, text
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
|
|
|
from langchain_community.storage import SQLStore
|
|
|
|
pytest.importorskip("sqlalchemy")
|
|
|
|
|
|
@pytest.fixture
|
|
def sql_engine() -> Engine:
|
|
"""Yield redis client."""
|
|
return create_engine(url="sqlite://", echo=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def sql_aengine() -> AsyncEngine:
|
|
"""Yield redis client."""
|
|
return create_async_engine(url="sqlite+aiosqlite:///:memory:", echo=True)
|
|
|
|
|
|
def test_mget(sql_engine: Engine) -> None:
|
|
"""Test mget method."""
|
|
store = SQLStore(engine=sql_engine, namespace="test")
|
|
store.create_schema()
|
|
keys = ["key1", "key2"]
|
|
with sql_engine.connect() as session:
|
|
session.execute(
|
|
text(
|
|
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
|
"values('test','key1',:value)"
|
|
).bindparams(value=b"value1"),
|
|
)
|
|
session.execute(
|
|
text(
|
|
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
|
"values('test','key2',:value)"
|
|
).bindparams(value=b"value2"),
|
|
)
|
|
session.commit()
|
|
|
|
result = store.mget(keys)
|
|
assert result == [b"value1", b"value2"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_amget(sql_aengine: AsyncEngine) -> None:
|
|
"""Test mget method."""
|
|
store = SQLStore(engine=sql_aengine, namespace="test")
|
|
await store.acreate_schema()
|
|
keys = ["key1", "key2"]
|
|
async with sql_aengine.connect() as session:
|
|
await session.execute(
|
|
text(
|
|
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
|
"values('test','key1',:value)"
|
|
).bindparams(value=b"value1"),
|
|
)
|
|
await session.execute(
|
|
text(
|
|
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
|
"values('test','key2',:value)"
|
|
).bindparams(value=b"value2"),
|
|
)
|
|
await session.commit()
|
|
|
|
result = await store.amget(keys)
|
|
assert result == [b"value1", b"value2"]
|
|
|
|
|
|
def test_mset(sql_engine: Engine) -> None:
|
|
"""Test that multiple keys can be set."""
|
|
store = SQLStore(engine=sql_engine, namespace="test")
|
|
store.create_schema()
|
|
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
|
store.mset(key_value_pairs)
|
|
|
|
with sql_engine.connect() as session:
|
|
result = session.exec_driver_sql("select * from langchain_key_value_stores")
|
|
assert result.keys() == ["namespace", "key", "value"]
|
|
data = [(row[0], row[1]) for row in result]
|
|
assert data == [("test", "key1"), ("test", "key2")]
|
|
session.commit()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_amset(sql_aengine: AsyncEngine) -> None:
|
|
"""Test that multiple keys can be set."""
|
|
store = SQLStore(engine=sql_aengine, namespace="test")
|
|
await store.acreate_schema()
|
|
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
|
await store.amset(key_value_pairs)
|
|
|
|
async with sql_aengine.connect() as session:
|
|
result = await session.exec_driver_sql(
|
|
"select * from langchain_key_value_stores"
|
|
)
|
|
assert result.keys() == ["namespace", "key", "value"]
|
|
data = [(row[0], row[1]) for row in result]
|
|
assert data == [("test", "key1"), ("test", "key2")]
|
|
await session.commit()
|
|
|
|
|
|
def test_mdelete(sql_engine: Engine) -> None:
|
|
"""Test that deletion works as expected."""
|
|
store = SQLStore(engine=sql_engine, namespace="test")
|
|
store.create_schema()
|
|
keys = ["key1", "key2"]
|
|
with sql_engine.connect() as session:
|
|
session.execute(
|
|
text(
|
|
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
|
"values('test','key1',:value)"
|
|
).bindparams(value=b"value1"),
|
|
)
|
|
session.execute(
|
|
text(
|
|
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
|
"values('test','key2',:value)"
|
|
).bindparams(value=b"value2"),
|
|
)
|
|
session.commit()
|
|
store.mdelete(keys)
|
|
with sql_engine.connect() as session:
|
|
result = session.exec_driver_sql("select * from langchain_key_value_stores")
|
|
assert result.keys() == ["namespace", "key", "value"]
|
|
data = [row for row in result]
|
|
assert data == []
|
|
session.commit()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_amdelete(sql_aengine: AsyncEngine) -> None:
|
|
"""Test that deletion works as expected."""
|
|
store = SQLStore(engine=sql_aengine, namespace="test")
|
|
await store.acreate_schema()
|
|
keys = ["key1", "key2"]
|
|
async with sql_aengine.connect() as session:
|
|
await session.execute(
|
|
text(
|
|
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
|
"values('test','key1',:value)"
|
|
).bindparams(value=b"value1"),
|
|
)
|
|
await session.execute(
|
|
text(
|
|
"insert into langchain_key_value_stores ('namespace', 'key', 'value') "
|
|
"values('test','key2',:value)"
|
|
).bindparams(value=b"value2"),
|
|
)
|
|
await session.commit()
|
|
await store.amdelete(keys)
|
|
async with sql_aengine.connect() as session:
|
|
result = await session.exec_driver_sql(
|
|
"select * from langchain_key_value_stores"
|
|
)
|
|
assert result.keys() == ["namespace", "key", "value"]
|
|
data = [row for row in result]
|
|
assert data == []
|
|
await session.commit()
|
|
|
|
|
|
def test_yield_keys(sql_engine: Engine) -> None:
|
|
store = SQLStore(engine=sql_engine, namespace="test")
|
|
store.create_schema()
|
|
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
|
store.mset(key_value_pairs)
|
|
assert sorted(store.yield_keys()) == ["key1", "key2"]
|
|
assert sorted(store.yield_keys(prefix="key")) == ["key1", "key2"]
|
|
assert sorted(store.yield_keys(prefix="lang")) == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ayield_keys(sql_aengine: AsyncEngine) -> None:
|
|
store = SQLStore(engine=sql_aengine, namespace="test")
|
|
await store.acreate_schema()
|
|
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
|
await store.amset(key_value_pairs)
|
|
assert sorted([k async for k in store.ayield_keys()]) == ["key1", "key2"]
|
|
assert sorted([k async for k in store.ayield_keys(prefix="key")]) == [
|
|
"key1",
|
|
"key2",
|
|
]
|
|
assert sorted([k async for k in store.ayield_keys(prefix="lang")]) == []
|