| @@ -43,7 +43,7 @@ func InitPath(configFilePath string) { | |||
| initRmClient(cfg) | |||
| initTmClient(cfg) | |||
| initDatasource(cfg) | |||
| initDatasource() | |||
| } | |||
| var ( | |||
| @@ -84,10 +84,11 @@ func initRmClient(cfg *Config) { | |||
| integration.Init() | |||
| tcc.InitTCC() | |||
| at.InitAT(cfg.ClientConfig.UndoConfig, cfg.AsyncWorkerConfig) | |||
| at.InitXA(cfg.ClientConfig.XaConfig) | |||
| }) | |||
| } | |||
| func initDatasource(cfg *Config) { | |||
| func initDatasource() { | |||
| onceInitDatasource.Do(func() { | |||
| datasource.Init() | |||
| }) | |||
| @@ -54,15 +54,17 @@ const ( | |||
| ) | |||
| type ClientConfig struct { | |||
| TmConfig tm.TmConfig `yaml:"tm" json:"tm,omitempty" koanf:"tm"` | |||
| RmConfig rm.Config `yaml:"rm" json:"rm,omitempty" koanf:"rm"` | |||
| UndoConfig undo.Config `yaml:"undo" json:"undo,omitempty" koanf:"undo"` | |||
| TmConfig tm.TmConfig `yaml:"tm" json:"tm,omitempty" koanf:"tm"` | |||
| RmConfig rm.Config `yaml:"rm" json:"rm,omitempty" koanf:"rm"` | |||
| UndoConfig undo.Config `yaml:"undo" json:"undo,omitempty" koanf:"undo"` | |||
| XaConfig sql.XAConfig `yaml:"xa" json:"xa" koanf:"xa"` | |||
| } | |||
| func (c *ClientConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { | |||
| c.TmConfig.RegisterFlagsWithPrefix(prefix+".tm", f) | |||
| c.RmConfig.RegisterFlagsWithPrefix(prefix+".rm", f) | |||
| c.UndoConfig.RegisterFlagsWithPrefix(prefix+".undo", f) | |||
| c.XaConfig.RegisterFlagsWithPrefix(prefix+".xa", f) | |||
| } | |||
| type Config struct { | |||
| @@ -17,8 +17,10 @@ | |||
| package datasource | |||
| import sql2 "github.com/seata/seata-go/pkg/datasource/sql" | |||
| import ( | |||
| "github.com/seata/seata-go/pkg/datasource/sql" | |||
| ) | |||
| func Init() { | |||
| sql2.Init() | |||
| sql.Init() | |||
| } | |||
| @@ -56,23 +56,23 @@ func (a *ATSourceManager) GetBranchType() branch.BranchType { | |||
| return branch.BranchTypeAT | |||
| } | |||
| // Get all resources managed by this manager | |||
| // GetCachedResources get all resources managed by this manager | |||
| func (a *ATSourceManager) GetCachedResources() *sync.Map { | |||
| return &a.resourceCache | |||
| } | |||
| // Register a Resource to be managed by Resource Manager | |||
| // RegisterResource register a Resource to be managed by Resource Manager | |||
| func (a *ATSourceManager) RegisterResource(res rm.Resource) error { | |||
| a.resourceCache.Store(res.GetResourceId(), res) | |||
| return a.basic.RegisterResource(res) | |||
| } | |||
| // Unregister a Resource from the Resource Manager | |||
| // UnregisterResource unregister a Resource from the Resource Manager | |||
| func (a *ATSourceManager) UnregisterResource(res rm.Resource) error { | |||
| return a.basic.UnregisterResource(res) | |||
| } | |||
| // Rollback a branch transaction | |||
| // BranchRollback rollback a branch transaction | |||
| func (a *ATSourceManager) BranchRollback(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) { | |||
| var dbResource *DBResource | |||
| if resource, ok := a.resourceCache.Load(branchResource.ResourceId); !ok { | |||
| @@ -103,28 +103,26 @@ func (a *ATSourceManager) BranchRollback(ctx context.Context, branchResource rm. | |||
| return branch.BranchStatusPhasetwoRollbacked, nil | |||
| } | |||
| // BranchCommit | |||
| // BranchCommit commit the branch transaction | |||
| func (a *ATSourceManager) BranchCommit(ctx context.Context, resource rm.BranchResource) (branch.BranchStatus, error) { | |||
| a.worker.BranchCommit(ctx, resource) | |||
| return branch.BranchStatusPhasetwoCommitted, nil | |||
| } | |||
| // LockQuery | |||
| func (a *ATSourceManager) LockQuery(ctx context.Context, param rm.LockQueryParam) (bool, error) { | |||
| return a.rmRemoting.LockQuery(param) | |||
| } | |||
| // BranchRegister | |||
| // BranchRegister branch transaction register | |||
| func (a *ATSourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { | |||
| return a.rmRemoting.BranchRegister(req) | |||
| } | |||
| // BranchReport | |||
| // BranchReport Report status of transaction branch | |||
| func (a *ATSourceManager) BranchReport(ctx context.Context, param rm.BranchReportParam) error { | |||
| return a.rmRemoting.BranchReport(param) | |||
| } | |||
| // CreateTableMetaCache | |||
| func (a *ATSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, | |||
| db *sql.DB) (datasource.TableMetaCache, error) { | |||
| return a.basic.CreateTableMetaCache(ctx, resID, dbType, db) | |||
| @@ -211,7 +211,13 @@ func (c *Conn) Begin() (driver.Tx, error) { | |||
| // | |||
| // global transaction according to tranCtx. If so, it needs to be included in the transaction management of seata | |||
| func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | |||
| c.autoCommit = false | |||
| if c.txCtx.TransactionMode == types.XAMode { | |||
| return newTx( | |||
| withDriverConn(c), | |||
| withTxCtx(c.txCtx), | |||
| withOriginTx(nil), | |||
| ) | |||
| } | |||
| if conn, ok := c.targetConn.(driver.ConnBeginTx); ok { | |||
| tx, err := conn.BeginTx(ctx, opts) | |||
| @@ -1,91 +0,0 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package conn | |||
| import "time" | |||
| const ( | |||
| // TMENDICANT Ends a recovery scan. | |||
| TMENDRSCAN = 0x00800000 | |||
| /** | |||
| * Disassociates the caller and marks the transaction branch | |||
| * rollback-only. | |||
| */ | |||
| TMFAIL = 0x20000000 | |||
| /** | |||
| * Caller is joining existing transaction branch. | |||
| */ | |||
| TMJOIN = 0x00200000 | |||
| /** | |||
| * Use TMNOFLAGS to indicate no flags value is selected. | |||
| */ | |||
| TMNOFLAGS = 0x00000000 | |||
| /** | |||
| * Caller is using one-phase optimization. | |||
| */ | |||
| TMONEPHASE = 0x40000000 | |||
| /** | |||
| * Caller is resuming association with a suspended | |||
| * transaction branch. | |||
| */ | |||
| TMRESUME = 0x08000000 | |||
| /** | |||
| * Starts a recovery scan. | |||
| */ | |||
| TMSTARTRSCAN = 0x01000000 | |||
| /** | |||
| * Disassociates caller from a transaction branch. | |||
| */ | |||
| TMSUCCESS = 0x04000000 | |||
| /** | |||
| * Caller is suspending (not ending) its association with | |||
| * a transaction branch. | |||
| */ | |||
| TMSUSPEND = 0x02000000 | |||
| /** | |||
| * The transaction branch has been read-only and has been committed. | |||
| */ | |||
| XA_RDONLY = 0x00000003 | |||
| /** | |||
| * The transaction work has been prepared normally. | |||
| */ | |||
| XA_OK = 0 | |||
| ) | |||
| type XAResource interface { | |||
| Commit(xid string, onePhase bool) error | |||
| End(xid string, flags int) error | |||
| Forget(xid string) error | |||
| GetTransactionTimeout() time.Duration | |||
| IsSameRM(resource XAResource) bool | |||
| XAPrepare(xid string) (int, error) | |||
| Recover(flag int) []string | |||
| Rollback(xid string) error | |||
| SetTransactionTimeout(duration time.Duration) bool | |||
| Start(xid string, flags int) error | |||
| } | |||
| @@ -22,11 +22,10 @@ import ( | |||
| gosql "database/sql" | |||
| "database/sql/driver" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/tm" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| ) | |||
| // ATConn Database connection proxy object under XA transaction model | |||
| @@ -26,11 +26,13 @@ import ( | |||
| "github.com/golang/mock/gomock" | |||
| "github.com/google/uuid" | |||
| "github.com/stretchr/testify/assert" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/seata/seata-go/pkg/tm" | |||
| "github.com/stretchr/testify/assert" | |||
| ) | |||
| func TestMain(m *testing.M) { | |||
| @@ -41,7 +43,7 @@ func TestMain(m *testing.M) { | |||
| func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) { | |||
| ctrl := gomock.NewController(t) | |||
| mockMgr := initMockResourceManager(t, ctrl) | |||
| mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) | |||
| _ = mockMgr | |||
| db, err := sql.Open(SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_client?multiStatements=true") | |||
| @@ -57,6 +59,14 @@ func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQL | |||
| mockConn := mock.NewMockTestDriverConn(ctrl) | |||
| mockConn.EXPECT().Begin().AnyTimes().Return(mockTx, nil) | |||
| mockConn.EXPECT().BeginTx(gomock.Any(), gomock.Any()).AnyTimes().Return(mockTx, nil) | |||
| mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn( | |||
| func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
| rows := &mysqlMockRows{} | |||
| rows.data = [][]interface{}{ | |||
| {"8.0.29"}, | |||
| } | |||
| return rows, nil | |||
| }) | |||
| baseMockConn(mockConn) | |||
| connector := mock.NewMockTestDriverConnector(ctrl) | |||
| @@ -19,41 +19,74 @@ package sql | |||
| import ( | |||
| "context" | |||
| gosql "database/sql" | |||
| "database/sql/driver" | |||
| "fmt" | |||
| "time" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/xa" | |||
| "github.com/seata/seata-go/pkg/tm" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| ) | |||
| var xaConnTimeout time.Duration | |||
| // XAConn Database connection proxy object under XA transaction model | |||
| // Conn is assumed to be stateful. | |||
| type XAConn struct { | |||
| *Conn | |||
| tx driver.Tx | |||
| xaResource xa.XAResource | |||
| xaBranchXid *XABranchXid | |||
| xaActive bool | |||
| rollBacked bool | |||
| branchRegisterTime time.Time | |||
| prepareTime time.Time | |||
| isConnKept bool | |||
| } | |||
| // QueryContext | |||
| func (c *XAConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
| func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | |||
| if c.createOnceTxContext(ctx) { | |||
| defer func() { | |||
| c.txCtx = types.NewTxCtx() | |||
| }() | |||
| } | |||
| return c.Conn.QueryContext(ctx, query, args) | |||
| //ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types, error) { | |||
| // ret, err := c.Conn.PrepareContext(ctx, query) | |||
| // if err != nil { | |||
| // return nil, err | |||
| // } | |||
| // return types.NewResult(types.WithRows(ret)), nil | |||
| //}) | |||
| return c.Conn.PrepareContext(ctx, query) | |||
| } | |||
| // PrepareContext | |||
| func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | |||
| // QueryContext exec xa sql | |||
| func (c *XAConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
| if c.createOnceTxContext(ctx) { | |||
| defer func() { | |||
| c.txCtx = types.NewTxCtx() | |||
| }() | |||
| } | |||
| return c.Conn.PrepareContext(ctx, query) | |||
| ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) { | |||
| ret, err := c.Conn.QueryContext(ctx, query, args) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return types.NewResult(types.WithRows(ret)), nil | |||
| }) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return ret.GetRows(), nil | |||
| } | |||
| // ExecContext | |||
| func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | |||
| if c.createOnceTxContext(ctx) { | |||
| defer func() { | |||
| @@ -61,26 +94,56 @@ func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.Na | |||
| }() | |||
| } | |||
| return c.Conn.ExecContext(ctx, query, args) | |||
| ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) { | |||
| ret, err := c.Conn.ExecContext(ctx, query, args) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return types.NewResult(types.WithResult(ret)), nil | |||
| }) | |||
| return ret.GetResult(), err | |||
| } | |||
| // BeginTx | |||
| // BeginTx like common transaction. but it just exec XA START | |||
| func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | |||
| c.autoCommit = false | |||
| c.txCtx = types.NewTxCtx() | |||
| c.txCtx.DBType = c.res.dbType | |||
| c.txCtx.TxOpt = opts | |||
| c.txCtx.ResourceID = c.res.resourceID | |||
| if tm.IsGlobalTx(ctx) { | |||
| c.txCtx.TransactionMode = types.XAMode | |||
| c.txCtx.XID = tm.GetXID(ctx) | |||
| c.txCtx.TransactionMode = types.XAMode | |||
| } | |||
| tx, err := c.Conn.BeginTx(ctx, opts) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| c.tx = tx | |||
| if c.autoCommit { | |||
| baseTx, ok := tx.(*Tx) | |||
| if !ok { | |||
| return nil, fmt.Errorf("start xa %s transaction failure for the tx is a wrong type", c.txCtx.XID) | |||
| } | |||
| c.branchRegisterTime = time.Now() | |||
| if err := baseTx.register(c.txCtx); err != nil { | |||
| c.cleanXABranchContext() | |||
| return nil, fmt.Errorf("failed to register xa branch %s, err:%w", c.txCtx.XID, err) | |||
| } | |||
| c.xaBranchXid = XaIdBuild(c.txCtx.XID, c.txCtx.BranchID) | |||
| c.keepIfNecessary() | |||
| if err = c.start(ctx); err != nil { | |||
| c.cleanXABranchContext() | |||
| return nil, fmt.Errorf("failed to start xa branch xid:%s err:%w", c.txCtx.XID, err) | |||
| } | |||
| c.xaActive = true | |||
| } | |||
| return &XATx{tx: tx.(*Tx)}, nil | |||
| } | |||
| @@ -91,9 +154,240 @@ func (c *XAConn) createOnceTxContext(ctx context.Context) bool { | |||
| if onceTx { | |||
| c.txCtx = types.NewTxCtx() | |||
| c.txCtx.DBType = c.res.dbType | |||
| c.txCtx.ResourceID = c.res.resourceID | |||
| c.txCtx.XID = tm.GetXID(ctx) | |||
| c.txCtx.TransactionMode = types.XAMode | |||
| c.txCtx.GlobalLockRequire = true | |||
| } | |||
| return onceTx | |||
| } | |||
| func (c *XAConn) createNewTxOnExecIfNeed(ctx context.Context, f func() (types.ExecResult, error)) (types.ExecResult, error) { | |||
| var err error | |||
| if c.txCtx.TransactionMode != types.Local && c.autoCommit { | |||
| _, err = c.BeginTx(ctx, driver.TxOptions{Isolation: driver.IsolationLevel(gosql.LevelDefault)}) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| } | |||
| defer func() { | |||
| recoverErr := recover() | |||
| if err != nil || recoverErr != nil { | |||
| log.Errorf("conn at rollback error:%v or recoverErr:%v", err, recoverErr) | |||
| if c.tx != nil { | |||
| rollbackErr := c.tx.Rollback() | |||
| if rollbackErr != nil { | |||
| log.Errorf("conn at rollback error:%v", rollbackErr) | |||
| } | |||
| } | |||
| } | |||
| }() | |||
| // execute SQL | |||
| ret, err := f() | |||
| if err != nil { | |||
| // XA End & Rollback | |||
| if rollbackErr := c.Rollback(ctx); rollbackErr != nil { | |||
| log.Errorf("failed to rollback xa branch of :%s, err:%w", c.txCtx.XID, rollbackErr) | |||
| } | |||
| return nil, err | |||
| } | |||
| if c.autoCommit { | |||
| if err := c.Commit(ctx); err != nil { | |||
| log.Errorf("xa connection proxy commit failure xid:%s, err:%v", c.txCtx.XID, err) | |||
| // XA End & Rollback | |||
| if err := c.Rollback(ctx); err != nil { | |||
| log.Errorf("xa connection proxy rollback failure xid:%s, err:%v", c.txCtx.XID, err) | |||
| } | |||
| } | |||
| } | |||
| return ret, nil | |||
| } | |||
| func (c *XAConn) keepIfNecessary() { | |||
| if c.ShouldBeHeld() { | |||
| if err := c.res.Hold(c.xaBranchXid.String(), c); err == nil { | |||
| c.isConnKept = true | |||
| } | |||
| } | |||
| } | |||
| func (c *XAConn) releaseIfNecessary() { | |||
| if c.ShouldBeHeld() && c.xaBranchXid.String() != "" { | |||
| if c.isConnKept { | |||
| c.res.Release(c.xaBranchXid.String()) | |||
| c.isConnKept = false | |||
| } | |||
| } | |||
| } | |||
| func (c *XAConn) start(ctx context.Context) error { | |||
| xaResource, err := xa.CreateXAResource(c.Conn.targetConn, c.dbType) | |||
| if err != nil { | |||
| return fmt.Errorf("create xa xid:%s resoruce err:%w", c.txCtx.XID, err) | |||
| } | |||
| c.xaResource = xaResource | |||
| if err := c.xaResource.Start(ctx, c.xaBranchXid.String(), xa.TMNoFlags); err != nil { | |||
| return fmt.Errorf("xa xid %s resource connection start err:%w", c.txCtx.XID, err) | |||
| } | |||
| if err := c.termination(c.xaBranchXid.String()); err != nil { | |||
| c.xaResource.End(ctx, c.xaBranchXid.String(), xa.TMFail) | |||
| c.XaRollback(ctx, c.xaBranchXid) | |||
| return err | |||
| } | |||
| return err | |||
| } | |||
| func (c *XAConn) end(ctx context.Context, flags int) error { | |||
| err := c.termination(c.xaBranchXid.String()) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| err = c.xaResource.End(ctx, c.xaBranchXid.String(), flags) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| return nil | |||
| } | |||
| func (c *XAConn) termination(xaBranchXid string) error { | |||
| branchStatus, err := branchStatus(xaBranchXid) | |||
| if err != nil { | |||
| c.releaseIfNecessary() | |||
| return fmt.Errorf("failed xa branch [%v] the global transaction has finish, branch status: [%v]", c.txCtx.XID, branchStatus) | |||
| } | |||
| return nil | |||
| } | |||
| func (c *XAConn) cleanXABranchContext() { | |||
| h, _ := time.ParseDuration("-1000h") | |||
| c.branchRegisterTime = time.Now().Add(h) | |||
| c.prepareTime = time.Now().Add(h) | |||
| c.xaActive = false | |||
| if !c.isConnKept { | |||
| c.xaBranchXid = nil | |||
| } | |||
| } | |||
| func (c *XAConn) Rollback(ctx context.Context) error { | |||
| if c.autoCommit { | |||
| return nil | |||
| } | |||
| if !c.xaActive || c.xaBranchXid == nil { | |||
| return fmt.Errorf("should NOT rollback on an inactive session") | |||
| } | |||
| if !c.rollBacked { | |||
| if c.xaResource.End(ctx, c.xaBranchXid.String(), xa.TMFail) != nil { | |||
| return c.rollbackErrorHandle() | |||
| } | |||
| if c.XaRollback(ctx, c.xaBranchXid) != nil { | |||
| c.cleanXABranchContext() | |||
| return c.rollbackErrorHandle() | |||
| } | |||
| if err := c.tx.Rollback(); err != nil { | |||
| c.cleanXABranchContext() | |||
| return fmt.Errorf("failed to report XA branch commit-failure on xid:%s err:%w", c.txCtx.XID, err) | |||
| } | |||
| } | |||
| c.cleanXABranchContext() | |||
| return nil | |||
| } | |||
| func (c *XAConn) rollbackErrorHandle() error { | |||
| return fmt.Errorf("failed to end(TMFAIL) xa branch on [%v] - [%v]", c.txCtx.XID, c.xaBranchXid.GetBranchId()) | |||
| } | |||
| func (c *XAConn) Commit(ctx context.Context) error { | |||
| if c.autoCommit { | |||
| return nil | |||
| } | |||
| if !c.xaActive || c.xaBranchXid == nil { | |||
| return fmt.Errorf("should NOT commit on an inactive session") | |||
| } | |||
| now := time.Now() | |||
| if c.end(ctx, xa.TMSuccess) != nil { | |||
| return c.commitErrorHandle() | |||
| } | |||
| if c.checkTimeout(ctx, now) != nil { | |||
| return c.commitErrorHandle() | |||
| } | |||
| if c.xaResource.XAPrepare(ctx, c.xaBranchXid.String()) != nil { | |||
| return c.commitErrorHandle() | |||
| } | |||
| return nil | |||
| } | |||
| func (c *XAConn) commitErrorHandle() error { | |||
| var err error | |||
| if err = c.tx.Rollback(); err != nil { | |||
| err = fmt.Errorf("failed to report XA branch commit-failure xid:%s, err:%w", c.txCtx.XID, err) | |||
| } | |||
| c.cleanXABranchContext() | |||
| return err | |||
| } | |||
| func (c *XAConn) ShouldBeHeld() bool { | |||
| return c.res.IsShouldBeHeld() || (c.res.GetDbType().String() != "" && c.res.GetDbType() != types.DBTypeUnknown) | |||
| } | |||
| func (c *XAConn) checkTimeout(ctx context.Context, now time.Time) error { | |||
| if now.Sub(c.branchRegisterTime) > xaConnTimeout { | |||
| c.XaRollback(ctx, c.xaBranchXid) | |||
| return fmt.Errorf("XA branch timeout error xid:%s", c.txCtx.XID) | |||
| } | |||
| return nil | |||
| } | |||
| func (c *XAConn) Close() error { | |||
| c.rollBacked = false | |||
| if c.isConnKept && c.ShouldBeHeld() { | |||
| return nil | |||
| } | |||
| c.cleanXABranchContext() | |||
| if err := c.Conn.Close(); err != nil { | |||
| return err | |||
| } | |||
| return nil | |||
| } | |||
| func (c *XAConn) CloseForce() error { | |||
| if err := c.Conn.Close(); err != nil { | |||
| return err | |||
| } | |||
| c.rollBacked = false | |||
| c.cleanXABranchContext() | |||
| if err := c.Conn.Close(); err != nil { | |||
| return err | |||
| } | |||
| c.releaseIfNecessary() | |||
| return nil | |||
| } | |||
| func (c *XAConn) XaCommit(ctx context.Context, xid string, branchId int64) error { | |||
| xaXid := XaIdBuild(xid, uint64(branchId)) | |||
| err := c.xaResource.Commit(ctx, xaXid.String(), false) | |||
| c.releaseIfNecessary() | |||
| return err | |||
| } | |||
| func (c *XAConn) XaRollbackByBranchId(ctx context.Context, xid string, branchId int64) error { | |||
| xaXid := XaIdBuild(xid, uint64(branchId)) | |||
| return c.XaRollback(ctx, xaXid) | |||
| } | |||
| func (c *XAConn) XaRollback(ctx context.Context, xaXid XAXid) error { | |||
| err := c.xaResource.Rollback(ctx, xaXid.GetGlobalXid()) | |||
| c.releaseIfNecessary() | |||
| return err | |||
| } | |||
| @@ -21,18 +21,59 @@ import ( | |||
| "context" | |||
| "database/sql" | |||
| "database/sql/driver" | |||
| "io" | |||
| "sync/atomic" | |||
| "testing" | |||
| "time" | |||
| "github.com/bluele/gcache" | |||
| "github.com/golang/mock/gomock" | |||
| "github.com/google/uuid" | |||
| "github.com/stretchr/testify/assert" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/seata/seata-go/pkg/tm" | |||
| "github.com/stretchr/testify/assert" | |||
| ) | |||
| type mysqlMockRows struct { | |||
| idx int | |||
| data [][]interface{} | |||
| } | |||
| func (m *mysqlMockRows) Columns() []string { | |||
| //TODO implement me | |||
| panic("implement me") | |||
| } | |||
| func (m *mysqlMockRows) Close() error { | |||
| //TODO implement me | |||
| panic("implement me") | |||
| } | |||
| func (m *mysqlMockRows) Next(dest []driver.Value) error { | |||
| if m.idx == len(m.data) { | |||
| return io.EOF | |||
| } | |||
| min := func(a, b int) int { | |||
| if a < b { | |||
| return a | |||
| } | |||
| return b | |||
| } | |||
| cnt := min(len(m.data[0]), len(dest)) | |||
| for i := 0; i < cnt; i++ { | |||
| dest[i] = m.data[m.idx][i] | |||
| } | |||
| m.idx++ | |||
| return nil | |||
| } | |||
| type mockSQLInterceptor struct { | |||
| before func(ctx context.Context, execCtx *types.ExecContext) | |||
| after func(ctx context.Context, execCtx *types.ExecContext) | |||
| @@ -78,16 +119,27 @@ func (mi *mockTxHook) BeforeRollback(tx *Tx) { | |||
| } | |||
| func baseMockConn(mockConn *mock.MockTestDriverConn) { | |||
| branchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build() | |||
| mockConn.EXPECT().ExecContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil) | |||
| mockConn.EXPECT().Exec(gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil) | |||
| mockConn.EXPECT().ResetSession(gomock.Any()).AnyTimes().Return(nil) | |||
| mockConn.EXPECT().Close().AnyTimes().Return(nil) | |||
| mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn( | |||
| func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
| rows := &mysqlMockRows{} | |||
| rows.data = [][]interface{}{ | |||
| {"8.0.29"}, | |||
| } | |||
| return rows, nil | |||
| }) | |||
| } | |||
| func initXAConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) { | |||
| ctrl := gomock.NewController(t) | |||
| mockMgr := initMockResourceManager(t, ctrl) | |||
| mockMgr := initMockResourceManager(branch.BranchTypeXA, ctrl) | |||
| _ = mockMgr | |||
| //db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
| db, err := sql.Open("seata-xa-mysql", "root:12345678@tcp(127.0.0.1:3306)/seata_client?multiStatements=true&interpolateParams=true") | |||
| @@ -20,11 +20,14 @@ package sql | |||
| import ( | |||
| "context" | |||
| "database/sql/driver" | |||
| "errors" | |||
| "io" | |||
| "sync" | |||
| "github.com/go-sql-driver/mysql" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| ) | |||
| type seataATConnector struct { | |||
| @@ -89,7 +92,6 @@ func (c *seataXAConnector) Driver() driver.Driver { | |||
| // method will call Close and return error (if any). | |||
| type seataConnector struct { | |||
| transType types.TransactionMode | |||
| conf *seataServerConfig | |||
| res *DBResource | |||
| once sync.Once | |||
| driver driver.Driver | |||
| @@ -116,6 +118,15 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) { | |||
| return nil, err | |||
| } | |||
| // get the version of mysql for xa. | |||
| if c.transType == types.XAMode { | |||
| version, err := c.dbVersion(ctx, conn) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| c.res.SetDbVersion(version) | |||
| } | |||
| return &Conn{ | |||
| targetConn: conn, | |||
| res: c.res, | |||
| @@ -126,6 +137,40 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) { | |||
| }, nil | |||
| } | |||
| func (c *seataConnector) dbVersion(ctx context.Context, conn driver.Conn) (string, error) { | |||
| queryConn, isQueryContext := conn.(driver.QueryerContext) | |||
| if !isQueryContext { | |||
| return "", errors.New("get db version error for unexpected driver conn") | |||
| } | |||
| res, err := queryConn.QueryContext(ctx, "SELECT VERSION()", nil) | |||
| if err != nil { | |||
| log.Errorf("seata connector get the xa mysql version err:%v", err) | |||
| return "", err | |||
| } | |||
| dest := make([]driver.Value, 1) | |||
| var version string | |||
| for true { | |||
| if err = res.Next(dest); err != nil { | |||
| if err == io.EOF { | |||
| return version, nil | |||
| } | |||
| return "", err | |||
| } | |||
| if len(dest) != 1 { | |||
| return "", errors.New("get the mysql version is not column 1") | |||
| } | |||
| var isVersionOk bool | |||
| version, isVersionOk = dest[0].(string) | |||
| if !isVersionOk { | |||
| return "", errors.New("get the mysql version is not a string") | |||
| } | |||
| } | |||
| return "", errors.New("get the mysql version is error") | |||
| } | |||
| // Driver returns the underlying Driver of the Connector, | |||
| // mainly to maintain compatibility with the Driver method | |||
| // on sql.DB. | |||
| @@ -25,10 +25,12 @@ import ( | |||
| "testing" | |||
| "github.com/golang/mock/gomock" | |||
| "github.com/stretchr/testify/assert" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/seata/seata-go/pkg/util/reflectx" | |||
| "github.com/stretchr/testify/assert" | |||
| ) | |||
| type initConnectorFunc func(t *testing.T, ctrl *gomock.Controller) driver.Connector | |||
| @@ -38,6 +40,14 @@ func initMockConnector(t *testing.T, ctrl *gomock.Controller) driver.Connector { | |||
| connector := mock.NewMockTestDriverConnector(ctrl) | |||
| connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil) | |||
| mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn( | |||
| func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
| rows := &mysqlMockRows{} | |||
| rows.data = [][]interface{}{ | |||
| {"8.0.29"}, | |||
| } | |||
| return rows, nil | |||
| }) | |||
| return connector | |||
| } | |||
| @@ -66,10 +76,10 @@ func Test_seataATConnector_Connect(t *testing.T) { | |||
| ctrl := gomock.NewController(t) | |||
| defer ctrl.Finish() | |||
| mockMgr := initMockResourceManager(t, ctrl) | |||
| mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) | |||
| _ = mockMgr | |||
| db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
| db, err := sql.Open(SeataATMySQLDriver, "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
| if err != nil { | |||
| t.Fatal(err) | |||
| } | |||
| @@ -110,10 +120,10 @@ func Test_seataXAConnector_Connect(t *testing.T) { | |||
| ctrl := gomock.NewController(t) | |||
| defer ctrl.Finish() | |||
| mockMgr := initMockResourceManager(t, ctrl) | |||
| mockMgr := initMockResourceManager(branch.BranchTypeXA, ctrl) | |||
| _ = mockMgr | |||
| db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
| db, err := sql.Open(SeataXAMySQLDriver, "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
| if err != nil { | |||
| t.Fatal(err) | |||
| } | |||
| @@ -25,7 +25,6 @@ import ( | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/seata/seata-go/pkg/protocol/message" | |||
| "github.com/seata/seata-go/pkg/rm" | |||
| ) | |||
| @@ -34,7 +33,7 @@ var ( | |||
| tableMetaCacheMap = map[types.DBType]TableMetaCache{} | |||
| ) | |||
| // RegisterTableCache | |||
| // RegisterTableCache register the table meta cache for at and xa | |||
| func RegisterTableCache(dbType types.DBType, tableMetaCache TableMetaCache) { | |||
| tableMetaCacheMap[dbType] = tableMetaCache | |||
| } | |||
| @@ -54,11 +53,8 @@ func GetDataSourceManager(branchType branch.BranchType) DataSourceManager { | |||
| return nil | |||
| } | |||
| // todo implements ResourceManagerOutbound interface | |||
| // DataSourceManager | |||
| type DataSourceManager interface { | |||
| rm.ResourceManager | |||
| // CreateTableMetaCache | |||
| CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (TableMetaCache, error) | |||
| } | |||
| @@ -67,9 +63,8 @@ type entry struct { | |||
| metaCache TableMetaCache | |||
| } | |||
| // BasicSourceManager | |||
| // BasicSourceManager the basic source manager for xa and at | |||
| type BasicSourceManager struct { | |||
| // lock | |||
| lock sync.RWMutex | |||
| // tableMetaCache | |||
| // todo do not put meta cache here | |||
| @@ -82,34 +77,7 @@ func NewBasicSourceManager() *BasicSourceManager { | |||
| } | |||
| } | |||
| // Commit a branch transaction | |||
| // TODO wait finish | |||
| func (dm *BasicSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) { | |||
| return branch.BranchStatusPhaseoneDone, nil | |||
| } | |||
| // Rollback a branch transaction | |||
| // TODO wait finish | |||
| func (dm *BasicSourceManager) BranchRollback(ctx context.Context, req message.BranchRollbackRequest) (branch.BranchStatus, error) { | |||
| return branch.BranchStatusPhaseoneFailed, nil | |||
| } | |||
| // Branch register long | |||
| func (dm *BasicSourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { | |||
| return 0, nil | |||
| } | |||
| // Branch report | |||
| func (dm *BasicSourceManager) BranchReport(ctx context.Context, req message.BranchReportRequest) error { | |||
| return nil | |||
| } | |||
| // Lock query boolean | |||
| func (dm *BasicSourceManager) LockQuery(ctx context.Context, branchType branch.BranchType, resourceId, xid, lockKeys string) (bool, error) { | |||
| return true, nil | |||
| } | |||
| // Register a model.Resource to be managed by model.Resource Manager | |||
| // RegisterResource register a model.Resource to be managed by model.Resource Manager | |||
| func (dm *BasicSourceManager) RegisterResource(resource rm.Resource) error { | |||
| err := rm.GetRMRemotingInstance().RegisterResource(resource) | |||
| if err != nil { | |||
| @@ -118,22 +86,11 @@ func (dm *BasicSourceManager) RegisterResource(resource rm.Resource) error { | |||
| return nil | |||
| } | |||
| // Unregister a model.Resource from the model.Resource Manager | |||
| func (dm *BasicSourceManager) UnregisterResource(resource rm.Resource) error { | |||
| return fmt.Errorf("unsupport unregister resource") | |||
| } | |||
| // Get all resources managed by this manager | |||
| func (dm *BasicSourceManager) GetManagedResources() *sync.Map { | |||
| return nil | |||
| } | |||
| // Get the model.BranchType | |||
| func (dm *BasicSourceManager) GetBranchType() branch.BranchType { | |||
| return branch.BranchTypeAT | |||
| } | |||
| // CreateTableMetaCache | |||
| // CreateTableMetaCache create a table meta cache | |||
| func (dm *BasicSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (TableMetaCache, error) { | |||
| dm.lock.Lock() | |||
| defer dm.lock.Unlock() | |||
| @@ -144,28 +101,19 @@ func (dm *BasicSourceManager) CreateTableMetaCache(ctx context.Context, resID st | |||
| } | |||
| dm.tableMetaCache[resID] = res | |||
| // 注册 AT 数据资源 | |||
| // dm.resourceMgr.RegisterResource(ATResource) | |||
| return res.metaCache, err | |||
| } | |||
| // TableMetaCache tables metadata cache, default is open | |||
| type TableMetaCache interface { | |||
| // Init | |||
| Init(ctx context.Context, conn *sql.DB) error | |||
| // GetTableMeta | |||
| GetTableMeta(ctx context.Context, dbName, table string) (*types.TableMeta, error) | |||
| // Destroy | |||
| Destroy() error | |||
| } | |||
| // buildResource | |||
| // todo not here | |||
| func buildResource(ctx context.Context, dbType types.DBType, db *sql.DB) (*entry, error) { | |||
| cache := tableMetaCacheMap[dbType] | |||
| if err := cache.Init(ctx, db); err != nil { | |||
| return nil, err | |||
| } | |||
| @@ -1,124 +0,0 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package sql | |||
| import ( | |||
| "fmt" | |||
| "time" | |||
| "github.com/bluele/gcache" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| ) | |||
| type Holdable interface { | |||
| SetHeld(held bool) | |||
| IsHeld() bool | |||
| ShouldBeHeld() bool | |||
| } | |||
| type BaseDataSourceResource struct { | |||
| db *DBResource | |||
| shouldBeHeld bool | |||
| keeper map[string]interface{} | |||
| Cache map[string]branch.BranchStatus | |||
| } | |||
| var BranchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build() | |||
| func (b *BaseDataSourceResource) init() error { | |||
| return nil | |||
| } | |||
| func (b *BaseDataSourceResource) GetDB() *DBResource { | |||
| return b.db | |||
| } | |||
| func (b *BaseDataSourceResource) SetDB(db *DBResource) { | |||
| b.db = db | |||
| } | |||
| func (b *BaseDataSourceResource) IsShouldBeHeld() bool { | |||
| return b.shouldBeHeld | |||
| } | |||
| func (b *BaseDataSourceResource) SetShouldBeHeld(shouldBeHeld bool) { | |||
| b.shouldBeHeld = shouldBeHeld | |||
| } | |||
| func (b *BaseDataSourceResource) GetKeeper() map[string]interface{} { | |||
| return b.keeper | |||
| } | |||
| func (b *BaseDataSourceResource) SetKeeper(keeper map[string]interface{}) { | |||
| b.keeper = keeper | |||
| } | |||
| func (b *BaseDataSourceResource) GetCache() map[string]branch.BranchStatus { | |||
| return b.Cache | |||
| } | |||
| func (b *BaseDataSourceResource) SetCache(cache map[string]branch.BranchStatus) { | |||
| b.Cache = cache | |||
| } | |||
| func (b *BaseDataSourceResource) GetResourceId() string { | |||
| return b.db.GetResourceId() | |||
| } | |||
| func (b *BaseDataSourceResource) Hold(key string, value Holdable) (interface{}, error) { | |||
| if value.IsHeld() { | |||
| var x = b.keeper[key] | |||
| if x != value { | |||
| return nil, fmt.Errorf("something wrong with keeper, keeping[%v] but[%v] is also kept with the same key[%v]", x, value, key) | |||
| } | |||
| return value, nil | |||
| } | |||
| var x = b.keeper[key] | |||
| b.keeper[key] = value | |||
| value.SetHeld(true) | |||
| return x, nil | |||
| } | |||
| func (b *BaseDataSourceResource) Release(key string, value Holdable) (interface{}, error) { | |||
| if value.IsHeld() { | |||
| var x = b.keeper[key] | |||
| if x != value { | |||
| return nil, fmt.Errorf("something wrong with keeper, keeping[%v] but[%v] is also kept with the same key[%v]", x, value, key) | |||
| } | |||
| return value, nil | |||
| } | |||
| var x = b.keeper[key] | |||
| b.keeper[key] = value | |||
| value.SetHeld(true) | |||
| return x, nil | |||
| } | |||
| func (b *BaseDataSourceResource) GetBranchStatus(xaBranchXid string) (interface{}, error) { | |||
| branchStatus, err := BranchStatusCache.GetIFPresent(xaBranchXid) | |||
| return branchStatus, err | |||
| } | |||
| func (b *BaseDataSourceResource) GetDbType() string { | |||
| return b.db.dbType.String() | |||
| } | |||
| func (b *BaseDataSourceResource) SetDbType(dbType types.DBType) { | |||
| b.db.dbType = dbType | |||
| } | |||
| @@ -18,19 +18,24 @@ | |||
| package sql | |||
| import ( | |||
| "context" | |||
| "database/sql" | |||
| "database/sql/driver" | |||
| "fmt" | |||
| "sync" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/undo" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/util" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| ) | |||
| type dbOption func(db *DBResource) | |||
| func withGroupID(id string) dbOption { | |||
| func withDsn(dsn string) dbOption { | |||
| return func(db *DBResource) { | |||
| db.groupID = id | |||
| db.dsn = dsn | |||
| } | |||
| } | |||
| @@ -52,21 +57,33 @@ func withDBType(dt types.DBType) dbOption { | |||
| } | |||
| } | |||
| func withBranchType(dt branch.BranchType) dbOption { | |||
| return func(db *DBResource) { | |||
| db.branchType = dt | |||
| } | |||
| } | |||
| func withTarget(source *sql.DB) dbOption { | |||
| return func(db *DBResource) { | |||
| db.db = source | |||
| } | |||
| } | |||
| func withConnector(ci driver.Connector) dbOption { | |||
| return func(db *DBResource) { | |||
| db.connector = ci | |||
| } | |||
| } | |||
| func withDBName(dbName string) dbOption { | |||
| return func(db *DBResource) { | |||
| db.dbName = dbName | |||
| } | |||
| } | |||
| func withConf(conf *seataServerConfig) dbOption { | |||
| func withConf(conf *XAConnConf) dbOption { | |||
| return func(db *DBResource) { | |||
| db.conf = *conf | |||
| db.xaConnConf = conf | |||
| } | |||
| } | |||
| @@ -77,47 +94,36 @@ func newResource(opts ...dbOption) (*DBResource, error) { | |||
| opts[i](db) | |||
| } | |||
| return db, db.init() | |||
| db.init() | |||
| return db, nil | |||
| } | |||
| // DBResource proxy sql.DB, enchance database/sql.DB to add distribute transaction ability | |||
| type DBResource struct { | |||
| // groupID | |||
| groupID string | |||
| // resourceID | |||
| xaConnConf *XAConnConf | |||
| // only use by mysql | |||
| dbVersion string | |||
| dsn string | |||
| resourceID string | |||
| // conf | |||
| conf seataServerConfig | |||
| // db | |||
| db *sql.DB | |||
| dbName string | |||
| // dbType | |||
| dbType types.DBType | |||
| // undoLogMgr | |||
| db *sql.DB | |||
| connector driver.Connector | |||
| dbName string | |||
| dbType types.DBType | |||
| undoLogMgr undo.UndoLogManager | |||
| // metaCache | |||
| metaCache datasource.TableMetaCache | |||
| } | |||
| branchType branch.BranchType | |||
| func (db *DBResource) init() error { | |||
| return nil | |||
| // for xa | |||
| metaCache datasource.TableMetaCache | |||
| shouldBeHeld bool | |||
| keeper sync.Map | |||
| } | |||
| // todo do not put meta data to rm | |||
| //func (db *DBResource) init() error { | |||
| // mgr := datasource.GetDataSourceManager(db.GetBranchType()) | |||
| // metaCache, err := mgr.CreateTableMetaCache(context.Background(), db.resourceID, db.dbType, db.db) | |||
| // if err != nil { | |||
| // return err | |||
| // } | |||
| // | |||
| // db.metaCache = metaCache | |||
| // | |||
| // return nil | |||
| //} | |||
| func (db *DBResource) GetResourceGroupId() string { | |||
| return db.groupID | |||
| panic("implement me") | |||
| } | |||
| func (db *DBResource) init() { | |||
| db.checkDbVersion() | |||
| } | |||
| func (db *DBResource) GetResourceId() string { | |||
| @@ -125,18 +131,106 @@ func (db *DBResource) GetResourceId() string { | |||
| } | |||
| func (db *DBResource) GetBranchType() branch.BranchType { | |||
| return db.conf.BranchType | |||
| return db.branchType | |||
| } | |||
| func (db *DBResource) GetDB() *sql.DB { | |||
| return db.db | |||
| } | |||
| type SqlDBProxy struct { | |||
| db *sql.DB | |||
| dbName string | |||
| func (db *DBResource) GetDBName() string { | |||
| return db.dbName | |||
| } | |||
| func (s *SqlDBProxy) GetDB() *sql.DB { | |||
| return s.db | |||
| func (db *DBResource) GetDbType() types.DBType { | |||
| return db.dbType | |||
| } | |||
| func (s *SqlDBProxy) GetDBName() string { | |||
| return s.dbName | |||
| func (db *DBResource) SetDbType(dbType types.DBType) { | |||
| db.dbType = dbType | |||
| } | |||
| func (db *DBResource) SetDbVersion(v string) { | |||
| db.dbVersion = v | |||
| } | |||
| func (db *DBResource) GetDbVersion() string { | |||
| return db.dbVersion | |||
| } | |||
| func (db *DBResource) IsShouldBeHeld() bool { | |||
| return db.shouldBeHeld | |||
| } | |||
| // Hold the xa connection. | |||
| func (db *DBResource) Hold(xaBranchID string, v interface{}) error { | |||
| _, exist := db.keeper.Load(xaBranchID) | |||
| if !exist { | |||
| db.keeper.Store(xaBranchID, v) | |||
| return nil | |||
| } | |||
| return nil | |||
| } | |||
| func (db *DBResource) Release(xaBranchID string) { | |||
| db.keeper.Delete(xaBranchID) | |||
| } | |||
| func (db *DBResource) Lookup(xaBranchID string) (interface{}, bool) { | |||
| return db.keeper.Load(xaBranchID) | |||
| } | |||
| func (db *DBResource) GetKeeper() *sync.Map { | |||
| return &db.keeper | |||
| } | |||
| func (db *DBResource) ConnectionForXA(ctx context.Context, xaXid XAXid) (*XAConn, error) { | |||
| xaBranchXid := xaXid.String() | |||
| tmpConn, ok := db.Lookup(xaBranchXid) | |||
| if ok && tmpConn != nil { | |||
| connectionProxyXa, isConnectionProxyXa := tmpConn.(*XAConn) | |||
| if !isConnectionProxyXa { | |||
| return nil, fmt.Errorf("get connection proxy xa from cache error, xid:%s", xaXid.String()) | |||
| } | |||
| return connectionProxyXa, nil | |||
| } | |||
| // why here need a new connection? | |||
| // 1. because there maybe a rm cluster | |||
| // 2. the first phase select a rm1, and store the connection is the keeper | |||
| // 3. tc request the second phase. but the rm1 is shutdown, so the tc select another rm (like rm2) | |||
| // 4. so when the second phase request coming to rm2, rm2 must not store the connection. | |||
| // 5. the rm2 get the second phase do the two thing. | |||
| // 1. in mysql version >= 8.0.29, mysql support the xa transaction commit by another connection. so just commit | |||
| // 2. when the version < 8.0.29. so just make the transaction rollback | |||
| newDriverConn, err := db.connector.Connect(ctx) | |||
| if err != nil { | |||
| return nil, fmt.Errorf("get xa new connection failure, xid:%s, err:%v", xaXid.String(), err) | |||
| } | |||
| xaConn := &XAConn{ | |||
| Conn: newDriverConn.(*Conn), | |||
| } | |||
| return xaConn, nil | |||
| } | |||
| func (db *DBResource) checkDbVersion() error { | |||
| switch db.dbType { | |||
| case types.DBTypeMySQL: | |||
| currentVersion, err := util.ConvertDbVersion(db.dbVersion) | |||
| if err != nil { | |||
| return fmt.Errorf("new connection xa proxy convert db version:%s err:%v", db.GetDbVersion(), err) | |||
| } | |||
| shouldKeptVersion, err := util.ConvertDbVersion("8.0.29") | |||
| if err != nil { | |||
| return fmt.Errorf("new connection xa proxy convert db version 8.0.29 err:%v", err) | |||
| } | |||
| if currentVersion < shouldKeptVersion { | |||
| db.shouldBeHeld = true | |||
| } | |||
| case types.DBTypeMARIADB: | |||
| db.shouldBeHeld = true | |||
| } | |||
| return nil | |||
| } | |||
| @@ -25,10 +25,10 @@ import ( | |||
| "fmt" | |||
| "strings" | |||
| mysql2 "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" | |||
| "github.com/go-sql-driver/mysql" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
| mysql2 "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| @@ -44,15 +44,17 @@ const ( | |||
| func initDriver() { | |||
| sql.Register(SeataATMySQLDriver, &seataATDriver{ | |||
| seataDriver: &seataDriver{ | |||
| transType: types.ATMode, | |||
| target: mysql.MySQLDriver{}, | |||
| branchType: branch.BranchTypeAT, | |||
| transType: types.ATMode, | |||
| target: mysql.MySQLDriver{}, | |||
| }, | |||
| }) | |||
| sql.Register(SeataXAMySQLDriver, &seataXADriver{ | |||
| seataDriver: &seataDriver{ | |||
| transType: types.XAMode, | |||
| target: mysql.MySQLDriver{}, | |||
| branchType: branch.BranchTypeXA, | |||
| transType: types.XAMode, | |||
| target: mysql.MySQLDriver{}, | |||
| }, | |||
| }) | |||
| } | |||
| @@ -98,8 +100,9 @@ func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err erro | |||
| } | |||
| type seataDriver struct { | |||
| transType types.TransactionMode | |||
| target driver.Driver | |||
| branchType branch.BranchType | |||
| transType types.TransactionMode | |||
| target driver.Driver | |||
| } | |||
| // Open never be called, because seataDriver implemented dri.DriverContext interface. | |||
| @@ -124,7 +127,7 @@ func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) | |||
| return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName()) | |||
| } | |||
| proxy, err := getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name) | |||
| proxy, err := d.getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name) | |||
| if err != nil { | |||
| log.Errorf("register resource: %w", err) | |||
| return nil, err | |||
| @@ -133,43 +136,15 @@ func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) | |||
| return proxy, nil | |||
| } | |||
| func (d *seataDriver) getTargetDriverName() string { | |||
| return "mysql" | |||
| } | |||
| type dsnConnector struct { | |||
| dsn string | |||
| driver driver.Driver | |||
| } | |||
| func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) { | |||
| return t.driver.Open(t.dsn) | |||
| } | |||
| func (t *dsnConnector) Driver() driver.Driver { | |||
| return t.driver | |||
| } | |||
| func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db *sql.DB, | |||
| dataSourceName string, opts ...seataOption) (driver.Connector, error) { | |||
| conf := loadConfig() | |||
| for i := range opts { | |||
| opts[i](conf) | |||
| } | |||
| if err := conf.validate(); err != nil { | |||
| log.Errorf("invalid conf: %w", err) | |||
| return nil, err | |||
| } | |||
| func (d *seataDriver) getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, | |||
| db *sql.DB, dataSourceName string) (driver.Connector, error) { | |||
| cfg, _ := mysql.ParseDSN(dataSourceName) | |||
| options := []dbOption{ | |||
| withGroupID(conf.GroupID), | |||
| withResourceID(parseResourceID(dataSourceName)), | |||
| withConf(conf), | |||
| withTarget(db), | |||
| withBranchType(d.branchType), | |||
| withDBType(dbType), | |||
| withDBName(cfg.DBName), | |||
| withConnector(connector), | |||
| } | |||
| res, err := newResource(options...) | |||
| @@ -179,7 +154,8 @@ func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db * | |||
| } | |||
| datasource.RegisterTableCache(types.DBTypeMySQL, mysql2.NewTableMetaInstance(db)) | |||
| if err = datasource.GetDataSourceManager(conf.BranchType).RegisterResource(res); err != nil { | |||
| if err = datasource.GetDataSourceManager(d.branchType).RegisterResource(res); err != nil { | |||
| log.Errorf("regisiter resource: %w", err) | |||
| return nil, err | |||
| } | |||
| @@ -187,38 +163,25 @@ func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db * | |||
| return &seataConnector{ | |||
| res: res, | |||
| target: connector, | |||
| conf: conf, | |||
| cfg: cfg, | |||
| }, nil | |||
| } | |||
| type ( | |||
| seataOption func(cfg *seataServerConfig) | |||
| // seataServerConfig | |||
| seataServerConfig struct { | |||
| // GroupID | |||
| GroupID string `yaml:"groupID"` | |||
| // BranchType | |||
| BranchType branch.BranchType | |||
| // Endpoints | |||
| Endpoints []string `yaml:"endpoints" json:"endpoints"` | |||
| } | |||
| ) | |||
| func (d *seataDriver) getTargetDriverName() string { | |||
| return "mysql" | |||
| } | |||
| func (c *seataServerConfig) validate() error { | |||
| return nil | |||
| type dsnConnector struct { | |||
| dsn string | |||
| driver driver.Driver | |||
| } | |||
| // loadConfig | |||
| func loadConfig() *seataServerConfig { | |||
| // set default value first. | |||
| // todo read from configuration file. | |||
| return &seataServerConfig{ | |||
| GroupID: "DEFAULT_GROUP", | |||
| BranchType: branch.BranchTypeAT, | |||
| Endpoints: []string{"127.0.0.1:8888"}, | |||
| } | |||
| func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) { | |||
| return t.driver.Open(t.dsn) | |||
| } | |||
| func (t *dsnConnector) Driver() driver.Driver { | |||
| return t.driver | |||
| } | |||
| func parseResourceID(dsn string) string { | |||
| @@ -28,12 +28,14 @@ import ( | |||
| "github.com/golang/mock/gomock" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/seata/seata-go/pkg/util/reflectx" | |||
| "github.com/stretchr/testify/assert" | |||
| ) | |||
| func initMockResourceManager(t *testing.T, ctrl *gomock.Controller) *mock.MockDataSourceManager { | |||
| func initMockResourceManager(branchType branch.BranchType, ctrl *gomock.Controller) *mock.MockDataSourceManager { | |||
| mockResourceMgr := mock.NewMockDataSourceManager(ctrl) | |||
| mockResourceMgr.SetBranchType(branchType) | |||
| rm.GetRmCacheInstance().RegisterResourceManager(mockResourceMgr) | |||
| mockResourceMgr.EXPECT().RegisterResource(gomock.Any()).AnyTimes().Return(nil) | |||
| mockResourceMgr.EXPECT().CreateTableMetaCache(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) | |||
| @@ -45,7 +47,7 @@ func Test_seataATDriver_Open(t *testing.T) { | |||
| ctrl := gomock.NewController(t) | |||
| defer ctrl.Finish() | |||
| mockMgr := initMockResourceManager(t, ctrl) | |||
| mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) | |||
| _ = mockMgr | |||
| db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
| @@ -93,7 +95,7 @@ func Test_seataATDriver_OpenConnector(t *testing.T) { | |||
| ctrl := gomock.NewController(t) | |||
| defer ctrl.Finish() | |||
| mockMgr := initMockResourceManager(t, ctrl) | |||
| mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) | |||
| _ = mockMgr | |||
| db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
| @@ -119,7 +121,7 @@ func Test_seataXADriver_OpenConnector(t *testing.T) { | |||
| ctrl := gomock.NewController(t) | |||
| defer ctrl.Finish() | |||
| mockMgr := initMockResourceManager(t, ctrl) | |||
| mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) | |||
| _ = mockMgr | |||
| db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") | |||
| @@ -514,8 +514,8 @@ func canAutoIncrement(pkMetaMap map[string]types.ColumnMeta) bool { | |||
| return false | |||
| } | |||
| func (u *insertExecutor) isAstStmtValid() bool { | |||
| return u.parserCtx != nil && u.parserCtx.InsertStmt != nil | |||
| func (i *insertExecutor) isAstStmtValid() bool { | |||
| return i.parserCtx != nil && i.parserCtx.InsertStmt != nil | |||
| } | |||
| func (i *insertExecutor) autoGeneratePks(execCtx *types.ExecContext, autoColumnName string, lastInsetId, updateCount int64) (map[string][]interface{}, error) { | |||
| @@ -27,7 +27,6 @@ import ( | |||
| var ( | |||
| atExecutors = make(map[types.DBType]func() SQLExecutor) | |||
| xaExecutors = make(map[types.DBType]func() SQLExecutor) | |||
| ) | |||
| // RegisterATExecutor AT executor | |||
| @@ -35,13 +34,6 @@ func RegisterATExecutor(dt types.DBType, builder func() SQLExecutor) { | |||
| atExecutors[dt] = builder | |||
| } | |||
| // RegisterXAExecutor XA executor | |||
| func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) { | |||
| xaExecutors[dt] = func() SQLExecutor { | |||
| return builder() | |||
| } | |||
| } | |||
| type ( | |||
| CallbackWithNamedValue func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) | |||
| @@ -66,12 +58,6 @@ func BuildExecutor(dbType types.DBType, transactionMode types.TransactionMode, q | |||
| hooks = append(hooks, commonHook...) | |||
| hooks = append(hooks, hookSolts[parseContext.SQLType]...) | |||
| if transactionMode == types.XAMode { | |||
| e := xaExecutors[dbType]() | |||
| e.Interceptors(hooks) | |||
| return e, nil | |||
| } | |||
| e := atExecutors[dbType]() | |||
| e.Interceptors(hooks) | |||
| return e, nil | |||
| @@ -82,12 +68,10 @@ type BaseExecutor struct { | |||
| ex SQLExecutor | |||
| } | |||
| // Interceptors | |||
| func (e *BaseExecutor) Interceptors(interceptors []SQLHook) { | |||
| e.hooks = interceptors | |||
| } | |||
| // ExecWithNamedValue | |||
| func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) { | |||
| for i := range e.hooks { | |||
| e.hooks[i].Before(ctx, execCtx) | |||
| @@ -1,93 +0,0 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package exec | |||
| import ( | |||
| "time" | |||
| ) | |||
| const ( | |||
| // TMENDICANT Ends a recovery scan. | |||
| TMENDRSCAN = 0x00800000 | |||
| /** | |||
| * Disassociates the caller and marks the transaction branch | |||
| * rollback-only. | |||
| */ | |||
| TMFAIL = 0x20000000 | |||
| /** | |||
| * Caller is joining existing transaction branch. | |||
| */ | |||
| TMJOIN = 0x00200000 | |||
| /** | |||
| * Use TMNOFLAGS to indicate no flags value is selected. | |||
| */ | |||
| TMNOFLAGS = 0x00000000 | |||
| /** | |||
| * Caller is using one-phase optimization. | |||
| */ | |||
| TMONEPHASE = 0x40000000 | |||
| /** | |||
| * Caller is resuming association with a suspended | |||
| * transaction branch. | |||
| */ | |||
| TMRESUME = 0x08000000 | |||
| /** | |||
| * Starts a recovery scan. | |||
| */ | |||
| TMSTARTRSCAN = 0x01000000 | |||
| /** | |||
| * Disassociates caller from a transaction branch. | |||
| */ | |||
| TMSUCCESS = 0x04000000 | |||
| /** | |||
| * Caller is suspending (not ending) its association with | |||
| * a transaction branch. | |||
| */ | |||
| TMSUSPEND = 0x02000000 | |||
| /** | |||
| * The transaction branch has been read-only and has been committed. | |||
| */ | |||
| XA_RDONLY = 0x00000003 | |||
| /** | |||
| * The transaction work has been prepared normally. | |||
| */ | |||
| XA_OK = 0 | |||
| ) | |||
| type XAResource interface { | |||
| Commit(xid string, onePhase bool) error | |||
| End(xid string, flags int) error | |||
| Forget(xid string) error | |||
| GetTransactionTimeout() time.Duration | |||
| IsSameRM(resource XAResource) bool | |||
| XAPrepare(xid string) error | |||
| Recover(flag int) ([]string, error) | |||
| Rollback(xid string) error | |||
| SetTransactionTimeout(duration time.Duration) bool | |||
| Start(xid string, flags int) error | |||
| } | |||
| @@ -30,6 +30,7 @@ import ( | |||
| "github.com/arana-db/parser/ast" | |||
| "github.com/arana-db/parser/format" | |||
| "github.com/arana-db/parser/model" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/undo/builder" | |||
| @@ -1,84 +0,0 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package xa | |||
| import ( | |||
| "context" | |||
| "database/sql/driver" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| ) | |||
| // XAExecutor The XA transaction manager. | |||
| type XAExecutor struct { | |||
| hooks []exec.SQLHook | |||
| ex exec.SQLExecutor | |||
| } | |||
| // Interceptors set xa executor hooks | |||
| func (e *XAExecutor) Interceptors(hooks []exec.SQLHook) { | |||
| e.hooks = hooks | |||
| } | |||
| // ExecWithNamedValue | |||
| func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||
| for _, hook := range e.hooks { | |||
| hook.Before(ctx, execCtx) | |||
| } | |||
| defer func() { | |||
| for _, hook := range e.hooks { | |||
| hook.After(ctx, execCtx) | |||
| } | |||
| }() | |||
| if e.ex != nil { | |||
| return e.ex.ExecWithNamedValue(ctx, execCtx, f) | |||
| } | |||
| return f(ctx, execCtx.Query, execCtx.NamedValues) | |||
| } | |||
| // ExecWithValue | |||
| func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { | |||
| for _, hook := range e.hooks { | |||
| hook.Before(ctx, execCtx) | |||
| } | |||
| defer func() { | |||
| for _, hook := range e.hooks { | |||
| hook.After(ctx, execCtx) | |||
| } | |||
| }() | |||
| if e.ex != nil { | |||
| return e.ex.ExecWithValue(ctx, execCtx, f) | |||
| } | |||
| nvargs := make([]driver.NamedValue, len(execCtx.Values)) | |||
| for i, value := range execCtx.Values { | |||
| nvargs = append(nvargs, driver.NamedValue{ | |||
| Value: value, | |||
| Ordinal: i, | |||
| }) | |||
| } | |||
| execCtx.NamedValues = nvargs | |||
| return f(ctx, execCtx.Query, execCtx.NamedValues) | |||
| } | |||
| @@ -38,6 +38,7 @@ import ( | |||
| type MockDataSourceManager struct { | |||
| ctrl *gomock.Controller | |||
| recorder *MockDataSourceManagerMockRecorder | |||
| branchType branch.BranchType | |||
| } | |||
| // MockDataSourceManagerMockRecorder is the mock recorder for MockDataSourceManager. | |||
| @@ -131,9 +132,13 @@ func (mr *MockDataSourceManagerMockRecorder) CreateTableMetaCache(ctx, resID, db | |||
| return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTableMetaCache", reflect.TypeOf((*MockDataSourceManager)(nil).CreateTableMetaCache), ctx, resID, dbType, db) | |||
| } | |||
| func (m *MockDataSourceManager) SetBranchType(branchType branch.BranchType) { | |||
| m.branchType = branchType | |||
| } | |||
| // GetBranchType mocks base method. | |||
| func (m *MockDataSourceManager) GetBranchType() branch.BranchType { | |||
| return branch.BranchTypeAT | |||
| return m.branchType | |||
| } | |||
| // GetBranchType indicates an expected call of GetBranchType. | |||
| @@ -20,7 +20,6 @@ package sql | |||
| import ( | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec/at" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec/xa" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/hook" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/undo" | |||
| @@ -42,7 +41,6 @@ func hookRegister() { | |||
| func executorRegister() { | |||
| at.Init() | |||
| xa.Init() | |||
| } | |||
| func undoInit() { | |||
| @@ -1,31 +0,0 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package sql | |||
| import ( | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| ) | |||
| type RootContext interface { | |||
| RootContext() | |||
| SetDefaultBranchType(branchType branch.BranchType) | |||
| GetXID() string | |||
| Bind(xid string) | |||
| GetTimeout() (int, bool) | |||
| SetTimeout(timeout int) | |||
| } | |||
| @@ -22,17 +22,16 @@ import ( | |||
| "database/sql/driver" | |||
| "fmt" | |||
| "sync" | |||
| "time" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/seata/seata-go/pkg/rm" | |||
| "github.com/seata/seata-go/pkg/util/backoff" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| ) | |||
| const REPORT_RETRY_COUNT = 5 | |||
| var ( | |||
| hl sync.RWMutex | |||
| txHooks []txHook | |||
| @@ -146,19 +145,31 @@ func (tx *Tx) commitOnLocal() error { | |||
| // register | |||
| func (tx *Tx) register(ctx *types.TransactionContext) error { | |||
| if !ctx.HasUndoLog() || !ctx.HasLockKey() { | |||
| if ctx.TransactionMode.BranchType() == branch.BranchTypeUnknow { | |||
| return nil | |||
| } | |||
| lockKey := "" | |||
| for k, _ := range ctx.LockKeys { | |||
| lockKey += k + ";" | |||
| if ctx.TransactionMode.BranchType() == branch.BranchTypeAT && !ctx.HasUndoLog() || !ctx.HasLockKey() { | |||
| return nil | |||
| } | |||
| request := rm.BranchRegisterParam{ | |||
| Xid: ctx.XID, | |||
| BranchType: ctx.TransactionMode.BranchType(), | |||
| ResourceId: ctx.ResourceID, | |||
| LockKeys: lockKey, | |||
| } | |||
| var lockKey string | |||
| if ctx.TransactionMode == types.ATMode { | |||
| if !ctx.HasUndoLog() || !ctx.HasLockKey() { | |||
| return nil | |||
| } | |||
| for k, _ := range ctx.LockKeys { | |||
| lockKey += k + ";" | |||
| } | |||
| request.LockKeys = lockKey | |||
| } | |||
| dataSourceManager := datasource.GetDataSourceManager(ctx.TransactionMode.BranchType()) | |||
| branchId, err := dataSourceManager.BranchRegister(context.Background(), request) | |||
| if err != nil { | |||
| @@ -184,21 +195,22 @@ func (tx *Tx) report(success bool) error { | |||
| if dataSourceManager == nil { | |||
| return fmt.Errorf("get dataSourceManager failed") | |||
| } | |||
| retry := REPORT_RETRY_COUNT | |||
| for retry > 0 { | |||
| err := dataSourceManager.BranchReport(context.Background(), request) | |||
| if err != nil { | |||
| retry-- | |||
| log.Infof("Failed to report [%s / %s] commit done [%s] Retry Countdown: %s", tx.tranCtx.BranchID, tx.tranCtx.XID, success, retry) | |||
| if retry == 0 { | |||
| log.Errorf("Failed to report branch status: %s", err.Error()) | |||
| return err | |||
| } | |||
| } else { | |||
| return nil | |||
| retry := backoff.New(context.Background(), backoff.Config{ | |||
| MinBackoff: 100 * time.Millisecond, | |||
| MaxBackoff: 200 * time.Millisecond, | |||
| MaxRetries: 5, | |||
| }) | |||
| var err error | |||
| for retry.Ongoing() { | |||
| if err = dataSourceManager.BranchReport(context.Background(), request); err == nil { | |||
| break | |||
| } | |||
| log.Infof("Failed to report [%s / %s] commit done [%s] Retry Countdown: %s", tx.tranCtx.BranchID, tx.tranCtx.XID, success, retry) | |||
| retry.Wait() | |||
| } | |||
| return nil | |||
| return err | |||
| } | |||
| func getStatus(success bool) branch.BranchStatus { | |||
| @@ -19,6 +19,7 @@ package sql | |||
| import ( | |||
| "github.com/pkg/errors" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/undo" | |||
| ) | |||
| @@ -17,7 +17,6 @@ | |||
| package sql | |||
| // XATx | |||
| type XATx struct { | |||
| tx *Tx | |||
| } | |||
| @@ -32,20 +31,14 @@ func (tx *XATx) Commit() error { | |||
| } | |||
| func (tx *XATx) Rollback() error { | |||
| err := tx.tx.Rollback() | |||
| if err != nil { | |||
| originTx := tx.tx | |||
| if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() { | |||
| originTx.report(false) | |||
| } | |||
| originTx := tx.tx | |||
| if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() { | |||
| return originTx.report(false) | |||
| } | |||
| return err | |||
| return nil | |||
| } | |||
| // commitOnXA | |||
| // commitOnXA commit xa and register branch transaction | |||
| func (tx *XATx) commitOnXA() error { | |||
| return nil | |||
| } | |||
| @@ -22,9 +22,9 @@ import ( | |||
| "fmt" | |||
| "strings" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/google/uuid" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| ) | |||
| type DBType int16 | |||
| @@ -76,6 +76,7 @@ const ( | |||
| DBTypePostgreSQL | |||
| DBTypeSQLServer | |||
| DBTypeOracle | |||
| DBTypeMARIADB | |||
| BranchPhase_Unknown = 0 | |||
| BranchPhase_Done = 1 | |||
| @@ -28,8 +28,10 @@ import ( | |||
| "database/sql/driver" | |||
| "errors" | |||
| "fmt" | |||
| "math" | |||
| "reflect" | |||
| "strconv" | |||
| "strings" | |||
| "time" | |||
| ) | |||
| @@ -375,3 +377,30 @@ type decimalCompose interface { | |||
| // represented then an error should be returned. | |||
| Compose(form byte, negative bool, coefficient []byte, exponent int32) error | |||
| } | |||
| // ConvertDbVersion convert a string db version to a number. | |||
| func ConvertDbVersion(version string) (int, error) { | |||
| parts := strings.Split(version, ".") | |||
| size := len(parts) | |||
| maxVersionDot := 3 | |||
| if size > maxVersionDot+1 { | |||
| return 0, fmt.Errorf("incompatible version format: %s", version) | |||
| } | |||
| var res int | |||
| for idx, part := range parts { | |||
| if partInt, err := strconv.Atoi(part); err == nil { | |||
| res += calculatePartValue(partInt, size, idx) | |||
| } else { | |||
| subParts := strings.Split(part, "-") | |||
| if subPartInt, err := strconv.Atoi(subParts[0]); err == nil { | |||
| res += calculatePartValue(subPartInt, size, idx) | |||
| } | |||
| } | |||
| } | |||
| return res, nil | |||
| } | |||
| func calculatePartValue(partNumeric, size, index int) int { | |||
| return partNumeric * int(math.Pow(100, float64(size-index))) | |||
| } | |||
| @@ -15,22 +15,27 @@ | |||
| * limitations under the License. | |||
| */ | |||
| package sql | |||
| package util | |||
| import ( | |||
| "testing" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec/xa" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/stretchr/testify/assert" | |||
| ) | |||
| func TestConn_BuildXAExecutor(t *testing.T) { | |||
| executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.XAMode, "SELECT * FROM user") | |||
| func TestConvertDbVersion(t *testing.T) { | |||
| version1 := "3.1.2" | |||
| v1Int, err1 := ConvertDbVersion(version1) | |||
| assert.NoError(t, err1) | |||
| assert.NoError(t, err) | |||
| version2 := "3.1.3" | |||
| v2Int, err2 := ConvertDbVersion(version2) | |||
| assert.NoError(t, err2) | |||
| _, ok := executor.(*xa.XAExecutor) | |||
| assert.True(t, ok, "need xa executor") | |||
| assert.Less(t, v1Int, v2Int) | |||
| version3 := "3.1.3" | |||
| v3Int, err3 := ConvertDbVersion(version3) | |||
| assert.NoError(t, err3) | |||
| assert.Equal(t, v2Int, v3Int) | |||
| } | |||
| @@ -15,24 +15,27 @@ | |||
| * limitations under the License. | |||
| */ | |||
| package exec | |||
| package xa | |||
| import ( | |||
| "context" | |||
| "database/sql/driver" | |||
| "errors" | |||
| "fmt" | |||
| "io" | |||
| "strings" | |||
| "time" | |||
| "github.com/pkg/errors" | |||
| ) | |||
| type MysqlXAConn struct { | |||
| driver.Conn | |||
| } | |||
| func (c *MysqlXAConn) Commit(xid string, onePhase bool) error { | |||
| func NewMysqlXaConn(conn driver.Conn) *MysqlXAConn { | |||
| return &MysqlXAConn{Conn: conn} | |||
| } | |||
| func (c *MysqlXAConn) Commit(ctx context.Context, xid string, onePhase bool) error { | |||
| var sb strings.Builder | |||
| sb.WriteString("XA COMMIT ") | |||
| sb.WriteString(xid) | |||
| @@ -41,33 +44,33 @@ func (c *MysqlXAConn) Commit(xid string, onePhase bool) error { | |||
| } | |||
| conn, _ := c.Conn.(driver.ExecerContext) | |||
| _, err := conn.ExecContext(context.TODO(), sb.String(), nil) | |||
| _, err := conn.ExecContext(ctx, sb.String(), nil) | |||
| return err | |||
| } | |||
| func (c *MysqlXAConn) End(xid string, flags int) error { | |||
| func (c *MysqlXAConn) End(ctx context.Context, xid string, flags int) error { | |||
| var sb strings.Builder | |||
| sb.WriteString("XA END ") | |||
| sb.WriteString(xid) | |||
| switch flags { | |||
| case TMSUCCESS: | |||
| case TMSuccess: | |||
| break | |||
| case TMSUSPEND: | |||
| case TMSuspend: | |||
| sb.WriteString(" SUSPEND") | |||
| break | |||
| case TMFAIL: | |||
| case TMFail: | |||
| break | |||
| default: | |||
| return errors.New("invalid arguments") | |||
| } | |||
| conn, _ := c.Conn.(driver.ExecerContext) | |||
| _, err := conn.ExecContext(context.TODO(), sb.String(), nil) | |||
| _, err := conn.ExecContext(ctx, sb.String(), nil) | |||
| return err | |||
| } | |||
| func (c *MysqlXAConn) Forget(xid string) error { | |||
| func (c *MysqlXAConn) Forget(ctx context.Context, xid string) error { | |||
| // mysql doesn't support this | |||
| return errors.New("mysql doesn't support this") | |||
| } | |||
| @@ -78,29 +81,29 @@ func (c *MysqlXAConn) GetTransactionTimeout() time.Duration { | |||
| // IsSameRM is called to determine if the resource manager instance represented by the target object | |||
| // is the same as the resource manager instance represented by the parameter xares. | |||
| func (c *MysqlXAConn) IsSameRM(xares XAResource) bool { | |||
| func (c *MysqlXAConn) IsSameRM(ctx context.Context, xares XAResource) bool { | |||
| // todo: the fn depends on the driver.Conn, but it doesn't support | |||
| return false | |||
| } | |||
| func (c *MysqlXAConn) XAPrepare(xid string) error { | |||
| func (c *MysqlXAConn) XAPrepare(ctx context.Context, xid string) error { | |||
| var sb strings.Builder | |||
| sb.WriteString("XA PREPARE ") | |||
| sb.WriteString(xid) | |||
| conn, _ := c.Conn.(driver.ExecerContext) | |||
| _, err := conn.ExecContext(context.TODO(), sb.String(), nil) | |||
| _, err := conn.ExecContext(ctx, sb.String(), nil) | |||
| return err | |||
| } | |||
| // Recover Obtains a list of prepared transaction branches from a resource manager. | |||
| // The transaction manager calls this method during recovery to obtain the list of transaction branches | |||
| // that are currently in prepared or heuristically completed states. | |||
| func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) { | |||
| startRscan := (flag & TMSTARTRSCAN) > 0 | |||
| endRscan := (flag & TMENDRSCAN) > 0 | |||
| func (c *MysqlXAConn) Recover(ctx context.Context, flag int) (xids []string, err error) { | |||
| startRscan := (flag & TMStartRScan) > 0 | |||
| endRscan := (flag & TMEndRScan) > 0 | |||
| if !startRscan && !endRscan && flag != TMNOFLAGS { | |||
| if !startRscan && !endRscan && flag != TMNoFlags { | |||
| return nil, errors.New("invalid arguments") | |||
| } | |||
| @@ -109,7 +112,7 @@ func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) { | |||
| } | |||
| conn := c.Conn.(driver.QueryerContext) | |||
| res, err := conn.QueryContext(context.TODO(), "XA RECOVER", nil) | |||
| res, err := conn.QueryContext(ctx, "XA RECOVER", nil) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| @@ -133,13 +136,13 @@ func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) { | |||
| return xids, err | |||
| } | |||
| func (c *MysqlXAConn) Rollback(xid string) error { | |||
| func (c *MysqlXAConn) Rollback(ctx context.Context, xid string) error { | |||
| var sb strings.Builder | |||
| sb.WriteString("XA ROLLBACK ") | |||
| sb.WriteString(xid) | |||
| conn, _ := c.Conn.(driver.ExecerContext) | |||
| _, err := conn.ExecContext(context.TODO(), sb.String(), nil) | |||
| _, err := conn.ExecContext(ctx, sb.String(), nil) | |||
| return err | |||
| } | |||
| @@ -147,25 +150,25 @@ func (c *MysqlXAConn) SetTransactionTimeout(duration time.Duration) bool { | |||
| return false | |||
| } | |||
| func (c *MysqlXAConn) Start(xid string, flags int) error { | |||
| func (c *MysqlXAConn) Start(ctx context.Context, xid string, flags int) error { | |||
| var sb strings.Builder | |||
| sb.WriteString("XA START") | |||
| sb.WriteString(xid) | |||
| switch flags { | |||
| case TMJOIN: | |||
| case TMJoin: | |||
| sb.WriteString(" JOIN") | |||
| break | |||
| case TMRESUME: | |||
| case TMResume: | |||
| sb.WriteString(" RESUME") | |||
| break | |||
| case TMNOFLAGS: | |||
| case TMNoFlags: | |||
| break | |||
| default: | |||
| return errors.New("invalid arguments") | |||
| } | |||
| conn, _ := c.Conn.(driver.ExecerContext) | |||
| _, err := conn.ExecContext(context.TODO(), sb.String(), nil) | |||
| _, err := conn.ExecContext(ctx, sb.String(), nil) | |||
| return err | |||
| } | |||
| @@ -15,19 +15,18 @@ | |||
| * limitations under the License. | |||
| */ | |||
| package exec | |||
| package xa | |||
| import ( | |||
| "context" | |||
| "database/sql/driver" | |||
| "fmt" | |||
| "errors" | |||
| "io" | |||
| "reflect" | |||
| "strings" | |||
| "testing" | |||
| "github.com/golang/mock/gomock" | |||
| "github.com/pkg/errors" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/mock" | |||
| ) | |||
| @@ -78,7 +77,7 @@ func TestMysqlXAConn_Commit(t *testing.T) { | |||
| c := &MysqlXAConn{ | |||
| Conn: mockConn, | |||
| } | |||
| if err := c.Commit(tt.input.xid, tt.input.onePhase); (err != nil) != tt.wantErr { | |||
| if err := c.Commit(context.Background(), tt.input.xid, tt.input.onePhase); (err != nil) != tt.wantErr { | |||
| t.Errorf("Commit() error = %v, wantErr %v", err, tt.wantErr) | |||
| } | |||
| }) | |||
| @@ -102,7 +101,7 @@ func TestMysqlXAConn_End(t *testing.T) { | |||
| name: "tm success", | |||
| input: args{ | |||
| xid: "xid", | |||
| flags: TMSUCCESS, | |||
| flags: TMSuccess, | |||
| }, | |||
| wantErr: false, | |||
| }, | |||
| @@ -110,7 +109,7 @@ func TestMysqlXAConn_End(t *testing.T) { | |||
| name: "tm failed", | |||
| input: args{ | |||
| xid: "xid", | |||
| flags: TMFAIL, | |||
| flags: TMFail, | |||
| }, | |||
| wantErr: false, | |||
| }, | |||
| @@ -124,7 +123,7 @@ func TestMysqlXAConn_End(t *testing.T) { | |||
| c := &MysqlXAConn{ | |||
| Conn: mockConn, | |||
| } | |||
| if err := c.End(tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { | |||
| if err := c.End(context.Background(), tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { | |||
| t.Errorf("End() error = %v, wantErr %v", err, tt.wantErr) | |||
| } | |||
| }) | |||
| @@ -148,7 +147,7 @@ func TestMysqlXAConn_Start(t *testing.T) { | |||
| name: "normal start", | |||
| input: args{ | |||
| xid: "xid", | |||
| flags: TMNOFLAGS, | |||
| flags: TMNoFlags, | |||
| }, | |||
| wantErr: false, | |||
| }, | |||
| @@ -161,7 +160,7 @@ func TestMysqlXAConn_Start(t *testing.T) { | |||
| c := &MysqlXAConn{ | |||
| Conn: mockConn, | |||
| } | |||
| if err := c.Start(tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { | |||
| if err := c.Start(context.Background(), tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { | |||
| t.Errorf("Start() error = %v, wantErr %v", err, tt.wantErr) | |||
| } | |||
| }) | |||
| @@ -196,7 +195,7 @@ func TestMysqlXAConn_XAPrepare(t *testing.T) { | |||
| c := &MysqlXAConn{ | |||
| Conn: mockConn, | |||
| } | |||
| if err := c.XAPrepare(tt.input.xid); (err != nil) != tt.wantErr { | |||
| if err := c.XAPrepare(context.Background(), tt.input.xid); (err != nil) != tt.wantErr { | |||
| t.Errorf("XAPrepare() error = %v, wantErr %v", err, tt.wantErr) | |||
| } | |||
| }) | |||
| @@ -219,7 +218,7 @@ func TestMysqlXAConn_Recover(t *testing.T) { | |||
| { | |||
| name: "normal recover", | |||
| args: args{ | |||
| flag: TMSTARTRSCAN | TMENDRSCAN, | |||
| flag: TMStartRScan | TMEndRScan, | |||
| }, | |||
| want: []string{"xid", "another_xid"}, | |||
| wantErr: false, | |||
| @@ -227,14 +226,14 @@ func TestMysqlXAConn_Recover(t *testing.T) { | |||
| { | |||
| name: "invalid flag for recover", | |||
| args: args{ | |||
| flag: TMFAIL, | |||
| flag: TMFail, | |||
| }, | |||
| wantErr: true, | |||
| }, | |||
| { | |||
| name: "valid flag for recover but don't scan", | |||
| args: args{ | |||
| flag: TMENDRSCAN, | |||
| flag: TMEndRScan, | |||
| }, | |||
| want: nil, | |||
| wantErr: false, | |||
| @@ -257,7 +256,7 @@ func TestMysqlXAConn_Recover(t *testing.T) { | |||
| c := &MysqlXAConn{ | |||
| Conn: mockConn, | |||
| } | |||
| got, err := c.Recover(tt.args.flag) | |||
| got, err := c.Recover(context.Background(), tt.args.flag) | |||
| if (err != nil) != tt.wantErr { | |||
| t.Errorf("Recover() error = %v, wantErr %v", err, tt.wantErr) | |||
| return | |||
| @@ -295,10 +294,7 @@ func (m *mysqlMockRows) Next(dest []driver.Value) error { | |||
| } | |||
| return b | |||
| } | |||
| cnt := min(len(m.data[0]), len(dest)) | |||
| fmt.Printf("cnt: %d", cnt) | |||
| for i := 0; i < cnt; i++ { | |||
| dest[i] = m.data[m.idx][i] | |||
| } | |||
| @@ -15,7 +15,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| package conn | |||
| package xa | |||
| import ( | |||
| "context" | |||
| @@ -15,7 +15,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| package conn | |||
| package xa | |||
| import ( | |||
| "database/sql/driver" | |||
| @@ -1,26 +0,0 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package xa | |||
| import ( | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
| ) | |||
| type XAConnection interface { | |||
| getXAResource() (exec.XAResource, error) | |||
| } | |||
| @@ -1,325 +0,0 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package xa | |||
| import ( | |||
| "context" | |||
| "fmt" | |||
| "time" | |||
| "github.com/seata/seata-go/pkg/datasource/sql" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/seata/seata-go/pkg/protocol/message" | |||
| "github.com/seata/seata-go/pkg/rm" | |||
| ) | |||
| type ConnectionProxyXA struct { | |||
| xaBranchXid *XABranchXid | |||
| currentAutoCommitStatus bool `default:"true"` | |||
| xaActive bool `default:"false"` | |||
| kept bool `default:"false"` | |||
| rollBacked bool `default:"false"` | |||
| branchRegisterTime int64 `default:"0"` | |||
| prepareTime int64 `default:"0"` | |||
| timeout int `default:"0"` | |||
| proxyShouldBeHeld bool `default:"false"` | |||
| originalConnection sql.Conn | |||
| xaConnection XAConnection | |||
| xaResource exec.XAResource | |||
| resource sql.BaseDataSourceResource | |||
| xid string | |||
| } | |||
| const timeout int = 60000 | |||
| func NewConnectionProxyXA(originalConnection sql.Conn, xaConnection XAConnection, resource sql.BaseDataSourceResource, xid string) (*ConnectionProxyXA, error) { | |||
| connectionProxyXA := &ConnectionProxyXA{} | |||
| connectionProxyXA.originalConnection = originalConnection | |||
| connectionProxyXA.xaConnection = xaConnection | |||
| connectionProxyXA.resource = resource | |||
| connectionProxyXA.xid = xid | |||
| connectionProxyXA.proxyShouldBeHeld = connectionProxyXA.resource.IsShouldBeHeld() | |||
| xaResource, err := xaConnection.getXAResource() | |||
| if err != nil { | |||
| return nil, fmt.Errorf("get xa resource failed") | |||
| } else { | |||
| connectionProxyXA.xaResource = xaResource | |||
| } | |||
| var rootContext sql.RootContext | |||
| transactionTimeout, ok := rootContext.GetTimeout() | |||
| if !ok { | |||
| transactionTimeout = timeout | |||
| } | |||
| if transactionTimeout < timeout { | |||
| transactionTimeout = timeout | |||
| } | |||
| connectionProxyXA.timeout = transactionTimeout | |||
| connectionProxyXA.currentAutoCommitStatus = connectionProxyXA.originalConnection.GetAutoCommit() | |||
| if !connectionProxyXA.currentAutoCommitStatus { | |||
| return nil, fmt.Errorf("connection[autocommit=false] as default is NOT supported") | |||
| } | |||
| return connectionProxyXA, nil | |||
| } | |||
| func (c *ConnectionProxyXA) keepIfNecessary() { | |||
| if c.ShouldBeHeld() { | |||
| c.resource.Hold(c.xaBranchXid.String(), c) | |||
| } | |||
| } | |||
| func (c *ConnectionProxyXA) releaseIfNecessary() { | |||
| if c.ShouldBeHeld() { | |||
| if c.xaBranchXid == nil { | |||
| if c.IsHeld() { | |||
| c.resource.Release(c.xaBranchXid.String(), c) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| func (c *ConnectionProxyXA) XaCommit(xid string, branchId int64) error { | |||
| xaXid := Build(xid, branchId) | |||
| err := c.xaResource.Commit(xaXid.String(), false) | |||
| c.releaseIfNecessary() | |||
| return err | |||
| } | |||
| func (c *ConnectionProxyXA) XaRollbackByBranchId(xid string, branchId int64) { | |||
| xaXid := Build(xid, branchId) | |||
| c.XaRollback(xaXid) | |||
| } | |||
| func (c *ConnectionProxyXA) XaRollback(xaXid XAXid) error { | |||
| err := c.xaResource.Rollback(xaXid.GetGlobalXid()) | |||
| c.releaseIfNecessary() | |||
| return err | |||
| } | |||
| func (c *ConnectionProxyXA) SetAutoCommit(autoCommit bool) error { | |||
| if c.currentAutoCommitStatus == autoCommit { | |||
| return nil | |||
| } | |||
| if autoCommit { | |||
| if c.xaActive { | |||
| _ = c.Commit() | |||
| } | |||
| } else { | |||
| if c.xaActive { | |||
| return fmt.Errorf("should NEVER happen: setAutoCommit from true to false while xa branch is active") | |||
| } | |||
| c.branchRegisterTime = time.Now().UnixMilli() | |||
| var branchRegisterParam rm.BranchRegisterParam | |||
| branchRegisterParam.BranchType = branch.BranchTypeXA | |||
| branchRegisterParam.ResourceId = c.resource.GetResourceId() | |||
| branchRegisterParam.Xid = c.xid | |||
| branchId, err := datasource.GetDataSourceManager(branch.BranchTypeXA).BranchRegister(context.TODO(), branchRegisterParam) | |||
| if err != nil { | |||
| c.cleanXABranchContext() | |||
| return fmt.Errorf("failed to register xa branch [%v]", c.xid) | |||
| } | |||
| c.xaBranchXid = Build(c.xid, branchId) | |||
| c.keepIfNecessary() | |||
| err = c.start() | |||
| if err != nil { | |||
| c.cleanXABranchContext() | |||
| return fmt.Errorf("failed to start xa branch [%v]", c.xid) | |||
| } | |||
| c.xaActive = true | |||
| } | |||
| c.currentAutoCommitStatus = autoCommit | |||
| return nil | |||
| } | |||
| func (c *ConnectionProxyXA) GetAutoCommit() bool { | |||
| return c.currentAutoCommitStatus | |||
| } | |||
| func (c *ConnectionProxyXA) Commit() error { | |||
| if c.currentAutoCommitStatus { | |||
| return nil | |||
| } | |||
| if !c.xaActive || c.xaBranchXid == nil { | |||
| return fmt.Errorf("should NOT commit on an inactive session") | |||
| } | |||
| now := time.Now().UnixMilli() | |||
| if c.end(exec.TMSUCCESS) != nil { | |||
| return c.commitErrorHandle() | |||
| } | |||
| if c.checkTimeout(now) != nil { | |||
| return c.commitErrorHandle() | |||
| } | |||
| if c.xaResource.XAPrepare(c.xaBranchXid.String()) != nil { | |||
| return c.commitErrorHandle() | |||
| } | |||
| return nil | |||
| } | |||
| func (c *ConnectionProxyXA) commitErrorHandle() error { | |||
| req := message.BranchReportRequest{ | |||
| BranchType: branch.BranchTypeXA, | |||
| Xid: c.xid, | |||
| BranchId: c.xaBranchXid.GetBranchId(), | |||
| Status: branch.BranchStatusPhaseoneFailed, | |||
| ApplicationData: nil, | |||
| ResourceId: c.resource.GetResourceId(), | |||
| } | |||
| if datasource.NewBasicSourceManager().BranchReport(context.TODO(), req) != nil { | |||
| c.cleanXABranchContext() | |||
| return fmt.Errorf("Failed to report XA branch commit-failure on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) | |||
| } | |||
| c.cleanXABranchContext() | |||
| return fmt.Errorf("Failed to end(TMSUCCESS)/prepare xa branch on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) | |||
| } | |||
| func (c *ConnectionProxyXA) Rollback() error { | |||
| if c.currentAutoCommitStatus { | |||
| return nil | |||
| } | |||
| if !c.xaActive || c.xaBranchXid == nil { | |||
| return fmt.Errorf("should NOT rollback on an inactive session") | |||
| } | |||
| if !c.rollBacked { | |||
| if c.xaResource.End(c.xaBranchXid.String(), exec.TMFAIL) != nil { | |||
| return c.rollbackErrorHandle() | |||
| } | |||
| if c.XaRollback(c.xaBranchXid) != nil { | |||
| c.cleanXABranchContext() | |||
| return c.rollbackErrorHandle() | |||
| } | |||
| req := message.BranchReportRequest{ | |||
| BranchType: branch.BranchTypeXA, | |||
| Xid: c.xid, | |||
| BranchId: c.xaBranchXid.GetBranchId(), | |||
| Status: branch.BranchStatusPhaseoneFailed, | |||
| ApplicationData: nil, | |||
| ResourceId: c.resource.GetResourceId(), | |||
| } | |||
| if datasource.NewBasicSourceManager().BranchReport(context.TODO(), req) != nil { | |||
| c.cleanXABranchContext() | |||
| return fmt.Errorf("failed to report XA branch commit-failure on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) | |||
| } | |||
| } | |||
| c.cleanXABranchContext() | |||
| return nil | |||
| } | |||
| func (c *ConnectionProxyXA) rollbackErrorHandle() error { | |||
| return fmt.Errorf("failed to end(TMFAIL) xa branch on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) | |||
| } | |||
| func (c *ConnectionProxyXA) start() error { | |||
| err := c.xaResource.Start(c.xaBranchXid.String(), exec.TMNOFLAGS) | |||
| if err := c.termination(c.xaBranchXid.String()); err != nil { | |||
| c.xaResource.End(c.xaBranchXid.String(), exec.TMFAIL) | |||
| c.XaRollback(c.xaBranchXid) | |||
| return err | |||
| } | |||
| return err | |||
| } | |||
| func (c *ConnectionProxyXA) end(flags int) error { | |||
| err := c.termination(c.xaBranchXid.String()) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| err = c.xaResource.End(c.xaBranchXid.String(), flags) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| return nil | |||
| } | |||
| func (c *ConnectionProxyXA) cleanXABranchContext() { | |||
| c.branchRegisterTime = 0 | |||
| c.prepareTime = 0 | |||
| c.timeout = 0 | |||
| c.xaActive = false | |||
| if !c.IsHeld() { | |||
| c.xaBranchXid = nil | |||
| } | |||
| } | |||
| func (c *ConnectionProxyXA) checkTimeout(now int64) error { | |||
| if now-c.branchRegisterTime > int64(c.timeout) { | |||
| c.XaRollback(c.xaBranchXid) | |||
| return fmt.Errorf("XA branch timeout error") | |||
| } | |||
| return nil | |||
| } | |||
| func (c *ConnectionProxyXA) Close() error { | |||
| c.rollBacked = false | |||
| if c.IsHeld() && c.ShouldBeHeld() { | |||
| return nil | |||
| } | |||
| c.cleanXABranchContext() | |||
| if err := c.originalConnection.Close(); err != nil { | |||
| return err | |||
| } | |||
| return nil | |||
| } | |||
| func (c *ConnectionProxyXA) CloseForce() error { | |||
| physicalConn := c.originalConnection | |||
| if err := physicalConn.Close(); err != nil { | |||
| return err | |||
| } | |||
| c.rollBacked = false | |||
| c.cleanXABranchContext() | |||
| if err := c.originalConnection.Close(); err != nil { | |||
| return err | |||
| } | |||
| c.releaseIfNecessary() | |||
| return nil | |||
| } | |||
| func (c *ConnectionProxyXA) SetHeld(kept bool) { | |||
| c.kept = kept | |||
| } | |||
| func (c *ConnectionProxyXA) IsHeld() bool { | |||
| return c.kept | |||
| } | |||
| func (c *ConnectionProxyXA) ShouldBeHeld() bool { | |||
| return c.proxyShouldBeHeld || c.resource.GetDB() != nil | |||
| } | |||
| func (c *ConnectionProxyXA) GetPrepareTime() int64 { | |||
| return c.prepareTime | |||
| } | |||
| func (c *ConnectionProxyXA) setPrepareTime(prepareTime int64) { | |||
| c.prepareTime = prepareTime | |||
| } | |||
| func (c *ConnectionProxyXA) termination(xaBranchXid string) error { | |||
| branchStatus, err := c.resource.GetBranchStatus(xaBranchXid) | |||
| if err != nil { | |||
| c.releaseIfNecessary() | |||
| return fmt.Errorf("failed xa branch [%v] the global transaction has finish, branch status: [%v]", c.xid, branchStatus) | |||
| } | |||
| return nil | |||
| } | |||
| @@ -0,0 +1,71 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package xa | |||
| import ( | |||
| "context" | |||
| "time" | |||
| ) | |||
| const ( | |||
| // TMEndRScan ends a recovery scan. | |||
| TMEndRScan = 0x00800000 | |||
| // TMFail disassociates the caller and marks the transaction branch | |||
| // rollback-only. | |||
| TMFail = 0x20000000 | |||
| // TMJoin joining existing transaction branch. | |||
| TMJoin = 0x00200000 | |||
| // TMNoFlags indicate no flags value is selected. | |||
| TMNoFlags = 0x00000000 | |||
| // TMOnePhase using one-phase optimization. | |||
| TMOnePhase = 0x40000000 | |||
| // TMResume is resuming association with a suspended transaction branch. | |||
| TMResume = 0x08000000 | |||
| // TMStartRScan starts a recovery scan. | |||
| TMStartRScan = 0x01000000 | |||
| // TMSuccess disassociates caller from a transaction branch. | |||
| TMSuccess = 0x04000000 | |||
| // TMSuspend is suspending (not ending) its association with a transaction branch. | |||
| TMSuspend = 0x02000000 | |||
| // XAReadOnly the transaction branch has been read-only and has been committed. | |||
| XAReadOnly = 0x00000003 | |||
| // XAOk The transaction work has been prepared normally. | |||
| XAOk = 0 | |||
| ) | |||
| type XAResource interface { | |||
| Commit(ctx context.Context, xid string, onePhase bool) error | |||
| End(ctx context.Context, xid string, flags int) error | |||
| Forget(ctx context.Context, xid string) error | |||
| GetTransactionTimeout() time.Duration | |||
| IsSameRM(ctx context.Context, resource XAResource) bool | |||
| XAPrepare(ctx context.Context, xid string) error | |||
| Recover(ctx context.Context, flag int) ([]string, error) | |||
| Rollback(ctx context.Context, xid string) error | |||
| SetTransactionTimeout(duration time.Duration) bool | |||
| Start(ctx context.Context, xid string, flags int) error | |||
| } | |||
| @@ -18,12 +18,31 @@ | |||
| package xa | |||
| import ( | |||
| "github.com/seata/seata-go/pkg/datasource/sql/exec" | |||
| "database/sql/driver" | |||
| "fmt" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| ) | |||
| func Init() { | |||
| exec.RegisterXAExecutor(types.DBTypeMySQL, func() exec.SQLExecutor { | |||
| return &XAExecutor{} | |||
| }) | |||
| // CreateXAResource create a connection for xa with the different db type. | |||
| // Such as mysql, oracle, MARIADB, POSTGRESQL | |||
| func CreateXAResource(conn driver.Conn, dbType types.DBType) (XAResource, error) { | |||
| var err error | |||
| var xaConnection XAResource | |||
| switch dbType { | |||
| case types.DBTypeMySQL: | |||
| xaConnection = NewMysqlXaConn(conn) | |||
| case types.DBTypeOracle: | |||
| case types.DBTypePostgreSQL: | |||
| default: | |||
| err = fmt.Errorf("not support db type for :%s", dbType.String()) | |||
| } | |||
| if err != nil { | |||
| log.Errorf(err.Error()) | |||
| return nil, err | |||
| } | |||
| return xaConnection, nil | |||
| } | |||
| @@ -15,7 +15,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| package xa | |||
| package sql | |||
| import ( | |||
| "strconv" | |||
| @@ -23,13 +23,12 @@ import ( | |||
| ) | |||
| const ( | |||
| BranchIdPrefix = "-" | |||
| SeataXaXidFormatId = 9752 | |||
| branchIdPrefix = "-" | |||
| ) | |||
| type XABranchXid struct { | |||
| xid string | |||
| branchId int64 | |||
| branchId uint64 | |||
| globalTransactionId []byte | |||
| branchQualifier []byte | |||
| } | |||
| @@ -63,14 +62,10 @@ func (x *XABranchXid) GetGlobalXid() string { | |||
| return x.xid | |||
| } | |||
| func (x *XABranchXid) GetBranchId() int64 { | |||
| func (x *XABranchXid) GetBranchId() uint64 { | |||
| return x.branchId | |||
| } | |||
| func (x *XABranchXid) GetFormatId() int { | |||
| return SeataXaXidFormatId | |||
| } | |||
| func (x *XABranchXid) GetGlobalTransactionId() []byte { | |||
| return x.globalTransactionId | |||
| } | |||
| @@ -80,7 +75,7 @@ func (x *XABranchXid) GetBranchQualifier() []byte { | |||
| } | |||
| func (x *XABranchXid) String() string { | |||
| return x.xid + BranchIdPrefix + strconv.FormatInt(x.branchId, 10) | |||
| return x.xid + branchIdPrefix + strconv.FormatUint(x.branchId, 10) | |||
| } | |||
| func WithXid(xid string) Option { | |||
| @@ -89,7 +84,7 @@ func WithXid(xid string) Option { | |||
| } | |||
| } | |||
| func WithBranchId(branchId int64) Option { | |||
| func WithBranchId(branchId uint64) Option { | |||
| return func(x *XABranchXid) { | |||
| x.branchId = branchId | |||
| } | |||
| @@ -113,7 +108,7 @@ func encode(x *XABranchXid) { | |||
| } | |||
| if x.branchId != 0 { | |||
| x.branchQualifier = []byte(BranchIdPrefix + strconv.FormatInt(x.branchId, 10)) | |||
| x.branchQualifier = []byte(branchIdPrefix + strconv.FormatUint(x.branchId, 10)) | |||
| } | |||
| } | |||
| @@ -123,7 +118,7 @@ func decode(x *XABranchXid) { | |||
| } | |||
| if len(x.branchQualifier) > 0 { | |||
| branchId := strings.TrimLeft(string(x.branchQualifier), BranchIdPrefix) | |||
| x.branchId, _ = strconv.ParseInt(branchId, 10, 64) | |||
| branchId := strings.TrimLeft(string(x.branchQualifier), branchIdPrefix) | |||
| x.branchId, _ = strconv.ParseUint(branchId, 10, 64) | |||
| } | |||
| } | |||
| @@ -15,7 +15,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| package xa | |||
| package sql | |||
| import ( | |||
| "testing" | |||
| @@ -25,8 +25,8 @@ import ( | |||
| func TestXABranchXidBuild(t *testing.T) { | |||
| xid := "111" | |||
| branchId := int64(222) | |||
| x := Build(xid, branchId) | |||
| branchId := uint64(222) | |||
| x := XaIdBuild(xid, branchId) | |||
| assert.Equal(t, x.GetGlobalXid(), xid) | |||
| assert.Equal(t, x.GetBranchId(), branchId) | |||
| @@ -36,11 +36,11 @@ func TestXABranchXidBuild(t *testing.T) { | |||
| func TestXABranchXidBuildWithByte(t *testing.T) { | |||
| xid := []byte("111") | |||
| branchId := []byte(BranchIdPrefix + "222") | |||
| x := BuildWithByte(xid, branchId) | |||
| branchId := []byte(branchIdPrefix + "222") | |||
| x := XaIdBuildWithByte(xid, branchId) | |||
| assert.Equal(t, x.GetGlobalTransactionId(), xid) | |||
| assert.Equal(t, x.GetBranchQualifier(), branchId) | |||
| assert.Equal(t, x.GetGlobalXid(), "111") | |||
| assert.Equal(t, x.GetBranchId(), int64(222)) | |||
| assert.Equal(t, x.GetBranchId(), uint64(222)) | |||
| } | |||
| @@ -0,0 +1,245 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package sql | |||
| import ( | |||
| "context" | |||
| "database/sql" | |||
| "errors" | |||
| "flag" | |||
| "fmt" | |||
| "sync" | |||
| "time" | |||
| "github.com/bluele/gcache" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/datasource" | |||
| "github.com/seata/seata-go/pkg/datasource/sql/types" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/seata/seata-go/pkg/rm" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| ) | |||
| var branchStatusCache gcache.Cache | |||
| type XAConnConf struct { | |||
| XaBranchExecutionTimeout time.Duration `json:"xa_branch_execution_timeout" xml:"xa_branch_execution_timeout" koanf:"xa_branch_execution_timeout"` | |||
| } | |||
| func (cfg *XAConnConf) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { | |||
| f.DurationVar(&cfg.XaBranchExecutionTimeout, prefix+".xa_branch_execution_timeout", time.Minute, "Undo log table name.") | |||
| } | |||
| type XAConfig struct { | |||
| xaConnConf XAConnConf | |||
| TwoPhaseHoldTime time.Duration `json:"two_phase_hold_time" yaml:"xa_two_phase_hold_time" koanf:"xa_two_phase_hold_time"` | |||
| } | |||
| func (cfg *XAConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { | |||
| f.DurationVar(&cfg.TwoPhaseHoldTime, prefix+".two_phase_hold_time", time.Millisecond*1000, "Undo log table name.") | |||
| cfg.xaConnConf.RegisterFlagsWithPrefix(prefix, f) | |||
| } | |||
| func InitXA(config XAConfig) *XAResourceManager { | |||
| xaSourceManager := &XAResourceManager{ | |||
| resourceCache: sync.Map{}, | |||
| basic: datasource.NewBasicSourceManager(), | |||
| rmRemoting: rm.GetRMRemotingInstance(), | |||
| config: config, | |||
| } | |||
| xaConnTimeout = config.xaConnConf.XaBranchExecutionTimeout | |||
| branchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build() | |||
| rm.GetRmCacheInstance().RegisterResourceManager(xaSourceManager) | |||
| go xaSourceManager.xaTwoPhaseTimeoutChecker() | |||
| return xaSourceManager | |||
| } | |||
| type XAResourceManager struct { | |||
| config XAConfig | |||
| resourceCache sync.Map | |||
| basic *datasource.BasicSourceManager | |||
| rmRemoting *rm.RMRemoting | |||
| } | |||
| func (xaManager *XAResourceManager) xaTwoPhaseTimeoutChecker() { | |||
| var dbResource *DBResource | |||
| xaManager.resourceCache.Range(func(key, value any) bool { | |||
| if source, ok := value.(*DBResource); ok { | |||
| dbResource = source | |||
| } | |||
| return false | |||
| }) | |||
| if dbResource.IsShouldBeHeld() { | |||
| ticker := time.NewTicker(time.Second) | |||
| for { | |||
| select { | |||
| case <-ticker.C: | |||
| xaManager.resourceCache.Range(func(key, value any) bool { | |||
| source, ok := value.(*DBResource) | |||
| if !ok { | |||
| return true | |||
| } | |||
| if source.IsShouldBeHeld() { | |||
| return true | |||
| } | |||
| source.GetKeeper().Range(func(key, value any) bool { | |||
| connectionXA, isConnectionXA := value.(*XAConn) | |||
| if !isConnectionXA { | |||
| return true | |||
| } | |||
| if time.Now().Sub(connectionXA.prepareTime) > xaManager.config.TwoPhaseHoldTime { | |||
| if err := connectionXA.CloseForce(); err != nil { | |||
| log.Errorf("Force close the xa xid:%s physical connection fail", connectionXA.txCtx.XID) | |||
| } | |||
| } | |||
| return true | |||
| }) | |||
| return true | |||
| }) | |||
| } | |||
| } | |||
| } | |||
| } | |||
| func (xaManager *XAResourceManager) GetBranchType() branch.BranchType { | |||
| return branch.BranchTypeXA | |||
| } | |||
| func (xaManager *XAResourceManager) GetCachedResources() *sync.Map { | |||
| return &xaManager.resourceCache | |||
| } | |||
| func (xaManager *XAResourceManager) RegisterResource(res rm.Resource) error { | |||
| xaManager.resourceCache.Store(res.GetResourceId(), res) | |||
| return xaManager.basic.RegisterResource(res) | |||
| } | |||
| func (xaManager *XAResourceManager) UnregisterResource(resource rm.Resource) error { | |||
| return xaManager.basic.UnregisterResource(resource) | |||
| } | |||
| func (xaManager *XAResourceManager) xaIDBuilder(xid string, branchId uint64) XAXid { | |||
| return XaIdBuild(xid, branchId) | |||
| } | |||
| func (xaManager *XAResourceManager) finishBranch(ctx context.Context, xaID XAXid, branchResource rm.BranchResource) (*XAConn, error) { | |||
| resource, ok := xaManager.resourceCache.Load(branchResource.ResourceId) | |||
| if !ok { | |||
| err := fmt.Errorf("unknow resource for rollback xa, resourceId: %s", branchResource.ResourceId) | |||
| log.Errorf(err.Error()) | |||
| return nil, err | |||
| } | |||
| dbResource, ok := resource.(*DBResource) | |||
| if !ok { | |||
| err := fmt.Errorf("unknow resource for rollback xa, resourceId: %s", branchResource.ResourceId) | |||
| log.Errorf(err.Error()) | |||
| return nil, err | |||
| } | |||
| connectionProxyXA, err := dbResource.ConnectionForXA(ctx, xaID) | |||
| if err != nil { | |||
| err := fmt.Errorf("get connection for rollback xa, resourceId: %s", branchResource.ResourceId) | |||
| log.Errorf(err.Error()) | |||
| return nil, err | |||
| } | |||
| return connectionProxyXA, nil | |||
| } | |||
| func (xaManager *XAResourceManager) BranchCommit(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) { | |||
| xaID := xaManager.xaIDBuilder(branchResource.Xid, uint64(branchResource.BranchId)) | |||
| connectionProxyXA, err := xaManager.finishBranch(ctx, xaID, branchResource) | |||
| if err != nil { | |||
| return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err | |||
| } | |||
| if commitErr := connectionProxyXA.XaCommit(ctx, xaID.String(), branchResource.BranchId); commitErr != nil { | |||
| err := fmt.Errorf("rollback xa, resourceId: %s", branchResource.ResourceId) | |||
| log.Errorf(err.Error()) | |||
| setBranchStatus(xaID.String(), branch.BranchStatusPhasetwoCommitted) | |||
| return branch.BranchStatusPhasetwoCommitFailedUnretryable, err | |||
| } | |||
| log.Infof("%s was committed", xaID.String()) | |||
| return branch.BranchStatusPhasetwoCommitted, nil | |||
| } | |||
| func (xaManager *XAResourceManager) BranchRollback(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) { | |||
| xaID := xaManager.xaIDBuilder(branchResource.Xid, uint64(branchResource.BranchId)) | |||
| connectionProxyXA, err := xaManager.finishBranch(ctx, xaID, branchResource) | |||
| if err != nil { | |||
| return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err | |||
| } | |||
| if rollbackErr := connectionProxyXA.XaRollbackByBranchId(ctx, xaID.String(), branchResource.BranchId); rollbackErr != nil { | |||
| err := fmt.Errorf("rollback xa, resourceId: %s", branchResource.ResourceId) | |||
| log.Errorf(err.Error()) | |||
| setBranchStatus(xaID.String(), branch.BranchStatusPhasetwoRollbacked) | |||
| return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err | |||
| } | |||
| log.Infof("%s was rollback", xaID.String()) | |||
| return branch.BranchStatusPhasetwoRollbacked, nil | |||
| } | |||
| func (xaManager *XAResourceManager) LockQuery(ctx context.Context, param rm.LockQueryParam) (bool, error) { | |||
| return false, nil | |||
| } | |||
| func (xaManager *XAResourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { | |||
| return xaManager.rmRemoting.BranchRegister(req) | |||
| } | |||
| func (xaManager *XAResourceManager) BranchReport(ctx context.Context, param rm.BranchReportParam) error { | |||
| return xaManager.rmRemoting.BranchReport(param) | |||
| } | |||
| func (xaManager *XAResourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (datasource.TableMetaCache, error) { | |||
| return xaManager.basic.CreateTableMetaCache(ctx, resID, dbType, db) | |||
| } | |||
| func branchStatus(xaBranchXid string) (branch.BranchStatus, error) { | |||
| tmpBranchStatus, err := branchStatusCache.GetIFPresent(xaBranchXid) | |||
| if err != nil { | |||
| if errors.Is(err, gcache.KeyNotFoundError) { | |||
| return branch.BranchStatusUnknown, nil | |||
| } | |||
| return branch.BranchStatusUnknown, err | |||
| } | |||
| branchStatus, isBranchStatus := tmpBranchStatus.(branch.BranchStatus) | |||
| if !isBranchStatus { | |||
| return branch.BranchStatusUnknown, fmt.Errorf("branchId:%s get result isn't branch status", xaBranchXid) | |||
| } | |||
| return branchStatus, nil | |||
| } | |||
| func setBranchStatus(xaBranchXid string, status branch.BranchStatus) { | |||
| branchStatusCache.Set(xaBranchXid, status) | |||
| } | |||
| @@ -15,16 +15,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| package xa | |||
| type Xid interface { | |||
| GetFormatId() int | |||
| GetGlobalTransactionId() []byte | |||
| GetBranchQualifier() []byte | |||
| } | |||
| package sql | |||
| type XAXid interface { | |||
| Xid | |||
| GetGlobalXid() string | |||
| GetBranchId() int64 | |||
| GetBranchId() uint64 | |||
| String() string | |||
| } | |||
| @@ -15,12 +15,12 @@ | |||
| * limitations under the License. | |||
| */ | |||
| package xa | |||
| package sql | |||
| func Build(xid string, branchId int64) *XABranchXid { | |||
| func XaIdBuild(xid string, branchId uint64) *XABranchXid { | |||
| return NewXABranchXid(WithXid(xid), WithBranchId(branchId)) | |||
| } | |||
| func BuildWithByte(globalTransactionId []byte, branchQualifier []byte) *XABranchXid { | |||
| func XaIdBuildWithByte(globalTransactionId []byte, branchQualifier []byte) *XABranchXid { | |||
| return NewXABranchXid(WithGlobalTransactionId(globalTransactionId), WithBranchQualifier(branchQualifier)) | |||
| } | |||
| @@ -35,71 +35,42 @@ const ( | |||
| ) | |||
| const ( | |||
| /** | |||
| * The BranchStatus_Unknown. | |||
| * description:BranchStatus_Unknown branch status. | |||
| */ | |||
| BranchStatusUnknown = BranchStatus(0) | |||
| /** | |||
| * The BranchStatus_Registered. | |||
| * description:BranchStatus_Registered to TC. | |||
| */ | |||
| BranchStatusRegistered = BranchStatus(1) | |||
| /** | |||
| * The Phase one done. | |||
| * description:Branch logic is successfully done at phase one. | |||
| */ | |||
| BranchStatusPhaseoneDone = BranchStatus(2) | |||
| /** | |||
| * The Phase one failed. | |||
| * description:Branch logic is failed at phase one. | |||
| */ | |||
| BranchStatusPhaseoneFailed = BranchStatus(3) | |||
| /** | |||
| * The Phase one timeout. | |||
| * description:Branch logic is NOT reported for a timeout. | |||
| */ | |||
| BranchStatusPhaseoneTimeout = BranchStatus(4) | |||
| /** | |||
| * The Phase two committed. | |||
| * description:Commit logic is successfully done at phase two. | |||
| */ | |||
| BranchStatusPhasetwoCommitted = BranchStatus(5) | |||
| /** | |||
| * The Phase two commit failed retryable. | |||
| * description:Commit logic is failed but retryable. | |||
| */ | |||
| BranchStatusPhasetwoCommitFailedRetryable = BranchStatus(6) | |||
| /** | |||
| * The Phase two commit failed unretryable. | |||
| * description:Commit logic is failed and NOT retryable. | |||
| */ | |||
| BranchStatusPhasetwoCommitFailedUnretryable = BranchStatus(7) | |||
| /** | |||
| * The Phase two rollbacked. | |||
| * description:Rollback logic is successfully done at phase two. | |||
| */ | |||
| BranchStatusPhasetwoRollbacked = BranchStatus(8) | |||
| /** | |||
| * The Phase two rollback failed retryable. | |||
| * description:Rollback logic is failed but retryable. | |||
| */ | |||
| BranchStatusPhasetwoRollbackFailedRetryable = BranchStatus(9) | |||
| /** | |||
| * The Phase two rollback failed unretryable. | |||
| * description:Rollback logic is failed but NOT retryable. | |||
| */ | |||
| BranchStatusPhasetwoRollbackFailedUnretryable = BranchStatus(10) | |||
| // BranchStatusUnknown the BranchStatus_Unknown. description:BranchStatus_Unknown branch status. | |||
| BranchStatusUnknown = iota | |||
| // BranchStatusRegistered the BranchStatus_Registered. description:BranchStatus_Registered to TC. | |||
| BranchStatusRegistered | |||
| // BranchStatusPhaseoneDone the Phase one done. description:Branch logic is successfully done at phase one. | |||
| BranchStatusPhaseoneDone | |||
| // BranchStatusPhaseoneFailed the Phase one failed. description:Branch logic is failed at phase one. | |||
| BranchStatusPhaseoneFailed | |||
| // BranchStatusPhaseoneTimeout the Phase one timeout. description:Branch logic is NOT reported for a timeout. | |||
| BranchStatusPhaseoneTimeout | |||
| // BranchStatusPhasetwoCommitted the Phase two committed. description:Commit logic is successfully done at phase two. | |||
| BranchStatusPhasetwoCommitted | |||
| // BranchStatusPhasetwoCommitFailedRetryable the Phase two commit failed retryable. description:Commit logic is failed but retryable. | |||
| BranchStatusPhasetwoCommitFailedRetryable | |||
| // BranchStatusPhasetwoCommitFailedUnretryable the Phase two commit failed unretryable. | |||
| // description:Commit logic is failed and NOT retryable. | |||
| BranchStatusPhasetwoCommitFailedUnretryable | |||
| // BranchStatusPhasetwoRollbacked The Phase two rollbacked. | |||
| // description:Rollback logic is successfully done at phase two. | |||
| BranchStatusPhasetwoRollbacked | |||
| // BranchStatusPhasetwoRollbackFailedRetryable the Phase two rollback failed retryable. | |||
| // description:Rollback logic is failed but retryable. | |||
| BranchStatusPhasetwoRollbackFailedRetryable | |||
| // BranchStatusPhasetwoRollbackFailedUnretryable the Phase two rollback failed unretryable. | |||
| // description:Rollback logic is failed but NOT retryable. | |||
| BranchStatusPhasetwoRollbackFailedUnretryable | |||
| ) | |||
| func (s BranchStatus) String() string { | |||
| @@ -31,7 +31,7 @@ type Resource interface { | |||
| GetBranchType() branch.BranchType | |||
| } | |||
| // branch resource which contains branch to commit or rollback | |||
| // BranchResource contains branch to commit or rollback | |||
| type BranchResource struct { | |||
| BranchType branch.BranchType | |||
| Xid string | |||
| @@ -42,9 +42,9 @@ type BranchResource struct { | |||
| // ResourceManagerInbound Control a branch transaction commit or rollback | |||
| type ResourceManagerInbound interface { | |||
| // Commit a branch transaction | |||
| // BranchCommit commit a branch transaction | |||
| BranchCommit(ctx context.Context, resource BranchResource) (branch.BranchStatus, error) | |||
| // Rollback a branch transaction | |||
| // BranchRollback rollback a branch transaction | |||
| BranchRollback(ctx context.Context, resource BranchResource) (branch.BranchStatus, error) | |||
| } | |||
| @@ -75,13 +75,13 @@ type LockQueryParam struct { | |||
| LockKeys string | |||
| } | |||
| // Resource Manager: send outbound request to TC | |||
| // ResourceManagerOutbound Resource Manager: send outbound request to TC | |||
| type ResourceManagerOutbound interface { | |||
| // Branch register long | |||
| // BranchRegister rm register the branch transaction | |||
| BranchRegister(ctx context.Context, param BranchRegisterParam) (int64, error) | |||
| // Branch report | |||
| // BranchReport branch transaction report the status | |||
| BranchReport(ctx context.Context, param BranchReportParam) error | |||
| // Lock query boolean | |||
| // LockQuery lock query boolean | |||
| LockQuery(ctx context.Context, param LockQueryParam) (bool, error) | |||
| } | |||
| @@ -90,13 +90,13 @@ type ResourceManager interface { | |||
| ResourceManagerInbound | |||
| ResourceManagerOutbound | |||
| // Register a Resource to be managed by Resource Manager | |||
| // RegisterResource register a resource to be managed by resource manager | |||
| RegisterResource(resource Resource) error | |||
| // Unregister a Resource from the Resource Manager | |||
| // UnregisterResource unregister a resource from the Resource Manager | |||
| UnregisterResource(resource Resource) error | |||
| // Get all resources managed by this manager | |||
| // GetCachedResources get all resources managed by this manager | |||
| GetCachedResources() *sync.Map | |||
| // Get the BranchType | |||
| // GetBranchType get the branch type | |||
| GetBranchType() branch.BranchType | |||
| } | |||