Browse Source

optimize some format (#392)

Co-authored-by: haohongfan1 <haohongfan1@jd.com>
tags/v1.0.3
georgehao GitHub 2 years ago
parent
commit
c3cd13b761
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 113 additions and 141 deletions
  1. +1
    -2
      pkg/datasource/sql/conn.go
  2. +5
    -5
      pkg/datasource/sql/conn_at.go
  3. +6
    -6
      pkg/datasource/sql/conn_at_test.go
  4. +2
    -2
      pkg/datasource/sql/conn_xa.go
  5. +6
    -6
      pkg/datasource/sql/conn_xa_test.go
  6. +3
    -3
      pkg/datasource/sql/connector.go
  7. +2
    -2
      pkg/datasource/sql/connector_test.go
  8. +2
    -1
      pkg/datasource/sql/driver.go
  9. +27
    -36
      pkg/datasource/sql/exec/executor.go
  10. +8
    -9
      pkg/datasource/sql/exec/hook.go
  11. +14
    -16
      pkg/datasource/sql/exec/xa/executor_xa.go
  12. +9
    -15
      pkg/datasource/sql/stmt.go
  13. +3
    -3
      pkg/datasource/sql/tx.go
  14. +1
    -1
      pkg/datasource/sql/tx_at.go
  15. +1
    -1
      pkg/datasource/sql/tx_xa.go
  16. +6
    -14
      pkg/datasource/sql/types/executor.go
  17. +17
    -19
      pkg/datasource/sql/types/types.go

+ 1
- 2
pkg/datasource/sql/conn.go View File

@@ -29,7 +29,6 @@ import (
// by multiple goroutines.
//
// Conn is assumed to be stateful.

type Conn struct {
res *DBResource
txCtx *types.TransactionContext
@@ -135,7 +134,7 @@ func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) {
return nil, driver.ErrSkip
}

executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TxType, query)
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransactionMode, query)
if err != nil {
return nil, err
}


+ 5
- 5
pkg/datasource/sql/conn_at.go View File

@@ -53,7 +53,7 @@ func (c *ATConn) QueryContext(ctx context.Context, query string, args []driver.N
}

ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) {
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TxType, query)
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransactionMode, query)
if err != nil {
return nil, err
}
@@ -89,7 +89,7 @@ func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.Na
}

ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) {
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TxType, query)
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransactionMode, query)
if err != nil {
return nil, err
}
@@ -130,7 +130,7 @@ func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx,

if tm.IsGlobalTx(ctx) {
c.txCtx.XID = tm.GetXID(ctx)
c.txCtx.TxType = types.ATMode
c.txCtx.TransactionMode = types.ATMode
}

tx, err := c.Conn.BeginTx(ctx, opts)
@@ -149,7 +149,7 @@ func (c *ATConn) createOnceTxContext(ctx context.Context) bool {
c.txCtx.DBType = c.res.dbType
c.txCtx.ResourceID = c.res.resourceID
c.txCtx.XID = tm.GetXID(ctx)
c.txCtx.TxType = types.ATMode
c.txCtx.TransactionMode = types.ATMode
c.txCtx.GlobalLockRequire = true
}

@@ -162,7 +162,7 @@ func (c *ATConn) createNewTxOnExecIfNeed(ctx context.Context, f func() (types.Ex
err error
)

if c.txCtx.TxType != types.Local && c.autoCommit {
if c.txCtx.TransactionMode != types.Local && c.autoCommit {
tx, err = c.BeginTx(ctx, driver.TxOptions{Isolation: driver.IsolationLevel(gosql.LevelDefault)})
if err != nil {
return nil, err


+ 6
- 6
pkg/datasource/sql/conn_at_test.go View File

@@ -87,14 +87,14 @@ func TestATConn_ExecContext(t *testing.T) {
beforeHook := func(_ context.Context, execCtx *types.ExecContext) {
t.Logf("on exec xid=%s", execCtx.TxCtx.XID)
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID)
assert.Equal(t, types.ATMode, execCtx.TxCtx.TxType)
assert.Equal(t, types.ATMode, execCtx.TxCtx.TransactionMode)
}
mi.before = beforeHook

var comitCnt int32
beforeCommit := func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
assert.Equal(t, types.ATMode, tx.tranCtx.TxType)
assert.Equal(t, types.ATMode, tx.tranCtx.TransactionMode)
}
ti.beforeCommit = beforeCommit

@@ -112,7 +112,7 @@ func TestATConn_ExecContext(t *testing.T) {
t.Run("not xid", func(t *testing.T) {
mi.before = func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XID)
assert.Equal(t, types.Local, execCtx.TxCtx.TxType)
assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode)
}

var comitCnt int32
@@ -149,7 +149,7 @@ func TestATConn_BeginTx(t *testing.T) {

mi.before = func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XID)
assert.Equal(t, types.Local, execCtx.TxCtx.TxType)
assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode)
}

var comitCnt int32
@@ -175,7 +175,7 @@ func TestATConn_BeginTx(t *testing.T) {

mi.before = func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XID)
assert.Equal(t, types.Local, execCtx.TxCtx.TxType)
assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode)
}

