From b588446bf9be4d7aa0e4590ff3d4860538b089ad Mon Sep 17 00:00:00 2001 From: Mike Wang <62768671+skcoirz@users.noreply.github.com> Date: Fri, 28 Apr 2023 20:42:24 -0700 Subject: [PATCH] [simple][test] Added test case for schema.py (#3692) - added unittest for schema.py covering utility functions and token counting. - fixed a nit. based on huggingface doc, the tokenizer model is gpt-2. [link](https://huggingface.co/transformers/v4.8.2/_modules/transformers/models/gpt2/tokenization_gpt2_fast.html) - make lint && make format, passed on local - screenshot of new test running result Screenshot 2023-04-27 at 9 51 55 PM --- langchain/schema.py | 42 +++++++------- tests/integration_tests/test_schema.py | 15 +++++ tests/unit_tests/test_schema.py | 77 ++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 19 deletions(-) create mode 100644 tests/integration_tests/test_schema.py create mode 100644 tests/unit_tests/test_schema.py diff --git a/langchain/schema.py b/langchain/schema.py index 821dc70a..1e4770e0 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -181,6 +181,28 @@ class PromptValue(BaseModel, ABC): """Return prompt as messages.""" +def _get_num_tokens_default_method(text: str) -> int: + """Get the number of tokens present in the text.""" + # TODO: this method may not be exact. + # TODO: this method may differ based on model (eg codex). + try: + from transformers import GPT2TokenizerFast + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "This is needed in order to calculate get_num_tokens. " + "Please install it with `pip install transformers`." + ) + # create a GPT-2 tokenizer instance + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + + # tokenize the text using the GPT-3 tokenizer + tokenized_text = tokenizer.tokenize(text) + + # calculate the number of tokens in the tokenized text + return len(tokenized_text) + + class BaseLanguageModel(BaseModel, ABC): @abstractmethod def generate_prompt( @@ -195,25 +217,7 @@ class BaseLanguageModel(BaseModel, ABC): """Take in a list of prompt values and return an LLMResult.""" def get_num_tokens(self, text: str) -> int: - """Get the number of tokens present in the text.""" - # TODO: this method may not be exact. - # TODO: this method may differ based on model (eg codex). - try: - from transformers import GPT2TokenizerFast - except ImportError: - raise ValueError( - "Could not import transformers python package. " - "This is needed in order to calculate get_num_tokens. " - "Please install it with `pip install transformers`." - ) - # create a GPT-3 tokenizer instance - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") - - # tokenize the text using the GPT-3 tokenizer - tokenized_text = tokenizer.tokenize(text) - - # calculate the number of tokens in the tokenized text - return len(tokenized_text) + return _get_num_tokens_default_method(text) def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: """Get the number of tokens in the message.""" diff --git a/tests/integration_tests/test_schema.py b/tests/integration_tests/test_schema.py new file mode 100644 index 00000000..472c27cc --- /dev/null +++ b/tests/integration_tests/test_schema.py @@ -0,0 +1,15 @@ +"""Test formatting functionality.""" + +from langchain.schema import _get_num_tokens_default_method + + +class TestTokenCountingWithGPT2Tokenizer: + def test_empty_token(self) -> None: + assert _get_num_tokens_default_method("") == 0 + + def test_multiple_tokens(self) -> None: + assert _get_num_tokens_default_method("a b c") == 3 + + def test_special_tokens(self) -> None: + # test for consistency when the default tokenizer is changed + assert _get_num_tokens_default_method("a:b_c d") == 6 diff --git a/tests/unit_tests/test_schema.py b/tests/unit_tests/test_schema.py new file mode 100644 index 00000000..ef9d9918 --- /dev/null +++ b/tests/unit_tests/test_schema.py @@ -0,0 +1,77 @@ +"""Test formatting functionality.""" + +import unittest + +from langchain.schema import ( + AIMessage, + HumanMessage, + SystemMessage, + get_buffer_string, + messages_from_dict, + messages_to_dict, +) + + +class TestGetBufferString(unittest.TestCase): + human_msg: HumanMessage = HumanMessage(content="human") + ai_msg: AIMessage = AIMessage(content="ai") + sys_msg: SystemMessage = SystemMessage(content="sys") + + def test_empty_input(self) -> None: + self.assertEqual(get_buffer_string([]), "") + + def test_valid_single_message(self) -> None: + expected_output = f"Human: {self.human_msg.content}" + self.assertEqual( + get_buffer_string([self.human_msg]), + expected_output, + ) + + def test_custom_human_prefix(self) -> None: + prefix = "H" + expected_output = f"{prefix}: {self.human_msg.content}" + self.assertEqual( + get_buffer_string([self.human_msg], human_prefix="H"), + expected_output, + ) + + def test_custom_ai_prefix(self) -> None: + prefix = "A" + expected_output = f"{prefix}: {self.ai_msg.content}" + self.assertEqual( + get_buffer_string([self.ai_msg], ai_prefix="A"), + expected_output, + ) + + def test_multiple_msg(self) -> None: + msgs = [self.human_msg, self.ai_msg, self.sys_msg] + expected_output = "\n".join( + [ + f"Human: {self.human_msg.content}", + f"AI: {self.ai_msg.content}", + f"System: {self.sys_msg.content}", + ] + ) + self.assertEqual( + get_buffer_string(msgs), + expected_output, + ) + + +class TestMessageDictConversion(unittest.TestCase): + human_msg: HumanMessage = HumanMessage( + content="human", additional_kwargs={"key": "value"} + ) + ai_msg: AIMessage = AIMessage(content="ai") + sys_msg: SystemMessage = SystemMessage(content="sys") + + def test_multiple_msg(self) -> None: + msgs = [ + self.human_msg, + self.ai_msg, + self.sys_msg, + ] + self.assertEqual( + messages_from_dict(messages_to_dict(msgs)), + msgs, + )