Browse Source

optimize: refactor at executor (#397)

* refactor at executor
tags/v1.0.3
Yuecai Liu GitHub 2 years ago
parent
commit
a44063609e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1715 additions and 232 deletions
  1. +1
    -1
      go.mod
  2. +1
    -1
      pkg/client/config.go
  3. +3
    -2
      pkg/datasource/sql/conn.go
  4. +0
    -8
      pkg/datasource/sql/conn_test.go
  5. +296
    -0
      pkg/datasource/sql/exec/at/base_executor.go
  6. +32
    -153
      pkg/datasource/sql/exec/at/executor_at.go
  7. +13
    -13
      pkg/datasource/sql/exec/at/plain_executor.go
  8. +272
    -0
      pkg/datasource/sql/exec/at/update_executor.go
  9. +101
    -0
      pkg/datasource/sql/exec/at/update_executor_test.go
  10. +20
    -45
      pkg/datasource/sql/exec/executor.go
  11. +12
    -2
      pkg/datasource/sql/exec/xa/executor_xa.go
  12. +5
    -4
      pkg/datasource/sql/stmt.go
  13. +3
    -2
      pkg/datasource/sql/types/image.go
  14. +377
    -0
      pkg/datasource/sql/util/convert.go
  15. +123
    -0
      pkg/datasource/sql/util/ctxutil.go
  16. +39
    -0
      pkg/datasource/sql/util/params.go
  17. +355
    -0
      pkg/datasource/sql/util/sql.go
  18. +1
    -1
      pkg/remoting/processor/client/rm_branch_rollback_processor.go
  19. +13
    -0
      pkg/util/reflectx/unexpoert_field.go
  20. +48
    -0
      pkg/util/reflectx/unexpoert_field_test.go

+ 1
- 1
go.mod View File

@@ -15,7 +15,6 @@ require (
github.com/golang/mock v1.6.0
github.com/google/uuid v1.3.0
github.com/knadh/koanf v1.4.3
github.com/mitchellh/copystructure v1.2.0
github.com/natefinch/lumberjack v2.0.0+incompatible
github.com/parnurzeal/gorequest v0.2.16
github.com/pierrec/lz4/v4 v4.1.17
@@ -88,6 +87,7 @@ require (
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.6 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect


+ 1
- 1
pkg/client/config.go View File

@@ -32,8 +32,8 @@ import (
"github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/rawbytes"
"github.com/seata/seata-go/pkg/remoting/getty"
"github.com/seata/seata-go/pkg/tm"
"github.com/seata/seata-go/pkg/rm/tcc"
"github.com/seata/seata-go/pkg/tm"
"github.com/seata/seata-go/pkg/util/flagext"
)



+ 3
- 2
pkg/datasource/sql/conn.go View File

@@ -23,6 +23,7 @@ import (

"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/util"
)

// Conn is a connection to a database. It is not used concurrently
@@ -146,8 +147,8 @@ func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) {
}

ret, err := executor.ExecWithValue(context.Background(), execCtx,
func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) {
ret, err := conn.Query(query, args)
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) {
ret, err := conn.Query(query, util.NamedValueToValue(args))
if err != nil {
return nil, err
}


+ 0
- 8
pkg/datasource/sql/conn_test.go View File

@@ -26,14 +26,6 @@ import (
"github.com/stretchr/testify/assert"
)

func TestConn_BuildATExecutor(t *testing.T) {
executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.ATMode, "SELECT * FROM user")

assert.NoError(t, err)
_, ok := executor.(*exec.BaseExecutor)
assert.True(t, ok, "need base executor")
}

func TestConn_BuildXAExecutor(t *testing.T) {
executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.XAMode, "SELECT * FROM user")



+ 296
- 0
pkg/datasource/sql/exec/at/base_executor.go View File

@@ -0,0 +1,296 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package at

import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"fmt"
"strings"

"github.com/arana-db/parser/ast"
"github.com/arana-db/parser/test_driver"
gxsort "github.com/dubbogo/gost/sort"
"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/util"
"github.com/seata/seata-go/pkg/util/reflectx"
)

type baseExecutor struct {
hooks []exec.SQLHook
}

func (b *baseExecutor) beforeHooks(ctx context.Context, execCtx *types.ExecContext) {
for _, hook := range b.hooks {
hook.Before(ctx, execCtx)
}
}

func (b *baseExecutor) afterHooks(ctx context.Context, execCtx *types.ExecContext) {
for _, hook := range b.hooks {
hook.After(ctx, execCtx)
}
}

// GetScanSlice get the column type for scann
// todo to use ColumnInfo get slice
func (*baseExecutor) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []interface{} {
scanSlice := make([]interface{}, 0, len(columnNames))
for _, columnNmae := range columnNames {
var (
// 从metData获取该列的元信息
columnMeta = tableMeta.Columns[columnNmae]
)
switch strings.ToUpper(columnMeta.DatabaseTypeString) {
case "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", "JSON", "TINYTEXT":
var scanVal string
scanSlice = append(scanSlice, &scanVal)
case "BIT", "INT", "LONGBLOB", "SMALLINT", "TINYINT", "BIGINT", "MEDIUMINT":
if columnMeta.IsNullable == 0 {
scanVal := int64(0)
scanSlice = append(scanSlice, &scanVal)
} else {
scanVal := sql.NullInt64{}
scanSlice = append(scanSlice, &scanVal)
}
case "DATE", "DATETIME", "TIME", "TIMESTAMP", "YEAR":
scanVal := sql.NullTime{}
scanSlice = append(scanSlice, &scanVal)
case "DECIMAL", "DOUBLE", "FLOAT":
if columnMeta.IsNullable == 0 {
scanVal := float64(0)
scanSlice = append(scanSlice, &scanVal)
} else {
scanVal := sql.NullFloat64{}
scanSlice = append(scanSlice, &scanVal)
}
default:
scanVal := sql.RawBytes{}
scanSlice = append(scanSlice, &scanVal)
}
}
return scanSlice
}

func (b *baseExecutor) buildSelectArgs(stmt *ast.SelectStmt, args []driver.NamedValue) []driver.NamedValue {
var (
selectArgsIndexs = make([]int32, 0)
selectArgs = make([]driver.NamedValue, 0)
)

b.traversalArgs(stmt.Where, &selectArgsIndexs)
if stmt.OrderBy != nil {
for _, item := range stmt.OrderBy.Items {
b.traversalArgs(item, &selectArgsIndexs)
}
}
if stmt.Limit != nil {
if stmt.Limit.Offset != nil {
b.traversalArgs(stmt.Limit.Offset, &selectArgsIndexs)
}
if stmt.Limit.Count != nil {
b.traversalArgs(stmt.Limit.Count, &selectArgsIndexs)
}
}
// sort selectArgs index array
gxsort.Int32(selectArgsIndexs)
for _, index := range selectArgsIndexs {
selectArgs = append(selectArgs, args[index])
}

return selectArgs
}

// todo perfect all sql operation
func (b *baseExecutor) traversalArgs(node ast.Node, argsIndex *[]int32) {
if node == nil {
return
}
switch node.(type) {
case *ast.BinaryOperationExpr:
expr := node.(*ast.BinaryOperationExpr)
b.traversalArgs(expr.L, argsIndex)
b.traversalArgs(expr.R, argsIndex)
break
case *ast.BetweenExpr:
expr := node.(*ast.BetweenExpr)
b.traversalArgs(expr.Left, argsIndex)
b.traversalArgs(expr.Right, argsIndex)
break
case *ast.PatternInExpr:
exprs := node.(*ast.PatternInExpr).List
for i := 0; i < len(exprs); i++ {
b.traversalArgs(exprs[i], argsIndex)
}
break
case *test_driver.ParamMarkerExpr:
*argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order))
break
}
}

