Add COBOL parser and splitter (#11674)

- **Description:** Add COBOL parser and splitter
  - **Issue:** n/a
  - **Dependencies:** n/a
  - **Tag maintainer:** @baskaryan 
  - **Twitter handle:** erhartford

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
pull/12173/head^2
Eric Hartford 9 months ago committed by GitHub
parent bb137fd6e7
commit 8c150ad7f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,96 @@
import re
from typing import Callable, List
from langchain.document_loaders.parsers.language.code_segmenter import CodeSegmenter
class CobolSegmenter(CodeSegmenter):
"""Code segmenter for `COBOL`."""
PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
DIVISION_PATTERN = re.compile(
r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
)
SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
def __init__(self, code: str):
super().__init__(code)
self.source_lines: List[str] = self.code.splitlines()
def is_valid(self) -> bool:
# Identify presence of any division to validate COBOL code
return any(self.DIVISION_PATTERN.match(line) for line in self.source_lines)
def _extract_code(self, start_idx: int, end_idx: int) -> str:
return "\n".join(self.source_lines[start_idx:end_idx]).rstrip("\n")
def _is_relevant_code(self, line: str) -> bool:
"""Check if a line is part of the procedure division or a relevant section."""
if "PROCEDURE DIVISION" in line.upper():
return True
# Add additional conditions for relevant sections if needed
return False
def _process_lines(self, func: Callable) -> List[str]:
"""A generic function to process COBOL lines based on provided func."""
elements: List[str] = []
start_idx = None
inside_relevant_section = False
for i, line in enumerate(self.source_lines):
if self._is_relevant_code(line):
inside_relevant_section = True
if inside_relevant_section and (
self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
or self.SECTION_PATTERN.match(line.strip())
):
if start_idx is not None:
func(elements, start_idx, i)
start_idx = i
# Handle the last element if exists
if start_idx is not None:
func(elements, start_idx, len(self.source_lines))
return elements
def extract_functions_classes(self) -> List[str]:
def extract_func(elements: List[str], start_idx: int, end_idx: int) -> None:
elements.append(self._extract_code(start_idx, end_idx))
return self._process_lines(extract_func)
def simplify_code(self) -> str:
simplified_lines: List[str] = []
inside_relevant_section = False
omitted_code_added = (
False # To track if "* OMITTED CODE *" has been added after the last header
)
for line in self.source_lines:
is_header = (
"PROCEDURE DIVISION" in line
or "DATA DIVISION" in line
or "IDENTIFICATION DIVISION" in line
or self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
or self.SECTION_PATTERN.match(line.strip())
)
if is_header:
inside_relevant_section = True
# Reset the flag since we're entering a new section/division or
# paragraph
omitted_code_added = False
if inside_relevant_section:
if is_header:
# Add header and reset the omitted code added flag
simplified_lines.append(line)
elif not omitted_code_added:
# Add omitted code comment only if it hasn't been added directly
# after the last header
simplified_lines.append("* OMITTED CODE *")
omitted_code_added = True
return "\n".join(simplified_lines)

@ -3,6 +3,7 @@ from typing import Any, Dict, Iterator, Optional
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseBlobParser
from langchain.document_loaders.blob_loaders import Blob
from langchain.document_loaders.parsers.language.cobol import CobolSegmenter
from langchain.document_loaders.parsers.language.javascript import JavaScriptSegmenter
from langchain.document_loaders.parsers.language.python import PythonSegmenter
from langchain.text_splitter import Language
@ -10,11 +11,13 @@ from langchain.text_splitter import Language
LANGUAGE_EXTENSIONS: Dict[str, str] = {
"py": Language.PYTHON,
"js": Language.JS,
"cobol": Language.COBOL,
}
LANGUAGE_SEGMENTERS: Dict[str, Any] = {
Language.PYTHON: PythonSegmenter,
Language.JS: JavaScriptSegmenter,
Language.COBOL: CobolSegmenter,
}

@ -811,6 +811,7 @@ class Language(str, Enum):
HTML = "html"
SOL = "sol"
CSHARP = "csharp"
COBOL = "cobol"
class RecursiveCharacterTextSplitter(TextSplitter):
@ -1305,6 +1306,38 @@ class RecursiveCharacterTextSplitter(TextSplitter):
" ",
"",
]
elif language == Language.COBOL:
return [
# Split along divisions
"\nIDENTIFICATION DIVISION.",
"\nENVIRONMENT DIVISION.",
"\nDATA DIVISION.",
"\nPROCEDURE DIVISION.",
# Split along sections within DATA DIVISION
"\nWORKING-STORAGE SECTION.",
"\nLINKAGE SECTION.",
"\nFILE SECTION.",
# Split along sections within PROCEDURE DIVISION
"\nINPUT-OUTPUT SECTION.",
# Split along paragraphs and common statements
"\nOPEN ",
"\nCLOSE ",
"\nREAD ",
"\nWRITE ",
"\nIF ",
"\nELSE ",
"\nMOVE ",
"\nPERFORM ",
"\nUNTIL ",
"\nVARYING ",
"\nACCEPT ",
"\nDISPLAY ",
"\nSTOP RUN.",
# Split by the normal type of lines
"\n",
" ",
"",
]
else:
raise ValueError(

@ -0,0 +1,49 @@
from langchain.document_loaders.parsers.language.cobol import CobolSegmenter
EXAMPLE_CODE = """
IDENTIFICATION DIVISION.
PROGRAM-ID. SampleProgram.
DATA DIVISION.
WORKING-STORAGE SECTION.
01 SAMPLE-VAR PIC X(20) VALUE 'Sample Value'.
PROCEDURE DIVISION.
A000-INITIALIZE-PARA.
DISPLAY 'Initialization Paragraph'.
MOVE 'New Value' TO SAMPLE-VAR.
A100-PROCESS-PARA.
DISPLAY SAMPLE-VAR.
STOP RUN.
"""
def test_extract_functions_classes() -> None:
"""Test that functions and classes are extracted correctly."""
segmenter = CobolSegmenter(EXAMPLE_CODE)
extracted_code = segmenter.extract_functions_classes()
assert extracted_code == [
"A000-INITIALIZE-PARA.\n "
"DISPLAY 'Initialization Paragraph'.\n "
"MOVE 'New Value' TO SAMPLE-VAR.",
"A100-PROCESS-PARA.\n DISPLAY SAMPLE-VAR.\n STOP RUN.",
]
def test_simplify_code() -> None:
"""Test that code is simplified correctly."""
expected_simplified_code = (
"IDENTIFICATION DIVISION.\n"
"PROGRAM-ID. SampleProgram.\n"
"DATA DIVISION.\n"
"WORKING-STORAGE SECTION.\n"
"* OMITTED CODE *\n"
"PROCEDURE DIVISION.\n"
"A000-INITIALIZE-PARA.\n"
"* OMITTED CODE *\n"
"A100-PROCESS-PARA.\n"
"* OMITTED CODE *\n"
)
segmenter = CobolSegmenter(EXAMPLE_CODE)
simplified_code = segmenter.simplify_code()
assert simplified_code.strip() == expected_simplified_code.strip()

@ -472,6 +472,41 @@ helloWorld();
]
def test_cobol_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.COBOL, chunk_size=CHUNK_SIZE, chunk_overlap=0
)
code = """
IDENTIFICATION DIVISION.
PROGRAM-ID. HelloWorld.
DATA DIVISION.
WORKING-STORAGE SECTION.
01 GREETING PIC X(12) VALUE 'Hello, World!'.
PROCEDURE DIVISION.
DISPLAY GREETING.
STOP RUN.
"""
chunks = splitter.split_text(code)
assert chunks == [
"IDENTIFICATION",
"DIVISION.",
"PROGRAM-ID.",
"HelloWorld.",
"DATA DIVISION.",
"WORKING-STORAGE",
"SECTION.",
"01 GREETING",
"PIC X(12)",
"VALUE 'Hello,",
"World!'.",
"PROCEDURE",
"DIVISION.",
"DISPLAY",
"GREETING.",
"STOP RUN.",
]
def test_typescript_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.TS, chunk_size=CHUNK_SIZE, chunk_overlap=0

Loading…
Cancel
Save