mirror of https://github.com/hwchase17/langchain
Enable streaming for OpenAI LLM (#986)
* Support a callback `on_llm_new_token` that users can implement when `OpenAI.streaming` is set to `True`pull/1057/head
parent
f05f025e41
commit
caa8e4742e
@ -0,0 +1,140 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6eaf7e66-f49c-42da-8d11-22ea13bef718",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Streaming with LLMs\n",
|
||||
"\n",
|
||||
"LangChain provides streaming support for LLMs. Currently, we only support streaming for the `OpenAI` LLM implementation, but streaming support for other LLM implementations is on the roadmap. To utilize streaming, use a [`CallbackHandler`](https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/base.py) that implements `on_llm_new_token`. In this example, we are using [`StreamingStdOutCallbackHandler`]()."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"Verse 1\n",
|
||||
"I'm sippin' on sparkling water,\n",
|
||||
"It's so refreshing and light,\n",
|
||||
"It's the perfect way to quench my thirst,\n",
|
||||
"On a hot summer night.\n",
|
||||
"\n",
|
||||
"Chorus\n",
|
||||
"Sparkling water, sparkling water,\n",
|
||||
"It's the best way to stay hydrated,\n",
|
||||
"It's so refreshing and light,\n",
|
||||
"It's the perfect way to stay alive.\n",
|
||||
"\n",
|
||||
"Verse 2\n",
|
||||
"I'm sippin' on sparkling water,\n",
|
||||
"It's so bubbly and bright,\n",
|
||||
"It's the perfect way to cool me down,\n",
|
||||
"On a hot summer night.\n",
|
||||
"\n",
|
||||
"Chorus\n",
|
||||
"Sparkling water, sparkling water,\n",
|
||||
"It's the best way to stay hydrated,\n",
|
||||
"It's so refreshing and light,\n",
|
||||
"It's the perfect way to stay alive.\n",
|
||||
"\n",
|
||||
"Verse 3\n",
|
||||
"I'm sippin' on sparkling water,\n",
|
||||
"It's so crisp and clean,\n",
|
||||
"It's the perfect way to keep me going,\n",
|
||||
"On a hot summer day.\n",
|
||||
"\n",
|
||||
"Chorus\n",
|
||||
"Sparkling water, sparkling water,\n",
|
||||
"It's the best way to stay hydrated,\n",
|
||||
"It's so refreshing and light,\n",
|
||||
"It's the perfect way to stay alive."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.callbacks.base import CallbackManager\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
|
||||
"resp = llm(\"Write me a song about sparkling water.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "61fb6de7-c6c8-48d0-a48e-1204c027a23c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"We still have access to the end `LLMResult` if using `generate`. However, `token_usage` is not currently supported for streaming."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "a35373f1-9ee6-4753-a343-5aee749b8527",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"Q: What did the fish say when it hit the wall?\n",
|
||||
"A: Dam!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LLMResult(generations=[[Generation(text='\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!', generation_info={'finish_reason': 'stop', 'logprobs': None})]], llm_output={'token_usage': {}})"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm.generate([\"Tell me a joke.\"])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,60 @@
|
||||
"""Callback Handler streams to stdout on new llm token."""
|
||||
import sys
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
Loading…
Reference in New Issue