* implement state machine repository * temporary recording * temporary recording * temporary storage * Supplementary testpull/808/head
@@ -67,6 +67,7 @@ const ( | |||
// TODO: this lock in process context only has one, try to add more to add concurrent | |||
VarNameProcessContextMutexLock string = "_current_context_mutex_lock" | |||
VarNameFailEndStateFlag string = "_fail_end_state_flag_" | |||
VarNameGlobalTx string = "_global_transaction_" | |||
// end region | |||
// region of loop | |||
@@ -79,10 +80,12 @@ const ( | |||
// end region | |||
// region others | |||
SeqEntityStateMachine string = "STATE_MACHINE" | |||
SeqEntityStateMachineInst string = "STATE_MACHINE_INST" | |||
SeqEntityStateInst string = "STATE_INST" | |||
OperationNameForward string = "forward" | |||
LoopStateNamePattern string = "-loop-" | |||
SagaTransNamePrefix string = "$Saga_" | |||
// end region | |||
SeperatorParentId string = ":" | |||
@@ -25,16 +25,24 @@ import ( | |||
) | |||
const ( | |||
DefaultTransOperTimeout = 60000 * 30 | |||
DefaultServiceInvokeTimeout = 60000 * 5 | |||
DefaultTransOperTimeout = 60000 * 30 | |||
DefaultServiceInvokeTimeout = 60000 * 5 | |||
DefaultClientSagaRetryPersistModeUpdate = false | |||
DefaultClientSagaCompensatePersistModeUpdate = false | |||
DefaultClientReportSuccessEnable = false | |||
DefaultClientSagaBranchRegisterEnable = true | |||
) | |||
type DefaultStateMachineConfig struct { | |||
// Configuration | |||
transOperationTimeout int | |||
serviceInvokeTimeout int | |||
charset string | |||
defaultTenantId string | |||
transOperationTimeout int | |||
serviceInvokeTimeout int | |||
charset string | |||
defaultTenantId string | |||
sagaRetryPersistModeUpdate bool | |||
sagaCompensatePersistModeUpdate bool | |||
sagaBranchRegisterEnable bool | |||
rmReportSuccessEnable bool | |||
// Components | |||
@@ -202,13 +210,49 @@ func (c *DefaultStateMachineConfig) ServiceInvokeTimeout() int { | |||
return c.serviceInvokeTimeout | |||
} | |||
func (c *DefaultStateMachineConfig) IsSagaRetryPersistModeUpdate() bool { | |||
return c.sagaRetryPersistModeUpdate | |||
} | |||
func (c *DefaultStateMachineConfig) SetSagaRetryPersistModeUpdate(sagaRetryPersistModeUpdate bool) { | |||
c.sagaRetryPersistModeUpdate = sagaRetryPersistModeUpdate | |||
} | |||
func (c *DefaultStateMachineConfig) IsSagaCompensatePersistModeUpdate() bool { | |||
return c.sagaCompensatePersistModeUpdate | |||
} | |||
func (c *DefaultStateMachineConfig) SetSagaCompensatePersistModeUpdate(sagaCompensatePersistModeUpdate bool) { | |||
c.sagaCompensatePersistModeUpdate = sagaCompensatePersistModeUpdate | |||
} | |||
func (c *DefaultStateMachineConfig) IsSagaBranchRegisterEnable() bool { | |||
return c.sagaBranchRegisterEnable | |||
} | |||
func (c *DefaultStateMachineConfig) SetSagaBranchRegisterEnable(sagaBranchRegisterEnable bool) { | |||
c.sagaBranchRegisterEnable = sagaBranchRegisterEnable | |||
} | |||
func (c *DefaultStateMachineConfig) IsRmReportSuccessEnable() bool { | |||
return c.rmReportSuccessEnable | |||
} | |||
func (c *DefaultStateMachineConfig) SetRmReportSuccessEnable(rmReportSuccessEnable bool) { | |||
c.rmReportSuccessEnable = rmReportSuccessEnable | |||
} | |||
func NewDefaultStateMachineConfig() *DefaultStateMachineConfig { | |||
c := &DefaultStateMachineConfig{ | |||
transOperationTimeout: DefaultTransOperTimeout, | |||
serviceInvokeTimeout: DefaultServiceInvokeTimeout, | |||
charset: "UTF-8", | |||
defaultTenantId: "000001", | |||
componentLock: &sync.Mutex{}, | |||
transOperationTimeout: DefaultTransOperTimeout, | |||
serviceInvokeTimeout: DefaultServiceInvokeTimeout, | |||
charset: "UTF-8", | |||
defaultTenantId: "000001", | |||
sagaRetryPersistModeUpdate: DefaultClientSagaRetryPersistModeUpdate, | |||
sagaCompensatePersistModeUpdate: DefaultClientSagaCompensatePersistModeUpdate, | |||
sagaBranchRegisterEnable: DefaultClientSagaBranchRegisterEnable, | |||
rmReportSuccessEnable: DefaultClientReportSuccessEnable, | |||
componentLock: &sync.Mutex{}, | |||
} | |||
// TODO: init config | |||
@@ -55,6 +55,8 @@ type StateLogStore interface { | |||
GetStateInstance(stateInstanceId string, stateMachineInstanceId string) (statelang.StateInstance, error) | |||
GetStateInstanceListByMachineInstanceId(stateMachineInstanceId string) ([]statelang.StateInstance, error) | |||
ClearUp(context ProcessContext) | |||
} | |||
type StateMachineRepository interface { | |||
@@ -72,6 +72,10 @@ type StateMachineInstance interface { | |||
SetStatus(status ExecutionStatus) | |||
StateMap() map[string]StateInstance | |||
SetStateMap(stateMap map[string]StateInstance) | |||
CompensationStatus() ExecutionStatus | |||
SetCompensationStatus(compensationStatus ExecutionStatus) | |||
@@ -234,6 +238,14 @@ func (s *StateMachineInstanceImpl) SetStatus(status ExecutionStatus) { | |||
s.status = status | |||
} | |||
func (s *StateMachineInstanceImpl) StateMap() map[string]StateInstance { | |||
return s.stateMap | |||
} | |||
func (s *StateMachineInstanceImpl) SetStateMap(stateMap map[string]StateInstance) { | |||
s.stateMap = stateMap | |||
} | |||
func (s *StateMachineInstanceImpl) CompensationStatus() ExecutionStatus { | |||
return s.compensationStatus | |||
} | |||
@@ -104,7 +104,7 @@ func scanRowsToStateMachine(rows *sql.Rows) (statelang.StateMachine, error) { | |||
if recoverStrategy != "" { | |||
stateMachine.SetRecoverStrategy(statelang.RecoverStrategy(recoverStrategy)) | |||
} | |||
stateMachine.SetTenantId(t) | |||
stateMachine.SetTenantId(tenantId) | |||
stateMachine.SetStatus(statelang.StateMachineStatus(status)) | |||
return stateMachine, nil | |||
} | |||
@@ -21,17 +21,24 @@ import ( | |||
"context" | |||
"database/sql" | |||
"fmt" | |||
"regexp" | |||
"strconv" | |||
"strings" | |||
"time" | |||
"github.com/pkg/errors" | |||
constant2 "github.com/seata/seata-go/pkg/constant" | |||
"github.com/seata/seata-go/pkg/protocol/branch" | |||
"github.com/seata/seata-go/pkg/protocol/message" | |||
"github.com/seata/seata-go/pkg/rm" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/constant" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/engine/core" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/engine/serializer" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/statelang" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" | |||
"github.com/seata/seata-go/pkg/tm" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
"regexp" | |||
"strconv" | |||
"strings" | |||
"time" | |||
) | |||
const ( | |||
@@ -124,14 +131,19 @@ func (s *StateLogStore) RecordStateMachineStarted(ctx context.Context, machineIn | |||
parentId := machineInstance.ParentID() | |||
if parentId == "" { | |||
//TODO begin transaction | |||
// begin transaction | |||
err = s.beginTransaction(ctx, machineInstance, context) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
if machineInstance.ID() == "" && s.seqGenerator != nil { | |||
machineInstance.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachineInst, "")) | |||
} | |||
//TODO bind SAGA branch type | |||
// bind SAGA branch type | |||
context.SetVariable(constant2.BranchTypeKey, branch.BranchTypeSAGA) | |||
serializedStartParams, err := s.paramsSerializer.Serialize(machineInstance.StartParams()) | |||
if err != nil { | |||
@@ -149,14 +161,48 @@ func (s *StateLogStore) RecordStateMachineStarted(ctx context.Context, machineIn | |||
return nil | |||
} | |||
func (s *StateLogStore) beginTransaction(ctx context.Context, machineInstance statelang.StateMachineInstance, context core.ProcessContext) error { | |||
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.StateMachineConfig) | |||
if !ok { | |||
return errors.New("begin transaction fail, stateMachineConfig is required in context") | |||
} | |||
defer func() { | |||
isAsync, ok := context.GetVariable(constant.VarNameIsAsyncExecution).(bool) | |||
if ok && isAsync { | |||
s.ClearUp(context) | |||
} | |||
}() | |||
tm.SetTxRole(ctx, tm.Launcher) | |||
tm.SetTxStatus(ctx, message.GlobalStatusUnKnown) | |||
tm.SetTxName(ctx, constant.SagaTransNamePrefix+machineInstance.StateMachine().Name()) | |||
err := tm.GetGlobalTransactionManager().Begin(ctx, time.Duration(cfg.TransOperationTimeout())) | |||
if err != nil { | |||
return err | |||
} | |||
machineInstance.SetID(tm.GetXID(ctx)) | |||
return nil | |||
} | |||
func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineInstance statelang.StateMachineInstance, | |||
context core.ProcessContext) error { | |||
if machineInstance == nil { | |||
return nil | |||
} | |||
defer func() { | |||
s.ClearUp(context) | |||
}() | |||
endParams := machineInstance.EndParams() | |||
if endParams != nil { | |||
delete(endParams, constant.VarNameGlobalTx) | |||
} | |||
// if success, clear exception | |||
if statelang.SU == machineInstance.Status() && machineInstance.Exception() != nil { | |||
machineInstance.SetException(nil) | |||
} | |||
@@ -183,10 +229,86 @@ func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineI | |||
return nil | |||
} | |||
//TODO check if timeout or else report transaction finished | |||
// check if timeout or else report transaction finished | |||
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.StateMachineConfig) | |||
if !ok { | |||
return errors.New("stateMachineConfig is required in context") | |||
} | |||
if core.IsTimeout(machineInstance.UpdatedTime(), cfg.TransOperationTimeout()) { | |||
log.Warnf("StateMachineInstance[%s] is execution timeout, skip report transaction finished to server.", machineInstance.ID()) | |||
} else if machineInstance.ParentID() == "" { | |||
//if parentId is not null, machineInstance is a SubStateMachine, do not report global transaction. | |||
err = s.reportTransactionFinished(ctx, machineInstance, context) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
return nil | |||
} | |||
func (s *StateLogStore) reportTransactionFinished(ctx context.Context, machineInstance statelang.StateMachineInstance, context core.ProcessContext) error { | |||
var err error | |||
defer func() { | |||
s.ClearUp(context) | |||
if err != nil { | |||
log.Errorf("Report transaction finish to server error: %v, StateMachine: %s, XID: %s, Reason: %s", | |||
err, machineInstance.StateMachine().Name(), machineInstance.ID(), err.Error()) | |||
} | |||
}() | |||
globalTransaction, err := s.getGlobalTransaction(machineInstance, context) | |||
if err != nil { | |||
log.Errorf("Failed to get global transaction: %v", err) | |||
return err | |||
} | |||
var globalStatus message.GlobalStatus | |||
if statelang.SU == machineInstance.Status() && machineInstance.CompensationStatus() == "" { | |||
globalStatus = message.GlobalStatusCommitted | |||
} else if statelang.SU == machineInstance.CompensationStatus() { | |||
globalStatus = message.GlobalStatusRollbacked | |||
} else if statelang.FA == machineInstance.CompensationStatus() || statelang.UN == machineInstance.CompensationStatus() { | |||
globalStatus = message.GlobalStatusRollbackRetrying | |||
} else if statelang.FA == machineInstance.Status() && machineInstance.CompensationStatus() == "" { | |||
globalStatus = message.GlobalStatusFinished | |||
} else if statelang.UN == machineInstance.Status() && machineInstance.CompensationStatus() == "" { | |||
globalStatus = message.GlobalStatusCommitRetrying | |||
} else { | |||
globalStatus = message.GlobalStatusUnKnown | |||
} | |||
globalTransaction.TxStatus = globalStatus | |||
_, err = tm.GetGlobalTransactionManager().GlobalReport(ctx, globalTransaction) | |||
if err != nil { | |||
return err | |||
} | |||
return nil | |||
} | |||
func (s *StateLogStore) getGlobalTransaction(machineInstance statelang.StateMachineInstance, context core.ProcessContext) (*tm.GlobalTransaction, error) { | |||
globalTransaction, ok := context.GetVariable(constant.VarNameGlobalTx).(*tm.GlobalTransaction) | |||
if ok { | |||
return globalTransaction, nil | |||
} | |||
var xid string | |||
parentId := machineInstance.ParentID() | |||
if parentId == "" { | |||
xid = machineInstance.ID() | |||
} else { | |||
xid = parentId[:strings.LastIndex(parentId, constant.SeperatorParentId)] | |||
} | |||
globalTransaction = &tm.GlobalTransaction{ | |||
Xid: xid, | |||
TxStatus: message.GlobalStatusUnKnown, | |||
TxRole: tm.Launcher, | |||
} | |||
context.SetVariable(constant.VarNameGlobalTx, globalTransaction) | |||
return globalTransaction, nil | |||
} | |||
func (s *StateLogStore) RecordStateMachineRestarted(ctx context.Context, machineInstance statelang.StateMachineInstance, | |||
context core.ProcessContext) error { | |||
if machineInstance == nil { | |||
@@ -211,18 +333,25 @@ func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance st | |||
if stateInstance == nil { | |||
return nil | |||
} | |||
isUpdateMode := s.isUpdateMode(stateInstance, context) | |||
isUpdateMode, err := s.isUpdateMode(stateInstance, context) | |||
if err != nil { | |||
return err | |||
} | |||
// if this state is for retry, do not register branch | |||
if stateInstance.StateIDRetriedFor() != "" { | |||
if isUpdateMode { | |||
stateInstance.SetID(stateInstance.StateIDRetriedFor()) | |||
} else { | |||
// generate id by default | |||
stateInstance.SetID(s.generateRetryStateInstanceId(stateInstance)) | |||
} | |||
} else if stateInstance.StateIDCompensatedFor() != "" { | |||
// if this state is for compensation, do not register branch | |||
stateInstance.SetID(s.generateCompensateStateInstanceId(stateInstance, isUpdateMode)) | |||
} else { | |||
//TODO register branch | |||
// register branch | |||
s.branchRegister(stateInstance, context) | |||
} | |||
if stateInstance.ID() == "" && s.seqGenerator != nil { | |||
@@ -252,9 +381,45 @@ func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance st | |||
return nil | |||
} | |||
func (s *StateLogStore) isUpdateMode(instance statelang.StateInstance, context core.ProcessContext) bool { | |||
//TODO implement me, add forward logic | |||
return false | |||
func (s *StateLogStore) isUpdateMode(stateInstance statelang.StateInstance, context core.ProcessContext) (bool, error) { | |||
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) | |||
if !ok { | |||
return false, errors.New("stateMachineConfig is required in context") | |||
} | |||
instruction, ok := context.GetInstruction().(*core.StateInstruction) | |||
if !ok { | |||
return false, errors.New("stateInstruction is required in processContext") | |||
} | |||
instructionState, err := instruction.GetState(context) | |||
if err != nil { | |||
return false, err | |||
} | |||
taskState, _ := instructionState.(*state.AbstractTaskState) | |||
stateMachine := stateInstance.StateMachineInstance().StateMachine() | |||
if stateInstance.StateIDRetriedFor() != "" { | |||
if taskState != nil && taskState.RetryPersistModeUpdate() { | |||
return taskState.RetryPersistModeUpdate(), nil | |||
} else if stateMachine.IsRetryPersistModeUpdate() { | |||
return stateMachine.IsRetryPersistModeUpdate(), nil | |||
} | |||
return cfg.IsSagaRetryPersistModeUpdate(), nil | |||
} else if stateInstance.StateIDCompensatedFor() != "" { | |||
// find if this compensate has been executed | |||
stateList := stateInstance.StateMachineInstance().StateList() | |||
for _, instance := range stateList { | |||
if instance.IsForCompensation() && instance.Name() == stateInstance.Name() { | |||
if taskState != nil && taskState.CompensatePersistModeUpdate() { | |||
return taskState.CompensatePersistModeUpdate(), nil | |||
} else if stateMachine.IsCompensatePersistModeUpdate() { | |||
return stateMachine.IsCompensatePersistModeUpdate(), nil | |||
} | |||
return cfg.IsSagaCompensatePersistModeUpdate(), nil | |||
} | |||
} | |||
} | |||
return false, nil | |||
} | |||
func (s *StateLogStore) generateRetryStateInstanceId(stateInstance statelang.StateInstance) string { | |||
@@ -293,6 +458,49 @@ func (s *StateLogStore) generateCompensateStateInstanceId(stateInstance statelan | |||
return fmt.Sprintf("%s-%d", originalCompensateStateInstId, maxIndex) | |||
} | |||
func (s *StateLogStore) branchRegister(stateInstance statelang.StateInstance, context core.ProcessContext) error { | |||
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) | |||
if !ok { | |||
return errors.New("stateMachineConfig is required in context") | |||
} | |||
if !cfg.IsSagaBranchRegisterEnable() { | |||
log.Debugf("sagaBranchRegisterEnable = false, skip branch report. state[%s]", stateInstance.Name()) | |||
return nil | |||
} | |||
//Register branch | |||
var err error | |||
machineInstance := stateInstance.StateMachineInstance() | |||
defer func() { | |||
if err != nil { | |||
log.Errorf("Branch transaction failure. StateMachine: %s, XID: %s, State: %s, stateId: %s, err: %v", | |||
machineInstance.StateMachine().Name(), machineInstance.ID(), stateInstance.Name(), stateInstance.ID(), err) | |||
} | |||
}() | |||
globalTransaction, err := s.getGlobalTransaction(machineInstance, context) | |||
if err != nil { | |||
return err | |||
} | |||
if globalTransaction == nil { | |||
err = errors.New("Global transaction is not exists") | |||
return err | |||
} | |||
branchId, err := rm.GetRMRemotingInstance().BranchRegister(rm.BranchRegisterParam{ | |||
BranchType: branch.BranchTypeSAGA, | |||
ResourceId: machineInstance.StateMachine().Name() + "#" + stateInstance.Name(), | |||
Xid: globalTransaction.Xid, | |||
}) | |||
if err != nil { | |||
return err | |||
} | |||
stateInstance.SetID(strconv.FormatInt(branchId, 10)) | |||
return nil | |||
} | |||
func (s *StateLogStore) getIdIndex(stateInstanceId string, separator string) int { | |||
if stateInstanceId != "" { | |||
start := strings.LastIndex(stateInstanceId, separator) | |||
@@ -332,11 +540,124 @@ func (s *StateLogStore) RecordStateFinished(ctx context.Context, stateInstance s | |||
return err | |||
} | |||
//TODO report branch | |||
// A switch to skip branch report on branch success, in order to optimize performance | |||
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) | |||
if !(ok && !cfg.IsRmReportSuccessEnable() && statelang.SU == stateInstance.Status()) { | |||
err = s.branchReport(stateInstance, context) | |||
return err | |||
} | |||
return nil | |||
} | |||
func (s *StateLogStore) branchReport(stateInstance statelang.StateInstance, context core.ProcessContext) error { | |||
cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) | |||
if ok && !cfg.IsSagaBranchRegisterEnable() { | |||
log.Debugf("sagaBranchRegisterEnable = false, skip branch report. state[%s]", stateInstance.Name()) | |||
return nil | |||
} | |||
var branchStatus branch.BranchStatus | |||
// find out the original state instance, only the original state instance is registered on the server, | |||
// and its status should be reported. | |||
var originalStateInst statelang.StateInstance | |||
if stateInstance.StateIDRetriedFor() != "" { | |||
isUpdateMode, err := s.isUpdateMode(stateInstance, context) | |||
if err != nil { | |||
return err | |||
} | |||
if isUpdateMode { | |||
originalStateInst = stateInstance | |||
} else { | |||
originalStateInst = s.findOutOriginalStateInstanceOfRetryState(stateInstance) | |||
} | |||
if statelang.SU == stateInstance.Status() { | |||
branchStatus = branch.BranchStatusPhasetwoCommitted | |||
} else if statelang.FA == stateInstance.Status() || statelang.UN == stateInstance.Status() { | |||
branchStatus = branch.BranchStatusPhaseoneFailed | |||
} else { | |||
branchStatus = branch.BranchStatusUnknown | |||
} | |||
} else if stateInstance.StateIDCompensatedFor() != "" { | |||
isUpdateMode, err := s.isUpdateMode(stateInstance, context) | |||
if err != nil { | |||
return err | |||
} | |||
if isUpdateMode { | |||
originalStateInst = stateInstance.StateMachineInstance().StateMap()[stateInstance.StateIDCompensatedFor()] | |||
} else { | |||
originalStateInst = s.findOutOriginalStateInstanceOfCompensateState(stateInstance) | |||
} | |||
} | |||
if originalStateInst == nil { | |||
originalStateInst = stateInstance | |||
} | |||
if branchStatus == branch.BranchStatusUnknown { | |||
if statelang.SU == originalStateInst.Status() && originalStateInst.CompensationStatus() == "" { | |||
branchStatus = branch.BranchStatusPhasetwoCommitted | |||
} else if statelang.SU == originalStateInst.CompensationStatus() { | |||
branchStatus = branch.BranchStatusPhasetwoRollbacked | |||
} else if statelang.FA == originalStateInst.CompensationStatus() || statelang.UN == originalStateInst.CompensationStatus() { | |||
branchStatus = branch.BranchStatusPhasetwoRollbackFailedRetryable | |||
} else if (statelang.FA == originalStateInst.Status() || statelang.UN == originalStateInst.Status()) && originalStateInst.CompensationStatus() == "" { | |||
branchStatus = branch.BranchStatusPhaseoneFailed | |||
} else { | |||
branchStatus = branch.BranchStatusUnknown | |||
} | |||
} | |||
var err error | |||
defer func() { | |||
if err != nil { | |||
log.Errorf("Report branch status to server error:%s, StateMachine:%s, StateName:%s, XID:%s, branchId:%s, branchStatus:%s, err:%v", | |||
err.Error(), originalStateInst.StateMachineInstance().StateMachine().Name(), originalStateInst.Name(), | |||
originalStateInst.StateMachineInstance().ID(), originalStateInst.ID(), branchStatus, err) | |||
} | |||
}() | |||
globalTransaction, err := s.getGlobalTransaction(stateInstance.StateMachineInstance(), context) | |||
if err != nil { | |||
return err | |||
} | |||
if globalTransaction == nil { | |||
err = errors.New("Global transaction is not exists") | |||
return err | |||
} | |||
branchId, err := strconv.ParseInt(originalStateInst.ID(), 10, 0) | |||
err = rm.GetRMRemotingInstance().BranchReport(rm.BranchReportParam{ | |||
BranchType: branch.BranchTypeSAGA, | |||
Xid: globalTransaction.Xid, | |||
BranchId: branchId, | |||
Status: branchStatus, | |||
}) | |||
return err | |||
} | |||
func (s *StateLogStore) findOutOriginalStateInstanceOfRetryState(stateInstance statelang.StateInstance) statelang.StateInstance { | |||
stateMap := stateInstance.StateMachineInstance().StateMap() | |||
originalStateInst := stateMap[stateInstance.StateIDRetriedFor()] | |||
for originalStateInst.StateIDRetriedFor() != "" { | |||
originalStateInst = stateMap[stateInstance.StateIDRetriedFor()] | |||
} | |||
return originalStateInst | |||
} | |||
func (s *StateLogStore) findOutOriginalStateInstanceOfCompensateState(stateInstance statelang.StateInstance) statelang.StateInstance { | |||
stateMap := stateInstance.StateMachineInstance().StateMap() | |||
originalStateInst := stateMap[stateInstance.StateIDCompensatedFor()] | |||
for originalStateInst.StateIDRetriedFor() != "" { | |||
originalStateInst = stateMap[stateInstance.StateIDRetriedFor()] | |||
} | |||
return originalStateInst | |||
} | |||
func (s *StateLogStore) GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateMachineInstance, error) { | |||
stateMachineInstance, err := SelectOne(s.db, s.getStateMachineInstanceByIdSql, scanRowsToStateMachineInstance, | |||
stateMachineInstanceId) | |||
@@ -437,7 +758,7 @@ func (s *StateLogStore) GetStateInstanceListByMachineInstanceId(stateMachineInst | |||
if lastStateInstance.EndTime().IsZero() { | |||
lastStateInstance.SetStatus(statelang.RU) | |||
} | |||
//TODO add forward and compensate logic | |||
originStateMap := make(map[string]statelang.StateInstance) | |||
compensatedStateMap := make(map[string]statelang.StateInstance) | |||
retriedStateMap := make(map[string]statelang.StateInstance) | |||
@@ -522,6 +843,11 @@ func (s *StateLogStore) SetSeqGenerator(seqGenerator sequence.SeqGenerator) { | |||
s.seqGenerator = seqGenerator | |||
} | |||
func (s *StateLogStore) ClearUp(context core.ProcessContext) { | |||
context.RemoveVariable(constant2.XidKey) | |||
context.RemoveVariable(constant2.BranchTypeKey) | |||
} | |||
func execStateMachineInstanceStatementForInsert(obj statelang.StateMachineInstance, stmt *sql.Stmt) (int64, error) { | |||
result, err := stmt.Exec( | |||
obj.ID(), | |||
@@ -57,6 +57,12 @@ func mockMachineInstance(stateMachineName string) statelang.StateMachineInstance | |||
return inst | |||
} | |||
func mockStateMachineConfig(context core.ProcessContext) core.StateMachineConfig { | |||
cfg := core.NewDefaultStateMachineConfig() | |||
context.SetVariable(constant.VarNameStateMachineConfig, cfg) | |||
return cfg | |||
} | |||
func TestStateLogStore_RecordStateMachineStarted(t *testing.T) { | |||
prepareDB() | |||
@@ -65,6 +71,7 @@ func TestStateLogStore_RecordStateMachineStarted(t *testing.T) { | |||
expected := mockMachineInstance(stateMachineName) | |||
expected.SetBusinessKey("test_started") | |||
ctx := mockProcessContext(stateMachineName, expected) | |||
mockStateMachineConfig(ctx) | |||
err := stateLogStore.RecordStateMachineStarted(context.Background(), expected, ctx) | |||
assert.Nil(t, err) | |||
actual, err := stateLogStore.GetStateMachineInstance(expected.ID()) | |||
@@ -0,0 +1,18 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package repository |
@@ -18,34 +18,221 @@ | |||
package repository | |||
import ( | |||
"github.com/seata/seata-go/pkg/saga/statemachine/statelang" | |||
"io" | |||
"sync" | |||
"time" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/constant" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/statelang" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/store/db" | |||
"github.com/seata/seata-go/pkg/util/log" | |||
) | |||
const ( | |||
DefaultJsonParser = "fastjson" | |||
) | |||
var ( | |||
stateMachineRepositoryImpl *StateMachineRepositoryImpl | |||
onceStateMachineRepositoryImpl sync.Once | |||
) | |||
type StateMachineRepositoryImpl struct { | |||
stateMachineMapById map[string]statelang.StateMachine | |||
stateMachineMapByNameAndTenant map[string]statelang.StateMachine | |||
stateLangStore *db.StateLangStore | |||
seqGenerator sequence.SeqGenerator | |||
defaultTenantId string | |||
jsonParserName string | |||
charset string | |||
mutex *sync.Mutex | |||
} | |||
func GetStateMachineRepositoryImpl() *StateMachineRepositoryImpl { | |||
if stateMachineRepositoryImpl == nil { | |||
onceStateMachineRepositoryImpl.Do(func() { | |||
//TODO get charset by config | |||
//TODO charset is not use | |||
//TODO using json parser | |||
stateMachineRepositoryImpl = &StateMachineRepositoryImpl{ | |||
stateMachineMapById: make(map[string]statelang.StateMachine), | |||
stateMachineMapByNameAndTenant: make(map[string]statelang.StateMachine), | |||
seqGenerator: sequence.NewUUIDSeqGenerator(), | |||
jsonParserName: DefaultJsonParser, | |||
charset: "UTF-8", | |||
mutex: &sync.Mutex{}, | |||
} | |||
}) | |||
} | |||
return stateMachineRepositoryImpl | |||
} | |||
func (s *StateMachineRepositoryImpl) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) { | |||
stateMachine := s.stateMachineMapById[stateMachineId] | |||
if stateMachine == nil && s.stateLangStore != nil { | |||
s.mutex.Lock() | |||
defer s.mutex.Unlock() | |||
stateMachine = s.stateMachineMapById[stateMachineId] | |||
if stateMachine == nil { | |||
oldStateMachine, err := s.stateLangStore.GetStateMachineById(stateMachineId) | |||
if err != nil { | |||
return oldStateMachine, err | |||
} | |||
parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content()) | |||
if err != nil { | |||
return oldStateMachine, err | |||
} | |||
oldStateMachine.SetStartState(parseStatMachine.StartState()) | |||
for key, val := range parseStatMachine.States() { | |||
oldStateMachine.States()[key] = val | |||
} | |||
s.stateMachineMapById[stateMachineId] = oldStateMachine | |||
s.stateMachineMapByNameAndTenant[oldStateMachine.Name()+"_"+oldStateMachine.TenantId()] = oldStateMachine | |||
return oldStateMachine, nil | |||
} | |||
} | |||
return stateMachine, nil | |||
} | |||
func (s *StateMachineRepositoryImpl) GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error) { | |||
return s.GetLastVersionStateMachine(stateMachineName, tenantId) | |||
} | |||
func (s *StateMachineRepositoryImpl) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) { | |||
key := stateMachineName + "_" + tenantId | |||
stateMachine := s.stateMachineMapByNameAndTenant[key] | |||
if stateMachine == nil && s.stateLangStore != nil { | |||
s.mutex.Lock() | |||
defer s.mutex.Unlock() | |||
stateMachine = s.stateMachineMapById[key] | |||
if stateMachine == nil { | |||
oldStateMachine, err := s.stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId) | |||
if err != nil { | |||
return oldStateMachine, err | |||
} | |||
parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content()) | |||
if err != nil { | |||
return oldStateMachine, err | |||
} | |||
oldStateMachine.SetStartState(parseStatMachine.StartState()) | |||
for key, val := range parseStatMachine.States() { | |||
oldStateMachine.States()[key] = val | |||
} | |||
s.stateMachineMapById[oldStateMachine.ID()] = oldStateMachine | |||
s.stateMachineMapByNameAndTenant[key] = oldStateMachine | |||
return oldStateMachine, nil | |||
} | |||
} | |||
return stateMachine, nil | |||
} | |||
func (s *StateMachineRepositoryImpl) RegistryStateMachine(machine statelang.StateMachine) error { | |||
stateMachineName := machine.Name() | |||
tenantId := machine.TenantId() | |||
if s.stateLangStore != nil { | |||
oldStateMachine, err := s.stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId) | |||
if err != nil { | |||
return err | |||
} | |||
if oldStateMachine != nil { | |||
if oldStateMachine.Content() == machine.Content() && machine.Version() != "" && machine.Version() == oldStateMachine.Version() { | |||
log.Debugf("StateMachine[%s] is already exist a same version", stateMachineName) | |||
machine.SetID(oldStateMachine.ID()) | |||
machine.SetCreateTime(oldStateMachine.CreateTime()) | |||
s.stateMachineMapById[machine.ID()] = machine | |||
s.stateMachineMapByNameAndTenant[machine.Name()+"_"+machine.TenantId()] = machine | |||
return nil | |||
} | |||
} | |||
if machine.ID() == "" { | |||
machine.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachine, "")) | |||
} | |||
machine.SetCreateTime(time.Now()) | |||
err = s.stateLangStore.StoreStateMachine(machine) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
if machine.ID() == "" { | |||
machine.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachine, "")) | |||
} | |||
s.stateMachineMapById[machine.ID()] = machine | |||
s.stateMachineMapByNameAndTenant[machine.Name()+"_"+machine.TenantId()] = machine | |||
return nil | |||
} | |||
func (s *StateMachineRepositoryImpl) RegistryStateMachineByReader(reader io.Reader) error { | |||
jsonByte, err := io.ReadAll(reader) | |||
if err != nil { | |||
return err | |||
} | |||
json := string(jsonByte) | |||
parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(json) | |||
if err != nil { | |||
return err | |||
} | |||
if parseStatMachine == nil { | |||
return nil | |||
} | |||
parseStatMachine.SetContent(json) | |||
s.RegistryStateMachine(parseStatMachine) | |||
log.Debugf("===== StateMachine Loaded: %s", json) | |||
return nil | |||
} | |||
func (s *StateMachineRepositoryImpl) SetStateLangStore(stateLangStore *db.StateLangStore) { | |||
s.stateLangStore = stateLangStore | |||
} | |||
func (s *StateMachineRepositoryImpl) SetSeqGenerator(seqGenerator sequence.SeqGenerator) { | |||
s.seqGenerator = seqGenerator | |||
} | |||
func (s *StateMachineRepositoryImpl) SetCharset(charset string) { | |||
s.charset = charset | |||
} | |||
func (s StateMachineRepositoryImpl) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) { | |||
//TODO implement me | |||
panic("implement me") | |||
func (s *StateMachineRepositoryImpl) GetCharset() string { | |||
return s.charset | |||
} | |||
func (s StateMachineRepositoryImpl) GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error) { | |||
//TODO implement me | |||
panic("implement me") | |||
func (s *StateMachineRepositoryImpl) SetDefaultTenantId(defaultTenantId string) { | |||
s.defaultTenantId = defaultTenantId | |||
} | |||
func (s StateMachineRepositoryImpl) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) { | |||
//TODO implement me | |||
panic("implement me") | |||
func (s *StateMachineRepositoryImpl) GetDefaultTenantId() string { | |||
return s.defaultTenantId | |||
} | |||
func (s StateMachineRepositoryImpl) RegistryStateMachine(machine statelang.StateMachine) error { | |||
//TODO implement me | |||
panic("implement me") | |||
func (s *StateMachineRepositoryImpl) SetJsonParserName(jsonParserName string) { | |||
s.jsonParserName = jsonParserName | |||
} | |||
func (s StateMachineRepositoryImpl) RegistryStateMachineByReader(reader io.Reader) error { | |||
//TODO implement me | |||
panic("implement me") | |||
func (s *StateMachineRepositoryImpl) GetJsonParserName() string { | |||
return s.jsonParserName | |||
} |
@@ -0,0 +1,118 @@ | |||
/* | |||
* Licensed to the Apache Software Foundation (ASF) under one or more | |||
* contributor license agreements. See the NOTICE file distributed with | |||
* this work for additional information regarding copyright ownership. | |||
* The ASF licenses this file to You under the Apache License, Version 2.0 | |||
* (the "License"); you may not use this file except in compliance with | |||
* the License. You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package repository | |||
import ( | |||
"database/sql" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser" | |||
"os" | |||
"sync" | |||
"testing" | |||
"time" | |||
_ "github.com/mattn/go-sqlite3" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/statelang" | |||
"github.com/seata/seata-go/pkg/saga/statemachine/store/db" | |||
) | |||
var ( | |||
oncePrepareDB sync.Once | |||
testdb *sql.DB | |||
) | |||
func prepareDB() { | |||
oncePrepareDB.Do(func() { | |||
var err error | |||
testdb, err = sql.Open("sqlite3", ":memory:") | |||
query_, err := os.ReadFile("../../../../../testdata/sql/saga/sqlite_init.sql") | |||
initScript := string(query_) | |||
if err != nil { | |||
panic(err) | |||
} | |||
if _, err := testdb.Exec(initScript); err != nil { | |||
panic(err) | |||
} | |||
}) | |||
} | |||
func loadStateMachineByYaml() string { | |||
query, _ := os.ReadFile("../../../../../testdata/saga/statelang/simple_statemachine.json") | |||
return string(query) | |||
} | |||
func TestStateMachineInMemory(t *testing.T) { | |||
const stateMachineId, stateMachineName, tenantId = "simpleStateMachine", "simpleStateMachine", "test" | |||
stateMachine := statelang.NewStateMachineImpl() | |||
stateMachine.SetID(stateMachineId) | |||
stateMachine.SetName(stateMachineName) | |||
stateMachine.SetTenantId(tenantId) | |||
stateMachine.SetComment("This is a test state machine") | |||
stateMachine.SetCreateTime(time.Now()) | |||
repository := GetStateMachineRepositoryImpl() | |||
err := repository.RegistryStateMachine(stateMachine) | |||
assert.Nil(t, err) | |||
machineById, err := repository.GetStateMachineById(stateMachine.ID()) | |||
assert.Nil(t, err) | |||
assert.Equal(t, stateMachine.Name(), machineById.Name()) | |||
assert.Equal(t, stateMachine.TenantId(), machineById.TenantId()) | |||
assert.Equal(t, stateMachine.Comment(), machineById.Comment()) | |||
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) | |||
machineByNameAndTenantId, err := repository.GetLastVersionStateMachine(stateMachine.Name(), stateMachine.TenantId()) | |||
assert.Nil(t, err) | |||
assert.Equal(t, stateMachine.ID(), machineByNameAndTenantId.ID()) | |||
assert.Equal(t, stateMachine.Comment(), machineById.Comment()) | |||
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) | |||
} | |||
func TestStateMachineInDb(t *testing.T) { | |||
prepareDB() | |||
const tenantId = "test" | |||
yaml := loadStateMachineByYaml() | |||
stateMachine, err := parser.NewJSONStateMachineParser().Parse(yaml) | |||
assert.Nil(t, err) | |||
stateMachine.SetTenantId(tenantId) | |||
stateMachine.SetContent(yaml) | |||
repository := GetStateMachineRepositoryImpl() | |||
repository.SetStateLangStore(db.NewStateLangStore(testdb, "seata_")) | |||
err = repository.RegistryStateMachine(stateMachine) | |||
assert.Nil(t, err) | |||
repository.stateMachineMapById[stateMachine.ID()] = nil | |||
machineById, err := repository.GetStateMachineById(stateMachine.ID()) | |||
assert.Nil(t, err) | |||
assert.Equal(t, stateMachine.Name(), machineById.Name()) | |||
assert.Equal(t, stateMachine.TenantId(), machineById.TenantId()) | |||
assert.Equal(t, stateMachine.Comment(), machineById.Comment()) | |||
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) | |||
repository.stateMachineMapByNameAndTenant[stateMachine.Name()+"_"+stateMachine.TenantId()] = nil | |||
machineByNameAndTenantId, err := repository.GetLastVersionStateMachine(stateMachine.Name(), stateMachine.TenantId()) | |||
assert.Nil(t, err) | |||
assert.Equal(t, stateMachine.ID(), machineByNameAndTenantId.ID()) | |||
assert.Equal(t, stateMachine.Comment(), machineById.Comment()) | |||
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) | |||
} |
@@ -151,3 +151,28 @@ func (g *GlobalTransactionManager) Rollback(ctx context.Context, gtr *GlobalTran | |||
return nil | |||
} | |||
// GlobalReport Global report. | |||
func (g *GlobalTransactionManager) GlobalReport(ctx context.Context, gtr *GlobalTransaction) (message.GlobalStatus, error) { | |||
if gtr.Xid == "" { | |||
return message.GlobalStatusUnKnown, fmt.Errorf("GlobalReport xid should not be empty") | |||
} | |||
req := message.GlobalReportRequest{ | |||
AbstractGlobalEndRequest: message.AbstractGlobalEndRequest{ | |||
Xid: gtr.Xid, | |||
}, | |||
GlobalStatus: gtr.TxStatus, | |||
} | |||
res, err := getty.GetGettyRemotingClient().SendSyncRequest(req) | |||
if err != nil { | |||
log.Errorf("GlobalBeginRequest error %v", err) | |||
return message.GlobalStatusUnKnown, err | |||
} | |||
if res == nil || res.(message.GlobalReportResponse).ResultCode == message.ResultCodeFailed { | |||
log.Errorf("GlobalReportRequest result is empty or result code is failed, res %v", res) | |||
return message.GlobalStatusUnKnown, fmt.Errorf("GlobalReportRequest result is empty or result code is failed.") | |||
} | |||
log.Infof("GlobalReportRequest success, res %v", res) | |||
return res.(message.GlobalReportResponse).GlobalStatus, nil | |||
} |