diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index a13beb73..ff99ee48 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -476,13 +476,6 @@ class BaseOpenAI(BaseLLM): def modelname_to_contextsize(self, modelname: str) -> int: """Calculate the maximum number of tokens possible to generate for a model. - text-davinci-003: 4,097 tokens - text-curie-001: 2,048 tokens - text-babbage-001: 2,048 tokens - text-ada-001: 2,048 tokens - code-davinci-002: 8,000 tokens - code-cushman-001: 2,048 tokens - Args: modelname: The modelname we want to know the context size for. @@ -494,20 +487,37 @@ class BaseOpenAI(BaseLLM): max_tokens = openai.modelname_to_contextsize("text-davinci-003") """ - if modelname == "text-davinci-003": - return 4097 - elif modelname == "text-curie-001": - return 2048 - elif modelname == "text-babbage-001": - return 2048 - elif modelname == "text-ada-001": - return 2048 - elif modelname == "code-davinci-002": - return 8000 - elif modelname == "code-cushman-001": - return 2048 - else: - return 4097 + model_token_mapping = { + "gpt-4": 8192, + "gpt-4-0314": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0314": 32768, + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0301": 4096, + "text-ada-001": 2049, + "ada": 2049, + "text-babbage-001": 2040, + "babbage": 2049, + "text-curie-001": 2049, + "curie": 2049, + "davinci": 2049, + "text-davinci-003": 4097, + "text-davinci-002": 4097, + "code-davinci-002": 8001, + "code-davinci-001": 8001, + "code-cushman-002": 2048, + "code-cushman-001": 2048, + } + + context_size = model_token_mapping.get(modelname, None) + + if context_size is None: + raise ValueError( + f"Unknown model: {modelname}. Please provide a valid OpenAI model name." + "Known models are: " + ", ".join(model_token_mapping.keys()) + ) + + return context_size def max_tokens_for_prompt(self, prompt: str) -> int: """Calculate the maximum number of tokens possible to generate for a prompt. diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 1ada0ca6..9db120a5 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -211,3 +211,14 @@ async def test_openai_chat_async_streaming_callback() -> None: result = await llm.agenerate(["Write me a sentence with 100 words."]) assert callback_handler.llm_streams != 0 assert isinstance(result, LLMResult) + + +def test_openai_modelname_to_contextsize_valid() -> None: + """Test model name to context size on a valid model.""" + assert OpenAI().modelname_to_contextsize("davinci") == 2049 + + +def test_openai_modelname_to_contextsize_invalid() -> None: + """Test model name to context size on an invalid model.""" + with pytest.raises(ValueError): + OpenAI().modelname_to_contextsize("foobar")