Skip to content

Commit

Permalink
lnwire: prep wire messages for TLV extensions
Browse files Browse the repository at this point in the history
Messages:
- UpdateFulfillHTLC
- UpdateFee
- UpdateFailMalformedHTLC
- UpdateFailHTLC
- UpdateAddHTLC
- Shutdown
- RevokeAndAck
- ReplyShortChanIDsEnd
- ReplyChannelRange
- QueryShortChanIDs
- QueryChannelRange
- NodeAnnouncement
- Init
- GossipTimestampRange
- FundingSigned
- FundingLocked
- FundingCreated
- CommitSig
- ClosingSigned
- ChannelUpdate
- ChannelReestablish
- ChannelAnnouncement
- AnnounceSignatures

lnwire: update quickcheck tests, use constant for Error

multi: update unit tests to pass deep equal assertions with messages

In this commit, we update a series of unit tests in the code base to now
pass due to the new wire message encode/decode logic. In many instances,
we'll now manually set the extra bytes to an empty byte slice to avoid
comparisons that fail due to one message having an empty byte slice and
the other having a nil pointer.
  • Loading branch information
Roasbeef authored and halseth committed Feb 24, 2021
1 parent a603ac4 commit 9a6bb19
Show file tree
Hide file tree
Showing 32 changed files with 278 additions and 231 deletions.
22 changes: 14 additions & 8 deletions channeldb/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,8 @@ func TestChannelStateTransition(t *testing.T) {
{
LogIndex: 2,
UpdateMsg: &lnwire.UpdateAddHTLC{
ChanID: lnwire.ChannelID{1, 2, 3},
ChanID: lnwire.ChannelID{1, 2, 3},
ExtraData: make([]byte, 0),
},
},
}
Expand All @@ -628,7 +629,9 @@ func TestChannelStateTransition(t *testing.T) {
if !reflect.DeepEqual(
dbUnsignedAckedUpdates[0], unsignedAckedUpdates[0],
) {
t.Fatalf("unexpected update")
t.Fatalf("unexpected update: expected %v, got %v",
spew.Sdump(unsignedAckedUpdates[0]),
spew.Sdump(dbUnsignedAckedUpdates))
}

// The balances, new update, the HTLCs and the changes to the fake
Expand Down Expand Up @@ -670,22 +673,25 @@ func TestChannelStateTransition(t *testing.T) {
wireSig,
wireSig,
},
ExtraData: make([]byte, 0),
},
LogUpdates: []LogUpdate{
{
LogIndex: 1,
UpdateMsg: &lnwire.UpdateAddHTLC{
ID: 1,
Amount: lnwire.NewMSatFromSatoshis(100),
Expiry: 25,
ID: 1,
Amount: lnwire.NewMSatFromSatoshis(100),
Expiry: 25,
ExtraData: make([]byte, 0),
},
},
{
LogIndex: 2,
UpdateMsg: &lnwire.UpdateAddHTLC{
ID: 2,
Amount: lnwire.NewMSatFromSatoshis(200),
Expiry: 50,
ID: 2,
Amount: lnwire.NewMSatFromSatoshis(200),
Expiry: 50,
ExtraData: make([]byte, 0),
},
},
},
Expand Down
5 changes: 4 additions & 1 deletion channeldb/waitingproof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"reflect"

"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/lnwire"
)
Expand All @@ -23,6 +24,7 @@ func TestWaitingProofStore(t *testing.T) {
proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{
NodeSignature: wireSig,
BitcoinSignature: wireSig,
ExtraOpaqueData: make([]byte, 0),
})

store, err := NewWaitingProofStore(db)
Expand All @@ -40,7 +42,8 @@ func TestWaitingProofStore(t *testing.T) {
t.Fatalf("unable retrieve proof from storage: %v", err)
}
if !reflect.DeepEqual(proof1, proof2) {
t.Fatal("wrong proof retrieved")
t.Fatalf("wrong proof retrieved: expected %v, got %v",
spew.Sdump(proof1), spew.Sdump(proof2))
}

