Skip to content
This repository has been archived by the owner on Nov 20, 2024. It is now read-only.

Commit

Permalink
Merge pull request #1 from tedsmitt/dev
Browse files Browse the repository at this point in the history
Addition of unit tests
  • Loading branch information
Ed Smith authored Mar 26, 2021
2 parents e3d0f21 + 28d58e1 commit 57b2401
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 43 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@ jobs:
uses: actions/checkout@v2
with:
fetch-depth: 0

- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.15

- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2
with:
version: latest
args: build
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2
with:
Expand Down
30 changes: 20 additions & 10 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,27 @@ on:

jobs:
test:
strategy:
matrix:
go-version: [1.15.x, 1.16.x]
os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }}
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Install Go
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go-version }}
- name: Checkout code
uses: actions/checkout@v2
- name: Test
run: go test ./pkg
go-version: 1.15

#- name: Test
# run: |
# go mod tidy
# go test -v ./pkg/cmd

- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v2
with:
version: latest
args: --snapshot --skip-publish --rm-dist
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ require (
github.com/fatih/color v1.10.0
github.com/spf13/cobra v1.1.3
github.com/spf13/viper v1.7.1
github.com/stretchr/testify v1.3.0
)
32 changes: 15 additions & 17 deletions pkg/cmd/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,30 @@ package cmd
import (
"encoding/json"
"fmt"
"log"
"os"

"github.com/aws/aws-sdk-go/service/ecs/ecsiface"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/spf13/viper"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecs"
)

func StartExecuteCommand() {
func StartExecuteCommand(client ecsiface.ECSAPI) error {

clusterName, err := getCluster()
clusterName, err := getCluster(client)
if err != nil {
log.Fatalf(red(err))
return err
}
task, err := getTask(clusterName)
task, err := getTask(client, clusterName)
if err != nil {
log.Fatalf(red(err))
return err
}
container, err := getContainer(task)
if err != nil {
log.Fatalf(red(err))
return err
}

// Check if command has been passed to the tool, otherwise default
// to /bin/sh
var command string
Expand All @@ -44,25 +44,23 @@ func StartExecuteCommand() {
Container: container.Name,
})
if err != nil {
log.Fatal(err)
return err
}
execSess, err := json.Marshal(execCommand.Session)
if err != nil {
log.Fatal(err)
return err
}
target := ssm.StartSessionInput{
Target: aws.String(fmt.Sprintf("ecs:%s_%s_%s", "execCommand", *task.TaskArn, *container.RuntimeId)),
Target: aws.String(fmt.Sprintf("ecs:%s_%s_%s", clusterName, *task.TaskArn, *container.RuntimeId)),
}
targetJson, err := json.Marshal(target)
if err != nil {
log.Println(err)
return err
}

// Expecting session-manager-plugin to be found in $PATH
runCommand("session-manager-plugin", string(execSess),
region, "StartSession", "", string(targetJson), endpoint)
if err != nil {
log.Fatalf(err.Error())
if err = runCommand("session-manager-plugin", string(execSess), region, "StartSession", "", string(targetJson), endpoint); err != nil {
return err
}
os.Exit(0)
return nil
}
79 changes: 79 additions & 0 deletions pkg/cmd/execute_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package cmd

import (
"os"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecs"
"github.com/stretchr/testify/assert"
)

func init() {
os.Setenv("AWS_DEFAULT_REGION", "eu-west-1")
}

func (m *MockECSAPI) ExecuteCommand(input *ecs.ExecuteCommandInput) (*ecs.ExecuteCommandOutput, error) { // This allows the test to use the same method
if m.ExecuteCommandMock != nil {
return m.ExecuteCommandMock(input) // We intercept and return a made up reply
}
return nil, nil // return any value you think is good for you
}

func TestStartExecuteCommand(t *testing.T) {
cases := []struct {
name string
client *MockECSAPI
expected error
}{
{
name: "TestStartExecuteCommandWithClusters",
client: &MockECSAPI{
ListClustersMock: func(input *ecs.ListClustersInput) (*ecs.ListClustersOutput, error) {
return &ecs.ListClustersOutput{
ClusterArns: []*string{
aws.String("arn:aws:ecs:eu-west-1:1111111111:cluster/execCommand"),
aws.String("arn:aws:ecs:eu-west-1:1111111111:cluster/bluegreen"),
},
}, nil
},
ListTasksMock: func(input *ecs.ListTasksInput) (*ecs.ListTasksOutput, error) {
return &ecs.ListTasksOutput{
TaskArns: []*string{
aws.String("arn:aws:ecs:eu-west-1:111111111111:task/execCommand/8a58117dac38436ba5547e9da5d3ac3d"),
},
}, nil
},
DescribeTasksMock: func(input *ecs.DescribeTasksInput) (*ecs.DescribeTasksOutput, error) {
return &ecs.DescribeTasksOutput{
Tasks: []*ecs.Task{
{
TaskArn: aws.String("arn:aws:ecs:eu-west-1:111111111111:task/execCommand/8a58117dac38436ba5547e9da5d3ac3d"),
Containers: []*ecs.Container{
{
Name: aws.String("echo-server"),
RuntimeId: aws.String("8a58117dac38436ba5547e9da5d3ac3d-1527056392"),
},
},
},
},
}, nil
},
ExecuteCommandMock: func(input *ecs.ExecuteCommandInput) (*ecs.ExecuteCommandOutput, error) {
return &ecs.ExecuteCommandOutput{
Session: &ecs.Session{
SessionId: aws.String("ecs-execute-command-05b8e510e3433762c"),
StreamUrl: aws.String("wss://ssmmessages.eu-west-1.amazonaws.com/v1/data-channel/ecs-execute-command-05b8e510e3433762c?role=publish_subscribe"),
TokenValue: aws.String("ABCDEF123456"),
},
}, nil
},
},
expected: nil, // If we execute with the session details above then we actually get a clean exit from session-manager-plugin, so we don't expect an error
},
}
for _, c := range cases {
result := StartExecuteCommand(c.client)
assert.Equal(t, c.expected, result)
}
}
34 changes: 19 additions & 15 deletions pkg/cmd/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"errors"
"flag"
"fmt"
"os"
"os/exec"
Expand All @@ -13,11 +14,11 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ecs"
"github.com/aws/aws-sdk-go/service/ecs/ecsiface"
"github.com/fatih/color"
)

