|
|
@ -182,7 +182,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|
|
|
generations=generations, llm_output={"token_usage": token_usage}
|
|
|
|
generations=generations, llm_output={"token_usage": token_usage}
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def stream(self, prompt: str) -> Generator:
|
|
|
|
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
|
|
|
"""Call OpenAI with streaming flag and return the resulting generator.
|
|
|
|
"""Call OpenAI with streaming flag and return the resulting generator.
|
|
|
|
|
|
|
|
|
|
|
|
BETA: this is a beta feature while we figure out the right abstraction.
|
|
|
|
BETA: this is a beta feature while we figure out the right abstraction.
|
|
|
@ -190,6 +190,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
prompt: The prompts to pass into the model.
|
|
|
|
prompt: The prompts to pass into the model.
|
|
|
|
|
|
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
A generator representing the stream of tokens from OpenAI.
|
|
|
|
A generator representing the stream of tokens from OpenAI.
|
|
|
@ -204,6 +205,10 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|
|
|
params = self._invocation_params
|
|
|
|
params = self._invocation_params
|
|
|
|
if params["best_of"] != 1:
|
|
|
|
if params["best_of"] != 1:
|
|
|
|
raise ValueError("OpenAI only supports best_of == 1 for streaming")
|
|
|
|
raise ValueError("OpenAI only supports best_of == 1 for streaming")
|
|
|
|
|
|
|
|
if stop is not None:
|
|
|
|
|
|
|
|
if "stop" in params:
|
|
|
|
|
|
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
|
|
|
|
|
|
params["stop"] = stop
|
|
|
|
params["stream"] = True
|
|
|
|
params["stream"] = True
|
|
|
|
generator = self.client.create(prompt=prompt, **params)
|
|
|
|
generator = self.client.create(prompt=prompt, **params)
|
|
|
|
|
|
|
|
|
|
|
|