Browse Source

feat: add xa (#467)

* feat: add xa
tags/v1.2.0
georgehao GitHub 2 years ago
parent
commit
27cba64d6f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
46 changed files with 1184 additions and 1209 deletions
  1. +3
    -2
      pkg/client/client.go
  2. +5
    -3
      pkg/client/config.go
  3. +4
    -2
      pkg/datasource/init.go
  4. +7
    -9
      pkg/datasource/sql/at_resource_manager.go
  5. +7
    -1
      pkg/datasource/sql/conn.go
  6. +0
    -91
      pkg/datasource/sql/conn/resource_xa.go
  7. +1
    -2
      pkg/datasource/sql/conn_at.go
  8. +12
    -2
      pkg/datasource/sql/conn_at_test.go
  9. +306
    -12
      pkg/datasource/sql/conn_xa.go
  10. +54
    -2
      pkg/datasource/sql/conn_xa_test.go
  11. +46
    -1
      pkg/datasource/sql/connector.go
  12. +15
    -5
      pkg/datasource/sql/connector_test.go
  13. +4
    -56
      pkg/datasource/sql/datasource/datasource_manager.go
  14. +0
    -124
      pkg/datasource/sql/datasource_resource.go
  15. +137
    -43
      pkg/datasource/sql/db.go
  16. +30
    -67
      pkg/datasource/sql/driver.go
  17. +6
    -4
      pkg/datasource/sql/driver_test.go
  18. +2
    -2
      pkg/datasource/sql/exec/at/insert_executor.go
  19. +0
    -16
      pkg/datasource/sql/exec/executor.go
  20. +0
    -93
      pkg/datasource/sql/exec/resource_xa.go
  21. +1
    -0
      pkg/datasource/sql/exec/select_for_update_executor.go
  22. +0
    -84
      pkg/datasource/sql/exec/xa/executor_xa.go
  23. +6
    -1
      pkg/datasource/sql/mock/mock_datasource_manager.go
  24. +0
    -2
      pkg/datasource/sql/plugin.go
  25. +0
    -31
      pkg/datasource/sql/root_context.go
  26. +34
    -22
      pkg/datasource/sql/tx.go
  27. +1
    -0
      pkg/datasource/sql/tx_at.go
  28. +5
    -12
      pkg/datasource/sql/tx_xa.go
  29. +3
    -2
      pkg/datasource/sql/types/types.go
  30. +29
    -0
      pkg/datasource/sql/util/convert.go
  31. +14
    -9
      pkg/datasource/sql/util/convert_test.go
  32. +29
    -26
      pkg/datasource/sql/xa/mysql_xa_connection.go
  33. +13
    -17
      pkg/datasource/sql/xa/mysql_xa_connection_test.go
  34. +1
    -1
      pkg/datasource/sql/xa/oracle_xa_connection.go
  35. +1
    -1
      pkg/datasource/sql/xa/oracle_xa_connection_test.go
  36. +0
    -26
      pkg/datasource/sql/xa/xa_connection.go
  37. +0
    -325
      pkg/datasource/sql/xa/xa_connection_proxy.go
  38. +71
    -0
      pkg/datasource/sql/xa/xa_resource.go
  39. +24
    -5
      pkg/datasource/sql/xa/xa_resource_factory.go
  40. +9
    -14
      pkg/datasource/sql/xa_branch_xid.go
  41. +6
    -6
      pkg/datasource/sql/xa_branch_xid_test.go
  42. +245
    -0
      pkg/datasource/sql/xa_resource_manager.go
  43. +3
    -9
      pkg/datasource/sql/xa_xid.go
  44. +3
    -3
      pkg/datasource/sql/xa_xid_builder.go
  45. +36
    -65
      pkg/protocol/branch/branch.go
  46. +11
    -11
      pkg/rm/rm_api.go

+ 3
- 2
pkg/client/client.go View File

@@ -43,7 +43,7 @@ func InitPath(configFilePath string) {

initRmClient(cfg)
initTmClient(cfg)
initDatasource(cfg)
initDatasource()
}

var (
@@ -84,10 +84,11 @@ func initRmClient(cfg *Config) {
integration.Init()
tcc.InitTCC()
at.InitAT(cfg.ClientConfig.UndoConfig, cfg.AsyncWorkerConfig)
at.InitXA(cfg.ClientConfig.XaConfig)
})
}

