fix batching (#339)

harrison/agent_multi_inputs^2
Harrison Chase 1 year ago committed by GitHub
parent 3c6796b72e
commit e26b6f9c89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "060f7bc9",
"id": "20ac6b98",
"metadata": {},
"source": [
"# LLM Functionality\n",
@ -15,7 +15,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "5bddaa9a",
"id": "df924055",
"metadata": {},
"outputs": [],
"source": [
@ -25,7 +25,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "f6bed875",
"id": "182b484c",
"metadata": {},
"outputs": [],
"source": [
@ -34,7 +34,7 @@
},
{
"cell_type": "markdown",
"id": "edb2f14e",
"id": "9695ccfc",
"metadata": {},
"source": [
"**Generate Text:** The most basic functionality an LLM has is just the ability to call it, passing in a string and getting back a string."
@ -43,7 +43,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "c29ba285",
"id": "9d12ac26",
"metadata": {},
"outputs": [
{
@ -63,7 +63,7 @@
},
{
"cell_type": "markdown",
"id": "1f4a350a",
"id": "e7d4d42d",
"metadata": {},
"source": [
"**Generate:** More broadly, you can call it with a list of inputs, getting back a more complete response than just the text. This complete response includes things like multiple top responses, as well as LLM provider specific information"
@ -72,76 +72,93 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "a586c9ca",
"id": "f4dc241a",
"metadata": {},
"outputs": [],
"source": [
"llm_result = llm.generate([\"Tell me a joke\", \"Tell me a poem\"])"
"llm_result = llm.generate([\"Tell me a joke\", \"Tell me a poem\"]*15)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "22470289",
"execution_count": 7,
"id": "740392f6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"30"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(llm_result.generations)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ab6cdcf1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'),\n",
"[Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'),\n",
" Generation(text='\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!')]"
]
},
"execution_count": 5,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Results of the first input\n",
"llm_result.generations[0]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a1e72553",
"execution_count": 9,
"id": "4946a778",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Generation(text='\\n\\nIn the eyes of the moon\\n\\nI have seen a face\\n\\nThat I will never forget.\\n\\nThe light that I see\\n\\nIs like a fire in my heart.\\n\\nEvery letter I write\\n\\nWill be the last word\\n\\nOf my love for this person.\\n\\nThe darkness that I feel\\n\\nIs like a weight on my heart.'),\n",
" Generation(text=\"\\n\\nA rose by the side of the road\\n\\nIs all I need to find my way\\n\\nTo the place I've been searching for\\n\\nAnd my heart is singing with joy\\n\\nWhen I look at this rose\\n\\nIt reminds me of the love I've found\\n\\nAnd I know that wherever I go\\n\\nI'll always find my rose by the side of the road.\")]"
"[Generation(text='\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode\\n\\nHow can I make you my\\nx- Multiplayer gamemode'),\n",
" Generation(text=\"\\n\\nWhen I was younger\\nI thought that love\\nI was something like a fairytale\\nI would find my prince\\nAnd we would be together\\nForever\\nI was naïve\\nAnd I was wrong\\nLove is not a fairytale\\nIt's something else entirely\\nSomething that should be cherished\\nAnd loved\\nAnd never taken for granted\\nLove is something that you have to work for\\nIt doesn't come easy\\nYou have to sacrifice\\nYour time, your effort\\nAnd sometimes you have to give up \\nYou have to do what's best for yourself\\nAnd sometimes that means giving love up\")]"
]
},
"execution_count": 6,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Results of the second input\n",
"llm_result.generations[1]"
"llm_result.generations[-1]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "90c52536",
"execution_count": 8,
"id": "242e4527",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'token_usage': <OpenAIObject at 0x10b4f0d10> JSON: {\n",
" \"completion_tokens\": 199,\n",
" \"prompt_tokens\": 8,\n",
" \"total_tokens\": 207\n",
" }}"
"{'token_usage': {'completion_tokens': 3721,\n",
" 'prompt_tokens': 120,\n",
" 'total_tokens': 3841}}"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@ -153,7 +170,7 @@
},
{
"cell_type": "markdown",
"id": "92f6e7a5",
"id": "bde8e04f",
"metadata": {},
"source": [
"**Number of Tokens:** You can also estimate how many tokens a piece of text will be in that model. This is useful because models have a context length (and cost more for more tokens), which means you need to be aware of how long the text you are passing in is.\n",
@ -164,7 +181,7 @@
{
"cell_type": "code",
"execution_count": 8,
"id": "acfd9200",
"id": "b623c774",
"metadata": {},
"outputs": [
{
@ -192,7 +209,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "68ff3688",
"id": "4196efd9",
"metadata": {},
"outputs": [],
"source": []

@ -45,6 +45,8 @@ class OpenAI(LLM, BaseModel):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[str] = None
batch_size: int = 20
"""Batch size to use when passing multiple documents to generate."""
class Config:
"""Configuration for this pydantic object."""
@ -114,6 +116,7 @@ class OpenAI(LLM, BaseModel):
response = openai.generate(["Tell me a joke."])
"""
# TODO: write a unit test for this
params = self._default_params
if stop is not None:
if "stop" in params:
@ -126,15 +129,31 @@ class OpenAI(LLM, BaseModel):
"max_tokens set to -1 not supported for multiple inputs."
)
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
response = self.client.create(model=self.model_name, prompt=prompts, **params)
generations = []
for i, prompt in enumerate(prompts):
choices = response["choices"][i * self.n : (i + 1) * self.n]
generations.append([Generation(text=choice["text"]) for choice in choices])
sub_prompts = [
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
choices = []
token_usage = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
token_usage = response["usage"]
_keys = ["completion_tokens", "prompt_tokens", "total_tokens"]
for _prompts in sub_prompts:
response = self.client.create(
model=self.model_name, prompt=_prompts, **params
)
choices.extend(response["choices"])
for _key in _keys:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
generations = []
for i, prompt in enumerate(prompts):
sub_choices = choices[i * self.n : (i + 1) * self.n]
generations.append(
[Generation(text=choice["text"]) for choice in sub_choices]
)
return LLMResult(
generations=generations, llm_output={"token_usage": token_usage}
)

Loading…
Cancel
Save