Skip to content

Commit

Permalink
feat(airgap): Support scp files to node's data-dir
Browse files Browse the repository at this point in the history
  • Loading branch information
orangedeng authored and JacieChao committed Apr 23, 2023
1 parent 8a3777b commit e93e779
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 14 deletions.
67 changes: 58 additions & 9 deletions pkg/airgap/file_scp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"os"
"path/filepath"
"strings"
"sync"

"github.com/cnrancher/autok3s/pkg/common"
Expand All @@ -21,12 +22,16 @@ import (
)

type fileMap struct {
mode os.FileMode
targetPath string
mode os.FileMode
dataDirSubpath string
targetPath string
}

const (
installScriptName = "install.sh"
installScriptName = "install.sh"
defaultDataDirPath = "/var/lib/rancher/k3s"
dataDirParamPrefix = "--data-dir"
dataDirParamPrefixShort = "-d"
)

var (
Expand All @@ -40,8 +45,8 @@ var (
targetPath: "/usr/local/bin",
},
"k3s-airgap-images": {
mode: 0644,
targetPath: "/var/lib/rancher/k3s/agent/images",
mode: 0644,
dataDirSubpath: "agent/images",
},
installScriptName: {
mode: 0755,
Expand All @@ -55,9 +60,10 @@ var (
}
)

func ScpFiles(clusterName string, pkg *common.Package, dialer *hosts.SSHDialer) error {
func ScpFiles(logger *logrus.Logger, clusterName string, pkg *common.Package, dialer *hosts.SSHDialer, extraArgs string) (er error) {
dataPath := getDataPath(extraArgs)
conn := dialer.GetClient()
fieldLogger := logrus.WithFields(logrus.Fields{
fieldLogger := logger.WithFields(logrus.Fields{
"cluster": clusterName,
"component": "airgap",
})
Expand Down Expand Up @@ -127,17 +133,30 @@ func ScpFiles(clusterName string, pkg *common.Package, dialer *hosts.SSHDialer)
return err
}

targetFilename := filepath.Join(remote.targetPath, filename)
targetPath := remote.targetPath
if remote.dataDirSubpath != "" {
targetPath = filepath.Join(dataPath, remote.dataDirSubpath)
}
targetFilename := filepath.Join(targetPath, filename)

var stdout, stderr bytes.Buffer
dialer = dialer.SetStdio(&stdout, &stderr, nil)
moveCMD := fmt.Sprintf("sudo mkdir -p %s;sudo mv %s %s", remote.targetPath, remoteFileName, targetFilename)
moveCMD := fmt.Sprintf("sudo mkdir -p %s;sudo mv %s %s", targetPath, remoteFileName, targetFilename)
fieldLogger.Infof("executing cmd in remote server %s", moveCMD)
if err := dialer.Cmd(moveCMD).Run(); err != nil {
fieldLogger.Errorf("failed to execute cmd %s, stdout: %s, stderr: %s, %v", moveCMD, stdout.String(), stderr.String(), err)
return err
}

fieldLogger.Infof("file moved to %s", targetFilename)
fieldLogger.Infof("remote file %s transferred", filename)

defer func(tmpFilename, targetFilename string) {
//clean up process when return error
if er != nil {
fieldLogger.Warnf("error occurs when transferring resources, following resources should be clean later: %s %s", tmpFilename, targetFilename)
}
}(remoteFileName, targetFilename)
}

fieldLogger.Info("all files transferred")
Expand Down Expand Up @@ -283,3 +302,33 @@ func getScpFileMap(arch string, pkg *common.Package) (map[string]fileMap, error)
func getRemoteTmpDir(clustername string) string {
return filepath.Join(remoteTmpDir, clustername)
}

func getDataPath(extraArgs string) string {
dataPath := defaultDataDirPath
args := strings.Split(extraArgs, " ")
for i, arg := range args {
var prefix string
if strings.HasPrefix(arg, dataDirParamPrefix) {
prefix = dataDirParamPrefix
}
if strings.HasPrefix(arg, dataDirParamPrefixShort) {
prefix = dataDirParamPrefixShort
}
if prefix == "" {
continue
}
// if the arg == dataPath prefix, return the next arg
if len(arg) == len(prefix) && i < len(args)-1 {
return args[i+1]
}
// this is the case as -d=/data/xxx
if len(arg) > len(prefix) && arg[len(prefix)] == '=' {
return strings.TrimPrefix(arg, prefix+"=")
}
// only two cases above are validated, otherwise return the default path
if prefix != "" {
break
}
}
return dataPath
}
27 changes: 27 additions & 0 deletions pkg/airgap/file_scp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package airgap

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetDataPath(t *testing.T) {
type testcase struct {
name string
args string
expectPath string
}
for _, c := range []testcase{
{name: "no extra args", expectPath: defaultDataDirPath},
{name: "no data path args", args: "--bind-address=0.0.0.0", expectPath: defaultDataDirPath},
{name: "data dir args", args: "--data-dir /data", expectPath: "/data"},
{name: "data dir args with equal sign", args: "--data-dir=/data", expectPath: "/data"},
{name: "data dir args with short name", args: "-d /data", expectPath: "/data"},
{name: "data dir args with short name and equal sign", args: "-d=/data", expectPath: "/data"},
{name: "wrong data dir args", args: "--data-dir", expectPath: defaultDataDirPath},
} {
path := getDataPath(c.args)
assert.Equalf(t, c.expectPath, path, "test: %s failed", c.name)
}
}
10 changes: 5 additions & 5 deletions pkg/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ func (p *ProviderBase) initNode(isFirstMaster bool, fixedIP string, cluster *typ
}

if pkg != nil {
if err := p.scpFiles(cluster.Name, pkg, &node); err != nil {
if err := p.scpFiles(cluster.Name, pkg, &node, extraArgs); err != nil {
return err
}
}
Expand Down Expand Up @@ -784,7 +784,7 @@ func (p *ProviderBase) Upgrade(cluster *types.Cluster) error {
var cmd string

if pkg != nil {
if err := p.scpFiles(cluster.Name, pkg, &node); err != nil {
if err := p.scpFiles(cluster.Name, pkg, &node, extraArgs); err != nil {
return err
}
cmd = k3sRestart
Expand All @@ -809,7 +809,7 @@ func (p *ProviderBase) Upgrade(cluster *types.Cluster) error {

var cmd string
if pkg != nil {
if err := p.scpFiles(cluster.Name, pkg, &node); err != nil {
if err := p.scpFiles(cluster.Name, pkg, &node, extraArgs); err != nil {
return err
}
cmd = k3sAgentRestart
Expand All @@ -835,14 +835,14 @@ func nodeByInstanceID(nodes []types.Node) map[string]types.Node {
return rtn
}

func (p *ProviderBase) scpFiles(clusterName string, pkg *common.Package, node *types.Node) error {
func (p *ProviderBase) scpFiles(clusterName string, pkg *common.Package, node *types.Node, extraArgs string) error {
dialer, err := hosts.NewSSHDialer(node, true, p.Logger)
if err != nil {
return err
}
defer dialer.Close()
dialer.SetWriter(p.Logger.Out)
return airgap.ScpFiles(clusterName, pkg, dialer)
return airgap.ScpFiles(p.Logger, clusterName, pkg, dialer, extraArgs)
}

func (p *ProviderBase) handleDataStoreCertificate(n *types.Node, c *types.Cluster) error {
Expand Down

0 comments on commit e93e779

Please sign in to comment.