Browse Source

(WIP)refactor:seata conn (#287)

* refactor:seata conn

* test: add unit test

* test: add unit test
tags/1.0.2-RC1
liaochuntao GitHub 3 years ago
parent
commit
b396c81a4e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1749 additions and 120 deletions
  1. +15
    -1
      .github/workflows/build.yml
  2. +94
    -74
      pkg/datasource/sql/conn.go
  3. +85
    -0
      pkg/datasource/sql/conn_at.go
  4. +127
    -0
      pkg/datasource/sql/conn_at_test.go
  5. +86
    -0
      pkg/datasource/sql/conn_xa.go
  6. +179
    -0
      pkg/datasource/sql/conn_xa_test.go
  7. +72
    -7
      pkg/datasource/sql/connector.go
  8. +129
    -0
      pkg/datasource/sql/connector_test.go
  9. +29
    -0
      pkg/datasource/sql/context.go
  10. +63
    -14
      pkg/datasource/sql/driver.go
  11. +93
    -0
      pkg/datasource/sql/driver_test.go
  12. +14
    -5
      pkg/datasource/sql/exec/executor.go
  13. +4
    -0
      pkg/datasource/sql/exec/hook.go
  14. +237
    -0
      pkg/datasource/sql/mock/mock_datasource_manager.go
  15. +411
    -0
      pkg/datasource/sql/mock/mock_driver.go
  16. +46
    -0
      pkg/datasource/sql/mock/test_driver.go
  17. +1
    -1
      pkg/datasource/sql/sql_test.go
  18. +4
    -4
      pkg/datasource/sql/stmt.go
  19. +51
    -5
      pkg/datasource/sql/tx.go
  20. +1
    -1
      pkg/datasource/sql/types/types.go
  21. +2
    -2
      pkg/datasource/sql/undo/base/undo.go
  22. +2
    -2
      pkg/datasource/sql/undo/mysql/undo.go
  23. +2
    -2
      pkg/datasource/sql/undo/undo.go
  24. +2
    -2
      pkg/datasource/sql/undo_test.go

+ 15
- 1
.github/workflows/build.yml View File

@@ -23,7 +23,7 @@ on:
push:
branches: [ master ]
pull_request:
branches: "*"
branches: ["*"]

jobs:
build:
@@ -39,6 +39,20 @@ jobs:
with:
go-version: 1.18

# close default MySQL-Server
- name: Shutdown default mysql
run: sudo service mysql stop

# run mysql server
- name: Create mysql database auth
uses: icomponent/mysql-action@master
with:
VERSION: 5.7
CONTAINER_NAME: mysql
PORT_MAPPING: 3306:3306
ROOT_PASSWORD: seata_go
DATABASE: seata_go_test

# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- name: "checkout ${{ github.ref }}"
uses: actions/checkout@v3


+ 94
- 74
pkg/datasource/sql/conn.go View File

@@ -19,29 +19,35 @@ package sql

import (
"context"
gosql "database/sql"
"database/sql/driver"
"errors"

"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/types"
)

// Conn is a connection to a database. It is not used concurrently
// by multiple goroutines.
//
// Conn is assumed to be stateful.

type Conn struct {
res *DBResource
txCtx *types.TransactionContext
targetConn driver.Conn
isInTransaction bool
autoCommit bool
autoCommitChanged bool
txType types.TransactionType
res *DBResource
txCtx *types.TransactionContext
targetConn driver.Conn
autoCommit bool
}

// ResetSession is called prior to executing a query on the connection
// if the connection has been used before. If the driver returns ErrBadConn
// the connection is discarded.
func (c *Conn) ResetSession(ctx context.Context) error {
conn, ok := c.targetConn.(driver.SessionResetter)
if !ok {
return driver.ErrSkip
}

c.txType = types.Local
c.txCtx = nil
return conn.ResetSession(ctx)
}
@@ -95,9 +101,8 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) {
return nil, driver.ErrSkip
}

if c.txCtx != nil {
// in transaction, need run Executor
executor, err := exec.BuildExecutor(c.res.dbType, query)
ret, err := c.createNewTxOnExecIfNeed(func() (types.ExecResult, error) {
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query)
if err != nil {
return nil, err
}
@@ -108,7 +113,7 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) {
Values: args,
}

ret, err := executor.ExecWithValue(context.Background(), execCtx,
return executor.ExecWithValue(context.Background(), execCtx,
func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) {
ret, err := conn.Exec(query, args)
if err != nil {
@@ -117,16 +122,12 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) {

return types.NewResult(types.WithResult(ret)), nil
})
if err != nil {
return nil, err
}
})

// todo if user has not opened a transaction, it may call tx.Commit() method to flush undo log

return ret.GetResult(), nil
if err != nil {
return nil, err
}

return conn.Exec(query, args)
return ret.GetResult(), nil
}

// ExecContext
@@ -142,41 +143,45 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name
return c.Exec(query, values)
}

executor, err := exec.BuildExecutor(c.res.dbType, query)
if err != nil {
return nil, err
}
ret, err := c.createNewTxOnExecIfNeed(func() (types.ExecResult, error) {
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query)
if err != nil {
return nil, err
}

execCtx := &exec.ExecContext{
TxCtx: c.txCtx,
Query: query,
NamedValues: args,
}
execCtx := &exec.ExecContext{
TxCtx: c.txCtx,
Query: query,
NamedValues: args,
}

ret, err := executor.ExecWithNamedValue(ctx, execCtx,
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) {
ret, err := targetConn.ExecContext(ctx, query, args)
if err != nil {
return nil, err
}
ret, err := executor.ExecWithNamedValue(ctx, execCtx,
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) {
ret, err := targetConn.ExecContext(ctx, query, args)
if err != nil {
return nil, err
}

return types.NewResult(types.WithResult(ret)), nil
})

return ret, err
})

return types.NewResult(types.WithResult(ret)), nil
})
if err != nil {
return nil, err
}

return ret.GetResult(), nil
}

// QueryContext
// Query
func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) {
conn, ok := c.targetConn.(driver.Queryer)
if !ok {
return nil, driver.ErrSkip
}

executor, err := exec.BuildExecutor(c.res.dbType, query)
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query)
if err != nil {
return nil, err
}
@@ -216,7 +221,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return c.Query(query, values)
}

executor, err := exec.BuildExecutor(c.res.dbType, query)
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query)
if err != nil {
return nil, err
}
@@ -252,10 +257,11 @@ func (c *Conn) Begin() (driver.Tx, error) {
return nil, err
}

c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.TxOpt = driver.TxOptions{}
c.autoCommit = true
if c.txCtx == nil {
c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.TxOpt = driver.TxOptions{}
}

return newTx(
withDriverConn(c),
@@ -264,6 +270,8 @@ func (c *Conn) Begin() (driver.Tx, error) {
)
}

// BeginTx Open a transaction and judge whether the current transaction needs to open a
// global transaction according to ctx. If so, it needs to be included in the transaction management of seata
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if conn, ok := c.targetConn.(driver.ConnBeginTx); ok {
tx, err := conn.BeginTx(ctx, opts)
@@ -271,11 +279,6 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
return nil, err
}

