Skip to content

Commit

Permalink
Better error reporting from decoder
Browse files Browse the repository at this point in the history
The parser works in three stages:

1. Lex the TOML.
2. Parse the TOML values (e.g. parse integers to int64)
3. Decode that in to the destination type (struct or map).

Reporting good errors with position info from stage 1 and 2 was already
possible, but not from stage 3. In some cases we only know if something
is wrong in stage 3, for example trying to parse "200" in to an int8: we
need the type info for that, and we can't do that in stage 1 or 2 since
we don't have it.

This copies a bit more data, and is slightly slower, but only by about
1% so it's acceptable.
  • Loading branch information
arp242 committed May 28, 2022
1 parent 201477d commit 2004196
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 141 deletions.
140 changes: 67 additions & 73 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
s = "%v"
}

return MetaData{}, e("cannot decode to non-pointer "+s, reflect.TypeOf(v))
return MetaData{}, fmt.Errorf("toml: cannot decode to non-pointer "+s, reflect.TypeOf(v))
}
if rv.IsNil() {
return MetaData{}, e("cannot decode to nil value of %q", reflect.TypeOf(v))
return MetaData{}, fmt.Errorf("toml: cannot decode to nil value of %q", reflect.TypeOf(v))
}

// Check if this is a supported type: struct, map, interface{}, or something
Expand All @@ -136,7 +136,7 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map &&
!(rv.Kind() == reflect.Interface && rv.NumMethod() == 0) &&
!rt.Implements(unmarshalToml) && !rt.Implements(unmarshalText) {
return MetaData{}, e("cannot decode to type %s", rt)
return MetaData{}, fmt.Errorf("toml: cannot decode to type %s", rt)
}

// TODO: parser should read from io.Reader? Or at the very least, make it
Expand All @@ -153,10 +153,11 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {

md := MetaData{
mapping: p.mapping,
types: p.types,
keyInfo: p.keyInfo,
keys: p.ordered,
decoded: make(map[string]struct{}, len(p.ordered)),
context: nil,
data: data,
}
return md, md.unify(p.mapping, rv)
}
Expand Down Expand Up @@ -242,15 +243,14 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
case reflect.Bool:
return md.unifyBool(data, rv)
case reflect.Interface:
// we only support empty interfaces.
if rv.NumMethod() > 0 {
return e("unsupported type %s", rv.Type())
if rv.NumMethod() > 0 { // Only support empty interfaces are supported.
return md.e("unsupported type %s", rv.Type())
}
return md.unifyAnything(data, rv)
case reflect.Float32, reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("unsupported type %s", rv.Kind())
return md.e("unsupported type %s", rv.Kind())
}

func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
Expand All @@ -259,7 +259,7 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
if mapping == nil {
return nil
}
return e("type mismatch for %s: expected table but found %T",
return md.e("type mismatch for %s: expected table but found %T",
rv.Type().String(), mapping)
}

Expand Down Expand Up @@ -291,7 +291,7 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
}
md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" {
return e("cannot write unexported field %s.%s", rv.Type().String(), f.name)
return md.e("cannot write unexported field %s.%s", rv.Type().String(), f.name)
}
}
}
Expand Down Expand Up @@ -341,7 +341,7 @@ func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
return md.badtype("slice", data)
}
if l := datav.Len(); l != rv.Len() {
return e("expected array length %d; got TOML array of length %d", rv.Len(), l)
return md.e("expected array length %d; got TOML array of length %d", rv.Len(), l)
}
return md.unifySliceArray(datav, rv)
}
Expand Down Expand Up @@ -382,11 +382,13 @@ func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
}

func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
rvk := rv.Kind()

