Add a Transformation Chain (#257)

Arbitrary transformation chains that can be used to add dictionary
extractions from llms/other chains
harrison/promot-mrkl
Akash Samant 1 year ago committed by GitHub
parent b7bef36ee1
commit 48b093823e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,41 @@
"""Chain that runs an arbitrary python function."""
from typing import Callable, Dict, List
from pydantic import BaseModel
from langchain.chains.base import Chain
class TransformChain(Chain, BaseModel):
"""Chain transform chain output.
Example:
.. code-block:: python
from langchain import TransformChain
transform_chain = TransformChain(input_variables=["text"],
output_variables["entities"], transform=func())
"""
input_variables: List[str]
output_variables: List[str]
transform: Callable[[Dict[str, str]], Dict[str, str]]
@property
def input_keys(self) -> List[str]:
"""Expect input keys.
:meta private:
"""
return self.input_variables
@property
def output_keys(self) -> List[str]:
"""Return output keys.
:meta private:
"""
return self.output_variables
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
return self.transform(inputs)

@ -0,0 +1,40 @@
"""Test transform chain."""
from typing import Dict
import pytest
from langchain.chains.transform import TransformChain
def dummy_transform(inputs: Dict[str, str]) -> Dict[str, str]:
"""Transform a dummy input for tests."""
outputs = inputs
outputs["greeting"] = f"{inputs['first_name']} {inputs['last_name']} says hello"
del outputs["first_name"]
del outputs["last_name"]
return outputs
def test_tranform_chain() -> None:
"""Test basic transform chain."""
transform_chain = TransformChain(
input_variables=["first_name", "last_name"],
output_variables=["greeting"],
transform=dummy_transform,
)
input_dict = {"first_name": "Leroy", "last_name": "Jenkins"}
response = transform_chain(input_dict)
expected_response = {"greeting": "Leroy Jenkins says hello"}
assert response == expected_response
def test_transform_chain_bad_inputs() -> None:
"""Test basic transform chain."""
transform_chain = TransformChain(
input_variables=["first_name", "last_name"],
output_variables=["greeting"],
transform=dummy_transform,
)
input_dict = {"name": "Leroy", "last_name": "Jenkins"}
with pytest.raises(ValueError):
_ = transform_chain(input_dict)
Loading…
Cancel
Save