@@ -11,6 +11,7 @@ require ( | |||
github.com/dubbogo/gost v1.12.6-0.20220824084206-300e27e9e524 | |||
github.com/gin-gonic/gin v1.8.0 | |||
github.com/go-sql-driver/mysql v1.6.0 | |||
github.com/goccy/go-json v0.9.7 | |||
github.com/golang/mock v1.6.0 | |||
github.com/google/uuid v1.3.0 | |||
github.com/mitchellh/copystructure v1.2.0 | |||
@@ -67,7 +68,6 @@ require ( | |||
github.com/go-playground/locales v0.14.0 // indirect | |||
github.com/go-playground/universal-translator v0.18.0 // indirect | |||
github.com/go-resty/resty/v2 v2.7.0 // indirect | |||
github.com/goccy/go-json v0.9.7 // indirect | |||
github.com/gogo/protobuf v1.3.2 // indirect | |||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect | |||
github.com/golang/protobuf v1.5.2 // indirect | |||
@@ -0,0 +1,114 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package datasource | |||
import ( | |||
"database/sql" | |||
"reflect" | |||
) | |||
type nullTime = sql.NullTime | |||
var ( | |||
ScanTypeFloat32 = reflect.TypeOf(float32(0)) | |||
ScanTypeFloat64 = reflect.TypeOf(float64(0)) | |||
ScanTypeInt8 = reflect.TypeOf(int8(0)) | |||
ScanTypeInt16 = reflect.TypeOf(int16(0)) | |||
ScanTypeInt32 = reflect.TypeOf(int32(0)) | |||
ScanTypeInt64 = reflect.TypeOf(int64(0)) | |||
ScanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) | |||
ScanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) | |||
ScanTypeNullTime = reflect.TypeOf(nullTime{}) | |||
ScanTypeUint8 = reflect.TypeOf(uint8(0)) | |||
ScanTypeUint16 = reflect.TypeOf(uint16(0)) | |||
ScanTypeUint32 = reflect.TypeOf(uint32(0)) | |||
ScanTypeUint64 = reflect.TypeOf(uint64(0)) | |||
ScanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) | |||
ScanTypeUnknown = reflect.TypeOf(new(interface{})) | |||
) | |||
func GetScanSlice(types []*sql.ColumnType) []interface{} { | |||
scanSlice := make([]interface{}, 0, len(types)) | |||
for _, tpy := range types { | |||
switch tpy.ScanType() { | |||
case ScanTypeFloat32: | |||
scanVal := float32(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeFloat64: | |||
scanVal := float64(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeInt8: | |||
scanVal := int8(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeInt16: | |||
scanVal := int16(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeInt32: | |||
scanVal := int32(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeInt64: | |||
scanVal := int64(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeNullFloat: | |||
scanVal := sql.NullFloat64{} | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeNullInt: | |||
scanVal := sql.NullInt64{} | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeNullTime: | |||
scanVal := sql.NullTime{} | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeUint8: | |||
scanVal := uint8(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeUint16: | |||
scanVal := uint16(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeUint32: | |||
scanVal := uint32(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeUint64: | |||
scanVal := uint64(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeRawBytes: | |||
scanVal := "" | |||
scanSlice = append(scanSlice, &scanVal) | |||
case ScanTypeUnknown: | |||
scanVal := new(interface{}) | |||
scanSlice = append(scanSlice, &scanVal) | |||
} | |||
} | |||
return scanSlice | |||
} | |||
func DeepEqual(x, y interface{}) bool { | |||
typx := reflect.ValueOf(x) | |||
typy := reflect.ValueOf(y) | |||
switch typx.Kind() { | |||
case reflect.Ptr: | |||
typx = typx.Elem() | |||
} | |||
switch typy.Kind() { | |||
case reflect.Ptr: | |||
typy = typy.Elem() | |||
} | |||
return reflect.DeepEqual(typx.Interface(), typy.Interface()) | |||
} |
@@ -20,9 +20,10 @@ package sql | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"errors" | |||
"sync" | |||
"github.com/pkg/errors" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/rm" | |||
@@ -21,6 +21,7 @@ import ( | |||
"encoding/base64" | |||
"encoding/json" | |||
"fmt" | |||
"reflect" | |||
"time" | |||
) | |||
@@ -93,6 +94,19 @@ func (rs RecordImages) Reserve() { | |||
} | |||
} | |||
func (rs RecordImages) IsEmptyImage() bool { | |||
if len(rs) == 0 { | |||
return true | |||
} | |||
for _, r := range rs { | |||
if r == nil || len(r.Rows) == 0 { | |||
continue | |||
} | |||
return false | |||
} | |||
return true | |||
} | |||
// RecordImage | |||
type RecordImage struct { | |||
// index | |||
@@ -251,3 +265,16 @@ func (c *ColumnImage) UnmarshalJSON(data []byte) error { | |||
func getTypeStr(src interface{}) string { | |||
return fmt.Sprintf("%T", src) | |||
} | |||
func (c *ColumnImage) GetActualValue() interface{} { | |||
if c.Value == nil { | |||
return nil | |||
} | |||
value := reflect.ValueOf(c.Value) | |||
kind := reflect.TypeOf(c.Value).Kind() | |||
switch kind { | |||
case reflect.Ptr: | |||
return value.Elem().Interface() | |||
} | |||
return c.Value | |||
} |
@@ -162,6 +162,10 @@ func (m *BaseUndoLogManager) FlushUndoLog(tranCtx *types.TransactionContext, con | |||
beforeImages := tranCtx.RoundImages.BeofreImages() | |||
afterImages := tranCtx.RoundImages.AfterImages() | |||
if beforeImages.IsEmptyImage() && afterImages.IsEmptyImage() { | |||
return nil | |||
} | |||
for i := 0; i < len(beforeImages); i++ { | |||
var ( | |||
tableName string | |||
@@ -474,7 +478,7 @@ func (m *BaseUndoLogManager) DecodeMap(str string) map[string]string { | |||
// getRollbackInfo parser rollback info | |||
func (m *BaseUndoLogManager) getRollbackInfo(rollbackInfo []byte, undoContext map[string]string) []byte { | |||
// Todo 目前 insert undo log 未实现压缩功能,实现后补齐这块功能 | |||
// Todo use compressor | |||
// get compress type | |||
/*compressorType, ok := undoContext[constant.CompressorTypeKey] | |||
if ok { | |||
@@ -35,6 +35,7 @@ import ( | |||
type BasicUndoLogBuilder struct{} | |||
// GetScanSlice get the column type for scann | |||
// todo to use ColumnInfo get slice | |||
func (*BasicUndoLogBuilder) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []driver.Value { | |||
scanSlice := make([]driver.Value, 0, len(columnNames)) | |||
for _, columnNmae := range columnNames { | |||
@@ -20,7 +20,11 @@ package executor | |||
import ( | |||
"context" | |||
"database/sql" | |||
"fmt" | |||
"strings" | |||
"github.com/goccy/go-json" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/datasource/sql/undo" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
@@ -29,7 +33,8 @@ import ( | |||
var _ undo.UndoExecutor = (*BaseExecutor)(nil) | |||
const ( | |||
selectSQL = "SELECT * FROM %s WHERE %s FOR UPDATE" | |||
checkSQLTemplate = "SELECT * FROM %s WHERE %s FOR UPDATE" | |||
maxInSize = 1000 | |||
) | |||
type BaseExecutor struct { | |||
@@ -48,33 +53,130 @@ func (b *BaseExecutor) UndoPrepare(undoPST *sql.Stmt, undoValues []types.ColumnI | |||
} | |||
func (b *BaseExecutor) dataValidationAndGoOn(conn *sql.Conn) (bool, error) { | |||
func (b *BaseExecutor) dataValidationAndGoOn(ctx context.Context, conn *sql.Conn) (bool, error) { | |||
beforeImage := b.sqlUndoLog.BeforeImage | |||
afterImage := b.sqlUndoLog.AfterImage | |||
equal, err := IsRecordsEquals(beforeImage, afterImage) | |||
equals, err := IsRecordsEquals(beforeImage, afterImage) | |||
if err != nil { | |||
return false, err | |||
} | |||
if equal { | |||
if equals { | |||
log.Infof("Stop rollback because there is no data change between the before data snapshot and the after data snapshot.") | |||
return false, nil | |||
} | |||
// todo compare from current db data to old image data | |||
// Validate if data is dirty. | |||
currentImage, err := b.queryCurrentRecords(ctx, conn) | |||
if err != nil { | |||
return false, err | |||
} | |||
// compare with current data and after image. | |||
equals, err = IsRecordsEquals(afterImage, currentImage) | |||
if err != nil { | |||
return false, err | |||
} | |||
if !equals { | |||
// If current data is not equivalent to the after data, then compare the current data with the before | |||
// data, too. No need continue to undo if current data is equivalent to the before data snapshot | |||
equals, err = IsRecordsEquals(beforeImage, currentImage) | |||
if err != nil { | |||
return false, err | |||
} | |||
if equals { | |||
log.Infof("Stop rollback because there is no data change between the before data snapshot and the current data snapshot.") | |||
// no need continue undo. | |||
return false, nil | |||
} else { | |||
oldRowJson, _ := json.Marshal(afterImage.Rows) | |||
newRowJson, _ := json.Marshal(currentImage.Rows) | |||
log.Infof("check dirty data failed, old and new data are not equal, "+ | |||
"tableName:[%s], oldRows:[%s],newRows:[%s].", afterImage.TableName, oldRowJson, newRowJson) | |||
return false, fmt.Errorf("Has dirty records when undo.") | |||
} | |||
} | |||
return true, nil | |||
} | |||
// todo | |||
//func (b *BaseExecutor) queryCurrentRecords(conn *sql.Conn) *types.RecordImage { | |||
// tableMeta := b.undoImage.TableMeta | |||
// pkNameList := tableMeta.GetPrimaryKeyOnlyName() | |||
// | |||
// b.undoImage.Rows | |||
// | |||
//} | |||
// | |||
//func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) { | |||
// | |||
//} | |||
func (b *BaseExecutor) queryCurrentRecords(ctx context.Context, conn *sql.Conn) (*types.RecordImage, error) { | |||
if b.undoImage == nil { | |||
return nil, fmt.Errorf("undo image is nil") | |||
} | |||
tableMeta := b.undoImage.TableMeta | |||
pkNameList := tableMeta.GetPrimaryKeyOnlyName() | |||
pkValues := b.parsePkValues(b.undoImage.Rows, pkNameList) | |||
if len(pkValues) == 0 { | |||
return nil, nil | |||
} | |||
var rowSize int | |||
for _, images := range pkValues { | |||
rowSize = len(images) | |||
break | |||
} | |||
where := buildWhereConditionByPKs(pkNameList, rowSize, maxInSize) | |||
checkSQL := fmt.Sprintf(checkSQLTemplate, b.undoImage.TableName, where) | |||
params := buildPKParams(b.undoImage.Rows, pkNameList) | |||
rows, err := conn.QueryContext(ctx, checkSQL, params...) | |||
if err != nil { | |||
return nil, err | |||
} | |||
image := types.RecordImage{ | |||
TableName: b.undoImage.TableName, | |||
TableMeta: tableMeta, | |||
SQLType: types.SQLTypeSelect, | |||
} | |||
rowImages := make([]types.RowImage, 0) | |||
for rows.Next() { | |||
columnTypes, err := rows.ColumnTypes() | |||
if err != nil { | |||
return nil, err | |||
} | |||
slice := datasource.GetScanSlice(columnTypes) | |||
if err = rows.Scan(slice...); err != nil { | |||
return nil, err | |||
} | |||
colNames, err := rows.Columns() | |||
if err != nil { | |||
return nil, err | |||
} | |||
columns := make([]types.ColumnImage, 0) | |||
for i, val := range slice { | |||
columns = append(columns, types.ColumnImage{ | |||
ColumnName: colNames[i], | |||
Value: val, | |||
}) | |||
} | |||
rowImages = append(rowImages, types.RowImage{Columns: columns}) | |||
} | |||
image.Rows = rowImages | |||
return &image, nil | |||
} | |||
func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) map[string][]types.ColumnImage { | |||
pkValues := make(map[string][]types.ColumnImage) | |||
// todo optimize 3 fors | |||
for _, row := range rows { | |||
for _, column := range row.Columns { | |||
for _, pk := range pkNameList { | |||
if strings.EqualFold(pk, column.ColumnName) { | |||
values := pkValues[strings.ToUpper(pk)] | |||
if values == nil { | |||
values = make([]types.ColumnImage, 0) | |||
} | |||
values = append(values, column) | |||
pkValues[pk] = values | |||
} | |||
} | |||
} | |||
} | |||
return pkValues | |||
} |
@@ -36,12 +36,12 @@ type mySQLUndoUpdateExecutor struct { | |||
func newMySQLUndoUpdateExecutor(sqlUndoLog undo.SQLUndoLog) *mySQLUndoUpdateExecutor { | |||
return &mySQLUndoUpdateExecutor{ | |||
sqlUndoLog: sqlUndoLog, | |||
baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog}, | |||
baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog, undoImage: sqlUndoLog.AfterImage}, | |||
} | |||
} | |||
func (m *mySQLUndoUpdateExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, conn *sql.Conn) error { | |||
ok, err := m.baseExecutor.dataValidationAndGoOn(conn) | |||
ok, err := m.baseExecutor.dataValidationAndGoOn(ctx, conn) | |||
if err != nil { | |||
return err | |||
} | |||
@@ -19,10 +19,11 @@ package executor | |||
import ( | |||
"fmt" | |||
"reflect" | |||
"strings" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
) | |||
// IsRecordsEquals check before record and after record if equal | |||
@@ -51,14 +52,16 @@ func compareRows(tableMeta types.TableMeta, oldRows []types.RowImage, newRows [] | |||
for key, oldRow := range oldRowMap { | |||
newRow := newRowMap[key] | |||
if newRow == nil { | |||
return false, fmt.Errorf("compare row failed, rowKey %s, reason [newField is null]", key) | |||
log.Errorf("compare row failed, rowKey %s, reason new field is null", key) | |||
return false, fmt.Errorf("compare image failed for new row is null") | |||
} | |||
for fieldName, oldValue := range oldRow { | |||
newValue := newRow[fieldName] | |||
if newValue == nil { | |||
return false, fmt.Errorf("compare row failed, rowKey %s, fieldName %s, reason [newField is null]", key, fieldName) | |||
log.Errorf("compare row failed, rowKey %s, fieldName %s, reason new value is null", key, fieldName) | |||
return false, fmt.Errorf("compare image failed for new value is null") | |||
} | |||
if !reflect.DeepEqual(newValue, oldValue) { | |||
if !datasource.DeepEqual(newValue, oldValue) { | |||
return false, nil | |||
} | |||
} | |||
@@ -80,7 +83,7 @@ func rowListToMap(rows []types.RowImage, primaryKeyList []string) map[string]map | |||
rowKey += "_##$$_" | |||
} | |||
// todo make value more accurate | |||
rowKey = fmt.Sprintf("%v%v", rowKey, column.Value) | |||
rowKey = fmt.Sprintf("%v%v", rowKey, column.GetActualValue()) | |||
firstUnderline = true | |||
} | |||
} | |||
@@ -90,3 +93,74 @@ func rowListToMap(rows []types.RowImage, primaryKeyList []string) map[string]map | |||
} | |||
return rowMap | |||
} | |||
// buildWhereConditionByPKs build where condition by primary keys | |||
// each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))" | |||
func buildWhereConditionByPKs(pkNameList []string, rowSize int, maxInSize int) string { | |||
var ( | |||
whereStr = &strings.Builder{} | |||
batchSize = rowSize/maxInSize + 1 | |||
) | |||
if rowSize%maxInSize == 0 { | |||
batchSize = rowSize / maxInSize | |||
} | |||
for batch := 0; batch < batchSize; batch++ { | |||
if batch > 0 { | |||
whereStr.WriteString(" OR ") | |||
} | |||
whereStr.WriteString("(") | |||
for i := 0; i < len(pkNameList); i++ { | |||
if i > 0 { | |||
whereStr.WriteString(",") | |||
} | |||
// todo add escape | |||
whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i])) | |||
} | |||
whereStr.WriteString(") IN (") | |||
var eachSize int | |||
if batch == batchSize-1 { | |||
if rowSize%maxInSize == 0 { | |||
eachSize = maxInSize | |||
} else { | |||
eachSize = rowSize % maxInSize | |||
} | |||
} else { | |||
eachSize = maxInSize | |||
} | |||
for i := 0; i < eachSize; i++ { | |||
if i > 0 { | |||
whereStr.WriteString(",") | |||
} | |||
whereStr.WriteString("(") | |||
for j := 0; j < len(pkNameList); j++ { | |||
if j > 0 { | |||
whereStr.WriteString(",") | |||
} | |||
whereStr.WriteString("?") | |||
} | |||
whereStr.WriteString(")") | |||
} | |||
whereStr.WriteString(")") | |||
} | |||
return whereStr.String() | |||
} | |||
func buildPKParams(rows []types.RowImage, pkNameList []string) []interface{} { | |||
params := make([]interface{}, 0) | |||
for _, row := range rows { | |||
coumnMap := row.GetColumnMap() | |||
for _, pk := range pkNameList { | |||
col := coumnMap[pk] | |||
if col != nil { | |||
params = append(params, col.Value) | |||
} | |||
} | |||
} | |||
return params | |||
} |