Skip to content

Commit

Permalink
Refactor NewStandardSQLTable to input columns and add test cases for …
Browse files Browse the repository at this point in the history
…sql_table
  • Loading branch information
Racso-3141 committed Oct 8, 2023
1 parent 8e36e6e commit d3adfb6
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 36 deletions.
6 changes: 3 additions & 3 deletions internal/stackql/datasource/sql_datasource/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
_ "github.com/snowflakedb/gosnowflake" //nolint:revive,nolintlint // this is a DB driver pattern

"github.com/stackql/stackql/internal/stackql/constants"
"github.com/stackql/stackql/internal/stackql/datasource/sql_table"
"github.com/stackql/stackql/internal/stackql/datasource/sqltable"
"github.com/stackql/stackql/internal/stackql/db_util"
"github.com/stackql/stackql/internal/stackql/dto"
)
Expand Down Expand Up @@ -67,11 +67,11 @@ func (ds *genericSQLDataSource) Begin() (*sql.Tx, error) {
return ds.db.Begin()
}

func (ds *genericSQLDataSource) GetTableMetadata(args ...string) (sql_table.SQLTable, error) {
func (ds *genericSQLDataSource) GetTableMetadata(args ...string) (sqltable.SQLTable, error) {
return nil, fmt.Errorf("could not obtain sql data source table metadata for args = '%v'", args)
}

// func (ds *genericSQLDataSource) getPostgresTableMetadata(schemaName, tableName string) (sql_table.SQLTable, error) {
// func (ds *genericSQLDataSource) getPostgresTableMetadata(schemaName, tableName string) (sqltable.SQLTable, error) {
// queryTemplate := `
// SELECT
// column_name,
Expand Down
4 changes: 2 additions & 2 deletions internal/stackql/datasource/sql_datasource/sql_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"fmt"

"github.com/stackql/stackql/internal/stackql/constants"
"github.com/stackql/stackql/internal/stackql/datasource/sql_table"
"github.com/stackql/stackql/internal/stackql/datasource/sqltable"
"github.com/stackql/stackql/internal/stackql/dto"
)

Expand All @@ -14,7 +14,7 @@ type SQLDataSource interface {
Exec(string, ...interface{}) (sql.Result, error)
Query(string, ...interface{}) (*sql.Rows, error)
QueryRow(string, ...any) *sql.Row
GetTableMetadata(...string) (sql_table.SQLTable, error)
GetTableMetadata(...string) (sqltable.SQLTable, error)
GetSchemaType() string
GetDBName() string
}
Expand Down
31 changes: 0 additions & 31 deletions internal/stackql/datasource/sql_table/sql_table.go

This file was deleted.

34 changes: 34 additions & 0 deletions internal/stackql/datasource/sqltable/sqltable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package sqltable

import (
"github.com/stackql/stackql/internal/stackql/symtab"
"github.com/stackql/stackql/internal/stackql/typing"
)

type SQLTable interface {
GetColumns() []typing.RelationalColumn
GetSymTab() symtab.SymTab
}

type StandardSQLTable struct {
symTab symtab.SymTab
columns []typing.RelationalColumn
}

func NewStandardSQLTable(relationalColumns []typing.RelationalColumn) (SQLTable, error) {
copiedSlice := make([]typing.RelationalColumn, len(relationalColumns))
copy(copiedSlice, relationalColumns)
rv := &StandardSQLTable{
symTab: symtab.NewHashMapTreeSymTab(),
columns: copiedSlice,
}
return rv, nil
}

func (sqt *StandardSQLTable) GetSymTab() symtab.SymTab {
return sqt.symTab
}

func (sqt *StandardSQLTable) GetColumns() []typing.RelationalColumn {
return sqt.columns
}
130 changes: 130 additions & 0 deletions internal/stackql/datasource/sqltable/sqltable_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package sqltable_test

import (
"math/rand"
"reflect"
"testing"
"time"

"github.com/stackql/stackql/internal/stackql/datasource/sqltable"
"github.com/stackql/stackql/internal/stackql/symtab"
"github.com/stackql/stackql/internal/stackql/typing"
)

func randString() string {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"
maxLength := 256
stringLength := r.Intn(maxLength + 1) // Randomly decide the length of the string
s := make([]byte, stringLength)
for i := range s {
s[i] = letters[r.Intn(len(letters))]
}
return string(s)
}

func generateRandomColumns(n int) []typing.RelationalColumn {
columns := make([]typing.RelationalColumn, n)
for i := range columns {
// Assuming RelationalColumn is a type like string for simplicity
columns[i] = typing.NewRelationalColumn(randString(), randString()) // Generate a random string of length 10
}
return columns
}

func TestNewStandardSQLTable(t *testing.T) {
table, err := sqltable.NewStandardSQLTable(generateRandomColumns(10))
if err != nil {
t.Fatalf("Expected no error, but got: %v", err)
}

if table == nil {
t.Fatal("Expected table to be non-nil")
}

_, ok := table.(*sqltable.StandardSQLTable)
if !ok {
t.Fatal("Expected table to be of type *standardSQLTable")
}
}

func TestGetSymTab(t *testing.T) {
columns := generateRandomColumns(10)
table, _ := sqltable.NewStandardSQLTable(columns)

// Initialize symTab with different values
symTab := table.GetSymTab()

// Set symbols in the symTab
err := symTab.SetSymbol("testKey", symtab.NewSymTabEntry("testType", "testData", "testIn"))
if err != nil {
t.Fatalf("Failed to set symbol: %v", err)
}

// Test if the symbol was set correctly
entry, exists := symTab.GetSymbol("testKey")
if exists != nil {
t.Fatalf("Symbol not found in symTab")
}
if !reflect.DeepEqual(entry, symtab.NewSymTabEntry("testType", "testData", "testIn")) {
t.Fatalf("Symbol not set correctly in symTab")
}

// Create a new leaf and set symbols in it
leafSymTab, err := symTab.NewLeaf("testLeafKey")
if err != nil {
t.Fatalf("Failed to create new leaf: %v", err)
}
err = leafSymTab.SetSymbol("leafKey", symtab.NewSymTabEntry("leafType", "leafData", "leafIn"))
if err != nil {
t.Fatalf("Failed to set symbol in leaf: %v", err)
}

// Test if the symbol was set correctly in the leaf
entry, exists = leafSymTab.GetSymbol("leafKey")
if exists != nil {
t.Fatalf("Symbol not found in leafSymTab")
}
if !reflect.DeepEqual(entry, symtab.NewSymTabEntry("leafType", "leafData", "leafIn")) {
t.Fatalf("Symbol not set correctly in leafSymTab")
}
}

func TestGetColumns(t *testing.T) {
testCases := []struct {
name string
numColumns int
}{
{
name: "Test with 0 columns",
numColumns: 0,
},
{
name: "Test with 5 columns",
numColumns: 5,
},
{
name: "Test with 10 columns",
numColumns: 10,
},
{
name: "Test with 15 columns",
numColumns: 15,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
inputColumns := generateRandomColumns(tc.numColumns)
table, err := sqltable.NewStandardSQLTable(inputColumns)
if err != nil {
t.Fatalf("Expected no error, but got: %v", err)
}

returnedColumns := table.GetColumns()
if !reflect.DeepEqual(returnedColumns, inputColumns) {
t.Fatalf("Expected columns %v, but got %v", inputColumns, returnedColumns)
}
})
}
}

0 comments on commit d3adfb6

Please sign in to comment.