diff --git a/langchain/callbacks/openai_info.py b/langchain/callbacks/openai_info.py index 9b0af085..e8d2f3b6 100644 --- a/langchain/callbacks/openai_info.py +++ b/langchain/callbacks/openai_info.py @@ -26,26 +26,34 @@ MODEL_COST_PER_1K_TOKENS = { "code-davinci-002": 0.02, "ada-finetuned": 0.0016, "babbage-finetuned": 0.0024, - "curie-finetuned": 0.0120, - "davinci-finetuned": 0.1200, + "curie-finetuned": 0.012, + "davinci-finetuned": 0.12, } +def standardize_model_name( + model_name: str, + is_completion: bool = False, +) -> str: + model_name = model_name.lower() + if "ft-" in model_name: + return model_name.split(":")[0] + "-finetuned" + elif is_completion and model_name.startswith("gpt-4"): + return model_name + "-completion" + else: + return model_name + + def get_openai_token_cost_for_model( model_name: str, num_tokens: int, is_completion: bool = False ) -> float: - # handling finetuned models - if "ft-" in model_name: - model_name = f"{model_name.split(':')[0]}-finetuned" - - suffix = "-completion" if is_completion and model_name.startswith("gpt-4") else "" - model = model_name.lower() + suffix - if model not in MODEL_COST_PER_1K_TOKENS: + model_name = standardize_model_name(model_name, is_completion=is_completion) + if model_name not in MODEL_COST_PER_1K_TOKENS: raise ValueError( f"Unknown model: {model_name}. Please provide a valid OpenAI model name." "Known models are: " + ", ".join(MODEL_COST_PER_1K_TOKENS.keys()) ) - return MODEL_COST_PER_1K_TOKENS[model] * num_tokens / 1000 + return MODEL_COST_PER_1K_TOKENS[model_name] * num_tokens / 1000 class OpenAICallbackHandler(BaseCallbackHandler): @@ -91,8 +99,8 @@ class OpenAICallbackHandler(BaseCallbackHandler): token_usage = response.llm_output["token_usage"] completion_tokens = token_usage.get("completion_tokens", 0) prompt_tokens = token_usage.get("prompt_tokens", 0) - model_name = response.llm_output.get("model_name") - if model_name and model_name in MODEL_COST_PER_1K_TOKENS: + model_name = standardize_model_name(response.llm_output.get("model_name", "")) + if model_name in MODEL_COST_PER_1K_TOKENS: completion_cost = get_openai_token_cost_for_model( model_name, completion_tokens, is_completion=True ) diff --git a/tests/unit_tests/callbacks/test_openai_info.py b/tests/unit_tests/callbacks/test_openai_info.py index c7348821..19b55424 100644 --- a/tests/unit_tests/callbacks/test_openai_info.py +++ b/tests/unit_tests/callbacks/test_openai_info.py @@ -44,3 +44,19 @@ def test_on_llm_end_custom_model(handler: OpenAICallbackHandler) -> None: ) handler.on_llm_end(response) assert handler.total_cost == 0 + + +def test_on_llm_end_finetuned_model(handler: OpenAICallbackHandler) -> None: + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 2, + "completion_tokens": 1, + "total_tokens": 3, + }, + "model_name": "ada:ft-your-org:custom-model-name-2022-02-15-04-21-04", + }, + ) + handler.on_llm_end(response) + assert handler.total_cost > 0