if num, ok := data.(float64); ok {
switch rv.Kind() {
switch rvk {
case reflect.Float32:
if num < -math.MaxFloat32 || num > math.MaxFloat32 {
return e("value %f is out of range for float32", num)
return md.parseErr(errParseRange{i: num, size: rvk.String()})
}
fallthrough
case reflect.Float64:
Expand All @@ -398,71 +400,44 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
}

if num, ok := data.(int64); ok {
switch rv.Kind() {
case reflect.Float32:
if num < -maxSafeFloat32Int || num > maxSafeFloat32Int {
return e("value %d is out of range for float32", num)
}
fallthrough
case reflect.Float64:
if num < -maxSafeFloat64Int || num > maxSafeFloat64Int {
return e("value %d is out of range for float64", num)
}
rv.SetFloat(float64(num))
default:
panic("bug")
if (rvk == reflect.Float32 && (num < -maxSafeFloat32Int || num > maxSafeFloat32Int)) ||
(rvk == reflect.Float64 && (num < -maxSafeFloat64Int || num > maxSafeFloat64Int)) {
return md.parseErr(errParseRange{i: num, size: rvk.String()})
}
rv.SetFloat(float64(num))
return nil
}

return md.badtype("float", data)
}

func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok {
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 {
switch rv.Kind() {
case reflect.Int, reflect.Int64:
// No bounds checking necessary.
case reflect.Int8:
if num < math.MinInt8 || num > math.MaxInt8 {
return e("value %d is out of range for int8", num)
}
case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 {
return e("value %d is out of range for int16", num)
}
case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 {
return e("value %d is out of range for int32", num)
}
}
rv.SetInt(num)
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 {
unum := uint64(num)
switch rv.Kind() {
case reflect.Uint, reflect.Uint64:
// No bounds checking necessary.
case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 {
return e("value %d is out of range for uint8", num)
}
case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 {
return e("value %d is out of range for uint16", num)
}
case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 {
return e("value %d is out of range for uint32", num)
}
}
rv.SetUint(unum)
} else {
panic("unreachable")
num, ok := data.(int64)
if !ok {
return md.badtype("integer", data)
}

rvk := rv.Kind()
switch {
case rvk >= reflect.Int && rvk <= reflect.Int64:
if (rvk == reflect.Int8 && (num < math.MinInt8 || num > math.MaxInt8)) ||
(rvk == reflect.Int16 && (num < math.MinInt16 || num > math.MaxInt16)) ||
(rvk == reflect.Int32 && (num < math.MinInt32 || num > math.MaxInt32)) {
return md.parseErr(errParseRange{i: num, size: rvk.String()})
}
return nil
rv.SetInt(num)
case rvk >= reflect.Uint && rvk <= reflect.Uint64:
unum := uint64(num)
if rvk == reflect.Uint8 && (num < 0 || unum > math.MaxUint8) ||
rvk == reflect.Uint16 && (num < 0 || unum > math.MaxUint16) ||
rvk == reflect.Uint32 && (num < 0 || unum > math.MaxUint32) {
return md.parseErr(errParseRange{i: num, size: rvk.String()})
}
rv.SetUint(unum)
default:
panic("unreachable")
}
return md.badtype("integer", data)
return nil
}

func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
Expand Down Expand Up @@ -513,7 +488,30 @@ func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) erro
}

func (md *MetaData) badtype(dst string, data interface{}) error {
return e("incompatible types: TOML key %q has type %T; destination has type %s", md.context, data, dst)
return md.e("incompatible types: TOML value has type %T; destination has type %s", data, dst)
}

func (md *MetaData) parseErr(err error) error {
k := md.context.String()
return ParseError{
LastKey: k,
Position: md.keyInfo[k].pos,
Line: md.keyInfo[k].pos.Line,
err: err,
input: string(md.data),
}
}

func (md *MetaData) e(format string, args ...interface{}) error {
f := "toml: "
if len(md.context) > 0 {
f = fmt.Sprintf("toml: (last key %q): ", md.context)
p := md.keyInfo[md.context.String()].pos
if p.Line > 0 {
f = fmt.Sprintf("toml: line %d (last key %q): ", p.Line, md.context)
}
}
return fmt.Errorf(f+format, args...)
}

// rvalue returns a reflect.Value of `v`. All pointers are resolved.
Expand Down Expand Up @@ -561,7 +559,3 @@ func isUnifiable(rv reflect.Value) bool {
}
return false
}

func e(format string, args ...interface{}) error {
return fmt.Errorf("toml: "+format, args...)
}
93 changes: 87 additions & 6 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,87 @@ func TestDecodeEmbedded(t *testing.T) {
}
}

