Skip to content

Commit

Permalink
[SPARK-12935][SQL] DataFrame API for Count-Min Sketch
Browse files Browse the repository at this point in the history
This PR integrates Count-Min Sketch from spark-sketch into DataFrame. This version resorts to `RDD.aggregate` for building the sketch. A more performant UDAF version can be built in future follow-up PRs.

Author: Cheng Lian <lian@databricks.com>

Closes apache#10911 from liancheng/cms-df-api.
  • Loading branch information
liancheng authored and rxin committed Jan 27, 2016
1 parent e7f9199 commit ce38a35
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ public abstract class BloomFilter {
public enum Version {
/**
* {@code BloomFilter} binary format version 1 (all values written in big-endian order):
* - Version number, always 1 (32 bit)
* - Total number of words of the underlying bit array (32 bit)
* - The words/longs (numWords * 64 bit)
* - Number of hash functions (32 bit)
* <ul>
* <li>Version number, always 1 (32 bit)</li>
* <li>Total number of words of the underlying bit array (32 bit)</li>
* <li>The words/longs (numWords * 64 bit)</li>
* <li>Number of hash functions (32 bit)</li>
* </ul>
*/
V1(1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,22 @@ abstract public class CountMinSketch {
public enum Version {
/**
* {@code CountMinSketch} binary format version 1 (all values written in big-endian order):
* - Version number, always 1 (32 bit)
* - Total count of added items (64 bit)
* - Depth (32 bit)
* - Width (32 bit)
* - Hash functions (depth * 64 bit)
* - Count table
* - Row 0 (width * 64 bit)
* - Row 1 (width * 64 bit)
* - ...
* - Row depth - 1 (width * 64 bit)
* <ul>
* <li>Version number, always 1 (32 bit)</li>
* <li>Total count of added items (64 bit)</li>
* <li>Depth (32 bit)</li>
* <li>Width (32 bit)</li>
* <li>Hash functions (depth * 64 bit)</li>
* <li>
* Count table
* <ul>
* <li>Row 0 (width * 64 bit)</li>
* <li>Row 1 (width * 64 bit)</li>
* <li>...</li>
* <li>Row {@code depth - 1} (width * 64 bit)</li>
* </ul>
* </li>
* </ul>
*/
V1(1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;

class CountMinSketchImpl extends CountMinSketch {
public static final long PRIME_MODULUS = (1L << 31) - 1;
class CountMinSketchImpl extends CountMinSketch implements Serializable {
private static final long PRIME_MODULUS = (1L << 31) - 1;

private int depth;
private int width;
Expand All @@ -37,6 +40,9 @@ class CountMinSketchImpl extends CountMinSketch {
private double eps;
private double confidence;

private CountMinSketchImpl() {
}

CountMinSketchImpl(int depth, int width, int seed) {
this.depth = depth;
this.width = width;
Expand All @@ -55,16 +61,6 @@ class CountMinSketchImpl extends CountMinSketch {
initTablesWith(depth, width, seed);
}

CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) {
this.depth = depth;
this.width = width;
this.eps = 2.0 / width;
this.confidence = 1 - 1 / Math.pow(2, depth);
this.hashA = hashA;
this.table = table;
this.totalCount = totalCount;
}

@Override
public boolean equals(Object other) {
if (other == this) {
Expand Down Expand Up @@ -325,27 +321,43 @@ public void writeTo(OutputStream out) throws IOException {
}

public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
CountMinSketchImpl sketch = new CountMinSketchImpl();
sketch.readFrom0(in);
return sketch;
}

private void readFrom0(InputStream in) throws IOException {
DataInputStream dis = new DataInputStream(in);

// Ignores version number
dis.readInt();
int version = dis.readInt();
if (version != Version.V1.getVersionNumber()) {
throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")");
}

long totalCount = dis.readLong();
int depth = dis.readInt();
int width = dis.readInt();
this.totalCount = dis.readLong();
this.depth = dis.readInt();
this.width = dis.readInt();
this.eps = 2.0 / width;
this.confidence = 1 - 1 / Math.pow(2, depth);

long hashA[] = new long[depth];
this.hashA = new long[depth];
for (int i = 0; i < depth; ++i) {
hashA[i] = dis.readLong();
this.hashA[i] = dis.readLong();
}

long table[][] = new long[depth][width];
this.table = new long[depth][width];
for (int i = 0; i < depth; ++i) {
for (int j = 0; j < width; ++j) {
table[i][j] = dis.readLong();
this.table[i][j] = dis.readLong();
}
}
}

private void writeObject(ObjectOutputStream out) throws IOException {
this.writeTo(out);
}

return new CountMinSketchImpl(depth, width, totalCount, hashA, table);
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
this.readFrom0(in);
}
}
5 changes: 5 additions & 0 deletions sql/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
<version>1.5.6</version>
<type>jar</type>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sketch_2.10</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.CountMinSketch

/**
* :: Experimental ::
Expand Down Expand Up @@ -309,4 +311,83 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
}

/**
* Builds a Count-min Sketch over a specified column.
*
* @param colName name of the column over which the sketch is built
* @param depth depth of the sketch
* @param width width of the sketch
* @param seed random seed
* @return a [[CountMinSketch]] over column `colName`
* @since 2.0.0
*/
def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = {
countMinSketch(Column(colName), depth, width, seed)
}

/**
* Builds a Count-min Sketch over a specified column.
*
* @param colName name of the column over which the sketch is built
* @param eps relative error of the sketch
* @param confidence confidence of the sketch
* @param seed random seed
* @return a [[CountMinSketch]] over column `colName`
* @since 2.0.0
*/
def countMinSketch(
colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
countMinSketch(Column(colName), eps, confidence, seed)
}

/**
* Builds a Count-min Sketch over a specified column.
*
* @param col the column over which the sketch is built
* @param depth depth of the sketch
* @param width width of the sketch
* @param seed random seed
* @return a [[CountMinSketch]] over column `colName`
* @since 2.0.0
*/
def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = {
countMinSketch(col, CountMinSketch.create(depth, width, seed))
}

/**
* Builds a Count-min Sketch over a specified column.
*
* @param col the column over which the sketch is built
* @param eps relative error of the sketch
* @param confidence confidence of the sketch
* @param seed random seed
* @return a [[CountMinSketch]] over column `colName`
* @since 2.0.0
*/
def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
countMinSketch(col, CountMinSketch.create(eps, confidence, seed))
}

private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = {
val singleCol = df.select(col)
val colType = singleCol.schema.head.dataType

require(
colType == StringType || colType.isInstanceOf[IntegralType],
s"Count-min Sketch only supports string type and integral types, " +
s"and does not support type $colType."
)

singleCol.rdd.aggregate(zero)(
(sketch: CountMinSketch, row: Row) => {
sketch.add(row.get(0))
sketch
},

(sketch1: CountMinSketch, sketch2: CountMinSketch) => {
sketch1.mergeInPlace(sketch2)
}
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.types.*;
import org.apache.spark.util.sketch.CountMinSketch;
import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.types.DataTypes.*;

public class JavaDataFrameSuite {
Expand Down Expand Up @@ -321,4 +322,29 @@ public void testTextLoad() {
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
}

@Test
public void testCountMinSketch() {
DataFrame df = context.range(1000);

CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
Assert.assertEquals(sketch1.totalCount(), 1000);
Assert.assertEquals(sketch1.depth(), 10);
Assert.assertEquals(sketch1.width(), 20);

CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42);
Assert.assertEquals(sketch2.totalCount(), 1000);
Assert.assertEquals(sketch2.depth(), 10);
Assert.assertEquals(sketch2.width(), 20);

CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42);
Assert.assertEquals(sketch3.totalCount(), 1000);
Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4);
Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3);

CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42);
Assert.assertEquals(sketch4.totalCount(), 1000);
Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4);
Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package org.apache.spark.sql

import java.util.Random

import org.scalatest.Matchers._

import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.DoubleType

class DataFrameStatSuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -210,4 +213,37 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
sampled.groupBy("key").count().orderBy("key"),
Seq(Row(0, 6), Row(1, 11)))
}

// This test case only verifies that `DataFrame.countMinSketch()` methods do return
// `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in
// `CountMinSketchSuite` in project spark-sketch.
test("countMinSketch") {
val df = sqlContext.range(1000)

val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42)
assert(sketch1.totalCount() === 1000)
assert(sketch1.depth() === 10)
assert(sketch1.width() === 20)

val sketch2 = df.stat.countMinSketch($"id", depth = 10, width = 20, seed = 42)
assert(sketch2.totalCount() === 1000)
assert(sketch2.depth() === 10)
assert(sketch2.width() === 20)

val sketch3 = df.stat.countMinSketch("id", eps = 0.001, confidence = 0.99, seed = 42)
assert(sketch3.totalCount() === 1000)
assert(sketch3.relativeError() === 0.001)
assert(sketch3.confidence() === 0.99 +- 5e-3)

val sketch4 = df.stat.countMinSketch($"id", eps = 0.001, confidence = 0.99, seed = 42)
assert(sketch4.totalCount() === 1000)
assert(sketch4.relativeError() === 0.001 +- 1e04)
assert(sketch4.confidence() === 0.99 +- 5e-3)

intercept[IllegalArgumentException] {
df.select('id cast DoubleType as 'id)
.stat
.countMinSketch('id, depth = 10, width = 20, seed = 42)
}
}
}

0 comments on commit ce38a35

Please sign in to comment.