c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.TxOpt = opts
c.autoCommit = true

return newTx(
withDriverConn(c),
withTxCtx(c.txCtx),
@@ -283,32 +286,15 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
)
}

// Check the transaction level. If the transaction level is non-default
// then return an error here as the BeginTx driver value is not supported.
if opts.Isolation != driver.IsolationLevel(gosql.LevelDefault) {
return nil, errors.New("sql: driver does not support non-default isolation level")
}

// If a read-only transaction is requested return an error as the
// BeginTx driver value is not supported.
if opts.ReadOnly {
return nil, errors.New("sql: driver does not support read-only transactions")
}

if ctx.Done() == nil {
return c.Begin()
}

txi, err := c.Begin()
if err == nil {
select {
case <-ctx.Done():
txi.Rollback()
return nil, ctx.Err()
default:
}
if err != nil {
return nil, err
}
return txi, err
return newTx(
withDriverConn(c),
withTxCtx(c.txCtx),
withOriginTx(txi),
)
}

// Close invalidates and potentially stops any current
@@ -326,3 +312,37 @@ func (c *Conn) Close() error {
c.txCtx = nil
return c.targetConn.Close()
}

func (c *Conn) createNewTxOnExecIfNeed(f func() (types.ExecResult, error)) (types.ExecResult, error) {
var (
tx driver.Tx
err error
)

if c.txCtx.TransType != types.Local && c.autoCommit {
tx, err = c.Begin()
if err != nil {
return nil, err
}
}

defer func() {
if tx != nil {
tx.Rollback()
}
}()

ret, err := f()

if err != nil {
return nil, err
}

if tx != nil {
if err := tx.Commit(); err != nil {
return nil, err
}
}

return ret, nil
}

+ 85
- 0
pkg/datasource/sql/conn_at.go View File

@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sql

import (
"context"
"database/sql/driver"

"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/tm"
)

type ATConn struct {
*Conn
}

func (c *ATConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if c.createTxCtxIfAbsent(ctx) {
defer func() {
c.txCtx = nil
}()
}

return c.Conn.PrepareContext(ctx, query)
}

// ExecContext
func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if c.createTxCtxIfAbsent(ctx) {
defer func() {
c.txCtx = nil
}()
}

return c.Conn.ExecContext(ctx, query, args)
}

// BeginTx
func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.TxOpt = opts

if IsGlobalTx(ctx) {
c.txCtx.XaID = tm.GetXID(ctx)
c.txCtx.TransType = c.txType
}

return c.Conn.BeginTx(ctx, opts)
}

func (c *ATConn) createTxCtxIfAbsent(ctx context.Context) bool {
var onceTx bool

if IsGlobalTx(ctx) && c.txCtx == nil {
c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.XaID = tm.GetXID(ctx)
c.txCtx.TransType = types.ATMode
c.autoCommit = true
onceTx = true
}

if c.txCtx == nil {
c.txCtx = types.NewTxCtx()
onceTx = true
}

return onceTx
}

+ 127
- 0
pkg/datasource/sql/conn_at_test.go View File

@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sql

import (
"context"
"database/sql"
"sync/atomic"
"testing"

"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/mock"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/tm"
"github.com/stretchr/testify/assert"
)

func TestATConn_ExecContext(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
_ = mockMgr

db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
if err != nil {
t.Fatal(err)
}

defer db.Close()

_ = initMockAtConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector {
mockTx := mock.NewMockTestDriverTx(ctrl)
mockTx.EXPECT().Commit().AnyTimes().Return(nil)
mockTx.EXPECT().Rollback().AnyTimes().Return(nil)

mockConn := mock.NewMockTestDriverConn(ctrl)
mockConn.EXPECT().Begin().AnyTimes().Return(mockTx, nil)
mockConn.EXPECT().BeginTx(gomock.Any(), gomock.Any()).AnyTimes().Return(mockTx, nil)
baseMoclConn(mockConn)

connector := mock.NewMockTestDriverConnector(ctrl)
connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil)
return connector
})

mi := &mockSQLInterceptor{}

ti := &mockTxHook{}

exec.CleanCommonHook()
CleanTxHooks()
exec.RegisCommonHook(mi)
RegisterTxHook(ti)

t.Run("have xid", func(t *testing.T) {
ctx := tm.InitSeataContext(context.Background())
tm.SetXID(ctx, uuid.New().String())
t.Logf("set xid=%s", tm.GetXID(ctx))

before := func(_ context.Context, execCtx *exec.ExecContext) {
t.Logf("on exec xid=%s", execCtx.TxCtx.XaID)
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID)
assert.Equal(t, types.ATMode, execCtx.TxCtx.TransType)
}
mi.before = before

var comitCnt int32
beforeCommit := func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
assert.Equal(t, types.ATMode, tx.ctx.TransType)
}
ti.beforeCommit = beforeCommit

conn, err := db.Conn(context.Background())
assert.NoError(t, err)

_, err = conn.ExecContext(ctx, "SELECT 1")
assert.NoError(t, err)
_, err = db.ExecContext(ctx, "SELECT 1")
assert.NoError(t, err)

assert.Equal(t, int32(2), atomic.LoadInt32(&comitCnt))
})

t.Run("not xid", func(t *testing.T) {
mi.before = func(_ context.Context, execCtx *exec.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XaID)
assert.Equal(t, types.Local, execCtx.TxCtx.TransType)
}

var comitCnt int32
ti.beforeCommit = func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
}

conn, err := db.Conn(context.Background())
assert.NoError(t, err)

_, err = conn.ExecContext(context.Background(), "SELECT 1")
assert.NoError(t, err)
_, err = db.ExecContext(context.Background(), "SELECT 1")
assert.NoError(t, err)

_, err = db.Exec("SELECT 1")
assert.NoError(t, err)

assert.Equal(t, int32(0), atomic.LoadInt32(&comitCnt))
})
}

+ 86
- 0
pkg/datasource/sql/conn_xa.go View File

@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sql

import (
"context"
"database/sql/driver"

"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/tm"
)

type XAConn struct {
*Conn
}

// PrepareContext
func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if c.createTxCtxIfAbsent(ctx) {
defer func() {
c.txCtx = nil
}()
}

return c.Conn.PrepareContext(ctx, query)
}

// ExecContext
func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if c.createTxCtxIfAbsent(ctx) {
defer func() {
c.txCtx = nil
}()
}

return c.Conn.ExecContext(ctx, query, args)
}

// BeginTx
func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.TxOpt = opts

if IsGlobalTx(ctx) {
c.txCtx.TransType = types.XAMode
c.txCtx.XaID = tm.GetXID(ctx)
}

return c.Conn.BeginTx(ctx, opts)
}

func (c *XAConn) createTxCtxIfAbsent(ctx context.Context) bool {
var onceTx bool

if IsGlobalTx(ctx) && c.txCtx == nil {
c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.XaID = tm.GetXID(ctx)
c.txCtx.TransType = types.XAMode
c.autoCommit = true
onceTx = true
}

if c.txCtx == nil {
c.txCtx = types.NewTxCtx()
onceTx = true
}

return onceTx
}

