custom chain and example

harrison/custom_pipeline
Harrison Chase 2 years ago
parent b5325c212b
commit 6766dff804

@ -0,0 +1,243 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "dd2aa1bb",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.pipeline import Pipeline\n",
"from langchain.chains.custom import SimpleCustomChain\n",
"from langchain.llms import OpenAI\n",
"from langchain.chains import LLMChain\n",
"from langchain import Prompt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "05390d95",
"metadata": {},
"outputs": [],
"source": [
"prompt_template = \"\"\"You are given the below API Documentation:\n",
"\n",
"API Documentation\n",
"The API endpoint /v1/forecast accepts a geographical coordinate, a list of weather variables and responds with a JSON hourly weather forecast for 7 days. Time always starts at 0:00 today and contains 168 hours. All URL parameters are listed below:\n",
"\n",
"Parameter\tFormat\tRequired\tDefault\tDescription\n",
"latitude, longitude\tFloating point\tYes\t\tGeographical WGS84 coordinate of the location\n",
"hourly\tString array\tNo\t\tA list of weather variables which should be returned. Values can be comma separated, or multiple &hourly= parameter in the URL can be used.\n",
"daily\tString array\tNo\t\tA list of daily weather variable aggregations which should be returned. Values can be comma separated, or multiple &daily= parameter in the URL can be used. If daily weather variables are specified, parameter timezone is required.\n",
"current_weather\tBool\tNo\tfalse\tInclude current weather conditions in the JSON output.\n",
"temperature_unit\tString\tNo\tcelsius\tIf fahrenheit is set, all temperature values are converted to Fahrenheit.\n",
"windspeed_unit\tString\tNo\tkmh\tOther wind speed speed units: ms, mph and kn\n",
"precipitation_unit\tString\tNo\tmm\tOther precipitation amount units: inch\n",
"timeformat\tString\tNo\tiso8601\tIf format unixtime is selected, all time values are returned in UNIX epoch time in seconds. Please note that all timestamp are in GMT+0! For daily values with unix timestamps, please apply utc_offset_seconds again to get the correct date.\n",
"timezone\tString\tNo\tGMT\tIf timezone is set, all timestamps are returned as local-time and data is returned starting at 00:00 local-time. Any time zone name from the time zone database is supported. If auto is set as a time zone, the coordinates will be automatically resolved to the local time zone.\n",
"past_days\tInteger (0-2)\tNo\t0\tIf past_days is set, yesterday or the day before yesterday data are also returned.\n",
"start_date\n",
"end_date\tString (yyyy-mm-dd)\tNo\t\tThe time interval to get weather data. A day must be specified as an ISO8601 date (e.g. 2022-06-30).\n",
"models\tString array\tNo\tauto\tManually select one or more weather models. Per default, the best suitable weather models will be combined.\n",
"\n",
"Hourly Parameter Definition\n",
"The parameter &hourly= accepts the following values. Most weather variables are given as an instantaneous value for the indicated hour. Some variables like precipitation are calculated from the preceding hour as an average or sum.\n",
"\n",
"Variable\tValid time\tUnit\tDescription\n",
"temperature_2m\tInstant\t°C (°F)\tAir temperature at 2 meters above ground\n",
"relativehumidity_2m\tInstant\t%\tRelative humidity at 2 meters above ground\n",
"dewpoint_2m\tInstant\t°C (°F)\tDew point temperature at 2 meters above ground\n",
"apparent_temperature\tInstant\t°C (°F)\tApparent temperature is the perceived feels-like temperature combining wind chill factor, relative humidity and solar radiation\n",
"pressure_msl\n",
"surface_pressure\tInstant\thPa\tAtmospheric air pressure reduced to mean sea level (msl) or pressure at surface. Typically pressure on mean sea level is used in meteorology. Surface pressure gets lower with increasing elevation.\n",
"cloudcover\tInstant\t%\tTotal cloud cover as an area fraction\n",
"cloudcover_low\tInstant\t%\tLow level clouds and fog up to 3 km altitude\n",
"cloudcover_mid\tInstant\t%\tMid level clouds from 3 to 8 km altitude\n",
"cloudcover_high\tInstant\t%\tHigh level clouds from 8 km altitude\n",
"windspeed_10m\n",
"windspeed_80m\n",
"windspeed_120m\n",
"windspeed_180m\tInstant\tkm/h (mph, m/s, knots)\tWind speed at 10, 80, 120 or 180 meters above ground. Wind speed on 10 meters is the standard level.\n",
"winddirection_10m\n",
"winddirection_80m\n",
"winddirection_120m\n",
"winddirection_180m\tInstant\t°\tWind direction at 10, 80, 120 or 180 meters above ground\n",
"windgusts_10m\tPreceding hour max\tkm/h (mph, m/s, knots)\tGusts at 10 meters above ground as a maximum of the preceding hour\n",
"shortwave_radiation\tPreceding hour mean\tW/m²\tShortwave solar radiation as average of the preceding hour. This is equal to the total global horizontal irradiation\n",
"direct_radiation\n",
"direct_normal_irradiance\tPreceding hour mean\tW/m²\tDirect solar radiation as average of the preceding hour on the horizontal plane and the normal plane (perpendicular to the sun)\n",
"diffuse_radiation\tPreceding hour mean\tW/m²\tDiffuse solar radiation as average of the preceding hour\n",
"vapor_pressure_deficit\tInstant\tkPa\tVapor Pressure Deificit (VPD) in kilopascal (kPa). For high VPD (>1.6), water transpiration of plants increases. For low VPD (<0.4), transpiration decreases\n",
"evapotranspiration\tPreceding hour sum\tmm (inch)\tEvapotranspration from land surface and plants that weather models assumes for this location. Available soil water is considered. 1 mm evapotranspiration per hour equals 1 liter of water per spare meter.\n",
"et0_fao_evapotranspiration\tPreceding hour sum\tmm (inch)\tET₀ Reference Evapotranspiration of a well watered grass field. Based on FAO-56 Penman-Monteith equations ET₀ is calculated from temperature, wind speed, humidity and solar radiation. Unlimited soil water is assumed. ET₀ is commonly used to estimate the required irrigation for plants.\n",
"precipitation\tPreceding hour sum\tmm (inch)\tTotal precipitation (rain, showers, snow) sum of the preceding hour\n",
"snowfall\tPreceding hour sum\tcm (inch)\tSnowfall amount of the preceding hour in centimeters. For the water equivalent in millimeter, divide by 7. E.g. 7 cm snow = 10 mm precipitation water equivalent\n",
"rain\tPreceding hour sum\tmm (inch)\tRain from large scale weather systems of the preceding hour in millimeter\n",
"showers\tPreceding hour sum\tmm (inch)\tShowers from convective precipitation in millimeters from the preceding hour\n",
"weathercode\tInstant\tWMO code\tWeather condition as a numeric code. Follow WMO weather interpretation codes. See table below for details.\n",
"snow_depth\tInstant\tmeters\tSnow depth on the ground\n",
"freezinglevel_height\tInstant\tmeters\tAltitude above sea level of the 0°C level\n",
"visibility\tInstant\tmeters\tViewing distance in meters. Influenced by low clouds, humidity and aerosols. Maximum visibility is approximately 24 km.\n",
"soil_temperature_0cm\n",
"soil_temperature_6cm\n",
"soil_temperature_18cm\n",
"soil_temperature_54cm\tInstant\t°C (°F)\tTemperature in the soil at 0, 6, 18 and 54 cm depths. 0 cm is the surface temperature on land or water surface temperature on water.\n",
"soil_moisture_0_1cm\n",
"soil_moisture_1_3cm\n",
"soil_moisture_3_9cm\n",
"soil_moisture_9_27cm\n",
"soil_moisture_27_81cm\tInstant\tm³/m³\tAverage soil water content as volumetric mixing ratio at 0-1, 1-3, 3-9, 9-27 and 27-81 cm depths.\n",
"\n",
"Using that documentation, write a query that you could send to the open meteo api to answer the following question.\n",
"\n",
"Question: {question}\n",
"GET Request: /v1/forecast?\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b95e1af1",
"metadata": {},
"outputs": [],
"source": [
"prompt = Prompt(input_variables=[\"question\"], template=prompt_template)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f0f3db6e",
"metadata": {},
"outputs": [],
"source": [
"question_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt, output_key=\"meteo_query\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "939c1039",
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"def make_open_meteo_request(req):\n",
" return str(requests.get(f\"https://api.open-meteo.com/v1/forecast?{req}\").json())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "88d73fd3",
"metadata": {},
"outputs": [],
"source": [
"custom_chain = SimpleCustomChain(func=make_open_meteo_request, input_key=\"meteo_query\", output_key=\"meteo_answer\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c63dc349",
"metadata": {},
"outputs": [],
"source": [
"answer_prompt_template = \"\"\"You are given the below weather information:\n",
"\n",
"{meteo_answer}\n",
"\n",
"Now answer the following question:\n",
"\n",
"Question: {question}\n",
"Answer:\"\"\"\n",
"answer_prompt = Prompt(input_variables=[\"question\", \"meteo_answer\"], template=answer_prompt_template)\n",
"\n",
"answer_chain = LLMChain(llm=OpenAI(temperature=0), prompt=answer_prompt)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1842b72d",
"metadata": {},
"outputs": [],
"source": [
"pipeline = Pipeline(chains=[question_chain, custom_chain, answer_chain], input_variables=[\"question\"], verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e381155b",
"metadata": {},
"outputs": [],
"source": [
"question = \"is it snowing in boston?\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d9110ed5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"\u001b[1mChain 0\u001b[0m:\n",
"{'meteo_query': 'latitude=42.3601&longitude=-71.0589&hourly=snowfall'}\n",
"\n",
"\u001b[1mChain 1\u001b[0m:\n",
"{'meteo_answer': \"{'latitude': 42.36515, 'longitude': -71.0618, 'generationtime_ms': 0.4010200500488281, 'utc_offset_seconds': 0, 'timezone': 'GMT', 'timezone_abbreviation': 'GMT', 'elevation': 9.0, 'hourly_units': {'time': 'iso8601', 'snowfall': 'cm'}, 'hourly': {'time': ['2022-11-19T00:00', '2022-11-19T01:00', '2022-11-19T02:00', '2022-11-19T03:00', '2022-11-19T04:00', '2022-11-19T05:00', '2022-11-19T06:00', '2022-11-19T07:00', '2022-11-19T08:00', '2022-11-19T09:00', '2022-11-19T10:00', '2022-11-19T11:00', '2022-11-19T12:00', '2022-11-19T13:00', '2022-11-19T14:00', '2022-11-19T15:00', '2022-11-19T16:00', '2022-11-19T17:00', '2022-11-19T18:00', '2022-11-19T19:00', '2022-11-19T20:00', '2022-11-19T21:00', '2022-11-19T22:00', '2022-11-19T23:00', '2022-11-20T00:00', '2022-11-20T01:00', '2022-11-20T02:00', '2022-11-20T03:00', '2022-11-20T04:00', '2022-11-20T05:00', '2022-11-20T06:00', '2022-11-20T07:00', '2022-11-20T08:00', '2022-11-20T09:00', '2022-11-20T10:00', '2022-11-20T11:00', '2022-11-20T12:00', '2022-11-20T13:00', '2022-11-20T14:00', '2022-11-20T15:00', '2022-11-20T16:00', '2022-11-20T17:00', '2022-11-20T18:00', '2022-11-20T19:00', '2022-11-20T20:00', '2022-11-20T21:00', '2022-11-20T22:00', '2022-11-20T23:00', '2022-11-21T00:00', '2022-11-21T01:00', '2022-11-21T02:00', '2022-11-21T03:00', '2022-11-21T04:00', '2022-11-21T05:00', '2022-11-21T06:00', '2022-11-21T07:00', '2022-11-21T08:00', '2022-11-21T09:00', '2022-11-21T10:00', '2022-11-21T11:00', '2022-11-21T12:00', '2022-11-21T13:00', '2022-11-21T14:00', '2022-11-21T15:00', '2022-11-21T16:00', '2022-11-21T17:00', '2022-11-21T18:00', '2022-11-21T19:00', '2022-11-21T20:00', '2022-11-21T21:00', '2022-11-21T22:00', '2022-11-21T23:00', '2022-11-22T00:00', '2022-11-22T01:00', '2022-11-22T02:00', '2022-11-22T03:00', '2022-11-22T04:00', '2022-11-22T05:00', '2022-11-22T06:00', '2022-11-22T07:00', '2022-11-22T08:00', '2022-11-22T09:00', '2022-11-22T10:00', '2022-11-22T11:00', '2022-11-22T12:00', '2022-11-22T13:00', '2022-11-22T14:00', '2022-11-22T15:00', '2022-11-22T16:00', '2022-11-22T17:00', '2022-11-22T18:00', '2022-11-22T19:00', '2022-11-22T20:00', '2022-11-22T21:00', '2022-11-22T22:00', '2022-11-22T23:00', '2022-11-23T00:00', '2022-11-23T01:00', '2022-11-23T02:00', '2022-11-23T03:00', '2022-11-23T04:00', '2022-11-23T05:00', '2022-11-23T06:00', '2022-11-23T07:00', '2022-11-23T08:00', '2022-11-23T09:00', '2022-11-23T10:00', '2022-11-23T11:00', '2022-11-23T12:00', '2022-11-23T13:00', '2022-11-23T14:00', '2022-11-23T15:00', '2022-11-23T16:00', '2022-11-23T17:00', '2022-11-23T18:00', '2022-11-23T19:00', '2022-11-23T20:00', '2022-11-23T21:00', '2022-11-23T22:00', '2022-11-23T23:00', '2022-11-24T00:00', '2022-11-24T01:00', '2022-11-24T02:00', '2022-11-24T03:00', '2022-11-24T04:00', '2022-11-24T05:00', '2022-11-24T06:00', '2022-11-24T07:00', '2022-11-24T08:00', '2022-11-24T09:00', '2022-11-24T10:00', '2022-11-24T11:00', '2022-11-24T12:00', '2022-11-24T13:00', '2022-11-24T14:00', '2022-11-24T15:00', '2022-11-24T16:00', '2022-11-24T17:00', '2022-11-24T18:00', '2022-11-24T19:00', '2022-11-24T20:00', '2022-11-24T21:00', '2022-11-24T22:00', '2022-11-24T23:00', '2022-11-25T00:00', '2022-11-25T01:00', '2022-11-25T02:00', '2022-11-25T03:00', '2022-11-25T04:00', '2022-11-25T05:00', '2022-11-25T06:00', '2022-11-25T07:00', '2022-11-25T08:00', '2022-11-25T09:00', '2022-11-25T10:00', '2022-11-25T11:00', '2022-11-25T12:00', '2022-11-25T13:00', '2022-11-25T14:00', '2022-11-25T15:00', '2022-11-25T16:00', '2022-11-25T17:00', '2022-11-25T18:00', '2022-11-25T19:00', '2022-11-25T20:00', '2022-11-25T21:00', '2022-11-25T22:00', '2022-11-25T23:00'], 'snowfall': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}}\"}\n",
"\n",
"\u001b[1mChain 2\u001b[0m:\n",
"{'text': ' No'}\n",
"\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' No'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipeline.run(question)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -38,7 +38,9 @@ class Chain(BaseModel, ABC):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output."""
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def __call__(
self, inputs: Dict[str, Any], return_only_outputs: bool = False
) -> Dict[str, str]:
"""Run the logic of this chain and add to output."""
self._validate_inputs(inputs)
if self.verbose:
@ -47,7 +49,10 @@ class Chain(BaseModel, ABC):
if self.verbose:
print("\n\033[1m> Finished chain.\033[0m")
self._validate_outputs(outputs)
return {**inputs, **outputs}
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""

@ -0,0 +1,41 @@
"""Custom chain class."""
from typing import Callable, Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
class SimpleCustomChain(Chain, BaseModel):
"""Custom chain with single string input/output."""
func: Callable[[str], str]
"""Custom callable function."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_input = inputs[self.input_key]
output = self.func(_input)
return {self.output_key: output}

@ -13,6 +13,7 @@ class Pipeline(Chain, BaseModel):
chains: List[Chain]
input_variables: List[str]
output_variables: List[str] #: :meta private:
return_all: bool = False
class Config:
"""Configuration for this pydantic object."""
@ -54,7 +55,11 @@ class Pipeline(Chain, BaseModel):
known_variables |= set(chain.output_keys)
if "output_variables" not in values:
values["output_variables"] = known_variables.difference(input_variables)
if values.get("return_all", False):
output_keys = known_variables.difference(input_variables)
else:
output_keys = chains[-1].output_keys
values["output_variables"] = output_keys
else:
missing_vars = known_variables.difference(values["output_variables"])
if missing_vars:
@ -65,7 +70,9 @@ class Pipeline(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
known_values = inputs.copy()
for chain in self.chains:
outputs = chain(known_values)
for i, chain in enumerate(self.chains):
outputs = chain(known_values, return_only_outputs=True)
if self.verbose:
print(f"\033[1mChain {i}\033[0m:\n{outputs}\n")
known_values.update(outputs)
return {k: known_values[k] for k in self.output_variables}

Loading…
Cancel
Save