Skip to content

Commit

Permalink
fix: primary key length
Browse files Browse the repository at this point in the history
  • Loading branch information
Seann-Moser committed Nov 14, 2024
1 parent 5826cda commit 3ad720b
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 4 deletions.
6 changes: 4 additions & 2 deletions column.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ func GetAllNumbersAsInt(input string) ([]int, error) {
// Calculates the byte length of the column based on charset and type
func (c *Column) GetByteLength() int {
// Handle variable-length types like CHAR and VARCHAR with overrides
if strings.HasPrefix(c.Type, "CHAR") || strings.HasPrefix(c.Type, "VARCHAR") {
t := strings.ToLower(c.Type)
if strings.HasPrefix(t, "char") || strings.HasPrefix(t, "varchar") {
// Multiply length by charset bytes for CHAR and VARCHAR
l, _ := GetAllNumbersAsInt(c.Type)
baseLength := 256
Expand Down Expand Up @@ -121,7 +122,8 @@ func (col *Column) GetDefinition() string {
definition += " AUTO_INCREMENT"
}
// Add charset if specified and relevant to the type (e.g., CHAR, VARCHAR, TEXT)
if col.Charset != "" && (strings.HasPrefix(col.Type, "CHAR") || strings.HasPrefix(col.Type, "VARCHAR") || col.Type == "TEXT" || col.Type == "JSON") {
t := strings.ToLower(col.Type)
if col.Charset != "" && (strings.HasPrefix(t, "char") || strings.HasPrefix(t, "varchar") || t == "text" || t == "json") {
definition += fmt.Sprintf(" CHARACTER SET %s", col.Charset)
}
return definition
Expand Down
15 changes: 14 additions & 1 deletion query_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

type Resource struct {
ID string `json:"id" db:"id" qc:"primary;join;join_name::resource_id;group_by_modifier::count"` // ID ("resource.*")
ID string `json:"id" db:"id" qc:"primary;join;join_name::resource_id;data_type::varchar(1024);charset::utf8"` // ID ("resource.*")
Description string `json:"description" db:"description" qc:"data_type::varchar(512);update"`
ResourceType string `json:"resource_type" db:"resource_type" qc:"update"` // ResourceType "url"
Data string `json:"data" db:"data" qc:"update;text"`
Expand All @@ -18,6 +18,19 @@ type Resource struct {
CreatedTimestamp string `json:"created_timestamp" db:"created_timestamp" qc:"skip;default::created_timestamp"`
}

func TestSqlDB_CreateTable(t *testing.T) {
table, err := NewTable[Resource]("test", QueryTypeSQL)
if err != nil {
t.Fatal(err)
}
db := SqlDB{}
schema, tableCreate, err := db.BuildCreateTableQueries("test", "resource", table.Columns)
if err != nil {
t.Fatal(err)
}
println(schema)
println(tableCreate)
}
func TestQuery_Build(t *testing.T) {
table, err := NewTable[Resource]("test", QueryTypeSQL)
if err != nil {
Expand Down
71 changes: 70 additions & 1 deletion sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"go.uber.org/zap"
"os"
"sort"
"strconv"
"strings"
"time"
)
Expand Down Expand Up @@ -69,6 +70,7 @@ func (s *SqlDB) BuildCreateTableQueries(dataset, table string, columns map[strin
sort.Slice(cols, func(i, j int) bool {
return cols[i].ColumnOrder < cols[j].ColumnOrder
})
var primaryKeyColumns []Column

// Build column definitions
for _, column := range cols {
Expand All @@ -83,9 +85,12 @@ func (s *SqlDB) BuildCreateTableQueries(dataset, table string, columns map[strin
}
if column.Primary {
primaryKeys = append(primaryKeys, column.Name)
primaryKeyColumns = append(primaryKeyColumns, column)
}
}

if _, err := s.CheckPrimaryKeyLength(primaryKeyColumns); err != nil {
return "", "", err
}
// Handle primary keys
if len(primaryKeys) == 0 {
return "", "", MissingPrimaryKeyErr
Expand Down Expand Up @@ -362,13 +367,44 @@ func (s *SqlDB) GetTableIndexes(database, tableName string) ([]IndexInfo, error)
return indexes, nil
}
func (s *SqlDB) Version() string {
if s.sql == nil {
return "8.0.40"
}
v, err := GetMySQLVersion(s.sql)
if err != nil {
return "unknown"
}
return v
}

const defaultMaxPrimaryKeyLength = 767

// CheckPrimaryKeyLength checks if the combined byte length of primary key columns exceeds the limit
func (s *SqlDB) CheckPrimaryKeyLength(columns []Column) (bool, error) {
// Get MySQL version to adjust max primary key length if needed
version := s.Version()
maxPrimaryKeyLength := defaultMaxPrimaryKeyLength

// Example check for newer MySQL versions (adjust as needed for version-specific handling)
if CompareVersions(version, "8.0.17") > 0 {
maxPrimaryKeyLength = 3072 // Increased max length for MySQL 8.0.17+ with InnoDB and utf8mb4
}

// Calculate total byte length of the primary key columns
totalLength := 0
for _, col := range columns {
if col.Primary {
totalLength += col.GetByteLength()
}
}

// Check if total length exceeds the maximum allowed length
if totalLength > maxPrimaryKeyLength {
return true, fmt.Errorf("primary key length exceeds the maximum allowed length of %d bytes", maxPrimaryKeyLength)
}
return false, nil
}

// GetMySQLVersion retrieves the MySQL version from the database.
func GetMySQLVersion(db *sqlx.DB) (string, error) {
var version string
Expand All @@ -378,3 +414,36 @@ func GetMySQLVersion(db *sqlx.DB) (string, error) {
}
return version, nil
}

func CompareVersions(version1, version2 string) int {
v1Parts := strings.Split(version1, ".")
v2Parts := strings.Split(version2, ".")

// Compare each part numerically
maxParts := len(v1Parts)
if len(v2Parts) > maxParts {
maxParts = len(v2Parts)
}

for i := 0; i < maxParts; i++ {
var v1, v2 int

// Convert the current part to an integer or assume 0 if part is missing
if i < len(v1Parts) {
v1, _ = strconv.Atoi(v1Parts[i])
}
if i < len(v2Parts) {
v2, _ = strconv.Atoi(v2Parts[i])
}

// Compare the individual parts
if v1 > v2 {
return 1
} else if v1 < v2 {
return -1
}
}

// Versions are equal
return 0
}

0 comments on commit 3ad720b

Please sign in to comment.