diff --git a/docs/modules/indexes/text_splitters/examples/code_splitter.ipynb b/docs/modules/indexes/text_splitters/examples/code_splitter.ipynb new file mode 100644 index 00000000..c769dd4a --- /dev/null +++ b/docs/modules/indexes/text_splitters/examples/code_splitter.ipynb @@ -0,0 +1,158 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CodeTextSplitter\n", + "\n", + "CodeTextSplitter allows you to split your code with multiple language support. Import enum `Language` and specify the language. " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.text_splitter import (\n", + " CodeTextSplitter,\n", + " Language,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Choose a language to use" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "python_splitter = CodeTextSplitter(\n", + " language=Language.PYTHON, chunk_size=16, chunk_overlap=0\n", + ")\n", + "js_splitter = CodeTextSplitter(\n", + " language=Language.JS, chunk_size=16, chunk_overlap=0\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Split the code" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='def', metadata={}),\n", + " Document(page_content='hello_world():', metadata={}),\n", + " Document(page_content='print(\"Hello,', metadata={}),\n", + " Document(page_content='World!\")', metadata={}),\n", + " Document(page_content='# Call the', metadata={}),\n", + " Document(page_content='function', metadata={}),\n", + " Document(page_content='hello_world()', metadata={})]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "PYTHON_CODE = \"\"\"\n", + "def hello_world():\n", + " print(\"Hello, World!\")\n", + "\n", + "# Call the function\n", + "hello_world()\n", + "\"\"\"\n", + "\n", + "python_docs = python_splitter.create_documents([PYTHON_CODE])\n", + "python_docs" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='function', metadata={}),\n", + " Document(page_content='helloWorld() {', metadata={}),\n", + " Document(page_content='console.log(\"He', metadata={}),\n", + " Document(page_content='llo,', metadata={}),\n", + " Document(page_content='World!\");', metadata={}),\n", + " Document(page_content='}', metadata={}),\n", + " Document(page_content='// Call the', metadata={}),\n", + " Document(page_content='function', metadata={}),\n", + " Document(page_content='helloWorld();', metadata={})]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "JS_CODE = \"\"\"\n", + "function helloWorld() {\n", + " console.log(\"Hello, World!\");\n", + "}\n", + "\n", + "// Call the function\n", + "helloWorld();\n", + "\"\"\"\n", + "\n", + "js_docs = js_splitter.create_documents([JS_CODE])\n", + "js_docs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langchain", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 38429f6c..c5e1a843 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -5,6 +5,7 @@ import copy import logging import re from abc import ABC, abstractmethod +from enum import Enum from typing import ( AbstractSet, Any, @@ -475,3 +476,314 @@ class PythonCodeTextSplitter(RecursiveCharacterTextSplitter): "", ] super().__init__(separators=separators, **kwargs) + + +class Language(str, Enum): + CPP = "cpp" + GO = "go" + JAVA = "java" + JS = "js" + PHP = "php" + PROTO = "proto" + PYTHON = "python" + RST = "rst" + RUBY = "ruby" + RUST = "rust" + SCALA = "scala" + SWIFT = "swift" + MARKDOWN = "markdown" + LATEX = "latex" + + +class CodeTextSplitter(RecursiveCharacterTextSplitter): + def __init__(self, language: Language, **kwargs: Any): + """ + A generic code text splitter supporting many programming languages. + Example: + splitter = CodeTextSplitter( + language=Language.JAVA + ) + Args: + Language: The programming language to use + """ + separators = self._get_separators_for_language(language) + super().__init__(separators=separators, **kwargs) + + def _get_separators_for_language(self, language: Language) -> List[str]: + if language == Language.CPP: + return [ + # Split along class definitions + "\nclass ", + # Split along function definitions + "\nvoid ", + "\nint ", + "\nfloat ", + "\ndouble ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.GO: + return [ + # Split along function definitions + "\nfunc ", + "\nvar ", + "\nconst ", + "\ntype ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JAVA: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JS: + return [ + # Split along function definitions + "\nfunction ", + "\nconst ", + "\nlet ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PHP: + return [ + # Split along function definitions + "\nfunction ", + # Split along class definitions + "\nclass ", + # Split along control flow statements + "\nif ", + "\nforeach ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PROTO: + return [ + # Split along message definitions + "\nmessage ", + # Split along service definitions + "\nservice ", + # Split along enum definitions + "\nenum ", + # Split along option definitions + "\noption ", + # Split along import statements + "\nimport ", + # Split along syntax declarations + "\nsyntax ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PYTHON: + return [ + # First, try to split along class definitions + "\nclass ", + "\ndef ", + "\n\tdef ", + # Now split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RST: + return [ + # Split along section titles + "\n===\n", + "\n---\n", + "\n***\n", + # Split along directive markers + "\n.. ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUBY: + return [ + # Split along method definitions + "\ndef ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nunless ", + "\nwhile ", + "\nfor ", + "\ndo ", + "\nbegin ", + "\nrescue ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUST: + return [ + # Split along function definitions + "\nfn ", + "\nconst ", + "\nlet ", + # Split along control flow statements + "\nif ", + "\nwhile ", + "\nfor ", + "\nloop ", + "\nmatch ", + "\nconst ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SCALA: + return [ + # Split along class definitions + "\nclass ", + "\nobject ", + # Split along method definitions + "\ndef ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nmatch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SWIFT: + return [ + # Split along function definitions + "\nfunc ", + # Split along class definitions + "\nclass ", + "\nstruct ", + "\nenum ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.MARKDOWN: + return [ + # First, try to split along Markdown headings (starting with level 2) + "\n## ", + "\n### ", + "\n#### ", + "\n##### ", + "\n###### ", + # Note the alternative syntax for headings (below) is not handled here + # Heading level 2 + # --------------- + # End of code block + "```\n\n", + # Horizontal lines + "\n\n***\n\n", + "\n\n---\n\n", + "\n\n___\n\n", + # Note that this splitter doesn't handle horizontal lines defined + # by *three or more* of ***, ---, or ___, but this is not handled + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.LATEX: + return [ + # First, try to split along Latex sections + "\n\\chapter{", + "\n\\section{", + "\n\\subsection{", + "\n\\subsubsection{", + # Now split by environments + "\n\\begin{enumerate}", + "\n\\begin{itemize}", + "\n\\begin{description}", + "\n\\begin{list}", + "\n\\begin{quote}", + "\n\\begin{quotation}", + "\n\\begin{verse}", + "\n\\begin{verbatim}", + ## Now split by math environments + "\n\\begin{align}", + "$$", + "$", + # Now split by the normal type of lines + " ", + "", + ] + else: + raise ValueError( + f"Language {language} is not supported! " + f"Please choose from {list(Language)}" + ) diff --git a/tests/unit_tests/test_text_splitter.py b/tests/unit_tests/test_text_splitter.py index 75f243b9..89d0e08d 100644 --- a/tests/unit_tests/test_text_splitter.py +++ b/tests/unit_tests/test_text_splitter.py @@ -4,6 +4,8 @@ import pytest from langchain.docstore.document import Document from langchain.text_splitter import ( CharacterTextSplitter, + CodeTextSplitter, + Language, PythonCodeTextSplitter, RecursiveCharacterTextSplitter, ) @@ -194,3 +196,317 @@ def test_python_text_splitter() -> None: split_3 = """def bar():""" expected_splits = [split_0, split_1, split_2, split_3] assert splits == expected_splits + + +CHUNK_SIZE = 16 + + +def test_python_code_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.PYTHON, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +def hello_world(): + print("Hello, World!") + +# Call the function +hello_world() + """ + chunks = splitter.split_text(code) + assert chunks == [ + "def", + "hello_world():", + 'print("Hello,', + 'World!")', + "# Call the", + "function", + "hello_world()", + ] + + +def test_golang_code_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.GO, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +package main + +import "fmt" + +func helloWorld() { + fmt.Println("Hello, World!") +} + +func main() { + helloWorld() +} + """ + chunks = splitter.split_text(code) + assert chunks == [ + "package main", + 'import "fmt"', + "func", + "helloWorld() {", + 'fmt.Println("He', + "llo,", + 'World!")', + "}", + "func main() {", + "helloWorld()", + "}", + ] + + +def test_rst_code_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.RST, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +Sample Document +=============== + +Section +------- + +This is the content of the section. + +Lists +----- + +- Item 1 +- Item 2 +- Item 3 + """ + chunks = splitter.split_text(code) + assert chunks == [ + "Sample Document", + "===============", + "Section", + "-------", + "This is the", + "content of the", + "section.", + "Lists\n-----", + "- Item 1", + "- Item 2", + "- Item 3", + ] + + +def test_proto_file_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.PROTO, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +syntax = "proto3"; + +package example; + +message Person { + string name = 1; + int32 age = 2; + repeated string hobbies = 3; +} + """ + chunks = splitter.split_text(code) + assert chunks == [ + "syntax =", + '"proto3";', + "package", + "example;", + "message Person", + "{", + "string name", + "= 1;", + "int32 age =", + "2;", + "repeated", + "string hobbies", + "= 3;", + "}", + ] + + +def test_javascript_code_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.JS, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +function helloWorld() { + console.log("Hello, World!"); +} + +// Call the function +helloWorld(); + """ + chunks = splitter.split_text(code) + assert chunks == [ + "function", + "helloWorld() {", + 'console.log("He', + "llo,", + 'World!");', + "}", + "// Call the", + "function", + "helloWorld();", + ] + + +def test_java_code_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.JAVA, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +public class HelloWorld { + public static void main(String[] args) { + System.out.println("Hello, World!"); + } +} + """ + chunks = splitter.split_text(code) + assert chunks == [ + "public class", + "HelloWorld {", + "public", + "static void", + "main(String[]", + "args) {", + "System.out.prin", + 'tln("Hello,', + 'World!");', + "}\n}", + ] + + +def test_cpp_code_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.CPP, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +#include + +int main() { + std::cout << "Hello, World!" << std::endl; + return 0; +} + """ + chunks = splitter.split_text(code) + assert chunks == [ + "#include", + "", + "int main() {", + "std::cout", + '<< "Hello,', + 'World!" <<', + "std::endl;", + "return 0;\n}", + ] + + +def test_scala_code_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.SCALA, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +object HelloWorld { + def main(args: Array[String]): Unit = { + println("Hello, World!") + } +} + """ + chunks = splitter.split_text(code) + assert chunks == [ + "object", + "HelloWorld {", + "def", + "main(args:", + "Array[String]):", + "Unit = {", + 'println("Hello,', + 'World!")', + "}\n}", + ] + + +def test_ruby_code_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.RUBY, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +def hello_world + puts "Hello, World!" +end + +hello_world + """ + chunks = splitter.split_text(code) + assert chunks == [ + "def hello_world", + 'puts "Hello,', + 'World!"', + "end", + "hello_world", + ] + + +def test_php_code_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.PHP, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ + + """ + chunks = splitter.split_text(code) + assert chunks == [ + "", + ] + + +def test_swift_code_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.SWIFT, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +func helloWorld() { + print("Hello, World!") +} + +helloWorld() + """ + chunks = splitter.split_text(code) + assert chunks == [ + "func", + "helloWorld() {", + 'print("Hello,', + 'World!")', + "}", + "helloWorld()", + ] + + +def test_rust_code_splitter() -> None: + splitter = CodeTextSplitter( + language=Language.RUST, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +fn main() { + println!("Hello, World!"); +} + """ + chunks = splitter.split_text(code) + assert chunks == ["fn main() {", 'println!("Hello', ",", 'World!");', "}"]