RFC: more complete return (#313)

Co-authored-by: Andrew Williamson <awilliamson10@indstate.edu>
Co-authored-by: awilliamson10 <aw.williamson10@gmail.com>
harrison/agent_multi_inputs^2
Harrison Chase 2 years ago committed by GitHub
parent 482611f426
commit 595cc1ae1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,11 +1,39 @@
"""Base interface for large language models to expose."""
from abc import ABC, abstractmethod
from typing import Any, List, Mapping, Optional
from typing import Any, List, Mapping, NamedTuple, Optional
class Generation(NamedTuple):
"""Output of a single generation."""
text: str
"""Generated text output."""
# TODO: add log probs
class LLMResult(NamedTuple):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class LLM(ABC):
"""LLM wrapper should take in a prompt and return a string."""
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
generations = []
for prompt in prompts:
text = self(prompt, stop=stop)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import LLM
from langchain.llms.base import LLM, Generation, LLMResult
from langchain.utils import get_from_dict_or_env
@ -97,6 +97,48 @@ class OpenAI(LLM, BaseModel):
}
return {**normal_params, **self.model_kwargs}
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Call out to OpenAI's endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
Example:
.. code-block:: python
response = openai.generate(["Tell me a joke."])
"""
params = self._default_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
if params["max_tokens"] == -1:
if len(prompts) != 1:
raise ValueError(
"max_tokens set to -1 not supported for multiple inputs."
)
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
response = self.client.create(model=self.model_name, prompt=prompts, **params)
generations = []
for i, prompt in enumerate(prompts):
choices = response["choices"][i * self.n : (i + 1) * self.n]
generations.append([Generation(text=choice["text"]) for choice in choices])
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
token_usage = response["usage"]
return LLMResult(
generations=generations, llm_output={"token_usage": token_usage}
)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
@ -117,17 +159,7 @@ class OpenAI(LLM, BaseModel):
response = openai("Tell me a joke.")
"""
params = self._default_params
if params["max_tokens"] == -1:
params["max_tokens"] = self.max_tokens_for_prompt(prompt)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
response = self.client.create(model=self.model_name, prompt=prompt, **params)
return response["choices"][0]["text"]
return self.generate([prompt], stop=stop).generations[0][0].text
def modelname_to_contextsize(self, modelname: str) -> int:
"""Calculate the maximum number of tokens possible to generate for a model.

Loading…
Cancel
Save