From 27cba64d6f639bac43f06d2ca088968c2a66da13 Mon Sep 17 00:00:00 2001 From: georgehao Date: Fri, 17 Mar 2023 14:17:59 +0800 Subject: [PATCH] feat: add xa (#467) * feat: add xa --- pkg/client/client.go | 5 +- pkg/client/config.go | 8 +- pkg/datasource/init.go | 6 +- .../sql/{at.go => at_resource_manager.go} | 16 +- pkg/datasource/sql/conn.go | 8 +- pkg/datasource/sql/conn/resource_xa.go | 91 ----- pkg/datasource/sql/conn_at.go | 3 +- pkg/datasource/sql/conn_at_test.go | 14 +- pkg/datasource/sql/conn_xa.go | 318 ++++++++++++++++- pkg/datasource/sql/conn_xa_test.go | 56 ++- pkg/datasource/sql/connector.go | 47 ++- pkg/datasource/sql/connector_test.go | 20 +- .../sql/datasource/datasource_manager.go | 60 +--- pkg/datasource/sql/datasource_resource.go | 124 ------- pkg/datasource/sql/db.go | 180 +++++++--- pkg/datasource/sql/driver.go | 97 ++---- pkg/datasource/sql/driver_test.go | 10 +- pkg/datasource/sql/exec/at/insert_executor.go | 4 +- pkg/datasource/sql/exec/executor.go | 16 - pkg/datasource/sql/exec/resource_xa.go | 93 ----- .../sql/exec/select_for_update_executor.go | 1 + pkg/datasource/sql/exec/xa/executor_xa.go | 84 ----- .../sql/mock/mock_datasource_manager.go | 7 +- pkg/datasource/sql/plugin.go | 2 - pkg/datasource/sql/root_context.go | 31 -- pkg/datasource/sql/tx.go | 56 +-- pkg/datasource/sql/tx_at.go | 1 + pkg/datasource/sql/tx_xa.go | 17 +- pkg/datasource/sql/types/types.go | 5 +- pkg/datasource/sql/util/convert.go | 29 ++ .../{conn_test.go => util/convert_test.go} | 23 +- .../mysql_xa_connection.go} | 55 +-- .../mysql_xa_connection_test.go} | 30 +- .../oracle.go => xa/oracle_xa_connection.go} | 2 +- .../oracle_xa_connection_test.go} | 2 +- pkg/datasource/sql/xa/xa_connection.go | 26 -- pkg/datasource/sql/xa/xa_connection_proxy.go | 325 ------------------ pkg/datasource/sql/xa/xa_resource.go | 71 ++++ .../default.go => xa/xa_resource_factory.go} | 29 +- pkg/datasource/sql/{xa => }/xa_branch_xid.go | 23 +- .../sql/{xa => }/xa_branch_xid_test.go | 12 +- pkg/datasource/sql/xa_resource_manager.go | 245 +++++++++++++ pkg/datasource/sql/{xa => }/xa_xid.go | 12 +- pkg/datasource/sql/{xa => }/xa_xid_builder.go | 6 +- pkg/protocol/branch/branch.go | 101 ++---- pkg/rm/rm_api.go | 22 +- 46 files changed, 1184 insertions(+), 1209 deletions(-) rename pkg/datasource/sql/{at.go => at_resource_manager.go} (91%) delete mode 100644 pkg/datasource/sql/conn/resource_xa.go delete mode 100644 pkg/datasource/sql/datasource_resource.go delete mode 100644 pkg/datasource/sql/exec/resource_xa.go delete mode 100644 pkg/datasource/sql/exec/xa/executor_xa.go delete mode 100644 pkg/datasource/sql/root_context.go rename pkg/datasource/sql/{conn_test.go => util/convert_test.go} (67%) rename pkg/datasource/sql/{exec/mysql_xa_resource.go => xa/mysql_xa_connection.go} (71%) rename pkg/datasource/sql/{exec/mysql_xa_resource_test.go => xa/mysql_xa_connection_test.go} (90%) rename pkg/datasource/sql/{conn/oracle.go => xa/oracle_xa_connection.go} (99%) rename pkg/datasource/sql/{conn/oracle_test.go => xa/oracle_xa_connection_test.go} (99%) delete mode 100644 pkg/datasource/sql/xa/xa_connection.go delete mode 100644 pkg/datasource/sql/xa/xa_connection_proxy.go create mode 100644 pkg/datasource/sql/xa/xa_resource.go rename pkg/datasource/sql/{exec/xa/default.go => xa/xa_resource_factory.go} (58%) rename pkg/datasource/sql/{xa => }/xa_branch_xid.go (83%) rename pkg/datasource/sql/{xa => }/xa_branch_xid_test.go (87%) create mode 100644 pkg/datasource/sql/xa_resource_manager.go rename pkg/datasource/sql/{xa => }/xa_xid.go (85%) rename pkg/datasource/sql/{xa => }/xa_xid_builder.go (85%) diff --git a/pkg/client/client.go b/pkg/client/client.go index b81c21cd..8fcf943c 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -43,7 +43,7 @@ func InitPath(configFilePath string) { initRmClient(cfg) initTmClient(cfg) - initDatasource(cfg) + initDatasource() } var ( @@ -84,10 +84,11 @@ func initRmClient(cfg *Config) { integration.Init() tcc.InitTCC() at.InitAT(cfg.ClientConfig.UndoConfig, cfg.AsyncWorkerConfig) + at.InitXA(cfg.ClientConfig.XaConfig) }) } -func initDatasource(cfg *Config) { +func initDatasource() { onceInitDatasource.Do(func() { datasource.Init() }) diff --git a/pkg/client/config.go b/pkg/client/config.go index a7520bac..e2abe347 100644 --- a/pkg/client/config.go +++ b/pkg/client/config.go @@ -54,15 +54,17 @@ const ( ) type ClientConfig struct { - TmConfig tm.TmConfig `yaml:"tm" json:"tm,omitempty" koanf:"tm"` - RmConfig rm.Config `yaml:"rm" json:"rm,omitempty" koanf:"rm"` - UndoConfig undo.Config `yaml:"undo" json:"undo,omitempty" koanf:"undo"` + TmConfig tm.TmConfig `yaml:"tm" json:"tm,omitempty" koanf:"tm"` + RmConfig rm.Config `yaml:"rm" json:"rm,omitempty" koanf:"rm"` + UndoConfig undo.Config `yaml:"undo" json:"undo,omitempty" koanf:"undo"` + XaConfig sql.XAConfig `yaml:"xa" json:"xa" koanf:"xa"` } func (c *ClientConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { c.TmConfig.RegisterFlagsWithPrefix(prefix+".tm", f) c.RmConfig.RegisterFlagsWithPrefix(prefix+".rm", f) c.UndoConfig.RegisterFlagsWithPrefix(prefix+".undo", f) + c.XaConfig.RegisterFlagsWithPrefix(prefix+".xa", f) } type Config struct { diff --git a/pkg/datasource/init.go b/pkg/datasource/init.go index e9e17cfa..19aaf420 100644 --- a/pkg/datasource/init.go +++ b/pkg/datasource/init.go @@ -17,8 +17,10 @@ package datasource -import sql2 "github.com/seata/seata-go/pkg/datasource/sql" +import ( + "github.com/seata/seata-go/pkg/datasource/sql" +) func Init() { - sql2.Init() + sql.Init() } diff --git a/pkg/datasource/sql/at.go b/pkg/datasource/sql/at_resource_manager.go similarity index 91% rename from pkg/datasource/sql/at.go rename to pkg/datasource/sql/at_resource_manager.go index 350a1651..0e2da590 100644 --- a/pkg/datasource/sql/at.go +++ b/pkg/datasource/sql/at_resource_manager.go @@ -56,23 +56,23 @@ func (a *ATSourceManager) GetBranchType() branch.BranchType { return branch.BranchTypeAT } -// Get all resources managed by this manager +// GetCachedResources get all resources managed by this manager func (a *ATSourceManager) GetCachedResources() *sync.Map { return &a.resourceCache } -// Register a Resource to be managed by Resource Manager +// RegisterResource register a Resource to be managed by Resource Manager func (a *ATSourceManager) RegisterResource(res rm.Resource) error { a.resourceCache.Store(res.GetResourceId(), res) return a.basic.RegisterResource(res) } -// Unregister a Resource from the Resource Manager +// UnregisterResource unregister a Resource from the Resource Manager func (a *ATSourceManager) UnregisterResource(res rm.Resource) error { return a.basic.UnregisterResource(res) } -// Rollback a branch transaction +// BranchRollback rollback a branch transaction func (a *ATSourceManager) BranchRollback(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) { var dbResource *DBResource if resource, ok := a.resourceCache.Load(branchResource.ResourceId); !ok { @@ -103,28 +103,26 @@ func (a *ATSourceManager) BranchRollback(ctx context.Context, branchResource rm. return branch.BranchStatusPhasetwoRollbacked, nil } -// BranchCommit +// BranchCommit commit the branch transaction func (a *ATSourceManager) BranchCommit(ctx context.Context, resource rm.BranchResource) (branch.BranchStatus, error) { a.worker.BranchCommit(ctx, resource) return branch.BranchStatusPhasetwoCommitted, nil } -// LockQuery func (a *ATSourceManager) LockQuery(ctx context.Context, param rm.LockQueryParam) (bool, error) { return a.rmRemoting.LockQuery(param) } -// BranchRegister +// BranchRegister branch transaction register func (a *ATSourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { return a.rmRemoting.BranchRegister(req) } -// BranchReport +// BranchReport Report status of transaction branch func (a *ATSourceManager) BranchReport(ctx context.Context, param rm.BranchReportParam) error { return a.rmRemoting.BranchReport(param) } -// CreateTableMetaCache func (a *ATSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (datasource.TableMetaCache, error) { return a.basic.CreateTableMetaCache(ctx, resID, dbType, db) diff --git a/pkg/datasource/sql/conn.go b/pkg/datasource/sql/conn.go index 33b397d1..e786a1c3 100644 --- a/pkg/datasource/sql/conn.go +++ b/pkg/datasource/sql/conn.go @@ -211,7 +211,13 @@ func (c *Conn) Begin() (driver.Tx, error) { // // global transaction according to tranCtx. If so, it needs to be included in the transaction management of seata func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - c.autoCommit = false + if c.txCtx.TransactionMode == types.XAMode { + return newTx( + withDriverConn(c), + withTxCtx(c.txCtx), + withOriginTx(nil), + ) + } if conn, ok := c.targetConn.(driver.ConnBeginTx); ok { tx, err := conn.BeginTx(ctx, opts) diff --git a/pkg/datasource/sql/conn/resource_xa.go b/pkg/datasource/sql/conn/resource_xa.go deleted file mode 100644 index 52bbced5..00000000 --- a/pkg/datasource/sql/conn/resource_xa.go +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package conn - -import "time" - -const ( - // TMENDICANT Ends a recovery scan. - TMENDRSCAN = 0x00800000 - - /** - * Disassociates the caller and marks the transaction branch - * rollback-only. - */ - TMFAIL = 0x20000000 - - /** - * Caller is joining existing transaction branch. - */ - TMJOIN = 0x00200000 - - /** - * Use TMNOFLAGS to indicate no flags value is selected. - */ - TMNOFLAGS = 0x00000000 - - /** - * Caller is using one-phase optimization. - */ - TMONEPHASE = 0x40000000 - - /** - * Caller is resuming association with a suspended - * transaction branch. - */ - TMRESUME = 0x08000000 - - /** - * Starts a recovery scan. - */ - TMSTARTRSCAN = 0x01000000 - - /** - * Disassociates caller from a transaction branch. - */ - TMSUCCESS = 0x04000000 - - /** - * Caller is suspending (not ending) its association with - * a transaction branch. - */ - TMSUSPEND = 0x02000000 - - /** - * The transaction branch has been read-only and has been committed. - */ - XA_RDONLY = 0x00000003 - - /** - * The transaction work has been prepared normally. - */ - XA_OK = 0 -) - -type XAResource interface { - Commit(xid string, onePhase bool) error - End(xid string, flags int) error - Forget(xid string) error - GetTransactionTimeout() time.Duration - IsSameRM(resource XAResource) bool - XAPrepare(xid string) (int, error) - Recover(flag int) []string - Rollback(xid string) error - SetTransactionTimeout(duration time.Duration) bool - Start(xid string, flags int) error -} diff --git a/pkg/datasource/sql/conn_at.go b/pkg/datasource/sql/conn_at.go index 19f09690..1ddfba51 100644 --- a/pkg/datasource/sql/conn_at.go +++ b/pkg/datasource/sql/conn_at.go @@ -22,11 +22,10 @@ import ( gosql "database/sql" "database/sql/driver" - "github.com/seata/seata-go/pkg/util/log" - "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/tm" + "github.com/seata/seata-go/pkg/util/log" ) // ATConn Database connection proxy object under XA transaction model diff --git a/pkg/datasource/sql/conn_at_test.go b/pkg/datasource/sql/conn_at_test.go index 405728e3..fa92620a 100644 --- a/pkg/datasource/sql/conn_at_test.go +++ b/pkg/datasource/sql/conn_at_test.go @@ -26,11 +26,13 @@ import ( "github.com/golang/mock/gomock" "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/mock" "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/tm" - "github.com/stretchr/testify/assert" ) func TestMain(m *testing.M) { @@ -41,7 +43,7 @@ func TestMain(m *testing.M) { func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) { ctrl := gomock.NewController(t) - mockMgr := initMockResourceManager(t, ctrl) + mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) _ = mockMgr db, err := sql.Open(SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_client?multiStatements=true") @@ -57,6 +59,14 @@ func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQL mockConn := mock.NewMockTestDriverConn(ctrl) mockConn.EXPECT().Begin().AnyTimes().Return(mockTx, nil) mockConn.EXPECT().BeginTx(gomock.Any(), gomock.Any()).AnyTimes().Return(mockTx, nil) + mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn( + func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + rows := &mysqlMockRows{} + rows.data = [][]interface{}{ + {"8.0.29"}, + } + return rows, nil + }) baseMockConn(mockConn) connector := mock.NewMockTestDriverConnector(ctrl) diff --git a/pkg/datasource/sql/conn_xa.go b/pkg/datasource/sql/conn_xa.go index 7d1b66f7..b42cdac6 100644 --- a/pkg/datasource/sql/conn_xa.go +++ b/pkg/datasource/sql/conn_xa.go @@ -19,41 +19,74 @@ package sql import ( "context" + gosql "database/sql" "database/sql/driver" + "fmt" + "time" "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/datasource/sql/xa" "github.com/seata/seata-go/pkg/tm" + "github.com/seata/seata-go/pkg/util/log" ) +var xaConnTimeout time.Duration + // XAConn Database connection proxy object under XA transaction model // Conn is assumed to be stateful. type XAConn struct { *Conn + + tx driver.Tx + xaResource xa.XAResource + xaBranchXid *XABranchXid + xaActive bool + rollBacked bool + branchRegisterTime time.Time + prepareTime time.Time + isConnKept bool } -// QueryContext -func (c *XAConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { +func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if c.createOnceTxContext(ctx) { defer func() { c.txCtx = types.NewTxCtx() }() } - return c.Conn.QueryContext(ctx, query, args) + //ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types, error) { + // ret, err := c.Conn.PrepareContext(ctx, query) + // if err != nil { + // return nil, err + // } + // return types.NewResult(types.WithRows(ret)), nil + //}) + + return c.Conn.PrepareContext(ctx, query) } -// PrepareContext -func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { +// QueryContext exec xa sql +func (c *XAConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { if c.createOnceTxContext(ctx) { defer func() { c.txCtx = types.NewTxCtx() }() } - return c.Conn.PrepareContext(ctx, query) + ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) { + ret, err := c.Conn.QueryContext(ctx, query, args) + if err != nil { + return nil, err + } + return types.NewResult(types.WithRows(ret)), nil + }) + + if err != nil { + return nil, err + } + return ret.GetRows(), nil } -// ExecContext func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { if c.createOnceTxContext(ctx) { defer func() { @@ -61,26 +94,56 @@ func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.Na }() } - return c.Conn.ExecContext(ctx, query, args) + ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) { + ret, err := c.Conn.ExecContext(ctx, query, args) + if err != nil { + return nil, err + } + return types.NewResult(types.WithResult(ret)), nil + }) + + return ret.GetResult(), err } -// BeginTx +// BeginTx like common transaction. but it just exec XA START func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - c.autoCommit = false - c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType c.txCtx.TxOpt = opts + c.txCtx.ResourceID = c.res.resourceID if tm.IsGlobalTx(ctx) { - c.txCtx.TransactionMode = types.XAMode c.txCtx.XID = tm.GetXID(ctx) + c.txCtx.TransactionMode = types.XAMode } tx, err := c.Conn.BeginTx(ctx, opts) if err != nil { return nil, err } + c.tx = tx + + if c.autoCommit { + baseTx, ok := tx.(*Tx) + if !ok { + return nil, fmt.Errorf("start xa %s transaction failure for the tx is a wrong type", c.txCtx.XID) + } + + c.branchRegisterTime = time.Now() + if err := baseTx.register(c.txCtx); err != nil { + c.cleanXABranchContext() + return nil, fmt.Errorf("failed to register xa branch %s, err:%w", c.txCtx.XID, err) + } + + c.xaBranchXid = XaIdBuild(c.txCtx.XID, c.txCtx.BranchID) + c.keepIfNecessary() + + if err = c.start(ctx); err != nil { + c.cleanXABranchContext() + return nil, fmt.Errorf("failed to start xa branch xid:%s err:%w", c.txCtx.XID, err) + } + c.xaActive = true + } return &XATx{tx: tx.(*Tx)}, nil } @@ -91,9 +154,240 @@ func (c *XAConn) createOnceTxContext(ctx context.Context) bool { if onceTx { c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType + c.txCtx.ResourceID = c.res.resourceID c.txCtx.XID = tm.GetXID(ctx) c.txCtx.TransactionMode = types.XAMode + c.txCtx.GlobalLockRequire = true } return onceTx } + +func (c *XAConn) createNewTxOnExecIfNeed(ctx context.Context, f func() (types.ExecResult, error)) (types.ExecResult, error) { + var err error + if c.txCtx.TransactionMode != types.Local && c.autoCommit { + _, err = c.BeginTx(ctx, driver.TxOptions{Isolation: driver.IsolationLevel(gosql.LevelDefault)}) + if err != nil { + return nil, err + } + } + defer func() { + recoverErr := recover() + if err != nil || recoverErr != nil { + log.Errorf("conn at rollback error:%v or recoverErr:%v", err, recoverErr) + if c.tx != nil { + rollbackErr := c.tx.Rollback() + if rollbackErr != nil { + log.Errorf("conn at rollback error:%v", rollbackErr) + } + } + } + }() + + // execute SQL + ret, err := f() + if err != nil { + // XA End & Rollback + if rollbackErr := c.Rollback(ctx); rollbackErr != nil { + log.Errorf("failed to rollback xa branch of :%s, err:%w", c.txCtx.XID, rollbackErr) + } + return nil, err + } + + if c.autoCommit { + if err := c.Commit(ctx); err != nil { + log.Errorf("xa connection proxy commit failure xid:%s, err:%v", c.txCtx.XID, err) + // XA End & Rollback + if err := c.Rollback(ctx); err != nil { + log.Errorf("xa connection proxy rollback failure xid:%s, err:%v", c.txCtx.XID, err) + } + } + } + + return ret, nil +} + +func (c *XAConn) keepIfNecessary() { + if c.ShouldBeHeld() { + if err := c.res.Hold(c.xaBranchXid.String(), c); err == nil { + c.isConnKept = true + } + } +} + +func (c *XAConn) releaseIfNecessary() { + if c.ShouldBeHeld() && c.xaBranchXid.String() != "" { + if c.isConnKept { + c.res.Release(c.xaBranchXid.String()) + c.isConnKept = false + } + } +} + +func (c *XAConn) start(ctx context.Context) error { + xaResource, err := xa.CreateXAResource(c.Conn.targetConn, c.dbType) + if err != nil { + return fmt.Errorf("create xa xid:%s resoruce err:%w", c.txCtx.XID, err) + } + c.xaResource = xaResource + + if err := c.xaResource.Start(ctx, c.xaBranchXid.String(), xa.TMNoFlags); err != nil { + return fmt.Errorf("xa xid %s resource connection start err:%w", c.txCtx.XID, err) + } + + if err := c.termination(c.xaBranchXid.String()); err != nil { + c.xaResource.End(ctx, c.xaBranchXid.String(), xa.TMFail) + c.XaRollback(ctx, c.xaBranchXid) + return err + } + return err +} + +func (c *XAConn) end(ctx context.Context, flags int) error { + err := c.termination(c.xaBranchXid.String()) + if err != nil { + return err + } + err = c.xaResource.End(ctx, c.xaBranchXid.String(), flags) + if err != nil { + return err + } + return nil +} + +func (c *XAConn) termination(xaBranchXid string) error { + branchStatus, err := branchStatus(xaBranchXid) + if err != nil { + c.releaseIfNecessary() + return fmt.Errorf("failed xa branch [%v] the global transaction has finish, branch status: [%v]", c.txCtx.XID, branchStatus) + } + return nil +} + +func (c *XAConn) cleanXABranchContext() { + h, _ := time.ParseDuration("-1000h") + c.branchRegisterTime = time.Now().Add(h) + c.prepareTime = time.Now().Add(h) + c.xaActive = false + if !c.isConnKept { + c.xaBranchXid = nil + } +} + +func (c *XAConn) Rollback(ctx context.Context) error { + if c.autoCommit { + return nil + } + + if !c.xaActive || c.xaBranchXid == nil { + return fmt.Errorf("should NOT rollback on an inactive session") + } + + if !c.rollBacked { + if c.xaResource.End(ctx, c.xaBranchXid.String(), xa.TMFail) != nil { + return c.rollbackErrorHandle() + } + if c.XaRollback(ctx, c.xaBranchXid) != nil { + c.cleanXABranchContext() + return c.rollbackErrorHandle() + } + if err := c.tx.Rollback(); err != nil { + c.cleanXABranchContext() + return fmt.Errorf("failed to report XA branch commit-failure on xid:%s err:%w", c.txCtx.XID, err) + } + } + c.cleanXABranchContext() + return nil +} + +func (c *XAConn) rollbackErrorHandle() error { + return fmt.Errorf("failed to end(TMFAIL) xa branch on [%v] - [%v]", c.txCtx.XID, c.xaBranchXid.GetBranchId()) +} + +func (c *XAConn) Commit(ctx context.Context) error { + if c.autoCommit { + return nil + } + + if !c.xaActive || c.xaBranchXid == nil { + return fmt.Errorf("should NOT commit on an inactive session") + } + + now := time.Now() + if c.end(ctx, xa.TMSuccess) != nil { + return c.commitErrorHandle() + } + + if c.checkTimeout(ctx, now) != nil { + return c.commitErrorHandle() + } + + if c.xaResource.XAPrepare(ctx, c.xaBranchXid.String()) != nil { + return c.commitErrorHandle() + } + return nil +} + +func (c *XAConn) commitErrorHandle() error { + var err error + if err = c.tx.Rollback(); err != nil { + err = fmt.Errorf("failed to report XA branch commit-failure xid:%s, err:%w", c.txCtx.XID, err) + } + c.cleanXABranchContext() + return err +} + +func (c *XAConn) ShouldBeHeld() bool { + return c.res.IsShouldBeHeld() || (c.res.GetDbType().String() != "" && c.res.GetDbType() != types.DBTypeUnknown) +} + +func (c *XAConn) checkTimeout(ctx context.Context, now time.Time) error { + if now.Sub(c.branchRegisterTime) > xaConnTimeout { + c.XaRollback(ctx, c.xaBranchXid) + return fmt.Errorf("XA branch timeout error xid:%s", c.txCtx.XID) + } + return nil +} + +func (c *XAConn) Close() error { + c.rollBacked = false + if c.isConnKept && c.ShouldBeHeld() { + return nil + } + c.cleanXABranchContext() + if err := c.Conn.Close(); err != nil { + return err + } + return nil +} + +func (c *XAConn) CloseForce() error { + if err := c.Conn.Close(); err != nil { + return err + } + c.rollBacked = false + c.cleanXABranchContext() + if err := c.Conn.Close(); err != nil { + return err + } + c.releaseIfNecessary() + return nil +} + +func (c *XAConn) XaCommit(ctx context.Context, xid string, branchId int64) error { + xaXid := XaIdBuild(xid, uint64(branchId)) + err := c.xaResource.Commit(ctx, xaXid.String(), false) + c.releaseIfNecessary() + return err +} + +func (c *XAConn) XaRollbackByBranchId(ctx context.Context, xid string, branchId int64) error { + xaXid := XaIdBuild(xid, uint64(branchId)) + return c.XaRollback(ctx, xaXid) +} + +func (c *XAConn) XaRollback(ctx context.Context, xaXid XAXid) error { + err := c.xaResource.Rollback(ctx, xaXid.GetGlobalXid()) + c.releaseIfNecessary() + return err +} diff --git a/pkg/datasource/sql/conn_xa_test.go b/pkg/datasource/sql/conn_xa_test.go index 0f421069..e154238a 100644 --- a/pkg/datasource/sql/conn_xa_test.go +++ b/pkg/datasource/sql/conn_xa_test.go @@ -21,18 +21,59 @@ import ( "context" "database/sql" "database/sql/driver" + "io" "sync/atomic" "testing" + "time" + "github.com/bluele/gcache" "github.com/golang/mock/gomock" "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/mock" "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/tm" - "github.com/stretchr/testify/assert" ) +type mysqlMockRows struct { + idx int + data [][]interface{} +} + +func (m *mysqlMockRows) Columns() []string { + //TODO implement me + panic("implement me") +} + +func (m *mysqlMockRows) Close() error { + //TODO implement me + panic("implement me") +} + +func (m *mysqlMockRows) Next(dest []driver.Value) error { + if m.idx == len(m.data) { + return io.EOF + } + + min := func(a, b int) int { + if a < b { + return a + } + return b + } + + cnt := min(len(m.data[0]), len(dest)) + + for i := 0; i < cnt; i++ { + dest[i] = m.data[m.idx][i] + } + m.idx++ + return nil +} + type mockSQLInterceptor struct { before func(ctx context.Context, execCtx *types.ExecContext) after func(ctx context.Context, execCtx *types.ExecContext) @@ -78,16 +119,27 @@ func (mi *mockTxHook) BeforeRollback(tx *Tx) { } func baseMockConn(mockConn *mock.MockTestDriverConn) { + branchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build() + mockConn.EXPECT().ExecContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil) mockConn.EXPECT().Exec(gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil) mockConn.EXPECT().ResetSession(gomock.Any()).AnyTimes().Return(nil) mockConn.EXPECT().Close().AnyTimes().Return(nil) + + mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn( + func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + rows := &mysqlMockRows{} + rows.data = [][]interface{}{ + {"8.0.29"}, + } + return rows, nil + }) } func initXAConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) { ctrl := gomock.NewController(t) - mockMgr := initMockResourceManager(t, ctrl) + mockMgr := initMockResourceManager(branch.BranchTypeXA, ctrl) _ = mockMgr //db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") db, err := sql.Open("seata-xa-mysql", "root:12345678@tcp(127.0.0.1:3306)/seata_client?multiStatements=true&interpolateParams=true") diff --git a/pkg/datasource/sql/connector.go b/pkg/datasource/sql/connector.go index 36d861a5..fb614904 100644 --- a/pkg/datasource/sql/connector.go +++ b/pkg/datasource/sql/connector.go @@ -20,11 +20,14 @@ package sql import ( "context" "database/sql/driver" + "errors" + "io" "sync" "github.com/go-sql-driver/mysql" "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/util/log" ) type seataATConnector struct { @@ -89,7 +92,6 @@ func (c *seataXAConnector) Driver() driver.Driver { // method will call Close and return error (if any). type seataConnector struct { transType types.TransactionMode - conf *seataServerConfig res *DBResource once sync.Once driver driver.Driver @@ -116,6 +118,15 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } + // get the version of mysql for xa. + if c.transType == types.XAMode { + version, err := c.dbVersion(ctx, conn) + if err != nil { + return nil, err + } + c.res.SetDbVersion(version) + } + return &Conn{ targetConn: conn, res: c.res, @@ -126,6 +137,40 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) { }, nil } +func (c *seataConnector) dbVersion(ctx context.Context, conn driver.Conn) (string, error) { + queryConn, isQueryContext := conn.(driver.QueryerContext) + if !isQueryContext { + return "", errors.New("get db version error for unexpected driver conn") + } + + res, err := queryConn.QueryContext(ctx, "SELECT VERSION()", nil) + if err != nil { + log.Errorf("seata connector get the xa mysql version err:%v", err) + return "", err + } + + dest := make([]driver.Value, 1) + var version string + for true { + if err = res.Next(dest); err != nil { + if err == io.EOF { + return version, nil + } + return "", err + } + if len(dest) != 1 { + return "", errors.New("get the mysql version is not column 1") + } + + var isVersionOk bool + version, isVersionOk = dest[0].(string) + if !isVersionOk { + return "", errors.New("get the mysql version is not a string") + } + } + return "", errors.New("get the mysql version is error") +} + // Driver returns the underlying Driver of the Connector, // mainly to maintain compatibility with the Driver method // on sql.DB. diff --git a/pkg/datasource/sql/connector_test.go b/pkg/datasource/sql/connector_test.go index 950e8133..0e99513a 100644 --- a/pkg/datasource/sql/connector_test.go +++ b/pkg/datasource/sql/connector_test.go @@ -25,10 +25,12 @@ import ( "testing" "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/seata/seata-go/pkg/datasource/sql/mock" "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/util/reflectx" - "github.com/stretchr/testify/assert" ) type initConnectorFunc func(t *testing.T, ctrl *gomock.Controller) driver.Connector @@ -38,6 +40,14 @@ func initMockConnector(t *testing.T, ctrl *gomock.Controller) driver.Connector { connector := mock.NewMockTestDriverConnector(ctrl) connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil) + mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn( + func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + rows := &mysqlMockRows{} + rows.data = [][]interface{}{ + {"8.0.29"}, + } + return rows, nil + }) return connector } @@ -66,10 +76,10 @@ func Test_seataATConnector_Connect(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockMgr := initMockResourceManager(t, ctrl) + mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) _ = mockMgr - db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") + db, err := sql.Open(SeataATMySQLDriver, "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") if err != nil { t.Fatal(err) } @@ -110,10 +120,10 @@ func Test_seataXAConnector_Connect(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockMgr := initMockResourceManager(t, ctrl) + mockMgr := initMockResourceManager(branch.BranchTypeXA, ctrl) _ = mockMgr - db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") + db, err := sql.Open(SeataXAMySQLDriver, "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") if err != nil { t.Fatal(err) } diff --git a/pkg/datasource/sql/datasource/datasource_manager.go b/pkg/datasource/sql/datasource/datasource_manager.go index 94f12352..b5457298 100644 --- a/pkg/datasource/sql/datasource/datasource_manager.go +++ b/pkg/datasource/sql/datasource/datasource_manager.go @@ -25,7 +25,6 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/protocol/branch" - "github.com/seata/seata-go/pkg/protocol/message" "github.com/seata/seata-go/pkg/rm" ) @@ -34,7 +33,7 @@ var ( tableMetaCacheMap = map[types.DBType]TableMetaCache{} ) -// RegisterTableCache +// RegisterTableCache register the table meta cache for at and xa func RegisterTableCache(dbType types.DBType, tableMetaCache TableMetaCache) { tableMetaCacheMap[dbType] = tableMetaCache } @@ -54,11 +53,8 @@ func GetDataSourceManager(branchType branch.BranchType) DataSourceManager { return nil } -// todo implements ResourceManagerOutbound interface -// DataSourceManager type DataSourceManager interface { rm.ResourceManager - // CreateTableMetaCache CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (TableMetaCache, error) } @@ -67,9 +63,8 @@ type entry struct { metaCache TableMetaCache } -// BasicSourceManager +// BasicSourceManager the basic source manager for xa and at type BasicSourceManager struct { - // lock lock sync.RWMutex // tableMetaCache // todo do not put meta cache here @@ -82,34 +77,7 @@ func NewBasicSourceManager() *BasicSourceManager { } } -// Commit a branch transaction -// TODO wait finish -func (dm *BasicSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) { - return branch.BranchStatusPhaseoneDone, nil -} - -// Rollback a branch transaction -// TODO wait finish -func (dm *BasicSourceManager) BranchRollback(ctx context.Context, req message.BranchRollbackRequest) (branch.BranchStatus, error) { - return branch.BranchStatusPhaseoneFailed, nil -} - -// Branch register long -func (dm *BasicSourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { - return 0, nil -} - -// Branch report -func (dm *BasicSourceManager) BranchReport(ctx context.Context, req message.BranchReportRequest) error { - return nil -} - -// Lock query boolean -func (dm *BasicSourceManager) LockQuery(ctx context.Context, branchType branch.BranchType, resourceId, xid, lockKeys string) (bool, error) { - return true, nil -} - -// Register a model.Resource to be managed by model.Resource Manager +// RegisterResource register a model.Resource to be managed by model.Resource Manager func (dm *BasicSourceManager) RegisterResource(resource rm.Resource) error { err := rm.GetRMRemotingInstance().RegisterResource(resource) if err != nil { @@ -118,22 +86,11 @@ func (dm *BasicSourceManager) RegisterResource(resource rm.Resource) error { return nil } -// Unregister a model.Resource from the model.Resource Manager func (dm *BasicSourceManager) UnregisterResource(resource rm.Resource) error { return fmt.Errorf("unsupport unregister resource") } -// Get all resources managed by this manager -func (dm *BasicSourceManager) GetManagedResources() *sync.Map { - return nil -} - -// Get the model.BranchType -func (dm *BasicSourceManager) GetBranchType() branch.BranchType { - return branch.BranchTypeAT -} - -// CreateTableMetaCache +// CreateTableMetaCache create a table meta cache func (dm *BasicSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (TableMetaCache, error) { dm.lock.Lock() defer dm.lock.Unlock() @@ -144,28 +101,19 @@ func (dm *BasicSourceManager) CreateTableMetaCache(ctx context.Context, resID st } dm.tableMetaCache[resID] = res - - // 注册 AT 数据资源 - // dm.resourceMgr.RegisterResource(ATResource) - return res.metaCache, err } // TableMetaCache tables metadata cache, default is open type TableMetaCache interface { - // Init Init(ctx context.Context, conn *sql.DB) error - // GetTableMeta GetTableMeta(ctx context.Context, dbName, table string) (*types.TableMeta, error) - // Destroy Destroy() error } // buildResource -// todo not here func buildResource(ctx context.Context, dbType types.DBType, db *sql.DB) (*entry, error) { cache := tableMetaCacheMap[dbType] - if err := cache.Init(ctx, db); err != nil { return nil, err } diff --git a/pkg/datasource/sql/datasource_resource.go b/pkg/datasource/sql/datasource_resource.go deleted file mode 100644 index 5b52d531..00000000 --- a/pkg/datasource/sql/datasource_resource.go +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package sql - -import ( - "fmt" - "time" - - "github.com/bluele/gcache" - - "github.com/seata/seata-go/pkg/datasource/sql/types" - "github.com/seata/seata-go/pkg/protocol/branch" -) - -type Holdable interface { - SetHeld(held bool) - IsHeld() bool - ShouldBeHeld() bool -} - -type BaseDataSourceResource struct { - db *DBResource - shouldBeHeld bool - keeper map[string]interface{} - Cache map[string]branch.BranchStatus -} - -var BranchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build() - -func (b *BaseDataSourceResource) init() error { - return nil -} - -func (b *BaseDataSourceResource) GetDB() *DBResource { - return b.db -} - -func (b *BaseDataSourceResource) SetDB(db *DBResource) { - b.db = db -} - -func (b *BaseDataSourceResource) IsShouldBeHeld() bool { - return b.shouldBeHeld -} - -func (b *BaseDataSourceResource) SetShouldBeHeld(shouldBeHeld bool) { - b.shouldBeHeld = shouldBeHeld -} - -func (b *BaseDataSourceResource) GetKeeper() map[string]interface{} { - return b.keeper -} - -func (b *BaseDataSourceResource) SetKeeper(keeper map[string]interface{}) { - b.keeper = keeper -} - -func (b *BaseDataSourceResource) GetCache() map[string]branch.BranchStatus { - return b.Cache -} - -func (b *BaseDataSourceResource) SetCache(cache map[string]branch.BranchStatus) { - b.Cache = cache -} - -func (b *BaseDataSourceResource) GetResourceId() string { - return b.db.GetResourceId() -} - -func (b *BaseDataSourceResource) Hold(key string, value Holdable) (interface{}, error) { - if value.IsHeld() { - var x = b.keeper[key] - if x != value { - return nil, fmt.Errorf("something wrong with keeper, keeping[%v] but[%v] is also kept with the same key[%v]", x, value, key) - } - return value, nil - } - var x = b.keeper[key] - b.keeper[key] = value - value.SetHeld(true) - return x, nil -} - -func (b *BaseDataSourceResource) Release(key string, value Holdable) (interface{}, error) { - if value.IsHeld() { - var x = b.keeper[key] - if x != value { - return nil, fmt.Errorf("something wrong with keeper, keeping[%v] but[%v] is also kept with the same key[%v]", x, value, key) - } - return value, nil - } - var x = b.keeper[key] - b.keeper[key] = value - value.SetHeld(true) - return x, nil -} - -func (b *BaseDataSourceResource) GetBranchStatus(xaBranchXid string) (interface{}, error) { - branchStatus, err := BranchStatusCache.GetIFPresent(xaBranchXid) - return branchStatus, err -} - -func (b *BaseDataSourceResource) GetDbType() string { - return b.db.dbType.String() -} - -func (b *BaseDataSourceResource) SetDbType(dbType types.DBType) { - b.db.dbType = dbType -} diff --git a/pkg/datasource/sql/db.go b/pkg/datasource/sql/db.go index 6d9c4bf6..6cfaaa77 100644 --- a/pkg/datasource/sql/db.go +++ b/pkg/datasource/sql/db.go @@ -18,19 +18,24 @@ package sql import ( + "context" "database/sql" + "database/sql/driver" + "fmt" + "sync" "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/undo" + "github.com/seata/seata-go/pkg/datasource/sql/util" "github.com/seata/seata-go/pkg/protocol/branch" ) type dbOption func(db *DBResource) -func withGroupID(id string) dbOption { +func withDsn(dsn string) dbOption { return func(db *DBResource) { - db.groupID = id + db.dsn = dsn } } @@ -52,21 +57,33 @@ func withDBType(dt types.DBType) dbOption { } } +func withBranchType(dt branch.BranchType) dbOption { + return func(db *DBResource) { + db.branchType = dt + } +} + func withTarget(source *sql.DB) dbOption { return func(db *DBResource) { db.db = source } } +func withConnector(ci driver.Connector) dbOption { + return func(db *DBResource) { + db.connector = ci + } +} + func withDBName(dbName string) dbOption { return func(db *DBResource) { db.dbName = dbName } } -func withConf(conf *seataServerConfig) dbOption { +func withConf(conf *XAConnConf) dbOption { return func(db *DBResource) { - db.conf = *conf + db.xaConnConf = conf } } @@ -77,47 +94,36 @@ func newResource(opts ...dbOption) (*DBResource, error) { opts[i](db) } - return db, db.init() + db.init() + return db, nil } // DBResource proxy sql.DB, enchance database/sql.DB to add distribute transaction ability type DBResource struct { - // groupID - groupID string - // resourceID + xaConnConf *XAConnConf + // only use by mysql + dbVersion string + dsn string resourceID string - // conf - conf seataServerConfig - // db - db *sql.DB - dbName string - // dbType - dbType types.DBType - // undoLogMgr + db *sql.DB + connector driver.Connector + dbName string + dbType types.DBType undoLogMgr undo.UndoLogManager - // metaCache - metaCache datasource.TableMetaCache -} + branchType branch.BranchType -func (db *DBResource) init() error { - return nil + // for xa + metaCache datasource.TableMetaCache + shouldBeHeld bool + keeper sync.Map } -// todo do not put meta data to rm -//func (db *DBResource) init() error { -// mgr := datasource.GetDataSourceManager(db.GetBranchType()) -// metaCache, err := mgr.CreateTableMetaCache(context.Background(), db.resourceID, db.dbType, db.db) -// if err != nil { -// return err -// } -// -// db.metaCache = metaCache -// -// return nil -//} - func (db *DBResource) GetResourceGroupId() string { - return db.groupID + panic("implement me") +} + +func (db *DBResource) init() { + db.checkDbVersion() } func (db *DBResource) GetResourceId() string { @@ -125,18 +131,106 @@ func (db *DBResource) GetResourceId() string { } func (db *DBResource) GetBranchType() branch.BranchType { - return db.conf.BranchType + return db.branchType +} + +func (db *DBResource) GetDB() *sql.DB { + return db.db } -type SqlDBProxy struct { - db *sql.DB - dbName string +func (db *DBResource) GetDBName() string { + return db.dbName } -func (s *SqlDBProxy) GetDB() *sql.DB { - return s.db +func (db *DBResource) GetDbType() types.DBType { + return db.dbType } -func (s *SqlDBProxy) GetDBName() string { - return s.dbName +func (db *DBResource) SetDbType(dbType types.DBType) { + db.dbType = dbType +} + +func (db *DBResource) SetDbVersion(v string) { + db.dbVersion = v +} + +func (db *DBResource) GetDbVersion() string { + return db.dbVersion +} + +func (db *DBResource) IsShouldBeHeld() bool { + return db.shouldBeHeld +} + +// Hold the xa connection. +func (db *DBResource) Hold(xaBranchID string, v interface{}) error { + _, exist := db.keeper.Load(xaBranchID) + if !exist { + db.keeper.Store(xaBranchID, v) + return nil + } + return nil +} + +func (db *DBResource) Release(xaBranchID string) { + db.keeper.Delete(xaBranchID) +} + +func (db *DBResource) Lookup(xaBranchID string) (interface{}, bool) { + return db.keeper.Load(xaBranchID) +} + +func (db *DBResource) GetKeeper() *sync.Map { + return &db.keeper +} + +func (db *DBResource) ConnectionForXA(ctx context.Context, xaXid XAXid) (*XAConn, error) { + xaBranchXid := xaXid.String() + tmpConn, ok := db.Lookup(xaBranchXid) + if ok && tmpConn != nil { + connectionProxyXa, isConnectionProxyXa := tmpConn.(*XAConn) + if !isConnectionProxyXa { + return nil, fmt.Errorf("get connection proxy xa from cache error, xid:%s", xaXid.String()) + } + return connectionProxyXa, nil + } + + // why here need a new connection? + // 1. because there maybe a rm cluster + // 2. the first phase select a rm1, and store the connection is the keeper + // 3. tc request the second phase. but the rm1 is shutdown, so the tc select another rm (like rm2) + // 4. so when the second phase request coming to rm2, rm2 must not store the connection. + // 5. the rm2 get the second phase do the two thing. + // 1. in mysql version >= 8.0.29, mysql support the xa transaction commit by another connection. so just commit + // 2. when the version < 8.0.29. so just make the transaction rollback + newDriverConn, err := db.connector.Connect(ctx) + if err != nil { + return nil, fmt.Errorf("get xa new connection failure, xid:%s, err:%v", xaXid.String(), err) + } + xaConn := &XAConn{ + Conn: newDriverConn.(*Conn), + } + return xaConn, nil +} + +func (db *DBResource) checkDbVersion() error { + switch db.dbType { + case types.DBTypeMySQL: + currentVersion, err := util.ConvertDbVersion(db.dbVersion) + if err != nil { + return fmt.Errorf("new connection xa proxy convert db version:%s err:%v", db.GetDbVersion(), err) + } + + shouldKeptVersion, err := util.ConvertDbVersion("8.0.29") + if err != nil { + return fmt.Errorf("new connection xa proxy convert db version 8.0.29 err:%v", err) + } + + if currentVersion < shouldKeptVersion { + db.shouldBeHeld = true + } + case types.DBTypeMARIADB: + db.shouldBeHeld = true + } + return nil } diff --git a/pkg/datasource/sql/driver.go b/pkg/datasource/sql/driver.go index c10dbbfb..6c1e3baa 100644 --- a/pkg/datasource/sql/driver.go +++ b/pkg/datasource/sql/driver.go @@ -25,10 +25,10 @@ import ( "fmt" "strings" - mysql2 "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" - "github.com/go-sql-driver/mysql" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" + mysql2 "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/util/log" @@ -44,15 +44,17 @@ const ( func initDriver() { sql.Register(SeataATMySQLDriver, &seataATDriver{ seataDriver: &seataDriver{ - transType: types.ATMode, - target: mysql.MySQLDriver{}, + branchType: branch.BranchTypeAT, + transType: types.ATMode, + target: mysql.MySQLDriver{}, }, }) sql.Register(SeataXAMySQLDriver, &seataXADriver{ seataDriver: &seataDriver{ - transType: types.XAMode, - target: mysql.MySQLDriver{}, + branchType: branch.BranchTypeXA, + transType: types.XAMode, + target: mysql.MySQLDriver{}, }, }) } @@ -98,8 +100,9 @@ func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err erro } type seataDriver struct { - transType types.TransactionMode - target driver.Driver + branchType branch.BranchType + transType types.TransactionMode + target driver.Driver } // Open never be called, because seataDriver implemented dri.DriverContext interface. @@ -124,7 +127,7 @@ func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName()) } - proxy, err := getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name) + proxy, err := d.getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name) if err != nil { log.Errorf("register resource: %w", err) return nil, err @@ -133,43 +136,15 @@ func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) return proxy, nil } -func (d *seataDriver) getTargetDriverName() string { - return "mysql" -} - -type dsnConnector struct { - dsn string - driver driver.Driver -} - -func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) { - return t.driver.Open(t.dsn) -} - -func (t *dsnConnector) Driver() driver.Driver { - return t.driver -} - -func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db *sql.DB, - dataSourceName string, opts ...seataOption) (driver.Connector, error) { - conf := loadConfig() - for i := range opts { - opts[i](conf) - } - - if err := conf.validate(); err != nil { - log.Errorf("invalid conf: %w", err) - return nil, err - } - +func (d *seataDriver) getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, + db *sql.DB, dataSourceName string) (driver.Connector, error) { cfg, _ := mysql.ParseDSN(dataSourceName) options := []dbOption{ - withGroupID(conf.GroupID), - withResourceID(parseResourceID(dataSourceName)), - withConf(conf), withTarget(db), + withBranchType(d.branchType), withDBType(dbType), withDBName(cfg.DBName), + withConnector(connector), } res, err := newResource(options...) @@ -179,7 +154,8 @@ func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db * } datasource.RegisterTableCache(types.DBTypeMySQL, mysql2.NewTableMetaInstance(db)) - if err = datasource.GetDataSourceManager(conf.BranchType).RegisterResource(res); err != nil { + + if err = datasource.GetDataSourceManager(d.branchType).RegisterResource(res); err != nil { log.Errorf("regisiter resource: %w", err) return nil, err } @@ -187,38 +163,25 @@ func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db * return &seataConnector{ res: res, target: connector, - conf: conf, cfg: cfg, }, nil } -type ( - seataOption func(cfg *seataServerConfig) - - // seataServerConfig - seataServerConfig struct { - // GroupID - GroupID string `yaml:"groupID"` - // BranchType - BranchType branch.BranchType - // Endpoints - Endpoints []string `yaml:"endpoints" json:"endpoints"` - } -) +func (d *seataDriver) getTargetDriverName() string { + return "mysql" +} -func (c *seataServerConfig) validate() error { - return nil +type dsnConnector struct { + dsn string + driver driver.Driver } -// loadConfig -func loadConfig() *seataServerConfig { - // set default value first. - // todo read from configuration file. - return &seataServerConfig{ - GroupID: "DEFAULT_GROUP", - BranchType: branch.BranchTypeAT, - Endpoints: []string{"127.0.0.1:8888"}, - } +func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) { + return t.driver.Open(t.dsn) +} + +func (t *dsnConnector) Driver() driver.Driver { + return t.driver } func parseResourceID(dsn string) string { diff --git a/pkg/datasource/sql/driver_test.go b/pkg/datasource/sql/driver_test.go index 90ec58e0..e94dfc54 100644 --- a/pkg/datasource/sql/driver_test.go +++ b/pkg/datasource/sql/driver_test.go @@ -28,12 +28,14 @@ import ( "github.com/golang/mock/gomock" "github.com/seata/seata-go/pkg/datasource/sql/mock" + "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/util/reflectx" "github.com/stretchr/testify/assert" ) -func initMockResourceManager(t *testing.T, ctrl *gomock.Controller) *mock.MockDataSourceManager { +func initMockResourceManager(branchType branch.BranchType, ctrl *gomock.Controller) *mock.MockDataSourceManager { mockResourceMgr := mock.NewMockDataSourceManager(ctrl) + mockResourceMgr.SetBranchType(branchType) rm.GetRmCacheInstance().RegisterResourceManager(mockResourceMgr) mockResourceMgr.EXPECT().RegisterResource(gomock.Any()).AnyTimes().Return(nil) mockResourceMgr.EXPECT().CreateTableMetaCache(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil) @@ -45,7 +47,7 @@ func Test_seataATDriver_Open(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockMgr := initMockResourceManager(t, ctrl) + mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) _ = mockMgr db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") @@ -93,7 +95,7 @@ func Test_seataATDriver_OpenConnector(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockMgr := initMockResourceManager(t, ctrl) + mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) _ = mockMgr db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") @@ -119,7 +121,7 @@ func Test_seataXADriver_OpenConnector(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockMgr := initMockResourceManager(t, ctrl) + mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl) _ = mockMgr db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") diff --git a/pkg/datasource/sql/exec/at/insert_executor.go b/pkg/datasource/sql/exec/at/insert_executor.go index ab92e4ea..24426e5f 100644 --- a/pkg/datasource/sql/exec/at/insert_executor.go +++ b/pkg/datasource/sql/exec/at/insert_executor.go @@ -514,8 +514,8 @@ func canAutoIncrement(pkMetaMap map[string]types.ColumnMeta) bool { return false } -func (u *insertExecutor) isAstStmtValid() bool { - return u.parserCtx != nil && u.parserCtx.InsertStmt != nil +func (i *insertExecutor) isAstStmtValid() bool { + return i.parserCtx != nil && i.parserCtx.InsertStmt != nil } func (i *insertExecutor) autoGeneratePks(execCtx *types.ExecContext, autoColumnName string, lastInsetId, updateCount int64) (map[string][]interface{}, error) { diff --git a/pkg/datasource/sql/exec/executor.go b/pkg/datasource/sql/exec/executor.go index 0ea477ac..e95892dd 100644 --- a/pkg/datasource/sql/exec/executor.go +++ b/pkg/datasource/sql/exec/executor.go @@ -27,7 +27,6 @@ import ( var ( atExecutors = make(map[types.DBType]func() SQLExecutor) - xaExecutors = make(map[types.DBType]func() SQLExecutor) ) // RegisterATExecutor AT executor @@ -35,13 +34,6 @@ func RegisterATExecutor(dt types.DBType, builder func() SQLExecutor) { atExecutors[dt] = builder } -// RegisterXAExecutor XA executor -func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) { - xaExecutors[dt] = func() SQLExecutor { - return builder() - } -} - type ( CallbackWithNamedValue func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error) @@ -66,12 +58,6 @@ func BuildExecutor(dbType types.DBType, transactionMode types.TransactionMode, q hooks = append(hooks, commonHook...) hooks = append(hooks, hookSolts[parseContext.SQLType]...) - if transactionMode == types.XAMode { - e := xaExecutors[dbType]() - e.Interceptors(hooks) - return e, nil - } - e := atExecutors[dbType]() e.Interceptors(hooks) return e, nil @@ -82,12 +68,10 @@ type BaseExecutor struct { ex SQLExecutor } -// Interceptors func (e *BaseExecutor) Interceptors(interceptors []SQLHook) { e.hooks = interceptors } -// ExecWithNamedValue func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) { for i := range e.hooks { e.hooks[i].Before(ctx, execCtx) diff --git a/pkg/datasource/sql/exec/resource_xa.go b/pkg/datasource/sql/exec/resource_xa.go deleted file mode 100644 index 22f9ade4..00000000 --- a/pkg/datasource/sql/exec/resource_xa.go +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package exec - -import ( - "time" -) - -const ( - // TMENDICANT Ends a recovery scan. - TMENDRSCAN = 0x00800000 - - /** - * Disassociates the caller and marks the transaction branch - * rollback-only. - */ - TMFAIL = 0x20000000 - - /** - * Caller is joining existing transaction branch. - */ - TMJOIN = 0x00200000 - - /** - * Use TMNOFLAGS to indicate no flags value is selected. - */ - TMNOFLAGS = 0x00000000 - - /** - * Caller is using one-phase optimization. - */ - TMONEPHASE = 0x40000000 - - /** - * Caller is resuming association with a suspended - * transaction branch. - */ - TMRESUME = 0x08000000 - - /** - * Starts a recovery scan. - */ - TMSTARTRSCAN = 0x01000000 - - /** - * Disassociates caller from a transaction branch. - */ - TMSUCCESS = 0x04000000 - - /** - * Caller is suspending (not ending) its association with - * a transaction branch. - */ - TMSUSPEND = 0x02000000 - - /** - * The transaction branch has been read-only and has been committed. - */ - XA_RDONLY = 0x00000003 - - /** - * The transaction work has been prepared normally. - */ - XA_OK = 0 -) - -type XAResource interface { - Commit(xid string, onePhase bool) error - End(xid string, flags int) error - Forget(xid string) error - GetTransactionTimeout() time.Duration - IsSameRM(resource XAResource) bool - XAPrepare(xid string) error - Recover(flag int) ([]string, error) - Rollback(xid string) error - SetTransactionTimeout(duration time.Duration) bool - Start(xid string, flags int) error -} diff --git a/pkg/datasource/sql/exec/select_for_update_executor.go b/pkg/datasource/sql/exec/select_for_update_executor.go index b79a7376..50b138d6 100644 --- a/pkg/datasource/sql/exec/select_for_update_executor.go +++ b/pkg/datasource/sql/exec/select_for_update_executor.go @@ -30,6 +30,7 @@ import ( "github.com/arana-db/parser/ast" "github.com/arana-db/parser/format" "github.com/arana-db/parser/model" + "github.com/seata/seata-go/pkg/datasource/sql/datasource" "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/undo/builder" diff --git a/pkg/datasource/sql/exec/xa/executor_xa.go b/pkg/datasource/sql/exec/xa/executor_xa.go deleted file mode 100644 index 8075a454..00000000 --- a/pkg/datasource/sql/exec/xa/executor_xa.go +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package xa - -import ( - "context" - "database/sql/driver" - - "github.com/seata/seata-go/pkg/datasource/sql/exec" - "github.com/seata/seata-go/pkg/datasource/sql/types" -) - -// XAExecutor The XA transaction manager. -type XAExecutor struct { - hooks []exec.SQLHook - ex exec.SQLExecutor -} - -// Interceptors set xa executor hooks -func (e *XAExecutor) Interceptors(hooks []exec.SQLHook) { - e.hooks = hooks -} - -// ExecWithNamedValue -func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { - for _, hook := range e.hooks { - hook.Before(ctx, execCtx) - } - - defer func() { - for _, hook := range e.hooks { - hook.After(ctx, execCtx) - } - }() - - if e.ex != nil { - return e.ex.ExecWithNamedValue(ctx, execCtx, f) - } - - return f(ctx, execCtx.Query, execCtx.NamedValues) -} - -// ExecWithValue -func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) { - for _, hook := range e.hooks { - hook.Before(ctx, execCtx) - } - - defer func() { - for _, hook := range e.hooks { - hook.After(ctx, execCtx) - } - }() - - if e.ex != nil { - return e.ex.ExecWithValue(ctx, execCtx, f) - } - - nvargs := make([]driver.NamedValue, len(execCtx.Values)) - for i, value := range execCtx.Values { - nvargs = append(nvargs, driver.NamedValue{ - Value: value, - Ordinal: i, - }) - } - execCtx.NamedValues = nvargs - - return f(ctx, execCtx.Query, execCtx.NamedValues) -} diff --git a/pkg/datasource/sql/mock/mock_datasource_manager.go b/pkg/datasource/sql/mock/mock_datasource_manager.go index 6b1f9209..cd788a24 100644 --- a/pkg/datasource/sql/mock/mock_datasource_manager.go +++ b/pkg/datasource/sql/mock/mock_datasource_manager.go @@ -38,6 +38,7 @@ import ( type MockDataSourceManager struct { ctrl *gomock.Controller recorder *MockDataSourceManagerMockRecorder + branchType branch.BranchType } // MockDataSourceManagerMockRecorder is the mock recorder for MockDataSourceManager. @@ -131,9 +132,13 @@ func (mr *MockDataSourceManagerMockRecorder) CreateTableMetaCache(ctx, resID, db return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTableMetaCache", reflect.TypeOf((*MockDataSourceManager)(nil).CreateTableMetaCache), ctx, resID, dbType, db) } +func (m *MockDataSourceManager) SetBranchType(branchType branch.BranchType) { + m.branchType = branchType +} + // GetBranchType mocks base method. func (m *MockDataSourceManager) GetBranchType() branch.BranchType { - return branch.BranchTypeAT + return m.branchType } // GetBranchType indicates an expected call of GetBranchType. diff --git a/pkg/datasource/sql/plugin.go b/pkg/datasource/sql/plugin.go index 4c7ffc59..8cab6ade 100644 --- a/pkg/datasource/sql/plugin.go +++ b/pkg/datasource/sql/plugin.go @@ -20,7 +20,6 @@ package sql import ( "github.com/seata/seata-go/pkg/datasource/sql/exec" "github.com/seata/seata-go/pkg/datasource/sql/exec/at" - "github.com/seata/seata-go/pkg/datasource/sql/exec/xa" "github.com/seata/seata-go/pkg/datasource/sql/hook" "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/datasource/sql/undo" @@ -42,7 +41,6 @@ func hookRegister() { func executorRegister() { at.Init() - xa.Init() } func undoInit() { diff --git a/pkg/datasource/sql/root_context.go b/pkg/datasource/sql/root_context.go deleted file mode 100644 index 993d317d..00000000 --- a/pkg/datasource/sql/root_context.go +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package sql - -import ( - "github.com/seata/seata-go/pkg/protocol/branch" -) - -type RootContext interface { - RootContext() - SetDefaultBranchType(branchType branch.BranchType) - GetXID() string - Bind(xid string) - GetTimeout() (int, bool) - SetTimeout(timeout int) -} diff --git a/pkg/datasource/sql/tx.go b/pkg/datasource/sql/tx.go index 8923899f..405c03c1 100644 --- a/pkg/datasource/sql/tx.go +++ b/pkg/datasource/sql/tx.go @@ -22,17 +22,16 @@ import ( "database/sql/driver" "fmt" "sync" + "time" "github.com/seata/seata-go/pkg/datasource/sql/datasource" + "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/seata/seata-go/pkg/protocol/branch" "github.com/seata/seata-go/pkg/rm" + "github.com/seata/seata-go/pkg/util/backoff" "github.com/seata/seata-go/pkg/util/log" - - "github.com/seata/seata-go/pkg/datasource/sql/types" ) -const REPORT_RETRY_COUNT = 5 - var ( hl sync.RWMutex txHooks []txHook @@ -146,19 +145,31 @@ func (tx *Tx) commitOnLocal() error { // register func (tx *Tx) register(ctx *types.TransactionContext) error { - if !ctx.HasUndoLog() || !ctx.HasLockKey() { + if ctx.TransactionMode.BranchType() == branch.BranchTypeUnknow { return nil } - lockKey := "" - for k, _ := range ctx.LockKeys { - lockKey += k + ";" + + if ctx.TransactionMode.BranchType() == branch.BranchTypeAT && !ctx.HasUndoLog() || !ctx.HasLockKey() { + return nil } + request := rm.BranchRegisterParam{ Xid: ctx.XID, BranchType: ctx.TransactionMode.BranchType(), ResourceId: ctx.ResourceID, - LockKeys: lockKey, } + + var lockKey string + if ctx.TransactionMode == types.ATMode { + if !ctx.HasUndoLog() || !ctx.HasLockKey() { + return nil + } + for k, _ := range ctx.LockKeys { + lockKey += k + ";" + } + request.LockKeys = lockKey + } + dataSourceManager := datasource.GetDataSourceManager(ctx.TransactionMode.BranchType()) branchId, err := dataSourceManager.BranchRegister(context.Background(), request) if err != nil { @@ -184,21 +195,22 @@ func (tx *Tx) report(success bool) error { if dataSourceManager == nil { return fmt.Errorf("get dataSourceManager failed") } - retry := REPORT_RETRY_COUNT - for retry > 0 { - err := dataSourceManager.BranchReport(context.Background(), request) - if err != nil { - retry-- - log.Infof("Failed to report [%s / %s] commit done [%s] Retry Countdown: %s", tx.tranCtx.BranchID, tx.tranCtx.XID, success, retry) - if retry == 0 { - log.Errorf("Failed to report branch status: %s", err.Error()) - return err - } - } else { - return nil + + retry := backoff.New(context.Background(), backoff.Config{ + MinBackoff: 100 * time.Millisecond, + MaxBackoff: 200 * time.Millisecond, + MaxRetries: 5, + }) + + var err error + for retry.Ongoing() { + if err = dataSourceManager.BranchReport(context.Background(), request); err == nil { + break } + log.Infof("Failed to report [%s / %s] commit done [%s] Retry Countdown: %s", tx.tranCtx.BranchID, tx.tranCtx.XID, success, retry) + retry.Wait() } - return nil + return err } func getStatus(success bool) branch.BranchStatus { diff --git a/pkg/datasource/sql/tx_at.go b/pkg/datasource/sql/tx_at.go index deca88b1..1cae0f77 100644 --- a/pkg/datasource/sql/tx_at.go +++ b/pkg/datasource/sql/tx_at.go @@ -19,6 +19,7 @@ package sql import ( "github.com/pkg/errors" + "github.com/seata/seata-go/pkg/datasource/sql/undo" ) diff --git a/pkg/datasource/sql/tx_xa.go b/pkg/datasource/sql/tx_xa.go index 5e6dea17..cdc87f8c 100644 --- a/pkg/datasource/sql/tx_xa.go +++ b/pkg/datasource/sql/tx_xa.go @@ -17,7 +17,6 @@ package sql -// XATx type XATx struct { tx *Tx } @@ -32,20 +31,14 @@ func (tx *XATx) Commit() error { } func (tx *XATx) Rollback() error { - err := tx.tx.Rollback() - if err != nil { - - originTx := tx.tx - - if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() { - originTx.report(false) - } + originTx := tx.tx + if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() { + return originTx.report(false) } - - return err + return nil } -// commitOnXA +// commitOnXA commit xa and register branch transaction func (tx *XATx) commitOnXA() error { return nil } diff --git a/pkg/datasource/sql/types/types.go b/pkg/datasource/sql/types/types.go index bfe98e15..cfe68615 100644 --- a/pkg/datasource/sql/types/types.go +++ b/pkg/datasource/sql/types/types.go @@ -22,9 +22,9 @@ import ( "fmt" "strings" - "github.com/seata/seata-go/pkg/protocol/branch" - "github.com/google/uuid" + + "github.com/seata/seata-go/pkg/protocol/branch" ) type DBType int16 @@ -76,6 +76,7 @@ const ( DBTypePostgreSQL DBTypeSQLServer DBTypeOracle + DBTypeMARIADB BranchPhase_Unknown = 0 BranchPhase_Done = 1 diff --git a/pkg/datasource/sql/util/convert.go b/pkg/datasource/sql/util/convert.go index bfcf82f2..fb4a9810 100644 --- a/pkg/datasource/sql/util/convert.go +++ b/pkg/datasource/sql/util/convert.go @@ -28,8 +28,10 @@ import ( "database/sql/driver" "errors" "fmt" + "math" "reflect" "strconv" + "strings" "time" ) @@ -375,3 +377,30 @@ type decimalCompose interface { // represented then an error should be returned. Compose(form byte, negative bool, coefficient []byte, exponent int32) error } + +// ConvertDbVersion convert a string db version to a number. +func ConvertDbVersion(version string) (int, error) { + parts := strings.Split(version, ".") + size := len(parts) + maxVersionDot := 3 + if size > maxVersionDot+1 { + return 0, fmt.Errorf("incompatible version format: %s", version) + } + + var res int + for idx, part := range parts { + if partInt, err := strconv.Atoi(part); err == nil { + res += calculatePartValue(partInt, size, idx) + } else { + subParts := strings.Split(part, "-") + if subPartInt, err := strconv.Atoi(subParts[0]); err == nil { + res += calculatePartValue(subPartInt, size, idx) + } + } + } + return res, nil +} + +func calculatePartValue(partNumeric, size, index int) int { + return partNumeric * int(math.Pow(100, float64(size-index))) +} diff --git a/pkg/datasource/sql/conn_test.go b/pkg/datasource/sql/util/convert_test.go similarity index 67% rename from pkg/datasource/sql/conn_test.go rename to pkg/datasource/sql/util/convert_test.go index 4484b028..86a7bb2c 100644 --- a/pkg/datasource/sql/conn_test.go +++ b/pkg/datasource/sql/util/convert_test.go @@ -15,22 +15,27 @@ * limitations under the License. */ -package sql +package util import ( "testing" - "github.com/seata/seata-go/pkg/datasource/sql/exec" - "github.com/seata/seata-go/pkg/datasource/sql/exec/xa" - "github.com/seata/seata-go/pkg/datasource/sql/types" "github.com/stretchr/testify/assert" ) -func TestConn_BuildXAExecutor(t *testing.T) { - executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.XAMode, "SELECT * FROM user") +func TestConvertDbVersion(t *testing.T) { + version1 := "3.1.2" + v1Int, err1 := ConvertDbVersion(version1) + assert.NoError(t, err1) - assert.NoError(t, err) + version2 := "3.1.3" + v2Int, err2 := ConvertDbVersion(version2) + assert.NoError(t, err2) - _, ok := executor.(*xa.XAExecutor) - assert.True(t, ok, "need xa executor") + assert.Less(t, v1Int, v2Int) + + version3 := "3.1.3" + v3Int, err3 := ConvertDbVersion(version3) + assert.NoError(t, err3) + assert.Equal(t, v2Int, v3Int) } diff --git a/pkg/datasource/sql/exec/mysql_xa_resource.go b/pkg/datasource/sql/xa/mysql_xa_connection.go similarity index 71% rename from pkg/datasource/sql/exec/mysql_xa_resource.go rename to pkg/datasource/sql/xa/mysql_xa_connection.go index 5bb66e5a..bb538e15 100644 --- a/pkg/datasource/sql/exec/mysql_xa_resource.go +++ b/pkg/datasource/sql/xa/mysql_xa_connection.go @@ -15,24 +15,27 @@ * limitations under the License. */ -package exec +package xa import ( "context" "database/sql/driver" + "errors" "fmt" "io" "strings" "time" - - "github.com/pkg/errors" ) type MysqlXAConn struct { driver.Conn } -func (c *MysqlXAConn) Commit(xid string, onePhase bool) error { +func NewMysqlXaConn(conn driver.Conn) *MysqlXAConn { + return &MysqlXAConn{Conn: conn} +} + +func (c *MysqlXAConn) Commit(ctx context.Context, xid string, onePhase bool) error { var sb strings.Builder sb.WriteString("XA COMMIT ") sb.WriteString(xid) @@ -41,33 +44,33 @@ func (c *MysqlXAConn) Commit(xid string, onePhase bool) error { } conn, _ := c.Conn.(driver.ExecerContext) - _, err := conn.ExecContext(context.TODO(), sb.String(), nil) + _, err := conn.ExecContext(ctx, sb.String(), nil) return err } -func (c *MysqlXAConn) End(xid string, flags int) error { +func (c *MysqlXAConn) End(ctx context.Context, xid string, flags int) error { var sb strings.Builder sb.WriteString("XA END ") sb.WriteString(xid) switch flags { - case TMSUCCESS: + case TMSuccess: break - case TMSUSPEND: + case TMSuspend: sb.WriteString(" SUSPEND") break - case TMFAIL: + case TMFail: break default: return errors.New("invalid arguments") } conn, _ := c.Conn.(driver.ExecerContext) - _, err := conn.ExecContext(context.TODO(), sb.String(), nil) + _, err := conn.ExecContext(ctx, sb.String(), nil) return err } -func (c *MysqlXAConn) Forget(xid string) error { +func (c *MysqlXAConn) Forget(ctx context.Context, xid string) error { // mysql doesn't support this return errors.New("mysql doesn't support this") } @@ -78,29 +81,29 @@ func (c *MysqlXAConn) GetTransactionTimeout() time.Duration { // IsSameRM is called to determine if the resource manager instance represented by the target object // is the same as the resource manager instance represented by the parameter xares. -func (c *MysqlXAConn) IsSameRM(xares XAResource) bool { +func (c *MysqlXAConn) IsSameRM(ctx context.Context, xares XAResource) bool { // todo: the fn depends on the driver.Conn, but it doesn't support return false } -func (c *MysqlXAConn) XAPrepare(xid string) error { +func (c *MysqlXAConn) XAPrepare(ctx context.Context, xid string) error { var sb strings.Builder sb.WriteString("XA PREPARE ") sb.WriteString(xid) conn, _ := c.Conn.(driver.ExecerContext) - _, err := conn.ExecContext(context.TODO(), sb.String(), nil) + _, err := conn.ExecContext(ctx, sb.String(), nil) return err } // Recover Obtains a list of prepared transaction branches from a resource manager. // The transaction manager calls this method during recovery to obtain the list of transaction branches // that are currently in prepared or heuristically completed states. -func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) { - startRscan := (flag & TMSTARTRSCAN) > 0 - endRscan := (flag & TMENDRSCAN) > 0 +func (c *MysqlXAConn) Recover(ctx context.Context, flag int) (xids []string, err error) { + startRscan := (flag & TMStartRScan) > 0 + endRscan := (flag & TMEndRScan) > 0 - if !startRscan && !endRscan && flag != TMNOFLAGS { + if !startRscan && !endRscan && flag != TMNoFlags { return nil, errors.New("invalid arguments") } @@ -109,7 +112,7 @@ func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) { } conn := c.Conn.(driver.QueryerContext) - res, err := conn.QueryContext(context.TODO(), "XA RECOVER", nil) + res, err := conn.QueryContext(ctx, "XA RECOVER", nil) if err != nil { return nil, err } @@ -133,13 +136,13 @@ func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) { return xids, err } -func (c *MysqlXAConn) Rollback(xid string) error { +func (c *MysqlXAConn) Rollback(ctx context.Context, xid string) error { var sb strings.Builder sb.WriteString("XA ROLLBACK ") sb.WriteString(xid) conn, _ := c.Conn.(driver.ExecerContext) - _, err := conn.ExecContext(context.TODO(), sb.String(), nil) + _, err := conn.ExecContext(ctx, sb.String(), nil) return err } @@ -147,25 +150,25 @@ func (c *MysqlXAConn) SetTransactionTimeout(duration time.Duration) bool { return false } -func (c *MysqlXAConn) Start(xid string, flags int) error { +func (c *MysqlXAConn) Start(ctx context.Context, xid string, flags int) error { var sb strings.Builder sb.WriteString("XA START") sb.WriteString(xid) switch flags { - case TMJOIN: + case TMJoin: sb.WriteString(" JOIN") break - case TMRESUME: + case TMResume: sb.WriteString(" RESUME") break - case TMNOFLAGS: + case TMNoFlags: break default: return errors.New("invalid arguments") } conn, _ := c.Conn.(driver.ExecerContext) - _, err := conn.ExecContext(context.TODO(), sb.String(), nil) + _, err := conn.ExecContext(ctx, sb.String(), nil) return err } diff --git a/pkg/datasource/sql/exec/mysql_xa_resource_test.go b/pkg/datasource/sql/xa/mysql_xa_connection_test.go similarity index 90% rename from pkg/datasource/sql/exec/mysql_xa_resource_test.go rename to pkg/datasource/sql/xa/mysql_xa_connection_test.go index 9f9d644c..08a370f1 100644 --- a/pkg/datasource/sql/exec/mysql_xa_resource_test.go +++ b/pkg/datasource/sql/xa/mysql_xa_connection_test.go @@ -15,19 +15,18 @@ * limitations under the License. */ -package exec +package xa import ( "context" "database/sql/driver" - "fmt" + "errors" "io" "reflect" "strings" "testing" "github.com/golang/mock/gomock" - "github.com/pkg/errors" "github.com/seata/seata-go/pkg/datasource/sql/mock" ) @@ -78,7 +77,7 @@ func TestMysqlXAConn_Commit(t *testing.T) { c := &MysqlXAConn{ Conn: mockConn, } - if err := c.Commit(tt.input.xid, tt.input.onePhase); (err != nil) != tt.wantErr { + if err := c.Commit(context.Background(), tt.input.xid, tt.input.onePhase); (err != nil) != tt.wantErr { t.Errorf("Commit() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -102,7 +101,7 @@ func TestMysqlXAConn_End(t *testing.T) { name: "tm success", input: args{ xid: "xid", - flags: TMSUCCESS, + flags: TMSuccess, }, wantErr: false, }, @@ -110,7 +109,7 @@ func TestMysqlXAConn_End(t *testing.T) { name: "tm failed", input: args{ xid: "xid", - flags: TMFAIL, + flags: TMFail, }, wantErr: false, }, @@ -124,7 +123,7 @@ func TestMysqlXAConn_End(t *testing.T) { c := &MysqlXAConn{ Conn: mockConn, } - if err := c.End(tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { + if err := c.End(context.Background(), tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { t.Errorf("End() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -148,7 +147,7 @@ func TestMysqlXAConn_Start(t *testing.T) { name: "normal start", input: args{ xid: "xid", - flags: TMNOFLAGS, + flags: TMNoFlags, }, wantErr: false, }, @@ -161,7 +160,7 @@ func TestMysqlXAConn_Start(t *testing.T) { c := &MysqlXAConn{ Conn: mockConn, } - if err := c.Start(tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { + if err := c.Start(context.Background(), tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr { t.Errorf("Start() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -196,7 +195,7 @@ func TestMysqlXAConn_XAPrepare(t *testing.T) { c := &MysqlXAConn{ Conn: mockConn, } - if err := c.XAPrepare(tt.input.xid); (err != nil) != tt.wantErr { + if err := c.XAPrepare(context.Background(), tt.input.xid); (err != nil) != tt.wantErr { t.Errorf("XAPrepare() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -219,7 +218,7 @@ func TestMysqlXAConn_Recover(t *testing.T) { { name: "normal recover", args: args{ - flag: TMSTARTRSCAN | TMENDRSCAN, + flag: TMStartRScan | TMEndRScan, }, want: []string{"xid", "another_xid"}, wantErr: false, @@ -227,14 +226,14 @@ func TestMysqlXAConn_Recover(t *testing.T) { { name: "invalid flag for recover", args: args{ - flag: TMFAIL, + flag: TMFail, }, wantErr: true, }, { name: "valid flag for recover but don't scan", args: args{ - flag: TMENDRSCAN, + flag: TMEndRScan, }, want: nil, wantErr: false, @@ -257,7 +256,7 @@ func TestMysqlXAConn_Recover(t *testing.T) { c := &MysqlXAConn{ Conn: mockConn, } - got, err := c.Recover(tt.args.flag) + got, err := c.Recover(context.Background(), tt.args.flag) if (err != nil) != tt.wantErr { t.Errorf("Recover() error = %v, wantErr %v", err, tt.wantErr) return @@ -295,10 +294,7 @@ func (m *mysqlMockRows) Next(dest []driver.Value) error { } return b } - cnt := min(len(m.data[0]), len(dest)) - fmt.Printf("cnt: %d", cnt) - for i := 0; i < cnt; i++ { dest[i] = m.data[m.idx][i] } diff --git a/pkg/datasource/sql/conn/oracle.go b/pkg/datasource/sql/xa/oracle_xa_connection.go similarity index 99% rename from pkg/datasource/sql/conn/oracle.go rename to pkg/datasource/sql/xa/oracle_xa_connection.go index 78820965..df03920f 100644 --- a/pkg/datasource/sql/conn/oracle.go +++ b/pkg/datasource/sql/xa/oracle_xa_connection.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package conn +package xa import ( "context" diff --git a/pkg/datasource/sql/conn/oracle_test.go b/pkg/datasource/sql/xa/oracle_xa_connection_test.go similarity index 99% rename from pkg/datasource/sql/conn/oracle_test.go rename to pkg/datasource/sql/xa/oracle_xa_connection_test.go index 8daeb585..6b8471e7 100644 --- a/pkg/datasource/sql/conn/oracle_test.go +++ b/pkg/datasource/sql/xa/oracle_xa_connection_test.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package conn +package xa import ( "database/sql/driver" diff --git a/pkg/datasource/sql/xa/xa_connection.go b/pkg/datasource/sql/xa/xa_connection.go deleted file mode 100644 index 0b07a879..00000000 --- a/pkg/datasource/sql/xa/xa_connection.go +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package xa - -import ( - "github.com/seata/seata-go/pkg/datasource/sql/exec" -) - -type XAConnection interface { - getXAResource() (exec.XAResource, error) -} diff --git a/pkg/datasource/sql/xa/xa_connection_proxy.go b/pkg/datasource/sql/xa/xa_connection_proxy.go deleted file mode 100644 index e447cc8c..00000000 --- a/pkg/datasource/sql/xa/xa_connection_proxy.go +++ /dev/null @@ -1,325 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package xa - -import ( - "context" - "fmt" - "time" - - "github.com/seata/seata-go/pkg/datasource/sql" - "github.com/seata/seata-go/pkg/datasource/sql/datasource" - "github.com/seata/seata-go/pkg/datasource/sql/exec" - "github.com/seata/seata-go/pkg/protocol/branch" - "github.com/seata/seata-go/pkg/protocol/message" - "github.com/seata/seata-go/pkg/rm" -) - -type ConnectionProxyXA struct { - xaBranchXid *XABranchXid - currentAutoCommitStatus bool `default:"true"` - xaActive bool `default:"false"` - kept bool `default:"false"` - rollBacked bool `default:"false"` - branchRegisterTime int64 `default:"0"` - prepareTime int64 `default:"0"` - timeout int `default:"0"` - proxyShouldBeHeld bool `default:"false"` - originalConnection sql.Conn - xaConnection XAConnection - xaResource exec.XAResource - resource sql.BaseDataSourceResource - xid string -} - -const timeout int = 60000 - -func NewConnectionProxyXA(originalConnection sql.Conn, xaConnection XAConnection, resource sql.BaseDataSourceResource, xid string) (*ConnectionProxyXA, error) { - connectionProxyXA := &ConnectionProxyXA{} - - connectionProxyXA.originalConnection = originalConnection - connectionProxyXA.xaConnection = xaConnection - connectionProxyXA.resource = resource - connectionProxyXA.xid = xid - - connectionProxyXA.proxyShouldBeHeld = connectionProxyXA.resource.IsShouldBeHeld() - - xaResource, err := xaConnection.getXAResource() - if err != nil { - return nil, fmt.Errorf("get xa resource failed") - } else { - connectionProxyXA.xaResource = xaResource - } - var rootContext sql.RootContext - transactionTimeout, ok := rootContext.GetTimeout() - if !ok { - transactionTimeout = timeout - } - if transactionTimeout < timeout { - transactionTimeout = timeout - } - connectionProxyXA.timeout = transactionTimeout - connectionProxyXA.currentAutoCommitStatus = connectionProxyXA.originalConnection.GetAutoCommit() - if !connectionProxyXA.currentAutoCommitStatus { - return nil, fmt.Errorf("connection[autocommit=false] as default is NOT supported") - } - - return connectionProxyXA, nil -} - -func (c *ConnectionProxyXA) keepIfNecessary() { - if c.ShouldBeHeld() { - c.resource.Hold(c.xaBranchXid.String(), c) - } -} - -func (c *ConnectionProxyXA) releaseIfNecessary() { - if c.ShouldBeHeld() { - if c.xaBranchXid == nil { - if c.IsHeld() { - c.resource.Release(c.xaBranchXid.String(), c) - } - } - } -} - -func (c *ConnectionProxyXA) XaCommit(xid string, branchId int64) error { - xaXid := Build(xid, branchId) - err := c.xaResource.Commit(xaXid.String(), false) - c.releaseIfNecessary() - return err -} - -func (c *ConnectionProxyXA) XaRollbackByBranchId(xid string, branchId int64) { - xaXid := Build(xid, branchId) - c.XaRollback(xaXid) -} - -func (c *ConnectionProxyXA) XaRollback(xaXid XAXid) error { - err := c.xaResource.Rollback(xaXid.GetGlobalXid()) - c.releaseIfNecessary() - return err -} - -func (c *ConnectionProxyXA) SetAutoCommit(autoCommit bool) error { - if c.currentAutoCommitStatus == autoCommit { - return nil - } - if autoCommit { - if c.xaActive { - _ = c.Commit() - } - } else { - if c.xaActive { - return fmt.Errorf("should NEVER happen: setAutoCommit from true to false while xa branch is active") - } - - c.branchRegisterTime = time.Now().UnixMilli() - var branchRegisterParam rm.BranchRegisterParam - branchRegisterParam.BranchType = branch.BranchTypeXA - branchRegisterParam.ResourceId = c.resource.GetResourceId() - branchRegisterParam.Xid = c.xid - branchId, err := datasource.GetDataSourceManager(branch.BranchTypeXA).BranchRegister(context.TODO(), branchRegisterParam) - if err != nil { - c.cleanXABranchContext() - return fmt.Errorf("failed to register xa branch [%v]", c.xid) - } - c.xaBranchXid = Build(c.xid, branchId) - c.keepIfNecessary() - err = c.start() - if err != nil { - c.cleanXABranchContext() - return fmt.Errorf("failed to start xa branch [%v]", c.xid) - } - c.xaActive = true - } - c.currentAutoCommitStatus = autoCommit - return nil -} - -func (c *ConnectionProxyXA) GetAutoCommit() bool { - return c.currentAutoCommitStatus -} - -func (c *ConnectionProxyXA) Commit() error { - if c.currentAutoCommitStatus { - return nil - } - if !c.xaActive || c.xaBranchXid == nil { - return fmt.Errorf("should NOT commit on an inactive session") - } - now := time.Now().UnixMilli() - if c.end(exec.TMSUCCESS) != nil { - return c.commitErrorHandle() - } - if c.checkTimeout(now) != nil { - return c.commitErrorHandle() - } - if c.xaResource.XAPrepare(c.xaBranchXid.String()) != nil { - return c.commitErrorHandle() - } - return nil -} - -func (c *ConnectionProxyXA) commitErrorHandle() error { - req := message.BranchReportRequest{ - BranchType: branch.BranchTypeXA, - Xid: c.xid, - BranchId: c.xaBranchXid.GetBranchId(), - Status: branch.BranchStatusPhaseoneFailed, - ApplicationData: nil, - ResourceId: c.resource.GetResourceId(), - } - if datasource.NewBasicSourceManager().BranchReport(context.TODO(), req) != nil { - c.cleanXABranchContext() - return fmt.Errorf("Failed to report XA branch commit-failure on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) - } - c.cleanXABranchContext() - return fmt.Errorf("Failed to end(TMSUCCESS)/prepare xa branch on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) -} - -func (c *ConnectionProxyXA) Rollback() error { - if c.currentAutoCommitStatus { - return nil - } - if !c.xaActive || c.xaBranchXid == nil { - return fmt.Errorf("should NOT rollback on an inactive session") - } - if !c.rollBacked { - if c.xaResource.End(c.xaBranchXid.String(), exec.TMFAIL) != nil { - return c.rollbackErrorHandle() - } - if c.XaRollback(c.xaBranchXid) != nil { - c.cleanXABranchContext() - return c.rollbackErrorHandle() - } - req := message.BranchReportRequest{ - BranchType: branch.BranchTypeXA, - Xid: c.xid, - BranchId: c.xaBranchXid.GetBranchId(), - Status: branch.BranchStatusPhaseoneFailed, - ApplicationData: nil, - ResourceId: c.resource.GetResourceId(), - } - if datasource.NewBasicSourceManager().BranchReport(context.TODO(), req) != nil { - c.cleanXABranchContext() - return fmt.Errorf("failed to report XA branch commit-failure on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) - } - } - c.cleanXABranchContext() - return nil -} - -func (c *ConnectionProxyXA) rollbackErrorHandle() error { - return fmt.Errorf("failed to end(TMFAIL) xa branch on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId()) -} - -func (c *ConnectionProxyXA) start() error { - err := c.xaResource.Start(c.xaBranchXid.String(), exec.TMNOFLAGS) - if err := c.termination(c.xaBranchXid.String()); err != nil { - c.xaResource.End(c.xaBranchXid.String(), exec.TMFAIL) - c.XaRollback(c.xaBranchXid) - return err - } - return err -} - -func (c *ConnectionProxyXA) end(flags int) error { - err := c.termination(c.xaBranchXid.String()) - if err != nil { - return err - } - err = c.xaResource.End(c.xaBranchXid.String(), flags) - if err != nil { - return err - } - return nil -} - -func (c *ConnectionProxyXA) cleanXABranchContext() { - c.branchRegisterTime = 0 - c.prepareTime = 0 - c.timeout = 0 - c.xaActive = false - if !c.IsHeld() { - c.xaBranchXid = nil - } -} - -func (c *ConnectionProxyXA) checkTimeout(now int64) error { - if now-c.branchRegisterTime > int64(c.timeout) { - c.XaRollback(c.xaBranchXid) - return fmt.Errorf("XA branch timeout error") - } - return nil -} - -func (c *ConnectionProxyXA) Close() error { - c.rollBacked = false - if c.IsHeld() && c.ShouldBeHeld() { - return nil - } - c.cleanXABranchContext() - if err := c.originalConnection.Close(); err != nil { - return err - } - return nil -} - -func (c *ConnectionProxyXA) CloseForce() error { - physicalConn := c.originalConnection - if err := physicalConn.Close(); err != nil { - return err - } - c.rollBacked = false - c.cleanXABranchContext() - if err := c.originalConnection.Close(); err != nil { - return err - } - c.releaseIfNecessary() - return nil -} - -func (c *ConnectionProxyXA) SetHeld(kept bool) { - c.kept = kept -} - -func (c *ConnectionProxyXA) IsHeld() bool { - return c.kept -} - -func (c *ConnectionProxyXA) ShouldBeHeld() bool { - return c.proxyShouldBeHeld || c.resource.GetDB() != nil -} - -func (c *ConnectionProxyXA) GetPrepareTime() int64 { - return c.prepareTime -} - -func (c *ConnectionProxyXA) setPrepareTime(prepareTime int64) { - c.prepareTime = prepareTime -} - -func (c *ConnectionProxyXA) termination(xaBranchXid string) error { - branchStatus, err := c.resource.GetBranchStatus(xaBranchXid) - if err != nil { - c.releaseIfNecessary() - return fmt.Errorf("failed xa branch [%v] the global transaction has finish, branch status: [%v]", c.xid, branchStatus) - } - return nil -} diff --git a/pkg/datasource/sql/xa/xa_resource.go b/pkg/datasource/sql/xa/xa_resource.go new file mode 100644 index 00000000..db26a643 --- /dev/null +++ b/pkg/datasource/sql/xa/xa_resource.go @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package xa + +import ( + "context" + "time" +) + +const ( + // TMEndRScan ends a recovery scan. + TMEndRScan = 0x00800000 + // TMFail disassociates the caller and marks the transaction branch + // rollback-only. + TMFail = 0x20000000 + + // TMJoin joining existing transaction branch. + TMJoin = 0x00200000 + + // TMNoFlags indicate no flags value is selected. + TMNoFlags = 0x00000000 + + // TMOnePhase using one-phase optimization. + TMOnePhase = 0x40000000 + + // TMResume is resuming association with a suspended transaction branch. + TMResume = 0x08000000 + + // TMStartRScan starts a recovery scan. + TMStartRScan = 0x01000000 + + // TMSuccess disassociates caller from a transaction branch. + TMSuccess = 0x04000000 + + // TMSuspend is suspending (not ending) its association with a transaction branch. + TMSuspend = 0x02000000 + + // XAReadOnly the transaction branch has been read-only and has been committed. + XAReadOnly = 0x00000003 + + // XAOk The transaction work has been prepared normally. + XAOk = 0 +) + +type XAResource interface { + Commit(ctx context.Context, xid string, onePhase bool) error + End(ctx context.Context, xid string, flags int) error + Forget(ctx context.Context, xid string) error + GetTransactionTimeout() time.Duration + IsSameRM(ctx context.Context, resource XAResource) bool + XAPrepare(ctx context.Context, xid string) error + Recover(ctx context.Context, flag int) ([]string, error) + Rollback(ctx context.Context, xid string) error + SetTransactionTimeout(duration time.Duration) bool + Start(ctx context.Context, xid string, flags int) error +} diff --git a/pkg/datasource/sql/exec/xa/default.go b/pkg/datasource/sql/xa/xa_resource_factory.go similarity index 58% rename from pkg/datasource/sql/exec/xa/default.go rename to pkg/datasource/sql/xa/xa_resource_factory.go index f5e76a44..1c6bac8d 100644 --- a/pkg/datasource/sql/exec/xa/default.go +++ b/pkg/datasource/sql/xa/xa_resource_factory.go @@ -18,12 +18,31 @@ package xa import ( - "github.com/seata/seata-go/pkg/datasource/sql/exec" + "database/sql/driver" + "fmt" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/util/log" ) -func Init() { - exec.RegisterXAExecutor(types.DBTypeMySQL, func() exec.SQLExecutor { - return &XAExecutor{} - }) +// CreateXAResource create a connection for xa with the different db type. +// Such as mysql, oracle, MARIADB, POSTGRESQL +func CreateXAResource(conn driver.Conn, dbType types.DBType) (XAResource, error) { + var err error + var xaConnection XAResource + switch dbType { + case types.DBTypeMySQL: + xaConnection = NewMysqlXaConn(conn) + case types.DBTypeOracle: + case types.DBTypePostgreSQL: + default: + err = fmt.Errorf("not support db type for :%s", dbType.String()) + } + + if err != nil { + log.Errorf(err.Error()) + return nil, err + } + + return xaConnection, nil } diff --git a/pkg/datasource/sql/xa/xa_branch_xid.go b/pkg/datasource/sql/xa_branch_xid.go similarity index 83% rename from pkg/datasource/sql/xa/xa_branch_xid.go rename to pkg/datasource/sql/xa_branch_xid.go index 52ad1470..b57a3266 100644 --- a/pkg/datasource/sql/xa/xa_branch_xid.go +++ b/pkg/datasource/sql/xa_branch_xid.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package xa +package sql import ( "strconv" @@ -23,13 +23,12 @@ import ( ) const ( - BranchIdPrefix = "-" - SeataXaXidFormatId = 9752 + branchIdPrefix = "-" ) type XABranchXid struct { xid string - branchId int64 + branchId uint64 globalTransactionId []byte branchQualifier []byte } @@ -63,14 +62,10 @@ func (x *XABranchXid) GetGlobalXid() string { return x.xid } -func (x *XABranchXid) GetBranchId() int64 { +func (x *XABranchXid) GetBranchId() uint64 { return x.branchId } -func (x *XABranchXid) GetFormatId() int { - return SeataXaXidFormatId -} - func (x *XABranchXid) GetGlobalTransactionId() []byte { return x.globalTransactionId } @@ -80,7 +75,7 @@ func (x *XABranchXid) GetBranchQualifier() []byte { } func (x *XABranchXid) String() string { - return x.xid + BranchIdPrefix + strconv.FormatInt(x.branchId, 10) + return x.xid + branchIdPrefix + strconv.FormatUint(x.branchId, 10) } func WithXid(xid string) Option { @@ -89,7 +84,7 @@ func WithXid(xid string) Option { } } -func WithBranchId(branchId int64) Option { +func WithBranchId(branchId uint64) Option { return func(x *XABranchXid) { x.branchId = branchId } @@ -113,7 +108,7 @@ func encode(x *XABranchXid) { } if x.branchId != 0 { - x.branchQualifier = []byte(BranchIdPrefix + strconv.FormatInt(x.branchId, 10)) + x.branchQualifier = []byte(branchIdPrefix + strconv.FormatUint(x.branchId, 10)) } } @@ -123,7 +118,7 @@ func decode(x *XABranchXid) { } if len(x.branchQualifier) > 0 { - branchId := strings.TrimLeft(string(x.branchQualifier), BranchIdPrefix) - x.branchId, _ = strconv.ParseInt(branchId, 10, 64) + branchId := strings.TrimLeft(string(x.branchQualifier), branchIdPrefix) + x.branchId, _ = strconv.ParseUint(branchId, 10, 64) } } diff --git a/pkg/datasource/sql/xa/xa_branch_xid_test.go b/pkg/datasource/sql/xa_branch_xid_test.go similarity index 87% rename from pkg/datasource/sql/xa/xa_branch_xid_test.go rename to pkg/datasource/sql/xa_branch_xid_test.go index 16b2340b..454e8a53 100644 --- a/pkg/datasource/sql/xa/xa_branch_xid_test.go +++ b/pkg/datasource/sql/xa_branch_xid_test.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package xa +package sql import ( "testing" @@ -25,8 +25,8 @@ import ( func TestXABranchXidBuild(t *testing.T) { xid := "111" - branchId := int64(222) - x := Build(xid, branchId) + branchId := uint64(222) + x := XaIdBuild(xid, branchId) assert.Equal(t, x.GetGlobalXid(), xid) assert.Equal(t, x.GetBranchId(), branchId) @@ -36,11 +36,11 @@ func TestXABranchXidBuild(t *testing.T) { func TestXABranchXidBuildWithByte(t *testing.T) { xid := []byte("111") - branchId := []byte(BranchIdPrefix + "222") - x := BuildWithByte(xid, branchId) + branchId := []byte(branchIdPrefix + "222") + x := XaIdBuildWithByte(xid, branchId) assert.Equal(t, x.GetGlobalTransactionId(), xid) assert.Equal(t, x.GetBranchQualifier(), branchId) assert.Equal(t, x.GetGlobalXid(), "111") - assert.Equal(t, x.GetBranchId(), int64(222)) + assert.Equal(t, x.GetBranchId(), uint64(222)) } diff --git a/pkg/datasource/sql/xa_resource_manager.go b/pkg/datasource/sql/xa_resource_manager.go new file mode 100644 index 00000000..e25505a9 --- /dev/null +++ b/pkg/datasource/sql/xa_resource_manager.go @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sql + +import ( + "context" + "database/sql" + "errors" + "flag" + "fmt" + "sync" + "time" + + "github.com/bluele/gcache" + + "github.com/seata/seata-go/pkg/datasource/sql/datasource" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "github.com/seata/seata-go/pkg/protocol/branch" + "github.com/seata/seata-go/pkg/rm" + "github.com/seata/seata-go/pkg/util/log" +) + +var branchStatusCache gcache.Cache + +type XAConnConf struct { + XaBranchExecutionTimeout time.Duration `json:"xa_branch_execution_timeout" xml:"xa_branch_execution_timeout" koanf:"xa_branch_execution_timeout"` +} + +func (cfg *XAConnConf) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { + f.DurationVar(&cfg.XaBranchExecutionTimeout, prefix+".xa_branch_execution_timeout", time.Minute, "Undo log table name.") +} + +type XAConfig struct { + xaConnConf XAConnConf + TwoPhaseHoldTime time.Duration `json:"two_phase_hold_time" yaml:"xa_two_phase_hold_time" koanf:"xa_two_phase_hold_time"` +} + +func (cfg *XAConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { + f.DurationVar(&cfg.TwoPhaseHoldTime, prefix+".two_phase_hold_time", time.Millisecond*1000, "Undo log table name.") + cfg.xaConnConf.RegisterFlagsWithPrefix(prefix, f) +} + +func InitXA(config XAConfig) *XAResourceManager { + xaSourceManager := &XAResourceManager{ + resourceCache: sync.Map{}, + basic: datasource.NewBasicSourceManager(), + rmRemoting: rm.GetRMRemotingInstance(), + config: config, + } + + xaConnTimeout = config.xaConnConf.XaBranchExecutionTimeout + + branchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build() + + rm.GetRmCacheInstance().RegisterResourceManager(xaSourceManager) + + go xaSourceManager.xaTwoPhaseTimeoutChecker() + + return xaSourceManager +} + +type XAResourceManager struct { + config XAConfig + resourceCache sync.Map + basic *datasource.BasicSourceManager + rmRemoting *rm.RMRemoting +} + +func (xaManager *XAResourceManager) xaTwoPhaseTimeoutChecker() { + var dbResource *DBResource + xaManager.resourceCache.Range(func(key, value any) bool { + if source, ok := value.(*DBResource); ok { + dbResource = source + } + return false + }) + + if dbResource.IsShouldBeHeld() { + ticker := time.NewTicker(time.Second) + for { + select { + case <-ticker.C: + xaManager.resourceCache.Range(func(key, value any) bool { + source, ok := value.(*DBResource) + if !ok { + return true + } + if source.IsShouldBeHeld() { + return true + } + + source.GetKeeper().Range(func(key, value any) bool { + connectionXA, isConnectionXA := value.(*XAConn) + if !isConnectionXA { + return true + } + + if time.Now().Sub(connectionXA.prepareTime) > xaManager.config.TwoPhaseHoldTime { + if err := connectionXA.CloseForce(); err != nil { + log.Errorf("Force close the xa xid:%s physical connection fail", connectionXA.txCtx.XID) + } + } + return true + }) + return true + }) + } + } + + } + +} + +func (xaManager *XAResourceManager) GetBranchType() branch.BranchType { + return branch.BranchTypeXA +} + +func (xaManager *XAResourceManager) GetCachedResources() *sync.Map { + return &xaManager.resourceCache +} + +func (xaManager *XAResourceManager) RegisterResource(res rm.Resource) error { + xaManager.resourceCache.Store(res.GetResourceId(), res) + return xaManager.basic.RegisterResource(res) +} + +func (xaManager *XAResourceManager) UnregisterResource(resource rm.Resource) error { + return xaManager.basic.UnregisterResource(resource) +} + +func (xaManager *XAResourceManager) xaIDBuilder(xid string, branchId uint64) XAXid { + return XaIdBuild(xid, branchId) +} + +func (xaManager *XAResourceManager) finishBranch(ctx context.Context, xaID XAXid, branchResource rm.BranchResource) (*XAConn, error) { + resource, ok := xaManager.resourceCache.Load(branchResource.ResourceId) + if !ok { + err := fmt.Errorf("unknow resource for rollback xa, resourceId: %s", branchResource.ResourceId) + log.Errorf(err.Error()) + return nil, err + } + + dbResource, ok := resource.(*DBResource) + if !ok { + err := fmt.Errorf("unknow resource for rollback xa, resourceId: %s", branchResource.ResourceId) + log.Errorf(err.Error()) + return nil, err + } + + connectionProxyXA, err := dbResource.ConnectionForXA(ctx, xaID) + if err != nil { + err := fmt.Errorf("get connection for rollback xa, resourceId: %s", branchResource.ResourceId) + log.Errorf(err.Error()) + return nil, err + } + + return connectionProxyXA, nil +} + +func (xaManager *XAResourceManager) BranchCommit(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) { + xaID := xaManager.xaIDBuilder(branchResource.Xid, uint64(branchResource.BranchId)) + connectionProxyXA, err := xaManager.finishBranch(ctx, xaID, branchResource) + if err != nil { + return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err + } + + if commitErr := connectionProxyXA.XaCommit(ctx, xaID.String(), branchResource.BranchId); commitErr != nil { + err := fmt.Errorf("rollback xa, resourceId: %s", branchResource.ResourceId) + log.Errorf(err.Error()) + setBranchStatus(xaID.String(), branch.BranchStatusPhasetwoCommitted) + return branch.BranchStatusPhasetwoCommitFailedUnretryable, err + } + + log.Infof("%s was committed", xaID.String()) + return branch.BranchStatusPhasetwoCommitted, nil +} + +func (xaManager *XAResourceManager) BranchRollback(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) { + xaID := xaManager.xaIDBuilder(branchResource.Xid, uint64(branchResource.BranchId)) + connectionProxyXA, err := xaManager.finishBranch(ctx, xaID, branchResource) + if err != nil { + return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err + } + + if rollbackErr := connectionProxyXA.XaRollbackByBranchId(ctx, xaID.String(), branchResource.BranchId); rollbackErr != nil { + err := fmt.Errorf("rollback xa, resourceId: %s", branchResource.ResourceId) + log.Errorf(err.Error()) + setBranchStatus(xaID.String(), branch.BranchStatusPhasetwoRollbacked) + return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err + } + + log.Infof("%s was rollback", xaID.String()) + return branch.BranchStatusPhasetwoRollbacked, nil +} + +func (xaManager *XAResourceManager) LockQuery(ctx context.Context, param rm.LockQueryParam) (bool, error) { + return false, nil +} + +func (xaManager *XAResourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) { + return xaManager.rmRemoting.BranchRegister(req) +} + +func (xaManager *XAResourceManager) BranchReport(ctx context.Context, param rm.BranchReportParam) error { + return xaManager.rmRemoting.BranchReport(param) +} + +func (xaManager *XAResourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (datasource.TableMetaCache, error) { + return xaManager.basic.CreateTableMetaCache(ctx, resID, dbType, db) +} + +func branchStatus(xaBranchXid string) (branch.BranchStatus, error) { + tmpBranchStatus, err := branchStatusCache.GetIFPresent(xaBranchXid) + if err != nil { + if errors.Is(err, gcache.KeyNotFoundError) { + return branch.BranchStatusUnknown, nil + } + return branch.BranchStatusUnknown, err + } + + branchStatus, isBranchStatus := tmpBranchStatus.(branch.BranchStatus) + if !isBranchStatus { + return branch.BranchStatusUnknown, fmt.Errorf("branchId:%s get result isn't branch status", xaBranchXid) + } + return branchStatus, nil +} + +func setBranchStatus(xaBranchXid string, status branch.BranchStatus) { + branchStatusCache.Set(xaBranchXid, status) +} diff --git a/pkg/datasource/sql/xa/xa_xid.go b/pkg/datasource/sql/xa_xid.go similarity index 85% rename from pkg/datasource/sql/xa/xa_xid.go rename to pkg/datasource/sql/xa_xid.go index 5114f278..51e7e74e 100644 --- a/pkg/datasource/sql/xa/xa_xid.go +++ b/pkg/datasource/sql/xa_xid.go @@ -15,16 +15,10 @@ * limitations under the License. */ -package xa - -type Xid interface { - GetFormatId() int - GetGlobalTransactionId() []byte - GetBranchQualifier() []byte -} +package sql type XAXid interface { - Xid GetGlobalXid() string - GetBranchId() int64 + GetBranchId() uint64 + String() string } diff --git a/pkg/datasource/sql/xa/xa_xid_builder.go b/pkg/datasource/sql/xa_xid_builder.go similarity index 85% rename from pkg/datasource/sql/xa/xa_xid_builder.go rename to pkg/datasource/sql/xa_xid_builder.go index 8afdab83..9efb6b62 100644 --- a/pkg/datasource/sql/xa/xa_xid_builder.go +++ b/pkg/datasource/sql/xa_xid_builder.go @@ -15,12 +15,12 @@ * limitations under the License. */ -package xa +package sql -func Build(xid string, branchId int64) *XABranchXid { +func XaIdBuild(xid string, branchId uint64) *XABranchXid { return NewXABranchXid(WithXid(xid), WithBranchId(branchId)) } -func BuildWithByte(globalTransactionId []byte, branchQualifier []byte) *XABranchXid { +func XaIdBuildWithByte(globalTransactionId []byte, branchQualifier []byte) *XABranchXid { return NewXABranchXid(WithGlobalTransactionId(globalTransactionId), WithBranchQualifier(branchQualifier)) } diff --git a/pkg/protocol/branch/branch.go b/pkg/protocol/branch/branch.go index 0b136ee9..e0e32c55 100644 --- a/pkg/protocol/branch/branch.go +++ b/pkg/protocol/branch/branch.go @@ -35,71 +35,42 @@ const ( ) const ( - /** - * The BranchStatus_Unknown. - * description:BranchStatus_Unknown branch status. - */ - BranchStatusUnknown = BranchStatus(0) - - /** - * The BranchStatus_Registered. - * description:BranchStatus_Registered to TC. - */ - BranchStatusRegistered = BranchStatus(1) - - /** - * The Phase one done. - * description:Branch logic is successfully done at phase one. - */ - BranchStatusPhaseoneDone = BranchStatus(2) - - /** - * The Phase one failed. - * description:Branch logic is failed at phase one. - */ - BranchStatusPhaseoneFailed = BranchStatus(3) - - /** - * The Phase one timeout. - * description:Branch logic is NOT reported for a timeout. - */ - BranchStatusPhaseoneTimeout = BranchStatus(4) - - /** - * The Phase two committed. - * description:Commit logic is successfully done at phase two. - */ - BranchStatusPhasetwoCommitted = BranchStatus(5) - - /** - * The Phase two commit failed retryable. - * description:Commit logic is failed but retryable. - */ - BranchStatusPhasetwoCommitFailedRetryable = BranchStatus(6) - - /** - * The Phase two commit failed unretryable. - * description:Commit logic is failed and NOT retryable. - */ - BranchStatusPhasetwoCommitFailedUnretryable = BranchStatus(7) - - /** - * The Phase two rollbacked. - * description:Rollback logic is successfully done at phase two. - */ - BranchStatusPhasetwoRollbacked = BranchStatus(8) - - /** - * The Phase two rollback failed retryable. - * description:Rollback logic is failed but retryable. - */ - BranchStatusPhasetwoRollbackFailedRetryable = BranchStatus(9) - - /** - * The Phase two rollback failed unretryable. - * description:Rollback logic is failed but NOT retryable. - */ - BranchStatusPhasetwoRollbackFailedUnretryable = BranchStatus(10) + // BranchStatusUnknown the BranchStatus_Unknown. description:BranchStatus_Unknown branch status. + BranchStatusUnknown = iota + + // BranchStatusRegistered the BranchStatus_Registered. description:BranchStatus_Registered to TC. + BranchStatusRegistered + + // BranchStatusPhaseoneDone the Phase one done. description:Branch logic is successfully done at phase one. + BranchStatusPhaseoneDone + + // BranchStatusPhaseoneFailed the Phase one failed. description:Branch logic is failed at phase one. + BranchStatusPhaseoneFailed + + // BranchStatusPhaseoneTimeout the Phase one timeout. description:Branch logic is NOT reported for a timeout. + BranchStatusPhaseoneTimeout + + // BranchStatusPhasetwoCommitted the Phase two committed. description:Commit logic is successfully done at phase two. + BranchStatusPhasetwoCommitted + + // BranchStatusPhasetwoCommitFailedRetryable the Phase two commit failed retryable. description:Commit logic is failed but retryable. + BranchStatusPhasetwoCommitFailedRetryable + + // BranchStatusPhasetwoCommitFailedUnretryable the Phase two commit failed unretryable. + // description:Commit logic is failed and NOT retryable. + BranchStatusPhasetwoCommitFailedUnretryable + + // BranchStatusPhasetwoRollbacked The Phase two rollbacked. + // description:Rollback logic is successfully done at phase two. + BranchStatusPhasetwoRollbacked + + // BranchStatusPhasetwoRollbackFailedRetryable the Phase two rollback failed retryable. + // description:Rollback logic is failed but retryable. + BranchStatusPhasetwoRollbackFailedRetryable + + // BranchStatusPhasetwoRollbackFailedUnretryable the Phase two rollback failed unretryable. + // description:Rollback logic is failed but NOT retryable. + BranchStatusPhasetwoRollbackFailedUnretryable ) func (s BranchStatus) String() string { diff --git a/pkg/rm/rm_api.go b/pkg/rm/rm_api.go index 7d2244b5..6e5149b6 100644 --- a/pkg/rm/rm_api.go +++ b/pkg/rm/rm_api.go @@ -31,7 +31,7 @@ type Resource interface { GetBranchType() branch.BranchType } -// branch resource which contains branch to commit or rollback +// BranchResource contains branch to commit or rollback type BranchResource struct { BranchType branch.BranchType Xid string @@ -42,9 +42,9 @@ type BranchResource struct { // ResourceManagerInbound Control a branch transaction commit or rollback type ResourceManagerInbound interface { - // Commit a branch transaction + // BranchCommit commit a branch transaction BranchCommit(ctx context.Context, resource BranchResource) (branch.BranchStatus, error) - // Rollback a branch transaction + // BranchRollback rollback a branch transaction BranchRollback(ctx context.Context, resource BranchResource) (branch.BranchStatus, error) } @@ -75,13 +75,13 @@ type LockQueryParam struct { LockKeys string } -// Resource Manager: send outbound request to TC +// ResourceManagerOutbound Resource Manager: send outbound request to TC type ResourceManagerOutbound interface { - // Branch register long + // BranchRegister rm register the branch transaction BranchRegister(ctx context.Context, param BranchRegisterParam) (int64, error) - // Branch report + // BranchReport branch transaction report the status BranchReport(ctx context.Context, param BranchReportParam) error - // Lock query boolean + // LockQuery lock query boolean LockQuery(ctx context.Context, param LockQueryParam) (bool, error) } @@ -90,13 +90,13 @@ type ResourceManager interface { ResourceManagerInbound ResourceManagerOutbound - // Register a Resource to be managed by Resource Manager + // RegisterResource register a resource to be managed by resource manager RegisterResource(resource Resource) error - // Unregister a Resource from the Resource Manager + // UnregisterResource unregister a resource from the Resource Manager UnregisterResource(resource Resource) error - // Get all resources managed by this manager + // GetCachedResources get all resources managed by this manager GetCachedResources() *sync.Map - // Get the BranchType + // GetBranchType get the branch type GetBranchType() branch.BranchType }