+ 179
- 0
pkg/datasource/sql/conn_xa_test.go View File

@@ -0,0 +1,179 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sql

import (
"context"
"database/sql"
"database/sql/driver"
"sync/atomic"
"testing"

"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/mock"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/tm"
"github.com/stretchr/testify/assert"
)

type mockSQLInterceptor struct {
before func(ctx context.Context, execCtx *exec.ExecContext)
after func(ctx context.Context, execCtx *exec.ExecContext)
}

func (mi *mockSQLInterceptor) Type() types.SQLType {
return types.SQLTypeUnknown
}

// Before
func (mi *mockSQLInterceptor) Before(ctx context.Context, execCtx *exec.ExecContext) {
if mi.before != nil {
mi.before(ctx, execCtx)
}
}

// After
func (mi *mockSQLInterceptor) After(ctx context.Context, execCtx *exec.ExecContext) {
if mi.after != nil {
mi.after(ctx, execCtx)
}
}

type mockTxHook struct {
beforeCommit func(tx *Tx)
beforeRollback func(tx *Tx)
}

// BeforeCommit
func (mi *mockTxHook) BeforeCommit(tx *Tx) {
if mi.beforeCommit != nil {
mi.beforeCommit(tx)
}
}

// BeforeRollback
func (mi *mockTxHook) BeforeRollback(tx *Tx) {
if mi.beforeRollback != nil {
mi.beforeRollback(tx)
}
}

func baseMoclConn(mockConn *mock.MockTestDriverConn) {
mockConn.EXPECT().ExecContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil)
mockConn.EXPECT().Exec(gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil)
mockConn.EXPECT().ResetSession(gomock.Any()).AnyTimes().Return(nil)
mockConn.EXPECT().Close().AnyTimes().Return(nil)
}

func TestXAConn_ExecContext(t *testing.T) {

ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
_ = mockMgr

db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
if err != nil {
t.Fatal(err)
}

defer db.Close()

_ = initMockXaConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector {
mockTx := mock.NewMockTestDriverTx(ctrl)
mockTx.EXPECT().Commit().AnyTimes().Return(nil)
mockTx.EXPECT().Rollback().AnyTimes().Return(nil)

mockConn := mock.NewMockTestDriverConn(ctrl)
mockConn.EXPECT().Begin().AnyTimes().Return(mockTx, nil)
mockConn.EXPECT().BeginTx(gomock.Any(), gomock.Any()).AnyTimes().Return(mockTx, nil)
baseMoclConn(mockConn)

connector := mock.NewMockTestDriverConnector(ctrl)
connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil)
return connector
})

mi := &mockSQLInterceptor{}
ti := &mockTxHook{}

exec.CleanCommonHook()
CleanTxHooks()
exec.RegisCommonHook(mi)
RegisterTxHook(ti)

t.Run("have xid", func(t *testing.T) {
ctx := tm.InitSeataContext(context.Background())
tm.SetXID(ctx, uuid.New().String())
t.Logf("set xid=%s", tm.GetXID(ctx))

before := func(_ context.Context, execCtx *exec.ExecContext) {
t.Logf("on exec xid=%s", execCtx.TxCtx.XaID)
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID)
assert.Equal(t, types.XAMode, execCtx.TxCtx.TransType)
}
mi.before = before

var comitCnt int32
beforeCommit := func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
assert.Equal(t, tx.ctx.TransType, types.XAMode)
}
ti.beforeCommit = beforeCommit

conn, err := db.Conn(context.Background())
assert.NoError(t, err)

_, err = conn.ExecContext(ctx, "SELECT 1")
assert.NoError(t, err)
_, err = db.ExecContext(ctx, "SELECT 1")
assert.NoError(t, err)

assert.Equal(t, int32(2), atomic.LoadInt32(&comitCnt))
})

t.Run("not xid", func(t *testing.T) {
before := func(_ context.Context, execCtx *exec.ExecContext) {
assert.Equal(t, "", execCtx.TxCtx.XaID)
assert.Equal(t, types.Local, execCtx.TxCtx.TransType)
}
mi.before = before

var comitCnt int32
beforeCommit := func(tx *Tx) {
atomic.AddInt32(&comitCnt, 1)
}
ti.beforeCommit = beforeCommit

conn, err := db.Conn(context.Background())
assert.NoError(t, err)

_, err = conn.ExecContext(context.Background(), "SELECT 1")
assert.NoError(t, err)
_, err = db.ExecContext(context.Background(), "SELECT 1")
assert.NoError(t, err)

_, err = db.Exec("SELECT 1")
assert.NoError(t, err)

assert.Equal(t, int32(0), atomic.LoadInt32(&comitCnt))
})
}

+ 72
- 7
pkg/datasource/sql/connector.go View File

@@ -21,14 +21,77 @@ import (
"context"
"database/sql/driver"
"sync"

"github.com/seata/seata-go/pkg/datasource/sql/types"
)

type seataATConnector struct {
*seataConnector
transType types.TransactionType
}

func (c *seataATConnector) Connect(ctx context.Context) (driver.Conn, error) {
conn, err := c.seataConnector.Connect(ctx)
if err != nil {
return nil, err
}

_conn, _ := conn.(*Conn)

return &ATConn{
Conn: _conn,
}, nil
}

func (c *seataATConnector) Driver() driver.Driver {
return &seataATDriver{
seataDriver: c.seataConnector.Driver().(*seataDriver),
}
}

type seataXAConnector struct {
*seataConnector
transType types.TransactionType
}

func (c *seataXAConnector) Connect(ctx context.Context) (driver.Conn, error) {
conn, err := c.seataConnector.Connect(ctx)
if err != nil {
return nil, err
}

_conn, _ := conn.(*Conn)

return &XAConn{
Conn: _conn,
}, nil
}

func (c *seataXAConnector) Driver() driver.Driver {
return &seataXADriver{
seataDriver: c.seataConnector.Driver().(*seataDriver),
}
}

// A Connector represents a driver in a fixed configuration
// and can create any number of equivalent Conns for use
// by multiple goroutines.
//
// A Connector can be passed to sql.OpenDB, to allow drivers
// to implement their own sql.DB constructors, or returned by
// DriverContext's OpenConnector method, to allow drivers
// access to context and to avoid repeated parsing of driver
// configuration.
//
// If a Connector implements io.Closer, the sql package's DB.Close
// method will call Close and return error (if any).
type seataConnector struct {
conf *seataServerConfig
res *DBResource
once sync.Once
driver driver.Driver
target driver.Connector
transType types.TransactionType
conf *seataServerConfig
res *DBResource
once sync.Once
driver driver.Driver
target driver.Connector
}

// Connect returns a connection to the database.
@@ -50,7 +113,7 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) {
return nil, err
}

return &Conn{targetConn: conn, res: c.res}, nil
return &Conn{txType: types.Local, targetConn: conn, res: c.res}, nil
}

// Driver returns the underlying Driver of the Connector,
@@ -62,5 +125,7 @@ func (c *seataConnector) Driver() driver.Driver {
c.driver = d
})

return &SeataDriver{target: c.driver}
return &seataDriver{
target: c.driver,
}
}