if _, err := store.Get(proof1.OppositeKey()); err != ErrWaitingProofNotFound {
Expand Down
6 changes: 4 additions & 2 deletions discovery/message_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@ func randCompressedPubKey(t *testing.T) [33]byte {

func randAnnounceSignatures() *lnwire.AnnounceSignatures {
return &lnwire.AnnounceSignatures{
ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()),
ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()),
ExtraOpaqueData: make([]byte, 0),
}
}

func randChannelUpdate() *lnwire.ChannelUpdate {
return &lnwire.ChannelUpdate{
ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()),
ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()),
ExtraOpaqueData: make([]byte, 0),
}
}

Expand Down
15 changes: 9 additions & 6 deletions htlcswitch/payment_result_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,21 @@ func TestNetworkResultSerialization(t *testing.T) {
ChanID: chanID,
ID: 2,
PaymentPreimage: preimage,
ExtraData: make([]byte, 0),
}

fail := &lnwire.UpdateFailHTLC{
ChanID: chanID,
ID: 1,
Reason: []byte{},
ChanID: chanID,
ID: 1,
Reason: []byte{},
ExtraData: make([]byte, 0),
}

fail2 := &lnwire.UpdateFailHTLC{
ChanID: chanID,
ID: 1,
Reason: reason[:],
ChanID: chanID,
ID: 1,
Reason: reason[:],
ExtraData: make([]byte, 0),
}