var (
client *ecs.ECS
region string
endpoint string

Expand All @@ -26,29 +27,24 @@ var (
yellow = color.New(color.FgYellow).SprintFunc()
)

func init() {
client = createEcsClient()
region = client.SigningRegion
endpoint = client.Endpoint
}

func createEcsClient() *ecs.ECS {
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))

client := ecs.New(sess)
region = client.SigningRegion
endpoint = client.Endpoint

return client
}

// Lists available clusters and prompts the user to select one
func getCluster() (string, error) {
func getCluster(client ecsiface.ECSAPI) (string, error) {
list, err := client.ListClusters(&ecs.ListClustersInput{})
if err != nil {
return "", err
}

var clusterName string
if len(list.ClusterArns) > 0 {
var clusterNames []string
Expand All @@ -70,31 +66,31 @@ func getCluster() (string, error) {
}

// Lists tasks in a cluster and prompts the user to select one
func getTask(clusterName string) (*ecs.Task, error) {
func getTask(client ecsiface.ECSAPI, clusterName string) (*ecs.Task, error) {
list, err := client.ListTasks(&ecs.ListTasksInput{
Cluster: aws.String(clusterName),
})
if err != nil {
return nil, err
return &ecs.Task{}, err
}
if len(list.TaskArns) > 0 {
describe, err := client.DescribeTasks(&ecs.DescribeTasksInput{
Cluster: aws.String(clusterName),
Tasks: list.TaskArns,
})
if err != nil {
return nil, err
return &ecs.Task{}, err
}
// Ask the user to select which Task to connect to
selection, err := selectTask(describe.Tasks)
if err != nil {
return nil, err
return &ecs.Task{}, err
}
task := selection
return task, nil
} else {
err := errors.New(fmt.Sprintf("There are no running tasks in the cluster %s", clusterName))
return nil, err
return &ecs.Task{}, err
}
}

Expand All @@ -116,6 +112,9 @@ func getContainer(task *ecs.Task) (*ecs.Container, error) {

// selectCluster provides the prompt for choosing a cluster
func selectCluster(clusterNames []string) (string, error) {
if flag.Lookup("test.v") != nil {
return clusterNames[0], nil
}
var clusterName string
var qs = []*survey.Question{
{
Expand All @@ -138,6 +137,9 @@ func selectCluster(clusterNames []string) (string, error) {

// selectTask provides the prompt for choosing a Task
func selectTask(tasks []*ecs.Task) (*ecs.Task, error) {
if flag.Lookup("test.v") != nil {
return tasks[0], nil
}
var options []string
for _, t := range tasks {
var containers []string
Expand Down Expand Up @@ -181,6 +183,9 @@ func selectTask(tasks []*ecs.Task) (*ecs.Task, error) {

// selectContainer prompts the user to choose a container within a task
func selectContainer(containers []*ecs.Container) (*ecs.Container, error) {
if flag.Lookup("test.v") != nil {
return containers[0], nil
}
var containerNames []string
for _, c := range containers {
containerNames = append(containerNames, *c.Name)
Expand Down Expand Up @@ -227,7 +232,6 @@ func runCommand(process string, args ...string) error {
for {
select {
case <-sigs:
os.Exit(0)
}
}
}()
Expand Down
Loading

0 comments on commit 57b2401

Please sign in to comment.