func (b *baseExecutor) buildRecordImages(rowsi driver.Rows, tableMetaData *types.TableMeta) (*types.RecordImage, error) {
// select column names
columnNames := rowsi.Columns()
rowImages := make([]types.RowImage, 0)

sqlRows := util.NewScanRows(rowsi)

for sqlRows.Next() {
ss := b.GetScanSlice(columnNames, tableMetaData)

err := sqlRows.Scan(ss...)
if err != nil {
return nil, err
}

columns := make([]types.ColumnImage, 0)
// build record image
for i, name := range columnNames {
columnMeta := tableMetaData.Columns[name]

keyType := types.IndexTypeNull
if _, ok := tableMetaData.GetPrimaryKeyMap()[name]; ok {
keyType = types.IndexTypePrimaryKey
}
jdbcType := types.MySQLStrToJavaType(columnMeta.DatabaseTypeString)

columns = append(columns, types.ColumnImage{
KeyType: keyType,
ColumnName: name,
ColumnType: jdbcType,
Value: reflectx.GetElemDataValue(ss[i]),
})
}
rowImages = append(rowImages, types.RowImage{Columns: columns})
}

return &types.RecordImage{TableName: tableMetaData.TableName, Rows: rowImages}, nil
}

// buildWhereConditionByPKs build where condition by primary keys
// each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))"
func (b *baseExecutor) buildWhereConditionByPKs(pkNameList []string, rowSize int, dbType string, maxInSize int) string {
var (
whereStr = &strings.Builder{}
batchSize = rowSize/maxInSize + 1
)

if rowSize%maxInSize == 0 {
batchSize = rowSize / maxInSize
}

for batch := 0; batch < batchSize; batch++ {
if batch > 0 {
whereStr.WriteString(" OR ")
}
whereStr.WriteString("(")

for i := 0; i < len(pkNameList); i++ {
if i > 0 {
whereStr.WriteString(",")
}
// todo add escape
whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i]))
}
whereStr.WriteString(") IN (")

var eachSize int

if batch == batchSize-1 {
if rowSize%maxInSize == 0 {
eachSize = maxInSize
} else {
eachSize = rowSize % maxInSize
}
} else {
eachSize = maxInSize
}

for i := 0; i < eachSize; i++ {
if i > 0 {
whereStr.WriteString(",")
}
whereStr.WriteString("(")
for j := 0; j < len(pkNameList); j++ {
if j > 0 {
whereStr.WriteString(",")
}
whereStr.WriteString("?")
}
whereStr.WriteString(")")
}
whereStr.WriteString(")")
}
return whereStr.String()
}

func (b *baseExecutor) buildPKParams(rows []types.RowImage, pkNameList []string) []driver.NamedValue {
params := make([]driver.NamedValue, 0)
for _, row := range rows {
coumnMap := row.GetColumnMap()
for i, pk := range pkNameList {
if col, ok := coumnMap[pk]; ok {
params = append(params, driver.NamedValue{
Ordinal: i, Value: col.Value,
})
}
}
}
return params
}

// the string as local key. the local key example(multi pk): "t_user:1_a,2_b"
func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.TableMeta) string {
var (
lockKeys bytes.Buffer
filedSequence int
)
lockKeys.WriteString(meta.TableName)
lockKeys.WriteString(":")

keys := meta.GetPrimaryKeyOnlyName()

for _, row := range records.Rows {
if filedSequence > 0 {
lockKeys.WriteString(",")
}
pkSplitIndex := 0
for _, column := range row.Columns {
var hasKeyColumn bool
for _, key := range keys {
if column.ColumnName == key {
hasKeyColumn = true
if pkSplitIndex > 0 {
lockKeys.WriteString("_")
}
lockKeys.WriteString(fmt.Sprintf("%v", column.Value))
pkSplitIndex++
}
}
if hasKeyColumn {
filedSequence++
}
}
}

return lockKeys.String()
}

+ 32
- 153
pkg/datasource/sql/exec/at/executor_at.go View File

@@ -19,182 +19,61 @@ package at

import (
"context"
"fmt"

"github.com/mitchellh/copystructure"
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/datasource/sql/parser"
"github.com/seata/seata-go/pkg/datasource/sql/undo"
"github.com/seata/seata-go/pkg/datasource/sql/util"
"github.com/seata/seata-go/pkg/tm"

"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/parser"
"github.com/seata/seata-go/pkg/datasource/sql/types"
)

