From cc876a6e0f9d4188e31d5f2ea9ee4ef5798ea45d Mon Sep 17 00:00:00 2001 From: Marin Nozhchev Date: Thu, 2 May 2024 19:44:41 +0300 Subject: [PATCH] Dereference fix in common copy implementation --- drivers/clickhouse/clickhouse.go | 77 +------------------- drivers/drivers.go | 20 +++--- drivers/drivers_test.go | 117 ++++++++++++++++++++----------- drivers/testdata/csvq/.gitignore | 1 + drivers/testdata/csvq/staff.csv | 2 + 5 files changed, 93 insertions(+), 124 deletions(-) create mode 100644 drivers/testdata/csvq/.gitignore create mode 100644 drivers/testdata/csvq/staff.csv diff --git a/drivers/clickhouse/clickhouse.go b/drivers/clickhouse/clickhouse.go index 59ccc6f1dd6..b397736f571 100644 --- a/drivers/clickhouse/clickhouse.go +++ b/drivers/clickhouse/clickhouse.go @@ -5,10 +5,7 @@ package clickhouse import ( - "context" "database/sql" - "fmt" - "reflect" "strconv" "strings" @@ -38,79 +35,7 @@ func init() { } return false }, - Copy: CopyWithInsert, + Copy: drivers.CopyWithInsert(func(int) string { return "?" }), NewMetadataReader: NewMetadataReader, }) } - -// CopyWithInsert builds a copy handler based on insert. -func CopyWithInsert(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) { - columns, err := rows.Columns() - if err != nil { - return 0, fmt.Errorf("failed to fetch source rows columns: %w", err) - } - clen := len(columns) - query := table - if !strings.HasPrefix(strings.ToLower(query), "insert into") { - leftParen := strings.IndexRune(table, '(') - if leftParen == -1 { - colRows, err := db.QueryContext(ctx, "SELECT * FROM "+table+" WHERE 1=0") - if err != nil { - return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err) - } - columns, err := colRows.Columns() - _ = colRows.Close() - if err != nil { - return 0, fmt.Errorf("failed to fetch target table columns: %w", err) - } - table += "(" + strings.Join(columns, ", ") + ")" - } - query = "INSERT INTO " + table + " VALUES (" + strings.Repeat("?, ", clen-1) + "?)" - } - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return 0, fmt.Errorf("failed to begin transaction: %w", err) - } - stmt, err := tx.PrepareContext(ctx, query) - if err != nil { - return 0, fmt.Errorf("failed to prepare insert query: %w", err) - } - defer stmt.Close() - columnTypes, err := rows.ColumnTypes() - if err != nil { - return 0, fmt.Errorf("failed to fetch source column types: %w", err) - } - values := make([]interface{}, clen) - valueRefs := make([]reflect.Value, clen) - actuals := make([]interface{}, clen) - for i := 0; i < len(columnTypes); i++ { - valueRefs[i] = reflect.New(columnTypes[i].ScanType()) - values[i] = valueRefs[i].Interface() - } - var n int64 - for rows.Next() { - err = rows.Scan(values...) - if err != nil { - return n, fmt.Errorf("failed to scan row: %w", err) - } - //We can't use values... in Exec() below, because, in some cases, clickhouse - //driver doesn't accept pointer to an argument instead of the arg itself. - for i := range values { - actuals[i] = valueRefs[i].Elem().Interface() - } - res, err := stmt.ExecContext(ctx, actuals...) - if err != nil { - return n, fmt.Errorf("failed to exec insert: %w", err) - } - rn, err := res.RowsAffected() - if err != nil { - return n, fmt.Errorf("failed to check rows affected: %w", err) - } - n += rn - } - err = tx.Commit() - if err != nil { - return n, fmt.Errorf("failed to commit transaction: %w", err) - } - return n, rows.Err() -} diff --git a/drivers/drivers.go b/drivers/drivers.go index 76fd9466c0f..2a721a9c0d4 100644 --- a/drivers/drivers.go +++ b/drivers/drivers.go @@ -540,16 +540,12 @@ func CopyWithInsert(placeholder func(int) string) func(ctx context.Context, db * if !strings.HasPrefix(strings.ToLower(query), "insert into") { leftParen := strings.IndexRune(table, '(') if leftParen == -1 { - colStmt, err := db.PrepareContext(ctx, "SELECT * FROM "+table+" WHERE 1=0") - if err != nil { - return 0, fmt.Errorf("failed to prepare query to determine target table columns: %w", err) - } - defer colStmt.Close() - colRows, err := colStmt.QueryContext(ctx) + colRows, err := db.QueryContext(ctx, "SELECT * FROM "+table+" WHERE 1=0") if err != nil { return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err) } columns, err := colRows.Columns() + _ = colRows.Close() if err != nil { return 0, fmt.Errorf("failed to fetch target table columns: %w", err) } @@ -576,8 +572,11 @@ func CopyWithInsert(placeholder func(int) string) func(ctx context.Context, db * return 0, fmt.Errorf("failed to fetch source column types: %w", err) } values := make([]interface{}, clen) + valueRefs := make([]reflect.Value, clen) + actuals := make([]interface{}, clen) for i := 0; i < len(columnTypes); i++ { - values[i] = reflect.New(columnTypes[i].ScanType()).Interface() + valueRefs[i] = reflect.New(columnTypes[i].ScanType()) + values[i] = valueRefs[i].Interface() } var n int64 for rows.Next() { @@ -585,7 +584,12 @@ func CopyWithInsert(placeholder func(int) string) func(ctx context.Context, db * if err != nil { return n, fmt.Errorf("failed to scan row: %w", err) } - res, err := stmt.ExecContext(ctx, values...) + //We can't use values... in Exec() below, because some drivers + //don't accept pointer to an argument instead of the arg itself. + for i := range values { + actuals[i] = valueRefs[i].Elem().Interface() + } + res, err := stmt.ExecContext(ctx, actuals...) if err != nil { return n, fmt.Errorf("failed to exec insert: %w", err) } diff --git a/drivers/drivers_test.go b/drivers/drivers_test.go index 3649b002998..ead41f786ee 100644 --- a/drivers/drivers_test.go +++ b/drivers/drivers_test.go @@ -115,6 +115,10 @@ var ( DSN: "trino://test@localhost:%s/tpch/sf1", DockerPort: "8080/tcp", }, + "csvq": { + // go test sets working directory to current package regardless of initial working directory + DSN: "csvq://./testdata/csvq", + }, } cleanup bool ) @@ -144,30 +148,21 @@ func TestMain(m *testing.M) { } for dbName, db := range dbs { - var ok bool - db.Resource, ok = pool.ContainerByName(db.RunOptions.Name) - if !ok { - buildOpts := &dt.BuildOptions{ - ContextDir: "./testdata/docker", - BuildArgs: db.BuildArgs, - } - db.Resource, err = pool.BuildAndRunWithBuildOptions(buildOpts, db.RunOptions) - if err != nil { - log.Fatalf("Could not start %s: %s", dbName, err) - } - } - - hostPort := db.Resource.GetPort(db.DockerPort) - db.URL, err = dburl.Parse(fmt.Sprintf(db.DSN, hostPort)) + dsn, hostPort := getConnInfo(dbName, db, pool) + db.URL, err = dburl.Parse(dsn) if err != nil { log.Fatalf("Failed to parse %s URL %s: %v", dbName, db.DSN, err) } if len(db.Exec) != 0 { + readyDSN := db.ReadyDSN if db.ReadyDSN == "" { - db.ReadyDSN = db.DSN + readyDSN = db.DSN + } + if hostPort != "" { + readyDSN = fmt.Sprintf(db.ReadyDSN, hostPort) } - readyURL, err := dburl.Parse(fmt.Sprintf(db.ReadyDSN, hostPort)) + readyURL, err := dburl.Parse(readyDSN) if err != nil { log.Fatalf("Failed to parse %s ready URL %s: %v", dbName, db.ReadyDSN, err) } @@ -205,8 +200,10 @@ func TestMain(m *testing.M) { // You can't defer this because os.Exit doesn't care for defer if cleanup { for _, db := range dbs { - if err := pool.Purge(db.Resource); err != nil { - log.Fatal("Could not purge resource: ", err) + if db.Resource != nil { + if err := pool.Purge(db.Resource); err != nil { + log.Fatal("Could not purge resource: ", err) + } } } } @@ -214,6 +211,35 @@ func TestMain(m *testing.M) { os.Exit(code) } +func getConnInfo(dbName string, db *Database, pool *dt.Pool) (string, string) { + if db.RunOptions == nil { + return db.DSN, "" + } + + var ok bool + db.Resource, ok = pool.ContainerByName(db.RunOptions.Name) + if ok && !db.Resource.Container.State.Running { + err := db.Resource.Close() + if err != nil { + log.Fatalf("Failed to clean up stale container %s: %s", dbName, err) + } + ok = false + } + if !ok { + buildOpts := &dt.BuildOptions{ + ContextDir: "./testdata/docker", + BuildArgs: db.BuildArgs, + } + var err error + db.Resource, err = pool.BuildAndRunWithBuildOptions(buildOpts, db.RunOptions) + if err != nil { + log.Fatalf("Failed to start %s: %s", dbName, err) + } + } + hostPort := db.Resource.GetPort(db.DockerPort) + return fmt.Sprintf(db.DSN, hostPort), hostPort +} + func TestWriter(t *testing.T) { type testFunc struct { label string @@ -467,6 +493,14 @@ func TestCopy(t *testing.T) { src: "select first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff", dest: "staff_copy(first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)", }, + { + dbName: "csvq", + setupQueries: []setupQuery{ + {query: "CREATE TABLE IF NOT EXISTS staff_copy AS SELECT * FROM `staff.csv` WHERE 0=1", check: true}, + }, + src: "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff", + dest: "staff_copy", + }, } for _, test := range testCases { db, ok := dbs[test.dbName] @@ -474,30 +508,33 @@ func TestCopy(t *testing.T) { continue } - // TODO test copy from a different DB, maybe csvq? - // TODO test copy from same DB + t.Run(test.dbName, func(t *testing.T) { + + // TODO test copy from a different DB, maybe csvq? + // TODO test copy from same DB - for _, q := range test.setupQueries { - _, err := db.DB.Exec(q.query) - if q.check && err != nil { - log.Fatalf("Failed to run setup query `%s`: %v", q.query, err) + for _, q := range test.setupQueries { + _, err := db.DB.Exec(q.query) + if q.check && err != nil { + t.Fatalf("Failed to run setup query `%s`: %v", q.query, err) + } + } + rows, err := pg.DB.Query(test.src) + if err != nil { + t.Fatalf("Could not get rows to copy: %v", err) } - } - rows, err := pg.DB.Query(test.src) - if err != nil { - log.Fatalf("Could not get rows to copy: %v", err) - } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - var rlen int64 = 1 - n, err := drivers.Copy(ctx, db.URL, nil, nil, rows, test.dest) - if err != nil { - log.Fatalf("Could not copy: %v", err) - } - if n != rlen { - log.Fatalf("Expected to copy %d rows but got %d", rlen, n) - } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var rlen int64 = 1 + n, err := drivers.Copy(ctx, db.URL, nil, nil, rows, test.dest) + if err != nil { + t.Fatalf("Could not copy: %v", err) + } + if n != rlen { + t.Fatalf("Expected to copy %d rows but got %d", rlen, n) + } + }) } } diff --git a/drivers/testdata/csvq/.gitignore b/drivers/testdata/csvq/.gitignore new file mode 100644 index 00000000000..8ad2a688ce3 --- /dev/null +++ b/drivers/testdata/csvq/.gitignore @@ -0,0 +1 @@ +*_copy diff --git a/drivers/testdata/csvq/staff.csv b/drivers/testdata/csvq/staff.csv new file mode 100644 index 00000000000..454999dfe7e --- /dev/null +++ b/drivers/testdata/csvq/staff.csv @@ -0,0 +1,2 @@ +first_name,last_name,address_id,email,store_id,active,username,password,last_update +John,Doe,1,john@invalid.com,1,true,jdoe,abc,2024-05-10T08:12:05.46875Z