Skip to content

Commit

Permalink
Improve performance and behavior of ListMap and ListSet
Browse files Browse the repository at this point in the history
Makes the immutable `ListMap` and `ListSet` collections more alike one another, both in their semantics and in their performance.

In terms of semantics, makes the `ListMap` iterator return the elements in reverse order, as `ListSet` already does (improving its performance as a side-effect). While, as mentioned in SI-8985, `ListMap` and `ListSet` doesn't seem to make any guarantees in terms of iteration order, I believe users expect `ListSet` and `ListMap` to behave in the same way, particularly when they are implemented in the exact same way.

In terms of performance, `ListMap` is given a custom builder that avoids creation in O(N^2) time using a strategy similar to the one already applied in `ListSet`. `ListSet`'s element removal method was not tail-recursive as the `ListMap` one, so a tail-recursive implementation was added.
  • Loading branch information
ruippeixotog committed Apr 30, 2016
1 parent 4c4c5e6 commit bb05a4d
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 14 deletions.
45 changes: 39 additions & 6 deletions src/library/scala/collection/immutable/ListMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ package immutable

import generic._
import scala.annotation.tailrec
import mutable.{Builder, ReusableBuilder}

/** $factoryInfo
* @since 1
Expand All @@ -32,13 +33,45 @@ object ListMap extends ImmutableMapFactory[ListMap] {
new MapCanBuildFrom[A, B]
def empty[A, B]: ListMap[A, B] = EmptyListMap.asInstanceOf[ListMap[A, B]]

override def newBuilder[A, B]: Builder[(A, B), ListMap[A, B]] = new ListMapBuilder[A, B]

@SerialVersionUID(-8256686706655863282L)
private object EmptyListMap extends ListMap[Any, Nothing] {
override def apply(key: Any) = throw new NoSuchElementException("key not found: " + key)
override def contains(key: Any) = false
override def last: (Any, Nothing) = throw new NoSuchElementException("Empty ListMap")
override def init: ListMap[Any, Nothing] = throw new NoSuchElementException("Empty ListMap")
}

/**
* A custom builder because forgetfully adding elements one at a time to a list backed map puts
* the "squared" in N^2. There is a temporary space cost, but it's improbable a list backed set
* could become large enough for this to matter given its pricy element lookup.
*
* This builder is reusable.
**/
class ListMapBuilder[A, B](initial: ListMap[A, B]) extends ReusableBuilder[(A, B), ListMap[A, B]] {
def this() = this(empty[A, B])
protected val elems = (new mutable.ListBuffer[(A, B)] ++= initial).reverse
protected val seen = new mutable.HashMap[A, B] ++= initial

def +=(kv: (A, B)): this.type = {
seen.get(kv._1) match {
case Some(v) =>
if(v != kv._2) {
elems -= (kv._1 -> v) += kv
seen += kv
}
case None =>
elems += kv
seen += kv
}
this
}

def clear() = { elems.clear(); seen.clear() }
def result() = elems.foldLeft(empty[A, B])(_ unchecked_+ _)
}
}

/** This class implements immutable maps using a list-based data structure, which preserves insertion order.
Expand Down Expand Up @@ -123,16 +156,18 @@ extends AbstractMap[A, B]
*/
def - (key: A): ListMap[A, B] = this

private[ListMap] def unchecked_+[B1 >: B](kv: (A, B1)): ListMap[A, B1] = new Node(kv._1, kv._2)

/** Returns an iterator over key-value pairs.
*/
def iterator: Iterator[(A,B)] =
new AbstractIterator[(A,B)] {
var self: ListMap[A,B] = ListMap.this
def hasNext = !self.isEmpty
def next(): (A,B) =
if (!hasNext) throw new NoSuchElementException("next on empty iterator")
else { val res = (self.key, self.value); self = self.next; res }
}.toList.reverseIterator
if (hasNext) { val res = (self.key, self.value); self = self.next; res }
else Iterator.empty.next()
}

protected def key: A = throw new NoSuchElementException("empty map")
protected def value: B = throw new NoSuchElementException("empty map")
Expand Down Expand Up @@ -213,9 +248,7 @@ extends AbstractMap[A, B]
if (cur.isEmpty)
acc.last
else if (k == cur.key)
(cur.next /: acc) {
case (t, h) => val tt = t; new tt.Node(h.key, h.value) // SI-7459
}
(cur.next /: acc) { case (t, h) => new t.Node(h.key, h.value) }
else
remove0(k, cur.next, cur::acc)

Expand Down
18 changes: 10 additions & 8 deletions src/library/scala/collection/immutable/ListSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,7 @@ sealed class ListSet[A] extends AbstractSet[A]
var that: ListSet[A] = self
def hasNext = that.nonEmpty
def next: A =
if (hasNext) {
val res = that.head
that = that.tail
res
}
if (hasNext) { val res = that.head; that = that.tail; res }
else Iterator.empty.next()
}

Expand Down Expand Up @@ -174,9 +170,15 @@ sealed class ListSet[A] extends AbstractSet[A]

/** `-` can be used to remove a single element from a set.
*/
override def -(e: A): ListSet[A] = if (e == head) self else {
val tail = self - e; new tail.Node(head)
}
override def -(e: A): ListSet[A] = removeInternal(e, this, Nil)

@tailrec private def removeInternal(k: A, cur: ListSet[A], acc: List[ListSet[A]]): ListSet[A] =
if (cur.isEmpty)
acc.last
else if (k == cur.head)
(cur.tail /: acc) { case (t, h) => new t.Node(h.head) }
else
removeInternal(k, cur.tail, cur :: acc)

override def tail: ListSet[A] = self
}
Expand Down
16 changes: 16 additions & 0 deletions test/junit/scala/collection/immutable/ListMapTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package scala.collection.immutable

import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

@RunWith(classOf[JUnit4])
class ListMapTest {

@Test
def hasCorrectBuilder(): Unit = {
val m = ListMap[String, String]("a" -> "1", "b" -> "2", "c" -> "3", "b" -> "2.2", "d" -> "4")
assertEquals(List("d" -> "4", "b" -> "2.2", "c" -> "3", "a" -> "1"), m.toList)
}
}
16 changes: 16 additions & 0 deletions test/junit/scala/collection/immutable/ListSetTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package scala.collection.immutable

import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

@RunWith(classOf[JUnit4])
class ListSetTest {

@Test
def hasTailRecursiveDelete(): Unit = {
val s = ListSet[Int](1 to 50000: _*)
try s - 25000 catch { case e: StackOverflowError => fail("A stack overflow occurred") }
}
}

0 comments on commit bb05a4d

Please sign in to comment.