type ATExecutor struct {
hooks []exec.SQLHook
ex exec.SQLExecutor
func init() {
exec.RegisterATExecutor(types.DBTypeMySQL, func() exec.SQLExecutor { return &AtExecutor{} })
}

func (e *ATExecutor) Interceptors(hooks []exec.SQLHook) {
e.hooks = hooks
type executor interface {
ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error)
}

func (e *ATExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
for _, hook := range e.hooks {
hook.Before(ctx, execCtx)
}

var (
beforeImages []*types.RecordImage
afterImages []*types.RecordImage
result types.ExecResult
err error
)

beforeImages, err = e.beforeImage(ctx, execCtx)
if err != nil {
return nil, err
}
if beforeImages != nil {
beforeImagesTmp, err := copystructure.Copy(beforeImages)
if err != nil {
return nil, err
}
newBeforeImages, ok := beforeImagesTmp.([]*types.RecordImage)
if !ok {
return nil, errors.New("copy beforeImages failed")
}
execCtx.TxCtx.RoundImages.AppendBeofreImages(newBeforeImages)
}

defer func() {
for _, hook := range e.hooks {
hook.After(ctx, execCtx)
}
}()

if e.ex != nil {
result, err = e.ex.ExecWithNamedValue(ctx, execCtx, f)
} else {
result, err = f(ctx, execCtx.Query, execCtx.NamedValues)
}

if err != nil {
return nil, err
}

afterImages, err = e.afterImage(ctx, execCtx, beforeImages)
if err != nil {
return nil, err
}
if afterImages != nil {
execCtx.TxCtx.RoundImages.AppendAfterImages(afterImages)
}

return result, err
type AtExecutor struct {
hooks []exec.SQLHook
}

func (e *ATExecutor) prepareUndoLog(ctx context.Context, execCtx *types.ExecContext) error {
if execCtx.TxCtx.RoundImages.IsEmpty() {
return nil
}

if execCtx.ParseContext.UpdateStmt != nil {
if !execCtx.TxCtx.RoundImages.IsBeforeAfterSizeEq() {
return fmt.Errorf("Before image size is not equaled to after image size, probably because you updated the primary keys.")
}
}
undoLogManager, err := undo.GetUndoLogManager(execCtx.DBType)
if err != nil {
return err
}
return undoLogManager.FlushUndoLog(execCtx.TxCtx, execCtx.Conn)
func (e *AtExecutor) Interceptors(hooks []exec.SQLHook) {
e.hooks = hooks
}

func (e *ATExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) {
for _, hook := range e.hooks {
hook.Before(ctx, execCtx)
}

var (
beforeImages []*types.RecordImage
afterImages []*types.RecordImage
result types.ExecResult
err error
)

beforeImages, err = e.beforeImage(ctx, execCtx)
if err != nil {
return nil, err
}
if beforeImages != nil {
execCtx.TxCtx.RoundImages.AppendBeofreImages(beforeImages)
}

defer func() {
for _, hook := range e.hooks {
hook.After(ctx, execCtx)
}
}()

if e.ex != nil {
result, err = e.ex.ExecWithValue(ctx, execCtx, f)
} else {
result, err = f(ctx, execCtx.Query, execCtx.Values)
}
if err != nil {
return nil, err
}

afterImages, err = e.afterImage(ctx, execCtx, beforeImages)
//ExecWithNamedValue
func (e *AtExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
parser, err := parser.DoParser(execCtx.Query)
if err != nil {
return nil, err
}
if afterImages != nil {
execCtx.TxCtx.RoundImages.AppendAfterImages(afterImages)
}

return result, err
}
var exec executor

func (e *ATExecutor) beforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) {
if !tm.IsGlobalTx(ctx) {
return nil, nil
}

pc, err := parser.DoParser(execCtx.Query)
if err != nil {
return nil, err
}
if !pc.HasValidStmt() {
return nil, nil
exec = NewPlainExecutor(parser, execCtx)
} else {
switch parser.SQLType {
//case types.SQLTypeInsert:
case types.SQLTypeUpdate:
exec = NewUpdateExecutor(parser, execCtx, e.hooks)
//case types.SQLTypeDelete:
//case types.SQLTypeSelectForUpdate:
//case types.SQLTypeMultiDelete:
//case types.SQLTypeMultiUpdate:
default:
exec = NewPlainExecutor(parser, execCtx)
}
}
execCtx.ParseContext = pc

builder := undo.GetUndologBuilder(pc.ExecutorType)
if builder == nil {
return nil, nil
}
return builder.BeforeImage(ctx, execCtx)
return exec.ExecContext(ctx, f)
}

// After
func (e *ATExecutor) afterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) {
if !tm.IsGlobalTx(ctx) {
return nil, nil
}
pc, err := parser.DoParser(execCtx.Query)
if err != nil {
return nil, err
}
if !pc.HasValidStmt() {
return nil, nil
}
execCtx.ParseContext = pc
builder := undo.GetUndologBuilder(pc.ExecutorType)
if builder == nil {
return nil, nil
}
return builder.AfterImage(ctx, execCtx, beforeImages)
//ExecWithValue
func (e *AtExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
execCtx.NamedValues = util.ValueToNamedValue(execCtx.Values)
return e.ExecWithNamedValue(ctx, execCtx, f)
}

pkg/datasource/sql/exec/at/default.go → pkg/datasource/sql/exec/at/plain_executor.go View File

@@ -18,21 +18,21 @@
package at

import (
"context"

"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/types"
)

func init() {
exec.RegisterATExecutor(types.DBTypeMySQL, types.UpdateExecutor, func() exec.SQLExecutor {
return &ATExecutor{}
})
exec.RegisterATExecutor(types.DBTypeMySQL, types.SelectExecutor, func() exec.SQLExecutor {
return &ATExecutor{}
})
exec.RegisterATExecutor(types.DBTypeMySQL, types.InsertExecutor, func() exec.SQLExecutor {
return &ATExecutor{}
})
exec.RegisterATExecutor(types.DBTypeMySQL, types.DeleteExecutor, func() exec.SQLExecutor {
return &ATExecutor{}
})
type PlainExecutor struct {
parserCtx *types.ParseContext
execCtx *types.ExecContext
}

func NewPlainExecutor(parserCtx *types.ParseContext, execCtx *types.ExecContext) *PlainExecutor {
return &PlainExecutor{parserCtx: parserCtx, execCtx: execCtx}
}

func (u *PlainExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
return f(ctx, u.execCtx.Query, u.execCtx.NamedValues)
}

+ 272
- 0
pkg/datasource/sql/exec/at/update_executor.go View File

@@ -0,0 +1,272 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package at

import (
"context"
"database/sql/driver"
"fmt"
"strings"

"github.com/arana-db/parser/ast"
"github.com/arana-db/parser/format"
"github.com/arana-db/parser/model"
"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/util"
"github.com/seata/seata-go/pkg/util/bytes"
"github.com/seata/seata-go/pkg/util/log"
)

var (
// todo: OnlyCareUpdateColumns should load from config first
onlyCareUpdateColumns = true
maxInSize = 1000
)

type updateExecutor struct {
baseExecutor
parserCtx *types.ParseContext
execContent *types.ExecContext
}

//NewUpdateExecutor get update executor
func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor {
return &updateExecutor{parserCtx: parserCtx, execContent: execContent, baseExecutor: baseExecutor{hooks: hooks}}
}

//ExecContext exec SQL, and generate before image and after image
func (u *updateExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
u.beforeHooks(ctx, u.execContent)
defer func() {
u.afterHooks(ctx, u.execContent)
}()

beforeImage, err := u.beforeImage(ctx)
if err != nil {
return nil, err
}

res, err := f(ctx, u.execContent.Query, u.execContent.NamedValues)
if err != nil {
return nil, err
}

afterImage, err := u.afterImage(ctx, *beforeImage)
if err != nil {
return nil, err
}

if len(beforeImage.Rows) != len(afterImage.Rows) {
return nil, fmt.Errorf("Before image size is not equaled to after image size, probably because you updated the primary keys.")
}

u.execContent.TxCtx.RoundImages.AppendBeofreImage(beforeImage)
u.execContent.TxCtx.RoundImages.AppendAfterImage(afterImage)

return res, nil
}

//beforeImage build before image
func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) {
if !u.isAstStmtValid() {
return nil, nil
}

selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContent.NamedValues)
if err != nil {
return nil, err
}

tableName, _ := u.parserCtx.GteTableName()
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName)
if err != nil {
return nil, err
}

var rowsi driver.Rows
queryerCtx, ok := u.execContent.Conn.(driver.QueryerContext)
var queryer driver.Queryer
if !ok {
queryer, ok = u.execContent.Conn.(driver.Queryer)
}
if ok {
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
if err != nil {
log.Errorf("ctx driver query: %+v", err)
return nil, err
}
} else {
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
return nil, fmt.Errorf("invalid conn")
}

image, err := u.buildRecordImages(rowsi, metaData)
if err != nil {
return nil, err
}

lockKey := u.buildLockKey(image, *metaData)
u.execContent.TxCtx.LockKeys[lockKey] = struct{}{}
image.SQLType = u.parserCtx.SQLType

return image, nil
}

//afterImage build after image
func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.RecordImage) (*types.RecordImage, error) {
if !u.isAstStmtValid() {
return nil, nil
}
if len(beforeImage.Rows) == 0 {
return &types.RecordImage{}, nil
}

tableName, _ := u.parserCtx.GteTableName()
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName)
if err != nil {
return nil, err
}
selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData)

