forked from Archives/langchain
Add a Transformation Chain (#257)
Arbitrary transformation chains that can be used to add dictionary extractions from llms/other chainsharrison/promot-mrkl
parent
b7bef36ee1
commit
48b093823e
@ -0,0 +1,41 @@
|
||||
"""Chain that runs an arbitrary python function."""
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
class TransformChain(Chain, BaseModel):
|
||||
"""Chain transform chain output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import TransformChain
|
||||
transform_chain = TransformChain(input_variables=["text"],
|
||||
output_variables["entities"], transform=func())
|
||||
"""
|
||||
|
||||
input_variables: List[str]
|
||||
output_variables: List[str]
|
||||
transform: Callable[[Dict[str, str]], Dict[str, str]]
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.output_variables
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
return self.transform(inputs)
|
@ -0,0 +1,40 @@
|
||||
"""Test transform chain."""
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.transform import TransformChain
|
||||
|
||||
|
||||
def dummy_transform(inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Transform a dummy input for tests."""
|
||||
outputs = inputs
|
||||
outputs["greeting"] = f"{inputs['first_name']} {inputs['last_name']} says hello"
|
||||
del outputs["first_name"]
|
||||
del outputs["last_name"]
|
||||
return outputs
|
||||
|
||||
|
||||
def test_tranform_chain() -> None:
|
||||
"""Test basic transform chain."""
|
||||
transform_chain = TransformChain(
|
||||
input_variables=["first_name", "last_name"],
|
||||
output_variables=["greeting"],
|
||||
transform=dummy_transform,
|
||||
)
|
||||
input_dict = {"first_name": "Leroy", "last_name": "Jenkins"}
|
||||
response = transform_chain(input_dict)
|
||||
expected_response = {"greeting": "Leroy Jenkins says hello"}
|
||||
assert response == expected_response
|
||||
|
||||
|
||||
def test_transform_chain_bad_inputs() -> None:
|
||||
"""Test basic transform chain."""
|
||||
transform_chain = TransformChain(
|
||||
input_variables=["first_name", "last_name"],
|
||||
output_variables=["greeting"],
|
||||
transform=dummy_transform,
|
||||
)
|
||||
input_dict = {"name": "Leroy", "last_name": "Jenkins"}
|
||||
with pytest.raises(ValueError):
|
||||
_ = transform_chain(input_dict)
|
Loading…
Reference in New Issue