+ 129
- 0
pkg/datasource/sql/connector_test.go View File

@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sql

import (
"context"
"database/sql"
"database/sql/driver"
"reflect"
"testing"

"github.com/golang/mock/gomock"
"github.com/seata/seata-go/pkg/datasource/sql/mock"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/stretchr/testify/assert"
)

type initConnectorFunc func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector

func initMockConnector(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector {
mockConn := mock.NewMockTestDriverConn(ctrl)

connector := mock.NewMockTestDriverConnector(ctrl)
connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil)
return connector
}

func initMockAtConnector(t *testing.T, ctrl *gomock.Controller, db *sql.DB, f initConnectorFunc) driver.Connector {
v := reflect.ValueOf(db)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}

field := v.FieldByName("connector")
fieldVal := GetUnexportedField(field)

atConnector, ok := fieldVal.(*seataATConnector)
assert.True(t, ok, "need return seata at connector")

v = reflect.ValueOf(atConnector)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
SetUnexportedField(v.FieldByName("target"), f(t, ctrl))

return fieldVal.(driver.Connector)
}

func Test_seataATConnector_Connect(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
_ = mockMgr

db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
if err != nil {
t.Fatal(err)
}

defer db.Close()

proxyConnector := initMockAtConnector(t, ctrl, db, initMockConnector)
conn, err := proxyConnector.Connect(context.Background())
assert.NoError(t, err)

atConn, ok := conn.(*ATConn)
assert.True(t, ok, "need return seata at connection")
assert.True(t, atConn.txType == types.Local, "init need local tx")
}

func initMockXaConnector(t *testing.T, ctrl *gomock.Controller, db *sql.DB, f initConnectorFunc) driver.Connector {
v := reflect.ValueOf(db)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}

field := v.FieldByName("connector")
fieldVal := GetUnexportedField(field)

atConnector, ok := fieldVal.(*seataXAConnector)
assert.True(t, ok, "need return seata xa connector")

v = reflect.ValueOf(atConnector)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
SetUnexportedField(v.FieldByName("target"), f(t, ctrl))

return fieldVal.(driver.Connector)
}

func Test_seataXAConnector_Connect(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
_ = mockMgr

db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
if err != nil {
t.Fatal(err)
}

defer db.Close()

proxyConnector := initMockXaConnector(t, ctrl, db, initMockConnector)
conn, err := proxyConnector.Connect(context.Background())
assert.NoError(t, err)

xaConn, ok := conn.(*XAConn)
assert.True(t, ok, "need return seata xa connection")
assert.True(t, xaConn.txType == types.Local, "init need local tx")
}

+ 29
- 0
pkg/datasource/sql/context.go View File

@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sql

import (
"context"

"github.com/seata/seata-go/pkg/tm"
)

// IsGlobalTx check is open global transactions
func IsGlobalTx(ctx context.Context) bool {
return tm.IsTransactionOpened(ctx)
}

+ 63
- 14
pkg/datasource/sql/driver.go View File

@@ -35,20 +35,69 @@ import (
)

const (
SeataMySQLDriver = "seata-mysql"
// SeataATMySQLDriver MySQL driver for AT mode
SeataATMySQLDriver = "seata-at-mysql"
// SeataXAMySQLDriver MySQL driver for XA mode
SeataXAMySQLDriver = "seata-xa-mysql"
)

