feat (documents): add a source code loader based on AST manipulation (#6486)

#### Summary

A new approach to loading source code is implemented:

Each top-level function and class in the code is loaded into separate
documents. Then, an additional document is created with the top-level
code, but without the already loaded functions and classes.

This could improve the accuracy of QA chains over source code.

For instance, having this script:

```
class MyClass:
    def __init__(self, name):
        self.name = name

    def greet(self):
        print(f"Hello, {self.name}!")

def main():
    name = input("Enter your name: ")
    obj = MyClass(name)
    obj.greet()

if __name__ == '__main__':
    main()
```

The loader will create three documents with this content:

First document:
```
class MyClass:
    def __init__(self, name):
        self.name = name

    def greet(self):
        print(f"Hello, {self.name}!")
```

Second document:
```
def main():
    name = input("Enter your name: ")
    obj = MyClass(name)
    obj.greet()
```

Third document:
```
# Code for: class MyClass:

# Code for: def main():

if __name__ == '__main__':
    main()
```

A threshold parameter is added to control whether small scripts are
split in this way or not.

At this moment, only Python and JavaScript are supported. The
appropriate parser is determined by examining the file extension.

#### Tests

This PR adds:

- Unit tests
- Integration tests

#### Dependencies

Only one dependency was added as optional (needed for the JavaScript
parser).

#### Documentation

A notebook is added showing how the loader can be used.

#### Who can review?

@eyurtsev @hwchase17

---------

Co-authored-by: rlm <pexpresss31@gmail.com>
pull/6853/head
Cristóbal Carnero Liñán 1 year ago committed by GitHub
parent da462d9dd4
commit e494b0a09f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,17 @@
class MyClass {
constructor(name) {
this.name = name;
}
greet() {
console.log(`Hello, ${this.name}!`);
}
}
function main() {
const name = prompt("Enter your name:");
const obj = new MyClass(name);
obj.greet();
}
main();

@ -0,0 +1,16 @@
class MyClass:
def __init__(self, name):
self.name = name
def greet(self):
print(f"Hello, {self.name}!")
def main():
name = input("Enter your name: ")
obj = MyClass(name)
obj.greet()
if __name__ == "__main__":
main()

@ -0,0 +1,419 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "213a38a2",
"metadata": {},
"source": [
"# Source Code\n",
"\n",
"This notebook covers how to load source code files using a special approach with language parsing: each top-level function and class in the code is loaded into separate documents. Any remaining code top-level code outside the already loaded functions and classes will be loaded into a seperate document.\n",
"\n",
"This approach can potentially improve the accuracy of QA models over source code. Currently, the supported languages for code parsing are Python and JavaScript. The language used for parsing can be configured, along with the minimum number of lines required to activate the splitting based on syntax."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7fa47b2e",
"metadata": {},
"outputs": [],
"source": [
"! pip install esprima"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "beb55c2f",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"from pprint import pprint\n",
"from langchain.text_splitter import Language\n",
"from langchain.document_loaders.generic import GenericLoader\n",
"from langchain.document_loaders.parsers import LanguageParser"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "64056e07",
"metadata": {},
"outputs": [],
"source": [
"loader = GenericLoader.from_filesystem(\n",
" \"./example_data/source_code\",\n",
" glob=\"*\",\n",
" suffixes=[\".py\", \".js\"],\n",
" parser=LanguageParser()\n",
")\n",
"docs = loader.load()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8af79bd7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(docs)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "85edf3fc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'content_type': 'functions_classes',\n",
" 'language': <Language.PYTHON: 'python'>,\n",
" 'source': 'example_data/source_code/example.py'}\n",
"{'content_type': 'functions_classes',\n",
" 'language': <Language.PYTHON: 'python'>,\n",
" 'source': 'example_data/source_code/example.py'}\n",
"{'content_type': 'simplified_code',\n",
" 'language': <Language.PYTHON: 'python'>,\n",
" 'source': 'example_data/source_code/example.py'}\n",
"{'content_type': 'functions_classes',\n",
" 'language': <Language.JS: 'js'>,\n",
" 'source': 'example_data/source_code/example.js'}\n",
"{'content_type': 'functions_classes',\n",
" 'language': <Language.JS: 'js'>,\n",
" 'source': 'example_data/source_code/example.js'}\n",
"{'content_type': 'simplified_code',\n",
" 'language': <Language.JS: 'js'>,\n",
" 'source': 'example_data/source_code/example.js'}\n"
]
}
],
"source": [
"for document in docs:\n",
" pprint(document.metadata)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f44e3e37",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"class MyClass:\n",
" def __init__(self, name):\n",
" self.name = name\n",
"\n",
" def greet(self):\n",
" print(f\"Hello, {self.name}!\")\n",
"\n",
"--8<--\n",
"\n",
"def main():\n",
" name = input(\"Enter your name: \")\n",
" obj = MyClass(name)\n",
" obj.greet()\n",
"\n",
"--8<--\n",
"\n",
"# Code for: class MyClass:\n",
"\n",
"\n",
"# Code for: def main():\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" main()\n",
"\n",
"--8<--\n",
"\n",
"class MyClass {\n",
" constructor(name) {\n",
" this.name = name;\n",
" }\n",
"\n",
" greet() {\n",
" console.log(`Hello, ${this.name}!`);\n",
" }\n",
"}\n",
"\n",
"--8<--\n",
"\n",
"function main() {\n",
" const name = prompt(\"Enter your name:\");\n",
" const obj = new MyClass(name);\n",
" obj.greet();\n",
"}\n",
"\n",
"--8<--\n",
"\n",
"// Code for: class MyClass {\n",
"\n",
"// Code for: function main() {\n",
"\n",
"main();\n"
]
}
],
"source": [
"print(\"\\n\\n--8<--\\n\\n\".join([document.page_content for document in docs]))"
]
},
{
"cell_type": "markdown",
"id": "69aad0ed",
"metadata": {},
"source": [
"The parser can be disabled for small files. \n",
"\n",
"The parameter `parser_threshold` indicates the minimum number of lines that the source code file must have to be segmented using the parser."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ae024794",
"metadata": {},
"outputs": [],
"source": [
"loader = GenericLoader.from_filesystem(\n",
" \"./example_data/source_code\",\n",
" glob=\"*\",\n",
" suffixes=[\".py\"],\n",
" parser=LanguageParser(language=Language.PYTHON, parser_threshold=1000)\n",
")\n",
"docs = loader.load()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5d3b372a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(docs)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "89e546ad",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"class MyClass:\n",
" def __init__(self, name):\n",
" self.name = name\n",
"\n",
" def greet(self):\n",
" print(f\"Hello, {self.name}!\")\n",
"\n",
"\n",
"def main():\n",
" name = input(\"Enter your name: \")\n",
" obj = MyClass(name)\n",
" obj.greet()\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" main()\n",
"\n"
]
}
],
"source": [
"print(docs[0].page_content)"
]
},
{
"cell_type": "markdown",
"id": "c9c71e61",
"metadata": {},
"source": [
"## Splitting\n",
"\n",
"Additional splitting could be needed for those functions, classes, or scripts that are too big."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "adbaa79f",
"metadata": {},
"outputs": [],
"source": [
"loader = GenericLoader.from_filesystem(\n",
" \"./example_data/source_code\",\n",
" glob=\"*\",\n",
" suffixes=[\".js\"],\n",
" parser=LanguageParser(language=Language.JS)\n",
")\n",
"docs = loader.load()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c44c0d3f",
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import (\n",
" RecursiveCharacterTextSplitter,\n",
" Language,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "b1e0053d",
"metadata": {},
"outputs": [],
"source": [
"js_splitter = RecursiveCharacterTextSplitter.from_language(\n",
" language=Language.JS, chunk_size=60, chunk_overlap=0\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "7dbe6188",
"metadata": {},
"outputs": [],
"source": [
"result = js_splitter.split_documents(docs)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8a80d089",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(result)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "000a6011",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"class MyClass {\n",
" constructor(name) {\n",
" this.name = name;\n",
"\n",
"--8<--\n",
"\n",
"}\n",
"\n",
"--8<--\n",
"\n",
"greet() {\n",
" console.log(`Hello, ${this.name}!`);\n",
" }\n",
"}\n",
"\n",
"--8<--\n",
"\n",
"function main() {\n",
" const name = prompt(\"Enter your name:\");\n",
"\n",
"--8<--\n",
"\n",
"const obj = new MyClass(name);\n",
" obj.greet();\n",
"}\n",
"\n",
"--8<--\n",
"\n",
"// Code for: class MyClass {\n",
"\n",
"// Code for: function main() {\n",
"\n",
"--8<--\n",
"\n",
"main();\n"
]
}
],
"source": [
"print(\"\\n\\n--8<--\\n\\n\".join([document.page_content for document in result]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -1,5 +1,6 @@
from langchain.document_loaders.parsers.audio import OpenAIWhisperParser
from langchain.document_loaders.parsers.html import BS4HTMLParser
from langchain.document_loaders.parsers.language import LanguageParser
from langchain.document_loaders.parsers.pdf import (
PDFMinerParser,
PDFPlumberParser,
@ -10,6 +11,7 @@ from langchain.document_loaders.parsers.pdf import (
__all__ = [
"BS4HTMLParser",
"LanguageParser",
"OpenAIWhisperParser",
"PDFMinerParser",
"PDFPlumberParser",

@ -0,0 +1,3 @@
from langchain.document_loaders.parsers.language.language_parser import LanguageParser
__all__ = ["LanguageParser"]

@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
from typing import List
class CodeSegmenter(ABC):
def __init__(self, code: str):
self.code = code
def is_valid(self) -> bool:
return True
@abstractmethod
def simplify_code(self) -> str:
raise NotImplementedError # pragma: no cover
@abstractmethod
def extract_functions_classes(self) -> List[str]:
raise NotImplementedError # pragma: no cover

@ -0,0 +1,65 @@
from typing import Any, List
from langchain.document_loaders.parsers.language.code_segmenter import CodeSegmenter
class JavaScriptSegmenter(CodeSegmenter):
def __init__(self, code: str):
super().__init__(code)
self.source_lines = self.code.splitlines()
try:
import esprima # noqa: F401
except ImportError:
raise ImportError(
"Could not import esprima Python package. "
"Please install it with `pip install esprima`."
)
def is_valid(self) -> bool:
import esprima
try:
esprima.parseScript(self.code)
return True
except esprima.Error:
return False
def _extract_code(self, node: Any) -> str:
start = node.loc.start.line - 1
end = node.loc.end.line
return "\n".join(self.source_lines[start:end])
def extract_functions_classes(self) -> List[str]:
import esprima
tree = esprima.parseScript(self.code, loc=True)
functions_classes = []
for node in tree.body:
if isinstance(
node,
(esprima.nodes.FunctionDeclaration, esprima.nodes.ClassDeclaration),
):
functions_classes.append(self._extract_code(node))
return functions_classes
def simplify_code(self) -> str:
import esprima
tree = esprima.parseScript(self.code, loc=True)
simplified_lines = self.source_lines[:]
for node in tree.body:
if isinstance(
node,
(esprima.nodes.FunctionDeclaration, esprima.nodes.ClassDeclaration),
):
start = node.loc.start.line - 1
simplified_lines[start] = f"// Code for: {simplified_lines[start]}"
for line_num in range(start + 1, node.loc.end.line):
simplified_lines[line_num] = None # type: ignore
return "\n".join(line for line in simplified_lines if line is not None)

@ -0,0 +1,143 @@
from typing import Any, Dict, Iterator, Optional
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.document_loaders.parsers.language.javascript import JavaScriptSegmenter
from langchain.document_loaders.parsers.language.python import PythonSegmenter
from langchain.text_splitter import Language
LANGUAGE_EXTENSIONS: Dict[str, str] = {
"py": Language.PYTHON,
"js": Language.JS,
}
LANGUAGE_SEGMENTERS: Dict[str, Any] = {
Language.PYTHON: PythonSegmenter,
Language.JS: JavaScriptSegmenter,
}
class LanguageParser(BaseBlobParser):
"""
Language parser that split code using the respective language syntax.
Each top-level function and class in the code is loaded into separate documents.
Furthermore, an extra document is generated, containing the remaining top-level code
that excludes the already segmented functions and classes.
This approach can potentially improve the accuracy of QA models over source code.
Currently, the supported languages for code parsing are Python and JavaScript.
The language used for parsing can be configured, along with the minimum number of
lines required to activate the splitting based on syntax.
Examples:
.. code-block:: python
from langchain.text_splitter.Language
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import LanguageParser
loader = GenericLoader.from_filesystem(
"./code",
glob="**/*",
suffixes=[".py", ".js"],
parser=LanguageParser()
)
docs = loader.load()
Example instantiations to manually select the language:
... code-block:: python
from langchain.text_splitter import Language
loader = GenericLoader.from_filesystem(
"./code",
glob="**/*",
suffixes=[".py"],
parser=LanguageParser(language=Language.PYTHON)
)
Example instantiations to set number of lines threshold:
... code-block:: python
loader = GenericLoader.from_filesystem(
"./code",
glob="**/*",
suffixes=[".py"],
parser=LanguageParser(parser_threshold=200)
)
"""
def __init__(self, language: Optional[Language] = None, parser_threshold: int = 0):
"""
Language parser that split code using the respective language syntax.
Args:
language: If None (default), it will try to infer language from source.
parser_threshold: Minimum lines needed to activate parsing (0 by default).
"""
self.language = language
self.parser_threshold = parser_threshold
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
code = blob.as_string()
language = self.language or (
LANGUAGE_EXTENSIONS.get(blob.source.rsplit(".", 1)[-1])
if isinstance(blob.source, str)
else None
)
if language is None:
yield Document(
page_content=code,
metadata={
"source": blob.source,
},
)
return
if self.parser_threshold >= len(code.splitlines()):
yield Document(
page_content=code,
metadata={
"source": blob.source,
"language": language,
},
)
return
self.Segmenter = LANGUAGE_SEGMENTERS[language]
segmenter = self.Segmenter(blob.as_string())
if not segmenter.is_valid():
yield Document(
page_content=code,
metadata={
"source": blob.source,
},
)
return
for functions_classes in segmenter.extract_functions_classes():
yield Document(
page_content=functions_classes,
metadata={
"source": blob.source,
"content_type": "functions_classes",
"language": language,
},
)
yield Document(
page_content=segmenter.simplify_code(),
metadata={
"source": blob.source,
"content_type": "simplified_code",
"language": language,
},
)

@ -0,0 +1,47 @@
import ast
from typing import Any, List
from langchain.document_loaders.parsers.language.code_segmenter import CodeSegmenter
class PythonSegmenter(CodeSegmenter):
def __init__(self, code: str):
super().__init__(code)
self.source_lines = self.code.splitlines()
def is_valid(self) -> bool:
try:
ast.parse(self.code)
return True
except SyntaxError:
return False
def _extract_code(self, node: Any) -> str:
start = node.lineno - 1
end = node.end_lineno
return "\n".join(self.source_lines[start:end])
def extract_functions_classes(self) -> List[str]:
tree = ast.parse(self.code)
functions_classes = []
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
functions_classes.append(self._extract_code(node))
return functions_classes
def simplify_code(self) -> str:
tree = ast.parse(self.code)
simplified_lines = self.source_lines[:]
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
start = node.lineno - 1
simplified_lines[start] = f"# Code for: {simplified_lines[start]}"
assert isinstance(node.end_lineno, int)
for line_num in range(start + 1, node.end_lineno):
simplified_lines[line_num] = None # type: ignore
return "\n".join(line for line in simplified_lines if line is not None)

586
poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -109,6 +109,7 @@ nebula3-python = {version = "^3.4.0", optional = true}
langchainplus-sdk = ">=0.0.17"
awadb = {version = "^0.3.3", optional = true}
azure-search-documents = {version = "11.4.0a20230509004", source = "azure-sdk-dev", optional = true}
esprima = {version = "^4.0.1", optional = true}
openllm = {version = ">=0.1.6", optional = true}
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
@ -222,6 +223,7 @@ clarifai = ["clarifai"]
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
javascript = ["esprima"]
azure = [
"azure-identity",
"azure-cosmos",
@ -303,6 +305,7 @@ all = [
"tigrisdb",
"nebula3-python",
"awadb",
"esprima",
]
# An extra used to be able to add extended testing.
@ -312,6 +315,7 @@ extended_testing = [
"beautifulsoup4",
"bibtexparser",
"chardet",
"esprima",
"jq",
"pdfminer.six",
"pgvector",
@ -354,7 +358,7 @@ exclude = [
[tool.mypy]
ignore_missing_imports = "True"
disallow_untyped_defs = "True"
exclude = ["notebooks"]
exclude = ["notebooks", "examples", "example_data"]
[tool.coverage.run]
omit = [

@ -0,0 +1,133 @@
from pathlib import Path
import pytest
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import LanguageParser
from langchain.text_splitter import Language
def test_language_loader_for_python() -> None:
"""Test Python loader with parser enabled."""
file_path = Path(__file__).parent.parent.parent / "examples"
loader = GenericLoader.from_filesystem(
file_path, glob="hello_world.py", parser=LanguageParser(parser_threshold=5)
)
docs = loader.load()
assert len(docs) == 2
metadata = docs[0].metadata
assert metadata["source"] == str(file_path / "hello_world.py")
assert metadata["content_type"] == "functions_classes"
assert metadata["language"] == "python"
metadata = docs[1].metadata
assert metadata["source"] == str(file_path / "hello_world.py")
assert metadata["content_type"] == "simplified_code"
assert metadata["language"] == "python"
assert (
docs[0].page_content
== """def main():
print("Hello World!")
return 0"""
)
assert (
docs[1].page_content
== """#!/usr/bin/env python3
import sys
# Code for: def main():
if __name__ == "__main__":
sys.exit(main())"""
)
def test_language_loader_for_python_with_parser_threshold() -> None:
"""Test Python loader with parser enabled and below threshold."""
file_path = Path(__file__).parent.parent.parent / "examples"
loader = GenericLoader.from_filesystem(
file_path,
glob="hello_world.py",
parser=LanguageParser(language=Language.PYTHON, parser_threshold=1000),
)
docs = loader.load()
assert len(docs) == 1
def esprima_installed() -> bool:
try:
import esprima # noqa: F401
return True
except Exception as e:
print(f"esprima not installed, skipping test {e}")
return False
@pytest.mark.skipif(not esprima_installed(), reason="requires esprima package")
def test_language_loader_for_javascript() -> None:
"""Test JavaScript loader with parser enabled."""
file_path = Path(__file__).parent.parent.parent / "examples"
loader = GenericLoader.from_filesystem(
file_path, glob="hello_world.js", parser=LanguageParser(parser_threshold=5)
)
docs = loader.load()
assert len(docs) == 3
metadata = docs[0].metadata
assert metadata["source"] == str(file_path / "hello_world.js")
assert metadata["content_type"] == "functions_classes"
assert metadata["language"] == "js"
metadata = docs[1].metadata
assert metadata["source"] == str(file_path / "hello_world.js")
assert metadata["content_type"] == "functions_classes"
assert metadata["language"] == "js"
metadata = docs[2].metadata
assert metadata["source"] == str(file_path / "hello_world.js")
assert metadata["content_type"] == "simplified_code"
assert metadata["language"] == "js"
assert (
docs[0].page_content
== """class HelloWorld {
sayHello() {
console.log("Hello World!");
}
}"""
)
assert (
docs[1].page_content
== """function main() {
const hello = new HelloWorld();
hello.sayHello();
}"""
)
assert (
docs[2].page_content
== """// Code for: class HelloWorld {
// Code for: function main() {
main();"""
)
def test_language_loader_for_javascript_with_parser_threshold() -> None:
"""Test JavaScript loader with parser enabled and below threshold."""
file_path = Path(__file__).parent.parent.parent / "examples"
loader = GenericLoader.from_filesystem(
file_path,
glob="hello_world.js",
parser=LanguageParser(language=Language.JS, parser_threshold=1000),
)
docs = loader.load()
assert len(docs) == 1

@ -0,0 +1,12 @@
class HelloWorld {
sayHello() {
console.log("Hello World!");
}
}
function main() {
const hello = new HelloWorld();
hello.sayHello();
}
main();

@ -0,0 +1,13 @@
#!/usr/bin/env python3
import sys
def main():
print("Hello World!")
return 0
if __name__ == "__main__":
sys.exit(main())

@ -0,0 +1,46 @@
import unittest
import pytest
from langchain.document_loaders.parsers.language.javascript import JavaScriptSegmenter
@pytest.mark.requires("esprima")
class TestJavaScriptSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """const os = require('os');
function hello(text) {
console.log(text);
}
class Simple {
constructor() {
this.a = 1;
}
}
hello("Hello!");"""
self.expected_simplified_code = """const os = require('os');
// Code for: function hello(text) {
// Code for: class Simple {
hello("Hello!");"""
self.expected_extracted_code = [
"function hello(text) {\n console.log(text);\n}",
"class Simple {\n constructor() {\n this.a = 1;\n }\n}",
]
def test_extract_functions_classes(self) -> None:
segmenter = JavaScriptSegmenter(self.example_code)
extracted_code = segmenter.extract_functions_classes()
self.assertEqual(extracted_code, self.expected_extracted_code)
def test_simplify_code(self) -> None:
segmenter = JavaScriptSegmenter(self.example_code)
simplified_code = segmenter.simplify_code()
self.assertEqual(simplified_code, self.expected_simplified_code)

@ -0,0 +1,40 @@
import unittest
from langchain.document_loaders.parsers.language.python import PythonSegmenter
class TestPythonSegmenter(unittest.TestCase):
def setUp(self) -> None:
self.example_code = """import os
def hello(text):
print(text)
class Simple:
def __init__(self):
self.a = 1
hello("Hello!")"""
self.expected_simplified_code = """import os
# Code for: def hello(text):
# Code for: class Simple:
hello("Hello!")"""
self.expected_extracted_code = [
"def hello(text):\n" " print(text)",
"class Simple:\n" " def __init__(self):\n" " self.a = 1",
]
def test_extract_functions_classes(self) -> None:
segmenter = PythonSegmenter(self.example_code)
extracted_code = segmenter.extract_functions_classes()
self.assertEqual(extracted_code, self.expected_extracted_code)
def test_simplify_code(self) -> None:
segmenter = PythonSegmenter(self.example_code)
simplified_code = segmenter.simplify_code()
self.assertEqual(simplified_code, self.expected_simplified_code)

@ -5,6 +5,7 @@ def test_parsers_public_api_correct() -> None:
"""Test public API of parsers for breaking changes."""
assert set(__all__) == {
"BS4HTMLParser",
"LanguageParser",
"OpenAIWhisperParser",
"PyPDFParser",
"PDFMinerParser",

Loading…
Cancel
Save