@@ -15,7 +15,6 @@ require ( | |||||
github.com/golang/mock v1.6.0 | github.com/golang/mock v1.6.0 | ||||
github.com/google/uuid v1.3.0 | github.com/google/uuid v1.3.0 | ||||
github.com/knadh/koanf v1.4.3 | github.com/knadh/koanf v1.4.3 | ||||
github.com/mitchellh/copystructure v1.2.0 | |||||
github.com/natefinch/lumberjack v2.0.0+incompatible | github.com/natefinch/lumberjack v2.0.0+incompatible | ||||
github.com/parnurzeal/gorequest v0.2.16 | github.com/parnurzeal/gorequest v0.2.16 | ||||
github.com/pierrec/lz4/v4 v4.1.17 | github.com/pierrec/lz4/v4 v4.1.17 | ||||
@@ -88,6 +87,7 @@ require ( | |||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect | github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect | ||||
github.com/magiconair/properties v1.8.6 // indirect | github.com/magiconair/properties v1.8.6 // indirect | ||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect | github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect | ||||
github.com/mitchellh/copystructure v1.2.0 // indirect | |||||
github.com/mitchellh/go-homedir v1.1.0 // indirect | github.com/mitchellh/go-homedir v1.1.0 // indirect | ||||
github.com/mitchellh/mapstructure v1.5.0 // indirect | github.com/mitchellh/mapstructure v1.5.0 // indirect | ||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect | github.com/mitchellh/reflectwalk v1.0.2 // indirect | ||||
@@ -32,8 +32,8 @@ import ( | |||||
"github.com/knadh/koanf/parsers/yaml" | "github.com/knadh/koanf/parsers/yaml" | ||||
"github.com/knadh/koanf/providers/rawbytes" | "github.com/knadh/koanf/providers/rawbytes" | ||||
"github.com/seata/seata-go/pkg/remoting/getty" | "github.com/seata/seata-go/pkg/remoting/getty" | ||||
"github.com/seata/seata-go/pkg/tm" | |||||
"github.com/seata/seata-go/pkg/rm/tcc" | "github.com/seata/seata-go/pkg/rm/tcc" | ||||
"github.com/seata/seata-go/pkg/tm" | |||||
"github.com/seata/seata-go/pkg/util/flagext" | "github.com/seata/seata-go/pkg/util/flagext" | ||||
) | ) | ||||
@@ -23,6 +23,7 @@ import ( | |||||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | "github.com/seata/seata-go/pkg/datasource/sql/exec" | ||||
"github.com/seata/seata-go/pkg/datasource/sql/types" | "github.com/seata/seata-go/pkg/datasource/sql/types" | ||||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||||
) | ) | ||||
// Conn is a connection to a database. It is not used concurrently | // Conn is a connection to a database. It is not used concurrently | ||||
@@ -146,8 +147,8 @@ func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) { | |||||
} | } | ||||
ret, err := executor.ExecWithValue(context.Background(), execCtx, | ret, err := executor.ExecWithValue(context.Background(), execCtx, | ||||
func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { | |||||
ret, err := conn.Query(query, args) | |||||
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { | |||||
ret, err := conn.Query(query, util.NamedValueToValue(args)) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
@@ -26,14 +26,6 @@ import ( | |||||
"github.com/stretchr/testify/assert" | "github.com/stretchr/testify/assert" | ||||
) | ) | ||||
func TestConn_BuildATExecutor(t *testing.T) { | |||||
executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.ATMode, "SELECT * FROM user") | |||||
assert.NoError(t, err) | |||||
_, ok := executor.(*exec.BaseExecutor) | |||||
assert.True(t, ok, "need base executor") | |||||
} | |||||
func TestConn_BuildXAExecutor(t *testing.T) { | func TestConn_BuildXAExecutor(t *testing.T) { | ||||
executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.XAMode, "SELECT * FROM user") | executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.XAMode, "SELECT * FROM user") | ||||
@@ -0,0 +1,296 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
package at | |||||
import ( | |||||
"bytes" | |||||
"context" | |||||
"database/sql" | |||||
"database/sql/driver" | |||||
"fmt" | |||||
"strings" | |||||
"github.com/arana-db/parser/ast" | |||||
"github.com/arana-db/parser/test_driver" | |||||
gxsort "github.com/dubbogo/gost/sort" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||||
"github.com/seata/seata-go/pkg/util/reflectx" | |||||
) | |||||
type baseExecutor struct { | |||||
hooks []exec.SQLHook | |||||
} | |||||
func (b *baseExecutor) beforeHooks(ctx context.Context, execCtx *types.ExecContext) { | |||||
for _, hook := range b.hooks { | |||||
hook.Before(ctx, execCtx) | |||||
} | |||||
} | |||||
func (b *baseExecutor) afterHooks(ctx context.Context, execCtx *types.ExecContext) { | |||||
for _, hook := range b.hooks { | |||||
hook.After(ctx, execCtx) | |||||
} | |||||
} | |||||
// GetScanSlice get the column type for scann | |||||
// todo to use ColumnInfo get slice | |||||
func (*baseExecutor) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []interface{} { | |||||
scanSlice := make([]interface{}, 0, len(columnNames)) | |||||
for _, columnNmae := range columnNames { | |||||
var ( | |||||
// 从metData获取该列的元信息 | |||||
columnMeta = tableMeta.Columns[columnNmae] | |||||
) | |||||
switch strings.ToUpper(columnMeta.DatabaseTypeString) { | |||||
case "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", "JSON", "TINYTEXT": | |||||
var scanVal string | |||||
scanSlice = append(scanSlice, &scanVal) | |||||
case "BIT", "INT", "LONGBLOB", "SMALLINT", "TINYINT", "BIGINT", "MEDIUMINT": | |||||
if columnMeta.IsNullable == 0 { | |||||
scanVal := int64(0) | |||||
scanSlice = append(scanSlice, &scanVal) | |||||
} else { | |||||
scanVal := sql.NullInt64{} | |||||
scanSlice = append(scanSlice, &scanVal) | |||||
} | |||||
case "DATE", "DATETIME", "TIME", "TIMESTAMP", "YEAR": | |||||
scanVal := sql.NullTime{} | |||||
scanSlice = append(scanSlice, &scanVal) | |||||
case "DECIMAL", "DOUBLE", "FLOAT": | |||||
if columnMeta.IsNullable == 0 { | |||||
scanVal := float64(0) | |||||
scanSlice = append(scanSlice, &scanVal) | |||||
} else { | |||||
scanVal := sql.NullFloat64{} | |||||
scanSlice = append(scanSlice, &scanVal) | |||||
} | |||||
default: | |||||
scanVal := sql.RawBytes{} | |||||
scanSlice = append(scanSlice, &scanVal) | |||||
} | |||||
} | |||||
return scanSlice | |||||
} | |||||
func (b *baseExecutor) buildSelectArgs(stmt *ast.SelectStmt, args []driver.NamedValue) []driver.NamedValue { | |||||
var ( | |||||
selectArgsIndexs = make([]int32, 0) | |||||
selectArgs = make([]driver.NamedValue, 0) | |||||
) | |||||
b.traversalArgs(stmt.Where, &selectArgsIndexs) | |||||
if stmt.OrderBy != nil { | |||||
for _, item := range stmt.OrderBy.Items { | |||||
b.traversalArgs(item, &selectArgsIndexs) | |||||
} | |||||
} | |||||
if stmt.Limit != nil { | |||||
if stmt.Limit.Offset != nil { | |||||
b.traversalArgs(stmt.Limit.Offset, &selectArgsIndexs) | |||||
} | |||||
if stmt.Limit.Count != nil { | |||||
b.traversalArgs(stmt.Limit.Count, &selectArgsIndexs) | |||||
} | |||||
} | |||||
// sort selectArgs index array | |||||
gxsort.Int32(selectArgsIndexs) | |||||
for _, index := range selectArgsIndexs { | |||||
selectArgs = append(selectArgs, args[index]) | |||||
} | |||||
return selectArgs | |||||
} | |||||
// todo perfect all sql operation | |||||
func (b *baseExecutor) traversalArgs(node ast.Node, argsIndex *[]int32) { | |||||
if node == nil { | |||||
return | |||||
} | |||||
switch node.(type) { | |||||
case *ast.BinaryOperationExpr: | |||||
expr := node.(*ast.BinaryOperationExpr) | |||||
b.traversalArgs(expr.L, argsIndex) | |||||
b.traversalArgs(expr.R, argsIndex) | |||||
break | |||||
case *ast.BetweenExpr: | |||||
expr := node.(*ast.BetweenExpr) | |||||
b.traversalArgs(expr.Left, argsIndex) | |||||
b.traversalArgs(expr.Right, argsIndex) | |||||
break | |||||
case *ast.PatternInExpr: | |||||
exprs := node.(*ast.PatternInExpr).List | |||||
for i := 0; i < len(exprs); i++ { | |||||
b.traversalArgs(exprs[i], argsIndex) | |||||
} | |||||
break | |||||
case *test_driver.ParamMarkerExpr: | |||||
*argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order)) | |||||
break | |||||
} | |||||
} | |||||
func (b *baseExecutor) buildRecordImages(rowsi driver.Rows, tableMetaData *types.TableMeta) (*types.RecordImage, error) { | |||||
// select column names | |||||
columnNames := rowsi.Columns() | |||||
rowImages := make([]types.RowImage, 0) | |||||
sqlRows := util.NewScanRows(rowsi) | |||||
for sqlRows.Next() { | |||||
ss := b.GetScanSlice(columnNames, tableMetaData) | |||||
err := sqlRows.Scan(ss...) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
columns := make([]types.ColumnImage, 0) | |||||
// build record image | |||||
for i, name := range columnNames { | |||||
columnMeta := tableMetaData.Columns[name] | |||||
keyType := types.IndexTypeNull | |||||
if _, ok := tableMetaData.GetPrimaryKeyMap()[name]; ok { | |||||
keyType = types.IndexTypePrimaryKey | |||||
} | |||||
jdbcType := types.MySQLStrToJavaType(columnMeta.DatabaseTypeString) | |||||
columns = append(columns, types.ColumnImage{ | |||||
KeyType: keyType, | |||||
ColumnName: name, | |||||
ColumnType: jdbcType, | |||||
Value: reflectx.GetElemDataValue(ss[i]), | |||||
}) | |||||
} | |||||
rowImages = append(rowImages, types.RowImage{Columns: columns}) | |||||
} | |||||
return &types.RecordImage{TableName: tableMetaData.TableName, Rows: rowImages}, nil | |||||
} | |||||
// buildWhereConditionByPKs build where condition by primary keys | |||||
// each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))" | |||||
func (b *baseExecutor) buildWhereConditionByPKs(pkNameList []string, rowSize int, dbType string, maxInSize int) string { | |||||
var ( | |||||
whereStr = &strings.Builder{} | |||||
batchSize = rowSize/maxInSize + 1 | |||||
) | |||||
if rowSize%maxInSize == 0 { | |||||
batchSize = rowSize / maxInSize | |||||
} | |||||
for batch := 0; batch < batchSize; batch++ { | |||||
if batch > 0 { | |||||
whereStr.WriteString(" OR ") | |||||
} | |||||
whereStr.WriteString("(") | |||||
for i := 0; i < len(pkNameList); i++ { | |||||
if i > 0 { | |||||
whereStr.WriteString(",") | |||||
} | |||||
// todo add escape | |||||
whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i])) | |||||
} | |||||
whereStr.WriteString(") IN (") | |||||
var eachSize int | |||||
if batch == batchSize-1 { | |||||
if rowSize%maxInSize == 0 { | |||||
eachSize = maxInSize | |||||
} else { | |||||
eachSize = rowSize % maxInSize | |||||
} | |||||
} else { | |||||
eachSize = maxInSize | |||||
} | |||||
for i := 0; i < eachSize; i++ { | |||||
if i > 0 { | |||||
whereStr.WriteString(",") | |||||
} | |||||
whereStr.WriteString("(") | |||||
for j := 0; j < len(pkNameList); j++ { | |||||
if j > 0 { | |||||
whereStr.WriteString(",") | |||||
} | |||||
whereStr.WriteString("?") | |||||
} | |||||
whereStr.WriteString(")") | |||||
} | |||||
whereStr.WriteString(")") | |||||
} | |||||
return whereStr.String() | |||||
} | |||||
func (b *baseExecutor) buildPKParams(rows []types.RowImage, pkNameList []string) []driver.NamedValue { | |||||
params := make([]driver.NamedValue, 0) | |||||
for _, row := range rows { | |||||
coumnMap := row.GetColumnMap() | |||||
for i, pk := range pkNameList { | |||||
if col, ok := coumnMap[pk]; ok { | |||||
params = append(params, driver.NamedValue{ | |||||
Ordinal: i, Value: col.Value, | |||||
}) | |||||
} | |||||
} | |||||
} | |||||
return params | |||||
} | |||||
// the string as local key. the local key example(multi pk): "t_user:1_a,2_b" | |||||
func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.TableMeta) string { | |||||
var ( | |||||
lockKeys bytes.Buffer | |||||
filedSequence int | |||||
) | |||||
lockKeys.WriteString(meta.TableName) | |||||
lockKeys.WriteString(":") | |||||
keys := meta.GetPrimaryKeyOnlyName() | |||||
for _, row := range records.Rows { | |||||
if filedSequence > 0 { | |||||
lockKeys.WriteString(",") | |||||
} | |||||
pkSplitIndex := 0 | |||||
for _, column := range row.Columns { | |||||
var hasKeyColumn bool | |||||
for _, key := range keys { | |||||
if column.ColumnName == key { | |||||
hasKeyColumn = true | |||||
if pkSplitIndex > 0 { | |||||
lockKeys.WriteString("_") | |||||
} | |||||
lockKeys.WriteString(fmt.Sprintf("%v", column.Value)) | |||||
pkSplitIndex++ | |||||
} | |||||
} | |||||
if hasKeyColumn { | |||||
filedSequence++ | |||||
} | |||||
} | |||||
} | |||||
return lockKeys.String() | |||||
} |
@@ -19,182 +19,61 @@ package at | |||||
import ( | import ( | ||||
"context" | "context" | ||||
"fmt" | |||||
"github.com/mitchellh/copystructure" | |||||
"github.com/pkg/errors" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/parser" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/undo" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||||
"github.com/seata/seata-go/pkg/tm" | "github.com/seata/seata-go/pkg/tm" | ||||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | "github.com/seata/seata-go/pkg/datasource/sql/exec" | ||||
"github.com/seata/seata-go/pkg/datasource/sql/parser" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/types" | "github.com/seata/seata-go/pkg/datasource/sql/types" | ||||
) | ) | ||||
type ATExecutor struct { | |||||
hooks []exec.SQLHook | |||||
ex exec.SQLExecutor | |||||
func init() { | |||||
exec.RegisterATExecutor(types.DBTypeMySQL, func() exec.SQLExecutor { return &AtExecutor{} }) | |||||
} | } | ||||
func (e *ATExecutor) Interceptors(hooks []exec.SQLHook) { | |||||
e.hooks = hooks | |||||
type executor interface { | |||||
ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) | |||||
} | } | ||||
func (e *ATExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||||
for _, hook := range e.hooks { | |||||
hook.Before(ctx, execCtx) | |||||
} | |||||
var ( | |||||
beforeImages []*types.RecordImage | |||||
afterImages []*types.RecordImage | |||||
result types.ExecResult | |||||
err error | |||||
) | |||||
beforeImages, err = e.beforeImage(ctx, execCtx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if beforeImages != nil { | |||||
beforeImagesTmp, err := copystructure.Copy(beforeImages) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
newBeforeImages, ok := beforeImagesTmp.([]*types.RecordImage) | |||||
if !ok { | |||||
return nil, errors.New("copy beforeImages failed") | |||||
} | |||||
execCtx.TxCtx.RoundImages.AppendBeofreImages(newBeforeImages) | |||||
} | |||||
defer func() { | |||||
for _, hook := range e.hooks { | |||||
hook.After(ctx, execCtx) | |||||
} | |||||
}() | |||||
if e.ex != nil { | |||||
result, err = e.ex.ExecWithNamedValue(ctx, execCtx, f) | |||||
} else { | |||||
result, err = f(ctx, execCtx.Query, execCtx.NamedValues) | |||||
} | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
afterImages, err = e.afterImage(ctx, execCtx, beforeImages) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if afterImages != nil { | |||||
execCtx.TxCtx.RoundImages.AppendAfterImages(afterImages) | |||||
} | |||||
return result, err | |||||
type AtExecutor struct { | |||||
hooks []exec.SQLHook | |||||
} | } | ||||
func (e *ATExecutor) prepareUndoLog(ctx context.Context, execCtx *types.ExecContext) error { | |||||
if execCtx.TxCtx.RoundImages.IsEmpty() { | |||||
return nil | |||||
} | |||||
if execCtx.ParseContext.UpdateStmt != nil { | |||||
if !execCtx.TxCtx.RoundImages.IsBeforeAfterSizeEq() { | |||||
return fmt.Errorf("Before image size is not equaled to after image size, probably because you updated the primary keys.") | |||||
} | |||||
} | |||||
undoLogManager, err := undo.GetUndoLogManager(execCtx.DBType) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
return undoLogManager.FlushUndoLog(execCtx.TxCtx, execCtx.Conn) | |||||
func (e *AtExecutor) Interceptors(hooks []exec.SQLHook) { | |||||
e.hooks = hooks | |||||
} | } | ||||
func (e *ATExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) { | |||||
for _, hook := range e.hooks { | |||||
hook.Before(ctx, execCtx) | |||||
} | |||||
var ( | |||||
beforeImages []*types.RecordImage | |||||
afterImages []*types.RecordImage | |||||
result types.ExecResult | |||||
err error | |||||
) | |||||
beforeImages, err = e.beforeImage(ctx, execCtx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if beforeImages != nil { | |||||
execCtx.TxCtx.RoundImages.AppendBeofreImages(beforeImages) | |||||
} | |||||
defer func() { | |||||
for _, hook := range e.hooks { | |||||
hook.After(ctx, execCtx) | |||||
} | |||||
}() | |||||
if e.ex != nil { | |||||
result, err = e.ex.ExecWithValue(ctx, execCtx, f) | |||||
} else { | |||||
result, err = f(ctx, execCtx.Query, execCtx.Values) | |||||
} | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
afterImages, err = e.afterImage(ctx, execCtx, beforeImages) | |||||
//ExecWithNamedValue | |||||
func (e *AtExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||||
parser, err := parser.DoParser(execCtx.Query) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
if afterImages != nil { | |||||
execCtx.TxCtx.RoundImages.AppendAfterImages(afterImages) | |||||
} | |||||
return result, err | |||||
} | |||||
var exec executor | |||||
func (e *ATExecutor) beforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) { | |||||
if !tm.IsGlobalTx(ctx) { | if !tm.IsGlobalTx(ctx) { | ||||
return nil, nil | |||||
} | |||||
pc, err := parser.DoParser(execCtx.Query) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if !pc.HasValidStmt() { | |||||
return nil, nil | |||||
exec = NewPlainExecutor(parser, execCtx) | |||||
} else { | |||||
switch parser.SQLType { | |||||
//case types.SQLTypeInsert: | |||||
case types.SQLTypeUpdate: | |||||
exec = NewUpdateExecutor(parser, execCtx, e.hooks) | |||||
//case types.SQLTypeDelete: | |||||
//case types.SQLTypeSelectForUpdate: | |||||
//case types.SQLTypeMultiDelete: | |||||
//case types.SQLTypeMultiUpdate: | |||||
default: | |||||
exec = NewPlainExecutor(parser, execCtx) | |||||
} | |||||
} | } | ||||
execCtx.ParseContext = pc | |||||
builder := undo.GetUndologBuilder(pc.ExecutorType) | |||||
if builder == nil { | |||||
return nil, nil | |||||
} | |||||
return builder.BeforeImage(ctx, execCtx) | |||||
return exec.ExecContext(ctx, f) | |||||
} | } | ||||
// After | |||||
func (e *ATExecutor) afterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { | |||||
if !tm.IsGlobalTx(ctx) { | |||||
return nil, nil | |||||
} | |||||
pc, err := parser.DoParser(execCtx.Query) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if !pc.HasValidStmt() { | |||||
return nil, nil | |||||
} | |||||
execCtx.ParseContext = pc | |||||
builder := undo.GetUndologBuilder(pc.ExecutorType) | |||||
if builder == nil { | |||||
return nil, nil | |||||
} | |||||
return builder.AfterImage(ctx, execCtx, beforeImages) | |||||
//ExecWithValue | |||||
func (e *AtExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||||
execCtx.NamedValues = util.ValueToNamedValue(execCtx.Values) | |||||
return e.ExecWithNamedValue(ctx, execCtx, f) | |||||
} | } |
@@ -18,21 +18,21 @@ | |||||
package at | package at | ||||
import ( | import ( | ||||
"context" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | "github.com/seata/seata-go/pkg/datasource/sql/exec" | ||||
"github.com/seata/seata-go/pkg/datasource/sql/types" | "github.com/seata/seata-go/pkg/datasource/sql/types" | ||||
) | ) | ||||
func init() { | |||||
exec.RegisterATExecutor(types.DBTypeMySQL, types.UpdateExecutor, func() exec.SQLExecutor { | |||||
return &ATExecutor{} | |||||
}) | |||||
exec.RegisterATExecutor(types.DBTypeMySQL, types.SelectExecutor, func() exec.SQLExecutor { | |||||
return &ATExecutor{} | |||||
}) | |||||
exec.RegisterATExecutor(types.DBTypeMySQL, types.InsertExecutor, func() exec.SQLExecutor { | |||||
return &ATExecutor{} | |||||
}) | |||||
exec.RegisterATExecutor(types.DBTypeMySQL, types.DeleteExecutor, func() exec.SQLExecutor { | |||||
return &ATExecutor{} | |||||
}) | |||||
type PlainExecutor struct { | |||||
parserCtx *types.ParseContext | |||||
execCtx *types.ExecContext | |||||
} | |||||
func NewPlainExecutor(parserCtx *types.ParseContext, execCtx *types.ExecContext) *PlainExecutor { | |||||
return &PlainExecutor{parserCtx: parserCtx, execCtx: execCtx} | |||||
} | |||||
func (u *PlainExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||||
return f(ctx, u.execCtx.Query, u.execCtx.NamedValues) | |||||
} | } |
@@ -0,0 +1,272 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
package at | |||||
import ( | |||||
"context" | |||||
"database/sql/driver" | |||||
"fmt" | |||||
"strings" | |||||
"github.com/arana-db/parser/ast" | |||||
"github.com/arana-db/parser/format" | |||||
"github.com/arana-db/parser/model" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||||
"github.com/seata/seata-go/pkg/util/bytes" | |||||
"github.com/seata/seata-go/pkg/util/log" | |||||
) | |||||
var ( | |||||
// todo: OnlyCareUpdateColumns should load from config first | |||||
onlyCareUpdateColumns = true | |||||
maxInSize = 1000 | |||||
) | |||||
type updateExecutor struct { | |||||
baseExecutor | |||||
parserCtx *types.ParseContext | |||||
execContent *types.ExecContext | |||||
} | |||||
//NewUpdateExecutor get update executor | |||||
func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { | |||||
return &updateExecutor{parserCtx: parserCtx, execContent: execContent, baseExecutor: baseExecutor{hooks: hooks}} | |||||
} | |||||
//ExecContext exec SQL, and generate before image and after image | |||||
func (u *updateExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||||
u.beforeHooks(ctx, u.execContent) | |||||
defer func() { | |||||
u.afterHooks(ctx, u.execContent) | |||||
}() | |||||
beforeImage, err := u.beforeImage(ctx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
res, err := f(ctx, u.execContent.Query, u.execContent.NamedValues) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
afterImage, err := u.afterImage(ctx, *beforeImage) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if len(beforeImage.Rows) != len(afterImage.Rows) { | |||||
return nil, fmt.Errorf("Before image size is not equaled to after image size, probably because you updated the primary keys.") | |||||
} | |||||
u.execContent.TxCtx.RoundImages.AppendBeofreImage(beforeImage) | |||||
u.execContent.TxCtx.RoundImages.AppendAfterImage(afterImage) | |||||
return res, nil | |||||
} | |||||
//beforeImage build before image | |||||
func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) { | |||||
if !u.isAstStmtValid() { | |||||
return nil, nil | |||||
} | |||||
selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContent.NamedValues) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
tableName, _ := u.parserCtx.GteTableName() | |||||
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
var rowsi driver.Rows | |||||
queryerCtx, ok := u.execContent.Conn.(driver.QueryerContext) | |||||
var queryer driver.Queryer | |||||
if !ok { | |||||
queryer, ok = u.execContent.Conn.(driver.Queryer) | |||||
} | |||||
if ok { | |||||
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) | |||||
if err != nil { | |||||
log.Errorf("ctx driver query: %+v", err) | |||||
return nil, err | |||||
} | |||||
} else { | |||||
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") | |||||
return nil, fmt.Errorf("invalid conn") | |||||
} | |||||
image, err := u.buildRecordImages(rowsi, metaData) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
lockKey := u.buildLockKey(image, *metaData) | |||||
u.execContent.TxCtx.LockKeys[lockKey] = struct{}{} | |||||
image.SQLType = u.parserCtx.SQLType | |||||
return image, nil | |||||
} | |||||
//afterImage build after image | |||||
func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.RecordImage) (*types.RecordImage, error) { | |||||
if !u.isAstStmtValid() { | |||||
return nil, nil | |||||
} | |||||
if len(beforeImage.Rows) == 0 { | |||||
return &types.RecordImage{}, nil | |||||
} | |||||
tableName, _ := u.parserCtx.GteTableName() | |||||
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData) | |||||
var rowsi driver.Rows | |||||
queryerCtx, ok := u.execContent.Conn.(driver.QueryerContext) | |||||
var queryer driver.Queryer | |||||
if !ok { | |||||
queryer, ok = u.execContent.Conn.(driver.Queryer) | |||||
} | |||||
if ok { | |||||
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) | |||||
if err != nil { | |||||
log.Errorf("ctx driver query: %+v", err) | |||||
return nil, err | |||||
} | |||||
} else { | |||||
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") | |||||
return nil, fmt.Errorf("invalid conn") | |||||
} | |||||
afterImage, err := u.buildRecordImages(rowsi, metaData) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
afterImage.SQLType = u.parserCtx.SQLType | |||||
return afterImage, nil | |||||
} | |||||
func (u *updateExecutor) isAstStmtValid() bool { | |||||
return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil | |||||
} | |||||
//buildAfterImageSQL build the SQL to query after image data | |||||
func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta *types.TableMeta) (string, []driver.NamedValue) { | |||||
if len(beforeImage.Rows) == 0 { | |||||
return "", nil | |||||
} | |||||
sb := strings.Builder{} | |||||
// todo: OnlyCareUpdateColumns should load from config first | |||||
var selectFields string | |||||
var separator = "," | |||||
if onlyCareUpdateColumns { | |||||
for _, row := range beforeImage.Rows { | |||||
for _, column := range row.Columns { | |||||
selectFields += column.ColumnName + separator | |||||
} | |||||
} | |||||
selectFields = strings.TrimSuffix(selectFields, separator) | |||||
} else { | |||||
selectFields = "*" | |||||
} | |||||
sb.WriteString("SELECT " + selectFields + " FROM " + meta.TableName + " WHERE ") | |||||
whereSQL := u.buildWhereConditionByPKs(meta.GetPrimaryKeyOnlyName(), len(beforeImage.Rows), "mysql", maxInSize) | |||||
sb.WriteString(" " + whereSQL + " ") | |||||
return sb.String(), u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()) | |||||
} | |||||
//buildAfterImageSQL build the SQL to query before image data | |||||
func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) { | |||||
if !u.isAstStmtValid() { | |||||
log.Errorf("invalid update stmt") | |||||
return "", nil, fmt.Errorf("invalid update stmt") | |||||
} | |||||
updateStmt := u.parserCtx.UpdateStmt | |||||
fields := make([]*ast.SelectField, 0, len(updateStmt.List)) | |||||
if onlyCareUpdateColumns { | |||||
for _, column := range updateStmt.List { | |||||
fields = append(fields, &ast.SelectField{ | |||||
Expr: &ast.ColumnNameExpr{ | |||||
Name: column.Column, | |||||
}, | |||||
}) | |||||
} | |||||
// select indexes columns | |||||
tableName, _ := u.parserCtx.GteTableName() | |||||
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName) | |||||
if err != nil { | |||||
return "", nil, err | |||||
} | |||||
for _, columnName := range metaData.GetPrimaryKeyOnlyName() { | |||||
fields = append(fields, &ast.SelectField{ | |||||
Expr: &ast.ColumnNameExpr{ | |||||
Name: &ast.ColumnName{ | |||||
Name: model.CIStr{ | |||||
O: columnName, | |||||
L: columnName, | |||||
}, | |||||
}, | |||||
}, | |||||
}) | |||||
} | |||||
} else { | |||||
fields = append(fields, &ast.SelectField{ | |||||
Expr: &ast.ColumnNameExpr{ | |||||
Name: &ast.ColumnName{ | |||||
Name: model.CIStr{ | |||||
O: "*", | |||||
L: "*", | |||||
}, | |||||
}, | |||||
}, | |||||
}) | |||||
} | |||||
selStmt := ast.SelectStmt{ | |||||
SelectStmtOpts: &ast.SelectStmtOpts{}, | |||||
From: updateStmt.TableRefs, | |||||
Where: updateStmt.Where, | |||||
Fields: &ast.FieldList{Fields: fields}, | |||||
OrderBy: updateStmt.Order, | |||||
Limit: updateStmt.Limit, | |||||
TableHints: updateStmt.TableHints, | |||||
LockInfo: &ast.SelectLockInfo{ | |||||
LockType: ast.SelectLockForUpdate, | |||||
}, | |||||
} | |||||
b := bytes.NewByteBuffer([]byte{}) | |||||
_ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) | |||||
sql := string(b.Bytes()) | |||||
log.Infof("build select sql by update sourceQuery, sql {%s}", sql) | |||||
return sql, u.buildSelectArgs(&selStmt, args), nil | |||||
} |
@@ -0,0 +1,101 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
package at | |||||
import ( | |||||
"context" | |||||
"database/sql/driver" | |||||
"reflect" | |||||
"testing" | |||||
"github.com/agiledragon/gomonkey" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/parser" | |||||
_ "github.com/arana-db/parser/test_driver" | |||||
_ "github.com/seata/seata-go/pkg/util/log" | |||||
"github.com/stretchr/testify/assert" | |||||
) | |||||
func TestBuildSelectSQLByUpdate(t *testing.T) { | |||||
datasource.RegisterTableCache(types.DBTypeMySQL, mysql.NewTableMetaInstance(nil)) | |||||
stub := gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)), "GetTableMeta", | |||||
func(_ *mysql.TableMetaCache, ctx context.Context, dbName, tableName string) (*types.TableMeta, error) { | |||||
return &types.TableMeta{ | |||||
Indexs: map[string]types.IndexMeta{ | |||||
"id": { | |||||
IType: types.IndexTypePrimaryKey, | |||||
Columns: []types.ColumnMeta{ | |||||
{ColumnName: "id"}, | |||||
}, | |||||
}, | |||||
}, | |||||
}, nil | |||||
}) | |||||
defer stub.Reset() | |||||
tests := []struct { | |||||
name string | |||||
sourceQuery string | |||||
sourceQueryArgs []driver.Value | |||||
expectQuery string | |||||
expectQueryArgs []driver.Value | |||||
}{ | |||||
{ | |||||
sourceQuery: "update t_user set name = ?, age = ? where id = ?", | |||||
sourceQueryArgs: []driver.Value{"Jack", 1, 100}, | |||||
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? FOR UPDATE", | |||||
expectQueryArgs: []driver.Value{100}, | |||||
}, | |||||
{ | |||||
sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age between ? and ?", | |||||
sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, | |||||
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE", | |||||
expectQueryArgs: []driver.Value{100, 18, 28}, | |||||
}, | |||||
{ | |||||
sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age in (?,?)", | |||||
sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, | |||||
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE", | |||||
expectQueryArgs: []driver.Value{100, 18, 28}, | |||||
}, | |||||
{ | |||||
sourceQuery: "update t_user set name = ?, age = ? where kk between ? and ? and id = ? and addr in(?,?) and age > ? order by name desc limit ?", | |||||
sourceQueryArgs: []driver.Value{"Jack", 1, 10, 20, 17, "Beijing", "Guangzhou", 18, 2}, | |||||
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE", | |||||
expectQueryArgs: []driver.Value{10, 20, 17, "Beijing", "Guangzhou", 18, 2}, | |||||
}, | |||||
} | |||||
for _, tt := range tests { | |||||
t.Run(tt.name, func(t *testing.T) { | |||||
c, err := parser.DoParser(tt.sourceQuery) | |||||
assert.Nil(t, err) | |||||
executor := NewUpdateExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{}) | |||||
query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), util.ValueToNamedValue(tt.sourceQueryArgs)) | |||||
assert.Nil(t, err) | |||||
assert.Equal(t, tt.expectQuery, query) | |||||
assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args)) | |||||
}) | |||||
} | |||||
} |
@@ -25,35 +25,25 @@ import ( | |||||
"github.com/seata/seata-go/pkg/datasource/sql/types" | "github.com/seata/seata-go/pkg/datasource/sql/types" | ||||
"github.com/seata/seata-go/pkg/datasource/sql/undo" | "github.com/seata/seata-go/pkg/datasource/sql/undo" | ||||
"github.com/seata/seata-go/pkg/datasource/sql/undo/builder" | "github.com/seata/seata-go/pkg/datasource/sql/undo/builder" | ||||
"github.com/seata/seata-go/pkg/util/log" | |||||
) | ) | ||||
func init() { | func init() { | ||||
undo.RegisterUndoLogBuilder(types.UpdateExecutor, builder.GetMySQLUpdateUndoLogBuilder) | |||||
undo.RegisterUndoLogBuilder(types.MultiExecutor, builder.GetMySQLMultiUndoLogBuilder) | undo.RegisterUndoLogBuilder(types.MultiExecutor, builder.GetMySQLMultiUndoLogBuilder) | ||||
} | } | ||||
var ( | var ( | ||||
executorSoltsAT = make(map[types.DBType]map[types.ExecutorType]func() SQLExecutor) | |||||
executorSoltsXA = make(map[types.DBType]func() SQLExecutor) | |||||
atExecutors = make(map[types.DBType]func() SQLExecutor) | |||||
xaExecutors = make(map[types.DBType]func() SQLExecutor) | |||||
) | ) | ||||
// RegisterATExecutor AT executor | // RegisterATExecutor AT executor | ||||
func RegisterATExecutor(dt types.DBType, et types.ExecutorType, builder func() SQLExecutor) { | |||||
if _, ok := executorSoltsAT[dt]; !ok { | |||||
executorSoltsAT[dt] = make(map[types.ExecutorType]func() SQLExecutor) | |||||
} | |||||
val := executorSoltsAT[dt] | |||||
val[et] = func() SQLExecutor { | |||||
return &BaseExecutor{ex: builder()} | |||||
} | |||||
func RegisterATExecutor(dt types.DBType, builder func() SQLExecutor) { | |||||
atExecutors[dt] = builder | |||||
} | } | ||||
// RegisterXAExecutor XA executor | // RegisterXAExecutor XA executor | ||||
func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) { | func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) { | ||||
executorSoltsXA[dt] = func() SQLExecutor { | |||||
xaExecutors[dt] = func() SQLExecutor { | |||||
return builder() | return builder() | ||||
} | } | ||||
} | } | ||||
@@ -66,7 +56,7 @@ type ( | |||||
SQLExecutor interface { | SQLExecutor interface { | ||||
Interceptors(interceptors []SQLHook) | Interceptors(interceptors []SQLHook) | ||||
ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) | ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) | ||||
ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error) | |||||
ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) | |||||
} | } | ||||
) | ) | ||||
@@ -83,37 +73,14 @@ func BuildExecutor(dbType types.DBType, transactionMode types.TransactionMode, q | |||||
hooks = append(hooks, hookSolts[parseContext.SQLType]...) | hooks = append(hooks, hookSolts[parseContext.SQLType]...) | ||||
if transactionMode == types.XAMode { | if transactionMode == types.XAMode { | ||||
e := executorSoltsXA[dbType]() | |||||
e.Interceptors(hooks) | |||||
return e, nil | |||||
} | |||||
if transactionMode == types.ATMode { | |||||
e := executorSoltsAT[dbType][parseContext.ExecutorType]() | |||||
e.Interceptors(hooks) | |||||
return e, nil | |||||
} | |||||
factories, ok := executorSoltsAT[dbType] | |||||
if !ok { | |||||
log.Debugf("%s not found executor factories, return default Executor", dbType.String()) | |||||
e := &BaseExecutor{} | |||||
e := xaExecutors[dbType]() | |||||
e.Interceptors(hooks) | e.Interceptors(hooks) | ||||
return e, nil | return e, nil | ||||
} | } | ||||
supplier, ok := factories[parseContext.ExecutorType] | |||||
if !ok { | |||||
log.Debugf("%s not found executor for %s, return default Executor", | |||||
dbType.String(), parseContext.ExecutorType) | |||||
e := &BaseExecutor{} | |||||
e.Interceptors(hooks) | |||||
return e, nil | |||||
} | |||||
executor := supplier() | |||||
executor.Interceptors(hooks) | |||||
return executor, nil | |||||
e := atExecutors[dbType]() | |||||
e.Interceptors(hooks) | |||||
return e, nil | |||||
} | } | ||||
type BaseExecutor struct { | type BaseExecutor struct { | ||||
@@ -146,7 +113,7 @@ func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Ex | |||||
} | } | ||||
// ExecWithValue | // ExecWithValue | ||||
func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error) { | |||||
func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) { | |||||
for i := range e.hooks { | for i := range e.hooks { | ||||
e.hooks[i].Before(ctx, execCtx) | e.hooks[i].Before(ctx, execCtx) | ||||
} | } | ||||
@@ -161,5 +128,13 @@ func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecCon | |||||
return e.ex.ExecWithValue(ctx, execCtx, f) | return e.ex.ExecWithValue(ctx, execCtx, f) | ||||
} | } | ||||
return f(ctx, execCtx.Query, execCtx.Values) | |||||
nvargs := make([]driver.NamedValue, len(execCtx.Values)) | |||||
for i, value := range execCtx.Values { | |||||
nvargs = append(nvargs, driver.NamedValue{ | |||||
Value: value, | |||||
Ordinal: i, | |||||
}) | |||||
} | |||||
return f(ctx, execCtx.Query, nvargs) | |||||
} | } |
@@ -19,6 +19,7 @@ package xa | |||||
import ( | import ( | ||||
"context" | "context" | ||||
"database/sql/driver" | |||||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | "github.com/seata/seata-go/pkg/datasource/sql/exec" | ||||
"github.com/seata/seata-go/pkg/datasource/sql/types" | "github.com/seata/seata-go/pkg/datasource/sql/types" | ||||
@@ -55,7 +56,7 @@ func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Exec | |||||
} | } | ||||
// ExecWithValue | // ExecWithValue | ||||
func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) { | |||||
func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||||
for _, hook := range e.hooks { | for _, hook := range e.hooks { | ||||
hook.Before(ctx, execCtx) | hook.Before(ctx, execCtx) | ||||
} | } | ||||
@@ -70,5 +71,14 @@ func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecConte | |||||
return e.ex.ExecWithValue(ctx, execCtx, f) | return e.ex.ExecWithValue(ctx, execCtx, f) | ||||
} | } | ||||
return f(ctx, execCtx.Query, execCtx.Values) | |||||
nvargs := make([]driver.NamedValue, len(execCtx.Values)) | |||||
for i, value := range execCtx.Values { | |||||
nvargs = append(nvargs, driver.NamedValue{ | |||||
Value: value, | |||||
Ordinal: i, | |||||
}) | |||||
} | |||||
execCtx.NamedValues = nvargs | |||||
return f(ctx, execCtx.Query, execCtx.NamedValues) | |||||
} | } |
@@ -23,6 +23,7 @@ import ( | |||||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | "github.com/seata/seata-go/pkg/datasource/sql/exec" | ||||
"github.com/seata/seata-go/pkg/datasource/sql/types" | "github.com/seata/seata-go/pkg/datasource/sql/types" | ||||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||||
) | ) | ||||
type Stmt struct { | type Stmt struct { | ||||
@@ -75,8 +76,8 @@ func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { | |||||
} | } | ||||
ret, err := executor.ExecWithValue(context.Background(), execCtx, | ret, err := executor.ExecWithValue(context.Background(), execCtx, | ||||
func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { | |||||
ret, err := s.stmt.Query(args) | |||||
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { | |||||
ret, err := s.stmt.Query(util.NamedValueToValue(args)) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
@@ -144,8 +145,8 @@ func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { | |||||
} | } | ||||
ret, err := executor.ExecWithValue(context.Background(), execCtx, | ret, err := executor.ExecWithValue(context.Background(), execCtx, | ||||
func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { | |||||
ret, err := s.stmt.Exec(args) | |||||
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { | |||||
ret, err := s.stmt.Exec(util.NamedValueToValue(args)) | |||||
if err != nil { | if err != nil { | ||||
return nil, err | return nil, err | ||||
} | } | ||||
@@ -129,7 +129,8 @@ type RowImage struct { | |||||
func (r *RowImage) GetColumnMap() map[string]*ColumnImage { | func (r *RowImage) GetColumnMap() map[string]*ColumnImage { | ||||
m := make(map[string]*ColumnImage, 0) | m := make(map[string]*ColumnImage, 0) | ||||
for _, column := range r.Columns { | for _, column := range r.Columns { | ||||
m[column.ColumnName] = &column | |||||
tmpColumn := column | |||||
m[column.ColumnName] = &tmpColumn | |||||
} | } | ||||
return m | return m | ||||
} | } | ||||
@@ -245,7 +246,7 @@ func (c *ColumnImage) UnmarshalJSON(data []byte) error { | |||||
case JDBCTypeChar, JDBCTypeVarchar: | case JDBCTypeChar, JDBCTypeVarchar: | ||||
var val []byte | var val []byte | ||||
if val, err = base64.StdEncoding.DecodeString(value.(string)); err != nil { | if val, err = base64.StdEncoding.DecodeString(value.(string)); err != nil { | ||||
return err | |||||
val = []byte(value.(string)) | |||||
} | } | ||||
actualValue = string(val) | actualValue = string(val) | ||||
case JDBCTypeBinary, JDBCTypeVarBinary, JDBCTypeLongVarBinary, JDBCTypeBit: | case JDBCTypeBinary, JDBCTypeVarBinary, JDBCTypeLongVarBinary, JDBCTypeBit: | ||||
@@ -0,0 +1,377 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
// Copyright 2011 The Go Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the LICENSE file. | |||||
// Type conversions for Scan. | |||||
package util | |||||
import ( | |||||
"database/sql" | |||||
"database/sql/driver" | |||||
"errors" | |||||
"fmt" | |||||
"reflect" | |||||
"strconv" | |||||
"time" | |||||
) | |||||
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error | |||||
// convertAssignRows copies to dest the value in src, converting it if possible. | |||||
// An error is returned if the copy would result in loss of information. | |||||
// dest should be a pointer type. If rows is passed in, the rows will | |||||
// be used as the parent for any cursor values converted from a | |||||
// driver.Rows to a *ScanRows. | |||||
func convertAssignRows(dest, src interface{}, rows *ScanRows) error { | |||||
// Common cases, without reflect. | |||||
switch s := src.(type) { | |||||
case string: | |||||
switch d := dest.(type) { | |||||
case *string: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = s | |||||
return nil | |||||
case *[]byte: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = []byte(s) | |||||
return nil | |||||
case *sql.RawBytes: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = append((*d)[:0], s...) | |||||
return nil | |||||
} | |||||
case []byte: | |||||
switch d := dest.(type) { | |||||
case *string: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = string(s) | |||||
return nil | |||||
case *interface{}: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = cloneBytes(s) | |||||
return nil | |||||
case *[]byte: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = cloneBytes(s) | |||||
return nil | |||||
case *sql.RawBytes: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = s | |||||
return nil | |||||
} | |||||
case time.Time: | |||||
switch d := dest.(type) { | |||||
case *time.Time: | |||||
*d = s | |||||
return nil | |||||
case *string: | |||||
*d = s.Format(time.RFC3339Nano) | |||||
return nil | |||||
case *[]byte: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = []byte(s.Format(time.RFC3339Nano)) | |||||
return nil | |||||
case *sql.RawBytes: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = s.AppendFormat((*d)[:0], time.RFC3339Nano) | |||||
return nil | |||||
} | |||||
case decimalDecompose: | |||||
switch d := dest.(type) { | |||||
case decimalCompose: | |||||
return d.Compose(s.Decompose(nil)) | |||||
} | |||||
case nil: | |||||
switch d := dest.(type) { | |||||
case *interface{}: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = nil | |||||
return nil | |||||
case *[]byte: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = nil | |||||
return nil | |||||
case *sql.RawBytes: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
*d = nil | |||||
return nil | |||||
} | |||||
// The driver is returning a cursor the client may iterate over. | |||||
case driver.Rows: | |||||
switch d := dest.(type) { | |||||
case *ScanRows: | |||||
if d == nil { | |||||
return errNilPtr | |||||
} | |||||
if rows == nil { | |||||
return errors.New("invalid context to convert cursor rows, missing parent *ScanRows") | |||||
} | |||||
rows.closemu.Lock() | |||||
*d = ScanRows{ | |||||
dc: rows.dc, | |||||
releaseConn: func(error) {}, | |||||
rowsi: s, | |||||
} | |||||
// Chain the cancel function. | |||||
parentCancel := rows.cancel | |||||
rows.cancel = func() { | |||||
// When ScanRows.cancel is called, the closemu will be locked as well. | |||||
// So we can access rs.lasterr. | |||||
d.close(rows.lasterr) | |||||
if parentCancel != nil { | |||||
parentCancel() | |||||
} | |||||
} | |||||
rows.closemu.Unlock() | |||||
return nil | |||||
} | |||||
} | |||||
var sv reflect.Value | |||||
switch d := dest.(type) { | |||||
case *string: | |||||
sv = reflect.ValueOf(src) | |||||
switch sv.Kind() { | |||||
case reflect.Bool, | |||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, | |||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, | |||||
reflect.Float32, reflect.Float64: | |||||
*d = asString(src) | |||||
return nil | |||||
} | |||||
case *[]byte: | |||||
sv = reflect.ValueOf(src) | |||||
if b, ok := asBytes(nil, sv); ok { | |||||
*d = b | |||||
return nil | |||||
} | |||||
case *sql.RawBytes: | |||||
sv = reflect.ValueOf(src) | |||||
if b, ok := asBytes([]byte(*d)[:0], sv); ok { | |||||
*d = sql.RawBytes(b) | |||||
return nil | |||||
} | |||||
case *bool: | |||||
bv, err := driver.Bool.ConvertValue(src) | |||||
if err == nil { | |||||
*d = bv.(bool) | |||||
} | |||||
return err | |||||
case *interface{}: | |||||
*d = src | |||||
return nil | |||||
} | |||||
if scanner, ok := dest.(sql.Scanner); ok { | |||||
return scanner.Scan(src) | |||||
} | |||||
dpv := reflect.ValueOf(dest) | |||||
if dpv.Kind() != reflect.Ptr { | |||||
return errors.New("destination not a pointer") | |||||
} | |||||
if dpv.IsNil() { | |||||
return errNilPtr | |||||
} | |||||
if !sv.IsValid() { | |||||
sv = reflect.ValueOf(src) | |||||
} | |||||
dv := reflect.Indirect(dpv) | |||||
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { | |||||
switch b := src.(type) { | |||||
case []byte: | |||||
dv.Set(reflect.ValueOf(cloneBytes(b))) | |||||
default: | |||||
dv.Set(sv) | |||||
} | |||||
return nil | |||||
} | |||||
if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { | |||||
dv.Set(sv.Convert(dv.Type())) | |||||
return nil | |||||
} | |||||
// The following conversions use a string value as an intermediate representation | |||||
// to convert between various numeric types. | |||||
// | |||||
// This also allows scanning into user defined types such as "type Int int64". | |||||
// For symmetry, also check for string destination types. | |||||
switch dv.Kind() { | |||||
case reflect.Ptr: | |||||
if src == nil { | |||||
dv.Set(reflect.Zero(dv.Type())) | |||||
return nil | |||||
} | |||||
dv.Set(reflect.New(dv.Type().Elem())) | |||||
return convertAssignRows(dv.Interface(), src, rows) | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||||
if src == nil { | |||||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | |||||
} | |||||
s := asString(src) | |||||
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) | |||||
if err != nil { | |||||
err = strconvErr(err) | |||||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) | |||||
} | |||||
dv.SetInt(i64) | |||||
return nil | |||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||||
if src == nil { | |||||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | |||||
} | |||||
s := asString(src) | |||||
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) | |||||
if err != nil { | |||||
err = strconvErr(err) | |||||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) | |||||
} | |||||
dv.SetUint(u64) | |||||
return nil | |||||
case reflect.Float32, reflect.Float64: | |||||
if src == nil { | |||||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | |||||
} | |||||
s := asString(src) | |||||
f64, err := strconv.ParseFloat(s, dv.Type().Bits()) | |||||
if err != nil { | |||||
err = strconvErr(err) | |||||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) | |||||
} | |||||
dv.SetFloat(f64) | |||||
return nil | |||||
case reflect.String: | |||||
if src == nil { | |||||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | |||||
} | |||||
switch v := src.(type) { | |||||
case string: | |||||
dv.SetString(v) | |||||
return nil | |||||
case []byte: | |||||
dv.SetString(string(v)) | |||||
return nil | |||||
} | |||||
} | |||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) | |||||
} | |||||
func strconvErr(err error) error { | |||||
if ne, ok := err.(*strconv.NumError); ok { | |||||
return ne.Err | |||||
} | |||||
return err | |||||
} | |||||
func cloneBytes(b []byte) []byte { | |||||
if b == nil { | |||||
return nil | |||||
} | |||||
c := make([]byte, len(b)) | |||||
copy(c, b) | |||||
return c | |||||
} | |||||
func asString(src interface{}) string { | |||||
switch v := src.(type) { | |||||
case string: | |||||
return v | |||||
case []byte: | |||||
return string(v) | |||||
} | |||||
rv := reflect.ValueOf(src) | |||||
switch rv.Kind() { | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||||
return strconv.FormatInt(rv.Int(), 10) | |||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||||
return strconv.FormatUint(rv.Uint(), 10) | |||||
case reflect.Float64: | |||||
return strconv.FormatFloat(rv.Float(), 'g', -1, 64) | |||||
case reflect.Float32: | |||||
return strconv.FormatFloat(rv.Float(), 'g', -1, 32) | |||||
case reflect.Bool: | |||||
return strconv.FormatBool(rv.Bool()) | |||||
} | |||||
return fmt.Sprintf("%v", src) | |||||
} | |||||
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { | |||||
switch rv.Kind() { | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||||
return strconv.AppendInt(buf, rv.Int(), 10), true | |||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||||
return strconv.AppendUint(buf, rv.Uint(), 10), true | |||||
case reflect.Float32: | |||||
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true | |||||
case reflect.Float64: | |||||
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true | |||||
case reflect.Bool: | |||||
return strconv.AppendBool(buf, rv.Bool()), true | |||||
case reflect.String: | |||||
s := rv.String() | |||||
return append(buf, s...), true | |||||
} | |||||
return | |||||
} | |||||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() | |||||
type decimalDecompose interface { | |||||
// Decompose returns the internal decimal state in parts. | |||||
// If the provided buf has sufficient capacity, buf may be returned as the coefficient with | |||||
// the value set and length set as appropriate. | |||||
Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) | |||||
} | |||||
type decimalCompose interface { | |||||
// Compose sets the internal decimal value from parts. If the value cannot be | |||||
// represented then an error should be returned. | |||||
Compose(form byte, negative bool, coefficient []byte, exponent int32) error | |||||
} |
@@ -0,0 +1,123 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
// Copyright 2016 The Go Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the LICENSE file. | |||||
package util | |||||
import ( | |||||
"context" | |||||
"database/sql/driver" | |||||
"errors" | |||||
) | |||||
func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) { | |||||
if ciCtx, is := ci.(driver.ConnPrepareContext); is { | |||||
return ciCtx.PrepareContext(ctx, query) | |||||
} | |||||
si, err := ci.Prepare(query) | |||||
if err == nil { | |||||
select { | |||||
default: | |||||
case <-ctx.Done(): | |||||
si.Close() | |||||
return nil, ctx.Err() | |||||
} | |||||
} | |||||
return si, err | |||||
} | |||||
func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) { | |||||
if execerCtx != nil { | |||||
return execerCtx.ExecContext(ctx, query, nvdargs) | |||||
} | |||||
dargs, err := namedValueToValue(nvdargs) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
select { | |||||
default: | |||||
case <-ctx.Done(): | |||||
return nil, ctx.Err() | |||||
} | |||||
return execer.Exec(query, dargs) | |||||
} | |||||
func CtxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) { | |||||
if queryerCtx != nil { | |||||
return queryerCtx.QueryContext(ctx, query, nvdargs) | |||||
} | |||||
dargs, err := namedValueToValue(nvdargs) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
select { | |||||
default: | |||||
case <-ctx.Done(): | |||||
return nil, ctx.Err() | |||||
} | |||||
return queryer.Query(query, dargs) | |||||
} | |||||
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) { | |||||
if siCtx, is := si.(driver.StmtExecContext); is { | |||||
return siCtx.ExecContext(ctx, nvdargs) | |||||
} | |||||
dargs, err := namedValueToValue(nvdargs) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
select { | |||||
default: | |||||
case <-ctx.Done(): | |||||
return nil, ctx.Err() | |||||
} | |||||
return si.Exec(dargs) | |||||
} | |||||
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) { | |||||
if siCtx, is := si.(driver.StmtQueryContext); is { | |||||
return siCtx.QueryContext(ctx, nvdargs) | |||||
} | |||||
dargs, err := namedValueToValue(nvdargs) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
select { | |||||
default: | |||||
case <-ctx.Done(): | |||||
return nil, ctx.Err() | |||||
} | |||||
return si.Query(dargs) | |||||
} | |||||
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { | |||||
dargs := make([]driver.Value, len(named)) | |||||
for n, param := range named { | |||||
if len(param.Name) > 0 { | |||||
return nil, errors.New("sql: driver does not support the use of Named Parameters") | |||||
} | |||||
dargs[n] = param.Value | |||||
} | |||||
return dargs, nil | |||||
} |
@@ -0,0 +1,39 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
package util | |||||
import "database/sql/driver" | |||||
func NamedValueToValue(nvs []driver.NamedValue) []driver.Value { | |||||
vs := make([]driver.Value, 0, len(nvs)) | |||||
for _, nv := range nvs { | |||||
vs = append(vs, nv.Value) | |||||
} | |||||
return vs | |||||
} | |||||
func ValueToNamedValue(vs []driver.Value) []driver.NamedValue { | |||||
nvs := make([]driver.NamedValue, 0, len(vs)) | |||||
for i, v := range vs { | |||||
nvs = append(nvs, driver.NamedValue{ | |||||
Value: v, | |||||
Ordinal: i, | |||||
}) | |||||
} | |||||
return nvs | |||||
} |
@@ -0,0 +1,355 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
// Copyright 2011 The Go Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the LICENSE file. | |||||
// Package sql provides a generic interface around SQL (or SQL-like) | |||||
// databases. | |||||
// | |||||
// The sql package must be used in conjunction with a database driver. | |||||
// See https://golang.org/s/sqldrivers for a list of drivers. | |||||
// | |||||
// Drivers that do not support context cancellation will not return until | |||||
// after the query is completed. | |||||
// | |||||
// For usage examples, see the wiki page at | |||||
// https://golang.org/s/sqlwiki. | |||||
package util | |||||
import ( | |||||
"context" | |||||
"database/sql/driver" | |||||
"errors" | |||||
"fmt" | |||||
"io" | |||||
"sync" | |||||
) | |||||
// ScanRows is the result of a query. Its cursor starts before the first row | |||||
// of the result set. Use Next to advance from row to row. | |||||
type ScanRows struct { | |||||
//dc *driverConn // owned; must call releaseConn when closed to release | |||||
dc driver.Conn | |||||
releaseConn func(error) | |||||
rowsi driver.Rows | |||||
cancel func() // called when ScanRows is closed, may be nil. | |||||
//closeStmt *driverStmt // if non-nil, statement to Close on close | |||||
// closemu prevents ScanRows from closing while there | |||||
// is an active streaming result. It is held for read during non-close operations | |||||
// and exclusively during close. | |||||
// | |||||
// closemu guards lasterr and closed. | |||||
closemu sync.RWMutex | |||||
closed bool | |||||
lasterr error // non-nil only if closed is true | |||||
// lastcols is only used in Scan, Next, and NextResultSet which are expected | |||||
// not to be called concurrently. | |||||
lastcols []driver.Value | |||||
} | |||||
func NewScanRows(rowsi driver.Rows) *ScanRows { | |||||
return &ScanRows{rowsi: rowsi} | |||||
} | |||||
// lasterrOrErrLocked returns either lasterr or the provided err. | |||||
// rs.closemu must be read-locked. | |||||
func (rs *ScanRows) lasterrOrErrLocked(err error) error { | |||||
if rs.lasterr != nil && rs.lasterr != io.EOF { | |||||
return rs.lasterr | |||||
} | |||||
return err | |||||
} | |||||
// bypassRowsAwaitDone is only used for testing. | |||||
// If true, it will not close the ScanRows automatically from the context. | |||||
var bypassRowsAwaitDone = false | |||||
func (rs *ScanRows) initContextClose(ctx, txctx context.Context) { | |||||
if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) { | |||||
return | |||||
} | |||||
if bypassRowsAwaitDone { | |||||
return | |||||
} | |||||
ctx, rs.cancel = context.WithCancel(ctx) | |||||
go rs.awaitDone(ctx, txctx) | |||||
} | |||||
// awaitDone blocks until either ctx or txctx is canceled. The ctx is provided | |||||
// from the query context and is canceled when the query ScanRows is closed. | |||||
// If the query was issued in a transaction, the transaction's context | |||||
// is also provided in txctx to ensure ScanRows is closed if the Tx is closed. | |||||
func (rs *ScanRows) awaitDone(ctx, txctx context.Context) { | |||||
var txctxDone <-chan struct{} | |||||
if txctx != nil { | |||||
txctxDone = txctx.Done() | |||||
} | |||||
select { | |||||
case <-ctx.Done(): | |||||
case <-txctxDone: | |||||
} | |||||
rs.close(ctx.Err()) | |||||
} | |||||
// Next prepares the next result row for reading with the Scan method. It | |||||
// returns true on success, or false if there is no next result row or an error | |||||
// happened while preparing it. Err should be consulted to distinguish between | |||||
// the two cases. | |||||
// | |||||
// Every call to Scan, even the first one, must be preceded by a call to Next. | |||||
func (rs *ScanRows) Next() bool { | |||||
var doClose, ok bool | |||||
withLock(rs.closemu.RLocker(), func() { | |||||
doClose, ok = rs.nextLocked() | |||||
}) | |||||
if doClose { | |||||
rs.Close() | |||||
} | |||||
return ok | |||||
} | |||||
func (rs *ScanRows) nextLocked() (doClose, ok bool) { | |||||
if rs.closed { | |||||
return false, false | |||||
} | |||||
// Lock the driver connection before calling the driver interface | |||||
// rowsi to prevent a Tx from rolling back the connection at the same time. | |||||
//rs.dc.Lock() | |||||
//defer rs.dc.Unlock() | |||||
if rs.lastcols == nil { | |||||
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) | |||||
} | |||||
rs.lasterr = rs.rowsi.Next(rs.lastcols) | |||||
if rs.lasterr != nil { | |||||
// Close the connection if there is a driver error. | |||||
if rs.lasterr != io.EOF { | |||||
return true, false | |||||
} | |||||
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) | |||||
if !ok { | |||||
return true, false | |||||
} | |||||
// The driver is at the end of the current result set. | |||||
// Test to see if there is another result set after the current one. | |||||
// Only close ScanRows if there is no further result sets to read. | |||||
if !nextResultSet.HasNextResultSet() { | |||||
doClose = true | |||||
} | |||||
return doClose, false | |||||
} | |||||
return false, true | |||||
} | |||||
// NextResultSet prepares the next result set for reading. It reports whether | |||||
// there is further result sets, or false if there is no further result set | |||||
// or if there is an error advancing to it. The Err method should be consulted | |||||
// to distinguish between the two cases. | |||||
// | |||||
// After calling NextResultSet, the Next method should always be called before | |||||
// scanning. If there are further result sets they may not have rows in the result | |||||
// set. | |||||
func (rs *ScanRows) NextResultSet() bool { | |||||
var doClose bool | |||||
defer func() { | |||||
if doClose { | |||||
rs.Close() | |||||
} | |||||
}() | |||||
rs.closemu.RLock() | |||||
defer rs.closemu.RUnlock() | |||||
if rs.closed { | |||||
return false | |||||
} | |||||
rs.lastcols = nil | |||||
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) | |||||
if !ok { | |||||
doClose = true | |||||
return false | |||||
} | |||||
// Lock the driver connection before calling the driver interface | |||||
// rowsi to prevent a Tx from rolling back the connection at the same time. | |||||
//rs.dc.Lock() | |||||
//defer rs.dc.Unlock() | |||||
rs.lasterr = nextResultSet.NextResultSet() | |||||
if rs.lasterr != nil { | |||||
doClose = true | |||||
return false | |||||
} | |||||
return true | |||||
} | |||||
// Err returns the error, if any, that was encountered during iteration. | |||||
// Err may be called after an explicit or implicit Close. | |||||
func (rs *ScanRows) Err() error { | |||||
rs.closemu.RLock() | |||||
defer rs.closemu.RUnlock() | |||||
return rs.lasterrOrErrLocked(nil) | |||||
} | |||||
// | |||||
var errRowsClosed = errors.New("sql: ScanRows are closed") | |||||
// Scan copies the columns in the current row into the values pointed | |||||
// at by dest. The number of values in dest must be the same as the | |||||
// number of columns in ScanRows. | |||||
// | |||||
// Scan converts columns read from the database into the following | |||||
// common Go types and special types provided by the sql package: | |||||
// | |||||
// *string | |||||
// *[]byte | |||||
// *int, *int8, *int16, *int32, *int64 | |||||
// *uint, *uint8, *uint16, *uint32, *uint64 | |||||
// *bool | |||||
// *float32, *float64 | |||||
// *interface{} | |||||
// *RawBytes | |||||
// *ScanRows (cursor value) | |||||
// any type implementing Scanner (see Scanner docs) | |||||
// | |||||
// In the most simple case, if the type of the value from the source | |||||
// column is an integer, bool or string type T and dest is of type *T, | |||||
// Scan simply assigns the value through the pointer. | |||||
// | |||||
// Scan also converts between string and numeric types, as long as no | |||||
// information would be lost. While Scan stringifies all numbers | |||||
// scanned from numeric database columns into *string, scans into | |||||
// numeric types are checked for overflow. For example, a float64 with | |||||
// value 300 or a string with value "300" can scan into a uint16, but | |||||
// not into a uint8, though float64(255) or "255" can scan into a | |||||
// uint8. One exception is that scans of some float64 numbers to | |||||
// strings may lose information when stringifying. In general, scan | |||||
// floating point columns into *float64. | |||||
// | |||||
// If a dest argument has type *[]byte, Scan saves in that argument a | |||||
// copy of the corresponding data. The copy is owned by the caller and | |||||
// can be modified and held indefinitely. The copy can be avoided by | |||||
// using an argument of type *RawBytes instead; see the documentation | |||||
// for RawBytes for restrictions on its use. | |||||
// | |||||
// If an argument has type *interface{}, Scan copies the value | |||||
// provided by the underlying driver without conversion. When scanning | |||||
// from a source value of type []byte to *interface{}, a copy of the | |||||
// slice is made and the caller owns the result. | |||||
// | |||||
// Source values of type time.Time may be scanned into values of type | |||||
// *time.Time, *interface{}, *string, or *[]byte. When converting to | |||||
// the latter two, time.RFC3339Nano is used. | |||||
// | |||||
// Source values of type bool may be scanned into types *bool, | |||||
// *interface{}, *string, *[]byte, or *RawBytes. | |||||
// | |||||
// For scanning into *bool, the source may be true, false, 1, 0, or | |||||
// string inputs parseable by strconv.ParseBool. | |||||
// | |||||
// Scan can also convert a cursor returned from a query, such as | |||||
// "select cursor(select * from my_table) from dual", into a | |||||
// *ScanRows value that can itself be scanned from. The parent | |||||
// select query will close any cursor *ScanRows if the parent *ScanRows is closed. | |||||
// | |||||
// If any of the first arguments implementing Scanner returns an error, | |||||
// that error will be wrapped in the returned error | |||||
func (rs *ScanRows) Scan(dest ...interface{}) error { | |||||
rs.closemu.RLock() | |||||
if rs.lasterr != nil && rs.lasterr != io.EOF { | |||||
rs.closemu.RUnlock() | |||||
return rs.lasterr | |||||
} | |||||
if rs.closed { | |||||
err := rs.lasterrOrErrLocked(errRowsClosed) | |||||
rs.closemu.RUnlock() | |||||
return err | |||||
} | |||||
rs.closemu.RUnlock() | |||||
if rs.lastcols == nil { | |||||
return errors.New("sql: Scan called without calling Next") | |||||
} | |||||
if len(dest) != len(rs.lastcols) { | |||||
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest)) | |||||
} | |||||
for i, sv := range rs.lastcols { | |||||
err := convertAssignRows(dest[i], sv, rs) | |||||
if err != nil { | |||||
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err) | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
// rowsCloseHook returns a function so tests may install the | |||||
// hook through a test only mutex. | |||||
var rowsCloseHook = func() func(*ScanRows, *error) { return nil } | |||||
// Close closes the ScanRows, preventing further enumeration. If Next is called | |||||
// and returns false and there are no further result sets, | |||||
// the ScanRows are closed automatically and it will suffice to check the | |||||
// result of Err. Close is idempotent and does not affect the result of Err. | |||||
func (rs *ScanRows) Close() error { | |||||
//return rs.close(nil) | |||||
return nil | |||||
} | |||||
func (rs *ScanRows) close(err error) error { | |||||
rs.closemu.Lock() | |||||
defer rs.closemu.Unlock() | |||||
if rs.closed { | |||||
return nil | |||||
} | |||||
rs.closed = true | |||||
if rs.lasterr == nil { | |||||
rs.lasterr = err | |||||
} | |||||
err = rs.rowsi.Close() | |||||
//withLock(rs.dc, func() { | |||||
// err = rs.rowsi.Close() | |||||
//}) | |||||
if fn := rowsCloseHook(); fn != nil { | |||||
fn(rs, &err) | |||||
} | |||||
if rs.cancel != nil { | |||||
rs.cancel() | |||||
} | |||||
//if rs.closeStmt != nil { | |||||
// rs.closeStmt.Close() | |||||
//} | |||||
rs.releaseConn(err) | |||||
return err | |||||
} | |||||
// withLock runs while holding lk. | |||||
func withLock(lk sync.Locker, fn func()) { | |||||
lk.Lock() | |||||
defer lk.Unlock() // in case fn panics | |||||
fn() | |||||
} |
@@ -85,6 +85,6 @@ func (f *rmBranchRollbackProcessor) Process(ctx context.Context, rpcMessage mess | |||||
log.Errorf("send branch rollback response error: {%#v}", err.Error()) | log.Errorf("send branch rollback response error: {%#v}", err.Error()) | ||||
return err | return err | ||||
} | } | ||||
log.Infof("send branch rollback response success: xid %s, branchID %s, resourceID %s, applicationData %s", xid, branchID, resourceID, applicationData) | |||||
log.Infof("send branch rollback response success: xid %s, branchID %d, resourceID %s, applicationData %s", xid, branchID, resourceID, applicationData) | |||||
return nil | return nil | ||||
} | } |
@@ -31,3 +31,16 @@ func SetUnexportedField(field reflect.Value, value interface{}) { | |||||
func GetUnexportedField(field reflect.Value) interface{} { | func GetUnexportedField(field reflect.Value) interface{} { | ||||
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() | return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() | ||||
} | } | ||||
func GetElemDataValue(data interface{}) interface{} { | |||||
if data == nil { | |||||
return data | |||||
} | |||||
value := reflect.ValueOf(data) | |||||
kind := reflect.TypeOf(data).Kind() | |||||
switch kind { | |||||
case reflect.Ptr: | |||||
return value.Elem().Interface() | |||||
} | |||||
return value.Interface() | |||||
} |
@@ -0,0 +1,48 @@ | |||||
/* | |||||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||||
* contributor license agreements. See the NOTICE file distributed with | |||||
* this work for additional information regarding copyright ownership. | |||||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||||
* (the "License"); you may not use this file except in compliance with | |||||
* the License. You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
package reflectx | |||||
import ( | |||||
"testing" | |||||
) | |||||
func TestGetElemDataValue(t *testing.T) { | |||||
var aa = 10 | |||||
var bb = "name" | |||||
var cc bool | |||||
tests := []struct { | |||||
name string | |||||
args interface{} | |||||
want interface{} | |||||
}{ | |||||
{name: "test1", args: aa, want: aa}, | |||||
{name: "test2", args: &aa, want: aa}, | |||||
{name: "test3", args: bb, want: bb}, | |||||
{name: "test4", args: &bb, want: bb}, | |||||
{name: "test5", args: cc, want: cc}, | |||||
{name: "test6", args: &cc, want: cc}, | |||||
} | |||||
for _, tt := range tests { | |||||
t.Run(tt.name, func(t *testing.T) { | |||||
if got := GetElemDataValue(tt.args); got != tt.want { | |||||
t.Errorf("GetElemDataValue() = %v, want %v", got, tt.want) | |||||
} | |||||
}) | |||||
} | |||||
} |