var comitCnt int32
@@ -203,7 +203,7 @@ func TestATConn_BeginTx(t *testing.T) {

mi.before = func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID)
assert.Equal(t, types.ATMode, execCtx.TxCtx.TxType)
assert.Equal(t, types.ATMode, execCtx.TxCtx.TransactionMode)
}

var comitCnt int32


+ 2
- 2
pkg/datasource/sql/conn_xa.go View File

@@ -73,7 +73,7 @@ func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx,
c.txCtx.TxOpt = opts

if tm.IsGlobalTx(ctx) {
c.txCtx.TxType = types.XAMode
c.txCtx.TransactionMode = types.XAMode
c.txCtx.XID = tm.GetXID(ctx)
}

@@ -92,7 +92,7 @@ func (c *XAConn) createOnceTxContext(ctx context.Context) bool {
c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.XID = tm.GetXID(ctx)
c.txCtx.TxType = types.XAMode
c.txCtx.TransactionMode = types.XAMode
}

return onceTx


+ 6
- 6
pkg/datasource/sql/conn_xa_test.go View File

@@ -138,14 +138,14 @@ func TestXAConn_ExecContext(t *testing.T) {
before := func(_ context.Context, execCtx *types.ExecContext) {
t.Logf("on exec xid=%s", execCtx.TxCtx.XID)
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID)
assert.Equal(t, types.XAMode, execCtx.TxCtx.TxType)
assert.Equal(t, types.XAMode, execCtx.TxCtx.TransactionMode)
}
mi.before = before

var comitCnt int32
beforeCommit := func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
assert.Equal(t, tx.tranCtx.TxType, types.XAMode)
assert.Equal(t, tx.tranCtx.TransactionMode, types.XAMode)
}
ti.beforeCommit = beforeCommit

@@ -164,7 +164,7 @@ func TestXAConn_ExecContext(t *testing.T) {
t.Run("not xid", func(t *testing.T) {
before := func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XID)
assert.Equal(t, types.Local, execCtx.TxCtx.TxType)
assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode)
}
mi.before = before

@@ -203,7 +203,7 @@ func TestXAConn_BeginTx(t *testing.T) {

mi.before = func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XID)
assert.Equal(t, types.Local, execCtx.TxCtx.TxType)
assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode)
}

var comitCnt int32
@@ -229,7 +229,7 @@ func TestXAConn_BeginTx(t *testing.T) {

mi.before = func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XID)
assert.Equal(t, types.Local, execCtx.TxCtx.TxType)
assert.Equal(t, types.Local, execCtx.TxCtx.TransactionMode)
}

var comitCnt int32
@@ -257,7 +257,7 @@ func TestXAConn_BeginTx(t *testing.T) {

mi.before = func(_ context.Context, execCtx *types.ExecContext) {
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XID)
assert.Equal(t, types.XAMode, execCtx.TxCtx.TxType)
assert.Equal(t, types.XAMode, execCtx.TxCtx.TransactionMode)
}

var comitCnt int32


+ 3
- 3
pkg/datasource/sql/connector.go View File

