From a44063609e47afc26dea42353d8d81091ad7aad9 Mon Sep 17 00:00:00 2001 From: Yuecai Liu <38887641+luky116@users.noreply.github.com> Date: Tue, 20 Dec 2022 22:51:52 +0800 Subject: [PATCH] optimize: refactor at executor (#397) * refactor at executor --- go.mod | 2 +- pkg/client/config.go | 2 +- pkg/datasource/sql/conn.go | 5 +- pkg/datasource/sql/conn_test.go | 8 - pkg/datasource/sql/exec/at/base_executor.go | 296 ++++++++++++++ pkg/datasource/sql/exec/at/executor_at.go | 185 ++------- .../exec/at/{default.go => plain_executor.go} | 26 +- pkg/datasource/sql/exec/at/update_executor.go | 272 +++++++++++++ .../sql/exec/at/update_executor_test.go | 101 +++++ pkg/datasource/sql/exec/executor.go | 65 +-- pkg/datasource/sql/exec/xa/executor_xa.go | 14 +- pkg/datasource/sql/stmt.go | 9 +- pkg/datasource/sql/types/image.go | 5 +- pkg/datasource/sql/util/convert.go | 377 ++++++++++++++++++ pkg/datasource/sql/util/ctxutil.go | 123 ++++++ pkg/datasource/sql/util/params.go | 39 ++ pkg/datasource/sql/util/sql.go | 355 +++++++++++++++++ .../client/rm_branch_rollback_processor.go | 2 +- pkg/util/reflectx/unexpoert_field.go | 13 + pkg/util/reflectx/unexpoert_field_test.go | 48 +++ 20 files changed, 1715 insertions(+), 232 deletions(-) create mode 100644 pkg/datasource/sql/exec/at/base_executor.go rename pkg/datasource/sql/exec/at/{default.go => plain_executor.go} (65%) create mode 100644 pkg/datasource/sql/exec/at/update_executor.go create mode 100644 pkg/datasource/sql/exec/at/update_executor_test.go create mode 100644 pkg/datasource/sql/util/convert.go create mode 100644 pkg/datasource/sql/util/ctxutil.go create mode 100644 pkg/datasource/sql/util/params.go create mode 100644 pkg/datasource/sql/util/sql.go create mode 100644 pkg/util/reflectx/unexpoert_field_test.go diff --git a/go.mod b/go.mod index bace7086..d128204c 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,6 @@ require ( github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 github.com/knadh/koanf v1.4.3 - github.com/mitchellh/copystructure v1.2.0 github.com/natefinch/lumberjack v2.0.0+incompatible github.com/parnurzeal/gorequest v0.2.16 github.com/pierrec/lz4/v4 v4.1.17 @@ -88,6 +87,7 @@ require ( github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.6 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect + github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect diff --git a/pkg/client/config.go b/pkg/client/config.go index 64f2ec96..62ce9a2e 100644 --- a/pkg/client/config.go +++ b/pkg/client/config.go @@ -32,8 +32,8 @@ import ( "github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/providers/rawbytes" "github.com/seata/seata-go/pkg/remoting/getty" - "github.com/seata/seata-go/pkg/tm" "github.com/seata/seata-go/pkg/rm/tcc" + "github.com/seata/seata-go/pkg/tm" "github.com/seata/seata-go/pkg/util/flagext" ) diff --git a/pkg/datasource/sql/conn.go b/pkg/datasource/sql/conn.go index 3f224c99..a9a4aed7 100644 --- a/pkg/datasource/sql/conn.go +++ b/pkg/datasource/sql/conn.go @@ -23,6 +23,7 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/util" ) // Conn is a connection to a database. It is not used concurrently @@ -146,8 +147,8 @@ func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) { } ret, err := executor.ExecWithValue(context.Background(), execCtx, - func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { - ret, err := conn.Query(query, args) + func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { + ret, err := conn.Query(query, util.NamedValueToValue(args)) if err != nil { return nil, err } diff --git a/pkg/datasource/sql/conn_test.go b/pkg/datasource/sql/conn_test.go index 3be45a6c..4484b028 100644 --- a/pkg/datasource/sql/conn_test.go +++ b/pkg/datasource/sql/conn_test.go @@ -26,14 +26,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestConn_BuildATExecutor(t *testing.T) { - executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.ATMode, "SELECT * FROM user") - - assert.NoError(t, err) - _, ok := executor.(*exec.BaseExecutor) - assert.True(t, ok, "need base executor") -} - func TestConn_BuildXAExecutor(t *testing.T) { executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.XAMode, "SELECT * FROM user") diff --git a/pkg/datasource/sql/exec/at/base_executor.go b/pkg/datasource/sql/exec/at/base_executor.go new file mode 100644 index 00000000..d9de2315 --- /dev/null +++ b/pkg/datasource/sql/exec/at/base_executor.go @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package at + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "fmt" + "strings" + + "github.com/arana-db/parser/ast" + "github.com/arana-db/parser/test_driver" + gxsort "github.com/dubbogo/gost/sort" + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/util" + "github.com/seata/seata-go/pkg/util/reflectx" +) + +type baseExecutor struct { + hooks []exec.SQLHook +} + +func (b *baseExecutor) beforeHooks(ctx context.Context, execCtx *types.ExecContext) { + for _, hook := range b.hooks { + hook.Before(ctx, execCtx) + } +} + +func (b *baseExecutor) afterHooks(ctx context.Context, execCtx *types.ExecContext) { + for _, hook := range b.hooks { + hook.After(ctx, execCtx) + } +} + +// GetScanSlice get the column type for scann +// todo to use ColumnInfo get slice +func (*baseExecutor) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []interface{} { + scanSlice := make([]interface{}, 0, len(columnNames)) + for _, columnNmae := range columnNames { + var ( + // 从metData获取该列的元信息 + columnMeta = tableMeta.Columns[columnNmae] + ) + switch strings.ToUpper(columnMeta.DatabaseTypeString) { + case "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", "JSON", "TINYTEXT": + var scanVal string + scanSlice = append(scanSlice, &scanVal) + case "BIT", "INT", "LONGBLOB", "SMALLINT", "TINYINT", "BIGINT", "MEDIUMINT": + if columnMeta.IsNullable == 0 { + scanVal := int64(0) + scanSlice = append(scanSlice, &scanVal) + } else { + scanVal := sql.NullInt64{} + scanSlice = append(scanSlice, &scanVal) + } + case "DATE", "DATETIME", "TIME", "TIMESTAMP", "YEAR": + scanVal := sql.NullTime{} + scanSlice = append(scanSlice, &scanVal) + case "DECIMAL", "DOUBLE", "FLOAT": + if columnMeta.IsNullable == 0 { + scanVal := float64(0) + scanSlice = append(scanSlice, &scanVal) + } else { + scanVal := sql.NullFloat64{} + scanSlice = append(scanSlice, &scanVal) + } + default: + scanVal := sql.RawBytes{} + scanSlice = append(scanSlice, &scanVal) + } + } + return scanSlice +} + +func (b *baseExecutor) buildSelectArgs(stmt *ast.SelectStmt, args []driver.NamedValue) []driver.NamedValue { + var ( + selectArgsIndexs = make([]int32, 0) + selectArgs = make([]driver.NamedValue, 0) + ) + + b.traversalArgs(stmt.Where, &selectArgsIndexs) + if stmt.OrderBy != nil { + for _, item := range stmt.OrderBy.Items { + b.traversalArgs(item, &selectArgsIndexs) + } + } + if stmt.Limit != nil { + if stmt.Limit.Offset != nil { + b.traversalArgs(stmt.Limit.Offset, &selectArgsIndexs) + } + if stmt.Limit.Count != nil { + b.traversalArgs(stmt.Limit.Count, &selectArgsIndexs) + } + } + // sort selectArgs index array + gxsort.Int32(selectArgsIndexs) + for _, index := range selectArgsIndexs { + selectArgs = append(selectArgs, args[index]) + } + + return selectArgs +} + +// todo perfect all sql operation +func (b *baseExecutor) traversalArgs(node ast.Node, argsIndex *[]int32) { + if node == nil { + return + } + switch node.(type) { + case *ast.BinaryOperationExpr: + expr := node.(*ast.BinaryOperationExpr) + b.traversalArgs(expr.L, argsIndex) + b.traversalArgs(expr.R, argsIndex) + break + case *ast.BetweenExpr: + expr := node.(*ast.BetweenExpr) + b.traversalArgs(expr.Left, argsIndex) + b.traversalArgs(expr.Right, argsIndex) + break + case *ast.PatternInExpr: + exprs := node.(*ast.PatternInExpr).List + for i := 0; i < len(exprs); i++ { + b.traversalArgs(exprs[i], argsIndex) + } + break + case *test_driver.ParamMarkerExpr: + *argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order)) + break + } +} + +func (b *baseExecutor) buildRecordImages(rowsi driver.Rows, tableMetaData *types.TableMeta) (*types.RecordImage, error) { + // select column names + columnNames := rowsi.Columns() + rowImages := make([]types.RowImage, 0) + + sqlRows := util.NewScanRows(rowsi) + + for sqlRows.Next() { + ss := b.GetScanSlice(columnNames, tableMetaData) + + err := sqlRows.Scan(ss...) + if err != nil { + return nil, err + } + + columns := make([]types.ColumnImage, 0) + // build record image + for i, name := range columnNames { + columnMeta := tableMetaData.Columns[name] + + keyType := types.IndexTypeNull + if _, ok := tableMetaData.GetPrimaryKeyMap()[name]; ok { + keyType = types.IndexTypePrimaryKey + } + jdbcType := types.MySQLStrToJavaType(columnMeta.DatabaseTypeString) + + columns = append(columns, types.ColumnImage{ + KeyType: keyType, + ColumnName: name, + ColumnType: jdbcType, + Value: reflectx.GetElemDataValue(ss[i]), + }) + } + rowImages = append(rowImages, types.RowImage{Columns: columns}) + } + + return &types.RecordImage{TableName: tableMetaData.TableName, Rows: rowImages}, nil +} + +// buildWhereConditionByPKs build where condition by primary keys +// each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))" +func (b *baseExecutor) buildWhereConditionByPKs(pkNameList []string, rowSize int, dbType string, maxInSize int) string { + var ( + whereStr = &strings.Builder{} + batchSize = rowSize/maxInSize + 1 + ) + + if rowSize%maxInSize == 0 { + batchSize = rowSize / maxInSize + } + + for batch := 0; batch < batchSize; batch++ { + if batch > 0 { + whereStr.WriteString(" OR ") + } + whereStr.WriteString("(") + + for i := 0; i < len(pkNameList); i++ { + if i > 0 { + whereStr.WriteString(",") + } + // todo add escape + whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i])) + } + whereStr.WriteString(") IN (") + + var eachSize int + + if batch == batchSize-1 { + if rowSize%maxInSize == 0 { + eachSize = maxInSize + } else { + eachSize = rowSize % maxInSize + } + } else { + eachSize = maxInSize + } + + for i := 0; i < eachSize; i++ { + if i > 0 { + whereStr.WriteString(",") + } + whereStr.WriteString("(") + for j := 0; j < len(pkNameList); j++ { + if j > 0 { + whereStr.WriteString(",") + } + whereStr.WriteString("?") + } + whereStr.WriteString(")") + } + whereStr.WriteString(")") + } + return whereStr.String() +} + +func (b *baseExecutor) buildPKParams(rows []types.RowImage, pkNameList []string) []driver.NamedValue { + params := make([]driver.NamedValue, 0) + for _, row := range rows { + coumnMap := row.GetColumnMap() + for i, pk := range pkNameList { + if col, ok := coumnMap[pk]; ok { + params = append(params, driver.NamedValue{ + Ordinal: i, Value: col.Value, + }) + } + } + } + return params +} + +// the string as local key. the local key example(multi pk): "t_user:1_a,2_b" +func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.TableMeta) string { + var ( + lockKeys bytes.Buffer + filedSequence int + ) + lockKeys.WriteString(meta.TableName) + lockKeys.WriteString(":") + + keys := meta.GetPrimaryKeyOnlyName() + + for _, row := range records.Rows { + if filedSequence > 0 { + lockKeys.WriteString(",") + } + pkSplitIndex := 0 + for _, column := range row.Columns { + var hasKeyColumn bool + for _, key := range keys { + if column.ColumnName == key { + hasKeyColumn = true + if pkSplitIndex > 0 { + lockKeys.WriteString("_") + } + lockKeys.WriteString(fmt.Sprintf("%v", column.Value)) + pkSplitIndex++ + } + } + if hasKeyColumn { + filedSequence++ + } + } + } + + return lockKeys.String() +} diff --git a/pkg/datasource/sql/exec/at/executor_at.go b/pkg/datasource/sql/exec/at/executor_at.go index 22d7c65c..ca9167bd 100644 --- a/pkg/datasource/sql/exec/at/executor_at.go +++ b/pkg/datasource/sql/exec/at/executor_at.go @@ -19,182 +19,61 @@ package at import ( "context" - "fmt" - "github.com/mitchellh/copystructure" - "github.com/pkg/errors" - "github.com/seata/seata-go/pkg/datasource/sql/parser" - "github.com/seata/seata-go/pkg/datasource/sql/undo" + "github.com/seata/seata-go/pkg/datasource/sql/util" "github.com/seata/seata-go/pkg/tm" "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/parser" "github.com/seata/seata-go/pkg/datasource/sql/types" ) -type ATExecutor struct { - hooks []exec.SQLHook - ex exec.SQLExecutor +func init() { + exec.RegisterATExecutor(types.DBTypeMySQL, func() exec.SQLExecutor { return &AtExecutor{} }) } -func (e *ATExecutor) Interceptors(hooks []exec.SQLHook) { - e.hooks = hooks +type executor interface { + ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) } -func (e *ATExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { - for _, hook := range e.hooks { - hook.Before(ctx, execCtx) - } - - var ( - beforeImages []*types.RecordImage - afterImages []*types.RecordImage - result types.ExecResult - err error - ) - - beforeImages, err = e.beforeImage(ctx, execCtx) - if err != nil { - return nil, err - } - if beforeImages != nil { - beforeImagesTmp, err := copystructure.Copy(beforeImages) - if err != nil { - return nil, err - } - newBeforeImages, ok := beforeImagesTmp.([]*types.RecordImage) - if !ok { - return nil, errors.New("copy beforeImages failed") - } - execCtx.TxCtx.RoundImages.AppendBeofreImages(newBeforeImages) - } - - defer func() { - for _, hook := range e.hooks { - hook.After(ctx, execCtx) - } - }() - - if e.ex != nil { - result, err = e.ex.ExecWithNamedValue(ctx, execCtx, f) - } else { - result, err = f(ctx, execCtx.Query, execCtx.NamedValues) - } - - if err != nil { - return nil, err - } - - afterImages, err = e.afterImage(ctx, execCtx, beforeImages) - if err != nil { - return nil, err - } - if afterImages != nil { - execCtx.TxCtx.RoundImages.AppendAfterImages(afterImages) - } - - return result, err +type AtExecutor struct { + hooks []exec.SQLHook } -func (e *ATExecutor) prepareUndoLog(ctx context.Context, execCtx *types.ExecContext) error { - if execCtx.TxCtx.RoundImages.IsEmpty() { - return nil - } - - if execCtx.ParseContext.UpdateStmt != nil { - if !execCtx.TxCtx.RoundImages.IsBeforeAfterSizeEq() { - return fmt.Errorf("Before image size is not equaled to after image size, probably because you updated the primary keys.") - } - } - undoLogManager, err := undo.GetUndoLogManager(execCtx.DBType) - if err != nil { - return err - } - return undoLogManager.FlushUndoLog(execCtx.TxCtx, execCtx.Conn) +func (e *AtExecutor) Interceptors(hooks []exec.SQLHook) { + e.hooks = hooks } -func (e *ATExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) { - for _, hook := range e.hooks { - hook.Before(ctx, execCtx) - } - - var ( - beforeImages []*types.RecordImage - afterImages []*types.RecordImage - result types.ExecResult - err error - ) - - beforeImages, err = e.beforeImage(ctx, execCtx) - if err != nil { - return nil, err - } - if beforeImages != nil { - execCtx.TxCtx.RoundImages.AppendBeofreImages(beforeImages) - } - - defer func() { - for _, hook := range e.hooks { - hook.After(ctx, execCtx) - } - }() - - if e.ex != nil { - result, err = e.ex.ExecWithValue(ctx, execCtx, f) - } else { - result, err = f(ctx, execCtx.Query, execCtx.Values) - } - if err != nil { - return nil, err - } - - afterImages, err = e.afterImage(ctx, execCtx, beforeImages) +//ExecWithNamedValue +func (e *AtExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { + parser, err := parser.DoParser(execCtx.Query) if err != nil { return nil, err } - if afterImages != nil { - execCtx.TxCtx.RoundImages.AppendAfterImages(afterImages) - } - return result, err -} + var exec executor -func (e *ATExecutor) beforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) { if !tm.IsGlobalTx(ctx) { - return nil, nil - } - - pc, err := parser.DoParser(execCtx.Query) - if err != nil { - return nil, err - } - if !pc.HasValidStmt() { - return nil, nil + exec = NewPlainExecutor(parser, execCtx) + } else { + switch parser.SQLType { + //case types.SQLTypeInsert: + case types.SQLTypeUpdate: + exec = NewUpdateExecutor(parser, execCtx, e.hooks) + //case types.SQLTypeDelete: + //case types.SQLTypeSelectForUpdate: + //case types.SQLTypeMultiDelete: + //case types.SQLTypeMultiUpdate: + default: + exec = NewPlainExecutor(parser, execCtx) + } } - execCtx.ParseContext = pc - builder := undo.GetUndologBuilder(pc.ExecutorType) - if builder == nil { - return nil, nil - } - return builder.BeforeImage(ctx, execCtx) + return exec.ExecContext(ctx, f) } -// After -func (e *ATExecutor) afterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { - if !tm.IsGlobalTx(ctx) { - return nil, nil - } - pc, err := parser.DoParser(execCtx.Query) - if err != nil { - return nil, err - } - if !pc.HasValidStmt() { - return nil, nil - } - execCtx.ParseContext = pc - builder := undo.GetUndologBuilder(pc.ExecutorType) - if builder == nil { - return nil, nil - } - return builder.AfterImage(ctx, execCtx, beforeImages) +//ExecWithValue +func (e *AtExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { + execCtx.NamedValues = util.ValueToNamedValue(execCtx.Values) + return e.ExecWithNamedValue(ctx, execCtx, f) } diff --git a/pkg/datasource/sql/exec/at/default.go b/pkg/datasource/sql/exec/at/plain_executor.go similarity index 65% rename from pkg/datasource/sql/exec/at/default.go rename to pkg/datasource/sql/exec/at/plain_executor.go index 0eb5699f..bb02a5b1 100644 --- a/pkg/datasource/sql/exec/at/default.go +++ b/pkg/datasource/sql/exec/at/plain_executor.go @@ -18,21 +18,21 @@ package at import ( + "context" + "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/types" ) -func init() { - exec.RegisterATExecutor(types.DBTypeMySQL, types.UpdateExecutor, func() exec.SQLExecutor { - return &ATExecutor{} - }) - exec.RegisterATExecutor(types.DBTypeMySQL, types.SelectExecutor, func() exec.SQLExecutor { - return &ATExecutor{} - }) - exec.RegisterATExecutor(types.DBTypeMySQL, types.InsertExecutor, func() exec.SQLExecutor { - return &ATExecutor{} - }) - exec.RegisterATExecutor(types.DBTypeMySQL, types.DeleteExecutor, func() exec.SQLExecutor { - return &ATExecutor{} - }) +type PlainExecutor struct { + parserCtx *types.ParseContext + execCtx *types.ExecContext +} + +func NewPlainExecutor(parserCtx *types.ParseContext, execCtx *types.ExecContext) *PlainExecutor { + return &PlainExecutor{parserCtx: parserCtx, execCtx: execCtx} +} + +func (u *PlainExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { + return f(ctx, u.execCtx.Query, u.execCtx.NamedValues) } diff --git a/pkg/datasource/sql/exec/at/update_executor.go b/pkg/datasource/sql/exec/at/update_executor.go new file mode 100644 index 00000000..57493441 --- /dev/null +++ b/pkg/datasource/sql/exec/at/update_executor.go @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package at + +import ( + "context" + "database/sql/driver" + "fmt" + "strings" + + "github.com/arana-db/parser/ast" + "github.com/arana-db/parser/format" + "github.com/arana-db/parser/model" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/util" + "github.com/seata/seata-go/pkg/util/bytes" + "github.com/seata/seata-go/pkg/util/log" +) + +var ( + // todo: OnlyCareUpdateColumns should load from config first + onlyCareUpdateColumns = true + maxInSize = 1000 +) + +type updateExecutor struct { + baseExecutor + parserCtx *types.ParseContext + execContent *types.ExecContext +} + +//NewUpdateExecutor get update executor +func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { + return &updateExecutor{parserCtx: parserCtx, execContent: execContent, baseExecutor: baseExecutor{hooks: hooks}} +} + +//ExecContext exec SQL, and generate before image and after image +func (u *updateExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { + u.beforeHooks(ctx, u.execContent) + defer func() { + u.afterHooks(ctx, u.execContent) + }() + + beforeImage, err := u.beforeImage(ctx) + if err != nil { + return nil, err + } + + res, err := f(ctx, u.execContent.Query, u.execContent.NamedValues) + if err != nil { + return nil, err + } + + afterImage, err := u.afterImage(ctx, *beforeImage) + if err != nil { + return nil, err + } + + if len(beforeImage.Rows) != len(afterImage.Rows) { + return nil, fmt.Errorf("Before image size is not equaled to after image size, probably because you updated the primary keys.") + } + + u.execContent.TxCtx.RoundImages.AppendBeofreImage(beforeImage) + u.execContent.TxCtx.RoundImages.AppendAfterImage(afterImage) + + return res, nil +} + +//beforeImage build before image +func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) { + if !u.isAstStmtValid() { + return nil, nil + } + + selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContent.NamedValues) + if err != nil { + return nil, err + } + + tableName, _ := u.parserCtx.GteTableName() + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName) + if err != nil { + return nil, err + } + + var rowsi driver.Rows + queryerCtx, ok := u.execContent.Conn.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = u.execContent.Conn.(driver.Queryer) + } + if ok { + rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) + if err != nil { + log.Errorf("ctx driver query: %+v", err) + return nil, err + } + } else { + log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") + return nil, fmt.Errorf("invalid conn") + } + + image, err := u.buildRecordImages(rowsi, metaData) + if err != nil { + return nil, err + } + + lockKey := u.buildLockKey(image, *metaData) + u.execContent.TxCtx.LockKeys[lockKey] = struct{}{} + image.SQLType = u.parserCtx.SQLType + + return image, nil +} + +//afterImage build after image +func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.RecordImage) (*types.RecordImage, error) { + if !u.isAstStmtValid() { + return nil, nil + } + if len(beforeImage.Rows) == 0 { + return &types.RecordImage{}, nil + } + + tableName, _ := u.parserCtx.GteTableName() + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName) + if err != nil { + return nil, err + } + selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData) + + var rowsi driver.Rows + queryerCtx, ok := u.execContent.Conn.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = u.execContent.Conn.(driver.Queryer) + } + if ok { + rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) + if err != nil { + log.Errorf("ctx driver query: %+v", err) + return nil, err + } + } else { + log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") + return nil, fmt.Errorf("invalid conn") + } + + afterImage, err := u.buildRecordImages(rowsi, metaData) + if err != nil { + return nil, err + } + afterImage.SQLType = u.parserCtx.SQLType + + return afterImage, nil +} + +func (u *updateExecutor) isAstStmtValid() bool { + return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil +} + +//buildAfterImageSQL build the SQL to query after image data +func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta *types.TableMeta) (string, []driver.NamedValue) { + if len(beforeImage.Rows) == 0 { + return "", nil + } + sb := strings.Builder{} + // todo: OnlyCareUpdateColumns should load from config first + var selectFields string + var separator = "," + if onlyCareUpdateColumns { + for _, row := range beforeImage.Rows { + for _, column := range row.Columns { + selectFields += column.ColumnName + separator + } + } + selectFields = strings.TrimSuffix(selectFields, separator) + } else { + selectFields = "*" + } + sb.WriteString("SELECT " + selectFields + " FROM " + meta.TableName + " WHERE ") + whereSQL := u.buildWhereConditionByPKs(meta.GetPrimaryKeyOnlyName(), len(beforeImage.Rows), "mysql", maxInSize) + sb.WriteString(" " + whereSQL + " ") + return sb.String(), u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()) +} + +//buildAfterImageSQL build the SQL to query before image data +func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) { + if !u.isAstStmtValid() { + log.Errorf("invalid update stmt") + return "", nil, fmt.Errorf("invalid update stmt") + } + + updateStmt := u.parserCtx.UpdateStmt + fields := make([]*ast.SelectField, 0, len(updateStmt.List)) + + if onlyCareUpdateColumns { + for _, column := range updateStmt.List { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: column.Column, + }, + }) + } + + // select indexes columns + tableName, _ := u.parserCtx.GteTableName() + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName) + if err != nil { + return "", nil, err + } + for _, columnName := range metaData.GetPrimaryKeyOnlyName() { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.CIStr{ + O: columnName, + L: columnName, + }, + }, + }, + }) + } + } else { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.CIStr{ + O: "*", + L: "*", + }, + }, + }, + }) + } + + selStmt := ast.SelectStmt{ + SelectStmtOpts: &ast.SelectStmtOpts{}, + From: updateStmt.TableRefs, + Where: updateStmt.Where, + Fields: &ast.FieldList{Fields: fields}, + OrderBy: updateStmt.Order, + Limit: updateStmt.Limit, + TableHints: updateStmt.TableHints, + LockInfo: &ast.SelectLockInfo{ + LockType: ast.SelectLockForUpdate, + }, + } + + b := bytes.NewByteBuffer([]byte{}) + _ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) + sql := string(b.Bytes()) + log.Infof("build select sql by update sourceQuery, sql {%s}", sql) + + return sql, u.buildSelectArgs(&selStmt, args), nil +} diff --git a/pkg/datasource/sql/exec/at/update_executor_test.go b/pkg/datasource/sql/exec/at/update_executor_test.go new file mode 100644 index 00000000..ee04c48d --- /dev/null +++ b/pkg/datasource/sql/exec/at/update_executor_test.go @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package at + +import ( + "context" + "database/sql/driver" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" + "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/util" + + "github.com/seata/seata-go/pkg/datasource/sql/types" + + "github.com/seata/seata-go/pkg/datasource/sql/parser" + + _ "github.com/arana-db/parser/test_driver" + _ "github.com/seata/seata-go/pkg/util/log" + "github.com/stretchr/testify/assert" +) + +func TestBuildSelectSQLByUpdate(t *testing.T) { + datasource.RegisterTableCache(types.DBTypeMySQL, mysql.NewTableMetaInstance(nil)) + stub := gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)), "GetTableMeta", + func(_ *mysql.TableMetaCache, ctx context.Context, dbName, tableName string) (*types.TableMeta, error) { + return &types.TableMeta{ + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + }, nil + }) + defer stub.Reset() + + tests := []struct { + name string + sourceQuery string + sourceQueryArgs []driver.Value + expectQuery string + expectQueryArgs []driver.Value + }{ + { + sourceQuery: "update t_user set name = ?, age = ? where id = ?", + sourceQueryArgs: []driver.Value{"Jack", 1, 100}, + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? FOR UPDATE", + expectQueryArgs: []driver.Value{100}, + }, + { + sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age between ? and ?", + sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE", + expectQueryArgs: []driver.Value{100, 18, 28}, + }, + { + sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age in (?,?)", + sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE", + expectQueryArgs: []driver.Value{100, 18, 28}, + }, + { + sourceQuery: "update t_user set name = ?, age = ? where kk between ? and ? and id = ? and addr in(?,?) and age > ? order by name desc limit ?", + sourceQueryArgs: []driver.Value{"Jack", 1, 10, 20, 17, "Beijing", "Guangzhou", 18, 2}, + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE", + expectQueryArgs: []driver.Value{10, 20, 17, "Beijing", "Guangzhou", 18, 2}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := parser.DoParser(tt.sourceQuery) + assert.Nil(t, err) + executor := NewUpdateExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{}) + query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), util.ValueToNamedValue(tt.sourceQueryArgs)) + assert.Nil(t, err) + assert.Equal(t, tt.expectQuery, query) + assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args)) + }) + } +} diff --git a/pkg/datasource/sql/exec/executor.go b/pkg/datasource/sql/exec/executor.go index 09337330..37feb854 100644 --- a/pkg/datasource/sql/exec/executor.go +++ b/pkg/datasource/sql/exec/executor.go @@ -25,35 +25,25 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/undo" "github.com/seata/seata-go/pkg/datasource/sql/undo/builder" - "github.com/seata/seata-go/pkg/util/log" ) func init() { - undo.RegisterUndoLogBuilder(types.UpdateExecutor, builder.GetMySQLUpdateUndoLogBuilder) undo.RegisterUndoLogBuilder(types.MultiExecutor, builder.GetMySQLMultiUndoLogBuilder) } var ( - executorSoltsAT = make(map[types.DBType]map[types.ExecutorType]func() SQLExecutor) - executorSoltsXA = make(map[types.DBType]func() SQLExecutor) + atExecutors = make(map[types.DBType]func() SQLExecutor) + xaExecutors = make(map[types.DBType]func() SQLExecutor) ) // RegisterATExecutor AT executor -func RegisterATExecutor(dt types.DBType, et types.ExecutorType, builder func() SQLExecutor) { - if _, ok := executorSoltsAT[dt]; !ok { - executorSoltsAT[dt] = make(map[types.ExecutorType]func() SQLExecutor) - } - - val := executorSoltsAT[dt] - - val[et] = func() SQLExecutor { - return &BaseExecutor{ex: builder()} - } +func RegisterATExecutor(dt types.DBType, builder func() SQLExecutor) { + atExecutors[dt] = builder } // RegisterXAExecutor XA executor func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) { - executorSoltsXA[dt] = func() SQLExecutor { + xaExecutors[dt] = func() SQLExecutor { return builder() } } @@ -66,7 +56,7 @@ type ( SQLExecutor interface { Interceptors(interceptors []SQLHook) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) - ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error) + ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) } ) @@ -83,37 +73,14 @@ func BuildExecutor(dbType types.DBType, transactionMode types.TransactionMode, q hooks = append(hooks, hookSolts[parseContext.SQLType]...) if transactionMode == types.XAMode { - e := executorSoltsXA[dbType]() - e.Interceptors(hooks) - return e, nil - } - - if transactionMode == types.ATMode { - e := executorSoltsAT[dbType][parseContext.ExecutorType]() - e.Interceptors(hooks) - return e, nil - } - - factories, ok := executorSoltsAT[dbType] - if !ok { - log.Debugf("%s not found executor factories, return default Executor", dbType.String()) - e := &BaseExecutor{} + e := xaExecutors[dbType]() e.Interceptors(hooks) return e, nil } - supplier, ok := factories[parseContext.ExecutorType] - if !ok { - log.Debugf("%s not found executor for %s, return default Executor", - dbType.String(), parseContext.ExecutorType) - e := &BaseExecutor{} - e.Interceptors(hooks) - return e, nil - } - - executor := supplier() - executor.Interceptors(hooks) - return executor, nil + e := atExecutors[dbType]() + e.Interceptors(hooks) + return e, nil } type BaseExecutor struct { @@ -146,7 +113,7 @@ func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Ex } // ExecWithValue -func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error) { +func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) { for i := range e.hooks { e.hooks[i].Before(ctx, execCtx) } @@ -161,5 +128,13 @@ func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecCon return e.ex.ExecWithValue(ctx, execCtx, f) } - return f(ctx, execCtx.Query, execCtx.Values) + nvargs := make([]driver.NamedValue, len(execCtx.Values)) + for i, value := range execCtx.Values { + nvargs = append(nvargs, driver.NamedValue{ + Value: value, + Ordinal: i, + }) + } + + return f(ctx, execCtx.Query, nvargs) } diff --git a/pkg/datasource/sql/exec/xa/executor_xa.go b/pkg/datasource/sql/exec/xa/executor_xa.go index e7a0c3c6..8075a454 100644 --- a/pkg/datasource/sql/exec/xa/executor_xa.go +++ b/pkg/datasource/sql/exec/xa/executor_xa.go @@ -19,6 +19,7 @@ package xa import ( "context" + "database/sql/driver" "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/types" @@ -55,7 +56,7 @@ func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Exec } // ExecWithValue -func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) { +func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { for _, hook := range e.hooks { hook.Before(ctx, execCtx) } @@ -70,5 +71,14 @@ func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecConte return e.ex.ExecWithValue(ctx, execCtx, f) } - return f(ctx, execCtx.Query, execCtx.Values) + nvargs := make([]driver.NamedValue, len(execCtx.Values)) + for i, value := range execCtx.Values { + nvargs = append(nvargs, driver.NamedValue{ + Value: value, + Ordinal: i, + }) + } + execCtx.NamedValues = nvargs + + return f(ctx, execCtx.Query, execCtx.NamedValues) } diff --git a/pkg/datasource/sql/stmt.go b/pkg/datasource/sql/stmt.go index 9e5a84ed..2314e344 100644 --- a/pkg/datasource/sql/stmt.go +++ b/pkg/datasource/sql/stmt.go @@ -23,6 +23,7 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/util" ) type Stmt struct { @@ -75,8 +76,8 @@ func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { } ret, err := executor.ExecWithValue(context.Background(), execCtx, - func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { - ret, err := s.stmt.Query(args) + func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { + ret, err := s.stmt.Query(util.NamedValueToValue(args)) if err != nil { return nil, err } @@ -144,8 +145,8 @@ func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { } ret, err := executor.ExecWithValue(context.Background(), execCtx, - func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { - ret, err := s.stmt.Exec(args) + func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { + ret, err := s.stmt.Exec(util.NamedValueToValue(args)) if err != nil { return nil, err } diff --git a/pkg/datasource/sql/types/image.go b/pkg/datasource/sql/types/image.go index fced189a..eb9251a1 100644 --- a/pkg/datasource/sql/types/image.go +++ b/pkg/datasource/sql/types/image.go @@ -129,7 +129,8 @@ type RowImage struct { func (r *RowImage) GetColumnMap() map[string]*ColumnImage { m := make(map[string]*ColumnImage, 0) for _, column := range r.Columns { - m[column.ColumnName] = &column + tmpColumn := column + m[column.ColumnName] = &tmpColumn } return m } @@ -245,7 +246,7 @@ func (c *ColumnImage) UnmarshalJSON(data []byte) error { case JDBCTypeChar, JDBCTypeVarchar: var val []byte if val, err = base64.StdEncoding.DecodeString(value.(string)); err != nil { - return err + val = []byte(value.(string)) } actualValue = string(val) case JDBCTypeBinary, JDBCTypeVarBinary, JDBCTypeLongVarBinary, JDBCTypeBit: diff --git a/pkg/datasource/sql/util/convert.go b/pkg/datasource/sql/util/convert.go new file mode 100644 index 00000000..bfcf82f2 --- /dev/null +++ b/pkg/datasource/sql/util/convert.go @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Type conversions for Scan. + +package util + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" + "time" +) + +var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error + +// convertAssignRows copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. If rows is passed in, the rows will +// be used as the parent for any cursor values converted from a +// driver.Rows to a *ScanRows. +func convertAssignRows(dest, src interface{}, rows *ScanRows) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = append((*d)[:0], s...) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = string(s) + return nil + case *interface{}: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = s + return nil + } + case time.Time: + switch d := dest.(type) { + case *time.Time: + *d = s + return nil + case *string: + *d = s.Format(time.RFC3339Nano) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s.Format(time.RFC3339Nano)) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = s.AppendFormat((*d)[:0], time.RFC3339Nano) + return nil + } + case decimalDecompose: + switch d := dest.(type) { + case decimalCompose: + return d.Compose(s.Decompose(nil)) + } + case nil: + switch d := dest.(type) { + case *interface{}: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = nil + return nil + } + // The driver is returning a cursor the client may iterate over. + case driver.Rows: + switch d := dest.(type) { + case *ScanRows: + if d == nil { + return errNilPtr + } + if rows == nil { + return errors.New("invalid context to convert cursor rows, missing parent *ScanRows") + } + rows.closemu.Lock() + *d = ScanRows{ + dc: rows.dc, + releaseConn: func(error) {}, + rowsi: s, + } + // Chain the cancel function. + parentCancel := rows.cancel + rows.cancel = func() { + // When ScanRows.cancel is called, the closemu will be locked as well. + // So we can access rs.lasterr. + d.close(rows.lasterr) + if parentCancel != nil { + parentCancel() + } + } + rows.closemu.Unlock() + return nil + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = asString(src) + return nil + } + case *[]byte: + sv = reflect.ValueOf(src) + if b, ok := asBytes(nil, sv); ok { + *d = b + return nil + } + case *sql.RawBytes: + sv = reflect.ValueOf(src) + if b, ok := asBytes([]byte(*d)[:0], sv); ok { + *d = sql.RawBytes(b) + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *interface{}: + *d = src + return nil + } + + if scanner, ok := dest.(sql.Scanner); ok { + return scanner.Scan(src) + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Ptr { + return errors.New("destination not a pointer") + } + if dpv.IsNil() { + return errNilPtr + } + + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + + dv := reflect.Indirect(dpv) + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(cloneBytes(b))) + default: + dv.Set(sv) + } + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. + switch dv.Kind() { + case reflect.Ptr: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + dv.Set(reflect.New(dv.Type().Elem())) + return convertAssignRows(dv.Interface(), src, rows) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } + } + + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) +} + +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } + c := make([]byte, len(b)) + copy(c, b) + return c +} + +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(buf, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(buf, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(buf, rv.Bool()), true + case reflect.String: + s := rv.String() + return append(buf, s...), true + } + return +} + +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +type decimalDecompose interface { + // Decompose returns the internal decimal state in parts. + // If the provided buf has sufficient capacity, buf may be returned as the coefficient with + // the value set and length set as appropriate. + Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) +} + +type decimalCompose interface { + // Compose sets the internal decimal value from parts. If the value cannot be + // represented then an error should be returned. + Compose(form byte, negative bool, coefficient []byte, exponent int32) error +} diff --git a/pkg/datasource/sql/util/ctxutil.go b/pkg/datasource/sql/util/ctxutil.go new file mode 100644 index 00000000..fa29440b --- /dev/null +++ b/pkg/datasource/sql/util/ctxutil.go @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package util + +import ( + "context" + "database/sql/driver" + "errors" +) + +func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) { + if ciCtx, is := ci.(driver.ConnPrepareContext); is { + return ciCtx.PrepareContext(ctx, query) + } + si, err := ci.Prepare(query) + if err == nil { + select { + default: + case <-ctx.Done(): + si.Close() + return nil, ctx.Err() + } + } + return si, err +} + +func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) { + if execerCtx != nil { + return execerCtx.ExecContext(ctx, query, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + return execer.Exec(query, dargs) +} + +func CtxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) { + if queryerCtx != nil { + return queryerCtx.QueryContext(ctx, query, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + return queryer.Query(query, dargs) +} + +func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) { + if siCtx, is := si.(driver.StmtExecContext); is { + return siCtx.ExecContext(ctx, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + return si.Exec(dargs) +} + +func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) { + if siCtx, is := si.(driver.StmtQueryContext); is { + return siCtx.QueryContext(ctx, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + return si.Query(dargs) +} + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + return nil, errors.New("sql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} diff --git a/pkg/datasource/sql/util/params.go b/pkg/datasource/sql/util/params.go new file mode 100644 index 00000000..160e9f70 --- /dev/null +++ b/pkg/datasource/sql/util/params.go @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import "database/sql/driver" + +func NamedValueToValue(nvs []driver.NamedValue) []driver.Value { + vs := make([]driver.Value, 0, len(nvs)) + for _, nv := range nvs { + vs = append(vs, nv.Value) + } + return vs +} + +func ValueToNamedValue(vs []driver.Value) []driver.NamedValue { + nvs := make([]driver.NamedValue, 0, len(vs)) + for i, v := range vs { + nvs = append(nvs, driver.NamedValue{ + Value: v, + Ordinal: i, + }) + } + return nvs +} diff --git a/pkg/datasource/sql/util/sql.go b/pkg/datasource/sql/util/sql.go new file mode 100644 index 00000000..7c6e609f --- /dev/null +++ b/pkg/datasource/sql/util/sql.go @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sql provides a generic interface around SQL (or SQL-like) +// databases. +// +// The sql package must be used in conjunction with a database driver. +// See https://golang.org/s/sqldrivers for a list of drivers. +// +// Drivers that do not support context cancellation will not return until +// after the query is completed. +// +// For usage examples, see the wiki page at +// https://golang.org/s/sqlwiki. +package util + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "io" + "sync" +) + +// ScanRows is the result of a query. Its cursor starts before the first row +// of the result set. Use Next to advance from row to row. +type ScanRows struct { + //dc *driverConn // owned; must call releaseConn when closed to release + dc driver.Conn + releaseConn func(error) + rowsi driver.Rows + cancel func() // called when ScanRows is closed, may be nil. + //closeStmt *driverStmt // if non-nil, statement to Close on close + + // closemu prevents ScanRows from closing while there + // is an active streaming result. It is held for read during non-close operations + // and exclusively during close. + // + // closemu guards lasterr and closed. + closemu sync.RWMutex + closed bool + lasterr error // non-nil only if closed is true + + // lastcols is only used in Scan, Next, and NextResultSet which are expected + // not to be called concurrently. + lastcols []driver.Value +} + +func NewScanRows(rowsi driver.Rows) *ScanRows { + return &ScanRows{rowsi: rowsi} +} + +// lasterrOrErrLocked returns either lasterr or the provided err. +// rs.closemu must be read-locked. +func (rs *ScanRows) lasterrOrErrLocked(err error) error { + if rs.lasterr != nil && rs.lasterr != io.EOF { + return rs.lasterr + } + return err +} + +// bypassRowsAwaitDone is only used for testing. +// If true, it will not close the ScanRows automatically from the context. +var bypassRowsAwaitDone = false + +func (rs *ScanRows) initContextClose(ctx, txctx context.Context) { + if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) { + return + } + if bypassRowsAwaitDone { + return + } + ctx, rs.cancel = context.WithCancel(ctx) + go rs.awaitDone(ctx, txctx) +} + +// awaitDone blocks until either ctx or txctx is canceled. The ctx is provided +// from the query context and is canceled when the query ScanRows is closed. +// If the query was issued in a transaction, the transaction's context +// is also provided in txctx to ensure ScanRows is closed if the Tx is closed. +func (rs *ScanRows) awaitDone(ctx, txctx context.Context) { + var txctxDone <-chan struct{} + if txctx != nil { + txctxDone = txctx.Done() + } + select { + case <-ctx.Done(): + case <-txctxDone: + } + rs.close(ctx.Err()) +} + +// Next prepares the next result row for reading with the Scan method. It +// returns true on success, or false if there is no next result row or an error +// happened while preparing it. Err should be consulted to distinguish between +// the two cases. +// +// Every call to Scan, even the first one, must be preceded by a call to Next. +func (rs *ScanRows) Next() bool { + var doClose, ok bool + withLock(rs.closemu.RLocker(), func() { + doClose, ok = rs.nextLocked() + }) + if doClose { + rs.Close() + } + return ok +} + +func (rs *ScanRows) nextLocked() (doClose, ok bool) { + if rs.closed { + return false, false + } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + //rs.dc.Lock() + //defer rs.dc.Unlock() + + if rs.lastcols == nil { + rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) + } + + rs.lasterr = rs.rowsi.Next(rs.lastcols) + if rs.lasterr != nil { + // Close the connection if there is a driver error. + if rs.lasterr != io.EOF { + return true, false + } + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + return true, false + } + // The driver is at the end of the current result set. + // Test to see if there is another result set after the current one. + // Only close ScanRows if there is no further result sets to read. + if !nextResultSet.HasNextResultSet() { + doClose = true + } + return doClose, false + } + return false, true +} + +// NextResultSet prepares the next result set for reading. It reports whether +// there is further result sets, or false if there is no further result set +// or if there is an error advancing to it. The Err method should be consulted +// to distinguish between the two cases. +// +// After calling NextResultSet, the Next method should always be called before +// scanning. If there are further result sets they may not have rows in the result +// set. +func (rs *ScanRows) NextResultSet() bool { + var doClose bool + defer func() { + if doClose { + rs.Close() + } + }() + rs.closemu.RLock() + defer rs.closemu.RUnlock() + + if rs.closed { + return false + } + + rs.lastcols = nil + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + doClose = true + return false + } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + //rs.dc.Lock() + //defer rs.dc.Unlock() + + rs.lasterr = nextResultSet.NextResultSet() + if rs.lasterr != nil { + doClose = true + return false + } + return true +} + +// Err returns the error, if any, that was encountered during iteration. +// Err may be called after an explicit or implicit Close. +func (rs *ScanRows) Err() error { + rs.closemu.RLock() + defer rs.closemu.RUnlock() + return rs.lasterrOrErrLocked(nil) +} + +// +var errRowsClosed = errors.New("sql: ScanRows are closed") + +// Scan copies the columns in the current row into the values pointed +// at by dest. The number of values in dest must be the same as the +// number of columns in ScanRows. +// +// Scan converts columns read from the database into the following +// common Go types and special types provided by the sql package: +// +// *string +// *[]byte +// *int, *int8, *int16, *int32, *int64 +// *uint, *uint8, *uint16, *uint32, *uint64 +// *bool +// *float32, *float64 +// *interface{} +// *RawBytes +// *ScanRows (cursor value) +// any type implementing Scanner (see Scanner docs) +// +// In the most simple case, if the type of the value from the source +// column is an integer, bool or string type T and dest is of type *T, +// Scan simply assigns the value through the pointer. +// +// Scan also converts between string and numeric types, as long as no +// information would be lost. While Scan stringifies all numbers +// scanned from numeric database columns into *string, scans into +// numeric types are checked for overflow. For example, a float64 with +// value 300 or a string with value "300" can scan into a uint16, but +// not into a uint8, though float64(255) or "255" can scan into a +// uint8. One exception is that scans of some float64 numbers to +// strings may lose information when stringifying. In general, scan +// floating point columns into *float64. +// +// If a dest argument has type *[]byte, Scan saves in that argument a +// copy of the corresponding data. The copy is owned by the caller and +// can be modified and held indefinitely. The copy can be avoided by +// using an argument of type *RawBytes instead; see the documentation +// for RawBytes for restrictions on its use. +// +// If an argument has type *interface{}, Scan copies the value +// provided by the underlying driver without conversion. When scanning +// from a source value of type []byte to *interface{}, a copy of the +// slice is made and the caller owns the result. +// +// Source values of type time.Time may be scanned into values of type +// *time.Time, *interface{}, *string, or *[]byte. When converting to +// the latter two, time.RFC3339Nano is used. +// +// Source values of type bool may be scanned into types *bool, +// *interface{}, *string, *[]byte, or *RawBytes. +// +// For scanning into *bool, the source may be true, false, 1, 0, or +// string inputs parseable by strconv.ParseBool. +// +// Scan can also convert a cursor returned from a query, such as +// "select cursor(select * from my_table) from dual", into a +// *ScanRows value that can itself be scanned from. The parent +// select query will close any cursor *ScanRows if the parent *ScanRows is closed. +// +// If any of the first arguments implementing Scanner returns an error, +// that error will be wrapped in the returned error +func (rs *ScanRows) Scan(dest ...interface{}) error { + rs.closemu.RLock() + + if rs.lasterr != nil && rs.lasterr != io.EOF { + rs.closemu.RUnlock() + return rs.lasterr + } + if rs.closed { + err := rs.lasterrOrErrLocked(errRowsClosed) + rs.closemu.RUnlock() + return err + } + rs.closemu.RUnlock() + + if rs.lastcols == nil { + return errors.New("sql: Scan called without calling Next") + } + if len(dest) != len(rs.lastcols) { + return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest)) + } + for i, sv := range rs.lastcols { + err := convertAssignRows(dest[i], sv, rs) + if err != nil { + return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err) + } + } + return nil +} + +// rowsCloseHook returns a function so tests may install the +// hook through a test only mutex. +var rowsCloseHook = func() func(*ScanRows, *error) { return nil } + +// Close closes the ScanRows, preventing further enumeration. If Next is called +// and returns false and there are no further result sets, +// the ScanRows are closed automatically and it will suffice to check the +// result of Err. Close is idempotent and does not affect the result of Err. +func (rs *ScanRows) Close() error { + //return rs.close(nil) + return nil +} + +func (rs *ScanRows) close(err error) error { + rs.closemu.Lock() + defer rs.closemu.Unlock() + + if rs.closed { + return nil + } + rs.closed = true + + if rs.lasterr == nil { + rs.lasterr = err + } + + err = rs.rowsi.Close() + //withLock(rs.dc, func() { + // err = rs.rowsi.Close() + //}) + if fn := rowsCloseHook(); fn != nil { + fn(rs, &err) + } + if rs.cancel != nil { + rs.cancel() + } + + //if rs.closeStmt != nil { + // rs.closeStmt.Close() + //} + rs.releaseConn(err) + return err +} + +// withLock runs while holding lk. +func withLock(lk sync.Locker, fn func()) { + lk.Lock() + defer lk.Unlock() // in case fn panics + fn() +} diff --git a/pkg/remoting/processor/client/rm_branch_rollback_processor.go b/pkg/remoting/processor/client/rm_branch_rollback_processor.go index 96f0faf8..0a33d402 100644 --- a/pkg/remoting/processor/client/rm_branch_rollback_processor.go +++ b/pkg/remoting/processor/client/rm_branch_rollback_processor.go @@ -85,6 +85,6 @@ func (f *rmBranchRollbackProcessor) Process(ctx context.Context, rpcMessage mess log.Errorf("send branch rollback response error: {%#v}", err.Error()) return err } - log.Infof("send branch rollback response success: xid %s, branchID %s, resourceID %s, applicationData %s", xid, branchID, resourceID, applicationData) + log.Infof("send branch rollback response success: xid %s, branchID %d, resourceID %s, applicationData %s", xid, branchID, resourceID, applicationData) return nil } diff --git a/pkg/util/reflectx/unexpoert_field.go b/pkg/util/reflectx/unexpoert_field.go index 5f844af7..8e7bc038 100644 --- a/pkg/util/reflectx/unexpoert_field.go +++ b/pkg/util/reflectx/unexpoert_field.go @@ -31,3 +31,16 @@ func SetUnexportedField(field reflect.Value, value interface{}) { func GetUnexportedField(field reflect.Value) interface{} { return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() } + +func GetElemDataValue(data interface{}) interface{} { + if data == nil { + return data + } + value := reflect.ValueOf(data) + kind := reflect.TypeOf(data).Kind() + switch kind { + case reflect.Ptr: + return value.Elem().Interface() + } + return value.Interface() +} diff --git a/pkg/util/reflectx/unexpoert_field_test.go b/pkg/util/reflectx/unexpoert_field_test.go new file mode 100644 index 00000000..b355e477 --- /dev/null +++ b/pkg/util/reflectx/unexpoert_field_test.go @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package reflectx + +import ( + "testing" +) + +func TestGetElemDataValue(t *testing.T) { + var aa = 10 + var bb = "name" + var cc bool + + tests := []struct { + name string + args interface{} + want interface{} + }{ + {name: "test1", args: aa, want: aa}, + {name: "test2", args: &aa, want: aa}, + {name: "test3", args: bb, want: bb}, + {name: "test4", args: &bb, want: bb}, + {name: "test5", args: cc, want: cc}, + {name: "test6", args: &cc, want: cc}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetElemDataValue(tt.args); got != tt.want { + t.Errorf("GetElemDataValue() = %v, want %v", got, tt.want) + } + }) + } +}