Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set config default values before parse #525

Merged
merged 20 commits into from
Dec 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 87 additions & 31 deletions cfg/config.go
Original file line number Diff line number Diff line change
@@ -239,6 +239,17 @@ func tryApplyFunc(app funcApplier, field *simplejson.Json) (string, bool) {
return "", false
}

func DecodeConfig(config any, configJson []byte) error {
err := SetDefaultValues(config)
if err != nil {
return err
}

dec := json.NewDecoder(bytes.NewReader(configJson))
dec.DisallowUnknownFields()
return dec.Decode(config)
}

// Parse holy shit! who write this function?
func Parse(ptr any, values map[string]int) error {
v := reflect.ValueOf(ptr).Elem()
@@ -325,37 +336,7 @@ func ParseSlice(v reflect.Value, values map[string]int) error {
}

func ParseField(v reflect.Value, vField reflect.Value, tField *reflect.StructField, values map[string]int) error {
tag := tField.Tag.Get("required")
required := tag == trueValue

tag = tField.Tag.Get("default")
if tag != "" {
// need to check current value of field and then set default in case of empty value
switch vField.Kind() {
case reflect.String:
if vField.String() == "" {
vField.SetString(tag)
}
case reflect.Int:
if vField.Int() == 0 { // like in vField.IsZero
val, err := strconv.Atoi(tag)
if err != nil {
return fmt.Errorf("default value for field %s should be int, got=%s: %w", tField.Name, tag, err)
}
vField.SetInt(int64(val))
}
case reflect.Slice:
if vField.Len() == 0 {
val := strings.Fields(tag)
vField.Set(reflect.MakeSlice(vField.Type(), len(val), len(val)))
for i, v := range val {
vField.Index(i).SetString(v)
}
}
}
}

tag = tField.Tag.Get("options")
tag := tField.Tag.Get("options")
if tag != "" {
parts := strings.Split(tag, "|")
if vField.Kind() != reflect.String {
@@ -520,6 +501,9 @@ func ParseField(v reflect.Value, vField reflect.Value, tField *reflect.StructFie
}
}

tag = tField.Tag.Get("required")
required := tag == trueValue

if required && vField.IsZero() {
return fmt.Errorf("field %s should set as non-zero value", tField.Name)
}
@@ -561,6 +545,78 @@ func ParseFieldSelector(selector string) []string {
return result
}

func SetDefaultValues(data interface{}) error {
t := reflect.TypeOf(data).Elem()
v := reflect.ValueOf(data).Elem()

if t.Kind() != reflect.Struct {
return nil
}

for i := 0; i < t.NumField(); i++ {
tField := t.Field(i)
vField := v.Field(i)

vFieldKind := vField.Kind()

var err error
switch vFieldKind {
case reflect.Struct:
err = SetDefaultValues(vField.Addr().Interface())
if err != nil {
return err
}
case reflect.Slice:
for i := 0; i < vField.Len(); i++ {
item := vField.Index(i)
if item.Kind() == reflect.Struct {
err = SetDefaultValues(item.Addr().Interface())
if err != nil {
return err
}
}
}
}

defaultValue := tField.Tag.Get("default")
if defaultValue != "" {
switch vFieldKind {
case reflect.Bool:
currentValue := vField.Bool()
if !currentValue {
if defaultValue == "true" {
vField.SetBool(true)
} else if defaultValue == "false" {
vField.SetBool(false)
}
}
case reflect.String:
if vField.String() == "" {
vField.SetString(defaultValue)
}
case reflect.Int:
if vField.Int() == 0 { // like in vField.IsZero
val, err := strconv.Atoi(defaultValue)
if err != nil {
return fmt.Errorf("default value for field %s should be int, got=%s: %w", tField.Name, defaultValue, err)
}
vField.SetInt(int64(val))
}
case reflect.Slice:
if vField.Len() == 0 {
val := strings.Fields(defaultValue)
vField.Set(reflect.MakeSlice(vField.Type(), len(val), len(val)))
for i, v := range val {
vField.Index(i).SetString(v)
}
}
}
}
}

return nil
}

func ListToMap(a []string) map[string]bool {
result := make(map[string]bool, len(a))
for _, key := range a {
44 changes: 44 additions & 0 deletions cfg/config_test.go
Original file line number Diff line number Diff line change
@@ -35,6 +35,10 @@ type strDefault struct {
T string `default:"sync"`
}

type boolDefault struct {
T bool `default:"true"`
}

type PersistenceMode byte

const (
@@ -70,6 +74,23 @@ type sliceChild struct {
Value string `default:"child"`
}

func (s *sliceChild) UnmarshalJSON(raw []byte) error {
SetDefaultValues(s)
var childPtr struct {
Value *string
}

if err := json.Unmarshal(raw, &childPtr); err != nil {
return err
}

if childPtr.Value != nil {
s.Value = *childPtr.Value
}

return nil
}

type sliceStruct struct {
Value string `default:"parent"`
Childs []sliceChild `default:"" slice:"true"`
@@ -96,6 +117,7 @@ func TestParseRequiredErr(t *testing.T) {

func TestParseDefault(t *testing.T) {
s := &strDefault{}
SetDefaultValues(s)
err := Parse(s, nil)

assert.NoError(t, err, "shouldn't be an error")
@@ -113,6 +135,7 @@ func TestParseDuration(t *testing.T) {
T Duration `default:"5s" parse:"duration"`
T_ time.Duration
}{}
SetDefaultValues(s)
r.NoError(Parse(s, nil))
r.Equal(time.Second*5, s.T_)
})
@@ -124,6 +147,7 @@ func TestParseDuration(t *testing.T) {
T Duration `parse:"duration"`
T_ time.Duration
}{}
SetDefaultValues(s)
r.NoError(Parse(s, nil))
r.Equal(time.Duration(0), s.T_)
})
@@ -333,6 +357,7 @@ func TestHierarchy(t *testing.T) {

func TestSlice(t *testing.T) {
s := &sliceStruct{Value: "parent_value", Childs: []sliceChild{{"child_1"}, {}}}
SetDefaultValues(s)
err := Parse(s, map[string]int{})

assert.Nil(t, err, "shouldn't be an error")
@@ -343,6 +368,7 @@ func TestSlice(t *testing.T) {

func TestDefaultSlice(t *testing.T) {
s := &sliceStruct{Value: "parent_value"}
SetDefaultValues(s)
err := Parse(s, map[string]int{})

assert.Nil(t, err, "shouldn't be an error")
@@ -353,6 +379,7 @@ func TestDefaultSlice(t *testing.T) {

func TestBase8Default(t *testing.T) {
s := &strBase8{}
SetDefaultValues(s)
err := Parse(s, nil)
assert.Nil(t, err, "shouldn't be an error")
assert.Equal(t, int64(438), s.T_)
@@ -561,6 +588,23 @@ func TestParseDefaultInt(t *testing.T) {
{s: &intDefault{T: 17}, expected: 17},
}
for i, tc := range testCases {
SetDefaultValues(tc.s)
err := Parse(tc.s, nil)

assert.NoError(t, err, "shouldn't be an error tc: %d", i)
assert.Equal(t, tc.expected, tc.s.T, "wrong value tc: %d", i)
}
}

func TestParseDefaultBool(t *testing.T) {
testCases := []struct {
s *boolDefault
expected bool
}{
{s: &boolDefault{}, expected: true},
}
for i, tc := range testCases {
SetDefaultValues(tc.s)
err := Parse(tc.s, nil)

assert.NoError(t, err, "shouldn't be an error tc: %d", i)
44 changes: 25 additions & 19 deletions cmd/file.d/file.d_test.go
Original file line number Diff line number Diff line change
@@ -21,13 +21,9 @@ import (
_ "github.com/ozontech/file.d/plugin/action/rename"
_ "github.com/ozontech/file.d/plugin/action/throttle"
_ "github.com/ozontech/file.d/plugin/input/fake"
"github.com/ozontech/file.d/plugin/input/file"
http2 "github.com/ozontech/file.d/plugin/input/http"
k8s2 "github.com/ozontech/file.d/plugin/input/k8s"
_ "github.com/ozontech/file.d/plugin/output/devnull"
"github.com/ozontech/file.d/plugin/output/gelf"
_ "github.com/ozontech/file.d/plugin/output/kafka"
"github.com/ozontech/file.d/plugin/output/splunk"
uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -218,44 +214,44 @@ func TestThatPluginsAreImported(t *testing.T) {

type testConfig struct {
name string
factory func() (pipeline.AnyPlugin, pipeline.AnyConfig)
kind pipeline.PluginKind
configJSON string
}

func TestConfigParseValid(t *testing.T) {
testList := []testConfig{
{
name: "file",
factory: file.Factory,
configJSON: `{"offsets_op":"tail","persistence_mode":"sync","watching_dir":"/var/"}`,
kind: pipeline.PluginKindInput,
configJSON: `{"offsets_op":"tail","persistence_mode":"sync","watching_dir":"/var/", "offsets_file": "./offset.yaml"}`,
},
{
name: "http",
factory: http2.Factory,
configJSON: `{"address": ":9001","emulate_mode":"yes"}`,
kind: pipeline.PluginKindInput,
configJSON: `{"address": ":9001","emulate_mode":"elasticsearch"}`,
},
{
name: "k8s",
factory: k8s2.Factory,
kind: pipeline.PluginKindInput,
configJSON: `{"split_event_size":1000,"watching_dir":"/var/log/containers/","offsets_file":"/data/k8s-offsets.yaml"}`,
},
{
name: "gelf",
factory: gelf.Factory,
kind: pipeline.PluginKindOutput,
configJSON: `{"endpoint":"graylog.svc.cluster.local:12201","reconnect_interval":"1m","default_short_message_value":"message isn't provided"}`,
},
{
name: "splunk",
factory: splunk.Factory,
kind: pipeline.PluginKindOutput,
configJSON: `{"endpoint":"splunk_endpoint","token":"value_token"}`,
},
}
for _, tl := range testList {
tl := tl
t.Run(tl.name, func(t *testing.T) {
t.Parallel()
_, config := tl.factory()
err := fd.DecodeConfig(config, []byte(tl.configJSON))
pluginInfo := fd.DefaultPluginRegistry.Get(tl.kind, tl.name)
_, err := pipeline.GetConfig(pluginInfo, []byte(tl.configJSON), map[string]int{"gomaxprocs": 1, "capacity": 64})
assert.NoError(t, err, "shouldn't be an error")
})
}
@@ -265,26 +261,36 @@ func TestConfigParseInvalid(t *testing.T) {
testList := []testConfig{
{
name: "http",
factory: http2.Factory,
kind: pipeline.PluginKindInput,
configJSON: `{"address": ":9001","emulate_mode":"yes","un_exist_field":"bla-bla"}`,
},
{
name: "k8s",
factory: k8s2.Factory,
kind: pipeline.PluginKindInput,
configJSON: `{"split_event_size":pp,"watching_dir":"/var/log/containers/","offsets_file":"/data/k8s-offsets.yaml"}`,
},
{
name: "gelf",
factory: gelf.Factory,
kind: pipeline.PluginKindOutput,
configJSON: `{"reconnect_interval_1":"1m","default_short_message_value":"message isn't provided"}`,
},
{
name: "http",
kind: pipeline.PluginKindInput,
configJSON: `{"address": ":9001","emulate_mode":"yes"}`,
},
{
name: "file",
kind: pipeline.PluginKindInput,
configJSON: `{"offsets_op":"tail","persistence_mode":"sync","watching_dir":"/var/"}`,
},
}
for _, tl := range testList {
tl := tl
t.Run(tl.name, func(t *testing.T) {
t.Parallel()
_, config := tl.factory()
err := fd.DecodeConfig(config, []byte(tl.configJSON))
pluginInfo := fd.DefaultPluginRegistry.Get(tl.kind, tl.name)
_, err := pipeline.GetConfig(pluginInfo, []byte(tl.configJSON), map[string]int{"gomaxprocs": 1, "capacity": 64})
assert.Error(t, err, "should be an error")
})
}
Loading