mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
387 lines
11 KiB
Plaintext
387 lines
11 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "g9EmNu5DD9YI"
|
||
|
},
|
||
|
"source": [
|
||
|
"# Custom functions with OpenAI Functions Agent\n",
|
||
|
"\n",
|
||
|
"This notebook goes through how to integrate custom functions with OpenAI Functions agent."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "LFKylC3CPtTl"
|
||
|
},
|
||
|
"source": [
|
||
|
"Install libraries which are required to run this example notebook\n",
|
||
|
"\n",
|
||
|
"`pip install -q openai langchain yfinance`"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "E2DqzmEGDPak"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Define custom functions"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {
|
||
|
"id": "SiucthMs6SIK"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import yfinance as yf\n",
|
||
|
"from datetime import datetime, timedelta\n",
|
||
|
"\n",
|
||
|
"def get_current_stock_price(ticker):\n",
|
||
|
" \"\"\"Method to get current stock price\"\"\"\n",
|
||
|
"\n",
|
||
|
" ticker_data = yf.Ticker(ticker)\n",
|
||
|
" recent = ticker_data.history(period='1d')\n",
|
||
|
" return {\n",
|
||
|
" 'price': recent.iloc[0]['Close'],\n",
|
||
|
" 'currency': ticker_data.info['currency']\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
"def get_stock_performance(ticker, days):\n",
|
||
|
" \"\"\"Method to get stock price change in percentage\"\"\"\n",
|
||
|
"\n",
|
||
|
" past_date = datetime.today() - timedelta(days=days)\n",
|
||
|
" ticker_data = yf.Ticker(ticker)\n",
|
||
|
" history = ticker_data.history(start=past_date)\n",
|
||
|
" old_price = history.iloc[0]['Close']\n",
|
||
|
" current_price = history.iloc[-1]['Close']\n",
|
||
|
" return {\n",
|
||
|
" 'percent_change': ((current_price - old_price)/old_price)*100\n",
|
||
|
" }"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "vRLINGvQR1rO",
|
||
|
"outputId": "68230a4b-dda2-4273-b956-7439661e3785"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"{'price': 334.57000732421875, 'currency': 'USD'}"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"get_current_stock_price('MSFT')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "57T190q235mD",
|
||
|
"outputId": "c6ee66ec-0659-4632-85d1-263b08826e68"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"{'percent_change': 1.014466941163018}"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"get_stock_performance('MSFT', 30)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "MT8QsdyBDhwg"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Make custom tools"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {
|
||
|
"id": "NvLOUv-XP3Ap"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from typing import Type\n",
|
||
|
"from pydantic import BaseModel, Field\n",
|
||
|
"from langchain.tools import BaseTool\n",
|
||
|
"\n",
|
||
|
"class CurrentStockPriceInput(BaseModel):\n",
|
||
|
" \"\"\"Inputs for get_current_stock_price\"\"\"\n",
|
||
|
" ticker: str = Field(description=\"Ticker symbol of the stock\")\n",
|
||
|
"\n",
|
||
|
"class CurrentStockPriceTool(BaseTool):\n",
|
||
|
" name = \"get_current_stock_price\"\n",
|
||
|
" description = \"\"\"\n",
|
||
|
" Useful when you want to get current stock price.\n",
|
||
|
" You should enter the stock ticker symbol recognized by the yahoo finance\n",
|
||
|
" \"\"\"\n",
|
||
|
" args_schema: Type[BaseModel] = CurrentStockPriceInput\n",
|
||
|
"\n",
|
||
|
" def _run(self, ticker: str):\n",
|
||
|
" price_response = get_current_stock_price(ticker)\n",
|
||
|
" return price_response\n",
|
||
|
"\n",
|
||
|
" def _arun(self, ticker: str):\n",
|
||
|
" raise NotImplementedError(\"get_current_stock_price does not support async\")\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class StockPercentChangeInput(BaseModel):\n",
|
||
|
" \"\"\"Inputs for get_stock_performance\"\"\"\n",
|
||
|
" ticker: str = Field(description=\"Ticker symbol of the stock\")\n",
|
||
|
" days: int = Field(description='Timedelta days to get past date from current date')\n",
|
||
|
"\n",
|
||
|
"class StockPerformanceTool(BaseTool):\n",
|
||
|
" name = \"get_stock_performance\"\n",
|
||
|
" description = \"\"\"\n",
|
||
|
" Useful when you want to check performance of the stock.\n",
|
||
|
" You should enter the stock ticker symbol recognized by the yahoo finance.\n",
|
||
|
" You should enter days as number of days from today from which performance needs to be check.\n",
|
||
|
" output will be the change in the stock price represented as a percentage.\n",
|
||
|
" \"\"\"\n",
|
||
|
" args_schema: Type[BaseModel] = StockPercentChangeInput\n",
|
||
|
"\n",
|
||
|
" def _run(self, ticker: str, days: int):\n",
|
||
|
" response = get_stock_performance(ticker, days)\n",
|
||
|
" return response\n",
|
||
|
"\n",
|
||
|
" def _arun(self, ticker: str):\n",
|
||
|
" raise NotImplementedError(\"get_stock_performance does not support async\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "PVKoqeCyFKHF"
|
||
|
},
|
||
|
"source": [
|
||
|
"## Create Agent"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {
|
||
|
"id": "yY7qNB7vSQGh"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from langchain.agents import AgentType\n",
|
||
|
"from langchain.chat_models import ChatOpenAI\n",
|
||
|
"from langchain.agents import initialize_agent\n",
|
||
|
"\n",
|
||
|
"llm = ChatOpenAI(\n",
|
||
|
" model=\"gpt-3.5-turbo-0613\",\n",
|
||
|
" temperature=0\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"tools = [\n",
|
||
|
" CurrentStockPriceTool(),\n",
|
||
|
" StockPerformanceTool()\n",
|
||
|
"]\n",
|
||
|
"\n",
|
||
|
"agent = initialize_agent(tools, llm, agent=AgentType.OPENAI_FUNCTIONS, verbose=True)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/",
|
||
|
"height": 321
|
||
|
},
|
||
|
"id": "4X96xmgwRkcC",
|
||
|
"outputId": "a91b13ef-9643-4f60-d067-c4341e0b285e"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||
|
"\u001b[32;1m\u001b[1;3m\n",
|
||
|
"Invoking: `get_current_stock_price` with `{'ticker': 'MSFT'}`\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[0m\u001b[36;1m\u001b[1;3m{'price': 334.57000732421875, 'currency': 'USD'}\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||
|
"Invoking: `get_stock_performance` with `{'ticker': 'MSFT', 'days': 180}`\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[0m\u001b[33;1m\u001b[1;3m{'percent_change': 40.163963297187905}\u001b[0m\u001b[32;1m\u001b[1;3mThe current price of Microsoft stock is $334.57 USD. \n",
|
||
|
"\n",
|
||
|
"Over the past 6 months, Microsoft stock has performed well with a 40.16% increase in its price.\u001b[0m\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"'The current price of Microsoft stock is $334.57 USD. \\n\\nOver the past 6 months, Microsoft stock has performed well with a 40.16% increase in its price.'"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"agent.run(\"What is the current price of Microsoft stock? How it has performed over past 6 months?\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/",
|
||
|
"height": 285
|
||
|
},
|
||
|
"id": "nkZ_vmAcT7Al",
|
||
|
"outputId": "092ebc55-4d28-4a4b-aa2a-98ae47ceec20"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||
|
"\u001b[32;1m\u001b[1;3m\n",
|
||
|
"Invoking: `get_current_stock_price` with `{'ticker': 'GOOGL'}`\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[0m\u001b[36;1m\u001b[1;3m{'price': 118.33000183105469, 'currency': 'USD'}\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||
|
"Invoking: `get_current_stock_price` with `{'ticker': 'META'}`\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[0m\u001b[36;1m\u001b[1;3m{'price': 287.04998779296875, 'currency': 'USD'}\u001b[0m\u001b[32;1m\u001b[1;3mThe recent stock price of Google (GOOGL) is $118.33 USD and the recent stock price of Meta (META) is $287.05 USD.\u001b[0m\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"'The recent stock price of Google (GOOGL) is $118.33 USD and the recent stock price of Meta (META) is $287.05 USD.'"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"agent.run(\"Give me recent stock prices of Google and Meta?\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/",
|
||
|
"height": 466
|
||
|
},
|
||
|
"id": "jLU-HjMq7n1o",
|
||
|
"outputId": "a42194dd-26ed-4b5a-d4a2-1038420045c4"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||
|
"\u001b[32;1m\u001b[1;3m\n",
|
||
|
"Invoking: `get_stock_performance` with `{'ticker': 'MSFT', 'days': 90}`\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[0m\u001b[33;1m\u001b[1;3m{'percent_change': 18.043096235165596}\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
||
|
"Invoking: `get_stock_performance` with `{'ticker': 'GOOGL', 'days': 90}`\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[0m\u001b[33;1m\u001b[1;3m{'percent_change': 17.286155760642853}\u001b[0m\u001b[32;1m\u001b[1;3mIn the past 3 months, Microsoft (MSFT) has performed better than Google (GOOGL). Microsoft's stock price has increased by 18.04% while Google's stock price has increased by 17.29%.\u001b[0m\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"\"In the past 3 months, Microsoft (MSFT) has performed better than Google (GOOGL). Microsoft's stock price has increased by 18.04% while Google's stock price has increased by 17.29%.\""
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"agent.run('In the past 3 months, which stock between Microsoft and Google has performed the best?')"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"provenance": []
|
||
|
},
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.16"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 1
|
||
|
}
|