var rowsi driver.Rows
queryerCtx, ok := u.execContent.Conn.(driver.QueryerContext)
var queryer driver.Queryer
if !ok {
queryer, ok = u.execContent.Conn.(driver.Queryer)
}
if ok {
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
if err != nil {
log.Errorf("ctx driver query: %+v", err)
return nil, err
}
} else {
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
return nil, fmt.Errorf("invalid conn")
}

afterImage, err := u.buildRecordImages(rowsi, metaData)
if err != nil {
return nil, err
}
afterImage.SQLType = u.parserCtx.SQLType

return afterImage, nil
}

func (u *updateExecutor) isAstStmtValid() bool {
return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil
}

//buildAfterImageSQL build the SQL to query after image data
func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta *types.TableMeta) (string, []driver.NamedValue) {
if len(beforeImage.Rows) == 0 {
return "", nil
}
sb := strings.Builder{}
// todo: OnlyCareUpdateColumns should load from config first
var selectFields string
var separator = ","
if onlyCareUpdateColumns {
for _, row := range beforeImage.Rows {
for _, column := range row.Columns {
selectFields += column.ColumnName + separator
}
}
selectFields = strings.TrimSuffix(selectFields, separator)
} else {
selectFields = "*"
}
sb.WriteString("SELECT " + selectFields + " FROM " + meta.TableName + " WHERE ")
whereSQL := u.buildWhereConditionByPKs(meta.GetPrimaryKeyOnlyName(), len(beforeImage.Rows), "mysql", maxInSize)
sb.WriteString(" " + whereSQL + " ")
return sb.String(), u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName())
}

//buildAfterImageSQL build the SQL to query before image data
func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) {
if !u.isAstStmtValid() {
log.Errorf("invalid update stmt")
return "", nil, fmt.Errorf("invalid update stmt")
}

updateStmt := u.parserCtx.UpdateStmt
fields := make([]*ast.SelectField, 0, len(updateStmt.List))

if onlyCareUpdateColumns {
for _, column := range updateStmt.List {
fields = append(fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: column.Column,
},
})
}

// select indexes columns
tableName, _ := u.parserCtx.GteTableName()
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName)
if err != nil {
return "", nil, err
}
for _, columnName := range metaData.GetPrimaryKeyOnlyName() {
fields = append(fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Name: model.CIStr{
O: columnName,
L: columnName,
},
},
},
})
}
} else {
fields = append(fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Name: model.CIStr{
O: "*",
L: "*",
},
},
},
})
}

selStmt := ast.SelectStmt{
SelectStmtOpts: &ast.SelectStmtOpts{},
From: updateStmt.TableRefs,
Where: updateStmt.Where,
Fields: &ast.FieldList{Fields: fields},
OrderBy: updateStmt.Order,
Limit: updateStmt.Limit,
TableHints: updateStmt.TableHints,
LockInfo: &ast.SelectLockInfo{
LockType: ast.SelectLockForUpdate,
},
}

b := bytes.NewByteBuffer([]byte{})
_ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
sql := string(b.Bytes())
log.Infof("build select sql by update sourceQuery, sql {%s}", sql)

return sql, u.buildSelectArgs(&selStmt, args), nil
}

+ 101
- 0
pkg/datasource/sql/exec/at/update_executor_test.go View File

@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package at

import (
"context"
"database/sql/driver"
"reflect"
"testing"

"github.com/agiledragon/gomonkey"
"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql"
"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/util"

"github.com/seata/seata-go/pkg/datasource/sql/types"

"github.com/seata/seata-go/pkg/datasource/sql/parser"

_ "github.com/arana-db/parser/test_driver"
_ "github.com/seata/seata-go/pkg/util/log"
"github.com/stretchr/testify/assert"
)

func TestBuildSelectSQLByUpdate(t *testing.T) {
datasource.RegisterTableCache(types.DBTypeMySQL, mysql.NewTableMetaInstance(nil))
stub := gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)), "GetTableMeta",
func(_ *mysql.TableMetaCache, ctx context.Context, dbName, tableName string) (*types.TableMeta, error) {
return &types.TableMeta{
Indexs: map[string]types.IndexMeta{
"id": {
IType: types.IndexTypePrimaryKey,
Columns: []types.ColumnMeta{
{ColumnName: "id"},
},
},
},
}, nil
})
defer stub.Reset()