@@ -29,7 +29,7 @@ import (

type seataATConnector struct {
*seataConnector
transType types.TransactionType
transType types.TransactionMode
}

func (c *seataATConnector) Connect(ctx context.Context) (driver.Conn, error) {
@@ -53,7 +53,7 @@ func (c *seataATConnector) Driver() driver.Driver {

type seataXAConnector struct {
*seataConnector
transType types.TransactionType
transType types.TransactionMode
}

func (c *seataXAConnector) Connect(ctx context.Context) (driver.Conn, error) {
@@ -88,7 +88,7 @@ func (c *seataXAConnector) Driver() driver.Driver {
// If a Connector implements io.Closer, the sql package's DB.Close
// method will call Close and return error (if any).
type seataConnector struct {
transType types.TransactionType
transType types.TransactionMode
conf *seataServerConfig
res *DBResource
once sync.Once


+ 2
- 2
pkg/datasource/sql/connector_test.go View File

@@ -82,7 +82,7 @@ func Test_seataATConnector_Connect(t *testing.T) {

atConn, ok := conn.(*ATConn)
assert.True(t, ok, "need return seata at connection")
assert.True(t, atConn.txCtx.TxType == types.Local, "init need local tx")
assert.True(t, atConn.txCtx.TransactionMode == types.Local, "init need local tx")
}

func initMockXaConnector(t *testing.T, ctrl *gomock.Controller, db *sql.DB, f initConnectorFunc) driver.Connector {
@@ -126,5 +126,5 @@ func Test_seataXAConnector_Connect(t *testing.T) {

xaConn, ok := conn.(*XAConn)
assert.True(t, ok, "need return seata xa connection")
assert.True(t, xaConn.txCtx.TxType == types.Local, "init need local tx")
assert.True(t, xaConn.txCtx.TransactionMode == types.Local, "init need local tx")
}

+ 2
- 1
pkg/datasource/sql/driver.go View File

@@ -47,6 +47,7 @@ func init() {
target: mysql.MySQLDriver{},
},
})

sql.Register(SeataXAMySQLDriver, &seataXADriver{
seataDriver: &seataDriver{
transType: types.XAMode,
@@ -96,7 +97,7 @@ func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err erro
}

type seataDriver struct {
transType types.TransactionType
transType types.TransactionMode
target driver.Driver
}



+ 27
- 36
pkg/datasource/sql/exec/executor.go View File

@@ -22,6 +22,7 @@ import (
"database/sql/driver"
"fmt"

"github.com/mitchellh/copystructure"
"github.com/pkg/errors"

"github.com/seata/seata-go/pkg/datasource/sql/parser"
@@ -30,8 +31,6 @@ import (
"github.com/seata/seata-go/pkg/datasource/sql/undo/builder"
"github.com/seata/seata-go/pkg/tm"
"github.com/seata/seata-go/pkg/util/log"

"github.com/mitchellh/copystructure"
)

func init() {
@@ -39,13 +38,12 @@ func init() {
undo.RegisterUndoLogBuilder(types.MultiExecutor, builder.GetMySQLMultiUndoLogBuilder)
}

// executorSolts
var (
executorSoltsAT = make(map[types.DBType]map[types.ExecutorType]func() SQLExecutor)
executorSoltsXA = make(map[types.DBType]func() SQLExecutor)
)

// RegisterATExecutor
// RegisterATExecutor AT executor
func RegisterATExecutor(dt types.DBType, et types.ExecutorType, builder func() SQLExecutor) {
if _, ok := executorSoltsAT[dt]; !ok {
executorSoltsAT[dt] = make(map[types.ExecutorType]func() SQLExecutor)
@@ -58,7 +56,7 @@ func RegisterATExecutor(dt types.DBType, et types.ExecutorType, builder func() S
}
}

// RegisterXAExecutor
// RegisterXAExecutor XA executor
func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) {
executorSoltsXA[dt] = func() SQLExecutor {
return &BaseExecutor{ex: builder()}
@@ -71,41 +69,37 @@ type (
CallbackWithValue func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error)

SQLExecutor interface {
// Interceptors
Interceptors(interceptors []SQLHook)
// Exec
ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error)
// Exec
ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error)
}
)

// BuildExecutor
func BuildExecutor(dbType types.DBType, txType types.TransactionType, query string) (SQLExecutor, error) {
parseCtx, err := parser.DoParser(query)
// BuildExecutor use db type and transaction type to build an executor. the executor can
// add custom hook, and intercept the user's business sql to generate the undo log.
func BuildExecutor(dbType types.DBType, transactionMode types.TransactionMode, query string) (SQLExecutor, error) {
parseContext, err := parser.DoParser(query)
if err != nil {
return nil, err
}

hooks := make([]SQLHook, 0, 4)
hooks = append(hooks, commonHook...)
hooks = append(hooks, hookSolts[parseCtx.SQLType]...)
hooks = append(hooks, commonHook...)
hooks = append(hooks, hookSolts[parseContext.SQLType]...)

if txType == types.XAMode {
if transactionMode == types.XAMode {
e := executorSoltsXA[dbType]()
e.Interceptors(hooks)
return e, nil
}

if txType == types.ATMode {
e := executorSoltsAT[dbType][parseCtx.ExecutorType]()
if transactionMode == types.ATMode {
e := executorSoltsAT[dbType][parseContext.ExecutorType]()
e.Interceptors(hooks)
return e, nil
}

factories, ok := executorSoltsAT[dbType]

if !ok {
log.Debugf("%s not found executor factories, return default Executor", dbType.String())
e := &BaseExecutor{}
@@ -113,10 +107,10 @@ func BuildExecutor(dbType types.DBType, txType types.TransactionType, query stri
return e, nil
}

supplier, ok := factories[parseCtx.ExecutorType]
supplier, ok := factories[parseContext.ExecutorType]
if !ok {
log.Debugf("%s not found executor for %s, return default Executor",
dbType.String(), parseCtx.ExecutorType)
dbType.String(), parseContext.ExecutorType)
e := &BaseExecutor{}
e.Interceptors(hooks)
return e, nil
@@ -128,19 +122,17 @@ func BuildExecutor(dbType types.DBType, txType types.TransactionType, query stri
}

type BaseExecutor struct {
is []SQLHook
ex SQLExecutor
hooks []SQLHook
ex SQLExecutor
}

// Interceptors
func (e *BaseExecutor) Interceptors(interceptors []SQLHook) {
e.is = interceptors
func (e *BaseExecutor) Interceptors(hooks []SQLHook) {
e.hooks = hooks
}

// ExecWithNamedValue
func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) {
for i := range e.is {
_ = e.is[i].Before(ctx, execCtx)
for _, hook := range e.hooks {
hook.Before(ctx, execCtx)
}

var (
@@ -167,8 +159,8 @@ func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Ex
}

defer func() {
for i := range e.is {
_ = e.is[i].After(ctx, execCtx)
for _, hook := range e.hooks {
hook.After(ctx, execCtx)
}
}()

@@ -210,10 +202,9 @@ func (e *BaseExecutor) prepareUndoLog(ctx context.Context, execCtx *types.ExecCo
return undoLogManager.FlushUndoLog(execCtx.TxCtx, execCtx.Conn)
}

// ExecWithValue
func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithValue) (types.ExecResult, error) {
for i := range e.is {
e.is[i].Before(ctx, execCtx)
for _, hook := range e.hooks {
hook.Before(ctx, execCtx)
}

var (
@@ -232,8 +223,8 @@ func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecCon
}

defer func() {
for i := range e.is {
_ = e.is[i].After(ctx, execCtx)
for _, hook := range e.hooks {
hook.After(ctx, execCtx)
}
}()

@@ -257,7 +248,7 @@ func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecCon
return result, err
}

func (h *BaseExecutor) beforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) {
func (e *BaseExecutor) beforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) {
if !tm.IsGlobalTx(ctx) {
return nil, nil
}
@@ -279,7 +270,7 @@ func (h *BaseExecutor) beforeImage(ctx context.Context, execCtx *types.ExecConte
}

// After
func (h *BaseExecutor) afterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) {
func (e *BaseExecutor) afterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) {
if !tm.IsGlobalTx(ctx) {
return nil, nil
}


+ 8
- 9
pkg/datasource/sql/exec/hook.go View File

@@ -29,7 +29,7 @@ var (
hookSolts = map[types.SQLType][]SQLHook{}
)

// RegisCommonHook not goroutine safe
// RegisterCommonHook not goroutine safe
func RegisterCommonHook(hook SQLHook) {
commonHook = append(commonHook, hook)
}
@@ -40,13 +40,16 @@ func CleanCommonHook() {

// RegisterHook not goroutine safe
func RegisterHook(hook SQLHook) {
_, ok := hookSolts[hook.Type()]
sqlType := hook.Type()
if sqlType == types.SQLTypeUnknown {
return
}

_, ok := hookSolts[sqlType]
if !ok {
hookSolts[hook.Type()] = make([]SQLHook, 0, 4)
hookSolts[sqlType] = make([]SQLHook, 0, 4)
}

hookSolts[hook.Type()] = append(hookSolts[hook.Type()], hook)
hookSolts[sqlType] = append(hookSolts[sqlType], hook)
}

// SQLHook SQL execution front and back interceptor
@@ -55,10 +58,6 @@ func RegisterHook(hook SQLHook) {
// case 3. SQL black and white list
type SQLHook interface {
Type() types.SQLType

// Before
Before(ctx context.Context, execCtx *types.ExecContext) error

// After
After(ctx context.Context, execCtx *types.ExecContext) error
}

+ 14
- 16
pkg/datasource/sql/exec/xa/executor_xa.go View File

@@ -24,28 +24,26 @@ import (
"github.com/seata/seata-go/pkg/datasource/sql/types"
)

// todo
// 完善XA prepare
//
// XAExecutor The XA transaction manager.
type XAExecutor struct {
is []exec.SQLHook
ex exec.SQLExecutor
hooks []exec.SQLHook
ex exec.SQLExecutor
}

// Interceptors
func (e *XAExecutor) Interceptors(interceptors []exec.SQLHook) {
e.is = interceptors
// Interceptors set xa executor hooks
func (e *XAExecutor) Interceptors(hooks []exec.SQLHook) {
e.hooks = hooks
}

// ExecWithNamedValue
func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
for i := range e.is {
e.is[i].Before(ctx, execCtx)
for _, hook := range e.hooks {
hook.Before(ctx, execCtx)
}

defer func() {
for i := range e.is {
e.is[i].After(ctx, execCtx)
for _, hook := range e.hooks {
hook.After(ctx, execCtx)
}
}()

@@ -58,13 +56,13 @@ func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Exec

// ExecWithValue
func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithValue) (types.ExecResult, error) {
for i := range e.is {
e.is[i].Before(ctx, execCtx)
for _, hook := range e.hooks {
hook.Before(ctx, execCtx)
}

defer func() {
for i := range e.is {
e.is[i].After(ctx, execCtx)
for _, hook := range e.hooks {
hook.After(ctx, execCtx)
}
}()



+ 9
- 15
pkg/datasource/sql/stmt.go View File

@@ -26,15 +26,11 @@ import (
)

type Stmt struct {
conn *Conn
// res
res *DBResource
// txCtx
conn *Conn
res *DBResource
txCtx *types.TransactionContext
// query
query string
// stmt
stmt driver.Stmt
stmt driver.Stmt
}

// Close closes the statement.
@@ -67,7 +63,7 @@ func (s *Stmt) NumInput() int {
//
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TxType, s.query)
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransactionMode, s.query)
if err != nil {
return nil, err
}
@@ -94,10 +90,8 @@ func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
return ret.GetRows(), nil
}

// StmtQueryContext enhances the Stmt interface by providing Query with context.
// QueryContext executes a query that may return rows, such as a
// SELECT.
//
// QueryContext StmtQueryContext enhances the Stmt interface by providing Query with context.
// QueryContext executes a query that may return rows, such as a SELECT.
// QueryContext must honor the context timeout and return when it is canceled.
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
stmt, ok := s.stmt.(driver.StmtQueryContext)
@@ -105,7 +99,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
return nil, driver.ErrSkip
}

executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TxType, s.query)
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransactionMode, s.query)
if err != nil {
return nil, err
}
@@ -138,7 +132,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
// Deprecated: Drivers should implement StmtExecContext instead (or additionally).
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
// in transaction, need run Executor
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TxType, s.query)
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransactionMode, s.query)
if err != nil {
return nil, err
}
@@ -173,7 +167,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
}

// in transaction, need run Executor
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TxType, s.query)
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransactionMode, s.query)
if err != nil {
return nil, err
}


+ 3
- 3
pkg/datasource/sql/tx.go View File

@@ -156,11 +156,11 @@ func (tx *Tx) register(ctx *types.TransactionContext) error {
}
request := rm.BranchRegisterParam{
Xid: ctx.XID,
BranchType: ctx.TxType.GetBranchType(),
BranchType: ctx.TransactionMode.BranchType(),
ResourceId: ctx.ResourceID,
LockKeys: lockKey,
}
dataSourceManager := datasource.GetDataSourceManager(ctx.TxType.GetBranchType())
dataSourceManager := datasource.GetDataSourceManager(ctx.TransactionMode.BranchType())
branchId, err := dataSourceManager.BranchRegister(context.Background(), request)
if err != nil {
log.Infof("Failed to report branch status: %s", err.Error())
@@ -181,7 +181,7 @@ func (tx *Tx) report(success bool) error {
BranchId: int64(tx.tranCtx.BranchID),
Status: status,
}
dataSourceManager := datasource.GetDataSourceManager(tx.tranCtx.TxType.GetBranchType())
dataSourceManager := datasource.GetDataSourceManager(tx.tranCtx.TransactionMode.BranchType())
if dataSourceManager == nil {
return errors.New("get dataSourceManager failed")
}


+ 1
- 1
pkg/datasource/sql/tx_at.go View File

@@ -43,7 +43,7 @@ func (tx *ATTx) Rollback() error {

originTx := tx.tx

if originTx.tranCtx.OpenGlobalTrsnaction() && originTx.tranCtx.IsBranchRegistered() {
if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() {
originTx.report(false)
}
}


+ 1
- 1
pkg/datasource/sql/tx_xa.go View File

@@ -37,7 +37,7 @@ func (tx *XATx) Rollback() error {

originTx := tx.tx

if originTx.tranCtx.OpenGlobalTrsnaction() && originTx.tranCtx.IsBranchRegistered() {
if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() {
originTx.report(false)
}
}


+ 6
- 14
pkg/datasource/sql/types/executor.go View File

@@ -25,9 +25,6 @@ import (
seatabytes "github.com/seata/seata-go/pkg/util/bytes"
)

// ExecutorType
//
//go:generate stringer -type=ExecutorType
type ExecutorType int32

const (
@@ -45,18 +42,13 @@ const (
)

type ParseContext struct {
// SQLType
SQLType SQLType
// ExecutorType
SQLType SQLType
ExecutorType ExecutorType
// InsertStmt
InsertStmt *ast.InsertStmt
// UpdateStmt
UpdateStmt *ast.UpdateStmt
SelectStmt *ast.SelectStmt
// DeleteStmt
DeleteStmt *ast.DeleteStmt
MultiStmt []*ParseContext
InsertStmt *ast.InsertStmt
UpdateStmt *ast.UpdateStmt
SelectStmt *ast.SelectStmt
DeleteStmt *ast.DeleteStmt
MultiStmt []*ParseContext
}

func (p *ParseContext) HasValidStmt() bool {


+ 17
- 19
pkg/datasource/sql/types/types.go View File

@@ -27,11 +27,9 @@ import (
"github.com/google/uuid"
)

//go:generate stringer -type=DBType
type DBType int16

type (
// DBType
// BranchPhase
BranchPhase int8
// IndexType index type
@@ -102,24 +100,24 @@ func ParseDBType(driverName string) DBType {
}
}

// TransactionType
type TransactionType int8
type TransactionMode int8

const (
_ TransactionType = iota
_ TransactionMode = iota
Local
XAMode
ATMode
)

func (t TransactionType) GetBranchType() branch.BranchType {
if t == XAMode {
func (t TransactionMode) BranchType() branch.BranchType {
switch t {
case XAMode:
return branch.BranchTypeXA
}
if t == ATMode {
case ATMode:
return branch.BranchTypeAT
default:
return branch.BranchTypeUnknow
}
return branch.BranchTypeUnknow
}

// TransactionContext seata-go‘s context of transaction
@@ -132,8 +130,8 @@ type TransactionContext struct {
DBType DBType
// TxOpt transaction option
TxOpt driver.TxOptions
// TxType transaction mode, eg. XA/AT
TxType TransactionType
// TransactionMode transaction mode, eg. XA/AT
TransactionMode TransactionMode
// ResourceID resource id, database-table
ResourceID string
// BranchID transaction branch unique id
@@ -167,16 +165,16 @@ type ExecContext struct {

func NewTxCtx() *TransactionContext {
return &TransactionContext{
LockKeys: make(map[string]struct{}, 0),
TxType: Local,
LocalTransID: uuid.New().String(),
RoundImages: &RoundRecordImage{},
LockKeys: make(map[string]struct{}, 0),
TransactionMode: Local,
LocalTransID: uuid.New().String(),
RoundImages: &RoundRecordImage{},
}
}

// HasUndoLog
func (t *TransactionContext) HasUndoLog() bool {
return t.TxType == ATMode && !t.RoundImages.IsEmpty()
return t.TransactionMode == ATMode && !t.RoundImages.IsEmpty()
}

// HasLockKey
@@ -184,8 +182,8 @@ func (t *TransactionContext) HasLockKey() bool {
return len(t.LockKeys) != 0
}

func (t *TransactionContext) OpenGlobalTrsnaction() bool {
return t.TxType != Local
func (t *TransactionContext) OpenGlobalTransaction() bool {
return t.TransactionMode != Local
}

func (t *TransactionContext) IsBranchRegistered() bool {


Loading…
Cancel
Save