mirror of https://github.com/hwchase17/langchain
Add COBOL parser and splitter (#11674)
- **Description:** Add COBOL parser and splitter - **Issue:** n/a - **Dependencies:** n/a - **Tag maintainer:** @baskaryan - **Twitter handle:** erhartford --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>pull/12173/head^2
parent
bb137fd6e7
commit
8c150ad7f6
@ -0,0 +1,96 @@
|
|||||||
|
import re
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
from langchain.document_loaders.parsers.language.code_segmenter import CodeSegmenter
|
||||||
|
|
||||||
|
|
||||||
|
class CobolSegmenter(CodeSegmenter):
|
||||||
|
"""Code segmenter for `COBOL`."""
|
||||||
|
|
||||||
|
PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
|
||||||
|
DIVISION_PATTERN = re.compile(
|
||||||
|
r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
|
||||||
|
)
|
||||||
|
SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
|
||||||
|
|
||||||
|
def __init__(self, code: str):
|
||||||
|
super().__init__(code)
|
||||||
|
self.source_lines: List[str] = self.code.splitlines()
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
# Identify presence of any division to validate COBOL code
|
||||||
|
return any(self.DIVISION_PATTERN.match(line) for line in self.source_lines)
|
||||||
|
|
||||||
|
def _extract_code(self, start_idx: int, end_idx: int) -> str:
|
||||||
|
return "\n".join(self.source_lines[start_idx:end_idx]).rstrip("\n")
|
||||||
|
|
||||||
|
def _is_relevant_code(self, line: str) -> bool:
|
||||||
|
"""Check if a line is part of the procedure division or a relevant section."""
|
||||||
|
if "PROCEDURE DIVISION" in line.upper():
|
||||||
|
return True
|
||||||
|
# Add additional conditions for relevant sections if needed
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _process_lines(self, func: Callable) -> List[str]:
|
||||||
|
"""A generic function to process COBOL lines based on provided func."""
|
||||||
|
elements: List[str] = []
|
||||||
|
start_idx = None
|
||||||
|
inside_relevant_section = False
|
||||||
|
|
||||||
|
for i, line in enumerate(self.source_lines):
|
||||||
|
if self._is_relevant_code(line):
|
||||||
|
inside_relevant_section = True
|
||||||
|
|
||||||
|
if inside_relevant_section and (
|
||||||
|
self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
|
||||||
|
or self.SECTION_PATTERN.match(line.strip())
|
||||||
|
):
|
||||||
|
if start_idx is not None:
|
||||||
|
func(elements, start_idx, i)
|
||||||
|
start_idx = i
|
||||||
|
|
||||||
|
# Handle the last element if exists
|
||||||
|
if start_idx is not None:
|
||||||
|
func(elements, start_idx, len(self.source_lines))
|
||||||
|
|
||||||
|
return elements
|
||||||
|
|
||||||
|
def extract_functions_classes(self) -> List[str]:
|
||||||
|
def extract_func(elements: List[str], start_idx: int, end_idx: int) -> None:
|
||||||
|
elements.append(self._extract_code(start_idx, end_idx))
|
||||||
|
|
||||||
|
return self._process_lines(extract_func)
|
||||||
|
|
||||||
|
def simplify_code(self) -> str:
|
||||||
|
simplified_lines: List[str] = []
|
||||||
|
inside_relevant_section = False
|
||||||
|
omitted_code_added = (
|
||||||
|
False # To track if "* OMITTED CODE *" has been added after the last header
|
||||||
|
)
|
||||||
|
|
||||||
|
for line in self.source_lines:
|
||||||
|
is_header = (
|
||||||
|
"PROCEDURE DIVISION" in line
|
||||||
|
or "DATA DIVISION" in line
|
||||||
|
or "IDENTIFICATION DIVISION" in line
|
||||||
|
or self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
|
||||||
|
or self.SECTION_PATTERN.match(line.strip())
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_header:
|
||||||
|
inside_relevant_section = True
|
||||||
|
# Reset the flag since we're entering a new section/division or
|
||||||
|
# paragraph
|
||||||
|
omitted_code_added = False
|
||||||
|
|
||||||
|
if inside_relevant_section:
|
||||||
|
if is_header:
|
||||||
|
# Add header and reset the omitted code added flag
|
||||||
|
simplified_lines.append(line)
|
||||||
|
elif not omitted_code_added:
|
||||||
|
# Add omitted code comment only if it hasn't been added directly
|
||||||
|
# after the last header
|
||||||
|
simplified_lines.append("* OMITTED CODE *")
|
||||||
|
omitted_code_added = True
|
||||||
|
|
||||||
|
return "\n".join(simplified_lines)
|
@ -0,0 +1,49 @@
|
|||||||
|
from langchain.document_loaders.parsers.language.cobol import CobolSegmenter
|
||||||
|
|
||||||
|
EXAMPLE_CODE = """
|
||||||
|
IDENTIFICATION DIVISION.
|
||||||
|
PROGRAM-ID. SampleProgram.
|
||||||
|
DATA DIVISION.
|
||||||
|
WORKING-STORAGE SECTION.
|
||||||
|
01 SAMPLE-VAR PIC X(20) VALUE 'Sample Value'.
|
||||||
|
|
||||||
|
PROCEDURE DIVISION.
|
||||||
|
A000-INITIALIZE-PARA.
|
||||||
|
DISPLAY 'Initialization Paragraph'.
|
||||||
|
MOVE 'New Value' TO SAMPLE-VAR.
|
||||||
|
|
||||||
|
A100-PROCESS-PARA.
|
||||||
|
DISPLAY SAMPLE-VAR.
|
||||||
|
STOP RUN.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_functions_classes() -> None:
|
||||||
|
"""Test that functions and classes are extracted correctly."""
|
||||||
|
segmenter = CobolSegmenter(EXAMPLE_CODE)
|
||||||
|
extracted_code = segmenter.extract_functions_classes()
|
||||||
|
assert extracted_code == [
|
||||||
|
"A000-INITIALIZE-PARA.\n "
|
||||||
|
"DISPLAY 'Initialization Paragraph'.\n "
|
||||||
|
"MOVE 'New Value' TO SAMPLE-VAR.",
|
||||||
|
"A100-PROCESS-PARA.\n DISPLAY SAMPLE-VAR.\n STOP RUN.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_simplify_code() -> None:
|
||||||
|
"""Test that code is simplified correctly."""
|
||||||
|
expected_simplified_code = (
|
||||||
|
"IDENTIFICATION DIVISION.\n"
|
||||||
|
"PROGRAM-ID. SampleProgram.\n"
|
||||||
|
"DATA DIVISION.\n"
|
||||||
|
"WORKING-STORAGE SECTION.\n"
|
||||||
|
"* OMITTED CODE *\n"
|
||||||
|
"PROCEDURE DIVISION.\n"
|
||||||
|
"A000-INITIALIZE-PARA.\n"
|
||||||
|
"* OMITTED CODE *\n"
|
||||||
|
"A100-PROCESS-PARA.\n"
|
||||||
|
"* OMITTED CODE *\n"
|
||||||
|
)
|
||||||
|
segmenter = CobolSegmenter(EXAMPLE_CODE)
|
||||||
|
simplified_code = segmenter.simplify_code()
|
||||||
|
assert simplified_code.strip() == expected_simplified_code.strip()
|
Loading…
Reference in New Issue