raw, final int loadSize) t
SparkLoaderFactoryUtil.createLoaderFactory(childColumns[index]).create(columnBinary, loadSize);
}
// NOTE: null columns
- for (int i = 0; i < childColumns.length; i++) {
+ for (int i = 0; i < schema.length(); i++) {
if (!isSet[i]) {
childColumns[i].putNulls(0, loadSize);
}
diff --git a/src/main/java/jp/co/yahoo/yosegi/spark/reader/SparkColumnarBatchReader.java b/src/main/java/jp/co/yahoo/yosegi/spark/reader/SparkColumnarBatchReader.java
index aff02b0..f8e86cf 100644
--- a/src/main/java/jp/co/yahoo/yosegi/spark/reader/SparkColumnarBatchReader.java
+++ b/src/main/java/jp/co/yahoo/yosegi/spark/reader/SparkColumnarBatchReader.java
@@ -85,6 +85,9 @@ public ColumnarBatch next() throws IOException {
public void close() throws Exception {
reader.close();
for (int i = 0; i < childColumns.length; i++) {
+ if (childColumns[i] == null) {
+ continue;
+ }
childColumns[i].close();
}
}
diff --git a/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
new file mode 100644
index 0000000..43dad69
--- /dev/null
+++ b/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more contributor license
+# agreements. See the NOTICE file distributed with this work for additional information regarding
+# copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance with the License. You may obtain a
+# copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#
Unless required by applicable law or agreed to in writing, software distributed under the
+# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+# express or implied. See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+jp.co.yahoo.yosegi.spark.YosegiFileFormat
\ No newline at end of file
diff --git a/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/DataSourceTest.java b/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/DataSourceTest.java
new file mode 100644
index 0000000..1ed096b
--- /dev/null
+++ b/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/DataSourceTest.java
@@ -0,0 +1,341 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license
+ * agreements. See the NOTICE file distributed with this work for additional information regarding
+ * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License. You may obtain a
+ * copy of the License at
+ *
+ *
http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *
Unless required by applicable law or agreed to in writing, software distributed under the
+ * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package jp.co.yahoo.yosegi.spark.blackbox;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.DecimalType;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import java.io.File;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+
+public class DataSourceTest {
+ private static SparkSession spark;
+ private static SQLContext sqlContext;
+ private static final String appName = "DataSourceTest";
+
+ public boolean deleteDirectory(final File directory) {
+ final File[] allContents = directory.listFiles();
+ if (allContents != null) {
+ for (final File file : allContents) {
+ deleteDirectory(file);
+ }
+ }
+ return directory.delete();
+ }
+
+ public String getTmpPath() {
+ String tmpdir = System.getProperty("java.io.tmpdir");
+ if (tmpdir.endsWith("/")) {
+ tmpdir = tmpdir.substring(0, tmpdir.length() - 1);
+ }
+ return tmpdir + "/" + appName;
+ }
+
+ public String getLocation(final String table) {
+ String tmpPath = getTmpPath();
+ return tmpPath + "/" + table;
+ }
+
+ @BeforeAll
+ static void initAll() {
+ spark = SparkSession.builder().appName(appName).master("local[*]").getOrCreate();
+ sqlContext = spark.sqlContext();
+ }
+
+ @AfterAll
+ static void tearDownAll() {
+ spark.close();
+ }
+
+ @AfterEach
+ void tearDown() {
+ deleteDirectory(new File(getTmpPath()));
+ }
+
+ @Test
+ void T_DataSource_Primitive_1() throws IOException {
+ final String location = getLocation("primitive1");
+ final int precision = DecimalType.MAX_PRECISION();
+ final int scale = DecimalType.MINIMUM_ADJUSTED_SCALE();
+ /**
+ * CREATE TABLE primitive1 (
+ * id INT,
+ * bo BOOLEAN,
+ * by BYTE,
+ * de DECIMAL(38,6),
+ * do DOUBLE,
+ * fl FLOAT,
+ * in INT,
+ * lo LONG,
+ * sh SHORT,
+ * st STRING
+ * )
+ * USING yosegi
+ * LOCATION ''
+ */
+ final String ddl = String.format("CREATE TABLE primitive1 (\n" +
+ " id INT,\n" +
+ " bo BOOLEAN,\n" +
+ " by BYTE,\n" +
+ " de DECIMAL(%d, %d),\n" +
+ " do DOUBLE,\n" +
+ " fl FLOAT,\n" +
+ " in INT,\n" +
+ " lo LONG,\n" +
+ " sh SHORT,\n" +
+ " st STRING\n" +
+ ")\n" +
+ "USING yosegi\n" +
+ "LOCATION '%s';", precision, scale, location);
+ spark.sql(ddl);
+
+ /**
+ * FIXME: cannot insert decimal value.
+ */
+ final String insertSql = "INSERT INTO primitive1\n" +
+ "(id, bo, by, de, do, fl, in, lo, sh, st)\n" +
+ "VALUES\n" +
+ "(0, true, 127, 123.45678901, 1.7976931348623157e+308, 3.402823e+37, 2147483647, 9223372036854775807, 32767, 'value1');";
+ spark.sql(insertSql);
+
+ List rows = spark.sql("SELECT * FROM primitive1;").collectAsList();
+ assertEquals(rows.get(0).getAs("id"), Integer.valueOf(0));
+ assertEquals(rows.get(0).getAs("bo"), Boolean.valueOf(true));
+ assertEquals(rows.get(0).getAs("by"), Byte.valueOf((byte) 127));
+ assertNull(rows.get(0).getAs("de"));
+ assertEquals(rows.get(0).getAs("do"), Double.valueOf(1.7976931348623157e+308));
+ assertEquals(rows.get(0).getAs("fl"), Float.valueOf(3.402823e+37F));
+ assertEquals(rows.get(0).getAs("in"), Integer.valueOf(2147483647));
+ assertEquals(rows.get(0).getAs("lo"), Long.valueOf(9223372036854775807L));
+ assertEquals(rows.get(0).getAs("sh"), Short.valueOf((short) 32767));
+ assertEquals(rows.get(0).getAs("st"), "value1");
+ }
+
+ @Test
+ void T_DataSource_Primitive_2() throws IOException {
+ final String location = getLocation("primitive2");
+ final int precision = DecimalType.MAX_PRECISION();
+ final int scale = DecimalType.MINIMUM_ADJUSTED_SCALE();
+ /**
+ * CREATE TABLE primitive2 (
+ * id INT,
+ * bo BOOLEAN,
+ * by BYTE,
+ * de DOUBLE,
+ * do DOUBLE,
+ * fl FLOAT,
+ * in INT,
+ * lo LONG,
+ * sh SHORT,
+ * st STRING
+ * )
+ * USING yosegi
+ * LOCATION ''
+ */
+ final String ddl1 = String.format("CREATE TABLE primitive2 (\n" +
+ " id INT,\n" +
+ " bo BOOLEAN,\n" +
+ " by BYTE,\n" +
+ " de DOUBLE,\n" +
+ " do DOUBLE,\n" +
+ " fl FLOAT,\n" +
+ " in INT,\n" +
+ " lo LONG,\n" +
+ " sh SHORT,\n" +
+ " st STRING\n" +
+ ")\n" +
+ "USING yosegi\n" +
+ "LOCATION '%s';", location);
+ spark.sql(ddl1);
+
+ final String insertSql = "INSERT INTO primitive2\n" +
+ "(id, bo, by, de, do, fl, in, lo, sh, st)\n" +
+ "VALUES\n" +
+ "(0, true, 127, 123.45678901, 1.7976931348623157e+308, 3.402823e+37, 2147483647, 9223372036854775807, 32767, 'value1');";
+ spark.sql(insertSql);
+
+ spark.sql("DROP TABLE primitive2;");
+
+ /**
+ * CREATE TABLE primitive2 (
+ * id INT,
+ * bo BOOLEAN,
+ * by BYTE,
+ * de DECIMAL(38,6),
+ * do DOUBLE,
+ * fl FLOAT,
+ * in INT,
+ * lo LONG,
+ * sh SHORT,
+ * st STRING
+ * )
+ * USING yosegi
+ * LOCATION ''
+ */
+ final String ddl2 = String.format("CREATE TABLE primitive2 (\n" +
+ " id INT,\n" +
+ " bo BOOLEAN,\n" +
+ " by BYTE,\n" +
+ " de DECIMAL(%d, %d),\n" +
+ " do DOUBLE,\n" +
+ " fl FLOAT,\n" +
+ " in INT,\n" +
+ " lo LONG,\n" +
+ " sh SHORT,\n" +
+ " st STRING\n" +
+ ")\n" +
+ "USING yosegi\n" +
+ "LOCATION '%s';", precision, scale, location);
+ spark.sql(ddl2);
+
+ List rows = spark.sql("SELECT * FROM primitive2;").collectAsList();
+ assertEquals(rows.get(0).getAs("id"), Integer.valueOf(0));
+ assertEquals(rows.get(0).getAs("bo"), Boolean.valueOf(true));
+ assertEquals(rows.get(0).getAs("by"), Byte.valueOf((byte) 127));
+ assertEquals(rows.get(0).getAs("de"), BigDecimal.valueOf(123.456789));
+ assertEquals(rows.get(0).getAs("do"), Double.valueOf(1.7976931348623157e+308));
+ assertEquals(rows.get(0).getAs("fl"), Float.valueOf(3.402823e+37F));
+ assertEquals(rows.get(0).getAs("in"), Integer.valueOf(2147483647));
+ assertEquals(rows.get(0).getAs("lo"), Long.valueOf(9223372036854775807L));
+ assertEquals(rows.get(0).getAs("sh"), Short.valueOf((short) 32767));
+ assertEquals(rows.get(0).getAs("st"), "value1");
+ }
+
+ @Test
+ void T_DataSource_Expand_1() throws IOException {
+ final String location = getLocation("flatten1");
+ /**
+ * CREATE TABLE expand1 (
+ * id INT,
+ * a ARRAY
+ * )
+ * USING yosegi
+ * LOCATION '';
+ */
+ final String ddl1 = String.format("CREATE TABLE expand1 (\n" +
+ " id INT,\n" +
+ " a ARRAY\n" +
+ ")\n" +
+ "USING yosegi\n" +
+ "LOCATION '%s';", location);
+ spark.sql(ddl1);
+
+ final String insertSql = "INSERT INTO expand1\n" +
+ "(id, a)\n" +
+ "VALUES\n" +
+ "(0, array(1,2,3));";
+ spark.sql(insertSql);
+
+ spark.sql("DROP TABLE expand1;");
+
+ /**
+ * CREATE TABLE expand1(
+ * id INT,
+ * aa INT
+ * )
+ * USING yosegi
+ * LOCATION ''
+ * OPTIONS (
+ * 'spread.reader.expand.column'='{"base":{"node":"a", "link_name":"aa"}}'
+ * );
+ */
+ final String ddl2 = String.format("CREATE TABLE expand1(\n" +
+ " id INT,\n" +
+ " aa INT\n" +
+ ")\n" +
+ "USING yosegi\n" +
+ "LOCATION '%s'\n" +
+ "OPTIONS (\n" +
+ " 'spread.reader.expand.column'='{\"base\":{\"node\":\"a\", \"link_name\":\"aa\"}}'\n" +
+ ");", location);
+ spark.sql(ddl2);
+
+ List rows = spark.sql("SELECT * FROM expand1 ORDER BY id, aa;").collectAsList();
+ assertEquals(rows.get(0).getAs("id"), Integer.valueOf(0));
+ assertEquals(rows.get(1).getAs("id"), Integer.valueOf(0));
+ assertEquals(rows.get(2).getAs("id"), Integer.valueOf(0));
+ assertEquals(rows.get(0).getAs("aa"), Integer.valueOf(1));
+ assertEquals(rows.get(1).getAs("aa"), Integer.valueOf(2));
+ assertEquals(rows.get(2).getAs("aa"), Integer.valueOf(3));
+ }
+
+ @Test
+ void T_DataSource_Flatten_1() throws IOException {
+ final String location = getLocation("flatten1");
+ /**
+ * CREATE TABLE flatten1 (
+ * id INT,
+ * m MAP
+ * )
+ * USING yosegi
+ * LOCATION '';
+ */
+ final String ddl1 = String.format("CREATE TABLE flatten1 (\n" +
+ " id INT,\n" +
+ " m MAP\n" +
+ ")\n" +
+ "USING yosegi\n" +
+ "LOCATION '%s';", location);
+ spark.sql(ddl1);
+
+ final String insertSql = "INSERT INTO flatten1\n" +
+ "(id, m)\n" +
+ "VALUES\n" +
+ "(0, map('k1', 'v1', 'k2', 'v2'));";
+ spark.sql(insertSql);
+
+ spark.sql("DROP TABLE flatten1;");
+
+ /**
+ * CREATE TABLE flatten1 (
+ * id INT,
+ * mk1 STRING,
+ * mk2 STRING
+ * )
+ * USING yosegi
+ * LOCATION ''
+ * OPTIONS (
+ * 'spread.reader.flatten.column'='[{"link_name":"id", "nodes":["id"]}, {"link_name":"mk1", "nodes":["m","k1"]}, {"link_name":"mk2", "nodes":["m","k2"]}]'
+ * );
+ */
+ final String ddl2 = String.format("CREATE TABLE flatten1 (\n" +
+ " id INT,\n" +
+ " mk1 STRING,\n" +
+ " mk2 STRING\n" +
+ ")\n" +
+ "USING yosegi\n" +
+ "LOCATION '%s'\n" +
+ "OPTIONS (\n" +
+ " 'spread.reader.flatten.column'='[{\"link_name\":\"id\", \"nodes\":[\"id\"]}, {\"link_name\":\"mk1\", \"nodes\":[\"m\",\"k1\"]}, {\"link_name\":\"mk2\", \"nodes\":[\"m\",\"k2\"]}]'\n" +
+ ");", location);
+ spark.sql(ddl2);
+
+ List rows = spark.sql("SELECT * FROM flatten1;").collectAsList();
+ assertEquals(rows.get(0).getAs("id"), Integer.valueOf(0));
+ assertEquals(rows.get(0).getAs("mk1"), "v1");
+ assertEquals(rows.get(0).getAs("mk2"), "v2");
+ }
+}
diff --git a/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/ExpandFlatten.java b/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/ExpandFlattenTest.java
similarity index 97%
rename from src/test/java/jp/co/yahoo/yosegi/spark/blackbox/ExpandFlatten.java
rename to src/test/java/jp/co/yahoo/yosegi/spark/blackbox/ExpandFlattenTest.java
index 3773720..790588d 100644
--- a/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/ExpandFlatten.java
+++ b/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/ExpandFlattenTest.java
@@ -40,7 +40,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
-public class ExpandFlatten {
+public class ExpandFlattenTest {
private static SparkSession spark;
private static SQLContext sqlContext;
private static final String appName = "ExpandFlattenTest";
@@ -60,7 +60,11 @@ public String getResourcePath(final String resource) {
}
public String getTmpPath() {
- return System.getProperty("java.io.tmpdir") + appName + ".yosegi";
+ String tmpdir = System.getProperty("java.io.tmpdir");
+ if (tmpdir.endsWith("/")) {
+ tmpdir = tmpdir.substring(0, tmpdir.length() - 1);
+ }
+ return tmpdir + "/" + appName + ".yosegi";
}
public Dataset loadJsonFile(final String resource, final StructType schema) {
diff --git a/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/Load.java b/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/LoadTest.java
similarity index 99%
rename from src/test/java/jp/co/yahoo/yosegi/spark/blackbox/Load.java
rename to src/test/java/jp/co/yahoo/yosegi/spark/blackbox/LoadTest.java
index 9b8fa97..0b8b6c6 100644
--- a/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/Load.java
+++ b/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/LoadTest.java
@@ -43,7 +43,7 @@
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
-public class Load {
+public class LoadTest {
private static SparkSession spark;
private static SQLContext sqlContext;
private static String appName = "LoadTest";
@@ -63,7 +63,11 @@ public String getResourcePath(final String resource) {
}
public String getTmpPath() {
- return System.getProperty("java.io.tmpdir") + appName + ".yosegi";
+ String tmpdir = System.getProperty("java.io.tmpdir");
+ if (tmpdir.endsWith("/")) {
+ tmpdir = tmpdir.substring(0, tmpdir.length() - 1);
+ }
+ return tmpdir + "/" + appName + ".yosegi";
}
public Dataset loadJsonFile(final String resource, final StructType schema) {
diff --git a/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/PartitionLoadTest.java b/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/PartitionLoadTest.java
new file mode 100644
index 0000000..ce2b988
--- /dev/null
+++ b/src/test/java/jp/co/yahoo/yosegi/spark/blackbox/PartitionLoadTest.java
@@ -0,0 +1,189 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license
+ * agreements. See the NOTICE file distributed with this work for additional information regarding
+ * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License. You may obtain a
+ * copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *
Unless required by applicable law or agreed to in writing, software distributed under the
+ * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package jp.co.yahoo.yosegi.spark.blackbox;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SaveMode;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.apache.spark.sql.functions.col;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class PartitionLoadTest {
+ private static SparkSession spark;
+ private static SQLContext sqlContext;
+ private static String appName = "PartitionLoadTest";
+
+ public boolean deleteDirectory(final File directory) {
+ final File[] allContents = directory.listFiles();
+ if (allContents != null) {
+ for (final File file : allContents) {
+ deleteDirectory(file);
+ }
+ }
+ return directory.delete();
+ }
+
+ public String getResourcePath(final String resource) {
+ return Thread.currentThread().getContextClassLoader().getResource(resource).getPath();
+ }
+
+ public String getTmpPath() {
+ String tmpdir = System.getProperty("java.io.tmpdir");
+ if (tmpdir.endsWith("/")) {
+ tmpdir = tmpdir.substring(0, tmpdir.length() - 1);
+ }
+ return tmpdir + "/" + appName + ".yosegi";
+ }
+
+ public Dataset loadJsonFile(final String resource, final StructType schema) {
+ final String resourcePath = getResourcePath(resource);
+ if (schema == null) {
+ return sqlContext.read().json(resourcePath).orderBy(col("id").asc());
+ }
+ return sqlContext.read().schema(schema).json(resourcePath).orderBy(col("id").asc());
+ }
+
+ public void createYosegiFile(final String resource, final String... partitions) {
+ final Dataset df = loadJsonFile(resource, null);
+ final String tmpPath = getTmpPath();
+ df.write()
+ .mode(SaveMode.Overwrite)
+ .partitionBy(partitions)
+ .format("jp.co.yahoo.yosegi.spark.YosegiFileFormat")
+ .save(tmpPath);
+ }
+
+ public Dataset loadYosegiFile(final StructType schema) {
+ final String tmpPath = getTmpPath();
+ if (schema == null) {
+ return sqlContext
+ .read()
+ .format("jp.co.yahoo.yosegi.spark.YosegiFileFormat")
+ .load(tmpPath)
+ .orderBy(col("id").asc());
+ }
+ return sqlContext
+ .read()
+ .format("jp.co.yahoo.yosegi.spark.YosegiFileFormat")
+ .schema(schema)
+ .load(tmpPath)
+ .orderBy(col("id").asc());
+ }
+
+ @BeforeAll
+ static void initAll() {
+ spark = SparkSession.builder().appName(appName).master("local[*]").getOrCreate();
+ sqlContext = spark.sqlContext();
+ }
+
+ @AfterAll
+ static void tearDownAll() {
+ spark.close();
+ }
+
+ @AfterEach
+ void tearDown() {
+ deleteDirectory(new File(getTmpPath()));
+ }
+
+ /*
+ * FIXME: The rows cannot be loaded if rows in a partition have only null values.
+ * * {"id":1}
+ */
+ @Test
+ void T_load_Partition_Primitive_1() throws IOException {
+ // NOTE: create yosegi file
+ final String resource = "blackbox/Partition_Primitive_1.txt";
+ createYosegiFile(resource, "id");
+
+ // NOTE: schema
+ final int precision = DecimalType.MAX_PRECISION();
+ final int scale = DecimalType.MINIMUM_ADJUSTED_SCALE();
+ final List fields =
+ Arrays.asList(
+ DataTypes.createStructField("id", DataTypes.IntegerType, true),
+ DataTypes.createStructField("bo", DataTypes.BooleanType, true),
+ DataTypes.createStructField("by", DataTypes.ByteType, true),
+ DataTypes.createStructField("de", DataTypes.createDecimalType(precision, scale), true),
+ DataTypes.createStructField("do", DataTypes.DoubleType, true),
+ DataTypes.createStructField("fl", DataTypes.FloatType, true),
+ DataTypes.createStructField("in", DataTypes.IntegerType, true),
+ DataTypes.createStructField("lo", DataTypes.LongType, true),
+ DataTypes.createStructField("sh", DataTypes.ShortType, true),
+ DataTypes.createStructField("st", DataTypes.StringType, true));
+ final StructType structType = DataTypes.createStructType(fields);
+ // NOTE: load
+ final Dataset dfj = loadJsonFile(resource, structType);
+ final Dataset dfy = loadYosegiFile(structType);
+
+ // NOTE: assert
+ final List ldfj = dfj.collectAsList();
+ final List ldfy = dfy.collectAsList();
+ for (int i = 0; i < ldfj.size(); i++) {
+ for (final StructField field : fields) {
+ final String name = field.name();
+ final DataType dataType = field.dataType();
+ assertEquals((Object) ldfj.get(i).getAs(name), (Object) ldfy.get(i).getAs(name));
+ }
+ }
+ }
+
+ @Test
+ void T_load_Partition_Primitive_2() throws IOException {
+ // NOTE: create yosegi file
+ final String resource = "blackbox/Partition_Primitive_2.txt";
+ createYosegiFile(resource, "p1", "p2");
+
+ // NOTE: schema
+ final List fields =
+ Arrays.asList(
+ DataTypes.createStructField("id", DataTypes.IntegerType, true),
+ DataTypes.createStructField("p1", DataTypes.IntegerType, true),
+ DataTypes.createStructField("p2", DataTypes.StringType, true),
+ DataTypes.createStructField("v", DataTypes.StringType, true));
+ final StructType structType = DataTypes.createStructType(fields);
+ // NOTE: load
+ final Dataset dfj = loadJsonFile(resource, structType);
+ final Dataset dfy = loadYosegiFile(structType);
+
+ // NOTE: assert
+ final List ldfj = dfj.collectAsList();
+ final List ldfy = dfy.collectAsList();
+ for (int i = 0; i < ldfj.size(); i++) {
+ for (final StructField field : fields) {
+ final String name = field.name();
+ final DataType dataType = field.dataType();
+ assertEquals((Object) ldfj.get(i).getAs(name), (Object) ldfy.get(i).getAs(name));
+ }
+ }
+ }
+}
diff --git a/src/test/java/jp/co/yahoo/yosegi/spark/inmemory/loader/SparkArrayLoaderTest.java b/src/test/java/jp/co/yahoo/yosegi/spark/inmemory/loader/SparkArrayLoaderTest.java
index 8c80909..f244a9e 100644
--- a/src/test/java/jp/co/yahoo/yosegi/spark/inmemory/loader/SparkArrayLoaderTest.java
+++ b/src/test/java/jp/co/yahoo/yosegi/spark/inmemory/loader/SparkArrayLoaderTest.java
@@ -17,9 +17,12 @@
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import jp.co.yahoo.yosegi.binary.ColumnBinary;
+import jp.co.yahoo.yosegi.binary.ColumnBinaryMakerConfig;
import jp.co.yahoo.yosegi.binary.FindColumnBinaryMaker;
import jp.co.yahoo.yosegi.binary.maker.IColumnBinaryMaker;
import jp.co.yahoo.yosegi.binary.maker.MaxLengthBasedArrayColumnBinaryMaker;
+import jp.co.yahoo.yosegi.binary.maker.OptimizedNullArrayDumpStringColumnBinaryMaker;
+import jp.co.yahoo.yosegi.binary.maker.OptimizedNullArrayStringColumnBinaryMaker;
import jp.co.yahoo.yosegi.inmemory.IArrayLoader;
import jp.co.yahoo.yosegi.message.parser.json.JacksonMessageReader;
import jp.co.yahoo.yosegi.spark.test.Utils;
@@ -370,4 +373,46 @@ void T_load_String_1(final String binaryMakerClassName) throws IOException {
// NOTE: assert
assertArray(values, vector, loadSize, elmDataType);
}
+
+ @ParameterizedTest
+ @MethodSource("D_arrayColumnBinaryMaker")
+ void T_load_String_checkDictionaryReset(final String binaryMakerClassName) throws IOException {
+ // NOTE: test data
+ // NOTE: expected
+ String resource = "SparkArrayLoaderTest/String_1.txt";
+ List> values = toValues(resource);
+ int loadSize = values.size();
+
+ // NOTE: create ColumnBinary
+ IColumn column = toArrayColumn(resource);
+ IColumnBinaryMaker binaryMaker = FindColumnBinaryMaker.get(binaryMakerClassName);
+
+ ColumnBinaryMakerConfig config = new ColumnBinaryMakerConfig();
+ config.stringMakerClass = new OptimizedNullArrayStringColumnBinaryMaker();
+ ColumnBinary columnBinary = Utils.getColumnBinary(binaryMaker, column, config, null, null);
+
+ // NOTE: load
+ final DataType elmDataType = DataTypes.StringType;
+ final DataType dataType = DataTypes.createArrayType(elmDataType);
+ final WritableColumnVector vector = new OnHeapColumnVector(loadSize, dataType);
+ IArrayLoader loader = new SparkArrayLoader(vector, loadSize);
+ binaryMaker.load(columnBinary, loader);
+
+ // NOTE: assert
+ assertArray(values, vector, loadSize, elmDataType);
+
+ // NOTE: Check if the vector is reset
+ String resource2 = "SparkArrayLoaderTest/String_2.txt";
+ List> values2 = toValues(resource2);
+ int loadSize2 = values2.size();
+ IColumn column2 = toArrayColumn(resource2);
+ IColumnBinaryMaker binaryMaker2 = FindColumnBinaryMaker.get(binaryMakerClassName);
+ config.stringMakerClass = new OptimizedNullArrayDumpStringColumnBinaryMaker();
+ ColumnBinary columnBinary2 = Utils.getColumnBinary(binaryMaker2, column2, config, null, null);
+ vector.reset();
+ vector.reserve(loadSize2);
+ IArrayLoader loader2 = new SparkArrayLoader(vector, loadSize2);
+ binaryMaker.load(columnBinary2, loader2);
+ assertArray(values2, vector, loadSize2, elmDataType);
+ }
}
diff --git a/src/test/java/jp/co/yahoo/yosegi/spark/inmemory/loader/SparkStructLoaderTest.java b/src/test/java/jp/co/yahoo/yosegi/spark/inmemory/loader/SparkStructLoaderTest.java
index 2a0b266..80bc78c 100644
--- a/src/test/java/jp/co/yahoo/yosegi/spark/inmemory/loader/SparkStructLoaderTest.java
+++ b/src/test/java/jp/co/yahoo/yosegi/spark/inmemory/loader/SparkStructLoaderTest.java
@@ -220,4 +220,55 @@ void T_load_Struct_1(final String binaryMakerClassName) throws IOException {
// NOTE: assert
assertStruct(values, vector, loadSize, fields);
}
+
+ @ParameterizedTest
+ @MethodSource("D_spreadColumnBinaryMaker")
+ void T_load_Struct_checkDictionaryReset(final String binaryMakerClassName) throws IOException {
+ // NOTE: test data
+ // NOTE: expected
+ final String resource = "SparkStructLoaderTest/Struct_1.txt";
+ final List