Compare commits

...

2 Commits

Author SHA1 Message Date
Harrison Chase 6766dff804 custom chain and example 2 years ago
Harrison Chase b5325c212b chain pipelines 2 years ago

@ -0,0 +1,243 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "dd2aa1bb",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.pipeline import Pipeline\n",
"from langchain.chains.custom import SimpleCustomChain\n",
"from langchain.llms import OpenAI\n",
"from langchain.chains import LLMChain\n",
"from langchain import Prompt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "05390d95",
"metadata": {},
"outputs": [],
"source": [
"prompt_template = \"\"\"You are given the below API Documentation:\n",
"\n",
"API Documentation\n",
"The API endpoint /v1/forecast accepts a geographical coordinate, a list of weather variables and responds with a JSON hourly weather forecast for 7 days. Time always starts at 0:00 today and contains 168 hours. All URL parameters are listed below:\n",
"\n",
"Parameter\tFormat\tRequired\tDefault\tDescription\n",
"latitude, longitude\tFloating point\tYes\t\tGeographical WGS84 coordinate of the location\n",
"hourly\tString array\tNo\t\tA list of weather variables which should be returned. Values can be comma separated, or multiple &hourly= parameter in the URL can be used.\n",
"daily\tString array\tNo\t\tA list of daily weather variable aggregations which should be returned. Values can be comma separated, or multiple &daily= parameter in the URL can be used. If daily weather variables are specified, parameter timezone is required.\n",
"current_weather\tBool\tNo\tfalse\tInclude current weather conditions in the JSON output.\n",
"temperature_unit\tString\tNo\tcelsius\tIf fahrenheit is set, all temperature values are converted to Fahrenheit.\n",
"windspeed_unit\tString\tNo\tkmh\tOther wind speed speed units: ms, mph and kn\n",
"precipitation_unit\tString\tNo\tmm\tOther precipitation amount units: inch\n",
"timeformat\tString\tNo\tiso8601\tIf format unixtime is selected, all time values are returned in UNIX epoch time in seconds. Please note that all timestamp are in GMT+0! For daily values with unix timestamps, please apply utc_offset_seconds again to get the correct date.\n",
"timezone\tString\tNo\tGMT\tIf timezone is set, all timestamps are returned as local-time and data is returned starting at 00:00 local-time. Any time zone name from the time zone database is supported. If auto is set as a time zone, the coordinates will be automatically resolved to the local time zone.\n",
"past_days\tInteger (0-2)\tNo\t0\tIf past_days is set, yesterday or the day before yesterday data are also returned.\n",
"start_date\n",
"end_date\tString (yyyy-mm-dd)\tNo\t\tThe time interval to get weather data. A day must be specified as an ISO8601 date (e.g. 2022-06-30).\n",
"models\tString array\tNo\tauto\tManually select one or more weather models. Per default, the best suitable weather models will be combined.\n",
"\n",
"Hourly Parameter Definition\n",
"The parameter &hourly= accepts the following values. Most weather variables are given as an instantaneous value for the indicated hour. Some variables like precipitation are calculated from the preceding hour as an average or sum.\n",
"\n",
"Variable\tValid time\tUnit\tDescription\n",
"temperature_2m\tInstant\t°C (°F)\tAir temperature at 2 meters above ground\n",
"relativehumidity_2m\tInstant\t%\tRelative humidity at 2 meters above ground\n",
"dewpoint_2m\tInstant\t°C (°F)\tDew point temperature at 2 meters above ground\n",
"apparent_temperature\tInstant\t°C (°F)\tApparent temperature is the perceived feels-like temperature combining wind chill factor, relative humidity and solar radiation\n",
"pressure_msl\n",
"surface_pressure\tInstant\thPa\tAtmospheric air pressure reduced to mean sea level (msl) or pressure at surface. Typically pressure on mean sea level is used in meteorology. Surface pressure gets lower with increasing elevation.\n",
"cloudcover\tInstant\t%\tTotal cloud cover as an area fraction\n",
"cloudcover_low\tInstant\t%\tLow level clouds and fog up to 3 km altitude\n",
"cloudcover_mid\tInstant\t%\tMid level clouds from 3 to 8 km altitude\n",
"cloudcover_high\tInstant\t%\tHigh level clouds from 8 km altitude\n",
"windspeed_10m\n",
"windspeed_80m\n",
"windspeed_120m\n",
"windspeed_180m\tInstant\tkm/h (mph, m/s, knots)\tWind speed at 10, 80, 120 or 180 meters above ground. Wind speed on 10 meters is the standard level.\n",
"winddirection_10m\n",
"winddirection_80m\n",
"winddirection_120m\n",
"winddirection_180m\tInstant\t°\tWind direction at 10, 80, 120 or 180 meters above ground\n",
"windgusts_10m\tPreceding hour max\tkm/h (mph, m/s, knots)\tGusts at 10 meters above ground as a maximum of the preceding hour\n",
"shortwave_radiation\tPreceding hour mean\tW/m²\tShortwave solar radiation as average of the preceding hour. This is equal to the total global horizontal irradiation\n",
"direct_radiation\n",
"direct_normal_irradiance\tPreceding hour mean\tW/m²\tDirect solar radiation as average of the preceding hour on the horizontal plane and the normal plane (perpendicular to the sun)\n",
"diffuse_radiation\tPreceding hour mean\tW/m²\tDiffuse solar radiation as average of the preceding hour\n",
"vapor_pressure_deficit\tInstant\tkPa\tVapor Pressure Deificit (VPD) in kilopascal (kPa). For high VPD (>1.6), water transpiration of plants increases. For low VPD (<0.4), transpiration decreases\n",
"evapotranspiration\tPreceding hour sum\tmm (inch)\tEvapotranspration from land surface and plants that weather models assumes for this location. Available soil water is considered. 1 mm evapotranspiration per hour equals 1 liter of water per spare meter.\n",
"et0_fao_evapotranspiration\tPreceding hour sum\tmm (inch)\tET₀ Reference Evapotranspiration of a well watered grass field. Based on FAO-56 Penman-Monteith equations ET₀ is calculated from temperature, wind speed, humidity and solar radiation. Unlimited soil water is assumed. ET₀ is commonly used to estimate the required irrigation for plants.\n",
"precipitation\tPreceding hour sum\tmm (inch)\tTotal precipitation (rain, showers, snow) sum of the preceding hour\n",
"snowfall\tPreceding hour sum\tcm (inch)\tSnowfall amount of the preceding hour in centimeters. For the water equivalent in millimeter, divide by 7. E.g. 7 cm snow = 10 mm precipitation water equivalent\n",
"rain\tPreceding hour sum\tmm (inch)\tRain from large scale weather systems of the preceding hour in millimeter\n",
"showers\tPreceding hour sum\tmm (inch)\tShowers from convective precipitation in millimeters from the preceding hour\n",
"weathercode\tInstant\tWMO code\tWeather condition as a numeric code. Follow WMO weather interpretation codes. See table below for details.\n",
"snow_depth\tInstant\tmeters\tSnow depth on the ground\n",
"freezinglevel_height\tInstant\tmeters\tAltitude above sea level of the 0°C level\n",
"visibility\tInstant\tmeters\tViewing distance in meters. Influenced by low clouds, humidity and aerosols. Maximum visibility is approximately 24 km.\n",
"soil_temperature_0cm\n",
"soil_temperature_6cm\n",
"soil_temperature_18cm\n",
"soil_temperature_54cm\tInstant\t°C (°F)\tTemperature in the soil at 0, 6, 18 and 54 cm depths. 0 cm is the surface temperature on land or water surface temperature on water.\n",
"soil_moisture_0_1cm\n",
"soil_moisture_1_3cm\n",
"soil_moisture_3_9cm\n",
"soil_moisture_9_27cm\n",
"soil_moisture_27_81cm\tInstant\tm³/m³\tAverage soil water content as volumetric mixing ratio at 0-1, 1-3, 3-9, 9-27 and 27-81 cm depths.\n",
"\n",
"Using that documentation, write a query that you could send to the open meteo api to answer the following question.\n",
"\n",
"Question: {question}\n",
"GET Request: /v1/forecast?\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b95e1af1",
"metadata": {},
"outputs": [],
"source": [
"prompt = Prompt(input_variables=[\"question\"], template=prompt_template)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f0f3db6e",
"metadata": {},
"outputs": [],
"source": [
"question_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt, output_key=\"meteo_query\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "939c1039",
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"def make_open_meteo_request(req):\n",
" return str(requests.get(f\"https://api.open-meteo.com/v1/forecast?{req}\").json())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "88d73fd3",
"metadata": {},
"outputs": [],
"source": [
"custom_chain = SimpleCustomChain(func=make_open_meteo_request, input_key=\"meteo_query\", output_key=\"meteo_answer\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c63dc349",
"metadata": {},
"outputs": [],
"source": [
"answer_prompt_template = \"\"\"You are given the below weather information:\n",
"\n",
"{meteo_answer}\n",
"\n",
"Now answer the following question:\n",
"\n",
"Question: {question}\n",
"Answer:\"\"\"\n",
"answer_prompt = Prompt(input_variables=[\"question\", \"meteo_answer\"], template=answer_prompt_template)\n",
"\n",
"answer_chain = LLMChain(llm=OpenAI(temperature=0), prompt=answer_prompt)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1842b72d",
"metadata": {},
"outputs": [],
"source": [
"pipeline = Pipeline(chains=[question_chain, custom_chain, answer_chain], input_variables=[\"question\"], verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e381155b",
"metadata": {},
"outputs": [],
"source": [
"question = \"is it snowing in boston?\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d9110ed5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[1mChain 0\u001b[0m:\n",
"{'meteo_query': 'latitude=42.3601&longitude=-71.0589&hourly=snowfall'}\n",
"\n",
"\u001b[1mChain 1\u001b[0m:\n",
"{'meteo_answer': \"{'latitude': 42.36515, 'longitude': -71.0618, 'generationtime_ms': 0.4010200500488281, 'utc_offset_seconds': 0, 'timezone': 'GMT', 'timezone_abbreviation': 'GMT', 'elevation': 9.0, 'hourly_units': {'time': 'iso8601', 'snowfall': 'cm'}, 'hourly': {'time': ['2022-11-19T00:00', '2022-11-19T01:00', '2022-11-19T02:00', '2022-11-19T03:00', '2022-11-19T04:00', '2022-11-19T05:00', '2022-11-19T06:00', '2022-11-19T07:00', '2022-11-19T08:00', '2022-11-19T09:00', '2022-11-19T10:00', '2022-11-19T11:00', '2022-11-19T12:00', '2022-11-19T13:00', '2022-11-19T14:00', '2022-11-19T15:00', '2022-11-19T16:00', '2022-11-19T17:00', '2022-11-19T18:00', '2022-11-19T19:00', '2022-11-19T20:00', '2022-11-19T21:00', '2022-11-19T22:00', '2022-11-19T23:00', '2022-11-20T00:00', '2022-11-20T01:00', '2022-11-20T02:00', '2022-11-20T03:00', '2022-11-20T04:00', '2022-11-20T05:00', '2022-11-20T06:00', '2022-11-20T07:00', '2022-11-20T08:00', '2022-11-20T09:00', '2022-11-20T10:00', '2022-11-20T11:00', '2022-11-20T12:00', '2022-11-20T13:00', '2022-11-20T14:00', '2022-11-20T15:00', '2022-11-20T16:00', '2022-11-20T17:00', '2022-11-20T18:00', '2022-11-20T19:00', '2022-11-20T20:00', '2022-11-20T21:00', '2022-11-20T22:00', '2022-11-20T23:00', '2022-11-21T00:00', '2022-11-21T01:00', '2022-11-21T02:00', '2022-11-21T03:00', '2022-11-21T04:00', '2022-11-21T05:00', '2022-11-21T06:00', '2022-11-21T07:00', '2022-11-21T08:00', '2022-11-21T09:00', '2022-11-21T10:00', '2022-11-21T11:00', '2022-11-21T12:00', '2022-11-21T13:00', '2022-11-21T14:00', '2022-11-21T15:00', '2022-11-21T16:00', '2022-11-21T17:00', '2022-11-21T18:00', '2022-11-21T19:00', '2022-11-21T20:00', '2022-11-21T21:00', '2022-11-21T22:00', '2022-11-21T23:00', '2022-11-22T00:00', '2022-11-22T01:00', '2022-11-22T02:00', '2022-11-22T03:00', '2022-11-22T04:00', '2022-11-22T05:00', '2022-11-22T06:00', '2022-11-22T07:00', '2022-11-22T08:00', '2022-11-22T09:00', '2022-11-22T10:00', '2022-11-22T11:00', '2022-11-22T12:00', '2022-11-22T13:00', '2022-11-22T14:00', '2022-11-22T15:00', '2022-11-22T16:00', '2022-11-22T17:00', '2022-11-22T18:00', '2022-11-22T19:00', '2022-11-22T20:00', '2022-11-22T21:00', '2022-11-22T22:00', '2022-11-22T23:00', '2022-11-23T00:00', '2022-11-23T01:00', '2022-11-23T02:00', '2022-11-23T03:00', '2022-11-23T04:00', '2022-11-23T05:00', '2022-11-23T06:00', '2022-11-23T07:00', '2022-11-23T08:00', '2022-11-23T09:00', '2022-11-23T10:00', '2022-11-23T11:00', '2022-11-23T12:00', '2022-11-23T13:00', '2022-11-23T14:00', '2022-11-23T15:00', '2022-11-23T16:00', '2022-11-23T17:00', '2022-11-23T18:00', '2022-11-23T19:00', '2022-11-23T20:00', '2022-11-23T21:00', '2022-11-23T22:00', '2022-11-23T23:00', '2022-11-24T00:00', '2022-11-24T01:00', '2022-11-24T02:00', '2022-11-24T03:00', '2022-11-24T04:00', '2022-11-24T05:00', '2022-11-24T06:00', '2022-11-24T07:00', '2022-11-24T08:00', '2022-11-24T09:00', '2022-11-24T10:00', '2022-11-24T11:00', '2022-11-24T12:00', '2022-11-24T13:00', '2022-11-24T14:00', '2022-11-24T15:00', '2022-11-24T16:00', '2022-11-24T17:00', '2022-11-24T18:00', '2022-11-24T19:00', '2022-11-24T20:00', '2022-11-24T21:00', '2022-11-24T22:00', '2022-11-24T23:00', '2022-11-25T00:00', '2022-11-25T01:00', '2022-11-25T02:00', '2022-11-25T03:00', '2022-11-25T04:00', '2022-11-25T05:00', '2022-11-25T06:00', '2022-11-25T07:00', '2022-11-25T08:00', '2022-11-25T09:00', '2022-11-25T10:00', '2022-11-25T11:00', '2022-11-25T12:00', '2022-11-25T13:00', '2022-11-25T14:00', '2022-11-25T15:00', '2022-11-25T16:00', '2022-11-25T17:00', '2022-11-25T18:00', '2022-11-25T19:00', '2022-11-25T20:00', '2022-11-25T21:00', '2022-11-25T22:00', '2022-11-25T23:00'], 'snowfall': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}}\"}\n",
"\n",
"\u001b[1mChain 2\u001b[0m:\n",
"{'text': ' No'}\n",
"\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' No'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipeline.run(question)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -38,7 +38,9 @@ class Chain(BaseModel, ABC):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output."""
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def __call__(
self, inputs: Dict[str, Any], return_only_outputs: bool = False
) -> Dict[str, str]:
"""Run the logic of this chain and add to output."""
self._validate_inputs(inputs)
if self.verbose:
@ -47,7 +49,10 @@ class Chain(BaseModel, ABC):
if self.verbose:
print("\n\033[1m> Finished chain.\033[0m")
self._validate_outputs(outputs)
return {**inputs, **outputs}
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""

@ -0,0 +1,41 @@
"""Custom chain class."""
from typing import Callable, Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
class SimpleCustomChain(Chain, BaseModel):
"""Custom chain with single string input/output."""
func: Callable[[str], str]
"""Custom callable function."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_input = inputs[self.input_key]
output = self.func(_input)
return {self.output_key: output}

@ -0,0 +1,78 @@
"""Chain pipeline where the outputs of one step feed directly into next."""
from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
class Pipeline(Chain, BaseModel):
"""Chain pipeline where the outputs of one step feed directly into next."""
chains: List[Chain]
input_variables: List[str]
output_variables: List[str] #: :meta private:
return_all: bool = False
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return self.input_variables
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return self.output_variables
@root_validator(pre=True)
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that the correct inputs exist for all chains."""
chains = values["chains"]
input_variables = values["input_variables"]
known_variables = set(input_variables)
for chain in chains:
missing_vars = set(chain.input_keys).difference(known_variables)
if missing_vars:
raise ValueError(f"Missing required input keys: {missing_vars}")
overlapping_keys = known_variables.intersection(chain.output_keys)
if overlapping_keys:
raise ValueError(
f"Chain returned keys that already exist: {overlapping_keys}"
)
known_variables |= set(chain.output_keys)
if "output_variables" not in values:
if values.get("return_all", False):
output_keys = known_variables.difference(input_variables)
else:
output_keys = chains[-1].output_keys
values["output_variables"] = output_keys
else:
missing_vars = known_variables.difference(values["output_variables"])
if missing_vars:
raise ValueError(
f"Expected output variables that were not found: {missing_vars}."
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
known_values = inputs.copy()
for i, chain in enumerate(self.chains):
outputs = chain(known_values, return_only_outputs=True)
if self.verbose:
print(f"\033[1mChain {i}\033[0m:\n{outputs}\n")
known_values.update(outputs)
return {k: known_values[k] for k in self.output_variables}

@ -0,0 +1,59 @@
"""Simple chain pipeline where the outputs of one step feed directly into next."""
from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
class SimplePipeline(Chain, BaseModel):
"""Simple chain pipeline where the outputs of one step feed directly into next."""
chains: List[Chain]
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
@root_validator()
def validate_chains(cls, values: Dict) -> Dict:
"""Validate that chains are all single input/output."""
for chain in values["chains"]:
if len(chain.input_keys) != 1:
raise ValueError(
"Chains used in SimplePipeline should all have one input, got "
f"{chain} with {len(chain.input_keys)} inputs."
)
if len(chain.output_keys) != 1:
raise ValueError(
"Chains used in SimplePipeline should all have one output, got "
f"{chain} with {len(chain.output_keys)} outputs."
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_input = inputs[self.input_key]
for chain in self.chains:
_input = chain.run(_input)
return {self.output_key: _input}

@ -0,0 +1,103 @@
"""Test pipeline functionality."""
from typing import Dict, List
import pytest
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.pipeline import Pipeline
class FakeChain(Chain, BaseModel):
"""Fake Chain for testing purposes."""
input_variables: List[str]
output_variables: List[str]
@property
def input_keys(self) -> List[str]:
"""Input keys this chain returns."""
return self.input_variables
@property
def output_keys(self) -> List[str]:
"""Input keys this chain returns."""
return self.output_variables
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
outputs = {}
for var in self.output_variables:
variables = [inputs[k] for k in self.input_variables]
outputs[var] = " ".join(variables) + "foo"
return outputs
def test_pipeline_usage_single_inputs() -> None:
"""Test pipeline on single input chains."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo"])
output = pipeline({"foo": "123"})
expected_output = {"bar": "123foo", "baz": "123foofoo", "foo": "123"}
assert output == expected_output
def test_pipeline_usage_multiple_inputs() -> None:
"""Test pipeline on multiple input chains."""
chain_1 = FakeChain(input_variables=["foo", "test"], output_variables=["bar"])
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo", "test"])
output = pipeline({"foo": "123", "test": "456"})
expected_output = {
"bar": "123 456foo",
"baz": "123 456foo 123foo",
"foo": "123",
"test": "456",
}
assert output == expected_output
def test_pipeline_usage_multiple_outputs() -> None:
"""Test pipeline usage on multiple output chains."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
pipeline = Pipeline(chains=[chain_1, chain_2], input_variables=["foo"])
output = pipeline({"foo": "123"})
expected_output = {
"bar": "123foo",
"baz": "123foo 123foo",
"foo": "123",
"test": "123foo",
}
assert output == expected_output
def test_pipeline_missing_inputs() -> None:
"""Test error is raised when input variables are missing."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
chain_2 = FakeChain(input_variables=["bar", "test"], output_variables=["baz"])
with pytest.raises(ValueError):
# Also needs "test" as an input
Pipeline(chains=[chain_1, chain_2], input_variables=["foo"])
def test_pipeline_bad_outputs() -> None:
"""Test error is raised when bad outputs are specified."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
with pytest.raises(ValueError):
# "test" is not present as an output variable.
Pipeline(
chains=[chain_1, chain_2],
input_variables=["foo"],
output_variables=["test"],
)
def test_pipeline_overlapping_inputs() -> None:
"""Test error is raised when input variables are overlapping."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
with pytest.raises(ValueError):
# "test" is specified as an input, but also is an output of one step
Pipeline(chains=[chain_1, chain_2], input_variables=["foo", "test"])

@ -0,0 +1,59 @@
"""Test functionality around the simple pipeline chain."""
from typing import Dict, List
import pytest
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.simple_pipeline import SimplePipeline
class FakeChain(Chain, BaseModel):
"""Fake chain for testing purposes."""
input_variables: List[str]
output_variables: List[str]
@property
def input_keys(self) -> List[str]:
"""Input keys this chain returns."""
return self.input_variables
@property
def output_keys(self) -> List[str]:
"""Input keys this chain returns."""
return self.output_variables
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
outputs = {}
for var in self.output_variables:
variables = [inputs[k] for k in self.input_variables]
outputs[var] = " ".join(variables) + "foo"
return outputs
def test_pipeline_functionality() -> None:
"""Test simple pipeline functionality."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
pipeline = SimplePipeline(chains=[chain_1, chain_2])
output = pipeline({"input": "123"})
expected_output = {"output": "123foofoo", "input": "123"}
assert output == expected_output
def test_multi_input_errors() -> None:
"""Test pipeline errors if multiple input variables are expected."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
chain_2 = FakeChain(input_variables=["bar", "foo"], output_variables=["baz"])
with pytest.raises(ValueError):
SimplePipeline(chains=[chain_1, chain_2])
def test_multi_output_errors() -> None:
"""Test pipeline errors if multiple output variables are expected."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "grok"])
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
with pytest.raises(ValueError):
SimplePipeline(chains=[chain_1, chain_2])
Loading…
Cancel
Save