Skip to content

Commit

Permalink
reindent
Browse files Browse the repository at this point in the history
  • Loading branch information
tuzhucheng committed Dec 15, 2017
1 parent 2e3e1b3 commit e2c1551
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
1 change: 1 addition & 0 deletions src/main/scala/largelsh/PairwiseNaive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ object PairwiseNaive {
val spark = SparkSession
.builder()
.appName("Naive All Pairs Implementation")
.config("spark.driver.maxResultSize", 0)
.getOrCreate()

import spark.implicits._
Expand Down
19 changes: 9 additions & 10 deletions src/main/scala/largelsh/SparkLSHv2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,22 @@ object SparkLSHv2 {
model.approxSimilarityJoin(transformedA, transformedB, threshold, "EuclideanDistance")

val predictionPoints = transformedB.select("label", "features")
.rdd
.zipWithIndex
.rdd
.zipWithIndex

val seqop = (s: (Double, Double), t: (Double, Double)) => if (t._1 == t._2) (s._1 + 1, s._2 + 1) else (s._1, s._2 + 1)
val combop = (s1: (Double, Double), s2: (Double, Double)) => (s1._1 + s2._1, s1._2 + s2._2)
val groups = testingCount / 1000
val overallAccAndCount = (0L until groups).toList.par.map(mod => {
val predictionsSubset = predictionPoints.filter { case (row, idx) => idx % groups == mod}
.collect.par
val predictionsSubset = predictionPoints.filter { case (row, idx) => idx % groups == mod }.collect.par
val accAndCount = predictionsSubset.map { case (row, idx) => {
val key = row.getAs[org.apache.spark.ml.linalg.SparseVector](1)
val ann = model.approxNearestNeighbors(transformedA, key, k)
val prediction = ann.select("label").groupBy("label").count.sort(desc("label")).first.getDouble(0)
(row.getDouble(0), prediction) // label, prediction
}}.aggregate((0.0, 0.0))(seqop, combop)
val key = row.getAs[org.apache.spark.ml.linalg.SparseVector](1)
val ann = model.approxNearestNeighbors(transformedA, key, k)
val prediction = ann.select("label").groupBy("label").count.sort(desc("label")).first.getDouble(0)
(row.getDouble(0), prediction) // label, prediction
}}.aggregate((0.0, 0.0))(seqop, combop)

accAndCount
accAndCount
}).aggregate((0.0, 0.0))(combop, combop)
val accuracy = overallAccAndCount._1 / overallAccAndCount._2
println("bl:", bl, "nht:", nht, "k:", k, "accuracy:", accuracy)
Expand Down

0 comments on commit e2c1551

Please sign in to comment.