From b396c81a4e5d26c4b375695e6fcb75b5d9aa901c Mon Sep 17 00:00:00 2001 From: liaochuntao Date: Tue, 27 Sep 2022 13:00:46 +0800 Subject: [PATCH] (WIP)refactor:seata conn (#287) * refactor:seata conn * test: add unit test * test: add unit test --- .github/workflows/build.yml | 16 +- pkg/datasource/sql/conn.go | 168 +++---- pkg/datasource/sql/conn_at.go | 85 ++++ pkg/datasource/sql/conn_at_test.go | 127 ++++++ pkg/datasource/sql/conn_xa.go | 86 ++++ pkg/datasource/sql/conn_xa_test.go | 179 ++++++++ pkg/datasource/sql/connector.go | 79 +++- pkg/datasource/sql/connector_test.go | 129 ++++++ pkg/datasource/sql/context.go | 29 ++ pkg/datasource/sql/driver.go | 77 +++- pkg/datasource/sql/driver_test.go | 93 ++++ pkg/datasource/sql/exec/executor.go | 19 +- pkg/datasource/sql/exec/hook.go | 4 + .../sql/mock/mock_datasource_manager.go | 237 ++++++++++ pkg/datasource/sql/mock/mock_driver.go | 411 ++++++++++++++++++ pkg/datasource/sql/mock/test_driver.go | 46 ++ pkg/datasource/sql/sql_test.go | 2 +- pkg/datasource/sql/stmt.go | 8 +- pkg/datasource/sql/tx.go | 56 ++- pkg/datasource/sql/types/types.go | 2 +- pkg/datasource/sql/undo/base/undo.go | 4 +- pkg/datasource/sql/undo/mysql/undo.go | 4 +- pkg/datasource/sql/undo/undo.go | 4 +- pkg/datasource/sql/undo_test.go | 4 +- 24 files changed, 1749 insertions(+), 120 deletions(-) create mode 100644 pkg/datasource/sql/conn_at.go create mode 100644 pkg/datasource/sql/conn_at_test.go create mode 100644 pkg/datasource/sql/conn_xa.go create mode 100644 pkg/datasource/sql/conn_xa_test.go create mode 100644 pkg/datasource/sql/connector_test.go create mode 100644 pkg/datasource/sql/context.go create mode 100644 pkg/datasource/sql/driver_test.go create mode 100644 pkg/datasource/sql/mock/mock_datasource_manager.go create mode 100644 pkg/datasource/sql/mock/mock_driver.go create mode 100644 pkg/datasource/sql/mock/test_driver.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 352aaf55..4794be1d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -23,7 +23,7 @@ on: push: branches: [ master ] pull_request: - branches: "*" + branches: ["*"] jobs: build: @@ -39,6 +39,20 @@ jobs: with: go-version: 1.18 + # close default MySQL-Server + - name: Shutdown default mysql + run: sudo service mysql stop + + # run mysql server + - name: Create mysql database auth + uses: icomponent/mysql-action@master + with: + VERSION: 5.7 + CONTAINER_NAME: mysql + PORT_MAPPING: 3306:3306 + ROOT_PASSWORD: seata_go + DATABASE: seata_go_test + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - name: "checkout ${{ github.ref }}" uses: actions/checkout@v3 diff --git a/pkg/datasource/sql/conn.go b/pkg/datasource/sql/conn.go index 9c0df58a..ceb1a8ed 100644 --- a/pkg/datasource/sql/conn.go +++ b/pkg/datasource/sql/conn.go @@ -19,29 +19,35 @@ package sql import ( "context" - gosql "database/sql" "database/sql/driver" - "errors" "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/types" ) +// Conn is a connection to a database. It is not used concurrently +// by multiple goroutines. +// +// Conn is assumed to be stateful. + type Conn struct { - res *DBResource - txCtx *types.TransactionContext - targetConn driver.Conn - isInTransaction bool - autoCommit bool - autoCommitChanged bool + txType types.TransactionType + res *DBResource + txCtx *types.TransactionContext + targetConn driver.Conn + autoCommit bool } +// ResetSession is called prior to executing a query on the connection +// if the connection has been used before. If the driver returns ErrBadConn +// the connection is discarded. func (c *Conn) ResetSession(ctx context.Context) error { conn, ok := c.targetConn.(driver.SessionResetter) if !ok { return driver.ErrSkip } + c.txType = types.Local c.txCtx = nil return conn.ResetSession(ctx) } @@ -95,9 +101,8 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { return nil, driver.ErrSkip } - if c.txCtx != nil { - // in transaction, need run Executor - executor, err := exec.BuildExecutor(c.res.dbType, query) + ret, err := c.createNewTxOnExecIfNeed(func() (types.ExecResult, error) { + executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query) if err != nil { return nil, err } @@ -108,7 +113,7 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { Values: args, } - ret, err := executor.ExecWithValue(context.Background(), execCtx, + return executor.ExecWithValue(context.Background(), execCtx, func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { ret, err := conn.Exec(query, args) if err != nil { @@ -117,16 +122,12 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { return types.NewResult(types.WithResult(ret)), nil }) - if err != nil { - return nil, err - } + }) - // todo if user has not opened a transaction, it may call tx.Commit() method to flush undo log - - return ret.GetResult(), nil + if err != nil { + return nil, err } - - return conn.Exec(query, args) + return ret.GetResult(), nil } // ExecContext @@ -142,41 +143,45 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name return c.Exec(query, values) } - executor, err := exec.BuildExecutor(c.res.dbType, query) - if err != nil { - return nil, err - } + ret, err := c.createNewTxOnExecIfNeed(func() (types.ExecResult, error) { + executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query) + if err != nil { + return nil, err + } - execCtx := &exec.ExecContext{ - TxCtx: c.txCtx, - Query: query, - NamedValues: args, - } + execCtx := &exec.ExecContext{ + TxCtx: c.txCtx, + Query: query, + NamedValues: args, + } - ret, err := executor.ExecWithNamedValue(ctx, execCtx, - func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { - ret, err := targetConn.ExecContext(ctx, query, args) - if err != nil { - return nil, err - } + ret, err := executor.ExecWithNamedValue(ctx, execCtx, + func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { + ret, err := targetConn.ExecContext(ctx, query, args) + if err != nil { + return nil, err + } + + return types.NewResult(types.WithResult(ret)), nil + }) + + return ret, err + }) - return types.NewResult(types.WithResult(ret)), nil - }) if err != nil { return nil, err } - return ret.GetResult(), nil } -// QueryContext +// Query func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) { conn, ok := c.targetConn.(driver.Queryer) if !ok { return nil, driver.ErrSkip } - executor, err := exec.BuildExecutor(c.res.dbType, query) + executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query) if err != nil { return nil, err } @@ -216,7 +221,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam return c.Query(query, values) } - executor, err := exec.BuildExecutor(c.res.dbType, query) + executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query) if err != nil { return nil, err } @@ -252,10 +257,11 @@ func (c *Conn) Begin() (driver.Tx, error) { return nil, err } - c.txCtx = types.NewTxCtx() - c.txCtx.DBType = c.res.dbType - c.txCtx.TxOpt = driver.TxOptions{} - c.autoCommit = true + if c.txCtx == nil { + c.txCtx = types.NewTxCtx() + c.txCtx.DBType = c.res.dbType + c.txCtx.TxOpt = driver.TxOptions{} + } return newTx( withDriverConn(c), @@ -264,6 +270,8 @@ func (c *Conn) Begin() (driver.Tx, error) { ) } +// BeginTx Open a transaction and judge whether the current transaction needs to open a +// global transaction according to ctx. If so, it needs to be included in the transaction management of seata func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if conn, ok := c.targetConn.(driver.ConnBeginTx); ok { tx, err := conn.BeginTx(ctx, opts) @@ -271,11 +279,6 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e return nil, err } - c.txCtx = types.NewTxCtx() - c.txCtx.DBType = c.res.dbType - c.txCtx.TxOpt = opts - c.autoCommit = true - return newTx( withDriverConn(c), withTxCtx(c.txCtx), @@ -283,32 +286,15 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e ) } - // Check the transaction level. If the transaction level is non-default - // then return an error here as the BeginTx driver value is not supported. - if opts.Isolation != driver.IsolationLevel(gosql.LevelDefault) { - return nil, errors.New("sql: driver does not support non-default isolation level") - } - - // If a read-only transaction is requested return an error as the - // BeginTx driver value is not supported. - if opts.ReadOnly { - return nil, errors.New("sql: driver does not support read-only transactions") - } - - if ctx.Done() == nil { - return c.Begin() - } - txi, err := c.Begin() - if err == nil { - select { - case <-ctx.Done(): - txi.Rollback() - return nil, ctx.Err() - default: - } + if err != nil { + return nil, err } - return txi, err + return newTx( + withDriverConn(c), + withTxCtx(c.txCtx), + withOriginTx(txi), + ) } // Close invalidates and potentially stops any current @@ -326,3 +312,37 @@ func (c *Conn) Close() error { c.txCtx = nil return c.targetConn.Close() } + +func (c *Conn) createNewTxOnExecIfNeed(f func() (types.ExecResult, error)) (types.ExecResult, error) { + var ( + tx driver.Tx + err error + ) + + if c.txCtx.TransType != types.Local && c.autoCommit { + tx, err = c.Begin() + if err != nil { + return nil, err + } + } + + defer func() { + if tx != nil { + tx.Rollback() + } + }() + + ret, err := f() + + if err != nil { + return nil, err + } + + if tx != nil { + if err := tx.Commit(); err != nil { + return nil, err + } + } + + return ret, nil +} diff --git a/pkg/datasource/sql/conn_at.go b/pkg/datasource/sql/conn_at.go new file mode 100644 index 00000000..e511c2ac --- /dev/null +++ b/pkg/datasource/sql/conn_at.go @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sql + +import ( + "context" + "database/sql/driver" + + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/tm" +) + +type ATConn struct { + *Conn +} + +func (c *ATConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if c.createTxCtxIfAbsent(ctx) { + defer func() { + c.txCtx = nil + }() + } + + return c.Conn.PrepareContext(ctx, query) +} + +// ExecContext +func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if c.createTxCtxIfAbsent(ctx) { + defer func() { + c.txCtx = nil + }() + } + + return c.Conn.ExecContext(ctx, query, args) +} + +// BeginTx +func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + c.txCtx = types.NewTxCtx() + c.txCtx.DBType = c.res.dbType + c.txCtx.TxOpt = opts + + if IsGlobalTx(ctx) { + c.txCtx.XaID = tm.GetXID(ctx) + c.txCtx.TransType = c.txType + } + + return c.Conn.BeginTx(ctx, opts) +} + +func (c *ATConn) createTxCtxIfAbsent(ctx context.Context) bool { + var onceTx bool + + if IsGlobalTx(ctx) && c.txCtx == nil { + c.txCtx = types.NewTxCtx() + c.txCtx.DBType = c.res.dbType + c.txCtx.XaID = tm.GetXID(ctx) + c.txCtx.TransType = types.ATMode + c.autoCommit = true + onceTx = true + } + + if c.txCtx == nil { + c.txCtx = types.NewTxCtx() + onceTx = true + } + + return onceTx +} diff --git a/pkg/datasource/sql/conn_at_test.go b/pkg/datasource/sql/conn_at_test.go new file mode 100644 index 00000000..badc696a --- /dev/null +++ b/pkg/datasource/sql/conn_at_test.go @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sql + +import ( + "context" + "database/sql" + "sync/atomic" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/mock" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/tm" + "github.com/stretchr/testify/assert" +) + +func TestATConn_ExecContext(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMgr := initMockResourceManager(t, ctrl) + _ = mockMgr + + db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") + if err != nil { + t.Fatal(err) + } + + defer db.Close() + + _ = initMockAtConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector { + mockTx := mock.NewMockTestDriverTx(ctrl) + mockTx.EXPECT().Commit().AnyTimes().Return(nil) + mockTx.EXPECT().Rollback().AnyTimes().Return(nil) + + mockConn := mock.NewMockTestDriverConn(ctrl) + mockConn.EXPECT().Begin().AnyTimes().Return(mockTx, nil) + mockConn.EXPECT().BeginTx(gomock.Any(), gomock.Any()).AnyTimes().Return(mockTx, nil) + baseMoclConn(mockConn) + + connector := mock.NewMockTestDriverConnector(ctrl) + connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil) + return connector + }) + + mi := &mockSQLInterceptor{} + + ti := &mockTxHook{} + + exec.CleanCommonHook() + CleanTxHooks() + exec.RegisCommonHook(mi) + RegisterTxHook(ti) + + t.Run("have xid", func(t *testing.T) { + ctx := tm.InitSeataContext(context.Background()) + tm.SetXID(ctx, uuid.New().String()) + t.Logf("set xid=%s", tm.GetXID(ctx)) + + before := func(_ context.Context, execCtx *exec.ExecContext) { + t.Logf("on exec xid=%s", execCtx.TxCtx.XaID) + assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID) + assert.Equal(t, types.ATMode, execCtx.TxCtx.TransType) + } + mi.before = before + + var comitCnt int32 + beforeCommit := func(tx *Tx) { + atomic.AddInt32(&comitCnt, 1) + assert.Equal(t, types.ATMode, tx.ctx.TransType) + } + ti.beforeCommit = beforeCommit + + conn, err := db.Conn(context.Background()) + assert.NoError(t, err) + + _, err = conn.ExecContext(ctx, "SELECT 1") + assert.NoError(t, err) + _, err = db.ExecContext(ctx, "SELECT 1") + assert.NoError(t, err) + + assert.Equal(t, int32(2), atomic.LoadInt32(&comitCnt)) + }) + + t.Run("not xid", func(t *testing.T) { + mi.before = func(_ context.Context, execCtx *exec.ExecContext) { + assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, types.Local, execCtx.TxCtx.TransType) + } + + var comitCnt int32 + ti.beforeCommit = func(tx *Tx) { + atomic.AddInt32(&comitCnt, 1) + } + + conn, err := db.Conn(context.Background()) + assert.NoError(t, err) + + _, err = conn.ExecContext(context.Background(), "SELECT 1") + assert.NoError(t, err) + _, err = db.ExecContext(context.Background(), "SELECT 1") + assert.NoError(t, err) + + _, err = db.Exec("SELECT 1") + assert.NoError(t, err) + + assert.Equal(t, int32(0), atomic.LoadInt32(&comitCnt)) + }) +} diff --git a/pkg/datasource/sql/conn_xa.go b/pkg/datasource/sql/conn_xa.go new file mode 100644 index 00000000..7078e4c8 --- /dev/null +++ b/pkg/datasource/sql/conn_xa.go @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sql + +import ( + "context" + "database/sql/driver" + + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/tm" +) + +type XAConn struct { + *Conn +} + +// PrepareContext +func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if c.createTxCtxIfAbsent(ctx) { + defer func() { + c.txCtx = nil + }() + } + + return c.Conn.PrepareContext(ctx, query) +} + +// ExecContext +func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if c.createTxCtxIfAbsent(ctx) { + defer func() { + c.txCtx = nil + }() + } + + return c.Conn.ExecContext(ctx, query, args) +} + +// BeginTx +func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + c.txCtx = types.NewTxCtx() + c.txCtx.DBType = c.res.dbType + c.txCtx.TxOpt = opts + + if IsGlobalTx(ctx) { + c.txCtx.TransType = types.XAMode + c.txCtx.XaID = tm.GetXID(ctx) + } + + return c.Conn.BeginTx(ctx, opts) +} + +func (c *XAConn) createTxCtxIfAbsent(ctx context.Context) bool { + var onceTx bool + + if IsGlobalTx(ctx) && c.txCtx == nil { + c.txCtx = types.NewTxCtx() + c.txCtx.DBType = c.res.dbType + c.txCtx.XaID = tm.GetXID(ctx) + c.txCtx.TransType = types.XAMode + c.autoCommit = true + onceTx = true + } + + if c.txCtx == nil { + c.txCtx = types.NewTxCtx() + onceTx = true + } + + return onceTx +} diff --git a/pkg/datasource/sql/conn_xa_test.go b/pkg/datasource/sql/conn_xa_test.go new file mode 100644 index 00000000..198a7dd6 --- /dev/null +++ b/pkg/datasource/sql/conn_xa_test.go @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sql + +import ( + "context" + "database/sql" + "database/sql/driver" + "sync/atomic" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/seata/seata-go/pkg/datasource/sql/exec" + "github.com/seata/seata-go/pkg/datasource/sql/mock" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/tm" + "github.com/stretchr/testify/assert" +) + +type mockSQLInterceptor struct { + before func(ctx context.Context, execCtx *exec.ExecContext) + after func(ctx context.Context, execCtx *exec.ExecContext) +} + +func (mi *mockSQLInterceptor) Type() types.SQLType { + return types.SQLTypeUnknown +} + +// Before +func (mi *mockSQLInterceptor) Before(ctx context.Context, execCtx *exec.ExecContext) { + if mi.before != nil { + mi.before(ctx, execCtx) + } +} + +// After +func (mi *mockSQLInterceptor) After(ctx context.Context, execCtx *exec.ExecContext) { + if mi.after != nil { + mi.after(ctx, execCtx) + } +} + +type mockTxHook struct { + beforeCommit func(tx *Tx) + beforeRollback func(tx *Tx) +} + +// BeforeCommit +func (mi *mockTxHook) BeforeCommit(tx *Tx) { + if mi.beforeCommit != nil { + mi.beforeCommit(tx) + } +} + +// BeforeRollback +func (mi *mockTxHook) BeforeRollback(tx *Tx) { + if mi.beforeRollback != nil { + mi.beforeRollback(tx) + } +} + +func baseMoclConn(mockConn *mock.MockTestDriverConn) { + mockConn.EXPECT().ExecContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil) + mockConn.EXPECT().Exec(gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil) + mockConn.EXPECT().ResetSession(gomock.Any()).AnyTimes().Return(nil) + mockConn.EXPECT().Close().AnyTimes().Return(nil) +} + +func TestXAConn_ExecContext(t *testing.T) { + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMgr := initMockResourceManager(t, ctrl) + _ = mockMgr + + db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") + if err != nil { + t.Fatal(err) + } + + defer db.Close() + + _ = initMockXaConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector { + mockTx := mock.NewMockTestDriverTx(ctrl) + mockTx.EXPECT().Commit().AnyTimes().Return(nil) + mockTx.EXPECT().Rollback().AnyTimes().Return(nil) + + mockConn := mock.NewMockTestDriverConn(ctrl) + mockConn.EXPECT().Begin().AnyTimes().Return(mockTx, nil) + mockConn.EXPECT().BeginTx(gomock.Any(), gomock.Any()).AnyTimes().Return(mockTx, nil) + baseMoclConn(mockConn) + + connector := mock.NewMockTestDriverConnector(ctrl) + connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil) + return connector + }) + + mi := &mockSQLInterceptor{} + ti := &mockTxHook{} + + exec.CleanCommonHook() + CleanTxHooks() + exec.RegisCommonHook(mi) + RegisterTxHook(ti) + + t.Run("have xid", func(t *testing.T) { + ctx := tm.InitSeataContext(context.Background()) + tm.SetXID(ctx, uuid.New().String()) + t.Logf("set xid=%s", tm.GetXID(ctx)) + + before := func(_ context.Context, execCtx *exec.ExecContext) { + t.Logf("on exec xid=%s", execCtx.TxCtx.XaID) + assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID) + assert.Equal(t, types.XAMode, execCtx.TxCtx.TransType) + } + mi.before = before + + var comitCnt int32 + beforeCommit := func(tx *Tx) { + atomic.AddInt32(&comitCnt, 1) + assert.Equal(t, tx.ctx.TransType, types.XAMode) + } + ti.beforeCommit = beforeCommit + + conn, err := db.Conn(context.Background()) + assert.NoError(t, err) + + _, err = conn.ExecContext(ctx, "SELECT 1") + assert.NoError(t, err) + _, err = db.ExecContext(ctx, "SELECT 1") + assert.NoError(t, err) + + assert.Equal(t, int32(2), atomic.LoadInt32(&comitCnt)) + }) + + t.Run("not xid", func(t *testing.T) { + before := func(_ context.Context, execCtx *exec.ExecContext) { + assert.Equal(t, "", execCtx.TxCtx.XaID) + assert.Equal(t, types.Local, execCtx.TxCtx.TransType) + } + mi.before = before + + var comitCnt int32 + beforeCommit := func(tx *Tx) { + atomic.AddInt32(&comitCnt, 1) + } + ti.beforeCommit = beforeCommit + + conn, err := db.Conn(context.Background()) + assert.NoError(t, err) + + _, err = conn.ExecContext(context.Background(), "SELECT 1") + assert.NoError(t, err) + _, err = db.ExecContext(context.Background(), "SELECT 1") + assert.NoError(t, err) + + _, err = db.Exec("SELECT 1") + assert.NoError(t, err) + + assert.Equal(t, int32(0), atomic.LoadInt32(&comitCnt)) + }) +} diff --git a/pkg/datasource/sql/connector.go b/pkg/datasource/sql/connector.go index 58a24649..146f50eb 100644 --- a/pkg/datasource/sql/connector.go +++ b/pkg/datasource/sql/connector.go @@ -21,14 +21,77 @@ import ( "context" "database/sql/driver" "sync" + + "github.com/seata/seata-go/pkg/datasource/sql/types" ) +type seataATConnector struct { + *seataConnector + transType types.TransactionType +} + +func (c *seataATConnector) Connect(ctx context.Context) (driver.Conn, error) { + conn, err := c.seataConnector.Connect(ctx) + if err != nil { + return nil, err + } + + _conn, _ := conn.(*Conn) + + return &ATConn{ + Conn: _conn, + }, nil +} + +func (c *seataATConnector) Driver() driver.Driver { + return &seataATDriver{ + seataDriver: c.seataConnector.Driver().(*seataDriver), + } +} + +type seataXAConnector struct { + *seataConnector + transType types.TransactionType +} + +func (c *seataXAConnector) Connect(ctx context.Context) (driver.Conn, error) { + conn, err := c.seataConnector.Connect(ctx) + if err != nil { + return nil, err + } + + _conn, _ := conn.(*Conn) + + return &XAConn{ + Conn: _conn, + }, nil +} + +func (c *seataXAConnector) Driver() driver.Driver { + return &seataXADriver{ + seataDriver: c.seataConnector.Driver().(*seataDriver), + } +} + +// A Connector represents a driver in a fixed configuration +// and can create any number of equivalent Conns for use +// by multiple goroutines. +// +// A Connector can be passed to sql.OpenDB, to allow drivers +// to implement their own sql.DB constructors, or returned by +// DriverContext's OpenConnector method, to allow drivers +// access to context and to avoid repeated parsing of driver +// configuration. +// +// If a Connector implements io.Closer, the sql package's DB.Close +// method will call Close and return error (if any). type seataConnector struct { - conf *seataServerConfig - res *DBResource - once sync.Once - driver driver.Driver - target driver.Connector + transType types.TransactionType + conf *seataServerConfig + res *DBResource + once sync.Once + driver driver.Driver + target driver.Connector } // Connect returns a connection to the database. @@ -50,7 +113,7 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - return &Conn{targetConn: conn, res: c.res}, nil + return &Conn{txType: types.Local, targetConn: conn, res: c.res}, nil } // Driver returns the underlying Driver of the Connector, @@ -62,5 +125,7 @@ func (c *seataConnector) Driver() driver.Driver { c.driver = d }) - return &SeataDriver{target: c.driver} + return &seataDriver{ + target: c.driver, + } } diff --git a/pkg/datasource/sql/connector_test.go b/pkg/datasource/sql/connector_test.go new file mode 100644 index 00000000..bcdf5f3b --- /dev/null +++ b/pkg/datasource/sql/connector_test.go @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sql + +import ( + "context" + "database/sql" + "database/sql/driver" + "reflect" + "testing" + + "github.com/golang/mock/gomock" + "github.com/seata/seata-go/pkg/datasource/sql/mock" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/stretchr/testify/assert" +) + +type initConnectorFunc func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector + +func initMockConnector(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector { + mockConn := mock.NewMockTestDriverConn(ctrl) + + connector := mock.NewMockTestDriverConnector(ctrl) + connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil) + return connector +} + +func initMockAtConnector(t *testing.T, ctrl *gomock.Controller, db *sql.DB, f initConnectorFunc) driver.Connector { + v := reflect.ValueOf(db) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + field := v.FieldByName("connector") + fieldVal := GetUnexportedField(field) + + atConnector, ok := fieldVal.(*seataATConnector) + assert.True(t, ok, "need return seata at connector") + + v = reflect.ValueOf(atConnector) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + SetUnexportedField(v.FieldByName("target"), f(t, ctrl)) + + return fieldVal.(driver.Connector) +} + +func Test_seataATConnector_Connect(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMgr := initMockResourceManager(t, ctrl) + _ = mockMgr + + db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") + if err != nil { + t.Fatal(err) + } + + defer db.Close() + + proxyConnector := initMockAtConnector(t, ctrl, db, initMockConnector) + conn, err := proxyConnector.Connect(context.Background()) + assert.NoError(t, err) + + atConn, ok := conn.(*ATConn) + assert.True(t, ok, "need return seata at connection") + assert.True(t, atConn.txType == types.Local, "init need local tx") +} + +func initMockXaConnector(t *testing.T, ctrl *gomock.Controller, db *sql.DB, f initConnectorFunc) driver.Connector { + v := reflect.ValueOf(db) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + field := v.FieldByName("connector") + fieldVal := GetUnexportedField(field) + + atConnector, ok := fieldVal.(*seataXAConnector) + assert.True(t, ok, "need return seata xa connector") + + v = reflect.ValueOf(atConnector) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + SetUnexportedField(v.FieldByName("target"), f(t, ctrl)) + + return fieldVal.(driver.Connector) +} + +func Test_seataXAConnector_Connect(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMgr := initMockResourceManager(t, ctrl) + _ = mockMgr + + db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") + if err != nil { + t.Fatal(err) + } + + defer db.Close() + + proxyConnector := initMockXaConnector(t, ctrl, db, initMockConnector) + conn, err := proxyConnector.Connect(context.Background()) + assert.NoError(t, err) + + xaConn, ok := conn.(*XAConn) + assert.True(t, ok, "need return seata xa connection") + assert.True(t, xaConn.txType == types.Local, "init need local tx") +} diff --git a/pkg/datasource/sql/context.go b/pkg/datasource/sql/context.go new file mode 100644 index 00000000..a7a5c57f --- /dev/null +++ b/pkg/datasource/sql/context.go @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sql + +import ( + "context" + + "github.com/seata/seata-go/pkg/tm" +) + +// IsGlobalTx check is open global transactions +func IsGlobalTx(ctx context.Context) bool { + return tm.IsTransactionOpened(ctx) +} diff --git a/pkg/datasource/sql/driver.go b/pkg/datasource/sql/driver.go index dbf455e3..87f32515 100644 --- a/pkg/datasource/sql/driver.go +++ b/pkg/datasource/sql/driver.go @@ -35,20 +35,69 @@ import ( ) const ( - SeataMySQLDriver = "seata-mysql" + // SeataATMySQLDriver MySQL driver for AT mode + SeataATMySQLDriver = "seata-at-mysql" + // SeataXAMySQLDriver MySQL driver for XA mode + SeataXAMySQLDriver = "seata-xa-mysql" ) func init() { - sql.Register(SeataMySQLDriver, &SeataDriver{ - target: mysql.MySQLDriver{}, + sql.Register(SeataATMySQLDriver, &seataATDriver{ + seataDriver: &seataDriver{ + transType: types.ATMode, + target: mysql.MySQLDriver{}, + }, }) + sql.Register(SeataXAMySQLDriver, &seataXADriver{ + seataDriver: &seataDriver{ + transType: types.XAMode, + target: mysql.MySQLDriver{}, + }, + }) +} + +type seataATDriver struct { + *seataDriver +} + +func (d *seataATDriver) OpenConnector(name string) (c driver.Connector, err error) { + connector, err := d.seataDriver.OpenConnector(name) + if err != nil { + return nil, err + } + + _connector, _ := connector.(*seataConnector) + _connector.transType = types.ATMode + + return &seataATConnector{ + seataConnector: _connector, + }, nil +} + +type seataXADriver struct { + *seataDriver +} + +func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err error) { + connector, err := d.seataDriver.OpenConnector(name) + if err != nil { + return nil, err + } + + _connector, _ := connector.(*seataConnector) + _connector.transType = types.XAMode + + return &seataXAConnector{ + seataConnector: _connector, + }, nil } -type SeataDriver struct { - target driver.Driver +type seataDriver struct { + transType types.TransactionType + target driver.Driver } -func (d *SeataDriver) Open(name string) (driver.Conn, error) { +func (d *seataDriver) Open(name string) (driver.Conn, error) { conn, err := d.target.Open(name) if err != nil { log.Errorf("open target connection: %w", err) @@ -71,7 +120,7 @@ func (d *SeataDriver) Open(name string) (driver.Conn, error) { return conn, nil } -func (d *SeataDriver) OpenConnector(name string) (c driver.Connector, err error) { +func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) { c = &dsnConnector{dsn: name, driver: d.target} if driverCtx, ok := d.target.(driver.DriverContext); ok { c, err = driverCtx.OpenConnector(name) @@ -86,7 +135,7 @@ func (d *SeataDriver) OpenConnector(name string) (c driver.Connector, err error) return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName()) } - proxy, err := registerResource(c, dbType, sql.OpenDB(c), name) + proxy, err := registerResource(c, d.transType, dbType, sql.OpenDB(c), name) if err != nil { log.Errorf("register resource: %w", err) return nil, err @@ -95,8 +144,8 @@ func (d *SeataDriver) OpenConnector(name string) (c driver.Connector, err error) return proxy, nil } -func (d *SeataDriver) getTargetDriverName() string { - return strings.ReplaceAll(SeataMySQLDriver, "seata-", "") +func (d *seataDriver) getTargetDriverName() string { + return "mysql" } type dsnConnector struct { @@ -112,7 +161,7 @@ func (t *dsnConnector) Driver() driver.Driver { return t.driver } -func registerResource(connector driver.Connector, dbType types.DBType, db *sql.DB, +func registerResource(connector driver.Connector, txType types.TransactionType, dbType types.DBType, db *sql.DB, dataSourceName string, opts ...seataOption) (driver.Connector, error) { conf := loadConfig() for i := range opts { @@ -144,9 +193,9 @@ func registerResource(connector driver.Connector, dbType types.DBType, db *sql.D } return &seataConnector{ - res: res, - target: connector, - conf: conf, + res: res, + target: connector, + conf: conf, }, nil } diff --git a/pkg/datasource/sql/driver_test.go b/pkg/datasource/sql/driver_test.go new file mode 100644 index 00000000..13c5d512 --- /dev/null +++ b/pkg/datasource/sql/driver_test.go @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sql + +import ( + "database/sql" + "reflect" + "testing" + + "github.com/golang/mock/gomock" + + "github.com/seata/seata-go/pkg/datasource/sql/datasource" + "github.com/seata/seata-go/pkg/datasource/sql/mock" + "github.com/seata/seata-go/pkg/protocol/branch" + "github.com/stretchr/testify/assert" +) + +func initMockResourceManager(t *testing.T, ctrl *gomock.Controller) *mock.MockDataSourceManager { + mockResourceMgr := mock.NewMockDataSourceManager(ctrl) + datasource.RegisterResourceManager(branch.BranchTypeAT, mockResourceMgr) + + mockResourceMgr.EXPECT().RegisterResource(gomock.Any()).AnyTimes().Return(nil) + mockResourceMgr.EXPECT().CreateTableMetaCache(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) + + return mockResourceMgr +} + +func Test_seataATDriver_OpenConnector(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMgr := initMockResourceManager(t, ctrl) + _ = mockMgr + + db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") + if err != nil { + t.Fatal(err) + } + + defer db.Close() + + v := reflect.ValueOf(db) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + field := v.FieldByName("connector") + fieldVal := GetUnexportedField(field) + + _, ok := fieldVal.(*seataATConnector) + assert.True(t, ok, "need return seata at connector") +} + +func Test_seataXADriver_OpenConnector(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMgr := initMockResourceManager(t, ctrl) + _ = mockMgr + + db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") + if err != nil { + t.Fatal(err) + } + + defer db.Close() + + v := reflect.ValueOf(db) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + field := v.FieldByName("connector") + fieldVal := GetUnexportedField(field) + + _, ok := fieldVal.(*seataXAConnector) + assert.True(t, ok, "need return seata xa connector") +} diff --git a/pkg/datasource/sql/exec/executor.go b/pkg/datasource/sql/exec/executor.go index c67253f6..819724ed 100644 --- a/pkg/datasource/sql/exec/executor.go +++ b/pkg/datasource/sql/exec/executor.go @@ -56,8 +56,17 @@ type ( } ) -// buildExecutor -func BuildExecutor(dbType types.DBType, query string) (SQLExecutor, error) { +// BuildExecutor +func BuildExecutor(dbType types.DBType, txType types.TransactionType, query string) (SQLExecutor, error) { + if txType == types.XAMode { + hooks := make([]SQLInterceptor, 0, 4) + hooks = append(hooks, commonHook...) + + e := &BaseExecutor{} + e.interceptors(hooks) + return e, nil + } + parseCtx, err := parser.DoParser(query) if err != nil { return nil, err @@ -94,12 +103,12 @@ type BaseExecutor struct { ex SQLExecutor } -// Interceptors +// interceptors func (e *BaseExecutor) interceptors(interceptors []SQLInterceptor) { e.is = interceptors } -// Exec +// ExecWithNamedValue func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) { for i := range e.is { e.is[i].Before(ctx, execCtx) @@ -118,7 +127,7 @@ func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *ExecCont return f(ctx, execCtx.Query, execCtx.NamedValues) } -// Exec +// ExecWithValue func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *ExecContext, f CallbackWithValue) (types.ExecResult, error) { for i := range e.is { e.is[i].Before(ctx, execCtx) diff --git a/pkg/datasource/sql/exec/hook.go b/pkg/datasource/sql/exec/hook.go index c4342fc5..fc550772 100644 --- a/pkg/datasource/sql/exec/hook.go +++ b/pkg/datasource/sql/exec/hook.go @@ -35,6 +35,10 @@ func RegisCommonHook(hook SQLInterceptor) { commonHook = append(commonHook, hook) } +func CleanCommonHook() { + commonHook = make([]SQLInterceptor, 0, 4) +} + // RegisHook not goroutine safe func RegisHook(hook SQLInterceptor) { _, ok := hookSolts[hook.Type()] diff --git a/pkg/datasource/sql/mock/mock_datasource_manager.go b/pkg/datasource/sql/mock/mock_datasource_manager.go new file mode 100644 index 00000000..5978bd5b --- /dev/null +++ b/pkg/datasource/sql/mock/mock_datasource_manager.go @@ -0,0 +1,237 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: datasource_manager.go + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + sql "database/sql" + gomock "github.com/golang/mock/gomock" + datasource "github.com/seata/seata-go/pkg/datasource/sql/datasource" + types "github.com/seata/seata-go/pkg/datasource/sql/types" + branch "github.com/seata/seata-go/pkg/protocol/branch" + message "github.com/seata/seata-go/pkg/protocol/message" + rm "github.com/seata/seata-go/pkg/rm" + reflect "reflect" +) + +// MockDataSourceManager is a mock of DataSourceManager interface +type MockDataSourceManager struct { + ctrl *gomock.Controller + recorder *MockDataSourceManagerMockRecorder +} + +// MockDataSourceManagerMockRecorder is the mock recorder for MockDataSourceManager +type MockDataSourceManagerMockRecorder struct { + mock *MockDataSourceManager +} + +// NewMockDataSourceManager creates a new mock instance +func NewMockDataSourceManager(ctrl *gomock.Controller) *MockDataSourceManager { + mock := &MockDataSourceManager{ctrl: ctrl} + mock.recorder = &MockDataSourceManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockDataSourceManager) EXPECT() *MockDataSourceManagerMockRecorder { + return m.recorder +} + +// RegisterResource mocks base method +func (m *MockDataSourceManager) RegisterResource(resource rm.Resource) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterResource", resource) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterResource indicates an expected call of RegisterResource +func (mr *MockDataSourceManagerMockRecorder) RegisterResource(resource interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterResource", reflect.TypeOf((*MockDataSourceManager)(nil).RegisterResource), resource) +} + +// UnregisterResource mocks base method +func (m *MockDataSourceManager) UnregisterResource(resource rm.Resource) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnregisterResource", resource) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnregisterResource indicates an expected call of UnregisterResource +func (mr *MockDataSourceManagerMockRecorder) UnregisterResource(resource interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterResource", reflect.TypeOf((*MockDataSourceManager)(nil).UnregisterResource), resource) +} + +// GetManagedResources mocks base method +func (m *MockDataSourceManager) GetManagedResources() map[string]rm.Resource { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetManagedResources") + ret0, _ := ret[0].(map[string]rm.Resource) + return ret0 +} + +// GetManagedResources indicates an expected call of GetManagedResources +func (mr *MockDataSourceManagerMockRecorder) GetManagedResources() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedResources", reflect.TypeOf((*MockDataSourceManager)(nil).GetManagedResources)) +} + +// BranchRollback mocks base method +func (m *MockDataSourceManager) BranchRollback(ctx context.Context, req message.BranchRollbackRequest) (branch.BranchStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchRollback", ctx, req) + ret0, _ := ret[0].(branch.BranchStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BranchRollback indicates an expected call of BranchRollback +func (mr *MockDataSourceManagerMockRecorder) BranchRollback(ctx, req interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRollback", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRollback), ctx, req) +} + +// BranchCommit mocks base method +func (m *MockDataSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchCommit", ctx, req) + ret0, _ := ret[0].(branch.BranchStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BranchCommit indicates an expected call of BranchCommit +func (mr *MockDataSourceManagerMockRecorder) BranchCommit(ctx, req interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchCommit", reflect.TypeOf((*MockDataSourceManager)(nil).BranchCommit), ctx, req) +} + +// LockQuery mocks base method +func (m *MockDataSourceManager) LockQuery(ctx context.Context, req message.GlobalLockQueryRequest) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LockQuery", ctx, req) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LockQuery indicates an expected call of LockQuery +func (mr *MockDataSourceManagerMockRecorder) LockQuery(ctx, req interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockQuery", reflect.TypeOf((*MockDataSourceManager)(nil).LockQuery), ctx, req) +} + +// BranchRegister mocks base method +func (m *MockDataSourceManager) BranchRegister(ctx context.Context, clientId string, req message.BranchRegisterRequest) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchRegister", ctx, clientId, req) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BranchRegister indicates an expected call of BranchRegister +func (mr *MockDataSourceManagerMockRecorder) BranchRegister(ctx, clientId, req interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRegister", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRegister), ctx, clientId, req) +} + +// BranchReport mocks base method +func (m *MockDataSourceManager) BranchReport(ctx context.Context, req message.BranchReportRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchReport", ctx, req) + ret0, _ := ret[0].(error) + return ret0 +} + +// BranchReport indicates an expected call of BranchReport +func (mr *MockDataSourceManagerMockRecorder) BranchReport(ctx, req interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchReport", reflect.TypeOf((*MockDataSourceManager)(nil).BranchReport), ctx, req) +} + +// CreateTableMetaCache mocks base method +func (m *MockDataSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (datasource.TableMetaCache, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateTableMetaCache", ctx, resID, dbType, db) + ret0, _ := ret[0].(datasource.TableMetaCache) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateTableMetaCache indicates an expected call of CreateTableMetaCache +func (mr *MockDataSourceManagerMockRecorder) CreateTableMetaCache(ctx, resID, dbType, db interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTableMetaCache", reflect.TypeOf((*MockDataSourceManager)(nil).CreateTableMetaCache), ctx, resID, dbType, db) +} + +// MockTableMetaCache is a mock of TableMetaCache interface +type MockTableMetaCache struct { + ctrl *gomock.Controller + recorder *MockTableMetaCacheMockRecorder +} + +// MockTableMetaCacheMockRecorder is the mock recorder for MockTableMetaCache +type MockTableMetaCacheMockRecorder struct { + mock *MockTableMetaCache +} + +// NewMockTableMetaCache creates a new mock instance +func NewMockTableMetaCache(ctrl *gomock.Controller) *MockTableMetaCache { + mock := &MockTableMetaCache{ctrl: ctrl} + mock.recorder = &MockTableMetaCacheMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTableMetaCache) EXPECT() *MockTableMetaCacheMockRecorder { + return m.recorder +} + +// Init mocks base method +func (m *MockTableMetaCache) Init(ctx context.Context, conn *sql.DB) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Init", ctx, conn) + ret0, _ := ret[0].(error) + return ret0 +} + +// Init indicates an expected call of Init +func (mr *MockTableMetaCacheMockRecorder) Init(ctx, conn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockTableMetaCache)(nil).Init), ctx, conn) +} + +// GetTableMeta mocks base method +func (m *MockTableMetaCache) GetTableMeta(table string) (types.TableMeta, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTableMeta", table) + ret0, _ := ret[0].(types.TableMeta) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTableMeta indicates an expected call of GetTableMeta +func (mr *MockTableMetaCacheMockRecorder) GetTableMeta(table interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMeta", reflect.TypeOf((*MockTableMetaCache)(nil).GetTableMeta), table) +} + +// Destroy mocks base method +func (m *MockTableMetaCache) Destroy() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Destroy") + ret0, _ := ret[0].(error) + return ret0 +} + +// Destroy indicates an expected call of Destroy +func (mr *MockTableMetaCacheMockRecorder) Destroy() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockTableMetaCache)(nil).Destroy)) +} diff --git a/pkg/datasource/sql/mock/mock_driver.go b/pkg/datasource/sql/mock/mock_driver.go new file mode 100644 index 00000000..89012310 --- /dev/null +++ b/pkg/datasource/sql/mock/mock_driver.go @@ -0,0 +1,411 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: test_driver.go + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + driver "database/sql/driver" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockTestDriverConnector is a mock of TestDriverConnector interface +type MockTestDriverConnector struct { + ctrl *gomock.Controller + recorder *MockTestDriverConnectorMockRecorder +} + +// MockTestDriverConnectorMockRecorder is the mock recorder for MockTestDriverConnector +type MockTestDriverConnectorMockRecorder struct { + mock *MockTestDriverConnector +} + +// NewMockTestDriverConnector creates a new mock instance +func NewMockTestDriverConnector(ctrl *gomock.Controller) *MockTestDriverConnector { + mock := &MockTestDriverConnector{ctrl: ctrl} + mock.recorder = &MockTestDriverConnectorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTestDriverConnector) EXPECT() *MockTestDriverConnectorMockRecorder { + return m.recorder +} + +// Connect mocks base method +func (m *MockTestDriverConnector) Connect(arg0 context.Context) (driver.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Connect", arg0) + ret0, _ := ret[0].(driver.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Connect indicates an expected call of Connect +func (mr *MockTestDriverConnectorMockRecorder) Connect(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockTestDriverConnector)(nil).Connect), arg0) +} + +// Driver mocks base method +func (m *MockTestDriverConnector) Driver() driver.Driver { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Driver") + ret0, _ := ret[0].(driver.Driver) + return ret0 +} + +// Driver indicates an expected call of Driver +func (mr *MockTestDriverConnectorMockRecorder) Driver() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Driver", reflect.TypeOf((*MockTestDriverConnector)(nil).Driver)) +} + +// MockTestDriverConn is a mock of TestDriverConn interface +type MockTestDriverConn struct { + ctrl *gomock.Controller + recorder *MockTestDriverConnMockRecorder +} + +// MockTestDriverConnMockRecorder is the mock recorder for MockTestDriverConn +type MockTestDriverConnMockRecorder struct { + mock *MockTestDriverConn +} + +// NewMockTestDriverConn creates a new mock instance +func NewMockTestDriverConn(ctrl *gomock.Controller) *MockTestDriverConn { + mock := &MockTestDriverConn{ctrl: ctrl} + mock.recorder = &MockTestDriverConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTestDriverConn) EXPECT() *MockTestDriverConnMockRecorder { + return m.recorder +} + +// Prepare mocks base method +func (m *MockTestDriverConn) Prepare(query string) (driver.Stmt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", query) + ret0, _ := ret[0].(driver.Stmt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Prepare indicates an expected call of Prepare +func (mr *MockTestDriverConnMockRecorder) Prepare(query interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockTestDriverConn)(nil).Prepare), query) +} + +// Close mocks base method +func (m *MockTestDriverConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockTestDriverConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTestDriverConn)(nil).Close)) +} + +// Begin mocks base method +func (m *MockTestDriverConn) Begin() (driver.Tx, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Begin") + ret0, _ := ret[0].(driver.Tx) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Begin indicates an expected call of Begin +func (mr *MockTestDriverConnMockRecorder) Begin() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTestDriverConn)(nil).Begin)) +} + +// BeginTx mocks base method +func (m *MockTestDriverConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", ctx, opts) + ret0, _ := ret[0].(driver.Tx) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx +func (mr *MockTestDriverConnMockRecorder) BeginTx(ctx, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockTestDriverConn)(nil).BeginTx), ctx, opts) +} + +// Ping mocks base method +func (m *MockTestDriverConn) Ping(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Ping", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Ping indicates an expected call of Ping +func (mr *MockTestDriverConnMockRecorder) Ping(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockTestDriverConn)(nil).Ping), ctx) +} + +// PrepareContext mocks base method +func (m *MockTestDriverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PrepareContext", ctx, query) + ret0, _ := ret[0].(driver.Stmt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PrepareContext indicates an expected call of PrepareContext +func (mr *MockTestDriverConnMockRecorder) PrepareContext(ctx, query interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrepareContext", reflect.TypeOf((*MockTestDriverConn)(nil).PrepareContext), ctx, query) +} + +// Query mocks base method +func (m *MockTestDriverConn) Query(query string, args []driver.Value) (driver.Rows, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Query", query, args) + ret0, _ := ret[0].(driver.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query +func (mr *MockTestDriverConnMockRecorder) Query(query, args interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTestDriverConn)(nil).Query), query, args) +} + +// QueryContext mocks base method +func (m *MockTestDriverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryContext", ctx, query, args) + ret0, _ := ret[0].(driver.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext +func (mr *MockTestDriverConnMockRecorder) QueryContext(ctx, query, args interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockTestDriverConn)(nil).QueryContext), ctx, query, args) +} + +// Exec mocks base method +func (m *MockTestDriverConn) Exec(query string, args []driver.Value) (driver.Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Exec", query, args) + ret0, _ := ret[0].(driver.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec +func (mr *MockTestDriverConnMockRecorder) Exec(query, args interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTestDriverConn)(nil).Exec), query, args) +} + +// ExecContext mocks base method +func (m *MockTestDriverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExecContext", ctx, query, args) + ret0, _ := ret[0].(driver.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ExecContext indicates an expected call of ExecContext +func (mr *MockTestDriverConnMockRecorder) ExecContext(ctx, query, args interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockTestDriverConn)(nil).ExecContext), ctx, query, args) +} + +// ResetSession mocks base method +func (m *MockTestDriverConn) ResetSession(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResetSession", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// ResetSession indicates an expected call of ResetSession +func (mr *MockTestDriverConnMockRecorder) ResetSession(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetSession", reflect.TypeOf((*MockTestDriverConn)(nil).ResetSession), ctx) +} + +// MockTestDriverStmt is a mock of TestDriverStmt interface +type MockTestDriverStmt struct { + ctrl *gomock.Controller + recorder *MockTestDriverStmtMockRecorder +} + +// MockTestDriverStmtMockRecorder is the mock recorder for MockTestDriverStmt +type MockTestDriverStmtMockRecorder struct { + mock *MockTestDriverStmt +} + +// NewMockTestDriverStmt creates a new mock instance +func NewMockTestDriverStmt(ctrl *gomock.Controller) *MockTestDriverStmt { + mock := &MockTestDriverStmt{ctrl: ctrl} + mock.recorder = &MockTestDriverStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTestDriverStmt) EXPECT() *MockTestDriverStmtMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *MockTestDriverStmt) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockTestDriverStmtMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTestDriverStmt)(nil).Close)) +} + +// NumInput mocks base method +func (m *MockTestDriverStmt) NumInput() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NumInput") + ret0, _ := ret[0].(int) + return ret0 +} + +// NumInput indicates an expected call of NumInput +func (mr *MockTestDriverStmtMockRecorder) NumInput() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NumInput", reflect.TypeOf((*MockTestDriverStmt)(nil).NumInput)) +} + +// Exec mocks base method +func (m *MockTestDriverStmt) Exec(args []driver.Value) (driver.Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Exec", args) + ret0, _ := ret[0].(driver.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec +func (mr *MockTestDriverStmtMockRecorder) Exec(args interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTestDriverStmt)(nil).Exec), args) +} + +// Query mocks base method +func (m *MockTestDriverStmt) Query(args []driver.Value) (driver.Rows, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Query", args) + ret0, _ := ret[0].(driver.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query +func (mr *MockTestDriverStmtMockRecorder) Query(args interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTestDriverStmt)(nil).Query), args) +} + +// QueryContext mocks base method +func (m *MockTestDriverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryContext", ctx, args) + ret0, _ := ret[0].(driver.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext +func (mr *MockTestDriverStmtMockRecorder) QueryContext(ctx, args interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockTestDriverStmt)(nil).QueryContext), ctx, args) +} + +// ExecContext mocks base method +func (m *MockTestDriverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExecContext", ctx, args) + ret0, _ := ret[0].(driver.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ExecContext indicates an expected call of ExecContext +func (mr *MockTestDriverStmtMockRecorder) ExecContext(ctx, args interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockTestDriverStmt)(nil).ExecContext), ctx, args) +} + +// MockTestDriverTx is a mock of TestDriverTx interface +type MockTestDriverTx struct { + ctrl *gomock.Controller + recorder *MockTestDriverTxMockRecorder +} + +// MockTestDriverTxMockRecorder is the mock recorder for MockTestDriverTx +type MockTestDriverTxMockRecorder struct { + mock *MockTestDriverTx +} + +// NewMockTestDriverTx creates a new mock instance +func NewMockTestDriverTx(ctrl *gomock.Controller) *MockTestDriverTx { + mock := &MockTestDriverTx{ctrl: ctrl} + mock.recorder = &MockTestDriverTxMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTestDriverTx) EXPECT() *MockTestDriverTxMockRecorder { + return m.recorder +} + +// Commit mocks base method +func (m *MockTestDriverTx) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit +func (mr *MockTestDriverTxMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTestDriverTx)(nil).Commit)) +} + +// Rollback mocks base method +func (m *MockTestDriverTx) Rollback() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback") + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback +func (mr *MockTestDriverTxMockRecorder) Rollback() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTestDriverTx)(nil).Rollback)) +} diff --git a/pkg/datasource/sql/mock/test_driver.go b/pkg/datasource/sql/mock/test_driver.go new file mode 100644 index 00000000..0d193f62 --- /dev/null +++ b/pkg/datasource/sql/mock/test_driver.go @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mock + +import "database/sql/driver" + +type TestDriverConnector interface { + driver.Connector +} + +type TestDriverConn interface { + driver.Conn + driver.ConnBeginTx + driver.Pinger + driver.ConnPrepareContext + driver.Queryer + driver.QueryerContext + driver.Execer + driver.ExecerContext + driver.SessionResetter +} + +type TestDriverStmt interface { + driver.Stmt + driver.StmtQueryContext + driver.StmtExecContext +} + +type TestDriverTx interface { + driver.Tx +} diff --git a/pkg/datasource/sql/sql_test.go b/pkg/datasource/sql/sql_test.go index 3b082db7..d9e8ff74 100644 --- a/pkg/datasource/sql/sql_test.go +++ b/pkg/datasource/sql/sql_test.go @@ -36,7 +36,7 @@ func Test_SQLOpen(t *testing.T) { t.SkipNow() log.Info("begin test") var err error - db, err = sql.Open(SeataMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/polaris_server?multiStatements=true") + db, err = sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") if err != nil { t.Fatal(err) } diff --git a/pkg/datasource/sql/stmt.go b/pkg/datasource/sql/stmt.go index 41aafecf..402d91f5 100644 --- a/pkg/datasource/sql/stmt.go +++ b/pkg/datasource/sql/stmt.go @@ -67,7 +67,7 @@ func (s *Stmt) NumInput() int { // // Deprecated: Drivers should implement StmtQueryContext instead (or additionally). func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { - executor, err := exec.BuildExecutor(s.res.dbType, s.query) + executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query) if err != nil { return nil, err } @@ -105,7 +105,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv return nil, driver.ErrSkip } - executor, err := exec.BuildExecutor(s.res.dbType, s.query) + executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query) if err != nil { return nil, err } @@ -138,7 +138,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv // Deprecated: Drivers should implement StmtExecContext instead (or additionally). func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { // in transaction, need run Executor - executor, err := exec.BuildExecutor(s.res.dbType, s.query) + executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query) if err != nil { return nil, err } @@ -173,7 +173,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive } // in transaction, need run Executor - executor, err := exec.BuildExecutor(s.res.dbType, s.query) + executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query) if err != nil { return nil, err } diff --git a/pkg/datasource/sql/tx.go b/pkg/datasource/sql/tx.go index 7b8c1046..2a2c4e1b 100644 --- a/pkg/datasource/sql/tx.go +++ b/pkg/datasource/sql/tx.go @@ -20,6 +20,7 @@ package sql import ( "context" "database/sql/driver" + "sync" "github.com/seata/seata-go/pkg/datasource/sql/undo" @@ -34,7 +35,34 @@ import ( const REPORT_RETRY_COUNT = 5 -type txOption func(tx *Tx) +var ( + hl sync.RWMutex + txHooks []txHook +) + +func RegisterTxHook(h txHook) { + hl.Lock() + defer hl.Unlock() + + txHooks = append(txHooks, h) +} + +func CleanTxHooks() { + hl.Lock() + defer hl.Unlock() + + txHooks = make([]txHook, 0, 4) +} + +type ( + txOption func(tx *Tx) + + txHook interface { + BeforeCommit(tx *Tx) + + BeforeRollback(tx *Tx) + } +) func newTx(opts ...txOption) (driver.Tx, error) { tx := new(Tx) @@ -83,6 +111,15 @@ type Tx struct { // case 2. not need flush undolog, is XA mode, do local transaction commit // case 3. need run AT transaction func (tx *Tx) Commit() error { + if len(txHooks) != 0 { + hl.RLock() + defer hl.RUnlock() + + for i := range txHooks { + txHooks[i].BeforeCommit(tx) + } + } + if tx.ctx.TransType == types.Local { return tx.commitOnLocal() } @@ -96,6 +133,15 @@ func (tx *Tx) Commit() error { } func (tx *Tx) Rollback() error { + if len(txHooks) != 0 { + hl.RLock() + defer hl.RUnlock() + + for i := range txHooks { + txHooks[i].BeforeRollback(tx) + } + } + err := tx.target.Rollback() if err != nil { if tx.ctx.OpenGlobalTrsnaction() && tx.ctx.IsBranchRegistered() { @@ -124,7 +170,7 @@ func (tx *Tx) commitOnXA() error { // commitOnAT func (tx *Tx) commitOnAT() error { // if TX-Mode is AT, run regis this transaction branch - if err := tx.regis(tx.ctx); err != nil { + if err := tx.register(tx.ctx); err != nil { return err } @@ -133,7 +179,7 @@ func (tx *Tx) commitOnAT() error { return err } - if err := undoLogMgr.FlushUndoLog(tx.ctx, nil); err != nil { + if err := undoLogMgr.FlushUndoLog(tx.ctx, tx.conn.targetConn); err != nil { if rerr := tx.report(false); rerr != nil { return errors.WithStack(rerr) } @@ -151,8 +197,8 @@ func (tx *Tx) commitOnAT() error { return nil } -// regis -func (tx *Tx) regis(ctx *types.TransactionContext) error { +// register +func (tx *Tx) register(ctx *types.TransactionContext) error { if !ctx.HasUndoLog() || !ctx.HasLockKey() { return nil } diff --git a/pkg/datasource/sql/types/types.go b/pkg/datasource/sql/types/types.go index 493f7403..ca2a9b31 100644 --- a/pkg/datasource/sql/types/types.go +++ b/pkg/datasource/sql/types/types.go @@ -103,7 +103,7 @@ type TransactionContext struct { func NewTxCtx() *TransactionContext { return &TransactionContext{ LockKeys: make([]string, 0, 4), - TransType: ATMode, + TransType: Local, LocalTransID: uuid.New().String(), RoundImages: &RoundRecordImage{}, } diff --git a/pkg/datasource/sql/undo/base/undo.go b/pkg/datasource/sql/undo/base/undo.go index a01b6f44..61292b62 100644 --- a/pkg/datasource/sql/undo/base/undo.go +++ b/pkg/datasource/sql/undo/base/undo.go @@ -48,7 +48,7 @@ func (m *BaseUndoLogManager) Init() { } // InsertUndoLog -func (m *BaseUndoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Tx) error { +func (m *BaseUndoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Conn) error { return nil } @@ -102,7 +102,7 @@ func (m *BaseUndoLogManager) BatchDeleteUndoLog(xid []string, branchID []int64, } // FlushUndoLog -func (m *BaseUndoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Tx) error { +func (m *BaseUndoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Conn) error { return nil } diff --git a/pkg/datasource/sql/undo/mysql/undo.go b/pkg/datasource/sql/undo/mysql/undo.go index d9e4029f..1e659dae 100644 --- a/pkg/datasource/sql/undo/mysql/undo.go +++ b/pkg/datasource/sql/undo/mysql/undo.go @@ -39,7 +39,7 @@ func (m *undoLogManager) Init() { } // InsertUndoLog -func (m *undoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Tx) error { +func (m *undoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Conn) error { return m.Base.InsertUndoLog(l, tx) } @@ -54,7 +54,7 @@ func (m *undoLogManager) BatchDeleteUndoLog(xid []string, branchID []int64, conn } // FlushUndoLog -func (m *undoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Tx) error { +func (m *undoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Conn) error { return m.Base.FlushUndoLog(txCtx, tx) } diff --git a/pkg/datasource/sql/undo/undo.go b/pkg/datasource/sql/undo/undo.go index 9e52f8df..e23c5a93 100644 --- a/pkg/datasource/sql/undo/undo.go +++ b/pkg/datasource/sql/undo/undo.go @@ -50,13 +50,13 @@ func Regis(m UndoLogManager) error { type UndoLogManager interface { Init() // InsertUndoLog - InsertUndoLog(l []BranchUndoLog, tx driver.Tx) error + InsertUndoLog(l []BranchUndoLog, tx driver.Conn) error // DeleteUndoLog DeleteUndoLog(ctx context.Context, xid string, branchID int64, conn *sql.Conn) error // BatchDeleteUndoLog BatchDeleteUndoLog(xid []string, branchID []int64, conn *sql.Conn) error // FlushUndoLog - FlushUndoLog(txCtx *types.TransactionContext, tx driver.Tx) error + FlushUndoLog(txCtx *types.TransactionContext, tx driver.Conn) error // RunUndo RunUndo(xid string, branchID int64, conn *sql.Conn) error // DBType diff --git a/pkg/datasource/sql/undo_test.go b/pkg/datasource/sql/undo_test.go index 9acd6f55..06660944 100644 --- a/pkg/datasource/sql/undo_test.go +++ b/pkg/datasource/sql/undo_test.go @@ -32,7 +32,7 @@ func TestBatchDeleteUndoLogs(t *testing.T) { t.SkipNow() testBatchDeleteUndoLogs := func() { - db, err := sql.Open(SeataMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true") + db, err := sql.Open(SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true") assert.Nil(t, err) sqlConn, err := db.Conn(context.Background()) @@ -54,7 +54,7 @@ func TestDeleteUndoLogs(t *testing.T) { t.SkipNow() testDeleteUndoLogs := func() { - db, err := sql.Open(SeataMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true") + db, err := sql.Open(SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true") assert.Nil(t, err) ctx := context.Background()