diff --git a/pkg/airgap/file_scp.go b/pkg/airgap/file_scp.go index 149cf088..49cdbb14 100644 --- a/pkg/airgap/file_scp.go +++ b/pkg/airgap/file_scp.go @@ -7,6 +7,7 @@ import ( "io" "os" "path/filepath" + "strings" "sync" "github.com/cnrancher/autok3s/pkg/common" @@ -21,12 +22,16 @@ import ( ) type fileMap struct { - mode os.FileMode - targetPath string + mode os.FileMode + dataDirSubpath string + targetPath string } const ( - installScriptName = "install.sh" + installScriptName = "install.sh" + defaultDataDirPath = "/var/lib/rancher/k3s" + dataDirParamPrefix = "--data-dir" + dataDirParamPrefixShort = "-d" ) var ( @@ -40,8 +45,8 @@ var ( targetPath: "/usr/local/bin", }, "k3s-airgap-images": { - mode: 0644, - targetPath: "/var/lib/rancher/k3s/agent/images", + mode: 0644, + dataDirSubpath: "agent/images", }, installScriptName: { mode: 0755, @@ -55,9 +60,10 @@ var ( } ) -func ScpFiles(clusterName string, pkg *common.Package, dialer *hosts.SSHDialer) error { +func ScpFiles(logger *logrus.Logger, clusterName string, pkg *common.Package, dialer *hosts.SSHDialer, extraArgs string) (er error) { + dataPath := getDataPath(extraArgs) conn := dialer.GetClient() - fieldLogger := logrus.WithFields(logrus.Fields{ + fieldLogger := logger.WithFields(logrus.Fields{ "cluster": clusterName, "component": "airgap", }) @@ -127,17 +133,30 @@ func ScpFiles(clusterName string, pkg *common.Package, dialer *hosts.SSHDialer) return err } - targetFilename := filepath.Join(remote.targetPath, filename) + targetPath := remote.targetPath + if remote.dataDirSubpath != "" { + targetPath = filepath.Join(dataPath, remote.dataDirSubpath) + } + targetFilename := filepath.Join(targetPath, filename) + var stdout, stderr bytes.Buffer dialer = dialer.SetStdio(&stdout, &stderr, nil) - moveCMD := fmt.Sprintf("sudo mkdir -p %s;sudo mv %s %s", remote.targetPath, remoteFileName, targetFilename) + moveCMD := fmt.Sprintf("sudo mkdir -p %s;sudo mv %s %s", targetPath, remoteFileName, targetFilename) fieldLogger.Infof("executing cmd in remote server %s", moveCMD) if err := dialer.Cmd(moveCMD).Run(); err != nil { fieldLogger.Errorf("failed to execute cmd %s, stdout: %s, stderr: %s, %v", moveCMD, stdout.String(), stderr.String(), err) return err } + fieldLogger.Infof("file moved to %s", targetFilename) fieldLogger.Infof("remote file %s transferred", filename) + + defer func(tmpFilename, targetFilename string) { + //clean up process when return error + if er != nil { + fieldLogger.Warnf("error occurs when transferring resources, following resources should be clean later: %s %s", tmpFilename, targetFilename) + } + }(remoteFileName, targetFilename) } fieldLogger.Info("all files transferred") @@ -283,3 +302,33 @@ func getScpFileMap(arch string, pkg *common.Package) (map[string]fileMap, error) func getRemoteTmpDir(clustername string) string { return filepath.Join(remoteTmpDir, clustername) } + +func getDataPath(extraArgs string) string { + dataPath := defaultDataDirPath + args := strings.Split(extraArgs, " ") + for i, arg := range args { + var prefix string + if strings.HasPrefix(arg, dataDirParamPrefix) { + prefix = dataDirParamPrefix + } + if strings.HasPrefix(arg, dataDirParamPrefixShort) { + prefix = dataDirParamPrefixShort + } + if prefix == "" { + continue + } + // if the arg == dataPath prefix, return the next arg + if len(arg) == len(prefix) && i < len(args)-1 { + return args[i+1] + } + // this is the case as -d=/data/xxx + if len(arg) > len(prefix) && arg[len(prefix)] == '=' { + return strings.TrimPrefix(arg, prefix+"=") + } + // only two cases above are validated, otherwise return the default path + if prefix != "" { + break + } + } + return dataPath +} diff --git a/pkg/airgap/file_scp_test.go b/pkg/airgap/file_scp_test.go new file mode 100644 index 00000000..28aadc41 --- /dev/null +++ b/pkg/airgap/file_scp_test.go @@ -0,0 +1,27 @@ +package airgap + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetDataPath(t *testing.T) { + type testcase struct { + name string + args string + expectPath string + } + for _, c := range []testcase{ + {name: "no extra args", expectPath: defaultDataDirPath}, + {name: "no data path args", args: "--bind-address=0.0.0.0", expectPath: defaultDataDirPath}, + {name: "data dir args", args: "--data-dir /data", expectPath: "/data"}, + {name: "data dir args with equal sign", args: "--data-dir=/data", expectPath: "/data"}, + {name: "data dir args with short name", args: "-d /data", expectPath: "/data"}, + {name: "data dir args with short name and equal sign", args: "-d=/data", expectPath: "/data"}, + {name: "wrong data dir args", args: "--data-dir", expectPath: defaultDataDirPath}, + } { + path := getDataPath(c.args) + assert.Equalf(t, c.expectPath, path, "test: %s failed", c.name) + } +} diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index 41c36d26..b8605097 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -449,7 +449,7 @@ func (p *ProviderBase) initNode(isFirstMaster bool, fixedIP string, cluster *typ } if pkg != nil { - if err := p.scpFiles(cluster.Name, pkg, &node); err != nil { + if err := p.scpFiles(cluster.Name, pkg, &node, extraArgs); err != nil { return err } } @@ -784,7 +784,7 @@ func (p *ProviderBase) Upgrade(cluster *types.Cluster) error { var cmd string if pkg != nil { - if err := p.scpFiles(cluster.Name, pkg, &node); err != nil { + if err := p.scpFiles(cluster.Name, pkg, &node, extraArgs); err != nil { return err } cmd = k3sRestart @@ -809,7 +809,7 @@ func (p *ProviderBase) Upgrade(cluster *types.Cluster) error { var cmd string if pkg != nil { - if err := p.scpFiles(cluster.Name, pkg, &node); err != nil { + if err := p.scpFiles(cluster.Name, pkg, &node, extraArgs); err != nil { return err } cmd = k3sAgentRestart @@ -835,14 +835,14 @@ func nodeByInstanceID(nodes []types.Node) map[string]types.Node { return rtn } -func (p *ProviderBase) scpFiles(clusterName string, pkg *common.Package, node *types.Node) error { +func (p *ProviderBase) scpFiles(clusterName string, pkg *common.Package, node *types.Node, extraArgs string) error { dialer, err := hosts.NewSSHDialer(node, true, p.Logger) if err != nil { return err } defer dialer.Close() dialer.SetWriter(p.Logger.Out) - return airgap.ScpFiles(clusterName, pkg, dialer) + return airgap.ScpFiles(p.Logger, clusterName, pkg, dialer, extraArgs) } func (p *ProviderBase) handleDataStoreCertificate(n *types.Node, c *types.Cluster) error {