langchain/tests/unit_tests/chains/test_transform.py
Akash Samant 48b093823e
Add a Transformation Chain (#257)
Arbitrary transformation chains that can be used to add dictionary
extractions from llms/other chains
2022-12-06 21:58:16 -08:00

41 lines
1.2 KiB
Python

"""Test transform chain."""
from typing import Dict
import pytest
from langchain.chains.transform import TransformChain
def dummy_transform(inputs: Dict[str, str]) -> Dict[str, str]:
"""Transform a dummy input for tests."""
outputs = inputs
outputs["greeting"] = f"{inputs['first_name']} {inputs['last_name']} says hello"
del outputs["first_name"]
del outputs["last_name"]
return outputs
def test_tranform_chain() -> None:
"""Test basic transform chain."""
transform_chain = TransformChain(
input_variables=["first_name", "last_name"],
output_variables=["greeting"],
transform=dummy_transform,
)
input_dict = {"first_name": "Leroy", "last_name": "Jenkins"}
response = transform_chain(input_dict)
expected_response = {"greeting": "Leroy Jenkins says hello"}
assert response == expected_response
def test_transform_chain_bad_inputs() -> None:
"""Test basic transform chain."""
transform_chain = TransformChain(
input_variables=["first_name", "last_name"],
output_variables=["greeting"],
transform=dummy_transform,
)
input_dict = {"name": "Leroy", "last_name": "Jenkins"}
with pytest.raises(ValueError):
_ = transform_chain(input_dict)