diff --git a/pkg/tools/etcd_tools.go b/pkg/tools/etcd_tools.go index 3bffc224e375c..a7a467da2d2a6 100644 --- a/pkg/tools/etcd_tools.go +++ b/pkg/tools/etcd_tools.go @@ -17,6 +17,7 @@ limitations under the License. package tools import ( + "encoding/json" "errors" "fmt" "io/ioutil" @@ -24,7 +25,6 @@ import ( "os/exec" "reflect" "strconv" - "strings" "github.com/GoogleCloudPlatform/kubernetes/pkg/conversion" "github.com/GoogleCloudPlatform/kubernetes/pkg/runtime" @@ -425,20 +425,35 @@ func (h *EtcdHelper) AtomicUpdate(key string, ptrToType runtime.Object, ignoreNo } } -func checkEtcd(host string) error { +// GetEtcdVersion performs a version check against the provided Etcd server, returning a triplet +// of the release version, internal version, and error (if any). +func GetEtcdVersion(host string) (releaseVersion, internalVersion string, err error) { response, err := http.Get(host + "/version") if err != nil { - return err + return "", "", err } defer response.Body.Close() + body, err := ioutil.ReadAll(response.Body) if err != nil { - return err + return "", "", err } - if !strings.HasPrefix(string(body), "etcd") { - return fmt.Errorf("unknown server: %s", string(body)) + + var dat map[string]interface{} + if err := json.Unmarshal(body, &dat); err != nil { + return "", "", fmt.Errorf("unknown server: %s", string(body)) } - return nil + if obj := dat["releaseVersion"]; obj != nil { + if s, ok := obj.(string); ok { + releaseVersion = s + } + } + if obj := dat["internalVersion"]; obj != nil { + if s, ok := obj.(string); ok { + internalVersion = s + } + } + return } func startEtcd() (*exec.Cmd, error) { @@ -451,7 +466,7 @@ func startEtcd() (*exec.Cmd, error) { } func NewEtcdClientStartServerIfNecessary(server string) (EtcdClient, error) { - err := checkEtcd(server) + _, _, err := GetEtcdVersion(server) if err != nil { glog.Infof("Failed to find etcd, attempting to start.") _, err := startEtcd() diff --git a/pkg/tools/etcd_tools_test.go b/pkg/tools/etcd_tools_test.go index 9f582c8bfe8a3..0f42f00bc926f 100644 --- a/pkg/tools/etcd_tools_test.go +++ b/pkg/tools/etcd_tools_test.go @@ -19,6 +19,8 @@ package tools import ( "errors" "fmt" + "net/http" + "net/http/httptest" "reflect" "sync" "testing" @@ -29,6 +31,7 @@ import ( "github.com/GoogleCloudPlatform/kubernetes/pkg/conversion" "github.com/GoogleCloudPlatform/kubernetes/pkg/runtime" "github.com/coreos/go-etcd/etcd" + "github.com/stretchr/testify/assert" ) type TestResource struct { @@ -643,3 +646,48 @@ func TestAtomicUpdate_CreateCollision(t *testing.T) { t.Errorf("Some of the writes were lost. Stored value: %d", stored.Value) } } + +func TestGetEtcdVersion_ValidVersion(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "{\"releaseVersion\":\"2.0.3\",\"internalVersion\":\"2\"}") + })) + defer testServer.Close() + + var relVersion string + var intVersion string + var err error + if relVersion, intVersion, err = GetEtcdVersion(testServer.URL); err != nil { + t.Errorf("Unexpected error: %v", err) + } + assert.Equal(t, "2.0.3", relVersion, "Unexpected external version") + assert.Equal(t, "2", intVersion, "Unexpected internal version") + assert.Nil(t, err) +} + +func TestGetEtcdVersion_UnknownVersion(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "{\"unknownAttribute\":\"foobar\",\"internalVersion\":\"2\"}") + })) + defer testServer.Close() + + var relVersion string + var intVersion string + var err error + if relVersion, intVersion, err = GetEtcdVersion(testServer.URL); err != nil { + t.Errorf("Unexpected error: %v", err) + } + assert.Equal(t, "", relVersion, "Unexpected external version") + assert.Equal(t, "2", intVersion, "Unexpected internal version") + assert.Nil(t, err) +} + +func TestGetEtcdVersion_ErrorStatus(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer testServer.Close() + + var err error + _, _, err = GetEtcdVersion(testServer.URL) + assert.NotNil(t, err) +}