* implement state machine repository * temporary recording * temporary recording * temporary storage * Supplementary testpull/808/head
| @@ -67,6 +67,7 @@ const ( | |||
| // TODO: this lock in process context only has one, try to add more to add concurrent | |||
| VarNameProcessContextMutexLock string = "_current_context_mutex_lock" | |||
| VarNameFailEndStateFlag string = "_fail_end_state_flag_" | |||
| VarNameGlobalTx string = "_global_transaction_" | |||
| // end region | |||
| // region of loop | |||
| @@ -79,10 +80,12 @@ const ( | |||
| // end region | |||
| // region others | |||
| SeqEntityStateMachine string = "STATE_MACHINE" | |||
| SeqEntityStateMachineInst string = "STATE_MACHINE_INST" | |||
| SeqEntityStateInst string = "STATE_INST" | |||
| OperationNameForward string = "forward" | |||
| LoopStateNamePattern string = "-loop-" | |||
| SagaTransNamePrefix string = "$Saga_" | |||
| // end region | |||
| SeperatorParentId string = ":" | |||
| @@ -25,16 +25,24 @@ import ( | |||
| ) | |||
| const ( | |||
| DefaultTransOperTimeout = 60000 * 30 | |||
| DefaultServiceInvokeTimeout = 60000 * 5 | |||
| DefaultTransOperTimeout = 60000 * 30 | |||
| DefaultServiceInvokeTimeout = 60000 * 5 | |||
| DefaultClientSagaRetryPersistModeUpdate = false | |||
| DefaultClientSagaCompensatePersistModeUpdate = false | |||
| DefaultClientReportSuccessEnable = false | |||
| DefaultClientSagaBranchRegisterEnable = true | |||
| ) | |||
| type DefaultStateMachineConfig struct { | |||
| // Configuration | |||
| transOperationTimeout int | |||
| serviceInvokeTimeout int | |||
| charset string | |||
| defaultTenantId string | |||
| transOperationTimeout int | |||
| serviceInvokeTimeout int | |||
| charset string | |||
| defaultTenantId string | |||
| sagaRetryPersistModeUpdate bool | |||
| sagaCompensatePersistModeUpdate bool | |||
| sagaBranchRegisterEnable bool | |||
| rmReportSuccessEnable bool | |||
| // Components | |||
| @@ -202,13 +210,49 @@ func (c *DefaultStateMachineConfig) ServiceInvokeTimeout() int { | |||
| return c.serviceInvokeTimeout | |||
| } | |||
| func (c *DefaultStateMachineConfig) IsSagaRetryPersistModeUpdate() bool { | |||
| return c.sagaRetryPersistModeUpdate | |||
| } | |||
| func (c *DefaultStateMachineConfig) SetSagaRetryPersistModeUpdate(sagaRetryPersistModeUpdate bool) { | |||
| c.sagaRetryPersistModeUpdate = sagaRetryPersistModeUpdate | |||
| } | |||
| func (c *DefaultStateMachineConfig) IsSagaCompensatePersistModeUpdate() bool { | |||
| return c.sagaCompensatePersistModeUpdate | |||
| } | |||
| func (c *DefaultStateMachineConfig) SetSagaCompensatePersistModeUpdate(sagaCompensatePersistModeUpdate bool) { | |||
| c.sagaCompensatePersistModeUpdate = sagaCompensatePersistModeUpdate | |||
| } | |||
| func (c *DefaultStateMachineConfig) IsSagaBranchRegisterEnable() bool { | |||
| return c.sagaBranchRegisterEnable | |||
| } | |||
| func (c *DefaultStateMachineConfig) SetSagaBranchRegisterEnable(sagaBranchRegisterEnable bool) { | |||
| c.sagaBranchRegisterEnable = sagaBranchRegisterEnable | |||
| } | |||
| func (c *DefaultStateMachineConfig) IsRmReportSuccessEnable() bool { | |||
| return c.rmReportSuccessEnable | |||
| } | |||
| func (c *DefaultStateMachineConfig) SetRmReportSuccessEnable(rmReportSuccessEnable bool) { | |||
| c.rmReportSuccessEnable = rmReportSuccessEnable | |||
| } | |||
| func NewDefaultStateMachineConfig() *DefaultStateMachineConfig { | |||
| c := &DefaultStateMachineConfig{ | |||
| transOperationTimeout: DefaultTransOperTimeout, | |||
| serviceInvokeTimeout: DefaultServiceInvokeTimeout, | |||
| charset: "UTF-8", | |||
| defaultTenantId: "000001", | |||
| componentLock: &sync.Mutex{}, | |||
| transOperationTimeout: DefaultTransOperTimeout, | |||
| serviceInvokeTimeout: DefaultServiceInvokeTimeout, | |||
| charset: "UTF-8", | |||
| defaultTenantId: "000001", | |||
| sagaRetryPersistModeUpdate: DefaultClientSagaRetryPersistModeUpdate, | |||
| sagaCompensatePersistModeUpdate: DefaultClientSagaCompensatePersistModeUpdate, | |||
| sagaBranchRegisterEnable: DefaultClientSagaBranchRegisterEnable, | |||
| rmReportSuccessEnable: DefaultClientReportSuccessEnable, | |||
| componentLock: &sync.Mutex{}, | |||
| } | |||
| // TODO: init config | |||
| @@ -55,6 +55,8 @@ type StateLogStore interface { | |||
| GetStateInstance(stateInstanceId string, stateMachineInstanceId string) (statelang.StateInstance, error) | |||
| GetStateInstanceListByMachineInstanceId(stateMachineInstanceId string) ([]statelang.StateInstance, error) | |||
| ClearUp(context ProcessContext) | |||
| } | |||
| type StateMachineRepository interface { | |||
| @@ -72,6 +72,10 @@ type StateMachineInstance interface { | |||
| SetStatus(status ExecutionStatus) | |||
| StateMap() map[string]StateInstance | |||
| SetStateMap(stateMap map[string]StateInstance) | |||
| CompensationStatus() ExecutionStatus | |||
| SetCompensationStatus(compensationStatus ExecutionStatus) | |||
| @@ -234,6 +238,14 @@ func (s *StateMachineInstanceImpl) SetStatus(status ExecutionStatus) { | |||
| s.status = status | |||
| } | |||
| func (s *StateMachineInstanceImpl) StateMap() map[string]StateInstance { | |||
| return s.stateMap | |||
| } | |||
| func (s *StateMachineInstanceImpl) SetStateMap(stateMap map[string]StateInstance) { | |||
| s.stateMap = stateMap | |||
| } | |||
| func (s *StateMachineInstanceImpl) CompensationStatus() ExecutionStatus { | |||
| return s.compensationStatus | |||
| } | |||
| @@ -104,7 +104,7 @@ func scanRowsToStateMachine(rows *sql.Rows) (statelang.StateMachine, error) { | |||
| if recoverStrategy != "" { | |||
| stateMachine.SetRecoverStrategy(statelang.RecoverStrategy(recoverStrategy)) | |||
| } | |||
| stateMachine.SetTenantId(t) | |||
| stateMachine.SetTenantId(tenantId) | |||
| stateMachine.SetStatus(statelang.StateMachineStatus(status)) | |||
| return stateMachine, nil | |||
| } | |||
| @@ -21,17 +21,24 @@ import ( | |||
| "context" | |||
| "database/sql" | |||
| "fmt" | |||
| "regexp" | |||
| "strconv" | |||
| "strings" | |||
| "time" | |||
| "github.com/pkg/errors" | |||
| constant2 "github.com/seata/seata-go/pkg/constant" | |||
| "github.com/seata/seata-go/pkg/protocol/branch" | |||
| "github.com/seata/seata-go/pkg/protocol/message" | |||
| "github.com/seata/seata-go/pkg/rm" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/constant" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/engine/core" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/engine/serializer" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/statelang" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state" | |||
| "github.com/seata/seata-go/pkg/tm" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| "regexp" | |||
| "strconv" | |||
| "strings" | |||
| "time" | |||
| ) | |||
| const ( | |||
| @@ -124,14 +131,19 @@ func (s *StateLogStore) RecordStateMachineStarted(ctx context.Context, machineIn | |||
| parentId := machineInstance.ParentID() | |||
| if parentId == "" { | |||
| //TODO begin transaction | |||
| // begin transaction | |||
| err = s.beginTransaction(ctx, machineInstance, context) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| } | |||
| if machineInstance.ID() == "" && s.seqGenerator != nil { | |||
| machineInstance.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachineInst, "")) | |||
| } | |||
| //TODO bind SAGA branch type | |||
| // bind SAGA branch type | |||
| context.SetVariable(constant2.BranchTypeKey, branch.BranchTypeSAGA) | |||
| serializedStartParams, err := s.paramsSerializer.Serialize(machineInstance.StartParams()) | |||
| if err != nil { | |||
| @@ -149,14 +161,48 @@ func (s *StateLogStore) RecordStateMachineStarted(ctx context.Context, machineIn | |||
| return nil | |||
| } | |||
| func (s *StateLogStore) beginTransaction(ctx context.Context, machineInstance statelang.StateMachineInstance, context core.ProcessContext) error { | |||
| cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.StateMachineConfig) | |||
| if !ok { | |||
| return errors.New("begin transaction fail, stateMachineConfig is required in context") | |||
| } | |||
| defer func() { | |||
| isAsync, ok := context.GetVariable(constant.VarNameIsAsyncExecution).(bool) | |||
| if ok && isAsync { | |||
| s.ClearUp(context) | |||
| } | |||
| }() | |||
| tm.SetTxRole(ctx, tm.Launcher) | |||
| tm.SetTxStatus(ctx, message.GlobalStatusUnKnown) | |||
| tm.SetTxName(ctx, constant.SagaTransNamePrefix+machineInstance.StateMachine().Name()) | |||
| err := tm.GetGlobalTransactionManager().Begin(ctx, time.Duration(cfg.TransOperationTimeout())) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| machineInstance.SetID(tm.GetXID(ctx)) | |||
| return nil | |||
| } | |||
| func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineInstance statelang.StateMachineInstance, | |||
| context core.ProcessContext) error { | |||
| if machineInstance == nil { | |||
| return nil | |||
| } | |||
| defer func() { | |||
| s.ClearUp(context) | |||
| }() | |||
| endParams := machineInstance.EndParams() | |||
| if endParams != nil { | |||
| delete(endParams, constant.VarNameGlobalTx) | |||
| } | |||
| // if success, clear exception | |||
| if statelang.SU == machineInstance.Status() && machineInstance.Exception() != nil { | |||
| machineInstance.SetException(nil) | |||
| } | |||
| @@ -183,10 +229,86 @@ func (s *StateLogStore) RecordStateMachineFinished(ctx context.Context, machineI | |||
| return nil | |||
| } | |||
| //TODO check if timeout or else report transaction finished | |||
| // check if timeout or else report transaction finished | |||
| cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.StateMachineConfig) | |||
| if !ok { | |||
| return errors.New("stateMachineConfig is required in context") | |||
| } | |||
| if core.IsTimeout(machineInstance.UpdatedTime(), cfg.TransOperationTimeout()) { | |||
| log.Warnf("StateMachineInstance[%s] is execution timeout, skip report transaction finished to server.", machineInstance.ID()) | |||
| } else if machineInstance.ParentID() == "" { | |||
| //if parentId is not null, machineInstance is a SubStateMachine, do not report global transaction. | |||
| err = s.reportTransactionFinished(ctx, machineInstance, context) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| } | |||
| return nil | |||
| } | |||
| func (s *StateLogStore) reportTransactionFinished(ctx context.Context, machineInstance statelang.StateMachineInstance, context core.ProcessContext) error { | |||
| var err error | |||
| defer func() { | |||
| s.ClearUp(context) | |||
| if err != nil { | |||
| log.Errorf("Report transaction finish to server error: %v, StateMachine: %s, XID: %s, Reason: %s", | |||
| err, machineInstance.StateMachine().Name(), machineInstance.ID(), err.Error()) | |||
| } | |||
| }() | |||
| globalTransaction, err := s.getGlobalTransaction(machineInstance, context) | |||
| if err != nil { | |||
| log.Errorf("Failed to get global transaction: %v", err) | |||
| return err | |||
| } | |||
| var globalStatus message.GlobalStatus | |||
| if statelang.SU == machineInstance.Status() && machineInstance.CompensationStatus() == "" { | |||
| globalStatus = message.GlobalStatusCommitted | |||
| } else if statelang.SU == machineInstance.CompensationStatus() { | |||
| globalStatus = message.GlobalStatusRollbacked | |||
| } else if statelang.FA == machineInstance.CompensationStatus() || statelang.UN == machineInstance.CompensationStatus() { | |||
| globalStatus = message.GlobalStatusRollbackRetrying | |||
| } else if statelang.FA == machineInstance.Status() && machineInstance.CompensationStatus() == "" { | |||
| globalStatus = message.GlobalStatusFinished | |||
| } else if statelang.UN == machineInstance.Status() && machineInstance.CompensationStatus() == "" { | |||
| globalStatus = message.GlobalStatusCommitRetrying | |||
| } else { | |||
| globalStatus = message.GlobalStatusUnKnown | |||
| } | |||
| globalTransaction.TxStatus = globalStatus | |||
| _, err = tm.GetGlobalTransactionManager().GlobalReport(ctx, globalTransaction) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| return nil | |||
| } | |||
| func (s *StateLogStore) getGlobalTransaction(machineInstance statelang.StateMachineInstance, context core.ProcessContext) (*tm.GlobalTransaction, error) { | |||
| globalTransaction, ok := context.GetVariable(constant.VarNameGlobalTx).(*tm.GlobalTransaction) | |||
| if ok { | |||
| return globalTransaction, nil | |||
| } | |||
| var xid string | |||
| parentId := machineInstance.ParentID() | |||
| if parentId == "" { | |||
| xid = machineInstance.ID() | |||
| } else { | |||
| xid = parentId[:strings.LastIndex(parentId, constant.SeperatorParentId)] | |||
| } | |||
| globalTransaction = &tm.GlobalTransaction{ | |||
| Xid: xid, | |||
| TxStatus: message.GlobalStatusUnKnown, | |||
| TxRole: tm.Launcher, | |||
| } | |||
| context.SetVariable(constant.VarNameGlobalTx, globalTransaction) | |||
| return globalTransaction, nil | |||
| } | |||
| func (s *StateLogStore) RecordStateMachineRestarted(ctx context.Context, machineInstance statelang.StateMachineInstance, | |||
| context core.ProcessContext) error { | |||
| if machineInstance == nil { | |||
| @@ -211,18 +333,25 @@ func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance st | |||
| if stateInstance == nil { | |||
| return nil | |||
| } | |||
| isUpdateMode := s.isUpdateMode(stateInstance, context) | |||
| isUpdateMode, err := s.isUpdateMode(stateInstance, context) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| // if this state is for retry, do not register branch | |||
| if stateInstance.StateIDRetriedFor() != "" { | |||
| if isUpdateMode { | |||
| stateInstance.SetID(stateInstance.StateIDRetriedFor()) | |||
| } else { | |||
| // generate id by default | |||
| stateInstance.SetID(s.generateRetryStateInstanceId(stateInstance)) | |||
| } | |||
| } else if stateInstance.StateIDCompensatedFor() != "" { | |||
| // if this state is for compensation, do not register branch | |||
| stateInstance.SetID(s.generateCompensateStateInstanceId(stateInstance, isUpdateMode)) | |||
| } else { | |||
| //TODO register branch | |||
| // register branch | |||
| s.branchRegister(stateInstance, context) | |||
| } | |||
| if stateInstance.ID() == "" && s.seqGenerator != nil { | |||
| @@ -252,9 +381,45 @@ func (s *StateLogStore) RecordStateStarted(ctx context.Context, stateInstance st | |||
| return nil | |||
| } | |||
| func (s *StateLogStore) isUpdateMode(instance statelang.StateInstance, context core.ProcessContext) bool { | |||
| //TODO implement me, add forward logic | |||
| return false | |||
| func (s *StateLogStore) isUpdateMode(stateInstance statelang.StateInstance, context core.ProcessContext) (bool, error) { | |||
| cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) | |||
| if !ok { | |||
| return false, errors.New("stateMachineConfig is required in context") | |||
| } | |||
| instruction, ok := context.GetInstruction().(*core.StateInstruction) | |||
| if !ok { | |||
| return false, errors.New("stateInstruction is required in processContext") | |||
| } | |||
| instructionState, err := instruction.GetState(context) | |||
| if err != nil { | |||
| return false, err | |||
| } | |||
| taskState, _ := instructionState.(*state.AbstractTaskState) | |||
| stateMachine := stateInstance.StateMachineInstance().StateMachine() | |||
| if stateInstance.StateIDRetriedFor() != "" { | |||
| if taskState != nil && taskState.RetryPersistModeUpdate() { | |||
| return taskState.RetryPersistModeUpdate(), nil | |||
| } else if stateMachine.IsRetryPersistModeUpdate() { | |||
| return stateMachine.IsRetryPersistModeUpdate(), nil | |||
| } | |||
| return cfg.IsSagaRetryPersistModeUpdate(), nil | |||
| } else if stateInstance.StateIDCompensatedFor() != "" { | |||
| // find if this compensate has been executed | |||
| stateList := stateInstance.StateMachineInstance().StateList() | |||
| for _, instance := range stateList { | |||
| if instance.IsForCompensation() && instance.Name() == stateInstance.Name() { | |||
| if taskState != nil && taskState.CompensatePersistModeUpdate() { | |||
| return taskState.CompensatePersistModeUpdate(), nil | |||
| } else if stateMachine.IsCompensatePersistModeUpdate() { | |||
| return stateMachine.IsCompensatePersistModeUpdate(), nil | |||
| } | |||
| return cfg.IsSagaCompensatePersistModeUpdate(), nil | |||
| } | |||
| } | |||
| } | |||
| return false, nil | |||
| } | |||
| func (s *StateLogStore) generateRetryStateInstanceId(stateInstance statelang.StateInstance) string { | |||
| @@ -293,6 +458,49 @@ func (s *StateLogStore) generateCompensateStateInstanceId(stateInstance statelan | |||
| return fmt.Sprintf("%s-%d", originalCompensateStateInstId, maxIndex) | |||
| } | |||
| func (s *StateLogStore) branchRegister(stateInstance statelang.StateInstance, context core.ProcessContext) error { | |||
| cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) | |||
| if !ok { | |||
| return errors.New("stateMachineConfig is required in context") | |||
| } | |||
| if !cfg.IsSagaBranchRegisterEnable() { | |||
| log.Debugf("sagaBranchRegisterEnable = false, skip branch report. state[%s]", stateInstance.Name()) | |||
| return nil | |||
| } | |||
| //Register branch | |||
| var err error | |||
| machineInstance := stateInstance.StateMachineInstance() | |||
| defer func() { | |||
| if err != nil { | |||
| log.Errorf("Branch transaction failure. StateMachine: %s, XID: %s, State: %s, stateId: %s, err: %v", | |||
| machineInstance.StateMachine().Name(), machineInstance.ID(), stateInstance.Name(), stateInstance.ID(), err) | |||
| } | |||
| }() | |||
| globalTransaction, err := s.getGlobalTransaction(machineInstance, context) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| if globalTransaction == nil { | |||
| err = errors.New("Global transaction is not exists") | |||
| return err | |||
| } | |||
| branchId, err := rm.GetRMRemotingInstance().BranchRegister(rm.BranchRegisterParam{ | |||
| BranchType: branch.BranchTypeSAGA, | |||
| ResourceId: machineInstance.StateMachine().Name() + "#" + stateInstance.Name(), | |||
| Xid: globalTransaction.Xid, | |||
| }) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| stateInstance.SetID(strconv.FormatInt(branchId, 10)) | |||
| return nil | |||
| } | |||
| func (s *StateLogStore) getIdIndex(stateInstanceId string, separator string) int { | |||
| if stateInstanceId != "" { | |||
| start := strings.LastIndex(stateInstanceId, separator) | |||
| @@ -332,11 +540,124 @@ func (s *StateLogStore) RecordStateFinished(ctx context.Context, stateInstance s | |||
| return err | |||
| } | |||
| //TODO report branch | |||
| // A switch to skip branch report on branch success, in order to optimize performance | |||
| cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) | |||
| if !(ok && !cfg.IsRmReportSuccessEnable() && statelang.SU == stateInstance.Status()) { | |||
| err = s.branchReport(stateInstance, context) | |||
| return err | |||
| } | |||
| return nil | |||
| } | |||
| func (s *StateLogStore) branchReport(stateInstance statelang.StateInstance, context core.ProcessContext) error { | |||
| cfg, ok := context.GetVariable(constant.VarNameStateMachineConfig).(core.DefaultStateMachineConfig) | |||
| if ok && !cfg.IsSagaBranchRegisterEnable() { | |||
| log.Debugf("sagaBranchRegisterEnable = false, skip branch report. state[%s]", stateInstance.Name()) | |||
| return nil | |||
| } | |||
| var branchStatus branch.BranchStatus | |||
| // find out the original state instance, only the original state instance is registered on the server, | |||
| // and its status should be reported. | |||
| var originalStateInst statelang.StateInstance | |||
| if stateInstance.StateIDRetriedFor() != "" { | |||
| isUpdateMode, err := s.isUpdateMode(stateInstance, context) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| if isUpdateMode { | |||
| originalStateInst = stateInstance | |||
| } else { | |||
| originalStateInst = s.findOutOriginalStateInstanceOfRetryState(stateInstance) | |||
| } | |||
| if statelang.SU == stateInstance.Status() { | |||
| branchStatus = branch.BranchStatusPhasetwoCommitted | |||
| } else if statelang.FA == stateInstance.Status() || statelang.UN == stateInstance.Status() { | |||
| branchStatus = branch.BranchStatusPhaseoneFailed | |||
| } else { | |||
| branchStatus = branch.BranchStatusUnknown | |||
| } | |||
| } else if stateInstance.StateIDCompensatedFor() != "" { | |||
| isUpdateMode, err := s.isUpdateMode(stateInstance, context) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| if isUpdateMode { | |||
| originalStateInst = stateInstance.StateMachineInstance().StateMap()[stateInstance.StateIDCompensatedFor()] | |||
| } else { | |||
| originalStateInst = s.findOutOriginalStateInstanceOfCompensateState(stateInstance) | |||
| } | |||
| } | |||
| if originalStateInst == nil { | |||
| originalStateInst = stateInstance | |||
| } | |||
| if branchStatus == branch.BranchStatusUnknown { | |||
| if statelang.SU == originalStateInst.Status() && originalStateInst.CompensationStatus() == "" { | |||
| branchStatus = branch.BranchStatusPhasetwoCommitted | |||
| } else if statelang.SU == originalStateInst.CompensationStatus() { | |||
| branchStatus = branch.BranchStatusPhasetwoRollbacked | |||
| } else if statelang.FA == originalStateInst.CompensationStatus() || statelang.UN == originalStateInst.CompensationStatus() { | |||
| branchStatus = branch.BranchStatusPhasetwoRollbackFailedRetryable | |||
| } else if (statelang.FA == originalStateInst.Status() || statelang.UN == originalStateInst.Status()) && originalStateInst.CompensationStatus() == "" { | |||
| branchStatus = branch.BranchStatusPhaseoneFailed | |||
| } else { | |||
| branchStatus = branch.BranchStatusUnknown | |||
| } | |||
| } | |||
| var err error | |||
| defer func() { | |||
| if err != nil { | |||
| log.Errorf("Report branch status to server error:%s, StateMachine:%s, StateName:%s, XID:%s, branchId:%s, branchStatus:%s, err:%v", | |||
| err.Error(), originalStateInst.StateMachineInstance().StateMachine().Name(), originalStateInst.Name(), | |||
| originalStateInst.StateMachineInstance().ID(), originalStateInst.ID(), branchStatus, err) | |||
| } | |||
| }() | |||
| globalTransaction, err := s.getGlobalTransaction(stateInstance.StateMachineInstance(), context) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| if globalTransaction == nil { | |||
| err = errors.New("Global transaction is not exists") | |||
| return err | |||
| } | |||
| branchId, err := strconv.ParseInt(originalStateInst.ID(), 10, 0) | |||
| err = rm.GetRMRemotingInstance().BranchReport(rm.BranchReportParam{ | |||
| BranchType: branch.BranchTypeSAGA, | |||
| Xid: globalTransaction.Xid, | |||
| BranchId: branchId, | |||
| Status: branchStatus, | |||
| }) | |||
| return err | |||
| } | |||
| func (s *StateLogStore) findOutOriginalStateInstanceOfRetryState(stateInstance statelang.StateInstance) statelang.StateInstance { | |||
| stateMap := stateInstance.StateMachineInstance().StateMap() | |||
| originalStateInst := stateMap[stateInstance.StateIDRetriedFor()] | |||
| for originalStateInst.StateIDRetriedFor() != "" { | |||
| originalStateInst = stateMap[stateInstance.StateIDRetriedFor()] | |||
| } | |||
| return originalStateInst | |||
| } | |||
| func (s *StateLogStore) findOutOriginalStateInstanceOfCompensateState(stateInstance statelang.StateInstance) statelang.StateInstance { | |||
| stateMap := stateInstance.StateMachineInstance().StateMap() | |||
| originalStateInst := stateMap[stateInstance.StateIDCompensatedFor()] | |||
| for originalStateInst.StateIDRetriedFor() != "" { | |||
| originalStateInst = stateMap[stateInstance.StateIDRetriedFor()] | |||
| } | |||
| return originalStateInst | |||
| } | |||
| func (s *StateLogStore) GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateMachineInstance, error) { | |||
| stateMachineInstance, err := SelectOne(s.db, s.getStateMachineInstanceByIdSql, scanRowsToStateMachineInstance, | |||
| stateMachineInstanceId) | |||
| @@ -437,7 +758,7 @@ func (s *StateLogStore) GetStateInstanceListByMachineInstanceId(stateMachineInst | |||
| if lastStateInstance.EndTime().IsZero() { | |||
| lastStateInstance.SetStatus(statelang.RU) | |||
| } | |||
| //TODO add forward and compensate logic | |||
| originStateMap := make(map[string]statelang.StateInstance) | |||
| compensatedStateMap := make(map[string]statelang.StateInstance) | |||
| retriedStateMap := make(map[string]statelang.StateInstance) | |||
| @@ -522,6 +843,11 @@ func (s *StateLogStore) SetSeqGenerator(seqGenerator sequence.SeqGenerator) { | |||
| s.seqGenerator = seqGenerator | |||
| } | |||
| func (s *StateLogStore) ClearUp(context core.ProcessContext) { | |||
| context.RemoveVariable(constant2.XidKey) | |||
| context.RemoveVariable(constant2.BranchTypeKey) | |||
| } | |||
| func execStateMachineInstanceStatementForInsert(obj statelang.StateMachineInstance, stmt *sql.Stmt) (int64, error) { | |||
| result, err := stmt.Exec( | |||
| obj.ID(), | |||
| @@ -57,6 +57,12 @@ func mockMachineInstance(stateMachineName string) statelang.StateMachineInstance | |||
| return inst | |||
| } | |||
| func mockStateMachineConfig(context core.ProcessContext) core.StateMachineConfig { | |||
| cfg := core.NewDefaultStateMachineConfig() | |||
| context.SetVariable(constant.VarNameStateMachineConfig, cfg) | |||
| return cfg | |||
| } | |||
| func TestStateLogStore_RecordStateMachineStarted(t *testing.T) { | |||
| prepareDB() | |||
| @@ -65,6 +71,7 @@ func TestStateLogStore_RecordStateMachineStarted(t *testing.T) { | |||
| expected := mockMachineInstance(stateMachineName) | |||
| expected.SetBusinessKey("test_started") | |||
| ctx := mockProcessContext(stateMachineName, expected) | |||
| mockStateMachineConfig(ctx) | |||
| err := stateLogStore.RecordStateMachineStarted(context.Background(), expected, ctx) | |||
| assert.Nil(t, err) | |||
| actual, err := stateLogStore.GetStateMachineInstance(expected.ID()) | |||
| @@ -0,0 +1,18 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package repository | |||
| @@ -18,34 +18,221 @@ | |||
| package repository | |||
| import ( | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/statelang" | |||
| "io" | |||
| "sync" | |||
| "time" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/constant" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/statelang" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/store/db" | |||
| "github.com/seata/seata-go/pkg/util/log" | |||
| ) | |||
| const ( | |||
| DefaultJsonParser = "fastjson" | |||
| ) | |||
| var ( | |||
| stateMachineRepositoryImpl *StateMachineRepositoryImpl | |||
| onceStateMachineRepositoryImpl sync.Once | |||
| ) | |||
| type StateMachineRepositoryImpl struct { | |||
| stateMachineMapById map[string]statelang.StateMachine | |||
| stateMachineMapByNameAndTenant map[string]statelang.StateMachine | |||
| stateLangStore *db.StateLangStore | |||
| seqGenerator sequence.SeqGenerator | |||
| defaultTenantId string | |||
| jsonParserName string | |||
| charset string | |||
| mutex *sync.Mutex | |||
| } | |||
| func GetStateMachineRepositoryImpl() *StateMachineRepositoryImpl { | |||
| if stateMachineRepositoryImpl == nil { | |||
| onceStateMachineRepositoryImpl.Do(func() { | |||
| //TODO get charset by config | |||
| //TODO charset is not use | |||
| //TODO using json parser | |||
| stateMachineRepositoryImpl = &StateMachineRepositoryImpl{ | |||
| stateMachineMapById: make(map[string]statelang.StateMachine), | |||
| stateMachineMapByNameAndTenant: make(map[string]statelang.StateMachine), | |||
| seqGenerator: sequence.NewUUIDSeqGenerator(), | |||
| jsonParserName: DefaultJsonParser, | |||
| charset: "UTF-8", | |||
| mutex: &sync.Mutex{}, | |||
| } | |||
| }) | |||
| } | |||
| return stateMachineRepositoryImpl | |||
| } | |||
| func (s *StateMachineRepositoryImpl) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) { | |||
| stateMachine := s.stateMachineMapById[stateMachineId] | |||
| if stateMachine == nil && s.stateLangStore != nil { | |||
| s.mutex.Lock() | |||
| defer s.mutex.Unlock() | |||
| stateMachine = s.stateMachineMapById[stateMachineId] | |||
| if stateMachine == nil { | |||
| oldStateMachine, err := s.stateLangStore.GetStateMachineById(stateMachineId) | |||
| if err != nil { | |||
| return oldStateMachine, err | |||
| } | |||
| parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content()) | |||
| if err != nil { | |||
| return oldStateMachine, err | |||
| } | |||
| oldStateMachine.SetStartState(parseStatMachine.StartState()) | |||
| for key, val := range parseStatMachine.States() { | |||
| oldStateMachine.States()[key] = val | |||
| } | |||
| s.stateMachineMapById[stateMachineId] = oldStateMachine | |||
| s.stateMachineMapByNameAndTenant[oldStateMachine.Name()+"_"+oldStateMachine.TenantId()] = oldStateMachine | |||
| return oldStateMachine, nil | |||
| } | |||
| } | |||
| return stateMachine, nil | |||
| } | |||
| func (s *StateMachineRepositoryImpl) GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error) { | |||
| return s.GetLastVersionStateMachine(stateMachineName, tenantId) | |||
| } | |||
| func (s *StateMachineRepositoryImpl) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) { | |||
| key := stateMachineName + "_" + tenantId | |||
| stateMachine := s.stateMachineMapByNameAndTenant[key] | |||
| if stateMachine == nil && s.stateLangStore != nil { | |||
| s.mutex.Lock() | |||
| defer s.mutex.Unlock() | |||
| stateMachine = s.stateMachineMapById[key] | |||
| if stateMachine == nil { | |||
| oldStateMachine, err := s.stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId) | |||
| if err != nil { | |||
| return oldStateMachine, err | |||
| } | |||
| parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content()) | |||
| if err != nil { | |||
| return oldStateMachine, err | |||
| } | |||
| oldStateMachine.SetStartState(parseStatMachine.StartState()) | |||
| for key, val := range parseStatMachine.States() { | |||
| oldStateMachine.States()[key] = val | |||
| } | |||
| s.stateMachineMapById[oldStateMachine.ID()] = oldStateMachine | |||
| s.stateMachineMapByNameAndTenant[key] = oldStateMachine | |||
| return oldStateMachine, nil | |||
| } | |||
| } | |||
| return stateMachine, nil | |||
| } | |||
| func (s *StateMachineRepositoryImpl) RegistryStateMachine(machine statelang.StateMachine) error { | |||
| stateMachineName := machine.Name() | |||
| tenantId := machine.TenantId() | |||
| if s.stateLangStore != nil { | |||
| oldStateMachine, err := s.stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| if oldStateMachine != nil { | |||
| if oldStateMachine.Content() == machine.Content() && machine.Version() != "" && machine.Version() == oldStateMachine.Version() { | |||
| log.Debugf("StateMachine[%s] is already exist a same version", stateMachineName) | |||
| machine.SetID(oldStateMachine.ID()) | |||
| machine.SetCreateTime(oldStateMachine.CreateTime()) | |||
| s.stateMachineMapById[machine.ID()] = machine | |||
| s.stateMachineMapByNameAndTenant[machine.Name()+"_"+machine.TenantId()] = machine | |||
| return nil | |||
| } | |||
| } | |||
| if machine.ID() == "" { | |||
| machine.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachine, "")) | |||
| } | |||
| machine.SetCreateTime(time.Now()) | |||
| err = s.stateLangStore.StoreStateMachine(machine) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| } | |||
| if machine.ID() == "" { | |||
| machine.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachine, "")) | |||
| } | |||
| s.stateMachineMapById[machine.ID()] = machine | |||
| s.stateMachineMapByNameAndTenant[machine.Name()+"_"+machine.TenantId()] = machine | |||
| return nil | |||
| } | |||
| func (s *StateMachineRepositoryImpl) RegistryStateMachineByReader(reader io.Reader) error { | |||
| jsonByte, err := io.ReadAll(reader) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| json := string(jsonByte) | |||
| parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(json) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| if parseStatMachine == nil { | |||
| return nil | |||
| } | |||
| parseStatMachine.SetContent(json) | |||
| s.RegistryStateMachine(parseStatMachine) | |||
| log.Debugf("===== StateMachine Loaded: %s", json) | |||
| return nil | |||
| } | |||
| func (s *StateMachineRepositoryImpl) SetStateLangStore(stateLangStore *db.StateLangStore) { | |||
| s.stateLangStore = stateLangStore | |||
| } | |||
| func (s *StateMachineRepositoryImpl) SetSeqGenerator(seqGenerator sequence.SeqGenerator) { | |||
| s.seqGenerator = seqGenerator | |||
| } | |||
| func (s *StateMachineRepositoryImpl) SetCharset(charset string) { | |||
| s.charset = charset | |||
| } | |||
| func (s StateMachineRepositoryImpl) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) { | |||
| //TODO implement me | |||
| panic("implement me") | |||
| func (s *StateMachineRepositoryImpl) GetCharset() string { | |||
| return s.charset | |||
| } | |||
| func (s StateMachineRepositoryImpl) GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error) { | |||
| //TODO implement me | |||
| panic("implement me") | |||
| func (s *StateMachineRepositoryImpl) SetDefaultTenantId(defaultTenantId string) { | |||
| s.defaultTenantId = defaultTenantId | |||
| } | |||
| func (s StateMachineRepositoryImpl) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) { | |||
| //TODO implement me | |||
| panic("implement me") | |||
| func (s *StateMachineRepositoryImpl) GetDefaultTenantId() string { | |||
| return s.defaultTenantId | |||
| } | |||
| func (s StateMachineRepositoryImpl) RegistryStateMachine(machine statelang.StateMachine) error { | |||
| //TODO implement me | |||
| panic("implement me") | |||
| func (s *StateMachineRepositoryImpl) SetJsonParserName(jsonParserName string) { | |||
| s.jsonParserName = jsonParserName | |||
| } | |||
| func (s StateMachineRepositoryImpl) RegistryStateMachineByReader(reader io.Reader) error { | |||
| //TODO implement me | |||
| panic("implement me") | |||
| func (s *StateMachineRepositoryImpl) GetJsonParserName() string { | |||
| return s.jsonParserName | |||
| } | |||
| @@ -0,0 +1,118 @@ | |||
| /* | |||
| * Licensed to the Apache Software Foundation (ASF) under one or more | |||
| * contributor license agreements. See the NOTICE file distributed with | |||
| * this work for additional information regarding copyright ownership. | |||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | |||
| * (the "License"); you may not use this file except in compliance with | |||
| * the License. You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package repository | |||
| import ( | |||
| "database/sql" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser" | |||
| "os" | |||
| "sync" | |||
| "testing" | |||
| "time" | |||
| _ "github.com/mattn/go-sqlite3" | |||
| "github.com/stretchr/testify/assert" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/statelang" | |||
| "github.com/seata/seata-go/pkg/saga/statemachine/store/db" | |||
| ) | |||
| var ( | |||
| oncePrepareDB sync.Once | |||
| testdb *sql.DB | |||
| ) | |||
| func prepareDB() { | |||
| oncePrepareDB.Do(func() { | |||
| var err error | |||
| testdb, err = sql.Open("sqlite3", ":memory:") | |||
| query_, err := os.ReadFile("../../../../../testdata/sql/saga/sqlite_init.sql") | |||
| initScript := string(query_) | |||
| if err != nil { | |||
| panic(err) | |||
| } | |||
| if _, err := testdb.Exec(initScript); err != nil { | |||
| panic(err) | |||
| } | |||
| }) | |||
| } | |||
| func loadStateMachineByYaml() string { | |||
| query, _ := os.ReadFile("../../../../../testdata/saga/statelang/simple_statemachine.json") | |||
| return string(query) | |||
| } | |||
| func TestStateMachineInMemory(t *testing.T) { | |||
| const stateMachineId, stateMachineName, tenantId = "simpleStateMachine", "simpleStateMachine", "test" | |||
| stateMachine := statelang.NewStateMachineImpl() | |||
| stateMachine.SetID(stateMachineId) | |||
| stateMachine.SetName(stateMachineName) | |||
| stateMachine.SetTenantId(tenantId) | |||
| stateMachine.SetComment("This is a test state machine") | |||
| stateMachine.SetCreateTime(time.Now()) | |||
| repository := GetStateMachineRepositoryImpl() | |||
| err := repository.RegistryStateMachine(stateMachine) | |||
| assert.Nil(t, err) | |||
| machineById, err := repository.GetStateMachineById(stateMachine.ID()) | |||
| assert.Nil(t, err) | |||
| assert.Equal(t, stateMachine.Name(), machineById.Name()) | |||
| assert.Equal(t, stateMachine.TenantId(), machineById.TenantId()) | |||
| assert.Equal(t, stateMachine.Comment(), machineById.Comment()) | |||
| assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) | |||
| machineByNameAndTenantId, err := repository.GetLastVersionStateMachine(stateMachine.Name(), stateMachine.TenantId()) | |||
| assert.Nil(t, err) | |||
| assert.Equal(t, stateMachine.ID(), machineByNameAndTenantId.ID()) | |||
| assert.Equal(t, stateMachine.Comment(), machineById.Comment()) | |||
| assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) | |||
| } | |||
| func TestStateMachineInDb(t *testing.T) { | |||
| prepareDB() | |||
| const tenantId = "test" | |||
| yaml := loadStateMachineByYaml() | |||
| stateMachine, err := parser.NewJSONStateMachineParser().Parse(yaml) | |||
| assert.Nil(t, err) | |||
| stateMachine.SetTenantId(tenantId) | |||
| stateMachine.SetContent(yaml) | |||
| repository := GetStateMachineRepositoryImpl() | |||
| repository.SetStateLangStore(db.NewStateLangStore(testdb, "seata_")) | |||
| err = repository.RegistryStateMachine(stateMachine) | |||
| assert.Nil(t, err) | |||
| repository.stateMachineMapById[stateMachine.ID()] = nil | |||
| machineById, err := repository.GetStateMachineById(stateMachine.ID()) | |||
| assert.Nil(t, err) | |||
| assert.Equal(t, stateMachine.Name(), machineById.Name()) | |||
| assert.Equal(t, stateMachine.TenantId(), machineById.TenantId()) | |||
| assert.Equal(t, stateMachine.Comment(), machineById.Comment()) | |||
| assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) | |||
| repository.stateMachineMapByNameAndTenant[stateMachine.Name()+"_"+stateMachine.TenantId()] = nil | |||
| machineByNameAndTenantId, err := repository.GetLastVersionStateMachine(stateMachine.Name(), stateMachine.TenantId()) | |||
| assert.Nil(t, err) | |||
| assert.Equal(t, stateMachine.ID(), machineByNameAndTenantId.ID()) | |||
| assert.Equal(t, stateMachine.Comment(), machineById.Comment()) | |||
| assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano()) | |||
| } | |||
| @@ -151,3 +151,28 @@ func (g *GlobalTransactionManager) Rollback(ctx context.Context, gtr *GlobalTran | |||
| return nil | |||
| } | |||
| // GlobalReport Global report. | |||
| func (g *GlobalTransactionManager) GlobalReport(ctx context.Context, gtr *GlobalTransaction) (message.GlobalStatus, error) { | |||
| if gtr.Xid == "" { | |||
| return message.GlobalStatusUnKnown, fmt.Errorf("GlobalReport xid should not be empty") | |||
| } | |||
| req := message.GlobalReportRequest{ | |||
| AbstractGlobalEndRequest: message.AbstractGlobalEndRequest{ | |||
| Xid: gtr.Xid, | |||
| }, | |||
| GlobalStatus: gtr.TxStatus, | |||
| } | |||
| res, err := getty.GetGettyRemotingClient().SendSyncRequest(req) | |||
| if err != nil { | |||
| log.Errorf("GlobalBeginRequest error %v", err) | |||
| return message.GlobalStatusUnKnown, err | |||
| } | |||
| if res == nil || res.(message.GlobalReportResponse).ResultCode == message.ResultCodeFailed { | |||
| log.Errorf("GlobalReportRequest result is empty or result code is failed, res %v", res) | |||
| return message.GlobalStatusUnKnown, fmt.Errorf("GlobalReportRequest result is empty or result code is failed.") | |||
| } | |||
| log.Infof("GlobalReportRequest success, res %v", res) | |||
| return res.(message.GlobalReportResponse).GlobalStatus, nil | |||
| } | |||