func TestDecodeErrors(t *testing.T) {
tests := []struct {
s interface{}
toml string
wantErr string
}{
{
&struct{ V int8 }{},
`V = 999`,
`toml: line 1 (last key "V"): 999 is out of range for int8`,
},
{
&struct{ V float32 }{},
`V = 999999999999999`,
`toml: line 1 (last key "V"): 999999999999999 is out of range for float32`,
},
{
&struct{ V string }{},
`V = 5`,
`toml: line 1 (last key "V"): incompatible types: TOML value has type int64; destination has type string`,
},
{
&struct{ V interface{ ASD() } }{},
`V = 999`,
`toml: line 1 (last key "V"): unsupported type interface { ASD() }`,
},
{
&struct{ V struct{ V int } }{},
`V = 999`,
`toml: line 1 (last key "V"): type mismatch for struct { V int }: expected table but found int64`,
},
{
&struct{ V [1]int }{},
`V = [1,2,3]`,
`toml: line 1 (last key "V"): expected array length 1; got TOML array of length 3`,
},
{
&struct{ V struct{ N int8 } }{},
`V.N = 999`,
`toml: line 1 (last key "V.N"): 999 is out of range for int8`,
},
{
&struct{ V struct{ N float32 } }{},
`V.N = 999999999999999`,
`toml: line 1 (last key "V.N"): 999999999999999 is out of range for float32`,
},
{
&struct{ V struct{ N string } }{},
`V.N = 5`,
`toml: line 1 (last key "V.N"): incompatible types: TOML value has type int64; destination has type string`,
},
{
&struct {
V struct{ N interface{ ASD() } }
}{},
`V.N = 999`,
`toml: line 1 (last key "V.N"): unsupported type interface { ASD() }`,
},
{
&struct{ V struct{ N struct{ V int } } }{},
`V.N = 999`,
`toml: line 1 (last key "V.N"): type mismatch for struct { V int }: expected table but found int64`,
},
{
&struct{ V struct{ N [1]int } }{},
`V.N = [1,2,3]`,
`toml: line 1 (last key "V.N"): expected array length 1; got TOML array of length 3`,
},
}

for _, tt := range tests {
_, err := Decode(tt.toml, tt.s)
if err == nil {
t.Fatal("err is nil")
}
if err.Error() != tt.wantErr {
t.Errorf("\nhave: %q\nwant: %q", err, tt.wantErr)
}
}
}

func TestDecodeIgnoreFields(t *testing.T) {
const input = `
Number = 123
Expand Down Expand Up @@ -368,11 +449,11 @@ func TestDecodeTypes(t *testing.T) {
{(*Unmarshaler)(nil), `toml: cannot decode to nil value of "*toml.Unmarshaler"`},
{nil, `toml: cannot decode to non-pointer <nil>`},

{new(map[int]string), "cannot decode to a map with non-string key type"},
{new(map[interface{}]string), "cannot decode to a map with non-string key type"},
{new(map[int]string), "toml: cannot decode to a map with non-string key type"},
{new(map[interface{}]string), "toml: cannot decode to a map with non-string key type"},

{new(struct{ F int }), `toml: incompatible types: TOML key "F" has type bool; destination has type integer`},
{new(map[string]int), `toml: incompatible types: TOML key "F" has type bool; destination has type integer`},
{new(struct{ F int }), `toml: line 1 (last key "F"): incompatible types: TOML value has type bool; destination has type integer`},
{new(map[string]int), `toml: line 1 (last key "F"): incompatible types: TOML value has type bool; destination has type integer`},
{new(int), `toml: cannot decode to type int`},
{new([]int), "toml: cannot decode to type []int"},
} {
Expand Down Expand Up @@ -513,8 +594,8 @@ Locations = {NY = {Temp = "not cold", Rating = 4}, MI = {Temp = "freezing", Rati
if len(meta.keys) != 12 {
t.Errorf("after decode, got %d meta keys; want 12", len(meta.keys))
}
if len(meta.types) != 12 {
t.Errorf("after decode, got %d meta types; want 12", len(meta.types))
if len(meta.keyInfo) != 12 {
t.Errorf("after decode, got %d meta keyInfo; want 12", len(meta.keyInfo))
}
}

Expand Down
Loading

0 comments on commit 2004196

Please sign in to comment.