func init() {
sql.Register(SeataMySQLDriver, &SeataDriver{
target: mysql.MySQLDriver{},
sql.Register(SeataATMySQLDriver, &seataATDriver{
seataDriver: &seataDriver{
transType: types.ATMode,
target: mysql.MySQLDriver{},
},
})
sql.Register(SeataXAMySQLDriver, &seataXADriver{
seataDriver: &seataDriver{
transType: types.XAMode,
target: mysql.MySQLDriver{},
},
})
}

type seataATDriver struct {
*seataDriver
}

func (d *seataATDriver) OpenConnector(name string) (c driver.Connector, err error) {
connector, err := d.seataDriver.OpenConnector(name)
if err != nil {
return nil, err
}

_connector, _ := connector.(*seataConnector)
_connector.transType = types.ATMode

return &seataATConnector{
seataConnector: _connector,
}, nil
}

type seataXADriver struct {
*seataDriver
}

func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err error) {
connector, err := d.seataDriver.OpenConnector(name)
if err != nil {
return nil, err
}

_connector, _ := connector.(*seataConnector)
_connector.transType = types.XAMode

return &seataXAConnector{
seataConnector: _connector,
}, nil
}

type SeataDriver struct {
target driver.Driver
type seataDriver struct {
transType types.TransactionType
target driver.Driver
}

func (d *SeataDriver) Open(name string) (driver.Conn, error) {
func (d *seataDriver) Open(name string) (driver.Conn, error) {
conn, err := d.target.Open(name)
if err != nil {
log.Errorf("open target connection: %w", err)
@@ -71,7 +120,7 @@ func (d *SeataDriver) Open(name string) (driver.Conn, error) {
return conn, nil
}

func (d *SeataDriver) OpenConnector(name string) (c driver.Connector, err error) {
func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) {
c = &dsnConnector{dsn: name, driver: d.target}
if driverCtx, ok := d.target.(driver.DriverContext); ok {
c, err = driverCtx.OpenConnector(name)
@@ -86,7 +135,7 @@ func (d *SeataDriver) OpenConnector(name string) (c driver.Connector, err error)
return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName())
}

proxy, err := registerResource(c, dbType, sql.OpenDB(c), name)
proxy, err := registerResource(c, d.transType, dbType, sql.OpenDB(c), name)
if err != nil {
log.Errorf("register resource: %w", err)
return nil, err
@@ -95,8 +144,8 @@ func (d *SeataDriver) OpenConnector(name string) (c driver.Connector, err error)
return proxy, nil
}

func (d *SeataDriver) getTargetDriverName() string {
return strings.ReplaceAll(SeataMySQLDriver, "seata-", "")
func (d *seataDriver) getTargetDriverName() string {
return "mysql"
}

type dsnConnector struct {
@@ -112,7 +161,7 @@ func (t *dsnConnector) Driver() driver.Driver {
return t.driver
}

func registerResource(connector driver.Connector, dbType types.DBType, db *sql.DB,
func registerResource(connector driver.Connector, txType types.TransactionType, dbType types.DBType, db *sql.DB,
dataSourceName string, opts ...seataOption) (driver.Connector, error) {
conf := loadConfig()
for i := range opts {
@@ -144,9 +193,9 @@ func registerResource(connector driver.Connector, dbType types.DBType, db *sql.D
}

return &seataConnector{
res: res,
target: connector,
conf: conf,
res: res,
target: connector,
conf: conf,
}, nil
}



+ 93
- 0
pkg/datasource/sql/driver_test.go View File

@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sql

import (
"database/sql"
"reflect"
"testing"

"github.com/golang/mock/gomock"

"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/mock"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/stretchr/testify/assert"
)

func initMockResourceManager(t *testing.T, ctrl *gomock.Controller) *mock.MockDataSourceManager {
mockResourceMgr := mock.NewMockDataSourceManager(ctrl)
datasource.RegisterResourceManager(branch.BranchTypeAT, mockResourceMgr)

mockResourceMgr.EXPECT().RegisterResource(gomock.Any()).AnyTimes().Return(nil)
mockResourceMgr.EXPECT().CreateTableMetaCache(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil)

return mockResourceMgr
}

func Test_seataATDriver_OpenConnector(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
_ = mockMgr

db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
if err != nil {
t.Fatal(err)
}

defer db.Close()

v := reflect.ValueOf(db)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}

field := v.FieldByName("connector")
fieldVal := GetUnexportedField(field)

_, ok := fieldVal.(*seataATConnector)
assert.True(t, ok, "need return seata at connector")
}

func Test_seataXADriver_OpenConnector(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
_ = mockMgr

db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
if err != nil {
t.Fatal(err)
}

defer db.Close()

v := reflect.ValueOf(db)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}

field := v.FieldByName("connector")
fieldVal := GetUnexportedField(field)

_, ok := fieldVal.(*seataXAConnector)
assert.True(t, ok, "need return seata xa connector")
}

+ 14
- 5
pkg/datasource/sql/exec/executor.go View File

@@ -56,8 +56,17 @@ type (
}
)

// buildExecutor
func BuildExecutor(dbType types.DBType, query string) (SQLExecutor, error) {
// BuildExecutor
func BuildExecutor(dbType types.DBType, txType types.TransactionType, query string) (SQLExecutor, error) {
if txType == types.XAMode {
hooks := make([]SQLInterceptor, 0, 4)
hooks = append(hooks, commonHook...)

e := &BaseExecutor{}
e.interceptors(hooks)
return e, nil
}

parseCtx, err := parser.DoParser(query)
if err != nil {
return nil, err
@@ -94,12 +103,12 @@ type BaseExecutor struct {
ex SQLExecutor
}

// Interceptors
// interceptors
func (e *BaseExecutor) interceptors(interceptors []SQLInterceptor) {
e.is = interceptors
}

// Exec
// ExecWithNamedValue
func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) {
for i := range e.is {
e.is[i].Before(ctx, execCtx)
@@ -118,7 +127,7 @@ func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *ExecCont
return f(ctx, execCtx.Query, execCtx.NamedValues)
}

// Exec
// ExecWithValue
func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *ExecContext, f CallbackWithValue) (types.ExecResult, error) {
for i := range e.is {
e.is[i].Before(ctx, execCtx)


+ 4
- 0
pkg/datasource/sql/exec/hook.go View File

@@ -35,6 +35,10 @@ func RegisCommonHook(hook SQLInterceptor) {
commonHook = append(commonHook, hook)
}

func CleanCommonHook() {
commonHook = make([]SQLInterceptor, 0, 4)
}

// RegisHook not goroutine safe
func RegisHook(hook SQLInterceptor) {
_, ok := hookSolts[hook.Type()]


+ 237
- 0
pkg/datasource/sql/mock/mock_datasource_manager.go View File

@@ -0,0 +1,237 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: datasource_manager.go

// Package mock is a generated GoMock package.
package mock

import (
context "context"
sql "database/sql"
gomock "github.com/golang/mock/gomock"
datasource "github.com/seata/seata-go/pkg/datasource/sql/datasource"
types "github.com/seata/seata-go/pkg/datasource/sql/types"
branch "github.com/seata/seata-go/pkg/protocol/branch"
message "github.com/seata/seata-go/pkg/protocol/message"
rm "github.com/seata/seata-go/pkg/rm"
reflect "reflect"
)

// MockDataSourceManager is a mock of DataSourceManager interface
type MockDataSourceManager struct {
ctrl *gomock.Controller
recorder *MockDataSourceManagerMockRecorder
}

// MockDataSourceManagerMockRecorder is the mock recorder for MockDataSourceManager
type MockDataSourceManagerMockRecorder struct {
mock *MockDataSourceManager
}

// NewMockDataSourceManager creates a new mock instance
func NewMockDataSourceManager(ctrl *gomock.Controller) *MockDataSourceManager {
mock := &MockDataSourceManager{ctrl: ctrl}
mock.recorder = &MockDataSourceManagerMockRecorder{mock}
return mock
}

// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockDataSourceManager) EXPECT() *MockDataSourceManagerMockRecorder {
return m.recorder
}

// RegisterResource mocks base method
func (m *MockDataSourceManager) RegisterResource(resource rm.Resource) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RegisterResource", resource)
ret0, _ := ret[0].(error)
return ret0
}

// RegisterResource indicates an expected call of RegisterResource
func (mr *MockDataSourceManagerMockRecorder) RegisterResource(resource interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterResource", reflect.TypeOf((*MockDataSourceManager)(nil).RegisterResource), resource)
}

// UnregisterResource mocks base method
func (m *MockDataSourceManager) UnregisterResource(resource rm.Resource) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnregisterResource", resource)
ret0, _ := ret[0].(error)
return ret0
}

// UnregisterResource indicates an expected call of UnregisterResource
func (mr *MockDataSourceManagerMockRecorder) UnregisterResource(resource interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterResource", reflect.TypeOf((*MockDataSourceManager)(nil).UnregisterResource), resource)
}

// GetManagedResources mocks base method
func (m *MockDataSourceManager) GetManagedResources() map[string]rm.Resource {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetManagedResources")
ret0, _ := ret[0].(map[string]rm.Resource)
return ret0
}

// GetManagedResources indicates an expected call of GetManagedResources
func (mr *MockDataSourceManagerMockRecorder) GetManagedResources() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedResources", reflect.TypeOf((*MockDataSourceManager)(nil).GetManagedResources))
}

// BranchRollback mocks base method
func (m *MockDataSourceManager) BranchRollback(ctx context.Context, req message.BranchRollbackRequest) (branch.BranchStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BranchRollback", ctx, req)
ret0, _ := ret[0].(branch.BranchStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// BranchRollback indicates an expected call of BranchRollback
func (mr *MockDataSourceManagerMockRecorder) BranchRollback(ctx, req interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRollback", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRollback), ctx, req)
}

// BranchCommit mocks base method
func (m *MockDataSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BranchCommit", ctx, req)
ret0, _ := ret[0].(branch.BranchStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// BranchCommit indicates an expected call of BranchCommit
func (mr *MockDataSourceManagerMockRecorder) BranchCommit(ctx, req interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchCommit", reflect.TypeOf((*MockDataSourceManager)(nil).BranchCommit), ctx, req)
}

// LockQuery mocks base method
func (m *MockDataSourceManager) LockQuery(ctx context.Context, req message.GlobalLockQueryRequest) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LockQuery", ctx, req)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// LockQuery indicates an expected call of LockQuery
func (mr *MockDataSourceManagerMockRecorder) LockQuery(ctx, req interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockQuery", reflect.TypeOf((*MockDataSourceManager)(nil).LockQuery), ctx, req)
}

// BranchRegister mocks base method
func (m *MockDataSourceManager) BranchRegister(ctx context.Context, clientId string, req message.BranchRegisterRequest) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BranchRegister", ctx, clientId, req)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// BranchRegister indicates an expected call of BranchRegister
func (mr *MockDataSourceManagerMockRecorder) BranchRegister(ctx, clientId, req interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRegister", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRegister), ctx, clientId, req)
}

// BranchReport mocks base method
func (m *MockDataSourceManager) BranchReport(ctx context.Context, req message.BranchReportRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BranchReport", ctx, req)
ret0, _ := ret[0].(error)
return ret0
}

// BranchReport indicates an expected call of BranchReport
func (mr *MockDataSourceManagerMockRecorder) BranchReport(ctx, req interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchReport", reflect.TypeOf((*MockDataSourceManager)(nil).BranchReport), ctx, req)
}

// CreateTableMetaCache mocks base method
func (m *MockDataSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (datasource.TableMetaCache, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateTableMetaCache", ctx, resID, dbType, db)
ret0, _ := ret[0].(datasource.TableMetaCache)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// CreateTableMetaCache indicates an expected call of CreateTableMetaCache
func (mr *MockDataSourceManagerMockRecorder) CreateTableMetaCache(ctx, resID, dbType, db interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTableMetaCache", reflect.TypeOf((*MockDataSourceManager)(nil).CreateTableMetaCache), ctx, resID, dbType, db)
}

// MockTableMetaCache is a mock of TableMetaCache interface
type MockTableMetaCache struct {
ctrl *gomock.Controller
recorder *MockTableMetaCacheMockRecorder
}

// MockTableMetaCacheMockRecorder is the mock recorder for MockTableMetaCache
type MockTableMetaCacheMockRecorder struct {
mock *MockTableMetaCache
}

// NewMockTableMetaCache creates a new mock instance
func NewMockTableMetaCache(ctrl *gomock.Controller) *MockTableMetaCache {
mock := &MockTableMetaCache{ctrl: ctrl}
mock.recorder = &MockTableMetaCacheMockRecorder{mock}
return mock
}

// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockTableMetaCache) EXPECT() *MockTableMetaCacheMockRecorder {
return m.recorder
}

// Init mocks base method
func (m *MockTableMetaCache) Init(ctx context.Context, conn *sql.DB) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", ctx, conn)
ret0, _ := ret[0].(error)
return ret0
}

// Init indicates an expected call of Init
func (mr *MockTableMetaCacheMockRecorder) Init(ctx, conn interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockTableMetaCache)(nil).Init), ctx, conn)
}

// GetTableMeta mocks base method
func (m *MockTableMetaCache) GetTableMeta(table string) (types.TableMeta, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTableMeta", table)
ret0, _ := ret[0].(types.TableMeta)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// GetTableMeta indicates an expected call of GetTableMeta
func (mr *MockTableMetaCacheMockRecorder) GetTableMeta(table interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMeta", reflect.TypeOf((*MockTableMetaCache)(nil).GetTableMeta), table)
}

// Destroy mocks base method
func (m *MockTableMetaCache) Destroy() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Destroy")
ret0, _ := ret[0].(error)
return ret0
}

// Destroy indicates an expected call of Destroy
func (mr *MockTableMetaCacheMockRecorder) Destroy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockTableMetaCache)(nil).Destroy))
}

+ 411
- 0
pkg/datasource/sql/mock/mock_driver.go View File

@@ -0,0 +1,411 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: test_driver.go

// Package mock is a generated GoMock package.
package mock

import (
context "context"
driver "database/sql/driver"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)

// MockTestDriverConnector is a mock of TestDriverConnector interface
type MockTestDriverConnector struct {
ctrl *gomock.Controller
recorder *MockTestDriverConnectorMockRecorder
}

// MockTestDriverConnectorMockRecorder is the mock recorder for MockTestDriverConnector
type MockTestDriverConnectorMockRecorder struct {
mock *MockTestDriverConnector
}

// NewMockTestDriverConnector creates a new mock instance
func NewMockTestDriverConnector(ctrl *gomock.Controller) *MockTestDriverConnector {
mock := &MockTestDriverConnector{ctrl: ctrl}
mock.recorder = &MockTestDriverConnectorMockRecorder{mock}
return mock
}

// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockTestDriverConnector) EXPECT() *MockTestDriverConnectorMockRecorder {
return m.recorder
}

// Connect mocks base method
func (m *MockTestDriverConnector) Connect(arg0 context.Context) (driver.Conn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", arg0)
ret0, _ := ret[0].(driver.Conn)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// Connect indicates an expected call of Connect
func (mr *MockTestDriverConnectorMockRecorder) Connect(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockTestDriverConnector)(nil).Connect), arg0)
}

// Driver mocks base method
func (m *MockTestDriverConnector) Driver() driver.Driver {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Driver")
ret0, _ := ret[0].(driver.Driver)
return ret0
}

// Driver indicates an expected call of Driver
func (mr *MockTestDriverConnectorMockRecorder) Driver() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Driver", reflect.TypeOf((*MockTestDriverConnector)(nil).Driver))
}

