@@ -43,7 +43,7 @@ func InitPath(configFilePath string) { | |||
initRmClient(cfg) | |||
initTmClient(cfg) | |||
initDatasource(cfg) | |||
initDatasource() | |||
} | |||
var ( | |||
@@ -84,10 +84,11 @@ func initRmClient(cfg *Config) { | |||
integration.Init() | |||
tcc.InitTCC() | |||
at.InitAT(cfg.ClientConfig.UndoConfig, cfg.AsyncWorkerConfig) | |||
at.InitXA(cfg.ClientConfig.XaConfig) | |||
}) | |||
} | |||
func initDatasource(cfg *Config) { | |||
func initDatasource() { | |||
onceInitDatasource.Do(func() { | |||
datasource.Init() | |||
}) | |||
@@ -54,15 +54,17 @@ const ( | |||
) | |||
type ClientConfig struct { | |||
TmConfig tm.TmConfig `yaml:"tm" json:"tm,omitempty" koanf:"tm"` | |||
RmConfig rm.Config `yaml:"rm" json:"rm,omitempty" koanf:"rm"` | |||
UndoConfig undo.Config `yaml:"undo" json:"undo,omitempty" koanf:"undo"` | |||
TmConfig tm.TmConfig `yaml:"tm" json:"tm,omitempty" koanf:"tm"` | |||
RmConfig rm.Config `yaml:"rm" json:"rm,omitempty" koanf:"rm"` | |||
UndoConfig undo.Config `yaml:"undo" json:"undo,omitempty" koanf:"undo"` | |||
XaConfig sql.XAConfig `yaml:"xa" json:"xa" koanf:"xa"` | |||
} | |||
func (c *ClientConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { | |||
c.TmConfig.RegisterFlagsWithPrefix(prefix+".tm", f) | |||
c.RmConfig.RegisterFlagsWithPrefix(prefix+".rm", f) | |||
c.UndoConfig.RegisterFlagsWithPrefix(prefix+".undo", f) | |||
c.XaConfig.RegisterFlagsWithPrefix(prefix+".xa", f) | |||
} | |||
type Config struct { | |||
@@ -17,8 +17,10 @@ | |||
package datasource | |||
import sql2 "github.com/seata/seata-go/pkg/datasource/sql" | |||
import ( | |||
"github.com/seata/seata-go/pkg/datasource/sql" | |||
) | |||
func Init() { | |||
sql2.Init() | |||
sql.Init() | |||
} |
@@ -56,23 +56,23 @@ func (a *ATSourceManager) GetBranchType() branch.BranchType { | |||
return branch.BranchTypeAT | |||
} | |||
// Get all resources managed by this manager | |||
// GetCachedResources get all resources managed by this manager | |||
func (a *ATSourceManager) GetCachedResources() *sync.Map { | |||
return &a.resourceCache | |||
} | |||
// Register a Resource to be managed by Resource Manager | |||
// RegisterResource register a Resource to be managed by Resource Manager | |||
func (a *ATSourceManager) RegisterResource(res rm.Resource) error { | |||
a.resourceCache.Store(res.GetResourceId(), res) | |||
return a.basic.RegisterResource(res) | |||
} | |||
// Unregister a Resource from the Resource Manager | |||
// UnregisterResource unregister a Resource from the Resource Manager | |||
func (a *ATSourceManager) UnregisterResource(res rm.Resource) error { | |||
return a.basic.UnregisterResource(res) | |||
} | |||
// Rollback a branch transaction | |||
// BranchRollback rollback a branch transaction | |||
func (a *ATSourceManager) BranchRollback(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) { | |||
var dbResource *DBResource | |||
if resource, ok := a.resourceCache.Load(branchResource.ResourceId); !ok { | |||
@@ -103,28 +103,26 @@ func (a *ATSourceManager) BranchRollback(ctx context.Context, branchResource rm. | |||
return branch.BranchStatusPhasetwoRollbacked, nil | |||
} | |||
// BranchCommit | |||
// BranchCommit commit the branch transaction | |||
func (a *ATSourceManager) BranchCommit(ctx context.Context, resource rm.BranchResource) (branch.BranchStatus, error) { | |||
a.worker.BranchCommit(ctx, resource) | |||
return branch.BranchStatusPhasetwoCommitted, nil | |||
} | |||
// LockQuery | |||
func (a *ATSourceManager) LockQuery(ctx context.Context, param rm.LockQueryParam) (bool, error) { | |||
return a.rmRemoting.LockQuery(param) | |||
} | |||
// BranchRegister | |||
// BranchRegister branch transaction register | |||
func (a *ATSourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { | |||
return a.rmRemoting.BranchRegister(req) | |||
} | |||
// BranchReport | |||
// BranchReport Report status of transaction branch | |||
func (a *ATSourceManager) BranchReport(ctx context.Context, param rm.BranchReportParam) error { | |||
return a.rmRemoting.BranchReport(param) | |||
} | |||
// CreateTableMetaCache | |||
func (a *ATSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, | |||
db *sql.DB) (datasource.TableMetaCache, error) { | |||
return a.basic.CreateTableMetaCache(ctx, resID, dbType, db) |
@@ -211,7 +211,13 @@ func (c *Conn) Begin() (driver.Tx, error) { | |||
// | |||
// global transaction according to tranCtx. If so, it needs to be included in the transaction management of seata | |||
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | |||
c.autoCommit = false | |||
if c.txCtx.TransactionMode == types.XAMode { | |||
return newTx( | |||
withDriverConn(c), | |||
withTxCtx(c.txCtx), | |||
withOriginTx(nil), | |||
) | |||
} | |||
if conn, ok := c.targetConn.(driver.ConnBeginTx); ok { | |||
tx, err := conn.BeginTx(ctx, opts) | |||
@@ -1,91 +0,0 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package conn | |||
import "time" | |||
const ( | |||
// TMENDICANT Ends a recovery scan. | |||
TMENDRSCAN = 0x00800000 | |||
/** | |||
* Disassociates the caller and marks the transaction branch | |||
* rollback-only. | |||
*/ | |||
TMFAIL = 0x20000000 | |||
/** | |||
* Caller is joining existing transaction branch. | |||
*/ | |||
TMJOIN = 0x00200000 | |||
/** | |||
* Use TMNOFLAGS to indicate no flags value is selected. | |||
*/ | |||
TMNOFLAGS = 0x00000000 | |||
/** | |||
* Caller is using one-phase optimization. | |||
*/ | |||
TMONEPHASE = 0x40000000 | |||
/** | |||
* Caller is resuming association with a suspended | |||
* transaction branch. | |||
*/ | |||
TMRESUME = 0x08000000 | |||
/** | |||
* Starts a recovery scan. | |||
*/ | |||
TMSTARTRSCAN = 0x01000000 | |||
/** | |||
* Disassociates caller from a transaction branch. | |||
*/ | |||
TMSUCCESS = 0x04000000 | |||
/** | |||
* Caller is suspending (not ending) its association with | |||
* a transaction branch. | |||
*/ | |||
TMSUSPEND = 0x02000000 | |||
/** | |||
* The transaction branch has been read-only and has been committed. | |||
*/ | |||
XA_RDONLY = 0x00000003 | |||
/** | |||
* The transaction work has been prepared normally. | |||
*/ | |||
XA_OK = 0 | |||
) | |||
type XAResource interface { | |||
Commit(xid string, onePhase bool) error | |||
End(xid string, flags int) error | |||
Forget(xid string) error | |||
GetTransactionTimeout() time.Duration | |||
IsSameRM(resource XAResource) bool | |||
XAPrepare(xid string) (int, error) | |||
Recover(flag int) []string | |||
Rollback(xid string) error | |||
SetTransactionTimeout(duration time.Duration) bool | |||
Start(xid string, flags int) error | |||
} |
@@ -22,11 +22,10 @@ import ( | |||
gosql "database/sql" | |||
"database/sql/driver" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/tm" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
) | |||
// ATConn Database connection proxy object under XA transaction model | |||
@@ -26,11 +26,13 @@ import ( | |||
"github.com/golang/mock/gomock" | |||
"github.com/google/uuid" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/tm" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func TestMain(m *testing.M) { | |||
@@ -41,7 +43,7 @@ func TestMain(m *testing.M) { | |||
func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) { | |||
ctrl := gomock.NewController(t) | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open(SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_client?multiStatements=true") | |||
@@ -57,6 +59,14 @@ func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQL | |||
mockConn := mock.NewMockTestDriverConn(ctrl) | |||
mockConn.EXPECT().Begin().AnyTimes().Return(mockTx, nil) | |||
mockConn.EXPECT().BeginTx(gomock.Any(), gomock.Any()).AnyTimes().Return(mockTx, nil) | |||
mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn( | |||
func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
rows := &mysqlMockRows{} | |||
rows.data = [][]interface{}{ | |||
{"8.0.29"}, | |||
} | |||
return rows, nil | |||
}) | |||
baseMockConn(mockConn) | |||
connector := mock.NewMockTestDriverConnector(ctrl) | |||
@@ -19,41 +19,74 @@ package sql | |||
import ( | |||
"context" | |||
gosql "database/sql" | |||
"database/sql/driver" | |||
"fmt" | |||
"time" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/datasource/sql/xa" | |||
"github.com/seata/seata-go/pkg/tm" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
) | |||
var xaConnTimeout time.Duration | |||
// XAConn Database connection proxy object under XA transaction model | |||
// Conn is assumed to be stateful. | |||
type XAConn struct { | |||
*Conn | |||
tx driver.Tx | |||
xaResource xa.XAResource | |||
xaBranchXid *XABranchXid | |||
xaActive bool | |||
rollBacked bool | |||
branchRegisterTime time.Time | |||
prepareTime time.Time | |||
isConnKept bool | |||
} | |||
// QueryContext | |||
func (c *XAConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | |||
if c.createOnceTxContext(ctx) { | |||
defer func() { | |||
c.txCtx = types.NewTxCtx() | |||
}() | |||
} | |||
return c.Conn.QueryContext(ctx, query, args) | |||
//ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types, error) { | |||
// ret, err := c.Conn.PrepareContext(ctx, query) | |||
// if err != nil { | |||
// return nil, err | |||
// } | |||
// return types.NewResult(types.WithRows(ret)), nil | |||
//}) | |||
return c.Conn.PrepareContext(ctx, query) | |||
} | |||
// PrepareContext | |||
func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | |||
// QueryContext exec xa sql | |||
func (c *XAConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
if c.createOnceTxContext(ctx) { | |||
defer func() { | |||
c.txCtx = types.NewTxCtx() | |||
}() | |||
} | |||
return c.Conn.PrepareContext(ctx, query) | |||
ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) { | |||
ret, err := c.Conn.QueryContext(ctx, query, args) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return types.NewResult(types.WithRows(ret)), nil | |||
}) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return ret.GetRows(), nil | |||
} | |||
// ExecContext | |||
func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | |||
if c.createOnceTxContext(ctx) { | |||
defer func() { | |||
@@ -61,26 +94,56 @@ func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.Na | |||
}() | |||
} | |||
return c.Conn.ExecContext(ctx, query, args) | |||
ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) { | |||
ret, err := c.Conn.ExecContext(ctx, query, args) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return types.NewResult(types.WithResult(ret)), nil | |||
}) | |||
return ret.GetResult(), err | |||
} | |||
// BeginTx | |||
// BeginTx like common transaction. but it just exec XA START | |||
func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | |||
c.autoCommit = false | |||
c.txCtx = types.NewTxCtx() | |||
c.txCtx.DBType = c.res.dbType | |||
c.txCtx.TxOpt = opts | |||
c.txCtx.ResourceID = c.res.resourceID | |||
if tm.IsGlobalTx(ctx) { | |||
c.txCtx.TransactionMode = types.XAMode | |||
c.txCtx.XID = tm.GetXID(ctx) | |||
c.txCtx.TransactionMode = types.XAMode | |||
} | |||
tx, err := c.Conn.BeginTx(ctx, opts) | |||
if err != nil { | |||
return nil, err | |||
} | |||
c.tx = tx | |||
if c.autoCommit { | |||
baseTx, ok := tx.(*Tx) | |||
if !ok { | |||
return nil, fmt.Errorf("start xa %s transaction failure for the tx is a wrong type", c.txCtx.XID) | |||
} | |||
c.branchRegisterTime = time.Now() | |||
if err := baseTx.register(c.txCtx); err != nil { | |||
c.cleanXABranchContext() | |||
return nil, fmt.Errorf("failed to register xa branch %s, err:%w", c.txCtx.XID, err) | |||
} | |||
c.xaBranchXid = XaIdBuild(c.txCtx.XID, c.txCtx.BranchID) | |||
c.keepIfNecessary() | |||
if err = c.start(ctx); err != nil { | |||
c.cleanXABranchContext() | |||
return nil, fmt.Errorf("failed to start xa branch xid:%s err:%w", c.txCtx.XID, err) | |||
} | |||
c.xaActive = true | |||
} | |||
return &XATx{tx: tx.(*Tx)}, nil | |||
} | |||
@@ -91,9 +154,240 @@ func (c *XAConn) createOnceTxContext(ctx context.Context) bool { | |||
if onceTx { | |||
c.txCtx = types.NewTxCtx() | |||
c.txCtx.DBType = c.res.dbType | |||
c.txCtx.ResourceID = c.res.resourceID | |||
c.txCtx.XID = tm.GetXID(ctx) | |||
c.txCtx.TransactionMode = types.XAMode | |||
c.txCtx.GlobalLockRequire = true | |||
} | |||
return onceTx | |||
} | |||
func (c *XAConn) createNewTxOnExecIfNeed(ctx context.Context, f func() (types.ExecResult, error)) (types.ExecResult, error) { | |||
var err error | |||
if c.txCtx.TransactionMode != types.Local && c.autoCommit { | |||
_, err = c.BeginTx(ctx, driver.TxOptions{Isolation: driver.IsolationLevel(gosql.LevelDefault)}) | |||
if err != nil { | |||
return nil, err | |||
} | |||
} | |||
defer func() { | |||
recoverErr := recover() | |||
if err != nil || recoverErr != nil { | |||
log.Errorf("conn at rollback error:%v or recoverErr:%v", err, recoverErr) | |||
if c.tx != nil { | |||
rollbackErr := c.tx.Rollback() | |||
if rollbackErr != nil { | |||
log.Errorf("conn at rollback error:%v", rollbackErr) | |||
} | |||
} | |||
} | |||
}() | |||
// execute SQL | |||
ret, err := f() | |||
if err != nil { | |||
// XA End & Rollback | |||
if rollbackErr := c.Rollback(ctx); rollbackErr != nil { | |||
log.Errorf("failed to rollback xa branch of :%s, err:%w", c.txCtx.XID, rollbackErr) | |||
} | |||
return nil, err | |||
} | |||
if c.autoCommit { | |||
if err := c.Commit(ctx); err != nil { | |||
log.Errorf("xa connection proxy commit failure xid:%s, err:%v", c.txCtx.XID, err) | |||
// XA End & Rollback | |||
if err := c.Rollback(ctx); err != nil { | |||
log.Errorf("xa connection proxy rollback failure xid:%s, err:%v", c.txCtx.XID, err) | |||
} | |||
} | |||
} | |||
return ret, nil | |||
} | |||
func (c *XAConn) keepIfNecessary() { | |||
if c.ShouldBeHeld() { | |||
if err := c.res.Hold(c.xaBranchXid.String(), c); err == nil { | |||
c.isConnKept = true | |||
} | |||
} | |||
} | |||
func (c *XAConn) releaseIfNecessary() { | |||
if c.ShouldBeHeld() && c.xaBranchXid.String() != "" { | |||
if c.isConnKept { | |||
c.res.Release(c.xaBranchXid.String()) | |||
c.isConnKept = false | |||
} | |||
} | |||
} | |||
func (c *XAConn) start(ctx context.Context) error { | |||
xaResource, err := xa.CreateXAResource(c.Conn.targetConn, c.dbType) | |||
if err != nil { | |||
return fmt.Errorf("create xa xid:%s resoruce err:%w", c.txCtx.XID, err) | |||
} | |||
c.xaResource = xaResource | |||
if err := c.xaResource.Start(ctx, c.xaBranchXid.String(), xa.TMNoFlags); err != nil { | |||
return fmt.Errorf("xa xid %s resource connection start err:%w", c.txCtx.XID, err) | |||
} | |||
if err := c.termination(c.xaBranchXid.String()); err != nil { | |||
c.xaResource.End(ctx, c.xaBranchXid.String(), xa.TMFail) | |||
c.XaRollback(ctx, c.xaBranchXid) | |||
return err | |||
} | |||
return err | |||
} | |||
func (c *XAConn) end(ctx context.Context, flags int) error { | |||
err := c.termination(c.xaBranchXid.String()) | |||
if err != nil { | |||
return err | |||
} | |||
err = c.xaResource.End(ctx, c.xaBranchXid.String(), flags) | |||
if err != nil { | |||
return err | |||
} | |||
return nil | |||
} | |||
func (c *XAConn) termination(xaBranchXid string) error { | |||
branchStatus, err := branchStatus(xaBranchXid) | |||
if err != nil { | |||
c.releaseIfNecessary() | |||
return fmt.Errorf("failed xa branch [%v] the global transaction has finish, branch status: [%v]", c.txCtx.XID, branchStatus) | |||
} | |||
return nil | |||
} | |||
func (c *XAConn) cleanXABranchContext() { | |||
h, _ := time.ParseDuration("-1000h") | |||
c.branchRegisterTime = time.Now().Add(h) | |||
c.prepareTime = time.Now().Add(h) | |||
c.xaActive = false | |||
if !c.isConnKept { | |||
c.xaBranchXid = nil | |||
} | |||
} | |||
func (c *XAConn) Rollback(ctx context.Context) error { | |||
if c.autoCommit { | |||
return nil | |||
} | |||
if !c.xaActive || c.xaBranchXid == nil { | |||
return fmt.Errorf("should NOT rollback on an inactive session") | |||
} | |||
if !c.rollBacked { | |||
if c.xaResource.End(ctx, c.xaBranchXid.String(), xa.TMFail) != nil { | |||
return c.rollbackErrorHandle() | |||
} | |||
if c.XaRollback(ctx, c.xaBranchXid) != nil { | |||
c.cleanXABranchContext() | |||
return c.rollbackErrorHandle() | |||
} | |||
if err := c.tx.Rollback(); err != nil { | |||
c.cleanXABranchContext() | |||
return fmt.Errorf("failed to report XA branch commit-failure on xid:%s err:%w", c.txCtx.XID, err) | |||
} | |||
} | |||
c.cleanXABranchContext() | |||
return nil | |||
} | |||
func (c *XAConn) rollbackErrorHandle() error { | |||
return fmt.Errorf("failed to end(TMFAIL) xa branch on [%v] - [%v]", c.txCtx.XID, c.xaBranchXid.GetBranchId()) | |||
} | |||
func (c *XAConn) Commit(ctx context.Context) error { | |||
if c.autoCommit { | |||
return nil | |||
} | |||
if !c.xaActive || c.xaBranchXid == nil { | |||
return fmt.Errorf("should NOT commit on an inactive session") | |||
} | |||
now := time.Now() | |||
if c.end(ctx, xa.TMSuccess) != nil { | |||
return c.commitErrorHandle() | |||
} | |||
if c.checkTimeout(ctx, now) != nil { | |||
return c.commitErrorHandle() | |||
} | |||
if c.xaResource.XAPrepare(ctx, c.xaBranchXid.String()) != nil { | |||
return c.commitErrorHandle() | |||
} | |||
return nil | |||
} | |||
func (c *XAConn) commitErrorHandle() error { | |||
var err error | |||
if err = c.tx.Rollback(); err != nil { | |||
err = fmt.Errorf("failed to report XA branch commit-failure xid:%s, err:%w", c.txCtx.XID, err) | |||
} | |||
c.cleanXABranchContext() | |||
return err | |||
} | |||
func (c *XAConn) ShouldBeHeld() bool { | |||
return c.res.IsShouldBeHeld() || (c.res.GetDbType().String() != "" && c.res.GetDbType() != types.DBTypeUnknown) | |||
} | |||
func (c *XAConn) checkTimeout(ctx context.Context, now time.Time) error { | |||
if now.Sub(c.branchRegisterTime) > xaConnTimeout { | |||
c.XaRollback(ctx, c.xaBranchXid) | |||
return fmt.Errorf("XA branch timeout error xid:%s", c.txCtx.XID) | |||
} | |||
return nil | |||
} | |||
func (c *XAConn) Close() error { | |||
c.rollBacked = false | |||
if c.isConnKept && c.ShouldBeHeld() { | |||
return nil | |||
} | |||
c.cleanXABranchContext() | |||
if err := c.Conn.Close(); err != nil { | |||
return err | |||
} | |||
return nil | |||
} | |||
func (c *XAConn) CloseForce() error { | |||
if err := c.Conn.Close(); err != nil { | |||
return err | |||
} | |||
c.rollBacked = false | |||
c.cleanXABranchContext() | |||
if err := c.Conn.Close(); err != nil { | |||
return err | |||
} | |||
c.releaseIfNecessary() | |||
return nil | |||
} | |||
func (c *XAConn) XaCommit(ctx context.Context, xid string, branchId int64) error { | |||
xaXid := XaIdBuild(xid, uint64(branchId)) | |||
err := c.xaResource.Commit(ctx, xaXid.String(), false) | |||
c.releaseIfNecessary() | |||
return err | |||
} | |||
func (c *XAConn) XaRollbackByBranchId(ctx context.Context, xid string, branchId int64) error { | |||
xaXid := XaIdBuild(xid, uint64(branchId)) | |||
return c.XaRollback(ctx, xaXid) | |||
} | |||
func (c *XAConn) XaRollback(ctx context.Context, xaXid XAXid) error { | |||
err := c.xaResource.Rollback(ctx, xaXid.GetGlobalXid()) | |||
c.releaseIfNecessary() | |||
return err | |||
} |
@@ -21,18 +21,59 @@ import ( | |||
"context" | |||
"database/sql" | |||
"database/sql/driver" | |||
"io" | |||
"sync/atomic" | |||
"testing" | |||
"time" | |||
"github.com/bluele/gcache" | |||
"github.com/golang/mock/gomock" | |||
"github.com/google/uuid" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/tm" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
type mysqlMockRows struct { | |||
idx int | |||
data [][]interface{} | |||
} | |||
func (m *mysqlMockRows) Columns() []string { | |||
//TODO implement me | |||
panic("implement me") | |||
} | |||
func (m *mysqlMockRows) Close() error { | |||
//TODO implement me | |||
panic("implement me") | |||
} | |||
func (m *mysqlMockRows) Next(dest []driver.Value) error { | |||
if m.idx == len(m.data) { | |||
return io.EOF | |||
} | |||
min := func(a, b int) int { | |||
if a < b { | |||
return a | |||
} | |||
return b | |||
} | |||
cnt := min(len(m.data[0]), len(dest)) | |||
for i := 0; i < cnt; i++ { | |||
dest[i] = m.data[m.idx][i] | |||
} | |||
m.idx++ | |||
return nil | |||
} | |||
type mockSQLInterceptor struct { | |||
before func(ctx context.Context, execCtx *types.ExecContext) | |||
after func(ctx context.Context, execCtx *types.ExecContext) | |||
@@ -78,16 +119,27 @@ func (mi *mockTxHook) BeforeRollback(tx *Tx) { | |||
} | |||
func baseMockConn(mockConn *mock.MockTestDriverConn) { | |||
branchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build() | |||
mockConn.EXPECT().ExecContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil) | |||
mockConn.EXPECT().Exec(gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil) | |||
mockConn.EXPECT().ResetSession(gomock.Any()).AnyTimes().Return(nil) | |||
mockConn.EXPECT().Close().AnyTimes().Return(nil) | |||
mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn( | |||
func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
rows := &mysqlMockRows{} | |||
rows.data = [][]interface{}{ | |||
{"8.0.29"}, | |||
} | |||
return rows, nil | |||
}) | |||
} | |||
func initXAConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) { | |||
ctrl := gomock.NewController(t) | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
mockMgr := initMockResourceManager(branch.BranchTypeXA, ctrl) | |||
_ = mockMgr | |||
//db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
db, err := sql.Open("seata-xa-mysql", "root:12345678@tcp(127.0.0.1:3306)/seata_client?multiStatements=true&interpolateParams=true") | |||
@@ -20,11 +20,14 @@ package sql | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"errors" | |||
"io" | |||
"sync" | |||
"github.com/go-sql-driver/mysql" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
) | |||
type seataATConnector struct { | |||
@@ -89,7 +92,6 @@ func (c *seataXAConnector) Driver() driver.Driver { | |||
// method will call Close and return error (if any). | |||
type seataConnector struct { | |||
transType types.TransactionMode | |||
conf *seataServerConfig | |||
res *DBResource | |||
once sync.Once | |||
driver driver.Driver | |||
@@ -116,6 +118,15 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) { | |||
return nil, err | |||
} | |||
// get the version of mysql for xa. | |||
if c.transType == types.XAMode { | |||
version, err := c.dbVersion(ctx, conn) | |||
if err != nil { | |||
return nil, err | |||
} | |||
c.res.SetDbVersion(version) | |||
} | |||
return &Conn{ | |||
targetConn: conn, | |||
res: c.res, | |||
@@ -126,6 +137,40 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) { | |||
}, nil | |||
} | |||
func (c *seataConnector) dbVersion(ctx context.Context, conn driver.Conn) (string, error) { | |||
queryConn, isQueryContext := conn.(driver.QueryerContext) | |||
if !isQueryContext { | |||
return "", errors.New("get db version error for unexpected driver conn") | |||
} | |||
res, err := queryConn.QueryContext(ctx, "SELECT VERSION()", nil) | |||
if err != nil { | |||
log.Errorf("seata connector get the xa mysql version err:%v", err) | |||
return "", err | |||
} | |||
dest := make([]driver.Value, 1) | |||
var version string | |||
for true { | |||
if err = res.Next(dest); err != nil { | |||
if err == io.EOF { | |||
return version, nil | |||
} | |||
return "", err | |||
} | |||
if len(dest) != 1 { | |||
return "", errors.New("get the mysql version is not column 1") | |||
} | |||
var isVersionOk bool | |||
version, isVersionOk = dest[0].(string) | |||
if !isVersionOk { | |||
return "", errors.New("get the mysql version is not a string") | |||
} | |||
} | |||
return "", errors.New("get the mysql version is error") | |||
} | |||
// Driver returns the underlying Driver of the Connector, | |||
// mainly to maintain compatibility with the Driver method | |||
// on sql.DB. | |||
@@ -25,10 +25,12 @@ import ( | |||
"testing" | |||
"github.com/golang/mock/gomock" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/util/reflectx" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
type initConnectorFunc func(t *testing.T, ctrl *gomock.Controller) driver.Connector | |||
@@ -38,6 +40,14 @@ func initMockConnector(t *testing.T, ctrl *gomock.Controller) driver.Connector { | |||
connector := mock.NewMockTestDriverConnector(ctrl) | |||
connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil) | |||
mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn( | |||
func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
rows := &mysqlMockRows{} | |||
rows.data = [][]interface{}{ | |||
{"8.0.29"}, | |||
} | |||
return rows, nil | |||
}) | |||
return connector | |||
} | |||
@@ -66,10 +76,10 @@ func Test_seataATConnector_Connect(t *testing.T) { | |||
ctrl := gomock.NewController(t) | |||
defer ctrl.Finish() | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
db, err := sql.Open(SeataATMySQLDriver, "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
if err != nil { | |||
t.Fatal(err) | |||
} | |||
@@ -110,10 +120,10 @@ func Test_seataXAConnector_Connect(t *testing.T) { | |||
ctrl := gomock.NewController(t) | |||
defer ctrl.Finish() | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
mockMgr := initMockResourceManager(branch.BranchTypeXA, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
db, err := sql.Open(SeataXAMySQLDriver, "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
if err != nil { | |||
t.Fatal(err) | |||
} | |||
@@ -25,7 +25,6 @@ import ( | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/rm" | |||
) | |||
@@ -34,7 +33,7 @@ var ( | |||
tableMetaCacheMap = map[types.DBType]TableMetaCache{} | |||
) | |||
// RegisterTableCache | |||
// RegisterTableCache register the table meta cache for at and xa | |||
func RegisterTableCache(dbType types.DBType, tableMetaCache TableMetaCache) { | |||
tableMetaCacheMap[dbType] = tableMetaCache | |||
} | |||
@@ -54,11 +53,8 @@ func GetDataSourceManager(branchType branch.BranchType) DataSourceManager { | |||
return nil | |||
} | |||
// todo implements ResourceManagerOutbound interface | |||
// DataSourceManager | |||
type DataSourceManager interface { | |||
rm.ResourceManager | |||
// CreateTableMetaCache | |||
CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (TableMetaCache, error) | |||
} | |||
@@ -67,9 +63,8 @@ type entry struct { | |||
metaCache TableMetaCache | |||
} | |||
// BasicSourceManager | |||
// BasicSourceManager the basic source manager for xa and at | |||
type BasicSourceManager struct { | |||
// lock | |||
lock sync.RWMutex | |||
// tableMetaCache | |||
// todo do not put meta cache here | |||
@@ -82,34 +77,7 @@ func NewBasicSourceManager() *BasicSourceManager { | |||
} | |||
} | |||
// Commit a branch transaction | |||
// TODO wait finish | |||
func (dm *BasicSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) { | |||
return branch.BranchStatusPhaseoneDone, nil | |||
} | |||
// Rollback a branch transaction | |||
// TODO wait finish | |||
func (dm *BasicSourceManager) BranchRollback(ctx context.Context, req message.BranchRollbackRequest) (branch.BranchStatus, error) { | |||
return branch.BranchStatusPhaseoneFailed, nil | |||
} | |||
// Branch register long | |||
func (dm *BasicSourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { | |||
return 0, nil | |||
} | |||
// Branch report | |||
func (dm *BasicSourceManager) BranchReport(ctx context.Context, req message.BranchReportRequest) error { | |||
return nil | |||
} | |||
// Lock query boolean | |||
func (dm *BasicSourceManager) LockQuery(ctx context.Context, branchType branch.BranchType, resourceId, xid, lockKeys string) (bool, error) { | |||
return true, nil | |||
} | |||
// Register a model.Resource to be managed by model.Resource Manager | |||
// RegisterResource register a model.Resource to be managed by model.Resource Manager | |||
func (dm *BasicSourceManager) RegisterResource(resource rm.Resource) error { | |||
err := rm.GetRMRemotingInstance().RegisterResource(resource) | |||
if err != nil { | |||
@@ -118,22 +86,11 @@ func (dm *BasicSourceManager) RegisterResource(resource rm.Resource) error { | |||
return nil | |||
} | |||
// Unregister a model.Resource from the model.Resource Manager | |||
func (dm *BasicSourceManager) UnregisterResource(resource rm.Resource) error { | |||
return fmt.Errorf("unsupport unregister resource") | |||
} | |||
// Get all resources managed by this manager | |||
func (dm *BasicSourceManager) GetManagedResources() *sync.Map { | |||
return nil | |||
} | |||
// Get the model.BranchType | |||
func (dm *BasicSourceManager) GetBranchType() branch.BranchType { | |||
return branch.BranchTypeAT | |||
} | |||
// CreateTableMetaCache | |||
// CreateTableMetaCache create a table meta cache | |||
func (dm *BasicSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (TableMetaCache, error) { | |||
dm.lock.Lock() | |||
defer dm.lock.Unlock() | |||
@@ -144,28 +101,19 @@ func (dm *BasicSourceManager) CreateTableMetaCache(ctx context.Context, resID st | |||
} | |||
dm.tableMetaCache[resID] = res | |||
// 注册 AT 数据资源 | |||
// dm.resourceMgr.RegisterResource(ATResource) | |||
return res.metaCache, err | |||
} | |||
// TableMetaCache tables metadata cache, default is open | |||
type TableMetaCache interface { | |||
// Init | |||
Init(ctx context.Context, conn *sql.DB) error | |||
// GetTableMeta | |||
GetTableMeta(ctx context.Context, dbName, table string) (*types.TableMeta, error) | |||
// Destroy | |||
Destroy() error | |||
} | |||
// buildResource | |||
// todo not here | |||
func buildResource(ctx context.Context, dbType types.DBType, db *sql.DB) (*entry, error) { | |||
cache := tableMetaCacheMap[dbType] | |||
if err := cache.Init(ctx, db); err != nil { | |||
return nil, err | |||
} | |||
@@ -1,124 +0,0 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package sql | |||
import ( | |||
"fmt" | |||
"time" | |||
"github.com/bluele/gcache" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
) | |||
type Holdable interface { | |||
SetHeld(held bool) | |||
IsHeld() bool | |||
ShouldBeHeld() bool | |||
} | |||
type BaseDataSourceResource struct { | |||
db *DBResource | |||
shouldBeHeld bool | |||
keeper map[string]interface{} | |||
Cache map[string]branch.BranchStatus | |||
} | |||
var BranchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build() | |||
func (b *BaseDataSourceResource) init() error { | |||
return nil | |||
} | |||
func (b *BaseDataSourceResource) GetDB() *DBResource { | |||
return b.db | |||
} | |||
func (b *BaseDataSourceResource) SetDB(db *DBResource) { | |||
b.db = db | |||
} | |||
func (b *BaseDataSourceResource) IsShouldBeHeld() bool { | |||
return b.shouldBeHeld | |||
} | |||
func (b *BaseDataSourceResource) SetShouldBeHeld(shouldBeHeld bool) { | |||
b.shouldBeHeld = shouldBeHeld | |||
} | |||
func (b *BaseDataSourceResource) GetKeeper() map[string]interface{} { | |||
return b.keeper | |||
} | |||
func (b *BaseDataSourceResource) SetKeeper(keeper map[string]interface{}) { | |||
b.keeper = keeper | |||
} | |||
func (b *BaseDataSourceResource) GetCache() map[string]branch.BranchStatus { | |||
return b.Cache | |||
} | |||
func (b *BaseDataSourceResource) SetCache(cache map[string]branch.BranchStatus) { | |||
b.Cache = cache | |||
} | |||
func (b *BaseDataSourceResource) GetResourceId() string { | |||
return b.db.GetResourceId() | |||
} | |||
func (b *BaseDataSourceResource) Hold(key string, value Holdable) (interface{}, error) { | |||
if value.IsHeld() { | |||
var x = b.keeper[key] | |||
if x != value { | |||
return nil, fmt.Errorf("something wrong with keeper, keeping[%v] but[%v] is also kept with the same key[%v]", x, value, key) | |||
} | |||
return value, nil | |||
} | |||
var x = b.keeper[key] | |||
b.keeper[key] = value | |||
value.SetHeld(true) | |||
return x, nil | |||
} | |||
func (b *BaseDataSourceResource) Release(key string, value Holdable) (interface{}, error) { | |||
if value.IsHeld() { | |||
var x = b.keeper[key] | |||
if x != value { | |||
return nil, fmt.Errorf("something wrong with keeper, keeping[%v] but[%v] is also kept with the same key[%v]", x, value, key) | |||
} | |||
return value, nil | |||
} | |||
var x = b.keeper[key] | |||
b.keeper[key] = value | |||
value.SetHeld(true) | |||
return x, nil | |||
} | |||
func (b *BaseDataSourceResource) GetBranchStatus(xaBranchXid string) (interface{}, error) { | |||
branchStatus, err := BranchStatusCache.GetIFPresent(xaBranchXid) | |||
return branchStatus, err | |||
} | |||
func (b *BaseDataSourceResource) GetDbType() string { | |||
return b.db.dbType.String() | |||
} | |||
func (b *BaseDataSourceResource) SetDbType(dbType types.DBType) { | |||
b.db.dbType = dbType | |||
} |
@@ -18,19 +18,24 @@ | |||
package sql | |||
import ( | |||
"context" | |||
"database/sql" | |||
"database/sql/driver" | |||
"fmt" | |||
"sync" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/datasource/sql/undo" | |||
"github.com/seata/seata-go/pkg/datasource/sql/util" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
) | |||
type dbOption func(db *DBResource) | |||
func withGroupID(id string) dbOption { | |||
func withDsn(dsn string) dbOption { | |||
return func(db *DBResource) { | |||
db.groupID = id | |||
db.dsn = dsn | |||
} | |||
} | |||
@@ -52,21 +57,33 @@ func withDBType(dt types.DBType) dbOption { | |||
} | |||
} | |||
func withBranchType(dt branch.BranchType) dbOption { | |||
return func(db *DBResource) { | |||
db.branchType = dt | |||
} | |||
} | |||
func withTarget(source *sql.DB) dbOption { | |||
return func(db *DBResource) { | |||
db.db = source | |||
} | |||
} | |||
func withConnector(ci driver.Connector) dbOption { | |||
return func(db *DBResource) { | |||
db.connector = ci | |||
} | |||
} | |||
func withDBName(dbName string) dbOption { | |||
return func(db *DBResource) { | |||
db.dbName = dbName | |||
} | |||
} | |||
func withConf(conf *seataServerConfig) dbOption { | |||
func withConf(conf *XAConnConf) dbOption { | |||
return func(db *DBResource) { | |||
db.conf = *conf | |||
db.xaConnConf = conf | |||
} | |||
} | |||
@@ -77,47 +94,36 @@ func newResource(opts ...dbOption) (*DBResource, error) { | |||
opts[i](db) | |||
} | |||
return db, db.init() | |||
db.init() | |||
return db, nil | |||
} | |||
// DBResource proxy sql.DB, enchance database/sql.DB to add distribute transaction ability | |||
type DBResource struct { | |||
// groupID | |||
groupID string | |||
// resourceID | |||
xaConnConf *XAConnConf | |||
// only use by mysql | |||
dbVersion string | |||
dsn string | |||
resourceID string | |||
// conf | |||
conf seataServerConfig | |||
// db | |||
db *sql.DB | |||
dbName string | |||
// dbType | |||
dbType types.DBType | |||
// undoLogMgr | |||
db *sql.DB | |||
connector driver.Connector | |||
dbName string | |||
dbType types.DBType | |||
undoLogMgr undo.UndoLogManager | |||
// metaCache | |||
metaCache datasource.TableMetaCache | |||
} | |||
branchType branch.BranchType | |||
func (db *DBResource) init() error { | |||
return nil | |||
// for xa | |||
metaCache datasource.TableMetaCache | |||
shouldBeHeld bool | |||
keeper sync.Map | |||
} | |||
// todo do not put meta data to rm | |||
//func (db *DBResource) init() error { | |||
// mgr := datasource.GetDataSourceManager(db.GetBranchType()) | |||
// metaCache, err := mgr.CreateTableMetaCache(context.Background(), db.resourceID, db.dbType, db.db) | |||
// if err != nil { | |||
// return err | |||
// } | |||
// | |||
// db.metaCache = metaCache | |||
// | |||
// return nil | |||
//} | |||
func (db *DBResource) GetResourceGroupId() string { | |||
return db.groupID | |||
panic("implement me") | |||
} | |||
func (db *DBResource) init() { | |||
db.checkDbVersion() | |||
} | |||
func (db *DBResource) GetResourceId() string { | |||
@@ -125,18 +131,106 @@ func (db *DBResource) GetResourceId() string { | |||
} | |||
func (db *DBResource) GetBranchType() branch.BranchType { | |||
return db.conf.BranchType | |||
return db.branchType | |||
} | |||
func (db *DBResource) GetDB() *sql.DB { | |||
return db.db | |||
} | |||
type SqlDBProxy struct { | |||
db *sql.DB | |||
dbName string | |||
func (db *DBResource) GetDBName() string { | |||
return db.dbName | |||
} | |||
func (s *SqlDBProxy) GetDB() *sql.DB { | |||
return s.db | |||
func (db *DBResource) GetDbType() types.DBType { | |||
return db.dbType | |||
} | |||
func (s *SqlDBProxy) GetDBName() string { | |||
return s.dbName | |||
func (db *DBResource) SetDbType(dbType types.DBType) { | |||
db.dbType = dbType | |||
} | |||
func (db *DBResource) SetDbVersion(v string) { | |||
db.dbVersion = v | |||
} | |||
func (db *DBResource) GetDbVersion() string { | |||
return db.dbVersion | |||
} | |||
func (db *DBResource) IsShouldBeHeld() bool { | |||
return db.shouldBeHeld | |||
} | |||
// Hold the xa connection. | |||
func (db *DBResource) Hold(xaBranchID string, v interface{}) error { | |||
_, exist := db.keeper.Load(xaBranchID) | |||
if !exist { | |||
db.keeper.Store(xaBranchID, v) | |||
return nil | |||
} | |||
return nil | |||
} | |||
func (db *DBResource) Release(xaBranchID string) { | |||
db.keeper.Delete(xaBranchID) | |||
} | |||
func (db *DBResource) Lookup(xaBranchID string) (interface{}, bool) { | |||
return db.keeper.Load(xaBranchID) | |||
} | |||
func (db *DBResource) GetKeeper() *sync.Map { | |||
return &db.keeper | |||
} | |||
func (db *DBResource) ConnectionForXA(ctx context.Context, xaXid XAXid) (*XAConn, error) { | |||
xaBranchXid := xaXid.String() | |||
tmpConn, ok := db.Lookup(xaBranchXid) | |||
if ok && tmpConn != nil { | |||
connectionProxyXa, isConnectionProxyXa := tmpConn.(*XAConn) | |||
if !isConnectionProxyXa { | |||
return nil, fmt.Errorf("get connection proxy xa from cache error, xid:%s", xaXid.String()) | |||
} | |||
return connectionProxyXa, nil | |||
} | |||
// why here need a new connection? | |||
// 1. because there maybe a rm cluster | |||
// 2. the first phase select a rm1, and store the connection is the keeper | |||
// 3. tc request the second phase. but the rm1 is shutdown, so the tc select another rm (like rm2) | |||
// 4. so when the second phase request coming to rm2, rm2 must not store the connection. | |||
// 5. the rm2 get the second phase do the two thing. | |||
// 1. in mysql version >= 8.0.29, mysql support the xa transaction commit by another connection. so just commit | |||
// 2. when the version < 8.0.29. so just make the transaction rollback | |||
newDriverConn, err := db.connector.Connect(ctx) | |||
if err != nil { | |||
return nil, fmt.Errorf("get xa new connection failure, xid:%s, err:%v", xaXid.String(), err) | |||
} | |||
xaConn := &XAConn{ | |||
Conn: newDriverConn.(*Conn), | |||
} | |||
return xaConn, nil | |||
} | |||
func (db *DBResource) checkDbVersion() error { | |||
switch db.dbType { | |||
case types.DBTypeMySQL: | |||
currentVersion, err := util.ConvertDbVersion(db.dbVersion) | |||
if err != nil { | |||
return fmt.Errorf("new connection xa proxy convert db version:%s err:%v", db.GetDbVersion(), err) | |||
} | |||
shouldKeptVersion, err := util.ConvertDbVersion("8.0.29") | |||
if err != nil { | |||
return fmt.Errorf("new connection xa proxy convert db version 8.0.29 err:%v", err) | |||
} | |||
if currentVersion < shouldKeptVersion { | |||
db.shouldBeHeld = true | |||
} | |||
case types.DBTypeMARIADB: | |||
db.shouldBeHeld = true | |||
} | |||
return nil | |||
} |
@@ -25,10 +25,10 @@ import ( | |||
"fmt" | |||
"strings" | |||
mysql2 "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" | |||
"github.com/go-sql-driver/mysql" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
mysql2 "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
@@ -44,15 +44,17 @@ const ( | |||
func initDriver() { | |||
sql.Register(SeataATMySQLDriver, &seataATDriver{ | |||
seataDriver: &seataDriver{ | |||
transType: types.ATMode, | |||
target: mysql.MySQLDriver{}, | |||
branchType: branch.BranchTypeAT, | |||
transType: types.ATMode, | |||
target: mysql.MySQLDriver{}, | |||
}, | |||
}) | |||
sql.Register(SeataXAMySQLDriver, &seataXADriver{ | |||
seataDriver: &seataDriver{ | |||
transType: types.XAMode, | |||
target: mysql.MySQLDriver{}, | |||
branchType: branch.BranchTypeXA, | |||
transType: types.XAMode, | |||
target: mysql.MySQLDriver{}, | |||
}, | |||
}) | |||
} | |||
@@ -98,8 +100,9 @@ func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err erro | |||
} | |||
type seataDriver struct { | |||
transType types.TransactionMode | |||
target driver.Driver | |||
branchType branch.BranchType | |||
transType types.TransactionMode | |||
target driver.Driver | |||
} | |||
// Open never be called, because seataDriver implemented dri.DriverContext interface. | |||
@@ -124,7 +127,7 @@ func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) | |||
return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName()) | |||
} | |||
proxy, err := getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name) | |||
proxy, err := d.getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name) | |||
if err != nil { | |||
log.Errorf("register resource: %w", err) | |||
return nil, err | |||
@@ -133,43 +136,15 @@ func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) | |||
return proxy, nil | |||
} | |||
func (d *seataDriver) getTargetDriverName() string { | |||
return "mysql" | |||
} | |||
type dsnConnector struct { | |||
dsn string | |||
driver driver.Driver | |||
} | |||
func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) { | |||
return t.driver.Open(t.dsn) | |||
} | |||
func (t *dsnConnector) Driver() driver.Driver { | |||
return t.driver | |||
} | |||
func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db *sql.DB, | |||
dataSourceName string, opts ...seataOption) (driver.Connector, error) { | |||
conf := loadConfig() | |||
for i := range opts { | |||
opts[i](conf) | |||
} | |||
if err := conf.validate(); err != nil { | |||
log.Errorf("invalid conf: %w", err) | |||
return nil, err | |||
} | |||
func (d *seataDriver) getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, | |||
db *sql.DB, dataSourceName string) (driver.Connector, error) { | |||
cfg, _ := mysql.ParseDSN(dataSourceName) | |||
options := []dbOption{ | |||
withGroupID(conf.GroupID), | |||
withResourceID(parseResourceID(dataSourceName)), | |||
withConf(conf), | |||
withTarget(db), | |||
withBranchType(d.branchType), | |||
withDBType(dbType), | |||
withDBName(cfg.DBName), | |||
withConnector(connector), | |||
} | |||
res, err := newResource(options...) | |||
@@ -179,7 +154,8 @@ func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db * | |||
} | |||
datasource.RegisterTableCache(types.DBTypeMySQL, mysql2.NewTableMetaInstance(db)) | |||
if err = datasource.GetDataSourceManager(conf.BranchType).RegisterResource(res); err != nil { | |||
if err = datasource.GetDataSourceManager(d.branchType).RegisterResource(res); err != nil { | |||
log.Errorf("regisiter resource: %w", err) | |||
return nil, err | |||
} | |||
@@ -187,38 +163,25 @@ func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db * | |||
return &seataConnector{ | |||
res: res, | |||
target: connector, | |||
conf: conf, | |||
cfg: cfg, | |||
}, nil | |||
} | |||
type ( | |||
seataOption func(cfg *seataServerConfig) | |||
// seataServerConfig | |||
seataServerConfig struct { | |||
// GroupID | |||
GroupID string `yaml:"groupID"` | |||
// BranchType | |||
BranchType branch.BranchType | |||
// Endpoints | |||
Endpoints []string `yaml:"endpoints" json:"endpoints"` | |||
} | |||
) | |||
func (d *seataDriver) getTargetDriverName() string { | |||
return "mysql" | |||
} | |||
func (c *seataServerConfig) validate() error { | |||
return nil | |||
type dsnConnector struct { | |||
dsn string | |||
driver driver.Driver | |||
} | |||
// loadConfig | |||
func loadConfig() *seataServerConfig { | |||
// set default value first. | |||
// todo read from configuration file. | |||
return &seataServerConfig{ | |||
GroupID: "DEFAULT_GROUP", | |||
BranchType: branch.BranchTypeAT, | |||
Endpoints: []string{"127.0.0.1:8888"}, | |||
} | |||
func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) { | |||
return t.driver.Open(t.dsn) | |||
} | |||
func (t *dsnConnector) Driver() driver.Driver { | |||
return t.driver | |||
} | |||
func parseResourceID(dsn string) string { | |||
@@ -28,12 +28,14 @@ import ( | |||
"github.com/golang/mock/gomock" | |||
"github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/util/reflectx" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func initMockResourceManager(t *testing.T, ctrl *gomock.Controller) *mock.MockDataSourceManager { | |||
func initMockResourceManager(branchType branch.BranchType, ctrl *gomock.Controller) *mock.MockDataSourceManager { | |||
mockResourceMgr := mock.NewMockDataSourceManager(ctrl) | |||
mockResourceMgr.SetBranchType(branchType) | |||
rm.GetRmCacheInstance().RegisterResourceManager(mockResourceMgr) | |||
mockResourceMgr.EXPECT().RegisterResource(gomock.Any()).AnyTimes().Return(nil) | |||
mockResourceMgr.EXPECT().CreateTableMetaCache(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) | |||
@@ -45,7 +47,7 @@ func Test_seataATDriver_Open(t *testing.T) { | |||
ctrl := gomock.NewController(t) | |||
defer ctrl.Finish() | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
@@ -93,7 +95,7 @@ func Test_seataATDriver_OpenConnector(t *testing.T) { | |||
ctrl := gomock.NewController(t) | |||
defer ctrl.Finish() | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
@@ -119,7 +121,7 @@ func Test_seataXADriver_OpenConnector(t *testing.T) { | |||
ctrl := gomock.NewController(t) | |||
defer ctrl.Finish() | |||
mockMgr := initMockResourceManager(t, ctrl) | |||
mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) | |||
_ = mockMgr | |||
db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
@@ -514,8 +514,8 @@ func canAutoIncrement(pkMetaMap map[string]types.ColumnMeta) bool { | |||
return false | |||
} | |||
func (u *insertExecutor) isAstStmtValid() bool { | |||
return u.parserCtx != nil && u.parserCtx.InsertStmt != nil | |||
func (i *insertExecutor) isAstStmtValid() bool { | |||
return i.parserCtx != nil && i.parserCtx.InsertStmt != nil | |||
} | |||
func (i *insertExecutor) autoGeneratePks(execCtx *types.ExecContext, autoColumnName string, lastInsetId, updateCount int64) (map[string][]interface{}, error) { | |||
@@ -27,7 +27,6 @@ import ( | |||
var ( | |||
atExecutors = make(map[types.DBType]func() SQLExecutor) | |||
xaExecutors = make(map[types.DBType]func() SQLExecutor) | |||
) | |||
// RegisterATExecutor AT executor | |||
@@ -35,13 +34,6 @@ func RegisterATExecutor(dt types.DBType, builder func() SQLExecutor) { | |||
atExecutors[dt] = builder | |||
} | |||
// RegisterXAExecutor XA executor | |||
func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) { | |||
xaExecutors[dt] = func() SQLExecutor { | |||
return builder() | |||
} | |||
} | |||
type ( | |||
CallbackWithNamedValue func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) | |||
@@ -66,12 +58,6 @@ func BuildExecutor(dbType types.DBType, transactionMode types.TransactionMode, q | |||
hooks = append(hooks, commonHook...) | |||
hooks = append(hooks, hookSolts[parseContext.SQLType]...) | |||
if transactionMode == types.XAMode { | |||
e := xaExecutors[dbType]() | |||
e.Interceptors(hooks) | |||
return e, nil | |||
} | |||
e := atExecutors[dbType]() | |||
e.Interceptors(hooks) | |||
return e, nil | |||
@@ -82,12 +68,10 @@ type BaseExecutor struct { | |||
ex SQLExecutor | |||
} | |||
// Interceptors | |||
func (e *BaseExecutor) Interceptors(interceptors []SQLHook) { | |||
e.hooks = interceptors | |||
} | |||
// ExecWithNamedValue | |||
func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) { | |||
for i := range e.hooks { | |||
e.hooks[i].Before(ctx, execCtx) | |||
@@ -1,93 +0,0 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package exec | |||
import ( | |||
"time" | |||
) | |||
const ( | |||
// TMENDICANT Ends a recovery scan. | |||
TMENDRSCAN = 0x00800000 | |||
/** | |||
* Disassociates the caller and marks the transaction branch | |||
* rollback-only. | |||
*/ | |||
TMFAIL = 0x20000000 | |||
/** | |||
* Caller is joining existing transaction branch. | |||
*/ | |||
TMJOIN = 0x00200000 | |||
/** | |||
* Use TMNOFLAGS to indicate no flags value is selected. | |||
*/ | |||
TMNOFLAGS = 0x00000000 | |||
/** | |||
* Caller is using one-phase optimization. | |||
*/ | |||
TMONEPHASE = 0x40000000 | |||
/** | |||
* Caller is resuming association with a suspended | |||
* transaction branch. | |||
*/ | |||
TMRESUME = 0x08000000 | |||
/** | |||
* Starts a recovery scan. | |||
*/ | |||
TMSTARTRSCAN = 0x01000000 | |||
/** | |||
* Disassociates caller from a transaction branch. | |||
*/ | |||
TMSUCCESS = 0x04000000 | |||
/** | |||
* Caller is suspending (not ending) its association with | |||
* a transaction branch. | |||
*/ | |||
TMSUSPEND = 0x02000000 | |||
/** | |||
* The transaction branch has been read-only and has been committed. | |||
*/ | |||
XA_RDONLY = 0x00000003 | |||
/** | |||
* The transaction work has been prepared normally. | |||
*/ | |||
XA_OK = 0 | |||
) | |||
type XAResource interface { | |||
Commit(xid string, onePhase bool) error | |||
End(xid string, flags int) error | |||
Forget(xid string) error | |||
GetTransactionTimeout() time.Duration | |||
IsSameRM(resource XAResource) bool | |||
XAPrepare(xid string) error | |||
Recover(flag int) ([]string, error) | |||
Rollback(xid string) error | |||
SetTransactionTimeout(duration time.Duration) bool | |||
Start(xid string, flags int) error | |||
} |
@@ -30,6 +30,7 @@ import ( | |||
"github.com/arana-db/parser/ast" | |||
"github.com/arana-db/parser/format" | |||
"github.com/arana-db/parser/model" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/datasource/sql/undo/builder" | |||
@@ -1,84 +0,0 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package xa | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
) | |||
// XAExecutor The XA transaction manager. | |||
type XAExecutor struct { | |||
hooks []exec.SQLHook | |||
ex exec.SQLExecutor | |||
} | |||
// Interceptors set xa executor hooks | |||
func (e *XAExecutor) Interceptors(hooks []exec.SQLHook) { | |||
e.hooks = hooks | |||
} | |||
// ExecWithNamedValue | |||
func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||
for _, hook := range e.hooks { | |||
hook.Before(ctx, execCtx) | |||
} | |||
defer func() { | |||
for _, hook := range e.hooks { | |||
hook.After(ctx, execCtx) | |||
} | |||
}() | |||
if e.ex != nil { | |||
return e.ex.ExecWithNamedValue(ctx, execCtx, f) | |||
} | |||
return f(ctx, execCtx.Query, execCtx.NamedValues) | |||
} | |||
// ExecWithValue | |||
func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||
for _, hook := range e.hooks { | |||
hook.Before(ctx, execCtx) | |||
} | |||
defer func() { | |||
for _, hook := range e.hooks { | |||
hook.After(ctx, execCtx) | |||
} | |||
}() | |||
if e.ex != nil { | |||
return e.ex.ExecWithValue(ctx, execCtx, f) | |||
} | |||
nvargs := make([]driver.NamedValue, len(execCtx.Values)) | |||
for i, value := range execCtx.Values { | |||
nvargs = append(nvargs, driver.NamedValue{ | |||
Value: value, | |||
Ordinal: i, | |||
}) | |||
} | |||
execCtx.NamedValues = nvargs | |||
return f(ctx, execCtx.Query, execCtx.NamedValues) | |||
} |
@@ -38,6 +38,7 @@ import ( | |||
type MockDataSourceManager struct { | |||
ctrl *gomock.Controller | |||
recorder *MockDataSourceManagerMockRecorder | |||
branchType branch.BranchType | |||
} | |||
// MockDataSourceManagerMockRecorder is the mock recorder for MockDataSourceManager. | |||
@@ -131,9 +132,13 @@ func (mr *MockDataSourceManagerMockRecorder) CreateTableMetaCache(ctx, resID, db | |||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTableMetaCache", reflect.TypeOf((*MockDataSourceManager)(nil).CreateTableMetaCache), ctx, resID, dbType, db) | |||
} | |||
func (m *MockDataSourceManager) SetBranchType(branchType branch.BranchType) { | |||
m.branchType = branchType | |||
} | |||
// GetBranchType mocks base method. | |||
func (m *MockDataSourceManager) GetBranchType() branch.BranchType { | |||
return branch.BranchTypeAT | |||
return m.branchType | |||
} | |||
// GetBranchType indicates an expected call of GetBranchType. | |||
@@ -20,7 +20,6 @@ package sql | |||
import ( | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec/at" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec/xa" | |||
"github.com/seata/seata-go/pkg/datasource/sql/hook" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/datasource/sql/undo" | |||
@@ -42,7 +41,6 @@ func hookRegister() { | |||
func executorRegister() { | |||
at.Init() | |||
xa.Init() | |||
} | |||
func undoInit() { | |||
@@ -1,31 +0,0 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package sql | |||
import ( | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
) | |||
type RootContext interface { | |||
RootContext() | |||
SetDefaultBranchType(branchType branch.BranchType) | |||
GetXID() string | |||
Bind(xid string) | |||
GetTimeout() (int, bool) | |||
SetTimeout(timeout int) | |||
} |
@@ -22,17 +22,16 @@ import ( | |||
"database/sql/driver" | |||
"fmt" | |||
"sync" | |||
"time" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/rm" | |||
"github.com/seata/seata-go/pkg/util/backoff" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
) | |||
const REPORT_RETRY_COUNT = 5 | |||
var ( | |||
hl sync.RWMutex | |||
txHooks []txHook | |||
@@ -146,19 +145,31 @@ func (tx *Tx) commitOnLocal() error { | |||
// register | |||
func (tx *Tx) register(ctx *types.TransactionContext) error { | |||
if !ctx.HasUndoLog() || !ctx.HasLockKey() { | |||
if ctx.TransactionMode.BranchType() == branch.BranchTypeUnknow { | |||
return nil | |||
} | |||
lockKey := "" | |||
for k, _ := range ctx.LockKeys { | |||
lockKey += k + ";" | |||
if ctx.TransactionMode.BranchType() == branch.BranchTypeAT && !ctx.HasUndoLog() || !ctx.HasLockKey() { | |||
return nil | |||
} | |||
request := rm.BranchRegisterParam{ | |||
Xid: ctx.XID, | |||
BranchType: ctx.TransactionMode.BranchType(), | |||
ResourceId: ctx.ResourceID, | |||
LockKeys: lockKey, | |||
} | |||
var lockKey string | |||
if ctx.TransactionMode == types.ATMode { | |||
if !ctx.HasUndoLog() || !ctx.HasLockKey() { | |||
return nil | |||
} | |||
for k, _ := range ctx.LockKeys { | |||
lockKey += k + ";" | |||
} | |||
request.LockKeys = lockKey | |||
} | |||
dataSourceManager := datasource.GetDataSourceManager(ctx.TransactionMode.BranchType()) | |||
branchId, err := dataSourceManager.BranchRegister(context.Background(), request) | |||
if err != nil { | |||
@@ -184,21 +195,22 @@ func (tx *Tx) report(success bool) error { | |||
if dataSourceManager == nil { | |||
return fmt.Errorf("get dataSourceManager failed") | |||
} | |||
retry := REPORT_RETRY_COUNT | |||
for retry > 0 { | |||
err := dataSourceManager.BranchReport(context.Background(), request) | |||
if err != nil { | |||
retry-- | |||
log.Infof("Failed to report [%s / %s] commit done [%s] Retry Countdown: %s", tx.tranCtx.BranchID, tx.tranCtx.XID, success, retry) | |||
if retry == 0 { | |||
log.Errorf("Failed to report branch status: %s", err.Error()) | |||
return err | |||
} | |||
} else { | |||
return nil | |||
retry := backoff.New(context.Background(), backoff.Config{ | |||
MinBackoff: 100 * time.Millisecond, | |||
MaxBackoff: 200 * time.Millisecond, | |||
MaxRetries: 5, | |||
}) | |||
var err error | |||
for retry.Ongoing() { | |||
if err = dataSourceManager.BranchReport(context.Background(), request); err == nil { | |||
break | |||
} | |||
log.Infof("Failed to report [%s / %s] commit done [%s] Retry Countdown: %s", tx.tranCtx.BranchID, tx.tranCtx.XID, success, retry) | |||
retry.Wait() | |||
} | |||
return nil | |||
return err | |||
} | |||
func getStatus(success bool) branch.BranchStatus { | |||
@@ -19,6 +19,7 @@ package sql | |||
import ( | |||
"github.com/pkg/errors" | |||
"github.com/seata/seata-go/pkg/datasource/sql/undo" | |||
) | |||
@@ -17,7 +17,6 @@ | |||
package sql | |||
// XATx | |||
type XATx struct { | |||
tx *Tx | |||
} | |||
@@ -32,20 +31,14 @@ func (tx *XATx) Commit() error { | |||
} | |||
func (tx *XATx) Rollback() error { | |||
err := tx.tx.Rollback() | |||
if err != nil { | |||
originTx := tx.tx | |||
if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() { | |||
originTx.report(false) | |||
} | |||
originTx := tx.tx | |||
if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() { | |||
return originTx.report(false) | |||
} | |||
return err | |||
return nil | |||
} | |||
// commitOnXA | |||
// commitOnXA commit xa and register branch transaction | |||
func (tx *XATx) commitOnXA() error { | |||
return nil | |||
} |
@@ -22,9 +22,9 @@ import ( | |||
"fmt" | |||
"strings" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/google/uuid" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
) | |||
type DBType int16 | |||
@@ -76,6 +76,7 @@ const ( | |||
DBTypePostgreSQL | |||
DBTypeSQLServer | |||
DBTypeOracle | |||
DBTypeMARIADB | |||
BranchPhase_Unknown = 0 | |||
BranchPhase_Done = 1 | |||
@@ -28,8 +28,10 @@ import ( | |||
"database/sql/driver" | |||
"errors" | |||
"fmt" | |||
"math" | |||
"reflect" | |||
"strconv" | |||
"strings" | |||
"time" | |||
) | |||
@@ -375,3 +377,30 @@ type decimalCompose interface { | |||
// represented then an error should be returned. | |||
Compose(form byte, negative bool, coefficient []byte, exponent int32) error | |||
} | |||
// ConvertDbVersion convert a string db version to a number. | |||
func ConvertDbVersion(version string) (int, error) { | |||
parts := strings.Split(version, ".") | |||
size := len(parts) | |||
maxVersionDot := 3 | |||
if size > maxVersionDot+1 { | |||
return 0, fmt.Errorf("incompatible version format: %s", version) | |||
} | |||
var res int | |||
for idx, part := range parts { | |||
if partInt, err := strconv.Atoi(part); err == nil { | |||
res += calculatePartValue(partInt, size, idx) | |||
} else { | |||
subParts := strings.Split(part, "-") | |||
if subPartInt, err := strconv.Atoi(subParts[0]); err == nil { | |||
res += calculatePartValue(subPartInt, size, idx) | |||
} | |||
} | |||
} | |||
return res, nil | |||
} | |||
func calculatePartValue(partNumeric, size, index int) int { | |||
return partNumeric * int(math.Pow(100, float64(size-index))) | |||
} |
@@ -15,22 +15,27 @@ | |||
* limitations under the License. | |||
*/ | |||
package sql | |||
package util | |||
import ( | |||
"testing" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec/xa" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func TestConn_BuildXAExecutor(t *testing.T) { | |||
executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.XAMode, "SELECT * FROM user") | |||
func TestConvertDbVersion(t *testing.T) { | |||
version1 := "3.1.2" | |||
v1Int, err1 := ConvertDbVersion(version1) | |||
assert.NoError(t, err1) | |||
assert.NoError(t, err) | |||
version2 := "3.1.3" | |||
v2Int, err2 := ConvertDbVersion(version2) | |||
assert.NoError(t, err2) | |||
_, ok := executor.(*xa.XAExecutor) | |||
assert.True(t, ok, "need xa executor") | |||
assert.Less(t, v1Int, v2Int) | |||
version3 := "3.1.3" | |||
v3Int, err3 := ConvertDbVersion(version3) | |||
assert.NoError(t, err3) | |||
assert.Equal(t, v2Int, v3Int) | |||
} |
@@ -15,24 +15,27 @@ | |||
* limitations under the License. | |||
*/ | |||
package exec | |||
package xa | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"strings" | |||
"time" | |||
"github.com/pkg/errors" | |||
) | |||
type MysqlXAConn struct { | |||
driver.Conn | |||
} | |||
func (c *MysqlXAConn) Commit(xid string, onePhase bool) error { | |||
func NewMysqlXaConn(conn driver.Conn) *MysqlXAConn { | |||
return &MysqlXAConn{Conn: conn} | |||
} | |||
func (c *MysqlXAConn) Commit(ctx context.Context, xid string, onePhase bool) error { | |||
var sb strings.Builder | |||
sb.WriteString("XA COMMIT ") | |||
sb.WriteString(xid) | |||
@@ -41,33 +44,33 @@ func (c *MysqlXAConn) Commit(xid string, onePhase bool) error { | |||
} | |||
conn, _ := c.Conn.(driver.ExecerContext) | |||
_, err := conn.ExecContext(context.TODO(), sb.String(), nil) | |||
_, err := conn.ExecContext(ctx, sb.String(), nil) | |||
return err | |||
} | |||
func (c *MysqlXAConn) End(xid string, flags int) error { | |||
func (c *MysqlXAConn) End(ctx context.Context, xid string, flags int) error { | |||
var sb strings.Builder | |||
sb.WriteString("XA END ") | |||
sb.WriteString(xid) | |||
switch flags { | |||
case TMSUCCESS: | |||
case TMSuccess: | |||
break | |||
case TMSUSPEND: | |||
case TMSuspend: | |||
sb.WriteString(" SUSPEND") | |||
break | |||
case TMFAIL: | |||
case TMFail: | |||
break | |||
default: | |||
return errors.New("invalid arguments") | |||
} | |||
conn, _ := c.Conn.(driver.ExecerContext) | |||
_, err := conn.ExecContext(context.TODO(), sb.String(), nil) | |||
_, err := conn.ExecContext(ctx, sb.String(), nil) | |||
return err | |||
} | |||
func (c *MysqlXAConn) Forget(xid string) error { | |||
func (c *MysqlXAConn) Forget(ctx context.Context, xid string) error { | |||
// mysql doesn't support this | |||
return errors.New("mysql doesn't support this") | |||
} | |||
@@ -78,29 +81,29 @@ func (c *MysqlXAConn) GetTransactionTimeout() time.Duration { | |||
// IsSameRM is called to determine if the resource manager instance represented by the target object | |||
// is the same as the resource manager instance represented by the parameter xares. | |||
func (c *MysqlXAConn) IsSameRM(xares XAResource) bool { | |||
func (c *MysqlXAConn) IsSameRM(ctx context.Context, xares XAResource) bool { | |||
// todo: the fn depends on the driver.Conn, but it doesn't support | |||
return false | |||
} | |||
func (c *MysqlXAConn) XAPrepare(xid string) error { | |||
func (c *MysqlXAConn) XAPrepare(ctx context.Context, xid string) error { | |||
var sb strings.Builder | |||
sb.WriteString("XA PREPARE ") | |||
sb.WriteString(xid) | |||
conn, _ := c.Conn.(driver.ExecerContext) | |||
_, err := conn.ExecContext(context.TODO(), sb.String(), nil) | |||
_, err := conn.ExecContext(ctx, sb.String(), nil) | |||
return err | |||
} | |||
// Recover Obtains a list of prepared transaction branches from a resource manager. | |||
// The transaction manager calls this method during recovery to obtain the list of transaction branches | |||
// that are currently in prepared or heuristically completed states. | |||
func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) { | |||
startRscan := (flag & TMSTARTRSCAN) > 0 | |||
endRscan := (flag & TMENDRSCAN) > 0 | |||
func (c *MysqlXAConn) Recover(ctx context.Context, flag int) (xids []string, err error) { | |||
startRscan := (flag & TMStartRScan) > 0 | |||
endRscan := (flag & TMEndRScan) > 0 | |||
if !startRscan && !endRscan && flag != TMNOFLAGS { | |||
if !startRscan && !endRscan && flag != TMNoFlags { | |||
return nil, errors.New("invalid arguments") | |||
} | |||
@@ -109,7 +112,7 @@ func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) { | |||
} | |||
conn := c.Conn.(driver.QueryerContext) | |||
res, err := conn.QueryContext(context.TODO(), "XA RECOVER", nil) | |||
res, err := conn.QueryContext(ctx, "XA RECOVER", nil) | |||
if err != nil { | |||
return nil, err | |||
} | |||
@@ -133,13 +136,13 @@ func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) { | |||
return xids, err | |||
} | |||
func (c *MysqlXAConn) Rollback(xid string) error { | |||
func (c *MysqlXAConn) Rollback(ctx context.Context, xid string) error { | |||
var sb strings.Builder | |||
sb.WriteString("XA ROLLBACK ") | |||
sb.WriteString(xid) | |||
conn, _ := c.Conn.(driver.ExecerContext) | |||
_, err := conn.ExecContext(context.TODO(), sb.String(), nil) | |||
_, err := conn.ExecContext(ctx, sb.String(), nil) | |||
return err | |||
} | |||
@@ -147,25 +150,25 @@ func (c *MysqlXAConn) SetTransactionTimeout(duration time.Duration) bool { | |||
return false | |||
} | |||
func (c *MysqlXAConn) Start(xid string, flags int) error { | |||
func (c *MysqlXAConn) Start(ctx context.Context, xid string, flags int) error { | |||
var sb strings.Builder | |||
sb.WriteString("XA START") | |||
sb.WriteString(xid) | |||
switch flags { | |||
case TMJOIN: | |||
case TMJoin: | |||
sb.WriteString(" JOIN") | |||
break | |||
case TMRESUME: | |||
case TMResume: | |||
sb.WriteString(" RESUME") | |||
break | |||
case TMNOFLAGS: | |||
case TMNoFlags: | |||
break | |||
default: | |||
return errors.New("invalid arguments") | |||
} | |||
conn, _ := c.Conn.(driver.ExecerContext) | |||
_, err := conn.ExecContext(context.TODO(), sb.String(), nil) | |||
_, err := conn.ExecContext(ctx, sb.String(), nil) | |||
return err | |||
} |
@@ -15,19 +15,18 @@ | |||
* limitations under the License. | |||
*/ | |||
package exec | |||
package xa | |||
import ( | |||
"context" | |||
"database/sql/driver" | |||
"fmt" | |||
"errors" | |||
"io" | |||
"reflect" | |||
"strings" | |||
"testing" | |||
"github.com/golang/mock/gomock" | |||
"github.com/pkg/errors" | |||
"github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
) | |||
@@ -78,7 +77,7 @@ func TestMysqlXAConn_Commit(t *testing.T) { | |||
c := &MysqlXAConn{ | |||
Conn: mockConn, | |||
} | |||
if err := c.Commit(tt.input.xid, tt.input.onePhase); (err != nil) != tt.wantErr { | |||
if err := c.Commit(context.Background(), tt.input.xid, tt.input.onePhase); (err != nil) != tt.wantErr { | |||
t.Errorf("Commit() error = %v, wantErr %v", err, tt.wantErr) | |||
} | |||
}) | |||
@@ -102,7 +101,7 @@ func TestMysqlXAConn_End(t *testing.T) { | |||
name: "tm success", | |||
input: args{ | |||
xid: "xid", | |||
flags: TMSUCCESS, | |||
flags: TMSuccess, | |||
}, | |||
wantErr: false, | |||
}, | |||
@@ -110,7 +109,7 @@ func TestMysqlXAConn_End(t *testing.T) { | |||
name: "tm failed", | |||
input: args{ | |||
xid: "xid", | |||
flags: TMFAIL, | |||
flags: TMFail, | |||
}, | |||
wantErr: false, | |||
}, | |||
@@ -124,7 +123,7 @@ func TestMysqlXAConn_End(t *testing.T) { | |||
c := &MysqlXAConn{ | |||
Conn: mockConn, | |||
} | |||
if err := c.End(tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { | |||
if err := c.End(context.Background(), tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { | |||
t.Errorf("End() error = %v, wantErr %v", err, tt.wantErr) | |||
} | |||
}) | |||
@@ -148,7 +147,7 @@ func TestMysqlXAConn_Start(t *testing.T) { | |||
name: "normal start", | |||
input: args{ | |||
xid: "xid", | |||
flags: TMNOFLAGS, | |||
flags: TMNoFlags, | |||
}, | |||
wantErr: false, | |||
}, | |||
@@ -161,7 +160,7 @@ func TestMysqlXAConn_Start(t *testing.T) { | |||
c := &MysqlXAConn{ | |||
Conn: mockConn, | |||
} | |||
if err := c.Start(tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { | |||
if err := c.Start(context.Background(), tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { | |||
t.Errorf("Start() error = %v, wantErr %v", err, tt.wantErr) | |||
} | |||
}) | |||
@@ -196,7 +195,7 @@ func TestMysqlXAConn_XAPrepare(t *testing.T) { | |||
c := &MysqlXAConn{ | |||
Conn: mockConn, | |||
} | |||
if err := c.XAPrepare(tt.input.xid); (err != nil) != tt.wantErr { | |||
if err := c.XAPrepare(context.Background(), tt.input.xid); (err != nil) != tt.wantErr { | |||
t.Errorf("XAPrepare() error = %v, wantErr %v", err, tt.wantErr) | |||
} | |||
}) | |||
@@ -219,7 +218,7 @@ func TestMysqlXAConn_Recover(t *testing.T) { | |||
{ | |||
name: "normal recover", | |||
args: args{ | |||
flag: TMSTARTRSCAN | TMENDRSCAN, | |||
flag: TMStartRScan | TMEndRScan, | |||
}, | |||
want: []string{"xid", "another_xid"}, | |||
wantErr: false, | |||
@@ -227,14 +226,14 @@ func TestMysqlXAConn_Recover(t *testing.T) { | |||
{ | |||
name: "invalid flag for recover", | |||
args: args{ | |||
flag: TMFAIL, | |||
flag: TMFail, | |||
}, | |||
wantErr: true, | |||
}, | |||
{ | |||
name: "valid flag for recover but don't scan", | |||
args: args{ | |||
flag: TMENDRSCAN, | |||
flag: TMEndRScan, | |||
}, | |||
want: nil, | |||
wantErr: false, | |||
@@ -257,7 +256,7 @@ func TestMysqlXAConn_Recover(t *testing.T) { | |||
c := &MysqlXAConn{ | |||
Conn: mockConn, | |||
} | |||
got, err := c.Recover(tt.args.flag) | |||
got, err := c.Recover(context.Background(), tt.args.flag) | |||
if (err != nil) != tt.wantErr { | |||
t.Errorf("Recover() error = %v, wantErr %v", err, tt.wantErr) | |||
return | |||
@@ -295,10 +294,7 @@ func (m *mysqlMockRows) Next(dest []driver.Value) error { | |||
} | |||
return b | |||
} | |||
cnt := min(len(m.data[0]), len(dest)) | |||
fmt.Printf("cnt: %d", cnt) | |||
for i := 0; i < cnt; i++ { | |||
dest[i] = m.data[m.idx][i] | |||
} |
@@ -15,7 +15,7 @@ | |||
* limitations under the License. | |||
*/ | |||
package conn | |||
package xa | |||
import ( | |||
"context" |
@@ -15,7 +15,7 @@ | |||
* limitations under the License. | |||
*/ | |||
package conn | |||
package xa | |||
import ( | |||
"database/sql/driver" |
@@ -1,26 +0,0 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package xa | |||
import ( | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
) | |||
type XAConnection interface { | |||
getXAResource() (exec.XAResource, error) | |||
} |
@@ -1,325 +0,0 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package xa | |||
import ( | |||
"context" | |||
"fmt" | |||
"time" | |||
"github.com/seata/seata-go/pkg/datasource/sql" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/rm" | |||
) | |||
type ConnectionProxyXA struct { | |||
xaBranchXid *XABranchXid | |||
currentAutoCommitStatus bool `default:"true"` | |||
xaActive bool `default:"false"` | |||
kept bool `default:"false"` | |||
rollBacked bool `default:"false"` | |||
branchRegisterTime int64 `default:"0"` | |||
prepareTime int64 `default:"0"` | |||
timeout int `default:"0"` | |||
proxyShouldBeHeld bool `default:"false"` | |||
originalConnection sql.Conn | |||
xaConnection XAConnection | |||
xaResource exec.XAResource | |||
resource sql.BaseDataSourceResource | |||
xid string | |||
} | |||
const timeout int = 60000 | |||
func NewConnectionProxyXA(originalConnection sql.Conn, xaConnection XAConnection, resource sql.BaseDataSourceResource, xid string) (*ConnectionProxyXA, error) { | |||
connectionProxyXA := &ConnectionProxyXA{} | |||
connectionProxyXA.originalConnection = originalConnection | |||
connectionProxyXA.xaConnection = xaConnection | |||
connectionProxyXA.resource = resource | |||
connectionProxyXA.xid = xid | |||
connectionProxyXA.proxyShouldBeHeld = connectionProxyXA.resource.IsShouldBeHeld() | |||
xaResource, err := xaConnection.getXAResource() | |||
if err != nil { | |||
return nil, fmt.Errorf("get xa resource failed") | |||
} else { | |||
connectionProxyXA.xaResource = xaResource | |||
} | |||
var rootContext sql.RootContext | |||
transactionTimeout, ok := rootContext.GetTimeout() | |||
if !ok { | |||
transactionTimeout = timeout | |||
} | |||
if transactionTimeout < timeout { | |||
transactionTimeout = timeout | |||
} | |||
connectionProxyXA.timeout = transactionTimeout | |||
connectionProxyXA.currentAutoCommitStatus = connectionProxyXA.originalConnection.GetAutoCommit() | |||
if !connectionProxyXA.currentAutoCommitStatus { | |||
return nil, fmt.Errorf("connection[autocommit=false] as default is NOT supported") | |||
} | |||
return connectionProxyXA, nil | |||
} | |||
func (c *ConnectionProxyXA) keepIfNecessary() { | |||
if c.ShouldBeHeld() { | |||
c.resource.Hold(c.xaBranchXid.String(), c) | |||
} | |||
} | |||
func (c *ConnectionProxyXA) releaseIfNecessary() { | |||
if c.ShouldBeHeld() { | |||
if c.xaBranchXid == nil { | |||
if c.IsHeld() { | |||
c.resource.Release(c.xaBranchXid.String(), c) | |||
} | |||
} | |||
} | |||
} | |||
func (c *ConnectionProxyXA) XaCommit(xid string, branchId int64) error { | |||
xaXid := Build(xid, branchId) | |||
err := c.xaResource.Commit(xaXid.String(), false) | |||
c.releaseIfNecessary() | |||
return err | |||
} | |||
func (c *ConnectionProxyXA) XaRollbackByBranchId(xid string, branchId int64) { | |||
xaXid := Build(xid, branchId) | |||
c.XaRollback(xaXid) | |||
} | |||
func (c *ConnectionProxyXA) XaRollback(xaXid XAXid) error { | |||
err := c.xaResource.Rollback(xaXid.GetGlobalXid()) | |||
c.releaseIfNecessary() | |||
return err | |||
} | |||
func (c *ConnectionProxyXA) SetAutoCommit(autoCommit bool) error { | |||
if c.currentAutoCommitStatus == autoCommit { | |||
return nil | |||
} | |||
if autoCommit { | |||
if c.xaActive { | |||
_ = c.Commit() | |||
} | |||
} else { | |||
if c.xaActive { | |||
return fmt.Errorf("should NEVER happen: setAutoCommit from true to false while xa branch is active") | |||
} | |||
c.branchRegisterTime = time.Now().UnixMilli() | |||
var branchRegisterParam rm.BranchRegisterParam | |||
branchRegisterParam.BranchType = branch.BranchTypeXA | |||
branchRegisterParam.ResourceId = c.resource.GetResourceId() | |||
branchRegisterParam.Xid = c.xid | |||
branchId, err := datasource.GetDataSourceManager(branch.BranchTypeXA).BranchRegister(context.TODO(), branchRegisterParam) | |||
if err != nil { | |||
c.cleanXABranchContext() | |||
return fmt.Errorf("failed to register xa branch [%v]", c.xid) | |||
} | |||
c.xaBranchXid = Build(c.xid, branchId) | |||
c.keepIfNecessary() | |||
err = c.start() | |||
if err != nil { | |||
c.cleanXABranchContext() | |||
return fmt.Errorf("failed to start xa branch [%v]", c.xid) | |||
} | |||
c.xaActive = true | |||
} | |||
c.currentAutoCommitStatus = autoCommit | |||
return nil | |||
} | |||
func (c *ConnectionProxyXA) GetAutoCommit() bool { | |||
return c.currentAutoCommitStatus | |||
} | |||
func (c *ConnectionProxyXA) Commit() error { | |||
if c.currentAutoCommitStatus { | |||
return nil | |||
} | |||
if !c.xaActive || c.xaBranchXid == nil { | |||
return fmt.Errorf("should NOT commit on an inactive session") | |||
} | |||
now := time.Now().UnixMilli() | |||
if c.end(exec.TMSUCCESS) != nil { | |||
return c.commitErrorHandle() | |||
} | |||
if c.checkTimeout(now) != nil { | |||
return c.commitErrorHandle() | |||
} | |||
if c.xaResource.XAPrepare(c.xaBranchXid.String()) != nil { | |||
return c.commitErrorHandle() | |||
} | |||
return nil | |||
} | |||
func (c *ConnectionProxyXA) commitErrorHandle() error { | |||
req := message.BranchReportRequest{ | |||
BranchType: branch.BranchTypeXA, | |||
Xid: c.xid, | |||
BranchId: c.xaBranchXid.GetBranchId(), | |||
Status: branch.BranchStatusPhaseoneFailed, | |||
ApplicationData: nil, | |||
ResourceId: c.resource.GetResourceId(), | |||
} | |||
if datasource.NewBasicSourceManager().BranchReport(context.TODO(), req) != nil { | |||
c.cleanXABranchContext() | |||
return fmt.Errorf("Failed to report XA branch commit-failure on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) | |||
} | |||
c.cleanXABranchContext() | |||
return fmt.Errorf("Failed to end(TMSUCCESS)/prepare xa branch on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) | |||
} | |||
func (c *ConnectionProxyXA) Rollback() error { | |||
if c.currentAutoCommitStatus { | |||
return nil | |||
} | |||
if !c.xaActive || c.xaBranchXid == nil { | |||
return fmt.Errorf("should NOT rollback on an inactive session") | |||
} | |||
if !c.rollBacked { | |||
if c.xaResource.End(c.xaBranchXid.String(), exec.TMFAIL) != nil { | |||
return c.rollbackErrorHandle() | |||
} | |||
if c.XaRollback(c.xaBranchXid) != nil { | |||
c.cleanXABranchContext() | |||
return c.rollbackErrorHandle() | |||
} | |||
req := message.BranchReportRequest{ | |||
BranchType: branch.BranchTypeXA, | |||
Xid: c.xid, | |||
BranchId: c.xaBranchXid.GetBranchId(), | |||
Status: branch.BranchStatusPhaseoneFailed, | |||
ApplicationData: nil, | |||
ResourceId: c.resource.GetResourceId(), | |||
} | |||
if datasource.NewBasicSourceManager().BranchReport(context.TODO(), req) != nil { | |||
c.cleanXABranchContext() | |||
return fmt.Errorf("failed to report XA branch commit-failure on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) | |||
} | |||
} | |||
c.cleanXABranchContext() | |||
return nil | |||
} | |||
func (c *ConnectionProxyXA) rollbackErrorHandle() error { | |||
return fmt.Errorf("failed to end(TMFAIL) xa branch on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) | |||
} | |||
func (c *ConnectionProxyXA) start() error { | |||
err := c.xaResource.Start(c.xaBranchXid.String(), exec.TMNOFLAGS) | |||
if err := c.termination(c.xaBranchXid.String()); err != nil { | |||
c.xaResource.End(c.xaBranchXid.String(), exec.TMFAIL) | |||
c.XaRollback(c.xaBranchXid) | |||
return err | |||
} | |||
return err | |||
} | |||
func (c *ConnectionProxyXA) end(flags int) error { | |||
err := c.termination(c.xaBranchXid.String()) | |||
if err != nil { | |||
return err | |||
} | |||
err = c.xaResource.End(c.xaBranchXid.String(), flags) | |||
if err != nil { | |||
return err | |||
} | |||
return nil | |||
} | |||
func (c *ConnectionProxyXA) cleanXABranchContext() { | |||
c.branchRegisterTime = 0 | |||
c.prepareTime = 0 | |||
c.timeout = 0 | |||
c.xaActive = false | |||
if !c.IsHeld() { | |||
c.xaBranchXid = nil | |||
} | |||
} | |||
func (c *ConnectionProxyXA) checkTimeout(now int64) error { | |||
if now-c.branchRegisterTime > int64(c.timeout) { | |||
c.XaRollback(c.xaBranchXid) | |||
return fmt.Errorf("XA branch timeout error") | |||
} | |||
return nil | |||
} | |||
func (c *ConnectionProxyXA) Close() error { | |||
c.rollBacked = false | |||
if c.IsHeld() && c.ShouldBeHeld() { | |||
return nil | |||
} | |||
c.cleanXABranchContext() | |||
if err := c.originalConnection.Close(); err != nil { | |||
return err | |||
} | |||
return nil | |||
} | |||
func (c *ConnectionProxyXA) CloseForce() error { | |||
physicalConn := c.originalConnection | |||
if err := physicalConn.Close(); err != nil { | |||
return err | |||
} | |||
c.rollBacked = false | |||
c.cleanXABranchContext() | |||
if err := c.originalConnection.Close(); err != nil { | |||
return err | |||
} | |||
c.releaseIfNecessary() | |||
return nil | |||
} | |||
func (c *ConnectionProxyXA) SetHeld(kept bool) { | |||
c.kept = kept | |||
} | |||
func (c *ConnectionProxyXA) IsHeld() bool { | |||
return c.kept | |||
} | |||
func (c *ConnectionProxyXA) ShouldBeHeld() bool { | |||
return c.proxyShouldBeHeld || c.resource.GetDB() != nil | |||
} | |||
func (c *ConnectionProxyXA) GetPrepareTime() int64 { | |||
return c.prepareTime | |||
} | |||
func (c *ConnectionProxyXA) setPrepareTime(prepareTime int64) { | |||
c.prepareTime = prepareTime | |||
} | |||
func (c *ConnectionProxyXA) termination(xaBranchXid string) error { | |||
branchStatus, err := c.resource.GetBranchStatus(xaBranchXid) | |||
if err != nil { | |||
c.releaseIfNecessary() | |||
return fmt.Errorf("failed xa branch [%v] the global transaction has finish, branch status: [%v]", c.xid, branchStatus) | |||
} | |||
return nil | |||
} |
@@ -0,0 +1,71 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package xa | |||
import ( | |||
"context" | |||
"time" | |||
) | |||
const ( | |||
// TMEndRScan ends a recovery scan. | |||
TMEndRScan = 0x00800000 | |||
// TMFail disassociates the caller and marks the transaction branch | |||
// rollback-only. | |||
TMFail = 0x20000000 | |||
// TMJoin joining existing transaction branch. | |||
TMJoin = 0x00200000 | |||
// TMNoFlags indicate no flags value is selected. | |||
TMNoFlags = 0x00000000 | |||
// TMOnePhase using one-phase optimization. | |||
TMOnePhase = 0x40000000 | |||
// TMResume is resuming association with a suspended transaction branch. | |||
TMResume = 0x08000000 | |||
// TMStartRScan starts a recovery scan. | |||
TMStartRScan = 0x01000000 | |||
// TMSuccess disassociates caller from a transaction branch. | |||
TMSuccess = 0x04000000 | |||
// TMSuspend is suspending (not ending) its association with a transaction branch. | |||
TMSuspend = 0x02000000 | |||
// XAReadOnly the transaction branch has been read-only and has been committed. | |||
XAReadOnly = 0x00000003 | |||
// XAOk The transaction work has been prepared normally. | |||
XAOk = 0 | |||
) | |||
type XAResource interface { | |||
Commit(ctx context.Context, xid string, onePhase bool) error | |||
End(ctx context.Context, xid string, flags int) error | |||
Forget(ctx context.Context, xid string) error | |||
GetTransactionTimeout() time.Duration | |||
IsSameRM(ctx context.Context, resource XAResource) bool | |||
XAPrepare(ctx context.Context, xid string) error | |||
Recover(ctx context.Context, flag int) ([]string, error) | |||
Rollback(ctx context.Context, xid string) error | |||
SetTransactionTimeout(duration time.Duration) bool | |||
Start(ctx context.Context, xid string, flags int) error | |||
} |
@@ -18,12 +18,31 @@ | |||
package xa | |||
import ( | |||
"github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
"database/sql/driver" | |||
"fmt" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
) | |||
func Init() { | |||
exec.RegisterXAExecutor(types.DBTypeMySQL, func() exec.SQLExecutor { | |||
return &XAExecutor{} | |||
}) | |||
// CreateXAResource create a connection for xa with the different db type. | |||
// Such as mysql, oracle, MARIADB, POSTGRESQL | |||
func CreateXAResource(conn driver.Conn, dbType types.DBType) (XAResource, error) { | |||
var err error | |||
var xaConnection XAResource | |||
switch dbType { | |||
case types.DBTypeMySQL: | |||
xaConnection = NewMysqlXaConn(conn) | |||
case types.DBTypeOracle: | |||
case types.DBTypePostgreSQL: | |||
default: | |||
err = fmt.Errorf("not support db type for :%s", dbType.String()) | |||
} | |||
if err != nil { | |||
log.Errorf(err.Error()) | |||
return nil, err | |||
} | |||
return xaConnection, nil | |||
} |
@@ -15,7 +15,7 @@ | |||
* limitations under the License. | |||
*/ | |||
package xa | |||
package sql | |||
import ( | |||
"strconv" | |||
@@ -23,13 +23,12 @@ import ( | |||
) | |||
const ( | |||
BranchIdPrefix = "-" | |||
SeataXaXidFormatId = 9752 | |||
branchIdPrefix = "-" | |||
) | |||
type XABranchXid struct { | |||
xid string | |||
branchId int64 | |||
branchId uint64 | |||
globalTransactionId []byte | |||
branchQualifier []byte | |||
} | |||
@@ -63,14 +62,10 @@ func (x *XABranchXid) GetGlobalXid() string { | |||
return x.xid | |||
} | |||
func (x *XABranchXid) GetBranchId() int64 { | |||
func (x *XABranchXid) GetBranchId() uint64 { | |||
return x.branchId | |||
} | |||
func (x *XABranchXid) GetFormatId() int { | |||
return SeataXaXidFormatId | |||
} | |||
func (x *XABranchXid) GetGlobalTransactionId() []byte { | |||
return x.globalTransactionId | |||
} | |||
@@ -80,7 +75,7 @@ func (x *XABranchXid) GetBranchQualifier() []byte { | |||
} | |||
func (x *XABranchXid) String() string { | |||
return x.xid + BranchIdPrefix + strconv.FormatInt(x.branchId, 10) | |||
return x.xid + branchIdPrefix + strconv.FormatUint(x.branchId, 10) | |||
} | |||
func WithXid(xid string) Option { | |||
@@ -89,7 +84,7 @@ func WithXid(xid string) Option { | |||
} | |||
} | |||
func WithBranchId(branchId int64) Option { | |||
func WithBranchId(branchId uint64) Option { | |||
return func(x *XABranchXid) { | |||
x.branchId = branchId | |||
} | |||
@@ -113,7 +108,7 @@ func encode(x *XABranchXid) { | |||
} | |||
if x.branchId != 0 { | |||
x.branchQualifier = []byte(BranchIdPrefix + strconv.FormatInt(x.branchId, 10)) | |||
x.branchQualifier = []byte(branchIdPrefix + strconv.FormatUint(x.branchId, 10)) | |||
} | |||
} | |||
@@ -123,7 +118,7 @@ func decode(x *XABranchXid) { | |||
} | |||
if len(x.branchQualifier) > 0 { | |||
branchId := strings.TrimLeft(string(x.branchQualifier), BranchIdPrefix) | |||
x.branchId, _ = strconv.ParseInt(branchId, 10, 64) | |||
branchId := strings.TrimLeft(string(x.branchQualifier), branchIdPrefix) | |||
x.branchId, _ = strconv.ParseUint(branchId, 10, 64) | |||
} | |||
} |
@@ -15,7 +15,7 @@ | |||
* limitations under the License. | |||
*/ | |||
package xa | |||
package sql | |||
import ( | |||
"testing" | |||
@@ -25,8 +25,8 @@ import ( | |||
func TestXABranchXidBuild(t *testing.T) { | |||
xid := "111" | |||
branchId := int64(222) | |||
x := Build(xid, branchId) | |||
branchId := uint64(222) | |||
x := XaIdBuild(xid, branchId) | |||
assert.Equal(t, x.GetGlobalXid(), xid) | |||
assert.Equal(t, x.GetBranchId(), branchId) | |||
@@ -36,11 +36,11 @@ func TestXABranchXidBuild(t *testing.T) { | |||
func TestXABranchXidBuildWithByte(t *testing.T) { | |||
xid := []byte("111") | |||
branchId := []byte(BranchIdPrefix + "222") | |||
x := BuildWithByte(xid, branchId) | |||
branchId := []byte(branchIdPrefix + "222") | |||
x := XaIdBuildWithByte(xid, branchId) | |||
assert.Equal(t, x.GetGlobalTransactionId(), xid) | |||
assert.Equal(t, x.GetBranchQualifier(), branchId) | |||
assert.Equal(t, x.GetGlobalXid(), "111") | |||
assert.Equal(t, x.GetBranchId(), int64(222)) | |||
assert.Equal(t, x.GetBranchId(), uint64(222)) | |||
} |
@@ -0,0 +1,245 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package sql | |||
import ( | |||
"context" | |||
"database/sql" | |||
"errors" | |||
"flag" | |||
"fmt" | |||
"sync" | |||
"time" | |||
"github.com/bluele/gcache" | |||
"github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
"github.com/seata/seata-go/pkg/datasource/sql/types" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/rm" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
) | |||
var branchStatusCache gcache.Cache | |||
type XAConnConf struct { | |||
XaBranchExecutionTimeout time.Duration `json:"xa_branch_execution_timeout" xml:"xa_branch_execution_timeout" koanf:"xa_branch_execution_timeout"` | |||
} | |||
func (cfg *XAConnConf) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { | |||
f.DurationVar(&cfg.XaBranchExecutionTimeout, prefix+".xa_branch_execution_timeout", time.Minute, "Undo log table name.") | |||
} | |||
type XAConfig struct { | |||
xaConnConf XAConnConf | |||
TwoPhaseHoldTime time.Duration `json:"two_phase_hold_time" yaml:"xa_two_phase_hold_time" koanf:"xa_two_phase_hold_time"` | |||
} | |||
func (cfg *XAConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { | |||
f.DurationVar(&cfg.TwoPhaseHoldTime, prefix+".two_phase_hold_time", time.Millisecond*1000, "Undo log table name.") | |||
cfg.xaConnConf.RegisterFlagsWithPrefix(prefix, f) | |||
} | |||
func InitXA(config XAConfig) *XAResourceManager { | |||
xaSourceManager := &XAResourceManager{ | |||
resourceCache: sync.Map{}, | |||
basic: datasource.NewBasicSourceManager(), | |||
rmRemoting: rm.GetRMRemotingInstance(), | |||
config: config, | |||
} | |||
xaConnTimeout = config.xaConnConf.XaBranchExecutionTimeout | |||
branchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build() | |||
rm.GetRmCacheInstance().RegisterResourceManager(xaSourceManager) | |||
go xaSourceManager.xaTwoPhaseTimeoutChecker() | |||
return xaSourceManager | |||
} | |||
type XAResourceManager struct { | |||
config XAConfig | |||
resourceCache sync.Map | |||
basic *datasource.BasicSourceManager | |||
rmRemoting *rm.RMRemoting | |||
} | |||
func (xaManager *XAResourceManager) xaTwoPhaseTimeoutChecker() { | |||
var dbResource *DBResource | |||
xaManager.resourceCache.Range(func(key, value any) bool { | |||
if source, ok := value.(*DBResource); ok { | |||
dbResource = source | |||
} | |||
return false | |||
}) | |||
if dbResource.IsShouldBeHeld() { | |||
ticker := time.NewTicker(time.Second) | |||
for { | |||
select { | |||
case <-ticker.C: | |||
xaManager.resourceCache.Range(func(key, value any) bool { | |||
source, ok := value.(*DBResource) | |||
if !ok { | |||
return true | |||
} | |||
if source.IsShouldBeHeld() { | |||
return true | |||
} | |||
source.GetKeeper().Range(func(key, value any) bool { | |||
connectionXA, isConnectionXA := value.(*XAConn) | |||
if !isConnectionXA { | |||
return true | |||
} | |||
if time.Now().Sub(connectionXA.prepareTime) > xaManager.config.TwoPhaseHoldTime { | |||
if err := connectionXA.CloseForce(); err != nil { | |||
log.Errorf("Force close the xa xid:%s physical connection fail", connectionXA.txCtx.XID) | |||
} | |||
} | |||
return true | |||
}) | |||
return true | |||
}) | |||
} | |||
} | |||
} | |||
} | |||
func (xaManager *XAResourceManager) GetBranchType() branch.BranchType { | |||
return branch.BranchTypeXA | |||
} | |||
func (xaManager *XAResourceManager) GetCachedResources() *sync.Map { | |||
return &xaManager.resourceCache | |||
} | |||
func (xaManager *XAResourceManager) RegisterResource(res rm.Resource) error { | |||
xaManager.resourceCache.Store(res.GetResourceId(), res) | |||
return xaManager.basic.RegisterResource(res) | |||
} | |||
func (xaManager *XAResourceManager) UnregisterResource(resource rm.Resource) error { | |||
return xaManager.basic.UnregisterResource(resource) | |||
} | |||
func (xaManager *XAResourceManager) xaIDBuilder(xid string, branchId uint64) XAXid { | |||
return XaIdBuild(xid, branchId) | |||
} | |||
func (xaManager *XAResourceManager) finishBranch(ctx context.Context, xaID XAXid, branchResource rm.BranchResource) (*XAConn, error) { | |||
resource, ok := xaManager.resourceCache.Load(branchResource.ResourceId) | |||
if !ok { | |||
err := fmt.Errorf("unknow resource for rollback xa, resourceId: %s", branchResource.ResourceId) | |||
log.Errorf(err.Error()) | |||
return nil, err | |||
} | |||
dbResource, ok := resource.(*DBResource) | |||
if !ok { | |||
err := fmt.Errorf("unknow resource for rollback xa, resourceId: %s", branchResource.ResourceId) | |||
log.Errorf(err.Error()) | |||
return nil, err | |||
} | |||
connectionProxyXA, err := dbResource.ConnectionForXA(ctx, xaID) | |||
if err != nil { | |||
err := fmt.Errorf("get connection for rollback xa, resourceId: %s", branchResource.ResourceId) | |||
log.Errorf(err.Error()) | |||
return nil, err | |||
} | |||
return connectionProxyXA, nil | |||
} | |||
func (xaManager *XAResourceManager) BranchCommit(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) { | |||
xaID := xaManager.xaIDBuilder(branchResource.Xid, uint64(branchResource.BranchId)) | |||
connectionProxyXA, err := xaManager.finishBranch(ctx, xaID, branchResource) | |||
if err != nil { | |||
return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err | |||
} | |||
if commitErr := connectionProxyXA.XaCommit(ctx, xaID.String(), branchResource.BranchId); commitErr != nil { | |||
err := fmt.Errorf("rollback xa, resourceId: %s", branchResource.ResourceId) | |||
log.Errorf(err.Error()) | |||
setBranchStatus(xaID.String(), branch.BranchStatusPhasetwoCommitted) | |||
return branch.BranchStatusPhasetwoCommitFailedUnretryable, err | |||
} | |||
log.Infof("%s was committed", xaID.String()) | |||
return branch.BranchStatusPhasetwoCommitted, nil | |||
} | |||
func (xaManager *XAResourceManager) BranchRollback(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) { | |||
xaID := xaManager.xaIDBuilder(branchResource.Xid, uint64(branchResource.BranchId)) | |||
connectionProxyXA, err := xaManager.finishBranch(ctx, xaID, branchResource) | |||
if err != nil { | |||
return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err | |||
} | |||
if rollbackErr := connectionProxyXA.XaRollbackByBranchId(ctx, xaID.String(), branchResource.BranchId); rollbackErr != nil { | |||
err := fmt.Errorf("rollback xa, resourceId: %s", branchResource.ResourceId) | |||
log.Errorf(err.Error()) | |||
setBranchStatus(xaID.String(), branch.BranchStatusPhasetwoRollbacked) | |||
return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err | |||
} | |||
log.Infof("%s was rollback", xaID.String()) | |||
return branch.BranchStatusPhasetwoRollbacked, nil | |||
} | |||
func (xaManager *XAResourceManager) LockQuery(ctx context.Context, param rm.LockQueryParam) (bool, error) { | |||
return false, nil | |||
} | |||
func (xaManager *XAResourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { | |||
return xaManager.rmRemoting.BranchRegister(req) | |||
} | |||
func (xaManager *XAResourceManager) BranchReport(ctx context.Context, param rm.BranchReportParam) error { | |||
return xaManager.rmRemoting.BranchReport(param) | |||
} | |||
func (xaManager *XAResourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (datasource.TableMetaCache, error) { | |||
return xaManager.basic.CreateTableMetaCache(ctx, resID, dbType, db) | |||
} | |||
func branchStatus(xaBranchXid string) (branch.BranchStatus, error) { | |||
tmpBranchStatus, err := branchStatusCache.GetIFPresent(xaBranchXid) | |||
if err != nil { | |||
if errors.Is(err, gcache.KeyNotFoundError) { | |||
return branch.BranchStatusUnknown, nil | |||
} | |||
return branch.BranchStatusUnknown, err | |||
} | |||
branchStatus, isBranchStatus := tmpBranchStatus.(branch.BranchStatus) | |||
if !isBranchStatus { | |||
return branch.BranchStatusUnknown, fmt.Errorf("branchId:%s get result isn't branch status", xaBranchXid) | |||
} | |||
return branchStatus, nil | |||
} | |||
func setBranchStatus(xaBranchXid string, status branch.BranchStatus) { | |||
branchStatusCache.Set(xaBranchXid, status) | |||
} |
@@ -15,16 +15,10 @@ | |||
* limitations under the License. | |||
*/ | |||
package xa | |||
type Xid interface { | |||
GetFormatId() int | |||
GetGlobalTransactionId() []byte | |||
GetBranchQualifier() []byte | |||
} | |||
package sql | |||
type XAXid interface { | |||
Xid | |||
GetGlobalXid() string | |||
GetBranchId() int64 | |||
GetBranchId() uint64 | |||
String() string | |||
} |
@@ -15,12 +15,12 @@ | |||
* limitations under the License. | |||
*/ | |||
package xa | |||
package sql | |||
func Build(xid string, branchId int64) *XABranchXid { | |||
func XaIdBuild(xid string, branchId uint64) *XABranchXid { | |||
return NewXABranchXid(WithXid(xid), WithBranchId(branchId)) | |||
} | |||
func BuildWithByte(globalTransactionId []byte, branchQualifier []byte) *XABranchXid { | |||
func XaIdBuildWithByte(globalTransactionId []byte, branchQualifier []byte) *XABranchXid { | |||
return NewXABranchXid(WithGlobalTransactionId(globalTransactionId), WithBranchQualifier(branchQualifier)) | |||
} |
@@ -35,71 +35,42 @@ const ( | |||
) | |||
const ( | |||
/** | |||
* The BranchStatus_Unknown. | |||
* description:BranchStatus_Unknown branch status. | |||
*/ | |||
BranchStatusUnknown = BranchStatus(0) | |||
/** | |||
* The BranchStatus_Registered. | |||
* description:BranchStatus_Registered to TC. | |||
*/ | |||
BranchStatusRegistered = BranchStatus(1) | |||
/** | |||
* The Phase one done. | |||
* description:Branch logic is successfully done at phase one. | |||
*/ | |||
BranchStatusPhaseoneDone = BranchStatus(2) | |||
/** | |||
* The Phase one failed. | |||
* description:Branch logic is failed at phase one. | |||
*/ | |||
BranchStatusPhaseoneFailed = BranchStatus(3) | |||
/** | |||
* The Phase one timeout. | |||
* description:Branch logic is NOT reported for a timeout. | |||
*/ | |||
BranchStatusPhaseoneTimeout = BranchStatus(4) | |||
/** | |||
* The Phase two committed. | |||
* description:Commit logic is successfully done at phase two. | |||
*/ | |||
BranchStatusPhasetwoCommitted = BranchStatus(5) | |||
/** | |||
* The Phase two commit failed retryable. | |||
* description:Commit logic is failed but retryable. | |||
*/ | |||
BranchStatusPhasetwoCommitFailedRetryable = BranchStatus(6) | |||
/** | |||
* The Phase two commit failed unretryable. | |||
* description:Commit logic is failed and NOT retryable. | |||
*/ | |||
BranchStatusPhasetwoCommitFailedUnretryable = BranchStatus(7) | |||
/** | |||
* The Phase two rollbacked. | |||
* description:Rollback logic is successfully done at phase two. | |||
*/ | |||
BranchStatusPhasetwoRollbacked = BranchStatus(8) | |||
/** | |||
* The Phase two rollback failed retryable. | |||
* description:Rollback logic is failed but retryable. | |||
*/ | |||
BranchStatusPhasetwoRollbackFailedRetryable = BranchStatus(9) | |||
/** | |||
* The Phase two rollback failed unretryable. | |||
* description:Rollback logic is failed but NOT retryable. | |||
*/ | |||
BranchStatusPhasetwoRollbackFailedUnretryable = BranchStatus(10) | |||
// BranchStatusUnknown the BranchStatus_Unknown. description:BranchStatus_Unknown branch status. | |||
BranchStatusUnknown = iota | |||
// BranchStatusRegistered the BranchStatus_Registered. description:BranchStatus_Registered to TC. | |||
BranchStatusRegistered | |||
// BranchStatusPhaseoneDone the Phase one done. description:Branch logic is successfully done at phase one. | |||
BranchStatusPhaseoneDone | |||
// BranchStatusPhaseoneFailed the Phase one failed. description:Branch logic is failed at phase one. | |||
BranchStatusPhaseoneFailed | |||
// BranchStatusPhaseoneTimeout the Phase one timeout. description:Branch logic is NOT reported for a timeout. | |||
BranchStatusPhaseoneTimeout | |||
// BranchStatusPhasetwoCommitted the Phase two committed. description:Commit logic is successfully done at phase two. | |||
BranchStatusPhasetwoCommitted | |||
// BranchStatusPhasetwoCommitFailedRetryable the Phase two commit failed retryable. description:Commit logic is failed but retryable. | |||
BranchStatusPhasetwoCommitFailedRetryable | |||
// BranchStatusPhasetwoCommitFailedUnretryable the Phase two commit failed unretryable. | |||
// description:Commit logic is failed and NOT retryable. | |||
BranchStatusPhasetwoCommitFailedUnretryable | |||
// BranchStatusPhasetwoRollbacked The Phase two rollbacked. | |||
// description:Rollback logic is successfully done at phase two. | |||
BranchStatusPhasetwoRollbacked | |||
// BranchStatusPhasetwoRollbackFailedRetryable the Phase two rollback failed retryable. | |||
// description:Rollback logic is failed but retryable. | |||
BranchStatusPhasetwoRollbackFailedRetryable | |||
// BranchStatusPhasetwoRollbackFailedUnretryable the Phase two rollback failed unretryable. | |||
// description:Rollback logic is failed but NOT retryable. | |||
BranchStatusPhasetwoRollbackFailedUnretryable | |||
) | |||
func (s BranchStatus) String() string { | |||
@@ -31,7 +31,7 @@ type Resource interface { | |||
GetBranchType() branch.BranchType | |||
} | |||
// branch resource which contains branch to commit or rollback | |||
// BranchResource contains branch to commit or rollback | |||
type BranchResource struct { | |||
BranchType branch.BranchType | |||
Xid string | |||
@@ -42,9 +42,9 @@ type BranchResource struct { | |||
// ResourceManagerInbound Control a branch transaction commit or rollback | |||
type ResourceManagerInbound interface { | |||
// Commit a branch transaction | |||
// BranchCommit commit a branch transaction | |||
BranchCommit(ctx context.Context, resource BranchResource) (branch.BranchStatus, error) | |||
// Rollback a branch transaction | |||
// BranchRollback rollback a branch transaction | |||
BranchRollback(ctx context.Context, resource BranchResource) (branch.BranchStatus, error) | |||
} | |||
@@ -75,13 +75,13 @@ type LockQueryParam struct { | |||
LockKeys string | |||
} | |||
// Resource Manager: send outbound request to TC | |||
// ResourceManagerOutbound Resource Manager: send outbound request to TC | |||
type ResourceManagerOutbound interface { | |||
// Branch register long | |||
// BranchRegister rm register the branch transaction | |||
BranchRegister(ctx context.Context, param BranchRegisterParam) (int64, error) | |||
// Branch report | |||
// BranchReport branch transaction report the status | |||
BranchReport(ctx context.Context, param BranchReportParam) error | |||
// Lock query boolean | |||
// LockQuery lock query boolean | |||
LockQuery(ctx context.Context, param LockQueryParam) (bool, error) | |||
} | |||
@@ -90,13 +90,13 @@ type ResourceManager interface { | |||
ResourceManagerInbound | |||
ResourceManagerOutbound | |||
// Register a Resource to be managed by Resource Manager | |||
// RegisterResource register a resource to be managed by resource manager | |||
RegisterResource(resource Resource) error | |||
// Unregister a Resource from the Resource Manager | |||
// UnregisterResource unregister a resource from the Resource Manager | |||
UnregisterResource(resource Resource) error | |||
// Get all resources managed by this manager | |||
// GetCachedResources get all resources managed by this manager | |||
GetCachedResources() *sync.Map | |||
// Get the BranchType | |||
// GetBranchType get the branch type | |||
GetBranchType() branch.BranchType | |||
} | |||