From 291f13de3e0a6a7a0ffc059aca062b0932c1162a Mon Sep 17 00:00:00 2001 From: charlie Date: Sat, 18 Feb 2023 14:38:25 +0800 Subject: [PATCH] reactor: select for update executor (#478) * reactor: select for update executor --- pkg/client/client.go | 4 +- pkg/client/config.go | 1 + pkg/client/config_test.go | 10 +- pkg/datasource/sql/exec/at/at_executor.go | 8 +- pkg/datasource/sql/exec/at/base_executor.go | 6 +- pkg/datasource/sql/exec/at/config.go | 22 ++ pkg/datasource/sql/exec/at/delete_executor.go | 6 +- pkg/datasource/sql/exec/at/insert_executor.go | 14 +- .../sql/exec/at/select_for_update_executor.go | 323 ++++++++++++++++++ .../at/select_for_update_executor_test.go | 155 +++++++++ pkg/datasource/sql/exec/at/update_executor.go | 7 +- pkg/datasource/sql/exec/config/config.go | 27 ++ .../sql/exec/select_for_update_executor.go | 4 +- pkg/datasource/sql/types/executor.go | 2 +- .../builder/mysql_update_undo_log_builder.go | 6 +- pkg/datasource/sql/util/sql.go | 17 +- pkg/rm/config.go | 8 +- pkg/rm/init.go | 2 + testdata/conf/seatago.yml | 4 +- 19 files changed, 588 insertions(+), 38 deletions(-) create mode 100644 pkg/datasource/sql/exec/at/config.go create mode 100644 pkg/datasource/sql/exec/at/select_for_update_executor.go create mode 100644 pkg/datasource/sql/exec/at/select_for_update_executor_test.go create mode 100644 pkg/datasource/sql/exec/config/config.go diff --git a/pkg/client/client.go b/pkg/client/client.go index 641ab9a7..b81c21cd 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -21,8 +21,8 @@ import ( "sync" "github.com/seata/seata-go/pkg/datasource" - at "github.com/seata/seata-go/pkg/datasource/sql" + "github.com/seata/seata-go/pkg/datasource/sql/exec/config" "github.com/seata/seata-go/pkg/integration" "github.com/seata/seata-go/pkg/remoting/getty" "github.com/seata/seata-go/pkg/remoting/processor/client" @@ -75,9 +75,11 @@ func initRmClient(cfg *Config) { log.Init() initRemoting(cfg) rm.InitRm(rm.RmConfig{ + Config: cfg.ClientConfig.RmConfig, ApplicationID: cfg.ApplicationID, TxServiceGroup: cfg.TxServiceGroup, }) + config.Init(cfg.ClientConfig.RmConfig.LockConfig) client.RegisterProcessor() integration.Init() tcc.InitTCC() diff --git a/pkg/client/config.go b/pkg/client/config.go index 3e099fa3..a7520bac 100644 --- a/pkg/client/config.go +++ b/pkg/client/config.go @@ -31,6 +31,7 @@ import ( "github.com/knadh/koanf/parsers/toml" "github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/providers/rawbytes" + "github.com/seata/seata-go/pkg/datasource/sql" "github.com/seata/seata-go/pkg/datasource/sql/undo" "github.com/seata/seata-go/pkg/remoting/getty" diff --git a/pkg/client/config_test.go b/pkg/client/config_test.go index b60a0c30..b54ae20f 100644 --- a/pkg/client/config_test.go +++ b/pkg/client/config_test.go @@ -62,8 +62,8 @@ func TestLoadPath(t *testing.T) { assert.Equal(t, false, cfg.ClientConfig.RmConfig.SagaRetryPersistModeUpdate) assert.Equal(t, -2147482648, cfg.ClientConfig.RmConfig.TccActionInterceptorOrder) assert.Equal(t, "druid", cfg.ClientConfig.RmConfig.SqlParserType) - assert.Equal(t, 10, cfg.ClientConfig.RmConfig.LockConfig.RetryInterval) - assert.Equal(t, time.Second*30, cfg.ClientConfig.RmConfig.LockConfig.RetryTimes) + assert.Equal(t, 30*time.Second, cfg.ClientConfig.RmConfig.LockConfig.RetryInterval) + assert.Equal(t, 10, cfg.ClientConfig.RmConfig.LockConfig.RetryTimes) assert.Equal(t, true, cfg.ClientConfig.RmConfig.LockConfig.RetryPolicyBranchRollbackOnConflict) assert.NotNil(t, cfg.ClientConfig.UndoConfig) @@ -117,7 +117,7 @@ func TestLoadPath(t *testing.T) { } func TestLoadJson(t *testing.T) { - confJson := `{"enabled":false,"application-id":"application_test","tx-service-group":"default_tx_group","access-key":"test","secret-key":"test","enable-auto-data-source-proxy":false,"data-source-proxy-mode":"AT","client":{"rm":{"async-commit-buffer-limit":10000,"report-retry-count":5,"table-meta-check-enable":false,"report-success-enable":false,"saga-branch-register-enable":false,"saga-json-parser":"fastjson","saga-retry-persist-mode-update":false,"saga-compensate-persist-mode-update":false,"tcc-action-interceptor-order":-2147482648,"sql-parser-type":"druid","lock":{"retry-interval":10,"retry-times":"30s","retry-policy-branch-rollback-on-conflict":true}},"tm":{"commit-retry-count":5,"rollback-retry-count":5,"default-global-transaction-timeout":"60s","degrade-check":false,"degrade-check-period":2000,"degrade-check-allow-times":"10s","interceptor-order":-2147482648},"undo":{"data-validation":false,"log-serialization":"jackson222","only-care-update-columns":false,"log-table":"undo_log333","compress":{"enable":false,"type":"zip111","threshold":"128k"}}},"tcc":{"fence":{"log-table-name":"tcc_fence_log_test2","clean-period":80000000000}},"getty":{"reconnect-interval":1,"connection-num":10,"session":{"compress-encoding":true,"tcp-no-delay":false,"tcp-keep-alive":false,"keep-alive-period":"120s","tcp-r-buf-size":261120,"tcp-w-buf-size":32768,"tcp-read-timeout":"2s","tcp-write-timeout":"8s","wait-timeout":"2s","max-msg-len":261120,"session-name":"client_test","cron-period":"2s"}},"transport":{"shutdown":{"wait":"3s"},"type":"TCP","server":"NIO","heartbeat":true,"serialization":"seata","compressor":"none"," enable-tm-client-batch-send-request":false,"enable-rm-client-batch-send-request":true,"rpc-rm-request-timeout":"30s","rpc-tm-request-timeout":"30s"},"service":{"enable-degrade":true,"disable-global-transaction":true,"vgroup-mapping":{"default_tx_group":"default_test"},"grouplist":{"default":"127.0.0.1:8092"}}}` + confJson := `{"enabled":false,"application-id":"application_test","tx-service-group":"default_tx_group","access-key":"test","secret-key":"test","enable-auto-data-source-proxy":false,"data-source-proxy-mode":"AT","client":{"rm":{"async-commit-buffer-limit":10000,"report-retry-count":5,"table-meta-check-enable":false,"report-success-enable":false,"saga-branch-register-enable":false,"saga-json-parser":"fastjson","saga-retry-persist-mode-update":false,"saga-compensate-persist-mode-update":false,"tcc-action-interceptor-order":-2147482648,"sql-parser-type":"druid","lock":{"retry-interval":"30s","retry-times":10,"retry-policy-branch-rollback-on-conflict":true}},"tm":{"commit-retry-count":5,"rollback-retry-count":5,"default-global-transaction-timeout":"60s","degrade-check":false,"degrade-check-period":2000,"degrade-check-allow-times":"10s","interceptor-order":-2147482648},"undo":{"data-validation":false,"log-serialization":"jackson222","only-care-update-columns":false,"log-table":"undo_log333","compress":{"enable":false,"type":"zip111","threshold":"128k"}}},"tcc":{"fence":{"log-table-name":"tcc_fence_log_test2","clean-period":80000000000}},"getty":{"reconnect-interval":1,"connection-num":10,"session":{"compress-encoding":true,"tcp-no-delay":false,"tcp-keep-alive":false,"keep-alive-period":"120s","tcp-r-buf-size":261120,"tcp-w-buf-size":32768,"tcp-read-timeout":"2s","tcp-write-timeout":"8s","wait-timeout":"2s","max-msg-len":261120,"session-name":"client_test","cron-period":"2s"}},"transport":{"shutdown":{"wait":"3s"},"type":"TCP","server":"NIO","heartbeat":true,"serialization":"seata","compressor":"none"," enable-tm-client-batch-send-request":false,"enable-rm-client-batch-send-request":true,"rpc-rm-request-timeout":"30s","rpc-tm-request-timeout":"30s"},"service":{"enable-degrade":true,"disable-global-transaction":true,"vgroup-mapping":{"default_tx_group":"default_test"},"grouplist":{"default":"127.0.0.1:8092"}}}` cfg := LoadJson([]byte(confJson)) assert.NotNil(t, cfg) assert.Equal(t, false, cfg.Enabled) @@ -138,8 +138,8 @@ func TestLoadJson(t *testing.T) { assert.Equal(t, false, cfg.ClientConfig.RmConfig.SagaRetryPersistModeUpdate) assert.Equal(t, -2147482648, cfg.ClientConfig.RmConfig.TccActionInterceptorOrder) assert.Equal(t, "druid", cfg.ClientConfig.RmConfig.SqlParserType) - assert.Equal(t, 10, cfg.ClientConfig.RmConfig.LockConfig.RetryInterval) - assert.Equal(t, time.Second*30, cfg.ClientConfig.RmConfig.LockConfig.RetryTimes) + assert.Equal(t, 30*time.Second, cfg.ClientConfig.RmConfig.LockConfig.RetryInterval) + assert.Equal(t, 10, cfg.ClientConfig.RmConfig.LockConfig.RetryTimes) assert.Equal(t, true, cfg.ClientConfig.RmConfig.LockConfig.RetryPolicyBranchRollbackOnConflict) assert.NotNil(t, cfg.ClientConfig.UndoConfig) diff --git a/pkg/datasource/sql/exec/at/at_executor.go b/pkg/datasource/sql/exec/at/at_executor.go index adf39eaf..e32c1324 100644 --- a/pkg/datasource/sql/exec/at/at_executor.go +++ b/pkg/datasource/sql/exec/at/at_executor.go @@ -20,12 +20,11 @@ package at import ( "context" - "github.com/seata/seata-go/pkg/datasource/sql/util" - "github.com/seata/seata-go/pkg/tm" - "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/parser" "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/util" + "github.com/seata/seata-go/pkg/tm" ) func Init() { @@ -63,7 +62,8 @@ func (e *ATExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Exec exec = NewUpdateExecutor(parser, execCtx, e.hooks) case types.SQLTypeDelete: exec = NewDeleteExecutor(parser, execCtx, e.hooks) - //case types.SQLTypeSelectForUpdate: + case types.SQLTypeSelectForUpdate: + exec = NewSelectForUpdateExecutor(parser, execCtx, e.hooks) //case types.SQLTypeMultiDelete: //case types.SQLTypeMultiUpdate: default: diff --git a/pkg/datasource/sql/exec/at/base_executor.go b/pkg/datasource/sql/exec/at/base_executor.go index bdbf8360..0f1e854b 100644 --- a/pkg/datasource/sql/exec/at/base_executor.go +++ b/pkg/datasource/sql/exec/at/base_executor.go @@ -54,10 +54,10 @@ func (b *baseExecutor) afterHooks(ctx context.Context, execCtx *types.ExecContex // todo to use ColumnInfo get slice func (*baseExecutor) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []interface{} { scanSlice := make([]interface{}, 0, len(columnNames)) - for _, columnNmae := range columnNames { + for _, columnName := range columnNames { var ( - // 从metData获取该列的元信息 - columnMeta = tableMeta.Columns[columnNmae] + // get from metaData from this column + columnMeta = tableMeta.Columns[columnName] ) switch strings.ToUpper(columnMeta.DatabaseTypeString) { case "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", "JSON", "TINYTEXT": diff --git a/pkg/datasource/sql/exec/at/config.go b/pkg/datasource/sql/exec/at/config.go new file mode 100644 index 00000000..69df44de --- /dev/null +++ b/pkg/datasource/sql/exec/at/config.go @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package at + +import "github.com/seata/seata-go/pkg/rm" + +var LockConfig rm.LockConfig diff --git a/pkg/datasource/sql/exec/at/delete_executor.go b/pkg/datasource/sql/exec/at/delete_executor.go index b32cadfe..179b1f1e 100644 --- a/pkg/datasource/sql/exec/at/delete_executor.go +++ b/pkg/datasource/sql/exec/at/delete_executor.go @@ -101,8 +101,9 @@ func (d *deleteExecutor) beforeImage(ctx context.Context) (*types.RecordImage, e return nil, fmt.Errorf("invalid conn") } - tableName, _ := d.parserCtx.GteTableName() + tableName, _ := d.parserCtx.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, d.execContext.DBName, tableName) + if err != nil { return nil, err } @@ -155,10 +156,11 @@ func (d *deleteExecutor) buildBeforeImageSQL(query string, args []driver.NamedVa // afterImage build after image func (d *deleteExecutor) afterImage(ctx context.Context) (*types.RecordImage, error) { - tableName, _ := d.parserCtx.GteTableName() + tableName, _ := d.parserCtx.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, d.execContext.DBName, tableName) if err != nil { return nil, err } + return types.NewEmptyRecordImage(metaData, types.SQLTypeDelete), nil } diff --git a/pkg/datasource/sql/exec/at/insert_executor.go b/pkg/datasource/sql/exec/at/insert_executor.go index 8cc26d6a..e27ae877 100644 --- a/pkg/datasource/sql/exec/at/insert_executor.go +++ b/pkg/datasource/sql/exec/at/insert_executor.go @@ -24,6 +24,7 @@ import ( "strings" "github.com/arana-db/parser/ast" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/types" @@ -82,7 +83,7 @@ func (i *insertExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNam // beforeImage build before image func (i *insertExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) { - tableName, _ := i.parserCtx.GteTableName() + tableName, _ := i.parserCtx.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName) if err != nil { return nil, err @@ -96,7 +97,7 @@ func (i *insertExecutor) afterImage(ctx context.Context) (*types.RecordImage, er return nil, nil } - tableName, _ := i.parserCtx.GteTableName() + tableName, _ := i.parserCtx.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName) if err != nil { return nil, err @@ -141,7 +142,7 @@ func (i *insertExecutor) afterImage(ctx context.Context) (*types.RecordImage, er // buildAfterImageSQL build select sql from insert sql func (i *insertExecutor) buildAfterImageSQL(ctx context.Context) (string, []driver.NamedValue, error) { // get all pk value - tableName, _ := i.parserCtx.GteTableName() + tableName, _ := i.parserCtx.GetTableName() meta, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName) if err != nil { @@ -414,7 +415,8 @@ func (i *insertExecutor) getPkValuesByColumn(ctx context.Context, execCtx *types if !i.isAstStmtValid() { return nil, nil } - tableName, _ := i.parserCtx.GteTableName() + + tableName, _ := i.parserCtx.GetTableName() meta, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName) if err != nil { return nil, err @@ -452,11 +454,13 @@ func (i *insertExecutor) getPkValuesByAuto(ctx context.Context, execCtx *types.E if !i.isAstStmtValid() { return nil, nil } - tableName, _ := i.parserCtx.GteTableName() + + tableName, _ := i.parserCtx.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName) if err != nil { return nil, err } + pkValuesMap := make(map[string][]interface{}) pkMetaMap := metaData.GetPrimaryKeyMap() if len(pkMetaMap) == 0 { diff --git a/pkg/datasource/sql/exec/at/select_for_update_executor.go b/pkg/datasource/sql/exec/at/select_for_update_executor.go new file mode 100644 index 00000000..7031adc3 --- /dev/null +++ b/pkg/datasource/sql/exec/at/select_for_update_executor.go @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package at + +import ( + "bytes" + "context" + "database/sql/driver" + "fmt" + "io" + "reflect" + "strconv" + "time" + + "github.com/arana-db/parser/ast" + "github.com/arana-db/parser/format" + "github.com/arana-db/parser/model" + + "github.com/seata/seata-go/pkg/datasource/sql/datasource" + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/util" + "github.com/seata/seata-go/pkg/protocol/branch" + "github.com/seata/seata-go/pkg/rm" + "github.com/seata/seata-go/pkg/util/backoff" + seatabytes "github.com/seata/seata-go/pkg/util/bytes" + "github.com/seata/seata-go/pkg/util/log" +) + +type selectForUpdateExecutor struct { + baseExecutor + + parserCtx *types.ParseContext + execContext *types.ExecContext + cfg *rm.LockConfig + tx driver.Tx + tableName string + selectPKSQL string + metaData *types.TableMeta + savepointName string +} + +func NewSelectForUpdateExecutor(parserCtx *types.ParseContext, execContext *types.ExecContext, hooks []exec.SQLHook) executor { + return &selectForUpdateExecutor{ + baseExecutor: baseExecutor{ + hooks: hooks, + }, + parserCtx: parserCtx, + execContext: execContext, + cfg: &LockConfig, + } +} + +func (s *selectForUpdateExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { + s.beforeHooks(ctx, s.execContext) + defer func() { + s.afterHooks(ctx, s.execContext) + }() + + if !s.execContext.IsInGlobalTransaction && !s.execContext.IsRequireGlobalLock { + return f(ctx, s.execContext.Query, s.execContext.NamedValues) + } + + var ( + result types.ExecResult + originalAutoCommit = s.execContext.IsAutoCommit + err error + ) + + if s.tableName, err = s.execContext.ParseContext.GetTableName(); err != nil { + return nil, err + } + + if s.metaData, err = datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, s.execContext.DBName, s.tableName); err != nil { + return nil, err + } + + // build query primary key sql + if s.selectPKSQL, err = s.buildSelectPKSQL(s.execContext.ParseContext.SelectStmt, s.metaData); err != nil { + return nil, err + } + + bf := backoff.New(ctx, backoff.Config{ + MaxRetries: s.cfg.RetryTimes, + MinBackoff: s.cfg.RetryInterval, + MaxBackoff: s.cfg.RetryInterval, + }) + + for bf.Ongoing() { + if result, err = s.doExecContext(ctx, f); err == nil { + break + } + + // if there is an err in doExecContext, we should rollback first + if s.savepointName != "" { + if _, err := s.exec(fmt.Sprintf("rollback to %s;", s.savepointName), nil); err != nil { + log.Error(err) + return nil, err + } + } else { + if err = s.tx.Rollback(); err != nil { + return nil, err + } + } + + bf.Wait() + } + + if bf.Err() != nil { + lastErr := fmt.Errorf("lastErr %v, backoff error: %v", err, bf.Err()) + log.Warnf("select for update executor failed: %v", lastErr) + return nil, lastErr + } + + if originalAutoCommit { + if err = s.tx.Commit(); err != nil { + return nil, err + } + s.execContext.IsAutoCommit = true + } + + return result, nil +} + +func (s *selectForUpdateExecutor) doExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { + var ( + now = time.Now().Unix() + result types.ExecResult + originalAutoCommit = s.execContext.IsAutoCommit + err error + ) + + if originalAutoCommit { + // In order to hold the local db lock during global lock checking + // set auto commit value to false first if original auto commit was true + s.execContext.IsAutoCommit = false + s.tx, err = s.execContext.Conn.Begin() + if err != nil { + return nil, err + } + } else if s.execContext.IsSupportsSavepoints { + // In order to release the local db lock when global lock conflict + // create a save point if original auto commit was false, then use the save point here to release db + // lock during global lock checking if necessary + if _, err = s.exec(fmt.Sprintf("savepoint %d;", now), nil); err != nil { + return nil, err + } + s.savepointName = strconv.FormatInt(now, 10) + } else { + return nil, fmt.Errorf("not support savepoint. please check your db version") + } + + // execute business SQL, try to get local lock + result, err = f(ctx, s.execContext.Query, s.execContext.NamedValues) + if err != nil { + return nil, err + } + + // query primary key values + var lockKey string + if _, err = s.exec(s.selectPKSQL, func(rows driver.Rows) { + lockKey = s.buildLockKey(rows, s.metaData) + }); err != nil { + return nil, err + } + if lockKey == "" { + return nil, nil + } + + // check global lock + lockable, err := datasource.GetDataSourceManager(branch.BranchTypeAT).LockQuery(ctx, rm.LockQueryParam{ + Xid: s.execContext.TxCtx.XID, + BranchType: branch.BranchTypeAT, + ResourceId: s.execContext.TxCtx.ResourceID, + LockKeys: lockKey, + }) + if err != nil { + return nil, err + } + + if !lockable { + return nil, fmt.Errorf("get lock failed, lockKey: %v", lockKey) + } + + return result, nil +} + +// buildSelectSQLByUpdate build select sql from update sql +func (s *selectForUpdateExecutor) buildSelectPKSQL(stmt *ast.SelectStmt, meta *types.TableMeta) (string, error) { + pks := meta.GetPrimaryKeyOnlyName() + if len(pks) == 0 { + return "", fmt.Errorf("%s needs to contain the primary key.", meta.TableName) + } + + var fields []*ast.SelectField + for _, column := range pks { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.CIStr{ + O: column, + L: column, + }, + }, + }, + }) + } + + selStmt := ast.SelectStmt{ + SelectStmtOpts: &ast.SelectStmtOpts{}, + From: stmt.From, + Where: stmt.Where, + Fields: &ast.FieldList{Fields: fields}, + OrderBy: stmt.OrderBy, + Limit: stmt.Limit, + TableHints: stmt.TableHints, + LockInfo: &ast.SelectLockInfo{ + LockType: ast.SelectLockForUpdate, + }, + } + + b := seatabytes.NewByteBuffer([]byte{}) + selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) + sql := string(b.Bytes()) + log.Infof("build select sql by update sourceQuery, sql {}", sql) + + return sql, nil +} + +// the string as local key. the local key example(multi pk): "t_user:1_a,2_b" +func (s *selectForUpdateExecutor) buildLockKey(rows driver.Rows, meta *types.TableMeta) string { + var ( + lockKeys bytes.Buffer + idx int + columnNames []string + ) + lockKeys.WriteString(meta.TableName) + lockKeys.WriteString(":") + + columnNames = meta.GetPrimaryKeyOnlyName() + sqlRows := util.NewScanRows(rows) + for sqlRows.Next() { + ss := s.GetScanSlice(columnNames, meta) + if err := sqlRows.Scan(ss...); err != nil { + if err == io.EOF { + break + } + return "" + } + + if idx > 0 { + lockKeys.WriteString(",") + } + idx++ + + for i, value := range ss { + if i > 0 { + lockKeys.WriteString("_") + } + + // if the value is NullInt64 or NullString etc. then call its Value() + ty := reflect.TypeOf(value) + if f, ok := ty.MethodByName("Value"); ok { + res := f.Func.Call([]reflect.Value{reflect.ValueOf(value)}) + if res[1].IsNil() { // res[0]: driver.Value, [1]: error + lockKeys.WriteString(res[0].Elem().String()) + } + continue + } + + // if the value type is *int64, *string etc. then get the true value + lockKeys.WriteString(fmt.Sprintf("%v", reflect.ValueOf(value).Elem())) + } + } + return lockKeys.String() +} + +func (s *selectForUpdateExecutor) exec(sql string, f func(rows driver.Rows)) (driver.Rows, error) { + var ( + querierContext driver.QueryerContext + querier driver.Queryer + ok bool + ) + if querierContext, ok = s.execContext.Conn.(driver.QueryerContext); !ok { + err := fmt.Sprintf("invalid conn, can't convert %v to driver.QueryerContext", s.execContext.Conn) + if querier, ok = s.execContext.Conn.(driver.Queryer); !ok { + err = err + fmt.Sprintf(", also can't convert %v to drvier.Queryer", s.execContext.Conn) + return nil, fmt.Errorf(err) + } + } + + rows, err := util.CtxDriverQuery(context.TODO(), querierContext, querier, sql, nil) + defer func() { + if rows != nil { + _ = rows.Close() + } + }() + + if err != nil { + return nil, err + } + + if f != nil { + f(rows) + } + + return nil, nil +} diff --git a/pkg/datasource/sql/exec/at/select_for_update_executor_test.go b/pkg/datasource/sql/exec/at/select_for_update_executor_test.go new file mode 100644 index 00000000..491b8e6e --- /dev/null +++ b/pkg/datasource/sql/exec/at/select_for_update_executor_test.go @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package at + +import ( + "database/sql/driver" + "io" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/datasource/sql/parser" + "github.com/seata/seata-go/pkg/datasource/sql/types" +) + +var ( + index = 0 + rowValues = [][]interface{}{ + {1, "oid11"}, + {2, "oid22"}, + {3, "oid33"}, + } +) + +func TestBuildSelectPKSQL(t *testing.T) { + e := selectForUpdateExecutor{} + sql := "select name, order_id from t_user where age > ? for update" + + ctx, err := parser.DoParser(sql) + + metaData := types.TableMeta{ + TableName: "t_user", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + ColumnName: "id", + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + "order_id": { + IType: types.IndexTypePrimaryKey, + ColumnName: "order_id", + Columns: []types.ColumnMeta{ + {ColumnName: "order_id"}, + }, + }, + "age": { + IType: types.IndexTypeNull, + ColumnName: "age", + Columns: []types.ColumnMeta{ + {ColumnName: "age"}, + }, + }, + }, + } + + assert.Nil(t, err) + assert.NotNil(t, ctx) + assert.NotNil(t, ctx.SelectStmt) + + selSQL, err := e.buildSelectPKSQL(ctx.SelectStmt, &metaData) + assert.Nil(t, err) + equal := "SELECT SQL_NO_CACHE order_id,id FROM t_user WHERE age>? FOR UPDATE" == selSQL || + "SELECT SQL_NO_CACHE id,order_id FROM t_user WHERE age>? FOR UPDATE" == selSQL + assert.Equal(t, equal, true) +} + +func TestBuildLockKey(t *testing.T) { + e := selectForUpdateExecutor{} + + metaData := types.TableMeta{ + TableName: "t_user", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + ColumnName: "id", + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + "order_id": { + IType: types.IndexTypePrimaryKey, + ColumnName: "order_id", + Columns: []types.ColumnMeta{ + {ColumnName: "order_id"}, + }, + }, + "age": { + IType: types.IndexTypeNull, + ColumnName: "age", + Columns: []types.ColumnMeta{ + {ColumnName: "age"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + DatabaseTypeString: "INT", + ColumnName: "id", + }, + "order_id": { + DatabaseTypeString: "VARCHAR", + ColumnName: "order_id", + }, + "age": { + DatabaseTypeString: "INT", + ColumnName: "age", + }, + }, + } + rows := mockRows{} + lockKey := e.buildLockKey(rows, &metaData) + assert.Equal(t, "t_user:1_oid11,2_oid22,3_oid33", lockKey) +} + +type mockRows struct{} + +func (m mockRows) Columns() []string { + return []string{"id", "order_id"} +} + +func (m mockRows) Close() error { + //TODO implement me + panic("implement me") +} + +func (m mockRows) Next(dest []driver.Value) error { + if index == len(rowValues) { + return io.EOF + } + + if len(dest) >= 1 { + dest[0] = rowValues[index][0] + dest[1] = rowValues[index][1] + index++ + } + + return nil +} diff --git a/pkg/datasource/sql/exec/at/update_executor.go b/pkg/datasource/sql/exec/at/update_executor.go index b772dd8c..699fed6c 100644 --- a/pkg/datasource/sql/exec/at/update_executor.go +++ b/pkg/datasource/sql/exec/at/update_executor.go @@ -26,6 +26,7 @@ import ( "github.com/arana-db/parser/ast" "github.com/arana-db/parser/format" "github.com/arana-db/parser/model" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/types" @@ -94,7 +95,7 @@ func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, e return nil, err } - tableName, _ := u.parserCtx.GteTableName() + tableName, _ := u.parserCtx.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) if err != nil { return nil, err @@ -143,7 +144,7 @@ func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.Recor return &types.RecordImage{}, nil } - tableName, _ := u.parserCtx.GteTableName() + tableName, _ := u.parserCtx.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) if err != nil { return nil, err @@ -230,7 +231,7 @@ func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver. } // select indexes columns - tableName, _ := u.parserCtx.GteTableName() + tableName, _ := u.parserCtx.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) if err != nil { return "", nil, err diff --git a/pkg/datasource/sql/exec/config/config.go b/pkg/datasource/sql/exec/config/config.go new file mode 100644 index 00000000..43707373 --- /dev/null +++ b/pkg/datasource/sql/exec/config/config.go @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package config + +import ( + "github.com/seata/seata-go/pkg/datasource/sql/exec/at" + "github.com/seata/seata-go/pkg/rm" +) + +func Init(config rm.LockConfig) { + at.LockConfig = config +} diff --git a/pkg/datasource/sql/exec/select_for_update_executor.go b/pkg/datasource/sql/exec/select_for_update_executor.go index 961d0733..d334ba7b 100644 --- a/pkg/datasource/sql/exec/select_for_update_executor.go +++ b/pkg/datasource/sql/exec/select_for_update_executor.go @@ -63,7 +63,7 @@ func (s SelectForUpdateExecutor) ExecWithNamedValue(ctx context.Context, execCtx originalAutoCommit = execCtx.IsAutoCommit ) - table, err := execCtx.ParseContext.GteTableName() + table, err := execCtx.ParseContext.GetTableName() if err != nil { return nil, err } @@ -177,7 +177,7 @@ func (s SelectForUpdateExecutor) ExecWithValue(ctx context.Context, execCtx *typ originalAutoCommit = execCtx.IsAutoCommit ) - table, err := execCtx.ParseContext.GteTableName() + table, err := execCtx.ParseContext.GetTableName() if err != nil { return nil, err } diff --git a/pkg/datasource/sql/types/executor.go b/pkg/datasource/sql/types/executor.go index 2a4b5281..e06abd9c 100644 --- a/pkg/datasource/sql/types/executor.go +++ b/pkg/datasource/sql/types/executor.go @@ -55,7 +55,7 @@ func (p *ParseContext) HasValidStmt() bool { return p.InsertStmt != nil || p.UpdateStmt != nil || p.DeleteStmt != nil } -func (p *ParseContext) GteTableName() (string, error) { +func (p *ParseContext) GetTableName() (string, error) { var table *ast.TableRefsClause if p.InsertStmt != nil { diff --git a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go index a85bbe04..7c971519 100644 --- a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go @@ -67,7 +67,7 @@ func (u *MySQLUpdateUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *ty return nil, err } - tableName, _ := execCtx.ParseContext.GteTableName() + tableName, _ := execCtx.ParseContext.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, execCtx.DBName, tableName) if err != nil { return nil, err @@ -114,7 +114,7 @@ func (u *MySQLUpdateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *typ beforeImage = beforeImages[0] } - tableName, _ := execCtx.ParseContext.GteTableName() + tableName, _ := execCtx.ParseContext.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, execCtx.DBName, tableName) if err != nil { return nil, err @@ -188,7 +188,7 @@ func (u *MySQLUpdateUndoLogBuilder) buildBeforeImageSQL(ctx context.Context, exe } // select indexes columns - tableName, _ := execCtx.ParseContext.GteTableName() + tableName, _ := execCtx.ParseContext.GetTableName() metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, execCtx.DBName, tableName) if err != nil { return "", nil, err diff --git a/pkg/datasource/sql/util/sql.go b/pkg/datasource/sql/util/sql.go index d8163de6..260c664d 100644 --- a/pkg/datasource/sql/util/sql.go +++ b/pkg/datasource/sql/util/sql.go @@ -38,6 +38,7 @@ import ( "errors" "fmt" "io" + "reflect" "sync" ) @@ -294,9 +295,19 @@ func (rs *ScanRows) Scan(dest ...interface{}) error { return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest)) } for i, sv := range rs.lastcols { - err := convertAssignRows(dest[i], sv, rs) - if err != nil { - return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err) + // the type of dest may be NullString, NullInt64, int64, etc, we should call its Scan() + ty := reflect.TypeOf(dest[i]) + fn, ok := ty.MethodByName("Scan") + if !ok { + err := convertAssignRows(dest[i], sv, rs) + if err != nil { + return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err) + } + } else { + res := fn.Func.Call([]reflect.Value{reflect.ValueOf(dest[i]), reflect.ValueOf(sv)}) + if len(res) > 0 && !res[0].IsNil() { + return fmt.Errorf(`sql: Scan error on column index %d, name %q: %v`, i, rs.rowsi.Columns()[i], res[0].Elem().String()) + } } } return nil diff --git a/pkg/rm/config.go b/pkg/rm/config.go index 8d0bbcbb..ada232e3 100644 --- a/pkg/rm/config.go +++ b/pkg/rm/config.go @@ -37,8 +37,8 @@ type Config struct { } type LockConfig struct { - RetryInterval int `yaml:"retry-interval" json:"retry-interval,omitempty" koanf:"retry-interval"` - RetryTimes time.Duration `yaml:"retry-times" json:"retry-times,omitempty" koanf:"retry-times"` + RetryInterval time.Duration `yaml:"retry-interval" json:"retry-interval,omitempty" koanf:"retry-interval"` + RetryTimes int `yaml:"retry-times" json:"retry-times,omitempty" koanf:"retry-times"` RetryPolicyBranchRollbackOnConflict bool `yaml:"retry-policy-branch-rollback-on-conflict" json:"retry-policy-branch-rollback-on-conflict,omitempty" koanf:"retry-policy-branch-rollback-on-conflict"` } @@ -57,7 +57,7 @@ func (cfg *Config) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { } func (cfg *LockConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { - f.IntVar(&cfg.RetryInterval, prefix+".retry-interval", 10, "The maximum number of retries when lock fail.") - f.DurationVar(&cfg.RetryTimes, prefix+".retry-times", 30*time.Second, "The duration allowed for lock retrying.") + f.DurationVar(&cfg.RetryInterval, prefix+".retry-interval", 30*time.Second, "The maximum number of retries when lock fail.") + f.IntVar(&cfg.RetryTimes, prefix+".retry-times", 10, "The duration allowed for lock retrying.") f.BoolVar(&cfg.RetryPolicyBranchRollbackOnConflict, prefix+".retry-policy-branch-rollback-on-conflict", true, "The switch for lock conflict.") } diff --git a/pkg/rm/init.go b/pkg/rm/init.go index dfccd05f..469cdfaa 100644 --- a/pkg/rm/init.go +++ b/pkg/rm/init.go @@ -20,6 +20,8 @@ package rm var rmConfig RmConfig type RmConfig struct { + Config + ApplicationID string TxServiceGroup string } diff --git a/testdata/conf/seatago.yml b/testdata/conf/seatago.yml index cfc14803..6b82536b 100644 --- a/testdata/conf/seatago.yml +++ b/testdata/conf/seatago.yml @@ -44,8 +44,8 @@ seata: # Parse SQL parser selection sql-parser-type: druid lock: - retry-interval: 10 - retry-times: 30s + retry-interval: 30s + retry-times: 10 retry-policy-branch-rollback-on-conflict: true tm: commit-retry-count: 5