* refactor:seata conn * test: add unit test * test: add unit testtags/1.0.2-RC1
@@ -23,7 +23,7 @@ on: | |||
push: | |||
branches: [ master ] | |||
pull_request: | |||
branches: "*" | |||
branches: ["*"] | |||
jobs: | |||
build: | |||
@@ -39,6 +39,20 @@ jobs: | |||
with: | |||
go-version: 1.18 | |||
# close default MySQL-Server | |||
- name: Shutdown default mysql | |||
run: sudo service mysql stop | |||
# run mysql server | |||
- name: Create mysql database auth | |||
uses: icomponent/mysql-action@master | |||
with: | |||
VERSION: 5.7 | |||
CONTAINER_NAME: mysql | |||
PORT_MAPPING: 3306:3306 | |||
ROOT_PASSWORD: seata_go | |||
DATABASE: seata_go_test | |||
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it | |||
- name: "checkout ${{ github.ref }}" | |||
uses: actions/checkout@v3 | |||
@@ -19,29 +19,35 @@ package sql | |||
import ( | |||
"context" | |||
gosql "database/sql" | |||
"database/sql/driver" | |||
"errors" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
) | |||
// Conn is a connection to a database. It is not used concurrently | |||
// by multiple goroutines. | |||
// | |||
// Conn is assumed to be stateful. | |||
type Conn struct { | |||
res *DBResource | |||
txCtx *types.TransactionContext | |||
targetConn driver.Conn | |||
isInTransaction bool | |||
autoCommit bool | |||
autoCommitChanged bool | |||
txType types.TransactionType | |||
res *DBResource | |||
txCtx *types.TransactionContext | |||
targetConn driver.Conn | |||
autoCommit bool | |||
} | |||
// ResetSession is called prior to executing a query on the connection | |||
// if the connection has been used before. If the driver returns ErrBadConn | |||
// the connection is discarded. | |||
func (c *Conn) ResetSession(ctx context.Context) error { | |||
conn, ok := c.targetConn.(driver.SessionResetter) | |||
if !ok { | |||
return driver.ErrSkip | |||
} | |||
c.txType = types.Local | |||
c.txCtx = nil | |||
return conn.ResetSession(ctx) | |||
} | |||
@@ -95,9 +101,8 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { | |||
return nil, driver.ErrSkip | |||
} | |||
if c.txCtx != nil { | |||
// in transaction, need run Executor | |||
executor, err := exec.BuildExecutor(c.res.dbType, query) | |||
ret, err := c.createNewTxOnExecIfNeed(func() (types.ExecResult, error) { | |||
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -108,7 +113,7 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { | |||
Values: args, | |||
} | |||
ret, err := executor.ExecWithValue(context.Background(), execCtx, | |||
return executor.ExecWithValue(context.Background(), execCtx, | |||
func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { | |||
ret, err := conn.Exec(query, args) | |||
if err != nil { | |||
@@ -117,16 +122,12 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { | |||
return types.NewResult(types.WithResult(ret)), nil | |||
}) | |||
if err != nil { | |||
return nil, err | |||
} | |||
}) | |||
// todo if user has not opened a transaction, it may call tx.Commit() method to flush undo log | |||
return ret.GetResult(), nil | |||
if err != nil { | |||
return nil, err | |||
} | |||
return conn.Exec(query, args) | |||
return ret.GetResult(), nil | |||
} | |||
// ExecContext | |||
@@ -142,41 +143,45 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name | |||
return c.Exec(query, values) | |||
} | |||
executor, err := exec.BuildExecutor(c.res.dbType, query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
ret, err := c.createNewTxOnExecIfNeed(func() (types.ExecResult, error) { | |||
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
execCtx := &exec.ExecContext{ | |||
TxCtx: c.txCtx, | |||
Query: query, | |||
NamedValues: args, | |||
} | |||
execCtx := &exec.ExecContext{ | |||
TxCtx: c.txCtx, | |||
Query: query, | |||
NamedValues: args, | |||
} | |||
ret, err := executor.ExecWithNamedValue(ctx, execCtx, | |||
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { | |||
ret, err := targetConn.ExecContext(ctx, query, args) | |||
if err != nil { | |||
return nil, err | |||
} | |||
ret, err := executor.ExecWithNamedValue(ctx, execCtx, | |||
func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) { | |||
ret, err := targetConn.ExecContext(ctx, query, args) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return types.NewResult(types.WithResult(ret)), nil | |||
}) | |||
return ret, err | |||
}) | |||
return types.NewResult(types.WithResult(ret)), nil | |||
}) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return ret.GetResult(), nil | |||
} | |||
// QueryContext | |||
// Query | |||
func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) { | |||
conn, ok := c.targetConn.(driver.Queryer) | |||
if !ok { | |||
return nil, driver.ErrSkip | |||
} | |||
executor, err := exec.BuildExecutor(c.res.dbType, query) | |||
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -216,7 +221,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam | |||
return c.Query(query, values) | |||
} | |||
executor, err := exec.BuildExecutor(c.res.dbType, query) | |||
executor, err := exec.BuildExecutor(c.res.dbType, c.txCtx.TransType, query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -252,10 +257,11 @@ func (c *Conn) Begin() (driver.Tx, error) { | |||
return nil, err | |||
} | |||
c.txCtx = types.NewTxCtx() | |||
c.txCtx.DBType = c.res.dbType | |||
c.txCtx.TxOpt = driver.TxOptions{} | |||
c.autoCommit = true | |||
if c.txCtx == nil { | |||
c.txCtx = types.NewTxCtx() | |||
c.txCtx.DBType = c.res.dbType | |||
c.txCtx.TxOpt = driver.TxOptions{} | |||
} | |||
return newTx( | |||
withDriverConn(c), | |||
@@ -264,6 +270,8 @@ func (c *Conn) Begin() (driver.Tx, error) { | |||
) | |||
} | |||
// BeginTx Open a transaction and judge whether the current transaction needs to open a | |||
// global transaction according to ctx. If so, it needs to be included in the transaction management of seata | |||
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | |||
if conn, ok := c.targetConn.(driver.ConnBeginTx); ok { | |||
tx, err := conn.BeginTx(ctx, opts) | |||
@@ -271,11 +279,6 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e | |||
return nil, err | |||
} | |||
c.txCtx = types.NewTxCtx() | |||
c.txCtx.DBType = c.res.dbType | |||
c.txCtx.TxOpt = opts | |||
c.autoCommit = true | |||
return newTx( | |||
withDriverConn(c), | |||
withTxCtx(c.txCtx), | |||
@@ -283,32 +286,15 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e | |||
) | |||
} | |||
// Check the transaction level. If the transaction level is non-default | |||
// then return an error here as the BeginTx driver value is not supported. | |||
if opts.Isolation != driver.IsolationLevel(gosql.LevelDefault) { | |||
return nil, errors.New("sql: driver does not support non-default isolation level") | |||
} | |||
// If a read-only transaction is requested return an error as the | |||
// BeginTx driver value is not supported. | |||
if opts.ReadOnly { | |||
return nil, errors.New("sql: driver does not support read-only transactions") | |||
} | |||
if ctx.Done() == nil { | |||
return c.Begin() | |||
} | |||
txi, err := c.Begin() | |||
if err == nil { | |||
select { | |||
case <-ctx.Done(): | |||
txi.Rollback() | |||
return nil, ctx.Err() | |||
default: | |||
} | |||
if err != nil { | |||
return nil, err | |||
} | |||
return txi, err | |||
return newTx( | |||
withDriverConn(c), | |||
withTxCtx(c.txCtx), | |||
withOriginTx(txi), | |||
) | |||
} | |||
// Close invalidates and potentially stops any current | |||
@@ -326,3 +312,37 @@ func (c *Conn) Close() error { | |||
c.txCtx = nil | |||
return c.targetConn.Close() | |||
} | |||
func (c *Conn) createNewTxOnExecIfNeed(f func() (types.ExecResult, error)) (types.ExecResult, error) { | |||
var ( | |||
tx driver.Tx | |||
err error | |||
) | |||
if c.txCtx.TransType != types.Local && c.autoCommit { | |||
tx, err = c.Begin() | |||
if err != nil { | |||
return nil, err | |||
} | |||
} | |||
defer func() { | |||
if tx != nil { | |||
tx.Rollback() | |||
} | |||
}() | |||
ret, err := f() | |||
if err != nil { | |||
return nil, err | |||
} | |||
if tx != nil { | |||
if err := tx.Commit(); err != nil { | |||
return nil, err | |||
} | |||
} | |||
return ret, nil | |||
} |
@@ -0,0 +1,85 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package sql | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/tm" | |||
) | |||
type ATConn struct { | |||
*Conn | |||
} | |||
func (c *ATConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | |||
if c.createTxCtxIfAbsent(ctx) { | |||
defer func() { | |||
c.txCtx = nil | |||
}() | |||
} | |||
return c.Conn.PrepareContext(ctx, query) | |||
} | |||
// ExecContext | |||
func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | |||
if c.createTxCtxIfAbsent(ctx) { | |||
defer func() { | |||
c.txCtx = nil | |||
}() | |||
} | |||
return c.Conn.ExecContext(ctx, query, args) | |||
} | |||
// BeginTx | |||
func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | |||
c.txCtx = types.NewTxCtx() | |||
c.txCtx.DBType = c.res.dbType | |||
c.txCtx.TxOpt = opts | |||
if IsGlobalTx(ctx) { | |||
c.txCtx.XaID = tm.GetXID(ctx) | |||
c.txCtx.TransType = c.txType | |||
} | |||
return c.Conn.BeginTx(ctx, opts) | |||
} | |||
func (c *ATConn) createTxCtxIfAbsent(ctx context.Context) bool { | |||
var onceTx bool | |||
if IsGlobalTx(ctx) && c.txCtx == nil { | |||
c.txCtx = types.NewTxCtx() | |||
c.txCtx.DBType = c.res.dbType | |||
c.txCtx.XaID = tm.GetXID(ctx) | |||
c.txCtx.TransType = types.ATMode | |||
c.autoCommit = true | |||
onceTx = true | |||
} | |||
if c.txCtx == nil { | |||
c.txCtx = types.NewTxCtx() | |||
onceTx = true | |||
} | |||
return onceTx | |||
} |
@@ -0,0 +1,127 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package sql | |||
import ( | |||
"context" | |||
"database/sql" | |||
"sync/atomic" | |||
"testing" | |||
"github.com/golang/mock/gomock" | |||
"github.com/google/uuid" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/tm" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func TestATConn_ExecContext(t *testing.T) { | |||
ctrl := gomock.NewController(t) | |||
defer ctrl.Finish() | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
if err != nil { | |||
t.Fatal(err) | |||
} | |||
defer db.Close() | |||
_ = initMockAtConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector { | |||
mockTx := mock.NewMockTestDriverTx(ctrl) | |||
mockTx.EXPECT().Commit().AnyTimes().Return(nil) | |||
mockTx.EXPECT().Rollback().AnyTimes().Return(nil) | |||
mockConn := mock.NewMockTestDriverConn(ctrl) | |||
mockConn.EXPECT().Begin().AnyTimes().Return(mockTx, nil) | |||
mockConn.EXPECT().BeginTx(gomock.Any(), gomock.Any()).AnyTimes().Return(mockTx, nil) | |||
baseMoclConn(mockConn) | |||
connector := mock.NewMockTestDriverConnector(ctrl) | |||
connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil) | |||
return connector | |||
}) | |||
mi := &mockSQLInterceptor{} | |||
ti := &mockTxHook{} | |||
exec.CleanCommonHook() | |||
CleanTxHooks() | |||
exec.RegisCommonHook(mi) | |||
RegisterTxHook(ti) | |||
t.Run("have xid", func(t *testing.T) { | |||
ctx := tm.InitSeataContext(context.Background()) | |||
tm.SetXID(ctx, uuid.New().String()) | |||
t.Logf("set xid=%s", tm.GetXID(ctx)) | |||
before := func(_ context.Context, execCtx *exec.ExecContext) { | |||
t.Logf("on exec xid=%s", execCtx.TxCtx.XaID) | |||
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID) | |||
assert.Equal(t, types.ATMode, execCtx.TxCtx.TransType) | |||
} | |||
mi.before = before | |||
var comitCnt int32 | |||
beforeCommit := func(tx *Tx) { | |||
atomic.AddInt32(&comitCnt, 1) | |||
assert.Equal(t, types.ATMode, tx.ctx.TransType) | |||
} | |||
ti.beforeCommit = beforeCommit | |||
conn, err := db.Conn(context.Background()) | |||
assert.NoError(t, err) | |||
_, err = conn.ExecContext(ctx, "SELECT 1") | |||
assert.NoError(t, err) | |||
_, err = db.ExecContext(ctx, "SELECT 1") | |||
assert.NoError(t, err) | |||
assert.Equal(t, int32(2), atomic.LoadInt32(&comitCnt)) | |||
}) | |||
t.Run("not xid", func(t *testing.T) { | |||
mi.before = func(_ context.Context, execCtx *exec.ExecContext) { | |||
assert.Equal(t, "", execCtx.TxCtx.XaID) | |||
assert.Equal(t, types.Local, execCtx.TxCtx.TransType) | |||
} | |||
var comitCnt int32 | |||
ti.beforeCommit = func(tx *Tx) { | |||
atomic.AddInt32(&comitCnt, 1) | |||
} | |||
conn, err := db.Conn(context.Background()) | |||
assert.NoError(t, err) | |||
_, err = conn.ExecContext(context.Background(), "SELECT 1") | |||
assert.NoError(t, err) | |||
_, err = db.ExecContext(context.Background(), "SELECT 1") | |||
assert.NoError(t, err) | |||
_, err = db.Exec("SELECT 1") | |||
assert.NoError(t, err) | |||
assert.Equal(t, int32(0), atomic.LoadInt32(&comitCnt)) | |||
}) | |||
} |
@@ -0,0 +1,86 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package sql | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/tm" | |||
) | |||
type XAConn struct { | |||
*Conn | |||
} | |||
// PrepareContext | |||
func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | |||
if c.createTxCtxIfAbsent(ctx) { | |||
defer func() { | |||
c.txCtx = nil | |||
}() | |||
} | |||
return c.Conn.PrepareContext(ctx, query) | |||
} | |||
// ExecContext | |||
func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | |||
if c.createTxCtxIfAbsent(ctx) { | |||
defer func() { | |||
c.txCtx = nil | |||
}() | |||
} | |||
return c.Conn.ExecContext(ctx, query, args) | |||
} | |||
// BeginTx | |||
func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | |||
c.txCtx = types.NewTxCtx() | |||
c.txCtx.DBType = c.res.dbType | |||
c.txCtx.TxOpt = opts | |||
if IsGlobalTx(ctx) { | |||
c.txCtx.TransType = types.XAMode | |||
c.txCtx.XaID = tm.GetXID(ctx) | |||
} | |||
return c.Conn.BeginTx(ctx, opts) | |||
} | |||
func (c *XAConn) createTxCtxIfAbsent(ctx context.Context) bool { | |||
var onceTx bool | |||
if IsGlobalTx(ctx) && c.txCtx == nil { | |||
c.txCtx = types.NewTxCtx() | |||
c.txCtx.DBType = c.res.dbType | |||
c.txCtx.XaID = tm.GetXID(ctx) | |||
c.txCtx.TransType = types.XAMode | |||
c.autoCommit = true | |||
onceTx = true | |||
} | |||
if c.txCtx == nil { | |||
c.txCtx = types.NewTxCtx() | |||
onceTx = true | |||
} | |||
return onceTx | |||
} |
@@ -0,0 +1,179 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package sql | |||
import ( | |||
"context" | |||
"database/sql" | |||
"database/sql/driver" | |||
"sync/atomic" | |||
"testing" | |||
"github.com/golang/mock/gomock" | |||
"github.com/google/uuid" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/tm" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
type mockSQLInterceptor struct { | |||
before func(ctx context.Context, execCtx *exec.ExecContext) | |||
after func(ctx context.Context, execCtx *exec.ExecContext) | |||
} | |||
func (mi *mockSQLInterceptor) Type() types.SQLType { | |||
return types.SQLTypeUnknown | |||
} | |||
// Before | |||
func (mi *mockSQLInterceptor) Before(ctx context.Context, execCtx *exec.ExecContext) { | |||
if mi.before != nil { | |||
mi.before(ctx, execCtx) | |||
} | |||
} | |||
// After | |||
func (mi *mockSQLInterceptor) After(ctx context.Context, execCtx *exec.ExecContext) { | |||
if mi.after != nil { | |||
mi.after(ctx, execCtx) | |||
} | |||
} | |||
type mockTxHook struct { | |||
beforeCommit func(tx *Tx) | |||
beforeRollback func(tx *Tx) | |||
} | |||
// BeforeCommit | |||
func (mi *mockTxHook) BeforeCommit(tx *Tx) { | |||
if mi.beforeCommit != nil { | |||
mi.beforeCommit(tx) | |||
} | |||
} | |||
// BeforeRollback | |||
func (mi *mockTxHook) BeforeRollback(tx *Tx) { | |||
if mi.beforeRollback != nil { | |||
mi.beforeRollback(tx) | |||
} | |||
} | |||
func baseMoclConn(mockConn *mock.MockTestDriverConn) { | |||
mockConn.EXPECT().ExecContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil) | |||
mockConn.EXPECT().Exec(gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil) | |||
mockConn.EXPECT().ResetSession(gomock.Any()).AnyTimes().Return(nil) | |||
mockConn.EXPECT().Close().AnyTimes().Return(nil) | |||
} | |||
func TestXAConn_ExecContext(t *testing.T) { | |||
ctrl := gomock.NewController(t) | |||
defer ctrl.Finish() | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
if err != nil { | |||
t.Fatal(err) | |||
} | |||
defer db.Close() | |||
_ = initMockXaConnector(t, ctrl, db, func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector { | |||
mockTx := mock.NewMockTestDriverTx(ctrl) | |||
mockTx.EXPECT().Commit().AnyTimes().Return(nil) | |||
mockTx.EXPECT().Rollback().AnyTimes().Return(nil) | |||
mockConn := mock.NewMockTestDriverConn(ctrl) | |||
mockConn.EXPECT().Begin().AnyTimes().Return(mockTx, nil) | |||
mockConn.EXPECT().BeginTx(gomock.Any(), gomock.Any()).AnyTimes().Return(mockTx, nil) | |||
baseMoclConn(mockConn) | |||
connector := mock.NewMockTestDriverConnector(ctrl) | |||
connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil) | |||
return connector | |||
}) | |||
mi := &mockSQLInterceptor{} | |||
ti := &mockTxHook{} | |||
exec.CleanCommonHook() | |||
CleanTxHooks() | |||
exec.RegisCommonHook(mi) | |||
RegisterTxHook(ti) | |||
t.Run("have xid", func(t *testing.T) { | |||
ctx := tm.InitSeataContext(context.Background()) | |||
tm.SetXID(ctx, uuid.New().String()) | |||
t.Logf("set xid=%s", tm.GetXID(ctx)) | |||
before := func(_ context.Context, execCtx *exec.ExecContext) { | |||
t.Logf("on exec xid=%s", execCtx.TxCtx.XaID) | |||
assert.Equal(t, tm.GetXID(ctx), execCtx.TxCtx.XaID) | |||
assert.Equal(t, types.XAMode, execCtx.TxCtx.TransType) | |||
} | |||
mi.before = before | |||
var comitCnt int32 | |||
beforeCommit := func(tx *Tx) { | |||
atomic.AddInt32(&comitCnt, 1) | |||
assert.Equal(t, tx.ctx.TransType, types.XAMode) | |||
} | |||
ti.beforeCommit = beforeCommit | |||
conn, err := db.Conn(context.Background()) | |||
assert.NoError(t, err) | |||
_, err = conn.ExecContext(ctx, "SELECT 1") | |||
assert.NoError(t, err) | |||
_, err = db.ExecContext(ctx, "SELECT 1") | |||
assert.NoError(t, err) | |||
assert.Equal(t, int32(2), atomic.LoadInt32(&comitCnt)) | |||
}) | |||
t.Run("not xid", func(t *testing.T) { | |||
before := func(_ context.Context, execCtx *exec.ExecContext) { | |||
assert.Equal(t, "", execCtx.TxCtx.XaID) | |||
assert.Equal(t, types.Local, execCtx.TxCtx.TransType) | |||
} | |||
mi.before = before | |||
var comitCnt int32 | |||
beforeCommit := func(tx *Tx) { | |||
atomic.AddInt32(&comitCnt, 1) | |||
} | |||
ti.beforeCommit = beforeCommit | |||
conn, err := db.Conn(context.Background()) | |||
assert.NoError(t, err) | |||
_, err = conn.ExecContext(context.Background(), "SELECT 1") | |||
assert.NoError(t, err) | |||
_, err = db.ExecContext(context.Background(), "SELECT 1") | |||
assert.NoError(t, err) | |||
_, err = db.Exec("SELECT 1") | |||
assert.NoError(t, err) | |||
assert.Equal(t, int32(0), atomic.LoadInt32(&comitCnt)) | |||
}) | |||
} |
@@ -21,14 +21,77 @@ import ( | |||
"context" | |||
"database/sql/driver" | |||
"sync" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
) | |||
type seataATConnector struct { | |||
*seataConnector | |||
transType types.TransactionType | |||
} | |||
func (c *seataATConnector) Connect(ctx context.Context) (driver.Conn, error) { | |||
conn, err := c.seataConnector.Connect(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
_conn, _ := conn.(*Conn) | |||
return &ATConn{ | |||
Conn: _conn, | |||
}, nil | |||
} | |||
func (c *seataATConnector) Driver() driver.Driver { | |||
return &seataATDriver{ | |||
seataDriver: c.seataConnector.Driver().(*seataDriver), | |||
} | |||
} | |||
type seataXAConnector struct { | |||
*seataConnector | |||
transType types.TransactionType | |||
} | |||
func (c *seataXAConnector) Connect(ctx context.Context) (driver.Conn, error) { | |||
conn, err := c.seataConnector.Connect(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
_conn, _ := conn.(*Conn) | |||
return &XAConn{ | |||
Conn: _conn, | |||
}, nil | |||
} | |||
func (c *seataXAConnector) Driver() driver.Driver { | |||
return &seataXADriver{ | |||
seataDriver: c.seataConnector.Driver().(*seataDriver), | |||
} | |||
} | |||
// A Connector represents a driver in a fixed configuration | |||
// and can create any number of equivalent Conns for use | |||
// by multiple goroutines. | |||
// | |||
// A Connector can be passed to sql.OpenDB, to allow drivers | |||
// to implement their own sql.DB constructors, or returned by | |||
// DriverContext's OpenConnector method, to allow drivers | |||
// access to context and to avoid repeated parsing of driver | |||
// configuration. | |||
// | |||
// If a Connector implements io.Closer, the sql package's DB.Close | |||
// method will call Close and return error (if any). | |||
type seataConnector struct { | |||
conf *seataServerConfig | |||
res *DBResource | |||
once sync.Once | |||
driver driver.Driver | |||
target driver.Connector | |||
transType types.TransactionType | |||
conf *seataServerConfig | |||
res *DBResource | |||
once sync.Once | |||
driver driver.Driver | |||
target driver.Connector | |||
} | |||
// Connect returns a connection to the database. | |||
@@ -50,7 +113,7 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) { | |||
return nil, err | |||
} | |||
return &Conn{targetConn: conn, res: c.res}, nil | |||
return &Conn{txType: types.Local, targetConn: conn, res: c.res}, nil | |||
} | |||
// Driver returns the underlying Driver of the Connector, | |||
@@ -62,5 +125,7 @@ func (c *seataConnector) Driver() driver.Driver { | |||
c.driver = d | |||
}) | |||
return &SeataDriver{target: c.driver} | |||
return &seataDriver{ | |||
target: c.driver, | |||
} | |||
} |
@@ -0,0 +1,129 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package sql | |||
import ( | |||
"context" | |||
"database/sql" | |||
"database/sql/driver" | |||
"reflect" | |||
"testing" | |||
"github.com/golang/mock/gomock" | |||
"github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
type initConnectorFunc func(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector | |||
func initMockConnector(t *testing.T, ctrl *gomock.Controller) *mock.MockTestDriverConnector { | |||
mockConn := mock.NewMockTestDriverConn(ctrl) | |||
connector := mock.NewMockTestDriverConnector(ctrl) | |||
connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil) | |||
return connector | |||
} | |||
func initMockAtConnector(t *testing.T, ctrl *gomock.Controller, db *sql.DB, f initConnectorFunc) driver.Connector { | |||
v := reflect.ValueOf(db) | |||
if v.Kind() == reflect.Ptr { | |||
v = v.Elem() | |||
} | |||
field := v.FieldByName("connector") | |||
fieldVal := GetUnexportedField(field) | |||
atConnector, ok := fieldVal.(*seataATConnector) | |||
assert.True(t, ok, "need return seata at connector") | |||
v = reflect.ValueOf(atConnector) | |||
if v.Kind() == reflect.Ptr { | |||
v = v.Elem() | |||
} | |||
SetUnexportedField(v.FieldByName("target"), f(t, ctrl)) | |||
return fieldVal.(driver.Connector) | |||
} | |||
func Test_seataATConnector_Connect(t *testing.T) { | |||
ctrl := gomock.NewController(t) | |||
defer ctrl.Finish() | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
if err != nil { | |||
t.Fatal(err) | |||
} | |||
defer db.Close() | |||
proxyConnector := initMockAtConnector(t, ctrl, db, initMockConnector) | |||
conn, err := proxyConnector.Connect(context.Background()) | |||
assert.NoError(t, err) | |||
atConn, ok := conn.(*ATConn) | |||
assert.True(t, ok, "need return seata at connection") | |||
assert.True(t, atConn.txType == types.Local, "init need local tx") | |||
} | |||
func initMockXaConnector(t *testing.T, ctrl *gomock.Controller, db *sql.DB, f initConnectorFunc) driver.Connector { | |||
v := reflect.ValueOf(db) | |||
if v.Kind() == reflect.Ptr { | |||
v = v.Elem() | |||
} | |||
field := v.FieldByName("connector") | |||
fieldVal := GetUnexportedField(field) | |||
atConnector, ok := fieldVal.(*seataXAConnector) | |||
assert.True(t, ok, "need return seata xa connector") | |||
v = reflect.ValueOf(atConnector) | |||
if v.Kind() == reflect.Ptr { | |||
v = v.Elem() | |||
} | |||
SetUnexportedField(v.FieldByName("target"), f(t, ctrl)) | |||
return fieldVal.(driver.Connector) | |||
} | |||
func Test_seataXAConnector_Connect(t *testing.T) { | |||
ctrl := gomock.NewController(t) | |||
defer ctrl.Finish() | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
if err != nil { | |||
t.Fatal(err) | |||
} | |||
defer db.Close() | |||
proxyConnector := initMockXaConnector(t, ctrl, db, initMockConnector) | |||
conn, err := proxyConnector.Connect(context.Background()) | |||
assert.NoError(t, err) | |||
xaConn, ok := conn.(*XAConn) | |||
assert.True(t, ok, "need return seata xa connection") | |||
assert.True(t, xaConn.txType == types.Local, "init need local tx") | |||
} |
@@ -0,0 +1,29 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package sql | |||
import ( | |||
"context" | |||
"github.com/seata/seata-go/pkg/tm" | |||
) | |||
// IsGlobalTx check is open global transactions | |||
func IsGlobalTx(ctx context.Context) bool { | |||
return tm.IsTransactionOpened(ctx) | |||
} |
@@ -35,20 +35,69 @@ import ( | |||
) | |||
const ( | |||
SeataMySQLDriver = "seata-mysql" | |||
// SeataATMySQLDriver MySQL driver for AT mode | |||
SeataATMySQLDriver = "seata-at-mysql" | |||
// SeataXAMySQLDriver MySQL driver for XA mode | |||
SeataXAMySQLDriver = "seata-xa-mysql" | |||
) | |||
func init() { | |||
sql.Register(SeataMySQLDriver, &SeataDriver{ | |||
target: mysql.MySQLDriver{}, | |||
sql.Register(SeataATMySQLDriver, &seataATDriver{ | |||
seataDriver: &seataDriver{ | |||
transType: types.ATMode, | |||
target: mysql.MySQLDriver{}, | |||
}, | |||
}) | |||
sql.Register(SeataXAMySQLDriver, &seataXADriver{ | |||
seataDriver: &seataDriver{ | |||
transType: types.XAMode, | |||
target: mysql.MySQLDriver{}, | |||
}, | |||
}) | |||
} | |||
type seataATDriver struct { | |||
*seataDriver | |||
} | |||
func (d *seataATDriver) OpenConnector(name string) (c driver.Connector, err error) { | |||
connector, err := d.seataDriver.OpenConnector(name) | |||
if err != nil { | |||
return nil, err | |||
} | |||
_connector, _ := connector.(*seataConnector) | |||
_connector.transType = types.ATMode | |||
return &seataATConnector{ | |||
seataConnector: _connector, | |||
}, nil | |||
} | |||
type seataXADriver struct { | |||
*seataDriver | |||
} | |||
func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err error) { | |||
connector, err := d.seataDriver.OpenConnector(name) | |||
if err != nil { | |||
return nil, err | |||
} | |||
_connector, _ := connector.(*seataConnector) | |||
_connector.transType = types.XAMode | |||
return &seataXAConnector{ | |||
seataConnector: _connector, | |||
}, nil | |||
} | |||
type SeataDriver struct { | |||
target driver.Driver | |||
type seataDriver struct { | |||
transType types.TransactionType | |||
target driver.Driver | |||
} | |||
func (d *SeataDriver) Open(name string) (driver.Conn, error) { | |||
func (d *seataDriver) Open(name string) (driver.Conn, error) { | |||
conn, err := d.target.Open(name) | |||
if err != nil { | |||
log.Errorf("open target connection: %w", err) | |||
@@ -71,7 +120,7 @@ func (d *SeataDriver) Open(name string) (driver.Conn, error) { | |||
return conn, nil | |||
} | |||
func (d *SeataDriver) OpenConnector(name string) (c driver.Connector, err error) { | |||
func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) { | |||
c = &dsnConnector{dsn: name, driver: d.target} | |||
if driverCtx, ok := d.target.(driver.DriverContext); ok { | |||
c, err = driverCtx.OpenConnector(name) | |||
@@ -86,7 +135,7 @@ func (d *SeataDriver) OpenConnector(name string) (c driver.Connector, err error) | |||
return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName()) | |||
} | |||
proxy, err := registerResource(c, dbType, sql.OpenDB(c), name) | |||
proxy, err := registerResource(c, d.transType, dbType, sql.OpenDB(c), name) | |||
if err != nil { | |||
log.Errorf("register resource: %w", err) | |||
return nil, err | |||
@@ -95,8 +144,8 @@ func (d *SeataDriver) OpenConnector(name string) (c driver.Connector, err error) | |||
return proxy, nil | |||
} | |||
func (d *SeataDriver) getTargetDriverName() string { | |||
return strings.ReplaceAll(SeataMySQLDriver, "seata-", "") | |||
func (d *seataDriver) getTargetDriverName() string { | |||
return "mysql" | |||
} | |||
type dsnConnector struct { | |||
@@ -112,7 +161,7 @@ func (t *dsnConnector) Driver() driver.Driver { | |||
return t.driver | |||
} | |||
func registerResource(connector driver.Connector, dbType types.DBType, db *sql.DB, | |||
func registerResource(connector driver.Connector, txType types.TransactionType, dbType types.DBType, db *sql.DB, | |||
dataSourceName string, opts ...seataOption) (driver.Connector, error) { | |||
conf := loadConfig() | |||
for i := range opts { | |||
@@ -144,9 +193,9 @@ func registerResource(connector driver.Connector, dbType types.DBType, db *sql.D | |||
} | |||
return &seataConnector{ | |||
res: res, | |||
target: connector, | |||
conf: conf, | |||
res: res, | |||
target: connector, | |||
conf: conf, | |||
}, nil | |||
} | |||
@@ -0,0 +1,93 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package sql | |||
import ( | |||
"database/sql" | |||
"reflect" | |||
"testing" | |||
"github.com/golang/mock/gomock" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
"github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func initMockResourceManager(t *testing.T, ctrl *gomock.Controller) *mock.MockDataSourceManager { | |||
mockResourceMgr := mock.NewMockDataSourceManager(ctrl) | |||
datasource.RegisterResourceManager(branch.BranchTypeAT, mockResourceMgr) | |||
mockResourceMgr.EXPECT().RegisterResource(gomock.Any()).AnyTimes().Return(nil) | |||
mockResourceMgr.EXPECT().CreateTableMetaCache(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) | |||
return mockResourceMgr | |||
} | |||
func Test_seataATDriver_OpenConnector(t *testing.T) { | |||
ctrl := gomock.NewController(t) | |||
defer ctrl.Finish() | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
if err != nil { | |||
t.Fatal(err) | |||
} | |||
defer db.Close() | |||
v := reflect.ValueOf(db) | |||
if v.Kind() == reflect.Ptr { | |||
v = v.Elem() | |||
} | |||
field := v.FieldByName("connector") | |||
fieldVal := GetUnexportedField(field) | |||
_, ok := fieldVal.(*seataATConnector) | |||
assert.True(t, ok, "need return seata at connector") | |||
} | |||
func Test_seataXADriver_OpenConnector(t *testing.T) { | |||
ctrl := gomock.NewController(t) | |||
defer ctrl.Finish() | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
if err != nil { | |||
t.Fatal(err) | |||
} | |||
defer db.Close() | |||
v := reflect.ValueOf(db) | |||
if v.Kind() == reflect.Ptr { | |||
v = v.Elem() | |||
} | |||
field := v.FieldByName("connector") | |||
fieldVal := GetUnexportedField(field) | |||
_, ok := fieldVal.(*seataXAConnector) | |||
assert.True(t, ok, "need return seata xa connector") | |||
} |
@@ -56,8 +56,17 @@ type ( | |||
} | |||
) | |||
// buildExecutor | |||
func BuildExecutor(dbType types.DBType, query string) (SQLExecutor, error) { | |||
// BuildExecutor | |||
func BuildExecutor(dbType types.DBType, txType types.TransactionType, query string) (SQLExecutor, error) { | |||
if txType == types.XAMode { | |||
hooks := make([]SQLInterceptor, 0, 4) | |||
hooks = append(hooks, commonHook...) | |||
e := &BaseExecutor{} | |||
e.interceptors(hooks) | |||
return e, nil | |||
} | |||
parseCtx, err := parser.DoParser(query) | |||
if err != nil { | |||
return nil, err | |||
@@ -94,12 +103,12 @@ type BaseExecutor struct { | |||
ex SQLExecutor | |||
} | |||
// Interceptors | |||
// interceptors | |||
func (e *BaseExecutor) interceptors(interceptors []SQLInterceptor) { | |||
e.is = interceptors | |||
} | |||
// Exec | |||
// ExecWithNamedValue | |||
func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) { | |||
for i := range e.is { | |||
e.is[i].Before(ctx, execCtx) | |||
@@ -118,7 +127,7 @@ func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *ExecCont | |||
return f(ctx, execCtx.Query, execCtx.NamedValues) | |||
} | |||
// Exec | |||
// ExecWithValue | |||
func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *ExecContext, f CallbackWithValue) (types.ExecResult, error) { | |||
for i := range e.is { | |||
e.is[i].Before(ctx, execCtx) | |||
@@ -35,6 +35,10 @@ func RegisCommonHook(hook SQLInterceptor) { | |||
commonHook = append(commonHook, hook) | |||
} | |||
func CleanCommonHook() { | |||
commonHook = make([]SQLInterceptor, 0, 4) | |||
} | |||
// RegisHook not goroutine safe | |||
func RegisHook(hook SQLInterceptor) { | |||
_, ok := hookSolts[hook.Type()] | |||
@@ -0,0 +1,237 @@ | |||
// Code generated by MockGen. DO NOT EDIT. | |||
// Source: datasource_manager.go | |||
// Package mock is a generated GoMock package. | |||
package mock | |||
import ( | |||
context "context" | |||
sql "database/sql" | |||
gomock "github.com/golang/mock/gomock" | |||
datasource "github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
types "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
branch "github.com/seata/seata-go/pkg/protocol/branch" | |||
message "github.com/seata/seata-go/pkg/protocol/message" | |||
rm "github.com/seata/seata-go/pkg/rm" | |||
reflect "reflect" | |||
) | |||
// MockDataSourceManager is a mock of DataSourceManager interface | |||
type MockDataSourceManager struct { | |||
ctrl *gomock.Controller | |||
recorder *MockDataSourceManagerMockRecorder | |||
} | |||
// MockDataSourceManagerMockRecorder is the mock recorder for MockDataSourceManager | |||
type MockDataSourceManagerMockRecorder struct { | |||
mock *MockDataSourceManager | |||
} | |||
// NewMockDataSourceManager creates a new mock instance | |||
func NewMockDataSourceManager(ctrl *gomock.Controller) *MockDataSourceManager { | |||
mock := &MockDataSourceManager{ctrl: ctrl} | |||
mock.recorder = &MockDataSourceManagerMockRecorder{mock} | |||
return mock | |||
} | |||
// EXPECT returns an object that allows the caller to indicate expected use | |||
func (m *MockDataSourceManager) EXPECT() *MockDataSourceManagerMockRecorder { | |||
return m.recorder | |||
} | |||
// RegisterResource mocks base method | |||
func (m *MockDataSourceManager) RegisterResource(resource rm.Resource) error { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "RegisterResource", resource) | |||
ret0, _ := ret[0].(error) | |||
return ret0 | |||
} | |||
// RegisterResource indicates an expected call of RegisterResource | |||
func (mr *MockDataSourceManagerMockRecorder) RegisterResource(resource interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterResource", reflect.TypeOf((*MockDataSourceManager)(nil).RegisterResource), resource) | |||
} | |||
// UnregisterResource mocks base method | |||
func (m *MockDataSourceManager) UnregisterResource(resource rm.Resource) error { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "UnregisterResource", resource) | |||
ret0, _ := ret[0].(error) | |||
return ret0 | |||
} | |||
// UnregisterResource indicates an expected call of UnregisterResource | |||
func (mr *MockDataSourceManagerMockRecorder) UnregisterResource(resource interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterResource", reflect.TypeOf((*MockDataSourceManager)(nil).UnregisterResource), resource) | |||
} | |||
// GetManagedResources mocks base method | |||
func (m *MockDataSourceManager) GetManagedResources() map[string]rm.Resource { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "GetManagedResources") | |||
ret0, _ := ret[0].(map[string]rm.Resource) | |||
return ret0 | |||
} | |||
// GetManagedResources indicates an expected call of GetManagedResources | |||
func (mr *MockDataSourceManagerMockRecorder) GetManagedResources() *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedResources", reflect.TypeOf((*MockDataSourceManager)(nil).GetManagedResources)) | |||
} | |||
// BranchRollback mocks base method | |||
func (m *MockDataSourceManager) BranchRollback(ctx context.Context, req message.BranchRollbackRequest) (branch.BranchStatus, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "BranchRollback", ctx, req) | |||
ret0, _ := ret[0].(branch.BranchStatus) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// BranchRollback indicates an expected call of BranchRollback | |||
func (mr *MockDataSourceManagerMockRecorder) BranchRollback(ctx, req interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRollback", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRollback), ctx, req) | |||
} | |||
// BranchCommit mocks base method | |||
func (m *MockDataSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "BranchCommit", ctx, req) | |||
ret0, _ := ret[0].(branch.BranchStatus) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// BranchCommit indicates an expected call of BranchCommit | |||
func (mr *MockDataSourceManagerMockRecorder) BranchCommit(ctx, req interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchCommit", reflect.TypeOf((*MockDataSourceManager)(nil).BranchCommit), ctx, req) | |||
} | |||
// LockQuery mocks base method | |||
func (m *MockDataSourceManager) LockQuery(ctx context.Context, req message.GlobalLockQueryRequest) (bool, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "LockQuery", ctx, req) | |||
ret0, _ := ret[0].(bool) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// LockQuery indicates an expected call of LockQuery | |||
func (mr *MockDataSourceManagerMockRecorder) LockQuery(ctx, req interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockQuery", reflect.TypeOf((*MockDataSourceManager)(nil).LockQuery), ctx, req) | |||
} | |||
// BranchRegister mocks base method | |||
func (m *MockDataSourceManager) BranchRegister(ctx context.Context, clientId string, req message.BranchRegisterRequest) (int64, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "BranchRegister", ctx, clientId, req) | |||
ret0, _ := ret[0].(int64) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// BranchRegister indicates an expected call of BranchRegister | |||
func (mr *MockDataSourceManagerMockRecorder) BranchRegister(ctx, clientId, req interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRegister", reflect.TypeOf((*MockDataSourceManager)(nil).BranchRegister), ctx, clientId, req) | |||
} | |||
// BranchReport mocks base method | |||
func (m *MockDataSourceManager) BranchReport(ctx context.Context, req message.BranchReportRequest) error { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "BranchReport", ctx, req) | |||
ret0, _ := ret[0].(error) | |||
return ret0 | |||
} | |||
// BranchReport indicates an expected call of BranchReport | |||
func (mr *MockDataSourceManagerMockRecorder) BranchReport(ctx, req interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchReport", reflect.TypeOf((*MockDataSourceManager)(nil).BranchReport), ctx, req) | |||
} | |||
// CreateTableMetaCache mocks base method | |||
func (m *MockDataSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (datasource.TableMetaCache, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "CreateTableMetaCache", ctx, resID, dbType, db) | |||
ret0, _ := ret[0].(datasource.TableMetaCache) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// CreateTableMetaCache indicates an expected call of CreateTableMetaCache | |||
func (mr *MockDataSourceManagerMockRecorder) CreateTableMetaCache(ctx, resID, dbType, db interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTableMetaCache", reflect.TypeOf((*MockDataSourceManager)(nil).CreateTableMetaCache), ctx, resID, dbType, db) | |||
} | |||
// MockTableMetaCache is a mock of TableMetaCache interface | |||
type MockTableMetaCache struct { | |||
ctrl *gomock.Controller | |||
recorder *MockTableMetaCacheMockRecorder | |||
} | |||
// MockTableMetaCacheMockRecorder is the mock recorder for MockTableMetaCache | |||
type MockTableMetaCacheMockRecorder struct { | |||
mock *MockTableMetaCache | |||
} | |||
// NewMockTableMetaCache creates a new mock instance | |||
func NewMockTableMetaCache(ctrl *gomock.Controller) *MockTableMetaCache { | |||
mock := &MockTableMetaCache{ctrl: ctrl} | |||
mock.recorder = &MockTableMetaCacheMockRecorder{mock} | |||
return mock | |||
} | |||
// EXPECT returns an object that allows the caller to indicate expected use | |||
func (m *MockTableMetaCache) EXPECT() *MockTableMetaCacheMockRecorder { | |||
return m.recorder | |||
} | |||
// Init mocks base method | |||
func (m *MockTableMetaCache) Init(ctx context.Context, conn *sql.DB) error { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Init", ctx, conn) | |||
ret0, _ := ret[0].(error) | |||
return ret0 | |||
} | |||
// Init indicates an expected call of Init | |||
func (mr *MockTableMetaCacheMockRecorder) Init(ctx, conn interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockTableMetaCache)(nil).Init), ctx, conn) | |||
} | |||
// GetTableMeta mocks base method | |||
func (m *MockTableMetaCache) GetTableMeta(table string) (types.TableMeta, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "GetTableMeta", table) | |||
ret0, _ := ret[0].(types.TableMeta) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// GetTableMeta indicates an expected call of GetTableMeta | |||
func (mr *MockTableMetaCacheMockRecorder) GetTableMeta(table interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTableMeta", reflect.TypeOf((*MockTableMetaCache)(nil).GetTableMeta), table) | |||
} | |||
// Destroy mocks base method | |||
func (m *MockTableMetaCache) Destroy() error { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Destroy") | |||
ret0, _ := ret[0].(error) | |||
return ret0 | |||
} | |||
// Destroy indicates an expected call of Destroy | |||
func (mr *MockTableMetaCacheMockRecorder) Destroy() *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockTableMetaCache)(nil).Destroy)) | |||
} |
@@ -0,0 +1,411 @@ | |||
// Code generated by MockGen. DO NOT EDIT. | |||
// Source: test_driver.go | |||
// Package mock is a generated GoMock package. | |||
package mock | |||
import ( | |||
context "context" | |||
driver "database/sql/driver" | |||
gomock "github.com/golang/mock/gomock" | |||
reflect "reflect" | |||
) | |||
// MockTestDriverConnector is a mock of TestDriverConnector interface | |||
type MockTestDriverConnector struct { | |||
ctrl *gomock.Controller | |||
recorder *MockTestDriverConnectorMockRecorder | |||
} | |||
// MockTestDriverConnectorMockRecorder is the mock recorder for MockTestDriverConnector | |||
type MockTestDriverConnectorMockRecorder struct { | |||
mock *MockTestDriverConnector | |||
} | |||
// NewMockTestDriverConnector creates a new mock instance | |||
func NewMockTestDriverConnector(ctrl *gomock.Controller) *MockTestDriverConnector { | |||
mock := &MockTestDriverConnector{ctrl: ctrl} | |||
mock.recorder = &MockTestDriverConnectorMockRecorder{mock} | |||
return mock | |||
} | |||
// EXPECT returns an object that allows the caller to indicate expected use | |||
func (m *MockTestDriverConnector) EXPECT() *MockTestDriverConnectorMockRecorder { | |||
return m.recorder | |||
} | |||
// Connect mocks base method | |||
func (m *MockTestDriverConnector) Connect(arg0 context.Context) (driver.Conn, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Connect", arg0) | |||
ret0, _ := ret[0].(driver.Conn) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// Connect indicates an expected call of Connect | |||
func (mr *MockTestDriverConnectorMockRecorder) Connect(arg0 interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockTestDriverConnector)(nil).Connect), arg0) | |||
} | |||
// Driver mocks base method | |||
func (m *MockTestDriverConnector) Driver() driver.Driver { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Driver") | |||
ret0, _ := ret[0].(driver.Driver) | |||
return ret0 | |||
} | |||
// Driver indicates an expected call of Driver | |||
func (mr *MockTestDriverConnectorMockRecorder) Driver() *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Driver", reflect.TypeOf((*MockTestDriverConnector)(nil).Driver)) | |||
} | |||
// MockTestDriverConn is a mock of TestDriverConn interface | |||
type MockTestDriverConn struct { | |||
ctrl *gomock.Controller | |||
recorder *MockTestDriverConnMockRecorder | |||
} | |||
// MockTestDriverConnMockRecorder is the mock recorder for MockTestDriverConn | |||
type MockTestDriverConnMockRecorder struct { | |||
mock *MockTestDriverConn | |||
} | |||
// NewMockTestDriverConn creates a new mock instance | |||
func NewMockTestDriverConn(ctrl *gomock.Controller) *MockTestDriverConn { | |||
mock := &MockTestDriverConn{ctrl: ctrl} | |||
mock.recorder = &MockTestDriverConnMockRecorder{mock} | |||
return mock | |||
} | |||
// EXPECT returns an object that allows the caller to indicate expected use | |||
func (m *MockTestDriverConn) EXPECT() *MockTestDriverConnMockRecorder { | |||
return m.recorder | |||
} | |||
// Prepare mocks base method | |||
func (m *MockTestDriverConn) Prepare(query string) (driver.Stmt, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Prepare", query) | |||
ret0, _ := ret[0].(driver.Stmt) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// Prepare indicates an expected call of Prepare | |||
func (mr *MockTestDriverConnMockRecorder) Prepare(query interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockTestDriverConn)(nil).Prepare), query) | |||
} | |||
// Close mocks base method | |||
func (m *MockTestDriverConn) Close() error { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Close") | |||
ret0, _ := ret[0].(error) | |||
return ret0 | |||
} | |||
// Close indicates an expected call of Close | |||
func (mr *MockTestDriverConnMockRecorder) Close() *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTestDriverConn)(nil).Close)) | |||
} | |||
// Begin mocks base method | |||
func (m *MockTestDriverConn) Begin() (driver.Tx, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Begin") | |||
ret0, _ := ret[0].(driver.Tx) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// Begin indicates an expected call of Begin | |||
func (mr *MockTestDriverConnMockRecorder) Begin() *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTestDriverConn)(nil).Begin)) | |||
} | |||
// BeginTx mocks base method | |||
func (m *MockTestDriverConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "BeginTx", ctx, opts) | |||
ret0, _ := ret[0].(driver.Tx) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// BeginTx indicates an expected call of BeginTx | |||
func (mr *MockTestDriverConnMockRecorder) BeginTx(ctx, opts interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockTestDriverConn)(nil).BeginTx), ctx, opts) | |||
} | |||
// Ping mocks base method | |||
func (m *MockTestDriverConn) Ping(ctx context.Context) error { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Ping", ctx) | |||
ret0, _ := ret[0].(error) | |||
return ret0 | |||
} | |||
// Ping indicates an expected call of Ping | |||
func (mr *MockTestDriverConnMockRecorder) Ping(ctx interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockTestDriverConn)(nil).Ping), ctx) | |||
} | |||
// PrepareContext mocks base method | |||
func (m *MockTestDriverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "PrepareContext", ctx, query) | |||
ret0, _ := ret[0].(driver.Stmt) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// PrepareContext indicates an expected call of PrepareContext | |||
func (mr *MockTestDriverConnMockRecorder) PrepareContext(ctx, query interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrepareContext", reflect.TypeOf((*MockTestDriverConn)(nil).PrepareContext), ctx, query) | |||
} | |||
// Query mocks base method | |||
func (m *MockTestDriverConn) Query(query string, args []driver.Value) (driver.Rows, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Query", query, args) | |||
ret0, _ := ret[0].(driver.Rows) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// Query indicates an expected call of Query | |||
func (mr *MockTestDriverConnMockRecorder) Query(query, args interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTestDriverConn)(nil).Query), query, args) | |||
} | |||
// QueryContext mocks base method | |||
func (m *MockTestDriverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "QueryContext", ctx, query, args) | |||
ret0, _ := ret[0].(driver.Rows) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// QueryContext indicates an expected call of QueryContext | |||
func (mr *MockTestDriverConnMockRecorder) QueryContext(ctx, query, args interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockTestDriverConn)(nil).QueryContext), ctx, query, args) | |||
} | |||
// Exec mocks base method | |||
func (m *MockTestDriverConn) Exec(query string, args []driver.Value) (driver.Result, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Exec", query, args) | |||
ret0, _ := ret[0].(driver.Result) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// Exec indicates an expected call of Exec | |||
func (mr *MockTestDriverConnMockRecorder) Exec(query, args interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTestDriverConn)(nil).Exec), query, args) | |||
} | |||
// ExecContext mocks base method | |||
func (m *MockTestDriverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "ExecContext", ctx, query, args) | |||
ret0, _ := ret[0].(driver.Result) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// ExecContext indicates an expected call of ExecContext | |||
func (mr *MockTestDriverConnMockRecorder) ExecContext(ctx, query, args interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockTestDriverConn)(nil).ExecContext), ctx, query, args) | |||
} | |||
// ResetSession mocks base method | |||
func (m *MockTestDriverConn) ResetSession(ctx context.Context) error { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "ResetSession", ctx) | |||
ret0, _ := ret[0].(error) | |||
return ret0 | |||
} | |||
// ResetSession indicates an expected call of ResetSession | |||
func (mr *MockTestDriverConnMockRecorder) ResetSession(ctx interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetSession", reflect.TypeOf((*MockTestDriverConn)(nil).ResetSession), ctx) | |||
} | |||
// MockTestDriverStmt is a mock of TestDriverStmt interface | |||
type MockTestDriverStmt struct { | |||
ctrl *gomock.Controller | |||
recorder *MockTestDriverStmtMockRecorder | |||
} | |||
// MockTestDriverStmtMockRecorder is the mock recorder for MockTestDriverStmt | |||
type MockTestDriverStmtMockRecorder struct { | |||
mock *MockTestDriverStmt | |||
} | |||
// NewMockTestDriverStmt creates a new mock instance | |||
func NewMockTestDriverStmt(ctrl *gomock.Controller) *MockTestDriverStmt { | |||
mock := &MockTestDriverStmt{ctrl: ctrl} | |||
mock.recorder = &MockTestDriverStmtMockRecorder{mock} | |||
return mock | |||
} | |||
// EXPECT returns an object that allows the caller to indicate expected use | |||
func (m *MockTestDriverStmt) EXPECT() *MockTestDriverStmtMockRecorder { | |||
return m.recorder | |||
} | |||
// Close mocks base method | |||
func (m *MockTestDriverStmt) Close() error { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Close") | |||
ret0, _ := ret[0].(error) | |||
return ret0 | |||
} | |||
// Close indicates an expected call of Close | |||
func (mr *MockTestDriverStmtMockRecorder) Close() *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTestDriverStmt)(nil).Close)) | |||
} | |||
// NumInput mocks base method | |||
func (m *MockTestDriverStmt) NumInput() int { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "NumInput") | |||
ret0, _ := ret[0].(int) | |||
return ret0 | |||
} | |||
// NumInput indicates an expected call of NumInput | |||
func (mr *MockTestDriverStmtMockRecorder) NumInput() *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NumInput", reflect.TypeOf((*MockTestDriverStmt)(nil).NumInput)) | |||
} | |||
// Exec mocks base method | |||
func (m *MockTestDriverStmt) Exec(args []driver.Value) (driver.Result, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Exec", args) | |||
ret0, _ := ret[0].(driver.Result) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// Exec indicates an expected call of Exec | |||
func (mr *MockTestDriverStmtMockRecorder) Exec(args interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTestDriverStmt)(nil).Exec), args) | |||
} | |||
// Query mocks base method | |||
func (m *MockTestDriverStmt) Query(args []driver.Value) (driver.Rows, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Query", args) | |||
ret0, _ := ret[0].(driver.Rows) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// Query indicates an expected call of Query | |||
func (mr *MockTestDriverStmtMockRecorder) Query(args interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTestDriverStmt)(nil).Query), args) | |||
} | |||
// QueryContext mocks base method | |||
func (m *MockTestDriverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "QueryContext", ctx, args) | |||
ret0, _ := ret[0].(driver.Rows) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// QueryContext indicates an expected call of QueryContext | |||
func (mr *MockTestDriverStmtMockRecorder) QueryContext(ctx, args interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockTestDriverStmt)(nil).QueryContext), ctx, args) | |||
} | |||
// ExecContext mocks base method | |||
func (m *MockTestDriverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "ExecContext", ctx, args) | |||
ret0, _ := ret[0].(driver.Result) | |||
ret1, _ := ret[1].(error) | |||
return ret0, ret1 | |||
} | |||
// ExecContext indicates an expected call of ExecContext | |||
func (mr *MockTestDriverStmtMockRecorder) ExecContext(ctx, args interface{}) *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecContext", reflect.TypeOf((*MockTestDriverStmt)(nil).ExecContext), ctx, args) | |||
} | |||
// MockTestDriverTx is a mock of TestDriverTx interface | |||
type MockTestDriverTx struct { | |||
ctrl *gomock.Controller | |||
recorder *MockTestDriverTxMockRecorder | |||
} | |||
// MockTestDriverTxMockRecorder is the mock recorder for MockTestDriverTx | |||
type MockTestDriverTxMockRecorder struct { | |||
mock *MockTestDriverTx | |||
} | |||
// NewMockTestDriverTx creates a new mock instance | |||
func NewMockTestDriverTx(ctrl *gomock.Controller) *MockTestDriverTx { | |||
mock := &MockTestDriverTx{ctrl: ctrl} | |||
mock.recorder = &MockTestDriverTxMockRecorder{mock} | |||
return mock | |||
} | |||
// EXPECT returns an object that allows the caller to indicate expected use | |||
func (m *MockTestDriverTx) EXPECT() *MockTestDriverTxMockRecorder { | |||
return m.recorder | |||
} | |||
// Commit mocks base method | |||
func (m *MockTestDriverTx) Commit() error { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Commit") | |||
ret0, _ := ret[0].(error) | |||
return ret0 | |||
} | |||
// Commit indicates an expected call of Commit | |||
func (mr *MockTestDriverTxMockRecorder) Commit() *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTestDriverTx)(nil).Commit)) | |||
} | |||
// Rollback mocks base method | |||
func (m *MockTestDriverTx) Rollback() error { | |||
m.ctrl.T.Helper() | |||
ret := m.ctrl.Call(m, "Rollback") | |||
ret0, _ := ret[0].(error) | |||
return ret0 | |||
} | |||
// Rollback indicates an expected call of Rollback | |||
func (mr *MockTestDriverTxMockRecorder) Rollback() *gomock.Call { | |||
mr.mock.ctrl.T.Helper() | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTestDriverTx)(nil).Rollback)) | |||
} |
@@ -0,0 +1,46 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package mock | |||
import "database/sql/driver" | |||
type TestDriverConnector interface { | |||
driver.Connector | |||
} | |||
type TestDriverConn interface { | |||
driver.Conn | |||
driver.ConnBeginTx | |||
driver.Pinger | |||
driver.ConnPrepareContext | |||
driver.Queryer | |||
driver.QueryerContext | |||
driver.Execer | |||
driver.ExecerContext | |||
driver.SessionResetter | |||
} | |||
type TestDriverStmt interface { | |||
driver.Stmt | |||
driver.StmtQueryContext | |||
driver.StmtExecContext | |||
} | |||
type TestDriverTx interface { | |||
driver.Tx | |||
} |
@@ -36,7 +36,7 @@ func Test_SQLOpen(t *testing.T) { | |||
t.SkipNow() | |||
log.Info("begin test") | |||
var err error | |||
db, err = sql.Open(SeataMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/polaris_server?multiStatements=true") | |||
db, err = sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
if err != nil { | |||
t.Fatal(err) | |||
} | |||
@@ -67,7 +67,7 @@ func (s *Stmt) NumInput() int { | |||
// | |||
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally). | |||
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { | |||
executor, err := exec.BuildExecutor(s.res.dbType, s.query) | |||
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -105,7 +105,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv | |||
return nil, driver.ErrSkip | |||
} | |||
executor, err := exec.BuildExecutor(s.res.dbType, s.query) | |||
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -138,7 +138,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv | |||
// Deprecated: Drivers should implement StmtExecContext instead (or additionally). | |||
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { | |||
// in transaction, need run Executor | |||
executor, err := exec.BuildExecutor(s.res.dbType, s.query) | |||
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -173,7 +173,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive | |||
} | |||
// in transaction, need run Executor | |||
executor, err := exec.BuildExecutor(s.res.dbType, s.query) | |||
executor, err := exec.BuildExecutor(s.res.dbType, s.txCtx.TransType, s.query) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -20,6 +20,7 @@ package sql | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"sync" | |||
"github.com/seata/seata-go/pkg/datasource/sql/undo" | |||
@@ -34,7 +35,34 @@ import ( | |||
const REPORT_RETRY_COUNT = 5 | |||
type txOption func(tx *Tx) | |||
var ( | |||
hl sync.RWMutex | |||
txHooks []txHook | |||
) | |||
func RegisterTxHook(h txHook) { | |||
hl.Lock() | |||
defer hl.Unlock() | |||
txHooks = append(txHooks, h) | |||
} | |||
func CleanTxHooks() { | |||
hl.Lock() | |||
defer hl.Unlock() | |||
txHooks = make([]txHook, 0, 4) | |||
} | |||
type ( | |||
txOption func(tx *Tx) | |||
txHook interface { | |||
BeforeCommit(tx *Tx) | |||
BeforeRollback(tx *Tx) | |||
} | |||
) | |||
func newTx(opts ...txOption) (driver.Tx, error) { | |||
tx := new(Tx) | |||
@@ -83,6 +111,15 @@ type Tx struct { | |||
// case 2. not need flush undolog, is XA mode, do local transaction commit | |||
// case 3. need run AT transaction | |||
func (tx *Tx) Commit() error { | |||
if len(txHooks) != 0 { | |||
hl.RLock() | |||
defer hl.RUnlock() | |||
for i := range txHooks { | |||
txHooks[i].BeforeCommit(tx) | |||
} | |||
} | |||
if tx.ctx.TransType == types.Local { | |||
return tx.commitOnLocal() | |||
} | |||
@@ -96,6 +133,15 @@ func (tx *Tx) Commit() error { | |||
} | |||
func (tx *Tx) Rollback() error { | |||
if len(txHooks) != 0 { | |||
hl.RLock() | |||
defer hl.RUnlock() | |||
for i := range txHooks { | |||
txHooks[i].BeforeRollback(tx) | |||
} | |||
} | |||
err := tx.target.Rollback() | |||
if err != nil { | |||
if tx.ctx.OpenGlobalTrsnaction() && tx.ctx.IsBranchRegistered() { | |||
@@ -124,7 +170,7 @@ func (tx *Tx) commitOnXA() error { | |||
// commitOnAT | |||
func (tx *Tx) commitOnAT() error { | |||
// if TX-Mode is AT, run regis this transaction branch | |||
if err := tx.regis(tx.ctx); err != nil { | |||
if err := tx.register(tx.ctx); err != nil { | |||
return err | |||
} | |||
@@ -133,7 +179,7 @@ func (tx *Tx) commitOnAT() error { | |||
return err | |||
} | |||
if err := undoLogMgr.FlushUndoLog(tx.ctx, nil); err != nil { | |||
if err := undoLogMgr.FlushUndoLog(tx.ctx, tx.conn.targetConn); err != nil { | |||
if rerr := tx.report(false); rerr != nil { | |||
return errors.WithStack(rerr) | |||
} | |||
@@ -151,8 +197,8 @@ func (tx *Tx) commitOnAT() error { | |||
return nil | |||
} | |||
// regis | |||
func (tx *Tx) regis(ctx *types.TransactionContext) error { | |||
// register | |||
func (tx *Tx) register(ctx *types.TransactionContext) error { | |||
if !ctx.HasUndoLog() || !ctx.HasLockKey() { | |||
return nil | |||
} | |||
@@ -103,7 +103,7 @@ type TransactionContext struct { | |||
func NewTxCtx() *TransactionContext { | |||
return &TransactionContext{ | |||
LockKeys: make([]string, 0, 4), | |||
TransType: ATMode, | |||
TransType: Local, | |||
LocalTransID: uuid.New().String(), | |||
RoundImages: &RoundRecordImage{}, | |||
} | |||
@@ -48,7 +48,7 @@ func (m *BaseUndoLogManager) Init() { | |||
} | |||
// InsertUndoLog | |||
func (m *BaseUndoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Tx) error { | |||
func (m *BaseUndoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Conn) error { | |||
return nil | |||
} | |||
@@ -102,7 +102,7 @@ func (m *BaseUndoLogManager) BatchDeleteUndoLog(xid []string, branchID []int64, | |||
} | |||
// FlushUndoLog | |||
func (m *BaseUndoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Tx) error { | |||
func (m *BaseUndoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Conn) error { | |||
return nil | |||
} | |||
@@ -39,7 +39,7 @@ func (m *undoLogManager) Init() { | |||
} | |||
// InsertUndoLog | |||
func (m *undoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Tx) error { | |||
func (m *undoLogManager) InsertUndoLog(l []undo.BranchUndoLog, tx driver.Conn) error { | |||
return m.Base.InsertUndoLog(l, tx) | |||
} | |||
@@ -54,7 +54,7 @@ func (m *undoLogManager) BatchDeleteUndoLog(xid []string, branchID []int64, conn | |||
} | |||
// FlushUndoLog | |||
func (m *undoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Tx) error { | |||
func (m *undoLogManager) FlushUndoLog(txCtx *types.TransactionContext, tx driver.Conn) error { | |||
return m.Base.FlushUndoLog(txCtx, tx) | |||
} | |||
@@ -50,13 +50,13 @@ func Regis(m UndoLogManager) error { | |||
type UndoLogManager interface { | |||
Init() | |||
// InsertUndoLog | |||
InsertUndoLog(l []BranchUndoLog, tx driver.Tx) error | |||
InsertUndoLog(l []BranchUndoLog, tx driver.Conn) error | |||
// DeleteUndoLog | |||
DeleteUndoLog(ctx context.Context, xid string, branchID int64, conn *sql.Conn) error | |||
// BatchDeleteUndoLog | |||
BatchDeleteUndoLog(xid []string, branchID []int64, conn *sql.Conn) error | |||
// FlushUndoLog | |||
FlushUndoLog(txCtx *types.TransactionContext, tx driver.Tx) error | |||
FlushUndoLog(txCtx *types.TransactionContext, tx driver.Conn) error | |||
// RunUndo | |||
RunUndo(xid string, branchID int64, conn *sql.Conn) error | |||
// DBType | |||
@@ -32,7 +32,7 @@ func TestBatchDeleteUndoLogs(t *testing.T) { | |||
t.SkipNow() | |||
testBatchDeleteUndoLogs := func() { | |||
db, err := sql.Open(SeataMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true") | |||
db, err := sql.Open(SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true") | |||
assert.Nil(t, err) | |||
sqlConn, err := db.Conn(context.Background()) | |||
@@ -54,7 +54,7 @@ func TestDeleteUndoLogs(t *testing.T) { | |||
t.SkipNow() | |||
testDeleteUndoLogs := func() { | |||
db, err := sql.Open(SeataMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true") | |||
db, err := sql.Open(SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_order?multiStatements=true") | |||
assert.Nil(t, err) | |||
ctx := context.Background() | |||