testCases := []*networkResult{
Expand Down
2 changes: 2 additions & 0 deletions lnwallet/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3176,6 +3176,7 @@ func TestChanSyncOweCommitment(t *testing.T) {
Amount: htlcAmt,
Expiry: uint32(10),
OnionBlob: fakeOnionBlob,
ExtraData: make([]byte, 0),
}

htlcIndex, err := bobChannel.AddHTLC(h, nil)
Expand Down Expand Up @@ -3220,6 +3221,7 @@ func TestChanSyncOweCommitment(t *testing.T) {
Amount: htlcAmt,
Expiry: uint32(10),
OnionBlob: fakeOnionBlob,
ExtraData: make([]byte, 0),
}
aliceHtlcIndex, err := aliceChannel.AddHTLC(aliceHtlc, nil)
if err != nil {
Expand Down
25 changes: 4 additions & 21 deletions lnwire/announcement_signatures.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package lnwire

import (
"io"
"io/ioutil"
)

// AnnounceSignatures is a direct message between two endpoints of a
Expand Down Expand Up @@ -40,7 +39,7 @@ type AnnounceSignatures struct {
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData []byte
ExtraOpaqueData ExtraOpaqueData
}

// A compile time check to ensure AnnounceSignatures implements the
Expand All @@ -52,29 +51,13 @@ var _ Message = (*AnnounceSignatures)(nil)
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r,
return ReadElements(r,
&a.ChannelID,
&a.ShortChannelID,
&a.NodeSignature,
&a.BitcoinSignature,
&a.ExtraOpaqueData,
)
if err != nil {
return err
}

// Now that we've read out all the fields that we explicitly know of,
// we'll collect the remainder into the ExtraOpaqueData field. If there
// aren't any bytes, then we'll snip off the slice to avoid carrying
// around excess capacity.
a.ExtraOpaqueData, err = ioutil.ReadAll(r)
if err != nil {
return err
}
if len(a.ExtraOpaqueData) == 0 {
a.ExtraOpaqueData = nil
}

return nil
}

// Encode serializes the target AnnounceSignatures into the passed io.Writer
Expand Down Expand Up @@ -104,5 +87,5 @@ func (a *AnnounceSignatures) MsgType() MessageType {
//
// This is part of the lnwire.Message interface.
func (a *AnnounceSignatures) MaxPayloadLength(pver uint32) uint32 {
return 65533
return MaxMsgBody
}
25 changes: 4 additions & 21 deletions lnwire/channel_announcement.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package lnwire
import (
"bytes"
"io"
"io/ioutil"

"github.com/btcsuite/btcd/chaincfg/chainhash"
)
Expand Down Expand Up @@ -56,7 +55,7 @@ type ChannelAnnouncement struct {
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData []byte
ExtraOpaqueData ExtraOpaqueData
}

// A compile time check to ensure ChannelAnnouncement implements the
Expand All @@ -68,7 +67,7 @@ var _ Message = (*ChannelAnnouncement)(nil)
//
// This is part of the lnwire.Message interface.
func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error {
err := ReadElements(r,
return ReadElements(r,
&a.NodeSig1,
&a.NodeSig2,
&a.BitcoinSig1,
Expand All @@ -80,24 +79,8 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error {
&a.NodeID2,
&a.BitcoinKey1,
&a.BitcoinKey2,
&a.ExtraOpaqueData,
)
if err != nil {
return err
}

// Now that we've read out all the fields that we explicitly know of,
// we'll collect the remainder into the ExtraOpaqueData field. If there
// aren't any bytes, then we'll snip off the slice to avoid carrying
// around excess capacity.
a.ExtraOpaqueData, err = ioutil.ReadAll(r)
if err != nil {
return err
}
if len(a.ExtraOpaqueData) == 0 {
a.ExtraOpaqueData = nil
}

return nil
}

// Encode serializes the target ChannelAnnouncement into the passed io.Writer
Expand Down Expand Up @@ -134,7 +117,7 @@ func (a *ChannelAnnouncement) MsgType() MessageType {
//
// This is part of the lnwire.Message interface.
func (a *ChannelAnnouncement) MaxPayloadLength(pver uint32) uint32 {
return 65533
return MaxMsgBody
}

// DataToSign is used to retrieve part of the announcement message which should
Expand Down
47 changes: 25 additions & 22 deletions lnwire/channel_reestablish.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ type ChannelReestablish struct {
// LocalUnrevokedCommitPoint is the commitment point used in the
// current un-revoked commitment transaction of the sending party.
LocalUnrevokedCommitPoint *btcec.PublicKey

// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}

// A compile time check to ensure ChannelReestablish implements the
Expand All @@ -83,12 +88,20 @@ func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error {
// If the commit point wasn't sent, then we won't write out any of the
// remaining fields as they're optional.
if a.LocalUnrevokedCommitPoint == nil {
return nil
// However, we'll still write out the extra data if it's
// present.
//
// NOTE: This is here primarily for the quickcheck tests, in
// practice, we'll always populate this field.
return WriteElements(w, a.ExtraData)
}

// Otherwise, we'll write out the remaining elements.
return WriteElements(w, a.LastRemoteCommitSecret[:],
a.LocalUnrevokedCommitPoint)
return WriteElements(w,
a.LastRemoteCommitSecret[:],
a.LocalUnrevokedCommitPoint,
a.ExtraData,
)
}

// Decode deserializes a serialized ChannelReestablish stored in the passed
Expand Down Expand Up @@ -118,6 +131,9 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error {
var buf [32]byte
_, err = io.ReadFull(r, buf[:32])
if err == io.EOF {
// If there aren't any more bytes, then we'll emplace an empty
// extra data to make our quickcheck tests happy.
a.ExtraData = make([]byte, 0)
return nil
} else if err != nil {
return err
Expand All @@ -129,7 +145,11 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error {
// We'll conclude by parsing out the commitment point. We don't check
// the error in this case, as it has included the commit secret, then
// they MUST also include the commit point.
return ReadElement(r, &a.LocalUnrevokedCommitPoint)
if err = ReadElement(r, &a.LocalUnrevokedCommitPoint); err != nil {
return err
}

return a.ExtraData.Decode(r)
}

// MsgType returns the integer uniquely identifying this message type on the
Expand All @@ -145,22 +165,5 @@ func (a *ChannelReestablish) MsgType() MessageType {
//
// This is part of the lnwire.Message interface.
func (a *ChannelReestablish) MaxPayloadLength(pver uint32) uint32 {
var length uint32

// ChanID - 32 bytes
length += 32

// NextLocalCommitHeight - 8 bytes
length += 8

// RemoteCommitTailHeight - 8 bytes
length += 8

// LastRemoteCommitSecret - 32 bytes
length += 32

// LocalUnrevokedCommitPoint - 33 bytes
length += 33

return length
return MaxMsgBody
}
Loading

0 comments on commit 9a6bb19

Please sign in to comment.