func initDatasource(cfg *Config) {
func initDatasource() {
onceInitDatasource.Do(func() {
datasource.Init()
})


+ 5
- 3
pkg/client/config.go View File

@@ -54,15 +54,17 @@ const (
)

type ClientConfig struct {
TmConfig tm.TmConfig `yaml:"tm" json:"tm,omitempty" koanf:"tm"`
RmConfig rm.Config `yaml:"rm" json:"rm,omitempty" koanf:"rm"`
UndoConfig undo.Config `yaml:"undo" json:"undo,omitempty" koanf:"undo"`
TmConfig tm.TmConfig `yaml:"tm" json:"tm,omitempty" koanf:"tm"`
RmConfig rm.Config `yaml:"rm" json:"rm,omitempty" koanf:"rm"`
UndoConfig undo.Config `yaml:"undo" json:"undo,omitempty" koanf:"undo"`
XaConfig sql.XAConfig `yaml:"xa" json:"xa" koanf:"xa"`
}

func (c *ClientConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) {
c.TmConfig.RegisterFlagsWithPrefix(prefix+".tm", f)
c.RmConfig.RegisterFlagsWithPrefix(prefix+".rm", f)
c.UndoConfig.RegisterFlagsWithPrefix(prefix+".undo", f)
c.XaConfig.RegisterFlagsWithPrefix(prefix+".xa", f)
}

type Config struct {


+ 4
- 2
pkg/datasource/init.go View File

@@ -17,8 +17,10 @@

package datasource

import sql2 "github.com/seata/seata-go/pkg/datasource/sql"
import (
"github.com/seata/seata-go/pkg/datasource/sql"
)

func Init() {
sql2.Init()
sql.Init()
}

pkg/datasource/sql/at.go → pkg/datasource/sql/at_resource_manager.go View File

@@ -56,23 +56,23 @@ func (a *ATSourceManager) GetBranchType() branch.BranchType {
return branch.BranchTypeAT
}

// Get all resources managed by this manager
// GetCachedResources get all resources managed by this manager
func (a *ATSourceManager) GetCachedResources() *sync.Map {
return &a.resourceCache
}

// Register a Resource to be managed by Resource Manager
// RegisterResource register a Resource to be managed by Resource Manager
func (a *ATSourceManager) RegisterResource(res rm.Resource) error {
a.resourceCache.Store(res.GetResourceId(), res)
return a.basic.RegisterResource(res)
}

// Unregister a Resource from the Resource Manager
// UnregisterResource unregister a Resource from the Resource Manager
func (a *ATSourceManager) UnregisterResource(res rm.Resource) error {
return a.basic.UnregisterResource(res)
}

// Rollback a branch transaction
// BranchRollback rollback a branch transaction
func (a *ATSourceManager) BranchRollback(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) {
var dbResource *DBResource
if resource, ok := a.resourceCache.Load(branchResource.ResourceId); !ok {
@@ -103,28 +103,26 @@ func (a *ATSourceManager) BranchRollback(ctx context.Context, branchResource rm.
return branch.BranchStatusPhasetwoRollbacked, nil
}

// BranchCommit
// BranchCommit commit the branch transaction
func (a *ATSourceManager) BranchCommit(ctx context.Context, resource rm.BranchResource) (branch.BranchStatus, error) {
a.worker.BranchCommit(ctx, resource)
return branch.BranchStatusPhasetwoCommitted, nil
}

// LockQuery
func (a *ATSourceManager) LockQuery(ctx context.Context, param rm.LockQueryParam) (bool, error) {
return a.rmRemoting.LockQuery(param)
}

// BranchRegister
// BranchRegister branch transaction register
func (a *ATSourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) {
return a.rmRemoting.BranchRegister(req)
}

// BranchReport
// BranchReport Report status of transaction branch
func (a *ATSourceManager) BranchReport(ctx context.Context, param rm.BranchReportParam) error {
return a.rmRemoting.BranchReport(param)
}

// CreateTableMetaCache
func (a *ATSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType,
db *sql.DB) (datasource.TableMetaCache, error) {
return a.basic.CreateTableMetaCache(ctx, resID, dbType, db)

+ 7
- 1
pkg/datasource/sql/conn.go View File

@@ -211,7 +211,13 @@ func (c *Conn) Begin() (driver.Tx, error) {
//
// global transaction according to tranCtx. If so, it needs to be included in the transaction management of seata
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
c.autoCommit = false
if c.txCtx.TransactionMode == types.XAMode {
return newTx(
withDriverConn(c),
withTxCtx(c.txCtx),
withOriginTx(nil),
)
}

if conn, ok := c.targetConn.(driver.ConnBeginTx); ok {
tx, err := conn.BeginTx(ctx, opts)


+ 0
- 91
pkg/datasource/sql/conn/resource_xa.go View File

@@ -1,91 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package conn

import "time"

const (
// TMENDICANT Ends a recovery scan.
TMENDRSCAN = 0x00800000

/**
* Disassociates the caller and marks the transaction branch
* rollback-only.
*/
TMFAIL = 0x20000000

/**
* Caller is joining existing transaction branch.
*/
TMJOIN = 0x00200000

/**
* Use TMNOFLAGS to indicate no flags value is selected.
*/
TMNOFLAGS = 0x00000000

/**
* Caller is using one-phase optimization.
*/
TMONEPHASE = 0x40000000

/**
* Caller is resuming association with a suspended
* transaction branch.
*/
TMRESUME = 0x08000000

/**
* Starts a recovery scan.
*/
TMSTARTRSCAN = 0x01000000

/**
* Disassociates caller from a transaction branch.
*/
TMSUCCESS = 0x04000000

/**
* Caller is suspending (not ending) its association with
* a transaction branch.
*/
TMSUSPEND = 0x02000000

/**
* The transaction branch has been read-only and has been committed.
*/
XA_RDONLY = 0x00000003

/**
* The transaction work has been prepared normally.
*/
XA_OK = 0
)

type XAResource interface {
Commit(xid string, onePhase bool) error
End(xid string, flags int) error
Forget(xid string) error
GetTransactionTimeout() time.Duration
IsSameRM(resource XAResource) bool
XAPrepare(xid string) (int, error)
Recover(flag int) []string
Rollback(xid string) error
SetTransactionTimeout(duration time.Duration) bool
Start(xid string, flags int) error
}

+ 1
- 2
pkg/datasource/sql/conn_at.go View File

@@ -22,11 +22,10 @@ import (
gosql "database/sql"
"database/sql/driver"

"github.com/seata/seata-go/pkg/util/log"

"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/tm"
"github.com/seata/seata-go/pkg/util/log"
)

// ATConn Database connection proxy object under XA transaction model


+ 12
- 2
pkg/datasource/sql/conn_at_test.go View File

@@ -26,11 +26,13 @@ import (

"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"

"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/mock"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/tm"
"github.com/stretchr/testify/assert"
)

func TestMain(m *testing.M) {
@@ -41,7 +43,7 @@ func TestMain(m *testing.M) {
func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) {
ctrl := gomock.NewController(t)

mockMgr := initMockResourceManager(t, ctrl)
mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl)
_ = mockMgr

db, err := sql.Open(SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_client?multiStatements=true")
@@ -57,6 +59,14 @@ func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQL
mockConn := mock.NewMockTestDriverConn(ctrl)
mockConn.EXPECT().Begin().AnyTimes().Return(mockTx, nil)
mockConn.EXPECT().BeginTx(gomock.Any(), gomock.Any()).AnyTimes().Return(mockTx, nil)
mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(
func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
rows := &mysqlMockRows{}
rows.data = [][]interface{}{
{"8.0.29"},
}
return rows, nil
})
baseMockConn(mockConn)

connector := mock.NewMockTestDriverConnector(ctrl)


+ 306
- 12
pkg/datasource/sql/conn_xa.go View File

@@ -19,41 +19,74 @@ package sql

import (
"context"
gosql "database/sql"
"database/sql/driver"
"fmt"
"time"

"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/xa"
"github.com/seata/seata-go/pkg/tm"
"github.com/seata/seata-go/pkg/util/log"
)

var xaConnTimeout time.Duration

// XAConn Database connection proxy object under XA transaction model
// Conn is assumed to be stateful.
type XAConn struct {
*Conn

tx driver.Tx
xaResource xa.XAResource
xaBranchXid *XABranchXid
xaActive bool
rollBacked bool
branchRegisterTime time.Time
prepareTime time.Time
isConnKept bool
}

// QueryContext
func (c *XAConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if c.createOnceTxContext(ctx) {
defer func() {
c.txCtx = types.NewTxCtx()
}()
}

return c.Conn.QueryContext(ctx, query, args)
//ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types, error) {
// ret, err := c.Conn.PrepareContext(ctx, query)
// if err != nil {
// return nil, err
// }
// return types.NewResult(types.WithRows(ret)), nil
//})

return c.Conn.PrepareContext(ctx, query)
}

// PrepareContext
func (c *XAConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
// QueryContext exec xa sql
func (c *XAConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if c.createOnceTxContext(ctx) {
defer func() {
c.txCtx = types.NewTxCtx()
}()
}

return c.Conn.PrepareContext(ctx, query)
ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) {
ret, err := c.Conn.QueryContext(ctx, query, args)
if err != nil {
return nil, err
}
return types.NewResult(types.WithRows(ret)), nil
})

if err != nil {
return nil, err
}
return ret.GetRows(), nil
}

// ExecContext
func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if c.createOnceTxContext(ctx) {
defer func() {
@@ -61,26 +94,56 @@ func (c *XAConn) ExecContext(ctx context.Context, query string, args []driver.Na
}()
}

return c.Conn.ExecContext(ctx, query, args)
ret, err := c.createNewTxOnExecIfNeed(ctx, func() (types.ExecResult, error) {
ret, err := c.Conn.ExecContext(ctx, query, args)
if err != nil {
return nil, err
}
return types.NewResult(types.WithResult(ret)), nil
})

return ret.GetResult(), err
}

// BeginTx
// BeginTx like common transaction. but it just exec XA START
func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
c.autoCommit = false

c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.TxOpt = opts
c.txCtx.ResourceID = c.res.resourceID

if tm.IsGlobalTx(ctx) {
c.txCtx.TransactionMode = types.XAMode
c.txCtx.XID = tm.GetXID(ctx)
c.txCtx.TransactionMode = types.XAMode
}

tx, err := c.Conn.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
c.tx = tx

if c.autoCommit {
baseTx, ok := tx.(*Tx)
if !ok {
return nil, fmt.Errorf("start xa %s transaction failure for the tx is a wrong type", c.txCtx.XID)
}

c.branchRegisterTime = time.Now()
if err := baseTx.register(c.txCtx); err != nil {
c.cleanXABranchContext()
return nil, fmt.Errorf("failed to register xa branch %s, err:%w", c.txCtx.XID, err)
}

c.xaBranchXid = XaIdBuild(c.txCtx.XID, c.txCtx.BranchID)
c.keepIfNecessary()

if err = c.start(ctx); err != nil {
c.cleanXABranchContext()
return nil, fmt.Errorf("failed to start xa branch xid:%s err:%w", c.txCtx.XID, err)
}
c.xaActive = true
}

return &XATx{tx: tx.(*Tx)}, nil
}
@@ -91,9 +154,240 @@ func (c *XAConn) createOnceTxContext(ctx context.Context) bool {
if onceTx {
c.txCtx = types.NewTxCtx()
c.txCtx.DBType = c.res.dbType
c.txCtx.ResourceID = c.res.resourceID
c.txCtx.XID = tm.GetXID(ctx)
c.txCtx.TransactionMode = types.XAMode
c.txCtx.GlobalLockRequire = true
}

return onceTx
}

func (c *XAConn) createNewTxOnExecIfNeed(ctx context.Context, f func() (types.ExecResult, error)) (types.ExecResult, error) {
var err error
if c.txCtx.TransactionMode != types.Local && c.autoCommit {
_, err = c.BeginTx(ctx, driver.TxOptions{Isolation: driver.IsolationLevel(gosql.LevelDefault)})
if err != nil {
return nil, err
}
}
defer func() {
recoverErr := recover()
if err != nil || recoverErr != nil {
log.Errorf("conn at rollback error:%v or recoverErr:%v", err, recoverErr)
if c.tx != nil {
rollbackErr := c.tx.Rollback()
if rollbackErr != nil {
log.Errorf("conn at rollback error:%v", rollbackErr)
}
}
}
}()

// execute SQL
ret, err := f()
if err != nil {
// XA End & Rollback
if rollbackErr := c.Rollback(ctx); rollbackErr != nil {
log.Errorf("failed to rollback xa branch of :%s, err:%w", c.txCtx.XID, rollbackErr)
}
return nil, err
}

if c.autoCommit {
if err := c.Commit(ctx); err != nil {
log.Errorf("xa connection proxy commit failure xid:%s, err:%v", c.txCtx.XID, err)
// XA End & Rollback
if err := c.Rollback(ctx); err != nil {
log.Errorf("xa connection proxy rollback failure xid:%s, err:%v", c.txCtx.XID, err)
}
}
}

return ret, nil
}

func (c *XAConn) keepIfNecessary() {
if c.ShouldBeHeld() {
if err := c.res.Hold(c.xaBranchXid.String(), c); err == nil {
c.isConnKept = true
}
}
}

func (c *XAConn) releaseIfNecessary() {
if c.ShouldBeHeld() && c.xaBranchXid.String() != "" {
if c.isConnKept {
c.res.Release(c.xaBranchXid.String())
c.isConnKept = false
}
}
}

func (c *XAConn) start(ctx context.Context) error {
xaResource, err := xa.CreateXAResource(c.Conn.targetConn, c.dbType)
if err != nil {
return fmt.Errorf("create xa xid:%s resoruce err:%w", c.txCtx.XID, err)
}
c.xaResource = xaResource

if err := c.xaResource.Start(ctx, c.xaBranchXid.String(), xa.TMNoFlags); err != nil {
return fmt.Errorf("xa xid %s resource connection start err:%w", c.txCtx.XID, err)
}

if err := c.termination(c.xaBranchXid.String()); err != nil {
c.xaResource.End(ctx, c.xaBranchXid.String(), xa.TMFail)
c.XaRollback(ctx, c.xaBranchXid)
return err
}
return err
}

func (c *XAConn) end(ctx context.Context, flags int) error {
err := c.termination(c.xaBranchXid.String())
if err != nil {
return err
}
err = c.xaResource.End(ctx, c.xaBranchXid.String(), flags)
if err != nil {
return err
}
return nil
}

func (c *XAConn) termination(xaBranchXid string) error {
branchStatus, err := branchStatus(xaBranchXid)
if err != nil {
c.releaseIfNecessary()
return fmt.Errorf("failed xa branch [%v] the global transaction has finish, branch status: [%v]", c.txCtx.XID, branchStatus)
}
return nil
}

func (c *XAConn) cleanXABranchContext() {
h, _ := time.ParseDuration("-1000h")
c.branchRegisterTime = time.Now().Add(h)
c.prepareTime = time.Now().Add(h)
c.xaActive = false
if !c.isConnKept {
c.xaBranchXid = nil
}
}

func (c *XAConn) Rollback(ctx context.Context) error {
if c.autoCommit {
return nil
}

if !c.xaActive || c.xaBranchXid == nil {
return fmt.Errorf("should NOT rollback on an inactive session")
}

if !c.rollBacked {
if c.xaResource.End(ctx, c.xaBranchXid.String(), xa.TMFail) != nil {
return c.rollbackErrorHandle()
}
if c.XaRollback(ctx, c.xaBranchXid) != nil {
c.cleanXABranchContext()
return c.rollbackErrorHandle()
}
if err := c.tx.Rollback(); err != nil {
c.cleanXABranchContext()
return fmt.Errorf("failed to report XA branch commit-failure on xid:%s err:%w", c.txCtx.XID, err)
}
}
c.cleanXABranchContext()
return nil
}

func (c *XAConn) rollbackErrorHandle() error {
return fmt.Errorf("failed to end(TMFAIL) xa branch on [%v] - [%v]", c.txCtx.XID, c.xaBranchXid.GetBranchId())
}

func (c *XAConn) Commit(ctx context.Context) error {
if c.autoCommit {
return nil
}

if !c.xaActive || c.xaBranchXid == nil {
return fmt.Errorf("should NOT commit on an inactive session")
}

now := time.Now()
if c.end(ctx, xa.TMSuccess) != nil {
return c.commitErrorHandle()
}

if c.checkTimeout(ctx, now) != nil {
return c.commitErrorHandle()
}

if c.xaResource.XAPrepare(ctx, c.xaBranchXid.String()) != nil {
return c.commitErrorHandle()
}
return nil
}

func (c *XAConn) commitErrorHandle() error {
var err error
if err = c.tx.Rollback(); err != nil {
err = fmt.Errorf("failed to report XA branch commit-failure xid:%s, err:%w", c.txCtx.XID, err)
}
c.cleanXABranchContext()
return err
}

func (c *XAConn) ShouldBeHeld() bool {
return c.res.IsShouldBeHeld() || (c.res.GetDbType().String() != "" && c.res.GetDbType() != types.DBTypeUnknown)
}

func (c *XAConn) checkTimeout(ctx context.Context, now time.Time) error {
if now.Sub(c.branchRegisterTime) > xaConnTimeout {
c.XaRollback(ctx, c.xaBranchXid)
return fmt.Errorf("XA branch timeout error xid:%s", c.txCtx.XID)
}
return nil
}

func (c *XAConn) Close() error {
c.rollBacked = false
if c.isConnKept && c.ShouldBeHeld() {
return nil
}
c.cleanXABranchContext()
if err := c.Conn.Close(); err != nil {
return err
}
return nil
}

func (c *XAConn) CloseForce() error {
if err := c.Conn.Close(); err != nil {
return err
}
c.rollBacked = false
c.cleanXABranchContext()
if err := c.Conn.Close(); err != nil {
return err
}
c.releaseIfNecessary()
return nil
}

func (c *XAConn) XaCommit(ctx context.Context, xid string, branchId int64) error {
xaXid := XaIdBuild(xid, uint64(branchId))
err := c.xaResource.Commit(ctx, xaXid.String(), false)
c.releaseIfNecessary()
return err
}

func (c *XAConn) XaRollbackByBranchId(ctx context.Context, xid string, branchId int64) error {
xaXid := XaIdBuild(xid, uint64(branchId))
return c.XaRollback(ctx, xaXid)
}

func (c *XAConn) XaRollback(ctx context.Context, xaXid XAXid) error {
err := c.xaResource.Rollback(ctx, xaXid.GetGlobalXid())
c.releaseIfNecessary()
return err
}

+ 54
- 2
pkg/datasource/sql/conn_xa_test.go View File

@@ -21,18 +21,59 @@ import (
"context"
"database/sql"
"database/sql/driver"
"io"
"sync/atomic"
"testing"
"time"

"github.com/bluele/gcache"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"

"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/mock"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/tm"
"github.com/stretchr/testify/assert"
)

type mysqlMockRows struct {
idx int
data [][]interface{}
}

func (m *mysqlMockRows) Columns() []string {
//TODO implement me
panic("implement me")
}

func (m *mysqlMockRows) Close() error {
//TODO implement me
panic("implement me")
}

func (m *mysqlMockRows) Next(dest []driver.Value) error {
if m.idx == len(m.data) {
return io.EOF
}

min := func(a, b int) int {
if a < b {
return a
}
return b
}

cnt := min(len(m.data[0]), len(dest))

for i := 0; i < cnt; i++ {
dest[i] = m.data[m.idx][i]
}
m.idx++
return nil
}

type mockSQLInterceptor struct {
before func(ctx context.Context, execCtx *types.ExecContext)
after func(ctx context.Context, execCtx *types.ExecContext)
@@ -78,16 +119,27 @@ func (mi *mockTxHook) BeforeRollback(tx *Tx) {
}

func baseMockConn(mockConn *mock.MockTestDriverConn) {
branchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build()

mockConn.EXPECT().ExecContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil)
mockConn.EXPECT().Exec(gomock.Any(), gomock.Any()).AnyTimes().Return(&driver.ResultNoRows, nil)
mockConn.EXPECT().ResetSession(gomock.Any()).AnyTimes().Return(nil)
mockConn.EXPECT().Close().AnyTimes().Return(nil)

mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(
func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
rows := &mysqlMockRows{}
rows.data = [][]interface{}{
{"8.0.29"},
}
return rows, nil
})
}

func initXAConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQLInterceptor, *mockTxHook) {
ctrl := gomock.NewController(t)

mockMgr := initMockResourceManager(t, ctrl)
mockMgr := initMockResourceManager(branch.BranchTypeXA, ctrl)
_ = mockMgr
//db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
db, err := sql.Open("seata-xa-mysql", "root:12345678@tcp(127.0.0.1:3306)/seata_client?multiStatements=true&interpolateParams=true")


+ 46
- 1
pkg/datasource/sql/connector.go View File

@@ -20,11 +20,14 @@ package sql
import (
"context"
"database/sql/driver"
"errors"
"io"
"sync"

"github.com/go-sql-driver/mysql"

"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/util/log"
)

type seataATConnector struct {
@@ -89,7 +92,6 @@ func (c *seataXAConnector) Driver() driver.Driver {
// method will call Close and return error (if any).
type seataConnector struct {
transType types.TransactionMode
conf *seataServerConfig
res *DBResource
once sync.Once
driver driver.Driver
@@ -116,6 +118,15 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) {
return nil, err
}

// get the version of mysql for xa.
if c.transType == types.XAMode {
version, err := c.dbVersion(ctx, conn)
if err != nil {
return nil, err
}
c.res.SetDbVersion(version)
}

return &Conn{
targetConn: conn,
res: c.res,
@@ -126,6 +137,40 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) {
}, nil
}

func (c *seataConnector) dbVersion(ctx context.Context, conn driver.Conn) (string, error) {
queryConn, isQueryContext := conn.(driver.QueryerContext)
if !isQueryContext {
return "", errors.New("get db version error for unexpected driver conn")
}

res, err := queryConn.QueryContext(ctx, "SELECT VERSION()", nil)
if err != nil {
log.Errorf("seata connector get the xa mysql version err:%v", err)
return "", err
}

dest := make([]driver.Value, 1)
var version string
for true {
if err = res.Next(dest); err != nil {
if err == io.EOF {
return version, nil
}
return "", err
}
if len(dest) != 1 {
return "", errors.New("get the mysql version is not column 1")
}

var isVersionOk bool
version, isVersionOk = dest[0].(string)
if !isVersionOk {
return "", errors.New("get the mysql version is not a string")
}
}
return "", errors.New("get the mysql version is error")
}

// Driver returns the underlying Driver of the Connector,
// mainly to maintain compatibility with the Driver method
// on sql.DB.


+ 15
- 5
pkg/datasource/sql/connector_test.go View File

@@ -25,10 +25,12 @@ import (
"testing"

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"

"github.com/seata/seata-go/pkg/datasource/sql/mock"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/util/reflectx"
"github.com/stretchr/testify/assert"
)

type initConnectorFunc func(t *testing.T, ctrl *gomock.Controller) driver.Connector
@@ -38,6 +40,14 @@ func initMockConnector(t *testing.T, ctrl *gomock.Controller) driver.Connector {

connector := mock.NewMockTestDriverConnector(ctrl)
connector.EXPECT().Connect(gomock.Any()).AnyTimes().Return(mockConn, nil)
mockConn.EXPECT().QueryContext(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(
func(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
rows := &mysqlMockRows{}
rows.data = [][]interface{}{
{"8.0.29"},
}
return rows, nil
})
return connector
}

@@ -66,10 +76,10 @@ func Test_seataATConnector_Connect(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl)
_ = mockMgr

db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
db, err := sql.Open(SeataATMySQLDriver, "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
if err != nil {
t.Fatal(err)
}
@@ -110,10 +120,10 @@ func Test_seataXAConnector_Connect(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
mockMgr := initMockResourceManager(branch.BranchTypeXA, ctrl)
_ = mockMgr

db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
db, err := sql.Open(SeataXAMySQLDriver, "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
if err != nil {
t.Fatal(err)
}


+ 4
- 56
pkg/datasource/sql/datasource/datasource_manager.go View File

@@ -25,7 +25,6 @@ import (

"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/protocol/message"
"github.com/seata/seata-go/pkg/rm"
)

@@ -34,7 +33,7 @@ var (
tableMetaCacheMap = map[types.DBType]TableMetaCache{}
)

// RegisterTableCache
// RegisterTableCache register the table meta cache for at and xa
func RegisterTableCache(dbType types.DBType, tableMetaCache TableMetaCache) {
tableMetaCacheMap[dbType] = tableMetaCache
}
@@ -54,11 +53,8 @@ func GetDataSourceManager(branchType branch.BranchType) DataSourceManager {
return nil
}

// todo implements ResourceManagerOutbound interface
// DataSourceManager
type DataSourceManager interface {
rm.ResourceManager
// CreateTableMetaCache
CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (TableMetaCache, error)
}

@@ -67,9 +63,8 @@ type entry struct {
metaCache TableMetaCache
}

// BasicSourceManager
// BasicSourceManager the basic source manager for xa and at
type BasicSourceManager struct {
// lock
lock sync.RWMutex
// tableMetaCache
// todo do not put meta cache here
@@ -82,34 +77,7 @@ func NewBasicSourceManager() *BasicSourceManager {
}
}

// Commit a branch transaction
// TODO wait finish
func (dm *BasicSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) {
return branch.BranchStatusPhaseoneDone, nil
}

// Rollback a branch transaction
// TODO wait finish
func (dm *BasicSourceManager) BranchRollback(ctx context.Context, req message.BranchRollbackRequest) (branch.BranchStatus, error) {
return branch.BranchStatusPhaseoneFailed, nil
}

// Branch register long
func (dm *BasicSourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) {
return 0, nil
}

// Branch report
func (dm *BasicSourceManager) BranchReport(ctx context.Context, req message.BranchReportRequest) error {
return nil
}

// Lock query boolean
func (dm *BasicSourceManager) LockQuery(ctx context.Context, branchType branch.BranchType, resourceId, xid, lockKeys string) (bool, error) {
return true, nil
}

// Register a model.Resource to be managed by model.Resource Manager
// RegisterResource register a model.Resource to be managed by model.Resource Manager
func (dm *BasicSourceManager) RegisterResource(resource rm.Resource) error {
err := rm.GetRMRemotingInstance().RegisterResource(resource)
if err != nil {
@@ -118,22 +86,11 @@ func (dm *BasicSourceManager) RegisterResource(resource rm.Resource) error {
return nil
}

// Unregister a model.Resource from the model.Resource Manager
func (dm *BasicSourceManager) UnregisterResource(resource rm.Resource) error {
return fmt.Errorf("unsupport unregister resource")
}

// Get all resources managed by this manager
func (dm *BasicSourceManager) GetManagedResources() *sync.Map {
return nil
}

// Get the model.BranchType
func (dm *BasicSourceManager) GetBranchType() branch.BranchType {
return branch.BranchTypeAT
}

// CreateTableMetaCache
// CreateTableMetaCache create a table meta cache
func (dm *BasicSourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (TableMetaCache, error) {
dm.lock.Lock()
defer dm.lock.Unlock()
@@ -144,28 +101,19 @@ func (dm *BasicSourceManager) CreateTableMetaCache(ctx context.Context, resID st
}

dm.tableMetaCache[resID] = res

// 注册 AT 数据资源
// dm.resourceMgr.RegisterResource(ATResource)

return res.metaCache, err
}

// TableMetaCache tables metadata cache, default is open
type TableMetaCache interface {
// Init
Init(ctx context.Context, conn *sql.DB) error
// GetTableMeta
GetTableMeta(ctx context.Context, dbName, table string) (*types.TableMeta, error)
// Destroy
Destroy() error
}

// buildResource
// todo not here
func buildResource(ctx context.Context, dbType types.DBType, db *sql.DB) (*entry, error) {
cache := tableMetaCacheMap[dbType]

if err := cache.Init(ctx, db); err != nil {
return nil, err
}


+ 0
- 124
pkg/datasource/sql/datasource_resource.go View File

@@ -1,124 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sql

import (
"fmt"
"time"

"github.com/bluele/gcache"

"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/protocol/branch"
)

type Holdable interface {
SetHeld(held bool)
IsHeld() bool
ShouldBeHeld() bool
}

type BaseDataSourceResource struct {
db *DBResource
shouldBeHeld bool
keeper map[string]interface{}
Cache map[string]branch.BranchStatus
}

var BranchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build()

func (b *BaseDataSourceResource) init() error {
return nil
}

func (b *BaseDataSourceResource) GetDB() *DBResource {
return b.db
}

func (b *BaseDataSourceResource) SetDB(db *DBResource) {
b.db = db
}

func (b *BaseDataSourceResource) IsShouldBeHeld() bool {
return b.shouldBeHeld
}

func (b *BaseDataSourceResource) SetShouldBeHeld(shouldBeHeld bool) {
b.shouldBeHeld = shouldBeHeld
}

func (b *BaseDataSourceResource) GetKeeper() map[string]interface{} {
return b.keeper
}

func (b *BaseDataSourceResource) SetKeeper(keeper map[string]interface{}) {
b.keeper = keeper
}

func (b *BaseDataSourceResource) GetCache() map[string]branch.BranchStatus {
return b.Cache
}

func (b *BaseDataSourceResource) SetCache(cache map[string]branch.BranchStatus) {
b.Cache = cache
}

func (b *BaseDataSourceResource) GetResourceId() string {
return b.db.GetResourceId()
}

func (b *BaseDataSourceResource) Hold(key string, value Holdable) (interface{}, error) {
if value.IsHeld() {
var x = b.keeper[key]
if x != value {
return nil, fmt.Errorf("something wrong with keeper, keeping[%v] but[%v] is also kept with the same key[%v]", x, value, key)
}
return value, nil
}
var x = b.keeper[key]
b.keeper[key] = value
value.SetHeld(true)
return x, nil
}

func (b *BaseDataSourceResource) Release(key string, value Holdable) (interface{}, error) {
if value.IsHeld() {
var x = b.keeper[key]
if x != value {
return nil, fmt.Errorf("something wrong with keeper, keeping[%v] but[%v] is also kept with the same key[%v]", x, value, key)
}
return value, nil
}
var x = b.keeper[key]
b.keeper[key] = value
value.SetHeld(true)
return x, nil
}

func (b *BaseDataSourceResource) GetBranchStatus(xaBranchXid string) (interface{}, error) {
branchStatus, err := BranchStatusCache.GetIFPresent(xaBranchXid)
return branchStatus, err
}

func (b *BaseDataSourceResource) GetDbType() string {
return b.db.dbType.String()
}

func (b *BaseDataSourceResource) SetDbType(dbType types.DBType) {
b.db.dbType = dbType
}

+ 137
- 43
pkg/datasource/sql/db.go View File

@@ -18,19 +18,24 @@
package sql

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"sync"

"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/undo"
"github.com/seata/seata-go/pkg/datasource/sql/util"
"github.com/seata/seata-go/pkg/protocol/branch"
)

type dbOption func(db *DBResource)

func withGroupID(id string) dbOption {
func withDsn(dsn string) dbOption {
return func(db *DBResource) {
db.groupID = id
db.dsn = dsn
}
}

@@ -52,21 +57,33 @@ func withDBType(dt types.DBType) dbOption {
}
}

func withBranchType(dt branch.BranchType) dbOption {
return func(db *DBResource) {
db.branchType = dt
}
}

func withTarget(source *sql.DB) dbOption {
return func(db *DBResource) {
db.db = source
}
}

func withConnector(ci driver.Connector) dbOption {
return func(db *DBResource) {
db.connector = ci
}
}

func withDBName(dbName string) dbOption {
return func(db *DBResource) {
db.dbName = dbName
}
}

func withConf(conf *seataServerConfig) dbOption {
func withConf(conf *XAConnConf) dbOption {
return func(db *DBResource) {
db.conf = *conf
db.xaConnConf = conf
}
}

@@ -77,47 +94,36 @@ func newResource(opts ...dbOption) (*DBResource, error) {
opts[i](db)
}

return db, db.init()
db.init()
return db, nil
}

// DBResource proxy sql.DB, enchance database/sql.DB to add distribute transaction ability
type DBResource struct {
// groupID
groupID string
// resourceID
xaConnConf *XAConnConf
// only use by mysql
dbVersion string
dsn string
resourceID string
// conf
conf seataServerConfig
// db
db *sql.DB
dbName string
// dbType
dbType types.DBType
// undoLogMgr
db *sql.DB
connector driver.Connector
dbName string
dbType types.DBType
undoLogMgr undo.UndoLogManager
// metaCache
metaCache datasource.TableMetaCache
}
branchType branch.BranchType

func (db *DBResource) init() error {
return nil
// for xa
metaCache datasource.TableMetaCache
shouldBeHeld bool
keeper sync.Map
}

// todo do not put meta data to rm
//func (db *DBResource) init() error {
// mgr := datasource.GetDataSourceManager(db.GetBranchType())
// metaCache, err := mgr.CreateTableMetaCache(context.Background(), db.resourceID, db.dbType, db.db)
// if err != nil {
// return err
// }
//
// db.metaCache = metaCache
//
// return nil
//}

func (db *DBResource) GetResourceGroupId() string {
return db.groupID
panic("implement me")
}

func (db *DBResource) init() {
db.checkDbVersion()
}

func (db *DBResource) GetResourceId() string {
@@ -125,18 +131,106 @@ func (db *DBResource) GetResourceId() string {
}

func (db *DBResource) GetBranchType() branch.BranchType {
return db.conf.BranchType
return db.branchType
}

func (db *DBResource) GetDB() *sql.DB {
return db.db
}

type SqlDBProxy struct {
db *sql.DB
dbName string
func (db *DBResource) GetDBName() string {
return db.dbName
}

func (s *SqlDBProxy) GetDB() *sql.DB {
return s.db
func (db *DBResource) GetDbType() types.DBType {
return db.dbType
}

func (s *SqlDBProxy) GetDBName() string {
return s.dbName
func (db *DBResource) SetDbType(dbType types.DBType) {
db.dbType = dbType
}

func (db *DBResource) SetDbVersion(v string) {
db.dbVersion = v
}

func (db *DBResource) GetDbVersion() string {
return db.dbVersion
}

func (db *DBResource) IsShouldBeHeld() bool {
return db.shouldBeHeld
}

// Hold the xa connection.
func (db *DBResource) Hold(xaBranchID string, v interface{}) error {
_, exist := db.keeper.Load(xaBranchID)
if !exist {
db.keeper.Store(xaBranchID, v)
return nil
}
return nil
}

func (db *DBResource) Release(xaBranchID string) {
db.keeper.Delete(xaBranchID)
}

func (db *DBResource) Lookup(xaBranchID string) (interface{}, bool) {
return db.keeper.Load(xaBranchID)
}

func (db *DBResource) GetKeeper() *sync.Map {
return &db.keeper
}

func (db *DBResource) ConnectionForXA(ctx context.Context, xaXid XAXid) (*XAConn, error) {
xaBranchXid := xaXid.String()
tmpConn, ok := db.Lookup(xaBranchXid)
if ok && tmpConn != nil {
connectionProxyXa, isConnectionProxyXa := tmpConn.(*XAConn)
if !isConnectionProxyXa {
return nil, fmt.Errorf("get connection proxy xa from cache error, xid:%s", xaXid.String())
}
return connectionProxyXa, nil
}

// why here need a new connection?
// 1. because there maybe a rm cluster
// 2. the first phase select a rm1, and store the connection is the keeper
// 3. tc request the second phase. but the rm1 is shutdown, so the tc select another rm (like rm2)
// 4. so when the second phase request coming to rm2, rm2 must not store the connection.
// 5. the rm2 get the second phase do the two thing.
// 1. in mysql version >= 8.0.29, mysql support the xa transaction commit by another connection. so just commit
// 2. when the version < 8.0.29. so just make the transaction rollback
newDriverConn, err := db.connector.Connect(ctx)
if err != nil {
return nil, fmt.Errorf("get xa new connection failure, xid:%s, err:%v", xaXid.String(), err)
}
xaConn := &XAConn{
Conn: newDriverConn.(*Conn),
}
return xaConn, nil
}

func (db *DBResource) checkDbVersion() error {
switch db.dbType {
case types.DBTypeMySQL:
currentVersion, err := util.ConvertDbVersion(db.dbVersion)
if err != nil {
return fmt.Errorf("new connection xa proxy convert db version:%s err:%v", db.GetDbVersion(), err)
}

shouldKeptVersion, err := util.ConvertDbVersion("8.0.29")
if err != nil {
return fmt.Errorf("new connection xa proxy convert db version 8.0.29 err:%v", err)
}

if currentVersion < shouldKeptVersion {
db.shouldBeHeld = true
}
case types.DBTypeMARIADB:
db.shouldBeHeld = true
}
return nil
}

+ 30
- 67
pkg/datasource/sql/driver.go View File

@@ -25,10 +25,10 @@ import (
"fmt"
"strings"

mysql2 "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql"

"github.com/go-sql-driver/mysql"

"github.com/seata/seata-go/pkg/datasource/sql/datasource"
mysql2 "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/util/log"
@@ -44,15 +44,17 @@ const (
func initDriver() {
sql.Register(SeataATMySQLDriver, &seataATDriver{
seataDriver: &seataDriver{
transType: types.ATMode,
target: mysql.MySQLDriver{},
branchType: branch.BranchTypeAT,
transType: types.ATMode,
target: mysql.MySQLDriver{},
},
})

sql.Register(SeataXAMySQLDriver, &seataXADriver{
seataDriver: &seataDriver{
transType: types.XAMode,
target: mysql.MySQLDriver{},
branchType: branch.BranchTypeXA,
transType: types.XAMode,
target: mysql.MySQLDriver{},
},
})
}
@@ -98,8 +100,9 @@ func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err erro
}

type seataDriver struct {
transType types.TransactionMode
target driver.Driver
branchType branch.BranchType
transType types.TransactionMode
target driver.Driver
}

// Open never be called, because seataDriver implemented dri.DriverContext interface.
@@ -124,7 +127,7 @@ func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error)
return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName())
}

proxy, err := getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name)
proxy, err := d.getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name)
if err != nil {
log.Errorf("register resource: %w", err)
return nil, err
@@ -133,43 +136,15 @@ func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error)
return proxy, nil
}

func (d *seataDriver) getTargetDriverName() string {
return "mysql"
}

type dsnConnector struct {
dsn string
driver driver.Driver
}

func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
return t.driver.Open(t.dsn)
}

func (t *dsnConnector) Driver() driver.Driver {
return t.driver
}

func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db *sql.DB,
dataSourceName string, opts ...seataOption) (driver.Connector, error) {
conf := loadConfig()
for i := range opts {
opts[i](conf)
}

if err := conf.validate(); err != nil {
log.Errorf("invalid conf: %w", err)
return nil, err
}

func (d *seataDriver) getOpenConnectorProxy(connector driver.Connector, dbType types.DBType,
db *sql.DB, dataSourceName string) (driver.Connector, error) {
cfg, _ := mysql.ParseDSN(dataSourceName)
options := []dbOption{
withGroupID(conf.GroupID),
withResourceID(parseResourceID(dataSourceName)),
withConf(conf),
withTarget(db),
withBranchType(d.branchType),
withDBType(dbType),
withDBName(cfg.DBName),
withConnector(connector),
}

res, err := newResource(options...)
@@ -179,7 +154,8 @@ func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db *
}

datasource.RegisterTableCache(types.DBTypeMySQL, mysql2.NewTableMetaInstance(db))
if err = datasource.GetDataSourceManager(conf.BranchType).RegisterResource(res); err != nil {

if err = datasource.GetDataSourceManager(d.branchType).RegisterResource(res); err != nil {
log.Errorf("regisiter resource: %w", err)
return nil, err
}
@@ -187,38 +163,25 @@ func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db *
return &seataConnector{
res: res,
target: connector,
conf: conf,
cfg: cfg,
}, nil
}

type (
seataOption func(cfg *seataServerConfig)

// seataServerConfig
seataServerConfig struct {
// GroupID
GroupID string `yaml:"groupID"`
// BranchType
BranchType branch.BranchType
// Endpoints
Endpoints []string `yaml:"endpoints" json:"endpoints"`
}
)
func (d *seataDriver) getTargetDriverName() string {
return "mysql"
}

func (c *seataServerConfig) validate() error {
return nil
type dsnConnector struct {
dsn string
driver driver.Driver
}

// loadConfig
func loadConfig() *seataServerConfig {
// set default value first.
// todo read from configuration file.
return &seataServerConfig{
GroupID: "DEFAULT_GROUP",
BranchType: branch.BranchTypeAT,
Endpoints: []string{"127.0.0.1:8888"},
}
func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
return t.driver.Open(t.dsn)
}

func (t *dsnConnector) Driver() driver.Driver {
return t.driver
}

func parseResourceID(dsn string) string {


+ 6
- 4
pkg/datasource/sql/driver_test.go View File

@@ -28,12 +28,14 @@ import (

"github.com/golang/mock/gomock"
"github.com/seata/seata-go/pkg/datasource/sql/mock"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/util/reflectx"
"github.com/stretchr/testify/assert"
)

func initMockResourceManager(t *testing.T, ctrl *gomock.Controller) *mock.MockDataSourceManager {
func initMockResourceManager(branchType branch.BranchType, ctrl *gomock.Controller) *mock.MockDataSourceManager {
mockResourceMgr := mock.NewMockDataSourceManager(ctrl)
mockResourceMgr.SetBranchType(branchType)
rm.GetRmCacheInstance().RegisterResourceManager(mockResourceMgr)
mockResourceMgr.EXPECT().RegisterResource(gomock.Any()).AnyTimes().Return(nil)
mockResourceMgr.EXPECT().CreateTableMetaCache(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil, nil)
@@ -45,7 +47,7 @@ func Test_seataATDriver_Open(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl)
_ = mockMgr

db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
@@ -93,7 +95,7 @@ func Test_seataATDriver_OpenConnector(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl)
_ = mockMgr

db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")
@@ -119,7 +121,7 @@ func Test_seataXADriver_OpenConnector(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockMgr := initMockResourceManager(t, ctrl)
mockMgr := initMockResourceManager(branch.BranchTypeAT, ctrl)
_ = mockMgr

db, err := sql.Open("seata-xa-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true")


+ 2
- 2
pkg/datasource/sql/exec/at/insert_executor.go View File

@@ -514,8 +514,8 @@ func canAutoIncrement(pkMetaMap map[string]types.ColumnMeta) bool {
return false
}

func (u *insertExecutor) isAstStmtValid() bool {
return u.parserCtx != nil && u.parserCtx.InsertStmt != nil
func (i *insertExecutor) isAstStmtValid() bool {
return i.parserCtx != nil && i.parserCtx.InsertStmt != nil
}

func (i *insertExecutor) autoGeneratePks(execCtx *types.ExecContext, autoColumnName string, lastInsetId, updateCount int64) (map[string][]interface{}, error) {


+ 0
- 16
pkg/datasource/sql/exec/executor.go View File

@@ -27,7 +27,6 @@ import (

var (
atExecutors = make(map[types.DBType]func() SQLExecutor)
xaExecutors = make(map[types.DBType]func() SQLExecutor)
)

// RegisterATExecutor AT executor
@@ -35,13 +34,6 @@ func RegisterATExecutor(dt types.DBType, builder func() SQLExecutor) {
atExecutors[dt] = builder
}

// RegisterXAExecutor XA executor
func RegisterXAExecutor(dt types.DBType, builder func() SQLExecutor) {
xaExecutors[dt] = func() SQLExecutor {
return builder()
}
}

type (
CallbackWithNamedValue func(ctx context.Context, query string, args []driver.NamedValue) (types.ExecResult, error)

@@ -66,12 +58,6 @@ func BuildExecutor(dbType types.DBType, transactionMode types.TransactionMode, q
hooks = append(hooks, commonHook...)
hooks = append(hooks, hookSolts[parseContext.SQLType]...)

if transactionMode == types.XAMode {
e := xaExecutors[dbType]()
e.Interceptors(hooks)
return e, nil
}

e := atExecutors[dbType]()
e.Interceptors(hooks)
return e, nil
@@ -82,12 +68,10 @@ type BaseExecutor struct {
ex SQLExecutor
}

// Interceptors
func (e *BaseExecutor) Interceptors(interceptors []SQLHook) {
e.hooks = interceptors
}

// ExecWithNamedValue
func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) {
for i := range e.hooks {
e.hooks[i].Before(ctx, execCtx)


+ 0
- 93
pkg/datasource/sql/exec/resource_xa.go View File

@@ -1,93 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package exec

import (
"time"
)

const (
// TMENDICANT Ends a recovery scan.
TMENDRSCAN = 0x00800000

/**
* Disassociates the caller and marks the transaction branch
* rollback-only.
*/
TMFAIL = 0x20000000

/**
* Caller is joining existing transaction branch.
*/
TMJOIN = 0x00200000

/**
* Use TMNOFLAGS to indicate no flags value is selected.
*/
TMNOFLAGS = 0x00000000

/**
* Caller is using one-phase optimization.
*/
TMONEPHASE = 0x40000000

/**
* Caller is resuming association with a suspended
* transaction branch.
*/
TMRESUME = 0x08000000

/**
* Starts a recovery scan.
*/
TMSTARTRSCAN = 0x01000000

/**
* Disassociates caller from a transaction branch.
*/
TMSUCCESS = 0x04000000

/**
* Caller is suspending (not ending) its association with
* a transaction branch.
*/
TMSUSPEND = 0x02000000

/**
* The transaction branch has been read-only and has been committed.
*/
XA_RDONLY = 0x00000003

/**
* The transaction work has been prepared normally.
*/
XA_OK = 0
)

type XAResource interface {
Commit(xid string, onePhase bool) error
End(xid string, flags int) error
Forget(xid string) error
GetTransactionTimeout() time.Duration
IsSameRM(resource XAResource) bool
XAPrepare(xid string) error
Recover(flag int) ([]string, error)
Rollback(xid string) error
SetTransactionTimeout(duration time.Duration) bool
Start(xid string, flags int) error
}

+ 1
- 0
pkg/datasource/sql/exec/select_for_update_executor.go View File

@@ -30,6 +30,7 @@ import (
"github.com/arana-db/parser/ast"
"github.com/arana-db/parser/format"
"github.com/arana-db/parser/model"

"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/undo/builder"


+ 0
- 84
pkg/datasource/sql/exec/xa/executor_xa.go View File

@@ -1,84 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xa

import (
"context"
"database/sql/driver"

"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/types"
)

// XAExecutor The XA transaction manager.
type XAExecutor struct {
hooks []exec.SQLHook
ex exec.SQLExecutor
}

// Interceptors set xa executor hooks
func (e *XAExecutor) Interceptors(hooks []exec.SQLHook) {
e.hooks = hooks
}

// ExecWithNamedValue
func (e *XAExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
for _, hook := range e.hooks {
hook.Before(ctx, execCtx)
}

defer func() {
for _, hook := range e.hooks {
hook.After(ctx, execCtx)
}
}()

if e.ex != nil {
return e.ex.ExecWithNamedValue(ctx, execCtx, f)
}

return f(ctx, execCtx.Query, execCtx.NamedValues)
}

// ExecWithValue
func (e *XAExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecContext, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
for _, hook := range e.hooks {
hook.Before(ctx, execCtx)
}

defer func() {
for _, hook := range e.hooks {
hook.After(ctx, execCtx)
}
}()

if e.ex != nil {
return e.ex.ExecWithValue(ctx, execCtx, f)
}

nvargs := make([]driver.NamedValue, len(execCtx.Values))
for i, value := range execCtx.Values {
nvargs = append(nvargs, driver.NamedValue{
Value: value,
Ordinal: i,
})
}
execCtx.NamedValues = nvargs

return f(ctx, execCtx.Query, execCtx.NamedValues)
}

+ 6
- 1
pkg/datasource/sql/mock/mock_datasource_manager.go View File

@@ -38,6 +38,7 @@ import (
type MockDataSourceManager struct {
ctrl *gomock.Controller
recorder *MockDataSourceManagerMockRecorder
branchType branch.BranchType
}

// MockDataSourceManagerMockRecorder is the mock recorder for MockDataSourceManager.
@@ -131,9 +132,13 @@ func (mr *MockDataSourceManagerMockRecorder) CreateTableMetaCache(ctx, resID, db
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTableMetaCache", reflect.TypeOf((*MockDataSourceManager)(nil).CreateTableMetaCache), ctx, resID, dbType, db)
}

func (m *MockDataSourceManager) SetBranchType(branchType branch.BranchType) {
m.branchType = branchType
}

// GetBranchType mocks base method.
func (m *MockDataSourceManager) GetBranchType() branch.BranchType {
return branch.BranchTypeAT
return m.branchType
}

// GetBranchType indicates an expected call of GetBranchType.


+ 0
- 2
pkg/datasource/sql/plugin.go View File

@@ -20,7 +20,6 @@ package sql
import (
"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/exec/at"
"github.com/seata/seata-go/pkg/datasource/sql/exec/xa"
"github.com/seata/seata-go/pkg/datasource/sql/hook"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/undo"
@@ -42,7 +41,6 @@ func hookRegister() {

func executorRegister() {
at.Init()
xa.Init()
}

func undoInit() {


+ 0
- 31
pkg/datasource/sql/root_context.go View File

@@ -1,31 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sql

import (
"github.com/seata/seata-go/pkg/protocol/branch"
)

type RootContext interface {
RootContext()
SetDefaultBranchType(branchType branch.BranchType)
GetXID() string
Bind(xid string)
GetTimeout() (int, bool)
SetTimeout(timeout int)
}

+ 34
- 22
pkg/datasource/sql/tx.go View File

@@ -22,17 +22,16 @@ import (
"database/sql/driver"
"fmt"
"sync"
"time"

"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/rm"
"github.com/seata/seata-go/pkg/util/backoff"
"github.com/seata/seata-go/pkg/util/log"

"github.com/seata/seata-go/pkg/datasource/sql/types"
)

const REPORT_RETRY_COUNT = 5

var (
hl sync.RWMutex
txHooks []txHook
@@ -146,19 +145,31 @@ func (tx *Tx) commitOnLocal() error {

// register
func (tx *Tx) register(ctx *types.TransactionContext) error {
if !ctx.HasUndoLog() || !ctx.HasLockKey() {
if ctx.TransactionMode.BranchType() == branch.BranchTypeUnknow {
return nil
}
lockKey := ""
for k, _ := range ctx.LockKeys {
lockKey += k + ";"
if ctx.TransactionMode.BranchType() == branch.BranchTypeAT && !ctx.HasUndoLog() || !ctx.HasLockKey() {
return nil
}

request := rm.BranchRegisterParam{
Xid: ctx.XID,
BranchType: ctx.TransactionMode.BranchType(),
ResourceId: ctx.ResourceID,
LockKeys: lockKey,
}

var lockKey string
if ctx.TransactionMode == types.ATMode {
if !ctx.HasUndoLog() || !ctx.HasLockKey() {
return nil
}
for k, _ := range ctx.LockKeys {
lockKey += k + ";"
}
request.LockKeys = lockKey
}

dataSourceManager := datasource.GetDataSourceManager(ctx.TransactionMode.BranchType())
branchId, err := dataSourceManager.BranchRegister(context.Background(), request)
if err != nil {
@@ -184,21 +195,22 @@ func (tx *Tx) report(success bool) error {
if dataSourceManager == nil {
return fmt.Errorf("get dataSourceManager failed")
}
retry := REPORT_RETRY_COUNT
for retry > 0 {
err := dataSourceManager.BranchReport(context.Background(), request)
if err != nil {
retry--
log.Infof("Failed to report [%s / %s] commit done [%s] Retry Countdown: %s", tx.tranCtx.BranchID, tx.tranCtx.XID, success, retry)
if retry == 0 {
log.Errorf("Failed to report branch status: %s", err.Error())
return err
}
} else {
return nil

retry := backoff.New(context.Background(), backoff.Config{
MinBackoff: 100 * time.Millisecond,
MaxBackoff: 200 * time.Millisecond,
MaxRetries: 5,
})

var err error
for retry.Ongoing() {
if err = dataSourceManager.BranchReport(context.Background(), request); err == nil {
break
}
log.Infof("Failed to report [%s / %s] commit done [%s] Retry Countdown: %s", tx.tranCtx.BranchID, tx.tranCtx.XID, success, retry)
retry.Wait()
}
return nil
return err
}

func getStatus(success bool) branch.BranchStatus {


+ 1
- 0
pkg/datasource/sql/tx_at.go View File

@@ -19,6 +19,7 @@ package sql

import (
"github.com/pkg/errors"

"github.com/seata/seata-go/pkg/datasource/sql/undo"
)



+ 5
- 12
pkg/datasource/sql/tx_xa.go View File

@@ -17,7 +17,6 @@

package sql

// XATx
type XATx struct {
tx *Tx
}
@@ -32,20 +31,14 @@ func (tx *XATx) Commit() error {
}

func (tx *XATx) Rollback() error {
err := tx.tx.Rollback()
if err != nil {

originTx := tx.tx

if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() {
originTx.report(false)
}
originTx := tx.tx
if originTx.tranCtx.OpenGlobalTransaction() && originTx.tranCtx.IsBranchRegistered() {
return originTx.report(false)
}

return err
return nil
}

// commitOnXA
// commitOnXA commit xa and register branch transaction
func (tx *XATx) commitOnXA() error {
return nil
}

+ 3
- 2
pkg/datasource/sql/types/types.go View File

@@ -22,9 +22,9 @@ import (
"fmt"
"strings"

"github.com/seata/seata-go/pkg/protocol/branch"

"github.com/google/uuid"

"github.com/seata/seata-go/pkg/protocol/branch"
)

type DBType int16
@@ -76,6 +76,7 @@ const (
DBTypePostgreSQL
DBTypeSQLServer
DBTypeOracle
DBTypeMARIADB

BranchPhase_Unknown = 0
BranchPhase_Done = 1


+ 29
- 0
pkg/datasource/sql/util/convert.go View File

@@ -28,8 +28,10 @@ import (
"database/sql/driver"
"errors"
"fmt"
"math"
"reflect"
"strconv"
"strings"
"time"
)

@@ -375,3 +377,30 @@ type decimalCompose interface {
// represented then an error should be returned.
Compose(form byte, negative bool, coefficient []byte, exponent int32) error
}

// ConvertDbVersion convert a string db version to a number.
func ConvertDbVersion(version string) (int, error) {
parts := strings.Split(version, ".")
size := len(parts)
maxVersionDot := 3
if size > maxVersionDot+1 {
return 0, fmt.Errorf("incompatible version format: %s", version)
}

var res int
for idx, part := range parts {
if partInt, err := strconv.Atoi(part); err == nil {
res += calculatePartValue(partInt, size, idx)
} else {
subParts := strings.Split(part, "-")
if subPartInt, err := strconv.Atoi(subParts[0]); err == nil {
res += calculatePartValue(subPartInt, size, idx)
}
}
}
return res, nil
}

func calculatePartValue(partNumeric, size, index int) int {
return partNumeric * int(math.Pow(100, float64(size-index)))
}

pkg/datasource/sql/conn_test.go → pkg/datasource/sql/util/convert_test.go View File

@@ -15,22 +15,27 @@
* limitations under the License.
*/

package sql
package util

import (
"testing"

"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/datasource/sql/exec/xa"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/stretchr/testify/assert"
)

func TestConn_BuildXAExecutor(t *testing.T) {
executor, err := exec.BuildExecutor(types.DBTypeMySQL, types.XAMode, "SELECT * FROM user")
func TestConvertDbVersion(t *testing.T) {
version1 := "3.1.2"
v1Int, err1 := ConvertDbVersion(version1)
assert.NoError(t, err1)

assert.NoError(t, err)
version2 := "3.1.3"
v2Int, err2 := ConvertDbVersion(version2)
assert.NoError(t, err2)

_, ok := executor.(*xa.XAExecutor)
assert.True(t, ok, "need xa executor")
assert.Less(t, v1Int, v2Int)

version3 := "3.1.3"
v3Int, err3 := ConvertDbVersion(version3)
assert.NoError(t, err3)
assert.Equal(t, v2Int, v3Int)
}

pkg/datasource/sql/exec/mysql_xa_resource.go → pkg/datasource/sql/xa/mysql_xa_connection.go View File

@@ -15,24 +15,27 @@
* limitations under the License.
*/

package exec
package xa

import (
"context"
"database/sql/driver"
"errors"
"fmt"
"io"
"strings"
"time"

"github.com/pkg/errors"
)

type MysqlXAConn struct {
driver.Conn
}

func (c *MysqlXAConn) Commit(xid string, onePhase bool) error {
func NewMysqlXaConn(conn driver.Conn) *MysqlXAConn {
return &MysqlXAConn{Conn: conn}
}

func (c *MysqlXAConn) Commit(ctx context.Context, xid string, onePhase bool) error {
var sb strings.Builder
sb.WriteString("XA COMMIT ")
sb.WriteString(xid)
@@ -41,33 +44,33 @@ func (c *MysqlXAConn) Commit(xid string, onePhase bool) error {
}

conn, _ := c.Conn.(driver.ExecerContext)
_, err := conn.ExecContext(context.TODO(), sb.String(), nil)
_, err := conn.ExecContext(ctx, sb.String(), nil)
return err
}

func (c *MysqlXAConn) End(xid string, flags int) error {
func (c *MysqlXAConn) End(ctx context.Context, xid string, flags int) error {
var sb strings.Builder
sb.WriteString("XA END ")
sb.WriteString(xid)

switch flags {
case TMSUCCESS:
case TMSuccess:
break
case TMSUSPEND:
case TMSuspend:
sb.WriteString(" SUSPEND")
break
case TMFAIL:
case TMFail:
break
default:
return errors.New("invalid arguments")
}

conn, _ := c.Conn.(driver.ExecerContext)
_, err := conn.ExecContext(context.TODO(), sb.String(), nil)
_, err := conn.ExecContext(ctx, sb.String(), nil)
return err
}

func (c *MysqlXAConn) Forget(xid string) error {
func (c *MysqlXAConn) Forget(ctx context.Context, xid string) error {
// mysql doesn't support this
return errors.New("mysql doesn't support this")
}
@@ -78,29 +81,29 @@ func (c *MysqlXAConn) GetTransactionTimeout() time.Duration {

// IsSameRM is called to determine if the resource manager instance represented by the target object
// is the same as the resource manager instance represented by the parameter xares.
func (c *MysqlXAConn) IsSameRM(xares XAResource) bool {
func (c *MysqlXAConn) IsSameRM(ctx context.Context, xares XAResource) bool {
// todo: the fn depends on the driver.Conn, but it doesn't support
return false
}

func (c *MysqlXAConn) XAPrepare(xid string) error {
func (c *MysqlXAConn) XAPrepare(ctx context.Context, xid string) error {
var sb strings.Builder
sb.WriteString("XA PREPARE ")
sb.WriteString(xid)

conn, _ := c.Conn.(driver.ExecerContext)
_, err := conn.ExecContext(context.TODO(), sb.String(), nil)
_, err := conn.ExecContext(ctx, sb.String(), nil)
return err
}

// Recover Obtains a list of prepared transaction branches from a resource manager.
// The transaction manager calls this method during recovery to obtain the list of transaction branches
// that are currently in prepared or heuristically completed states.
func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) {
startRscan := (flag & TMSTARTRSCAN) > 0
endRscan := (flag & TMENDRSCAN) > 0
func (c *MysqlXAConn) Recover(ctx context.Context, flag int) (xids []string, err error) {
startRscan := (flag & TMStartRScan) > 0
endRscan := (flag & TMEndRScan) > 0

if !startRscan && !endRscan && flag != TMNOFLAGS {
if !startRscan && !endRscan && flag != TMNoFlags {
return nil, errors.New("invalid arguments")
}

@@ -109,7 +112,7 @@ func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) {
}

conn := c.Conn.(driver.QueryerContext)
res, err := conn.QueryContext(context.TODO(), "XA RECOVER", nil)
res, err := conn.QueryContext(ctx, "XA RECOVER", nil)
if err != nil {
return nil, err
}
@@ -133,13 +136,13 @@ func (c *MysqlXAConn) Recover(flag int) (xids []string, err error) {
return xids, err
}

func (c *MysqlXAConn) Rollback(xid string) error {
func (c *MysqlXAConn) Rollback(ctx context.Context, xid string) error {
var sb strings.Builder
sb.WriteString("XA ROLLBACK ")
sb.WriteString(xid)

conn, _ := c.Conn.(driver.ExecerContext)
_, err := conn.ExecContext(context.TODO(), sb.String(), nil)
_, err := conn.ExecContext(ctx, sb.String(), nil)
return err
}

@@ -147,25 +150,25 @@ func (c *MysqlXAConn) SetTransactionTimeout(duration time.Duration) bool {
return false
}

func (c *MysqlXAConn) Start(xid string, flags int) error {
func (c *MysqlXAConn) Start(ctx context.Context, xid string, flags int) error {
var sb strings.Builder
sb.WriteString("XA START")
sb.WriteString(xid)

switch flags {
case TMJOIN:
case TMJoin:
sb.WriteString(" JOIN")
break
case TMRESUME:
case TMResume:
sb.WriteString(" RESUME")
break
case TMNOFLAGS:
case TMNoFlags:
break
default:
return errors.New("invalid arguments")
}

conn, _ := c.Conn.(driver.ExecerContext)
_, err := conn.ExecContext(context.TODO(), sb.String(), nil)
_, err := conn.ExecContext(ctx, sb.String(), nil)
return err
}

pkg/datasource/sql/exec/mysql_xa_resource_test.go → pkg/datasource/sql/xa/mysql_xa_connection_test.go View File

@@ -15,19 +15,18 @@
* limitations under the License.
*/

package exec
package xa

import (
"context"
"database/sql/driver"
"fmt"
"errors"
"io"
"reflect"
"strings"
"testing"

"github.com/golang/mock/gomock"
"github.com/pkg/errors"

"github.com/seata/seata-go/pkg/datasource/sql/mock"
)
@@ -78,7 +77,7 @@ func TestMysqlXAConn_Commit(t *testing.T) {
c := &MysqlXAConn{
Conn: mockConn,
}
if err := c.Commit(tt.input.xid, tt.input.onePhase); (err != nil) != tt.wantErr {
if err := c.Commit(context.Background(), tt.input.xid, tt.input.onePhase); (err != nil) != tt.wantErr {
t.Errorf("Commit() error = %v, wantErr %v", err, tt.wantErr)
}
})
@@ -102,7 +101,7 @@ func TestMysqlXAConn_End(t *testing.T) {
name: "tm success",
input: args{
xid: "xid",
flags: TMSUCCESS,
flags: TMSuccess,
},
wantErr: false,
},
@@ -110,7 +109,7 @@ func TestMysqlXAConn_End(t *testing.T) {
name: "tm failed",
input: args{
xid: "xid",
flags: TMFAIL,
flags: TMFail,
},
wantErr: false,
},
@@ -124,7 +123,7 @@ func TestMysqlXAConn_End(t *testing.T) {
c := &MysqlXAConn{
Conn: mockConn,
}
if err := c.End(tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr {
if err := c.End(context.Background(), tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr {
t.Errorf("End() error = %v, wantErr %v", err, tt.wantErr)
}
})
@@ -148,7 +147,7 @@ func TestMysqlXAConn_Start(t *testing.T) {
name: "normal start",
input: args{
xid: "xid",
flags: TMNOFLAGS,
flags: TMNoFlags,
},
wantErr: false,
},
@@ -161,7 +160,7 @@ func TestMysqlXAConn_Start(t *testing.T) {
c := &MysqlXAConn{
Conn: mockConn,
}
if err := c.Start(tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr {
if err := c.Start(context.Background(), tt.input.xid, tt.input.flags); (err != nil) != tt.wantErr {
t.Errorf("Start() error = %v, wantErr %v", err, tt.wantErr)
}
})
@@ -196,7 +195,7 @@ func TestMysqlXAConn_XAPrepare(t *testing.T) {
c := &MysqlXAConn{
Conn: mockConn,
}
if err := c.XAPrepare(tt.input.xid); (err != nil) != tt.wantErr {
if err := c.XAPrepare(context.Background(), tt.input.xid); (err != nil) != tt.wantErr {
t.Errorf("XAPrepare() error = %v, wantErr %v", err, tt.wantErr)
}
})
@@ -219,7 +218,7 @@ func TestMysqlXAConn_Recover(t *testing.T) {
{
name: "normal recover",
args: args{
flag: TMSTARTRSCAN | TMENDRSCAN,
flag: TMStartRScan | TMEndRScan,
},
want: []string{"xid", "another_xid"},
wantErr: false,
@@ -227,14 +226,14 @@ func TestMysqlXAConn_Recover(t *testing.T) {
{
name: "invalid flag for recover",
args: args{
flag: TMFAIL,
flag: TMFail,
},
wantErr: true,
},
{
name: "valid flag for recover but don't scan",
args: args{
flag: TMENDRSCAN,
flag: TMEndRScan,
},
want: nil,
wantErr: false,
@@ -257,7 +256,7 @@ func TestMysqlXAConn_Recover(t *testing.T) {
c := &MysqlXAConn{
Conn: mockConn,
}
got, err := c.Recover(tt.args.flag)
got, err := c.Recover(context.Background(), tt.args.flag)
if (err != nil) != tt.wantErr {
t.Errorf("Recover() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -295,10 +294,7 @@ func (m *mysqlMockRows) Next(dest []driver.Value) error {
}
return b
}

cnt := min(len(m.data[0]), len(dest))
fmt.Printf("cnt: %d", cnt)

for i := 0; i < cnt; i++ {
dest[i] = m.data[m.idx][i]
}

pkg/datasource/sql/conn/oracle.go → pkg/datasource/sql/xa/oracle_xa_connection.go View File

@@ -15,7 +15,7 @@
* limitations under the License.
*/

package conn
package xa

import (
"context"

pkg/datasource/sql/conn/oracle_test.go → pkg/datasource/sql/xa/oracle_xa_connection_test.go View File

@@ -15,7 +15,7 @@
* limitations under the License.
*/

package conn
package xa

import (
"database/sql/driver"

+ 0
- 26
pkg/datasource/sql/xa/xa_connection.go View File

@@ -1,26 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xa

import (
"github.com/seata/seata-go/pkg/datasource/sql/exec"
)

type XAConnection interface {
getXAResource() (exec.XAResource, error)
}

+ 0
- 325
pkg/datasource/sql/xa/xa_connection_proxy.go View File

@@ -1,325 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xa

import (
"context"
"fmt"
"time"

"github.com/seata/seata-go/pkg/datasource/sql"
"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/exec"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/protocol/message"
"github.com/seata/seata-go/pkg/rm"
)

type ConnectionProxyXA struct {
xaBranchXid *XABranchXid
currentAutoCommitStatus bool `default:"true"`
xaActive bool `default:"false"`
kept bool `default:"false"`
rollBacked bool `default:"false"`
branchRegisterTime int64 `default:"0"`
prepareTime int64 `default:"0"`
timeout int `default:"0"`
proxyShouldBeHeld bool `default:"false"`
originalConnection sql.Conn
xaConnection XAConnection
xaResource exec.XAResource
resource sql.BaseDataSourceResource
xid string
}

const timeout int = 60000

func NewConnectionProxyXA(originalConnection sql.Conn, xaConnection XAConnection, resource sql.BaseDataSourceResource, xid string) (*ConnectionProxyXA, error) {
connectionProxyXA := &ConnectionProxyXA{}

connectionProxyXA.originalConnection = originalConnection
connectionProxyXA.xaConnection = xaConnection
connectionProxyXA.resource = resource
connectionProxyXA.xid = xid

connectionProxyXA.proxyShouldBeHeld = connectionProxyXA.resource.IsShouldBeHeld()

xaResource, err := xaConnection.getXAResource()
if err != nil {
return nil, fmt.Errorf("get xa resource failed")
} else {
connectionProxyXA.xaResource = xaResource
}
var rootContext sql.RootContext
transactionTimeout, ok := rootContext.GetTimeout()
if !ok {
transactionTimeout = timeout
}
if transactionTimeout < timeout {
transactionTimeout = timeout
}
connectionProxyXA.timeout = transactionTimeout
connectionProxyXA.currentAutoCommitStatus = connectionProxyXA.originalConnection.GetAutoCommit()
if !connectionProxyXA.currentAutoCommitStatus {
return nil, fmt.Errorf("connection[autocommit=false] as default is NOT supported")
}

return connectionProxyXA, nil
}

func (c *ConnectionProxyXA) keepIfNecessary() {
if c.ShouldBeHeld() {
c.resource.Hold(c.xaBranchXid.String(), c)
}
}

func (c *ConnectionProxyXA) releaseIfNecessary() {
if c.ShouldBeHeld() {
if c.xaBranchXid == nil {
if c.IsHeld() {
c.resource.Release(c.xaBranchXid.String(), c)
}
}
}
}

func (c *ConnectionProxyXA) XaCommit(xid string, branchId int64) error {
xaXid := Build(xid, branchId)
err := c.xaResource.Commit(xaXid.String(), false)
c.releaseIfNecessary()
return err
}

func (c *ConnectionProxyXA) XaRollbackByBranchId(xid string, branchId int64) {
xaXid := Build(xid, branchId)
c.XaRollback(xaXid)
}

func (c *ConnectionProxyXA) XaRollback(xaXid XAXid) error {
err := c.xaResource.Rollback(xaXid.GetGlobalXid())
c.releaseIfNecessary()
return err
}

func (c *ConnectionProxyXA) SetAutoCommit(autoCommit bool) error {
if c.currentAutoCommitStatus == autoCommit {
return nil
}
if autoCommit {
if c.xaActive {
_ = c.Commit()
}
} else {
if c.xaActive {
return fmt.Errorf("should NEVER happen: setAutoCommit from true to false while xa branch is active")
}

c.branchRegisterTime = time.Now().UnixMilli()
var branchRegisterParam rm.BranchRegisterParam
branchRegisterParam.BranchType = branch.BranchTypeXA
branchRegisterParam.ResourceId = c.resource.GetResourceId()
branchRegisterParam.Xid = c.xid
branchId, err := datasource.GetDataSourceManager(branch.BranchTypeXA).BranchRegister(context.TODO(), branchRegisterParam)
if err != nil {
c.cleanXABranchContext()
return fmt.Errorf("failed to register xa branch [%v]", c.xid)
}
c.xaBranchXid = Build(c.xid, branchId)
c.keepIfNecessary()
err = c.start()
if err != nil {
c.cleanXABranchContext()
return fmt.Errorf("failed to start xa branch [%v]", c.xid)
}
c.xaActive = true
}
c.currentAutoCommitStatus = autoCommit
return nil
}

func (c *ConnectionProxyXA) GetAutoCommit() bool {
return c.currentAutoCommitStatus
}

func (c *ConnectionProxyXA) Commit() error {
if c.currentAutoCommitStatus {
return nil
}
if !c.xaActive || c.xaBranchXid == nil {
return fmt.Errorf("should NOT commit on an inactive session")
}
now := time.Now().UnixMilli()
if c.end(exec.TMSUCCESS) != nil {
return c.commitErrorHandle()
}
if c.checkTimeout(now) != nil {
return c.commitErrorHandle()
}
if c.xaResource.XAPrepare(c.xaBranchXid.String()) != nil {
return c.commitErrorHandle()
}
return nil
}

func (c *ConnectionProxyXA) commitErrorHandle() error {
req := message.BranchReportRequest{
BranchType: branch.BranchTypeXA,
Xid: c.xid,
BranchId: c.xaBranchXid.GetBranchId(),
Status: branch.BranchStatusPhaseoneFailed,
ApplicationData: nil,
ResourceId: c.resource.GetResourceId(),
}
if datasource.NewBasicSourceManager().BranchReport(context.TODO(), req) != nil {
c.cleanXABranchContext()
return fmt.Errorf("Failed to report XA branch commit-failure on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId())
}
c.cleanXABranchContext()
return fmt.Errorf("Failed to end(TMSUCCESS)/prepare xa branch on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId())
}

func (c *ConnectionProxyXA) Rollback() error {
if c.currentAutoCommitStatus {
return nil
}
if !c.xaActive || c.xaBranchXid == nil {
return fmt.Errorf("should NOT rollback on an inactive session")
}
if !c.rollBacked {
if c.xaResource.End(c.xaBranchXid.String(), exec.TMFAIL) != nil {
return c.rollbackErrorHandle()
}
if c.XaRollback(c.xaBranchXid) != nil {
c.cleanXABranchContext()
return c.rollbackErrorHandle()
}
req := message.BranchReportRequest{
BranchType: branch.BranchTypeXA,
Xid: c.xid,
BranchId: c.xaBranchXid.GetBranchId(),
Status: branch.BranchStatusPhaseoneFailed,
ApplicationData: nil,
ResourceId: c.resource.GetResourceId(),
}
if datasource.NewBasicSourceManager().BranchReport(context.TODO(), req) != nil {
c.cleanXABranchContext()
return fmt.Errorf("failed to report XA branch commit-failure on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId())
}
}
c.cleanXABranchContext()
return nil
}

func (c *ConnectionProxyXA) rollbackErrorHandle() error {
return fmt.Errorf("failed to end(TMFAIL) xa branch on [%v] - [%v]", c.xid, c.xaBranchXid.GetBranchId())
}

func (c *ConnectionProxyXA) start() error {
err := c.xaResource.Start(c.xaBranchXid.String(), exec.TMNOFLAGS)
if err := c.termination(c.xaBranchXid.String()); err != nil {
c.xaResource.End(c.xaBranchXid.String(), exec.TMFAIL)
c.XaRollback(c.xaBranchXid)
return err
}
return err
}

func (c *ConnectionProxyXA) end(flags int) error {
err := c.termination(c.xaBranchXid.String())
if err != nil {
return err
}
err = c.xaResource.End(c.xaBranchXid.String(), flags)
if err != nil {
return err
}
return nil
}

func (c *ConnectionProxyXA) cleanXABranchContext() {
c.branchRegisterTime = 0
c.prepareTime = 0
c.timeout = 0
c.xaActive = false
if !c.IsHeld() {
c.xaBranchXid = nil
}
}

func (c *ConnectionProxyXA) checkTimeout(now int64) error {
if now-c.branchRegisterTime > int64(c.timeout) {
c.XaRollback(c.xaBranchXid)
return fmt.Errorf("XA branch timeout error")
}
return nil
}

func (c *ConnectionProxyXA) Close() error {
c.rollBacked = false
if c.IsHeld() && c.ShouldBeHeld() {
return nil
}
c.cleanXABranchContext()
if err := c.originalConnection.Close(); err != nil {
return err
}
return nil
}

func (c *ConnectionProxyXA) CloseForce() error {
physicalConn := c.originalConnection
if err := physicalConn.Close(); err != nil {
return err
}
c.rollBacked = false
c.cleanXABranchContext()
if err := c.originalConnection.Close(); err != nil {
return err
}
c.releaseIfNecessary()
return nil
}

func (c *ConnectionProxyXA) SetHeld(kept bool) {
c.kept = kept
}

func (c *ConnectionProxyXA) IsHeld() bool {
return c.kept
}

func (c *ConnectionProxyXA) ShouldBeHeld() bool {
return c.proxyShouldBeHeld || c.resource.GetDB() != nil
}

func (c *ConnectionProxyXA) GetPrepareTime() int64 {
return c.prepareTime
}

func (c *ConnectionProxyXA) setPrepareTime(prepareTime int64) {
c.prepareTime = prepareTime
}

func (c *ConnectionProxyXA) termination(xaBranchXid string) error {
branchStatus, err := c.resource.GetBranchStatus(xaBranchXid)
if err != nil {
c.releaseIfNecessary()
return fmt.Errorf("failed xa branch [%v] the global transaction has finish, branch status: [%v]", c.xid, branchStatus)
}
return nil
}

+ 71
- 0
pkg/datasource/sql/xa/xa_resource.go View File

@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package xa

import (
"context"
"time"
)

const (
// TMEndRScan ends a recovery scan.
TMEndRScan = 0x00800000
// TMFail disassociates the caller and marks the transaction branch
// rollback-only.
TMFail = 0x20000000

// TMJoin joining existing transaction branch.
TMJoin = 0x00200000

// TMNoFlags indicate no flags value is selected.
TMNoFlags = 0x00000000

// TMOnePhase using one-phase optimization.
TMOnePhase = 0x40000000

// TMResume is resuming association with a suspended transaction branch.
TMResume = 0x08000000

// TMStartRScan starts a recovery scan.
TMStartRScan = 0x01000000

// TMSuccess disassociates caller from a transaction branch.
TMSuccess = 0x04000000

// TMSuspend is suspending (not ending) its association with a transaction branch.
TMSuspend = 0x02000000

// XAReadOnly the transaction branch has been read-only and has been committed.
XAReadOnly = 0x00000003

// XAOk The transaction work has been prepared normally.
XAOk = 0
)

type XAResource interface {
Commit(ctx context.Context, xid string, onePhase bool) error
End(ctx context.Context, xid string, flags int) error
Forget(ctx context.Context, xid string) error
GetTransactionTimeout() time.Duration
IsSameRM(ctx context.Context, resource XAResource) bool
XAPrepare(ctx context.Context, xid string) error
Recover(ctx context.Context, flag int) ([]string, error)
Rollback(ctx context.Context, xid string) error
SetTransactionTimeout(duration time.Duration) bool
Start(ctx context.Context, xid string, flags int) error
}

pkg/datasource/sql/exec/xa/default.go → pkg/datasource/sql/xa/xa_resource_factory.go View File

@@ -18,12 +18,31 @@
package xa

import (
"github.com/seata/seata-go/pkg/datasource/sql/exec"
"database/sql/driver"
"fmt"

"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/util/log"
)

func Init() {
exec.RegisterXAExecutor(types.DBTypeMySQL, func() exec.SQLExecutor {
return &XAExecutor{}
})
// CreateXAResource create a connection for xa with the different db type.
// Such as mysql, oracle, MARIADB, POSTGRESQL
func CreateXAResource(conn driver.Conn, dbType types.DBType) (XAResource, error) {
var err error
var xaConnection XAResource
switch dbType {
case types.DBTypeMySQL:
xaConnection = NewMysqlXaConn(conn)
case types.DBTypeOracle:
case types.DBTypePostgreSQL:
default:
err = fmt.Errorf("not support db type for :%s", dbType.String())
}

if err != nil {
log.Errorf(err.Error())
return nil, err
}

return xaConnection, nil
}

pkg/datasource/sql/xa/xa_branch_xid.go → pkg/datasource/sql/xa_branch_xid.go View File

@@ -15,7 +15,7 @@
* limitations under the License.
*/

package xa
package sql

import (
"strconv"
@@ -23,13 +23,12 @@ import (
)

const (
BranchIdPrefix = "-"
SeataXaXidFormatId = 9752
branchIdPrefix = "-"
)

type XABranchXid struct {
xid string
branchId int64
branchId uint64
globalTransactionId []byte
branchQualifier []byte
}
@@ -63,14 +62,10 @@ func (x *XABranchXid) GetGlobalXid() string {
return x.xid
}

func (x *XABranchXid) GetBranchId() int64 {
func (x *XABranchXid) GetBranchId() uint64 {
return x.branchId
}

func (x *XABranchXid) GetFormatId() int {
return SeataXaXidFormatId
}

func (x *XABranchXid) GetGlobalTransactionId() []byte {
return x.globalTransactionId
}
@@ -80,7 +75,7 @@ func (x *XABranchXid) GetBranchQualifier() []byte {
}

func (x *XABranchXid) String() string {
return x.xid + BranchIdPrefix + strconv.FormatInt(x.branchId, 10)
return x.xid + branchIdPrefix + strconv.FormatUint(x.branchId, 10)
}

func WithXid(xid string) Option {
@@ -89,7 +84,7 @@ func WithXid(xid string) Option {
}
}

func WithBranchId(branchId int64) Option {
func WithBranchId(branchId uint64) Option {
return func(x *XABranchXid) {
x.branchId = branchId
}
@@ -113,7 +108,7 @@ func encode(x *XABranchXid) {
}

if x.branchId != 0 {
x.branchQualifier = []byte(BranchIdPrefix + strconv.FormatInt(x.branchId, 10))
x.branchQualifier = []byte(branchIdPrefix + strconv.FormatUint(x.branchId, 10))
}
}

@@ -123,7 +118,7 @@ func decode(x *XABranchXid) {
}

if len(x.branchQualifier) > 0 {
branchId := strings.TrimLeft(string(x.branchQualifier), BranchIdPrefix)
x.branchId, _ = strconv.ParseInt(branchId, 10, 64)
branchId := strings.TrimLeft(string(x.branchQualifier), branchIdPrefix)
x.branchId, _ = strconv.ParseUint(branchId, 10, 64)
}
}

pkg/datasource/sql/xa/xa_branch_xid_test.go → pkg/datasource/sql/xa_branch_xid_test.go View File

@@ -15,7 +15,7 @@
* limitations under the License.
*/

package xa
package sql

import (
"testing"
@@ -25,8 +25,8 @@ import (

func TestXABranchXidBuild(t *testing.T) {
xid := "111"
branchId := int64(222)
x := Build(xid, branchId)
branchId := uint64(222)
x := XaIdBuild(xid, branchId)
assert.Equal(t, x.GetGlobalXid(), xid)
assert.Equal(t, x.GetBranchId(), branchId)

@@ -36,11 +36,11 @@ func TestXABranchXidBuild(t *testing.T) {

func TestXABranchXidBuildWithByte(t *testing.T) {
xid := []byte("111")
branchId := []byte(BranchIdPrefix + "222")
x := BuildWithByte(xid, branchId)
branchId := []byte(branchIdPrefix + "222")
x := XaIdBuildWithByte(xid, branchId)
assert.Equal(t, x.GetGlobalTransactionId(), xid)
assert.Equal(t, x.GetBranchQualifier(), branchId)

assert.Equal(t, x.GetGlobalXid(), "111")
assert.Equal(t, x.GetBranchId(), int64(222))
assert.Equal(t, x.GetBranchId(), uint64(222))
}

+ 245
- 0
pkg/datasource/sql/xa_resource_manager.go View File

@@ -0,0 +1,245 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sql

import (
"context"
"database/sql"
"errors"
"flag"
"fmt"
"sync"
"time"

"github.com/bluele/gcache"

"github.com/seata/seata-go/pkg/datasource/sql/datasource"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/rm"
"github.com/seata/seata-go/pkg/util/log"
)

var branchStatusCache gcache.Cache

type XAConnConf struct {
XaBranchExecutionTimeout time.Duration `json:"xa_branch_execution_timeout" xml:"xa_branch_execution_timeout" koanf:"xa_branch_execution_timeout"`
}

func (cfg *XAConnConf) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) {
f.DurationVar(&cfg.XaBranchExecutionTimeout, prefix+".xa_branch_execution_timeout", time.Minute, "Undo log table name.")
}

type XAConfig struct {
xaConnConf XAConnConf
TwoPhaseHoldTime time.Duration `json:"two_phase_hold_time" yaml:"xa_two_phase_hold_time" koanf:"xa_two_phase_hold_time"`
}

func (cfg *XAConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) {
f.DurationVar(&cfg.TwoPhaseHoldTime, prefix+".two_phase_hold_time", time.Millisecond*1000, "Undo log table name.")
cfg.xaConnConf.RegisterFlagsWithPrefix(prefix, f)
}

func InitXA(config XAConfig) *XAResourceManager {
xaSourceManager := &XAResourceManager{
resourceCache: sync.Map{},
basic: datasource.NewBasicSourceManager(),
rmRemoting: rm.GetRMRemotingInstance(),
config: config,
}

xaConnTimeout = config.xaConnConf.XaBranchExecutionTimeout

branchStatusCache = gcache.New(1024).LRU().Expiration(time.Minute * 10).Build()

rm.GetRmCacheInstance().RegisterResourceManager(xaSourceManager)

go xaSourceManager.xaTwoPhaseTimeoutChecker()

return xaSourceManager
}

type XAResourceManager struct {
config XAConfig
resourceCache sync.Map
basic *datasource.BasicSourceManager
rmRemoting *rm.RMRemoting
}

func (xaManager *XAResourceManager) xaTwoPhaseTimeoutChecker() {
var dbResource *DBResource
xaManager.resourceCache.Range(func(key, value any) bool {
if source, ok := value.(*DBResource); ok {
dbResource = source
}
return false
})

if dbResource.IsShouldBeHeld() {
ticker := time.NewTicker(time.Second)
for {
select {
case <-ticker.C:
xaManager.resourceCache.Range(func(key, value any) bool {
source, ok := value.(*DBResource)
if !ok {
return true
}
if source.IsShouldBeHeld() {
return true
}

source.GetKeeper().Range(func(key, value any) bool {
connectionXA, isConnectionXA := value.(*XAConn)
if !isConnectionXA {
return true
}

if time.Now().Sub(connectionXA.prepareTime) > xaManager.config.TwoPhaseHoldTime {
if err := connectionXA.CloseForce(); err != nil {
log.Errorf("Force close the xa xid:%s physical connection fail", connectionXA.txCtx.XID)
}
}
return true
})
return true
})
}
}

}

}

func (xaManager *XAResourceManager) GetBranchType() branch.BranchType {
return branch.BranchTypeXA
}

func (xaManager *XAResourceManager) GetCachedResources() *sync.Map {
return &xaManager.resourceCache
}

func (xaManager *XAResourceManager) RegisterResource(res rm.Resource) error {
xaManager.resourceCache.Store(res.GetResourceId(), res)
return xaManager.basic.RegisterResource(res)
}

func (xaManager *XAResourceManager) UnregisterResource(resource rm.Resource) error {
return xaManager.basic.UnregisterResource(resource)
}

func (xaManager *XAResourceManager) xaIDBuilder(xid string, branchId uint64) XAXid {
return XaIdBuild(xid, branchId)
}

func (xaManager *XAResourceManager) finishBranch(ctx context.Context, xaID XAXid, branchResource rm.BranchResource) (*XAConn, error) {
resource, ok := xaManager.resourceCache.Load(branchResource.ResourceId)
if !ok {
err := fmt.Errorf("unknow resource for rollback xa, resourceId: %s", branchResource.ResourceId)
log.Errorf(err.Error())
return nil, err
}

dbResource, ok := resource.(*DBResource)
if !ok {
err := fmt.Errorf("unknow resource for rollback xa, resourceId: %s", branchResource.ResourceId)
log.Errorf(err.Error())
return nil, err
}

connectionProxyXA, err := dbResource.ConnectionForXA(ctx, xaID)
if err != nil {
err := fmt.Errorf("get connection for rollback xa, resourceId: %s", branchResource.ResourceId)
log.Errorf(err.Error())
return nil, err
}

return connectionProxyXA, nil
}

func (xaManager *XAResourceManager) BranchCommit(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) {
xaID := xaManager.xaIDBuilder(branchResource.Xid, uint64(branchResource.BranchId))
connectionProxyXA, err := xaManager.finishBranch(ctx, xaID, branchResource)
if err != nil {
return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err
}

if commitErr := connectionProxyXA.XaCommit(ctx, xaID.String(), branchResource.BranchId); commitErr != nil {
err := fmt.Errorf("rollback xa, resourceId: %s", branchResource.ResourceId)
log.Errorf(err.Error())
setBranchStatus(xaID.String(), branch.BranchStatusPhasetwoCommitted)
return branch.BranchStatusPhasetwoCommitFailedUnretryable, err
}

log.Infof("%s was committed", xaID.String())
return branch.BranchStatusPhasetwoCommitted, nil
}

func (xaManager *XAResourceManager) BranchRollback(ctx context.Context, branchResource rm.BranchResource) (branch.BranchStatus, error) {
xaID := xaManager.xaIDBuilder(branchResource.Xid, uint64(branchResource.BranchId))
connectionProxyXA, err := xaManager.finishBranch(ctx, xaID, branchResource)
if err != nil {
return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err
}

if rollbackErr := connectionProxyXA.XaRollbackByBranchId(ctx, xaID.String(), branchResource.BranchId); rollbackErr != nil {
err := fmt.Errorf("rollback xa, resourceId: %s", branchResource.ResourceId)
log.Errorf(err.Error())
setBranchStatus(xaID.String(), branch.BranchStatusPhasetwoRollbacked)
return branch.BranchStatusPhasetwoRollbackFailedUnretryable, err
}

log.Infof("%s was rollback", xaID.String())
return branch.BranchStatusPhasetwoRollbacked, nil
}

func (xaManager *XAResourceManager) LockQuery(ctx context.Context, param rm.LockQueryParam) (bool, error) {
return false, nil
}

func (xaManager *XAResourceManager) BranchRegister(ctx context.Context, req rm.BranchRegisterParam) (int64, error) {
return xaManager.rmRemoting.BranchRegister(req)
}

func (xaManager *XAResourceManager) BranchReport(ctx context.Context, param rm.BranchReportParam) error {
return xaManager.rmRemoting.BranchReport(param)
}

func (xaManager *XAResourceManager) CreateTableMetaCache(ctx context.Context, resID string, dbType types.DBType, db *sql.DB) (datasource.TableMetaCache, error) {
return xaManager.basic.CreateTableMetaCache(ctx, resID, dbType, db)
}

func branchStatus(xaBranchXid string) (branch.BranchStatus, error) {
tmpBranchStatus, err := branchStatusCache.GetIFPresent(xaBranchXid)
if err != nil {
if errors.Is(err, gcache.KeyNotFoundError) {
return branch.BranchStatusUnknown, nil
}
return branch.BranchStatusUnknown, err
}

branchStatus, isBranchStatus := tmpBranchStatus.(branch.BranchStatus)
if !isBranchStatus {
return branch.BranchStatusUnknown, fmt.Errorf("branchId:%s get result isn't branch status", xaBranchXid)
}
return branchStatus, nil
}

func setBranchStatus(xaBranchXid string, status branch.BranchStatus) {
branchStatusCache.Set(xaBranchXid, status)
}

pkg/datasource/sql/xa/xa_xid.go → pkg/datasource/sql/xa_xid.go View File

@@ -15,16 +15,10 @@
* limitations under the License.
*/

package xa

type Xid interface {
GetFormatId() int
GetGlobalTransactionId() []byte
GetBranchQualifier() []byte
}
package sql

type XAXid interface {
Xid
GetGlobalXid() string
GetBranchId() int64
GetBranchId() uint64
String() string
}

pkg/datasource/sql/xa/xa_xid_builder.go → pkg/datasource/sql/xa_xid_builder.go View File

@@ -15,12 +15,12 @@
* limitations under the License.
*/

package xa
package sql

func Build(xid string, branchId int64) *XABranchXid {
func XaIdBuild(xid string, branchId uint64) *XABranchXid {
return NewXABranchXid(WithXid(xid), WithBranchId(branchId))
}

func BuildWithByte(globalTransactionId []byte, branchQualifier []byte) *XABranchXid {
func XaIdBuildWithByte(globalTransactionId []byte, branchQualifier []byte) *XABranchXid {
return NewXABranchXid(WithGlobalTransactionId(globalTransactionId), WithBranchQualifier(branchQualifier))
}

+ 36
- 65
pkg/protocol/branch/branch.go View File

@@ -35,71 +35,42 @@ const (
)

const (
/**
* The BranchStatus_Unknown.
* description:BranchStatus_Unknown branch status.
*/
BranchStatusUnknown = BranchStatus(0)

/**
* The BranchStatus_Registered.
* description:BranchStatus_Registered to TC.
*/
BranchStatusRegistered = BranchStatus(1)

/**
* The Phase one done.
* description:Branch logic is successfully done at phase one.
*/
BranchStatusPhaseoneDone = BranchStatus(2)

/**
* The Phase one failed.
* description:Branch logic is failed at phase one.
*/
BranchStatusPhaseoneFailed = BranchStatus(3)

/**
* The Phase one timeout.
* description:Branch logic is NOT reported for a timeout.
*/
BranchStatusPhaseoneTimeout = BranchStatus(4)

/**
* The Phase two committed.
* description:Commit logic is successfully done at phase two.
*/
BranchStatusPhasetwoCommitted = BranchStatus(5)

/**
* The Phase two commit failed retryable.
* description:Commit logic is failed but retryable.
*/
BranchStatusPhasetwoCommitFailedRetryable = BranchStatus(6)

/**
* The Phase two commit failed unretryable.
* description:Commit logic is failed and NOT retryable.
*/
BranchStatusPhasetwoCommitFailedUnretryable = BranchStatus(7)

/**
* The Phase two rollbacked.
* description:Rollback logic is successfully done at phase two.
*/
BranchStatusPhasetwoRollbacked = BranchStatus(8)

/**
* The Phase two rollback failed retryable.
* description:Rollback logic is failed but retryable.
*/
BranchStatusPhasetwoRollbackFailedRetryable = BranchStatus(9)

/**
* The Phase two rollback failed unretryable.
* description:Rollback logic is failed but NOT retryable.
*/
BranchStatusPhasetwoRollbackFailedUnretryable = BranchStatus(10)
// BranchStatusUnknown the BranchStatus_Unknown. description:BranchStatus_Unknown branch status.
BranchStatusUnknown = iota

// BranchStatusRegistered the BranchStatus_Registered. description:BranchStatus_Registered to TC.
BranchStatusRegistered

// BranchStatusPhaseoneDone the Phase one done. description:Branch logic is successfully done at phase one.
BranchStatusPhaseoneDone

// BranchStatusPhaseoneFailed the Phase one failed. description:Branch logic is failed at phase one.
BranchStatusPhaseoneFailed

// BranchStatusPhaseoneTimeout the Phase one timeout. description:Branch logic is NOT reported for a timeout.
BranchStatusPhaseoneTimeout

// BranchStatusPhasetwoCommitted the Phase two committed. description:Commit logic is successfully done at phase two.
BranchStatusPhasetwoCommitted

// BranchStatusPhasetwoCommitFailedRetryable the Phase two commit failed retryable. description:Commit logic is failed but retryable.
BranchStatusPhasetwoCommitFailedRetryable

// BranchStatusPhasetwoCommitFailedUnretryable the Phase two commit failed unretryable.
// description:Commit logic is failed and NOT retryable.
BranchStatusPhasetwoCommitFailedUnretryable

// BranchStatusPhasetwoRollbacked The Phase two rollbacked.
// description:Rollback logic is successfully done at phase two.
BranchStatusPhasetwoRollbacked

// BranchStatusPhasetwoRollbackFailedRetryable the Phase two rollback failed retryable.
// description:Rollback logic is failed but retryable.
BranchStatusPhasetwoRollbackFailedRetryable

// BranchStatusPhasetwoRollbackFailedUnretryable the Phase two rollback failed unretryable.
// description:Rollback logic is failed but NOT retryable.
BranchStatusPhasetwoRollbackFailedUnretryable
)

func (s BranchStatus) String() string {


+ 11
- 11
pkg/rm/rm_api.go View File

@@ -31,7 +31,7 @@ type Resource interface {
GetBranchType() branch.BranchType
}

// branch resource which contains branch to commit or rollback
// BranchResource contains branch to commit or rollback
type BranchResource struct {
BranchType branch.BranchType
Xid string
@@ -42,9 +42,9 @@ type BranchResource struct {

// ResourceManagerInbound Control a branch transaction commit or rollback
type ResourceManagerInbound interface {
// Commit a branch transaction
// BranchCommit commit a branch transaction
BranchCommit(ctx context.Context, resource BranchResource) (branch.BranchStatus, error)
// Rollback a branch transaction
// BranchRollback rollback a branch transaction
BranchRollback(ctx context.Context, resource BranchResource) (branch.BranchStatus, error)
}

@@ -75,13 +75,13 @@ type LockQueryParam struct {
LockKeys string
}

// Resource Manager: send outbound request to TC
// ResourceManagerOutbound Resource Manager: send outbound request to TC
type ResourceManagerOutbound interface {
// Branch register long
// BranchRegister rm register the branch transaction
BranchRegister(ctx context.Context, param BranchRegisterParam) (int64, error)
// Branch report
// BranchReport branch transaction report the status
BranchReport(ctx context.Context, param BranchReportParam) error
// Lock query boolean
// LockQuery lock query boolean
LockQuery(ctx context.Context, param LockQueryParam) (bool, error)
}

@@ -90,13 +90,13 @@ type ResourceManager interface {
ResourceManagerInbound
ResourceManagerOutbound

// Register a Resource to be managed by Resource Manager
// RegisterResource register a resource to be managed by resource manager
RegisterResource(resource Resource) error
// Unregister a Resource from the Resource Manager
// UnregisterResource unregister a resource from the Resource Manager
UnregisterResource(resource Resource) error
// Get all resources managed by this manager
// GetCachedResources get all resources managed by this manager
GetCachedResources() *sync.Map
// Get the BranchType
// GetBranchType get the branch type
GetBranchType() branch.BranchType
}



Loading…
Cancel
Save