diff --git a/libs/langchain/langchain/document_loaders/parsers/language/cobol.py b/libs/langchain/langchain/document_loaders/parsers/language/cobol.py new file mode 100644 index 0000000000..2267e8c522 --- /dev/null +++ b/libs/langchain/langchain/document_loaders/parsers/language/cobol.py @@ -0,0 +1,96 @@ +import re +from typing import Callable, List + +from langchain.document_loaders.parsers.language.code_segmenter import CodeSegmenter + + +class CobolSegmenter(CodeSegmenter): + """Code segmenter for `COBOL`.""" + + PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE) + DIVISION_PATTERN = re.compile( + r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE + ) + SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE) + + def __init__(self, code: str): + super().__init__(code) + self.source_lines: List[str] = self.code.splitlines() + + def is_valid(self) -> bool: + # Identify presence of any division to validate COBOL code + return any(self.DIVISION_PATTERN.match(line) for line in self.source_lines) + + def _extract_code(self, start_idx: int, end_idx: int) -> str: + return "\n".join(self.source_lines[start_idx:end_idx]).rstrip("\n") + + def _is_relevant_code(self, line: str) -> bool: + """Check if a line is part of the procedure division or a relevant section.""" + if "PROCEDURE DIVISION" in line.upper(): + return True + # Add additional conditions for relevant sections if needed + return False + + def _process_lines(self, func: Callable) -> List[str]: + """A generic function to process COBOL lines based on provided func.""" + elements: List[str] = [] + start_idx = None + inside_relevant_section = False + + for i, line in enumerate(self.source_lines): + if self._is_relevant_code(line): + inside_relevant_section = True + + if inside_relevant_section and ( + self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0]) + or self.SECTION_PATTERN.match(line.strip()) + ): + if start_idx is not None: + func(elements, start_idx, i) + start_idx = i + + # Handle the last element if exists + if start_idx is not None: + func(elements, start_idx, len(self.source_lines)) + + return elements + + def extract_functions_classes(self) -> List[str]: + def extract_func(elements: List[str], start_idx: int, end_idx: int) -> None: + elements.append(self._extract_code(start_idx, end_idx)) + + return self._process_lines(extract_func) + + def simplify_code(self) -> str: + simplified_lines: List[str] = [] + inside_relevant_section = False + omitted_code_added = ( + False # To track if "* OMITTED CODE *" has been added after the last header + ) + + for line in self.source_lines: + is_header = ( + "PROCEDURE DIVISION" in line + or "DATA DIVISION" in line + or "IDENTIFICATION DIVISION" in line + or self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0]) + or self.SECTION_PATTERN.match(line.strip()) + ) + + if is_header: + inside_relevant_section = True + # Reset the flag since we're entering a new section/division or + # paragraph + omitted_code_added = False + + if inside_relevant_section: + if is_header: + # Add header and reset the omitted code added flag + simplified_lines.append(line) + elif not omitted_code_added: + # Add omitted code comment only if it hasn't been added directly + # after the last header + simplified_lines.append("* OMITTED CODE *") + omitted_code_added = True + + return "\n".join(simplified_lines) diff --git a/libs/langchain/langchain/document_loaders/parsers/language/language_parser.py b/libs/langchain/langchain/document_loaders/parsers/language/language_parser.py index 97d26a99e6..534f151e39 100644 --- a/libs/langchain/langchain/document_loaders/parsers/language/language_parser.py +++ b/libs/langchain/langchain/document_loaders/parsers/language/language_parser.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Iterator, Optional from langchain.docstore.document import Document from langchain.document_loaders.base import BaseBlobParser from langchain.document_loaders.blob_loaders import Blob +from langchain.document_loaders.parsers.language.cobol import CobolSegmenter from langchain.document_loaders.parsers.language.javascript import JavaScriptSegmenter from langchain.document_loaders.parsers.language.python import PythonSegmenter from langchain.text_splitter import Language @@ -10,11 +11,13 @@ from langchain.text_splitter import Language LANGUAGE_EXTENSIONS: Dict[str, str] = { "py": Language.PYTHON, "js": Language.JS, + "cobol": Language.COBOL, } LANGUAGE_SEGMENTERS: Dict[str, Any] = { Language.PYTHON: PythonSegmenter, Language.JS: JavaScriptSegmenter, + Language.COBOL: CobolSegmenter, } diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index b894937a57..721ffab6be 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -811,6 +811,7 @@ class Language(str, Enum): HTML = "html" SOL = "sol" CSHARP = "csharp" + COBOL = "cobol" class RecursiveCharacterTextSplitter(TextSplitter): @@ -1305,6 +1306,38 @@ class RecursiveCharacterTextSplitter(TextSplitter): " ", "", ] + elif language == Language.COBOL: + return [ + # Split along divisions + "\nIDENTIFICATION DIVISION.", + "\nENVIRONMENT DIVISION.", + "\nDATA DIVISION.", + "\nPROCEDURE DIVISION.", + # Split along sections within DATA DIVISION + "\nWORKING-STORAGE SECTION.", + "\nLINKAGE SECTION.", + "\nFILE SECTION.", + # Split along sections within PROCEDURE DIVISION + "\nINPUT-OUTPUT SECTION.", + # Split along paragraphs and common statements + "\nOPEN ", + "\nCLOSE ", + "\nREAD ", + "\nWRITE ", + "\nIF ", + "\nELSE ", + "\nMOVE ", + "\nPERFORM ", + "\nUNTIL ", + "\nVARYING ", + "\nACCEPT ", + "\nDISPLAY ", + "\nSTOP RUN.", + # Split by the normal type of lines + "\n", + " ", + "", + ] else: raise ValueError( diff --git a/libs/langchain/tests/unit_tests/document_loaders/parsers/language/test_cobol.py b/libs/langchain/tests/unit_tests/document_loaders/parsers/language/test_cobol.py new file mode 100644 index 0000000000..49d691696d --- /dev/null +++ b/libs/langchain/tests/unit_tests/document_loaders/parsers/language/test_cobol.py @@ -0,0 +1,49 @@ +from langchain.document_loaders.parsers.language.cobol import CobolSegmenter + +EXAMPLE_CODE = """ +IDENTIFICATION DIVISION. +PROGRAM-ID. SampleProgram. +DATA DIVISION. +WORKING-STORAGE SECTION. +01 SAMPLE-VAR PIC X(20) VALUE 'Sample Value'. + +PROCEDURE DIVISION. +A000-INITIALIZE-PARA. + DISPLAY 'Initialization Paragraph'. + MOVE 'New Value' TO SAMPLE-VAR. + +A100-PROCESS-PARA. + DISPLAY SAMPLE-VAR. + STOP RUN. +""" + + +def test_extract_functions_classes() -> None: + """Test that functions and classes are extracted correctly.""" + segmenter = CobolSegmenter(EXAMPLE_CODE) + extracted_code = segmenter.extract_functions_classes() + assert extracted_code == [ + "A000-INITIALIZE-PARA.\n " + "DISPLAY 'Initialization Paragraph'.\n " + "MOVE 'New Value' TO SAMPLE-VAR.", + "A100-PROCESS-PARA.\n DISPLAY SAMPLE-VAR.\n STOP RUN.", + ] + + +def test_simplify_code() -> None: + """Test that code is simplified correctly.""" + expected_simplified_code = ( + "IDENTIFICATION DIVISION.\n" + "PROGRAM-ID. SampleProgram.\n" + "DATA DIVISION.\n" + "WORKING-STORAGE SECTION.\n" + "* OMITTED CODE *\n" + "PROCEDURE DIVISION.\n" + "A000-INITIALIZE-PARA.\n" + "* OMITTED CODE *\n" + "A100-PROCESS-PARA.\n" + "* OMITTED CODE *\n" + ) + segmenter = CobolSegmenter(EXAMPLE_CODE) + simplified_code = segmenter.simplify_code() + assert simplified_code.strip() == expected_simplified_code.strip() diff --git a/libs/langchain/tests/unit_tests/test_text_splitter.py b/libs/langchain/tests/unit_tests/test_text_splitter.py index 578286cedc..e6951ef67f 100644 --- a/libs/langchain/tests/unit_tests/test_text_splitter.py +++ b/libs/langchain/tests/unit_tests/test_text_splitter.py @@ -472,6 +472,41 @@ helloWorld(); ] +def test_cobol_code_splitter() -> None: + splitter = RecursiveCharacterTextSplitter.from_language( + Language.COBOL, chunk_size=CHUNK_SIZE, chunk_overlap=0 + ) + code = """ +IDENTIFICATION DIVISION. +PROGRAM-ID. HelloWorld. +DATA DIVISION. +WORKING-STORAGE SECTION. +01 GREETING PIC X(12) VALUE 'Hello, World!'. +PROCEDURE DIVISION. +DISPLAY GREETING. +STOP RUN. + """ + chunks = splitter.split_text(code) + assert chunks == [ + "IDENTIFICATION", + "DIVISION.", + "PROGRAM-ID.", + "HelloWorld.", + "DATA DIVISION.", + "WORKING-STORAGE", + "SECTION.", + "01 GREETING", + "PIC X(12)", + "VALUE 'Hello,", + "World!'.", + "PROCEDURE", + "DIVISION.", + "DISPLAY", + "GREETING.", + "STOP RUN.", + ] + + def test_typescript_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.TS, chunk_size=CHUNK_SIZE, chunk_overlap=0