Skip to content

Commit

Permalink
#396 Implement support for parsing copybooks that do not have a root …
Browse files Browse the repository at this point in the history
…record GROUP.
  • Loading branch information
yruslan committed May 3, 2022
1 parent 3e0d45a commit 6b71a0b
Show file tree
Hide file tree
Showing 7 changed files with 297 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
*/
lazy val isHierarchical: Boolean = getAllSegmentRedefines.exists(_.parentSegment.nonEmpty)

val isFlatCopybook: Boolean = ast.children.exists(f => f.isInstanceOf[Primitive])

def getRootRecords: Seq[Statement] = {
if (isFlatCopybook) {
Seq(ast)
} else {
ast.children
}
}

/**
* Get the AST object of a field by name.
*
Expand All @@ -85,7 +95,7 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {

def getFieldByUniqueName(schema: CopybookAST, fieldName: String): Seq[Statement] = {
val transformedFieldName = CopybookParser.transformIdentifier(fieldName)
schema.children.flatMap(grp => getFieldByNameInGroup(grp.asInstanceOf[Group], transformedFieldName))
getFieldByNameInGroup(schema, transformedFieldName)
}

def getFieldByPathInGroup(group: Group, path: Array[String]): Seq[Statement] = {
Expand Down Expand Up @@ -116,12 +126,13 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {

def getFieldByPathName(ast: CopybookAST, fieldName: String): Seq[Statement] = {
val origPath = fieldName.split('.').map(str => CopybookParser.transformIdentifier(str))
val rootRecords = getRootRecords
val path = if (!pathBeginsWithRoot(ast, origPath)) {
ast.children.head.name +: origPath
rootRecords.head.name +: origPath
} else {
origPath
}
ast.children.flatMap(grp =>
rootRecords.flatMap(grp =>
if (grp.name.equalsIgnoreCase(path.head))
getFieldByPathInGroup(grp.asInstanceOf[Group], path.drop(1))
else
Expand Down Expand Up @@ -189,8 +200,6 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {

/** This routine is used for testing by generating a layout position information to compare with mainframe output */
def generateRecordLayoutPositions(): String = {
validate()

var fieldCounter: Int = 0

def alignLeft(str: String, width: Int): String = {
Expand Down Expand Up @@ -317,14 +326,6 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
}
visitGroup(ast)
}

