Browse Source

Feature: Database persistence for seata-go Saga state machine (#794)

* implement state machine repository

* temporary recording

* temporary recording

* temporary storage

* Supplementary test
pull/808/head
lxfeng1997 GitHub 6 months ago
parent
commit
689c5d6f7c
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
11 changed files with 784 additions and 42 deletions
  1. +3
    -0
      pkg/saga/statemachine/constant/constant.go
  2. +55
    -11
      pkg/saga/statemachine/engine/core/default_statemachine_config.go
  3. +2
    -0
      pkg/saga/statemachine/engine/core/statemachine_store.go
  4. +12
    -0
      pkg/saga/statemachine/statelang/statemachine_instance.go
  5. +1
    -1
      pkg/saga/statemachine/store/db/statelang.go
  6. +340
    -14
      pkg/saga/statemachine/store/db/statelog.go
  7. +7
    -0
      pkg/saga/statemachine/store/db/statelog_test.go
  8. +18
    -0
      pkg/saga/statemachine/store/repository/state_log_repository.go
  9. +203
    -16
      pkg/saga/statemachine/store/repository/state_machine_repository.go
  10. +118
    -0
      pkg/saga/statemachine/store/repository/state_machine_repository_test.go
  11. +25
    -0
      pkg/tm/global_transaction.go

+ 3
- 0
pkg/saga/statemachine/constant/constant.go View File

@@ -67,6 +67,7 @@ const (
// TODO: this lock in process context only has one, try to add more to add concurrent
VarNameProcessContextMutexLock string = "_current_context_mutex_lock"
VarNameFailEndStateFlag string = "_fail_end_state_flag_"
VarNameGlobalTx string = "_global_transaction_"
// end region

// region of loop
@@ -79,10 +80,12 @@ const (
// end region

// region others
SeqEntityStateMachine string = "STATE_MACHINE"
SeqEntityStateMachineInst string = "STATE_MACHINE_INST"
SeqEntityStateInst string = "STATE_INST"
OperationNameForward string = "forward"
LoopStateNamePattern string = "-loop-"
SagaTransNamePrefix string = "$Saga_"
// end region

SeperatorParentId string = ":"


+ 55
- 11
pkg/saga/statemachine/engine/core/default_statemachine_config.go View File

@@ -25,16 +25,24 @@ import (
)

const (
DefaultTransOperTimeout = 60000 * 30
DefaultServiceInvokeTimeout = 60000 * 5
DefaultTransOperTimeout = 60000 * 30
DefaultServiceInvokeTimeout = 60000 * 5
DefaultClientSagaRetryPersistModeUpdate = false
DefaultClientSagaCompensatePersistModeUpdate = false
DefaultClientReportSuccessEnable = false
DefaultClientSagaBranchRegisterEnable = true
)

type DefaultStateMachineConfig struct {
// Configuration
transOperationTimeout int
serviceInvokeTimeout int
charset string
defaultTenantId string
transOperationTimeout int
serviceInvokeTimeout int
charset string
defaultTenantId string
sagaRetryPersistModeUpdate bool
sagaCompensatePersistModeUpdate bool
sagaBranchRegisterEnable bool
rmReportSuccessEnable bool

// Components

@@ -202,13 +210,49 @@ func (c *DefaultStateMachineConfig) ServiceInvokeTimeout() int {
return c.serviceInvokeTimeout
}

func (c *DefaultStateMachineConfig) IsSagaRetryPersistModeUpdate() bool {
return c.sagaRetryPersistModeUpdate
}

func (c *DefaultStateMachineConfig) SetSagaRetryPersistModeUpdate(sagaRetryPersistModeUpdate bool) {
c.sagaRetryPersistModeUpdate = sagaRetryPersistModeUpdate
}

func (c *DefaultStateMachineConfig) IsSagaCompensatePersistModeUpdate() bool {
return c.sagaCompensatePersistModeUpdate
}

func (c *DefaultStateMachineConfig) SetSagaCompensatePersistModeUpdate(sagaCompensatePersistModeUpdate bool) {
c.sagaCompensatePersistModeUpdate = sagaCompensatePersistModeUpdate
}

func (c *DefaultStateMachineConfig) IsSagaBranchRegisterEnable() bool {
return c.sagaBranchRegisterEnable
}

func (c *DefaultStateMachineConfig) SetSagaBranchRegisterEnable(sagaBranchRegisterEnable bool) {
c.sagaBranchRegisterEnable = sagaBranchRegisterEnable
}

func (c *DefaultStateMachineConfig) IsRmReportSuccessEnable() bool {
return c.rmReportSuccessEnable
}

func (c *DefaultStateMachineConfig) SetRmReportSuccessEnable(rmReportSuccessEnable bool) {
c.rmReportSuccessEnable = rmReportSuccessEnable
}

func NewDefaultStateMachineConfig() *DefaultStateMachineConfig {
c := &DefaultStateMachineConfig{
transOperationTimeout: DefaultTransOperTimeout,
serviceInvokeTimeout: DefaultServiceInvokeTimeout,
charset: "UTF-8",
defaultTenantId: "000001",
componentLock: &sync.Mutex{},
transOperationTimeout: DefaultTransOperTimeout,
serviceInvokeTimeout: DefaultServiceInvokeTimeout,
charset: "UTF-8",
defaultTenantId: "000001",
sagaRetryPersistModeUpdate: DefaultClientSagaRetryPersistModeUpdate,
sagaCompensatePersistModeUpdate: DefaultClientSagaCompensatePersistModeUpdate,
sagaBranchRegisterEnable: DefaultClientSagaBranchRegisterEnable,
rmReportSuccessEnable: DefaultClientReportSuccessEnable,
componentLock: &sync.Mutex{},
}

// TODO: init config


+ 2
- 0
pkg/saga/statemachine/engine/core/statemachine_store.go View File

@@ -55,6 +55,8 @@ type StateLogStore interface {
GetStateInstance(stateInstanceId string, stateMachineInstanceId string) (statelang.StateInstance, error)

GetStateInstanceListByMachineInstanceId(stateMachineInstanceId string) ([]statelang.StateInstance, error)

ClearUp(context ProcessContext)
}

type StateMachineRepository interface {


+ 12
- 0
pkg/saga/statemachine/statelang/statemachine_instance.go View File

@@ -72,6 +72,10 @@ type StateMachineInstance interface {

SetStatus(status ExecutionStatus)

StateMap() map[string]StateInstance

SetStateMap(stateMap map[string]StateInstance)

CompensationStatus() ExecutionStatus

SetCompensationStatus(compensationStatus ExecutionStatus)
@@ -234,6 +238,14 @@ func (s *StateMachineInstanceImpl) SetStatus(status ExecutionStatus) {
s.status = status
}

func (s *StateMachineInstanceImpl) StateMap() map[string]StateInstance {
return s.stateMap
}

func (s *StateMachineInstanceImpl) SetStateMap(stateMap map[string]StateInstance) {
s.stateMap = stateMap
}

func (s *StateMachineInstanceImpl) CompensationStatus() ExecutionStatus {
return s.compensationStatus
}


+ 1
- 1
pkg/saga/statemachine/store/db/statelang.go View File

@@ -104,7 +104,7 @@ func scanRowsToStateMachine(rows *sql.Rows) (statelang.StateMachine, error) {
if recoverStrategy != "" {
stateMachine.SetRecoverStrategy(statelang.RecoverStrategy(recoverStrategy))
}
stateMachine.SetTenantId(t)
stateMachine.SetTenantId(tenantId)
stateMachine.SetStatus(statelang.StateMachineStatus(status))
return stateMachine, nil
}


+ 340
- 14
pkg/saga/statemachine/store/db/statelog.go View File

@@ -21,17 +21,24 @@ import (
"context"
"database/sql"
"fmt"
"regexp"
"strconv"
"strings"
"time"

"github.com/pkg/errors"
constant2 "github.com/seata/seata-go/pkg/constant"
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/protocol/message"
"github.com/seata/seata-go/pkg/rm"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/core"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/serializer"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"github.com/seata/seata-go/pkg/tm"
"github.com/seata/seata-go/pkg/util/log"
"regexp"
"strconv"
"strings"
"time"
)

const (
@@ -124,14 +131,19 @@ func (s *StateLogStore) RecordStateMachineStarted(ctx context.Context, machineIn
parentId := machineInstance.ParentID()

if parentId == "" {
//TODO begin transaction
// begin transaction
err = s.beginTransaction(ctx, machineInstance, context)
if err != nil {
return err
}
}

if machineInstance.ID() == "" && s.seqGenerator != nil {
machineInstance.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachineInst, ""))
}

//TODO bind SAGA branch type
// bind SAGA branch type
context.SetVariable(constant2.BranchTypeKey, branch.BranchTypeSAGA)

serializedStartParams, err := s.paramsSerializer.Serialize(machineInstance.StartParams())
if err != nil {
@@ -149,14 +161,48 @@ func (s *StateLogStore) RecordStateMachineStarted(ctx context.Context, machineIn
return nil
}

func (s *StateLogStore) beginTransaction(ctx context.Context, machineInstance statelang.StateMachineInstance, context core.ProcessContext) error {
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.StateMachineConfig)
if !ok {
return errors.New("begin transaction fail, stateMachineConfig is required in context")
}

defer func() {
isAsync, ok := context.GetVariable(constant.VarNameIsAsyncExecution).(bool)
if ok && isAsync {
s.ClearUp(context)
}
}()

tm.SetTxRole(ctx, tm.Launcher)
tm.SetTxStatus(ctx, message.GlobalStatusUnKnown)
tm.SetTxName(ctx, constant.SagaTransNamePrefix+machineInstance.StateMachine().Name())

err := tm.GetGlobalTransactionManager().Begin(ctx, time.Duration(cfg.TransOperationTimeout()))
if err != nil {
return err
}

machineInstance.SetID(tm.GetXID(ctx))
return nil
}

func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineInstance statelang.StateMachineInstance,
context core.ProcessContext) error {
if machineInstance == nil {
return nil
}

defer func() {
s.ClearUp(context)
}()

endParams := machineInstance.EndParams()
if endParams != nil {
delete(endParams, constant.VarNameGlobalTx)
}

// if success, clear exception
if statelang.SU == machineInstance.Status() && machineInstance.Exception() != nil {
machineInstance.SetException(nil)
}
@@ -183,10 +229,86 @@ func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineI
return nil
}

//TODO check if timeout or else report transaction finished
// check if timeout or else report transaction finished
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.StateMachineConfig)
if !ok {
return errors.New("stateMachineConfig is required in context")
}

if core.IsTimeout(machineInstance.UpdatedTime(), cfg.TransOperationTimeout()) {
log.Warnf("StateMachineInstance[%s] is execution timeout, skip report transaction finished to server.", machineInstance.ID())
} else if machineInstance.ParentID() == "" {
//if parentId is not null, machineInstance is a SubStateMachine, do not report global transaction.
err = s.reportTransactionFinished(ctx, machineInstance, context)
if err != nil {
return err
}
}
return nil
}

func (s *StateLogStore) reportTransactionFinished(ctx context.Context, machineInstance statelang.StateMachineInstance, context core.ProcessContext) error {
var err error
defer func() {
s.ClearUp(context)
if err != nil {
log.Errorf("Report transaction finish to server error: %v, StateMachine: %s, XID: %s, Reason: %s",
err, machineInstance.StateMachine().Name(), machineInstance.ID(), err.Error())
}
}()

globalTransaction, err := s.getGlobalTransaction(machineInstance, context)
if err != nil {
log.Errorf("Failed to get global transaction: %v", err)
return err
}

var globalStatus message.GlobalStatus
if statelang.SU == machineInstance.Status() && machineInstance.CompensationStatus() == "" {
globalStatus = message.GlobalStatusCommitted
} else if statelang.SU == machineInstance.CompensationStatus() {
globalStatus = message.GlobalStatusRollbacked
} else if statelang.FA == machineInstance.CompensationStatus() || statelang.UN == machineInstance.CompensationStatus() {
globalStatus = message.GlobalStatusRollbackRetrying
} else if statelang.FA == machineInstance.Status() && machineInstance.CompensationStatus() == "" {
globalStatus = message.GlobalStatusFinished
} else if statelang.UN == machineInstance.Status() && machineInstance.CompensationStatus() == "" {
globalStatus = message.GlobalStatusCommitRetrying
} else {
globalStatus = message.GlobalStatusUnKnown
}

globalTransaction.TxStatus = globalStatus
_, err = tm.GetGlobalTransactionManager().GlobalReport(ctx, globalTransaction)
if err != nil {
return err
}
return nil
}

func (s *StateLogStore) getGlobalTransaction(machineInstance statelang.StateMachineInstance, context core.ProcessContext) (*tm.GlobalTransaction, error) {
globalTransaction, ok := context.GetVariable(constant.VarNameGlobalTx).(*tm.GlobalTransaction)
if ok {
return globalTransaction, nil
}

var xid string
parentId := machineInstance.ParentID()
if parentId == "" {
xid = machineInstance.ID()
} else {
xid = parentId[:strings.LastIndex(parentId, constant.SeperatorParentId)]
}
globalTransaction = &tm.GlobalTransaction{
Xid: xid,
TxStatus: message.GlobalStatusUnKnown,
TxRole: tm.Launcher,
}

context.SetVariable(constant.VarNameGlobalTx, globalTransaction)
return globalTransaction, nil
}

func (s *StateLogStore) RecordStateMachineRestarted(ctx context.Context, machineInstance statelang.StateMachineInstance,
context core.ProcessContext) error {
if machineInstance == nil {
@@ -211,18 +333,25 @@ func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance st
if stateInstance == nil {
return nil
}
isUpdateMode := s.isUpdateMode(stateInstance, context)
isUpdateMode, err := s.isUpdateMode(stateInstance, context)
if err != nil {
return err
}

// if this state is for retry, do not register branch
if stateInstance.StateIDRetriedFor() != "" {
if isUpdateMode {
stateInstance.SetID(stateInstance.StateIDRetriedFor())
} else {
// generate id by default
stateInstance.SetID(s.generateRetryStateInstanceId(stateInstance))
}
} else if stateInstance.StateIDCompensatedFor() != "" {
// if this state is for compensation, do not register branch
stateInstance.SetID(s.generateCompensateStateInstanceId(stateInstance, isUpdateMode))
} else {
//TODO register branch
// register branch
s.branchRegister(stateInstance, context)
}

if stateInstance.ID() == "" && s.seqGenerator != nil {
@@ -252,9 +381,45 @@ func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance st
return nil
}

func (s *StateLogStore) isUpdateMode(instance statelang.StateInstance, context core.ProcessContext) bool {
//TODO implement me, add forward logic
return false
func (s *StateLogStore) isUpdateMode(stateInstance statelang.StateInstance, context core.ProcessContext) (bool, error) {
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig)
if !ok {
return false, errors.New("stateMachineConfig is required in context")
}

instruction, ok := context.GetInstruction().(*core.StateInstruction)
if !ok {
return false, errors.New("stateInstruction is required in processContext")
}
instructionState, err := instruction.GetState(context)
if err != nil {
return false, err
}
taskState, _ := instructionState.(*state.AbstractTaskState)
stateMachine := stateInstance.StateMachineInstance().StateMachine()

if stateInstance.StateIDRetriedFor() != "" {
if taskState != nil && taskState.RetryPersistModeUpdate() {
return taskState.RetryPersistModeUpdate(), nil
} else if stateMachine.IsRetryPersistModeUpdate() {
return stateMachine.IsRetryPersistModeUpdate(), nil
}
return cfg.IsSagaRetryPersistModeUpdate(), nil
} else if stateInstance.StateIDCompensatedFor() != "" {
// find if this compensate has been executed
stateList := stateInstance.StateMachineInstance().StateList()
for _, instance := range stateList {
if instance.IsForCompensation() && instance.Name() == stateInstance.Name() {
if taskState != nil && taskState.CompensatePersistModeUpdate() {
return taskState.CompensatePersistModeUpdate(), nil
} else if stateMachine.IsCompensatePersistModeUpdate() {
return stateMachine.IsCompensatePersistModeUpdate(), nil
}
return cfg.IsSagaCompensatePersistModeUpdate(), nil
}
}
}
return false, nil
}

func (s *StateLogStore) generateRetryStateInstanceId(stateInstance statelang.StateInstance) string {
@@ -293,6 +458,49 @@ func (s *StateLogStore) generateCompensateStateInstanceId(stateInstance statelan
return fmt.Sprintf("%s-%d", originalCompensateStateInstId, maxIndex)
}

func (s *StateLogStore) branchRegister(stateInstance statelang.StateInstance, context core.ProcessContext) error {
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig)
if !ok {
return errors.New("stateMachineConfig is required in context")
}

if !cfg.IsSagaBranchRegisterEnable() {
log.Debugf("sagaBranchRegisterEnable = false, skip branch report. state[%s]", stateInstance.Name())
return nil
}

//Register branch
var err error
machineInstance := stateInstance.StateMachineInstance()
defer func() {
if err != nil {
log.Errorf("Branch transaction failure. StateMachine: %s, XID: %s, State: %s, stateId: %s, err: %v",
machineInstance.StateMachine().Name(), machineInstance.ID(), stateInstance.Name(), stateInstance.ID(), err)
}
}()

globalTransaction, err := s.getGlobalTransaction(machineInstance, context)
if err != nil {
return err
}
if globalTransaction == nil {
err = errors.New("Global transaction is not exists")
return err
}

branchId, err := rm.GetRMRemotingInstance().BranchRegister(rm.BranchRegisterParam{
BranchType: branch.BranchTypeSAGA,
ResourceId: machineInstance.StateMachine().Name() + "#" + stateInstance.Name(),
Xid: globalTransaction.Xid,
})
if err != nil {
return err
}

stateInstance.SetID(strconv.FormatInt(branchId, 10))
return nil
}

func (s *StateLogStore) getIdIndex(stateInstanceId string, separator string) int {
if stateInstanceId != "" {
start := strings.LastIndex(stateInstanceId, separator)
@@ -332,11 +540,124 @@ func (s *StateLogStore) RecordStateFinished(ctx context.Context, stateInstance s
return err
}

//TODO report branch
// A switch to skip branch report on branch success, in order to optimize performance
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig)
if !(ok && !cfg.IsRmReportSuccessEnable() && statelang.SU == stateInstance.Status()) {
err = s.branchReport(stateInstance, context)
return err
}

return nil

}

func (s *StateLogStore) branchReport(stateInstance statelang.StateInstance, context core.ProcessContext) error {
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig)
if ok && !cfg.IsSagaBranchRegisterEnable() {
log.Debugf("sagaBranchRegisterEnable = false, skip branch report. state[%s]", stateInstance.Name())
return nil
}

var branchStatus branch.BranchStatus
// find out the original state instance, only the original state instance is registered on the server,
// and its status should be reported.
var originalStateInst statelang.StateInstance
if stateInstance.StateIDRetriedFor() != "" {
isUpdateMode, err := s.isUpdateMode(stateInstance, context)
if err != nil {
return err
}

if isUpdateMode {
originalStateInst = stateInstance
} else {
originalStateInst = s.findOutOriginalStateInstanceOfRetryState(stateInstance)
}

if statelang.SU == stateInstance.Status() {
branchStatus = branch.BranchStatusPhasetwoCommitted
} else if statelang.FA == stateInstance.Status() || statelang.UN == stateInstance.Status() {
branchStatus = branch.BranchStatusPhaseoneFailed
} else {
branchStatus = branch.BranchStatusUnknown
}
} else if stateInstance.StateIDCompensatedFor() != "" {
isUpdateMode, err := s.isUpdateMode(stateInstance, context)
if err != nil {
return err
}

if isUpdateMode {
originalStateInst = stateInstance.StateMachineInstance().StateMap()[stateInstance.StateIDCompensatedFor()]
} else {
originalStateInst = s.findOutOriginalStateInstanceOfCompensateState(stateInstance)
}
}

if originalStateInst == nil {
originalStateInst = stateInstance
}

if branchStatus == branch.BranchStatusUnknown {
if statelang.SU == originalStateInst.Status() && originalStateInst.CompensationStatus() == "" {
branchStatus = branch.BranchStatusPhasetwoCommitted
} else if statelang.SU == originalStateInst.CompensationStatus() {
branchStatus = branch.BranchStatusPhasetwoRollbacked
} else if statelang.FA == originalStateInst.CompensationStatus() || statelang.UN == originalStateInst.CompensationStatus() {
branchStatus = branch.BranchStatusPhasetwoRollbackFailedRetryable
} else if (statelang.FA == originalStateInst.Status() || statelang.UN == originalStateInst.Status()) && originalStateInst.CompensationStatus() == "" {
branchStatus = branch.BranchStatusPhaseoneFailed
} else {
branchStatus = branch.BranchStatusUnknown
}
}

var err error
defer func() {
if err != nil {
log.Errorf("Report branch status to server error:%s, StateMachine:%s, StateName:%s, XID:%s, branchId:%s, branchStatus:%s, err:%v",
err.Error(), originalStateInst.StateMachineInstance().StateMachine().Name(), originalStateInst.Name(),
originalStateInst.StateMachineInstance().ID(), originalStateInst.ID(), branchStatus, err)
}
}()

globalTransaction, err := s.getGlobalTransaction(stateInstance.StateMachineInstance(), context)
if err != nil {
return err
}
if globalTransaction == nil {
err = errors.New("Global transaction is not exists")
return err
}

branchId, err := strconv.ParseInt(originalStateInst.ID(), 10, 0)
err = rm.GetRMRemotingInstance().BranchReport(rm.BranchReportParam{
BranchType: branch.BranchTypeSAGA,
Xid: globalTransaction.Xid,
BranchId: branchId,
Status: branchStatus,
})
return err
}

func (s *StateLogStore) findOutOriginalStateInstanceOfRetryState(stateInstance statelang.StateInstance) statelang.StateInstance {
stateMap := stateInstance.StateMachineInstance().StateMap()
originalStateInst := stateMap[stateInstance.StateIDRetriedFor()]
for originalStateInst.StateIDRetriedFor() != "" {
originalStateInst = stateMap[stateInstance.StateIDRetriedFor()]
}
return originalStateInst
}

func (s *StateLogStore) findOutOriginalStateInstanceOfCompensateState(stateInstance statelang.StateInstance) statelang.StateInstance {
stateMap := stateInstance.StateMachineInstance().StateMap()
originalStateInst := stateMap[stateInstance.StateIDCompensatedFor()]
for originalStateInst.StateIDRetriedFor() != "" {
originalStateInst = stateMap[stateInstance.StateIDRetriedFor()]
}
return originalStateInst
}

func (s *StateLogStore) GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateMachineInstance, error) {
stateMachineInstance, err := SelectOne(s.db, s.getStateMachineInstanceByIdSql, scanRowsToStateMachineInstance,
stateMachineInstanceId)
@@ -437,7 +758,7 @@ func (s *StateLogStore) GetStateInstanceListByMachineInstanceId(stateMachineInst
if lastStateInstance.EndTime().IsZero() {
lastStateInstance.SetStatus(statelang.RU)
}
//TODO add forward and compensate logic
originStateMap := make(map[string]statelang.StateInstance)
compensatedStateMap := make(map[string]statelang.StateInstance)
retriedStateMap := make(map[string]statelang.StateInstance)
@@ -522,6 +843,11 @@ func (s *StateLogStore) SetSeqGenerator(seqGenerator sequence.SeqGenerator) {
s.seqGenerator = seqGenerator
}

func (s *StateLogStore) ClearUp(context core.ProcessContext) {
context.RemoveVariable(constant2.XidKey)
context.RemoveVariable(constant2.BranchTypeKey)
}

func execStateMachineInstanceStatementForInsert(obj statelang.StateMachineInstance, stmt *sql.Stmt) (int64, error) {
result, err := stmt.Exec(
obj.ID(),


+ 7
- 0
pkg/saga/statemachine/store/db/statelog_test.go View File

@@ -57,6 +57,12 @@ func mockMachineInstance(stateMachineName string) statelang.StateMachineInstance
return inst
}

func mockStateMachineConfig(context core.ProcessContext) core.StateMachineConfig {
cfg := core.NewDefaultStateMachineConfig()
context.SetVariable(constant.VarNameStateMachineConfig, cfg)
return cfg
}

func TestStateLogStore_RecordStateMachineStarted(t *testing.T) {
prepareDB()

@@ -65,6 +71,7 @@ func TestStateLogStore_RecordStateMachineStarted(t *testing.T) {
expected := mockMachineInstance(stateMachineName)
expected.SetBusinessKey("test_started")
ctx := mockProcessContext(stateMachineName, expected)
mockStateMachineConfig(ctx)
err := stateLogStore.RecordStateMachineStarted(context.Background(), expected, ctx)
assert.Nil(t, err)
actual, err := stateLogStore.GetStateMachineInstance(expected.ID())


+ 18
- 0
pkg/saga/statemachine/store/repository/state_log_repository.go View File

@@ -0,0 +1,18 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package repository

+ 203
- 16
pkg/saga/statemachine/store/repository/state_machine_repository.go View File

@@ -18,34 +18,221 @@
package repository

import (
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"io"
"sync"
"time"

"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser"
"github.com/seata/seata-go/pkg/saga/statemachine/store/db"
"github.com/seata/seata-go/pkg/util/log"
)

const (
DefaultJsonParser = "fastjson"
)

var (
stateMachineRepositoryImpl *StateMachineRepositoryImpl
onceStateMachineRepositoryImpl sync.Once
)

type StateMachineRepositoryImpl struct {
stateMachineMapById map[string]statelang.StateMachine
stateMachineMapByNameAndTenant map[string]statelang.StateMachine

stateLangStore *db.StateLangStore
seqGenerator sequence.SeqGenerator
defaultTenantId string
jsonParserName string
charset string
mutex *sync.Mutex
}

func GetStateMachineRepositoryImpl() *StateMachineRepositoryImpl {
if stateMachineRepositoryImpl == nil {
onceStateMachineRepositoryImpl.Do(func() {
//TODO get charset by config
//TODO charset is not use
//TODO using json parser
stateMachineRepositoryImpl = &StateMachineRepositoryImpl{
stateMachineMapById: make(map[string]statelang.StateMachine),
stateMachineMapByNameAndTenant: make(map[string]statelang.StateMachine),
seqGenerator: sequence.NewUUIDSeqGenerator(),
jsonParserName: DefaultJsonParser,
charset: "UTF-8",
mutex: &sync.Mutex{},
}
})
}

return stateMachineRepositoryImpl
}

func (s *StateMachineRepositoryImpl) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) {
stateMachine := s.stateMachineMapById[stateMachineId]
if stateMachine == nil && s.stateLangStore != nil {
s.mutex.Lock()
defer s.mutex.Unlock()

stateMachine = s.stateMachineMapById[stateMachineId]
if stateMachine == nil {
oldStateMachine, err := s.stateLangStore.GetStateMachineById(stateMachineId)
if err != nil {
return oldStateMachine, err
}

parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content())
if err != nil {
return oldStateMachine, err
}

oldStateMachine.SetStartState(parseStatMachine.StartState())
for key, val := range parseStatMachine.States() {
oldStateMachine.States()[key] = val
}

s.stateMachineMapById[stateMachineId] = oldStateMachine
s.stateMachineMapByNameAndTenant[oldStateMachine.Name()+"_"+oldStateMachine.TenantId()] = oldStateMachine
return oldStateMachine, nil
}
}
return stateMachine, nil
}

func (s *StateMachineRepositoryImpl) GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error) {
return s.GetLastVersionStateMachine(stateMachineName, tenantId)
}

func (s *StateMachineRepositoryImpl) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) {
key := stateMachineName + "_" + tenantId
stateMachine := s.stateMachineMapByNameAndTenant[key]
if stateMachine == nil && s.stateLangStore != nil {
s.mutex.Lock()
defer s.mutex.Unlock()

stateMachine = s.stateMachineMapById[key]
if stateMachine == nil {
oldStateMachine, err := s.stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId)
if err != nil {
return oldStateMachine, err
}

parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content())
if err != nil {
return oldStateMachine, err
}

oldStateMachine.SetStartState(parseStatMachine.StartState())
for key, val := range parseStatMachine.States() {
oldStateMachine.States()[key] = val
}

s.stateMachineMapById[oldStateMachine.ID()] = oldStateMachine
s.stateMachineMapByNameAndTenant[key] = oldStateMachine
return oldStateMachine, nil
}
}
return stateMachine, nil
}

func (s *StateMachineRepositoryImpl) RegistryStateMachine(machine statelang.StateMachine) error {
stateMachineName := machine.Name()
tenantId := machine.TenantId()

if s.stateLangStore != nil {
oldStateMachine, err := s.stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId)
if err != nil {
return err
}

if oldStateMachine != nil {
if oldStateMachine.Content() == machine.Content() && machine.Version() != "" && machine.Version() == oldStateMachine.Version() {
log.Debugf("StateMachine[%s] is already exist a same version", stateMachineName)
machine.SetID(oldStateMachine.ID())
machine.SetCreateTime(oldStateMachine.CreateTime())

s.stateMachineMapById[machine.ID()] = machine
s.stateMachineMapByNameAndTenant[machine.Name()+"_"+machine.TenantId()] = machine
return nil
}
}

if machine.ID() == "" {
machine.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachine, ""))
}

machine.SetCreateTime(time.Now())

err = s.stateLangStore.StoreStateMachine(machine)
if err != nil {
return err
}
}

if machine.ID() == "" {
machine.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachine, ""))
}