// MockTestDriverConn is a mock of TestDriverConn interface
type MockTestDriverConn struct {
ctrl *gomock.Controller
recorder *MockTestDriverConnMockRecorder
}

// MockTestDriverConnMockRecorder is the mock recorder for MockTestDriverConn
type MockTestDriverConnMockRecorder struct {
mock *MockTestDriverConn
}

// NewMockTestDriverConn creates a new mock instance
func NewMockTestDriverConn(ctrl *gomock.Controller) *MockTestDriverConn {
mock := &MockTestDriverConn{ctrl: ctrl}
mock.recorder = &MockTestDriverConnMockRecorder{mock}
return mock
}

// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockTestDriverConn) EXPECT() *MockTestDriverConnMockRecorder {
return m.recorder
}

// Prepare mocks base method
func (m *MockTestDriverConn) Prepare(query string) (driver.Stmt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Prepare", query)
ret0, _ := ret[0].(driver.Stmt)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// Prepare indicates an expected call of Prepare
func (mr *MockTestDriverConnMockRecorder) Prepare(query interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockTestDriverConn)(nil).Prepare), query)
}

// Close mocks base method
func (m *MockTestDriverConn) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}

// Close indicates an expected call of Close
func (mr *MockTestDriverConnMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTestDriverConn)(nil).Close))
}

// Begin mocks base method
func (m *MockTestDriverConn) Begin() (driver.Tx, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Begin")
ret0, _ := ret[0].(driver.Tx)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// Begin indicates an expected call of Begin
func (mr *MockTestDriverConnMockRecorder) Begin() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTestDriverConn)(nil).Begin))
}

// BeginTx mocks base method
func (m *MockTestDriverConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginTx", ctx, opts)
ret0, _ := ret[0].(driver.Tx)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// BeginTx indicates an expected call of BeginTx
func (mr *MockTestDriverConnMockRecorder) BeginTx(ctx, opts interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockTestDriverConn)(nil).BeginTx), ctx, opts)
}

// Ping mocks base method
func (m *MockTestDriverConn) Ping(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Ping", ctx)
ret0, _ := ret[0].(error)
return ret0
}

// Ping indicates an expected call of Ping
func (mr *MockTestDriverConnMockRecorder) Ping(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockTestDriverConn)(nil).Ping), ctx)
}

