Add more code splitters (go, rst, js, java, cpp, scala, ruby, php, swift, rust) (#5171)

As the title says, I added more code splitters.
The implementation is trivial, so i don't add separate tests for each
splitter.
Let me know if any concerns.

Fixes # (issue)
https://github.com/hwchase17/langchain/issues/5170

## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:
@eyurtsev @hwchase17

---------

Signed-off-by: byhsu <byhsu@linkedin.com>
Co-authored-by: byhsu <byhsu@linkedin.com>
This commit is contained in:
ByronHsu 2023-05-30 08:04:05 -07:00 committed by GitHub
parent a61b7f7e7c
commit 9d658aaa5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 786 additions and 0 deletions

View File

@ -0,0 +1,158 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# CodeTextSplitter\n",
"\n",
"CodeTextSplitter allows you to split your code with multiple language support. Import enum `Language` and specify the language. "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import (\n",
" CodeTextSplitter,\n",
" Language,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Choose a language to use"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"python_splitter = CodeTextSplitter(\n",
" language=Language.PYTHON, chunk_size=16, chunk_overlap=0\n",
")\n",
"js_splitter = CodeTextSplitter(\n",
" language=Language.JS, chunk_size=16, chunk_overlap=0\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split the code"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='def', metadata={}),\n",
" Document(page_content='hello_world():', metadata={}),\n",
" Document(page_content='print(\"Hello,', metadata={}),\n",
" Document(page_content='World!\")', metadata={}),\n",
" Document(page_content='# Call the', metadata={}),\n",
" Document(page_content='function', metadata={}),\n",
" Document(page_content='hello_world()', metadata={})]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"PYTHON_CODE = \"\"\"\n",
"def hello_world():\n",
" print(\"Hello, World!\")\n",
"\n",
"# Call the function\n",
"hello_world()\n",
"\"\"\"\n",
"\n",
"python_docs = python_splitter.create_documents([PYTHON_CODE])\n",
"python_docs"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='function', metadata={}),\n",
" Document(page_content='helloWorld() {', metadata={}),\n",
" Document(page_content='console.log(\"He', metadata={}),\n",
" Document(page_content='llo,', metadata={}),\n",
" Document(page_content='World!\");', metadata={}),\n",
" Document(page_content='}', metadata={}),\n",
" Document(page_content='// Call the', metadata={}),\n",
" Document(page_content='function', metadata={}),\n",
" Document(page_content='helloWorld();', metadata={})]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"JS_CODE = \"\"\"\n",
"function helloWorld() {\n",
" console.log(\"Hello, World!\");\n",
"}\n",
"\n",
"// Call the function\n",
"helloWorld();\n",
"\"\"\"\n",
"\n",
"js_docs = js_splitter.create_documents([JS_CODE])\n",
"js_docs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "langchain",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -5,6 +5,7 @@ import copy
import logging import logging
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum
from typing import ( from typing import (
AbstractSet, AbstractSet,
Any, Any,
@ -475,3 +476,314 @@ class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
"", "",
] ]
super().__init__(separators=separators, **kwargs) super().__init__(separators=separators, **kwargs)
class Language(str, Enum):
CPP = "cpp"
GO = "go"
JAVA = "java"
JS = "js"
PHP = "php"
PROTO = "proto"
PYTHON = "python"
RST = "rst"
RUBY = "ruby"
RUST = "rust"
SCALA = "scala"
SWIFT = "swift"
MARKDOWN = "markdown"
LATEX = "latex"
class CodeTextSplitter(RecursiveCharacterTextSplitter):
def __init__(self, language: Language, **kwargs: Any):
"""
A generic code text splitter supporting many programming languages.
Example:
splitter = CodeTextSplitter(
language=Language.JAVA
)
Args:
Language: The programming language to use
"""
separators = self._get_separators_for_language(language)
super().__init__(separators=separators, **kwargs)
def _get_separators_for_language(self, language: Language) -> List[str]:
if language == Language.CPP:
return [
# Split along class definitions
"\nclass ",
# Split along function definitions
"\nvoid ",
"\nint ",
"\nfloat ",
"\ndouble ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.GO:
return [
# Split along function definitions
"\nfunc ",
"\nvar ",
"\nconst ",
"\ntype ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.JAVA:
return [
# Split along class definitions
"\nclass ",
# Split along method definitions
"\npublic ",
"\nprotected ",
"\nprivate ",
"\nstatic ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.JS:
return [
# Split along function definitions
"\nfunction ",
"\nconst ",
"\nlet ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\ndefault ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.PHP:
return [
# Split along function definitions
"\nfunction ",
# Split along class definitions
"\nclass ",
# Split along control flow statements
"\nif ",
"\nforeach ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.PROTO:
return [
# Split along message definitions
"\nmessage ",
# Split along service definitions
"\nservice ",
# Split along enum definitions
"\nenum ",
# Split along option definitions
"\noption ",
# Split along import statements
"\nimport ",
# Split along syntax declarations
"\nsyntax ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.PYTHON:
return [
# First, try to split along class definitions
"\nclass ",
"\ndef ",
"\n\tdef ",
# Now split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.RST:
return [
# Split along section titles
"\n===\n",
"\n---\n",
"\n***\n",
# Split along directive markers
"\n.. ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.RUBY:
return [
# Split along method definitions
"\ndef ",
"\nclass ",
# Split along control flow statements
"\nif ",
"\nunless ",
"\nwhile ",
"\nfor ",
"\ndo ",
"\nbegin ",
"\nrescue ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.RUST:
return [
# Split along function definitions
"\nfn ",
"\nconst ",
"\nlet ",
# Split along control flow statements
"\nif ",
"\nwhile ",
"\nfor ",
"\nloop ",
"\nmatch ",
"\nconst ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.SCALA:
return [
# Split along class definitions
"\nclass ",
"\nobject ",
# Split along method definitions
"\ndef ",
"\nval ",
"\nvar ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nmatch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.SWIFT:
return [
# Split along function definitions
"\nfunc ",
# Split along class definitions
"\nclass ",
"\nstruct ",
"\nenum ",
# Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
# Split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
elif language == Language.MARKDOWN:
return [
# First, try to split along Markdown headings (starting with level 2)
"\n## ",
"\n### ",
"\n#### ",
"\n##### ",
"\n###### ",
# Note the alternative syntax for headings (below) is not handled here
# Heading level 2
# ---------------
# End of code block
"```\n\n",
# Horizontal lines
"\n\n***\n\n",
"\n\n---\n\n",
"\n\n___\n\n",
# Note that this splitter doesn't handle horizontal lines defined
# by *three or more* of ***, ---, or ___, but this is not handled
"\n\n",
"\n",
" ",
"",
]
elif language == Language.LATEX:
return [
# First, try to split along Latex sections
"\n\\chapter{",
"\n\\section{",
"\n\\subsection{",
"\n\\subsubsection{",
# Now split by environments
"\n\\begin{enumerate}",
"\n\\begin{itemize}",
"\n\\begin{description}",
"\n\\begin{list}",
"\n\\begin{quote}",
"\n\\begin{quotation}",
"\n\\begin{verse}",
"\n\\begin{verbatim}",
## Now split by math environments
"\n\\begin{align}",
"$$",
"$",
# Now split by the normal type of lines
" ",
"",
]
else:
raise ValueError(
f"Language {language} is not supported! "
f"Please choose from {list(Language)}"
)

View File

@ -4,6 +4,8 @@ import pytest
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import ( from langchain.text_splitter import (
CharacterTextSplitter, CharacterTextSplitter,
CodeTextSplitter,
Language,
PythonCodeTextSplitter, PythonCodeTextSplitter,
RecursiveCharacterTextSplitter, RecursiveCharacterTextSplitter,
) )
@ -194,3 +196,317 @@ def test_python_text_splitter() -> None:
split_3 = """def bar():""" split_3 = """def bar():"""
expected_splits = [split_0, split_1, split_2, split_3] expected_splits = [split_0, split_1, split_2, split_3]
assert splits == expected_splits assert splits == expected_splits
CHUNK_SIZE = 16
def test_python_code_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.PYTHON, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
def hello_world():
print("Hello, World!")
# Call the function
hello_world()
"""
chunks = splitter.split_text(code)
assert chunks == [
"def",
"hello_world():",
'print("Hello,',
'World!")',
"# Call the",
"function",
"hello_world()",
]
def test_golang_code_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.GO, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
package main
import "fmt"
func helloWorld() {
fmt.Println("Hello, World!")
}
func main() {
helloWorld()
}
"""
chunks = splitter.split_text(code)
assert chunks == [
"package main",
'import "fmt"',
"func",
"helloWorld() {",
'fmt.Println("He',
"llo,",
'World!")',
"}",
"func main() {",
"helloWorld()",
"}",
]
def test_rst_code_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.RST, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
Sample Document
===============
Section
-------
This is the content of the section.
Lists
-----
- Item 1
- Item 2
- Item 3
"""
chunks = splitter.split_text(code)
assert chunks == [
"Sample Document",
"===============",
"Section",
"-------",
"This is the",
"content of the",
"section.",
"Lists\n-----",
"- Item 1",
"- Item 2",
"- Item 3",
]
def test_proto_file_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.PROTO, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
syntax = "proto3";
package example;
message Person {
string name = 1;
int32 age = 2;
repeated string hobbies = 3;
}
"""
chunks = splitter.split_text(code)
assert chunks == [
"syntax =",
'"proto3";',
"package",
"example;",
"message Person",
"{",
"string name",
"= 1;",
"int32 age =",
"2;",
"repeated",
"string hobbies",
"= 3;",
"}",
]
def test_javascript_code_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.JS, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
function helloWorld() {
console.log("Hello, World!");
}
// Call the function
helloWorld();
"""
chunks = splitter.split_text(code)
assert chunks == [
"function",
"helloWorld() {",
'console.log("He',
"llo,",
'World!");',
"}",
"// Call the",
"function",
"helloWorld();",
]
def test_java_code_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.JAVA, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
public class HelloWorld {
public static void main(String[] args) {
System.out.println("Hello, World!");
}
}
"""
chunks = splitter.split_text(code)
assert chunks == [
"public class",
"HelloWorld {",
"public",
"static void",
"main(String[]",
"args) {",
"System.out.prin",
'tln("Hello,',
'World!");',
"}\n}",
]
def test_cpp_code_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.CPP, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
#include <iostream>
int main() {
std::cout << "Hello, World!" << std::endl;
return 0;
}
"""
chunks = splitter.split_text(code)
assert chunks == [
"#include",
"<iostream>",
"int main() {",
"std::cout",
'<< "Hello,',
'World!" <<',
"std::endl;",
"return 0;\n}",
]
def test_scala_code_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.SCALA, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
object HelloWorld {
def main(args: Array[String]): Unit = {
println("Hello, World!")
}
}
"""
chunks = splitter.split_text(code)
assert chunks == [
"object",
"HelloWorld {",
"def",
"main(args:",
"Array[String]):",
"Unit = {",
'println("Hello,',
'World!")',
"}\n}",
]
def test_ruby_code_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.RUBY, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
def hello_world
puts "Hello, World!"
end
hello_world
"""
chunks = splitter.split_text(code)
assert chunks == [
"def hello_world",
'puts "Hello,',
'World!"',
"end",
"hello_world",
]
def test_php_code_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.PHP, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
<?php
function hello_world() {
echo "Hello, World!";
}
hello_world();
?>
"""
chunks = splitter.split_text(code)
assert chunks == [
"<?php",
"function",
"hello_world() {",
"echo",
'"Hello,',
'World!";',
"}",
"hello_world();",
"?>",
]
def test_swift_code_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.SWIFT, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
func helloWorld() {
print("Hello, World!")
}
helloWorld()
"""
chunks = splitter.split_text(code)
assert chunks == [
"func",
"helloWorld() {",
'print("Hello,',
'World!")',
"}",
"helloWorld()",
]
def test_rust_code_splitter() -> None:
splitter = CodeTextSplitter(
language=Language.RUST, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
fn main() {
println!("Hello, World!");
}
"""
chunks = splitter.split_text(code)
assert chunks == ["fn main() {", 'println!("Hello', ",", 'World!");', "}"]