From 0a4baca291d4946860355bc772b0428db2f5eda5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fynn=20Fl=C3=BCgge?= Date: Wed, 4 Oct 2023 00:35:36 +0200 Subject: [PATCH] chore: add kotlin code splitter (#11364) - **Description:** Adds Kotlin language to `TextSplitter` --------- Co-authored-by: Eugene Yurtsev --- .../text_splitters/code_splitter.mdx | 1 + libs/langchain/langchain/text_splitter.py | 27 ++++++++++++++++ .../tests/unit_tests/test_text_splitter.py | 32 +++++++++++++++++++ 3 files changed, 60 insertions(+) diff --git a/docs/snippets/modules/data_connection/document_transformers/text_splitters/code_splitter.mdx b/docs/snippets/modules/data_connection/document_transformers/text_splitters/code_splitter.mdx index e7a2db3d06..0a6135e303 100644 --- a/docs/snippets/modules/data_connection/document_transformers/text_splitters/code_splitter.mdx +++ b/docs/snippets/modules/data_connection/document_transformers/text_splitters/code_splitter.mdx @@ -17,6 +17,7 @@ from langchain.text_splitter import ( ['cpp', 'go', 'java', + 'kotlin', 'js', 'ts', 'php', diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index 12664014be..5aebe6b5f3 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -614,6 +614,7 @@ class Language(str, Enum): CPP = "cpp" GO = "go" JAVA = "java" + KOTLIN = "kotlin" JS = "js" TS = "ts" PHP = "php" @@ -762,6 +763,32 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] + elif language == Language.KOTLIN: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\ninternal ", + "\ncompanion ", + "\nfun ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nwhen ", + "\ncase ", + "\nelse ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] elif language == Language.JS: return [ # Split along function definitions diff --git a/libs/langchain/tests/unit_tests/test_text_splitter.py b/libs/langchain/tests/unit_tests/test_text_splitter.py index ce4680c20e..ccddee434a 100644 --- a/libs/langchain/tests/unit_tests/test_text_splitter.py +++ b/libs/langchain/tests/unit_tests/test_text_splitter.py @@ -525,6 +525,38 @@ public class HelloWorld { ] +def test_kotlin_code_splitter() -> None: + splitter = RecursiveCharacterTextSplitter.from_language( + Language.KOTLIN, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +class HelloWorld { + companion object { + @JvmStatic + fun main(args: Array) { + println("Hello, World!") + } + } +} + """ + chunks = splitter.split_text(code) + assert chunks == [ + "class", + "HelloWorld {", + "companion", + "object {", + "@JvmStatic", + "fun", + "main(args:", + "Array)", + "{", + 'println("Hello,', + 'World!")', + "}\n }", + "}", + ] + + def test_csharp_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.CSHARP, chunk_size=CHUNK_SIZE, chunk_overlap=0