Add finish reason to Generation for usage downstream (#526)

Add `finish_reason` to `Generation` as well as extend
`BaseOpenAI._generate` to include it in the output. This can be useful
for usage in downstream tasks when we need to filter for only
generations that finished because of `"stop"` for example. Maybe we
should add this to `LLMChain` as well?

For more details, see
https://beta.openai.com/docs/guides/completion/best-practices

Signed-off-by: Diwank Singh Tomer <diwank.singh@gmail.com>
This commit is contained in:
Diwank Singh Tomer 2023-01-06 20:45:25 +05:30 committed by GitHub
parent e64ed7b975
commit ba0cbb4a41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 2 deletions

View File

@ -164,7 +164,16 @@ class BaseOpenAI(BaseLLM, BaseModel):
for i, prompt in enumerate(prompts):
sub_choices = choices[i * self.n : (i + 1) * self.n]
generations.append(
[Generation(text=choice["text"]) for choice in sub_choices]
[
Generation(
text=choice["text"],
generation_info=dict(
finish_reason=choice["finish_reason"],
logprobs=choice["logprobs"],
),
)
for choice in sub_choices
]
)
return LLMResult(
generations=generations, llm_output={"token_usage": token_usage}

View File

@ -1,6 +1,6 @@
"""Common schema objects."""
from typing import List, NamedTuple, Optional
from typing import Any, Dict, List, NamedTuple, Optional
class AgentAction(NamedTuple):
@ -23,6 +23,10 @@ class Generation(NamedTuple):
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs