Browse Source

add data check before rollbeck (#366)

* add data check before rollbeck
tags/1.0.2-RC1
Yuecai Liu GitHub 2 years ago
parent
commit
f4251b5ba8
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 350 additions and 27 deletions
  1. +1
    -1
      go.mod
  2. +114
    -0
      pkg/datasource/sql/datasource/utils.go
  3. +2
    -1
      pkg/datasource/sql/tx.go
  4. +27
    -0
      pkg/datasource/sql/types/image.go
  5. +5
    -1
      pkg/datasource/sql/undo/base/undo.go
  6. +1
    -0
      pkg/datasource/sql/undo/builder/basic_undo_log_builder.go
  7. +119
    -17
      pkg/datasource/sql/undo/executor/executor.go
  8. +2
    -2
      pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go
  9. +79
    -5
      pkg/datasource/sql/undo/executor/utils.go

+ 1
- 1
go.mod View File

@@ -11,6 +11,7 @@ require (
github.com/dubbogo/gost v1.12.6-0.20220824084206-300e27e9e524 github.com/dubbogo/gost v1.12.6-0.20220824084206-300e27e9e524
github.com/gin-gonic/gin v1.8.0 github.com/gin-gonic/gin v1.8.0
github.com/go-sql-driver/mysql v1.6.0 github.com/go-sql-driver/mysql v1.6.0
github.com/goccy/go-json v0.9.7
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/mitchellh/copystructure v1.2.0 github.com/mitchellh/copystructure v1.2.0
@@ -67,7 +68,6 @@ require (
github.com/go-playground/locales v0.14.0 // indirect github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/go-resty/resty/v2 v2.7.0 // indirect github.com/go-resty/resty/v2 v2.7.0 // indirect
github.com/goccy/go-json v0.9.7 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.2 // indirect


+ 114
- 0
pkg/datasource/sql/datasource/utils.go View File

@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package datasource

import (
"database/sql"
"reflect"
)

type nullTime = sql.NullTime

var (
ScanTypeFloat32 = reflect.TypeOf(float32(0))
ScanTypeFloat64 = reflect.TypeOf(float64(0))
ScanTypeInt8 = reflect.TypeOf(int8(0))
ScanTypeInt16 = reflect.TypeOf(int16(0))
ScanTypeInt32 = reflect.TypeOf(int32(0))
ScanTypeInt64 = reflect.TypeOf(int64(0))
ScanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
ScanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
ScanTypeNullTime = reflect.TypeOf(nullTime{})
ScanTypeUint8 = reflect.TypeOf(uint8(0))
ScanTypeUint16 = reflect.TypeOf(uint16(0))
ScanTypeUint32 = reflect.TypeOf(uint32(0))
ScanTypeUint64 = reflect.TypeOf(uint64(0))
ScanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
ScanTypeUnknown = reflect.TypeOf(new(interface{}))
)

func GetScanSlice(types []*sql.ColumnType) []interface{} {
scanSlice := make([]interface{}, 0, len(types))
for _, tpy := range types {
switch tpy.ScanType() {
case ScanTypeFloat32:
scanVal := float32(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeFloat64:
scanVal := float64(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt8:
scanVal := int8(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt16:
scanVal := int16(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt32:
scanVal := int32(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeInt64:
scanVal := int64(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeNullFloat:
scanVal := sql.NullFloat64{}
scanSlice = append(scanSlice, &scanVal)
case ScanTypeNullInt:
scanVal := sql.NullInt64{}
scanSlice = append(scanSlice, &scanVal)
case ScanTypeNullTime:
scanVal := sql.NullTime{}
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint8:
scanVal := uint8(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint16:
scanVal := uint16(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint32:
scanVal := uint32(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUint64:
scanVal := uint64(0)
scanSlice = append(scanSlice, &scanVal)
case ScanTypeRawBytes:
scanVal := ""
scanSlice = append(scanSlice, &scanVal)
case ScanTypeUnknown:
scanVal := new(interface{})
scanSlice = append(scanSlice, &scanVal)
}
}
return scanSlice
}

func DeepEqual(x, y interface{}) bool {
typx := reflect.ValueOf(x)
typy := reflect.ValueOf(y)

switch typx.Kind() {
case reflect.Ptr:
typx = typx.Elem()
}

switch typy.Kind() {
case reflect.Ptr:
typy = typy.Elem()
}

return reflect.DeepEqual(typx.Interface(), typy.Interface())
}

+ 2
- 1
pkg/datasource/sql/tx.go View File

@@ -20,9 +20,10 @@ package sql
import ( import (
"context" "context"
"database/sql/driver" "database/sql/driver"
"errors"
"sync" "sync"


"github.com/pkg/errors"

"github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/rm" "github.com/seata/seata-go/pkg/rm"


+ 27
- 0
pkg/datasource/sql/types/image.go View File

@@ -21,6 +21,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect"
"time" "time"
) )


@@ -93,6 +94,19 @@ func (rs RecordImages) Reserve() {
} }
} }


func (rs RecordImages) IsEmptyImage() bool {
if len(rs) == 0 {
return true
}
for _, r := range rs {
if r == nil || len(r.Rows) == 0 {
continue
}
return false
}
return true
}

// RecordImage // RecordImage
type RecordImage struct { type RecordImage struct {
// index // index
@@ -251,3 +265,16 @@ func (c *ColumnImage) UnmarshalJSON(data []byte) error {
func getTypeStr(src interface{}) string { func getTypeStr(src interface{}) string {
return fmt.Sprintf("%T", src) return fmt.Sprintf("%T", src)
} }

func (c *ColumnImage) GetActualValue() interface{} {
if c.Value == nil {
return nil
}
value := reflect.ValueOf(c.Value)
kind := reflect.TypeOf(c.Value).Kind()
switch kind {
case reflect.Ptr:
return value.Elem().Interface()
}
return c.Value
}

+ 5
- 1
pkg/datasource/sql/undo/base/undo.go View File

@@ -162,6 +162,10 @@ func (m *BaseUndoLogManager) FlushUndoLog(tranCtx *types.TransactionContext, con
beforeImages := tranCtx.RoundImages.BeofreImages() beforeImages := tranCtx.RoundImages.BeofreImages()
afterImages := tranCtx.RoundImages.AfterImages() afterImages := tranCtx.RoundImages.AfterImages()


if beforeImages.IsEmptyImage() && afterImages.IsEmptyImage() {
return nil
}

for i := 0; i < len(beforeImages); i++ { for i := 0; i < len(beforeImages); i++ {
var ( var (
tableName string tableName string
@@ -474,7 +478,7 @@ func (m *BaseUndoLogManager) DecodeMap(str string) map[string]string {


// getRollbackInfo parser rollback info // getRollbackInfo parser rollback info
func (m *BaseUndoLogManager) getRollbackInfo(rollbackInfo []byte, undoContext map[string]string) []byte { func (m *BaseUndoLogManager) getRollbackInfo(rollbackInfo []byte, undoContext map[string]string) []byte {
// Todo 目前 insert undo log 未实现压缩功能,实现后补齐这块功能
// Todo use compressor
// get compress type // get compress type
/*compressorType, ok := undoContext[constant.CompressorTypeKey] /*compressorType, ok := undoContext[constant.CompressorTypeKey]
if ok { if ok {


+ 1
- 0
pkg/datasource/sql/undo/builder/basic_undo_log_builder.go View File

@@ -35,6 +35,7 @@ import (
type BasicUndoLogBuilder struct{} type BasicUndoLogBuilder struct{}


// GetScanSlice get the column type for scann // GetScanSlice get the column type for scann
// todo to use ColumnInfo get slice
func (*BasicUndoLogBuilder) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []driver.Value { func (*BasicUndoLogBuilder) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []driver.Value {
scanSlice := make([]driver.Value, 0, len(columnNames)) scanSlice := make([]driver.Value, 0, len(columnNames))
for _, columnNmae := range columnNames { for _, columnNmae := range columnNames {


+ 119
- 17
pkg/datasource/sql/undo/executor/executor.go View File

@@ -20,7 +20,11 @@ package executor
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings"


"github.com/goccy/go-json"
"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/undo" "github.com/seata/seata-go/pkg/datasource/sql/undo"
"github.com/seata/seata-go/pkg/util/log" "github.com/seata/seata-go/pkg/util/log"
@@ -29,7 +33,8 @@ import (
var _ undo.UndoExecutor = (*BaseExecutor)(nil) var _ undo.UndoExecutor = (*BaseExecutor)(nil)


const ( const (
selectSQL = "SELECT * FROM %s WHERE %s FOR UPDATE"
checkSQLTemplate = "SELECT * FROM %s WHERE %s FOR UPDATE"
maxInSize = 1000
) )


type BaseExecutor struct { type BaseExecutor struct {
@@ -48,33 +53,130 @@ func (b *BaseExecutor) UndoPrepare(undoPST *sql.Stmt, undoValues []types.ColumnI


} }


func (b *BaseExecutor) dataValidationAndGoOn(conn *sql.Conn) (bool, error) {
func (b *BaseExecutor) dataValidationAndGoOn(ctx context.Context, conn *sql.Conn) (bool, error) {
beforeImage := b.sqlUndoLog.BeforeImage beforeImage := b.sqlUndoLog.BeforeImage
afterImage := b.sqlUndoLog.AfterImage afterImage := b.sqlUndoLog.AfterImage


equal, err := IsRecordsEquals(beforeImage, afterImage)
equals, err := IsRecordsEquals(beforeImage, afterImage)
if err != nil { if err != nil {
return false, err return false, err
} }
if equal {
if equals {
log.Infof("Stop rollback because there is no data change between the before data snapshot and the after data snapshot.") log.Infof("Stop rollback because there is no data change between the before data snapshot and the after data snapshot.")
return false, nil return false, nil
} }


// todo compare from current db data to old image data
// Validate if data is dirty.
currentImage, err := b.queryCurrentRecords(ctx, conn)
if err != nil {
return false, err
}
// compare with current data and after image.
equals, err = IsRecordsEquals(afterImage, currentImage)
if err != nil {
return false, err
}
if !equals {
// If current data is not equivalent to the after data, then compare the current data with the before
// data, too. No need continue to undo if current data is equivalent to the before data snapshot
equals, err = IsRecordsEquals(beforeImage, currentImage)
if err != nil {
return false, err
}


if equals {
log.Infof("Stop rollback because there is no data change between the before data snapshot and the current data snapshot.")
// no need continue undo.
return false, nil
} else {
oldRowJson, _ := json.Marshal(afterImage.Rows)
newRowJson, _ := json.Marshal(currentImage.Rows)
log.Infof("check dirty data failed, old and new data are not equal, "+
"tableName:[%s], oldRows:[%s],newRows:[%s].", afterImage.TableName, oldRowJson, newRowJson)
return false, fmt.Errorf("Has dirty records when undo.")
}
}
return true, nil return true, nil
} }


// todo
//func (b *BaseExecutor) queryCurrentRecords(conn *sql.Conn) *types.RecordImage {
// tableMeta := b.undoImage.TableMeta
// pkNameList := tableMeta.GetPrimaryKeyOnlyName()
//
// b.undoImage.Rows
//
//}
//
//func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) {
//
//}
func (b *BaseExecutor) queryCurrentRecords(ctx context.Context, conn *sql.Conn) (*types.RecordImage, error) {
if b.undoImage == nil {
return nil, fmt.Errorf("undo image is nil")
}
tableMeta := b.undoImage.TableMeta
pkNameList := tableMeta.GetPrimaryKeyOnlyName()
pkValues := b.parsePkValues(b.undoImage.Rows, pkNameList)

if len(pkValues) == 0 {
return nil, nil
}

var rowSize int
for _, images := range pkValues {
rowSize = len(images)
break
}

where := buildWhereConditionByPKs(pkNameList, rowSize, maxInSize)
checkSQL := fmt.Sprintf(checkSQLTemplate, b.undoImage.TableName, where)
params := buildPKParams(b.undoImage.Rows, pkNameList)

rows, err := conn.QueryContext(ctx, checkSQL, params...)
if err != nil {
return nil, err
}

image := types.RecordImage{
TableName: b.undoImage.TableName,
TableMeta: tableMeta,
SQLType: types.SQLTypeSelect,
}
rowImages := make([]types.RowImage, 0)
for rows.Next() {
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
slice := datasource.GetScanSlice(columnTypes)
if err = rows.Scan(slice...); err != nil {
return nil, err
}

colNames, err := rows.Columns()
if err != nil {
return nil, err
}

columns := make([]types.ColumnImage, 0)
for i, val := range slice {
columns = append(columns, types.ColumnImage{
ColumnName: colNames[i],
Value: val,
})
}
rowImages = append(rowImages, types.RowImage{Columns: columns})
}

image.Rows = rowImages
return &image, nil
}

func (b *BaseExecutor) parsePkValues(rows []types.RowImage, pkNameList []string) map[string][]types.ColumnImage {
pkValues := make(map[string][]types.ColumnImage)
// todo optimize 3 fors
for _, row := range rows {
for _, column := range row.Columns {
for _, pk := range pkNameList {
if strings.EqualFold(pk, column.ColumnName) {
values := pkValues[strings.ToUpper(pk)]
if values == nil {
values = make([]types.ColumnImage, 0)
}
values = append(values, column)
pkValues[pk] = values
}
}
}
}
return pkValues
}

+ 2
- 2
pkg/datasource/sql/undo/executor/mysql_undo_update_executor.go View File

@@ -36,12 +36,12 @@ type mySQLUndoUpdateExecutor struct {
func newMySQLUndoUpdateExecutor(sqlUndoLog undo.SQLUndoLog) *mySQLUndoUpdateExecutor { func newMySQLUndoUpdateExecutor(sqlUndoLog undo.SQLUndoLog) *mySQLUndoUpdateExecutor {
return &mySQLUndoUpdateExecutor{ return &mySQLUndoUpdateExecutor{
sqlUndoLog: sqlUndoLog, sqlUndoLog: sqlUndoLog,
baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog},
baseExecutor: &BaseExecutor{sqlUndoLog: sqlUndoLog, undoImage: sqlUndoLog.AfterImage},
} }
} }


func (m *mySQLUndoUpdateExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, conn *sql.Conn) error { func (m *mySQLUndoUpdateExecutor) ExecuteOn(ctx context.Context, dbType types.DBType, conn *sql.Conn) error {
ok, err := m.baseExecutor.dataValidationAndGoOn(conn)
ok, err := m.baseExecutor.dataValidationAndGoOn(ctx, conn)
if err != nil { if err != nil {
return err return err
} }


+ 79
- 5
pkg/datasource/sql/undo/executor/utils.go View File

@@ -19,10 +19,11 @@ package executor


import ( import (
"fmt" "fmt"
"reflect"
"strings" "strings"


"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/util/log"
) )


// IsRecordsEquals check before record and after record if equal // IsRecordsEquals check before record and after record if equal
@@ -51,14 +52,16 @@ func compareRows(tableMeta types.TableMeta, oldRows []types.RowImage, newRows []
for key, oldRow := range oldRowMap { for key, oldRow := range oldRowMap {
newRow := newRowMap[key] newRow := newRowMap[key]
if newRow == nil { if newRow == nil {
return false, fmt.Errorf("compare row failed, rowKey %s, reason [newField is null]", key)
log.Errorf("compare row failed, rowKey %s, reason new field is null", key)
return false, fmt.Errorf("compare image failed for new row is null")
} }
for fieldName, oldValue := range oldRow { for fieldName, oldValue := range oldRow {
newValue := newRow[fieldName] newValue := newRow[fieldName]
if newValue == nil { if newValue == nil {
return false, fmt.Errorf("compare row failed, rowKey %s, fieldName %s, reason [newField is null]", key, fieldName)
log.Errorf("compare row failed, rowKey %s, fieldName %s, reason new value is null", key, fieldName)
return false, fmt.Errorf("compare image failed for new value is null")
} }
if !reflect.DeepEqual(newValue, oldValue) {
if !datasource.DeepEqual(newValue, oldValue) {
return false, nil return false, nil
} }
} }
@@ -80,7 +83,7 @@ func rowListToMap(rows []types.RowImage, primaryKeyList []string) map[string]map
rowKey += "_##$$_" rowKey += "_##$$_"
} }
// todo make value more accurate // todo make value more accurate
rowKey = fmt.Sprintf("%v%v", rowKey, column.Value)
rowKey = fmt.Sprintf("%v%v", rowKey, column.GetActualValue())
firstUnderline = true firstUnderline = true
} }
} }
@@ -90,3 +93,74 @@ func rowListToMap(rows []types.RowImage, primaryKeyList []string) map[string]map
} }
return rowMap return rowMap
} }

// buildWhereConditionByPKs build where condition by primary keys
// each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))"
func buildWhereConditionByPKs(pkNameList []string, rowSize int, maxInSize int) string {
var (
whereStr = &strings.Builder{}
batchSize = rowSize/maxInSize + 1
)

if rowSize%maxInSize == 0 {
batchSize = rowSize / maxInSize
}

for batch := 0; batch < batchSize; batch++ {
if batch > 0 {
whereStr.WriteString(" OR ")
}
whereStr.WriteString("(")

for i := 0; i < len(pkNameList); i++ {
if i > 0 {
whereStr.WriteString(",")
}
// todo add escape
whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i]))
}
whereStr.WriteString(") IN (")

var eachSize int

if batch == batchSize-1 {
if rowSize%maxInSize == 0 {
eachSize = maxInSize
} else {
eachSize = rowSize % maxInSize
}
} else {
eachSize = maxInSize
}

for i := 0; i < eachSize; i++ {
if i > 0 {
whereStr.WriteString(",")
}
whereStr.WriteString("(")
for j := 0; j < len(pkNameList); j++ {
if j > 0 {
whereStr.WriteString(",")
}
whereStr.WriteString("?")
}
whereStr.WriteString(")")
}
whereStr.WriteString(")")
}
return whereStr.String()
}

func buildPKParams(rows []types.RowImage, pkNameList []string) []interface{} {
params := make([]interface{}, 0)
for _, row := range rows {
coumnMap := row.GetColumnMap()
for _, pk := range pkNameList {
col := coumnMap[pk]
if col != nil {
params = append(params, col.Value)
}
}
}
return params
}

Loading…
Cancel
Save