s.stateMachineMapById[machine.ID()] = machine
s.stateMachineMapByNameAndTenant[machine.Name()+"_"+machine.TenantId()] = machine
return nil
}

func (s *StateMachineRepositoryImpl) RegistryStateMachineByReader(reader io.Reader) error {
jsonByte, err := io.ReadAll(reader)
if err != nil {
return err
}

json := string(jsonByte)
parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(json)
if err != nil {
return err
}

if parseStatMachine == nil {
return nil
}

parseStatMachine.SetContent(json)
s.RegistryStateMachine(parseStatMachine)

log.Debugf("===== StateMachine Loaded: %s", json)

return nil
}

func (s *StateMachineRepositoryImpl) SetStateLangStore(stateLangStore *db.StateLangStore) {
s.stateLangStore = stateLangStore
}

func (s *StateMachineRepositoryImpl) SetSeqGenerator(seqGenerator sequence.SeqGenerator) {
s.seqGenerator = seqGenerator
}

func (s *StateMachineRepositoryImpl) SetCharset(charset string) {
s.charset = charset
}

func (s StateMachineRepositoryImpl) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) {
//TODO implement me
panic("implement me")
func (s *StateMachineRepositoryImpl) GetCharset() string {
return s.charset
}

func (s StateMachineRepositoryImpl) GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error) {
//TODO implement me
panic("implement me")
func (s *StateMachineRepositoryImpl) SetDefaultTenantId(defaultTenantId string) {
s.defaultTenantId = defaultTenantId
}

func (s StateMachineRepositoryImpl) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) {
//TODO implement me
panic("implement me")
func (s *StateMachineRepositoryImpl) GetDefaultTenantId() string {
return s.defaultTenantId
}

func (s StateMachineRepositoryImpl) RegistryStateMachine(machine statelang.StateMachine) error {
//TODO implement me
panic("implement me")
func (s *StateMachineRepositoryImpl) SetJsonParserName(jsonParserName string) {
s.jsonParserName = jsonParserName
}

func (s StateMachineRepositoryImpl) RegistryStateMachineByReader(reader io.Reader) error {
//TODO implement me
panic("implement me")
func (s *StateMachineRepositoryImpl) GetJsonParserName() string {
return s.jsonParserName
}

+ 118
- 0
pkg/saga/statemachine/store/repository/state_machine_repository_test.go View File

@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package repository

import (
"database/sql"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser"
"os"
"sync"
"testing"
"time"

_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/store/db"
)

var (
oncePrepareDB sync.Once
testdb *sql.DB
)

func prepareDB() {
oncePrepareDB.Do(func() {
var err error
testdb, err = sql.Open("sqlite3", ":memory:")
query_, err := os.ReadFile("../../../../../testdata/sql/saga/sqlite_init.sql")
initScript := string(query_)
if err != nil {
panic(err)
}
if _, err := testdb.Exec(initScript); err != nil {
panic(err)
}
})
}

func loadStateMachineByYaml() string {
query, _ := os.ReadFile("../../../../../testdata/saga/statelang/simple_statemachine.json")
return string(query)
}

func TestStateMachineInMemory(t *testing.T) {
const stateMachineId, stateMachineName, tenantId = "simpleStateMachine", "simpleStateMachine", "test"
stateMachine := statelang.NewStateMachineImpl()
stateMachine.SetID(stateMachineId)
stateMachine.SetName(stateMachineName)
stateMachine.SetTenantId(tenantId)
stateMachine.SetComment("This is a test state machine")
stateMachine.SetCreateTime(time.Now())

repository := GetStateMachineRepositoryImpl()

err := repository.RegistryStateMachine(stateMachine)
assert.Nil(t, err)

machineById, err := repository.GetStateMachineById(stateMachine.ID())
assert.Nil(t, err)
assert.Equal(t, stateMachine.Name(), machineById.Name())
assert.Equal(t, stateMachine.TenantId(), machineById.TenantId())
assert.Equal(t, stateMachine.Comment(), machineById.Comment())
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano())

machineByNameAndTenantId, err := repository.GetLastVersionStateMachine(stateMachine.Name(), stateMachine.TenantId())
assert.Nil(t, err)
assert.Equal(t, stateMachine.ID(), machineByNameAndTenantId.ID())
assert.Equal(t, stateMachine.Comment(), machineById.Comment())
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano())
}

func TestStateMachineInDb(t *testing.T) {
prepareDB()

const tenantId = "test"
yaml := loadStateMachineByYaml()
stateMachine, err := parser.NewJSONStateMachineParser().Parse(yaml)
assert.Nil(t, err)
stateMachine.SetTenantId(tenantId)
stateMachine.SetContent(yaml)

repository := GetStateMachineRepositoryImpl()
repository.SetStateLangStore(db.NewStateLangStore(testdb, "seata_"))

err = repository.RegistryStateMachine(stateMachine)
assert.Nil(t, err)

repository.stateMachineMapById[stateMachine.ID()] = nil
machineById, err := repository.GetStateMachineById(stateMachine.ID())
assert.Nil(t, err)
assert.Equal(t, stateMachine.Name(), machineById.Name())
assert.Equal(t, stateMachine.TenantId(), machineById.TenantId())
assert.Equal(t, stateMachine.Comment(), machineById.Comment())
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano())

repository.stateMachineMapByNameAndTenant[stateMachine.Name()+"_"+stateMachine.TenantId()] = nil
machineByNameAndTenantId, err := repository.GetLastVersionStateMachine(stateMachine.Name(), stateMachine.TenantId())
assert.Nil(t, err)
assert.Equal(t, stateMachine.ID(), machineByNameAndTenantId.ID())
assert.Equal(t, stateMachine.Comment(), machineById.Comment())
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano())
}