tests := []struct {
name string
sourceQuery string
sourceQueryArgs []driver.Value
expectQuery string
expectQueryArgs []driver.Value
}{
{
sourceQuery: "update t_user set name = ?, age = ? where id = ?",
sourceQueryArgs: []driver.Value{"Jack", 1, 100},
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? FOR UPDATE",
expectQueryArgs: []driver.Value{100},
},
{
sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age between ? and ?",
sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28},
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE",
expectQueryArgs: []driver.Value{100, 18, 28},
},
{
sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age in (?,?)",
sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28},
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE",
expectQueryArgs: []driver.Value{100, 18, 28},
},
{
sourceQuery: "update t_user set name = ?, age = ? where kk between ? and ? and id = ? and addr in(?,?) and age > ? order by name desc limit ?",
sourceQueryArgs: []driver.Value{"Jack", 1, 10, 20, 17, "Beijing", "Guangzhou", 18, 2},
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE",
expectQueryArgs: []driver.Value{10, 20, 17, "Beijing", "Guangzhou", 18, 2},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := parser.DoParser(tt.sourceQuery)
assert.Nil(t, err)
executor := NewUpdateExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{})
query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), util.ValueToNamedValue(tt.sourceQueryArgs))
assert.Nil(t, err)
assert.Equal(t, tt.expectQuery, query)
assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args))
})
}
}

+ 20
- 45
pkg/datasource/sql/exec/executor.go View File

@@ -25,35 +25,25 @@ import (
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/undo"
"github.com/seata/seata-go/pkg/datasource/sql/undo/builder"
"github.com/seata/seata-go/pkg/util/log"
)

func init() {
undo.RegisterUndoLogBuilder(types.UpdateExecutor, builder.GetMySQLUpdateUndoLogBuilder)
undo.RegisterUndoLogBuilder(types.MultiExecutor, builder.GetMySQLMultiUndoLogBuilder)
}

var (
executorSoltsAT = make(map[types.DBType]map[types.ExecutorType]func() SQLExecutor)
executorSoltsXA = make(map[types.DBType]func() SQLExecutor)
atExecutors = make(map[types.DBType]func() SQLExecutor)
xaExecutors = make(map[types.DBType]func() SQLExecutor)
)

// RegisterATExecutor AT executor
func RegisterATExecutor(dt types.DBType, et types.ExecutorType, builder func() SQLExecutor) {
if _, ok := executorSoltsAT[dt]; !ok {
executorSoltsAT[dt] = make(map[types.ExecutorType]func() SQLExecutor)
}

val := executorSoltsAT[dt]

val[et] = func() SQLExecutor {
return &BaseExecutor{ex: builder()}
}
func RegisterATExecutor(dt types.DBType, builder func() SQLExecutor) {
atExecutors[dt] = builder
}

// RegisterXAExecutor XA executor
func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) {
executorSoltsXA[dt] = func() SQLExecutor {
xaExecutors[dt] = func() SQLExecutor {
return builder()
}
}
@@ -66,7 +56,7 @@ type (
SQLExecutor interface {
Interceptors(interceptors []SQLHook)
ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error)
ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error)
ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error)
}
)

@@ -83,37 +73,14 @@ func BuildExecutor(dbType types.DBType, transactionMode types.TransactionMode, q
hooks = append(hooks, hookSolts[parseContext.SQLType]...)

if transactionMode == types.XAMode {
e := executorSoltsXA[dbType]()
e.Interceptors(hooks)
return e, nil
}

if transactionMode == types.ATMode {
e := executorSoltsAT[dbType][parseContext.ExecutorType]()
e.Interceptors(hooks)
return e, nil
}

factories, ok := executorSoltsAT[dbType]
if !ok {
log.Debugf("%s not found executor factories, return default Executor", dbType.String())
e := &BaseExecutor{}
e := xaExecutors[dbType]()
e.Interceptors(hooks)
return e, nil
}

supplier, ok := factories[parseContext.ExecutorType]
if !ok {
log.Debugf("%s not found executor for %s, return default Executor",
dbType.String(), parseContext.ExecutorType)
e := &BaseExecutor{}
e.Interceptors(hooks)
return e, nil
}

executor := supplier()
executor.Interceptors(hooks)
return executor, nil
e := atExecutors[dbType]()
e.Interceptors(hooks)
return e, nil
}

type BaseExecutor struct {
@@ -146,7 +113,7 @@ func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Ex
}

// ExecWithValue
func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error) {
func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) {
for i := range e.hooks {
e.hooks[i].Before(ctx, execCtx)
}
@@ -161,5 +128,13 @@ func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecCon
return e.ex.ExecWithValue(ctx, execCtx, f)
}

return f(ctx, execCtx.Query, execCtx.Values)
nvargs := make([]driver.NamedValue, len(execCtx.Values))
for i, value := range execCtx.Values {
nvargs = append(nvargs, driver.NamedValue{
Value: value,
Ordinal: i,
})
}

return f(ctx, execCtx.Query, nvargs)
}

+ 12
- 2
pkg/datasource/sql/exec/xa/executor_xa.go View File

@@ -19,6 +19,7 @@ package xa

import (
"context"
"database/sql/driver"

"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/types"
@@ -55,7 +56,7 @@ func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Exec
}

// ExecWithValue
func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) {
func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
for _, hook := range e.hooks {
hook.Before(ctx, execCtx)
}
@@ -70,5 +71,14 @@ func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecConte
return e.ex.ExecWithValue(ctx, execCtx, f)
}

return f(ctx, execCtx.Query, execCtx.Values)
nvargs := make([]driver.NamedValue, len(execCtx.Values))
for i, value := range execCtx.Values {
nvargs = append(nvargs, driver.NamedValue{
Value: value,
Ordinal: i,
})
}
execCtx.NamedValues = nvargs

return f(ctx, execCtx.Query, execCtx.NamedValues)
}

+ 5
- 4
pkg/datasource/sql/stmt.go View File

@@ -23,6 +23,7 @@ import (

"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/util"
)

type Stmt struct {
@@ -75,8 +76,8 @@ func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
}

ret, err := executor.ExecWithValue(context.Background(), execCtx,
func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) {
ret, err := s.stmt.Query(args)
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) {
ret, err := s.stmt.Query(util.NamedValueToValue(args))
if err != nil {
return nil, err
}
@@ -144,8 +145,8 @@ func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
}

