Skip to content

Commit

Permalink
optimize Chain.Ops() (#99)
Browse files Browse the repository at this point in the history
Use linear 2SUM-like algorithm in the typical case where the chain is
ascending.

Updates #60
Updates #25
  • Loading branch information
mmcloughlin authored May 12, 2021
1 parent c4e8a5d commit 23389d5
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
38 changes: 38 additions & 0 deletions chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ import (
// [efficientcompaddchain] Bergeron, F., Berstel, J. and Brlek, S. Efficient computation of addition
// chains. Journal de theorie des nombres de Bordeaux. 1994.
// http://www.numdam.org/item/JTNB_1994__6_1_21_0
// [knuth] Knuth, Donald E. Evaluation of Powers. The Art of Computer Programming, Volume 2
// (Third Edition): Seminumerical Algorithms, chapter 4.6.3. 1997.
// https://www-cs-faculty.stanford.edu/~knuth/taocp.html

// Chain is an addition chain.
type Chain []*big.Int
Expand Down Expand Up @@ -49,6 +52,25 @@ func (c Chain) End() *big.Int {
func (c Chain) Ops(k int) []Op {
ops := []Op{}
s := new(big.Int)

// If the prefix is ascending this can be done in linear time.
if c[:k].IsAscending() {
for l, r := 0, k-1; l <= r; {
s.Add(c[l], c[r])
cmp := s.Cmp(c[k])
if cmp == 0 {
ops = append(ops, Op{l, r})
}
if cmp <= 0 {
l++
} else {
r--
}
}
return ops
}

// Fallback to quadratic.
for i := 0; i < k; i++ {
for j := i; j < k; j++ {
s.Add(c[i], c[j])
Expand All @@ -57,6 +79,7 @@ func (c Chain) Ops(k int) []Op {
}
}
}

return ops
}

Expand Down Expand Up @@ -135,6 +158,21 @@ func (c Chain) Superset(targets []*big.Int) error {
return nil
}

// IsAscending reports whether the chain is ascending, that is if it's in sorted
// order without repeats, as defined in [knuth] Section 4.6.3 formula (11).
// Does not fully validate the chain, only that it is ascending.
func (c Chain) IsAscending() bool {
if len(c) == 0 || !bigint.EqualInt64(c[0], 1) {
return false
}
for i := 1; i < len(c); i++ {
if c[i-1].Cmp(c[i]) >= 0 {
return false
}
}
return true
}

// Product computes the product of two addition chains. The is the "o times"
// operator defined in [efficientcompaddchain] Section 2.
func Product(a, b Chain) Chain {
Expand Down
73 changes: 73 additions & 0 deletions chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,79 @@ import (
"testing"
)

func TestChainOps(t *testing.T) {
cases := []struct {
Name string
Chain Chain
Expect [][]Op
}{
{
Name: "short",
Chain: Int64s(1, 2),
Expect: [][]Op{
{{0, 0}}, // 2
},
},
{
Name: "multiple_choices",
Chain: Int64s(1, 2, 3, 4),
Expect: [][]Op{
{{0, 0}}, // 2
{{0, 1}}, // 3
{{0, 2}, {1, 1}}, // 4
},
},
{
Name: "non_ascending",
Chain: Int64s(1, 2, 3, 4, 7, 5, 6),
Expect: [][]Op{
{{0, 0}}, // 2
{{0, 1}}, // 3
{{0, 2}, {1, 1}}, // 4
{{2, 3}}, // 7
{{0, 3}, {1, 2}}, // 5
{{0, 5}, {1, 3}, {2, 2}}, // 6
},
},
}
for _, c := range cases {
c := c // scopelint
t.Run(c.Name, func(t *testing.T) {
var got [][]Op
for k := 1; k < len(c.Chain); k++ {
got = append(got, c.Chain.Ops(k))
}
if !reflect.DeepEqual(got, c.Expect) {
t.Logf("got = %v", got)
t.Logf("expect = %v", c.Expect)
t.Fail()
}
})
}
}

func TestChainIsAscending(t *testing.T) {
cases := []struct {
Name string
Chain Chain
Expect bool
}{
{Name: "empty", Chain: Int64s(), Expect: false},
{Name: "does_not_start_with_one", Chain: Int64s(42), Expect: false},
{Name: "ascending", Chain: Int64s(1, 2, 3, 5, 8), Expect: true},
{Name: "repeat", Chain: Int64s(1, 2, 3, 3, 8), Expect: false},
{Name: "not_sorted", Chain: Int64s(1, 2, 3, 4, 7, 5, 6), Expect: false},
}
for _, c := range cases {
c := c // scopelint
t.Run(c.Name, func(t *testing.T) {
if got := c.Chain.IsAscending(); got != c.Expect {
t.Fatalf("%v.IsAscending() = %v; expect %v", c.Chain, got, c.Expect)
}
})
}
}

func TestProduct(t *testing.T) {
a := Int64s(1, 2, 4, 6, 10)
b := Int64s(1, 2, 4, 8)
Expand Down

0 comments on commit 23389d5

Please sign in to comment.