// PrepareContext mocks base method
func (m *MockTestDriverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PrepareContext", ctx, query)
ret0, _ := ret[0].(driver.Stmt)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// PrepareContext indicates an expected call of PrepareContext
func (mr *MockTestDriverConnMockRecorder) PrepareContext(ctx, query interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrepareContext", reflect.TypeOf((*MockTestDriverConn)(nil).PrepareContext), ctx, query)
}

// Query mocks base method
func (m *MockTestDriverConn) Query(query string, args []driver.Value) (driver.Rows, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Query", query, args)
ret0, _ := ret[0].(driver.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// Query indicates an expected call of Query
func (mr *MockTestDriverConnMockRecorder) Query(query, args interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTestDriverConn)(nil).Query), query, args)
}

// QueryContext mocks base method
func (m *MockTestDriverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "QueryContext", ctx, query, args)
ret0, _ := ret[0].(driver.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// QueryContext indicates an expected call of QueryContext
func (mr *MockTestDriverConnMockRecorder) QueryContext(ctx, query, args interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockTestDriverConn)(nil).QueryContext), ctx, query, args)
}

// Exec mocks base method
func (m *MockTestDriverConn) Exec(query string, args []driver.Value) (driver.Result, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Exec", query, args)
ret0, _ := ret[0].(driver.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// Exec indicates an expected call of Exec
func (mr *MockTestDriverConnMockRecorder) Exec(query, args interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTestDriverConn)(nil).Exec), query, args)
}

// ExecContext mocks base method
func (m *MockTestDriverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExecContext", ctx, query, args)
ret0, _ := ret[0].(driver.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// ExecContext indicates an expected call of ExecContext
func (mr *MockTestDriverConnMockRecorder) ExecContext(ctx, query, args interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockTestDriverConn)(nil).ExecContext), ctx, query, args)
}

// ResetSession mocks base method
func (m *MockTestDriverConn) ResetSession(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ResetSession", ctx)
ret0, _ := ret[0].(error)
return ret0
}

// ResetSession indicates an expected call of ResetSession
func (mr *MockTestDriverConnMockRecorder) ResetSession(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetSession", reflect.TypeOf((*MockTestDriverConn)(nil).ResetSession), ctx)
}

// MockTestDriverStmt is a mock of TestDriverStmt interface
type MockTestDriverStmt struct {
ctrl *gomock.Controller
recorder *MockTestDriverStmtMockRecorder
}

// MockTestDriverStmtMockRecorder is the mock recorder for MockTestDriverStmt
type MockTestDriverStmtMockRecorder struct {
mock *MockTestDriverStmt
}

// NewMockTestDriverStmt creates a new mock instance
func NewMockTestDriverStmt(ctrl *gomock.Controller) *MockTestDriverStmt {
mock := &MockTestDriverStmt{ctrl: ctrl}
mock.recorder = &MockTestDriverStmtMockRecorder{mock}
return mock
}

// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockTestDriverStmt) EXPECT() *MockTestDriverStmtMockRecorder {
return m.recorder
}

// Close mocks base method
func (m *MockTestDriverStmt) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}

// Close indicates an expected call of Close
func (mr *MockTestDriverStmtMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTestDriverStmt)(nil).Close))
}

// NumInput mocks base method
func (m *MockTestDriverStmt) NumInput() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NumInput")
ret0, _ := ret[0].(int)
return ret0
}

// NumInput indicates an expected call of NumInput
func (mr *MockTestDriverStmtMockRecorder) NumInput() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NumInput", reflect.TypeOf((*MockTestDriverStmt)(nil).NumInput))
}

// Exec mocks base method
func (m *MockTestDriverStmt) Exec(args []driver.Value) (driver.Result, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Exec", args)
ret0, _ := ret[0].(driver.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// Exec indicates an expected call of Exec
func (mr *MockTestDriverStmtMockRecorder) Exec(args interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTestDriverStmt)(nil).Exec), args)
}

// Query mocks base method
func (m *MockTestDriverStmt) Query(args []driver.Value) (driver.Rows, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Query", args)
ret0, _ := ret[0].(driver.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// Query indicates an expected call of Query
func (mr *MockTestDriverStmtMockRecorder) Query(args interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTestDriverStmt)(nil).Query), args)
}

// QueryContext mocks base method
func (m *MockTestDriverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "QueryContext", ctx, args)
ret0, _ := ret[0].(driver.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// QueryContext indicates an expected call of QueryContext
func (mr *MockTestDriverStmtMockRecorder) QueryContext(ctx, args interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockTestDriverStmt)(nil).QueryContext), ctx, args)
}

// ExecContext mocks base method
func (m *MockTestDriverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExecContext", ctx, args)
ret0, _ := ret[0].(driver.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}

// ExecContext indicates an expected call of ExecContext
func (mr *MockTestDriverStmtMockRecorder) ExecContext(ctx, args interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockTestDriverStmt)(nil).ExecContext), ctx, args)
}

// MockTestDriverTx is a mock of TestDriverTx interface
type MockTestDriverTx struct {
ctrl *gomock.Controller
recorder *MockTestDriverTxMockRecorder
}

// MockTestDriverTxMockRecorder is the mock recorder for MockTestDriverTx
type MockTestDriverTxMockRecorder struct {
mock *MockTestDriverTx
}

// NewMockTestDriverTx creates a new mock instance
func NewMockTestDriverTx(ctrl *gomock.Controller) *MockTestDriverTx {
mock := &MockTestDriverTx{ctrl: ctrl}
mock.recorder = &MockTestDriverTxMockRecorder{mock}
return mock
}

// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockTestDriverTx) EXPECT() *MockTestDriverTxMockRecorder {
return m.recorder
}

// Commit mocks base method
func (m *MockTestDriverTx) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}

// Commit indicates an expected call of Commit
func (mr *MockTestDriverTxMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTestDriverTx)(nil).Commit))
}

// Rollback mocks base method
func (m *MockTestDriverTx) Rollback() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Rollback")
ret0, _ := ret[0].(error)
return ret0
}

// Rollback indicates an expected call of Rollback
func (mr *MockTestDriverTxMockRecorder) Rollback() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTestDriverTx)(nil).Rollback))
}

+ 46
- 0
pkg/datasource/sql/mock/test_driver.go View File

@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mock

import "database/sql/driver"

type TestDriverConnector interface {
driver.Connector
}

type TestDriverConn interface {
driver.Conn
driver.ConnBeginTx
driver.Pinger
driver.ConnPrepareContext
driver.Queryer
driver.QueryerContext
driver.Execer
driver.ExecerContext
driver.SessionResetter
}

type TestDriverStmt interface {
driver.Stmt
driver.StmtQueryContext
driver.StmtExecContext
}

type TestDriverTx interface {
driver.Tx
}

+ 1
- 1
pkg/datasource/sql/sql_test.go View File

