Browse Source

Feature/add script manager&actuator (#868)

* add script manager & actuator

* Otto package replaces Goja package

* Resolve the issues raised by Copilot

* add test for vm-pool

* update GetServiceInvoker method
pull/651/merge
flypiggy GitHub 1 month ago
parent
commit
fcc1cf2527
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
8 changed files with 910 additions and 146 deletions
  1. +6
    -4
      go.mod
  2. +10
    -6
      go.sum
  3. +25
    -5
      pkg/saga/statemachine/engine/config/default_statemachine_config.go
  4. +38
    -131
      pkg/saga/statemachine/engine/invoker/invoker.go
  5. +162
    -0
      pkg/saga/statemachine/engine/invoker/javascript_script_invoker.go
  6. +262
    -0
      pkg/saga/statemachine/engine/invoker/javascript_script_invoker_test.go
  7. +195
    -0
      pkg/saga/statemachine/engine/invoker/local_invoker.go
  8. +212
    -0
      pkg/saga/statemachine/engine/invoker/local_invoker_test.go

+ 6
- 4
go.mod View File

@@ -1,6 +1,6 @@
module github.com/seata/seata-go

go 1.18
go 1.20

require (
dubbo.apache.org/dubbo-go/v3 v3.0.4
@@ -35,7 +35,8 @@ require (
github.com/agiledragon/gomonkey/v2 v2.9.0
github.com/google/cel-go v0.18.0
github.com/mattn/go-sqlite3 v1.14.19
golang.org/x/sync v0.6.0
github.com/robertkrimen/otto v0.4.0
golang.org/x/sync v0.16.0
google.golang.org/protobuf v1.33.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -90,8 +91,9 @@ require (
github.com/yusufpapurcu/wmi v1.2.2 // indirect
go.uber.org/multierr v1.8.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/text v0.27.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect
gopkg.in/sourcemap.v1 v1.0.5 // indirect
)

require (
@@ -106,7 +108,7 @@ require (
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/sys v0.32.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5 // indirect
vimagination.zapto.org/memio v0.0.0-20200222190306-588ebc67b97d // indirect


+ 10
- 6
go.sum View File

@@ -672,6 +672,8 @@ github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40T
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rhnvrm/simples3 v0.6.1/go.mod h1:Y+3vYm2V7Y4VijFoJHHTrja6OgPrJ2cBti8dPGkC3sA=
github.com/robertkrimen/otto v0.4.0 h1:/c0GRrK1XDPcgIasAsnlpBT5DelIeB9U/Z/JCQsgr7E=
github.com/robertkrimen/otto v0.4.0/go.mod h1:uW9yN1CYflmUQYvAMS0m+ZiNo3dMzRUDQJX0jWbzgxw=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
@@ -952,8 +954,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -1031,8 +1033,8 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -1045,8 +1047,8 @@ golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -1251,6 +1253,8 @@ gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8=
gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k=
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
gopkg.in/sourcemap.v1 v1.0.5 h1:inv58fC9f9J3TK2Y2R1NPntXEn3/wjWHkonhIUODNTI=
gopkg.in/sourcemap.v1 v1.0.5/go.mod h1:2RlvNNSMglmRrcvhfuzp4hQHwOtjxlbjX7UPY/GXb78=
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=


+ 25
- 5
pkg/saga/statemachine/engine/config/default_statemachine_config.go View File

@@ -292,8 +292,17 @@ func (c *DefaultStateMachineConfig) GetExpressionFactory(expressionType string)
return c.expressionFactoryManager.GetExpressionFactory(expressionType)
}

func (c *DefaultStateMachineConfig) GetServiceInvoker(serviceType string) invoker.ServiceInvoker {
return c.serviceInvokerManager.ServiceInvoker(serviceType)
func (c *DefaultStateMachineConfig) GetServiceInvoker(serviceType string) (invoker.ServiceInvoker, error) {
if serviceType == "" {
serviceType = "local"
}

invoker := c.serviceInvokerManager.ServiceInvoker(serviceType)
if invoker == nil {
return nil, fmt.Errorf("service invoker not found for type: %s", serviceType)
}

return invoker, nil
}

func (c *DefaultStateMachineConfig) RegisterStateMachineDef(resources []string) error {
@@ -492,9 +501,20 @@ func (c *DefaultStateMachineConfig) initServiceInvokers() error {
c.serviceInvokerManager = invoker.NewServiceInvokerManagerImpl()
}

defaultServiceType := "local"
if existingInvoker := c.serviceInvokerManager.ServiceInvoker(defaultServiceType); existingInvoker == nil {
c.RegisterServiceInvoker(defaultServiceType, invoker.NewLocalServiceInvoker())
if existing := c.serviceInvokerManager.ServiceInvoker("local"); existing == nil {
c.RegisterServiceInvoker("local", invoker.NewLocalServiceInvoker())
}

if existing := c.serviceInvokerManager.ServiceInvoker("http"); existing == nil {
c.RegisterServiceInvoker("http", invoker.NewHTTPInvoker())
}

if existing := c.serviceInvokerManager.ServiceInvoker("grpc"); existing == nil {
c.RegisterServiceInvoker("grpc", invoker.NewGRPCInvoker())
}

if existing := c.serviceInvokerManager.ServiceInvoker("func"); existing == nil {
c.RegisterServiceInvoker("func", invoker.NewFuncInvoker())
}

return nil


+ 38
- 131
pkg/saga/statemachine/engine/invoker/invoker.go View File

@@ -20,7 +20,6 @@ package invoker
import (
"context"
"encoding/json"
"fmt"
"reflect"
"sync"

@@ -43,164 +42,72 @@ func (p *DefaultJsonParser) Marshal(v any) ([]byte, error) {
}

type ScriptInvokerManager interface {
GetInvoker(scriptType string) (ScriptInvoker, error)
RegisterInvoker(invoker ScriptInvoker)
Execute(ctx context.Context, scriptType string, script string, params map[string]interface{}) (interface{}, error)
}

type ScriptInvoker interface {
}

type ServiceInvokerManager interface {
ServiceInvoker(serviceType string) ServiceInvoker
PutServiceInvoker(serviceType string, invoker ServiceInvoker)
}

type ServiceInvoker interface {
Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error)
Invoke(ctx context.Context, script string, params map[string]interface{}) (interface{}, error)
Type() string
Close(ctx context.Context) error
}

type ServiceInvokerManagerImpl struct {
invokers map[string]ServiceInvoker
type ScriptInvokerManagerImpl struct {
invokers map[string]ScriptInvoker
mutex sync.Mutex
}

type LocalServiceInvoker struct {
serviceRegistry map[string]interface{}
methodCache map[string]*reflect.Method
jsonParser JsonParser
mutex sync.RWMutex
}

func NewLocalServiceInvoker() *LocalServiceInvoker {
return &LocalServiceInvoker{
serviceRegistry: make(map[string]interface{}),
methodCache: make(map[string]*reflect.Method),
jsonParser: &DefaultJsonParser{},
}
}

func (l *LocalServiceInvoker) RegisterService(serviceName string, instance interface{}) {
l.mutex.Lock()
defer l.mutex.Unlock()
l.serviceRegistry[serviceName] = instance
}

func (l *LocalServiceInvoker) Invoke(ctx context.Context, input []any, service state.ServiceTaskState) ([]reflect.Value, error) {
serviceName := service.ServiceName()
instance, exists := l.serviceRegistry[serviceName]
if !exists {
return nil, fmt.Errorf("service %s not registered", serviceName)
}

methodName := service.ServiceMethod()
method, err := l.getMethod(serviceName, methodName, service.ParameterTypes())
if err != nil {
return nil, err
}

params, err := l.resolveParameters(input, method.Type)
if err != nil {
return nil, err
func NewScriptInvokerManager() *ScriptInvokerManagerImpl {
return &ScriptInvokerManagerImpl{
invokers: make(map[string]ScriptInvoker),
}

return l.invokeMethod(instance, method, params), nil
}

func (l *LocalServiceInvoker) resolveMethod(key, serviceName, methodName string) (*reflect.Method, error) {
l.mutex.Lock()
defer l.mutex.Unlock()

if cachedMethod, ok := l.methodCache[key]; ok {
return cachedMethod, nil
func (m *ScriptInvokerManagerImpl) GetInvoker(scriptType string) (ScriptInvoker, error) {
if scriptType == "" {
return nil, nil
}
m.mutex.Lock()
defer m.mutex.Unlock()

instance, exists := l.serviceRegistry[serviceName]
invoker, exists := m.invokers[scriptType]
if !exists {
return nil, fmt.Errorf("service %s not found", serviceName)
}

objType := reflect.TypeOf(instance)
method, ok := objType.MethodByName(methodName)
if !ok {
return nil, fmt.Errorf("method %s not found in service %s", methodName, serviceName)
return nil, nil
}

l.methodCache[key] = &method
return &method, nil
return invoker, nil
}

func (l *LocalServiceInvoker) getMethod(serviceName, methodName string, paramTypes []string) (*reflect.Method, error) {
key := fmt.Sprintf("%s.%s", serviceName, methodName)

l.mutex.RLock()
if method, ok := l.methodCache[key]; ok {
l.mutex.RUnlock()
return method, nil
func (m *ScriptInvokerManagerImpl) RegisterInvoker(invoker ScriptInvoker) {
if invoker == nil || invoker.Type() == "" {
return
}
l.mutex.RUnlock()
return l.resolveMethod(key, serviceName, methodName)
m.mutex.Lock()
defer m.mutex.Unlock()
m.invokers[invoker.Type()] = invoker
}

func (l *LocalServiceInvoker) resolveParameters(input []any, methodType reflect.Type) ([]reflect.Value, error) {
params := make([]reflect.Value, methodType.NumIn())
for i := 0; i < methodType.NumIn(); i++ {
paramType := methodType.In(i)
if i >= len(input) {
params[i] = reflect.Zero(paramType)
continue
}

converted, err := l.convertParam(input[i], paramType)
if err != nil {
return nil, err
}
params[i] = reflect.ValueOf(converted)
func (m *ScriptInvokerManagerImpl) Execute(ctx context.Context, scriptType string, script string, params map[string]interface{}) (interface{}, error) {
invoker, err := m.GetInvoker(scriptType)
if err != nil || invoker == nil {
return nil, err
}
return params, nil
return invoker.Invoke(ctx, script, params)
}

func (l *LocalServiceInvoker) convertParam(value any, targetType reflect.Type) (any, error) {
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
value = reflect.ValueOf(value).Interface()
}

if targetType.Kind() == reflect.Int && reflect.TypeOf(value).Kind() == reflect.Float64 {
return int(value.(float64)), nil
} else if targetType == reflect.TypeOf("") && reflect.TypeOf(value).Kind() == reflect.Int {
return fmt.Sprintf("%d", value), nil
}

if targetType.Kind() == reflect.Struct {
jsonData, err := l.jsonParser.Marshal(value)
if err != nil {
return nil, err
}
instance := reflect.New(targetType).Interface()
if err := l.jsonParser.Unmarshal(jsonData, instance); err != nil {
return nil, err
}
return instance, nil
}

return value, nil
type ServiceInvokerManager interface {
ServiceInvoker(serviceType string) ServiceInvoker
PutServiceInvoker(serviceType string, invoker ServiceInvoker)
}

func (l *LocalServiceInvoker) invokeMethod(instance interface{}, method *reflect.Method, params []reflect.Value) []reflect.Value {
instanceValue := reflect.ValueOf(instance)
if method.Func.IsValid() {
allParams := append([]reflect.Value{instanceValue}, params...)
return method.Func.Call(allParams)
}
return nil
type ServiceInvoker interface {
Invoke(ctx context.Context, input []any, service state.ServiceTaskState) (output []reflect.Value, err error)
Close(ctx context.Context) error
}

func (l *LocalServiceInvoker) Close(ctx context.Context) error {
l.mutex.Lock()
defer l.mutex.Unlock()
l.serviceRegistry = nil
l.methodCache = nil
return nil
type ServiceInvokerManagerImpl struct {
invokers map[string]ServiceInvoker
mutex sync.Mutex
}

func NewServiceInvokerManagerImpl() *ServiceInvokerManagerImpl {


+ 162
- 0
pkg/saga/statemachine/engine/invoker/javascript_script_invoker.go View File

@@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"fmt"
"sync"

"github.com/robertkrimen/otto"
)

const defaultPoolSize = 10

type JavaScriptScriptInvoker struct {
mutex sync.Mutex
jsonParser JsonParser
closed bool
vmPool chan *otto.Otto
poolSize int
}

func NewJavaScriptScriptInvoker() *JavaScriptScriptInvoker {
return &JavaScriptScriptInvoker{
jsonParser: &DefaultJsonParser{},
closed: false,
poolSize: defaultPoolSize,
vmPool: make(chan *otto.Otto, defaultPoolSize),
}
}

func NewJavaScriptScriptInvokerWithPoolSize(poolSize int) *JavaScriptScriptInvoker {
if poolSize <= 0 {
poolSize = defaultPoolSize
}
return &JavaScriptScriptInvoker{
jsonParser: &DefaultJsonParser{},
closed: false,
poolSize: poolSize,
vmPool: make(chan *otto.Otto, poolSize),
}
}

func (j *JavaScriptScriptInvoker) Type() string {
return "javascript"
}

func (j *JavaScriptScriptInvoker) Invoke(ctx context.Context, script string, params map[string]interface{}) (interface{}, error) {
j.mutex.Lock()
closed := j.closed
j.mutex.Unlock()

if closed {
return nil, fmt.Errorf("javascript invoker has been closed")
}

var vm *otto.Otto
select {
case vm = <-j.vmPool:
if err := cleanVMState(vm); err != nil {
vm = otto.New()
}
default:
vm = otto.New()
}

defer func() {
j.mutex.Lock()
defer j.mutex.Unlock()
if !j.closed {
select {
case j.vmPool <- vm:
default:
// Pool full, discard current instance
}
}
}()

for key, value := range params {
if err := vm.Set(key, value); err != nil {
return nil, fmt.Errorf("javascript set param %s error: %w", key, err)
}
}

resultChan := make(chan struct {
val otto.Value
err error
}, 1)

go func() {
defer func() {
if r := recover(); r != nil {
resultChan <- struct {
val otto.Value
err error
}{otto.UndefinedValue(), fmt.Errorf("javascript engine panic: %v", r)}
}
}()

val, err := vm.Run(script)
resultChan <- struct {
val otto.Value
err error
}{val, err}
}()

select {
case <-ctx.Done():
return nil, fmt.Errorf("javascript execution timeout: %w", ctx.Err())
case res := <-resultChan:
if res.err != nil {
return nil, fmt.Errorf("javascript execute error: %w", res.err)
}
val, err := res.val.Export()
if err != nil {
return nil, fmt.Errorf("failed to export javascript result: %w", err)
}
return val, nil
}
}

func (j *JavaScriptScriptInvoker) Close(ctx context.Context) error {
j.mutex.Lock()
defer j.mutex.Unlock()

if j.closed {
return nil
}

j.closed = true
close(j.vmPool)
for range j.vmPool {
// Let GC recycle VM resources
}
return nil
}

func cleanVMState(vm *otto.Otto) error {
_, err := vm.Run(`
for (const prop in global) {
if (!['Object', 'Array', 'Function', 'String', 'Number', 'Boolean', 'JSON', 'Date', 'RegExp'].includes(prop)) {
delete global[prop];
}
}
`)
return err
}

+ 262
- 0
pkg/saga/statemachine/engine/invoker/javascript_script_invoker_test.go View File

@@ -0,0 +1,262 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"sync"
"testing"
"time"

"github.com/robertkrimen/otto"
"github.com/stretchr/testify/assert"
)

func TestJavaScriptScriptInvoker_Type(t *testing.T) {
invoker := NewJavaScriptScriptInvoker()
assert.Equal(t, "javascript", invoker.Type())
}

func TestJavaScriptScriptInvoker_Invoke_Basic(t *testing.T) {
tests := []struct {
name string
script string
params map[string]interface{}
expected interface{}
}{
{
name: "simple expression",
script: "1 + 2",
params: nil,
expected: float64(3),
},
{
name: "param calculation",
script: "a * b + c",
params: map[string]interface{}{"a": 2, "b": 3, "c": 4},
expected: float64(10),
},
{
name: "return string",
script: "['hello', name].join(' ')",
params: map[string]interface{}{"name": "world"},
expected: "hello world",
},
{
name: "return object",
script: `var obj = {id: 1, name: name}; obj;`,
params: map[string]interface{}{"name": "test"},
expected: map[string]interface{}{"id": float64(1), "name": "test"},
},
}

invoker := NewJavaScriptScriptInvoker()
ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := invoker.Invoke(ctx, tt.script, tt.params)
assert.NoError(t, err)

if resultMap, ok := result.(map[string]interface{}); ok {
for k, v := range resultMap {
if intVal, isInt := v.(int64); isInt {
resultMap[k] = float64(intVal)
}
}
}

assert.Equal(t, tt.expected, result)
})
}
}

func TestJavaScriptScriptInvoker_Invoke_Error(t *testing.T) {
tests := []struct {
name string
script string
params map[string]interface{}
errMsg string
}{
{
name: "syntax error",
script: "1 + ",
params: nil,
errMsg: "javascript execute error",
},
{
name: "reference undefined variable",
script: "undefinedVar",
params: nil,
errMsg: "javascript execute error",
},
}

invoker := NewJavaScriptScriptInvoker()
ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := invoker.Invoke(ctx, tt.script, tt.params)

if err == nil {
t.Fatalf("Test case [%s] expected error but got none", tt.name)
}
assert.Contains(t, err.Error(), tt.errMsg, "Test case [%s] error message mismatch", tt.name)
})
}
}

func TestJavaScriptScriptInvoker_Invoke_Timeout(t *testing.T) {

script := `var target = 300; var start = new Date().getTime(); var elapsed = 0; while (elapsed < target) { elapsed = new Date().getTime() - start; } "done";`
invoker := NewJavaScriptScriptInvoker()

ctx1, cancel1 := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel1()
_, err := invoker.Invoke(ctx1, script, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "javascript execution timeout")

ctx2, cancel2 := context.WithTimeout(context.Background(), 400*time.Millisecond)
defer cancel2()
result, err := invoker.Invoke(ctx2, script, nil)
assert.NoError(t, err, "Scenario 2: script execution should not return error")
assert.Equal(t, "done", result, "Scenario 2: should return 'done'")
}

func TestJavaScriptScriptInvoker_Invoke_Concurrent(t *testing.T) {
invoker := NewJavaScriptScriptInvoker()
ctx := context.Background()
var wg sync.WaitGroup
concurrency := 100
errChan := make(chan error, concurrency)

script := `a + b`
params := map[string]interface{}{"a": 10, "b": 20}

for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
result, err := invoker.Invoke(ctx, script, params)
if err != nil {
errChan <- err
return
}
if result != float64(30) {
errChan <- assert.AnError
}
}()
}

wg.Wait()
close(errChan)

assert.Empty(t, errChan, "Concurrent execution has errors")
}

func TestJavaScriptScriptInvoker_Close(t *testing.T) {
invoker := NewJavaScriptScriptInvoker()
ctx := context.Background()

result, err := invoker.Invoke(ctx, "1 + 1", nil)
assert.NoError(t, err)
assert.Equal(t, float64(2), result)

err = invoker.Close(ctx)
assert.NoError(t, err)

_, err = invoker.Invoke(ctx, "1 + 1", nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "javascript invoker has been closed")
}

func TestOttoScript(t *testing.T) {
vm := otto.New()
script := `var target = 300; var start = new Date().getTime(); var elapsed = 0; while (elapsed < target) { elapsed = new Date().getTime() - start; } "done";`
val, err := vm.Run(script)
if err != nil {
t.Fatalf("otto failed to parse script: %v", err)
}

result, exportErr := val.Export()
if exportErr != nil {
t.Fatalf("failed to export otto value: %v", exportErr)
}
t.Logf("Script execution result: %v", result)
}

func TestJavaScriptScriptInvoker_VMPoolReuse(t *testing.T) {
poolSize := 2
invoker := NewJavaScriptScriptInvokerWithPoolSize(poolSize)
ctx := context.Background()

vmIDs := make([]string, 0, 5)

script := `
if (!this.vmId) {
this.vmId = Math.random().toString(36).substr(2, 8);
}
this.vmId;
`

for i := 0; i < 5; i++ {
result, err := invoker.Invoke(ctx, script, nil)
assert.NoError(t, err, "Error occurred while executing script")

id, ok := result.(string)
assert.True(t, ok, "VM ID should be a string type")
vmIDs = append(vmIDs, id)
}

uniqueIDs := make(map[string]bool)
for _, id := range vmIDs {
uniqueIDs[id] = true
}

assert.True(t, len(uniqueIDs) <= 5, "Abnormal number of VM instances created")
assert.True(t, len(uniqueIDs) >= 1, "No VM instances reused from the pool")
}

func TestJavaScriptScriptInvoker_VMStateClean(t *testing.T) {
invoker := NewJavaScriptScriptInvokerWithPoolSize(1)
ctx := context.Background()

_, err := invoker.Invoke(ctx, `this.foo = "polluted data"`, nil)
assert.NoError(t, err)

result, err := invoker.Invoke(ctx, `typeof this.foo`, nil)
assert.NoError(t, err)
assert.Equal(t, "undefined", result, "VM state not cleaned, residual global variable exists")

_, err = invoker.Invoke(ctx, `this.bar = function() { return "residual function"; }`, nil)
assert.NoError(t, err)

result, err = invoker.Invoke(ctx, `typeof this.bar`, nil)
assert.NoError(t, err)
assert.Equal(t, "undefined", result, "VM state not cleaned, residual function exists")
}

func TestJavaScriptScriptInvoker_PoolSizeDefault(t *testing.T) {
invoker := NewJavaScriptScriptInvokerWithPoolSize(0)
assert.Equal(t, defaultPoolSize, invoker.poolSize, "Default pool size not used when pool size is 0")

invoker = NewJavaScriptScriptInvokerWithPoolSize(-5)
assert.Equal(t, defaultPoolSize, invoker.poolSize, "Default pool size not used when pool size is negative")
}

+ 195
- 0
pkg/saga/statemachine/engine/invoker/local_invoker.go View File

@@ -0,0 +1,195 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"fmt"
"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
"reflect"
"sync"
)

type LocalServiceInvoker struct {
serviceRegistry map[string]interface{}
methodCache map[string]*reflect.Method
jsonParser JsonParser
mutex sync.RWMutex
}

func NewLocalServiceInvoker() *LocalServiceInvoker {
return &LocalServiceInvoker{
serviceRegistry: make(map[string]interface{}),
methodCache: make(map[string]*reflect.Method),
jsonParser: &DefaultJsonParser{},
}
}

func (l *LocalServiceInvoker) RegisterService(serviceName string, instance interface{}) {
l.mutex.Lock()
defer l.mutex.Unlock()
l.serviceRegistry[serviceName] = instance
}

func (l *LocalServiceInvoker) Invoke(ctx context.Context, input []any, service state.ServiceTaskState) ([]reflect.Value, error) {
serviceName := service.ServiceName()
instance, exists := l.serviceRegistry[serviceName]
if !exists {
return nil, fmt.Errorf("service %s not registered", serviceName)
}

methodName := service.ServiceMethod()
method, err := l.getMethod(serviceName, methodName, service.ParameterTypes())
if err != nil {
return nil, err
}

params, err := l.resolveParameters(input, method.Type)
if err != nil {
return nil, err
}

return l.invokeMethod(instance, method, params), nil
}

func (l *LocalServiceInvoker) resolveMethod(key, serviceName, methodName string) (*reflect.Method, error) {
l.mutex.Lock()
defer l.mutex.Unlock()

if cachedMethod, ok := l.methodCache[key]; ok {
return cachedMethod, nil
}

instance, exists := l.serviceRegistry[serviceName]
if !exists {
return nil, fmt.Errorf("service %s not found", serviceName)
}

objType := reflect.TypeOf(instance)
method, ok := objType.MethodByName(methodName)
if !ok {
return nil, fmt.Errorf("method %s not found in service %s", methodName, serviceName)
}

l.methodCache[key] = &method
return &method, nil
}

func (l *LocalServiceInvoker) getMethod(serviceName, methodName string, paramTypes []string) (*reflect.Method, error) {
key := fmt.Sprintf("%s.%s", serviceName, methodName)

l.mutex.RLock()
if method, ok := l.methodCache[key]; ok {
l.mutex.RUnlock()
return method, nil
}
l.mutex.RUnlock()

return l.resolveMethod(key, serviceName, methodName)
}

func (l *LocalServiceInvoker) resolveParameters(input []any, methodType reflect.Type) ([]reflect.Value, error) {
numIn := methodType.NumIn()
paramStart, paramCount := 1, 0

if numIn > 0 {
paramCount = numIn - paramStart
}

if paramCount == 0 {
if len(input) > 0 {
return nil, fmt.Errorf("unexpected parameters: expected 0, got %d", len(input))
}
return []reflect.Value{}, nil
}

if len(input) < paramCount {
return nil, fmt.Errorf("insufficient parameters: expected %d, got %d", paramCount, len(input))
}

if len(input) > paramCount {
return nil, fmt.Errorf("too many parameters: expected %d, got %d", paramCount, len(input))
}

params := make([]reflect.Value, paramCount)
for i := 0; i < paramCount; i++ {
methodParamIndex := i + paramStart
paramType := methodType.In(methodParamIndex)

converted, err := l.convertParam(input[i], paramType)
if err != nil {
return nil, fmt.Errorf("parameter %d conversion error: %w", i, err)
}

params[i] = reflect.ValueOf(converted)
}

return params, nil
}

func (l *LocalServiceInvoker) convertParam(value any, targetType reflect.Type) (any, error) {
if targetType.Kind() == reflect.Ptr {
elemType := targetType.Elem()
instance := reflect.New(elemType).Interface()
jsonData, err := l.jsonParser.Marshal(value)
if err != nil {
return nil, err
}
if err := l.jsonParser.Unmarshal(jsonData, instance); err != nil {
return nil, err
}
return instance, nil
}

if targetType.Kind() == reflect.Struct {
instance := reflect.New(targetType).Interface()
jsonData, err := l.jsonParser.Marshal(value)
if err != nil {
return nil, err
}
if err := l.jsonParser.Unmarshal(jsonData, instance); err != nil {
return nil, err
}
return reflect.ValueOf(instance).Elem().Interface(), nil
}

if targetType.Kind() == reflect.Int && reflect.TypeOf(value).Kind() == reflect.Float64 {
return int(value.(float64)), nil
} else if targetType == reflect.TypeOf("") && reflect.TypeOf(value).Kind() == reflect.Int {
return fmt.Sprintf("%d", value), nil
}

return value, nil
}

func (l *LocalServiceInvoker) invokeMethod(instance interface{}, method *reflect.Method, params []reflect.Value) []reflect.Value {
instanceValue := reflect.ValueOf(instance)
if method.Func.IsValid() {
allParams := append([]reflect.Value{instanceValue}, params...)
return method.Func.Call(allParams)
}
return nil
}

func (l *LocalServiceInvoker) Close(ctx context.Context) error {
l.mutex.Lock()
defer l.mutex.Unlock()
l.serviceRegistry = nil
l.methodCache = nil
return nil
}

+ 212
- 0
pkg/saga/statemachine/engine/invoker/local_invoker_test.go View File

@@ -0,0 +1,212 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package invoker

import (
"context"
"errors"
"fmt"
"reflect"
"testing"

"github.com/seata/seata-go/pkg/saga/statemachine/statelang/state"
)

type MockLocalService struct {
invokeCount int
}

func (m *MockLocalService) GetServiceName() string {
return "MockLocalService"
}

func (m *MockLocalService) Add(a, b int) int {
m.invokeCount++
return a + b
}

func (m *MockLocalService) Multiply(f float64, i int) float64 {
m.invokeCount++
return f * float64(i)
}

type User struct {
Name string `json:"name"`
Age int `json:"age"`
}

func (m *MockLocalService) GetUserName(user User) string {
m.invokeCount++
return user.Name
}

func (m *MockLocalService) ErrorMethod() error {
return errors.New("expected error")
}

func TestLocalInvoker_ServiceNotRegistered(t *testing.T) {
invoker := NewLocalServiceInvoker()
ctx := context.Background()
taskState := newLocalServiceTaskState("unregisteredService", "AnyMethod")

_, err := invoker.Invoke(ctx, []any{}, taskState)
if err == nil {
t.Error("expected error when service not registered, but got nil")
}
if err.Error() != "service unregisteredService not registered" {
t.Errorf("unexpected error message: %v", err)
}
}

func TestLocalInvoker_MethodNotFound(t *testing.T) {
invoker := NewLocalServiceInvoker()
service := &MockLocalService{}
invoker.RegisterService("mockService", service)

ctx := context.Background()
taskState := newLocalServiceTaskState("mockService", "NonExistentMethod")

_, err := invoker.Invoke(ctx, []any{}, taskState)
if err == nil {
t.Error("expected error when method not found, but got nil")
}
if err.Error() != "method NonExistentMethod not found in service mockService" {
t.Errorf("unexpected error message: %v", err)
}
}

func TestLocalInvoker_InvokeSuccess(t *testing.T) {
tests := []struct {
name string
service interface{}
serviceName string
methodName string
input []any
expected interface{}
}{
{
name: "test basic method call",
service: &MockLocalService{},
serviceName: "mockService",
methodName: "GetServiceName",
input: []any{},
expected: "MockLocalService",
},
{
name: "test method with parameters",
service: &MockLocalService{},
serviceName: "mockService",
methodName: "Add",
input: []any{2, 3},
expected: 5,
},
{
name: "test parameter type conversion",
service: &MockLocalService{},
serviceName: "mockService",
methodName: "Multiply",
input: []any{2.5, 4},
expected: 10.0,
},
}

invoker := NewLocalServiceInvoker()
ctx := context.Background()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
invoker.RegisterService(tt.serviceName, tt.service)
taskState := newLocalServiceTaskState(tt.serviceName, tt.methodName)

results, err := invoker.Invoke(ctx, tt.input, taskState)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(results) == 0 {
t.Fatal("no results returned")
}

result := results[0].Interface()
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

func TestLocalInvoker_StructParameterConversion(t *testing.T) {
invoker := NewLocalServiceInvoker()
service := &MockLocalService{}
invoker.RegisterService("userService", service)

ctx := context.Background()
taskState := newLocalServiceTaskState("userService", "GetUserName")

input := []any{map[string]interface{}{"name": "Alice", "age": 30}}
results, err := invoker.Invoke(ctx, input, taskState)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(results) == 0 {
t.Fatal("no results returned")
}

result := results[0].Interface()
if result != "Alice" {
t.Errorf("expected 'Alice', got %v", result)
}
}

func TestLocalInvoker_MethodCaching(t *testing.T) {
invoker := NewLocalServiceInvoker()
service := &MockLocalService{}
invoker.RegisterService("cacheTestService", service)

ctx := context.Background()
taskState := newLocalServiceTaskState("cacheTestService", "Add")

_, err := invoker.Invoke(ctx, []any{1, 1}, taskState)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

results, err := invoker.Invoke(ctx, []any{2, 3}, taskState)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if results[0].Interface() != 5 {
t.Errorf("expected 5, got %v", results[0].Interface())
}

if service.invokeCount != 2 {
t.Errorf("expected 2 invocations, got %d", service.invokeCount)
}
}

func newLocalServiceTaskState(serviceName, methodName string) state.ServiceTaskState {
serviceTaskStateImpl := state.NewServiceTaskStateImpl()
serviceTaskStateImpl.SetName(fmt.Sprintf("%s_%s", serviceName, methodName))
serviceTaskStateImpl.SetIsAsync(false)
serviceTaskStateImpl.SetServiceName(serviceName)
serviceTaskStateImpl.SetServiceType("local")
serviceTaskStateImpl.SetServiceMethod(methodName)
return serviceTaskStateImpl
}

Loading…
Cancel
Save