| @@ -11,6 +11,7 @@ require ( | |||
| github.com/dubbogo/gost v1.12.6-0.20220824084206-300e27e9e524 | |||
| github.com/gin-gonic/gin v1.8.0 | |||
| github.com/go-sql-driver/mysql v1.6.0 | |||
| github.com/goccy/go-json v0.9.7 | |||
| github.com/golang/mock v1.6.0 | |||
| github.com/google/uuid v1.3.0 | |||
| github.com/mitchellh/copystructure v1.2.0 | |||
| @@ -67,7 +68,6 @@ require ( | |||
| github.com/go-playground/locales v0.14.0 // indirect | |||
| github.com/go-playground/universal-translator v0.18.0 // indirect | |||
| github.com/go-resty/resty/v2 v2.7.0 // indirect | |||
| github.com/goccy/go-json v0.9.7 // indirect | |||
| github.com/gogo/protobuf v1.3.2 // indirect | |||
| github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect | |||
| github.com/golang/protobuf v1.5.2 // indirect | |||
| @@ -0,0 +1,114 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package datasource | |||
| import ( | |||
| "database/sql" | |||
| "reflect" | |||
| ) | |||
| type nullTime = sql.NullTime | |||
| var ( | |||
| ScanTypeFloat32 = reflect.TypeOf(float32(0)) | |||
| ScanTypeFloat64 = reflect.TypeOf(float64(0)) | |||
| ScanTypeInt8 = reflect.TypeOf(int8(0)) | |||
| ScanTypeInt16 = reflect.TypeOf(int16(0)) | |||
| ScanTypeInt32 = reflect.TypeOf(int32(0)) | |||
| ScanTypeInt64 = reflect.TypeOf(int64(0)) | |||
| ScanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) | |||
| ScanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) | |||
| ScanTypeNullTime = reflect.TypeOf(nullTime{}) | |||
| ScanTypeUint8 = reflect.TypeOf(uint8(0)) | |||
| ScanTypeUint16 = reflect.TypeOf(uint16(0)) | |||
| ScanTypeUint32 = reflect.TypeOf(uint32(0)) | |||
| ScanTypeUint64 = reflect.TypeOf(uint64(0)) | |||
| ScanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) | |||
| ScanTypeUnknown = reflect.TypeOf(new(interface{})) | |||
| ) | |||
| func GetScanSlice(types []*sql.ColumnType) []interface{} { | |||
| scanSlice := make([]interface{}, 0, len(types)) | |||
| for _, tpy := range types { | |||
| switch tpy.ScanType() { | |||
| case ScanTypeFloat32: | |||
| scanVal := float32(0) | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeFloat64: | |||
| scanVal := float64(0) | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeInt8: | |||
| scanVal := int8(0) | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeInt16: | |||
| scanVal := int16(0) | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeInt32: | |||
| scanVal := int32(0) | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeInt64: | |||
| scanVal := int64(0) | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeNullFloat: | |||
| scanVal := sql.NullFloat64{} | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeNullInt: | |||
| scanVal := sql.NullInt64{} | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeNullTime: | |||
| scanVal := sql.NullTime{} | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeUint8: | |||
| scanVal := uint8(0) | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeUint16: | |||
| scanVal := uint16(0) | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeUint32: | |||
| scanVal := uint32(0) | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeUint64: | |||
| scanVal := uint64(0) | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeRawBytes: | |||
| scanVal := "" | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| case ScanTypeUnknown: | |||
| scanVal := new(interface{}) | |||
| scanSlice = append(scanSlice, &scanVal) | |||
| } | |||
| } | |||
| return scanSlice | |||
| } | |||
| func DeepEqual(x, y interface{}) bool { | |||
| typx := reflect.ValueOf(x) | |||
| typy := reflect.ValueOf(y) | |||
| switch typx.Kind() { | |||
| case reflect.Ptr: | |||
| typx = typx.Elem() | |||
| } | |||
| switch typy.Kind() { | |||
| case reflect.Ptr: | |||
| typy = typy.Elem() | |||
| } | |||
| return reflect.DeepEqual(typx.Interface(), typy.Interface()) | |||
| } | |||
| @@ -20,9 +20,10 @@ package sql | |||
| import ( | |||
| "context" | |||
| "database/sql/driver" | |||
| "errors" | |||
| "sync" | |||
| "github.com/pkg/errors" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/seata/seata-go/pkg/rm" | |||
| @@ -21,6 +21,7 @@ import ( | |||
| "encoding/base64" | |||
| "encoding/json" | |||
| "fmt" | |||
| "reflect" | |||
| "time" | |||
| ) | |||
| @@ -93,6 +94,19 @@ func (rs RecordImages) Reserve() { | |||
| } | |||
| } | |||
| func (rs RecordImages) IsEmptyImage() bool { | |||
| if len(rs) == 0 { | |||
| return true | |||
| } | |||
| for _, r := range rs { | |||
| if r == nil || len(r.Rows) == 0 { | |||
| continue | |||
| } | |||
| return false | |||
| } | |||
| return true | |||
| } | |||
| // RecordImage | |||
| type RecordImage struct { | |||
| // index | |||
| @@ -251,3 +265,16 @@ func (c *ColumnImage) UnmarshalJSON(data []byte) error { | |||
| func getTypeStr(src interface{}) string { | |||
| return fmt.Sprintf("%T", src) | |||
| } | |||
| func (c *ColumnImage) GetActualValue() interface{} { | |||
| if c.Value == nil { | |||
| return nil | |||
| } | |||
| value := reflect.ValueOf(c.Value) | |||
| kind := reflect.TypeOf(c.Value).Kind() | |||
| switch kind { | |||
| case reflect.Ptr: | |||
| return value.Elem().Interface() | |||
| } | |||
| return c.Value | |||
| } | |||
| @@ -162,6 +162,10 @@ func (m *BaseUndoLogManager) FlushUndoLog(tranCtx *types.TransactionContext, con | |||
| beforeImages := tranCtx.RoundImages.BeofreImages() | |||
| afterImages := tranCtx.RoundImages.AfterImages() | |||
| if beforeImages.IsEmptyImage() && afterImages.IsEmptyImage() { | |||
| return nil | |||
| } | |||
| for i := 0; i < len(beforeImages); i++ { | |||
| var ( | |||
| tableName string | |||
| @@ -474,7 +478,7 @@ func (m *BaseUndoLogManager) DecodeMap(str string) map[string]string { | |||
| // getRollbackInfo parser rollback info | |||
| func (m *BaseUndoLogManager) getRollbackInfo(rollbackInfo []byte, undoContext map[string]string) []byte { | |||
| // Todo 目前 insert undo log 未实现压缩功能,实现后补齐这块功能 | |||
| // Todo use compressor | |||
| // get compress type | |||
| /*compressorType, ok := undoContext[constant.CompressorTypeKey] | |||
| if ok { | |||
| @@ -35,6 +35,7 @@ import ( | |||
| type BasicUndoLogBuilder struct{} | |||
| // GetScanSlice get the column type for scann | |||
| // todo to use ColumnInfo get slice | |||
| func (*BasicUndoLogBuilder) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []driver.Value { | |||
| scanSlice := make([]driver.Value, 0, len(columnNames)) | |||
| for _, columnNmae := range columnNames { | |||
| @@ -20,7 +20,11 @@ package executor | |||
| import ( | |||
| "context" | |||
| "database/sql" | |||
| "fmt" | |||
| "strings" | |||
| "github.com/goccy/go-json" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/undo" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| @@ -29,7 +33,8 @@ import ( | |||
| var _ undo.UndoExecutor = (*BaseExecutor)(nil) | |||
| const ( | |||
| selectSQL = "SELECT * FROM %s WHERE %s FOR UPDATE" | |||
| checkSQLTemplate = "SELECT * FROM %s WHERE %s FOR UPDATE" | |||
| maxInSize = 1000 | |||
| ) | |||
| type BaseExecutor struct { | |||
| @@ -48,33 +53,130 @@ func (b *BaseExecutor) UndoPrepare(undoPST *sql.Stmt, undoValues []types.ColumnI | |||
| } | |||
| func (b *BaseExecutor) dataValidationAndGoOn(conn *sql.Conn) (bool, error) { | |||
| func (b *BaseExecutor) dataValidationAndGoOn(ctx context.Context, conn *sql.Conn) (bool, error) { | |||
| beforeImage := b.sqlUndoLog.BeforeImage | |||
| afterImage := b.sqlUndoLog.AfterImage | |||
| equal, err := IsRecordsEquals(beforeImage, afterImage) | |||
| equals, err := IsRecordsEquals(beforeImage, afterImage) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| if equal { | |||
| if equals { | |||
| log.Infof("Stop rollback because there is no data change between the before data snapshot and the after data snapshot.") | |||
| return false, nil | |||
| } | |||
| // todo compare from current db data to old image data | |||
| // Validate if data is dirty. | |||
| currentImage, err := b.queryCurrentRecords(ctx, conn) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| // compare with current data and after image. | |||
| equals, err = IsRecordsEquals(afterImage, currentImage) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| if !equals { | |||
| // If current data is not equivalent to the after data, then compare the current data with the before | |||
| // data, too. No need continue to undo if current data is equivalent to the before data snapshot | |||
| equals, err = IsRecordsEquals(beforeImage, currentImage) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| if equals { | |||
| log.Infof("Stop rollback because there is no data change between the before data snapshot and the current data snapshot.") | |||
| // no need continue undo. | |||
| return false, nil | |||
| } else { | |||
| oldRowJson, _ := json.Marshal(afterImage.Rows) | |||
| newRowJson, _ := json.Marshal(currentImage.Rows) | |||
| log.Infof("check dirty data failed, old and new data are not equal, "+ | |||
| "tableName:[%s], oldRows:[%s],newRows:[%s].", afterImage.TableName, oldRowJson, newRowJson) | |||
| return false, fmt.Errorf("Has dirty records when undo.") | |||
| } | |||
| } | |||
| return true, nil | |||
| } | |||
| // todo | |||
| //func (b *BaseExecutor) queryCurrentRecords(conn *sql.Conn) *types.RecordImage { | |||
| // tableMeta := b.undoImage.TableMeta | |||
| // pkNameList := tableMeta.GetPrimaryKeyOnlyName() | |||
| // | |||
| // b.undoImage.Rows | |||
| // | |||
| //} | |||
| // | |||
| //func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) { | |||
| // | |||
| //} | |||
| func (b *BaseExecutor) queryCurrentRecords(ctx context.Context, conn *sql.Conn) (*types.RecordImage, error) { | |||
| if b.undoImage == nil { | |||
| return nil, fmt.Errorf("undo image is nil") | |||
| } | |||
| tableMeta := b.undoImage.TableMeta | |||
| pkNameList := tableMeta.GetPrimaryKeyOnlyName() | |||
| pkValues := b.parsePkValues(b.undoImage.Rows, pkNameList) | |||
| if len(pkValues) == 0 { | |||
| return nil, nil | |||
| } | |||
| var rowSize int | |||
| for _, images := range pkValues { | |||
| rowSize = len(images) | |||
| break | |||
| } | |||
| where := buildWhereConditionByPKs(pkNameList, rowSize, maxInSize) | |||
| checkSQL := fmt.Sprintf(checkSQLTemplate, b.undoImage.TableName, where) | |||
| params := buildPKParams(b.undoImage.Rows, pkNameList) | |||
| rows, err := conn.QueryContext(ctx, checkSQL, params...) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| image := types.RecordImage{ | |||
| TableName: b.undoImage.TableName, | |||
| TableMeta: tableMeta, | |||
| SQLType: types.SQLTypeSelect, | |||
| } | |||
| rowImages := make([]types.RowImage, 0) | |||
| for rows.Next() { | |||
| columnTypes, err := rows.ColumnTypes() | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| slice := datasource.GetScanSlice(columnTypes) | |||
| if err = rows.Scan(slice...); err != nil { | |||
| return nil, err | |||
| } | |||
| colNames, err := rows.Columns() | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| columns := make([]types.ColumnImage, 0) | |||
| for i, val := range slice { | |||
| columns = append(columns, types.ColumnImage{ | |||
| ColumnName: colNames[i], | |||
| Value: val, | |||
| }) | |||
| } | |||
| rowImages = append(rowImages, types.RowImage{Columns: columns}) | |||
| } | |||
| image.Rows = rowImages | |||
| return &image, nil | |||
| } | |||
| func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) map[string][]types.ColumnImage { | |||
| pkValues := make(map[string][]types.ColumnImage) | |||
| // todo optimize 3 fors | |||
| for _, row := range rows { | |||
| for _, column := range row.Columns { | |||
| for _, pk := range pkNameList { | |||
| if strings.EqualFold(pk, column.ColumnName) { | |||
| values := pkValues[strings.ToUpper(pk)] | |||
| if values == nil { | |||
| values = make([]types.ColumnImage, 0) | |||
| } | |||
| values = append(values, column) | |||
| pkValues[pk] = values | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return pkValues | |||
| } | |||
| @@ -36,12 +36,12 @@ type mySQLUndoUpdateExecutor struct { | |||
| func newMySQLUndoUpdateExecutor(sqlUndoLog undo.SQLUndoLog) *mySQLUndoUpdateExecutor { | |||
| return &mySQLUndoUpdateExecutor{ | |||
| sqlUndoLog: sqlUndoLog, | |||
| baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog}, | |||
| baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog, undoImage: sqlUndoLog.AfterImage}, | |||
| } | |||
| } | |||
| func (m *mySQLUndoUpdateExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, conn *sql.Conn) error { | |||
| ok, err := m.baseExecutor.dataValidationAndGoOn(conn) | |||
| ok, err := m.baseExecutor.dataValidationAndGoOn(ctx, conn) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| @@ -19,10 +19,11 @@ package executor | |||
| import ( | |||
| "fmt" | |||
| "reflect" | |||
| "strings" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| ) | |||
| // IsRecordsEquals check before record and after record if equal | |||
| @@ -51,14 +52,16 @@ func compareRows(tableMeta types.TableMeta, oldRows []types.RowImage, newRows [] | |||
| for key, oldRow := range oldRowMap { | |||
| newRow := newRowMap[key] | |||
| if newRow == nil { | |||
| return false, fmt.Errorf("compare row failed, rowKey %s, reason [newField is null]", key) | |||
| log.Errorf("compare row failed, rowKey %s, reason new field is null", key) | |||
| return false, fmt.Errorf("compare image failed for new row is null") | |||
| } | |||
| for fieldName, oldValue := range oldRow { | |||
| newValue := newRow[fieldName] | |||
| if newValue == nil { | |||
| return false, fmt.Errorf("compare row failed, rowKey %s, fieldName %s, reason [newField is null]", key, fieldName) | |||
| log.Errorf("compare row failed, rowKey %s, fieldName %s, reason new value is null", key, fieldName) | |||
| return false, fmt.Errorf("compare image failed for new value is null") | |||
| } | |||
| if !reflect.DeepEqual(newValue, oldValue) { | |||
| if !datasource.DeepEqual(newValue, oldValue) { | |||
| return false, nil | |||
| } | |||
| } | |||
| @@ -80,7 +83,7 @@ func rowListToMap(rows []types.RowImage, primaryKeyList []string) map[string]map | |||
| rowKey += "_##$$_" | |||
| } | |||
| // todo make value more accurate | |||
| rowKey = fmt.Sprintf("%v%v", rowKey, column.Value) | |||
| rowKey = fmt.Sprintf("%v%v", rowKey, column.GetActualValue()) | |||
| firstUnderline = true | |||
| } | |||
| } | |||
| @@ -90,3 +93,74 @@ func rowListToMap(rows []types.RowImage, primaryKeyList []string) map[string]map | |||
| } | |||
| return rowMap | |||
| } | |||
| // buildWhereConditionByPKs build where condition by primary keys | |||
| // each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))" | |||
| func buildWhereConditionByPKs(pkNameList []string, rowSize int, maxInSize int) string { | |||
| var ( | |||
| whereStr = &strings.Builder{} | |||
| batchSize = rowSize/maxInSize + 1 | |||
| ) | |||
| if rowSize%maxInSize == 0 { | |||
| batchSize = rowSize / maxInSize | |||
| } | |||
| for batch := 0; batch < batchSize; batch++ { | |||
| if batch > 0 { | |||
| whereStr.WriteString(" OR ") | |||
| } | |||
| whereStr.WriteString("(") | |||
| for i := 0; i < len(pkNameList); i++ { | |||
| if i > 0 { | |||
| whereStr.WriteString(",") | |||
| } | |||
| // todo add escape | |||
| whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i])) | |||
| } | |||
| whereStr.WriteString(") IN (") | |||
| var eachSize int | |||
| if batch == batchSize-1 { | |||
| if rowSize%maxInSize == 0 { | |||
| eachSize = maxInSize | |||
| } else { | |||
| eachSize = rowSize % maxInSize | |||
| } | |||
| } else { | |||
| eachSize = maxInSize | |||
| } | |||
| for i := 0; i < eachSize; i++ { | |||
| if i > 0 { | |||
| whereStr.WriteString(",") | |||
| } | |||
| whereStr.WriteString("(") | |||
| for j := 0; j < len(pkNameList); j++ { | |||
| if j > 0 { | |||
| whereStr.WriteString(",") | |||
| } | |||
| whereStr.WriteString("?") | |||
| } | |||
| whereStr.WriteString(")") | |||
| } | |||
| whereStr.WriteString(")") | |||
| } | |||
| return whereStr.String() | |||
| } | |||
| func buildPKParams(rows []types.RowImage, pkNameList []string) []interface{} { | |||
| params := make([]interface{}, 0) | |||
| for _, row := range rows { | |||
| coumnMap := row.GetColumnMap() | |||
| for _, pk := range pkNameList { | |||
| col := coumnMap[pk] | |||
| if col != nil { | |||
| params = append(params, col.Value) | |||
| } | |||
| } | |||
| } | |||
| return params | |||
| } | |||