diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 2b020595..2cca3e2e 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -1,6 +1,6 @@ """Wrapper around OpenAI APIs.""" import sys -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, Generator, List, Mapping, Optional from pydantic import BaseModel, Extra, Field, root_validator @@ -160,6 +160,30 @@ class OpenAI(LLM, BaseModel): generations=generations, llm_output={"token_usage": token_usage} ) + def stream(self, prompt: str) -> Generator: + """Call OpenAI with streaming flag and return the resulting generator. + + Args: + prompt: The prompts to pass into the model. + + Returns: + A generator representing the stream of tokens from OpenAI. + + Example: + .. code-block:: python + + generator = openai.stream("Tell me a joke.") + for token in generator: + yield token + """ + params = self._default_params + if params["best_of"] != 1: + raise ValueError("OpenAI only supports best_of == 1 for streaming") + params["stream"] = True + generator = self.client.create(model=self.model_name, prompt=prompt, **params) + + return generator + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 721b5b43..9c03b92f 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -1,6 +1,7 @@ """Test OpenAI API wrapper.""" from pathlib import Path +from typing import Generator import pytest @@ -55,3 +56,21 @@ def test_saving_loading_llm(tmp_path: Path) -> None: llm.save(file_path=tmp_path / "openai.yaml") loaded_llm = load_llm(tmp_path / "openai.yaml") assert loaded_llm == llm + + +def test_openai_streaming() -> None: + """Test streaming tokens from OpenAI.""" + llm = OpenAI(max_tokens=10) + generator = llm.stream("I'm Pickle Rick") + + assert isinstance(generator, Generator) + + for token in generator: + assert isinstance(token["choices"][0]["text"], str) + + +def test_openai_streaming_error() -> None: + """Test error handling in stream.""" + llm = OpenAI(best_of=2) + with pytest.raises(ValueError): + llm.stream("I'm Pickle Rick")