Browse Source

optimize tm

tags/v1.0.0-rc2
liuxiaomin 5 years ago
parent
commit
b088469904
5 changed files with 128 additions and 241 deletions
  1. +14
    -0
      client/context/root_context.go
  2. +84
    -209
      client/proxy/service.go
  3. +1
    -0
      client/tcc/business_action_context.go
  4. +25
    -20
      client/tm/proxy.go
  5. +4
    -12
      client/tm/proxy_test.go

+ 14
- 0
client/context/root_context.go View File

@@ -24,6 +24,20 @@ type RootContext struct {
localMap map[string]interface{}
}

func NewRootContext(ctx context.Context) *RootContext {
rootCtx := &RootContext{
Context: ctx,
localMap: make(map[string]interface{}),
}

xId := ctx.Value(KEY_XID)
if xId != nil {
xid := xId.(string)
rootCtx.Bind(xid)
}
return rootCtx
}

func (c *RootContext) Set(key string, value interface{}) {
if c.localMap == nil {
c.localMap = make(map[string]interface{})


+ 84
- 209
client/proxy/service.go View File

@@ -3,184 +3,80 @@ package proxy
import (
"context"
"reflect"
"strings"
"sync"
"unicode"
"unicode/utf8"
)

import (
"github.com/pkg/errors"
)

import (
context2 "github.com/dk-lockdown/seata-golang/client/context"
"github.com/dk-lockdown/seata-golang/pkg/logging"
)

const (
// PROXY_METHOD ...
PROXY_METHOD = "ProxyMethods"
)

var (
// Precompute the reflect type for error. Can't use error directly
// because Typeof takes an empty interface value. This is annoying.
typeOfError = reflect.TypeOf((*error)(nil)).Elem()

// ServiceMap ...
ServiceMap = &serviceMap{
serviceMap: make(map[string]*Service),
}
// serviceDescriptorMap, string -> *ServiceDescriptor
serviceDescriptorMap = sync.Map{}
)

type ProxyService interface {
// Methods
ProxyMethods() map[string]bool
}

// MethodType ...
type MethodType struct {
method reflect.Method
ctxType reflect.Type
argsType []reflect.Type
returnValuesType []reflect.Type
}

// Method ...
func (m *MethodType) Method() reflect.Method {
return m.method
}

// ArgsType ...
func (m *MethodType) ArgsType() []reflect.Type {
return m.argsType
}

// ReturnValuesType ...
func (m *MethodType) ReturnValuesType() []reflect.Type {
return m.returnValuesType
}

// Service ...
type Service struct {
name string
rcvr reflect.Value
rcvrType reflect.Type
methods map[string]*MethodType
}

// Method ...
func (s *Service) Method() map[string]*MethodType {
return s.methods
}

// RcvrType ...
func (s *Service) RcvrType() reflect.Type {
return s.rcvrType
}

// Rcvr ...
func (s *Service) Rcvr() reflect.Value {
return s.rcvr
}

type serviceMap struct {
mutex sync.RWMutex // protects the serviceMap
serviceMap map[string]*Service // service name -> service
}

func (sm *serviceMap) GetService(name string) *Service {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
if s, ok := sm.serviceMap[name]; ok {
return s
}
return nil
}

func (sm *serviceMap) Register(rcvr ProxyService) (string, error) {
s := new(Service)
s.rcvrType = reflect.TypeOf(rcvr)
s.rcvr = reflect.ValueOf(rcvr)
sname := reflect.Indirect(s.rcvr).Type().Name()
if sname == "" {
s := "no service name for type " + s.rcvrType.String()
logging.Logger.Errorf(s)
return "", errors.New(s)
}
if !isExported(sname) {
s := "type " + sname + " is not exported"
logging.Logger.Errorf(s)
return "", errors.New(s)
}

if server := sm.GetService(sname); server != nil {
return "", errors.New("service already defined: " + sname)
}
s.name = sname
s.methods = make(map[string]*MethodType)

// Install the methods
methods := ""
methods, s.methods = suitableMethods(s.rcvrType)

if len(s.methods) == 0 {
s := "type " + sname + " has no exported methods of suitable type"
logging.Logger.Errorf(s)
return "", errors.New(s)
}
sm.mutex.Lock()
sm.serviceMap[s.name] = s
sm.mutex.Unlock()

return strings.TrimSuffix(methods, ","), nil
}

func (sm *serviceMap) UnRegister(serviceId string) error {
sm.mutex.RLock()
_, ok := sm.serviceMap[serviceId]
if !ok {
sm.mutex.RUnlock()
return errors.New("no service for " + serviceId)
}
sm.mutex.RUnlock()

sm.mutex.Lock()
defer sm.mutex.Unlock()
delete(sm.serviceMap, serviceId)
return nil
}

func suitableMethods(typ reflect.Type) (string, map[string]*MethodType) {
methods := make(map[string]*MethodType)
var mts []string
logging.Logger.Debugf("[%s] NumMethod is %d", typ.String(), typ.NumMethod())
method, ok := typ.MethodByName(PROXY_METHOD)
var transactionMethods map[string]bool
if ok && method.Type.NumIn() == 1 && method.Type.NumOut() == 1 && method.Type.Out(0).String() == "map[string]bool" {
transactionMethods = method.Func.Call([]reflect.Value{reflect.New(typ.Elem())})[0].Interface().(map[string]bool)
}

for m := 0; m < typ.NumMethod(); m++ {
method = typ.Method(m)
_, ok := transactionMethods[method.Name]
if ok {
if mt := suiteMethod(method); mt != nil {
methods[method.Name] = mt
}
mts = append(mts, method.Name)
type MethodDescriptor struct {
Method reflect.Method
CallerValue reflect.Value
CtxType reflect.Type
ArgsType []reflect.Type
ArgsNum int
ReturnValuesType []reflect.Type
ReturnValuesNum int
}

type ServiceDescriptor struct {
Name string
ReflectType reflect.Type
ReflectValue reflect.Value
Methods sync.Map //string -> *MethodDescriptor
}

// Register
func Register(service interface{},methodName string) *MethodDescriptor {
serviceType := reflect.TypeOf(service)
serviceValue := reflect.ValueOf(service)
svcName := reflect.Indirect(serviceValue).Type().Name()

svcDesc, _ := serviceDescriptorMap.LoadOrStore(svcName,&ServiceDescriptor{
Name: svcName,
ReflectType: serviceType,
ReflectValue: serviceValue,
Methods: sync.Map{},
})
svcDescriptor := svcDesc.(*ServiceDescriptor)
methodDesc, methodExist := svcDescriptor.Methods.Load(methodName)
if methodExist {
methodDescriptor := methodDesc.(*MethodDescriptor)
return methodDescriptor
}

method, methodFounded := serviceType.MethodByName(methodName)
if methodFounded {
methodDescriptor := describeMethod(method)
if methodDescriptor != nil {
methodDescriptor.CallerValue = serviceValue
svcDescriptor.Methods.Store(methodName, methodDescriptor)
return methodDescriptor
}

}
return strings.Join(mts, ","), methods
return nil
}

func suiteMethod(method reflect.Method) *MethodType {
mtype := method.Type
mname := method.Name
inNum := mtype.NumIn()
outNum := mtype.NumOut()
// describeMethod
// might return nil when method is not exported or some other error
func describeMethod(method reflect.Method) *MethodDescriptor {
methodType := method.Type
methodName := method.Name
inNum := methodType.NumIn()
outNum := methodType.NumOut()

// Method must be exported.
if method.PkgPath != "" {
@@ -193,42 +89,42 @@ func suiteMethod(method reflect.Method) *MethodType {
)

for index := 1; index < inNum; index++ {
if mtype.In(index).String() == "context.Context" {
ctxType = mtype.In(index)
if methodType.In(index).String() == "context.Context" {
ctxType = methodType.In(index)
}
argsType = append(argsType, mtype.In(index))
argsType = append(argsType, methodType.In(index))
// need not be a pointer.
if !isExportedOrBuiltinType(mtype.In(index)) {
logging.Logger.Errorf("argument type of method %q is not exported %v", mname, mtype.In(index))
if !isExportedOrBuiltinType(methodType.In(index)) {
logging.Logger.Errorf("argument type of method %q is not exported %v", methodName, methodType.In(index))
return nil
}
}


// The latest return type of the method must be error.
if returnType := mtype.Out(outNum - 1); returnType != typeOfError {
logging.Logger.Warnf("the latest return type %s of method %q is not error", returnType, mname)
if returnType := methodType.Out(outNum - 1); returnType != typeOfError {
logging.Logger.Warnf("the latest return type %s of method %q is not error", returnType, methodName)
return nil
}

// returnValuesType
for num := 0; num < outNum; num++ {
returnValuesType = append(returnValuesType, mtype.Out(num))
returnValuesType = append(returnValuesType, methodType.Out(num))
// need not be a pointer.
if !isExportedOrBuiltinType(mtype.Out(num)) {
logging.Logger.Errorf("reply type of method %s not exported{%v}", mname, mtype.Out(num))
if !isExportedOrBuiltinType(methodType.Out(num)) {
logging.Logger.Errorf("reply type of method %s not exported{%v}", methodName, methodType.Out(num))
return nil
}
}

return &MethodType{method: method, argsType: argsType, returnValuesType: returnValuesType, ctxType: ctxType}
}

func SuiteContext(m *MethodType,ctx context.Context) reflect.Value {
if contextv := reflect.ValueOf(ctx); contextv.IsValid() {
return contextv
return &MethodDescriptor{
Method: method,
CtxType: ctxType,
ArgsType: argsType,
ArgsNum: inNum,
ReturnValuesType: returnValuesType,
ReturnValuesNum: outNum,
}
return reflect.Zero(m.ctxType)
}

// Is this an exported - upper case - name
@@ -247,29 +143,18 @@ func isExportedOrBuiltinType(t reflect.Type) bool {
return isExported(t.Name()) || t.PkgPath() == ""
}

func Invoke(methodType *MethodType,ctx *context2.RootContext,serviceName string,methodName string,args []interface{}) []reflect.Value {
svc := ServiceMap.GetService(serviceName)
if svc == nil {
logging.Logger.Errorf("cannot find service [%s]", serviceName)
panic(errors.Errorf("cannot find service [%s]", serviceName))
}
// Invoke
func Invoke(methodDesc *MethodDescriptor, ctx *context2.RootContext, args []interface{}) []reflect.Value {

// get method
method := svc.Method()[methodName]
if method == nil {
logging.Logger.Errorf("cannot find method [%s] of service [%s]", methodName, serviceName)
panic(errors.Errorf("cannot find method [%s] of service [%s]", methodName, serviceName))
}

in := []reflect.Value{svc.Rcvr()}
in := []reflect.Value{methodDesc.CallerValue}

for i := 0; i < len(args); i++ {
t := reflect.ValueOf(args[i])
if method.ArgsType()[i].String() == "context.Context" {
t = SuiteContext(method,ctx.Context)
if methodDesc.ArgsType[i].String() == "context.Context" {
t = SuiteContext(methodDesc,ctx.Context)
}
if !t.IsValid() {
at := method.ArgsType()[i]
at := methodDesc.ArgsType[i]
if at.Kind() == reflect.Ptr {
at = at.Elem()
}
@@ -278,32 +163,22 @@ func Invoke(methodType *MethodType,ctx *context2.RootContext,serviceName string,
in = append(in, t)
}

returnValues := method.Method().Func.Call(in)
returnValues := methodDesc.Method.Func.Call(in)

return returnValues
}

func GetMethod(serviceName, methodName string) *MethodType {
svc := ServiceMap.GetService(serviceName)
if svc == nil {
logging.Logger.Errorf("cannot find service [%s]", serviceName)
panic(errors.Errorf("cannot find service [%s]", serviceName))
}

// get method
method := svc.Method()[methodName]
if method == nil {
logging.Logger.Errorf("cannot find method [%s] of service [%s]", methodName, serviceName)
panic(errors.Errorf("cannot find method [%s] of service [%s]", methodName, serviceName))
func SuiteContext(methodDesc *MethodDescriptor, ctx context.Context) reflect.Value {
if contextValue := reflect.ValueOf(ctx); contextValue.IsValid() {
return contextValue
}
return method
return reflect.Zero(methodDesc.CtxType)
}

func ReturnWithError(method *MethodType,err error) []reflect.Value {
func ReturnWithError(methodDesc *MethodDescriptor,err error) []reflect.Value {
var result = make([]reflect.Value,0)
returnValuesLen := len(method.returnValuesType)
for i := 0;i < returnValuesLen - 1; i++ {
result = append(result,reflect.Zero(method.returnValuesType[i]))
for i := 0;i < methodDesc.ReturnValuesNum - 1; i++ {
result = append(result,reflect.Zero(methodDesc.ReturnValuesType[i]))
}
result = append(result,reflect.ValueOf(err))
return result

+ 1
- 0
client/tcc/business_action_context.go View File

@@ -0,0 +1 @@
package tcc

+ 25
- 20
client/tm/proxy.go View File

@@ -16,7 +16,7 @@ import (
)

type GlobalTransactionProxyService interface {
GetProxyService() proxy.ProxyService
GetProxyService() interface{}
GetMethodTransactionInfo(methodName string) *TransactionInfo
}

@@ -37,11 +37,9 @@ func Implement(v GlobalTransactionProxyService) {
logging.Logger.Errorf("%s must be a struct ptr", valueOf.String())
return
}
proxiedService := v.GetProxyService()
pxdService := reflect.ValueOf(proxiedService)
serviceName := reflect.Indirect(pxdService).Type().Name()
proxyService := v.GetProxyService()

makeCallProxy := func(serviceName, methodName string,txInfo *TransactionInfo) func(in []reflect.Value) []reflect.Value {
makeCallProxy := func(methodDesc *proxy.MethodDescriptor, txInfo *TransactionInfo) func(in []reflect.Value) []reflect.Value {
return func(in []reflect.Value) []reflect.Value {
var (
args = make([]interface{},0)
@@ -50,17 +48,22 @@ func Implement(v GlobalTransactionProxyService) {
)

if txInfo == nil {
// testing phase, this problem should be resolved
panic(errors.New("transactionInfo does not exist"))
}
method := proxy.GetMethod(serviceName,methodName)

inNum := len(in)
invCtx := &context2.RootContext{Context: context.Background()}
if inNum + 1 != methodDesc.ArgsNum {
// testing phase, this problem should be resolved
panic(errors.New("args does not match"))
}

invCtx := context2.NewRootContext(context.Background())
for i := 0; i < inNum; i++ {
if in[i].Type().String() == "context.Context" {
if !in[i].IsNil() {
// the user declared context as method's parameter
invCtx = &context2.RootContext{Context:in[i].Interface().(context.Context)}
invCtx = context2.NewRootContext(in[i].Interface().(context.Context))
}
}
args = append(args,in[i].Interface())
@@ -72,14 +75,14 @@ func Implement(v GlobalTransactionProxyService) {
switch txInfo.Propagation {
case NOT_SUPPORTED:
suspendedResourcesHolder,_ = tx.Suspend(true,invCtx)
returnValues = proxy.Invoke(method,invCtx,serviceName,methodName,args)
returnValues = proxy.Invoke(methodDesc, invCtx, args)
return returnValues
case REQUIRES_NEW:
suspendedResourcesHolder,_ = tx.Suspend(true,invCtx)
break
case SUPPORTS:
if !invCtx.InGlobalTransaction() {
returnValues = proxy.Invoke(method,invCtx,serviceName,methodName,args)
returnValues = proxy.Invoke(methodDesc, invCtx, args)
return returnValues
}
break
@@ -87,26 +90,26 @@ func Implement(v GlobalTransactionProxyService) {
break
case NEVER:
if invCtx.InGlobalTransaction() {
return proxy.ReturnWithError(method,errors.Errorf("Existing transaction found for transaction marked with propagation 'never',xid = %s",invCtx.GetXID()))
return proxy.ReturnWithError(methodDesc,errors.Errorf("Existing transaction found for transaction marked with propagation 'never',xid = %s",invCtx.GetXID()))
} else {
returnValues = proxy.Invoke(method,invCtx,serviceName,methodName,args)
returnValues = proxy.Invoke(methodDesc, invCtx, args)
return returnValues
}
case MANDATORY:
if !invCtx.InGlobalTransaction() {
return proxy.ReturnWithError(method,errors.New("No existing transaction found for transaction marked with propagation 'mandatory'"))
return proxy.ReturnWithError(methodDesc,errors.New("No existing transaction found for transaction marked with propagation 'mandatory'"))
}
break
default:
return proxy.ReturnWithError(method,errors.Errorf("Not Supported Propagation: %s",txInfo.Propagation.String()))
return proxy.ReturnWithError(methodDesc,errors.Errorf("Not Supported Propagation: %s",txInfo.Propagation.String()))
}

beginErr := tx.BeginWithTimeoutAndName(txInfo.TimeOut,txInfo.Name,invCtx)
if beginErr != nil {
return proxy.ReturnWithError(method, errors.WithStack(beginErr))
return proxy.ReturnWithError(methodDesc, errors.WithStack(beginErr))
}

returnValues = proxy.Invoke(method,invCtx,serviceName,methodName,args)
returnValues = proxy.Invoke(methodDesc, invCtx, args)

errValue := returnValues[len(returnValues)-1]

@@ -114,14 +117,14 @@ func Implement(v GlobalTransactionProxyService) {
if errValue.IsValid() && !errValue.IsNil() {
rollbackErr := tx.Rollback(invCtx)
if rollbackErr != nil {
return proxy.ReturnWithError(method,errors.WithStack(rollbackErr))
return proxy.ReturnWithError(methodDesc,errors.WithStack(rollbackErr))
}
return proxy.ReturnWithError(method,errors.New("rollback failure"))
return proxy.ReturnWithError(methodDesc,errors.New("rollback failure"))
}

commitErr := tx.Commit(invCtx)
if commitErr != nil {
return proxy.ReturnWithError(method,errors.WithStack(commitErr))
return proxy.ReturnWithError(methodDesc,errors.WithStack(commitErr))
}

return returnValues
@@ -142,8 +145,10 @@ func Implement(v GlobalTransactionProxyService) {
continue
}

methodDescriptor := proxy.Register(proxyService,methodName)

// do method proxy here:
f.Set(reflect.MakeFunc(f.Type(), makeCallProxy(serviceName,methodName,v.GetMethodTransactionInfo(methodName))))
f.Set(reflect.MakeFunc(f.Type(), makeCallProxy(methodDescriptor,v.GetMethodTransactionInfo(methodName))))
logging.Logger.Debugf("set method [%s]", methodName)
}
}


+ 4
- 12
client/tm/proxy_test.go View File

@@ -4,7 +4,6 @@ import (
"context"
"github.com/dk-lockdown/seata-golang/client/config"
getty2 "github.com/dk-lockdown/seata-golang/client/getty"
"github.com/dk-lockdown/seata-golang/client/proxy"
"github.com/stretchr/testify/assert"
"testing"
)
@@ -24,14 +23,8 @@ func (svc *ZooService) ManageAnimal(ctx context.Context,dog *Dog,cat *Cat) (bool
return true,nil
}

func (svc *ZooService) ProxyMethods() map[string]bool {
mp := make(map[string]bool)
mp["ManageAnimal"] = true
return mp
}

type TestService struct {
proxy.ProxyService
*ZooService
ManageAnimal func(ctx context.Context,dog *Dog,cat *Cat) (bool,error)
}

@@ -45,8 +38,8 @@ func init() {
}
}

func (svc *TestService) GetProxyService() proxy.ProxyService {
return svc.ProxyService
func (svc *TestService) GetProxyService() interface{} {
return svc.ZooService
}

func (svc *TestService) GetMethodTransactionInfo(methodName string) *TransactionInfo {
@@ -58,9 +51,8 @@ func TestProxy_Implement(t *testing.T) {
getty2.InitRpcClient()
NewTMClient()
zooSvc := &ZooService{}
proxy.ServiceMap.Register(zooSvc)

ts := &TestService{ProxyService: zooSvc}
ts := &TestService{ZooService: zooSvc}
Implement(ts)
result, err := ts.ManageAnimal(context.Background(),&Dog{},&Cat{})
assert.True(t,result)


Loading…
Cancel
Save