ret, err := executor.ExecWithValue(context.Background(), execCtx,
func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) {
ret, err := s.stmt.Exec(args)
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) {
ret, err := s.stmt.Exec(util.NamedValueToValue(args))
if err != nil {
return nil, err
}


+ 3
- 2
pkg/datasource/sql/types/image.go View File

@@ -129,7 +129,8 @@ type RowImage struct {
func (r *RowImage) GetColumnMap() map[string]*ColumnImage {
m := make(map[string]*ColumnImage, 0)
for _, column := range r.Columns {
m[column.ColumnName] = &column
tmpColumn := column
m[column.ColumnName] = &tmpColumn
}
return m
}
@@ -245,7 +246,7 @@ func (c *ColumnImage) UnmarshalJSON(data []byte) error {
case JDBCTypeChar, JDBCTypeVarchar:
var val []byte
if val, err = base64.StdEncoding.DecodeString(value.(string)); err != nil {
return err
val = []byte(value.(string))
}
actualValue = string(val)
case JDBCTypeBinary, JDBCTypeVarBinary, JDBCTypeLongVarBinary, JDBCTypeBit:


+ 377
- 0
pkg/datasource/sql/util/convert.go View File

@@ -0,0 +1,377 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Type conversions for Scan.

package util

import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strconv"
"time"
)

var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error

// convertAssignRows copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information.
// dest should be a pointer type. If rows is passed in, the rows will
// be used as the parent for any cursor values converted from a
// driver.Rows to a *ScanRows.
func convertAssignRows(dest, src interface{}, rows *ScanRows) error {
// Common cases, without reflect.
switch s := src.(type) {
case string:
switch d := dest.(type) {
case *string:
if d == nil {
return errNilPtr
}
*d = s
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = []byte(s)
return nil
case *sql.RawBytes:
if d == nil {
return errNilPtr
}
*d = append((*d)[:0], s...)
return nil
}
case []byte:
switch d := dest.(type) {
case *string:
if d == nil {
return errNilPtr
}
*d = string(s)
return nil
case *interface{}:
if d == nil {
return errNilPtr
}
*d = cloneBytes(s)
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = cloneBytes(s)
return nil
case *sql.RawBytes:
if d == nil {
return errNilPtr
}
*d = s
return nil
}
case time.Time:
switch d := dest.(type) {
case *time.Time:
*d = s
return nil
case *string:
*d = s.Format(time.RFC3339Nano)
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = []byte(s.Format(time.RFC3339Nano))
return nil
case *sql.RawBytes:
if d == nil {
return errNilPtr
}
*d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
return nil
}
case decimalDecompose:
switch d := dest.(type) {
case decimalCompose:
return d.Compose(s.Decompose(nil))
}
case nil:
switch d := dest.(type) {
case *interface{}:
if d == nil {
return errNilPtr
}
*d = nil
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = nil
return nil
case *sql.RawBytes:
if d == nil {
return errNilPtr
}
*d = nil
return nil
}
// The driver is returning a cursor the client may iterate over.
case driver.Rows:
switch d := dest.(type) {
case *ScanRows:
if d == nil {
return errNilPtr
}
if rows == nil {
return errors.New("invalid context to convert cursor rows, missing parent *ScanRows")
}
rows.closemu.Lock()
*d = ScanRows{
dc: rows.dc,
releaseConn: func(error) {},
rowsi: s,
}
// Chain the cancel function.
parentCancel := rows.cancel
rows.cancel = func() {
// When ScanRows.cancel is called, the closemu will be locked as well.
// So we can access rs.lasterr.
d.close(rows.lasterr)
if parentCancel != nil {
parentCancel()
}
}
rows.closemu.Unlock()
return nil
}
}

var sv reflect.Value

switch d := dest.(type) {
case *string:
sv = reflect.ValueOf(src)
switch sv.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
*d = asString(src)
return nil
}
case *[]byte:
sv = reflect.ValueOf(src)
if b, ok := asBytes(nil, sv); ok {
*d = b
return nil
}
case *sql.RawBytes:
sv = reflect.ValueOf(src)
if b, ok := asBytes([]byte(*d)[:0], sv); ok {
*d = sql.RawBytes(b)
return nil
}
case *bool:
bv, err := driver.Bool.ConvertValue(src)
if err == nil {
*d = bv.(bool)
}
return err
case *interface{}:
*d = src
return nil
}

if scanner, ok := dest.(sql.Scanner); ok {
return scanner.Scan(src)
}

dpv := reflect.ValueOf(dest)
if dpv.Kind() != reflect.Ptr {
return errors.New("destination not a pointer")
}
if dpv.IsNil() {
return errNilPtr
}

if !sv.IsValid() {
sv = reflect.ValueOf(src)
}

dv := reflect.Indirect(dpv)
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
switch b := src.(type) {
case []byte:
dv.Set(reflect.ValueOf(cloneBytes(b)))
default:
dv.Set(sv)
}
return nil
}

if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
dv.Set(sv.Convert(dv.Type()))
return nil
}

// The following conversions use a string value as an intermediate representation
// to convert between various numeric types.
//
// This also allows scanning into user defined types such as "type Int int64".
// For symmetry, also check for string destination types.
switch dv.Kind() {
case reflect.Ptr:
if src == nil {
dv.Set(reflect.Zero(dv.Type()))
return nil
}
dv.Set(reflect.New(dv.Type().Elem()))
return convertAssignRows(dv.Interface(), src, rows)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetInt(i64)
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetUint(u64)
return nil
case reflect.Float32, reflect.Float64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetFloat(f64)
return nil
case reflect.String:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
switch v := src.(type) {
case string:
dv.SetString(v)
return nil
case []byte:
dv.SetString(string(v))
return nil
}
}

return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
}

func strconvErr(err error) error {
if ne, ok := err.(*strconv.NumError); ok {
return ne.Err
}
return err
}

func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
c := make([]byte, len(b))
copy(c, b)
return c
}

func asString(src interface{}) string {
switch v := src.(type) {
case string:
return v
case []byte:
return string(v)
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(rv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.FormatUint(rv.Uint(), 10)
case reflect.Float64:
return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
case reflect.Float32:
return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
case reflect.Bool:
return strconv.FormatBool(rv.Bool())
}
return fmt.Sprintf("%v", src)
}

func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.AppendInt(buf, rv.Int(), 10), true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.AppendUint(buf, rv.Uint(), 10), true
case reflect.Float32:
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
case reflect.Float64:
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
case reflect.Bool:
return strconv.AppendBool(buf, rv.Bool()), true
case reflect.String:
s := rv.String()
return append(buf, s...), true
}
return
}

var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()

type decimalDecompose interface {
// Decompose returns the internal decimal state in parts.
// If the provided buf has sufficient capacity, buf may be returned as the coefficient with
// the value set and length set as appropriate.
Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
}

