Skip to content

Commit

Permalink
refactor: replace schema struct with interface to enable feature flag…
Browse files Browse the repository at this point in the history
…ging (#5426)
  • Loading branch information
shekhar-rudder authored Jan 16, 2025
1 parent 3c7106e commit cd4a7e1
Showing 4 changed files with 62 additions and 36 deletions.
2 changes: 1 addition & 1 deletion warehouse/router/upload.go
Original file line number Diff line number Diff line change
@@ -70,7 +70,7 @@ type UploadJob struct {
stagingFileRepo *repo.StagingFiles
loadFilesRepo *repo.LoadFiles
whManager manager.Manager
schemaHandle *schema.Schema
schemaHandle schema.Handler
conf *config.Config
logger logger.Logger
statsFactory stats.Stats
28 changes: 20 additions & 8 deletions warehouse/router/upload_test.go
Original file line number Diff line number Diff line change
@@ -19,10 +19,10 @@ import (

backendconfig "github.com/rudderlabs/rudder-server/backend-config"
"github.com/rudderlabs/rudder-server/services/alerta"
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper"
"github.com/rudderlabs/rudder-server/warehouse/integrations/redshift"
"github.com/rudderlabs/rudder-server/warehouse/internal/model"
"github.com/rudderlabs/rudder-server/warehouse/schema"
warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

@@ -127,17 +127,31 @@ func TestColumnCountStat(t *testing.T) {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()
conf := config.New()
conf.Set(fmt.Sprintf("Warehouse.%s.columnCountLimit", strings.ToLower(warehouseutils.WHDestNameMap[tc.destinationType])), tc.columnCountLimit)

j := UploadJob{
conf: conf,
upload: model.Upload{
pool, err := dockertest.NewPool("")
require.NoError(t, err)

pgResource, err := postgres.Setup(pool, t)
require.NoError(t, err)

uploadJobFactory := &UploadJobFactory{
logger: logger.NOP,
statsFactory: statsStore,
conf: conf,
db: sqlmiddleware.New(pgResource.DB),
}
whManager, err := manager.New(warehouseutils.POSTGRES, conf, logger.NOP, statsStore)
require.NoError(t, err)
j := uploadJobFactory.NewUploadJob(context.Background(), &model.UploadJob{
Upload: model.Upload{
WorkspaceID: workspaceID,
DestinationID: destinationID,
SourceID: sourceID,
},
warehouse: model.Warehouse{
Warehouse: model.Warehouse{
Type: tc.destinationType,
Destination: backendconfig.DestinationT{
ID: destinationID,
@@ -148,9 +162,7 @@ func TestColumnCountStat(t *testing.T) {
Name: sourceName,
},
},
statsFactory: statsStore,
schemaHandle: &schema.Schema{}, // TODO use constructor
}
}, whManager)
j.schemaHandle.UpdateWarehouseTableSchema(tableName, model.TableSchema{
"test-column-1": "string",
"test-column-2": "string",
50 changes: 32 additions & 18 deletions warehouse/schema/schema.go
Original file line number Diff line number Diff line change
@@ -45,7 +45,21 @@ type fetchSchemaRepo interface {
FetchSchema(ctx context.Context) (model.Schema, error)
}

type Schema struct {
type Handler interface {
SyncRemoteSchema(ctx context.Context, fetchSchemaRepo fetchSchemaRepo, uploadID int64) (bool, error)
IsWarehouseSchemaEmpty() bool
GetTableSchemaInWarehouse(tableName string) model.TableSchema
GetLocalSchema(ctx context.Context) (model.Schema, error)
UpdateLocalSchema(ctx context.Context, updatedSchema model.Schema) error
UpdateWarehouseTableSchema(tableName string, tableSchema model.TableSchema)
GetColumnsCountInWarehouseSchema(tableName string) int
ConsolidateStagingFilesUsingLocalSchema(ctx context.Context, stagingFiles []*model.StagingFile) (model.Schema, error)
UpdateLocalSchemaWithWarehouse(ctx context.Context) error
TableSchemaDiff(tableName string, tableSchema model.TableSchema) whutils.TableSchemaDiff
FetchSchemaFromWarehouse(ctx context.Context, repo fetchSchemaRepo) error
}

type schema struct {
warehouse model.Warehouse
schemaRepo schemaRepo
stagingFileRepo stagingFileRepo
@@ -69,8 +83,8 @@ func New(
conf *config.Config,
logger logger.Logger,
statsFactory stats.Stats,
) *Schema {
s := &Schema{
) Handler {
s := &schema{
warehouse: warehouse,
schemaRepo: repo.NewWHSchemas(db),
stagingFileRepo: repo.NewStagingFiles(db),
@@ -95,7 +109,7 @@ func New(
// 4. Enhances the consolidated schema with discards schema
// 5. Enhances the consolidated schema with ID resolution schema
// 6. Returns the consolidated schema
func (sh *Schema) ConsolidateStagingFilesUsingLocalSchema(ctx context.Context, stagingFiles []*model.StagingFile) (model.Schema, error) {
func (sh *schema) ConsolidateStagingFilesUsingLocalSchema(ctx context.Context, stagingFiles []*model.StagingFile) (model.Schema, error) {
consolidatedSchema := model.Schema{}
batches := lo.Chunk(stagingFiles, sh.stagingFilesSchemaPaginationSize)
for _, batch := range batches {
@@ -244,24 +258,24 @@ func enhanceSchemaWithIDResolution(consolidatedSchema model.Schema, isIDResoluti
return consolidatedSchema
}

func (sh *Schema) isIDResolutionEnabled() bool {
func (sh *schema) isIDResolutionEnabled() bool {
return sh.enableIDResolution && slices.Contains(whutils.IdentityEnabledWarehouses, sh.warehouse.Type)
}

func (sh *Schema) UpdateLocalSchemaWithWarehouse(ctx context.Context) error {
func (sh *schema) UpdateLocalSchemaWithWarehouse(ctx context.Context) error {
sh.schemaInWarehouseMu.RLock()
defer sh.schemaInWarehouseMu.RUnlock()
return sh.updateLocalSchema(ctx, sh.schemaInWarehouse)
}

func (sh *Schema) UpdateLocalSchema(ctx context.Context, updatedSchema model.Schema) error {
func (sh *schema) UpdateLocalSchema(ctx context.Context, updatedSchema model.Schema) error {
return sh.updateLocalSchema(ctx, updatedSchema)
}

// updateLocalSchema
// 1. Inserts the updated schema into the local schema table
// 2. Updates the local schema instance
func (sh *Schema) updateLocalSchema(ctx context.Context, updatedSchema model.Schema) error {
func (sh *schema) updateLocalSchema(ctx context.Context, updatedSchema model.Schema) error {
updatedSchemaInBytes, err := json.Marshal(updatedSchema)
if err != nil {
return fmt.Errorf("marshaling schema: %w", err)
@@ -292,7 +306,7 @@ func (sh *Schema) updateLocalSchema(ctx context.Context, updatedSchema model.Sch
// 3. Initialize local schema
// 4. Updates local schema with warehouse schema if it has changed
// 5. Returns true if schema has changed
func (sh *Schema) SyncRemoteSchema(ctx context.Context, fetchSchemaRepo fetchSchemaRepo, uploadID int64) (bool, error) {
func (sh *schema) SyncRemoteSchema(ctx context.Context, fetchSchemaRepo fetchSchemaRepo, uploadID int64) (bool, error) {
localSchema, err := sh.GetLocalSchema(ctx)
if err != nil {
return false, fmt.Errorf("fetching schema from local: %w", err)
@@ -321,7 +335,7 @@ func (sh *Schema) SyncRemoteSchema(ctx context.Context, fetchSchemaRepo fetchSch
}

// GetLocalSchema returns the local schema from wh_schemas table
func (sh *Schema) GetLocalSchema(ctx context.Context) (model.Schema, error) {
func (sh *schema) GetLocalSchema(ctx context.Context) (model.Schema, error) {
whSchema, err := sh.schemaRepo.GetForNamespace(
ctx,
sh.warehouse.Source.ID,
@@ -341,7 +355,7 @@ func (sh *Schema) GetLocalSchema(ctx context.Context) (model.Schema, error) {
// 1. Fetches schema from warehouse
// 2. Removes deprecated columns from schema
// 3. Updates local warehouse schema and unrecognized schema instance
func (sh *Schema) FetchSchemaFromWarehouse(ctx context.Context, repo fetchSchemaRepo) error {
func (sh *schema) FetchSchemaFromWarehouse(ctx context.Context, repo fetchSchemaRepo) error {
warehouseSchema, err := repo.FetchSchema(ctx)
if err != nil {
return fmt.Errorf("fetching schema: %w", err)
@@ -356,7 +370,7 @@ func (sh *Schema) FetchSchemaFromWarehouse(ctx context.Context, repo fetchSchema
}

// removeDeprecatedColumns skips deprecated columns from the schema map
func (sh *Schema) removeDeprecatedColumns(schema model.Schema) {
func (sh *schema) removeDeprecatedColumns(schema model.Schema) {
for tableName, columnMap := range schema {
for columnName := range columnMap {
if deprecatedColumnsRegex.MatchString(columnName) {
@@ -376,12 +390,12 @@ func (sh *Schema) removeDeprecatedColumns(schema model.Schema) {
}

// hasSchemaChanged compares the localSchema with the schemaInWarehouse
func (sh *Schema) hasSchemaChanged(localSchema model.Schema) bool {
func (sh *schema) hasSchemaChanged(localSchema model.Schema) bool {
return !reflect.DeepEqual(localSchema, sh.schemaInWarehouse)
}

// TableSchemaDiff returns the diff between the warehouse schema and the upload schema
func (sh *Schema) TableSchemaDiff(tableName string, tableSchema model.TableSchema) whutils.TableSchemaDiff {
func (sh *schema) TableSchemaDiff(tableName string, tableSchema model.TableSchema) whutils.TableSchemaDiff {
diff := whutils.TableSchemaDiff{
ColumnMap: make(model.TableSchema),
UpdatedSchema: make(model.TableSchema),
@@ -422,13 +436,13 @@ func (sh *Schema) TableSchemaDiff(tableName string, tableSchema model.TableSchem
return diff
}

func (sh *Schema) GetTableSchemaInWarehouse(tableName string) model.TableSchema {
func (sh *schema) GetTableSchemaInWarehouse(tableName string) model.TableSchema {
sh.schemaInWarehouseMu.RLock()
defer sh.schemaInWarehouseMu.RUnlock()
return sh.schemaInWarehouse[tableName]
}

func (sh *Schema) UpdateWarehouseTableSchema(tableName string, tableSchema model.TableSchema) {
func (sh *schema) UpdateWarehouseTableSchema(tableName string, tableSchema model.TableSchema) {
sh.schemaInWarehouseMu.Lock()
defer sh.schemaInWarehouseMu.Unlock()
if sh.schemaInWarehouse == nil {
@@ -437,13 +451,13 @@ func (sh *Schema) UpdateWarehouseTableSchema(tableName string, tableSchema model
sh.schemaInWarehouse[tableName] = tableSchema
}

func (sh *Schema) IsWarehouseSchemaEmpty() bool {
func (sh *schema) IsWarehouseSchemaEmpty() bool {
sh.schemaInWarehouseMu.RLock()
defer sh.schemaInWarehouseMu.RUnlock()
return len(sh.schemaInWarehouse) == 0
}

func (sh *Schema) GetColumnsCountInWarehouseSchema(tableName string) int {
func (sh *schema) GetColumnsCountInWarehouseSchema(tableName string) int {
sh.schemaInWarehouseMu.RLock()
defer sh.schemaInWarehouseMu.RUnlock()
return len(sh.schemaInWarehouse[tableName])
18 changes: 9 additions & 9 deletions warehouse/schema/schema_test.go
Original file line number Diff line number Diff line change
@@ -154,7 +154,7 @@ func TestSchema_UpdateLocalSchema(t *testing.T) {
statsStore, err := memstats.New()
require.NoError(t, err)

s := Schema{
s := schema{
warehouse: model.Warehouse{
WorkspaceID: workspaceID,
Source: backendconfig.SourceT{
@@ -352,7 +352,7 @@ func TestSchema_FetchSchemaFromWarehouse(t *testing.T) {
err: tc.mockErr,
}

s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,
@@ -511,7 +511,7 @@ func TestSchema_TableSchemaDiff(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

s := Schema{
s := schema{
schemaInWarehouse: tc.currentSchema,
}
diff := s.TableSchemaDiff(tc.tableName, tc.uploadTableSchema)
@@ -592,7 +592,7 @@ func TestSchema_HasLocalSchemaChanged(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Type: warehouseutils.SNOWFLAKE,
},
@@ -1625,7 +1625,7 @@ func TestSchema_ConsolidateStagingFilesUsingLocalSchema(t *testing.T) {
err: tc.mockErr,
}

s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,
@@ -1668,7 +1668,7 @@ func TestSchema_SyncRemoteSchema(t *testing.T) {
tableName := "test_table_name"

t.Run("should return error if unable to fetch local schema", func(t *testing.T) {
s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,
@@ -1697,7 +1697,7 @@ func TestSchema_SyncRemoteSchema(t *testing.T) {
require.False(t, schemaChanged)
})
t.Run("should return error if unable to fetch remote schema", func(t *testing.T) {
s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,
@@ -1766,7 +1766,7 @@ func TestSchema_SyncRemoteSchema(t *testing.T) {
},
}

s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,
@@ -1835,7 +1835,7 @@ func TestSchema_SyncRemoteSchema(t *testing.T) {
},
}

s := &Schema{
s := &schema{
warehouse: model.Warehouse{
Source: backendconfig.SourceT{
ID: sourceID,

0 comments on commit cd4a7e1

Please sign in to comment.