private def validate(): Unit = {
for (grp <- ast.children) {
if (!grp.isInstanceOf[Group]) {
throw new IllegalArgumentException(s"Found a non-GROUP field at the root level (${grp.name}). Please, add the record-level root field, for example '01 RECORD.' at the top of the copybook")
}
}
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,12 @@ object CopybookParser extends Logging {
if (segmentRedefines.isEmpty) {
ast
} else {
val newSchema = processRootLevelFields(ast)
val isFlatAst = ast.children.exists(_.isInstanceOf[Primitive])
val newSchema = if (isFlatAst) {
processGroupFields(ast)
} else {
processRootLevelFields(ast)
}
validateAllSegmentsFound()
ast.withUpdatedChildren(newSchema.children)
}
Expand Down Expand Up @@ -985,10 +990,8 @@ object CopybookParser extends Logging {
* @return The same AST with non-filler size set for each group
*/
private def calculateNonFillerSizes(ast: CopybookAST): CopybookAST = {
var lastFillerIndex = 0

def calcSubGroupNonFillers(group: Group): Group = {
val newChildren = calcNonFillers(group)
def calcGroupNonFillers(group: Group): Group = {
val newChildren = calcNonFillerChildren(group)
var i = 0
var nonFillers = 0
while (i < group.children.length) {
Expand All @@ -999,11 +1002,11 @@ object CopybookParser extends Logging {
group.copy(nonFillerSize = nonFillers, children = newChildren.children)(group.parent)
}

def calcNonFillers(group: CopybookAST): CopybookAST = {
def calcNonFillerChildren(group: CopybookAST): CopybookAST = {
val newChildren = ArrayBuffer[Statement]()
group.children.foreach {
case grp: Group =>
val newGrp = calcSubGroupNonFillers(grp)
val newGrp = calcGroupNonFillers(grp)
if (newGrp.children.nonEmpty) {
newChildren += newGrp
}
Expand All @@ -1012,7 +1015,7 @@ object CopybookParser extends Logging {
group.withUpdatedChildren(newChildren)
}

calcNonFillers(ast)
calcGroupNonFillers(ast)
}

/** Transforms the Cobol identifiers to be useful in Spark context. Removes characters an identifier cannot contain. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ object RecordExtractors {
): Seq[Any] = {
val dependFields = scala.collection.mutable.HashMap.empty[String, Either[Int, String]]

val isAstFlat = ast.children.exists(_.isInstanceOf[Primitive])

def extractArray(field: Statement, useOffset: Int): (Int, Array[Any]) = {
val from = 0
val arraySize = field.arrayMaxSize
Expand Down Expand Up @@ -173,15 +175,27 @@ object RecordExtractors {

var nextOffset = offsetBytes

val records = for (record <- ast.children) yield {
val rootRecords = if (isAstFlat) {
Seq(ast)
} else {
ast.children
}

val records = for (record <- rootRecords) yield {
val (size, values) = getGroupValues(nextOffset, record.asInstanceOf[Group])
if (!record.isRedefined) {
nextOffset += size
}
values
}

applyRecordPostProcessing(ast, records, policy, generateRecordId, segmentLevelIds, fileId, recordId, data.length, generateInputFileField, inputFileName, handler)
val effectiveSchemaRetentionPolicy = if (isAstFlat) {
SchemaRetentionPolicy.CollapseRoot
} else {
policy
}

applyRecordPostProcessing(ast, records, effectiveSchemaRetentionPolicy, generateRecordId, segmentLevelIds, fileId, recordId, data.length, generateInputFileField, inputFileName, handler)
}

/**
Expand Down Expand Up @@ -226,6 +240,8 @@ object RecordExtractors {
inputFileName: String = "",
handler: RecordHandler[T]
): Seq[Any] = {
val isAstFlat = ast.children.exists(_.isInstanceOf[Primitive])

val dependFields = scala.collection.mutable.HashMap.empty[String, Either[Int, String]]

def extractArray(field: Statement, useOffset: Int, data: Array[Byte], currentIndex: Int, parentSegmentIds: List[String]): (Int, Array[Any]) = {
Expand Down Expand Up @@ -377,15 +393,27 @@ object RecordExtractors {

var nextOffset = offsetBytes

val records = ast.children.collect { case grp: Group if grp.parentSegment.isEmpty =>
val rootRecords = if (isAstFlat) {
Seq(ast)
} else {
ast.children
}

val records = rootRecords.collect { case grp: Group if grp.parentSegment.isEmpty =>
val (size, values) = getGroupValues(nextOffset, grp, segmentsData(0)._2, 0, segmentsData(0)._1 :: Nil)
nextOffset += size
values
}

val recordLength = segmentsData.map(_._2.length).sum

applyRecordPostProcessing(ast, records, policy, generateRecordId, Nil, fileId, recordId, recordLength, generateInputFileField, inputFileName, handler)
val effectiveSchemaRetentionPolicy = if (isAstFlat) {
SchemaRetentionPolicy.CollapseRoot
} else {
policy
}

applyRecordPostProcessing(ast, records, effectiveSchemaRetentionPolicy, generateRecordId, Nil, fileId, recordId, recordLength, generateInputFileField, inputFileName, handler)
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2018 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.cobrix.cobol.parser.copybooks

import org.scalatest.FunSuite
import org.slf4j.{Logger, LoggerFactory}
import za.co.absa.cobrix.cobol.parser.CopybookParser
import za.co.absa.cobrix.cobol.testutils.SimpleComparisonBase

class FlatCopybooksSpec extends FunSuite with SimpleComparisonBase {
private implicit val logger: Logger = LoggerFactory.getLogger(this.getClass)

test("Flat copybooks should be parsed normally") {
val copyBookContents: String =
s""" 01 F1 PIC X(10).
| 01 F2 PIC 9(4).
| 01 F3 PIC S9(6).
| 01 G1.
| 03 F4 PIC 9(4).
|""".stripMargin

val parsed = CopybookParser.parseTree(copyBookContents)

val actualLayout = parsed.generateRecordLayoutPositions()

assertEqualsMultiline(actualLayout,
"""-------- FIELD LEVEL/NAME --------- --ATTRIBS-- FLD START END LENGTH
|
|1 F1 1 1 10 10
|1 F2 2 11 14 4
|1 F3 3 15 20 6
|1 G1 4 21 24 4
| 3 F4 5 21 24 4"""
.stripMargin.replace("\r\n", "\n"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ class CobolSchema(copybook: Copybook,
@throws(classOf[IllegalStateException])
private def createSparkSchema(): StructType = {
logger.info("Layout positions:\n" + copybook.generateRecordLayoutPositions())
val records = for (record <- copybook.ast.children) yield {

val records = for (record <- copybook.getRootRecords) yield {
val group = record.asInstanceOf[Group]
val redefines = copybook.getAllSegmentRedefines
parseGroup(group, redefines)
}
val expandRecords = if (policy == SchemaRetentionPolicy.CollapseRoot) {
val expandRecords = if (policy == SchemaRetentionPolicy.CollapseRoot || copybook.isFlatCopybook) {
// Expand root group fields
records.toArray.flatMap(group => group.dataType.asInstanceOf[StructType].fields)
} else {
Expand Down
Loading

0 comments on commit 6b71a0b

Please sign in to comment.