type decimalCompose interface {
// Compose sets the internal decimal value from parts. If the value cannot be
// represented then an error should be returned.
Compose(form byte, negative bool, coefficient []byte, exponent int32) error
}

+ 123
- 0
pkg/datasource/sql/util/ctxutil.go View File

@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package util

import (
"context"
"database/sql/driver"
"errors"
)

func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
if ciCtx, is := ci.(driver.ConnPrepareContext); is {
return ciCtx.PrepareContext(ctx, query)
}
si, err := ci.Prepare(query)
if err == nil {
select {
default:
case <-ctx.Done():
si.Close()
return nil, ctx.Err()
}
}
return si, err
}

func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
if execerCtx != nil {
return execerCtx.ExecContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}

select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return execer.Exec(query, dargs)
}

func CtxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
if queryerCtx != nil {
return queryerCtx.QueryContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}

select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return queryer.Query(query, dargs)
}

func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
if siCtx, is := si.(driver.StmtExecContext); is {
return siCtx.ExecContext(ctx, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}

select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return si.Exec(dargs)
}

func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
if siCtx, is := si.(driver.StmtQueryContext); is {
return siCtx.QueryContext(ctx, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}

select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return si.Query(dargs)
}

func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("sql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}

+ 39
- 0
pkg/datasource/sql/util/params.go View File

@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package util

import "database/sql/driver"

func NamedValueToValue(nvs []driver.NamedValue) []driver.Value {
vs := make([]driver.Value, 0, len(nvs))
for _, nv := range nvs {
vs = append(vs, nv.Value)
}
return vs
}

func ValueToNamedValue(vs []driver.Value) []driver.NamedValue {
nvs := make([]driver.NamedValue, 0, len(vs))
for i, v := range vs {
nvs = append(nvs, driver.NamedValue{
Value: v,
Ordinal: i,
})
}
return nvs
}

+ 355
- 0
pkg/datasource/sql/util/sql.go View File

@@ -0,0 +1,355 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package sql provides a generic interface around SQL (or SQL-like)
// databases.
//
// The sql package must be used in conjunction with a database driver.
// See https://golang.org/s/sqldrivers for a list of drivers.
//
// Drivers that do not support context cancellation will not return until
// after the query is completed.
//
// For usage examples, see the wiki page at
// https://golang.org/s/sqlwiki.
package util

import (
"context"
"database/sql/driver"
"errors"
"fmt"
"io"
"sync"
)

// ScanRows is the result of a query. Its cursor starts before the first row
// of the result set. Use Next to advance from row to row.
type ScanRows struct {
//dc *driverConn // owned; must call releaseConn when closed to release
dc driver.Conn
releaseConn func(error)
rowsi driver.Rows
cancel func() // called when ScanRows is closed, may be nil.
//closeStmt *driverStmt // if non-nil, statement to Close on close

// closemu prevents ScanRows from closing while there
// is an active streaming result. It is held for read during non-close operations
// and exclusively during close.
//
// closemu guards lasterr and closed.
closemu sync.RWMutex
closed bool
lasterr error // non-nil only if closed is true

// lastcols is only used in Scan, Next, and NextResultSet which are expected
// not to be called concurrently.
lastcols []driver.Value
}

func NewScanRows(rowsi driver.Rows) *ScanRows {
return &ScanRows{rowsi: rowsi}
}

// lasterrOrErrLocked returns either lasterr or the provided err.
// rs.closemu must be read-locked.
func (rs *ScanRows) lasterrOrErrLocked(err error) error {
if rs.lasterr != nil && rs.lasterr != io.EOF {
return rs.lasterr
}
return err
}

// bypassRowsAwaitDone is only used for testing.
// If true, it will not close the ScanRows automatically from the context.
var bypassRowsAwaitDone = false

func (rs *ScanRows) initContextClose(ctx, txctx context.Context) {
if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) {
return
}
if bypassRowsAwaitDone {
return
}
ctx, rs.cancel = context.WithCancel(ctx)
go rs.awaitDone(ctx, txctx)
}

// awaitDone blocks until either ctx or txctx is canceled. The ctx is provided
// from the query context and is canceled when the query ScanRows is closed.
// If the query was issued in a transaction, the transaction's context
// is also provided in txctx to ensure ScanRows is closed if the Tx is closed.
func (rs *ScanRows) awaitDone(ctx, txctx context.Context) {
var txctxDone <-chan struct{}
if txctx != nil {
txctxDone = txctx.Done()
}
select {
case <-ctx.Done():
case <-txctxDone:
}
rs.close(ctx.Err())
}

// Next prepares the next result row for reading with the Scan method. It
// returns true on success, or false if there is no next result row or an error
// happened while preparing it. Err should be consulted to distinguish between
// the two cases.
//
// Every call to Scan, even the first one, must be preceded by a call to Next.
func (rs *ScanRows) Next() bool {
var doClose, ok bool
withLock(rs.closemu.RLocker(), func() {
doClose, ok = rs.nextLocked()
})
if doClose {
rs.Close()
}
return ok
}

func (rs *ScanRows) nextLocked() (doClose, ok bool) {
if rs.closed {
return false, false
}

// Lock the driver connection before calling the driver interface
// rowsi to prevent a Tx from rolling back the connection at the same time.
//rs.dc.Lock()
//defer rs.dc.Unlock()

if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
}

rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr != nil {
// Close the connection if there is a driver error.
if rs.lasterr != io.EOF {
return true, false
}
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
if !ok {
return true, false
}
// The driver is at the end of the current result set.
// Test to see if there is another result set after the current one.
// Only close ScanRows if there is no further result sets to read.
if !nextResultSet.HasNextResultSet() {
doClose = true
}
return doClose, false
}
return false, true
}

// NextResultSet prepares the next result set for reading. It reports whether
// there is further result sets, or false if there is no further result set
// or if there is an error advancing to it. The Err method should be consulted
// to distinguish between the two cases.
//
// After calling NextResultSet, the Next method should always be called before
// scanning. If there are further result sets they may not have rows in the result
// set.
func (rs *ScanRows) NextResultSet() bool {
var doClose bool
defer func() {
if doClose {
rs.Close()
}
}()
rs.closemu.RLock()
defer rs.closemu.RUnlock()

if rs.closed {
return false
}

rs.lastcols = nil
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
if !ok {
doClose = true
return false
}

// Lock the driver connection before calling the driver interface
// rowsi to prevent a Tx from rolling back the connection at the same time.
//rs.dc.Lock()
//defer rs.dc.Unlock()

rs.lasterr = nextResultSet.NextResultSet()
if rs.lasterr != nil {
doClose = true
return false
}
return true
}

