From 48b093823e16f343a53696d226a9a4339be1a431 Mon Sep 17 00:00:00 2001 From: Akash Samant <70665700+asamant21@users.noreply.github.com> Date: Tue, 6 Dec 2022 21:58:16 -0800 Subject: [PATCH] Add a Transformation Chain (#257) Arbitrary transformation chains that can be used to add dictionary extractions from llms/other chains --- langchain/chains/transform.py | 41 +++++++++++++++++++++++ tests/unit_tests/chains/test_transform.py | 40 ++++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 langchain/chains/transform.py create mode 100644 tests/unit_tests/chains/test_transform.py diff --git a/langchain/chains/transform.py b/langchain/chains/transform.py new file mode 100644 index 00000000..f3635671 --- /dev/null +++ b/langchain/chains/transform.py @@ -0,0 +1,41 @@ +"""Chain that runs an arbitrary python function.""" +from typing import Callable, Dict, List + +from pydantic import BaseModel + +from langchain.chains.base import Chain + + +class TransformChain(Chain, BaseModel): + """Chain transform chain output. + + Example: + .. code-block:: python + + from langchain import TransformChain + transform_chain = TransformChain(input_variables=["text"], + output_variables["entities"], transform=func()) + """ + + input_variables: List[str] + output_variables: List[str] + transform: Callable[[Dict[str, str]], Dict[str, str]] + + @property + def input_keys(self) -> List[str]: + """Expect input keys. + + :meta private: + """ + return self.input_variables + + @property + def output_keys(self) -> List[str]: + """Return output keys. + + :meta private: + """ + return self.output_variables + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + return self.transform(inputs) diff --git a/tests/unit_tests/chains/test_transform.py b/tests/unit_tests/chains/test_transform.py new file mode 100644 index 00000000..a4dbca25 --- /dev/null +++ b/tests/unit_tests/chains/test_transform.py @@ -0,0 +1,40 @@ +"""Test transform chain.""" +from typing import Dict + +import pytest + +from langchain.chains.transform import TransformChain + + +def dummy_transform(inputs: Dict[str, str]) -> Dict[str, str]: + """Transform a dummy input for tests.""" + outputs = inputs + outputs["greeting"] = f"{inputs['first_name']} {inputs['last_name']} says hello" + del outputs["first_name"] + del outputs["last_name"] + return outputs + + +def test_tranform_chain() -> None: + """Test basic transform chain.""" + transform_chain = TransformChain( + input_variables=["first_name", "last_name"], + output_variables=["greeting"], + transform=dummy_transform, + ) + input_dict = {"first_name": "Leroy", "last_name": "Jenkins"} + response = transform_chain(input_dict) + expected_response = {"greeting": "Leroy Jenkins says hello"} + assert response == expected_response + + +def test_transform_chain_bad_inputs() -> None: + """Test basic transform chain.""" + transform_chain = TransformChain( + input_variables=["first_name", "last_name"], + output_variables=["greeting"], + transform=dummy_transform, + ) + input_dict = {"name": "Leroy", "last_name": "Jenkins"} + with pytest.raises(ValueError): + _ = transform_chain(input_dict)