+ 25
- 0
pkg/tm/global_transaction.go View File

@@ -151,3 +151,28 @@ func (g *GlobalTransactionManager) Rollback(ctx context.Context, gtr *GlobalTran

return nil
}

// GlobalReport Global report.
func (g *GlobalTransactionManager) GlobalReport(ctx context.Context, gtr *GlobalTransaction) (message.GlobalStatus, error) {
if gtr.Xid == "" {
return message.GlobalStatusUnKnown, fmt.Errorf("GlobalReport xid should not be empty")
}

req := message.GlobalReportRequest{
AbstractGlobalEndRequest: message.AbstractGlobalEndRequest{
Xid: gtr.Xid,
},
GlobalStatus: gtr.TxStatus,
}
res, err := getty.GetGettyRemotingClient().SendSyncRequest(req)
if err != nil {
log.Errorf("GlobalBeginRequest error %v", err)
return message.GlobalStatusUnKnown, err
}
if res == nil || res.(message.GlobalReportResponse).ResultCode == message.ResultCodeFailed {
log.Errorf("GlobalReportRequest result is empty or result code is failed, res %v", res)
return message.GlobalStatusUnKnown, fmt.Errorf("GlobalReportRequest result is empty or result code is failed.")
}
log.Infof("GlobalReportRequest success, res %v", res)
return res.(message.GlobalReportResponse).GlobalStatus, nil
}

Loading…
Cancel
Save