mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
8698cb9b28
Turns on https://docs.astral.sh/ruff/settings/#format_docstring-code-format and https://docs.astral.sh/ruff/settings/#format_skip-magic-trailing-comma ```toml [tool.ruff.format] docstring-code-format = true skip-magic-trailing-comma = true ```
31 lines
840 B
Python
31 lines
840 B
Python
import pytest
|
|
|
|
from langchain_openai import ChatOpenAI, OpenAI
|
|
|
|
_EXPECTED_NUM_TOKENS = {
|
|
"ada": 17,
|
|
"babbage": 17,
|
|
"curie": 17,
|
|
"davinci": 17,
|
|
"gpt-4": 12,
|
|
"gpt-4-32k": 12,
|
|
"gpt-3.5-turbo": 12,
|
|
}
|
|
|
|
_MODELS = models = ["ada", "babbage", "curie", "davinci"]
|
|
_CHAT_MODELS = ["gpt-4", "gpt-4-32k", "gpt-3.5-turbo"]
|
|
|
|
|
|
@pytest.mark.parametrize("model", _MODELS)
|
|
def test_openai_get_num_tokens(model: str) -> None:
|
|
"""Test get_tokens."""
|
|
llm = OpenAI(model=model)
|
|
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
|
|
|
|
|
@pytest.mark.parametrize("model", _CHAT_MODELS)
|
|
def test_chat_openai_get_num_tokens(model: str) -> None:
|
|
"""Test get_tokens."""
|
|
llm = ChatOpenAI(model=model)
|
|
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|