diff --git a/pkg/datasource/sql/datasource/base/meta_cache.go b/pkg/datasource/sql/datasource/base/meta_cache.go index c4694d8b..2ee9b463 100644 --- a/pkg/datasource/sql/datasource/base/meta_cache.go +++ b/pkg/datasource/sql/datasource/base/meta_cache.go @@ -21,6 +21,7 @@ import ( "context" "database/sql" "fmt" + "strings" "sync" "time" @@ -110,8 +111,9 @@ func (c *BaseTableMetaCache) refresh(ctx context.Context) { for i := range v { tm := v[i] - if _, ok := c.cache[tm.TableName]; ok { - c.cache[tm.TableName] = &entry{ + upperTableName := strings.ToUpper(tm.TableName) + if _, ok := c.cache[upperTableName]; ok { + c.cache[upperTableName] = &entry{ value: tm, } } @@ -157,16 +159,16 @@ func (c *BaseTableMetaCache) GetTableMeta(ctx context.Context, dbName, tableName defer c.lock.Unlock() defer conn.Close() - - v, ok := c.cache[tableName] + upperTableName := strings.ToUpper(tableName) + v, ok := c.cache[upperTableName] if !ok { - meta, err := c.trigger.LoadOne(ctx, dbName, tableName, conn) + meta, err := c.trigger.LoadOne(ctx, dbName, upperTableName, conn) if err != nil { return types.TableMeta{}, err } if meta != nil && !meta.IsEmpty() { - c.cache[tableName] = &entry{ + c.cache[upperTableName] = &entry{ value: *meta, lastAccess: time.Now(), } @@ -178,7 +180,7 @@ func (c *BaseTableMetaCache) GetTableMeta(ctx context.Context, dbName, tableName } v.lastAccess = time.Now() - c.cache[tableName] = v + c.cache[upperTableName] = v return v.value, nil } diff --git a/pkg/datasource/sql/datasource/base/meta_cache_test.go b/pkg/datasource/sql/datasource/base/meta_cache_test.go index 570674e1..134993b8 100644 --- a/pkg/datasource/sql/datasource/base/meta_cache_test.go +++ b/pkg/datasource/sql/datasource/base/meta_cache_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + "github.com/DATA-DOG/go-sqlmock" "github.com/agiledragon/gomonkey/v2" "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/assert" @@ -40,8 +41,30 @@ var ( type mockTrigger struct { } +// LoadOne simulates loading table metadata, including id, name, and age columns. func (m *mockTrigger) LoadOne(ctx context.Context, dbName string, table string, conn *sql.Conn) (*types.TableMeta, error) { - return nil, nil + + return &types.TableMeta{ + TableName: table, + Columns: map[string]types.ColumnMeta{ + "id": {ColumnName: "id"}, + "name": {ColumnName: "name"}, + "age": {ColumnName: "age"}, + }, + Indexs: map[string]types.IndexMeta{ + "id": { + Name: "PRIMARY", + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{{ColumnName: "id"}}, + }, + "id_name_age": { + Name: "name_age_idx", + IType: types.IndexUnique, + Columns: []types.ColumnMeta{{ColumnName: "name"}, {ColumnName: "age"}}, + }, + }, + ColumnNames: []string{"id", "name", "age"}, + }, nil } func (m *mockTrigger) LoadAll(ctx context.Context, dbName string, conn *sql.Conn, tables ...string) ([]types.TableMeta, error) { @@ -77,7 +100,7 @@ func TestBaseTableMetaCache_refresh(t *testing.T) { size: 0, expireDuration: EexpireTime, cache: map[string]*entry{ - "test": { + "TEST": { value: types.TableMeta{}, lastAccess: time.Now(), }, @@ -120,7 +143,116 @@ func TestBaseTableMetaCache_refresh(t *testing.T) { go c.refresh(tt.args.ctx) time.Sleep(time.Second * 3) - assert.Equal(t, c.cache["test"].value, tt.want) + assert.Equal(t, c.cache["TEST"].value, tt.want) + }) + } +} + +func TestBaseTableMetaCache_GetTableMeta(t *testing.T) { + var ( + tableMeta1 types.TableMeta + tableMeta2 types.TableMeta + columns = make(map[string]types.ColumnMeta) + index = make(map[string]types.IndexMeta) + index2 = make(map[string]types.IndexMeta) + columnMeta1 []types.ColumnMeta + columnMeta2 []types.ColumnMeta + ColumnNames []string + ) + columnId := types.ColumnMeta{ + ColumnDef: nil, + ColumnName: "id", + } + columnName := types.ColumnMeta{ + ColumnDef: nil, + ColumnName: "name", + } + columnAge := types.ColumnMeta{ + ColumnDef: nil, + ColumnName: "age", + } + columns["id"] = columnId + columns["name"] = columnName + columns["age"] = columnAge + columnMeta1 = append(columnMeta1, columnId) + columnMeta2 = append(columnMeta2, columnName, columnAge) + index["id"] = types.IndexMeta{ + Name: "PRIMARY", + IType: types.IndexTypePrimaryKey, + Columns: columnMeta1, + } + index["id_name_age"] = types.IndexMeta{ + Name: "name_age_idx", + IType: types.IndexUnique, + Columns: columnMeta2, + } + + ColumnNames = []string{"id", "name", "age"} + tableMeta1 = types.TableMeta{ + TableName: "T_USER1", + Columns: columns, + Indexs: index, + ColumnNames: ColumnNames, + } + + index2["id_name_age"] = types.IndexMeta{ + Name: "name_age_idx", + IType: types.IndexUnique, + Columns: columnMeta2, + } + + tableMeta2 = types.TableMeta{ + TableName: "T_USER2", + Columns: columns, + Indexs: index2, + ColumnNames: ColumnNames, + } + tests := []types.TableMeta{tableMeta1, tableMeta2} + // Use sqlmock to simulate a database connection + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("Failed to create sqlmock: %v", err) + } + defer db.Close() + for _, tt := range tests { + t.Run(tt.TableName, func(t *testing.T) { + mockTrigger := &mockTrigger{} + // Mock a query response + mock.ExpectQuery("SELECT").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age"})) + // Create a mock database connection + conn, err := db.Conn(context.Background()) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + defer conn.Close() + cache := &BaseTableMetaCache{ + trigger: mockTrigger, + cache: map[string]*entry{ + "T_USER": { + value: tableMeta2, + lastAccess: time.Now(), + }, + "T_USER1": { + value: tableMeta1, + lastAccess: time.Now(), + }, + }, + lock: sync.RWMutex{}, + } + + meta, _ := cache.GetTableMeta(context.Background(), "db", tt.TableName, conn) + + if meta.TableName != tt.TableName { + t.Errorf("GetTableMeta() got TableName = %v, want %v", meta.TableName, tt.TableName) + } + // Ensure the retrieved table is cached + cache.lock.RLock() + _, cached := cache.cache[tt.TableName] + cache.lock.RUnlock() + + if !cached { + t.Errorf("GetTableMeta() got TableName = %v, want %v", meta.TableName, tt.TableName) + } }) } }