Compare commits

...

26 Commits

Author SHA1 Message Date
  Wiggins dc0a7a1410
Implement SagaResource and SagaResourceManager (#855) 1 month ago
  flypiggy fcc1cf2527
Feature/add script manager&actuator (#868) 1 month ago
  flypiggy 17237abd3a
Resolve Merge Conflicts, Add Persistence Initialization, and Optimize Tests (#847) 2 months ago
  lxfeng1997 0b467cb183
Decouple the transaction from the statelog (#841) 3 months ago
  flypiggy c2dfcf0e38
Feature: finish implementation for StateMachineConfig (#805) 3 months ago
  flypiggy 269b081669
refactor-engine/core (#838) 3 months ago
  FengZhang 911fc21045
Revert "bugfix:Remove issue translation workflow as usthe/issues-translate-action@v2.7 does not allow issues in apache/incubator-seata-go" (#822) 4 months ago
  深几许 018486951c
bugfix:Remove issue translation workflow as usthe/issues-translate-action@v2.7 does not allow issues in apache/incubator-seata-go (#821) 4 months ago
  lxfeng1997 d51b344002
change the StateLangStore type (#812) 4 months ago
  marsevilspirit 8e84b0ef2c
Feat cel expression (#788) 4 months ago
  深几许 0866141db3
refactor: implement pending TODOs in machine status_decision logic (#808) 4 months ago
  Jingliu d3a4cb1689
Feature/saga interface optimization (#778) 4 months ago
  1kasa 6c148cfb21
feat: Supplement the statelog_repository section in Database persistence for Saga state machine (#800) 6 months ago
  lxfeng1997 689c5d6f7c
Feature: Database persistence for seata-go Saga state machine (#794) 6 months ago
  FengZhang bb1c1262d4
feat: HttpServiceTaskState Support (#769) 6 months ago
  flypiggy e70cabe8c7
[Refactor] Migrate StateMachineObject to client/config and unify config parsing with koanf (#785) 6 months ago
  FinnTew 6734025528
[to-reply] feature: support saga multi type config (#741) 7 months ago
  marsevilspirit 1385f9e856
feat: add func invoker (#744) 9 months ago
  A Cabbage 8d68ae4494
feature: sequential execution of state machine in Saga (#681) 1 year ago
  FanOne d1fad744aa
feature saga :support generate id by Snowflake (#670) 1 year ago
  Xiangkun Yin 540dab4ea4
feature: add default implementation for StateMachineConfig (#669) 1 year ago
  Jingliu Xiong 90a8721983
feat: add grpc invoker in saga mode (#668) 1 year ago
  Xiangkun Yin 33ff4e59cc
feature: add saga persistence layer (#649) 1 year ago
  Jingliu Xiong 9c4c0f44a0
feature: add serverice task parse in statelang (#650) 1 year ago
  Xiangkun Yin 34c5e527c0
refactor: refactor saga scaffold to break import cycles (#647) 1 year ago
  wt_better 58fc3e13df
feature: init saga framework (#635) 1 year ago
100 changed files with 11904 additions and 21 deletions
Split View
  1. +18
    -9
      go.mod
  2. +32
    -12
      go.sum
  3. +2
    -0
      pkg/client/client.go
  4. +4
    -0
      pkg/client/config.go
  5. +2
    -0
      pkg/remoting/loadbalance/loadbalance.go
  6. +69
    -0
      pkg/remoting/loadbalance/round_robin_loadbalance.go
  7. +100
    -0
      pkg/remoting/loadbalance/round_robin_loadbalance_test.go
  8. +34
    -0
      pkg/saga/config.go
  9. +31
    -0
      pkg/saga/readme.md
  10. +35
    -0
      pkg/saga/rm/handler_saga.go
  11. +52
    -0
      pkg/saga/rm/saga_resource.go
  12. +159
    -0
      pkg/saga/rm/saga_resource_manager.go
  13. +43
    -0
      pkg/saga/rm/saga_resource_manager_test.go
  14. +38
    -0
      pkg/saga/rm/state_machine_engine_holder.go
  15. +103
    -0
      pkg/saga/statemachine/constant/constant.go
  16. +734
    -0
      pkg/saga/statemachine/engine/config/default_statemachine_config.go
  17. +265
    -0
      pkg/saga/statemachine/engine/config/default_statemachine_config_test.go
  18. +86
    -0
      pkg/saga/statemachine/engine/config/noop_store.go
  19. +3
    -0
      pkg/saga/statemachine/engine/config/testdata/invalid.json
  20. +16
    -0
      pkg/saga/statemachine/engine/config/testdata/invalid.json.comment
  21. +19
    -0
      pkg/saga/statemachine/engine/config/testdata/invalid.yaml
  22. +274
    -0
      pkg/saga/statemachine/engine/config/testdata/order_saga.json
  23. +16
    -0
      pkg/saga/statemachine/engine/config/testdata/order_saga.json.comment
  24. +204
    -0
      pkg/saga/statemachine/engine/config/testdata/order_saga.yaml
  25. +722
    -0
      pkg/saga/statemachine/engine/core/process_ctrl_statemachine_engine.go
  26. +83
    -0
      pkg/saga/statemachine/engine/exception/exception.go
  27. +68
    -0
      pkg/saga/statemachine/engine/exception/exception_test.go
  28. +32
    -0
      pkg/saga/statemachine/engine/exception/forward_invalid_exception.go
  29. +68
    -0
      pkg/saga/statemachine/engine/exception/forward_invalid_exception_test.go
  30. +96
    -0
      pkg/saga/statemachine/engine/expr/cel_expression.go
  31. +36
    -0
      pkg/saga/statemachine/engine/expr/cel_expression_factory.go
  32. +31
    -0
      pkg/saga/statemachine/engine/expr/cel_expression_factory_test.go
  33. +42
    -0
      pkg/saga/statemachine/engine/expr/cel_expression_test.go
  34. +21
    -0
      pkg/saga/statemachine/engine/expr/error_expression.go
  35. +30
    -0
      pkg/saga/statemachine/engine/expr/expression.go
  36. +22
    -0
      pkg/saga/statemachine/engine/expr/expression_factory.go
  37. +50
    -0
      pkg/saga/statemachine/engine/expr/expression_factory_manager.go
  38. +105
    -0
      pkg/saga/statemachine/engine/expr/expression_resolver.go
  39. +53
    -0
      pkg/saga/statemachine/engine/expr/expression_resolver_test.go
  40. +64
    -0
      pkg/saga/statemachine/engine/expr/sequence_expression.go
  41. +39
    -0
      pkg/saga/statemachine/engine/expr/sequence_expression_factory.go
  42. +225
    -0
      pkg/saga/statemachine/engine/invoker/func_invoker.go
  43. +168
    -0
      pkg/saga/statemachine/engine/invoker/func_invoker_test.go
  44. +261
    -0
      pkg/saga/statemachine/engine/invoker/grpc_invoker.go
  45. +185
    -0
      pkg/saga/statemachine/engine/invoker/grpc_invoker_test.go
  46. +220
    -0
      pkg/saga/statemachine/engine/invoker/http_invoker.go
  47. +176
    -0
      pkg/saga/statemachine/engine/invoker/http_invoker_test.go
  48. +127
    -0
      pkg/saga/statemachine/engine/invoker/invoker.go
  49. +162
    -0
      pkg/saga/statemachine/engine/invoker/javascript_script_invoker.go
  50. +262
    -0
      pkg/saga/statemachine/engine/invoker/javascript_script_invoker_test.go
  51. +195
    -0
      pkg/saga/statemachine/engine/invoker/local_invoker.go
  52. +212
    -0
      pkg/saga/statemachine/engine/invoker/local_invoker_test.go
  53. +81
    -0
      pkg/saga/statemachine/engine/pcext/compensation_holder.go
  54. +186
    -0
      pkg/saga/statemachine/engine/pcext/engine_utils.go
  55. +112
    -0
      pkg/saga/statemachine/engine/pcext/instruction.go
  56. +110
    -0
      pkg/saga/statemachine/engine/pcext/loop_context_holder.go
  57. +59
    -0
      pkg/saga/statemachine/engine/pcext/loop_task_utils.go
  58. +151
    -0
      pkg/saga/statemachine/engine/pcext/parameter_utils.go
  59. +115
    -0
      pkg/saga/statemachine/engine/pcext/process_handler.go
  60. +131
    -0
      pkg/saga/statemachine/engine/pcext/process_router.go
  61. +169
    -0
      pkg/saga/statemachine/engine/pcext/state_router_impl.go
  62. +137
    -0
      pkg/saga/statemachine/engine/repo/repository/state_log_repository.go
  63. +237
    -0
      pkg/saga/statemachine/engine/repo/repository/state_machine_repository.go
  64. +118
    -0
      pkg/saga/statemachine/engine/repo/repository/state_machine_repository_test.go
  65. +47
    -0
      pkg/saga/statemachine/engine/repo/statemachine_store.go
  66. +22
    -0
      pkg/saga/statemachine/engine/sequence/sequence.go
  67. +133
    -0
      pkg/saga/statemachine/engine/sequence/snowflake.go
  68. +45
    -0
      pkg/saga/statemachine/engine/sequence/snowflake_test.go
  69. +31
    -0
      pkg/saga/statemachine/engine/sequence/uuid.go
  70. +59
    -0
      pkg/saga/statemachine/engine/serializer/serializer.go
  71. +34
    -0
      pkg/saga/statemachine/engine/serializer/serializer_test.go
  72. +77
    -0
      pkg/saga/statemachine/engine/statemachine_config.go
  73. +59
    -0
      pkg/saga/statemachine/engine/statemachine_engine.go
  74. +33
    -0
      pkg/saga/statemachine/engine/statemachine_engine_test.go
  75. +19
    -0
      pkg/saga/statemachine/engine/strategy.go
  76. +246
    -0
      pkg/saga/statemachine/engine/strategy/status_decision.go
  77. +113
    -0
      pkg/saga/statemachine/engine/utils/process_context_utils.go
  78. +109
    -0
      pkg/saga/statemachine/process_ctrl/bussiness_processor.go
  79. +7
    -0
      pkg/saga/statemachine/process_ctrl/default_process_handler.go
  80. +21
    -0
      pkg/saga/statemachine/process_ctrl/event.go
  81. +146
    -0
      pkg/saga/statemachine/process_ctrl/event_bus.go
  82. +56
    -0
      pkg/saga/statemachine/process_ctrl/event_consumer.go
  83. +36
    -0
      pkg/saga/statemachine/process_ctrl/event_publisher.go
  84. +193
    -0
      pkg/saga/statemachine/process_ctrl/handlers/service_task_state_handler.go
  85. +21
    -0
      pkg/saga/statemachine/process_ctrl/instruction.go
  86. +24
    -0
      pkg/saga/statemachine/process_ctrl/process/process_type.go
  87. +225
    -0
      pkg/saga/statemachine/process_ctrl/process_context.go
  88. +48
    -0
      pkg/saga/statemachine/process_ctrl/process_controller.go
  89. +107
    -0
      pkg/saga/statemachine/process_ctrl/process_router.go
  90. +91
    -0
      pkg/saga/statemachine/statelang/parser/choice_state_json_parser.go
  91. +48
    -0
      pkg/saga/statemachine/statelang/parser/compensation_trigger_state_parser.go
  92. +83
    -0
      pkg/saga/statemachine/statelang/parser/end_state_parser.go
  93. +139
    -0
      pkg/saga/statemachine/statelang/parser/statemachine_config_parser.go
  94. +882
    -0
      pkg/saga/statemachine/statelang/parser/statemachine_config_parser_test.go
  95. +130
    -0
      pkg/saga/statemachine/statelang/parser/statemachine_json_parser.go
  96. +127
    -0
      pkg/saga/statemachine/statelang/parser/statemachine_json_parser_test.go
  97. +253
    -0
      pkg/saga/statemachine/statelang/parser/statemachine_parser.go
  98. +101
    -0
      pkg/saga/statemachine/statelang/parser/sub_state_machine_parser.go
  99. +347
    -0
      pkg/saga/statemachine/statelang/parser/task_state_json_parser.go
  100. +92
    -0
      pkg/saga/statemachine/statelang/state.go

+ 18
- 9
go.mod View File

@@ -1,6 +1,6 @@
module github.com/seata/seata-go

go 1.18
go 1.20

require (
dubbo.apache.org/dubbo-go/v3 v3.0.4
@@ -25,7 +25,7 @@ require (
github.com/stretchr/testify v1.8.3
go.uber.org/atomic v1.9.0
go.uber.org/zap v1.21.0
google.golang.org/grpc v1.51.0
google.golang.org/grpc v1.57.0
gopkg.in/yaml.v2 v2.4.0
vimagination.zapto.org/byteio v0.0.0-20200222190125-d27cba0f0b10
)
@@ -33,17 +33,24 @@ require (
require (
github.com/agiledragon/gomonkey v2.0.2+incompatible
github.com/agiledragon/gomonkey/v2 v2.9.0
github.com/google/cel-go v0.18.0
github.com/mattn/go-sqlite3 v1.14.19
github.com/robertkrimen/otto v0.4.0
golang.org/x/sync v0.16.0
google.golang.org/protobuf v1.33.0
gopkg.in/yaml.v3 v3.0.1
)

require (
github.com/RoaringBitmap/roaring v1.2.0 // indirect
github.com/Workiva/go-datastructures v1.0.52 // indirect
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df // indirect
github.com/apache/dubbo-go-hessian2 v1.11.4 // indirect
github.com/benbjohnson/clock v1.1.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bits-and-blooms/bitset v1.2.0 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/creasty/defaults v1.5.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
@@ -52,7 +59,7 @@ require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/jinzhu/copier v0.3.5 // indirect
@@ -84,10 +91,9 @@ require (
github.com/yusufpapurcu/wmi v1.2.2 // indirect
go.uber.org/multierr v1.8.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
golang.org/x/text v0.27.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gopkg.in/sourcemap.v1 v1.0.5 // indirect
)

require (
@@ -98,10 +104,13 @@ require (
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/pelletier/go-toml v1.9.3 // indirect
github.com/pingcap/log v0.0.0-20210906054005-afc726e70354 // indirect
github.com/stoewer/go-strcase v1.2.0 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.15.0 // indirect
google.golang.org/genproto v0.0.0-20220630174209-ad1d48641aa7 // indirect
golang.org/x/sys v0.32.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5 // indirect
vimagination.zapto.org/memio v0.0.0-20200222190306-588ebc67b97d // indirect
)



+ 32
- 12
go.sum View File

@@ -67,6 +67,8 @@ github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk5
github.com/alibaba/sentinel-golang v1.0.4/go.mod h1:Lag5rIYyJiPOylK8Kku2P+a23gdKMMqzQS7wTnjWEpk=
github.com/aliyun/alibaba-cloud-sdk-go v1.61.18/go.mod h1:v8ESoHo4SyHmuB4b1tJqDHxfTGEciD+yhvOU/5s1Rfk=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df h1:7RFfzj4SSt6nnvCPbCqijJi1nWCd+TqAT3bYCStRC18=
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df/go.mod h1:pSwJ0fSY5KhvocuWSx4fz3BA8OrA1bQn+K1Eli3BRwM=
github.com/apache/dubbo-getty v1.4.9-0.20221022181821-4dc6252ce98c/go.mod h1:6qmrqBSPGs3B35zwEuGhEYNVsx1nfGT/xzV2yOt2amM=
github.com/apache/dubbo-getty v1.4.10-0.20230731065302-7c0f0039e59c h1:e1pJKY0lFvO6rik7m3qmpMRA98cc9Zkg6AJeB1/7QFQ=
github.com/apache/dubbo-getty v1.4.10-0.20230731065302-7c0f0039e59c/go.mod h1:TqHfi87Ufv7wpwI7nER5Kx8FCb/jjwlyazxiYwEmTs8=
@@ -119,8 +121,9 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA
github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@@ -304,14 +307,17 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/cel-go v0.18.0 h1:u74MPiEC8mejBrkXqrTWT102g5IFEUjxOngzQIijMzU=
github.com/google/cel-go v0.18.0/go.mod h1:PVAybmSnWkNMUZR/tEWFUiJ1Np4Hz0MHsZJcgC4zln4=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
@@ -515,6 +521,8 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI=
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
@@ -664,6 +672,8 @@ github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40T
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rhnvrm/simples3 v0.6.1/go.mod h1:Y+3vYm2V7Y4VijFoJHHTrja6OgPrJ2cBti8dPGkC3sA=
github.com/robertkrimen/otto v0.4.0 h1:/c0GRrK1XDPcgIasAsnlpBT5DelIeB9U/Z/JCQsgr7E=
github.com/robertkrimen/otto v0.4.0/go.mod h1:uW9yN1CYflmUQYvAMS0m+ZiNo3dMzRUDQJX0jWbzgxw=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
@@ -706,6 +716,8 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg=
github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU=
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI=
@@ -846,6 +858,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw=
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw=
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -940,6 +954,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -1017,8 +1033,8 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -1031,8 +1047,8 @@ golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -1167,8 +1183,10 @@ google.golang.org/genproto v0.0.0-20210106152847-07624b53cd92/go.mod h1:FWY/as6D
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0=
google.golang.org/genproto v0.0.0-20211104193956-4c6863e31247/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20220504150022-98cd25cafc72/go.mod h1:RAyBrSAP7Fh3Nc84ghnVLDPuV51xc9agzmm4Ph6i0Q4=
google.golang.org/genproto v0.0.0-20220630174209-ad1d48641aa7 h1:q4zUJDd0+knPFB9x20S3vnxzlYNBbt8Yd7zBMVMteeM=
google.golang.org/genproto v0.0.0-20220630174209-ad1d48641aa7/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA=
google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 h1:nIgk/EEq3/YlnmVVXVnm14rC2oxgs1o0ong4sD/rd44=
google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5/go.mod h1:5DZzOUPCLYL3mNkQ0ms0F3EuUNZ7py1Bqeq6sxzI7/Q=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5 h1:eSaPbMR4T7WfH9FvABk36NBMacoTUKdWCvV0dx+KfOg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5/go.mod h1:zBEcrKX2ZOcEkHWxBPAIvYUWOKKMIhYcmNiUIu2ji3I=
google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
@@ -1197,10 +1215,10 @@ google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQ
google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34=
google.golang.org/grpc v1.41.0/go.mod h1:U3l9uK9J0sini8mHphKoXyaqDA/8VyGnDee1zzIUK6k=
google.golang.org/grpc v1.46.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk=
google.golang.org/grpc v1.47.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk=
google.golang.org/grpc v1.48.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk=
google.golang.org/grpc v1.51.0 h1:E1eGv1FTqoLIdnBCZufiSHgKjlqG6fKFf6pPWtMTh8U=
google.golang.org/grpc v1.51.0/go.mod h1:wgNDFcnuBGmxLKI/qn4T+m5BtEBYXJPvibbUPsAIPww=
google.golang.org/grpc v1.57.0 h1:kfzNeI/klCGD2YPMUlaGNT3pxvYfga7smW3Vth8Zsiw=
google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
@@ -1216,8 +1234,8 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d/go.mod h1:cuepJuh7vyXfUyUwEgHQXw849cJrilpS5NeIjOWESAw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -1235,6 +1253,8 @@ gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8=
gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k=
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
gopkg.in/sourcemap.v1 v1.0.5 h1:inv58fC9f9J3TK2Y2R1NPntXEn3/wjWHkonhIUODNTI=
gopkg.in/sourcemap.v1 v1.0.5/go.mod h1:2RlvNNSMglmRrcvhfuzp4hQHwOtjxlbjX7UPY/GXb78=
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=


+ 2
- 0
pkg/client/client.go View File

@@ -30,6 +30,7 @@ import (
"github.com/seata/seata-go/pkg/remoting/processor/client"
"github.com/seata/seata-go/pkg/rm"
"github.com/seata/seata-go/pkg/rm/tcc"
saga "github.com/seata/seata-go/pkg/saga/rm"
"github.com/seata/seata-go/pkg/tm"
"github.com/seata/seata-go/pkg/util/log"
)
@@ -87,6 +88,7 @@ func initRmClient(cfg *Config) {
client.RegisterProcessor()
integration.Init()
tcc.InitTCC()
saga.InitSaga()
at.InitAT(cfg.ClientConfig.UndoConfig, cfg.AsyncWorkerConfig)
at.InitXA(cfg.ClientConfig.XaConfig)
})


+ 4
- 0
pkg/client/config.go View File

@@ -20,6 +20,7 @@ package client
import (
"flag"
"fmt"
"github.com/seata/seata-go/pkg/saga"
"io/ioutil"
"os"
"path/filepath"
@@ -84,6 +85,8 @@ type Config struct {
TransportConfig remoteConfig.TransportConfig `yaml:"transport" json:"transport" koanf:"transport"`
ServiceConfig discovery.ServiceConfig `yaml:"service" json:"service" koanf:"service"`
RegistryConfig discovery.RegistryConfig `yaml:"registry" json:"registry" koanf:"registry"`

SagaConfig saga.Config `yaml:"saga" json:"saga" koanf:"saga"`
}

func (c *Config) RegisterFlags(f *flag.FlagSet) {
@@ -102,6 +105,7 @@ func (c *Config) RegisterFlags(f *flag.FlagSet) {
c.TransportConfig.RegisterFlagsWithPrefix("transport", f)
c.RegistryConfig.RegisterFlagsWithPrefix("registry", f)
c.ServiceConfig.RegisterFlagsWithPrefix("service", f)
c.SagaConfig.RegisterFlagsWithPrefix("saga", f)
}

type loaderConf struct {


+ 2
- 0
pkg/remoting/loadbalance/loadbalance.go View File

@@ -37,6 +37,8 @@ func Select(loadBalanceType string, sessions *sync.Map, xid string) getty.Sessio
return RandomLoadBalance(sessions, xid)
case xidLoadBalance:
return XidLoadBalance(sessions, xid)
case roundRobinLoadBalance:
return RoundRobinLoadBalance(sessions, xid)
default:
return RandomLoadBalance(sessions, xid)
}


+ 69
- 0
pkg/remoting/loadbalance/round_robin_loadbalance.go View File

@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package loadbalance

import (
"math"
"sort"
"sync"
"sync/atomic"

getty "github.com/apache/dubbo-getty"
)

var sequence int32

func RoundRobinLoadBalance(sessions *sync.Map, s string) getty.Session {
// collect sync.Map adderToSession
// filter out closed session instance
adderToSession := make(map[string]getty.Session, 0)
// map has no sequence, we should sort it to make sure the sequence is always the same
adders := make([]string, 0)
sessions.Range(func(key, value interface{}) bool {
session := key.(getty.Session)
if session.IsClosed() {
sessions.Delete(key)
} else {
adderToSession[session.RemoteAddr()] = session
adders = append(adders, session.RemoteAddr())
}
return true
})
sort.Strings(adders)
// adderToSession eq 0 means there are no available session
if len(adderToSession) == 0 {
return nil
}
index := getPositiveSequence() % len(adderToSession)
return adderToSession[adders[index]]
}

func getPositiveSequence() int {
for {
current := atomic.LoadInt32(&sequence)
var next int32
if current == math.MaxInt32 {
next = 0
} else {
next = current + 1
}
if atomic.CompareAndSwapInt32(&sequence, current, next) {
return int(current)
}
}
}

+ 100
- 0
pkg/remoting/loadbalance/round_robin_loadbalance_test.go View File

@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package loadbalance

import (
"fmt"
"math"
"sync"
"testing"

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"

"github.com/seata/seata-go/pkg/remoting/mock"
)

func TestRoundRobinLoadBalance_Normal(t *testing.T) {
ctrl := gomock.NewController(t)
sessions := &sync.Map{}

for i := 0; i < 10; i++ {
session := mock.NewMockTestSession(ctrl)
session.EXPECT().IsClosed().Return(i == 2).AnyTimes()
session.EXPECT().RemoteAddr().Return(fmt.Sprintf("%d", i)).AnyTimes()
sessions.Store(session, fmt.Sprintf("session-%d", i+1))
}

for i := 0; i < 10; i++ {
if i == 2 {
continue
}
result := RoundRobinLoadBalance(sessions, "some_xid")
assert.Equal(t, fmt.Sprintf("%d", i), result.RemoteAddr())
assert.NotNil(t, result)
assert.False(t, result.IsClosed())
}
}

func TestRoundRobinLoadBalance_OverSequence(t *testing.T) {
ctrl := gomock.NewController(t)
sessions := &sync.Map{}
sequence = math.MaxInt32

for i := 0; i < 10; i++ {
session := mock.NewMockTestSession(ctrl)
session.EXPECT().IsClosed().Return(false).AnyTimes()
session.EXPECT().RemoteAddr().Return(fmt.Sprintf("%d", i)).AnyTimes()
sessions.Store(session, fmt.Sprintf("session-%d", i+1))
}

for i := 0; i < 10; i++ {
// over sequence here
if i == 0 {
result := RoundRobinLoadBalance(sessions, "some_xid")
assert.Equal(t, "7", result.RemoteAddr())
assert.NotNil(t, result)
assert.False(t, result.IsClosed())
continue
}
result := RoundRobinLoadBalance(sessions, "some_xid")
assert.Equal(t, fmt.Sprintf("%d", i-1), result.RemoteAddr())
assert.NotNil(t, result)
assert.False(t, result.IsClosed())
}
}

func TestRoundRobinLoadBalance_All_Closed(t *testing.T) {
ctrl := gomock.NewController(t)
sessions := &sync.Map{}
for i := 0; i < 10; i++ {
session := mock.NewMockTestSession(ctrl)
session.EXPECT().IsClosed().Return(true).AnyTimes()
sessions.Store(session, fmt.Sprintf("session-%d", i+1))
}
if result := RoundRobinLoadBalance(sessions, "some_xid"); result != nil {
t.Errorf("Expected nil, actual got %+v", result)
}
}

func TestRoundRobinLoadBalance_Empty(t *testing.T) {
sessions := &sync.Map{}
if result := RoundRobinLoadBalance(sessions, "some_xid"); result != nil {
t.Errorf("Expected nil, actual got %+v", result)
}
}

+ 34
- 0
pkg/saga/config.go View File

@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package saga

import (
"flag"

"github.com/seata/seata-go/pkg/saga/statemachine"
)

type Config struct {
StateMachine *statemachine.StateMachineObject `yaml:"state-machine" json:"state-machine" koanf:"state-machine"`
}

func (cfg *Config) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) {
if cfg.StateMachine != nil {
cfg.StateMachine.RegisterFlagsWithPrefix(prefix+".state-machine", f)
}
}

+ 31
- 0
pkg/saga/readme.md View File

@@ -0,0 +1,31 @@

# seata saga

未来计划有三种使用方式

- 基于状态机引擎的 json
link: statemachine_engine#Start
- stream builder
stateMachine.serviceTask().build().Start
- 二阶段方式saga,类似tcc使用

上面1、2是以来[statemachine](statemachine),状态机引擎实现的,3相对比较独立。


状态机的实现在:saga-statemachine包中
其中[statelang](statemachine%2Fstatelang)是状态机语言的解析,目前实现的是json解析方式,状态机语言可以参考:
https://seata.io/docs/user/mode/saga

状态机json执行的入口类是:[statemachine_engine.go](statemachine%2Fengine%2Fstatemachine_engine.go)

下面简单说下engine中各个包的作用:
events:saga的是基于事件处理的,其中是event、eventBus的实现
expr:表达式声明、解析、执行
invoker:声明了serviceInvoker、scriptInvoker等接口、task调用管理、执行都在这个包中,例如httpInvoker
process_ctrl:状态机处理流程:上下文、执行、事件流转
sequence:分布式id
store:状态机存储接口、实现
status_decision:状态机状态决策




+ 35
- 0
pkg/saga/rm/handler_saga.go View File

@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package rm

import (
"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/rm"
)

type RMHandlerSaga struct{}

func (h *RMHandlerSaga) HandleUndoLogDeleteRequest(request interface{}) {
// do nothing
}
func (h *RMHandlerSaga) GetResourceManager() rm.ResourceManager {
return rm.GetRmCacheInstance().GetResourceManager(branch.BranchTypeSAGA)
}
func (h *RMHandlerSaga) GetBranchType() branch.BranchType {
return branch.BranchTypeSAGA
}

+ 52
- 0
pkg/saga/rm/saga_resource.go View File

@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package rm

import (
"fmt"
"github.com/seata/seata-go/pkg/protocol/branch"
)

type SagaResource struct {
resourceGroupId string
applicationId string
}

func (r *SagaResource) GetResourceGroupId() string {
return r.resourceGroupId
}

func (r *SagaResource) SetResourceGroupId(resourceGroupId string) {
r.resourceGroupId = resourceGroupId
}

func (r *SagaResource) GetResourceId() string {
return fmt.Sprintf("%s#%s", r.applicationId, r.resourceGroupId)
}

func (r *SagaResource) GetBranchType() branch.BranchType {
return branch.BranchTypeSAGA
}

func (r *SagaResource) GetApplicationId() string {
return r.applicationId
}

func (r *SagaResource) SetApplicationId(applicationId string) {
r.applicationId = applicationId
}

+ 159
- 0
pkg/saga/rm/saga_resource_manager.go View File

@@ -0,0 +1,159 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package rm

import (
"bytes"
"context"
"fmt"
"log"
"sync"

"github.com/seata/seata-go/pkg/protocol/branch"
"github.com/seata/seata-go/pkg/protocol/message"
"github.com/seata/seata-go/pkg/rm"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/exception"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
seataErrors "github.com/seata/seata-go/pkg/util/errors"
)

var (
sagaResourceManagerInstance *SagaResourceManager
once sync.Once
)

type SagaResourceManager struct {
rmRemoting *rm.RMRemoting
resourceCache sync.Map
}

func InitSaga() {
rm.GetRmCacheInstance().RegisterResourceManager(GetSagaResourceManager())
}

func GetSagaResourceManager() *SagaResourceManager {
once.Do(func() {
sagaResourceManagerInstance = &SagaResourceManager{
rmRemoting: rm.GetRMRemotingInstance(),
resourceCache: sync.Map{},
}
})
return sagaResourceManagerInstance
}

func (s *SagaResourceManager) RegisterResource(resource rm.Resource) error {
if _, ok := resource.(*SagaResource); !ok {
return fmt.Errorf("register saga resource error, SagaResource is needed, param %v", resource)
}
s.resourceCache.Store(resource.GetResourceId(), resource)
return s.rmRemoting.RegisterResource(resource)

}

func (s *SagaResourceManager) GetCachedResources() *sync.Map {
return &s.resourceCache
}

func (s *SagaResourceManager) GetBranchType() branch.BranchType {
return branch.BranchTypeSAGA
}

func (s *SagaResourceManager) BranchCommit(ctx context.Context, resource rm.BranchResource) (branch.BranchStatus, error) {
engine := GetStateMachineEngine()
stMaInst, err := engine.Forward(ctx, resource.Xid, nil)
if err != nil {
if fie, ok := exception.IsForwardInvalidException(err); ok {
log.Printf("StateMachine forward failed, xid: %s, err: %v", resource.Xid, err)
if isInstanceNotExists(fie.ErrCode) {
return branch.BranchStatusPhasetwoCommitted, nil
}
}
log.Printf("StateMachine forward failed, xid: %s, err: %v", resource.Xid, err)
return branch.BranchStatusPhasetwoCommitFailedRetryable, err
}

status := stMaInst.Status()
compStatus := stMaInst.CompensationStatus()

switch {
case status == statelang.SU && compStatus == "":
return branch.BranchStatusPhasetwoCommitted, nil
case compStatus == statelang.SU:
return branch.BranchStatusPhasetwoRollbacked, nil
case compStatus == statelang.FA || compStatus == statelang.UN:
return branch.BranchStatusPhasetwoRollbackFailedRetryable, nil
case status == statelang.FA && compStatus == "":
return branch.BranchStatusPhaseoneFailed, nil
default:
return branch.BranchStatusPhasetwoCommitFailedRetryable, nil
}
}

func (s *SagaResourceManager) BranchRollback(ctx context.Context, resource rm.BranchResource) (branch.BranchStatus, error) {
engine := GetStateMachineEngine()
stMaInst, err := engine.ReloadStateMachineInstance(ctx, resource.Xid)
if err != nil || stMaInst == nil {
return branch.BranchStatusPhasetwoRollbacked, nil
}

strategy := stMaInst.StateMachine().RecoverStrategy()
appData := resource.ApplicationData
isTimeoutRollback := bytes.Equal(appData, []byte{byte(message.GlobalStatusTimeoutRollbacking)}) || bytes.Equal(appData, []byte{byte(message.GlobalStatusTimeoutRollbackRetrying)})

if strategy == statelang.Forward && isTimeoutRollback {
log.Printf("Retry by custom recover strategy [Forward] on timeout, SAGA global[%s]", resource.Xid)
return branch.BranchStatusPhasetwoCommitFailedRetryable, nil
}

stMaInst, err = engine.Compensate(ctx, resource.Xid, nil)
if err == nil && stMaInst.CompensationStatus() == statelang.SU {
return branch.BranchStatusPhasetwoRollbacked, nil
}

if fie, ok := exception.IsEngineExecutionException(err); ok {
log.Printf("StateMachine compensate failed, xid: %s, err: %v", resource.Xid, err)
if isInstanceNotExists(fie.ErrCode) {
return branch.BranchStatusPhasetwoRollbacked, nil
}
}
log.Printf("StateMachine compensate failed, xid: %s, err: %v", resource.Xid, err)
return branch.BranchStatusPhasetwoRollbackFailedRetryable, err
}

func (s *SagaResourceManager) BranchRegister(ctx context.Context, param rm.BranchRegisterParam) (int64, error) {
return s.rmRemoting.BranchRegister(param)
}

func (s *SagaResourceManager) BranchReport(ctx context.Context, param rm.BranchReportParam) error {
return s.rmRemoting.BranchReport(param)
}

func (s *SagaResourceManager) LockQuery(ctx context.Context, param rm.LockQueryParam) (bool, error) {
// LockQuery is not supported for Saga resources
return false, fmt.Errorf("LockQuery is not supported for Saga resources")
}

func (s *SagaResourceManager) UnregisterResource(resource rm.Resource) error {
// UnregisterResource is not supported for SagaResourceManager
return fmt.Errorf("UnregisterResource is not supported for SagaResourceManager")
}

// isInstanceNotExists checks if the error code indicates StateMachineInstanceNotExists
func isInstanceNotExists(errCode string) bool {
return errCode == fmt.Sprintf("%v", seataErrors.StateMachineInstanceNotExists)
}

+ 43
- 0
pkg/saga/rm/saga_resource_manager_test.go View File

@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package rm

import (
"sync"
"testing"
)

func TestGetSagaResourceManager_Singleton(t *testing.T) {
var wg sync.WaitGroup
instances := make([]*SagaResourceManager, 10)
for i := 0; i < 10; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
instances[idx] = GetSagaResourceManager()
}(i)
}
wg.Wait()

first := instances[0]
for i, inst := range instances {
if inst != first {
t.Errorf("Instance at index %d is not the same as the first instance", i)
}
}
}

+ 38
- 0
pkg/saga/rm/state_machine_engine_holder.go View File

@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package rm

import (
"github.com/seata/seata-go/pkg/saga/statemachine/engine"
"sync"
)

var (
stateMachineEngine engine.StateMachineEngine
stateMachineEngineOnce sync.Once
)

func GetStateMachineEngine() engine.StateMachineEngine {
return stateMachineEngine
}

func SetStateMachineEngine(smEngine engine.StateMachineEngine) {
stateMachineEngineOnce.Do(func() {
stateMachineEngine = smEngine
})
}

+ 103
- 0
pkg/saga/statemachine/constant/constant.go View File

@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package constant

type NetExceptionType string

const (
// region State Types
StateTypeServiceTask string = "ServiceTask"
StateTypeChoice string = "Choice"
StateTypeFail string = "Fail"
StateTypeSucceed string = "Succeed"
StateTypeCompensationTrigger string = "CompensationTrigger"
StateTypeSubStateMachine string = "SubStateMachine"
StateTypeCompensateSubMachine string = "CompensateSubMachine"
StateTypeScriptTask string = "ScriptTask"
StateTypeLoopStart string = "LoopStart"
// end region

// region Service Types
ServiceTypeGRPC string = "GRPC"
// end region

// region System Variables
VarNameOutputParams string = "outputParams"
VarNameProcessType string = "_ProcessType_"
VarNameOperationName string = "_operation_name_"
OperationNameStart string = "start"
OperationNameCompensate string = "compensate"
VarNameAsyncCallback string = "_async_callback_"
VarNameCurrentExceptionRoute string = "_current_exception_route_"
VarNameIsExceptionNotCatch string = "_is_exception_not_catch_"
VarNameSubMachineParentId string = "_sub_machine_parent_id_"
VarNameCurrentChoice string = "_current_choice_"
VarNameStateMachineInst string = "_current_statemachine_instance_"
VarNameStateMachine string = "_current_statemachine_"
VarNameStateMachineEngine string = "_current_statemachine_engine_"
VarNameStateMachineConfig string = "_statemachine_config_"
VarNameStateMachineContext string = "context"
VarNameIsAsyncExecution string = "_is_async_execution_"
VarNameStateInst string = "_current_state_instance_"
VarNameBusinesskey string = "_business_key_"
VarNameParentId string = "_parent_id_"
VarNameCurrentException string = "currentException"
CompensateSubMachineStateNamePrefix string = "_compensate_sub_machine_state_"
DefaultScriptType string = "groovy"
VarNameSyncExeStack string = "_sync_execution_stack_"
VarNameInputParams string = "inputParams"
VarNameIsLoopState string = "_is_loop_state_"
VarNameCurrentCompensateTriggerState string = "_is_compensating_"
VarNameCurrentCompensationHolder string = "_current_compensation_holder_"
VarNameFirstCompensationStateStarted string = "_first_compensation_state_started"
VarNameCurrentLoopContextHolder string = "_current_loop_context_holder_"
VarNameRetriedStateInstId string = "_retried_state_instance_id"
VarNameIsForSubStatMachineForward string = "_is_for_sub_statemachine_forward_"
// TODO: this lock in process context only has one, try to add more to add concurrent
VarNameProcessContextMutexLock string = "_current_context_mutex_lock"
VarNameFailEndStateFlag string = "_fail_end_state_flag_"
VarNameGlobalTx string = "_global_transaction_"
// end region

// region of loop
LoopCounter string = "loopCounter"
LoopSemaphore string = "loopSemaphore"
LoopResult string = "loopResult"
NumberOfInstances string = "nrOfInstances"
NumberOfActiveInstances string = "nrOfActiveInstances"
NumberOfCompletedInstances string = "nrOfCompletedInstances"
// end region

// region others
SeqEntityStateMachine string = "STATE_MACHINE"
SeqEntityStateMachineInst string = "STATE_MACHINE_INST"
SeqEntityStateInst string = "STATE_INST"
OperationNameForward string = "forward"
LoopStateNamePattern string = "-loop-"
SagaTransNamePrefix string = "$Saga_"
// end region

SeperatorParentId string = ":"

// Machine execution timeout error code
FrameworkErrorCodeStateMachineExecutionTimeout = "StateMachineExecutionTimeout"
ConnectException NetExceptionType = "CONNECT_EXCEPTION"
ConnectTimeoutException NetExceptionType = "CONNECT_TIMEOUT_EXCEPTION"
ReadTimeoutException NetExceptionType = "READ_TIMEOUT_EXCEPTION"
NotNetException NetExceptionType = "NOT_NET_EXCEPTION"
)

+ 734
- 0
pkg/saga/statemachine/engine/config/default_statemachine_config.go View File

@@ -0,0 +1,734 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package config

import (
"context"
"encoding/json"
"fmt"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/repo/repository"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/strategy"
"gopkg.in/yaml.v3"
"log"
"os"
"path/filepath"
"strings"
"sync"

"github.com/seata/seata-go/pkg/saga/statemachine"
"github.com/seata/seata-go/pkg/saga/statemachine/engine"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/expr"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/invoker"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/repo"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser"
"github.com/seata/seata-go/pkg/saga/statemachine/store"
)

const (
DefaultTransOperTimeout = 60000 * 30
DefaultServiceInvokeTimeout = 60000 * 5
DefaultClientSagaRetryPersistModeUpdate = false
DefaultClientSagaCompensatePersistModeUpdate = false
DefaultClientReportSuccessEnable = false
DefaultClientSagaBranchRegisterEnable = true
)

type DefaultStateMachineConfig struct {
// Configuration
transOperationTimeout int
serviceInvokeTimeout int
charset string
defaultTenantId string
sagaRetryPersistModeUpdate bool
sagaCompensatePersistModeUpdate bool
sagaBranchRegisterEnable bool
rmReportSuccessEnable bool
stateMachineResources []string

// State machine definitions
stateMachineDefs map[string]*statemachine.StateMachineObject

// Components
processController process_ctrl.ProcessController

// Event Bus
syncEventBus process_ctrl.EventBus
asyncEventBus process_ctrl.EventBus

// Event publisher
syncProcessCtrlEventPublisher process_ctrl.EventPublisher
asyncProcessCtrlEventPublisher process_ctrl.EventPublisher

// Store related components
stateLogRepository repo.StateLogRepository
stateLogStore store.StateLogStore
stateLangStore store.StateLangStore
stateMachineRepository repo.StateMachineRepository

// Expression related components
expressionFactoryManager *expr.ExpressionFactoryManager
expressionResolver expr.ExpressionResolver

// Invoker related components
serviceInvokerManager invoker.ServiceInvokerManager
scriptInvokerManager invoker.ScriptInvokerManager

// Other components
statusDecisionStrategy engine.StatusDecisionStrategy
seqGenerator sequence.SeqGenerator
componentLock *sync.Mutex
}

func (c *DefaultStateMachineConfig) ComponentLock() *sync.Mutex {
return c.componentLock
}

func (c *DefaultStateMachineConfig) SetComponentLock(componentLock *sync.Mutex) {
c.componentLock = componentLock
}

func (c *DefaultStateMachineConfig) SetTransOperationTimeout(transOperationTimeout int) {
c.transOperationTimeout = transOperationTimeout
}

func (c *DefaultStateMachineConfig) SetServiceInvokeTimeout(serviceInvokeTimeout int) {
c.serviceInvokeTimeout = serviceInvokeTimeout
}

func (c *DefaultStateMachineConfig) SetCharset(charset string) {
c.charset = charset
}

func (c *DefaultStateMachineConfig) SetDefaultTenantId(defaultTenantId string) {
c.defaultTenantId = defaultTenantId
}

func (c *DefaultStateMachineConfig) SetSyncEventBus(syncEventBus process_ctrl.EventBus) {
c.syncEventBus = syncEventBus
}

func (c *DefaultStateMachineConfig) SetAsyncEventBus(asyncEventBus process_ctrl.EventBus) {
c.asyncEventBus = asyncEventBus
}

func (c *DefaultStateMachineConfig) SetSyncProcessCtrlEventPublisher(syncProcessCtrlEventPublisher process_ctrl.EventPublisher) {
c.syncProcessCtrlEventPublisher = syncProcessCtrlEventPublisher
}

func (c *DefaultStateMachineConfig) SetAsyncProcessCtrlEventPublisher(asyncProcessCtrlEventPublisher process_ctrl.EventPublisher) {
c.asyncProcessCtrlEventPublisher = asyncProcessCtrlEventPublisher
}

func (c *DefaultStateMachineConfig) SetStateLogRepository(stateLogRepository repo.StateLogRepository) {
c.stateLogRepository = stateLogRepository
}

func (c *DefaultStateMachineConfig) SetStateLogStore(stateLogStore store.StateLogStore) {
c.stateLogStore = stateLogStore
}

func (c *DefaultStateMachineConfig) SetStateLangStore(stateLangStore store.StateLangStore) {
c.stateLangStore = stateLangStore
}

func (c *DefaultStateMachineConfig) SetStateMachineRepository(stateMachineRepository repo.StateMachineRepository) {
c.stateMachineRepository = stateMachineRepository
}

func (c *DefaultStateMachineConfig) SetExpressionFactoryManager(expressionFactoryManager *expr.ExpressionFactoryManager) {
c.expressionFactoryManager = expressionFactoryManager
}

func (c *DefaultStateMachineConfig) SetExpressionResolver(expressionResolver expr.ExpressionResolver) {
c.expressionResolver = expressionResolver
}

func (c *DefaultStateMachineConfig) SetServiceInvokerManager(serviceInvokerManager invoker.ServiceInvokerManager) {
c.serviceInvokerManager = serviceInvokerManager
}

func (c *DefaultStateMachineConfig) SetScriptInvokerManager(scriptInvokerManager invoker.ScriptInvokerManager) {
c.scriptInvokerManager = scriptInvokerManager
}

func (c *DefaultStateMachineConfig) SetStatusDecisionStrategy(statusDecisionStrategy engine.StatusDecisionStrategy) {
c.statusDecisionStrategy = statusDecisionStrategy
}

func (c *DefaultStateMachineConfig) SetSeqGenerator(seqGenerator sequence.SeqGenerator) {
c.seqGenerator = seqGenerator
}

func (c *DefaultStateMachineConfig) StateLogRepository() repo.StateLogRepository {
return c.stateLogRepository
}

func (c *DefaultStateMachineConfig) StateMachineRepository() repo.StateMachineRepository {
return c.stateMachineRepository
}

func (c *DefaultStateMachineConfig) StateLogStore() store.StateLogStore {
return c.stateLogStore
}

func (c *DefaultStateMachineConfig) StateLangStore() store.StateLangStore {
return c.stateLangStore
}

func (c *DefaultStateMachineConfig) ExpressionFactoryManager() *expr.ExpressionFactoryManager {
return c.expressionFactoryManager
}

func (c *DefaultStateMachineConfig) ExpressionResolver() expr.ExpressionResolver {
return c.expressionResolver
}

func (c *DefaultStateMachineConfig) SeqGenerator() sequence.SeqGenerator {
return c.seqGenerator
}

func (c *DefaultStateMachineConfig) StatusDecisionStrategy() engine.StatusDecisionStrategy {
return c.statusDecisionStrategy
}

func (c *DefaultStateMachineConfig) SyncEventBus() process_ctrl.EventBus {
return c.syncEventBus
}

func (c *DefaultStateMachineConfig) AsyncEventBus() process_ctrl.EventBus {
return c.asyncEventBus
}

func (c *DefaultStateMachineConfig) EventPublisher() process_ctrl.EventPublisher {
return c.syncProcessCtrlEventPublisher
}

func (c *DefaultStateMachineConfig) AsyncEventPublisher() process_ctrl.EventPublisher {
return c.asyncProcessCtrlEventPublisher
}

func (c *DefaultStateMachineConfig) ServiceInvokerManager() invoker.ServiceInvokerManager {
return c.serviceInvokerManager
}

func (c *DefaultStateMachineConfig) ScriptInvokerManager() invoker.ScriptInvokerManager {
return c.scriptInvokerManager
}

func (c *DefaultStateMachineConfig) CharSet() string {
return c.charset
}

func (c *DefaultStateMachineConfig) SetCharSet(charset string) {
c.charset = charset
}

func (c *DefaultStateMachineConfig) GetDefaultTenantId() string {
return c.defaultTenantId
}

func (c *DefaultStateMachineConfig) GetTransOperationTimeout() int {
return c.transOperationTimeout
}

func (c *DefaultStateMachineConfig) GetServiceInvokeTimeout() int {
return c.serviceInvokeTimeout
}

func (c *DefaultStateMachineConfig) IsSagaRetryPersistModeUpdate() bool {
return c.sagaRetryPersistModeUpdate
}

func (c *DefaultStateMachineConfig) SetSagaRetryPersistModeUpdate(sagaRetryPersistModeUpdate bool) {
c.sagaRetryPersistModeUpdate = sagaRetryPersistModeUpdate
}

func (c *DefaultStateMachineConfig) IsSagaCompensatePersistModeUpdate() bool {
return c.sagaCompensatePersistModeUpdate
}

func (c *DefaultStateMachineConfig) SetSagaCompensatePersistModeUpdate(sagaCompensatePersistModeUpdate bool) {
c.sagaCompensatePersistModeUpdate = sagaCompensatePersistModeUpdate
}

func (c *DefaultStateMachineConfig) IsSagaBranchRegisterEnable() bool {
return c.sagaBranchRegisterEnable
}

func (c *DefaultStateMachineConfig) SetSagaBranchRegisterEnable(sagaBranchRegisterEnable bool) {
c.sagaBranchRegisterEnable = sagaBranchRegisterEnable
}

func (c *DefaultStateMachineConfig) IsRmReportSuccessEnable() bool {
return c.rmReportSuccessEnable
}

func (c *DefaultStateMachineConfig) SetRmReportSuccessEnable(rmReportSuccessEnable bool) {
c.rmReportSuccessEnable = rmReportSuccessEnable
}

func (c *DefaultStateMachineConfig) GetStateMachineDefinition(name string) *statemachine.StateMachineObject {
return c.stateMachineDefs[name]
}

func (c *DefaultStateMachineConfig) GetExpressionFactory(expressionType string) expr.ExpressionFactory {
return c.expressionFactoryManager.GetExpressionFactory(expressionType)
}

func (c *DefaultStateMachineConfig) GetServiceInvoker(serviceType string) (invoker.ServiceInvoker, error) {
if serviceType == "" {
serviceType = "local"
}

invoker := c.serviceInvokerManager.ServiceInvoker(serviceType)
if invoker == nil {
return nil, fmt.Errorf("service invoker not found for type: %s", serviceType)
}

return invoker, nil
}

func (c *DefaultStateMachineConfig) RegisterStateMachineDef(resources []string) error {
var allFiles []string

for _, pattern := range resources {
matches, err := filepath.Glob(pattern)
if err != nil {
return fmt.Errorf("failed to expand glob pattern: pattern=%s, err=%w", pattern, err)
}
if len(matches) == 0 {
return fmt.Errorf("open resource file failed: pattern=%s", pattern)
}
allFiles = append(allFiles, matches...)
}

for _, realPath := range allFiles {
file, err := os.Open(realPath)
if err != nil {
return fmt.Errorf("open resource file failed: path=%s, err=%w", realPath, err)
}
defer file.Close()

if err := c.stateMachineRepository.RegistryStateMachineByReader(file); err != nil {
return fmt.Errorf("register state machine from file failed: path=%s, err=%w", realPath, err)
}
}

return nil
}

func (c *DefaultStateMachineConfig) RegisterExpressionFactory(expressionType string, factory expr.ExpressionFactory) {
c.expressionFactoryManager.PutExpressionFactory(expressionType, factory)
}

func (c *DefaultStateMachineConfig) RegisterServiceInvoker(serviceType string, invoker invoker.ServiceInvoker) {
c.serviceInvokerManager.PutServiceInvoker(serviceType, invoker)
}

type ConfigFileParams struct {
TransOperationTimeout int `json:"trans_operation_timeout" yaml:"trans_operation_timeout"`
ServiceInvokeTimeout int `json:"service_invoke_timeout" yaml:"service_invoke_timeout"`
Charset string `json:"charset" yaml:"charset"`
DefaultTenantId string `json:"default_tenant_id" yaml:"default_tenant_id"`
SagaRetryPersistModeUpdate bool `json:"saga_retry_persist_mode_update" yaml:"saga_retry_persist_mode_update"`
SagaCompensatePersistModeUpdate bool `json:"saga_compensate_persist_mode_update" yaml:"saga_compensate_persist_mode_update"`
SagaBranchRegisterEnable bool `json:"saga_branch_register_enable" yaml:"saga_branch_register_enable"`
RmReportSuccessEnable bool `json:"rm_report_success_enable" yaml:"rm_report_success_enable"`
StateMachineResources []string `json:"state_machine_resources" yaml:"state_machine_resources"`
}

func (c *DefaultStateMachineConfig) LoadConfig(configPath string) error {
if c.seqGenerator == nil {
c.seqGenerator = sequence.NewUUIDSeqGenerator()
}

content, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("failed to read config file: path=%s, error=%w", configPath, err)
}

parser := parser.NewStateMachineConfigParser()
smo, err := parser.Parse(content)
if err != nil {
return fmt.Errorf("failed to parse state machine definition: path=%s, error=%w", configPath, err)
}

var configFileParams ConfigFileParams
if err := json.Unmarshal(content, &configFileParams); err != nil {
if err := yaml.Unmarshal(content, &configFileParams); err != nil {
return fmt.Errorf("failed to unmarshal config file as YAML: %w", err)
} else {
c.applyConfigFileParams(&configFileParams)
}
} else {
c.applyConfigFileParams(&configFileParams)
}

if _, exists := c.stateMachineDefs[smo.Name]; exists {
return fmt.Errorf("state machine definition with name %s already exists", smo.Name)
}
c.stateMachineDefs[smo.Name] = smo

return nil
}

func (c *DefaultStateMachineConfig) applyConfigFileParams(rc *ConfigFileParams) {
if rc.TransOperationTimeout > 0 {
c.transOperationTimeout = rc.TransOperationTimeout
}
if rc.ServiceInvokeTimeout > 0 {
c.serviceInvokeTimeout = rc.ServiceInvokeTimeout
}
if rc.Charset != "" {
c.charset = rc.Charset
}
if rc.DefaultTenantId != "" {
c.defaultTenantId = rc.DefaultTenantId
}
c.sagaRetryPersistModeUpdate = rc.SagaRetryPersistModeUpdate
c.sagaCompensatePersistModeUpdate = rc.SagaCompensatePersistModeUpdate
c.sagaBranchRegisterEnable = rc.SagaBranchRegisterEnable
c.rmReportSuccessEnable = rc.RmReportSuccessEnable
if len(rc.StateMachineResources) > 0 {
c.stateMachineResources = rc.StateMachineResources
}
}

func (c *DefaultStateMachineConfig) registerEventConsumers() error {
if c.processController == nil {
return fmt.Errorf("ProcessController is not initialized")
}

pcImpl, ok := c.processController.(*process_ctrl.ProcessControllerImpl)
if !ok {
return fmt.Errorf("ProcessController is not an instance of ProcessControllerImpl")
}

if pcImpl.BusinessProcessor() == nil {
return fmt.Errorf("BusinessProcessor in ProcessController is not initialized")
}

consumer := process_ctrl.NewProcessCtrlEventConsumer(c.processController)

c.syncEventBus.RegisterEventConsumer(consumer)
c.asyncEventBus.RegisterEventConsumer(consumer)

return nil
}

func (c *DefaultStateMachineConfig) Init() error {
if err := c.initExpressionComponents(); err != nil {
return fmt.Errorf("initialize expression components failed: %w", err)
}

if err := c.initServiceInvokers(); err != nil {
return fmt.Errorf("initialize service invokers failed: %w", err)
}

if err := c.registerEventConsumers(); err != nil {
return fmt.Errorf("register event consumers failed: %w", err)
}

if c.stateMachineRepository != nil && len(c.stateMachineResources) > 0 {
if err := c.RegisterStateMachineDef(c.stateMachineResources); err != nil {
return fmt.Errorf("register state machine def failed: %w", err)
}
}

if err := c.Validate(); err != nil {
return fmt.Errorf("configuration validation failed: %w", err)
}

return nil
}

func (c *DefaultStateMachineConfig) initExpressionComponents() error {
if c.expressionFactoryManager == nil {
c.expressionFactoryManager = expr.NewExpressionFactoryManager()
}

defaultType := expr.DefaultExpressionType
if defaultType == "" {
defaultType = "Default"
}

if factory := c.expressionFactoryManager.GetExpressionFactory(defaultType); factory == nil {
c.RegisterExpressionFactory(defaultType, expr.NewCELExpressionFactory())
}

if factory := c.expressionFactoryManager.GetExpressionFactory("CEL"); factory == nil {
c.RegisterExpressionFactory("CEL", expr.NewCELExpressionFactory())
}

if factory := c.expressionFactoryManager.GetExpressionFactory("el"); factory == nil {
c.RegisterExpressionFactory("el", expr.NewCELExpressionFactory())
}

if c.seqGenerator != nil {
sequenceFactory := expr.NewSequenceExpressionFactory(c.seqGenerator)
c.RegisterExpressionFactory("SEQUENCE", sequenceFactory)
c.RegisterExpressionFactory("SEQ", sequenceFactory)
}

if c.expressionResolver == nil {
resolver := &expr.DefaultExpressionResolver{}
resolver.SetExpressionFactoryManager(*c.expressionFactoryManager)
c.expressionResolver = resolver
}

return nil
}

func (c *DefaultStateMachineConfig) initServiceInvokers() error {
if c.serviceInvokerManager == nil {
c.serviceInvokerManager = invoker.NewServiceInvokerManagerImpl()
}

if existing := c.serviceInvokerManager.ServiceInvoker("local"); existing == nil {
c.RegisterServiceInvoker("local", invoker.NewLocalServiceInvoker())
}

if existing := c.serviceInvokerManager.ServiceInvoker("http"); existing == nil {
c.RegisterServiceInvoker("http", invoker.NewHTTPInvoker())
}

if existing := c.serviceInvokerManager.ServiceInvoker("grpc"); existing == nil {
c.RegisterServiceInvoker("grpc", invoker.NewGRPCInvoker())
}

if existing := c.serviceInvokerManager.ServiceInvoker("func"); existing == nil {
c.RegisterServiceInvoker("func", invoker.NewFuncInvoker())
}

return nil
}

func (c *DefaultStateMachineConfig) Validate() error {
var errs []error

if c.expressionFactoryManager == nil {
errs = append(errs, fmt.Errorf("expression factory manager is nil"))
}
if c.expressionResolver == nil {
errs = append(errs, fmt.Errorf("expression resolver is nil"))
}
if c.serviceInvokerManager == nil {
errs = append(errs, fmt.Errorf("service invoker manager is nil"))
}

if c.transOperationTimeout <= 0 {
errs = append(errs, fmt.Errorf("invalid trans operation timeout: %d", c.transOperationTimeout))
}
if c.serviceInvokeTimeout <= 0 {
errs = append(errs, fmt.Errorf("invalid service invoke timeout: %d", c.serviceInvokeTimeout))
}
if c.charset == "" {
errs = append(errs, fmt.Errorf("charset is empty"))
}

if c.stateMachineRepository != nil {
if c.stateLogStore == nil {
errs = append(errs, fmt.Errorf("state log store is nil"))
}
if c.stateLangStore == nil {
errs = append(errs, fmt.Errorf("state lang store is nil"))
}
if c.stateLogRepository == nil {
errs = append(errs, fmt.Errorf("state log repository is nil"))
}
}

if c.statusDecisionStrategy == nil {
errs = append(errs, fmt.Errorf("status decision strategy is nil"))
}
if c.syncEventBus == nil {
errs = append(errs, fmt.Errorf("sync event bus is nil"))
}
if c.asyncEventBus == nil {
errs = append(errs, fmt.Errorf("async event bus is nil"))
}

if len(errs) > 0 {
return fmt.Errorf("configuration validation failed with %d errors: %v", len(errs), errs)
}
return nil
}

func (c *DefaultStateMachineConfig) EvaluateExpression(expressionStr string, context any) (any, error) {
if c.expressionResolver == nil {
return nil, fmt.Errorf("expression resolver not initialized")
}

expression := c.expressionResolver.Expression(expressionStr)
if expression == nil {
return nil, fmt.Errorf("failed to parse expression: %s", expressionStr)
}

var result any
var evalErr error

func() {
defer func() {
if r := recover(); r != nil {
evalErr = fmt.Errorf("expression evaluation panicked: %v", r)
}
}()

result = expression.Value(context)
}()

if evalErr != nil {
return nil, evalErr
}

if err, ok := result.(error); ok {
return nil, fmt.Errorf("expression evaluation returned error: %w", err)
}

return result, nil
}

func NewDefaultStateMachineConfig(opts ...Option) *DefaultStateMachineConfig {
ctx := context.Background()
defaultBP := process_ctrl.NewBusinessProcessor()

c := &DefaultStateMachineConfig{
transOperationTimeout: DefaultTransOperTimeout,
serviceInvokeTimeout: DefaultServiceInvokeTimeout,
charset: "UTF-8",
defaultTenantId: "000001",
stateMachineResources: parseEnvResources(),
sagaRetryPersistModeUpdate: DefaultClientSagaRetryPersistModeUpdate,
sagaCompensatePersistModeUpdate: DefaultClientSagaCompensatePersistModeUpdate,
sagaBranchRegisterEnable: DefaultClientSagaBranchRegisterEnable,
rmReportSuccessEnable: DefaultClientReportSuccessEnable,
stateMachineDefs: make(map[string]*statemachine.StateMachineObject),
componentLock: &sync.Mutex{},
seqGenerator: sequence.NewUUIDSeqGenerator(),
statusDecisionStrategy: strategy.NewDefaultStatusDecisionStrategy(),
processController: func() process_ctrl.ProcessController {
pc := &process_ctrl.ProcessControllerImpl{}
pc.SetBusinessProcessor(defaultBP)
return pc
}(),
syncEventBus: process_ctrl.NewDirectEventBus(),
asyncEventBus: process_ctrl.NewAsyncEventBus(ctx, 1000, 5),

syncProcessCtrlEventPublisher: nil,
asyncProcessCtrlEventPublisher: nil,

stateLogStore: &NoopStateLogStore{},
stateLangStore: &NoopStateLangStore{},
}

c.stateMachineRepository = repository.GetStateMachineRepositoryImpl()
c.stateLogRepository = repository.NewStateLogRepositoryImpl()

c.syncProcessCtrlEventPublisher = process_ctrl.NewProcessCtrlEventPublisher(c.syncEventBus)
c.asyncProcessCtrlEventPublisher = process_ctrl.NewProcessCtrlEventPublisher(c.asyncEventBus)

for _, opt := range opts {
opt(c)
}

if err := c.LoadConfig("config.yaml"); err == nil {
log.Printf("Successfully loaded config from config.yaml")
} else {
log.Printf("Failed to load config file (using default/env values): %v", err)
}

return c
}

func parseEnvResources() []string {
if env := os.Getenv("SEATA_STATE_MACHINE_RESOURCES"); env != "" {
parts := strings.Split(env, ",")
var res []string
for _, p := range parts {
if p = strings.TrimSpace(p); p != "" {
res = append(res, p)
}
}
return res
}
return nil
}

type Option func(*DefaultStateMachineConfig)

func WithStatusDecisionStrategy(strategy engine.StatusDecisionStrategy) Option {
return func(c *DefaultStateMachineConfig) {
c.statusDecisionStrategy = strategy
}
}

func WithSeqGenerator(gen sequence.SeqGenerator) Option {
return func(c *DefaultStateMachineConfig) {
c.seqGenerator = gen
}
}

func WithProcessController(ctrl process_ctrl.ProcessController) Option {
return func(c *DefaultStateMachineConfig) {
c.processController = ctrl
}
}

func WithBusinessProcessor(bp process_ctrl.BusinessProcessor) Option {
return func(c *DefaultStateMachineConfig) {
if pc, ok := c.processController.(*process_ctrl.ProcessControllerImpl); ok {
pc.SetBusinessProcessor(bp)
} else {
log.Printf("ProcessController is not of type *ProcessControllerImpl, unable to set BusinessProcessor")
}
}
}

func WithStateMachineResources(paths []string) Option {
return func(c *DefaultStateMachineConfig) {
if len(paths) > 0 {
c.stateMachineResources = paths
}
}
}

func WithStateLogRepository(logRepo repo.StateLogRepository) Option {
return func(c *DefaultStateMachineConfig) {
c.stateLogRepository = logRepo
}
}

func WithStateLogStore(logStore store.StateLogStore) Option {
return func(c *DefaultStateMachineConfig) {
c.stateLogStore = logStore
}
}

func WithStateLangStore(langStore store.StateLangStore) Option {
return func(c *DefaultStateMachineConfig) {
c.stateLangStore = langStore
}
}

func WithStateMachineRepository(machineRepo repo.StateMachineRepository) Option {
return func(c *DefaultStateMachineConfig) {
c.stateMachineRepository = machineRepo
}
}

+ 265
- 0
pkg/saga/statemachine/engine/config/default_statemachine_config_test.go View File

@@ -0,0 +1,265 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package config

import (
"errors"
"io"
"os"
"path/filepath"
"testing"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/stretchr/testify/assert"
)

func TestDefaultStateMachineConfig_LoadValidJSON(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
testFile := filepath.Join("testdata", "order_saga.json")

err := config.LoadConfig(testFile)
assert.NoError(t, err, "Loading JSON configuration should succeed")

smo := config.GetStateMachineDefinition("OrderSaga")
assert.NotNil(t, smo, "State machine definition should not be nil")
assert.Equal(t, "CreateOrder", smo.StartState, "The start state should be correct")
assert.Contains(t, smo.States, "CreateOrder", "The state node should exist")

assert.Equal(t, 30000, config.transOperationTimeout, "The timeout should be read correctly")
}

func TestDefaultStateMachineConfig_LoadValidYAML(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
testFile := filepath.Join("testdata", "order_saga.yaml")

err := config.LoadConfig(testFile)
assert.NoError(t, err, "Loading YAML configuration should succeed")

smo := config.GetStateMachineDefinition("OrderSaga")
assert.NotNil(t, smo)
}

func TestLoadNonExistentFile(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
err := config.LoadConfig("non_existent.json")
assert.Error(t, err, "Loading a non-existent file should report an error")
assert.Contains(t, err.Error(), "failed to read config file", "The error message should contain file read failure")
}

func TestGetStateMachineDefinition_Exists(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
_ = config.LoadConfig(filepath.Join("testdata", "order_saga.json"))

smo := config.GetStateMachineDefinition("OrderSaga")
assert.NotNil(t, smo)
assert.Equal(t, "1.0", smo.Version, "The version number should be correct")
}

func TestGetNonExistentStateMachine(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
smo := config.GetStateMachineDefinition("NonExistent")
assert.Nil(t, smo, "An unloaded state machine should return nil")
}

func TestLoadDuplicateStateMachine(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
testFile := filepath.Join("testdata", "order_saga.json")

err := config.LoadConfig(testFile)
assert.NoError(t, err)

err = config.LoadConfig(testFile)
assert.Error(t, err, "Duplicate loading should trigger a name conflict")
assert.Contains(t, err.Error(), "already exists", "The error message should contain a conflict prompt")
}

func TestRuntimeConfig_OverrideDefaults(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
assert.Equal(t, "UTF-8", config.charset, "The default character set should be UTF-8")

_ = config.LoadConfig(filepath.Join("testdata", "order_saga.json"))
assert.Equal(t, "UTF-8", config.charset, "If the configuration does not specify, the default value should be used")

customConfig := &ConfigFileParams{
Charset: "GBK",
}
config.applyConfigFileParams(customConfig)
assert.Equal(t, "GBK", config.charset, "Runtime parameters should be correctly overridden")
}

func TestGetDefaultExpressionFactory(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()

err := config.Init()
assert.NoError(t, err, "Init should not return error")

factory := config.GetExpressionFactory("el")
assert.NotNil(t, factory, "The default EL factory should exist")

unknownFactory := config.GetExpressionFactory("unknown")
assert.Nil(t, unknownFactory, "An unknown expression type should return nil")
}

func TestGetServiceInvoker(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
if err := config.Init(); err != nil {
t.Fatalf("init config failed: %v", err)
}

invoker := config.GetServiceInvoker("local")
if invoker == nil {
t.Errorf("expected non-nil invoker, got nil")
}
}

func TestLoadConfig_InvalidJSON(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
testFile := filepath.Join("testdata", "invalid.json")

err := config.LoadConfig(testFile)
assert.Error(t, err, "Loading an invalid JSON configuration should report an error")
assert.Contains(t, err.Error(), "failed to parse state machine definition", "The error message should contain parsing failure")
}

func TestLoadConfig_InvalidYAML(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
testFile := filepath.Join("testdata", "invalid.yaml")

err := config.LoadConfig(testFile)
assert.Error(t, err, "Loading an invalid YAML configuration should report an error")
assert.Contains(t, err.Error(), "failed to parse state machine definition", "The error message should contain parsing failure")
}

func TestRegisterStateMachineDef_Fail(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
invalidResource := []string{"invalid_path.json"}

err := config.RegisterStateMachineDef(invalidResource)
assert.Error(t, err, "Registering an invalid resource should report an error")
assert.Contains(t, err.Error(), "open resource file failed", "The error message should contain file opening failure")
}

func TestInit_ExpressionFactoryManagerNil(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
config.expressionFactoryManager = nil

err := config.Init()
assert.NoError(t, err, "Initialization should succeed when the expression factory manager is nil")
}

func TestInit_ServiceInvokerManagerNil(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
config.serviceInvokerManager = nil

err := config.Init()
assert.NoError(t, err, "Initialization should succeed when the service invoker manager is nil")
}

func TestInit_StateMachineRepositoryNil(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
config.stateMachineRepository = nil

err := config.Init()
assert.NoError(t, err, "Initialization should succeed when the state machine repository is nil")
}

func TestApplyRuntimeConfig_BoundaryValues(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
customConfig := &ConfigFileParams{
TransOperationTimeout: 1,
ServiceInvokeTimeout: 1,
}
config.applyConfigFileParams(customConfig)
assert.Equal(t, 1, config.transOperationTimeout, "The minimum transaction operation timeout should be correctly applied")
assert.Equal(t, 1, config.serviceInvokeTimeout, "The minimum service invocation timeout should be correctly applied")

maxTimeout := int(^uint(0) >> 1)
customConfig = &ConfigFileParams{
TransOperationTimeout: maxTimeout,
ServiceInvokeTimeout: maxTimeout,
}
config.applyConfigFileParams(customConfig)
assert.Equal(t, maxTimeout, config.transOperationTimeout, "The maximum transaction operation timeout should be correctly applied")
assert.Equal(t, maxTimeout, config.serviceInvokeTimeout, "The maximum service invocation timeout should be correctly applied")
}

type TestStateMachineRepositoryMock struct{}

func (m *TestStateMachineRepositoryMock) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) {
return nil, errors.New("get state machine by id failed")
}

func (m *TestStateMachineRepositoryMock) GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error) {
return nil, errors.New("get state machine by name and tenant id failed")
}

func (m *TestStateMachineRepositoryMock) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) {
return nil, errors.New("get last version state machine failed")
}

func (m *TestStateMachineRepositoryMock) RegistryStateMachine(machine statelang.StateMachine) error {
return errors.New("registry state machine failed")
}

func (m *TestStateMachineRepositoryMock) RegistryStateMachineByReader(reader io.Reader) error {
return errors.New("registry state machine by reader failed")
}

func TestRegisterStateMachineDef_RepositoryError(t *testing.T) {
os.Unsetenv("SEATA_STATE_MACHINE_RESOURCES")

config := NewDefaultStateMachineConfig()
config.stateMachineRepository = &TestStateMachineRepositoryMock{}
resource := []string{filepath.Join("testdata", "order_saga.json")}

err := config.RegisterStateMachineDef(resource)
assert.Error(t, err, "Registration should fail when the state machine repository reports an error")
assert.Contains(t, err.Error(), "register state machine from file failed", "The error message should contain registration failure")
}

+ 86
- 0
pkg/saga/statemachine/engine/config/noop_store.go View File

@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package config

import (
"context"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
)

// NoopStateLogStore is a no-op implementation of StateLogStore for out-of-the-box scenarios.
// All methods perform no actual operations and return nil or zero values to ensure validation passes.
type NoopStateLogStore struct{}

func (s *NoopStateLogStore) RecordStateMachineStarted(ctx context.Context, machineInstance statelang.StateMachineInstance, pc process_ctrl.ProcessContext) error {
return nil
}

func (s *NoopStateLogStore) RecordStateMachineFinished(ctx context.Context, machineInstance statelang.StateMachineInstance, pc process_ctrl.ProcessContext) error {
return nil
}

func (s *NoopStateLogStore) RecordStateMachineRestarted(ctx context.Context, machineInstance statelang.StateMachineInstance, pc process_ctrl.ProcessContext) error {
return nil
}

func (s *NoopStateLogStore) RecordStateStarted(ctx context.Context, stateInstance statelang.StateInstance, pc process_ctrl.ProcessContext) error {
return nil
}

func (s *NoopStateLogStore) RecordStateFinished(ctx context.Context, stateInstance statelang.StateInstance, pc process_ctrl.ProcessContext) error {
return nil
}

func (s *NoopStateLogStore) GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateMachineInstance, error) {
return nil, nil
}

func (s *NoopStateLogStore) GetStateMachineInstanceByBusinessKey(businessKey string, tenantId string) (statelang.StateMachineInstance, error) {
return nil, nil
}

func (s *NoopStateLogStore) GetStateMachineInstanceByParentId(parentId string) ([]statelang.StateMachineInstance, error) {
return nil, nil
}

func (s *NoopStateLogStore) GetStateInstance(stateInstanceId string, stateMachineInstanceId string) (statelang.StateInstance, error) {
return nil, nil
}

func (s *NoopStateLogStore) GetStateInstanceListByMachineInstanceId(stateMachineInstanceId string) ([]statelang.StateInstance, error) {
return nil, nil
}

func (s *NoopStateLogStore) ClearUp(pc process_ctrl.ProcessContext) {
// no-op
}

type NoopStateLangStore struct{}

func (s *NoopStateLangStore) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) {
return nil, nil
}

func (s *NoopStateLangStore) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) {
return nil, nil
}

func (s *NoopStateLangStore) StoreStateMachine(stateMachine statelang.StateMachine) error {
return nil
}

+ 3
- 0
pkg/saga/statemachine/engine/config/testdata/invalid.json View File

@@ -0,0 +1,3 @@
{
"name": "John",
"age": 30,

+ 16
- 0
pkg/saga/statemachine/engine/config/testdata/invalid.json.comment View File

@@ -0,0 +1,16 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

+ 19
- 0
pkg/saga/statemachine/engine/config/testdata/invalid.yaml View File

@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

parent:
child1: value1
child2: value2
child3: value3

+ 274
- 0
pkg/saga/statemachine/engine/config/testdata/order_saga.json View File

@@ -0,0 +1,274 @@
{
"Name": "OrderSaga",
"Version": "1.0",
"StartState": "CreateOrder",
"trans_operation_timeout": 30000,
"States": {
"CreateOrder": {
"Type": "ServiceTask",
"serviceType": "local",
"serviceName": "orderService",
"serviceMethod": "createOrder",
"CompensateState": "CancelOrder",
"ForCompensation": false,
"ForUpdate": false,
"Retry": [
{
"Exceptions": [
"OrderCreationException",
"InventoryUnavailableException"
],
"IntervalSeconds": 2,
"MaxAttempts": 3,
"BackoffRate": 1.5
}
],
"Catches": [
{
"Exceptions": [
"OrderCreationException",
"InventoryUnavailableException"
],
"Next": "ErrorHandler"
}
],
"Status": {
"return.code == 'SUCCESS'": "SUCCEEDED",
"return.code == 'FAIL'": "FAILED",
"$exception{*}": "UNKNOWN"
},
"Input": [
{
"orderInfo": "$.orderInfo"
}
],
"Output": {
"orderId": "$.#root"
},
"Next": "CheckStock",
"Loop": {
"Parallel": 1,
"Collection": "$.orderItems",
"ElementVariableName": "item",
"ElementIndexName": "index",
"CompletionCondition": "[nrOfInstances] == [nrOfCompletedInstances]"
}
},
"CheckStock": {
"Type": "ServiceTask",
"serviceType": "local",
"serviceName": "inventoryService",
"serviceMethod": "checkAvailability",
"CompensateState": "RollbackStock",
"ForCompensation": false,
"ForUpdate": false,
"Retry": [
{
"Exceptions": [
"StockCheckException"
],
"IntervalSeconds": 2,
"MaxAttempts": 2,
"BackoffRate": 1.2
}
],
"Catches": [
{
"Exceptions": [
"StockCheckException"
],
"Next": "ErrorHandler"
}
],
"Status": {
"return.available == true": "IN_STOCK",
"return.available == false": "OUT_OF_STOCK",
"$exception{*}": "UNKNOWN"
},
"Input": [
{
"orderId": "$.orderId"
},
{
"itemsList": "$.orderItems"
}
],
"Output": {
"stockAvailable": "$.#root"
},
"Next": "DecideStock"
},
"DecideStock": {
"Type": "Choice",
"Choices": [
{
"Expression": "stockAvailable == true",
"Next": "ReserveStock"
},
{
"Expression": "stockAvailable == false",
"Next": "CancelOrder"
}
],
"Default": "ErrorHandler"
},
"ReserveStock": {
"Type": "ServiceTask",
"serviceType": "local",
"serviceName": "inventoryService",
"serviceMethod": "reserveItems",
"CompensateState": "RollbackStock",
"ForCompensation": false,
"ForUpdate": false,
"Retry": [
{
"Exceptions": [
"StockReservationException"
],
"IntervalSeconds": 2,
"MaxAttempts": 2,
"BackoffRate": 1.2
}
],
"Catches": [
{
"Exceptions": [
"StockReservationException"
],
"Next": "ErrorHandler"
}
],
"Status": {
"return.code == 'RESERVED'": "STOCK_RESERVED",
"return.code == 'FAILED'": "FAILED",
"$exception{*}": "UNKNOWN"
},
"Input": [
{
"orderId": "$.orderId"
},
{
"itemList": "$.orderItems"
}
],
"Output": {
"stockReservationId": "$.#root"
},
"Next": "ProcessPayment"
},
"ProcessPayment": {
"Type": "ServiceTask",
"serviceType": "local",
"serviceName": "paymentService",
"serviceMethod": "processPayment",
"CompensateState": "RefundPayment",
"ForCompensation": false,
"ForUpdate": false,
"Retry": [
{
"Exceptions": [
"PaymentProcessingException"
],
"IntervalSeconds": 3,
"MaxAttempts": 3,
"BackoffRate": 1.5
}
],
"Catches": [
{
"Exceptions": [
"PaymentProcessingException"
],
"Next": "ErrorHandler"
}
],
"Status": {
"return.code == 'PAID'": "PAYMENT_SUCCESS",
"return.code == 'DECLINED'": "PAYMENT_FAILED",
"$exception{*}": "UNKNOWN"
},
"Input": [
{
"orderId": "$.orderId"
},
{
"amount": "$.orderTotal"
}
],
"Output": {
"paymentTransactionId": "$.#root"
},
"Next": "CompleteOrder"
},
"CompleteOrder": {
"Type": "Succeed"
},
"CancelOrder": {
"Type": "ServiceTask",
"serviceType": "local",
"serviceName": "orderService",
"serviceMethod": "cancelOrder",
"ForCompensation": true,
"ForUpdate": true,
"Input": [
{
"orderId": "$.orderId"
}
],
"Output": {
"cancelResult": "$.#root"
},
"Next": "RollbackStock"
},
"RollbackStock": {
"Type": "ServiceTask",
"serviceType": "local",
"serviceName": "inventoryService",
"serviceMethod": "releaseItems",
"ForCompensation": true,
"ForUpdate": true,
"Input": [
{
"orderId": "$.orderId"
},
{
"stockReservationId": "$.stockReservationId"
}
],
"Output": {
"rollbackResult": "$.#root"
},
"Next": "RefundPayment"
},
"RefundPayment": {
"Type": "ServiceTask",
"serviceType": "local",
"serviceName": "paymentService",
"serviceMethod": "refundPayment",
"ForCompensation": true,
"ForUpdate": true,
"Input": [
{
"orderId": "$.orderId"
},
{
"paymentTransactionId": "$.paymentTransactionId"
}
],
"Output": {
"refundResult": "$.#root"
},
"Next": "FailState"
},
"ErrorHandler": {
"Type": "Fail",
"ErrorCode": "ORDER_PROCESSING_ERROR",
"Message": "An unrecoverable error occurred during order processing."
},
"FailState": {
"Type": "Fail",
"ErrorCode": "ORDER_CANCELLED",
"Message": "The order has been cancelled and compensation actions have been completed."
}
}
}

+ 16
- 0
pkg/saga/statemachine/engine/config/testdata/order_saga.json.comment View File

@@ -0,0 +1,16 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

+ 204
- 0
pkg/saga/statemachine/engine/config/testdata/order_saga.yaml View File

@@ -0,0 +1,204 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Name: "OrderSaga"
Version: "1.0"
StartState: "CreateOrder"
trans_operation_timeout: 30000
States:
CreateOrder:
Type: "ServiceTask"
serviceType: "local"
serviceName: "orderService"
serviceMethod: "createOrder"
CompensateState: "CancelOrder"
ForCompensation: false
ForUpdate: false
Retry:
- Exceptions:
- "OrderCreationException"
- "InventoryUnavailableException"
IntervalSeconds: 2
MaxAttempts: 3
BackoffRate: 1.5
Catches:
- Exceptions:
- "OrderCreationException"
- "InventoryUnavailableException"
Next: "ErrorHandler"
Status:
"return.code == 'SUCCESS'": "SUCCEEDED"
"return.code == 'FAIL'": "FAILED"
"$exception{*}": "UNKNOWN"
Input:
- orderInfo: "$.orderInfo"
Output:
orderId: "$.#root"
Next: "CheckStock"
Loop:
Parallel: 1
Collection: "$.orderItems"
ElementVariableName: "item"
ElementIndexName: "index"
CompletionCondition: "[nrOfInstances] == [nrOfCompletedInstances]"

CheckStock:
Type: "ServiceTask"
serviceType: "local"
serviceName: "inventoryService"
serviceMethod: "checkAvailability"
CompensateState: "RollbackStock"
ForCompensation: false
ForUpdate: false
Retry:
- Exceptions:
- "StockCheckException"
IntervalSeconds: 2
MaxAttempts: 2
BackoffRate: 1.2
Catches:
- Exceptions:
- "StockCheckException"
Next: "ErrorHandler"
Status:
"return.available == true": "IN_STOCK"
"return.available == false": "OUT_OF_STOCK"
"$exception{*}": "UNKNOWN"
Input:
- orderId: "$.orderId"
- itemsList: "$.orderItems"
Output:
stockAvailable: "$.#root"
Next: "DecideStock"

DecideStock:
Type: "Choice"
Choices:
- Expression: "stockAvailable == true"
Next: "ReserveStock"
- Expression: "stockAvailable == false"
Next: "CancelOrder"
Default: "ErrorHandler"

ReserveStock:
Type: "ServiceTask"
serviceType: "local"
serviceName: "inventoryService"
serviceMethod: "reserveItems"
CompensateState: "RollbackStock"
ForCompensation: false
ForUpdate: false
Retry:
- Exceptions:
- "StockReservationException"
IntervalSeconds: 2
MaxAttempts: 2
BackoffRate: 1.2
Catches:
- Exceptions:
- "StockReservationException"
Next: "ErrorHandler"
Status:
"return.code == 'RESERVED'": "STOCK_RESERVED"
"return.code == 'FAILED'": "FAILED"
"$exception{*}": "UNKNOWN"
Input:
- orderId: "$.orderId"
- itemList: "$.orderItems"
Output:
stockReservationId: "$.#root"
Next: "ProcessPayment"

ProcessPayment:
Type: "ServiceTask"
serviceType: "local"
serviceName: "paymentService"
serviceMethod: "processPayment"
CompensateState: "RefundPayment"
ForCompensation: false
ForUpdate: false
Retry:
- Exceptions:
- "PaymentProcessingException"
IntervalSeconds: 3
MaxAttempts: 3
BackoffRate: 1.5
Catches:
- Exceptions:
- "PaymentProcessingException"
Next: "ErrorHandler"
Status:
"return.code == 'PAID'": "PAYMENT_SUCCESS"
"return.code == 'DECLINED'": "PAYMENT_FAILED"
"$exception{*}": "UNKNOWN"
Input:
- orderId: "$.orderId"
- amount: "$.orderTotal"
Output:
paymentTransactionId: "$.#root"
Next: "CompleteOrder"

CompleteOrder:
Type: "Succeed"

CancelOrder:
Type: "ServiceTask"
serviceType: "local"
serviceName: "orderService"
serviceMethod: "cancelOrder"
ForCompensation: true
ForUpdate: true
Input:
- orderId: "$.orderId"
Output:
cancelResult: "$.#root"
Next: "RollbackStock"

RollbackStock:
Type: "ServiceTask"
serviceType: "local"
serviceName: "inventoryService"
serviceMethod: "releaseItems"
ForCompensation: true
ForUpdate: true
Input:
- orderId: "$.orderId"
- stockReservationId: "$.stockReservationId"
Output:
rollbackResult: "$.#root"
Next: "RefundPayment"

RefundPayment:
Type: "ServiceTask"
serviceType: "local"
serviceName: "paymentService"
serviceMethod: "refundPayment"
ForCompensation: true
ForUpdate: true
Input:
- orderId: "$.orderId"
- paymentTransactionId: "$.paymentTransactionId"
Output:
refundResult: "$.#root"
Next: "FailState"

ErrorHandler:
Type: "Fail"
ErrorCode: "ORDER_PROCESSING_ERROR"
Message: "订单处理过程中发生不可恢复的错误。"

FailState:
Type: "Fail"
ErrorCode: "ORDER_CANCELLED"
Message: "订单已取消,触发补偿完成。"

+ 722
- 0
pkg/saga/statemachine/engine/core/process_ctrl_statemachine_engine.go View File

@@ -0,0 +1,722 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package core

import (
"context"
"fmt"
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/config"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/exception"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/pcext"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/utils"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl/process"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
seataErrors "github.com/seata/seata-go/pkg/util/errors"
"github.com/seata/seata-go/pkg/util/log"
"time"
)

type ProcessCtrlStateMachineEngine struct {
StateMachineConfig engine.StateMachineConfig
}

func NewProcessCtrlStateMachineEngine() *ProcessCtrlStateMachineEngine {
return &ProcessCtrlStateMachineEngine{
StateMachineConfig: config.NewDefaultStateMachineConfig(),
}
}

func (p ProcessCtrlStateMachineEngine) Start(ctx context.Context, stateMachineName string, tenantId string,
startParams map[string]interface{}) (statelang.StateMachineInstance, error) {
return p.startInternal(ctx, stateMachineName, tenantId, "", startParams, false, nil)
}

func (p ProcessCtrlStateMachineEngine) StartAsync(ctx context.Context, stateMachineName string, tenantId string,
startParams map[string]interface{}, callback engine.CallBack) (statelang.StateMachineInstance, error) {
return p.startInternal(ctx, stateMachineName, tenantId, "", startParams, true, callback)
}

func (p ProcessCtrlStateMachineEngine) StartWithBusinessKey(ctx context.Context, stateMachineName string,
tenantId string, businessKey string, startParams map[string]interface{}) (statelang.StateMachineInstance, error) {
return p.startInternal(ctx, stateMachineName, tenantId, businessKey, startParams, false, nil)
}

func (p ProcessCtrlStateMachineEngine) StartWithBusinessKeyAsync(ctx context.Context, stateMachineName string,
tenantId string, businessKey string, startParams map[string]interface{}, callback engine.CallBack) (statelang.StateMachineInstance, error) {
return p.startInternal(ctx, stateMachineName, tenantId, businessKey, startParams, true, callback)
}

func (p ProcessCtrlStateMachineEngine) Forward(ctx context.Context, stateMachineInstId string,
replaceParams map[string]interface{}) (statelang.StateMachineInstance, error) {
return p.forwardInternal(ctx, stateMachineInstId, replaceParams, false, false, nil)
}

func (p ProcessCtrlStateMachineEngine) ForwardAsync(ctx context.Context, stateMachineInstId string, replaceParams map[string]interface{}, callback engine.CallBack) (statelang.StateMachineInstance, error) {
return p.forwardInternal(ctx, stateMachineInstId, replaceParams, false, true, callback)
}

func (p ProcessCtrlStateMachineEngine) Compensate(ctx context.Context, stateMachineInstId string,
replaceParams map[string]any) (statelang.StateMachineInstance, error) {
return p.compensateInternal(ctx, stateMachineInstId, replaceParams, false, nil)
}

func (p ProcessCtrlStateMachineEngine) CompensateAsync(ctx context.Context, stateMachineInstId string, replaceParams map[string]interface{}, callback engine.CallBack) (statelang.StateMachineInstance, error) {
return p.compensateInternal(ctx, stateMachineInstId, replaceParams, true, callback)
}

func (p ProcessCtrlStateMachineEngine) SkipAndForward(ctx context.Context, stateMachineInstId string, replaceParams map[string]interface{}) (statelang.StateMachineInstance, error) {
return p.forwardInternal(ctx, stateMachineInstId, replaceParams, true, false, nil)
}

func (p ProcessCtrlStateMachineEngine) SkipAndForwardAsync(ctx context.Context, stateMachineInstId string, callback engine.CallBack) (statelang.StateMachineInstance, error) {
return p.forwardInternal(ctx, stateMachineInstId, nil, true, true, callback)
}

func (p ProcessCtrlStateMachineEngine) GetStateMachineConfig() engine.StateMachineConfig {
return p.StateMachineConfig
}

func (p ProcessCtrlStateMachineEngine) ReloadStateMachineInstance(ctx context.Context, instId string) (statelang.StateMachineInstance, error) {
inst, err := p.StateMachineConfig.StateLogStore().GetStateMachineInstance(instId)
if err != nil {
return nil, err
}
if inst != nil {
stateMachine := inst.StateMachine()
if stateMachine == nil {
stateMachine, err = p.StateMachineConfig.StateMachineRepository().GetStateMachineById(inst.MachineID())
if err != nil {
return nil, err
}
inst.SetStateMachine(stateMachine)
}
if stateMachine == nil {
return nil, exception.NewEngineExecutionException(seataErrors.ObjectNotExists,
"StateMachine[id:"+inst.MachineID()+"] not exist.", nil)
}

stateList := inst.StateList()
if len(stateList) == 0 {
stateList, err = p.StateMachineConfig.StateLogStore().GetStateInstanceListByMachineInstanceId(instId)
if err != nil {
return nil, err
}
if len(stateList) > 0 {
for _, tmpStateInstance := range stateList {
inst.PutState(tmpStateInstance.ID(), tmpStateInstance)
}
}
}

if len(inst.EndParams()) == 0 {
endParams, err := p.replayContextVariables(ctx, inst)
if err != nil {
return nil, err
}
inst.SetEndParams(endParams)
}
}
return inst, nil
}

func (p ProcessCtrlStateMachineEngine) startInternal(ctx context.Context, stateMachineName string, tenantId string,
businessKey string, startParams map[string]interface{}, async bool, callback engine.CallBack) (statelang.StateMachineInstance, error) {
if tenantId == "" {
tenantId = p.StateMachineConfig.GetDefaultTenantId()
}

stateMachineInstance, err := p.createMachineInstance(stateMachineName, tenantId, businessKey, startParams)
if err != nil {
return nil, err
}

// Build the process_ctrl context.
processContextBuilder := utils.NewProcessContextBuilder().
WithProcessType(process.StateLang).
WithOperationName(constant.OperationNameStart).
WithAsyncCallback(callback).
WithInstruction(pcext.NewStateInstruction(stateMachineName, tenantId)).
WithStateMachineInstance(stateMachineInstance).
WithStateMachineConfig(p.StateMachineConfig).
WithStateMachineEngine(p).
WithIsAsyncExecution(async)

contextMap := p.copyMap(startParams)

stateMachineInstance.SetContext(contextMap)

processContext := processContextBuilder.WithStateMachineContextVariables(contextMap).Build()

if stateMachineInstance.StateMachine().IsPersist() && p.StateMachineConfig.StateLogStore() != nil {
err := p.StateMachineConfig.StateLogStore().RecordStateMachineStarted(ctx, stateMachineInstance, processContext)
if err != nil {
return nil, err
}
}

if stateMachineInstance.ID() == "" {
stateMachineInstance.SetID(p.StateMachineConfig.SeqGenerator().GenerateId(constant.SeqEntityStateMachineInst, ""))
}

var eventPublisher process_ctrl.EventPublisher
if async {
eventPublisher = p.StateMachineConfig.AsyncEventPublisher()
} else {
eventPublisher = p.StateMachineConfig.EventPublisher()
}

_, err = eventPublisher.PushEvent(ctx, processContext)
if err != nil {
return nil, err
}

return stateMachineInstance, nil
}

func (p ProcessCtrlStateMachineEngine) forwardInternal(ctx context.Context, stateMachineInstId string,
replaceParams map[string]interface{}, skip bool, async bool, callback engine.CallBack) (statelang.StateMachineInstance, error) {
stateMachineInstance, err := p.reloadStateMachineInstance(ctx, stateMachineInstId)
if err != nil {
return nil, err
}

if stateMachineInstance == nil {
return nil, exception.NewEngineExecutionException(seataErrors.StateMachineInstanceNotExists, "StateMachineInstance is not exists", nil)
}

if stateMachineInstance.Status() == statelang.SU && stateMachineInstance.CompensationStatus() == "" {
return stateMachineInstance, nil
}

acceptStatus := []statelang.ExecutionStatus{statelang.FA, statelang.UN, statelang.RU}
if _, err := p.checkStatus(ctx, stateMachineInstance, acceptStatus, nil, stateMachineInstance.Status(), "", "forward"); err != nil {
return nil, err
}

actList := stateMachineInstance.StateList()
if len(actList) == 0 {
return nil, exception.NewEngineExecutionException(seataErrors.OperationDenied,
fmt.Sprintf("StateMachineInstance[id:%s] has no stateInstance, please start a new StateMachine execution instead", stateMachineInstId), nil)
}

lastForwardState, err := p.findOutLastForwardStateInstance(actList)
if err != nil {
return nil, err
}
if lastForwardState == nil {
return nil, exception.NewEngineExecutionException(seataErrors.OperationDenied,
fmt.Sprintf("StateMachineInstance[id:%s] Cannot find last forward execution stateInstance", stateMachineInstId), nil)
}

contextBuilder := utils.NewProcessContextBuilder().
WithProcessType(process.StateLang).
WithOperationName(constant.OperationNameForward).
WithAsyncCallback(callback).
WithStateMachineInstance(stateMachineInstance).
WithStateInstance(lastForwardState).
WithStateMachineConfig(p.StateMachineConfig).
WithStateMachineEngine(p).
WithIsAsyncExecution(async)

context := contextBuilder.Build()

contextVariables, err := p.getStateMachineContextVariables(ctx, stateMachineInstance)
if err != nil {
return nil, err
}

if replaceParams != nil {
for k, v := range replaceParams {
contextVariables[k] = v
}
}
p.putBusinesskeyToContextariables(stateMachineInstance, contextVariables)

concurrentContextVariables := p.copyMap(contextVariables)

context.SetVariable(constant.VarNameStateMachineContext, concurrentContextVariables)
stateMachineInstance.SetContext(concurrentContextVariables)

originStateName := pcext.GetOriginStateName(lastForwardState)
lastState := stateMachineInstance.StateMachine().State(originStateName)
loop := pcext.GetLoopConfig(ctx, context, lastState)
if loop != nil && lastForwardState.Status() == statelang.SU {
lastForwardState = p.findOutLastNeedForwardStateInstance(ctx, context)
}

context.SetVariable(lastForwardState.Name()+constant.VarNameRetriedStateInstId, lastForwardState.ID())
if lastForwardState.Type() == constant.StateTypeSubStateMachine && lastForwardState.CompensationStatus() != statelang.SU {
context.SetVariable(constant.VarNameIsForSubStatMachineForward, true)
}

if lastForwardState.Status() != statelang.SU {
lastForwardState.SetIgnoreStatus(true)
}

inst := pcext.NewStateInstruction(stateMachineInstance.StateMachine().Name(), stateMachineInstance.TenantID())
if skip || lastForwardState.Status() == statelang.SU {
next := ""
curState := stateMachineInstance.StateMachine().State(pcext.GetOriginStateName(lastForwardState))
if taskState, ok := curState.(*state.AbstractTaskState); ok {
next = taskState.Next()
}
if next == "" {
log.Warn(fmt.Sprintf("Last Forward execution StateInstance was succeed, and it has not Next State, skip forward operation"))
return stateMachineInstance, nil
}
inst.SetStateName(next)
} else {
if lastForwardState.Status() == statelang.RU && !pcext.IsTimeout(lastForwardState.StartedTime(), p.StateMachineConfig.GetServiceInvokeTimeout()) {
return nil, exception.NewEngineExecutionException(seataErrors.OperationDenied,
fmt.Sprintf("State [%s] is running, operation[forward] denied", lastForwardState.Name()), nil)
}
inst.SetStateName(pcext.GetOriginStateName(lastForwardState))
}
context.SetInstruction(inst)

stateMachineInstance.SetStatus(statelang.RU)
stateMachineInstance.SetRunning(true)

log.Info(fmt.Sprintf("Operation [forward] started stateMachineInstance[id:%s]", stateMachineInstance.ID()))

if stateMachineInstance.StateMachine().IsPersist() {
if err := p.StateMachineConfig.StateLogStore().RecordStateMachineRestarted(ctx, stateMachineInstance, context); err != nil {
return nil, err
}
}

curState, err := inst.GetState(context)
if err != nil {
return nil, err
}
loop = pcext.GetLoopConfig(ctx, context, curState)
if loop != nil {
inst.SetTemporaryState(state.NewLoopStartStateImpl())
}

if async {
if _, err := p.StateMachineConfig.AsyncEventPublisher().PushEvent(ctx, context); err != nil {
return nil, err
}
} else {
if _, err := p.StateMachineConfig.EventPublisher().PushEvent(ctx, context); err != nil {
return nil, err
}
}

return stateMachineInstance, nil
}

func (p ProcessCtrlStateMachineEngine) findOutLastForwardStateInstance(stateInstanceList []statelang.StateInstance) (statelang.StateInstance, error) {
var lastForwardStateInstance statelang.StateInstance
var err error
for i := len(stateInstanceList) - 1; i >= 0; i-- {
stateInstance := stateInstanceList[i]
if !stateInstance.IsForCompensation() {
if stateInstance.CompensationStatus() == statelang.SU {
continue
}

if stateInstance.Type() == constant.StateTypeSubStateMachine {
finalState := stateInstance
for finalState.StateIDRetriedFor() != "" {
if finalState, err = p.StateMachineConfig.StateLogStore().GetStateInstance(finalState.StateIDRetriedFor(),
finalState.MachineInstanceID()); err != nil {
return nil, err
}
}

subInst, _ := p.StateMachineConfig.StateLogStore().GetStateMachineInstanceByParentId(pcext.GenerateParentId(finalState))
if len(subInst) > 0 {
if subInst[0].CompensationStatus() == statelang.SU {
continue
}

if subInst[0].CompensationStatus() == statelang.UN {
return nil, exception.NewEngineExecutionException(seataErrors.ForwardInvalid,
"Last forward execution state instance is SubStateMachine and compensation status is [UN], Operation[forward] denied, stateInstanceId:"+stateInstance.ID(),
nil)
}
}
} else if stateInstance.CompensationStatus() == statelang.UN {
return nil, exception.NewEngineExecutionException(seataErrors.ForwardInvalid,
"Last forward execution state instance compensation status is [UN], Operation[forward] denied, stateInstanceId:"+stateInstance.ID(),
nil)
}

lastForwardStateInstance = stateInstance
break
}
}
return lastForwardStateInstance, nil
}

// copyMap not deep copy, so best practice: Don’t pass by reference
func (p ProcessCtrlStateMachineEngine) copyMap(startParams map[string]interface{}) map[string]interface{} {
copyMap := make(map[string]interface{}, len(startParams))
for k, v := range startParams {
copyMap[k] = v
}
return copyMap
}

func (p ProcessCtrlStateMachineEngine) createMachineInstance(stateMachineName string, tenantId string, businessKey string, startParams map[string]interface{}) (statelang.StateMachineInstance, error) {
stateMachine, err := p.StateMachineConfig.StateMachineRepository().GetLastVersionStateMachine(stateMachineName, tenantId)
if err != nil {
return nil, err
}

if stateMachine == nil {
return nil, errors.New("StateMachine [" + stateMachineName + "] is not exists")
}

stateMachineInstance := statelang.NewStateMachineInstanceImpl()
stateMachineInstance.SetStateMachine(stateMachine)
stateMachineInstance.SetTenantID(tenantId)
stateMachineInstance.SetBusinessKey(businessKey)
stateMachineInstance.SetStartParams(startParams)
if startParams != nil {
if businessKey != "" {
startParams[constant.VarNameBusinesskey] = businessKey
}

if startParams[constant.VarNameParentId] != nil {
parentId, ok := startParams[constant.VarNameParentId].(string)
if !ok {

}
stateMachineInstance.SetParentID(parentId)
delete(startParams, constant.VarNameParentId)
}
}

stateMachineInstance.SetStatus(statelang.RU)
stateMachineInstance.SetRunning(true)

now := time.Now()
stateMachineInstance.SetStartedTime(now)
stateMachineInstance.SetUpdatedTime(now)
return stateMachineInstance, nil
}

func (p ProcessCtrlStateMachineEngine) compensateInternal(ctx context.Context, stateMachineInstId string, replaceParams map[string]any,
async bool, callback engine.CallBack) (statelang.StateMachineInstance, error) {
stateMachineInstance, err := p.reloadStateMachineInstance(ctx, stateMachineInstId)
if err != nil {
return nil, err
}

if stateMachineInstance == nil {
return nil, exception.NewEngineExecutionException(seataErrors.StateMachineInstanceNotExists,
"StateMachineInstance is not exits", nil)
}

if statelang.SU == stateMachineInstance.CompensationStatus() {
return stateMachineInstance, nil
}

if stateMachineInstance.CompensationStatus() != "" {
denyStatus := make([]statelang.ExecutionStatus, 0)
denyStatus = append(denyStatus, statelang.SU)
p.checkStatus(ctx, stateMachineInstance, nil, denyStatus, "", stateMachineInstance.CompensationStatus(),
"compensate")
}

if replaceParams != nil {
for key, value := range replaceParams {
stateMachineInstance.EndParams()[key] = value
}
}

contextBuilder := utils.NewProcessContextBuilder().WithProcessType(process.StateLang).
WithOperationName(constant.OperationNameCompensate).WithAsyncCallback(callback).
WithStateMachineInstance(stateMachineInstance).
WithStateMachineConfig(p.StateMachineConfig).WithStateMachineEngine(p).WithIsAsyncExecution(async)

context := contextBuilder.Build()

contextVariables, err := p.getStateMachineContextVariables(ctx, stateMachineInstance)

if replaceParams != nil {
for key, value := range replaceParams {
contextVariables[key] = value
}
}

p.putBusinesskeyToContextariables(stateMachineInstance, contextVariables)

// TODO: Here is not use sync.map, make sure whether to use it
concurrentContextVariables := make(map[string]any)
p.nullSafeCopy(contextVariables, concurrentContextVariables)

context.SetVariable(constant.VarNameStateMachineContext, concurrentContextVariables)
stateMachineInstance.SetContext(concurrentContextVariables)

tempCompensationTriggerState := state.NewCompensationTriggerStateImpl()
tempCompensationTriggerState.SetStateMachine(stateMachineInstance.StateMachine())

stateMachineInstance.SetRunning(true)

log.Info("Operation [compensate] start. stateMachineInstance[id:" + stateMachineInstance.ID() + "]")

if stateMachineInstance.StateMachine().IsPersist() {
err := p.StateMachineConfig.StateLogStore().RecordStateMachineRestarted(ctx, stateMachineInstance, context)
if err != nil {
return nil, err
}
}

inst := pcext.NewStateInstruction(stateMachineInstance.TenantID(), stateMachineInstance.StateMachine().Name())
inst.SetTemporaryState(tempCompensationTriggerState)
context.SetInstruction(inst)

if async {
_, err := p.StateMachineConfig.AsyncEventPublisher().PushEvent(ctx, context)
if err != nil {
return nil, err
}
} else {
_, err := p.StateMachineConfig.EventPublisher().PushEvent(ctx, context)
if err != nil {
return nil, err
}
}

return stateMachineInstance, nil
}

func (p ProcessCtrlStateMachineEngine) reloadStateMachineInstance(ctx context.Context, instId string) (statelang.StateMachineInstance, error) {
instance, err := p.StateMachineConfig.StateLogStore().GetStateMachineInstance(instId)
if err != nil {
return nil, err
}
if instance != nil {
stateMachine := instance.StateMachine()
if stateMachine == nil {
stateMachine, err = p.StateMachineConfig.StateMachineRepository().GetStateMachineById(instance.MachineID())
if err != nil {
return nil, err
}
instance.SetStateMachine(stateMachine)
}
if stateMachine == nil {
return nil, exception.NewEngineExecutionException(seataErrors.ObjectNotExists,
"StateMachine[id:"+instance.MachineID()+"] not exist.", nil)
}

stateList := instance.StateList()
if stateList == nil || len(stateList) == 0 {
stateList, err = p.StateMachineConfig.StateLogStore().GetStateInstanceListByMachineInstanceId(instId)
if err != nil {
return nil, err
}
if stateList != nil && len(stateList) > 0 {
for _, tmpStateInstance := range stateList {
instance.PutState(tmpStateInstance.ID(), tmpStateInstance)
}
}
}

if instance.EndParams() == nil || len(instance.EndParams()) == 0 {
variables, err := p.replayContextVariables(ctx, instance)
if err != nil {
return nil, err
}
instance.SetEndParams(variables)
}
}
return instance, nil
}

func (p ProcessCtrlStateMachineEngine) replayContextVariables(ctx context.Context, stateMachineInstance statelang.StateMachineInstance) (map[string]any, error) {
contextVariables := make(map[string]any)
if stateMachineInstance.StartParams() != nil {
for key, value := range stateMachineInstance.StartParams() {
contextVariables[key] = value
}
}

stateInstanceList := stateMachineInstance.StateList()
if stateInstanceList == nil || len(stateInstanceList) == 0 {
return contextVariables, nil
}

for _, stateInstance := range stateInstanceList {
serviceOutputParams := stateInstance.OutputParams()
if serviceOutputParams != nil {
serviceTaskStateImpl, ok := stateMachineInstance.StateMachine().State(pcext.GetOriginStateName(stateInstance)).(*state.ServiceTaskStateImpl)
if !ok {
return nil, exception.NewEngineExecutionException(seataErrors.ObjectNotExists,
"Cannot find State by state name ["+stateInstance.Name()+"], may be this is a bug", nil)
}

if serviceTaskStateImpl.Output() != nil && len(serviceTaskStateImpl.Output()) != 0 {
outputVariablesToContext, err := pcext.CreateOutputParams(p.StateMachineConfig,
p.StateMachineConfig.ExpressionResolver(), serviceTaskStateImpl.AbstractTaskState, serviceOutputParams)
if err != nil {
return nil, exception.NewEngineExecutionException(seataErrors.ObjectNotExists,
"Context variable replay failed", err)
}
if outputVariablesToContext != nil && len(outputVariablesToContext) != 0 {
for key, value := range outputVariablesToContext {
contextVariables[key] = value
}
}
if len(stateInstance.BusinessKey()) > 0 {
contextVariables[serviceTaskStateImpl.Name()+constant.VarNameBusinesskey] = stateInstance.BusinessKey()
}
}
}
}

return contextVariables, nil
}

func (p ProcessCtrlStateMachineEngine) checkStatus(ctx context.Context, stateMachineInstance statelang.StateMachineInstance,
acceptStatus []statelang.ExecutionStatus, denyStatus []statelang.ExecutionStatus, status statelang.ExecutionStatus,
compenStatus statelang.ExecutionStatus, operation string) (bool, error) {
if status != "" && compenStatus != "" {
return false, exception.NewEngineExecutionException(seataErrors.InvalidParameter,
"status and compensationStatus are not supported at the same time", nil)
}
if status == "" && compenStatus == "" {
return false, exception.NewEngineExecutionException(seataErrors.InvalidParameter,
"status and compensationStatus must input at least one", nil)
}
if statelang.SU == compenStatus {
message := p.buildExceptionMessage(stateMachineInstance, nil, nil, "", statelang.SU, operation)
return false, exception.NewEngineExecutionException(seataErrors.OperationDenied,
message, nil)
}

if stateMachineInstance.IsRunning() &&
!pcext.IsTimeout(stateMachineInstance.UpdatedTime(), p.StateMachineConfig.GetTransOperationTimeout()) {
return false, exception.NewEngineExecutionException(seataErrors.OperationDenied,
"StateMachineInstance [id:"+stateMachineInstance.ID()+"] is running, operation["+operation+
"] denied", nil)
}

if (denyStatus == nil || len(denyStatus) == 0) && (acceptStatus == nil || len(acceptStatus) == 0) {
return false, exception.NewEngineExecutionException(seataErrors.InvalidParameter,
"StateMachineInstance[id:"+stateMachineInstance.ID()+
"], acceptable status and deny status must input at least one", nil)
}

currentStatus := compenStatus
if status != "" {
currentStatus = status
}

if denyStatus != nil && len(denyStatus) == 0 {
for _, tempDenyStatus := range denyStatus {
if tempDenyStatus == currentStatus {
message := p.buildExceptionMessage(stateMachineInstance, acceptStatus, denyStatus, status,
compenStatus, operation)
return false, exception.NewEngineExecutionException(seataErrors.OperationDenied,
message, nil)
}
}
}

if acceptStatus == nil || len(acceptStatus) == 0 {
return true, nil
} else {
for _, tempStatus := range acceptStatus {
if tempStatus == currentStatus {
return true, nil
}
}
}

message := p.buildExceptionMessage(stateMachineInstance, acceptStatus, denyStatus, status, compenStatus,
operation)
return false, exception.NewEngineExecutionException(seataErrors.OperationDenied,
message, nil)
}

func (p ProcessCtrlStateMachineEngine) getStateMachineContextVariables(ctx context.Context,
stateMachineInstance statelang.StateMachineInstance) (map[string]any, error) {
contextVariables := stateMachineInstance.EndParams()
if contextVariables == nil || len(contextVariables) == 0 {
return p.replayContextVariables(ctx, stateMachineInstance)
}
return contextVariables, nil
}

func (p ProcessCtrlStateMachineEngine) buildExceptionMessage(instance statelang.StateMachineInstance,
acceptStatus []statelang.ExecutionStatus, denyStatus []statelang.ExecutionStatus, status statelang.ExecutionStatus,
compenStatus statelang.ExecutionStatus, operation string) string {
message := fmt.Sprintf("StateMachineInstance[id:%s]", instance.ID())
if len(acceptStatus) > 0 {
message += ",acceptable status :"
for _, tempStatus := range acceptStatus {
message += string(tempStatus) + " "
}
}

if len(denyStatus) > 0 {
message += ",deny status:"
for _, tempStatus := range denyStatus {
message += string(tempStatus) + " "
}
}

if status != "" {
message += ",current status:" + string(status)
}

if compenStatus != "" {
message += ",current compensation status:" + string(compenStatus)
}

message += fmt.Sprintf(",so operation [%s] denied", operation)
return message
}

func (p ProcessCtrlStateMachineEngine) putBusinesskeyToContextariables(instance statelang.StateMachineInstance, variables map[string]any) {
if instance.BusinessKey() != "" && variables[constant.VarNameBusinesskey] == "" {
variables[constant.VarNameBusinesskey] = instance.BusinessKey()
}
}

func (p ProcessCtrlStateMachineEngine) nullSafeCopy(srcMap map[string]any, destMap map[string]any) {
for key, value := range srcMap {
if value == nil {
destMap[key] = value
}
}
}

func (p ProcessCtrlStateMachineEngine) findOutLastNeedForwardStateInstance(ctx context.Context, processContext process_ctrl.ProcessContext) statelang.StateInstance {
stateMachineInstance := processContext.GetVariable(constant.VarNameStateMachineInst).(statelang.StateMachineInstance)
lastForwardState := processContext.GetVariable(constant.VarNameStateInst).(statelang.StateInstance)

actList := stateMachineInstance.StateList()
for i := len(actList) - 1; i >= 0; i-- {
stateInstance := actList[i]
if pcext.GetOriginStateName(stateInstance) == pcext.GetOriginStateName(lastForwardState) && stateInstance.Status() != statelang.SU {
return stateInstance
}
}
return lastForwardState
}

+ 83
- 0
pkg/saga/statemachine/engine/exception/exception.go View File

@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package exception

import (
perror "errors"
"fmt"
"github.com/seata/seata-go/pkg/util/errors"
)

type EngineExecutionException struct {
errors.SeataError
stateName string
stateMachineName string
stateMachineInstanceId string
stateInstanceId string
ErrCode string
}

func (e *EngineExecutionException) Error() string {
return fmt.Sprintf("EngineExecutionException: %s", e.ErrCode)
}

func NewEngineExecutionException(code errors.TransactionErrorCode, msg string, parent error) *EngineExecutionException {
seataError := errors.New(code, msg, parent)
return &EngineExecutionException{
SeataError: *seataError,
}
}
func IsEngineExecutionException(err error) (*EngineExecutionException, bool) {
var fie *EngineExecutionException
if perror.As(err, &fie) {
return fie, true
}
return nil, false
}

func (e *EngineExecutionException) StateName() string {
return e.stateName
}

func (e *EngineExecutionException) SetStateName(stateName string) {
e.stateName = stateName
}

func (e *EngineExecutionException) StateMachineName() string {
return e.stateMachineName
}

func (e *EngineExecutionException) SetStateMachineName(stateMachineName string) {
e.stateMachineName = stateMachineName
}

func (e *EngineExecutionException) StateMachineInstanceId() string {
return e.stateMachineInstanceId
}

func (e *EngineExecutionException) SetStateMachineInstanceId(stateMachineInstanceId string) {
e.stateMachineInstanceId = stateMachineInstanceId
}

func (e *EngineExecutionException) StateInstanceId() string {
return e.stateInstanceId
}

func (e *EngineExecutionException) SetStateInstanceId(stateInstanceId string) {
e.stateInstanceId = stateInstanceId
}

+ 68
- 0
pkg/saga/statemachine/engine/exception/exception_test.go View File

@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package exception

import (
"errors"
"testing"

pkgerr "github.com/seata/seata-go/pkg/util/errors"
)

func TestIsEngineExecutionException(t *testing.T) {
cases := []struct {
name string
err error
wantOk bool
wantMsg string
}{
{
name: "EngineExecutionException",
err: &EngineExecutionException{SeataError: pkgerr.SeataError{Message: "engine error"}},
wantOk: true,
wantMsg: "engine error",
},
{
name: "Other error",
err: errors.New("some other error"),
wantOk: false,
wantMsg: "",
},
{
name: "nil error",
err: nil,
wantOk: false,
wantMsg: "",
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
fie, ok := IsEngineExecutionException(c.err)
if ok != c.wantOk {
t.Errorf("expected ok=%v, got %v", c.wantOk, ok)
}
if ok && fie.SeataError.Message != c.wantMsg {
t.Errorf("expected Message=%q, got %q", c.wantMsg, fie.SeataError.Message)
}
if !ok && fie != nil {
t.Errorf("expected fie=nil, got %v", fie)
}
})
}
}

+ 32
- 0
pkg/saga/statemachine/engine/exception/forward_invalid_exception.go View File

@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package exception

import "errors"

type ForwardInvalidException struct {
EngineExecutionException
}

func IsForwardInvalidException(err error) (*ForwardInvalidException, bool) {
var fie *ForwardInvalidException
if errors.As(err, &fie) {
return fie, true
}
return nil, false
}

+ 68
- 0
pkg/saga/statemachine/engine/exception/forward_invalid_exception_test.go View File

@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package exception

import (
"errors"
"testing"

pkgerr "github.com/seata/seata-go/pkg/util/errors"
)

func TestIsForwardInvalidException(t *testing.T) {
cases := []struct {
name string
err error
wantOk bool
wantMsg string
}{
{
name: "ForwardInvalidException",
err: &ForwardInvalidException{EngineExecutionException: EngineExecutionException{SeataError: pkgerr.SeataError{Message: "forward invalid"}}},
wantOk: true,
wantMsg: "forward invalid",
},
{
name: "Other error",
err: errors.New("some other error"),
wantOk: false,
wantMsg: "",
},
{
name: "nil error",
err: nil,
wantOk: false,
wantMsg: "",
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
fie, ok := IsForwardInvalidException(c.err)
if ok != c.wantOk {
t.Errorf("expected ok=%v, got %v", c.wantOk, ok)
}
if ok && fie.SeataError.Message != c.wantMsg {
t.Errorf("expected Message=%q, got %q", c.wantMsg, fie.SeataError.Message)
}
if !ok && fie != nil {
t.Errorf("expected fie=nil, got %v", fie)
}
})
}
}

+ 96
- 0
pkg/saga/statemachine/engine/expr/cel_expression.go View File

@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package expr

import (
"github.com/google/cel-go/cel"
)

type CELExpression struct {
env *cel.Env
program cel.Program
expression string
}

// This is used to make sure that the CELExpression implements the Expression interface.
var _ Expression = (*CELExpression)(nil)

// NewCELExpression creates a new CELExpression instance
// by compiling the provided expression.
// how to use cel: https://codelabs.developers.google.com/codelabs/cel-go
func NewCELExpression(expression string) (*CELExpression, error) {
// Create the standard environment.
env, err := cel.NewEnv(
cel.Variable(
"elContext", cel.DynType,
),
)

if err != nil {
return nil, err
}

// Check that the expression compiles and returns a String.
ast, issues := env.Compile(expression)
// Report syntax errors, if present.
if issues != nil && issues.Err() != nil {
return nil, issues.Err()
}

// Type-check the expression ofr correctness.
checkedAst, issues := env.Check(ast)
if issues.Err() != nil {
return nil, issues.Err()
}

program, err := env.Program(checkedAst)
if err != nil {
return nil, err
}

CELExpression := &CELExpression{
env: env,
program: program,
expression: expression,
}

return CELExpression, nil
}

// Value evaluates the expression with the provided context and returns the result.
func (c *CELExpression) Value(elContext any) any {
result, _, err := c.program.Eval(map[string]any{
"elContext": elContext,
})
if err != nil {
return err
}
return result.Value()
}

// TODO: I think this is not needed.
// I see seata-java doesn't use this method.
// Do we need to implement this?
func (c *CELExpression) SetValue(val any, elContext any) {
panic("implement me")
}

// ExpressionString returns the expression string.
func (c *CELExpression) ExpressionString() string {
return c.expression
}

+ 36
- 0
pkg/saga/statemachine/engine/expr/cel_expression_factory.go View File

@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package expr

type CELExpressionFactory struct {
expr *CELExpression
}

// This is used to make sure that CELExpressionFactory implements ExpressionFactory
var _ ExpressionFactory = (*CELExpressionFactory)(nil)

// NewCELExpressionFactory creates a new instance of CELExpressionFactory
func NewCELExpressionFactory() *CELExpressionFactory {
return &CELExpressionFactory{}
}

// CreateExpression creates a new instance of CELExpression
func (f *CELExpressionFactory) CreateExpression(expression string) Expression {
f.expr, _ = NewCELExpression(expression)
return f.expr
}

+ 31
- 0
pkg/saga/statemachine/engine/expr/cel_expression_factory_test.go View File

@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package expr

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestCelExpressionFactory(t *testing.T) {
factory := NewCELExpressionFactory()
expression := factory.CreateExpression("'Hello' + ' World!'")
value := expression.Value(nil)
assert.Equal(t, "Hello World!", value, "Expected 'Hello World!'")
}

+ 42
- 0
pkg/saga/statemachine/engine/expr/cel_expression_test.go View File

@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package expr

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestValueWithNil(t *testing.T) {
expr, err := NewCELExpression("'Hello' + ' World!'")
assert.NoError(t, err, "Error creating expression")
value := expr.Value(nil)
assert.Equal(t, "Hello World!", value, "Expected 'Hello World!'")
}

func TestValue(t *testing.T) {
expr, err := NewCELExpression("elContext['name'] + ' World!'")
assert.NoError(t, err, "Error creating expression")
elContext := map[string]any{
"name": "Hello",
}

value := expr.Value(elContext)
assert.Equal(t, "Hello World!", value, "Expected 'Hello World!'")
}

+ 21
- 0
pkg/saga/statemachine/engine/expr/error_expression.go View File

@@ -0,0 +1,21 @@
package expr

// ErrorExpression is a placeholder implementation that always reports an error.
// When parsing/constructing an expression fails, this type is returned directly.
// The Value() method returns the error as-is, and SetValue() is a no-op.
type ErrorExpression struct {
err error
expressionStr string
}

func (e *ErrorExpression) Value(elContext any) any {
return e.err
}

func (e *ErrorExpression) SetValue(value any, elContext any) {
// No write operation for error expressions
}

func (e *ErrorExpression) ExpressionString() string {
return e.expressionStr
}

+ 30
- 0
pkg/saga/statemachine/engine/expr/expression.go View File

@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package expr

// expression interface
type Expression interface {
// get the value of the expression
// elContext is the el context
Value(elContext any) any

SetValue(value any, elContext any)

// return the expression string
ExpressionString() string
}

+ 22
- 0
pkg/saga/statemachine/engine/expr/expression_factory.go View File

@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package expr

type ExpressionFactory interface {
CreateExpression(expression string) Expression
}

+ 50
- 0
pkg/saga/statemachine/engine/expr/expression_factory_manager.go View File

@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package expr

import (
"maps"
"strings"
)

const DefaultExpressionType = "Default"

type ExpressionFactoryManager struct {
expressionFactoryMap map[string]ExpressionFactory
}

func NewExpressionFactoryManager() *ExpressionFactoryManager {
return &ExpressionFactoryManager{
expressionFactoryMap: make(map[string]ExpressionFactory),
}
}

func (e *ExpressionFactoryManager) GetExpressionFactory(expressionType string) ExpressionFactory {
if strings.TrimSpace(expressionType) == "" {
expressionType = DefaultExpressionType
}
return e.expressionFactoryMap[expressionType]
}

func (e *ExpressionFactoryManager) SetExpressionFactoryMap(expressionFactoryMap map[string]ExpressionFactory) {
maps.Copy(e.expressionFactoryMap, expressionFactoryMap)
}

func (e *ExpressionFactoryManager) PutExpressionFactory(expressionType string, factory ExpressionFactory) {
e.expressionFactoryMap[expressionType] = factory
}

+ 105
- 0
pkg/saga/statemachine/engine/expr/expression_resolver.go View File

@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package expr

import (
"errors"
"strings"
)

type ExpressionResolver interface {
Expression(expressionStr string) Expression
ExpressionFactoryManager() ExpressionFactoryManager
SetExpressionFactoryManager(expressionFactoryManager ExpressionFactoryManager)
}

type DefaultExpressionResolver struct {
expressionFactoryManager ExpressionFactoryManager
}

func (resolver *DefaultExpressionResolver) Expression(expressionStr string) Expression {
expressionStruct, err := parseExpressionStruct(expressionStr)
if err != nil {
return nil
}
expressionFactory := resolver.expressionFactoryManager.GetExpressionFactory(expressionStruct.typ)
if expressionFactory == nil {
return nil
}
return expressionFactory.CreateExpression(expressionStruct.content)
}

func (resolver *DefaultExpressionResolver) ExpressionFactoryManager() ExpressionFactoryManager {
return resolver.expressionFactoryManager
}

func (resolver *DefaultExpressionResolver) SetExpressionFactoryManager(expressionFactoryManager ExpressionFactoryManager) {
resolver.expressionFactoryManager = expressionFactoryManager
}

type ExpressionStruct struct {
typeStart int
typeEnd int
end int
typ string
content string
}

// old style: $type{content}
// new style: $type.content
func parseExpressionStruct(expressionStr string) (*ExpressionStruct, error) {
eStruct := &ExpressionStruct{}
eStruct.typeStart = strings.Index(expressionStr, "$")
if eStruct.typeStart == -1 {
return nil, errors.New("invalid expression")
}

dot := strings.Index(expressionStr, ".")
leftBracket := strings.Index(expressionStr, "{")

isOldEvaluatorStyle := false
if eStruct.typeStart == 0 {
if leftBracket < 0 && dot < 0 {
return nil, errors.New("invalid expression")
}
// Backward compatible for structure: $expressionType{expressionContent}
if leftBracket > 0 && (leftBracket < dot || dot < 0) {
eStruct.typeEnd = leftBracket
isOldEvaluatorStyle = true
}
if dot > 0 && (dot < leftBracket || leftBracket < 0) {
eStruct.typeEnd = dot
}
}

if eStruct.typeStart == 0 && leftBracket != -1 && leftBracket < dot {
// Backward compatible for structure: $expressionType{expressionContent}
eStruct.typeEnd = strings.Index(expressionStr, "{")
isOldEvaluatorStyle = true
}

eStruct.typ = expressionStr[eStruct.typeStart+1 : eStruct.typeEnd]

if isOldEvaluatorStyle {
eStruct.end = strings.Index(expressionStr, "}")
} else {
eStruct.end = len(expressionStr)
}
eStruct.content = expressionStr[eStruct.typeEnd+1 : eStruct.end]
return eStruct, nil
}

+ 53
- 0
pkg/saga/statemachine/engine/expr/expression_resolver_test.go View File

@@ -0,0 +1,53 @@
package expr

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestParseExpressionStruct(t *testing.T) {
tests := []struct {
expressionStr string
expected *ExpressionStruct
expectError bool
}{
{
expressionStr: "$type{content}",
expected: &ExpressionStruct{
typeStart: 0,
typeEnd: 5,
typ: "type",
end: 13,
content: "content",
},
expectError: false,
},
{
expressionStr: "$type.content",
expected: &ExpressionStruct{
typeStart: 0,
typeEnd: 5,
typ: "type",
end: 13,
content: "content",
},
expectError: false,
},
{
expressionStr: "invalid expression",
expected: nil,
expectError: true,
},
}
for _, test := range tests {
result, err := parseExpressionStruct(test.expressionStr)
if test.expectError {
assert.Error(t, err, "Expected an error for input '%s'", test.expressionStr)
} else {
assert.NoError(t, err, "Did not expect an error for input '%s'", test.expressionStr)
assert.NotNil(t, result, "Expected a non-nil result for input '%s'", test.expressionStr)
assert.Equal(t, *test.expected, *result, "Expected result %+v, got %+v for input '%s'", test.expected, result, test.expressionStr)
}
}
}

+ 64
- 0
pkg/saga/statemachine/engine/expr/sequence_expression.go View File

@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package expr

import (
"github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence"
)

type SequenceExpression struct {
seqGenerator sequence.SeqGenerator
entity string
rule string
}

func (s *SequenceExpression) SeqGenerator() sequence.SeqGenerator {
return s.seqGenerator
}

func (s *SequenceExpression) SetSeqGenerator(seqGenerator sequence.SeqGenerator) {
s.seqGenerator = seqGenerator
}

func (s *SequenceExpression) Entity() string {
return s.entity
}

func (s *SequenceExpression) SetEntity(entity string) {
s.entity = entity
}

func (s *SequenceExpression) Rule() string {
return s.rule
}

func (s *SequenceExpression) SetRule(rule string) {
s.rule = rule
}

func (s SequenceExpression) Value(elContext any) any {
return s.seqGenerator.GenerateId(s.entity, s.rule)
}

func (s SequenceExpression) SetValue(value any, elContext any) {

}

func (s SequenceExpression) ExpressionString() string {
return s.entity + "|" + s.rule
}

+ 39
- 0
pkg/saga/statemachine/engine/expr/sequence_expression_factory.go View File

@@ -0,0 +1,39 @@
package expr

import (
"fmt"
"strings"

"github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence"
)

// SequenceExpressionFactory implements the ExpressionFactory interface,
// designed to parse strings in the format "entity|rule" and create SequenceExpression instances.
// If the format is invalid, it returns an *ErrorExpression containing the parsing error.
type SequenceExpressionFactory struct {
seqGenerator sequence.SeqGenerator
}

func NewSequenceExpressionFactory(seqGenerator sequence.SeqGenerator) *SequenceExpressionFactory {
return &SequenceExpressionFactory{seqGenerator: seqGenerator}
}

// CreateExpression parses the input string into a SequenceExpression.
// The input must be in the format "entity|rule". If the format is invalid,
// it returns an ErrorExpression with a descriptive error message.
func (f *SequenceExpressionFactory) CreateExpression(expression string) Expression {
parts := strings.Split(expression, "|")
if len(parts) != 2 {
return &ErrorExpression{
err: fmt.Errorf("invalid sequence expression format: %s, expected 'entity|rule'", expression),
expressionStr: expression,
}
}

seqExpr := &SequenceExpression{
seqGenerator: f.seqGenerator,
entity: strings.TrimSpace(parts[0]),
rule: strings.TrimSpace(parts[1]),
}
return seqExpr
}

+ 225
- 0
pkg/saga/statemachine/engine/invoker/func_invoker.go View File

@@ -0,0 +1,225 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"sync"
"time"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"github.com/seata/seata-go/pkg/util/log"
)

type FuncInvoker struct {
ServicesMapLock sync.Mutex
servicesMap map[string]FuncService
}

func NewFuncInvoker() *FuncInvoker {
return &FuncInvoker{
servicesMap: make(map[string]FuncService),
}
}

func (f *FuncInvoker) RegisterService(serviceName string, service FuncService) {
f.ServicesMapLock.Lock()
defer f.ServicesMapLock.Unlock()
f.servicesMap[serviceName] = service
}

func (f *FuncInvoker) GetService(serviceName string) FuncService {
f.ServicesMapLock.Lock()
defer f.ServicesMapLock.Unlock()
return f.servicesMap[serviceName]
}

func (f *FuncInvoker) Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error) {
serviceTaskStateImpl := service.(*state.ServiceTaskStateImpl)
FuncService := f.GetService(serviceTaskStateImpl.ServiceName())
if FuncService == nil {
return nil, errors.New("no func service " + serviceTaskStateImpl.ServiceName() + " for service task state")
}

if serviceTaskStateImpl.IsAsync() {
go func() {
_, err := FuncService.CallMethod(serviceTaskStateImpl, input)
if err != nil {
log.Errorf("invoke Service[%s].%s failed, err is %s", serviceTaskStateImpl.ServiceName(), serviceTaskStateImpl.ServiceMethod(), err.Error())
}
}()
return nil, nil
}

return FuncService.CallMethod(serviceTaskStateImpl, input)
}

func (f *FuncInvoker) Close(ctx context.Context) error {
return nil
}

type FuncService interface {
CallMethod(ServiceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error)
}

type FuncServiceImpl struct {
serviceName string
methodLock sync.Mutex
method any
}

func NewFuncService(serviceName string, method any) *FuncServiceImpl {
return &FuncServiceImpl{
serviceName: serviceName,
method: method,
}
}

func (f *FuncServiceImpl) getMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl) (*reflect.Value, error) {
method := serviceTaskStateImpl.Method()
if method == nil {
return f.initMethod(serviceTaskStateImpl)
}
return method, nil
}

func (f *FuncServiceImpl) prepareArguments(input []any) []reflect.Value {
args := make([]reflect.Value, len(input))
for i, arg := range input {
args[i] = reflect.ValueOf(arg)
}
return args
}

func (f *FuncServiceImpl) CallMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error) {
method, err := f.getMethod(serviceTaskStateImpl)
if err != nil {
return nil, err
}

args := f.prepareArguments(input)

retryCountMap := make(map[state.Retry]int)
for {
res, err, shouldRetry := f.invokeMethod(method, args, serviceTaskStateImpl, retryCountMap)

if !shouldRetry {
if err != nil {
return nil, errors.New("invoke service[" + serviceTaskStateImpl.ServiceName() + "]." + serviceTaskStateImpl.ServiceMethod() + " failed, err is " + err.Error())
}
return res, nil
}
}
}

func (f *FuncServiceImpl) initMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl) (*reflect.Value, error) {
methodName := serviceTaskStateImpl.ServiceMethod()
f.methodLock.Lock()
defer f.methodLock.Unlock()
methodValue := reflect.ValueOf(f.method)
if methodValue.IsZero() {
return nil, errors.New("invalid method when func call, serviceName: " + f.serviceName)
}

if methodValue.Kind() == reflect.Func {
serviceTaskStateImpl.SetMethod(&methodValue)
return &methodValue, nil
}

method := methodValue.MethodByName(methodName)
if method.IsZero() {
return nil, errors.New("invalid method name when func call, serviceName: " + f.serviceName + ", methodName: " + methodName)
}
serviceTaskStateImpl.SetMethod(&method)
return &method, nil
}

func (f *FuncServiceImpl) invokeMethod(method *reflect.Value, args []reflect.Value, serviceTaskStateImpl *state.ServiceTaskStateImpl, retryCountMap map[state.Retry]int) ([]reflect.Value, error, bool) {
var res []reflect.Value
var resErr error
var shouldRetry bool

defer func() {
if r := recover(); r != nil {
errStr := fmt.Sprintf("%v", r)
retry := f.matchRetry(serviceTaskStateImpl, errStr)
resErr = errors.New(errStr)
if retry != nil {
shouldRetry = f.needRetry(serviceTaskStateImpl, retryCountMap, retry, resErr)
}
}
}()

outs := method.Call(args)
if err, ok := outs[len(outs)-1].Interface().(error); ok {
resErr = err
errStr := err.Error()
retry := f.matchRetry(serviceTaskStateImpl, errStr)
if retry != nil {
shouldRetry = f.needRetry(serviceTaskStateImpl, retryCountMap, retry, resErr)
}
return nil, resErr, shouldRetry
}

res = outs
return res, nil, false
}

func (f *FuncServiceImpl) matchRetry(impl *state.ServiceTaskStateImpl, str string) state.Retry {
if impl.Retry() != nil {
for _, retry := range impl.Retry() {
if retry.Exceptions() != nil {
for _, exception := range retry.Exceptions() {
if strings.Contains(str, exception) {
return retry
}
}
}
}
}
return nil
}

func (f *FuncServiceImpl) needRetry(impl *state.ServiceTaskStateImpl, countMap map[state.Retry]int, retry state.Retry, err error) bool {
attempt, exist := countMap[retry]
if !exist {
countMap[retry] = 0
}

if attempt >= retry.MaxAttempt() {
return false
}

interval := retry.IntervalSecond()
backoffRate := retry.BackoffRate()
curInterval := int64(interval * 1000)
if attempt != 0 {
curInterval = int64(interval * backoffRate * float64(attempt) * 1000)
}

log.Warnf("invoke service[%s.%s] failed, will retry after %s millis, current retry count: %s, current err: %s",
impl.ServiceName(), impl.ServiceMethod(), curInterval, attempt, err)

time.Sleep(time.Duration(curInterval) * time.Millisecond)
countMap[retry] = attempt + 1
return true
}

+ 168
- 0
pkg/saga/statemachine/engine/invoker/func_invoker_test.go View File

@@ -0,0 +1,168 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"errors"
"fmt"
"testing"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
)

// struct's method test
type mockFuncImpl struct {
invokeCount int
}

func (m *mockFuncImpl) SayHelloRight(word string) (string, error) {
m.invokeCount++
fmt.Println("invoke right")
return word, nil
}

func (m *mockFuncImpl) SayHelloRightLater(word string, delay int) (string, error) {
m.invokeCount++
if delay == m.invokeCount {
fmt.Println("invoke right")
return word, nil
}
fmt.Println("invoke fail")
return "", errors.New("invoke failed")
}

func TestFuncInvokerInvokeSucceed(t *testing.T) {
tests := []struct {
name string
input []any
taskState state.ServiceTaskState
expected string
expectErr bool
}{
{
name: "Invoke Struct Succeed",
input: []any{"hello"},
taskState: newFuncHelloServiceTaskState(),
expected: "hello",
expectErr: false,
},
{
name: "Invoke Struct In Retry",
input: []any{"hello", 2},
taskState: newFuncHelloServiceTaskStateWithRetry(),
expected: "hello",
expectErr: false,
},
}

ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
invoker := newFuncServiceInvoker()
values, err := invoker.Invoke(ctx, tt.input, tt.taskState)

if (err != nil) != tt.expectErr {
t.Errorf("expected error: %v, got: %v", tt.expectErr, err)
}

if values == nil || len(values) == 0 {
t.Fatal("no value in values")
}

if resultString, ok := values[0].Interface().(string); ok {
if resultString != tt.expected {
t.Errorf("expect %s, but got %s", tt.expected, resultString)
}
} else {
t.Errorf("expected string, but got %v", values[0].Interface())
}

if resultError, ok := values[1].Interface().(error); ok {
if resultError != nil {
t.Errorf("expect nil, but got %s", resultError)
}
}
})
}
}

func TestFuncInvokerInvokeFailed(t *testing.T) {
tests := []struct {
name string
input []any
taskState state.ServiceTaskState
expected string
expectErr bool
}{
{
name: "Invoke Struct Failed In Retry",
input: []any{"hello", 5},
taskState: newFuncHelloServiceTaskStateWithRetry(),
expected: "",
expectErr: true,
},
}

ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
invoker := newFuncServiceInvoker()
_, err := invoker.Invoke(ctx, tt.input, tt.taskState)

if (err != nil) != tt.expectErr {
t.Errorf("expected error: %v, got: %v", tt.expectErr, err)
}
})
}
}

func newFuncServiceInvoker() ServiceInvoker {
mockFuncInvoker := NewFuncInvoker()
mockFuncService := &mockFuncImpl{}
mockService := NewFuncService("hello", mockFuncService)
mockFuncInvoker.RegisterService("hello", mockService)
return mockFuncInvoker
}

func newFuncHelloServiceTaskState() state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName("hello")
serviceTaskStateImpl.SetIsAsync(false)
serviceTaskStateImpl.SetServiceName("hello")
serviceTaskStateImpl.SetServiceType("func")
serviceTaskStateImpl.SetServiceMethod("SayHelloRight")
return serviceTaskStateImpl
}

func newFuncHelloServiceTaskStateWithRetry() state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName("hello")
serviceTaskStateImpl.SetIsAsync(false)
serviceTaskStateImpl.SetServiceName("hello")
serviceTaskStateImpl.SetServiceType("func")
serviceTaskStateImpl.SetServiceMethod("SayHelloRightLater")

retryImpl := &state.RetryImpl{}
retryImpl.SetExceptions([]string{"fail"})
retryImpl.SetIntervalSecond(1)
retryImpl.SetMaxAttempt(3)
retryImpl.SetBackoffRate(0.9)
serviceTaskStateImpl.SetRetry([]state.Retry{retryImpl})
return serviceTaskStateImpl
}

+ 261
- 0
pkg/saga/statemachine/engine/invoker/grpc_invoker.go View File

@@ -0,0 +1,261 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"errors"
"fmt"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"github.com/seata/seata-go/pkg/util/log"
"google.golang.org/grpc"
"reflect"
"strings"
"sync"
"time"
)

type GRPCInvoker struct {
clients map[string]GRPCClient
clientsMapLock sync.Mutex
needClose bool
}

func NewGRPCInvoker() *GRPCInvoker {
return &GRPCInvoker{
clients: make(map[string]GRPCClient),
}
}

func (g *GRPCInvoker) NeedClose() bool {
return g.needClose
}

func (g *GRPCInvoker) SetNeedClose(needClose bool) {
g.needClose = needClose
}

func (g *GRPCInvoker) RegisterClient(serviceName string, client GRPCClient) {
g.clientsMapLock.Lock()
defer g.clientsMapLock.Unlock()

g.clients[serviceName] = client
}

func (g *GRPCInvoker) GetClient(serviceName string) GRPCClient {
g.clientsMapLock.Lock()
defer g.clientsMapLock.Unlock()

if client, ok := g.clients[serviceName]; ok {
return client
}

return nil
}

func (g *GRPCInvoker) Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error) {
serviceTaskStateImpl := service.(*state.ServiceTaskStateImpl)
client := g.GetClient(serviceTaskStateImpl.ServiceName())
if client == nil {
return nil, errors.New(fmt.Sprintf("no grpc client %s for service task state", serviceTaskStateImpl.ServiceName()))
}

// context is the first arg in grpc client method
input = append([]any{ctx}, input...)
if serviceTaskStateImpl.IsAsync() {
go func() {
_, err := client.CallMethod(serviceTaskStateImpl, input)
if err != nil {
log.Errorf("invoke Service[%s].%s failed, err is %s", serviceTaskStateImpl.ServiceName(),
serviceTaskStateImpl.ServiceMethod(), err)
}
}()
return nil, nil
} else {
return client.CallMethod(serviceTaskStateImpl, input)
}
}

func (g *GRPCInvoker) Close(ctx context.Context) error {
if g.needClose {
for _, client := range g.clients {
err := client.CloseConnection()
if err != nil {
return err
}
}
}
return nil
}

type GRPCClient interface {
CallMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error)

CloseConnection() error
}

type GPRCClientImpl struct {
serviceName string
client any
connection *grpc.ClientConn
methodLock sync.Mutex
}

func NewGRPCClient(serviceName string, client any, connection *grpc.ClientConn) *GPRCClientImpl {
return &GPRCClientImpl{
serviceName: serviceName,
client: client,
connection: connection,
}
}

func (g *GPRCClientImpl) CallMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error) {

if serviceTaskStateImpl.Method() == nil {
err := g.initMethod(serviceTaskStateImpl)
if err != nil {
return nil, err
}
}
method := serviceTaskStateImpl.Method()

args := make([]reflect.Value, 0, len(input))
for _, arg := range input {
args = append(args, reflect.ValueOf(arg))
}

retryCountMap := make(map[state.Retry]int)
for {
res, err, shouldRetry := func() (res []reflect.Value, resErr error, shouldRetry bool) {
defer func() {
// err may happen in the method invoke (panic) and method return, we try to find err and use it to decide retry by
// whether contains exception or not
if r := recover(); r != nil {
errStr := fmt.Sprintf("%v", r)
retry := g.matchRetry(serviceTaskStateImpl, errStr)
res = nil
resErr = errors.New(errStr)

if retry == nil {
return
}
shouldRetry = g.needRetry(serviceTaskStateImpl, retryCountMap, retry, resErr)
return
}
}()

outs := method.Call(args)
// err is the last arg in grpc client method
if err, ok := outs[len(outs)-1].Interface().(error); ok {
errStr := err.Error()
retry := g.matchRetry(serviceTaskStateImpl, errStr)
res = nil
resErr = err

if retry == nil {
return
}
shouldRetry = g.needRetry(serviceTaskStateImpl, retryCountMap, retry, resErr)
return
}

// invoke success
res = outs
resErr = nil
shouldRetry = false
return
}()

if !shouldRetry {
if err != nil {
return nil, errors.New(fmt.Sprintf("invoke Service[%s] failed, not satisfy retry config, the last err is %s",
serviceTaskStateImpl.ServiceName(), err))
}
return res, nil
}
}
}

func (g *GPRCClientImpl) CloseConnection() error {
if g.connection != nil {
err := g.connection.Close()
if err != nil {
return err
}
}
return nil
}

func (g *GPRCClientImpl) initMethod(serviceTaskStateImpl *state.ServiceTaskStateImpl) error {
methodName := serviceTaskStateImpl.ServiceMethod()
g.methodLock.Lock()
defer g.methodLock.Unlock()
clientValue := reflect.ValueOf(g.client)
if clientValue.IsZero() {
return errors.New(fmt.Sprintf("invalid client value when grpc client call, serviceName: %s", g.serviceName))
}
method := clientValue.MethodByName(methodName)
if method.IsZero() {
return errors.New(fmt.Sprintf("invalid client method when grpc client call, serviceName: %s, serviceMethod: %s",
g.serviceName, methodName))
}
serviceTaskStateImpl.SetMethod(&method)
return nil
}

func (g *GPRCClientImpl) matchRetry(impl *state.ServiceTaskStateImpl, str string) state.Retry {
if impl.Retry() != nil {
for _, retry := range impl.Retry() {
if retry.Exceptions() != nil {
for _, exception := range retry.Exceptions() {
if strings.Contains(str, exception) {
return retry
}
}
}
}
}
return nil
}

func (g *GPRCClientImpl) needRetry(impl *state.ServiceTaskStateImpl, countMap map[state.Retry]int, retry state.Retry, err error) bool {
attempt, exist := countMap[retry]
if !exist {
countMap[retry] = 0
}

if attempt >= retry.MaxAttempt() {
return false
}

intervalSecond := retry.IntervalSecond()
backoffRate := retry.BackoffRate()
var currentInterval int64
if attempt == 0 {
currentInterval = int64(intervalSecond * 1000)
} else {
currentInterval = int64(intervalSecond * backoffRate * float64(attempt) * 1000)
}

log.Warnf("invoke service[%s.%s] failed, will retry after %s millis, current retry count: %s, current err: %s",
impl.ServiceName(), impl.ServiceMethod(), currentInterval, attempt, err)

time.Sleep(time.Duration(currentInterval) * time.Millisecond)
countMap[retry] = attempt + 1
return true
}

+ 185
- 0
pkg/saga/statemachine/engine/invoker/grpc_invoker_test.go View File

@@ -0,0 +1,185 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"errors"
"fmt"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"testing"
"time"

pb "github.com/seata/seata-go/testdata/saga/engine/invoker/grpc"
)

type MockGPRCClientImpl struct {
GPRCClientImpl
}

type mockClientImpl struct {
invokeCount int
}

func (m *mockClientImpl) SayHelloRight(ctx context.Context, word string) (string, error) {
m.invokeCount++
fmt.Println("invoke right")
return word, nil
}

func (m *mockClientImpl) SayHelloRightLater(ctx context.Context, word string, delay int) (string, error) {
m.invokeCount++
if delay == m.invokeCount {
fmt.Println("invoke right")
return word, nil
}
fmt.Println("invoke fail")
return "", errors.New("invoke failed")
}

func TestGRPCInvokerInvokeSucceedWithOutRetry(t *testing.T) {
ctx := context.Background()
invoker := newGRPCServiceInvoker()
values, err := invoker.Invoke(ctx, []any{"hello"}, newHelloServiceTaskState())
if err != nil {
t.Error(err)
return
}
if values == nil || len(values) == 0 {
t.Error("no value in values")
return
}
if values[0].Interface().(string) != "hello" {
t.Errorf("expect hello, but got %v", values[0].Interface())
}
if _, ok := values[1].Interface().(error); ok {
t.Errorf("expect nil, but got %v", values[1].Interface())
}
}

func TestGRPCInvokerInvokeSucceedInRetry(t *testing.T) {
ctx := context.Background()
invoker := newGRPCServiceInvoker()
values, err := invoker.Invoke(ctx, []any{"hello", 2}, newHelloServiceTaskStateWithRetry())
if err != nil {
t.Error(err)
return
}
if values == nil || len(values) == 0 {
t.Error("no value in values")
return
}
if values[0].Interface().(string) != "hello" {
t.Errorf("expect hello, but got %v", values[0].Interface())
}
if _, ok := values[1].Interface().(error); ok {
t.Errorf("expect nil, but got %v", values[1].Interface())
}
}

func TestGRPCInvokerInvokeFailedInRetry(t *testing.T) {
ctx := context.Background()
invoker := newGRPCServiceInvoker()
_, err := invoker.Invoke(ctx, []any{"hello", 5}, newHelloServiceTaskStateWithRetry())
if err != nil {
assert.Error(t, err)
}
}

func TestGRPCInvokerInvokeE2E(t *testing.T) {
go func() {
pb.StartProductServer()
}()
time.Sleep(3000 * time.Millisecond)
conn, err := grpc.Dial("localhost:8080", grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("did not connect: %v", err)
}
c := pb.NewProductInfoClient(conn)
grpcClient := NewGRPCClient("product", c, conn)

invoker := NewGRPCInvoker()
invoker.RegisterClient("product", grpcClient)
ctx := context.Background()
values, err := invoker.Invoke(ctx, []any{&pb.Product{Id: "123"}}, newProductServiceTaskState())
if err != nil {
t.Error(err)
return
}
t.Log(values)
err = invoker.Close(ctx)
if err != nil {
t.Error(err)
return
}
}

func newGRPCServiceInvoker() ServiceInvoker {
mockGRPCInvoker := NewGRPCInvoker()
mockGRPCClient := &mockClientImpl{}
mockClient := NewGRPCClient("hello", mockGRPCClient, &grpc.ClientConn{})
mockGRPCInvoker.RegisterClient("hello", mockClient)
return mockGRPCInvoker
}

func newProductServiceTaskState() state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName("product")
serviceTaskStateImpl.SetIsAsync(false)
serviceTaskStateImpl.SetServiceName("product")
serviceTaskStateImpl.SetServiceType("GRPC")
serviceTaskStateImpl.SetServiceMethod("AddProduct")

retryImpl := &state.RetryImpl{}
retryImpl.SetExceptions([]string{"fail"})
retryImpl.SetIntervalSecond(1)
retryImpl.SetMaxAttempt(3)
retryImpl.SetBackoffRate(0.9)
serviceTaskStateImpl.SetRetry([]state.Retry{retryImpl})
return serviceTaskStateImpl
}

func newHelloServiceTaskState() state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName("hello")
serviceTaskStateImpl.SetIsAsync(false)
serviceTaskStateImpl.SetServiceName("hello")
serviceTaskStateImpl.SetServiceType("GRPC")
serviceTaskStateImpl.SetServiceMethod("SayHelloRight")
return serviceTaskStateImpl
}

func newHelloServiceTaskStateWithRetry() state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName("hello")
serviceTaskStateImpl.SetIsAsync(false)
serviceTaskStateImpl.SetServiceName("hello")
serviceTaskStateImpl.SetServiceType("GRPC")
serviceTaskStateImpl.SetServiceMethod("SayHelloRightLater")

retryImpl := &state.RetryImpl{}
retryImpl.SetExceptions([]string{"fail"})
retryImpl.SetIntervalSecond(1)
retryImpl.SetMaxAttempt(3)
retryImpl.SetBackoffRate(0.9)
serviceTaskStateImpl.SetRetry([]state.Retry{retryImpl})
return serviceTaskStateImpl
}

+ 220
- 0
pkg/saga/statemachine/engine/invoker/http_invoker.go View File

@@ -0,0 +1,220 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"reflect"
"strings"
"sync"
"time"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"github.com/seata/seata-go/pkg/util/log"
)

const errHttpCode = 400

type HTTPInvoker struct {
clientsMapLock sync.Mutex
clients map[string]HTTPClient
}

func NewHTTPInvoker() *HTTPInvoker {
return &HTTPInvoker{
clients: make(map[string]HTTPClient),
}
}

func (h *HTTPInvoker) RegisterClient(serviceName string, client HTTPClient) {
h.clientsMapLock.Lock()
defer h.clientsMapLock.Unlock()
h.clients[serviceName] = client
}

func (h *HTTPInvoker) GetClient(serviceName string) HTTPClient {
h.clientsMapLock.Lock()
defer h.clientsMapLock.Unlock()
return h.clients[serviceName]
}

func (h *HTTPInvoker) Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error) {
serviceTaskStateImpl := service.(*state.ServiceTaskStateImpl)
client := h.GetClient(serviceTaskStateImpl.ServiceName())
if client == nil {
return nil, fmt.Errorf("no http client %s for service task state", serviceTaskStateImpl.ServiceName())
}

if serviceTaskStateImpl.IsAsync() {
go func() {
_, err := client.Call(ctx, serviceTaskStateImpl, input)
if err != nil {
log.Errorf("invoke Service[%s].%s failed, err is %s", serviceTaskStateImpl.ServiceName(),
serviceTaskStateImpl.ServiceMethod(), err)
}
}()
return nil, nil
}

return client.Call(ctx, serviceTaskStateImpl, input)
}

func (h *HTTPInvoker) Close(ctx context.Context) error {
return nil
}

type HTTPClient interface {
Call(ctx context.Context, serviceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error)
}

type HTTPClientImpl struct {
serviceName string
baseURL string
client *http.Client
}

func NewHTTPClient(serviceName string, baseURL string, client *http.Client) *HTTPClientImpl {
if client == nil {
client = &http.Client{
Timeout: time.Second * 30,
}
}
return &HTTPClientImpl{
serviceName: serviceName,
baseURL: baseURL,
client: client,
}
}

func (h *HTTPClientImpl) Call(ctx context.Context, serviceTaskStateImpl *state.ServiceTaskStateImpl, input []any) ([]reflect.Value, error) {
retryCountMap := make(map[state.Retry]int)
for {
res, err, shouldRetry := func() (res []reflect.Value, resErr error, shouldRetry bool) {
defer func() {
if r := recover(); r != nil {
errStr := fmt.Sprintf("%v", r)
retry := h.matchRetry(serviceTaskStateImpl, errStr)
resErr = errors.New(errStr)
if retry != nil {
shouldRetry = h.needRetry(serviceTaskStateImpl, retryCountMap, retry, resErr)
}
}
}()

reqBody, err := json.Marshal(input)
if err != nil {
return nil, err, false
}

req, err := http.NewRequestWithContext(ctx,
serviceTaskStateImpl.ServiceMethod(),
h.baseURL+serviceTaskStateImpl.Name(),
bytes.NewBuffer(reqBody))
if err != nil {
return nil, err, false
}

req.Header.Set("Content-Type", "application/json")

resp, err := h.client.Do(req)
if err != nil {
retry := h.matchRetry(serviceTaskStateImpl, err.Error())
if retry != nil {
return nil, err, h.needRetry(serviceTaskStateImpl, retryCountMap, retry, err)
}
return nil, err, false
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err, false
}

if resp.StatusCode >= errHttpCode {
errStr := fmt.Sprintf("HTTP error: %d - %s", resp.StatusCode, string(body))
retry := h.matchRetry(serviceTaskStateImpl, errStr)
if retry != nil {
return nil, errors.New(errStr), h.needRetry(serviceTaskStateImpl, retryCountMap, retry, err)
}
return nil, errors.New(errStr), false
}

return []reflect.Value{
reflect.ValueOf(string(body)),
reflect.Zero(reflect.TypeOf((*error)(nil)).Elem()),
}, nil, false
}()

if !shouldRetry {
if err != nil {
return nil, fmt.Errorf("invoke Service[%s] failed, not satisfy retry config, the last err is %s",
serviceTaskStateImpl.ServiceName(), err)
}
return res, nil
}
}
}

func (h *HTTPClientImpl) matchRetry(impl *state.ServiceTaskStateImpl, str string) state.Retry {
if impl.Retry() != nil {
for _, retry := range impl.Retry() {
if retry.Exceptions() != nil {
for _, exception := range retry.Exceptions() {
if strings.Contains(str, exception) {
return retry
}
}
}
}
}
return nil
}

func (h *HTTPClientImpl) needRetry(impl *state.ServiceTaskStateImpl, countMap map[state.Retry]int, retry state.Retry, err error) bool {
attempt, exist := countMap[retry]
if !exist {
countMap[retry] = 0
}

if attempt >= retry.MaxAttempt() {
return false
}

intervalSecond := retry.IntervalSecond()
backoffRate := retry.BackoffRate()
var currentInterval int64
if attempt == 0 {
currentInterval = int64(intervalSecond * 1000)
} else {
currentInterval = int64(intervalSecond * backoffRate * float64(attempt) * 1000)
}

log.Warnf("invoke service[%s.%s] failed, will retry after %s millis, current retry count: %s, current err: %s",
impl.ServiceName(), impl.ServiceMethod(), currentInterval, attempt, err)

time.Sleep(time.Duration(currentInterval) * time.Millisecond)
countMap[retry] = attempt + 1
return true
}

+ 176
- 0
pkg/saga/statemachine/engine/invoker/http_invoker_test.go View File

@@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"github.com/stretchr/testify/assert"
)

func TestHTTPInvokerInvokeSucceedWithOutRetry(t *testing.T) {
// create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var input []interface{}
err := json.NewDecoder(r.Body).Decode(&input)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(input[0].(string)))
}))
defer server.Close()

// create HTTP Invoker
invoker := NewHTTPInvoker()
client := NewHTTPClient("test", server.URL+"/", &http.Client{})
invoker.RegisterClient("test", client)

// invoke
ctx := context.Background()
values, err := invoker.Invoke(ctx, []any{"hello"}, newHTTPServiceTaskState())

// verify
assert.NoError(t, err)
assert.NotNil(t, values)
assert.Equal(t, "hello", values[0].Interface())
}

func TestHTTPInvokerInvokeWithRetry(t *testing.T) {
attemptCount := 0
// create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount++
if attemptCount < 2 {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("fail"))
return
}
var input []interface{}
json.NewDecoder(r.Body).Decode(&input)
w.WriteHeader(http.StatusOK)
w.Write([]byte(input[0].(string)))
}))
defer server.Close()

// create HTTP Invoker
invoker := NewHTTPInvoker()
client := NewHTTPClient("test", server.URL+"/", &http.Client{})
invoker.RegisterClient("test", client)

// invoker
ctx := context.Background()
values, err := invoker.Invoke(ctx, []any{"hello"}, newHTTPServiceTaskStateWithRetry())

// verify
assert.NoError(t, err)
assert.NotNil(t, values)
assert.Equal(t, "hello", values[0].Interface())
assert.Equal(t, 2, attemptCount)
}

func TestHTTPInvokerInvokeFailedInRetry(t *testing.T) {
// create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("fail"))
}))
defer server.Close()

// create HTTP Invoker
invoker := NewHTTPInvoker()
client := NewHTTPClient("test", server.URL+"/", &http.Client{})
invoker.RegisterClient("test", client)

// invoker
ctx := context.Background()
_, err := invoker.Invoke(ctx, []any{"hello"}, newHTTPServiceTaskStateWithRetry())

// verify
assert.Error(t, err)
}

func TestHTTPInvokerAsyncInvoke(t *testing.T) {
called := false
// create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
defer server.Close()

// create HTTP Invoker
invoker := NewHTTPInvoker()
client := NewHTTPClient("test", server.URL+"/", &http.Client{})
invoker.RegisterClient("test", client)

// async invoke
ctx := context.Background()
taskState := newHTTPServiceTaskStateWithAsync()
_, err := invoker.Invoke(ctx, []any{"hello"}, taskState)

// verify
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
assert.True(t, called)
}

func newHTTPServiceTaskState() state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName("test")
serviceTaskStateImpl.SetIsAsync(false)
serviceTaskStateImpl.SetServiceName("test")
serviceTaskStateImpl.SetServiceType("HTTP")
serviceTaskStateImpl.SetServiceMethod("POST")
return serviceTaskStateImpl
}

func newHTTPServiceTaskStateWithAsync() state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName("test")
serviceTaskStateImpl.SetIsAsync(true)
serviceTaskStateImpl.SetServiceName("test")
serviceTaskStateImpl.SetServiceType("HTTP")
serviceTaskStateImpl.SetServiceMethod("POST")
return serviceTaskStateImpl
}

func newHTTPServiceTaskStateWithRetry() state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName("test")
serviceTaskStateImpl.SetIsAsync(false)
serviceTaskStateImpl.SetServiceName("test")
serviceTaskStateImpl.SetServiceType("HTTP")
serviceTaskStateImpl.SetServiceMethod("POST")

retryImpl := &state.RetryImpl{}
retryImpl.SetExceptions([]string{"fail"})
retryImpl.SetIntervalSecond(1)
retryImpl.SetMaxAttempt(3)
retryImpl.SetBackoffRate(0.9)
serviceTaskStateImpl.SetRetry([]state.Retry{retryImpl})
return serviceTaskStateImpl
}

+ 127
- 0
pkg/saga/statemachine/engine/invoker/invoker.go View File

@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"encoding/json"
"reflect"
"sync"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
)

type JsonParser interface {
Unmarshal(data []byte, v any) error
Marshal(v any) ([]byte, error)
}

type DefaultJsonParser struct{}

func (p *DefaultJsonParser) Unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}

func (p *DefaultJsonParser) Marshal(v any) ([]byte, error) {
return json.Marshal(v)
}

type ScriptInvokerManager interface {
GetInvoker(scriptType string) (ScriptInvoker, error)
RegisterInvoker(invoker ScriptInvoker)
Execute(ctx context.Context, scriptType string, script string, params map[string]interface{}) (interface{}, error)
}

type ScriptInvoker interface {
Invoke(ctx context.Context, script string, params map[string]interface{}) (interface{}, error)
Type() string
Close(ctx context.Context) error
}

type ScriptInvokerManagerImpl struct {
invokers map[string]ScriptInvoker
mutex sync.Mutex
}

func NewScriptInvokerManager() *ScriptInvokerManagerImpl {
return &ScriptInvokerManagerImpl{
invokers: make(map[string]ScriptInvoker),
}
}

func (m *ScriptInvokerManagerImpl) GetInvoker(scriptType string) (ScriptInvoker, error) {
if scriptType == "" {
return nil, nil
}
m.mutex.Lock()
defer m.mutex.Unlock()

invoker, exists := m.invokers[scriptType]
if !exists {
return nil, nil
}
return invoker, nil
}

func (m *ScriptInvokerManagerImpl) RegisterInvoker(invoker ScriptInvoker) {
if invoker == nil || invoker.Type() == "" {
return
}
m.mutex.Lock()
defer m.mutex.Unlock()
m.invokers[invoker.Type()] = invoker
}

func (m *ScriptInvokerManagerImpl) Execute(ctx context.Context, scriptType string, script string, params map[string]interface{}) (interface{}, error) {
invoker, err := m.GetInvoker(scriptType)
if err != nil || invoker == nil {
return nil, err
}
return invoker.Invoke(ctx, script, params)
}

type ServiceInvokerManager interface {
ServiceInvoker(serviceType string) ServiceInvoker
PutServiceInvoker(serviceType string, invoker ServiceInvoker)
}

type ServiceInvoker interface {
Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error)
Close(ctx context.Context) error
}

type ServiceInvokerManagerImpl struct {
invokers map[string]ServiceInvoker
mutex sync.Mutex
}

func NewServiceInvokerManagerImpl() *ServiceInvokerManagerImpl {
return &ServiceInvokerManagerImpl{
invokers: make(map[string]ServiceInvoker),
}
}

func (manager *ServiceInvokerManagerImpl) ServiceInvoker(serviceType string) ServiceInvoker {
return manager.invokers[serviceType]
}

func (manager *ServiceInvokerManagerImpl) PutServiceInvoker(serviceType string, invoker ServiceInvoker) {
manager.mutex.Lock()
defer manager.mutex.Unlock()
manager.invokers[serviceType] = invoker
}

+ 162
- 0
pkg/saga/statemachine/engine/invoker/javascript_script_invoker.go View File

@@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"fmt"
"sync"

"github.com/robertkrimen/otto"
)

const defaultPoolSize = 10

type JavaScriptScriptInvoker struct {
mutex sync.Mutex
jsonParser JsonParser
closed bool
vmPool chan *otto.Otto
poolSize int
}

func NewJavaScriptScriptInvoker() *JavaScriptScriptInvoker {
return &JavaScriptScriptInvoker{
jsonParser: &DefaultJsonParser{},
closed: false,
poolSize: defaultPoolSize,
vmPool: make(chan *otto.Otto, defaultPoolSize),
}
}

func NewJavaScriptScriptInvokerWithPoolSize(poolSize int) *JavaScriptScriptInvoker {
if poolSize <= 0 {
poolSize = defaultPoolSize
}
return &JavaScriptScriptInvoker{
jsonParser: &DefaultJsonParser{},
closed: false,
poolSize: poolSize,
vmPool: make(chan *otto.Otto, poolSize),
}
}

func (j *JavaScriptScriptInvoker) Type() string {
return "javascript"
}

func (j *JavaScriptScriptInvoker) Invoke(ctx context.Context, script string, params map[string]interface{}) (interface{}, error) {
j.mutex.Lock()
closed := j.closed
j.mutex.Unlock()

if closed {
return nil, fmt.Errorf("javascript invoker has been closed")
}

var vm *otto.Otto
select {
case vm = <-j.vmPool:
if err := cleanVMState(vm); err != nil {
vm = otto.New()
}
default:
vm = otto.New()
}

defer func() {
j.mutex.Lock()
defer j.mutex.Unlock()
if !j.closed {
select {
case j.vmPool <- vm:
default:
// Pool full, discard current instance
}
}
}()

for key, value := range params {
if err := vm.Set(key, value); err != nil {
return nil, fmt.Errorf("javascript set param %s error: %w", key, err)
}
}

resultChan := make(chan struct {
val otto.Value
err error
}, 1)

go func() {
defer func() {
if r := recover(); r != nil {
resultChan <- struct {
val otto.Value
err error
}{otto.UndefinedValue(), fmt.Errorf("javascript engine panic: %v", r)}
}
}()

val, err := vm.Run(script)
resultChan <- struct {
val otto.Value
err error
}{val, err}
}()

select {
case <-ctx.Done():
return nil, fmt.Errorf("javascript execution timeout: %w", ctx.Err())
case res := <-resultChan:
if res.err != nil {
return nil, fmt.Errorf("javascript execute error: %w", res.err)
}
val, err := res.val.Export()
if err != nil {
return nil, fmt.Errorf("failed to export javascript result: %w", err)
}
return val, nil
}
}

func (j *JavaScriptScriptInvoker) Close(ctx context.Context) error {
j.mutex.Lock()
defer j.mutex.Unlock()

if j.closed {
return nil
}

j.closed = true
close(j.vmPool)
for range j.vmPool {
// Let GC recycle VM resources
}
return nil
}

func cleanVMState(vm *otto.Otto) error {
_, err := vm.Run(`
for (const prop in global) {
if (!['Object', 'Array', 'Function', 'String', 'Number', 'Boolean', 'JSON', 'Date', 'RegExp'].includes(prop)) {
delete global[prop];
}
}
`)
return err
}

+ 262
- 0
pkg/saga/statemachine/engine/invoker/javascript_script_invoker_test.go View File

@@ -0,0 +1,262 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"sync"
"testing"
"time"

"github.com/robertkrimen/otto"
"github.com/stretchr/testify/assert"
)

func TestJavaScriptScriptInvoker_Type(t *testing.T) {
invoker := NewJavaScriptScriptInvoker()
assert.Equal(t, "javascript", invoker.Type())
}

func TestJavaScriptScriptInvoker_Invoke_Basic(t *testing.T) {
tests := []struct {
name string
script string
params map[string]interface{}
expected interface{}
}{
{
name: "simple expression",
script: "1 + 2",
params: nil,
expected: float64(3),
},
{
name: "param calculation",
script: "a * b + c",
params: map[string]interface{}{"a": 2, "b": 3, "c": 4},
expected: float64(10),
},
{
name: "return string",
script: "['hello', name].join(' ')",
params: map[string]interface{}{"name": "world"},
expected: "hello world",
},
{
name: "return object",
script: `var obj = {id: 1, name: name}; obj;`,
params: map[string]interface{}{"name": "test"},
expected: map[string]interface{}{"id": float64(1), "name": "test"},
},
}

invoker := NewJavaScriptScriptInvoker()
ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := invoker.Invoke(ctx, tt.script, tt.params)
assert.NoError(t, err)

if resultMap, ok := result.(map[string]interface{}); ok {
for k, v := range resultMap {
if intVal, isInt := v.(int64); isInt {
resultMap[k] = float64(intVal)
}
}
}

assert.Equal(t, tt.expected, result)
})
}
}

func TestJavaScriptScriptInvoker_Invoke_Error(t *testing.T) {
tests := []struct {
name string
script string
params map[string]interface{}
errMsg string
}{
{
name: "syntax error",
script: "1 + ",
params: nil,
errMsg: "javascript execute error",
},
{
name: "reference undefined variable",
script: "undefinedVar",
params: nil,
errMsg: "javascript execute error",
},
}

invoker := NewJavaScriptScriptInvoker()
ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := invoker.Invoke(ctx, tt.script, tt.params)

if err == nil {
t.Fatalf("Test case [%s] expected error but got none", tt.name)
}
assert.Contains(t, err.Error(), tt.errMsg, "Test case [%s] error message mismatch", tt.name)
})
}
}

func TestJavaScriptScriptInvoker_Invoke_Timeout(t *testing.T) {

script := `var target = 300; var start = new Date().getTime(); var elapsed = 0; while (elapsed < target) { elapsed = new Date().getTime() - start; } "done";`
invoker := NewJavaScriptScriptInvoker()

ctx1, cancel1 := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel1()
_, err := invoker.Invoke(ctx1, script, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "javascript execution timeout")

ctx2, cancel2 := context.WithTimeout(context.Background(), 400*time.Millisecond)
defer cancel2()
result, err := invoker.Invoke(ctx2, script, nil)
assert.NoError(t, err, "Scenario 2: script execution should not return error")
assert.Equal(t, "done", result, "Scenario 2: should return 'done'")
}

func TestJavaScriptScriptInvoker_Invoke_Concurrent(t *testing.T) {
invoker := NewJavaScriptScriptInvoker()
ctx := context.Background()
var wg sync.WaitGroup
concurrency := 100
errChan := make(chan error, concurrency)

script := `a + b`
params := map[string]interface{}{"a": 10, "b": 20}

for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
result, err := invoker.Invoke(ctx, script, params)
if err != nil {
errChan <- err
return
}
if result != float64(30) {
errChan <- assert.AnError
}
}()
}

wg.Wait()
close(errChan)

assert.Empty(t, errChan, "Concurrent execution has errors")
}

func TestJavaScriptScriptInvoker_Close(t *testing.T) {
invoker := NewJavaScriptScriptInvoker()
ctx := context.Background()

result, err := invoker.Invoke(ctx, "1 + 1", nil)
assert.NoError(t, err)
assert.Equal(t, float64(2), result)

err = invoker.Close(ctx)
assert.NoError(t, err)

_, err = invoker.Invoke(ctx, "1 + 1", nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "javascript invoker has been closed")
}

func TestOttoScript(t *testing.T) {
vm := otto.New()
script := `var target = 300; var start = new Date().getTime(); var elapsed = 0; while (elapsed < target) { elapsed = new Date().getTime() - start; } "done";`
val, err := vm.Run(script)
if err != nil {
t.Fatalf("otto failed to parse script: %v", err)
}

result, exportErr := val.Export()
if exportErr != nil {
t.Fatalf("failed to export otto value: %v", exportErr)
}
t.Logf("Script execution result: %v", result)
}

func TestJavaScriptScriptInvoker_VMPoolReuse(t *testing.T) {
poolSize := 2
invoker := NewJavaScriptScriptInvokerWithPoolSize(poolSize)
ctx := context.Background()

vmIDs := make([]string, 0, 5)

script := `
if (!this.vmId) {
this.vmId = Math.random().toString(36).substr(2, 8);
}
this.vmId;
`

for i := 0; i < 5; i++ {
result, err := invoker.Invoke(ctx, script, nil)
assert.NoError(t, err, "Error occurred while executing script")

id, ok := result.(string)
assert.True(t, ok, "VM ID should be a string type")
vmIDs = append(vmIDs, id)
}

uniqueIDs := make(map[string]bool)
for _, id := range vmIDs {
uniqueIDs[id] = true
}

assert.True(t, len(uniqueIDs) <= 5, "Abnormal number of VM instances created")
assert.True(t, len(uniqueIDs) >= 1, "No VM instances reused from the pool")
}

func TestJavaScriptScriptInvoker_VMStateClean(t *testing.T) {
invoker := NewJavaScriptScriptInvokerWithPoolSize(1)
ctx := context.Background()

_, err := invoker.Invoke(ctx, `this.foo = "polluted data"`, nil)
assert.NoError(t, err)

result, err := invoker.Invoke(ctx, `typeof this.foo`, nil)
assert.NoError(t, err)
assert.Equal(t, "undefined", result, "VM state not cleaned, residual global variable exists")

_, err = invoker.Invoke(ctx, `this.bar = function() { return "residual function"; }`, nil)
assert.NoError(t, err)

result, err = invoker.Invoke(ctx, `typeof this.bar`, nil)
assert.NoError(t, err)
assert.Equal(t, "undefined", result, "VM state not cleaned, residual function exists")
}

func TestJavaScriptScriptInvoker_PoolSizeDefault(t *testing.T) {
invoker := NewJavaScriptScriptInvokerWithPoolSize(0)
assert.Equal(t, defaultPoolSize, invoker.poolSize, "Default pool size not used when pool size is 0")

invoker = NewJavaScriptScriptInvokerWithPoolSize(-5)
assert.Equal(t, defaultPoolSize, invoker.poolSize, "Default pool size not used when pool size is negative")
}

+ 195
- 0
pkg/saga/statemachine/engine/invoker/local_invoker.go View File

@@ -0,0 +1,195 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"fmt"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"reflect"
"sync"
)

type LocalServiceInvoker struct {
serviceRegistry map[string]interface{}
methodCache map[string]*reflect.Method
jsonParser JsonParser
mutex sync.RWMutex
}

func NewLocalServiceInvoker() *LocalServiceInvoker {
return &LocalServiceInvoker{
serviceRegistry: make(map[string]interface{}),
methodCache: make(map[string]*reflect.Method),
jsonParser: &DefaultJsonParser{},
}
}

func (l *LocalServiceInvoker) RegisterService(serviceName string, instance interface{}) {
l.mutex.Lock()
defer l.mutex.Unlock()
l.serviceRegistry[serviceName] = instance
}

func (l *LocalServiceInvoker) Invoke(ctx context.Context, input []any, service state.ServiceTaskState) ([]reflect.Value, error) {
serviceName := service.ServiceName()
instance, exists := l.serviceRegistry[serviceName]
if !exists {
return nil, fmt.Errorf("service %s not registered", serviceName)
}

methodName := service.ServiceMethod()
method, err := l.getMethod(serviceName, methodName, service.ParameterTypes())
if err != nil {
return nil, err
}

params, err := l.resolveParameters(input, method.Type)
if err != nil {
return nil, err
}

return l.invokeMethod(instance, method, params), nil
}

func (l *LocalServiceInvoker) resolveMethod(key, serviceName, methodName string) (*reflect.Method, error) {
l.mutex.Lock()
defer l.mutex.Unlock()

if cachedMethod, ok := l.methodCache[key]; ok {
return cachedMethod, nil
}

instance, exists := l.serviceRegistry[serviceName]
if !exists {
return nil, fmt.Errorf("service %s not found", serviceName)
}

objType := reflect.TypeOf(instance)
method, ok := objType.MethodByName(methodName)
if !ok {
return nil, fmt.Errorf("method %s not found in service %s", methodName, serviceName)
}

l.methodCache[key] = &method
return &method, nil
}

func (l *LocalServiceInvoker) getMethod(serviceName, methodName string, paramTypes []string) (*reflect.Method, error) {
key := fmt.Sprintf("%s.%s", serviceName, methodName)

l.mutex.RLock()
if method, ok := l.methodCache[key]; ok {
l.mutex.RUnlock()
return method, nil
}
l.mutex.RUnlock()

return l.resolveMethod(key, serviceName, methodName)
}

func (l *LocalServiceInvoker) resolveParameters(input []any, methodType reflect.Type) ([]reflect.Value, error) {
numIn := methodType.NumIn()
paramStart, paramCount := 1, 0

if numIn > 0 {
paramCount = numIn - paramStart
}

if paramCount == 0 {
if len(input) > 0 {
return nil, fmt.Errorf("unexpected parameters: expected 0, got %d", len(input))
}
return []reflect.Value{}, nil
}

if len(input) < paramCount {
return nil, fmt.Errorf("insufficient parameters: expected %d, got %d", paramCount, len(input))
}

if len(input) > paramCount {
return nil, fmt.Errorf("too many parameters: expected %d, got %d", paramCount, len(input))
}

params := make([]reflect.Value, paramCount)
for i := 0; i < paramCount; i++ {
methodParamIndex := i + paramStart
paramType := methodType.In(methodParamIndex)

converted, err := l.convertParam(input[i], paramType)
if err != nil {
return nil, fmt.Errorf("parameter %d conversion error: %w", i, err)
}

params[i] = reflect.ValueOf(converted)
}

return params, nil
}

func (l *LocalServiceInvoker) convertParam(value any, targetType reflect.Type) (any, error) {
if targetType.Kind() == reflect.Ptr {
elemType := targetType.Elem()
instance := reflect.New(elemType).Interface()
jsonData, err := l.jsonParser.Marshal(value)
if err != nil {
return nil, err
}
if err := l.jsonParser.Unmarshal(jsonData, instance); err != nil {
return nil, err
}
return instance, nil
}

if targetType.Kind() == reflect.Struct {
instance := reflect.New(targetType).Interface()
jsonData, err := l.jsonParser.Marshal(value)
if err != nil {
return nil, err
}
if err := l.jsonParser.Unmarshal(jsonData, instance); err != nil {
return nil, err
}
return reflect.ValueOf(instance).Elem().Interface(), nil
}

if targetType.Kind() == reflect.Int && reflect.TypeOf(value).Kind() == reflect.Float64 {
return int(value.(float64)), nil
} else if targetType == reflect.TypeOf("") && reflect.TypeOf(value).Kind() == reflect.Int {
return fmt.Sprintf("%d", value), nil
}

return value, nil
}

func (l *LocalServiceInvoker) invokeMethod(instance interface{}, method *reflect.Method, params []reflect.Value) []reflect.Value {
instanceValue := reflect.ValueOf(instance)
if method.Func.IsValid() {
allParams := append([]reflect.Value{instanceValue}, params...)
return method.Func.Call(allParams)
}
return nil
}

func (l *LocalServiceInvoker) Close(ctx context.Context) error {
l.mutex.Lock()
defer l.mutex.Unlock()
l.serviceRegistry = nil
l.methodCache = nil
return nil
}

+ 212
- 0
pkg/saga/statemachine/engine/invoker/local_invoker_test.go View File

@@ -0,0 +1,212 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"errors"
"fmt"
"reflect"
"testing"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
)

type MockLocalService struct {
invokeCount int
}

func (m *MockLocalService) GetServiceName() string {
return "MockLocalService"
}

func (m *MockLocalService) Add(a, b int) int {
m.invokeCount++
return a + b
}

func (m *MockLocalService) Multiply(f float64, i int) float64 {
m.invokeCount++
return f * float64(i)
}

type User struct {
Name string `json:"name"`
Age int `json:"age"`
}

func (m *MockLocalService) GetUserName(user User) string {
m.invokeCount++
return user.Name
}

func (m *MockLocalService) ErrorMethod() error {
return errors.New("expected error")
}

func TestLocalInvoker_ServiceNotRegistered(t *testing.T) {
invoker := NewLocalServiceInvoker()
ctx := context.Background()
taskState := newLocalServiceTaskState("unregisteredService", "AnyMethod")

_, err := invoker.Invoke(ctx, []any{}, taskState)
if err == nil {
t.Error("expected error when service not registered, but got nil")
}
if err.Error() != "service unregisteredService not registered" {
t.Errorf("unexpected error message: %v", err)
}
}

func TestLocalInvoker_MethodNotFound(t *testing.T) {
invoker := NewLocalServiceInvoker()
service := &MockLocalService{}
invoker.RegisterService("mockService", service)

ctx := context.Background()
taskState := newLocalServiceTaskState("mockService", "NonExistentMethod")

_, err := invoker.Invoke(ctx, []any{}, taskState)
if err == nil {
t.Error("expected error when method not found, but got nil")
}
if err.Error() != "method NonExistentMethod not found in service mockService" {
t.Errorf("unexpected error message: %v", err)
}
}

func TestLocalInvoker_InvokeSuccess(t *testing.T) {
tests := []struct {
name string
service interface{}
serviceName string
methodName string
input []any
expected interface{}
}{
{
name: "test basic method call",
service: &MockLocalService{},
serviceName: "mockService",
methodName: "GetServiceName",
input: []any{},
expected: "MockLocalService",
},
{
name: "test method with parameters",
service: &MockLocalService{},
serviceName: "mockService",
methodName: "Add",
input: []any{2, 3},
expected: 5,
},
{
name: "test parameter type conversion",
service: &MockLocalService{},
serviceName: "mockService",
methodName: "Multiply",
input: []any{2.5, 4},
expected: 10.0,
},
}

invoker := NewLocalServiceInvoker()
ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
invoker.RegisterService(tt.serviceName, tt.service)
taskState := newLocalServiceTaskState(tt.serviceName, tt.methodName)

results, err := invoker.Invoke(ctx, tt.input, taskState)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(results) == 0 {
t.Fatal("no results returned")
}

result := results[0].Interface()
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

func TestLocalInvoker_StructParameterConversion(t *testing.T) {
invoker := NewLocalServiceInvoker()
service := &MockLocalService{}
invoker.RegisterService("userService", service)

ctx := context.Background()
taskState := newLocalServiceTaskState("userService", "GetUserName")

input := []any{map[string]interface{}{"name": "Alice", "age": 30}}
results, err := invoker.Invoke(ctx, input, taskState)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(results) == 0 {
t.Fatal("no results returned")
}

result := results[0].Interface()
if result != "Alice" {
t.Errorf("expected 'Alice', got %v", result)
}
}

func TestLocalInvoker_MethodCaching(t *testing.T) {
invoker := NewLocalServiceInvoker()
service := &MockLocalService{}
invoker.RegisterService("cacheTestService", service)

ctx := context.Background()
taskState := newLocalServiceTaskState("cacheTestService", "Add")

_, err := invoker.Invoke(ctx, []any{1, 1}, taskState)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

results, err := invoker.Invoke(ctx, []any{2, 3}, taskState)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if results[0].Interface() != 5 {
t.Errorf("expected 5, got %v", results[0].Interface())
}

if service.invokeCount != 2 {
t.Errorf("expected 2 invocations, got %d", service.invokeCount)
}
}

func newLocalServiceTaskState(serviceName, methodName string) state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName(fmt.Sprintf("%s_%s", serviceName, methodName))
serviceTaskStateImpl.SetIsAsync(false)
serviceTaskStateImpl.SetServiceName(serviceName)
serviceTaskStateImpl.SetServiceType("local")
serviceTaskStateImpl.SetServiceMethod(methodName)
return serviceTaskStateImpl
}

+ 81
- 0
pkg/saga/statemachine/engine/pcext/compensation_holder.go View File

@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package pcext

import (
"context"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/util/collection"
"sync"
)

type CompensationHolder struct {
statesNeedCompensation *sync.Map
statesForCompensation *sync.Map
stateStackNeedCompensation *collection.Stack
}

func (c *CompensationHolder) StatesNeedCompensation() *sync.Map {
return c.statesNeedCompensation
}

func (c *CompensationHolder) SetStatesNeedCompensation(statesNeedCompensation *sync.Map) {
c.statesNeedCompensation = statesNeedCompensation
}

func (c *CompensationHolder) StatesForCompensation() *sync.Map {
return c.statesForCompensation
}

func (c *CompensationHolder) SetStatesForCompensation(statesForCompensation *sync.Map) {
c.statesForCompensation = statesForCompensation
}

func (c *CompensationHolder) StateStackNeedCompensation() *collection.Stack {
return c.stateStackNeedCompensation
}

func (c *CompensationHolder) SetStateStackNeedCompensation(stateStackNeedCompensation *collection.Stack) {
c.stateStackNeedCompensation = stateStackNeedCompensation
}

func (c *CompensationHolder) AddToBeCompensatedState(stateName string, toBeCompensatedState statelang.StateInstance) {
c.statesNeedCompensation.Store(stateName, toBeCompensatedState)
}

func NewCompensationHolder() *CompensationHolder {
return &CompensationHolder{
statesNeedCompensation: &sync.Map{},
statesForCompensation: &sync.Map{},
stateStackNeedCompensation: collection.NewStack(),
}
}

func GetCurrentCompensationHolder(ctx context.Context, processContext process_ctrl.ProcessContext, forceCreate bool) *CompensationHolder {
compensationholder := processContext.GetVariable(constant.VarNameCurrentCompensationHolder).(*CompensationHolder)
lock := processContext.GetVariable(constant.VarNameProcessContextMutexLock).(*sync.Mutex)
lock.Lock()
defer lock.Unlock()
if compensationholder == nil && forceCreate {
compensationholder = NewCompensationHolder()
processContext.SetVariable(constant.VarNameCurrentCompensationHolder, compensationholder)
}
return compensationholder
}

+ 186
- 0
pkg/saga/statemachine/engine/pcext/engine_utils.go View File

@@ -0,0 +1,186 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package pcext

import (
"context"
"errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"github.com/seata/seata-go/pkg/util/log"
"golang.org/x/sync/semaphore"
"reflect"
"strings"
"sync"
"time"
)

func EndStateMachine(ctx context.Context, processContext process_ctrl.ProcessContext) error {
if processContext.HasVariable(constant.VarNameIsLoopState) {
if processContext.HasVariable(constant.LoopSemaphore) {
weighted, ok := processContext.GetVariable(constant.LoopSemaphore).(semaphore.Weighted)
if !ok {
return errors.New("semaphore type is not weighted")
}
weighted.Release(1)
}
}

stateMachineInstance, ok := processContext.GetVariable(constant.VarNameStateMachineInst).(statelang.StateMachineInstance)
if !ok {
return errors.New("state machine instance type is not statelang.StateMachineInstance")
}

stateMachineInstance.SetEndTime(time.Now())

exp, ok := processContext.GetVariable(constant.VarNameCurrentException).(error)
if !ok {
return errors.New("exception type is not error")
}

if exp != nil {
stateMachineInstance.SetException(exp)
log.Debugf("Exception Occurred: %s", exp)
}

stateMachineConfig, ok := processContext.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig)

if err := stateMachineConfig.StatusDecisionStrategy().DecideOnEndState(ctx, processContext, stateMachineInstance, exp); err != nil {
return err
}

contextParams, ok := processContext.GetVariable(constant.VarNameStateMachineContext).(map[string]interface{})
if !ok {
return errors.New("state machine context type is not map[string]interface{}")
}
endParams := stateMachineInstance.EndParams()
for k, v := range contextParams {
endParams[k] = v
}
stateMachineInstance.SetEndParams(endParams)

stateInstruction, ok := processContext.GetInstruction().(StateInstruction)
if !ok {
return errors.New("state instruction type is not process_ctrl.StateInstruction")
}
stateInstruction.SetEnd(true)

stateMachineInstance.SetRunning(false)
stateMachineInstance.SetEndTime(time.Now())

if stateMachineInstance.StateMachine().IsPersist() && stateMachineConfig.StateLangStore() != nil {
err := stateMachineConfig.StateLogStore().RecordStateMachineFinished(ctx, stateMachineInstance, processContext)
if err != nil {
return err
}
}

callBack, ok := processContext.GetVariable(constant.VarNameAsyncCallback).(engine.CallBack)
if ok {
if exp != nil {
callBack.OnError(ctx, processContext, stateMachineInstance, exp)
} else {
callBack.OnFinished(ctx, processContext, stateMachineInstance)
}
}

return nil
}

func HandleException(processContext process_ctrl.ProcessContext, abstractTaskState *state.AbstractTaskState, err error) {
catches := abstractTaskState.Catches()
if catches != nil && len(catches) != 0 {
for _, exceptionMatch := range catches {
exceptions := exceptionMatch.Exceptions()
exceptionTypes := exceptionMatch.ExceptionTypes()
if exceptions != nil && len(exceptions) != 0 {
if exceptionTypes == nil {
lock := processContext.GetVariable(constant.VarNameProcessContextMutexLock).(*sync.Mutex)
lock.Lock()
defer lock.Unlock()
error := errors.New("")
for i := 0; i < len(exceptions); i++ {
exceptionTypes = append(exceptionTypes, reflect.TypeOf(error))
}
}

exceptionMatch.SetExceptionTypes(exceptionTypes)
}

for i, _ := range exceptionTypes {
if reflect.TypeOf(err) == exceptionTypes[i] {
// HACK: we can not get error type in config file during runtime, so we use exception str
if strings.Contains(err.Error(), exceptions[i]) {
hierarchicalProcessContext := processContext.(process_ctrl.HierarchicalProcessContext)
hierarchicalProcessContext.SetVariable(constant.VarNameCurrentExceptionRoute, exceptionMatch.Next())
return
}
}
}
}
}

log.Error("Task execution failed and no catches configured")
hierarchicalProcessContext := processContext.(process_ctrl.HierarchicalProcessContext)
hierarchicalProcessContext.SetVariable(constant.VarNameIsExceptionNotCatch, true)
}

// GetOriginStateName get origin state name without suffix like fork
func GetOriginStateName(stateInstance statelang.StateInstance) string {
stateName := stateInstance.Name()
if stateName != "" {
end := strings.LastIndex(stateName, constant.LoopStateNamePattern)
if end > -1 {
return stateName[:end+1]
}
}
return stateName
}

// IsTimeout test if is timeout
func IsTimeout(gmtUpdated time.Time, timeoutMillis int) bool {
if timeoutMillis < 0 {
return false
}
return time.Now().Unix()-gmtUpdated.Unix() > int64(timeoutMillis)
}

func GenerateParentId(stateInstance statelang.StateInstance) string {
return stateInstance.MachineInstanceID() + constant.SeperatorParentId + stateInstance.ID()
}

// GetNetExceptionType Speculate what kind of network anomaly is caused by the error
func GetNetExceptionType(err error) constant.NetExceptionType {
if err == nil {
return constant.NotNetException
}

// If it contains a specific error message, simply guess
errMsg := err.Error()
if strings.Contains(errMsg, "connection refused") {
return constant.ConnectException
} else if strings.Contains(errMsg, "timeout") {
return constant.ConnectTimeoutException
} else if strings.Contains(errMsg, "i/o timeout") {
return constant.ReadTimeoutException
}
return constant.NotNetException
}

+ 112
- 0
pkg/saga/statemachine/engine/pcext/instruction.go View File

@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package pcext

import (
"errors"
"fmt"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
)

type StateInstruction struct {
stateName string
stateMachineName string
tenantId string
end bool
temporaryState statelang.State
}

func NewStateInstruction(stateMachineName string, tenantId string) *StateInstruction {
return &StateInstruction{stateMachineName: stateMachineName, tenantId: tenantId}
}

func (s *StateInstruction) StateName() string {
return s.stateName
}

func (s *StateInstruction) SetStateName(stateName string) {
s.stateName = stateName
}

func (s *StateInstruction) StateMachineName() string {
return s.stateMachineName
}

func (s *StateInstruction) SetStateMachineName(stateMachineName string) {
s.stateMachineName = stateMachineName
}

func (s *StateInstruction) TenantId() string {
return s.tenantId
}

func (s *StateInstruction) SetTenantId(tenantId string) {
s.tenantId = tenantId
}

func (s *StateInstruction) End() bool {
return s.end
}

func (s *StateInstruction) SetEnd(end bool) {
s.end = end
}

func (s *StateInstruction) TemporaryState() statelang.State {
return s.temporaryState
}

func (s *StateInstruction) SetTemporaryState(temporaryState statelang.State) {
s.temporaryState = temporaryState
}

func (s *StateInstruction) GetState(context process_ctrl.ProcessContext) (statelang.State, error) {
if s.temporaryState != nil {
return s.temporaryState, nil
}

if s.stateMachineName == "" {
return nil, errors.New("stateMachineName is required")
}

stateMachineConfig, ok := context.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig)
if !ok {
return nil, errors.New("stateMachineConfig is required in context")
}
stateMachine, err := stateMachineConfig.StateMachineRepository().GetLastVersionStateMachine(s.stateMachineName, s.tenantId)
if err != nil {
return nil, errors.New("get stateMachine in state machine repository error")
}
if stateMachine == nil {
return nil, errors.New(fmt.Sprintf("stateMachine [%s] is not exist", s.stateMachineName))
}

if s.stateName == "" {
s.stateName = stateMachine.StartState()
}

state := stateMachine.States()[s.stateName]
if state == nil {
return nil, errors.New(fmt.Sprintf("state [%s] is not exist", s.stateName))
}

return state, nil
}

+ 110
- 0
pkg/saga/statemachine/engine/pcext/loop_context_holder.go View File

@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package pcext

import (
"context"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"sync"
)

type LoopContextHolder struct {
nrOfInstances int32
nrOfActiveInstances int32
nrOfCompletedInstances int32
failEnd bool
completionConditionSatisfied bool
loopCounterStack []int
forwardCounterStack []int
collection interface{}
}

func NewLoopContextHolder() *LoopContextHolder {
return &LoopContextHolder{
nrOfInstances: 0,
nrOfActiveInstances: 0,
nrOfCompletedInstances: 0,
failEnd: false,
completionConditionSatisfied: false,
loopCounterStack: make([]int, 0),
forwardCounterStack: make([]int, 0),
collection: nil,
}
}

func GetCurrentLoopContextHolder(ctx context.Context, processContext process_ctrl.ProcessContext, forceCreate bool) *LoopContextHolder {
mutex := processContext.GetVariable(constant.VarNameProcessContextMutexLock).(*sync.Mutex)
mutex.Lock()
defer mutex.Unlock()

loopContextHolder := processContext.GetVariable(constant.VarNameCurrentLoopContextHolder).(*LoopContextHolder)
if loopContextHolder == nil && forceCreate {
loopContextHolder = &LoopContextHolder{}
processContext.SetVariable(constant.VarNameCurrentLoopContextHolder, loopContextHolder)
}
return loopContextHolder
}

func ClearCurrent(ctx context.Context, processContext process_ctrl.ProcessContext) {
processContext.RemoveVariable(constant.VarNameCurrentLoopContextHolder)
}

func (l *LoopContextHolder) NrOfInstances() int32 {
return l.nrOfInstances
}

func (l *LoopContextHolder) NrOfActiveInstances() int32 {
return l.nrOfActiveInstances
}

func (l *LoopContextHolder) NrOfCompletedInstances() int32 {
return l.nrOfCompletedInstances
}

func (l *LoopContextHolder) FailEnd() bool {
return l.failEnd
}

func (l *LoopContextHolder) SetFailEnd(failEnd bool) {
l.failEnd = failEnd
}

func (l *LoopContextHolder) CompletionConditionSatisfied() bool {
return l.completionConditionSatisfied
}

func (l *LoopContextHolder) SetCompletionConditionSatisfied(completionConditionSatisfied bool) {
l.completionConditionSatisfied = completionConditionSatisfied
}

func (l *LoopContextHolder) LoopCounterStack() []int {
return l.loopCounterStack
}

func (l *LoopContextHolder) ForwardCounterStack() []int {
return l.forwardCounterStack
}

func (l *LoopContextHolder) Collection() interface{} {
return l.collection
}

func (l *LoopContextHolder) SetCollection(collection interface{}) {
l.collection = collection
}

+ 59
- 0
pkg/saga/statemachine/engine/pcext/loop_task_utils.go View File

@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package pcext

import (
"context"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"github.com/seata/seata-go/pkg/util/log"
)

func GetLoopConfig(ctx context.Context, processContext process_ctrl.ProcessContext, currentState statelang.State) state.Loop {
if matchLoop(currentState) {
taskState := currentState.(state.AbstractTaskState)
stateMachineInstance := processContext.GetVariable(constant.VarNameStateMachineInst).(statelang.StateMachineInstance)
stateMachineConfig := processContext.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig)

if taskState.Loop() != nil {
loop := taskState.Loop()
collectionName := loop.Collection()
if collectionName != "" {
expression := CreateValueExpression(stateMachineConfig.ExpressionResolver(), collectionName)
collection := GetValue(expression, stateMachineInstance.Context(), nil)
collectionList := collection.([]any)
if len(collectionList) > 0 {
current := GetCurrentLoopContextHolder(ctx, processContext, true)
current.SetCollection(collection)
return loop
}
}
log.Warn("State [{}] loop collection param [{}] invalid", currentState.Name(), collectionName)
}

}
return nil
}

func matchLoop(currentState statelang.State) bool {
return currentState != nil && (constant.StateTypeServiceTask == currentState.Type() ||
constant.StateTypeScriptTask == currentState.Type() || constant.StateTypeSubStateMachine == currentState.Type())
}

+ 151
- 0
pkg/saga/statemachine/engine/pcext/parameter_utils.go View File

@@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package pcext

import (
"fmt"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/expr"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"strings"
"sync"
)

func CreateInputParams(processContext process_ctrl.ProcessContext, expressionResolver expr.ExpressionResolver,
stateInstance *statelang.StateInstanceImpl, serviceTaskState *state.AbstractTaskState, variablesFrom any) []any {
inputAssignments := serviceTaskState.Input()
if inputAssignments == nil || len(inputAssignments) == 0 {
return inputAssignments
}

inputExpressions := serviceTaskState.InputExpressions()
if inputExpressions == nil || len(inputExpressions) == 0 {
lock := processContext.GetVariable(constant.VarNameProcessContextMutexLock).(*sync.Mutex)
lock.Lock()
defer lock.Unlock()
inputExpressions = serviceTaskState.InputExpressions()
if inputExpressions == nil || len(inputExpressions) == 0 {
inputExpressions = make([]any, 0, len(inputAssignments))

for _, assignment := range inputAssignments {
inputExpressions = append(inputExpressions, CreateValueExpression(expressionResolver, assignment))
}
}
serviceTaskState.SetInputExpressions(inputExpressions)
}
inputValues := make([]any, 0, len(inputExpressions))
for _, valueExpression := range inputExpressions {
value := GetValue(valueExpression, variablesFrom, stateInstance)
inputValues = append(inputValues, value)
}

return inputValues
}

func CreateOutputParams(config engine.StateMachineConfig, expressionResolver expr.ExpressionResolver,
serviceTaskState *state.AbstractTaskState, variablesFrom any) (map[string]any, error) {
outputAssignments := serviceTaskState.Output()
if outputAssignments == nil || len(outputAssignments) == 0 {
return make(map[string]any, 0), nil
}

outputExpressions := serviceTaskState.OutputExpressions()
if outputExpressions == nil {
config.ComponentLock().Lock()
defer config.ComponentLock().Unlock()
outputExpressions = serviceTaskState.OutputExpressions()
if outputExpressions == nil {
outputExpressions = make(map[string]any, len(outputAssignments))
for key, value := range outputAssignments {
outputExpressions[key] = CreateValueExpression(expressionResolver, value)
}
}
serviceTaskState.SetOutputExpressions(outputExpressions)
}
outputValues := make(map[string]any, len(outputExpressions))
for paramName, _ := range outputExpressions {
outputValues[paramName] = GetValue(outputExpressions[paramName], variablesFrom, nil)
}
return outputValues, nil
}

func CreateValueExpression(expressionResolver expr.ExpressionResolver, paramAssignment any) any {
var valueExpression any

switch paramAssignment.(type) {
case expr.Expression:
valueExpression = paramAssignment
case map[string]any:
paramMapAssignment := paramAssignment.(map[string]any)
paramMap := make(map[string]any, len(paramMapAssignment))
for key, value := range paramMapAssignment {
paramMap[key] = CreateValueExpression(expressionResolver, value)
}
valueExpression = paramMap
case []any:
paramListAssignment := paramAssignment.([]any)
paramList := make([]any, 0, len(paramListAssignment))
for _, value := range paramListAssignment {
paramList = append(paramList, CreateValueExpression(expressionResolver, value))
}
valueExpression = paramList
case string:
value := paramAssignment.(string)
if !strings.HasPrefix(value, "$") {
valueExpression = paramAssignment
}
valueExpression = expressionResolver.Expression(value)
default:
valueExpression = paramAssignment
}
return valueExpression
}

func GetValue(valueExpression any, variablesFrom any, stateInstance statelang.StateInstance) any {
switch valueExpression.(type) {
case expr.Expression:
expression := valueExpression.(expr.Expression)
value := expression.Value(variablesFrom)
if _, ok := valueExpression.(expr.SequenceExpression); value != nil && stateInstance != nil && stateInstance.BusinessKey() == "" && ok {
stateInstance.SetBusinessKey(fmt.Sprintf("%v", value))
}
return value
case map[string]any:
mapValueExpression := valueExpression.(map[string]any)
mapValue := make(map[string]any, len(mapValueExpression))
for key, value := range mapValueExpression {
value = GetValue(value, variablesFrom, stateInstance)
if value != nil {
mapValue[key] = value
}
}
return mapValue
case []any:
valueExpressionList := valueExpression.([]any)
listValue := make([]any, 0, len(valueExpression.([]any)))
for i, _ := range valueExpressionList {
listValue = append(listValue, GetValue(valueExpressionList[i], variablesFrom, stateInstance))
}
return listValue
default:
return valueExpression
}
}

+ 115
- 0
pkg/saga/statemachine/engine/pcext/process_handler.go View File

@@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package pcext

import (
"context"
"errors"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"sync"
)

type StateHandler interface {
State() string
process_ctrl.ProcessHandler
}

type InterceptAbleStateHandler interface {
StateHandler
StateHandlerInterceptorList() []StateHandlerInterceptor
RegistryStateHandlerInterceptor(stateHandlerInterceptor StateHandlerInterceptor)
}

type StateHandlerInterceptor interface {
PreProcess(ctx context.Context, processContext process_ctrl.ProcessContext) error
PostProcess(ctx context.Context, processContext process_ctrl.ProcessContext) error
Match(stateType string) bool
}

type StateMachineProcessHandler struct {
mp map[string]StateHandler
mu sync.RWMutex
}

func NewStateMachineProcessHandler() *StateMachineProcessHandler {
return &StateMachineProcessHandler{
mp: make(map[string]StateHandler),
}
}

func (s *StateMachineProcessHandler) Process(ctx context.Context, processContext process_ctrl.ProcessContext) error {
stateInstruction, _ := processContext.GetInstruction().(StateInstruction)

state, err := stateInstruction.GetState(processContext)
if err != nil {
return err
}

stateType := state.Type()
stateHandler := s.GetStateHandler(stateType)
if stateHandler == nil {
return errors.New("Not support [" + stateType + "] state handler")
}

interceptAbleStateHandler, ok := stateHandler.(InterceptAbleStateHandler)

var stateHandlerInterceptorList []StateHandlerInterceptor
if ok {
stateHandlerInterceptorList = interceptAbleStateHandler.StateHandlerInterceptorList()
}

if stateHandlerInterceptorList != nil && len(stateHandlerInterceptorList) > 0 {
for _, stateHandlerInterceptor := range stateHandlerInterceptorList {
err = stateHandlerInterceptor.PreProcess(ctx, processContext)
if err != nil {
return err
}
}
}

err = stateHandler.Process(ctx, processContext)
if err != nil {
return err
}

if stateHandlerInterceptorList != nil && len(stateHandlerInterceptorList) > 0 {
for _, stateHandlerInterceptor := range stateHandlerInterceptorList {
err = stateHandlerInterceptor.PostProcess(ctx, processContext)
if err != nil {
return err
}
}
}

return nil
}

func (s *StateMachineProcessHandler) GetStateHandler(stateType string) StateHandler {
s.mu.RLock()
defer s.mu.RUnlock()
return s.mp[stateType]
}

func (s *StateMachineProcessHandler) RegistryStateHandler(stateType string, stateHandler StateHandler) {
s.mu.Lock()
defer s.mu.Unlock()
if s.mp == nil {
s.mp = make(map[string]StateHandler)
}
s.mp[stateType] = stateHandler
}

+ 131
- 0
pkg/saga/statemachine/engine/pcext/process_router.go View File

@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package pcext

import (
"context"
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
)

type StateMachineProcessRouter struct {
stateRouters map[string]process_ctrl.StateRouter
}

func (s *StateMachineProcessRouter) Route(ctx context.Context, processContext process_ctrl.ProcessContext) (process_ctrl.Instruction, error) {
stateInstruction, ok := processContext.GetInstruction().(StateInstruction)
if !ok {
return nil, errors.New("instruction is not a state instruction")
}

var state statelang.State
if stateInstruction.TemporaryState() != nil {
state = stateInstruction.TemporaryState()
stateInstruction.SetTemporaryState(nil)
} else {
stateMachineConfig, ok := processContext.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig)
if !ok {
return nil, errors.New("state machine config not found")
}

stateMachine, err := stateMachineConfig.StateMachineRepository().GetStateMachineByNameAndTenantId(stateInstruction.StateMachineName(),
stateInstruction.TenantId())
if err != nil {
return nil, err
}

state = stateMachine.States()[stateInstruction.StateName()]
}

stateType := state.Type()
router := s.stateRouters[stateType]

var interceptors []process_ctrl.StateRouterInterceptor
if interceptAbleStateRouter, ok := router.(process_ctrl.InterceptAbleStateRouter); ok {
interceptors = interceptAbleStateRouter.StateRouterInterceptor()
}

var executedInterceptors []process_ctrl.StateRouterInterceptor
var exception error
instruction, exception := func() (process_ctrl.Instruction, error) {
if interceptors == nil || len(executedInterceptors) == 0 {
executedInterceptors = make([]process_ctrl.StateRouterInterceptor, 0, len(interceptors))
for _, interceptor := range interceptors {
executedInterceptors = append(executedInterceptors, interceptor)
err := interceptor.PreRoute(ctx, processContext, state)
if err != nil {
return nil, err
}
}
}

instruction, err := router.Route(ctx, processContext, state)
if err != nil {
return nil, err
}
return instruction, nil
}()

if interceptors == nil || len(executedInterceptors) == 0 {
for i := len(executedInterceptors) - 1; i >= 0; i-- {
err := executedInterceptors[i].PostRoute(ctx, processContext, instruction, exception)
if err != nil {
return nil, err
}
}

// if 'Succeed' or 'Fail' State did not configured, we must end the state machine
if instruction == nil && !stateInstruction.End() {
err := EndStateMachine(ctx, processContext)
if err != nil {
return nil, err
}
}
}

return instruction, nil
}

func (s *StateMachineProcessRouter) InitDefaultStateRouters() {
if s.stateRouters == nil || len(s.stateRouters) == 0 {
s.stateRouters = make(map[string]process_ctrl.StateRouter)
taskStateRouter := &TaskStateRouter{}
s.stateRouters[constant.StateTypeServiceTask] = taskStateRouter
s.stateRouters[constant.StateTypeScriptTask] = taskStateRouter
s.stateRouters[constant.StateTypeChoice] = taskStateRouter
s.stateRouters[constant.StateTypeCompensationTrigger] = taskStateRouter
s.stateRouters[constant.StateTypeSubStateMachine] = taskStateRouter
s.stateRouters[constant.StateTypeCompensateSubMachine] = taskStateRouter
s.stateRouters[constant.StateTypeLoopStart] = taskStateRouter

endStateRouter := &EndStateRouter{}
s.stateRouters[constant.StateTypeSucceed] = endStateRouter
s.stateRouters[constant.StateTypeFail] = endStateRouter
}
}

func (s *StateMachineProcessRouter) StateRouters() map[string]process_ctrl.StateRouter {
return s.stateRouters
}

func (s *StateMachineProcessRouter) SetStateRouters(stateRouters map[string]process_ctrl.StateRouter) {
s.stateRouters = stateRouters
}

+ 169
- 0
pkg/saga/statemachine/engine/pcext/state_router_impl.go View File

@@ -0,0 +1,169 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package pcext

import (
"context"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/exception"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
sagaState "github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
seataErrors "github.com/seata/seata-go/pkg/util/errors"
"github.com/seata/seata-go/pkg/util/log"
)

type EndStateRouter struct {
}

func (e EndStateRouter) Route(ctx context.Context, processContext process_ctrl.ProcessContext, state statelang.State) (process_ctrl.Instruction, error) {
return nil, nil
}

type TaskStateRouter struct {
}

func (t TaskStateRouter) Route(ctx context.Context, processContext process_ctrl.ProcessContext, state statelang.State) (process_ctrl.Instruction, error) {
stateInstruction, _ := processContext.GetInstruction().(StateInstruction)
if stateInstruction.End() {
log.Infof("StateInstruction is ended, Stop the StateMachine executing. StateMachine[%s] Current State[%s]",
stateInstruction.StateMachineName(), stateInstruction.StateName())
}

// check if in loop async condition
isLoop, ok := processContext.GetVariable(constant.VarNameIsLoopState).(bool)
if ok && isLoop {
log.Infof("StateMachine[%s] Current State[%s] is in loop async condition, skip route processing.",
stateInstruction.StateMachineName(), stateInstruction.StateName())
return nil, nil
}

// The current CompensationTriggerState can mark the compensation process is started and perform compensation
// route processing.
compensationTriggerState, ok := processContext.GetVariable(constant.VarNameCurrentCompensateTriggerState).(statelang.State)
if ok {
return t.compensateRoute(ctx, processContext, compensationTriggerState)
}

// There is an exception route, indicating that an exception is thrown, and the exception route is prioritized.
next := processContext.GetVariable(constant.VarNameCurrentExceptionRoute).(string)

if next != "" {
processContext.RemoveVariable(constant.VarNameCurrentExceptionRoute)
} else {
next = state.Next()
}

// If next is empty, the state selected by the Choice state was taken.
if next == "" && processContext.HasVariable(constant.VarNameCurrentChoice) {
next = processContext.GetVariable(constant.VarNameCurrentChoice).(string)
processContext.RemoveVariable(constant.VarNameCurrentChoice)
}

if next == "" {
return nil, nil
}

stateMachine := state.StateMachine()
nextState := stateMachine.State(next)
if nextState == nil {
return nil, exception.NewEngineExecutionException(seataErrors.ObjectNotExists,
"Next state["+next+"] is not exits", nil)
}

stateInstruction.SetStateName(next)

if nil != GetLoopConfig(ctx, processContext, nextState) {
stateInstruction.SetTemporaryState(sagaState.NewLoopStartStateImpl())
}

return stateInstruction, nil
}

func (t *TaskStateRouter) compensateRoute(ctx context.Context, processContext process_ctrl.ProcessContext,
compensationTriggerState statelang.State) (process_ctrl.Instruction, error) {
//If there is already a compensation state that has been executed,
// it is judged whether it is wrong or unsuccessful,
// and the compensation process is interrupted.
isFirstCompensationStateStart := processContext.GetVariable(constant.VarNameFirstCompensationStateStarted).(bool)
if isFirstCompensationStateStart {
exception := processContext.GetVariable(constant.VarNameCurrentException).(error)
if exception != nil {
return nil, EndStateMachine(ctx, processContext)
}

stateInstance := processContext.GetVariable(constant.VarNameStateInst).(statelang.StateInstance)
if stateInstance != nil && statelang.SU != stateInstance.Status() {
return nil, EndStateMachine(ctx, processContext)
}
}

stateStackToBeCompensated := GetCurrentCompensationHolder(ctx, processContext, true).StateStackNeedCompensation()
if stateStackToBeCompensated != nil {
stateToBeCompensated := stateStackToBeCompensated.Pop().(statelang.StateInstance)

stateMachine := processContext.GetVariable(constant.VarNameStateMachine).(statelang.StateMachine)
state := stateMachine.State(GetOriginStateName(stateToBeCompensated))
if taskState, ok := state.(sagaState.AbstractTaskState); ok {
instruction := processContext.GetInstruction().(StateInstruction)

var compensateState statelang.State
compensateStateName := taskState.CompensateState()
if len(compensateStateName) != 0 {
compensateState = stateMachine.State(compensateStateName)
}

if subStateMachine, ok := state.(sagaState.SubStateMachine); compensateState == nil && ok {
compensateState = subStateMachine.CompensateStateImpl()
instruction.SetTemporaryState(compensateState)
}

if compensateState == nil {
return nil, EndStateMachine(ctx, processContext)
}

instruction.SetStateName(compensateState.Name())

GetCurrentCompensationHolder(ctx, processContext, true).AddToBeCompensatedState(compensateState.Name(),
stateToBeCompensated)

hierarchicalProcessContext := processContext.(process_ctrl.HierarchicalProcessContext)
hierarchicalProcessContext.SetVariableLocally(constant.VarNameFirstCompensationStateStarted, true)

if _, ok := compensateState.(sagaState.CompensateSubStateMachineState); ok {
hierarchicalProcessContext = processContext.(process_ctrl.HierarchicalProcessContext)
hierarchicalProcessContext.SetVariableLocally(
compensateState.Name()+constant.VarNameSubMachineParentId,
GenerateParentId(stateToBeCompensated))
}

return instruction, nil
}
}

processContext.RemoveVariable(constant.VarNameCurrentCompensateTriggerState)

compensationTriggerStateNext := compensationTriggerState.Next()
if compensationTriggerStateNext == "" {
return nil, EndStateMachine(ctx, processContext)
}

instruction := processContext.GetInstruction().(StateInstruction)
instruction.SetStateName(compensationTriggerStateNext)
return instruction, nil
}

+ 137
- 0
pkg/saga/statemachine/engine/repo/repository/state_log_repository.go View File

@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package repository

import (
"context"
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/store"
"sync"
)

var (
stateLogRepositoryImpl *StateLogRepositoryImpl
onceStateLogRepositoryImpl sync.Once
)

type StateLogRepositoryImpl struct {
stateLogStore store.StateLogStore
}

func NewStateLogRepositoryImpl() *StateLogRepositoryImpl {
onceStateLogRepositoryImpl.Do(func() {
stateLogRepositoryImpl = &StateLogRepositoryImpl{}
})
return stateLogRepositoryImpl
}

func (s *StateLogRepositoryImpl) RecordStateMachineStarted(
ctx context.Context,
machineInstance statelang.StateMachineInstance,
processContext process_ctrl.ProcessContext,
) error {
if s.stateLogStore == nil {
return errors.New("stateLogStore is not initialized")
}
return s.stateLogStore.RecordStateMachineStarted(ctx, machineInstance, processContext)
}

func (s *StateLogRepositoryImpl) RecordStateMachineFinished(
ctx context.Context,
machineInstance statelang.StateMachineInstance,
processContext process_ctrl.ProcessContext,
) error {
if s.stateLogStore == nil {
return errors.New("stateLogStore is not initialized")
}
return s.stateLogStore.RecordStateMachineFinished(ctx, machineInstance, processContext)
}

func (s *StateLogRepositoryImpl) RecordStateMachineRestarted(
ctx context.Context,
machineInstance statelang.StateMachineInstance,
processContext process_ctrl.ProcessContext,
) error {
if s.stateLogStore == nil {
return errors.New("stateLogStore is not initialized")
}
return s.stateLogStore.RecordStateMachineRestarted(ctx, machineInstance, processContext)
}

func (s *StateLogRepositoryImpl) RecordStateStarted(
ctx context.Context,
stateInstance statelang.StateInstance,
processContext process_ctrl.ProcessContext,
) error {
if s.stateLogStore == nil {
return errors.New("stateLogStore is not initialized")
}
return s.stateLogStore.RecordStateStarted(ctx, stateInstance, processContext)
}

func (s *StateLogRepositoryImpl) RecordStateFinished(
ctx context.Context,
stateInstance statelang.StateInstance,
processContext process_ctrl.ProcessContext,
) error {
if s.stateLogStore == nil {
return errors.New("stateLogStore is not initialized")
}
return s.stateLogStore.RecordStateFinished(ctx, stateInstance, processContext)
}

func (s *StateLogRepositoryImpl) GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateMachineInstance, error) {
if s.stateLogStore == nil {
return nil, errors.New("stateLogStore is not initialized")
}
return s.stateLogStore.GetStateMachineInstance(stateMachineInstanceId)
}

func (s *StateLogRepositoryImpl) GetStateMachineInstanceByBusinessKey(businessKey, tenantId string) (statelang.StateMachineInstance, error) {
if s.stateLogStore == nil {
return nil, errors.New("stateLogStore is not initialized")
}
return s.stateLogStore.GetStateMachineInstanceByBusinessKey(businessKey, tenantId)
}

func (s *StateLogRepositoryImpl) GetStateMachineInstanceByParentId(parentId string) ([]statelang.StateMachineInstance, error) {
if s.stateLogStore == nil {
return nil, errors.New("stateLogStore is not initialized")
}
return s.stateLogStore.GetStateMachineInstanceByParentId(parentId)
}

func (s *StateLogRepositoryImpl) GetStateInstance(stateInstanceId, machineInstId string) (statelang.StateInstance, error) {
if s.stateLogStore == nil {
return nil, errors.New("stateLogStore is not initialized")
}
return s.stateLogStore.GetStateInstance(stateInstanceId, machineInstId)
}

func (s *StateLogRepositoryImpl) GetStateInstanceListByMachineInstanceId(stateMachineInstanceId string) ([]statelang.StateInstance, error) {
if s.stateLogStore == nil {
return nil, errors.New("stateLogStore is not initialized")
}
return s.stateLogStore.GetStateInstanceListByMachineInstanceId(stateMachineInstanceId)
}

func (s *StateLogRepositoryImpl) SetStateLogStore(stateLogStore store.StateLogStore) {
s.stateLogStore = stateLogStore
}

+ 237
- 0
pkg/saga/statemachine/engine/repo/repository/state_machine_repository.go View File

@@ -0,0 +1,237 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package repository

import (
"io"
"sync"
"time"

"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser"
"github.com/seata/seata-go/pkg/saga/statemachine/store"
"github.com/seata/seata-go/pkg/util/log"
)

const (
DefaultJsonParser = "fastjson"
)

var (
stateMachineRepositoryImpl *StateMachineRepositoryImpl
onceStateMachineRepositoryImpl sync.Once
)

type StateMachineRepositoryImpl struct {
stateMachineMapById map[string]statelang.StateMachine
stateMachineMapByNameAndTenant map[string]statelang.StateMachine

stateLangStore store.StateLangStore
seqGenerator sequence.SeqGenerator
defaultTenantId string
jsonParserName string
charset string
mutex *sync.Mutex
}

func GetStateMachineRepositoryImpl() *StateMachineRepositoryImpl {
if stateMachineRepositoryImpl == nil {
onceStateMachineRepositoryImpl.Do(func() {
//TODO charset is not use
//TODO using json parser
stateMachineRepositoryImpl = &StateMachineRepositoryImpl{
stateMachineMapById: make(map[string]statelang.StateMachine),
stateMachineMapByNameAndTenant: make(map[string]statelang.StateMachine),
seqGenerator: sequence.NewUUIDSeqGenerator(),
jsonParserName: DefaultJsonParser,
charset: "UTF-8",
mutex: &sync.Mutex{},
}
})
}

return stateMachineRepositoryImpl
}

func (s *StateMachineRepositoryImpl) GetStateMachineById(stateMachineId string) (statelang.StateMachine, error) {
stateMachine := s.stateMachineMapById[stateMachineId]
if stateMachine == nil && s.stateLangStore != nil {
s.mutex.Lock()
defer s.mutex.Unlock()

stateMachine = s.stateMachineMapById[stateMachineId]
if stateMachine == nil {
oldStateMachine, err := s.stateLangStore.GetStateMachineById(stateMachineId)
if err != nil {
return oldStateMachine, err
}

parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content())
if err != nil {
return oldStateMachine, err
}

oldStateMachine.SetStartState(parseStatMachine.StartState())
for key, val := range parseStatMachine.States() {
oldStateMachine.States()[key] = val
}

s.stateMachineMapById[stateMachineId] = oldStateMachine
s.stateMachineMapByNameAndTenant[oldStateMachine.Name()+"_"+oldStateMachine.TenantId()] = oldStateMachine
return oldStateMachine, nil
}
}
return stateMachine, nil
}

func (s *StateMachineRepositoryImpl) GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error) {
return s.GetLastVersionStateMachine(stateMachineName, tenantId)
}

func (s *StateMachineRepositoryImpl) GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error) {
key := stateMachineName + "_" + tenantId
stateMachine := s.stateMachineMapByNameAndTenant[key]
if stateMachine == nil && s.stateLangStore != nil {
s.mutex.Lock()
defer s.mutex.Unlock()

stateMachine = s.stateMachineMapById[key]
if stateMachine == nil {
oldStateMachine, err := s.stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId)
if err != nil {
return oldStateMachine, err
}

parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(oldStateMachine.Content())
if err != nil {
return oldStateMachine, err
}

oldStateMachine.SetStartState(parseStatMachine.StartState())
for key, val := range parseStatMachine.States() {
oldStateMachine.States()[key] = val
}

s.stateMachineMapById[oldStateMachine.ID()] = oldStateMachine
s.stateMachineMapByNameAndTenant[key] = oldStateMachine
return oldStateMachine, nil
}
}
return stateMachine, nil
}

func (s *StateMachineRepositoryImpl) RegistryStateMachine(machine statelang.StateMachine) error {
stateMachineName := machine.Name()
tenantId := machine.TenantId()

if s.stateLangStore != nil {
oldStateMachine, err := s.stateLangStore.GetLastVersionStateMachine(stateMachineName, tenantId)
if err != nil {
return err
}

if oldStateMachine != nil {
if oldStateMachine.Content() == machine.Content() && machine.Version() != "" && machine.Version() == oldStateMachine.Version() {
log.Debugf("StateMachine[%s] is already exist a same version", stateMachineName)
machine.SetID(oldStateMachine.ID())
machine.SetCreateTime(oldStateMachine.CreateTime())

s.stateMachineMapById[machine.ID()] = machine
s.stateMachineMapByNameAndTenant[machine.Name()+"_"+machine.TenantId()] = machine
return nil
}
}

if machine.ID() == "" {
machine.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachine, ""))
}

machine.SetCreateTime(time.Now())

err = s.stateLangStore.StoreStateMachine(machine)
if err != nil {
return err
}
}

if machine.ID() == "" {
machine.SetID(s.seqGenerator.GenerateId(constant.SeqEntityStateMachine, ""))
}

s.stateMachineMapById[machine.ID()] = machine
s.stateMachineMapByNameAndTenant[machine.Name()+"_"+machine.TenantId()] = machine
return nil
}

func (s *StateMachineRepositoryImpl) RegistryStateMachineByReader(reader io.Reader) error {
jsonByte, err := io.ReadAll(reader)
if err != nil {
return err
}

json := string(jsonByte)
parseStatMachine, err := parser.NewJSONStateMachineParser().Parse(json)
if err != nil {
return err
}

if parseStatMachine == nil {
return nil
}

parseStatMachine.SetContent(json)
s.RegistryStateMachine(parseStatMachine)

log.Debugf("===== StateMachine Loaded: %s", json)

return nil
}

func (s *StateMachineRepositoryImpl) SetStateLangStore(stateLangStore store.StateLangStore) {
s.stateLangStore = stateLangStore
}

func (s *StateMachineRepositoryImpl) SetSeqGenerator(seqGenerator sequence.SeqGenerator) {
s.seqGenerator = seqGenerator
}

func (s *StateMachineRepositoryImpl) SetCharset(charset string) {
s.charset = charset
}

func (s *StateMachineRepositoryImpl) GetCharset() string {
return s.charset
}

func (s *StateMachineRepositoryImpl) SetDefaultTenantId(defaultTenantId string) {
s.defaultTenantId = defaultTenantId
}

func (s *StateMachineRepositoryImpl) GetDefaultTenantId() string {
return s.defaultTenantId
}

func (s *StateMachineRepositoryImpl) SetJsonParserName(jsonParserName string) {
s.jsonParserName = jsonParserName
}

func (s *StateMachineRepositoryImpl) GetJsonParserName() string {
return s.jsonParserName
}

+ 118
- 0
pkg/saga/statemachine/engine/repo/repository/state_machine_repository_test.go View File

@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package repository

import (
"database/sql"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/parser"
"os"
"sync"
"testing"
"time"

_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/store/db"
)

var (
oncePrepareDB sync.Once
testdb *sql.DB
)

func prepareDB() {
oncePrepareDB.Do(func() {
var err error
testdb, err = sql.Open("sqlite3", ":memory:")
query_, err := os.ReadFile("../../../../../testdata/sql/saga/sqlite_init.sql")
initScript := string(query_)
if err != nil {
panic(err)
}
if _, err := testdb.Exec(initScript); err != nil {
panic(err)
}
})
}

func loadStateMachineByYaml() string {
query, _ := os.ReadFile("../../../../../testdata/saga/statelang/simple_statemachine.json")
return string(query)
}

func TestStateMachineInMemory(t *testing.T) {
const stateMachineId, stateMachineName, tenantId = "simpleStateMachine", "simpleStateMachine", "test"
stateMachine := statelang.NewStateMachineImpl()
stateMachine.SetID(stateMachineId)
stateMachine.SetName(stateMachineName)
stateMachine.SetTenantId(tenantId)
stateMachine.SetComment("This is a test state machine")
stateMachine.SetCreateTime(time.Now())

repository := GetStateMachineRepositoryImpl()

err := repository.RegistryStateMachine(stateMachine)
assert.Nil(t, err)

machineById, err := repository.GetStateMachineById(stateMachine.ID())
assert.Nil(t, err)
assert.Equal(t, stateMachine.Name(), machineById.Name())
assert.Equal(t, stateMachine.TenantId(), machineById.TenantId())
assert.Equal(t, stateMachine.Comment(), machineById.Comment())
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano())

machineByNameAndTenantId, err := repository.GetLastVersionStateMachine(stateMachine.Name(), stateMachine.TenantId())
assert.Nil(t, err)
assert.Equal(t, stateMachine.ID(), machineByNameAndTenantId.ID())
assert.Equal(t, stateMachine.Comment(), machineById.Comment())
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano())
}

func TestStateMachineInDb(t *testing.T) {
prepareDB()

const tenantId = "test"
yaml := loadStateMachineByYaml()
stateMachine, err := parser.NewJSONStateMachineParser().Parse(yaml)
assert.Nil(t, err)
stateMachine.SetTenantId(tenantId)
stateMachine.SetContent(yaml)

repository := GetStateMachineRepositoryImpl()
repository.SetStateLangStore(db.NewStateLangStore(testdb, "seata_"))

err = repository.RegistryStateMachine(stateMachine)
assert.Nil(t, err)

repository.stateMachineMapById[stateMachine.ID()] = nil
machineById, err := repository.GetStateMachineById(stateMachine.ID())
assert.Nil(t, err)
assert.Equal(t, stateMachine.Name(), machineById.Name())
assert.Equal(t, stateMachine.TenantId(), machineById.TenantId())
assert.Equal(t, stateMachine.Comment(), machineById.Comment())
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano())

repository.stateMachineMapByNameAndTenant[stateMachine.Name()+"_"+stateMachine.TenantId()] = nil
machineByNameAndTenantId, err := repository.GetLastVersionStateMachine(stateMachine.Name(), stateMachine.TenantId())
assert.Nil(t, err)
assert.Equal(t, stateMachine.ID(), machineByNameAndTenantId.ID())
assert.Equal(t, stateMachine.Comment(), machineById.Comment())
assert.Equal(t, stateMachine.CreateTime().UnixNano(), machineById.CreateTime().UnixNano())
}

+ 47
- 0
pkg/saga/statemachine/engine/repo/statemachine_store.go View File

@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package repo

import (
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"io"
)

type StateLogRepository interface {
GetStateMachineInstance(stateMachineInstanceId string) (statelang.StateMachineInstance, error)

GetStateMachineInstanceByBusinessKey(businessKey string, tenantId string) (statelang.StateMachineInstance, error)

GetStateMachineInstanceByParentId(parentId string) ([]statelang.StateMachineInstance, error)

GetStateInstance(stateInstanceId string, stateMachineInstanceId string) (statelang.StateInstance, error)

GetStateInstanceListByMachineInstanceId(stateMachineInstanceId string) ([]statelang.StateInstance, error)
}

type StateMachineRepository interface {
GetStateMachineById(stateMachineId string) (statelang.StateMachine, error)

GetStateMachineByNameAndTenantId(stateMachineName string, tenantId string) (statelang.StateMachine, error)

GetLastVersionStateMachine(stateMachineName string, tenantId string) (statelang.StateMachine, error)

RegistryStateMachine(statelang.StateMachine) error

RegistryStateMachineByReader(reader io.Reader) error
}

+ 22
- 0
pkg/saga/statemachine/engine/sequence/sequence.go View File

@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sequence

type SeqGenerator interface {
GenerateId(entity string, ruleName string) string
}

+ 133
- 0
pkg/saga/statemachine/engine/sequence/snowflake.go View File

@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sequence

import (
"fmt"
"sync"
"time"

"github.com/seata/seata-go/pkg/util/log"
)

// SnowflakeSeqGenerator snowflake gen ids
// ref: https://en.wikipedia.org/wiki/Snowflake_ID

var (
// set the beginning time
epoch = time.Date(2024, time.January, 01, 00, 00, 00, 00, time.UTC).UnixMilli()
)

const (
// timestamp occupancy bits
timestampBits = 41
// dataCenterId occupancy bits
dataCenterIdBits = 5
// workerId occupancy bits
workerIdBits = 5
// sequence occupancy bits
seqBits = 12

// timestamp max value, just like 2^41-1 = 2199023255551
timestampMaxValue = -1 ^ (-1 << timestampBits)
// dataCenterId max value, just like 2^5-1 = 31
dataCenterIdMaxValue = -1 ^ (-1 << dataCenterIdBits)
// workId max value, just like 2^5-1 = 31
workerIdMaxValue = -1 ^ (-1 << workerIdBits)
// sequence max value, just like 2^12-1 = 4095
seqMaxValue = -1 ^ (-1 << seqBits)

// number of workId offsets (seqBits)
workIdShift = 12
// number of dataCenterId offsets (seqBits + workerIdBits)
dataCenterIdShift = 17
// number of timestamp offsets (seqBits + workerIdBits + dataCenterIdBits)
timestampShift = 22

defaultInitValue = 0
)

type SnowflakeSeqGenerator struct {
mu *sync.Mutex
timestamp int64
dataCenterId int64
workerId int64
sequence int64
}

// NewSnowflakeSeqGenerator initiates the snowflake generator
func NewSnowflakeSeqGenerator(dataCenterId, workId int64) (r *SnowflakeSeqGenerator, err error) {
if dataCenterId < 0 || dataCenterId > dataCenterIdMaxValue {
err = fmt.Errorf("dataCenterId should between 0 and %d", dataCenterIdMaxValue-1)
return
}

if workId < 0 || workId > workerIdMaxValue {
err = fmt.Errorf("workId should between 0 and %d", dataCenterIdMaxValue-1)
return
}

return &SnowflakeSeqGenerator{
mu: new(sync.Mutex),
timestamp: defaultInitValue - 1,
dataCenterId: dataCenterId,
workerId: workId,
sequence: defaultInitValue,
}, nil
}

// GenerateId timestamp + dataCenterId + workId + sequence
func (S *SnowflakeSeqGenerator) GenerateId(entity string, ruleName string) string {
S.mu.Lock()
defer S.mu.Unlock()

now := time.Now().UnixMilli()

if S.timestamp > now { // Clock callback
log.Errorf("Clock moved backwards. Refusing to generate ID, last timestamp is %d, now is %d", S.timestamp, now)
return ""
}

if S.timestamp == now {
// generate multiple IDs in the same millisecond, incrementing the sequence number to prevent conflicts
S.sequence = (S.sequence + 1) & seqMaxValue
if S.sequence == 0 {
// sequence overflow, waiting for next millisecond
for now <= S.timestamp {
now = time.Now().UnixMilli()
}
}
} else {
// initialized sequences are used directly at different millisecond timestamps
S.sequence = defaultInitValue
}
tmp := now - epoch
if tmp > timestampMaxValue {
log.Errorf("epoch should between 0 and %d", timestampMaxValue-1)
return ""
}
S.timestamp = now

// combine the parts to generate the final ID and convert the 64-bit binary to decimal digits.
r := (tmp)<<timestampShift |
(S.dataCenterId << dataCenterIdShift) |
(S.workerId << workIdShift) |
(S.sequence)

return fmt.Sprintf("%d", r)
}

+ 45
- 0
pkg/saga/statemachine/engine/sequence/snowflake_test.go View File

@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sequence

import (
"strconv"
"testing"
)

func TestSnowflakeSeqGenerator_GenerateId(t *testing.T) {
var dataCenterId, workId int64 = 1, 1
generator, err := NewSnowflakeSeqGenerator(dataCenterId, workId)
if err != nil {
t.Error(err)
return
}
var x, y string
for i := 0; i < 100; i++ {
y = generator.GenerateId("", "")
if x == y {
t.Errorf("x(%s) & y(%s) are the same", x, y)
}
x = y
}
}

func TestEpoch(t *testing.T) {
t.Log(epoch)
t.Log(len(strconv.FormatInt(epoch, 10)))
}

+ 31
- 0
pkg/saga/statemachine/engine/sequence/uuid.go View File

@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package sequence

import "github.com/google/uuid"

type UUIDSeqGenerator struct {
}

func NewUUIDSeqGenerator() *UUIDSeqGenerator {
return &UUIDSeqGenerator{}
}

func (U UUIDSeqGenerator) GenerateId(entity string, ruleName string) string {
return uuid.New().String()
}

+ 59
- 0
pkg/saga/statemachine/engine/serializer/serializer.go View File

@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package serializer

import (
"bytes"
"encoding/gob"
"encoding/json"
"github.com/pkg/errors"
)

type ParamsSerializer struct{}

func (ParamsSerializer) Serialize(object any) (string, error) {
result, err := json.Marshal(object)
return string(result), err
}

func (ParamsSerializer) Deserialize(object string) (any, error) {
var result any
err := json.Unmarshal([]byte(object), &result)
return result, err
}

type ErrorSerializer struct{}

func (ErrorSerializer) Serialize(object error) ([]byte, error) {
var buffer bytes.Buffer
encoder := gob.NewEncoder(&buffer)
if object != nil {
err := encoder.Encode(object.Error())
return buffer.Bytes(), err
}
return nil, nil
}

func (ErrorSerializer) Deserialize(object []byte) (error, error) {
var errorMsg string
buffer := bytes.NewReader(object)
encoder := gob.NewDecoder(buffer)
err := encoder.Decode(&errorMsg)

return errors.New(errorMsg), err
}

+ 34
- 0
pkg/saga/statemachine/engine/serializer/serializer_test.go View File

@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package serializer

import (
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"testing"
)

func TestErrorSerializer(t *testing.T) {
serializer := ErrorSerializer{}
expected := errors.New("This is a test error")
serialized, err := serializer.Serialize(expected)
assert.Nil(t, err)
actual, err := serializer.Deserialize(serialized)
assert.Nil(t, err)
assert.Equal(t, expected.Error(), actual.Error())
}

+ 77
- 0
pkg/saga/statemachine/engine/statemachine_config.go View File

@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package engine

import (
"github.com/seata/seata-go/pkg/saga/statemachine"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/expr"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/invoker"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/repo"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/sequence"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/store"
"sync"
)

type StateMachineConfig interface {
StateLogRepository() repo.StateLogRepository

StateMachineRepository() repo.StateMachineRepository

StateLogStore() store.StateLogStore

StateLangStore() store.StateLangStore

ExpressionFactoryManager() *expr.ExpressionFactoryManager

ExpressionResolver() expr.ExpressionResolver

SeqGenerator() sequence.SeqGenerator

StatusDecisionStrategy() StatusDecisionStrategy

EventPublisher() process_ctrl.EventPublisher

AsyncEventPublisher() process_ctrl.EventPublisher

ServiceInvokerManager() invoker.ServiceInvokerManager

ScriptInvokerManager() invoker.ScriptInvokerManager

CharSet() string

GetDefaultTenantId() string

GetTransOperationTimeout() int

GetServiceInvokeTimeout() int

ComponentLock() *sync.Mutex

RegisterStateMachineDef(resources []string) error

RegisterExpressionFactory(expressionType string, factory expr.ExpressionFactory)

RegisterServiceInvoker(serviceType string, invoker invoker.ServiceInvoker)

GetStateMachineDefinition(name string) *statemachine.StateMachineObject

GetExpressionFactory(expressionType string) expr.ExpressionFactory

GetServiceInvoker(serviceType string) invoker.ServiceInvoker
}

+ 59
- 0
pkg/saga/statemachine/engine/statemachine_engine.go View File

@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package engine

import (
"context"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
)

type StateMachineEngine interface {
// Start starts a state machine instance
Start(ctx context.Context, stateMachineName string, tenantId string, startParams map[string]interface{}) (statelang.StateMachineInstance, error)
// StartAsync start a state machine instance asynchronously
StartAsync(ctx context.Context, stateMachineName string, tenantId string, startParams map[string]interface{},
callback CallBack) (statelang.StateMachineInstance, error)
// StartWithBusinessKey starts a state machine instance with a business key
StartWithBusinessKey(ctx context.Context, stateMachineName string, tenantId string, businessKey string,
startParams map[string]interface{}) (statelang.StateMachineInstance, error)
// StartWithBusinessKeyAsync starts a state machine instance with a business key asynchronously
StartWithBusinessKeyAsync(ctx context.Context, stateMachineName string, tenantId string, businessKey string,
startParams map[string]interface{}, callback CallBack) (statelang.StateMachineInstance, error)
// Forward restart a failed state machine instance
Forward(ctx context.Context, stateMachineInstId string, replaceParams map[string]interface{}) (statelang.StateMachineInstance, error)
// ForwardAsync restart a failed state machine instance asynchronously
ForwardAsync(ctx context.Context, stateMachineInstId string, replaceParams map[string]interface{}, callback CallBack) (statelang.StateMachineInstance, error)
// Compensate compensate a state machine instance
Compensate(ctx context.Context, stateMachineInstId string, replaceParams map[string]interface{}) (statelang.StateMachineInstance, error)
// CompensateAsync compensate a state machine instance asynchronously
CompensateAsync(ctx context.Context, stateMachineInstId string, replaceParams map[string]interface{}, callback CallBack) (statelang.StateMachineInstance, error)
// SkipAndForward skips the current failed state instance and restarts the state machine instance
SkipAndForward(ctx context.Context, stateMachineInstId string, replaceParams map[string]interface{}) (statelang.StateMachineInstance, error)
// SkipAndForwardAsync skips the current failed state instance and restarts the state machine instance asynchronously
SkipAndForwardAsync(ctx context.Context, stateMachineInstId string, callback CallBack) (statelang.StateMachineInstance, error)
// GetStateMachineConfig gets the state machine configurations
GetStateMachineConfig() StateMachineConfig
// ReloadStateMachineInstance reloads a state machine instance
ReloadStateMachineInstance(ctx context.Context, instId string) (statelang.StateMachineInstance, error)
}

type CallBack interface {
OnFinished(ctx context.Context, context process_ctrl.ProcessContext, stateMachineInstance statelang.StateMachineInstance)
OnError(ctx context.Context, context process_ctrl.ProcessContext, stateMachineInstance statelang.StateMachineInstance, err error)
}

+ 33
- 0
pkg/saga/statemachine/engine/statemachine_engine_test.go View File

@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package engine

import (
"context"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/core"
"testing"
)

func TestEngine(t *testing.T) {

}

func TestSimpleStateMachine(t *testing.T) {
engine := core.NewProcessCtrlStateMachineEngine()
engine.Start(context.Background(), "simpleStateMachine", "tenantId", nil)
}

+ 19
- 0
pkg/saga/statemachine/engine/strategy.go View File

@@ -0,0 +1,19 @@
package engine

import (
"context"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
)

type StatusDecisionStrategy interface {
// DecideOnEndState Determine state machine execution status when executing to EndState
DecideOnEndState(ctx context.Context, processContext process_ctrl.ProcessContext,
stateMachineInstance statelang.StateMachineInstance, exp error) error
// DecideOnTaskStateFail Determine state machine execution status when executing TaskState error
DecideOnTaskStateFail(ctx context.Context, processContext process_ctrl.ProcessContext,
stateMachineInstance statelang.StateMachineInstance, exp error) error
// DecideMachineForwardExecutionStatus Determine the forward execution state of the state machine
DecideMachineForwardExecutionStatus(ctx context.Context,
stateMachineInstance statelang.StateMachineInstance, exp error, specialPolicy bool) error
}

+ 246
- 0
pkg/saga/statemachine/engine/strategy/status_decision.go View File

@@ -0,0 +1,246 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package strategy

import (
"context"
"errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/exception"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/pcext"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/util/log"
)

type DefaultStatusDecisionStrategy struct {
}

func NewDefaultStatusDecisionStrategy() *DefaultStatusDecisionStrategy {
return &DefaultStatusDecisionStrategy{}
}

func (d DefaultStatusDecisionStrategy) DecideOnEndState(ctx context.Context, processContext process_ctrl.ProcessContext,
stateMachineInstance statelang.StateMachineInstance, exp error) error {
if statelang.RU == stateMachineInstance.CompensationStatus() {
compensationHolder := pcext.GetCurrentCompensationHolder(ctx, processContext, true)
if err := decideMachineCompensateStatus(ctx, stateMachineInstance, compensationHolder); err != nil {
return err
}
} else {
failEndStateFlag, ok := processContext.GetVariable(constant.VarNameFailEndStateFlag).(bool)
if !ok {
failEndStateFlag = false
}
if _, err := decideMachineForwardExecutionStatus(ctx, stateMachineInstance, exp, failEndStateFlag); err != nil {
return err
}
}

if stateMachineInstance.CompensationStatus() != "" && constant.OperationNameForward ==
processContext.GetVariable(constant.VarNameOperationName).(string) && statelang.SU == stateMachineInstance.Status() {
stateMachineInstance.SetCompensationStatus(statelang.FA)
}

log.Debugf("StateMachine Instance[id:%s,name:%s] execute finish with status[%s], compensation status [%s].",
stateMachineInstance.ID(), stateMachineInstance.StateMachine().Name(),
stateMachineInstance.Status(), stateMachineInstance.CompensationStatus())

return nil
}

func decideMachineCompensateStatus(ctx context.Context, stateMachineInstance statelang.StateMachineInstance, compensationHolder *pcext.CompensationHolder) error {
if stateMachineInstance.Status() == "" || statelang.RU == stateMachineInstance.Status() {
stateMachineInstance.SetStatus(statelang.UN)
}
if !compensationHolder.StateStackNeedCompensation().Empty() {
hasCompensateSUorUN := false
compensationHolder.StatesForCompensation().Range(
func(key, value any) bool {
stateInstance, ok := value.(statelang.StateInstance)
if !ok {
return false
}
if statelang.UN == stateInstance.Status() || statelang.SU == stateInstance.Status() {
hasCompensateSUorUN = true
return true
}
return false
})

if hasCompensateSUorUN {
stateMachineInstance.SetCompensationStatus(statelang.UN)
} else {
stateMachineInstance.SetCompensationStatus(statelang.FA)
}
} else {
hasCompensateError := false
compensationHolder.StatesForCompensation().Range(
func(key, value any) bool {
stateInstance, ok := value.(statelang.StateInstance)
if !ok {
return false
}
if statelang.SU != stateInstance.Status() {
hasCompensateError = true
return true
}
return false
})

if hasCompensateError {
stateMachineInstance.SetCompensationStatus(statelang.UN)
} else {
stateMachineInstance.SetCompensationStatus(statelang.SU)
}
}
return nil
}

func decideMachineForwardExecutionStatus(ctx context.Context, stateMachineInstance statelang.StateMachineInstance, exp error, specialPolicy bool) (bool, error) {
result := false

if stateMachineInstance.Status() == "" || statelang.RU == stateMachineInstance.Status() {
result = true
stateList := stateMachineInstance.StateList()
//Determine the final state of the entire state machine based on the state of each StateInstance
setMachineStatusBasedOnStateListAndException(stateMachineInstance, stateList, exp)

if specialPolicy && statelang.SU == stateMachineInstance.Status() {
for _, stateInstance := range stateMachineInstance.StateList() {
if !stateInstance.IsIgnoreStatus() && (stateInstance.IsForUpdate() || stateInstance.IsForCompensation()) {
stateMachineInstance.SetStatus(statelang.UN)
break
}
}
if statelang.SU == stateMachineInstance.Status() {
stateMachineInstance.SetStatus(statelang.FA)
}
}
}
return result, nil
}

func setMachineStatusBasedOnStateListAndException(stateMachineInstance statelang.StateMachineInstance,
stateList []statelang.StateInstance, exp error) {
hasSetStatus := false
hasSuccessUpdateService := false
if stateList != nil && len(stateList) > 0 {
hasUnsuccessService := false

for i := len(stateList) - 1; i >= 0; i-- {
stateInstance := stateList[i]

if stateInstance.IsIgnoreStatus() || stateInstance.IsForCompensation() {
continue
}
if statelang.UN == stateInstance.Status() {
stateMachineInstance.SetStatus(statelang.UN)
hasSetStatus = true
} else if statelang.SU == stateInstance.Status() {
if constant.StateTypeServiceTask == stateInstance.Type() {
if stateInstance.IsForUpdate() && !stateInstance.IsForCompensation() {
hasSuccessUpdateService = true
}
}
} else if statelang.SK == stateInstance.Status() {
// ignore
} else {
hasUnsuccessService = true
}
}

if !hasSetStatus && hasUnsuccessService {
if hasSuccessUpdateService {
stateMachineInstance.SetStatus(statelang.UN)
} else {
stateMachineInstance.SetStatus(statelang.FA)
}
hasSetStatus = true
}
}

if !hasSetStatus {
setMachineStatusBasedOnException(stateMachineInstance, exp, hasSuccessUpdateService)
}
}

func setMachineStatusBasedOnException(stateMachineInstance statelang.StateMachineInstance, exp error, hasSuccessUpdateService bool) {

if exp == nil {
log.Debugf("No error found, setting StateMachineInstance[id:%s] status to SU", stateMachineInstance.ID())
stateMachineInstance.SetStatus(statelang.SU)
return
}

var engineExp *exception.EngineExecutionException
if errors.As(exp, &engineExp) && engineExp.ErrCode == constant.FrameworkErrorCodeStateMachineExecutionTimeout {
log.Warnf("Execution timeout detected, setting StateMachineInstance[id:%s] status to UN", stateMachineInstance.ID())
stateMachineInstance.SetStatus(statelang.UN)
return
}

if hasSuccessUpdateService {
log.Infof("Has successful update service, setting StateMachineInstance[id:%s] status to UN", stateMachineInstance.ID())
stateMachineInstance.SetStatus(statelang.UN)
return
}

netType := pcext.GetNetExceptionType(exp)
switch netType {
case constant.ConnectException, constant.ConnectTimeoutException, constant.NotNetException:
log.Warnf("Detected network connect issue, setting StateMachineInstance[id:%s] status to FA", stateMachineInstance.ID())
stateMachineInstance.SetStatus(statelang.FA)
case constant.ReadTimeoutException:
log.Warnf("Detected read timeout, setting StateMachineInstance[id:%s] status to UN", stateMachineInstance.ID())
stateMachineInstance.SetStatus(statelang.UN)
default:
//Default failure
log.Errorf("Unknown exception type, setting StateMachineInstance[id:%s] status to FA", stateMachineInstance.ID())
stateMachineInstance.SetStatus(statelang.FA)

}
}

func (d DefaultStatusDecisionStrategy) DecideOnTaskStateFail(ctx context.Context, processContext process_ctrl.ProcessContext,
stateMachineInstance statelang.StateMachineInstance, exp error) error {

log.Debugf("Starting DecideOnTaskStateFail for StateMachineInstance[id:%s]", stateMachineInstance.ID())
result, err := decideMachineForwardExecutionStatus(ctx, stateMachineInstance, exp, true)
if err != nil {
log.Errorf("DecideMachineForwardExecutionStatus failed: %v", err)
return err
}

if !result {
log.Warnf("Forward execution result is false, setting compensation status UN for StateMachineInstance[id:%s]", stateMachineInstance.ID())
stateMachineInstance.SetCompensationStatus(statelang.UN)
}
return nil
}

func (d DefaultStatusDecisionStrategy) DecideMachineForwardExecutionStatus(ctx context.Context,
stateMachineInstance statelang.StateMachineInstance, exp error, specialPolicy bool) error {

log.Debugf("Starting DecideMachineForwardExecutionStatus for StateMachineInstance[id:%s], specialPolicy: %v", stateMachineInstance.ID(), specialPolicy)
_, err := decideMachineForwardExecutionStatus(ctx, stateMachineInstance, exp, specialPolicy)
if err != nil {
log.Errorf("DecideMachineForwardExecutionStatus failed: %v", err)
}
return err
}

+ 113
- 0
pkg/saga/statemachine/engine/utils/process_context_utils.go View File

@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package utils

import (
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl/process"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
)

// ProcessContextBuilder process_ctrl builder
type ProcessContextBuilder struct {
processContext process_ctrl.ProcessContext
}

func NewProcessContextBuilder() *ProcessContextBuilder {
processContextImpl := process_ctrl.NewProcessContextImpl()
return &ProcessContextBuilder{processContextImpl}
}

func (p *ProcessContextBuilder) WithProcessType(processType process.ProcessType) *ProcessContextBuilder {
p.processContext.SetVariable(constant.VarNameProcessType, processType)
return p
}

func (p *ProcessContextBuilder) WithOperationName(operationName string) *ProcessContextBuilder {
p.processContext.SetVariable(constant.VarNameOperationName, operationName)
return p
}

func (p *ProcessContextBuilder) WithAsyncCallback(callBack engine.CallBack) *ProcessContextBuilder {
if callBack != nil {
p.processContext.SetVariable(constant.VarNameAsyncCallback, callBack)
}

return p
}

func (p *ProcessContextBuilder) WithInstruction(instruction process_ctrl.Instruction) *ProcessContextBuilder {
if instruction != nil {
p.processContext.SetInstruction(instruction)
}

return p
}

func (p *ProcessContextBuilder) WithStateMachineInstance(stateMachineInstance statelang.StateMachineInstance) *ProcessContextBuilder {
if stateMachineInstance != nil {
p.processContext.SetVariable(constant.VarNameStateMachineInst, stateMachineInstance)
p.processContext.SetVariable(constant.VarNameStateMachine, stateMachineInstance.StateMachine())
}

return p
}

func (p *ProcessContextBuilder) WithStateMachineEngine(stateMachineEngine engine.StateMachineEngine) *ProcessContextBuilder {
if stateMachineEngine != nil {
p.processContext.SetVariable(constant.VarNameStateMachineEngine, stateMachineEngine)
}

return p
}

func (p *ProcessContextBuilder) WithStateMachineConfig(stateMachineConfig engine.StateMachineConfig) *ProcessContextBuilder {
if stateMachineConfig != nil {
p.processContext.SetVariable(constant.VarNameStateMachineConfig, stateMachineConfig)
}

return p
}

func (p *ProcessContextBuilder) WithStateMachineContextVariables(contextMap map[string]interface{}) *ProcessContextBuilder {
if contextMap != nil {
p.processContext.SetVariable(constant.VarNameStateMachineContext, contextMap)
}

return p
}

func (p *ProcessContextBuilder) WithIsAsyncExecution(async bool) *ProcessContextBuilder {
p.processContext.SetVariable(constant.VarNameIsAsyncExecution, async)

return p
}

func (p *ProcessContextBuilder) WithStateInstance(state statelang.StateInstance) *ProcessContextBuilder {
if state != nil {
p.processContext.SetVariable(constant.VarNameStateInst, state)
}

return p
}

func (p *ProcessContextBuilder) Build() process_ctrl.ProcessContext {
return p.processContext
}

+ 109
- 0
pkg/saga/statemachine/process_ctrl/bussiness_processor.go View File

@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package process_ctrl

import (
"context"
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl/process"
"sync"
)

type BusinessProcessor interface {
Process(ctx context.Context, processContext ProcessContext) error

Route(ctx context.Context, processContext ProcessContext) error
}

func NewBusinessProcessor() BusinessProcessor {
return &DefaultBusinessProcessor{
processHandlers: make(map[string]ProcessHandler),
routerHandlers: make(map[string]RouterHandler),
}
}

type DefaultBusinessProcessor struct {
processHandlers map[string]ProcessHandler
routerHandlers map[string]RouterHandler
mu sync.RWMutex
}

func (d *DefaultBusinessProcessor) RegistryProcessHandler(processType process.ProcessType, processHandler ProcessHandler) {
d.mu.Lock()
defer d.mu.Unlock()

d.processHandlers[string(processType)] = processHandler
}

func (d *DefaultBusinessProcessor) RegistryRouterHandler(processType process.ProcessType, routerHandler RouterHandler) {
d.mu.Lock()
defer d.mu.Unlock()

d.routerHandlers[string(processType)] = routerHandler
}

func (d *DefaultBusinessProcessor) Process(ctx context.Context, processContext ProcessContext) error {
processType := d.matchProcessType(processContext)

processHandler, err := d.getProcessHandler(processType)
if err != nil {
return err
}

return processHandler.Process(ctx, processContext)
}

func (d *DefaultBusinessProcessor) Route(ctx context.Context, processContext ProcessContext) error {
processType := d.matchProcessType(processContext)

routerHandler, err := d.getRouterHandler(processType)
if err != nil {
return err
}

return routerHandler.Route(ctx, processContext)
}

func (d *DefaultBusinessProcessor) getProcessHandler(processType process.ProcessType) (ProcessHandler, error) {
d.mu.RLock()
defer d.mu.RUnlock()
processHandler, ok := d.processHandlers[string(processType)]
if !ok {
return nil, errors.New("Cannot find Process handler by type " + string(processType))
}
return processHandler, nil
}

func (d *DefaultBusinessProcessor) getRouterHandler(processType process.ProcessType) (RouterHandler, error) {
d.mu.RLock()
defer d.mu.RUnlock()
routerHandler, ok := d.routerHandlers[string(processType)]
if !ok {
return nil, errors.New("Cannot find router handler by type " + string(processType))
}
return routerHandler, nil
}

func (d *DefaultBusinessProcessor) matchProcessType(processContext ProcessContext) process.ProcessType {
ok := processContext.HasVariable(constant.VarNameProcessType)
if ok {
return processContext.GetVariable(constant.VarNameProcessType).(process.ProcessType)
}
return process.StateLang
}

+ 7
- 0
pkg/saga/statemachine/process_ctrl/default_process_handler.go View File

@@ -0,0 +1,7 @@
package process_ctrl

import "context"

type ProcessHandler interface {
Process(ctx context.Context, processContext ProcessContext) error
}

+ 21
- 0
pkg/saga/statemachine/process_ctrl/event.go View File

@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package process_ctrl

type Event interface {
}

+ 146
- 0
pkg/saga/statemachine/process_ctrl/event_bus.go View File

@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package process_ctrl

import (
"context"
"fmt"
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/util/collection"
"github.com/seata/seata-go/pkg/util/log"
)

type EventBus interface {
Offer(ctx context.Context, event Event) (bool, error)

EventConsumerList(event Event) []EventConsumer

RegisterEventConsumer(consumer EventConsumer)
}

type BaseEventBus struct {
eventConsumerList []EventConsumer
}

func (b *BaseEventBus) RegisterEventConsumer(consumer EventConsumer) {
if b.eventConsumerList == nil {
b.eventConsumerList = make([]EventConsumer, 0)
}
b.eventConsumerList = append(b.eventConsumerList, consumer)
}

func (b *BaseEventBus) EventConsumerList(event Event) []EventConsumer {
var acceptedConsumerList = make([]EventConsumer, 0)
for i := range b.eventConsumerList {
eventConsumer := b.eventConsumerList[i]
if eventConsumer.Accept(event) {
acceptedConsumerList = append(acceptedConsumerList, eventConsumer)
}
}
return acceptedConsumerList
}

type DirectEventBus struct {
BaseEventBus
}

func (d DirectEventBus) Offer(ctx context.Context, event Event) (bool, error) {
eventConsumerList := d.EventConsumerList(event)
if len(eventConsumerList) == 0 {
log.Debugf("cannot find event handler by type: %T", event)
return false, nil
}

isFirstEvent := true
processContext, ok := event.(ProcessContext)
if !ok {
log.Errorf("event %T is illegal, required process_ctrl.ProcessContext", event)
return false, nil
}

stack := processContext.GetVariable(constant.VarNameSyncExeStack).(*collection.Stack)
if stack == nil {
stack = collection.NewStack()
processContext.SetVariable(constant.VarNameSyncExeStack, stack)
isFirstEvent = true
}

stack.Push(processContext)
if isFirstEvent {
for stack.Len() > 0 {
currentContext := stack.Pop().(ProcessContext)
for _, eventConsumer := range eventConsumerList {
err := eventConsumer.Process(ctx, currentContext)
if err != nil {
log.Errorf("process event %T error: %s", event, err.Error())
return false, err
}
}
}
}

return true, nil
}

type AsyncEventBus struct {
BaseEventBus
}

func (a AsyncEventBus) Offer(ctx context.Context, event Event) (bool, error) {
eventConsumerList := a.EventConsumerList(event)
if len(eventConsumerList) == 0 {
errStr := fmt.Sprintf("cannot find event handler by type: %T", event)
log.Errorf(errStr)
return false, errors.New(errStr)
}

processContext, ok := event.(ProcessContext)
if !ok {
errStr := fmt.Sprintf("event %T is illegal, required process_ctrl.ProcessContext", event)
log.Errorf(errStr)
return false, errors.New(errStr)
}

for _, eventConsumer := range eventConsumerList {
go func() {
err := eventConsumer.Process(ctx, processContext)
if err != nil {
log.Errorf("process event %T error: %s", event, err.Error())
}
}()
}

return true, nil
}

func NewDirectEventBus() *DirectEventBus {
return &DirectEventBus{
BaseEventBus: BaseEventBus{
eventConsumerList: make([]EventConsumer, 0),
},
}
}

func NewAsyncEventBus(ctx context.Context, queueSize int, workerCount int) *AsyncEventBus {
return &AsyncEventBus{
BaseEventBus: BaseEventBus{
eventConsumerList: make([]EventConsumer, 0),
},
}
}

+ 56
- 0
pkg/saga/statemachine/process_ctrl/event_consumer.go View File

@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package process_ctrl

import (
"context"
"fmt"
)

type EventConsumer interface {
Accept(event Event) bool

Process(ctx context.Context, event Event) error
}

type ProcessCtrlEventConsumer struct {
processController ProcessController
}

func (p ProcessCtrlEventConsumer) Accept(event Event) bool {
if event == nil {
return false
}

_, ok := event.(ProcessContext)
return ok
}

func (p ProcessCtrlEventConsumer) Process(ctx context.Context, event Event) error {
processContext, ok := event.(ProcessContext)
if !ok {
return fmt.Errorf("event %T is illegal, required process_ctrl.ProcessContext", event)
}
return p.processController.Process(ctx, processContext)
}

func NewProcessCtrlEventConsumer(controller ProcessController) EventConsumer {
return &ProcessCtrlEventConsumer{
processController: controller,
}
}

+ 36
- 0
pkg/saga/statemachine/process_ctrl/event_publisher.go View File

@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package process_ctrl

import "context"

type EventPublisher interface {
PushEvent(ctx context.Context, event Event) (bool, error)
}

type ProcessCtrlEventPublisher struct {
eventBus EventBus
}

func NewProcessCtrlEventPublisher(eventBus EventBus) *ProcessCtrlEventPublisher {
return &ProcessCtrlEventPublisher{eventBus: eventBus}
}

func (p ProcessCtrlEventPublisher) PushEvent(ctx context.Context, event Event) (bool, error) {
return p.eventBus.Offer(ctx, event)
}

+ 193
- 0
pkg/saga/statemachine/process_ctrl/handlers/service_task_state_handler.go View File

@@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package handlers

import (
"context"
"errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/engine"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/exception"
"github.com/seata/seata-go/pkg/saga/statemachine/engine/pcext"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
seataErrors "github.com/seata/seata-go/pkg/util/errors"
"github.com/seata/seata-go/pkg/util/log"
)

type ServiceTaskStateHandler struct {
interceptors []pcext.StateHandlerInterceptor
}

func NewServiceTaskStateHandler() *ServiceTaskStateHandler {
return &ServiceTaskStateHandler{}
}

func (s *ServiceTaskStateHandler) State() string {
return constant.StateTypeServiceTask
}

func (s *ServiceTaskStateHandler) Process(ctx context.Context, processContext process_ctrl.ProcessContext) error {
stateInstruction, ok := processContext.GetInstruction().(pcext.StateInstruction)
if !ok {
return errors.New("invalid state instruction from processContext")
}
stateInterface, err := stateInstruction.GetState(processContext)
if err != nil {
return err
}
serviceTaskStateImpl, ok := stateInterface.(*state.ServiceTaskStateImpl)

serviceName := serviceTaskStateImpl.ServiceName()
methodName := serviceTaskStateImpl.ServiceMethod()
stateInstance, ok := processContext.GetVariable(constant.VarNameStateInst).(statelang.StateInstance)
if !ok {
return errors.New("invalid state instance type from processContext")
}

// invoke service task and record
var result any
var resultErr error
handleResultErr := func(err error) {
log.Error("<<<<<<<<<<<<<<<<<<<<<< State[%s], ServiceName[%s], Method[%s] Execute failed.",
serviceTaskStateImpl.Name(), serviceName, methodName, err)

hierarchicalProcessContext, ok := processContext.(process_ctrl.HierarchicalProcessContext)
if !ok {
return
}
hierarchicalProcessContext.SetVariable(constant.VarNameCurrentException, err)
pcext.HandleException(processContext, serviceTaskStateImpl.AbstractTaskState, err)
}

input, ok := processContext.GetVariable(constant.VarNameInputParams).([]any)
if !ok {
handleResultErr(errors.New("invalid input params type from processContext"))
return nil
}

stateInstance.SetStatus(statelang.RU)
log.Debugf(">>>>>>>>>>>>>>>>>>>>>> Start to execute State[%s], ServiceName[%s], Method[%s], Input:%s",
serviceTaskStateImpl.Name(), serviceName, methodName, input)

if _, ok := stateInterface.(state.CompensateSubStateMachineState); ok {
// If it is the compensation of the subState machine,
// directly call the state machine's compensate method
stateMachineEngine, ok := processContext.GetVariable(constant.VarNameStateMachineEngine).(engine.StateMachineEngine)
if !ok {
handleResultErr(errors.New("invalid stateMachineEngine type from processContext"))
return nil
}

result, resultErr = s.compensateSubStateMachine(ctx, processContext, serviceTaskStateImpl, input,
stateInstance, stateMachineEngine)
if resultErr != nil {
handleResultErr(resultErr)
return nil
}
} else {
stateMachineConfig, ok := processContext.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig)
if !ok {
handleResultErr(errors.New("invalid stateMachineConfig type from processContext"))
return nil
}

serviceInvoker := stateMachineConfig.ServiceInvokerManager().ServiceInvoker(serviceTaskStateImpl.ServiceType())
if serviceInvoker == nil {
resultErr = exception.NewEngineExecutionException(seataErrors.ObjectNotExists,
"No such ServiceInvoker["+serviceTaskStateImpl.ServiceType()+"]", nil)
handleResultErr(resultErr)
return nil
}

result, resultErr = serviceInvoker.Invoke(ctx, input, serviceTaskStateImpl)
if resultErr != nil {
handleResultErr(resultErr)
return nil
}
}

log.Debugf("<<<<<<<<<<<<<<<<<<<<<< State[%s], ServiceName[%s], Method[%s] Execute finish. result: %s",
serviceTaskStateImpl.Name(), serviceName, methodName, result)

if result != nil {
stateInstance.SetOutputParams(result)
hierarchicalProcessContext, ok := processContext.(process_ctrl.HierarchicalProcessContext)
if !ok {
handleResultErr(errors.New("invalid hierarchical process context type from processContext"))
return nil
}

hierarchicalProcessContext.SetVariable(constant.VarNameOutputParams, result)
}

return nil
}

func (s *ServiceTaskStateHandler) StateHandlerInterceptorList() []pcext.StateHandlerInterceptor {
return s.interceptors
}

func (s *ServiceTaskStateHandler) RegistryStateHandlerInterceptor(stateHandlerInterceptor pcext.StateHandlerInterceptor) {
s.interceptors = append(s.interceptors, stateHandlerInterceptor)
}

func (s *ServiceTaskStateHandler) compensateSubStateMachine(ctx context.Context, processContext process_ctrl.ProcessContext,
serviceTaskState state.ServiceTaskState, input any, instance statelang.StateInstance,
machineEngine engine.StateMachineEngine) (any, error) {
subStateMachineParentId, ok := processContext.GetVariable(serviceTaskState.Name() + constant.VarNameSubMachineParentId).(string)
if !ok {
return nil, errors.New("invalid subStateMachineParentId type from processContext")
}

if subStateMachineParentId == "" {
return nil, exception.NewEngineExecutionException(seataErrors.ObjectNotExists,
"sub statemachine parentId is required", nil)
}

stateMachineConfig := processContext.GetVariable(constant.VarNameStateMachineConfig).(engine.StateMachineConfig)
subInst, err := stateMachineConfig.StateLogStore().GetStateMachineInstanceByParentId(subStateMachineParentId)
if err != nil {
return nil, err
}

if subInst == nil || len(subInst) == 0 {
return nil, exception.NewEngineExecutionException(seataErrors.ObjectNotExists,
"cannot find sub statemachine instance by parentId:"+subStateMachineParentId, nil)
}

subStateMachineInstId := subInst[0].ID()
log.Debugf(">>>>>>>>>>>>>>>>>>>>>> Start to compensate sub statemachine [id:%s]", subStateMachineInstId)

startParams := make(map[string]any)

if inputList, ok := input.([]any); ok {
if len(inputList) > 0 {
startParams = inputList[0].(map[string]any)
}
} else if inputMap, ok := input.(map[string]any); ok {
startParams = inputMap
}

compensateInst, err := machineEngine.Compensate(ctx, subStateMachineInstId, startParams)
instance.SetStatus(compensateInst.CompensationStatus())
log.Debugf("<<<<<<<<<<<<<<<<<<<<<< Compensate sub statemachine [id:%s] finished with status[%s], "+"compensateState[%s]",
subStateMachineInstId, compensateInst.Status(), compensateInst.CompensationStatus())
return compensateInst.EndParams(), nil
}

+ 21
- 0
pkg/saga/statemachine/process_ctrl/instruction.go View File

@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package process_ctrl

type Instruction interface {
}

+ 24
- 0
pkg/saga/statemachine/process_ctrl/process/process_type.go View File

@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package process

type ProcessType string

const (
StateLang ProcessType = "STATE_LANG" // SEATA State Language
)

+ 225
- 0
pkg/saga/statemachine/process_ctrl/process_context.go View File

@@ -0,0 +1,225 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package process_ctrl

import (
"sync"
)

type ProcessContext interface {
GetVariable(name string) interface{}

SetVariable(name string, value interface{})

GetVariables() map[string]interface{}

SetVariables(variables map[string]interface{})

RemoveVariable(name string) interface{}

HasVariable(name string) bool

GetInstruction() Instruction

SetInstruction(instruction Instruction)
}

type HierarchicalProcessContext interface {
ProcessContext

GetVariableLocally(name string) interface{}

SetVariableLocally(name string, value interface{})

GetVariablesLocally() map[string]interface{}

SetVariablesLocally(variables map[string]interface{})

RemoveVariableLocally(name string) interface{}

HasVariableLocally(name string) bool

ClearLocally()
}

type ProcessContextImpl struct {
parent ProcessContext
mu sync.RWMutex
mp map[string]interface{}
instruction Instruction
}

func NewProcessContextImpl() *ProcessContextImpl {
return &ProcessContextImpl{
mp: make(map[string]interface{}),
}
}

func (p *ProcessContextImpl) GetVariable(name string) interface{} {
p.mu.RLock()
defer p.mu.RUnlock()

value, ok := p.mp[name]
if ok {
return value
}

if p.parent != nil {
return p.parent.GetVariable(name)
}

return nil
}

func (p *ProcessContextImpl) SetVariable(name string, value interface{}) {
p.mu.Lock()
defer p.mu.Unlock()

_, ok := p.mp[name]
if ok {
p.mp[name] = value
} else {
if p.parent != nil {
p.parent.SetVariable(name, value)
} else {
p.mp[name] = value
}
}
}

func (p *ProcessContextImpl) GetVariables() map[string]interface{} {
p.mu.RLock()
defer p.mu.RUnlock()

newVariablesMap := make(map[string]interface{})
if p.parent != nil {
variables := p.parent.GetVariables()
for k, v := range variables {
newVariablesMap[k] = v
}
}

for k, v := range p.mp {
newVariablesMap[k] = v
}

return newVariablesMap
}

func (p *ProcessContextImpl) SetVariables(variables map[string]interface{}) {
for k, v := range variables {
p.SetVariable(k, v)
}
}

func (p *ProcessContextImpl) RemoveVariable(name string) interface{} {
p.mu.Lock()
defer p.mu.Unlock()

value, ok := p.mp[name]
if ok {
delete(p.mp, name)
return value
}

if p.parent != nil {
return p.parent.RemoveVariable(name)
}

return nil
}

func (p *ProcessContextImpl) HasVariable(name string) bool {
p.mu.RLock()
defer p.mu.RUnlock()

_, ok := p.mp[name]
if ok {
return true
}

if p.parent != nil {
return p.parent.HasVariable(name)
}

return false
}

func (p *ProcessContextImpl) GetInstruction() Instruction {
return p.instruction
}

func (p *ProcessContextImpl) SetInstruction(instruction Instruction) {
p.instruction = instruction
}

func (p *ProcessContextImpl) GetVariableLocally(name string) interface{} {
p.mu.RLock()
defer p.mu.RUnlock()

value, _ := p.mp[name]
return value
}

func (p *ProcessContextImpl) SetVariableLocally(name string, value interface{}) {
p.mu.Lock()
defer p.mu.Unlock()

p.mp[name] = value
}

func (p *ProcessContextImpl) GetVariablesLocally() map[string]interface{} {
p.mu.RLock()
defer p.mu.RUnlock()

newVariablesMap := make(map[string]interface{}, len(p.mp))
for k, v := range p.mp {
newVariablesMap[k] = v
}
return newVariablesMap
}

func (p *ProcessContextImpl) SetVariablesLocally(variables map[string]interface{}) {
for k, v := range variables {
p.SetVariableLocally(k, v)
}
}

func (p *ProcessContextImpl) RemoveVariableLocally(name string) interface{} {
p.mu.Lock()
defer p.mu.Unlock()

value, _ := p.mp[name]
delete(p.mp, name)
return value
}

func (p *ProcessContextImpl) HasVariableLocally(name string) bool {
p.mu.RLock()
defer p.mu.RUnlock()

_, ok := p.mp[name]
return ok
}

func (p *ProcessContextImpl) ClearLocally() {
p.mu.Lock()
defer p.mu.Unlock()

p.mp = map[string]interface{}{}
}

+ 48
- 0
pkg/saga/statemachine/process_ctrl/process_controller.go View File

@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package process_ctrl

import (
"context"
)

type ProcessController interface {
Process(ctx context.Context, context ProcessContext) error
}

type ProcessControllerImpl struct {
businessProcessor BusinessProcessor
}

func (p *ProcessControllerImpl) Process(ctx context.Context, context ProcessContext) error {
if err := p.businessProcessor.Process(ctx, context); err != nil {
return err
}
if err := p.businessProcessor.Route(ctx, context); err != nil {
return err
}
return nil
}

func (p *ProcessControllerImpl) BusinessProcessor() BusinessProcessor {
return p.businessProcessor
}

func (p *ProcessControllerImpl) SetBusinessProcessor(businessProcessor BusinessProcessor) {
p.businessProcessor = businessProcessor
}

+ 107
- 0
pkg/saga/statemachine/process_ctrl/process_router.go View File

@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package process_ctrl

import (
"context"
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/process_ctrl/process"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/util/log"
)

type RouterHandler interface {
Route(ctx context.Context, processContext ProcessContext) error
}

type ProcessRouter interface {
Route(ctx context.Context, processContext ProcessContext) error
}

type InterceptAbleStateRouter interface {
StateRouter
StateRouterInterceptor() []StateRouterInterceptor
RegistryStateRouterInterceptor(stateRouterInterceptor StateRouterInterceptor)
}

type StateRouter interface {
Route(ctx context.Context, processContext ProcessContext, state statelang.State) (Instruction, error)
}

type StateRouterInterceptor interface {
PreRoute(ctx context.Context, processContext ProcessContext, state statelang.State) error
PostRoute(ctx context.Context, processContext ProcessContext, instruction Instruction, err error) error
Match(stateType string) bool
}

type DefaultRouterHandler struct {
eventPublisher EventPublisher
processRouters map[string]ProcessRouter
}

func (d *DefaultRouterHandler) Route(ctx context.Context, processContext ProcessContext) error {
processType := d.matchProcessType(ctx, processContext)
if processType == "" {
log.Warnf("Process type not found, context= %s", processContext)
return errors.New("Process type not found")
}

processRouter := d.processRouters[string(processType)]
if processRouter == nil {
log.Errorf("Cannot find process router by type %s, context = %s", processType, processContext)
return errors.New("Process router not found")
}

instruction := processRouter.Route(ctx, processContext)
if instruction == nil {
log.Info("route instruction is null, process end")
} else {
processContext.SetInstruction(instruction)
_, err := d.eventPublisher.PushEvent(ctx, processContext)
if err != nil {
return err
}
}

return nil
}

func (d *DefaultRouterHandler) matchProcessType(ctx context.Context, processContext ProcessContext) process.ProcessType {
processType, ok := processContext.GetVariable(constant.VarNameProcessType).(process.ProcessType)
if !ok || processType == "" {
processType = process.StateLang
}
return processType
}

func (d *DefaultRouterHandler) EventPublisher() EventPublisher {
return d.eventPublisher
}

func (d *DefaultRouterHandler) SetEventPublisher(eventPublisher EventPublisher) {
d.eventPublisher = eventPublisher
}

func (d *DefaultRouterHandler) ProcessRouters() map[string]ProcessRouter {
return d.processRouters
}

func (d *DefaultRouterHandler) SetProcessRouters(processRouters map[string]ProcessRouter) {
d.processRouters = processRouters
}

+ 91
- 0
pkg/saga/statemachine/statelang/parser/choice_state_json_parser.go View File

@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parser

import (
"fmt"
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
)

type ChoiceStateParser struct {
*BaseStateParser
}

func NewChoiceStateParser() *ChoiceStateParser {
return &ChoiceStateParser{
&BaseStateParser{},
}
}

func (c ChoiceStateParser) StateType() string {
return constant.StateTypeChoice
}

func (c ChoiceStateParser) Parse(stateName string, stateMap map[string]interface{}) (statelang.State, error) {
choiceState := state.NewChoiceStateImpl()
choiceState.SetName(stateName)

//parse Type
typeName, err := c.GetString(stateName, stateMap, "Type")
if err != nil {
return nil, err
}
choiceState.SetType(typeName)

//parse Default
defaultChoice, err := c.GetString(stateName, stateMap, "Default")
if err != nil {
return nil, err
}
choiceState.SetDefault(defaultChoice)

//parse Choices
slice, err := c.GetSlice(stateName, stateMap, "Choices")
if err != nil {
return nil, err
}

var choices []state.Choice
for i := range slice {
choiceValMap, ok := slice[i].(map[string]interface{})
if !ok {
return nil, errors.New(fmt.Sprintf("State [%s] Choices element required struct", stateName))
}

choice := state.NewChoiceImpl()
expression, err := c.GetString(stateName, choiceValMap, "Expression")
if err != nil {
return nil, err
}
choice.SetExpression(expression)

next, err := c.GetString(stateName, choiceValMap, "Next")
if err != nil {
return nil, err
}
choice.SetNext(next)

choices = append(choices, choice)
}
choiceState.SetChoices(choices)

return choiceState, nil
}

+ 48
- 0
pkg/saga/statemachine/statelang/parser/compensation_trigger_state_parser.go View File

@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parser

import (
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
)

type CompensationTriggerStateParser struct {
*BaseStateParser
}

func NewCompensationTriggerStateParser() *CompensationTriggerStateParser {
return &CompensationTriggerStateParser{
&BaseStateParser{},
}
}

func (c CompensationTriggerStateParser) StateType() string {
return constant.StateTypeCompensationTrigger
}

func (c CompensationTriggerStateParser) Parse(stateName string, stateMap map[string]interface{}) (statelang.State, error) {
compensateSubStateMachineStateImpl := state.NewCompensationTriggerStateImpl()
err := c.ParseBaseAttributes(stateName, compensateSubStateMachineStateImpl, stateMap)
if err != nil {
return nil, err
}

return compensateSubStateMachineStateImpl, nil
}

+ 83
- 0
pkg/saga/statemachine/statelang/parser/end_state_parser.go View File

@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parser

import (
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
)

type SucceedEndStateParser struct {
*BaseStateParser
}

func NewSucceedEndStateParser() *SucceedEndStateParser {
return &SucceedEndStateParser{
&BaseStateParser{},
}
}

func (s SucceedEndStateParser) StateType() string {
return constant.StateTypeSucceed
}

func (s SucceedEndStateParser) Parse(stateName string, stateMap map[string]interface{}) (statelang.State, error) {
succeedEndStateImpl := state.NewSucceedEndStateImpl()
err := s.ParseBaseAttributes(stateName, succeedEndStateImpl, stateMap)
if err != nil {
return nil, err
}

return succeedEndStateImpl, nil
}

type FailEndStateParser struct {
*BaseStateParser
}

func NewFailEndStateParser() *FailEndStateParser {
return &FailEndStateParser{
&BaseStateParser{},
}
}

func (f FailEndStateParser) StateType() string {
return constant.StateTypeFail
}

func (f FailEndStateParser) Parse(stateName string, stateMap map[string]interface{}) (statelang.State, error) {
failEndStateImpl := state.NewFailEndStateImpl()
err := f.ParseBaseAttributes(stateName, failEndStateImpl, stateMap)
if err != nil {
return nil, err
}

errorCode, err := f.GetStringOrDefault(stateName, stateMap, "ErrorCode", "")
if err != nil {
return nil, err
}
failEndStateImpl.SetErrorCode(errorCode)

message, err := f.GetStringOrDefault(stateName, stateMap, "Message", "")
if err != nil {
return nil, err
}
failEndStateImpl.SetMessage(message)
return failEndStateImpl, nil
}

+ 139
- 0
pkg/saga/statemachine/statelang/parser/statemachine_config_parser.go View File

@@ -0,0 +1,139 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parser

import (
"bytes"
"fmt"
"github.com/seata/seata-go/pkg/saga/statemachine"
"io"
"os"

"github.com/knadh/koanf"
"github.com/knadh/koanf/parsers/json"
"github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/rawbytes"
)

// ConfigParser is a general configuration parser interface, used to agree on the implementation of different types of parsers
type ConfigParser interface {
Parse(configContent []byte) (*statemachine.StateMachineObject, error)
}

type JSONConfigParser struct{}

func NewJSONConfigParser() *JSONConfigParser {
return &JSONConfigParser{}
}

func (p *JSONConfigParser) Parse(configContent []byte) (*statemachine.StateMachineObject, error) {
if configContent == nil || len(configContent) == 0 {
return nil, fmt.Errorf("empty JSON config content")
}

k := koanf.New(".")
if err := k.Load(rawbytes.Provider(configContent), json.Parser()); err != nil {
return nil, fmt.Errorf("failed to parse JSON config content: %w", err)
}

var stateMachineObject statemachine.StateMachineObject
if err := k.Unmarshal("", &stateMachineObject); err != nil {
return nil, fmt.Errorf("failed to unmarshal JSON config to struct: %w", err)
}

return &stateMachineObject, nil
}

type YAMLConfigParser struct{}

func NewYAMLConfigParser() *YAMLConfigParser {
return &YAMLConfigParser{}
}

func (p *YAMLConfigParser) Parse(configContent []byte) (*statemachine.StateMachineObject, error) {
if configContent == nil || len(configContent) == 0 {
return nil, fmt.Errorf("empty YAML config content")
}

k := koanf.New(".")
if err := k.Load(rawbytes.Provider(configContent), yaml.Parser()); err != nil {
return nil, fmt.Errorf("failed to parse YAML config content: %w", err)
}

var stateMachineObject statemachine.StateMachineObject
if err := k.Unmarshal("", &stateMachineObject); err != nil {
return nil, fmt.Errorf("failed to unmarshal YAML config to struct: %w", err)
}

return &stateMachineObject, nil
}

type StateMachineConfigParser struct{}

func NewStateMachineConfigParser() *StateMachineConfigParser {
return &StateMachineConfigParser{}
}

func (p *StateMachineConfigParser) CheckConfigFile(filePath string) error {
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fmt.Errorf("config file %s does not exist: %w", filePath, err)
}
if err != nil {
return fmt.Errorf("failed to access config file %s: %w", filePath, err)
}
return nil
}

func (p *StateMachineConfigParser) ReadConfigFile(configFilePath string) ([]byte, error) {
file, _ := os.Open(configFilePath)
defer func(file *os.File) {
_ = file.Close()
}(file)

var buf bytes.Buffer
_, err := io.Copy(&buf, file)
if err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", configFilePath, err)
}

return buf.Bytes(), nil
}

func (p *StateMachineConfigParser) getParser(content []byte) (ConfigParser, error) {
k := koanf.New(".")
if err := k.Load(rawbytes.Provider(content), json.Parser()); err == nil {
return NewJSONConfigParser(), nil
}

k = koanf.New(".")
if err := k.Load(rawbytes.Provider(content), yaml.Parser()); err == nil {
return NewYAMLConfigParser(), nil
}

return nil, fmt.Errorf("unsupported config file format")
}

func (p *StateMachineConfigParser) Parse(content []byte) (*statemachine.StateMachineObject, error) {
parser, err := p.getParser(content)
if err != nil {
return nil, err
}

return parser.Parse(content)
}

+ 882
- 0
pkg/saga/statemachine/statelang/parser/statemachine_config_parser_test.go View File

@@ -0,0 +1,882 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parser

import (
"github.com/seata/seata-go/pkg/saga/statemachine"
"github.com/stretchr/testify/assert"
"testing"
)

func TestStateMachineConfigParser_Parse(t *testing.T) {
parser := NewStateMachineConfigParser()

tests := []struct {
name string
configFilePath string
expectedObject *statemachine.StateMachineObject
}{
{
name: "JSON Simple 1",
configFilePath: "../../../../../testdata/saga/statelang/simple_statelang_with_choice.json",
expectedObject: GetStateMachineObject1("json"),
},
{
name: "JSON Simple 2",
configFilePath: "../../../../../testdata/saga/statelang/simple_statemachine.json",
expectedObject: GetStateMachineObject2("json"),
},
{
name: "JSON Simple 3",
configFilePath: "../../../../../testdata/saga/statelang/state_machine_new_designer.json",
expectedObject: GetStateMachineObject3("json"),
},
{
name: "YAML Simple 1",
configFilePath: "../../../../../testdata/saga/statelang/simple_statelang_with_choice.yaml",
expectedObject: GetStateMachineObject1("yaml"),
},
{
name: "YAML Simple 2",
configFilePath: "../../../../../testdata/saga/statelang/simple_statemachine.yaml",
expectedObject: GetStateMachineObject2("yaml"),
},
{
name: "YAML Simple 3",
configFilePath: "../../../../../testdata/saga/statelang/state_machine_new_designer.yaml",
expectedObject: GetStateMachineObject3("yaml"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
content, err := parser.ReadConfigFile(tt.configFilePath)
if err != nil {
t.Error("parse fail: " + err.Error())
}
object, err := parser.Parse(content)
if err != nil {
t.Error("parse fail: " + err.Error())
}
assert.Equal(t, tt.expectedObject, object)
})
}
}

func GetStateMachineObject1(format string) *statemachine.StateMachineObject {
switch format {
case "json":
case "yaml":
}

return &statemachine.StateMachineObject{
Name: "simpleChoiceTestStateMachine",
Comment: "带条件分支的测试状态机定义",
StartState: "FirstState",
Version: "0.0.1",
States: map[string]interface{}{
"FirstState": map[string]interface{}{
"Type": "ServiceTask",
"ServiceName": "demoService",
"ServiceMethod": "foo",
"Next": "ChoiceState",
},
"ChoiceState": map[string]interface{}{
"Type": "Choice",
"Choices": []interface{}{
map[string]interface{}{
"Expression": "[a] == 1",
"Next": "SecondState",
},
map[string]interface{}{
"Expression": "[a] == 2",
"Next": "ThirdState",
},
},
"Default": "SecondState",
},
"SecondState": map[string]interface{}{
"Type": "ServiceTask",
"ServiceName": "demoService",
"ServiceMethod": "bar",
},
"ThirdState": map[string]interface{}{
"Type": "ServiceTask",
"ServiceName": "demoService",
"ServiceMethod": "foo",
},
},
}
}

func GetStateMachineObject2(format string) *statemachine.StateMachineObject {
var retryMap map[string]interface{}

switch format {
case "json":
retryMap = map[string]interface{}{
"Exceptions": []interface{}{
"java.lang.Exception",
},
"IntervalSeconds": float64(2),
"MaxAttempts": float64(3),
"BackoffRate": 1.5,
}
case "yaml":
retryMap = map[string]interface{}{
"Exceptions": []interface{}{
"java.lang.Exception",
},
"IntervalSeconds": 2,
"MaxAttempts": 3,
"BackoffRate": 1.5,
}
}

return &statemachine.StateMachineObject{
Name: "simpleTestStateMachine",
Comment: "测试状态机定义",
StartState: "FirstState",
Version: "0.0.1",
States: map[string]interface{}{
"FirstState": map[string]interface{}{
"Type": "ServiceTask",
"ServiceName": "is.seata.saga.DemoService",
"ServiceMethod": "foo",
"IsPersist": false,
"Next": "ScriptState",
},
"ScriptState": map[string]interface{}{
"Type": "ScriptTask",
"ScriptType": "groovy",
"ScriptContent": "return 'hello ' + inputA",
"Input": []interface{}{
map[string]interface{}{
"inputA": "$.data1",
},
},
"Output": map[string]interface{}{
"scriptStateResult": "$.#root",
},
"Next": "ChoiceState",
},
"ChoiceState": map[string]interface{}{
"Type": "Choice",
"Choices": []interface{}{
map[string]interface{}{
"Expression": "foo == 1",
"Next": "FirstMatchState",
},
map[string]interface{}{
"Expression": "foo == 2",
"Next": "SecondMatchState",
},
},
"Default": "FailState",
},
"FirstMatchState": map[string]interface{}{
"Type": "ServiceTask",
"ServiceName": "is.seata.saga.DemoService",
"ServiceMethod": "bar",
"CompensateState": "CompensateFirst",
"Status": map[string]interface{}{
"return.code == 'S'": "SU",
"return.code == 'F'": "FA",
"$exception{java.lang.Throwable}": "UN",
},
"Input": []interface{}{
map[string]interface{}{
"inputA1": "$.data1",
"inputA2": map[string]interface{}{
"a": "$.data2.a",
},
},
map[string]interface{}{
"inputB": "$.header",
},
},
"Output": map[string]interface{}{
"firstMatchStateResult": "$.#root",
},
"Retry": []interface{}{
retryMap,
},
"Catch": []interface{}{
map[string]interface{}{
"Exceptions": []interface{}{
"java.lang.Exception",
},
"Next": "CompensationTrigger",
},
},
"Next": "SuccessState",
},
"CompensateFirst": map[string]interface{}{
"Type": "ServiceTask",
"ServiceName": "is.seata.saga.DemoService",
"ServiceMethod": "compensateBar",
"IsForCompensation": true,
"IsForUpdate": true,
"Input": []interface{}{
map[string]interface{}{
"input": "$.data",
},
},
"Output": map[string]interface{}{
"firstMatchStateResult": "$.#root",
},
"Status": map[string]interface{}{
"return.code == 'S'": "SU",
"return.code == 'F'": "FA",
"$exception{java.lang.Throwable}": "UN",
},
},
"CompensationTrigger": map[string]interface{}{
"Type": "CompensationTrigger",
"Next": "CompensateEndState",
},
"CompensateEndState": map[string]interface{}{
"Type": "Fail",
"ErrorCode": "StateCompensated",
"Message": "State Compensated!",
},
"SecondMatchState": map[string]interface{}{
"Type": "SubStateMachine",
"StateMachineName": "simpleTestSubStateMachine",
"Input": []interface{}{
map[string]interface{}{
"input": "$.data",
},
map[string]interface{}{
"header": "$.header",
},
},
"Output": map[string]interface{}{
"firstMatchStateResult": "$.#root",
},
"Next": "SuccessState",
},
"FailState": map[string]interface{}{
"Type": "Fail",
"ErrorCode": "DefaultStateError",
"Message": "No Matches!",
},
"SuccessState": map[string]interface{}{
"Type": "Succeed",
},
},
}
}

func GetStateMachineObject3(format string) *statemachine.StateMachineObject {
var (
boundsMap1 map[string]interface{}
boundsMap2 map[string]interface{}
boundsMap3 map[string]interface{}
boundsMap4 map[string]interface{}
boundsMap5 map[string]interface{}
boundsMap6 map[string]interface{}
boundsMap7 map[string]interface{}
boundsMap8 map[string]interface{}
boundsMap9 map[string]interface{}

waypoints1 []interface{}
waypoints2 []interface{}
waypoints3 []interface{}
waypoints4 []interface{}
waypoints5 []interface{}
waypoints6 []interface{}
waypoints7 []interface{}
)

switch format {
case "json":
boundsMap1 = map[string]interface{}{
"x": float64(300),
"y": float64(178),
"width": float64(100),
"height": float64(80),
}
boundsMap2 = map[string]interface{}{
"x": float64(455),
"y": float64(193),
"width": float64(50),
"height": float64(50),
}
boundsMap3 = map[string]interface{}{
"x": float64(300),
"y": float64(310),
"width": float64(100),
"height": float64(80),
}
boundsMap4 = map[string]interface{}{
"x": float64(550),
"y": float64(178),
"width": float64(100),
"height": float64(80),
}
boundsMap5 = map[string]interface{}{
"x": float64(550),
"y": float64(310),
"width": float64(100),
"height": float64(80),
}
boundsMap6 = map[string]interface{}{
"x": float64(632),
"y": float64(372),
"width": float64(36),
"height": float64(36),
}
boundsMap7 = map[string]interface{}{
"x": float64(722),
"y": float64(200),
"width": float64(36),
"height": float64(36),
}
boundsMap8 = map[string]interface{}{
"x": float64(722),
"y": float64(372),
"width": float64(36),
"height": float64(36),
}
boundsMap9 = map[string]interface{}{
"x": float64(812),
"y": float64(372),
"width": float64(36),
"height": float64(36),
}

waypoints1 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(400),
"y": float64(218),
},
"x": float64(400),
"y": float64(218),
},
map[string]interface{}{"x": float64(435), "y": float64(218)},
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(455),
"y": float64(218),
},
"x": float64(455),
"y": float64(218),
},
}
waypoints2 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(505),
"y": float64(218),
},
"x": float64(505),
"y": float64(218),
},
map[string]interface{}{"x": float64(530), "y": float64(218)},
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(550),
"y": float64(218),
},
"x": float64(550),
"y": float64(218),
},
}
waypoints3 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(480),
"y": float64(243),
},
"x": float64(480),
"y": float64(243),
},
map[string]interface{}{"x": float64(600), "y": float64(290)},
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(600),
"y": float64(310),
},
"x": float64(600),
"y": float64(310),
},
}
waypoints4 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(650),
"y": float64(218),
},
"x": float64(650),
"y": float64(218),
},
map[string]interface{}{"x": float64(702), "y": float64(218)},
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(722),
"y": float64(218),
},
"x": float64(722),
"y": float64(218),
},
}
waypoints5 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(668),
"y": float64(390),
},
"x": float64(668),
"y": float64(390),
},
map[string]interface{}{"x": float64(702), "y": float64(390)},
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(722),
"y": float64(390),
},
"x": float64(722),
"y": float64(390),
},
}
waypoints6 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(600),
"y": float64(310),
},
"x": float64(600),
"y": float64(310),
},
map[string]interface{}{"x": float64(740), "y": float64(256)},
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(740),
"y": float64(236),
},
"x": float64(740),
"y": float64(236),
},
}
waypoints7 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(758),
"y": float64(390),
},
"x": float64(758),
"y": float64(390),
},
map[string]interface{}{"x": float64(792), "y": float64(390)},
map[string]interface{}{
"original": map[string]interface{}{
"x": float64(812),
"y": float64(390),
},
"x": float64(812),
"y": float64(390),
},
}

case "yaml":
boundsMap1 = map[string]interface{}{
"x": 300,
"y": 178,
"width": 100,
"height": 80,
}
boundsMap2 = map[string]interface{}{
"x": 455,
"y": 193,
"width": 50,
"height": 50,
}
boundsMap3 = map[string]interface{}{
"x": 300,
"y": 310,
"width": 100,
"height": 80,
}
boundsMap4 = map[string]interface{}{
"x": 550,
"y": 178,
"width": 100,
"height": 80,
}
boundsMap5 = map[string]interface{}{
"x": 550,
"y": 310,
"width": 100,
"height": 80,
}
boundsMap6 = map[string]interface{}{
"x": 632,
"y": 372,
"width": 36,
"height": 36,
}
boundsMap7 = map[string]interface{}{
"x": 722,
"y": 200,
"width": 36,
"height": 36,
}
boundsMap8 = map[string]interface{}{
"x": 722,
"y": 372,
"width": 36,
"height": 36,
}
boundsMap9 = map[string]interface{}{
"x": 812,
"y": 372,
"width": 36,
"height": 36,
}

waypoints1 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": 400,
"y": 218,
},
"x": 400,
"y": 218,
},
map[string]interface{}{"x": 435, "y": 218},
map[string]interface{}{
"original": map[string]interface{}{
"x": 455,
"y": 218,
},
"x": 455,
"y": 218,
},
}
waypoints2 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": 505,
"y": 218,
},
"x": 505,
"y": 218,
},
map[string]interface{}{"x": 530, "y": 218},
map[string]interface{}{
"original": map[string]interface{}{
"x": 550,
"y": 218,
},
"x": 550,
"y": 218,
},
}
waypoints3 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": 480,
"y": 243,
},
"x": 480,
"y": 243,
},
map[string]interface{}{"x": 600, "y": 290},
map[string]interface{}{
"original": map[string]interface{}{
"x": 600,
"y": 310,
},
"x": 600,
"y": 310,
},
}
waypoints4 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": 650,
"y": 218,
},
"x": 650,
"y": 218,
},
map[string]interface{}{"x": 702, "y": 218},
map[string]interface{}{
"original": map[string]interface{}{
"x": 722,
"y": 218,
},
"x": 722,
"y": 218,
},
}
waypoints5 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": 668,
"y": 390,
},
"x": 668,
"y": 390,
},
map[string]interface{}{"x": 702, "y": 390},
map[string]interface{}{
"original": map[string]interface{}{
"x": 722,
"y": 390,
},
"x": 722,
"y": 390,
},
}
waypoints6 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": 600,
"y": 310,
},
"x": 600,
"y": 310,
},
map[string]interface{}{"x": 740, "y": 256},
map[string]interface{}{
"original": map[string]interface{}{
"x": 740,
"y": 236,
},
"x": 740,
"y": 236,
},
}
waypoints7 = []interface{}{
map[string]interface{}{
"original": map[string]interface{}{
"x": 758,
"y": 390,
},
"x": 758,
"y": 390,
},
map[string]interface{}{"x": 792, "y": 390},
map[string]interface{}{
"original": map[string]interface{}{
"x": 812,
"y": 390,
},
"x": 812,
"y": 390,
},
}
}

return &statemachine.StateMachineObject{
Name: "StateMachineNewDesigner",
Comment: "This state machine is modeled by designer tools.",
Version: "0.0.1",
StartState: "ServiceTask-a9h2o51",
RecoverStrategy: "",
Persist: false,
RetryPersistModeUpdate: false,
CompensatePersistModeUpdate: false,
Type: "",
States: map[string]interface{}{
"ServiceTask-a9h2o51": map[string]interface{}{
"style": map[string]interface{}{
"bounds": boundsMap1,
},
"Name": "ServiceTask-a9h2o51",
"IsForCompensation": false,
"Input": []interface{}{map[string]interface{}{}},
"Output": map[string]interface{}{},
"Status": map[string]interface{}{},
"Retry": []interface{}{},
"ServiceName": "",
"ServiceMethod": "",
"Type": "ServiceTask",
"Next": "Choice-4ajl8nt",
"edge": map[string]interface{}{
"Choice-4ajl8nt": map[string]interface{}{
"style": map[string]interface{}{
"waypoints": waypoints1,
"source": "ServiceTask-a9h2o51",
"target": "Choice-4ajl8nt",
},
"Type": "Transition",
},
},
"CompensateState": "CompensateFirstState",
},
"Choice-4ajl8nt": map[string]interface{}{
"style": map[string]interface{}{
"bounds": boundsMap2,
},
"Name": "Choice-4ajl8nt",
"Type": "Choice",
"Choices": []interface{}{
map[string]interface{}{
"Expression": "",
"Next": "SubStateMachine-cauj9uy",
},
map[string]interface{}{
"Expression": "",
"Next": "ServiceTask-vdij28l",
},
},
"Default": "SubStateMachine-cauj9uy",
"edge": map[string]interface{}{
"SubStateMachine-cauj9uy": map[string]interface{}{
"style": map[string]interface{}{
"waypoints": waypoints2,
"source": "Choice-4ajl8nt",
"target": "SubStateMachine-cauj9uy",
},
"Type": "ChoiceEntry",
},
"ServiceTask-vdij28l": map[string]interface{}{
"style": map[string]interface{}{
"waypoints": waypoints3,
"source": "Choice-4ajl8nt",
"target": "ServiceTask-vdij28l",
},
"Type": "ChoiceEntry",
},
},
},
"CompensateFirstState": map[string]interface{}{
"style": map[string]interface{}{
"bounds": boundsMap3,
},
"Name": "CompensateFirstState",
"IsForCompensation": true,
"Input": []interface{}{map[string]interface{}{}},
"Output": map[string]interface{}{},
"Status": map[string]interface{}{},
"Retry": []interface{}{},
"ServiceName": "",
"ServiceMethod": "",
"Type": "ServiceTask",
},
"SubStateMachine-cauj9uy": map[string]interface{}{
"style": map[string]interface{}{
"bounds": boundsMap4,
},
"Name": "SubStateMachine-cauj9uy",
"IsForCompensation": false,
"Input": []interface{}{map[string]interface{}{}},
"Output": map[string]interface{}{},
"Status": map[string]interface{}{},
"Retry": []interface{}{},
"StateMachineName": "",
"Type": "SubStateMachine",
"Next": "Succeed-5x3z98u",
"edge": map[string]interface{}{
"Succeed-5x3z98u": map[string]interface{}{
"style": map[string]interface{}{
"waypoints": waypoints4,
"source": "SubStateMachine-cauj9uy",
"target": "Succeed-5x3z98u",
},
"Type": "Transition",
},
},
},
"ServiceTask-vdij28l": map[string]interface{}{
"style": map[string]interface{}{
"bounds": boundsMap5,
},
"Name": "ServiceTask-vdij28l",
"IsForCompensation": false,
"Input": []interface{}{map[string]interface{}{}},
"Output": map[string]interface{}{},
"Status": map[string]interface{}{},
"Retry": []interface{}{},
"ServiceName": "",
"ServiceMethod": "",
"Catch": []interface{}{
map[string]interface{}{
"Exceptions": []interface{}{},
"Next": "CompensationTrigger-uldp2ou",
},
},
"Type": "ServiceTask",
"catch": map[string]interface{}{
"style": map[string]interface{}{
"bounds": boundsMap6,
},
"edge": map[string]interface{}{
"CompensationTrigger-uldp2ou": map[string]interface{}{
"style": map[string]interface{}{
"waypoints": waypoints5,
"source": "ServiceTask-vdij28l",
"target": "CompensationTrigger-uldp2ou",
},
"Type": "ExceptionMatch",
},
},
},
"Next": "Succeed-5x3z98u",
"edge": map[string]interface{}{
"Succeed-5x3z98u": map[string]interface{}{
"style": map[string]interface{}{
"waypoints": waypoints6,
"source": "ServiceTask-vdij28l",
"target": "Succeed-5x3z98u",
},
"Type": "Transition",
},
},
},
"Succeed-5x3z98u": map[string]interface{}{
"style": map[string]interface{}{
"bounds": boundsMap7,
},
"Name": "Succeed-5x3z98u",
"Type": "Succeed",
},
"CompensationTrigger-uldp2ou": map[string]interface{}{
"style": map[string]interface{}{
"bounds": boundsMap8,
},
"Name": "CompensationTrigger-uldp2ou",
"Type": "CompensationTrigger",
"Next": "Fail-9roxcv5",
"edge": map[string]interface{}{
"Fail-9roxcv5": map[string]interface{}{
"style": map[string]interface{}{
"waypoints": waypoints7,
"source": "CompensationTrigger-uldp2ou",
"target": "Fail-9roxcv5",
},
"Type": "Transition",
},
},
},
"Fail-9roxcv5": map[string]interface{}{
"style": map[string]interface{}{
"bounds": boundsMap9,
},
"Name": "Fail-9roxcv5",
"ErrorCode": "",
"Message": "",
"Type": "Fail",
},
},
}
}

+ 130
- 0
pkg/saga/statemachine/statelang/parser/statemachine_json_parser.go View File

@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parser

import (
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
)

type JSONStateMachineParser struct {
*BaseStateParser
}

func NewJSONStateMachineParser() *JSONStateMachineParser {
return &JSONStateMachineParser{
&BaseStateParser{},
}
}

func (stateMachineParser JSONStateMachineParser) GetType() string {
return "JSON"
}

func (stateMachineParser JSONStateMachineParser) Parse(content string) (statelang.StateMachine, error) {
stateMachineJsonObject, err := NewStateMachineConfigParser().Parse([]byte(content))
if err != nil {
return nil, err
}

stateMachine := statelang.NewStateMachineImpl()
stateMachine.SetName(stateMachineJsonObject.Name)
stateMachine.SetComment(stateMachineJsonObject.Comment)
stateMachine.SetVersion(stateMachineJsonObject.Version)
stateMachine.SetStartState(stateMachineJsonObject.StartState)
stateMachine.SetPersist(stateMachineJsonObject.Persist)

if stateMachineJsonObject.Type != "" {
stateMachine.SetType(stateMachineJsonObject.Type)
}

if stateMachineJsonObject.RecoverStrategy != "" {
recoverStrategy, ok := statelang.ValueOfRecoverStrategy(stateMachineJsonObject.RecoverStrategy)
if !ok {
return nil, errors.New("Not support " + stateMachineJsonObject.RecoverStrategy)
}
stateMachine.SetRecoverStrategy(recoverStrategy)
}

stateParserFactory := NewDefaultStateParserFactory()
stateParserFactory.InitDefaultStateParser()
for stateName, v := range stateMachineJsonObject.States {
stateMap, ok := v.(map[string]interface{})
if !ok {
return nil, errors.New("State [" + stateName + "] scheme illegal, required map")
}

stateType, ok := stateMap["Type"].(string)
if !ok {
return nil, errors.New("State [" + stateName + "] Type illegal, required string")
}

//stateMap
stateParser := stateParserFactory.GetStateParser(stateType)
if stateParser == nil {
return nil, errors.New("State Type [" + stateType + "] is not support")
}

_, stateExist := stateMachine.States()[stateName]
if stateExist {
return nil, errors.New("State [name:" + stateName + "] already exists")
}

state, err := stateParser.Parse(stateName, stateMap)
if err != nil {
return nil, err
}

state.SetStateMachine(stateMachine)
stateMachine.States()[stateName] = state
}

for _, stateValue := range stateMachine.States() {
if stateMachineParser.isTaskState(stateValue.Type()) {
stateMachineParser.setForCompensation(stateValue, stateMachine)
}
}

return stateMachine, nil
}

func (stateMachineParser JSONStateMachineParser) setForCompensation(stateValue statelang.State, stateMachine *statelang.StateMachineImpl) {
if stateValue.Type() == constant.StateTypeServiceTask {
serviceTaskStateImpl, ok := stateValue.(*state.ServiceTaskStateImpl)
if ok {
if serviceTaskStateImpl.CompensateState() != "" {
compState := stateMachine.States()[serviceTaskStateImpl.CompensateState()]
if stateMachineParser.isTaskState(compState.Type()) {
compStateImpl, ok := compState.(*state.ServiceTaskStateImpl)
if ok {
compStateImpl.SetForCompensation(true)
}
}
}
}
}
}

func (stateMachineParser JSONStateMachineParser) isTaskState(stateType string) bool {
if stateType == constant.StateTypeServiceTask {
return true
}
return false
}

+ 127
- 0
pkg/saga/statemachine/statelang/parser/statemachine_json_parser_test.go View File

@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parser

import (
"os"
"testing"
)

func readFileContent(filePath string) (string, error) {
content, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
return string(content), nil
}

func TestParseChoice(t *testing.T) {
parser := NewJSONStateMachineParser()

tests := []struct {
name string
configFilePath string
}{
{
name: "JSON Simple: StateLang With Choice",
configFilePath: "../../../../../testdata/saga/statelang/simple_statelang_with_choice.json",
},
{
name: "YAML Simple: StateLang With Choice",
configFilePath: "../../../../../testdata/saga/statelang/simple_statelang_with_choice.yaml",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
content, err := readFileContent(tt.configFilePath)
if err != nil {
t.Error("read file fail: " + err.Error())
return
}
_, err = parser.Parse(content)
if err != nil {
t.Error("parse fail: " + err.Error())
}
})
}
}

func TestParseServiceTaskForSimpleStateMachine(t *testing.T) {
parser := NewJSONStateMachineParser()

tests := []struct {
name string
configFilePath string
}{
{
name: "JSON Simple: StateMachine",
configFilePath: "../../../../../testdata/saga/statelang/simple_statemachine.json",
},
{
name: "YAML Simple: StateMachine",
configFilePath: "../../../../../testdata/saga/statelang/simple_statemachine.yaml",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
content, err := readFileContent(tt.configFilePath)
if err != nil {
t.Error("read file fail: " + err.Error())
return
}
_, err = parser.Parse(content)
if err != nil {
t.Error("parse fail: " + err.Error())
}
})
}
}

func TestParseServiceTaskForNewDesigner(t *testing.T) {
parser := NewJSONStateMachineParser()

tests := []struct {
name string
configFilePath string
}{
{
name: "JSON Simple: StateMachine New Designer",
configFilePath: "../../../../../testdata/saga/statelang/state_machine_new_designer.json",
},
{
name: "YAML Simple: StateMachine New Designer",
configFilePath: "../../../../../testdata/saga/statelang/state_machine_new_designer.yaml",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
content, err := readFileContent(tt.configFilePath)
if err != nil {
t.Error("read file fail: " + err.Error())
return
}
_, err = parser.Parse(content)
if err != nil {
t.Error("parse fail: " + err.Error())
}
})
}
}

+ 253
- 0
pkg/saga/statemachine/statelang/parser/statemachine_parser.go View File

@@ -0,0 +1,253 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parser

import (
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"strconv"
"strings"
"sync"
)

type StateMachineParser interface {
GetType() string
Parse(content string) (statelang.StateMachine, error)
}

type StateParser interface {
StateType() string
Parse(stateName string, stateMap map[string]interface{}) (statelang.State, error)
}

type BaseStateParser struct {
}

func NewBaseStateParser() *BaseStateParser {
return &BaseStateParser{}
}

func (b BaseStateParser) ParseBaseAttributes(stateName string, state statelang.State, stateMap map[string]interface{}) error {
state.SetName(stateName)

comment, err := b.GetStringOrDefault(stateName, stateMap, "Comment", "")
if err != nil {
return err
}
state.SetComment(comment)

next, err := b.GetStringOrDefault(stateName, stateMap, "Next", "")
if err != nil {
return err
}
state.SetNext(next)
return nil
}

func (b BaseStateParser) GetString(stateName string, stateMap map[string]interface{}, key string) (string, error) {
value := stateMap[key]
if value == nil {
var result string
return result, errors.New("State [" + stateName + "] " + key + " not exist")
}

valueAsString, ok := value.(string)
if !ok {
var s string
return s, errors.New("State [" + stateName + "] " + key + " illegal, required string")
}
return valueAsString, nil
}

func (b BaseStateParser) GetStringOrDefault(stateName string, stateMap map[string]interface{}, key string, defaultValue string) (string, error) {
value := stateMap[key]
if value == nil {
return defaultValue, nil
}

valueAsString, ok := value.(string)
if !ok {
return defaultValue, errors.New("State [" + stateName + "] " + key + " illegal, required string")
}
return valueAsString, nil
}

func (b BaseStateParser) GetSlice(stateName string, stateMap map[string]interface{}, key string) ([]interface{}, error) {
value := stateMap[key]
if value == nil {
var result []interface{}
return result, errors.New("State [" + stateName + "] " + key + " not exist")
}

valueAsSlice, ok := value.([]interface{})
if !ok {
var slice []interface{}
return slice, errors.New("State [" + stateName + "] " + key + " illegal, required []interface{}")
}
return valueAsSlice, nil
}

func (b BaseStateParser) GetSliceOrDefault(stateName string, stateMap map[string]interface{}, key string, defaultValue []interface{}) ([]interface{}, error) {
value := stateMap[key]

if value == nil {
return defaultValue, nil
}

valueAsSlice, ok := value.([]interface{})
if !ok {
return defaultValue, errors.New("State [" + stateName + "] " + key + " illegal, required []interface{}")
}
return valueAsSlice, nil
}

func (b BaseStateParser) GetMapOrDefault(stateMap map[string]interface{}, key string, defaultValue map[string]interface{}) (map[string]interface{}, error) {
value := stateMap[key]

if value == nil {
return defaultValue, nil
}

valueAsMap, ok := value.(map[string]interface{})
if !ok {
return defaultValue, nil
}
return valueAsMap, nil
}

func (b BaseStateParser) GetBool(stateName string, stateMap map[string]interface{}, key string) (bool, error) {
value := stateMap[key]

if value == nil {
return false, errors.New("State [" + stateName + "] " + key + " not exist")
}

valueAsBool, ok := value.(bool)
if !ok {
return false, errors.New("State [" + stateName + "] " + key + " illegal, required bool")
}
return valueAsBool, nil
}

func (b BaseStateParser) GetBoolOrDefault(stateName string, stateMap map[string]interface{}, key string, defaultValue bool) (bool, error) {
value := stateMap[key]

if value == nil {
return defaultValue, nil
}

valueAsBool, ok := value.(bool)
if !ok {
return defaultValue, errors.New("State [" + stateName + "] " + key + " illegal, required bool")
}
return valueAsBool, nil
}

func (b BaseStateParser) GetIntOrDefault(stateName string, stateMap map[string]interface{}, key string, defaultValue int) (int, error) {
value := stateMap[key]

if value == nil {
return defaultValue, nil
}

// use float64 conversion when the configuration file is json, and use int conversion when the configuration file is yaml
valueAsFloat64, okToFloat64 := value.(float64)
valueAsInt, okToInt := value.(int)
if !okToFloat64 && !okToInt {
return defaultValue, errors.New("State [" + stateName + "] " + key + " illegal, required int")
}

if okToFloat64 {
floatStr := strconv.FormatFloat(valueAsFloat64, 'f', -1, 64)
if strings.Contains(floatStr, ".") {
return defaultValue, errors.New("State [" + stateName + "] " + key + " illegal, required int")
}

return int(valueAsFloat64), nil
}

return valueAsInt, nil
}

func (b BaseStateParser) GetFloat64OrDefault(stateName string, stateMap map[string]interface{}, key string, defaultValue float64) (float64, error) {
value := stateMap[key]

if value == nil {
return defaultValue, nil
}

// use float64 conversion when the configuration file is json, and use int conversion when the configuration file is yaml
valueAsFloat64, okToFloat64 := value.(float64)
valueAsInt, okToInt := value.(int)
if !okToFloat64 && !okToInt {
return defaultValue, errors.New("State [" + stateName + "] " + key + " illegal, required float64")
}

if okToFloat64 {
return valueAsFloat64, nil
}
return float64(valueAsInt), nil
}

type StateParserFactory interface {
RegistryStateParser(stateType string, stateParser StateParser)

GetStateParser(stateType string) StateParser
}

type DefaultStateParserFactory struct {
stateParserMap map[string]StateParser
mutex sync.Mutex
}

func NewDefaultStateParserFactory() *DefaultStateParserFactory {
var stateParserMap map[string]StateParser = make(map[string]StateParser)
return &DefaultStateParserFactory{
stateParserMap: stateParserMap,
}
}

// InitDefaultStateParser init StateParser by default
func (d *DefaultStateParserFactory) InitDefaultStateParser() {
choiceStateParser := NewChoiceStateParser()
serviceTaskStateParser := NewServiceTaskStateParser()
subStateMachineParser := NewSubStateMachineParser()
succeedEndStateParser := NewSucceedEndStateParser()
compensationTriggerStateParser := NewCompensationTriggerStateParser()
failEndStateParser := NewFailEndStateParser()
scriptTaskStateParser := NewScriptTaskStateParser()

d.RegistryStateParser(choiceStateParser.StateType(), choiceStateParser)
d.RegistryStateParser(serviceTaskStateParser.StateType(), serviceTaskStateParser)
d.RegistryStateParser(subStateMachineParser.StateType(), subStateMachineParser)
d.RegistryStateParser(succeedEndStateParser.StateType(), succeedEndStateParser)
d.RegistryStateParser(compensationTriggerStateParser.StateType(), compensationTriggerStateParser)
d.RegistryStateParser(compensationTriggerStateParser.StateType(), compensationTriggerStateParser)
d.RegistryStateParser(failEndStateParser.StateType(), failEndStateParser)
d.RegistryStateParser(scriptTaskStateParser.StateType(), scriptTaskStateParser)
}

func (d *DefaultStateParserFactory) RegistryStateParser(stateType string, stateParser StateParser) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.stateParserMap[stateType] = stateParser
}

func (d *DefaultStateParserFactory) GetStateParser(stateType string) StateParser {
return d.stateParserMap[stateType]
}

+ 101
- 0
pkg/saga/statemachine/statelang/parser/sub_state_machine_parser.go View File

@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parser

import (
"fmt"
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
)

type SubStateMachineParser struct {
*AbstractTaskStateParser
}

func NewSubStateMachineParser() *SubStateMachineParser {
return &SubStateMachineParser{
NewAbstractTaskStateParser(),
}
}

func (s SubStateMachineParser) StateType() string {
return constant.StateTypeSubStateMachine
}

func (s SubStateMachineParser) Parse(stateName string, stateMap map[string]interface{}) (statelang.State, error) {
subStateMachineImpl := state.NewSubStateMachineImpl()

err := s.ParseTaskAttributes(stateName, subStateMachineImpl.AbstractTaskState, stateMap)
if err != nil {
return nil, err
}

stateMachineName, err := s.BaseStateParser.GetString(stateName, stateMap, "StateMachineName")
if err != nil {
return nil, err
}
subStateMachineImpl.SetName(stateMachineName)

if subStateMachineImpl.CompensateState() == "" {
// build default SubStateMachine compensate state
compensateSubStateMachineStateParser := NewCompensateSubStateMachineStateParser()
compensateState, err := compensateSubStateMachineStateParser.Parse(stateName, nil)
if err != nil {
return nil, err
}
compensateStateImpl, ok := compensateState.(state.TaskState)
if !ok {
return nil, errors.New(fmt.Sprintf("State [name:%s] has wrong compensateState type", stateName))
}
subStateMachineImpl.SetCompensateStateImpl(compensateStateImpl)
subStateMachineImpl.SetCompensateState(compensateStateImpl.Name())
}
return subStateMachineImpl, nil
}

type CompensateSubStateMachineStateParser struct {
*AbstractTaskStateParser
}

func NewCompensateSubStateMachineStateParser() *CompensateSubStateMachineStateParser {
return &CompensateSubStateMachineStateParser{
NewAbstractTaskStateParser(),
}
}

func (c CompensateSubStateMachineStateParser) StateType() string {
return constant.StateTypeCompensateSubMachine
}

func (c CompensateSubStateMachineStateParser) Parse(stateName string, stateMap map[string]interface{}) (statelang.State, error) {
compensateSubStateMachineStateImpl := state.NewCompensateSubStateMachineStateImpl()
compensateSubStateMachineStateImpl.SetForCompensation(true)

if stateMap != nil {
err := c.ParseTaskAttributes(stateName, compensateSubStateMachineStateImpl.ServiceTaskStateImpl.AbstractTaskState, stateMap)
if err != nil {
return nil, err
}
}
if compensateSubStateMachineStateImpl.Name() == "" {
compensateSubStateMachineStateImpl.SetName(constant.CompensateSubMachineStateNamePrefix + compensateSubStateMachineStateImpl.Hashcode())
}
return compensateSubStateMachineStateImpl, nil
}

+ 347
- 0
pkg/saga/statemachine/statelang/parser/task_state_json_parser.go View File

@@ -0,0 +1,347 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package parser

import (
"fmt"
"github.com/pkg/errors"
"github.com/seata/seata-go/pkg/saga/statemachine/constant"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
)

type AbstractTaskStateParser struct {
*BaseStateParser
}

func NewAbstractTaskStateParser() *AbstractTaskStateParser {
return &AbstractTaskStateParser{
&BaseStateParser{},
}
}

func (a *AbstractTaskStateParser) ParseTaskAttributes(stateName string, state *state.AbstractTaskState, stateMap map[string]interface{}) error {
err := a.ParseBaseAttributes(state.Name(), state.BaseState, stateMap)
if err != nil {
return err
}

compensateState, err := a.GetStringOrDefault(stateName, stateMap, "CompensateState", "")
if err != nil {
return err
}
state.SetCompensateState(compensateState)

isForCompensation, err := a.GetBoolOrDefault(stateName, stateMap, "IsForCompensation", false)
if err != nil {
return err
}
state.SetForCompensation(isForCompensation)

isForUpdate, err := a.GetBoolOrDefault(stateName, stateMap, "IsForUpdate", false)
if err != nil {
return err
}
state.SetForUpdate(isForUpdate)

isPersist, err := a.GetBoolOrDefault(stateName, stateMap, "IsPersist", false)
if err != nil {
return err
}
state.SetPersist(isPersist)

isRetryPersistModeUpdate, err := a.GetBoolOrDefault(stateName, stateMap, "IsRetryPersistModeUpdate", false)
if err != nil {
return err
}
state.SetRetryPersistModeUpdate(isRetryPersistModeUpdate)

isCompensatePersistModeUpdate, err := a.GetBoolOrDefault(stateName, stateMap, "IsCompensatePersistModeUpdate", false)
if err != nil {
return err
}
state.SetCompensatePersistModeUpdate(isCompensatePersistModeUpdate)

retryInterfaces, err := a.GetSliceOrDefault(stateName, stateMap, "Retry", nil)
if err != nil {
return err
}
if retryInterfaces != nil {
retries, err := a.parseRetries(state.Name(), retryInterfaces)
if err != nil {
return err
}
state.SetRetry(retries)
}

catchInterfaces, err := a.GetSliceOrDefault(stateName, stateMap, "Catch", nil)
if err != nil {
return err
}
if catchInterfaces != nil {
catches, err := a.parseCatches(state.Name(), catchInterfaces)
if err != nil {
return err
}
state.SetCatches(catches)
}

inputInterfaces, err := a.GetSliceOrDefault(stateName, stateMap, "Input", nil)
if err != nil {
return err
}
if inputInterfaces != nil {
state.SetInput(inputInterfaces)
}

output, err := a.GetMapOrDefault(stateMap, "Output", nil)
if err != nil {
return err
}
if output != nil {
state.SetOutput(output)
}

statusMap, ok := stateMap["Status"].(map[string]string)
if ok {
state.SetStatus(statusMap)
}

loopMap, ok := stateMap["Loop"].(map[string]interface{})
if ok {
loop := a.parseLoop(stateName, loopMap)
state.SetLoop(loop)
}

return nil
}

func (a *AbstractTaskStateParser) parseLoop(stateName string, loopMap map[string]interface{}) state.Loop {
loopImpl := &state.LoopImpl{}
parallel, err := a.GetIntOrDefault(stateName, loopMap, "Parallel", 1)
if err != nil {
return nil
}
loopImpl.SetParallel(parallel)

collection, err := a.GetStringOrDefault(stateName, loopMap, "Collection", "")
if err != nil {
return nil
}
loopImpl.SetCollection(collection)

elementVariableName, err := a.GetStringOrDefault(stateName, loopMap, "ElementVariableName", "loopElement")
if err != nil {
return nil
}
loopImpl.SetElementVariableName(elementVariableName)

elementIndexName, err := a.GetStringOrDefault(stateName, loopMap, "ElementIndexName", "loopCounter")
if err != nil {
return nil
}
loopImpl.SetElementIndexName(elementIndexName)

completionCondition, err := a.GetStringOrDefault(stateName, loopMap, "CompletionCondition", "[nrOfInstances] == [nrOfCompletedInstances]")
if err != nil {
return nil
}
loopImpl.SetElementIndexName(completionCondition)
return loopImpl
}

func (a *AbstractTaskStateParser) parseRetries(stateName string, retryInterfaces []interface{}) ([]state.Retry, error) {
retries := make([]state.Retry, 0)
for _, retryInterface := range retryInterfaces {
retryMap, ok := retryInterface.(map[string]interface{})
if !ok {

return nil, errors.New("State [" + stateName + "] " + "Retry illegal, require map[string]interface{}")
}
retry := &state.RetryImpl{}
exceptions, err := a.GetSliceOrDefault(stateName, retryMap, "Exceptions", nil)
if err != nil {
return nil, err
}
if exceptions != nil {
errors := make([]string, 0)
for _, errorType := range exceptions {
errors = append(errors, errorType.(string))
}
retry.SetExceptions(errors)
}

maxAttempts, err := a.GetIntOrDefault(stateName, retryMap, "MaxAttempts", 0)
if err != nil {
return nil, err
}
retry.SetMaxAttempt(maxAttempts)

backoffInterval, err := a.GetFloat64OrDefault(stateName, retryMap, "BackoffInterval", 0)
if err != nil {
return nil, err
}
retry.SetBackoffRate(backoffInterval)

intervalSeconds, err := a.GetFloat64OrDefault(stateName, retryMap, "IntervalSeconds", 0)
if err != nil {
return nil, err
}
retry.SetIntervalSecond(intervalSeconds)
retries = append(retries, retry)
}
return retries, nil
}

func (a *AbstractTaskStateParser) parseCatches(stateName string, catchInterfaces []interface{}) ([]state.ExceptionMatch, error) {
errorMatches := make([]state.ExceptionMatch, 0, len(catchInterfaces))
for _, catchInterface := range catchInterfaces {
catchMap, ok := catchInterface.(map[string]interface{})
if !ok {
return nil, errors.New("State [" + stateName + "] " + "Catch illegal, require map[string]interface{}")
}
errorMatch := &state.ExceptionMatchImpl{}
errorInterfaces, err := a.GetSliceOrDefault(stateName, catchMap, "Exceptions", nil)
if err != nil {
return nil, err
}
if errorInterfaces != nil {
errorNames := make([]string, 0)
for _, errorType := range errorInterfaces {
errorNames = append(errorNames, errorType.(string))
}
errorMatch.SetExceptions(errorNames)
}
next, err := a.GetStringOrDefault(stateName, catchMap, "Next", "")
if err != nil {
return nil, err
}
errorMatch.SetNext(next)
errorMatches = append(errorMatches, errorMatch)
}
return errorMatches, nil
}

type ServiceTaskStateParser struct {
*AbstractTaskStateParser
}

func NewServiceTaskStateParser() *ServiceTaskStateParser {
return &ServiceTaskStateParser{
NewAbstractTaskStateParser(),
}
}

func (s ServiceTaskStateParser) StateType() string {
return constant.StateTypeServiceTask
}

func (s ServiceTaskStateParser) Parse(stateName string, stateMap map[string]interface{}) (statelang.State, error) {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()

err := s.ParseTaskAttributes(stateName, serviceTaskStateImpl.AbstractTaskState, stateMap)
if err != nil {
return nil, err
}

serviceName, err := s.GetString(stateName, stateMap, "ServiceName")
if err != nil {
return nil, err
}
serviceTaskStateImpl.SetServiceName(serviceName)

serviceMethod, err := s.GetString(stateName, stateMap, "ServiceMethod")
if err != nil {
return nil, err
}
serviceTaskStateImpl.SetServiceMethod(serviceMethod)

serviceType, err := s.GetStringOrDefault(stateName, stateMap, "ServiceType", "")
if err != nil {
return nil, err
}
serviceTaskStateImpl.SetServiceType(serviceType)

parameterTypeInterfaces, err := s.GetSliceOrDefault(stateName, stateMap, "ParameterTypes", nil)
if err != nil {
return nil, err
}
if parameterTypeInterfaces != nil {
var parameterTypes []string
for i := range parameterTypeInterfaces {
parameterType, ok := parameterTypeInterfaces[i].(string)
if !ok {
return nil, errors.New(fmt.Sprintf("State [%s] parameterType required string", stateName))
}

parameterTypes = append(parameterTypes, parameterType)
}
serviceTaskStateImpl.SetParameterTypes(parameterTypes)
}

isAsync, err := s.GetBoolOrDefault(stateName, stateMap, "IsAsync", false)
if err != nil {
return nil, err
}
serviceTaskStateImpl.SetIsAsync(isAsync)

return serviceTaskStateImpl, nil
}

type ScriptTaskStateParser struct {
*AbstractTaskStateParser
}

func NewScriptTaskStateParser() *ScriptTaskStateParser {
return &ScriptTaskStateParser{
NewAbstractTaskStateParser(),
}
}

func (s ScriptTaskStateParser) StateType() string {
return constant.StateTypeScriptTask
}

func (s ScriptTaskStateParser) Parse(stateName string, stateMap map[string]interface{}) (statelang.State, error) {
scriptTaskStateImpl := state.NewScriptTaskStateImpl()

err := s.ParseTaskAttributes(stateName, scriptTaskStateImpl.AbstractTaskState, stateMap)
if err != nil {
return nil, err
}

scriptType, err := s.GetStringOrDefault(stateName, stateMap, "ScriptType", "")
if err != nil {
return nil, err
}
if scriptType != "" {
scriptTaskStateImpl.SetScriptType(scriptType)
}

scriptContent, err := s.GetStringOrDefault(stateName, stateMap, "ScriptContent", "")
if err != nil {
return nil, err
}
scriptTaskStateImpl.SetScriptContent(scriptContent)

scriptTaskStateImpl.SetForCompensation(false)
scriptTaskStateImpl.SetForUpdate(false)
scriptTaskStateImpl.SetPersist(false)

return scriptTaskStateImpl, nil
}

+ 92
- 0
pkg/saga/statemachine/statelang/state.go View File

@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package statelang

type State interface {
Name() string

SetName(name string)

Comment() string

SetComment(comment string)

Type() string

SetType(typeName string)

Next() string

SetNext(next string)

StateMachine() StateMachine

SetStateMachine(machine StateMachine)
}

type BaseState struct {
name string `alias:"Name"`
comment string `alias:"Comment"`
typeName string `alias:"Type"`
next string `alias:"Next"`
stateMachine StateMachine
}

func NewBaseState() *BaseState {
return &BaseState{}
}

func (b *BaseState) Name() string {
return b.name
}

func (b *BaseState) SetName(name string) {
b.name = name
}

func (b *BaseState) Comment() string {
return b.comment
}

func (b *BaseState) SetComment(comment string) {
b.comment = comment
}

func (b *BaseState) Type() string {
return b.typeName
}

func (b *BaseState) SetType(typeName string) {
b.typeName = typeName
}

func (b *BaseState) Next() string {
return b.next
}

func (b *BaseState) SetNext(next string) {
b.next = next
}

func (b *BaseState) StateMachine() StateMachine {
return b.stateMachine
}

func (b *BaseState) SetStateMachine(machine StateMachine) {
b.stateMachine = machine
}

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save