// Err returns the error, if any, that was encountered during iteration.
// Err may be called after an explicit or implicit Close.
func (rs *ScanRows) Err() error {
rs.closemu.RLock()
defer rs.closemu.RUnlock()
return rs.lasterrOrErrLocked(nil)
}

//
var errRowsClosed = errors.New("sql: ScanRows are closed")

// Scan copies the columns in the current row into the values pointed
// at by dest. The number of values in dest must be the same as the
// number of columns in ScanRows.
//
// Scan converts columns read from the database into the following
// common Go types and special types provided by the sql package:
//
// *string
// *[]byte
// *int, *int8, *int16, *int32, *int64
// *uint, *uint8, *uint16, *uint32, *uint64
// *bool
// *float32, *float64
// *interface{}
// *RawBytes
// *ScanRows (cursor value)
// any type implementing Scanner (see Scanner docs)
//
// In the most simple case, if the type of the value from the source
// column is an integer, bool or string type T and dest is of type *T,
// Scan simply assigns the value through the pointer.
//
// Scan also converts between string and numeric types, as long as no
// information would be lost. While Scan stringifies all numbers
// scanned from numeric database columns into *string, scans into
// numeric types are checked for overflow. For example, a float64 with
// value 300 or a string with value "300" can scan into a uint16, but
// not into a uint8, though float64(255) or "255" can scan into a
// uint8. One exception is that scans of some float64 numbers to
// strings may lose information when stringifying. In general, scan
// floating point columns into *float64.
//
// If a dest argument has type *[]byte, Scan saves in that argument a
// copy of the corresponding data. The copy is owned by the caller and
// can be modified and held indefinitely. The copy can be avoided by
// using an argument of type *RawBytes instead; see the documentation
// for RawBytes for restrictions on its use.
//
// If an argument has type *interface{}, Scan copies the value
// provided by the underlying driver without conversion. When scanning
// from a source value of type []byte to *interface{}, a copy of the
// slice is made and the caller owns the result.
//
// Source values of type time.Time may be scanned into values of type
// *time.Time, *interface{}, *string, or *[]byte. When converting to
// the latter two, time.RFC3339Nano is used.
//
// Source values of type bool may be scanned into types *bool,
// *interface{}, *string, *[]byte, or *RawBytes.
//
// For scanning into *bool, the source may be true, false, 1, 0, or
// string inputs parseable by strconv.ParseBool.
//
// Scan can also convert a cursor returned from a query, such as
// "select cursor(select * from my_table) from dual", into a
// *ScanRows value that can itself be scanned from. The parent
// select query will close any cursor *ScanRows if the parent *ScanRows is closed.
//
// If any of the first arguments implementing Scanner returns an error,
// that error will be wrapped in the returned error
func (rs *ScanRows) Scan(dest ...interface{}) error {
rs.closemu.RLock()

if rs.lasterr != nil && rs.lasterr != io.EOF {
rs.closemu.RUnlock()
return rs.lasterr
}
if rs.closed {
err := rs.lasterrOrErrLocked(errRowsClosed)
rs.closemu.RUnlock()
return err
}
rs.closemu.RUnlock()

if rs.lastcols == nil {
return errors.New("sql: Scan called without calling Next")
}
if len(dest) != len(rs.lastcols) {
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
}
for i, sv := range rs.lastcols {
err := convertAssignRows(dest[i], sv, rs)
if err != nil {
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
}
}
return nil
}

// rowsCloseHook returns a function so tests may install the
// hook through a test only mutex.
var rowsCloseHook = func() func(*ScanRows, *error) { return nil }

// Close closes the ScanRows, preventing further enumeration. If Next is called
// and returns false and there are no further result sets,
// the ScanRows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err.
func (rs *ScanRows) Close() error {
//return rs.close(nil)
return nil
}

func (rs *ScanRows) close(err error) error {
rs.closemu.Lock()
defer rs.closemu.Unlock()

if rs.closed {
return nil
}
rs.closed = true

if rs.lasterr == nil {
rs.lasterr = err
}

err = rs.rowsi.Close()
//withLock(rs.dc, func() {
// err = rs.rowsi.Close()
//})
if fn := rowsCloseHook(); fn != nil {
fn(rs, &err)
}
if rs.cancel != nil {
rs.cancel()
}

//if rs.closeStmt != nil {
// rs.closeStmt.Close()
//}
rs.releaseConn(err)
return err
}

// withLock runs while holding lk.
func withLock(lk sync.Locker, fn func()) {
lk.Lock()
defer lk.Unlock() // in case fn panics
fn()
}

+ 1
- 1
pkg/remoting/processor/client/rm_branch_rollback_processor.go View File

@@ -85,6 +85,6 @@ func (f *rmBranchRollbackProcessor) Process(ctx context.Context, rpcMessage mess
log.Errorf("send branch rollback response error: {%#v}", err.Error())
return err
}
log.Infof("send branch rollback response success: xid %s, branchID %s, resourceID %s, applicationData %s", xid, branchID, resourceID, applicationData)
log.Infof("send branch rollback response success: xid %s, branchID %d, resourceID %s, applicationData %s", xid, branchID, resourceID, applicationData)
return nil
}

+ 13
- 0
pkg/util/reflectx/unexpoert_field.go View File

@@ -31,3 +31,16 @@ func SetUnexportedField(field reflect.Value, value interface{}) {
func GetUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}

func GetElemDataValue(data interface{}) interface{} {
if data == nil {
return data
}
value := reflect.ValueOf(data)
kind := reflect.TypeOf(data).Kind()
switch kind {
case reflect.Ptr:
return value.Elem().Interface()
}
return value.Interface()
}

+ 48
- 0
pkg/util/reflectx/unexpoert_field_test.go View File

@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package reflectx

import (
"testing"
)

func TestGetElemDataValue(t *testing.T) {
var aa = 10
var bb = "name"
var cc bool

tests := []struct {
name string
args interface{}
want interface{}
}{
{name: "test1", args: aa, want: aa},
{name: "test2", args: &aa, want: aa},
{name: "test3", args: bb, want: bb},
{name: "test4", args: &bb, want: bb},
{name: "test5", args: cc, want: cc},
{name: "test6", args: &cc, want: cc},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetElemDataValue(tt.args); got != tt.want {
t.Errorf("GetElemDataValue() = %v, want %v", got, tt.want)
}
})
}
}

Loading…
Cancel
Save