@@ -15,7 +15,6 @@ require ( | |||
github.com/golang/mock v1.6.0 | |||
github.com/google/uuid v1.3.0 | |||
github.com/knadh/koanf v1.4.3 | |||
github.com/mitchellh/copystructure v1.2.0 | |||
github.com/natefinch/lumberjack v2.0.0+incompatible | |||
github.com/parnurzeal/gorequest v0.2.16 | |||
github.com/pierrec/lz4/v4 v4.1.17 | |||
@@ -88,6 +87,7 @@ require ( | |||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect | |||
github.com/magiconair/properties v1.8.6 // indirect | |||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect | |||
github.com/mitchellh/copystructure v1.2.0 // indirect | |||
github.com/mitchellh/go-homedir v1.1.0 // indirect | |||
github.com/mitchellh/mapstructure v1.5.0 // indirect | |||
github.com/mitchellh/reflectwalk v1.0.2 // indirect | |||
@@ -32,8 +32,8 @@ import ( | |||
"github.com/knadh/koanf/parsers/yaml" | |||
"github.com/knadh/koanf/providers/rawbytes" | |||
"github.com/seata/seata-go/pkg/remoting/getty" | |||
"github.com/seata/seata-go/pkg/tm" | |||
"github.com/seata/seata-go/pkg/rm/tcc" | |||
"github.com/seata/seata-go/pkg/tm" | |||
"github.com/seata/seata-go/pkg/util/flagext" | |||
) | |||
@@ -23,6 +23,7 @@ import ( | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||
) | |||
// Conn is a connection to a database. It is not used concurrently | |||
@@ -146,8 +147,8 @@ func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) { | |||
} | |||
ret, err := executor.ExecWithValue(context.Background(), execCtx, | |||
func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { | |||
ret, err := conn.Query(query, args) | |||
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { | |||
ret, err := conn.Query(query, util.NamedValueToValue(args)) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -26,14 +26,6 @@ import ( | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func TestConn_BuildATExecutor(t *testing.T) { | |||
executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.ATMode, "SELECT * FROM user") | |||
assert.NoError(t, err) | |||
_, ok := executor.(*exec.BaseExecutor) | |||
assert.True(t, ok, "need base executor") | |||
} | |||
func TestConn_BuildXAExecutor(t *testing.T) { | |||
executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.XAMode, "SELECT * FROM user") | |||
@@ -0,0 +1,296 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package at | |||
import ( | |||
"bytes" | |||
"context" | |||
"database/sql" | |||
"database/sql/driver" | |||
"fmt" | |||
"strings" | |||
"github.com/arana-db/parser/ast" | |||
"github.com/arana-db/parser/test_driver" | |||
gxsort "github.com/dubbogo/gost/sort" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||
"github.com/seata/seata-go/pkg/util/reflectx" | |||
) | |||
type baseExecutor struct { | |||
hooks []exec.SQLHook | |||
} | |||
func (b *baseExecutor) beforeHooks(ctx context.Context, execCtx *types.ExecContext) { | |||
for _, hook := range b.hooks { | |||
hook.Before(ctx, execCtx) | |||
} | |||
} | |||
func (b *baseExecutor) afterHooks(ctx context.Context, execCtx *types.ExecContext) { | |||
for _, hook := range b.hooks { | |||
hook.After(ctx, execCtx) | |||
} | |||
} | |||
// GetScanSlice get the column type for scann | |||
// todo to use ColumnInfo get slice | |||
func (*baseExecutor) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []interface{} { | |||
scanSlice := make([]interface{}, 0, len(columnNames)) | |||
for _, columnNmae := range columnNames { | |||
var ( | |||
// 从metData获取该列的元信息 | |||
columnMeta = tableMeta.Columns[columnNmae] | |||
) | |||
switch strings.ToUpper(columnMeta.DatabaseTypeString) { | |||
case "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", "JSON", "TINYTEXT": | |||
var scanVal string | |||
scanSlice = append(scanSlice, &scanVal) | |||
case "BIT", "INT", "LONGBLOB", "SMALLINT", "TINYINT", "BIGINT", "MEDIUMINT": | |||
if columnMeta.IsNullable == 0 { | |||
scanVal := int64(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
} else { | |||
scanVal := sql.NullInt64{} | |||
scanSlice = append(scanSlice, &scanVal) | |||
} | |||
case "DATE", "DATETIME", "TIME", "TIMESTAMP", "YEAR": | |||
scanVal := sql.NullTime{} | |||
scanSlice = append(scanSlice, &scanVal) | |||
case "DECIMAL", "DOUBLE", "FLOAT": | |||
if columnMeta.IsNullable == 0 { | |||
scanVal := float64(0) | |||
scanSlice = append(scanSlice, &scanVal) | |||
} else { | |||
scanVal := sql.NullFloat64{} | |||
scanSlice = append(scanSlice, &scanVal) | |||
} | |||
default: | |||
scanVal := sql.RawBytes{} | |||
scanSlice = append(scanSlice, &scanVal) | |||
} | |||
} | |||
return scanSlice | |||
} | |||
func (b *baseExecutor) buildSelectArgs(stmt *ast.SelectStmt, args []driver.NamedValue) []driver.NamedValue { | |||
var ( | |||
selectArgsIndexs = make([]int32, 0) | |||
selectArgs = make([]driver.NamedValue, 0) | |||
) | |||
b.traversalArgs(stmt.Where, &selectArgsIndexs) | |||
if stmt.OrderBy != nil { | |||
for _, item := range stmt.OrderBy.Items { | |||
b.traversalArgs(item, &selectArgsIndexs) | |||
} | |||
} | |||
if stmt.Limit != nil { | |||
if stmt.Limit.Offset != nil { | |||
b.traversalArgs(stmt.Limit.Offset, &selectArgsIndexs) | |||
} | |||
if stmt.Limit.Count != nil { | |||
b.traversalArgs(stmt.Limit.Count, &selectArgsIndexs) | |||
} | |||
} | |||
// sort selectArgs index array | |||
gxsort.Int32(selectArgsIndexs) | |||
for _, index := range selectArgsIndexs { | |||
selectArgs = append(selectArgs, args[index]) | |||
} | |||
return selectArgs | |||
} | |||
// todo perfect all sql operation | |||
func (b *baseExecutor) traversalArgs(node ast.Node, argsIndex *[]int32) { | |||
if node == nil { | |||
return | |||
} | |||
switch node.(type) { | |||
case *ast.BinaryOperationExpr: | |||
expr := node.(*ast.BinaryOperationExpr) | |||
b.traversalArgs(expr.L, argsIndex) | |||
b.traversalArgs(expr.R, argsIndex) | |||
break | |||
case *ast.BetweenExpr: | |||
expr := node.(*ast.BetweenExpr) | |||
b.traversalArgs(expr.Left, argsIndex) | |||
b.traversalArgs(expr.Right, argsIndex) | |||
break | |||
case *ast.PatternInExpr: | |||
exprs := node.(*ast.PatternInExpr).List | |||
for i := 0; i < len(exprs); i++ { | |||
b.traversalArgs(exprs[i], argsIndex) | |||
} | |||
break | |||
case *test_driver.ParamMarkerExpr: | |||
*argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order)) | |||
break | |||
} | |||
} | |||
func (b *baseExecutor) buildRecordImages(rowsi driver.Rows, tableMetaData *types.TableMeta) (*types.RecordImage, error) { | |||
// select column names | |||
columnNames := rowsi.Columns() | |||
rowImages := make([]types.RowImage, 0) | |||
sqlRows := util.NewScanRows(rowsi) | |||
for sqlRows.Next() { | |||
ss := b.GetScanSlice(columnNames, tableMetaData) | |||
err := sqlRows.Scan(ss...) | |||
if err != nil { | |||
return nil, err | |||
} | |||
columns := make([]types.ColumnImage, 0) | |||
// build record image | |||
for i, name := range columnNames { | |||
columnMeta := tableMetaData.Columns[name] | |||
keyType := types.IndexTypeNull | |||
if _, ok := tableMetaData.GetPrimaryKeyMap()[name]; ok { | |||
keyType = types.IndexTypePrimaryKey | |||
} | |||
jdbcType := types.MySQLStrToJavaType(columnMeta.DatabaseTypeString) | |||
columns = append(columns, types.ColumnImage{ | |||
KeyType: keyType, | |||
ColumnName: name, | |||
ColumnType: jdbcType, | |||
Value: reflectx.GetElemDataValue(ss[i]), | |||
}) | |||
} | |||
rowImages = append(rowImages, types.RowImage{Columns: columns}) | |||
} | |||
return &types.RecordImage{TableName: tableMetaData.TableName, Rows: rowImages}, nil | |||
} | |||
// buildWhereConditionByPKs build where condition by primary keys | |||
// each pk is a condition.the result will like :" (id,userCode) in ((?,?),(?,?)) or (id,userCode) in ((?,?),(?,?) ) or (id,userCode) in ((?,?))" | |||
func (b *baseExecutor) buildWhereConditionByPKs(pkNameList []string, rowSize int, dbType string, maxInSize int) string { | |||
var ( | |||
whereStr = &strings.Builder{} | |||
batchSize = rowSize/maxInSize + 1 | |||
) | |||
if rowSize%maxInSize == 0 { | |||
batchSize = rowSize / maxInSize | |||
} | |||
for batch := 0; batch < batchSize; batch++ { | |||
if batch > 0 { | |||
whereStr.WriteString(" OR ") | |||
} | |||
whereStr.WriteString("(") | |||
for i := 0; i < len(pkNameList); i++ { | |||
if i > 0 { | |||
whereStr.WriteString(",") | |||
} | |||
// todo add escape | |||
whereStr.WriteString(fmt.Sprintf("`%s`", pkNameList[i])) | |||
} | |||
whereStr.WriteString(") IN (") | |||
var eachSize int | |||
if batch == batchSize-1 { | |||
if rowSize%maxInSize == 0 { | |||
eachSize = maxInSize | |||
} else { | |||
eachSize = rowSize % maxInSize | |||
} | |||
} else { | |||
eachSize = maxInSize | |||
} | |||
for i := 0; i < eachSize; i++ { | |||
if i > 0 { | |||
whereStr.WriteString(",") | |||
} | |||
whereStr.WriteString("(") | |||
for j := 0; j < len(pkNameList); j++ { | |||
if j > 0 { | |||
whereStr.WriteString(",") | |||
} | |||
whereStr.WriteString("?") | |||
} | |||
whereStr.WriteString(")") | |||
} | |||
whereStr.WriteString(")") | |||
} | |||
return whereStr.String() | |||
} | |||
func (b *baseExecutor) buildPKParams(rows []types.RowImage, pkNameList []string) []driver.NamedValue { | |||
params := make([]driver.NamedValue, 0) | |||
for _, row := range rows { | |||
coumnMap := row.GetColumnMap() | |||
for i, pk := range pkNameList { | |||
if col, ok := coumnMap[pk]; ok { | |||
params = append(params, driver.NamedValue{ | |||
Ordinal: i, Value: col.Value, | |||
}) | |||
} | |||
} | |||
} | |||
return params | |||
} | |||
// the string as local key. the local key example(multi pk): "t_user:1_a,2_b" | |||
func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.TableMeta) string { | |||
var ( | |||
lockKeys bytes.Buffer | |||
filedSequence int | |||
) | |||
lockKeys.WriteString(meta.TableName) | |||
lockKeys.WriteString(":") | |||
keys := meta.GetPrimaryKeyOnlyName() | |||
for _, row := range records.Rows { | |||
if filedSequence > 0 { | |||
lockKeys.WriteString(",") | |||
} | |||
pkSplitIndex := 0 | |||
for _, column := range row.Columns { | |||
var hasKeyColumn bool | |||
for _, key := range keys { | |||
if column.ColumnName == key { | |||
hasKeyColumn = true | |||
if pkSplitIndex > 0 { | |||
lockKeys.WriteString("_") | |||
} | |||
lockKeys.WriteString(fmt.Sprintf("%v", column.Value)) | |||
pkSplitIndex++ | |||
} | |||
} | |||
if hasKeyColumn { | |||
filedSequence++ | |||
} | |||
} | |||
} | |||
return lockKeys.String() | |||
} |
@@ -19,182 +19,61 @@ package at | |||
import ( | |||
"context" | |||
"fmt" | |||
"github.com/mitchellh/copystructure" | |||
"github.com/pkg/errors" | |||
"github.com/seata/seata-go/pkg/datasource/sql/parser" | |||
"github.com/seata/seata-go/pkg/datasource/sql/undo" | |||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||
"github.com/seata/seata-go/pkg/tm" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/parser" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
) | |||
type ATExecutor struct { | |||
hooks []exec.SQLHook | |||
ex exec.SQLExecutor | |||
func init() { | |||
exec.RegisterATExecutor(types.DBTypeMySQL, func() exec.SQLExecutor { return &AtExecutor{} }) | |||
} | |||
func (e *ATExecutor) Interceptors(hooks []exec.SQLHook) { | |||
e.hooks = hooks | |||
type executor interface { | |||
ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) | |||
} | |||
func (e *ATExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||
for _, hook := range e.hooks { | |||
hook.Before(ctx, execCtx) | |||
} | |||
var ( | |||
beforeImages []*types.RecordImage | |||
afterImages []*types.RecordImage | |||
result types.ExecResult | |||
err error | |||
) | |||
beforeImages, err = e.beforeImage(ctx, execCtx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if beforeImages != nil { | |||
beforeImagesTmp, err := copystructure.Copy(beforeImages) | |||
if err != nil { | |||
return nil, err | |||
} | |||
newBeforeImages, ok := beforeImagesTmp.([]*types.RecordImage) | |||
if !ok { | |||
return nil, errors.New("copy beforeImages failed") | |||
} | |||
execCtx.TxCtx.RoundImages.AppendBeofreImages(newBeforeImages) | |||
} | |||
defer func() { | |||
for _, hook := range e.hooks { | |||
hook.After(ctx, execCtx) | |||
} | |||
}() | |||
if e.ex != nil { | |||
result, err = e.ex.ExecWithNamedValue(ctx, execCtx, f) | |||
} else { | |||
result, err = f(ctx, execCtx.Query, execCtx.NamedValues) | |||
} | |||
if err != nil { | |||
return nil, err | |||
} | |||
afterImages, err = e.afterImage(ctx, execCtx, beforeImages) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if afterImages != nil { | |||
execCtx.TxCtx.RoundImages.AppendAfterImages(afterImages) | |||
} | |||
return result, err | |||
type AtExecutor struct { | |||
hooks []exec.SQLHook | |||
} | |||
func (e *ATExecutor) prepareUndoLog(ctx context.Context, execCtx *types.ExecContext) error { | |||
if execCtx.TxCtx.RoundImages.IsEmpty() { | |||
return nil | |||
} | |||
if execCtx.ParseContext.UpdateStmt != nil { | |||
if !execCtx.TxCtx.RoundImages.IsBeforeAfterSizeEq() { | |||
return fmt.Errorf("Before image size is not equaled to after image size, probably because you updated the primary keys.") | |||
} | |||
} | |||
undoLogManager, err := undo.GetUndoLogManager(execCtx.DBType) | |||
if err != nil { | |||
return err | |||
} | |||
return undoLogManager.FlushUndoLog(execCtx.TxCtx, execCtx.Conn) | |||
func (e *AtExecutor) Interceptors(hooks []exec.SQLHook) { | |||
e.hooks = hooks | |||
} | |||
func (e *ATExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) { | |||
for _, hook := range e.hooks { | |||
hook.Before(ctx, execCtx) | |||
} | |||
var ( | |||
beforeImages []*types.RecordImage | |||
afterImages []*types.RecordImage | |||
result types.ExecResult | |||
err error | |||
) | |||
beforeImages, err = e.beforeImage(ctx, execCtx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if beforeImages != nil { | |||
execCtx.TxCtx.RoundImages.AppendBeofreImages(beforeImages) | |||
} | |||
defer func() { | |||
for _, hook := range e.hooks { | |||
hook.After(ctx, execCtx) | |||
} | |||
}() | |||
if e.ex != nil { | |||
result, err = e.ex.ExecWithValue(ctx, execCtx, f) | |||
} else { | |||
result, err = f(ctx, execCtx.Query, execCtx.Values) | |||
} | |||
if err != nil { | |||
return nil, err | |||
} | |||
afterImages, err = e.afterImage(ctx, execCtx, beforeImages) | |||
//ExecWithNamedValue | |||
func (e *AtExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||
parser, err := parser.DoParser(execCtx.Query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if afterImages != nil { | |||
execCtx.TxCtx.RoundImages.AppendAfterImages(afterImages) | |||
} | |||
return result, err | |||
} | |||
var exec executor | |||
func (e *ATExecutor) beforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) { | |||
if !tm.IsGlobalTx(ctx) { | |||
return nil, nil | |||
} | |||
pc, err := parser.DoParser(execCtx.Query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if !pc.HasValidStmt() { | |||
return nil, nil | |||
exec = NewPlainExecutor(parser, execCtx) | |||
} else { | |||
switch parser.SQLType { | |||
//case types.SQLTypeInsert: | |||
case types.SQLTypeUpdate: | |||
exec = NewUpdateExecutor(parser, execCtx, e.hooks) | |||
//case types.SQLTypeDelete: | |||
//case types.SQLTypeSelectForUpdate: | |||
//case types.SQLTypeMultiDelete: | |||
//case types.SQLTypeMultiUpdate: | |||
default: | |||
exec = NewPlainExecutor(parser, execCtx) | |||
} | |||
} | |||
execCtx.ParseContext = pc | |||
builder := undo.GetUndologBuilder(pc.ExecutorType) | |||
if builder == nil { | |||
return nil, nil | |||
} | |||
return builder.BeforeImage(ctx, execCtx) | |||
return exec.ExecContext(ctx, f) | |||
} | |||
// After | |||
func (e *ATExecutor) afterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { | |||
if !tm.IsGlobalTx(ctx) { | |||
return nil, nil | |||
} | |||
pc, err := parser.DoParser(execCtx.Query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if !pc.HasValidStmt() { | |||
return nil, nil | |||
} | |||
execCtx.ParseContext = pc | |||
builder := undo.GetUndologBuilder(pc.ExecutorType) | |||
if builder == nil { | |||
return nil, nil | |||
} | |||
return builder.AfterImage(ctx, execCtx, beforeImages) | |||
//ExecWithValue | |||
func (e *AtExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||
execCtx.NamedValues = util.ValueToNamedValue(execCtx.Values) | |||
return e.ExecWithNamedValue(ctx, execCtx, f) | |||
} |
@@ -18,21 +18,21 @@ | |||
package at | |||
import ( | |||
"context" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
) | |||
func init() { | |||
exec.RegisterATExecutor(types.DBTypeMySQL, types.UpdateExecutor, func() exec.SQLExecutor { | |||
return &ATExecutor{} | |||
}) | |||
exec.RegisterATExecutor(types.DBTypeMySQL, types.SelectExecutor, func() exec.SQLExecutor { | |||
return &ATExecutor{} | |||
}) | |||
exec.RegisterATExecutor(types.DBTypeMySQL, types.InsertExecutor, func() exec.SQLExecutor { | |||
return &ATExecutor{} | |||
}) | |||
exec.RegisterATExecutor(types.DBTypeMySQL, types.DeleteExecutor, func() exec.SQLExecutor { | |||
return &ATExecutor{} | |||
}) | |||
type PlainExecutor struct { | |||
parserCtx *types.ParseContext | |||
execCtx *types.ExecContext | |||
} | |||
func NewPlainExecutor(parserCtx *types.ParseContext, execCtx *types.ExecContext) *PlainExecutor { | |||
return &PlainExecutor{parserCtx: parserCtx, execCtx: execCtx} | |||
} | |||
func (u *PlainExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||
return f(ctx, u.execCtx.Query, u.execCtx.NamedValues) | |||
} |
@@ -0,0 +1,272 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package at | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"fmt" | |||
"strings" | |||
"github.com/arana-db/parser/ast" | |||
"github.com/arana-db/parser/format" | |||
"github.com/arana-db/parser/model" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||
"github.com/seata/seata-go/pkg/util/bytes" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
) | |||
var ( | |||
// todo: OnlyCareUpdateColumns should load from config first | |||
onlyCareUpdateColumns = true | |||
maxInSize = 1000 | |||
) | |||
type updateExecutor struct { | |||
baseExecutor | |||
parserCtx *types.ParseContext | |||
execContent *types.ExecContext | |||
} | |||
//NewUpdateExecutor get update executor | |||
func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { | |||
return &updateExecutor{parserCtx: parserCtx, execContent: execContent, baseExecutor: baseExecutor{hooks: hooks}} | |||
} | |||
//ExecContext exec SQL, and generate before image and after image | |||
func (u *updateExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||
u.beforeHooks(ctx, u.execContent) | |||
defer func() { | |||
u.afterHooks(ctx, u.execContent) | |||
}() | |||
beforeImage, err := u.beforeImage(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
res, err := f(ctx, u.execContent.Query, u.execContent.NamedValues) | |||
if err != nil { | |||
return nil, err | |||
} | |||
afterImage, err := u.afterImage(ctx, *beforeImage) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if len(beforeImage.Rows) != len(afterImage.Rows) { | |||
return nil, fmt.Errorf("Before image size is not equaled to after image size, probably because you updated the primary keys.") | |||
} | |||
u.execContent.TxCtx.RoundImages.AppendBeofreImage(beforeImage) | |||
u.execContent.TxCtx.RoundImages.AppendAfterImage(afterImage) | |||
return res, nil | |||
} | |||
//beforeImage build before image | |||
func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) { | |||
if !u.isAstStmtValid() { | |||
return nil, nil | |||
} | |||
selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContent.NamedValues) | |||
if err != nil { | |||
return nil, err | |||
} | |||
tableName, _ := u.parserCtx.GteTableName() | |||
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName) | |||
if err != nil { | |||
return nil, err | |||
} | |||
var rowsi driver.Rows | |||
queryerCtx, ok := u.execContent.Conn.(driver.QueryerContext) | |||
var queryer driver.Queryer | |||
if !ok { | |||
queryer, ok = u.execContent.Conn.(driver.Queryer) | |||
} | |||
if ok { | |||
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) | |||
if err != nil { | |||
log.Errorf("ctx driver query: %+v", err) | |||
return nil, err | |||
} | |||
} else { | |||
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") | |||
return nil, fmt.Errorf("invalid conn") | |||
} | |||
image, err := u.buildRecordImages(rowsi, metaData) | |||
if err != nil { | |||
return nil, err | |||
} | |||
lockKey := u.buildLockKey(image, *metaData) | |||
u.execContent.TxCtx.LockKeys[lockKey] = struct{}{} | |||
image.SQLType = u.parserCtx.SQLType | |||
return image, nil | |||
} | |||
//afterImage build after image | |||
func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.RecordImage) (*types.RecordImage, error) { | |||
if !u.isAstStmtValid() { | |||
return nil, nil | |||
} | |||
if len(beforeImage.Rows) == 0 { | |||
return &types.RecordImage{}, nil | |||
} | |||
tableName, _ := u.parserCtx.GteTableName() | |||
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName) | |||
if err != nil { | |||
return nil, err | |||
} | |||
selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData) | |||
var rowsi driver.Rows | |||
queryerCtx, ok := u.execContent.Conn.(driver.QueryerContext) | |||
var queryer driver.Queryer | |||
if !ok { | |||
queryer, ok = u.execContent.Conn.(driver.Queryer) | |||
} | |||
if ok { | |||
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) | |||
if err != nil { | |||
log.Errorf("ctx driver query: %+v", err) | |||
return nil, err | |||
} | |||
} else { | |||
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") | |||
return nil, fmt.Errorf("invalid conn") | |||
} | |||
afterImage, err := u.buildRecordImages(rowsi, metaData) | |||
if err != nil { | |||
return nil, err | |||
} | |||
afterImage.SQLType = u.parserCtx.SQLType | |||
return afterImage, nil | |||
} | |||
func (u *updateExecutor) isAstStmtValid() bool { | |||
return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil | |||
} | |||
//buildAfterImageSQL build the SQL to query after image data | |||
func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta *types.TableMeta) (string, []driver.NamedValue) { | |||
if len(beforeImage.Rows) == 0 { | |||
return "", nil | |||
} | |||
sb := strings.Builder{} | |||
// todo: OnlyCareUpdateColumns should load from config first | |||
var selectFields string | |||
var separator = "," | |||
if onlyCareUpdateColumns { | |||
for _, row := range beforeImage.Rows { | |||
for _, column := range row.Columns { | |||
selectFields += column.ColumnName + separator | |||
} | |||
} | |||
selectFields = strings.TrimSuffix(selectFields, separator) | |||
} else { | |||
selectFields = "*" | |||
} | |||
sb.WriteString("SELECT " + selectFields + " FROM " + meta.TableName + " WHERE ") | |||
whereSQL := u.buildWhereConditionByPKs(meta.GetPrimaryKeyOnlyName(), len(beforeImage.Rows), "mysql", maxInSize) | |||
sb.WriteString(" " + whereSQL + " ") | |||
return sb.String(), u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()) | |||
} | |||
//buildAfterImageSQL build the SQL to query before image data | |||
func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) { | |||
if !u.isAstStmtValid() { | |||
log.Errorf("invalid update stmt") | |||
return "", nil, fmt.Errorf("invalid update stmt") | |||
} | |||
updateStmt := u.parserCtx.UpdateStmt | |||
fields := make([]*ast.SelectField, 0, len(updateStmt.List)) | |||
if onlyCareUpdateColumns { | |||
for _, column := range updateStmt.List { | |||
fields = append(fields, &ast.SelectField{ | |||
Expr: &ast.ColumnNameExpr{ | |||
Name: column.Column, | |||
}, | |||
}) | |||
} | |||
// select indexes columns | |||
tableName, _ := u.parserCtx.GteTableName() | |||
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContent.DBName, tableName) | |||
if err != nil { | |||
return "", nil, err | |||
} | |||
for _, columnName := range metaData.GetPrimaryKeyOnlyName() { | |||
fields = append(fields, &ast.SelectField{ | |||
Expr: &ast.ColumnNameExpr{ | |||
Name: &ast.ColumnName{ | |||
Name: model.CIStr{ | |||
O: columnName, | |||
L: columnName, | |||
}, | |||
}, | |||
}, | |||
}) | |||
} | |||
} else { | |||
fields = append(fields, &ast.SelectField{ | |||
Expr: &ast.ColumnNameExpr{ | |||
Name: &ast.ColumnName{ | |||
Name: model.CIStr{ | |||
O: "*", | |||
L: "*", | |||
}, | |||
}, | |||
}, | |||
}) | |||
} | |||
selStmt := ast.SelectStmt{ | |||
SelectStmtOpts: &ast.SelectStmtOpts{}, | |||
From: updateStmt.TableRefs, | |||
Where: updateStmt.Where, | |||
Fields: &ast.FieldList{Fields: fields}, | |||
OrderBy: updateStmt.Order, | |||
Limit: updateStmt.Limit, | |||
TableHints: updateStmt.TableHints, | |||
LockInfo: &ast.SelectLockInfo{ | |||
LockType: ast.SelectLockForUpdate, | |||
}, | |||
} | |||
b := bytes.NewByteBuffer([]byte{}) | |||
_ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) | |||
sql := string(b.Bytes()) | |||
log.Infof("build select sql by update sourceQuery, sql {%s}", sql) | |||
return sql, u.buildSelectArgs(&selStmt, args), nil | |||
} |
@@ -0,0 +1,101 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package at | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"reflect" | |||
"testing" | |||
"github.com/agiledragon/gomonkey" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/datasource/sql/parser" | |||
_ "github.com/arana-db/parser/test_driver" | |||
_ "github.com/seata/seata-go/pkg/util/log" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func TestBuildSelectSQLByUpdate(t *testing.T) { | |||
datasource.RegisterTableCache(types.DBTypeMySQL, mysql.NewTableMetaInstance(nil)) | |||
stub := gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)), "GetTableMeta", | |||
func(_ *mysql.TableMetaCache, ctx context.Context, dbName, tableName string) (*types.TableMeta, error) { | |||
return &types.TableMeta{ | |||
Indexs: map[string]types.IndexMeta{ | |||
"id": { | |||
IType: types.IndexTypePrimaryKey, | |||
Columns: []types.ColumnMeta{ | |||
{ColumnName: "id"}, | |||
}, | |||
}, | |||
}, | |||
}, nil | |||
}) | |||
defer stub.Reset() | |||
tests := []struct { | |||
name string | |||
sourceQuery string | |||
sourceQueryArgs []driver.Value | |||
expectQuery string | |||
expectQueryArgs []driver.Value | |||
}{ | |||
{ | |||
sourceQuery: "update t_user set name = ?, age = ? where id = ?", | |||
sourceQueryArgs: []driver.Value{"Jack", 1, 100}, | |||
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? FOR UPDATE", | |||
expectQueryArgs: []driver.Value{100}, | |||
}, | |||
{ | |||
sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age between ? and ?", | |||
sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, | |||
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE", | |||
expectQueryArgs: []driver.Value{100, 18, 28}, | |||
}, | |||
{ | |||
sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age in (?,?)", | |||
sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, | |||
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE", | |||
expectQueryArgs: []driver.Value{100, 18, 28}, | |||
}, | |||
{ | |||
sourceQuery: "update t_user set name = ?, age = ? where kk between ? and ? and id = ? and addr in(?,?) and age > ? order by name desc limit ?", | |||
sourceQueryArgs: []driver.Value{"Jack", 1, 10, 20, 17, "Beijing", "Guangzhou", 18, 2}, | |||
expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE", | |||
expectQueryArgs: []driver.Value{10, 20, 17, "Beijing", "Guangzhou", 18, 2}, | |||
}, | |||
} | |||
for _, tt := range tests { | |||
t.Run(tt.name, func(t *testing.T) { | |||
c, err := parser.DoParser(tt.sourceQuery) | |||
assert.Nil(t, err) | |||
executor := NewUpdateExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{}) | |||
query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), util.ValueToNamedValue(tt.sourceQueryArgs)) | |||
assert.Nil(t, err) | |||
assert.Equal(t, tt.expectQuery, query) | |||
assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args)) | |||
}) | |||
} | |||
} |
@@ -25,35 +25,25 @@ import ( | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/datasource/sql/undo" | |||
"github.com/seata/seata-go/pkg/datasource/sql/undo/builder" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
) | |||
func init() { | |||
undo.RegisterUndoLogBuilder(types.UpdateExecutor, builder.GetMySQLUpdateUndoLogBuilder) | |||
undo.RegisterUndoLogBuilder(types.MultiExecutor, builder.GetMySQLMultiUndoLogBuilder) | |||
} | |||
var ( | |||
executorSoltsAT = make(map[types.DBType]map[types.ExecutorType]func() SQLExecutor) | |||
executorSoltsXA = make(map[types.DBType]func() SQLExecutor) | |||
atExecutors = make(map[types.DBType]func() SQLExecutor) | |||
xaExecutors = make(map[types.DBType]func() SQLExecutor) | |||
) | |||
// RegisterATExecutor AT executor | |||
func RegisterATExecutor(dt types.DBType, et types.ExecutorType, builder func() SQLExecutor) { | |||
if _, ok := executorSoltsAT[dt]; !ok { | |||
executorSoltsAT[dt] = make(map[types.ExecutorType]func() SQLExecutor) | |||
} | |||
val := executorSoltsAT[dt] | |||
val[et] = func() SQLExecutor { | |||
return &BaseExecutor{ex: builder()} | |||
} | |||
func RegisterATExecutor(dt types.DBType, builder func() SQLExecutor) { | |||
atExecutors[dt] = builder | |||
} | |||
// RegisterXAExecutor XA executor | |||
func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) { | |||
executorSoltsXA[dt] = func() SQLExecutor { | |||
xaExecutors[dt] = func() SQLExecutor { | |||
return builder() | |||
} | |||
} | |||
@@ -66,7 +56,7 @@ type ( | |||
SQLExecutor interface { | |||
Interceptors(interceptors []SQLHook) | |||
ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) | |||
ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error) | |||
ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) | |||
} | |||
) | |||
@@ -83,37 +73,14 @@ func BuildExecutor(dbType types.DBType, transactionMode types.TransactionMode, q | |||
hooks = append(hooks, hookSolts[parseContext.SQLType]...) | |||
if transactionMode == types.XAMode { | |||
e := executorSoltsXA[dbType]() | |||
e.Interceptors(hooks) | |||
return e, nil | |||
} | |||
if transactionMode == types.ATMode { | |||
e := executorSoltsAT[dbType][parseContext.ExecutorType]() | |||
e.Interceptors(hooks) | |||
return e, nil | |||
} | |||
factories, ok := executorSoltsAT[dbType] | |||
if !ok { | |||
log.Debugf("%s not found executor factories, return default Executor", dbType.String()) | |||
e := &BaseExecutor{} | |||
e := xaExecutors[dbType]() | |||
e.Interceptors(hooks) | |||
return e, nil | |||
} | |||
supplier, ok := factories[parseContext.ExecutorType] | |||
if !ok { | |||
log.Debugf("%s not found executor for %s, return default Executor", | |||
dbType.String(), parseContext.ExecutorType) | |||
e := &BaseExecutor{} | |||
e.Interceptors(hooks) | |||
return e, nil | |||
} | |||
executor := supplier() | |||
executor.Interceptors(hooks) | |||
return executor, nil | |||
e := atExecutors[dbType]() | |||
e.Interceptors(hooks) | |||
return e, nil | |||
} | |||
type BaseExecutor struct { | |||
@@ -146,7 +113,7 @@ func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Ex | |||
} | |||
// ExecWithValue | |||
func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error) { | |||
func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) { | |||
for i := range e.hooks { | |||
e.hooks[i].Before(ctx, execCtx) | |||
} | |||
@@ -161,5 +128,13 @@ func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecCon | |||
return e.ex.ExecWithValue(ctx, execCtx, f) | |||
} | |||
return f(ctx, execCtx.Query, execCtx.Values) | |||
nvargs := make([]driver.NamedValue, len(execCtx.Values)) | |||
for i, value := range execCtx.Values { | |||
nvargs = append(nvargs, driver.NamedValue{ | |||
Value: value, | |||
Ordinal: i, | |||
}) | |||
} | |||
return f(ctx, execCtx.Query, nvargs) | |||
} |
@@ -19,6 +19,7 @@ package xa | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
@@ -55,7 +56,7 @@ func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Exec | |||
} | |||
// ExecWithValue | |||
func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) { | |||
func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||
for _, hook := range e.hooks { | |||
hook.Before(ctx, execCtx) | |||
} | |||
@@ -70,5 +71,14 @@ func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecConte | |||
return e.ex.ExecWithValue(ctx, execCtx, f) | |||
} | |||
return f(ctx, execCtx.Query, execCtx.Values) | |||
nvargs := make([]driver.NamedValue, len(execCtx.Values)) | |||
for i, value := range execCtx.Values { | |||
nvargs = append(nvargs, driver.NamedValue{ | |||
Value: value, | |||
Ordinal: i, | |||
}) | |||
} | |||
execCtx.NamedValues = nvargs | |||
return f(ctx, execCtx.Query, execCtx.NamedValues) | |||
} |
@@ -23,6 +23,7 @@ import ( | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||
) | |||
type Stmt struct { | |||
@@ -75,8 +76,8 @@ func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { | |||
} | |||
ret, err := executor.ExecWithValue(context.Background(), execCtx, | |||
func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { | |||
ret, err := s.stmt.Query(args) | |||
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { | |||
ret, err := s.stmt.Query(util.NamedValueToValue(args)) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -144,8 +145,8 @@ func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { | |||
} | |||
ret, err := executor.ExecWithValue(context.Background(), execCtx, | |||
func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { | |||
ret, err := s.stmt.Exec(args) | |||
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { | |||
ret, err := s.stmt.Exec(util.NamedValueToValue(args)) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -129,7 +129,8 @@ type RowImage struct { | |||
func (r *RowImage) GetColumnMap() map[string]*ColumnImage { | |||
m := make(map[string]*ColumnImage, 0) | |||
for _, column := range r.Columns { | |||
m[column.ColumnName] = &column | |||
tmpColumn := column | |||
m[column.ColumnName] = &tmpColumn | |||
} | |||
return m | |||
} | |||
@@ -245,7 +246,7 @@ func (c *ColumnImage) UnmarshalJSON(data []byte) error { | |||
case JDBCTypeChar, JDBCTypeVarchar: | |||
var val []byte | |||
if val, err = base64.StdEncoding.DecodeString(value.(string)); err != nil { | |||
return err | |||
val = []byte(value.(string)) | |||
} | |||
actualValue = string(val) | |||
case JDBCTypeBinary, JDBCTypeVarBinary, JDBCTypeLongVarBinary, JDBCTypeBit: | |||
@@ -0,0 +1,377 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
// Copyright 2011 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// Type conversions for Scan. | |||
package util | |||
import ( | |||
"database/sql" | |||
"database/sql/driver" | |||
"errors" | |||
"fmt" | |||
"reflect" | |||
"strconv" | |||
"time" | |||
) | |||
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error | |||
// convertAssignRows copies to dest the value in src, converting it if possible. | |||
// An error is returned if the copy would result in loss of information. | |||
// dest should be a pointer type. If rows is passed in, the rows will | |||
// be used as the parent for any cursor values converted from a | |||
// driver.Rows to a *ScanRows. | |||
func convertAssignRows(dest, src interface{}, rows *ScanRows) error { | |||
// Common cases, without reflect. | |||
switch s := src.(type) { | |||
case string: | |||
switch d := dest.(type) { | |||
case *string: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = s | |||
return nil | |||
case *[]byte: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = []byte(s) | |||
return nil | |||
case *sql.RawBytes: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = append((*d)[:0], s...) | |||
return nil | |||
} | |||
case []byte: | |||
switch d := dest.(type) { | |||
case *string: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = string(s) | |||
return nil | |||
case *interface{}: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = cloneBytes(s) | |||
return nil | |||
case *[]byte: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = cloneBytes(s) | |||
return nil | |||
case *sql.RawBytes: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = s | |||
return nil | |||
} | |||
case time.Time: | |||
switch d := dest.(type) { | |||
case *time.Time: | |||
*d = s | |||
return nil | |||
case *string: | |||
*d = s.Format(time.RFC3339Nano) | |||
return nil | |||
case *[]byte: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = []byte(s.Format(time.RFC3339Nano)) | |||
return nil | |||
case *sql.RawBytes: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = s.AppendFormat((*d)[:0], time.RFC3339Nano) | |||
return nil | |||
} | |||
case decimalDecompose: | |||
switch d := dest.(type) { | |||
case decimalCompose: | |||
return d.Compose(s.Decompose(nil)) | |||
} | |||
case nil: | |||
switch d := dest.(type) { | |||
case *interface{}: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = nil | |||
return nil | |||
case *[]byte: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = nil | |||
return nil | |||
case *sql.RawBytes: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
*d = nil | |||
return nil | |||
} | |||
// The driver is returning a cursor the client may iterate over. | |||
case driver.Rows: | |||
switch d := dest.(type) { | |||
case *ScanRows: | |||
if d == nil { | |||
return errNilPtr | |||
} | |||
if rows == nil { | |||
return errors.New("invalid context to convert cursor rows, missing parent *ScanRows") | |||
} | |||
rows.closemu.Lock() | |||
*d = ScanRows{ | |||
dc: rows.dc, | |||
releaseConn: func(error) {}, | |||
rowsi: s, | |||
} | |||
// Chain the cancel function. | |||
parentCancel := rows.cancel | |||
rows.cancel = func() { | |||
// When ScanRows.cancel is called, the closemu will be locked as well. | |||
// So we can access rs.lasterr. | |||
d.close(rows.lasterr) | |||
if parentCancel != nil { | |||
parentCancel() | |||
} | |||
} | |||
rows.closemu.Unlock() | |||
return nil | |||
} | |||
} | |||
var sv reflect.Value | |||
switch d := dest.(type) { | |||
case *string: | |||
sv = reflect.ValueOf(src) | |||
switch sv.Kind() { | |||
case reflect.Bool, | |||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, | |||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, | |||
reflect.Float32, reflect.Float64: | |||
*d = asString(src) | |||
return nil | |||
} | |||
case *[]byte: | |||
sv = reflect.ValueOf(src) | |||
if b, ok := asBytes(nil, sv); ok { | |||
*d = b | |||
return nil | |||
} | |||
case *sql.RawBytes: | |||
sv = reflect.ValueOf(src) | |||
if b, ok := asBytes([]byte(*d)[:0], sv); ok { | |||
*d = sql.RawBytes(b) | |||
return nil | |||
} | |||
case *bool: | |||
bv, err := driver.Bool.ConvertValue(src) | |||
if err == nil { | |||
*d = bv.(bool) | |||
} | |||
return err | |||
case *interface{}: | |||
*d = src | |||
return nil | |||
} | |||
if scanner, ok := dest.(sql.Scanner); ok { | |||
return scanner.Scan(src) | |||
} | |||
dpv := reflect.ValueOf(dest) | |||
if dpv.Kind() != reflect.Ptr { | |||
return errors.New("destination not a pointer") | |||
} | |||
if dpv.IsNil() { | |||
return errNilPtr | |||
} | |||
if !sv.IsValid() { | |||
sv = reflect.ValueOf(src) | |||
} | |||
dv := reflect.Indirect(dpv) | |||
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { | |||
switch b := src.(type) { | |||
case []byte: | |||
dv.Set(reflect.ValueOf(cloneBytes(b))) | |||
default: | |||
dv.Set(sv) | |||
} | |||
return nil | |||
} | |||
if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { | |||
dv.Set(sv.Convert(dv.Type())) | |||
return nil | |||
} | |||
// The following conversions use a string value as an intermediate representation | |||
// to convert between various numeric types. | |||
// | |||
// This also allows scanning into user defined types such as "type Int int64". | |||
// For symmetry, also check for string destination types. | |||
switch dv.Kind() { | |||
case reflect.Ptr: | |||
if src == nil { | |||
dv.Set(reflect.Zero(dv.Type())) | |||
return nil | |||
} | |||
dv.Set(reflect.New(dv.Type().Elem())) | |||
return convertAssignRows(dv.Interface(), src, rows) | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
if src == nil { | |||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | |||
} | |||
s := asString(src) | |||
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) | |||
if err != nil { | |||
err = strconvErr(err) | |||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) | |||
} | |||
dv.SetInt(i64) | |||
return nil | |||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||
if src == nil { | |||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | |||
} | |||
s := asString(src) | |||
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) | |||
if err != nil { | |||
err = strconvErr(err) | |||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) | |||
} | |||
dv.SetUint(u64) | |||
return nil | |||
case reflect.Float32, reflect.Float64: | |||
if src == nil { | |||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | |||
} | |||
s := asString(src) | |||
f64, err := strconv.ParseFloat(s, dv.Type().Bits()) | |||
if err != nil { | |||
err = strconvErr(err) | |||
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) | |||
} | |||
dv.SetFloat(f64) | |||
return nil | |||
case reflect.String: | |||
if src == nil { | |||
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) | |||
} | |||
switch v := src.(type) { | |||
case string: | |||
dv.SetString(v) | |||
return nil | |||
case []byte: | |||
dv.SetString(string(v)) | |||
return nil | |||
} | |||
} | |||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) | |||
} | |||
func strconvErr(err error) error { | |||
if ne, ok := err.(*strconv.NumError); ok { | |||
return ne.Err | |||
} | |||
return err | |||
} | |||
func cloneBytes(b []byte) []byte { | |||
if b == nil { | |||
return nil | |||
} | |||
c := make([]byte, len(b)) | |||
copy(c, b) | |||
return c | |||
} | |||
func asString(src interface{}) string { | |||
switch v := src.(type) { | |||
case string: | |||
return v | |||
case []byte: | |||
return string(v) | |||
} | |||
rv := reflect.ValueOf(src) | |||
switch rv.Kind() { | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
return strconv.FormatInt(rv.Int(), 10) | |||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||
return strconv.FormatUint(rv.Uint(), 10) | |||
case reflect.Float64: | |||
return strconv.FormatFloat(rv.Float(), 'g', -1, 64) | |||
case reflect.Float32: | |||
return strconv.FormatFloat(rv.Float(), 'g', -1, 32) | |||
case reflect.Bool: | |||
return strconv.FormatBool(rv.Bool()) | |||
} | |||
return fmt.Sprintf("%v", src) | |||
} | |||
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { | |||
switch rv.Kind() { | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
return strconv.AppendInt(buf, rv.Int(), 10), true | |||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||
return strconv.AppendUint(buf, rv.Uint(), 10), true | |||
case reflect.Float32: | |||
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true | |||
case reflect.Float64: | |||
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true | |||
case reflect.Bool: | |||
return strconv.AppendBool(buf, rv.Bool()), true | |||
case reflect.String: | |||
s := rv.String() | |||
return append(buf, s...), true | |||
} | |||
return | |||
} | |||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() | |||
type decimalDecompose interface { | |||
// Decompose returns the internal decimal state in parts. | |||
// If the provided buf has sufficient capacity, buf may be returned as the coefficient with | |||
// the value set and length set as appropriate. | |||
Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) | |||
} | |||
type decimalCompose interface { | |||
// Compose sets the internal decimal value from parts. If the value cannot be | |||
// represented then an error should be returned. | |||
Compose(form byte, negative bool, coefficient []byte, exponent int32) error | |||
} |
@@ -0,0 +1,123 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
// Copyright 2016 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
package util | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"errors" | |||
) | |||
func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) { | |||
if ciCtx, is := ci.(driver.ConnPrepareContext); is { | |||
return ciCtx.PrepareContext(ctx, query) | |||
} | |||
si, err := ci.Prepare(query) | |||
if err == nil { | |||
select { | |||
default: | |||
case <-ctx.Done(): | |||
si.Close() | |||
return nil, ctx.Err() | |||
} | |||
} | |||
return si, err | |||
} | |||
func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) { | |||
if execerCtx != nil { | |||
return execerCtx.ExecContext(ctx, query, nvdargs) | |||
} | |||
dargs, err := namedValueToValue(nvdargs) | |||
if err != nil { | |||
return nil, err | |||
} | |||
select { | |||
default: | |||
case <-ctx.Done(): | |||
return nil, ctx.Err() | |||
} | |||
return execer.Exec(query, dargs) | |||
} | |||
func CtxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) { | |||
if queryerCtx != nil { | |||
return queryerCtx.QueryContext(ctx, query, nvdargs) | |||
} | |||
dargs, err := namedValueToValue(nvdargs) | |||
if err != nil { | |||
return nil, err | |||
} | |||
select { | |||
default: | |||
case <-ctx.Done(): | |||
return nil, ctx.Err() | |||
} | |||
return queryer.Query(query, dargs) | |||
} | |||
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) { | |||
if siCtx, is := si.(driver.StmtExecContext); is { | |||
return siCtx.ExecContext(ctx, nvdargs) | |||
} | |||
dargs, err := namedValueToValue(nvdargs) | |||
if err != nil { | |||
return nil, err | |||
} | |||
select { | |||
default: | |||
case <-ctx.Done(): | |||
return nil, ctx.Err() | |||
} | |||
return si.Exec(dargs) | |||
} | |||
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) { | |||
if siCtx, is := si.(driver.StmtQueryContext); is { | |||
return siCtx.QueryContext(ctx, nvdargs) | |||
} | |||
dargs, err := namedValueToValue(nvdargs) | |||
if err != nil { | |||
return nil, err | |||
} | |||
select { | |||
default: | |||
case <-ctx.Done(): | |||
return nil, ctx.Err() | |||
} | |||
return si.Query(dargs) | |||
} | |||
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { | |||
dargs := make([]driver.Value, len(named)) | |||
for n, param := range named { | |||
if len(param.Name) > 0 { | |||
return nil, errors.New("sql: driver does not support the use of Named Parameters") | |||
} | |||
dargs[n] = param.Value | |||
} | |||
return dargs, nil | |||
} |
@@ -0,0 +1,39 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package util | |||
import "database/sql/driver" | |||
func NamedValueToValue(nvs []driver.NamedValue) []driver.Value { | |||
vs := make([]driver.Value, 0, len(nvs)) | |||
for _, nv := range nvs { | |||
vs = append(vs, nv.Value) | |||
} | |||
return vs | |||
} | |||
func ValueToNamedValue(vs []driver.Value) []driver.NamedValue { | |||
nvs := make([]driver.NamedValue, 0, len(vs)) | |||
for i, v := range vs { | |||
nvs = append(nvs, driver.NamedValue{ | |||
Value: v, | |||
Ordinal: i, | |||
}) | |||
} | |||
return nvs | |||
} |
@@ -0,0 +1,355 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
// Copyright 2011 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the LICENSE file. | |||
// Package sql provides a generic interface around SQL (or SQL-like) | |||
// databases. | |||
// | |||
// The sql package must be used in conjunction with a database driver. | |||
// See https://golang.org/s/sqldrivers for a list of drivers. | |||
// | |||
// Drivers that do not support context cancellation will not return until | |||
// after the query is completed. | |||
// | |||
// For usage examples, see the wiki page at | |||
// https://golang.org/s/sqlwiki. | |||
package util | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"sync" | |||
) | |||
// ScanRows is the result of a query. Its cursor starts before the first row | |||
// of the result set. Use Next to advance from row to row. | |||
type ScanRows struct { | |||
//dc *driverConn // owned; must call releaseConn when closed to release | |||
dc driver.Conn | |||
releaseConn func(error) | |||
rowsi driver.Rows | |||
cancel func() // called when ScanRows is closed, may be nil. | |||
//closeStmt *driverStmt // if non-nil, statement to Close on close | |||
// closemu prevents ScanRows from closing while there | |||
// is an active streaming result. It is held for read during non-close operations | |||
// and exclusively during close. | |||
// | |||
// closemu guards lasterr and closed. | |||
closemu sync.RWMutex | |||
closed bool | |||
lasterr error // non-nil only if closed is true | |||
// lastcols is only used in Scan, Next, and NextResultSet which are expected | |||
// not to be called concurrently. | |||
lastcols []driver.Value | |||
} | |||
func NewScanRows(rowsi driver.Rows) *ScanRows { | |||
return &ScanRows{rowsi: rowsi} | |||
} | |||
// lasterrOrErrLocked returns either lasterr or the provided err. | |||
// rs.closemu must be read-locked. | |||
func (rs *ScanRows) lasterrOrErrLocked(err error) error { | |||
if rs.lasterr != nil && rs.lasterr != io.EOF { | |||
return rs.lasterr | |||
} | |||
return err | |||
} | |||
// bypassRowsAwaitDone is only used for testing. | |||
// If true, it will not close the ScanRows automatically from the context. | |||
var bypassRowsAwaitDone = false | |||
func (rs *ScanRows) initContextClose(ctx, txctx context.Context) { | |||
if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) { | |||
return | |||
} | |||
if bypassRowsAwaitDone { | |||
return | |||
} | |||
ctx, rs.cancel = context.WithCancel(ctx) | |||
go rs.awaitDone(ctx, txctx) | |||
} | |||
// awaitDone blocks until either ctx or txctx is canceled. The ctx is provided | |||
// from the query context and is canceled when the query ScanRows is closed. | |||
// If the query was issued in a transaction, the transaction's context | |||
// is also provided in txctx to ensure ScanRows is closed if the Tx is closed. | |||
func (rs *ScanRows) awaitDone(ctx, txctx context.Context) { | |||
var txctxDone <-chan struct{} | |||
if txctx != nil { | |||
txctxDone = txctx.Done() | |||
} | |||
select { | |||
case <-ctx.Done(): | |||
case <-txctxDone: | |||
} | |||
rs.close(ctx.Err()) | |||
} | |||
// Next prepares the next result row for reading with the Scan method. It | |||
// returns true on success, or false if there is no next result row or an error | |||
// happened while preparing it. Err should be consulted to distinguish between | |||
// the two cases. | |||
// | |||
// Every call to Scan, even the first one, must be preceded by a call to Next. | |||
func (rs *ScanRows) Next() bool { | |||
var doClose, ok bool | |||
withLock(rs.closemu.RLocker(), func() { | |||
doClose, ok = rs.nextLocked() | |||
}) | |||
if doClose { | |||
rs.Close() | |||
} | |||
return ok | |||
} | |||
func (rs *ScanRows) nextLocked() (doClose, ok bool) { | |||
if rs.closed { | |||
return false, false | |||
} | |||
// Lock the driver connection before calling the driver interface | |||
// rowsi to prevent a Tx from rolling back the connection at the same time. | |||
//rs.dc.Lock() | |||
//defer rs.dc.Unlock() | |||
if rs.lastcols == nil { | |||
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) | |||
} | |||
rs.lasterr = rs.rowsi.Next(rs.lastcols) | |||
if rs.lasterr != nil { | |||
// Close the connection if there is a driver error. | |||
if rs.lasterr != io.EOF { | |||
return true, false | |||
} | |||
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) | |||
if !ok { | |||
return true, false | |||
} | |||
// The driver is at the end of the current result set. | |||
// Test to see if there is another result set after the current one. | |||
// Only close ScanRows if there is no further result sets to read. | |||
if !nextResultSet.HasNextResultSet() { | |||
doClose = true | |||
} | |||
return doClose, false | |||
} | |||
return false, true | |||
} | |||
// NextResultSet prepares the next result set for reading. It reports whether | |||
// there is further result sets, or false if there is no further result set | |||
// or if there is an error advancing to it. The Err method should be consulted | |||
// to distinguish between the two cases. | |||
// | |||
// After calling NextResultSet, the Next method should always be called before | |||
// scanning. If there are further result sets they may not have rows in the result | |||
// set. | |||
func (rs *ScanRows) NextResultSet() bool { | |||
var doClose bool | |||
defer func() { | |||
if doClose { | |||
rs.Close() | |||
} | |||
}() | |||
rs.closemu.RLock() | |||
defer rs.closemu.RUnlock() | |||
if rs.closed { | |||
return false | |||
} | |||
rs.lastcols = nil | |||
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) | |||
if !ok { | |||
doClose = true | |||
return false | |||
} | |||
// Lock the driver connection before calling the driver interface | |||
// rowsi to prevent a Tx from rolling back the connection at the same time. | |||
//rs.dc.Lock() | |||
//defer rs.dc.Unlock() | |||
rs.lasterr = nextResultSet.NextResultSet() | |||
if rs.lasterr != nil { | |||
doClose = true | |||
return false | |||
} | |||
return true | |||
} | |||
// Err returns the error, if any, that was encountered during iteration. | |||
// Err may be called after an explicit or implicit Close. | |||
func (rs *ScanRows) Err() error { | |||
rs.closemu.RLock() | |||
defer rs.closemu.RUnlock() | |||
return rs.lasterrOrErrLocked(nil) | |||
} | |||
// | |||
var errRowsClosed = errors.New("sql: ScanRows are closed") | |||
// Scan copies the columns in the current row into the values pointed | |||
// at by dest. The number of values in dest must be the same as the | |||
// number of columns in ScanRows. | |||
// | |||
// Scan converts columns read from the database into the following | |||
// common Go types and special types provided by the sql package: | |||
// | |||
// *string | |||
// *[]byte | |||
// *int, *int8, *int16, *int32, *int64 | |||
// *uint, *uint8, *uint16, *uint32, *uint64 | |||
// *bool | |||
// *float32, *float64 | |||
// *interface{} | |||
// *RawBytes | |||
// *ScanRows (cursor value) | |||
// any type implementing Scanner (see Scanner docs) | |||
// | |||
// In the most simple case, if the type of the value from the source | |||
// column is an integer, bool or string type T and dest is of type *T, | |||
// Scan simply assigns the value through the pointer. | |||
// | |||
// Scan also converts between string and numeric types, as long as no | |||
// information would be lost. While Scan stringifies all numbers | |||
// scanned from numeric database columns into *string, scans into | |||
// numeric types are checked for overflow. For example, a float64 with | |||
// value 300 or a string with value "300" can scan into a uint16, but | |||
// not into a uint8, though float64(255) or "255" can scan into a | |||
// uint8. One exception is that scans of some float64 numbers to | |||
// strings may lose information when stringifying. In general, scan | |||
// floating point columns into *float64. | |||
// | |||
// If a dest argument has type *[]byte, Scan saves in that argument a | |||
// copy of the corresponding data. The copy is owned by the caller and | |||
// can be modified and held indefinitely. The copy can be avoided by | |||
// using an argument of type *RawBytes instead; see the documentation | |||
// for RawBytes for restrictions on its use. | |||
// | |||
// If an argument has type *interface{}, Scan copies the value | |||
// provided by the underlying driver without conversion. When scanning | |||
// from a source value of type []byte to *interface{}, a copy of the | |||
// slice is made and the caller owns the result. | |||
// | |||
// Source values of type time.Time may be scanned into values of type | |||
// *time.Time, *interface{}, *string, or *[]byte. When converting to | |||
// the latter two, time.RFC3339Nano is used. | |||
// | |||
// Source values of type bool may be scanned into types *bool, | |||
// *interface{}, *string, *[]byte, or *RawBytes. | |||
// | |||
// For scanning into *bool, the source may be true, false, 1, 0, or | |||
// string inputs parseable by strconv.ParseBool. | |||
// | |||
// Scan can also convert a cursor returned from a query, such as | |||
// "select cursor(select * from my_table) from dual", into a | |||
// *ScanRows value that can itself be scanned from. The parent | |||
// select query will close any cursor *ScanRows if the parent *ScanRows is closed. | |||
// | |||
// If any of the first arguments implementing Scanner returns an error, | |||
// that error will be wrapped in the returned error | |||
func (rs *ScanRows) Scan(dest ...interface{}) error { | |||
rs.closemu.RLock() | |||
if rs.lasterr != nil && rs.lasterr != io.EOF { | |||
rs.closemu.RUnlock() | |||
return rs.lasterr | |||
} | |||
if rs.closed { | |||
err := rs.lasterrOrErrLocked(errRowsClosed) | |||
rs.closemu.RUnlock() | |||
return err | |||
} | |||
rs.closemu.RUnlock() | |||
if rs.lastcols == nil { | |||
return errors.New("sql: Scan called without calling Next") | |||
} | |||
if len(dest) != len(rs.lastcols) { | |||
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest)) | |||
} | |||
for i, sv := range rs.lastcols { | |||
err := convertAssignRows(dest[i], sv, rs) | |||
if err != nil { | |||
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err) | |||
} | |||
} | |||
return nil | |||
} | |||
// rowsCloseHook returns a function so tests may install the | |||
// hook through a test only mutex. | |||
var rowsCloseHook = func() func(*ScanRows, *error) { return nil } | |||
// Close closes the ScanRows, preventing further enumeration. If Next is called | |||
// and returns false and there are no further result sets, | |||
// the ScanRows are closed automatically and it will suffice to check the | |||
// result of Err. Close is idempotent and does not affect the result of Err. | |||
func (rs *ScanRows) Close() error { | |||
//return rs.close(nil) | |||
return nil | |||
} | |||
func (rs *ScanRows) close(err error) error { | |||
rs.closemu.Lock() | |||
defer rs.closemu.Unlock() | |||
if rs.closed { | |||
return nil | |||
} | |||
rs.closed = true | |||
if rs.lasterr == nil { | |||
rs.lasterr = err | |||
} | |||
err = rs.rowsi.Close() | |||
//withLock(rs.dc, func() { | |||
// err = rs.rowsi.Close() | |||
//}) | |||
if fn := rowsCloseHook(); fn != nil { | |||
fn(rs, &err) | |||
} | |||
if rs.cancel != nil { | |||
rs.cancel() | |||
} | |||
//if rs.closeStmt != nil { | |||
// rs.closeStmt.Close() | |||
//} | |||
rs.releaseConn(err) | |||
return err | |||
} | |||
// withLock runs while holding lk. | |||
func withLock(lk sync.Locker, fn func()) { | |||
lk.Lock() | |||
defer lk.Unlock() // in case fn panics | |||
fn() | |||
} |
@@ -85,6 +85,6 @@ func (f *rmBranchRollbackProcessor) Process(ctx context.Context, rpcMessage mess | |||
log.Errorf("send branch rollback response error: {%#v}", err.Error()) | |||
return err | |||
} | |||
log.Infof("send branch rollback response success: xid %s, branchID %s, resourceID %s, applicationData %s", xid, branchID, resourceID, applicationData) | |||
log.Infof("send branch rollback response success: xid %s, branchID %d, resourceID %s, applicationData %s", xid, branchID, resourceID, applicationData) | |||
return nil | |||
} |
@@ -31,3 +31,16 @@ func SetUnexportedField(field reflect.Value, value interface{}) { | |||
func GetUnexportedField(field reflect.Value) interface{} { | |||
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() | |||
} | |||
func GetElemDataValue(data interface{}) interface{} { | |||
if data == nil { | |||
return data | |||
} | |||
value := reflect.ValueOf(data) | |||
kind := reflect.TypeOf(data).Kind() | |||
switch kind { | |||
case reflect.Ptr: | |||
return value.Elem().Interface() | |||
} | |||
return value.Interface() | |||
} |
@@ -0,0 +1,48 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package reflectx | |||
import ( | |||
"testing" | |||
) | |||
func TestGetElemDataValue(t *testing.T) { | |||
var aa = 10 | |||
var bb = "name" | |||
var cc bool | |||
tests := []struct { | |||
name string | |||
args interface{} | |||
want interface{} | |||
}{ | |||
{name: "test1", args: aa, want: aa}, | |||
{name: "test2", args: &aa, want: aa}, | |||
{name: "test3", args: bb, want: bb}, | |||
{name: "test4", args: &bb, want: bb}, | |||
{name: "test5", args: cc, want: cc}, | |||
{name: "test6", args: &cc, want: cc}, | |||
} | |||
for _, tt := range tests { | |||
t.Run(tt.name, func(t *testing.T) { | |||
if got := GetElemDataValue(tt.args); got != tt.want { | |||
t.Errorf("GetElemDataValue() = %v, want %v", got, tt.want) | |||
} | |||
}) | |||
} | |||
} |