diff --git a/docs/docs/modules/data_connection/document_transformers/code_splitter.ipynb b/docs/docs/modules/data_connection/document_transformers/code_splitter.ipynb index 1d91f2877a..1090eb17ea 100644 --- a/docs/docs/modules/data_connection/document_transformers/code_splitter.ipynb +++ b/docs/docs/modules/data_connection/document_transformers/code_splitter.ipynb @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "id": "a9e37aa1", "metadata": {}, "outputs": [], @@ -35,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "id": "e21a2434", "metadata": {}, "outputs": [ @@ -61,10 +61,14 @@ " 'html',\n", " 'sol',\n", " 'csharp',\n", - " 'cobol']" + " 'cobol',\n", + " 'c',\n", + " 'lua',\n", + " 'perl',\n", + " 'haskell']" ] }, - "execution_count": 2, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -564,13 +568,50 @@ "c_docs" ] }, + { + "cell_type": "markdown", + "id": "af9de667-230e-4c2a-8c5f-122a28515d97", + "metadata": {}, + "source": [ + "## Haskell\n", + "Here's an example using the Haskell text splitter:" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "688185b5", "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='main :: IO ()'),\n", + " Document(page_content='main = do\\n putStrLn \"Hello, World!\"\\n-- Some'),\n", + " Document(page_content='sample functions\\nadd :: Int -> Int -> Int\\nadd x y'),\n", + " Document(page_content='= x + y')]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "HASKELL_CODE = \"\"\"\n", + "main :: IO ()\n", + "main = do\n", + " putStrLn \"Hello, World!\"\n", + "-- Some sample functions\n", + "add :: Int -> Int -> Int\n", + "add x y = x + y\n", + "\"\"\"\n", + "haskell_splitter = RecursiveCharacterTextSplitter.from_language(\n", + " language=Language.HASKELL, chunk_size=50, chunk_overlap=0\n", + ")\n", + "haskell_docs = haskell_splitter.create_documents([HASKELL_CODE])\n", + "haskell_docs" + ] } ], "metadata": { diff --git a/libs/text-splitters/langchain_text_splitters/base.py b/libs/text-splitters/langchain_text_splitters/base.py index 16480a16ed..b0fb33caa2 100644 --- a/libs/text-splitters/langchain_text_splitters/base.py +++ b/libs/text-splitters/langchain_text_splitters/base.py @@ -291,6 +291,7 @@ class Language(str, Enum): C = "c" LUA = "lua" PERL = "perl" + HASKELL = "haskell" @dataclass(frozen=True) diff --git a/libs/text-splitters/langchain_text_splitters/character.py b/libs/text-splitters/langchain_text_splitters/character.py index 090f6cc7f6..d01f2662e4 100644 --- a/libs/text-splitters/langchain_text_splitters/character.py +++ b/libs/text-splitters/langchain_text_splitters/character.py @@ -571,7 +571,45 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] - + elif language == Language.HASKELL: + return [ + # Split along function definitions + "\nmain :: ", + "\nmain = ", + "\nlet ", + "\nin ", + "\ndo ", + "\nwhere ", + "\n:: ", + "\n= ", + # Split along type declarations + "\ndata ", + "\nnewtype ", + "\ntype ", + "\n:: ", + # Split along module declarations + "\nmodule ", + # Split along import statements + "\nimport ", + "\nqualified ", + "\nimport qualified ", + # Split along typeclass declarations + "\nclass ", + "\ninstance ", + # Split along case expressions + "\ncase ", + # Split along guards in function definitions + "\n| ", + # Split along record field declarations + "\ndata ", + "\n= {", + "\n, ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] else: raise ValueError( f"Language {language} is not supported! " diff --git a/libs/text-splitters/tests/unit_tests/test_text_splitters.py b/libs/text-splitters/tests/unit_tests/test_text_splitters.py index edfcd0c61a..825fc4397f 100644 --- a/libs/text-splitters/tests/unit_tests/test_text_splitters.py +++ b/libs/text-splitters/tests/unit_tests/test_text_splitters.py @@ -1248,6 +1248,38 @@ def test_solidity_code_splitter() -> None: ] +def test_haskell_code_splitter() -> None: + splitter = RecursiveCharacterTextSplitter.from_language( + Language.HASKELL, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ + main :: IO () + main = do + putStrLn "Hello, World!" + + -- Some sample functions + add :: Int -> Int -> Int + add x y = x + y + """ + # Adjusted expected chunks to account for indentation and newlines + expected_chunks = [ + "main ::", + "IO ()", + "main = do", + "putStrLn", + '"Hello, World!"', + "--", + "Some sample", + "functions", + "add :: Int ->", + "Int -> Int", + "add x y = x", + "+ y", + ] + chunks = splitter.split_text(code) + assert chunks == expected_chunks + + @pytest.mark.requires("lxml") def test_html_header_text_splitter(tmp_path: Path) -> None: splitter = HTMLHeaderTextSplitter(