diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 476d7dca34..b69acc868c 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -182,7 +182,7 @@ class BaseOpenAI(BaseLLM, BaseModel): generations=generations, llm_output={"token_usage": token_usage} ) - def stream(self, prompt: str) -> Generator: + def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator: """Call OpenAI with streaming flag and return the resulting generator. BETA: this is a beta feature while we figure out the right abstraction. @@ -190,6 +190,7 @@ class BaseOpenAI(BaseLLM, BaseModel): Args: prompt: The prompts to pass into the model. + stop: Optional list of stop words to use when generating. Returns: A generator representing the stream of tokens from OpenAI. @@ -204,6 +205,10 @@ class BaseOpenAI(BaseLLM, BaseModel): params = self._invocation_params if params["best_of"] != 1: raise ValueError("OpenAI only supports best_of == 1 for streaming") + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop params["stream"] = True generator = self.client.create(prompt=prompt, **params)