Browse Source

add data check before rollbeck (#366)

* add data check before rollbeck
tags/1.0.2-RC1
Yuecai Liu GitHub 2 years ago
parent
commit
f4251b5ba8
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 350 additions and 27 deletions
  1. +1
    -1
      go.mod
  2. +114
    -0
      pkg/datasource/sql/datasource/utils.go
  3. +2
    -1
      pkg/datasource/sql/tx.go
  4. +27
    -0
      pkg/datasource/sql/types/image.go
  5. +5
    -1
      pkg/datasource/sql/undo/base/undo.go
  6. +1
    -0
      pkg/datasource/sql/undo/builder/basic_undo_log_builder.go
  7. +119
    -17
      pkg/datasource/sql/undo/executor/executor.go
  8. +2
    -2
      pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go
  9. +79
    -5
      pkg/datasource/sql/undo/executor/utils.go

+ 1
- 1
go.mod View File

@@ -11,6 +11,7 @@ require (
github.com/dubbogo/gost v1.12.6-0.20220824084206-300e27e9e524
github.com/gin-gonic/gin v1.8.0
github.com/go-sql-driver/mysql v1.6.0
github.com/goccy/go-json v0.9.7
github.com/golang/mock v1.6.0
github.com/google/uuid v1.3.0
github.com/mitchellh/copystructure v1.2.0
@@ -67,7 +68,6 @@ require (
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/go-resty/resty/v2 v2.7.0 // indirect
github.com/goccy/go-json v0.9.7 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect


+ 114
- 0
pkg/datasource/sql/datasource/utils.go View File

@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package datasource

import (
"database/sql"
"reflect"
)

type nullTime = sql.NullTime

var (
ScanTypeFloat32 = reflect.TypeOf(float32(0))
ScanTypeFloat64 = reflect.TypeOf(float64(0))
ScanTypeInt8 = reflect.TypeOf(int8(0))
ScanTypeInt16 = reflect.TypeOf(int16(0))
ScanTypeInt32 = reflect.TypeOf(int32(0))
ScanTypeInt64 = reflect.TypeOf(int64(0))
ScanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
ScanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
ScanTypeNullTime = reflect.TypeOf(nullTime{})
ScanTypeUint8 = reflect.TypeOf(uint8(0))
ScanTypeUint16 = reflect.TypeOf(uint16(0))
ScanTypeUint32 = reflect.TypeOf(uint32(0))
ScanTypeUint64 = reflect.TypeOf(uint64(0))
ScanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
ScanTypeUnknown = reflect.TypeOf(new(interface{}))
)

func GetScanSlice(types []*sql.ColumnType) []interface{} {
scanSlice := make([]interface{}, 0, len(types))
for _, tpy := range types {
switch tpy.ScanType() {
case ScanTypeFloat32:
scanVal := float32(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeFloat64:
scanVal := float64(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt8:
scanVal := int8(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt16:
scanVal := int16(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt32:
scanVal := int32(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt64:
scanVal := int64(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeNullFloat:
scanVal := sql.NullFloat64{}
scanSlice = append(scanSlice, &scanVal)
case ScanTypeNullInt:
scanVal := sql.NullInt64{}
scanSlice = append(scanSlice, &scanVal)
case ScanTypeNullTime:
scanVal := sql.NullTime{}
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint8:
scanVal := uint8(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint16:
scanVal := uint16(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint32:
scanVal := uint32(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint64:
scanVal := uint64(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeRawBytes:
scanVal := ""
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUnknown:
scanVal := new(interface{})
scanSlice = append(scanSlice, &scanVal)
}
}
return scanSlice
}

func DeepEqual(x, y interface{}) bool {
typx := reflect.ValueOf(x)
typy := reflect.ValueOf(y)

switch typx.Kind() {
case reflect.Ptr:
typx = typx.Elem()
}

switch typy.Kind() {
case reflect.Ptr:
typy = typy.Elem()
}

return reflect.DeepEqual(typx.Interface(), typy.Interface())
}

+ 2
- 1
pkg/datasource/sql/tx.go View File

@@ -20,9 +20,10 @@ package sql
import (
"context"
"database/sql/driver"
"errors"
"sync"

"github.com/pkg/errors"

"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/rm"


+ 27
- 0
pkg/datasource/sql/types/image.go View File

@@ -21,6 +21,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"reflect"
"time"
)

@@ -93,6 +94,19 @@ func (rs RecordImages) Reserve() {
}
}

func (rs RecordImages) IsEmptyImage() bool {
if len(rs) == 0 {
return true
}
for _, r := range rs {
if r == nil || len(r.Rows) == 0 {
continue
}
return false
}
return true
}

// RecordImage
type RecordImage struct {
// index
@@ -251,3 +265,16 @@ func (c *ColumnImage) UnmarshalJSON(data []byte) error {
func getTypeStr(src interface{}) string {
return fmt.Sprintf("%T", src)
}

func (c *ColumnImage) GetActualValue() interface{} {
if c.Value == nil {
return nil
}
value := reflect.ValueOf(c.Value)
kind := reflect.TypeOf(c.Value).Kind()
switch kind {
case reflect.Ptr:
return value.Elem().Interface()
}
return c.Value
}

+ 5
- 1
pkg/datasource/sql/undo/base/undo.go View File

@@ -162,6 +162,10 @@ func (m *BaseUndoLogManager) FlushUndoLog(tranCtx *types.TransactionContext, con
beforeImages := tranCtx.RoundImages.BeofreImages()
afterImages := tranCtx.RoundImages.AfterImages()

if beforeImages.IsEmptyImage() && afterImages.IsEmptyImage() {
return nil
}

for i := 0; i < len(beforeImages); i++ {
var (
tableName string
@@ -474,7 +478,7 @@ func (m *BaseUndoLogManager) DecodeMap(str string) map[string]string {

// getRollbackInfo parser rollback info
func (m *BaseUndoLogManager) getRollbackInfo(rollbackInfo []byte, undoContext map[string]string) []byte {
// Todo 目前 insert undo log 未实现压缩功能,实现后补齐这块功能
// Todo use compressor
// get compress type
/*compressorType, ok := undoContext[constant.CompressorTypeKey]
if ok {


+ 1
- 0
pkg/datasource/sql/undo/builder/basic_undo_log_builder.go View File

@@ -35,6 +35,7 @@ import (
type BasicUndoLogBuilder struct{}

// GetScanSlice get the column type for scann
// todo to use ColumnInfo get slice
func (*BasicUndoLogBuilder) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []driver.Value {
scanSlice := make([]driver.Value, 0, len(columnNames))
for _, columnNmae := range columnNames {


+ 119
- 17
pkg/datasource/sql/undo/executor/executor.go View File

@@ -20,7 +20,11 @@ package executor
import (
"context"
"database/sql"
"fmt"
"strings"

"github.com/goccy/go-json"
"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/undo"
"github.com/seata/seata-go/pkg/util/log"
@@ -29,7 +33,8 @@ import (
var _ undo.UndoExecutor = (*BaseExecutor)(nil)

const (
selectSQL = "SELECT * FROM %s WHERE %s FOR UPDATE"
checkSQLTemplate = "SELECT * FROM %s WHERE %s FOR UPDATE"
maxInSize = 1000
)

type BaseExecutor struct {
@@ -48,33 +53,130 @@ func (b *BaseExecutor) UndoPrepare(undoPST *sql.Stmt, undoValues []types.ColumnI

}

func (b *BaseExecutor) dataValidationAndGoOn(conn *sql.Conn) (bool, error) {
func (b *BaseExecutor) dataValidationAndGoOn(ctx context.Context, conn *sql.Conn) (bool, error) {
beforeImage := b.sqlUndoLog.BeforeImage
afterImage := b.sqlUndoLog.AfterImage

equal, err := IsRecordsEquals(beforeImage, afterImage)
equals, err := IsRecordsEquals(beforeImage, afterImage)
if err != nil {
return false, err
}
if equal {
if equals {
log.Infof("Stop rollback because there is no data change between the before data snapshot and the after data snapshot.")
return false, nil
}

// todo compare from current db data to old image data
// Validate if data is dirty.
currentImage, err := b.queryCurrentRecords(ctx, conn)
if err != nil {
return false, err
}
// compare with current data and after image.
equals, err = IsRecordsEquals(afterImage, currentImage)
if err != nil {
return false, err
}
if !equals {
// If current data is not equivalent to the after data, then compare the current data with the before
// data, too. No need continue to undo if current data is equivalent to the before data snapshot
equals, err = IsRecordsEquals(beforeImage, currentImage)
if err != nil {
return false, err
}

if equals {
log.Infof("Stop rollback because there is no data change between the before data snapshot and the current data snapshot.")
// no need continue undo.
return false, nil
} else {
oldRowJson, _ := json.Marshal(afterImage.Rows)
newRowJson, _ := json.Marshal(currentImage.Rows)
log.Infof("check dirty data failed, old and new data are not equal, "+
"tableName:[%s], oldRows:[%s],newRows:[%s].", afterImage.TableName, oldRowJson, newRowJson)
return false, fmt.Errorf("Has dirty records when undo.")
}
}
return true, nil
}

// todo
//func (b *BaseExecutor) queryCurrentRecords(conn *sql.Conn) *types.RecordImage {
// tableMeta := b.undoImage.TableMeta
// pkNameList := tableMeta.GetPrimaryKeyOnlyName()
//
// b.undoImage.Rows
//
//}
//
//func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) {
//
//}
func (b *BaseExecutor) queryCurrentRecords(ctx context.Context, conn *sql.Conn) (*types.RecordImage, error) {
if b.undoImage == nil {
return nil, fmt.Errorf("undo image is nil")
}
tableMeta := b.undoImage.TableMeta
pkNameList := tableMeta.GetPrimaryKeyOnlyName()
pkValues := b.parsePkValues(b.undoImage.Rows, pkNameList)

if len(pkValues) == 0 {
return nil, nil
}

var rowSize int
for _, images := range pkValues {
rowSize = len(images)
break
}

where := buildWhereConditionByPKs(pkNameList, rowSize, maxInSize)
checkSQL := fmt.Sprintf(checkSQLTemplate, b.undoImage.TableName, where)
params := buildPKParams(b.undoImage.Rows, pkNameList)

rows, err := conn.QueryContext(ctx, checkSQL, params...)
if err != nil {
return nil, err
}

image := types.RecordImage{
TableName: b.undoImage.TableName,
TableMeta: tableMeta,
SQLType: types.SQLTypeSelect,
}
rowImages := make([]types.RowImage, 0)
for rows.Next() {
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
slice := datasource.GetScanSlice(columnTypes)
if err = rows.Scan(slice...); err != nil {
return nil, err
}

colNames, err := rows.Columns()
if err != nil {
return nil, err
}

columns := make([]types.ColumnImage, 0)
for i, val := range slice {
columns = append(columns, types.ColumnImage{
ColumnName: colNames[i],
Value: val,
})
}
rowImages = append(rowImages, types.RowImage{Columns: columns})
}

image.Rows = rowImages
return &image, nil
}

func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) map[string][]types.ColumnImage {
pkValues := make(map[string][]types.ColumnImage)
// todo optimize 3 fors
for _, row := range rows {
for _, column := range row.Columns {
for _, pk := range pkNameList {
if strings.EqualFold(pk, column.ColumnName) {
values := pkValues[strings.ToUpper(pk)]
if values == nil {
values = make([]types.ColumnImage, 0)
}
values = append(values, column)
pkValues[pk] = values
}
}
}
}
return pkValues
}

+ 2
- 2
pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go View File

@@ -36,12 +36,12 @@ type mySQLUndoUpdateExecutor struct {
func newMySQLUndoUpdateExecutor(sqlUndoLog undo.SQLUndoLog) *mySQLUndoUpdateExecutor {
return &mySQLUndoUpdateExecutor{
sqlUndoLog: sqlUndoLog,
baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog},
baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog, undoImage: sqlUndoLog.AfterImage},
}
}

func (m *mySQLUndoUpdateExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, conn *sql.Conn) error {
ok, err := m.baseExecutor.dataValidationAndGoOn(conn)
ok, err := m.baseExecutor.dataValidationAndGoOn(ctx, conn)
if err != nil {
return err
}


+ 79
- 5
pkg/datasource/sql/undo/executor/utils.go View File

@@ -19,10 +19,11 @@ package executor

import (
"fmt"
"reflect"
"strings"

"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/util/log"
)

// IsRecordsEquals check before record and after record if equal
@@ -51,14 +52,16 @@ func compareRows(tableMeta types.TableMeta, oldRows []types.RowImage, newRows []
for key, oldRow := range oldRowMap {
newRow := newRowMap[key]
if newRow == nil {
return false, fmt.Errorf("compare row failed, rowKey %s, reason [newField is null]", key)
log.Errorf("compare row failed, rowKey %s, reason new field is null", key)
return false, fmt.Errorf("compare image failed for new row is null")
}
for fieldName, oldValue := range oldRow {
newValue := newRow[fieldName]
if newValue == nil {
return false, fmt.Errorf("compare row failed, rowKey %s, fieldName %s, reason [newField is null]", key, fieldName)
log.Errorf("compare row failed, rowKey %s, fieldName %s, reason new value is null", key, fieldName)
return false, fmt.Errorf("compare image failed for new value is null")
}
if !reflect.DeepEqual(newValue, oldValue) {
if !datasource.DeepEqual(newValue, oldValue) {
return false, nil
}
}
@@ -80,7 +83,7 @@ func rowListToMap(rows []types.RowImage, primaryKeyList []string) map[string]map
rowKey += "_##$$_"
}
// todo make value more accurate
rowKey = fmt.Sprintf("%v%v", rowKey, column.Value)
rowKey = fmt.Sprintf("%v%v", rowKey, column.GetActualValue())
firstUnderline = true
}
}
@@ -90,3 +93,74 @@ func rowListToMap(rows []types.RowImage, primaryKeyList []string) map[string]map
}
return rowMap
}

// buildWhereConditionByPKs build where condition by primary keys
// each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))"
func buildWhereConditionByPKs(pkNameList []string, rowSize int, maxInSize int) string {
var (
whereStr = &strings.Builder{}
batchSize = rowSize/maxInSize + 1
)

if rowSize%maxInSize == 0 {
batchSize = rowSize / maxInSize
}

for batch := 0; batch < batchSize; batch++ {
if batch > 0 {
whereStr.WriteString(" OR ")
}
whereStr.WriteString("(")

for i := 0; i < len(pkNameList); i++ {
if i > 0 {
whereStr.WriteString(",")
}
// todo add escape
whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i]))
}
whereStr.WriteString(") IN (")

var eachSize int

if batch == batchSize-1 {
if rowSize%maxInSize == 0 {
eachSize = maxInSize
} else {
eachSize = rowSize % maxInSize
}
} else {
eachSize = maxInSize
}

for i := 0; i < eachSize; i++ {
if i > 0 {
whereStr.WriteString(",")
}
whereStr.WriteString("(")
for j := 0; j < len(pkNameList); j++ {
if j > 0 {
whereStr.WriteString(",")
}
whereStr.WriteString("?")
}
whereStr.WriteString(")")
}
whereStr.WriteString(")")
}
return whereStr.String()
}

func buildPKParams(rows []types.RowImage, pkNameList []string) []interface{} {
params := make([]interface{}, 0)
for _, row := range rows {
coumnMap := row.GetColumnMap()
for _, pk := range pkNameList {
col := coumnMap[pk]
if col != nil {
params = append(params, col.Value)
}
}
}
return params
}

Loading…
Cancel
Save