@@ -36,7 +36,7 @@ func Test_SQLOpen(t *testing.T) {
t.SkipNow()
log.Info("begin test")
var err error
db, err = sql.Open(SeataMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/polaris_server?multiStatements=true")
db, err = sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
if err != nil {
t.Fatal(err)
}


+ 4
- 4
pkg/datasource/sql/stmt.go View File

@@ -67,7 +67,7 @@ func (s *Stmt) NumInput() int {
//
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
executor, err := exec.BuildExecutor(s.res.dbType, s.query)
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query)
if err != nil {
return nil, err
}
@@ -105,7 +105,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
return nil, driver.ErrSkip
}

executor, err := exec.BuildExecutor(s.res.dbType, s.query)
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query)
if err != nil {
return nil, err
}
@@ -138,7 +138,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
// Deprecated: Drivers should implement StmtExecContext instead (or additionally).
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
// in transaction, need run Executor
executor, err := exec.BuildExecutor(s.res.dbType, s.query)
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query)
if err != nil {
return nil, err
}
@@ -173,7 +173,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
}

// in transaction, need run Executor
executor, err := exec.BuildExecutor(s.res.dbType, s.query)
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query)
if err != nil {
return nil, err
}


+ 51
- 5
pkg/datasource/sql/tx.go View File

@@ -20,6 +20,7 @@ package sql
import (
"context"
"database/sql/driver"
"sync"

"github.com/seata/seata-go/pkg/datasource/sql/undo"

@@ -34,7 +35,34 @@ import (

const REPORT_RETRY_COUNT = 5

type txOption func(tx *Tx)
var (
hl sync.RWMutex
txHooks []txHook
)

func RegisterTxHook(h txHook) {
hl.Lock()
defer hl.Unlock()

txHooks = append(txHooks, h)
}

func CleanTxHooks() {
hl.Lock()
defer hl.Unlock()

txHooks = make([]txHook, 0, 4)
}

type (
txOption func(tx *Tx)

txHook interface {
BeforeCommit(tx *Tx)

BeforeRollback(tx *Tx)
}
)

func newTx(opts ...txOption) (driver.Tx, error) {
tx := new(Tx)
@@ -83,6 +111,15 @@ type Tx struct {
// case 2. not need flush undolog, is XA mode, do local transaction commit
// case 3. need run AT transaction
func (tx *Tx) Commit() error {
if len(txHooks) != 0 {
hl.RLock()
defer hl.RUnlock()

for i := range txHooks {
txHooks[i].BeforeCommit(tx)
}
}

if tx.ctx.TransType == types.Local {
return tx.commitOnLocal()
}
@@ -96,6 +133,15 @@ func (tx *Tx) Commit() error {
}

func (tx *Tx) Rollback() error {
if len(txHooks) != 0 {
hl.RLock()
defer hl.RUnlock()

for i := range txHooks {
txHooks[i].BeforeRollback(tx)
}
}

err := tx.target.Rollback()
if err != nil {
if tx.ctx.OpenGlobalTrsnaction() && tx.ctx.IsBranchRegistered() {
@@ -124,7 +170,7 @@ func (tx *Tx) commitOnXA() error {
// commitOnAT
func (tx *Tx) commitOnAT() error {
// if TX-Mode is AT, run regis this transaction branch
if err := tx.regis(tx.ctx); err != nil {
if err := tx.register(tx.ctx); err != nil {
return err
}

@@ -133,7 +179,7 @@ func (tx *Tx) commitOnAT() error {
return err
}

if err := undoLogMgr.FlushUndoLog(tx.ctx, nil); err != nil {
if err := undoLogMgr.FlushUndoLog(tx.ctx, tx.conn.targetConn); err != nil {
if rerr := tx.report(false); rerr != nil {
return errors.WithStack(rerr)
}
@@ -151,8 +197,8 @@ func (tx *Tx) commitOnAT() error {
return nil
}

// regis
func (tx *Tx) regis(ctx *types.TransactionContext) error {
// register
func (tx *Tx) register(ctx *types.TransactionContext) error {
if !ctx.HasUndoLog() || !ctx.HasLockKey() {
return nil
}


+ 1
- 1
pkg/datasource/sql/types/types.go View File

@@ -103,7 +103,7 @@ type TransactionContext struct {
func NewTxCtx() *TransactionContext {
return &TransactionContext{
LockKeys: make([]string, 0, 4),
TransType: ATMode,
TransType: Local,
LocalTransID: uuid.New().String(),
RoundImages: &RoundRecordImage{},
}


+ 2
- 2
pkg/datasource/sql/undo/base/undo.go View File

@@ -48,7 +48,7 @@ func (m *BaseUndoLogManager) Init() {
}

// InsertUndoLog
func (m *BaseUndoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Tx) error {
func (m *BaseUndoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Conn) error {
return nil
}

@@ -102,7 +102,7 @@ func (m *BaseUndoLogManager) BatchDeleteUndoLog(xid []string, branchID []int64,
}

// FlushUndoLog
func (m *BaseUndoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Tx) error {
func (m *BaseUndoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Conn) error {
return nil
}



+ 2
- 2
pkg/datasource/sql/undo/mysql/undo.go View File

@@ -39,7 +39,7 @@ func (m *undoLogManager) Init() {
}

// InsertUndoLog
func (m *undoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Tx) error {
func (m *undoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Conn) error {
return m.Base.InsertUndoLog(l, tx)
}

@@ -54,7 +54,7 @@ func (m *undoLogManager) BatchDeleteUndoLog(xid []string, branchID []int64, conn
}

// FlushUndoLog
func (m *undoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Tx) error {
func (m *undoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Conn) error {
return m.Base.FlushUndoLog(txCtx, tx)
}



+ 2
- 2
pkg/datasource/sql/undo/undo.go View File

@@ -50,13 +50,13 @@ func Regis(m UndoLogManager) error {
type UndoLogManager interface {
Init()
// InsertUndoLog
InsertUndoLog(l []BranchUndoLog, tx driver.Tx) error
InsertUndoLog(l []BranchUndoLog, tx driver.Conn) error
// DeleteUndoLog
DeleteUndoLog(ctx context.Context, xid string, branchID int64, conn *sql.Conn) error
// BatchDeleteUndoLog
BatchDeleteUndoLog(xid []string, branchID []int64, conn *sql.Conn) error
// FlushUndoLog
FlushUndoLog(txCtx *types.TransactionContext, tx driver.Tx) error
FlushUndoLog(txCtx *types.TransactionContext, tx driver.Conn) error
// RunUndo
RunUndo(xid string, branchID int64, conn *sql.Conn) error
// DBType


+ 2
- 2
pkg/datasource/sql/undo_test.go View File

@@ -32,7 +32,7 @@ func TestBatchDeleteUndoLogs(t *testing.T) {
t.SkipNow()

testBatchDeleteUndoLogs := func() {
db, err := sql.Open(SeataMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true")
db, err := sql.Open(SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true")
assert.Nil(t, err)

sqlConn, err := db.Conn(context.Background())
@@ -54,7 +54,7 @@ func TestDeleteUndoLogs(t *testing.T) {
t.SkipNow()

testDeleteUndoLogs := func() {
db, err := sql.Open(SeataMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true")
db, err := sql.Open(SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true")
assert.Nil(t, err)

ctx := context.Background()


Loading…
Cancel
Save