Browse Source

增加直接触发事件的命令

pull/1/head
Sydonian 2 years ago
parent
commit
91fd58ed9d
2 changed files with 332 additions and 49 deletions
  1. +224
    -37
      pkg/cmdtrie/command_trie.go
  2. +108
    -12
      pkg/cmdtrie/command_trie_test.go

+ 224
- 37
pkg/cmdtrie/command_trie.go View File

@@ -4,10 +4,15 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"

myreflect "gitlink.org.cn/cloudream/common/utils/reflect"
) )


type command struct { type command struct {
fn reflect.Value
fn reflect.Value
fnType reflect.Type
staticArgTypes []reflect.Type
lastIsArray bool
} }


type trieNode struct { type trieNode struct {
@@ -15,29 +20,36 @@ type trieNode struct {
cmd *command cmd *command
} }


type CommandTrie[TCtx any] struct {
root trieNode
type anyCommandTrie struct {
root trieNode
ctxType reflect.Type
retType reflect.Type
} }


func NewCommandTrie[TCtx any]() CommandTrie[TCtx] {
return CommandTrie[TCtx]{
func newAnyCommandTrie(ctxType reflect.Type, retType reflect.Type) anyCommandTrie {
return anyCommandTrie{
root: trieNode{ root: trieNode{
nexts: make(map[string]*trieNode), nexts: make(map[string]*trieNode),
}, },
ctxType: ctxType,
retType: retType,
} }
} }


func (t *CommandTrie[TCtx]) Add(fn any, prefixWords ...string) error {
func (t *anyCommandTrie) Add(fn any, prefixWords ...string) error {
typ := reflect.TypeOf(fn) typ := reflect.TypeOf(fn)
if typ.Kind() != reflect.Func { if typ.Kind() != reflect.Func {
return fmt.Errorf("fn must be a function, but get a %v", typ.Kind()) return fmt.Errorf("fn must be a function, but get a %v", typ.Kind())
} }


for i := 0; i < typ.NumIn(); i++ {
argType := typ.In(i)
if argType.Kind() == reflect.Array && i < typ.NumIn()-1 {
return fmt.Errorf("array argument must at the last one")
}
err := t.checkFnReturn(typ)
if err != nil {
return err
}

err = t.checkFnArgs(typ)
if err != nil {
return err
} }


ptr := &t.root ptr := &t.root
@@ -52,16 +64,124 @@ func (t *CommandTrie[TCtx]) Add(fn any, prefixWords ...string) error {
ptr = next ptr = next
} }


fnType := reflect.TypeOf(fn)
var staticArgTypes []reflect.Type
if t.ctxType != nil {
for i := 1; i < fnType.NumIn(); i++ {
staticArgTypes = append(staticArgTypes, fnType.In(i))
}
} else {
for i := 0; i < fnType.NumIn(); i++ {
staticArgTypes = append(staticArgTypes, fnType.In(i))
}
}

var lastIsArray = false
if len(staticArgTypes) > 0 {
kind := staticArgTypes[len(staticArgTypes)-1].Kind()
lastIsArray = kind == reflect.Array || kind == reflect.Slice
}

ptr.cmd = &command{ ptr.cmd = &command{
fn: reflect.ValueOf(fn),
fn: reflect.ValueOf(fn),
fnType: reflect.TypeOf(fn),
staticArgTypes: staticArgTypes,
lastIsArray: lastIsArray,
}
return nil
}

func (t *anyCommandTrie) checkFnReturn(typ reflect.Type) error {
if t.retType != nil {
if typ.NumOut() != 1 {
return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
}

fnRetType := typ.Out(0)
if t.retType.Kind() == reflect.Interface {

// 如果TRet是接口类型,那么fn的返回值只要实现了此接口,就也可以接受
if !fnRetType.Implements(t.retType) {
return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
}

} else if fnRetType != t.retType {
return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
}
} }
return nil return nil
} }


func (t *CommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error {
func (t *anyCommandTrie) checkFnArgs(typ reflect.Type) error {
if t.ctxType != nil {
if typ.NumIn() < 1 {
return fmt.Errorf("fn must have a ctx argument")
}

for i := 0; i < typ.NumIn(); i++ {
argType := typ.In(i)
if i == 0 && argType != t.ctxType {
return fmt.Errorf("first argument of fn must be %s", t.ctxType.Name())
}

if argType.Kind() == reflect.Array && i < typ.NumIn()-1 {
return fmt.Errorf("array argument must at the last one")
}
}
} else {
for i := 0; i < typ.NumIn(); i++ {
argType := typ.In(i)
if argType.Kind() == reflect.Array && i < typ.NumIn()-1 {
return fmt.Errorf("array argument must at the last one")
}
}
}
return nil
}

func (t *anyCommandTrie) Execute(ctx any, cmdWords ...string) ([]reflect.Value, error) {
var cmd *command var cmd *command
var argWords []string var argWords []string


cmd, argWords, err := t.findCommand(cmdWords, argWords)
if err != nil {
return nil, err
}

if cmd.lastIsArray {
// 最后一个参数如果是数组,那么可以少一个参数
if len(argWords) < len(cmd.staticArgTypes)-1 {
return nil, fmt.Errorf("no enough arguments for command")
}
} else if len(argWords) < len(cmd.staticArgTypes) {
return nil, fmt.Errorf("no enough arguments for command")
}

var callArgs []reflect.Value

// 如果有Ctx参数,则加上Ctx参数
if t.ctxType != nil {
callArgs = append(callArgs, reflect.ValueOf(ctx))
}

// 数组参数只能是最后一个,所以先处理最后一个参数前的参数
callArgs, err = t.parseFrontArgs(cmd, argWords, callArgs)
if err != nil {
return nil, err
}

// 解析最后一个参数
callArgs, err = t.parseLastArg(cmd, argWords, callArgs)
if err != nil {
return nil, err
}

return cmd.fn.Call(callArgs), nil
}

func (t *anyCommandTrie) findCommand(cmdWords []string, argWords []string) (*command, []string, error) {
var cmd *command

ptr := &t.root ptr := &t.root
for i := 0; i < len(cmdWords); i++ { for i := 0; i < len(cmdWords); i++ {
next, ok := ptr.nexts[cmdWords[i]] next, ok := ptr.nexts[cmdWords[i]]
@@ -76,31 +196,33 @@ func (t *CommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error {
ptr = next ptr = next
} }
if cmd == nil { if cmd == nil {
return fmt.Errorf("command not found")
}

fnType := cmd.fn.Type()

// 最后一个参数如果是数组,那么可以少一个参数
if len(argWords) < fnType.NumIn()-1 {
return fmt.Errorf("no enough arguments for command")
return nil, nil, fmt.Errorf("command not found")
} }
return cmd, argWords, nil
}


var callArgs []reflect.Value

// 数组参数只能是最后一个,所以先处理最后一个参数前的参数
for i := 0; i < fnType.NumIn()-1; i++ {
val, err := t.parseValue(argWords[i], fnType.In(i))
func (t *anyCommandTrie) parseFrontArgs(cmd *command, argWords []string, callArgs []reflect.Value) ([]reflect.Value, error) {
for i := 0; i < len(cmd.staticArgTypes)-1; i++ {
val, err := t.parseValue(argWords[i], cmd.staticArgTypes[i])
if err != nil { if err != nil {
return fmt.Errorf("cannot parse function argument at %d, err: %s", i, err.Error())
// 如果有Ctx参数,则参数的位置要往后一个
argIndex := i
if t.ctxType != nil {
argIndex++
}

return nil, fmt.Errorf("cannot parse function argument at %d, err: %s", argIndex, err.Error())
} }


callArgs = append(callArgs, val) callArgs = append(callArgs, val)
} }
return callArgs, nil
}


if fnType.NumIn() > 0 {
lastArgType := fnType.In(fnType.NumIn() - 1)
lastArgWords := argWords[fnType.NumIn()-1:]
func (t *anyCommandTrie) parseLastArg(cmd *command, argWords []string, callArgs []reflect.Value) ([]reflect.Value, error) {
if len(cmd.staticArgTypes) > 0 {
lastArgType := cmd.staticArgTypes[len(cmd.staticArgTypes)-1]
lastArgWords := argWords[len(cmd.staticArgTypes)-1:]
lastArgTypeKind := lastArgType.Kind() lastArgTypeKind := lastArgType.Kind()


var lastArg reflect.Value var lastArg reflect.Value
@@ -114,31 +236,29 @@ func (t *CommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error {
for i := 0; i < len(lastArgWords); i++ { for i := 0; i < len(lastArgWords); i++ {
eleVal, err := t.parseValue(lastArgWords[i], lastArgType.Elem()) eleVal, err := t.parseValue(lastArgWords[i], lastArgType.Elem())
if err != nil { if err != nil {
return fmt.Errorf("cannot parse as array element, err: %s", err.Error())
return nil, fmt.Errorf("cannot parse as array element, err: %s", err.Error())
} }
lastArg.Index(i).Set(eleVal) lastArg.Index(i).Set(eleVal)
} }


} else { } else {
if len(lastArgWords) == 0 { if len(lastArgWords) == 0 {
return fmt.Errorf("no enough arguments for command")
return nil, fmt.Errorf("no enough arguments for command")
} }


var err error var err error
lastArg, err = t.parseValue(lastArgWords[0], lastArgType) lastArg, err = t.parseValue(lastArgWords[0], lastArgType)
if err != nil { if err != nil {
return fmt.Errorf("cannot parse function argument at %d, err: %s", fnType.NumIn()-1, err.Error())
return nil, fmt.Errorf("cannot parse function argument at %d, err: %s", cmd.fnType.NumIn()-1, err.Error())
} }
} }


callArgs = append(callArgs, lastArg) callArgs = append(callArgs, lastArg)
} }

cmd.fn.Call(callArgs)
return nil
return callArgs, nil
} }


func (t *CommandTrie[TCtx]) parseValue(word string, valueType reflect.Type) (reflect.Value, error) {
func (t *anyCommandTrie) parseValue(word string, valueType reflect.Type) (reflect.Value, error) {
valTypeKind := valueType.Kind() valTypeKind := valueType.Kind()


if valTypeKind == reflect.String { if valTypeKind == reflect.String {
@@ -183,3 +303,70 @@ func (t *CommandTrie[TCtx]) parseValue(word string, valueType reflect.Type) (ref


return reflect.Value{}, fmt.Errorf("cannot parse string as %s", valueType.Name()) return reflect.Value{}, fmt.Errorf("cannot parse string as %s", valueType.Name())
} }

type CommandTrie[TCtx any, TRet any] struct {
anyTrie anyCommandTrie
}

func NewCommandTrie[TCtx any, TRet any]() CommandTrie[TCtx, TRet] {
return CommandTrie[TCtx, TRet]{
anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), myreflect.TypeOf[TRet]()),
}
}

func (t *CommandTrie[TCtx, TRet]) Add(fn any, prefixWords ...string) error {
return t.anyTrie.Add(fn, prefixWords...)
}

func (t *CommandTrie[TCtx, TRet]) Execute(ctx TCtx, cmdWords ...string) (TRet, error) {
retValues, err := t.anyTrie.Execute(ctx, cmdWords...)
if err != nil {
var defRet TRet
return defRet, err
}

return retValues[0].Interface().(TRet), nil
}

type VoidCommandTrie[TCtx any] struct {
anyTrie anyCommandTrie
}

func NewVoidCommandTrie[TCtx any]() VoidCommandTrie[TCtx] {
return VoidCommandTrie[TCtx]{
anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), nil),
}
}

func (t *VoidCommandTrie[TCtx]) Add(fn any, prefixWords ...string) error {
return t.anyTrie.Add(fn, prefixWords...)
}

func (t *VoidCommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error {
_, err := t.anyTrie.Execute(ctx, cmdWords...)
return err
}

type StaticCommandTrie[TRet any] struct {
anyTrie anyCommandTrie
}

func NewStaticCommandTrie[TRet any]() StaticCommandTrie[TRet] {
return StaticCommandTrie[TRet]{
anyTrie: newAnyCommandTrie(nil, myreflect.TypeOf[TRet]()),
}
}

func (t *StaticCommandTrie[TRet]) Add(fn any, prefixWords ...string) error {
return t.anyTrie.Add(fn, prefixWords...)
}

func (t *StaticCommandTrie[TRet]) Execute(cmdWords ...string) (TRet, error) {
retValues, err := t.anyTrie.Execute(nil, cmdWords...)
if err != nil {
var defRet TRet
return defRet, err
}

return retValues[0].Interface().(TRet), nil
}

+ 108
- 12
pkg/cmdtrie/command_trie_test.go View File

@@ -8,37 +8,39 @@ import (


func Test_CommandTrie(t *testing.T) { func Test_CommandTrie(t *testing.T) {
Convey("无参数命令", t, func() { Convey("无参数命令", t, func() {
trie := NewCommandTrie[int]()
trie := NewVoidCommandTrie[int]()


var ret string var ret string


trie.Add(func() {
err := trie.Add(func(int) {
ret = "ok" ret = "ok"
}, "a") }, "a")
So(err, ShouldBeNil)


err := trie.Execute(0, "a")
err = trie.Execute(0, "a")
So(err, ShouldBeNil) So(err, ShouldBeNil)


So(ret, ShouldEqual, "ok") So(ret, ShouldEqual, "ok")
}) })


Convey("各种参数", t, func() { Convey("各种参数", t, func() {
trie := NewCommandTrie[int]()
trie := NewVoidCommandTrie[int]()


var argI int var argI int
var argStr string var argStr string
var argBl bool var argBl bool
var argFP float32 var argFP float32


trie.Add(func(i int, str string, bl bool, fp float32) {
err := trie.Add(func(int, i int, str string, bl bool, fp float32) {
argI = i argI = i
argStr = str argStr = str
argBl = bl argBl = bl
argFP = fp argFP = fp


}, "a", "b") }, "a", "b")
So(err, ShouldBeNil)


err := trie.Execute(0, "a", "b", "1", "2", "true", "3")
err = trie.Execute(0, "a", "b", "1", "2", "true", "3")
So(err, ShouldBeNil) So(err, ShouldBeNil)


So(argI, ShouldEqual, 1) So(argI, ShouldEqual, 1)
@@ -48,18 +50,19 @@ func Test_CommandTrie(t *testing.T) {
}) })


Convey("有数组参数", t, func() { Convey("有数组参数", t, func() {
trie := NewCommandTrie[int]()
trie := NewVoidCommandTrie[int]()


var argI int var argI int
var argArr []int64 var argArr []int64


trie.Add(func(i int, arr []int64) {
err := trie.Add(func(int, i int, arr []int64) {
argI = i argI = i
argArr = arr argArr = arr


}, "a", "b") }, "a", "b")
So(err, ShouldBeNil)


err := trie.Execute(0, "a", "b", "1", "2", "3", "4")
err = trie.Execute(0, "a", "b", "1", "2", "3", "4")
So(err, ShouldBeNil) So(err, ShouldBeNil)


So(argI, ShouldEqual, 1) So(argI, ShouldEqual, 1)
@@ -67,21 +70,114 @@ func Test_CommandTrie(t *testing.T) {
}) })


Convey("有数组参数,但为空", t, func() { Convey("有数组参数,但为空", t, func() {
trie := NewCommandTrie[int]()
trie := NewVoidCommandTrie[int]()


var argI int var argI int
var argArr []int64 var argArr []int64


trie.Add(func(i int, arr []int64) {
err := trie.Add(func(int, i int, arr []int64) {
argI = i argI = i
argArr = arr argArr = arr


}, "a", "b") }, "a", "b")
So(err, ShouldBeNil)

err = trie.Execute(0, "a", "b", "1")
So(err, ShouldBeNil)

So(argI, ShouldEqual, 1)
So(argArr, ShouldResemble, []int64{})
})

Convey("带返回值", t, func() {
trie := NewCommandTrie[int, int]()

var argI int
var argArr []int64

err := trie.Add(func(int, i int, arr []int64) int {
argI = i
argArr = arr
return 123
}, "a", "b")
So(err, ShouldBeNil)

ret, err := trie.Execute(0, "a", "b", "1")
So(err, ShouldBeNil)

So(argI, ShouldEqual, 1)
So(argArr, ShouldResemble, []int64{})
So(ret, ShouldEqual, 123)
})

Convey("返回值是接口类型", t, func() {
trie := NewCommandTrie[int, any]()

var argI int
var argArr []int64

err := trie.Add(func(int, i int, arr []int64) int {
argI = i
argArr = arr
return 123
}, "a", "b")
So(err, ShouldBeNil)

err = trie.Add(func(int, i int, arr []int64) string {
return "123"
}, "a", "c")
So(err, ShouldBeNil)

ret, err := trie.Execute(0, "a", "b", "1")
So(err, ShouldBeNil)
So(argI, ShouldEqual, 1)
So(argArr, ShouldResemble, []int64{})
So(ret, ShouldEqual, 123)

ret2, err := trie.Execute(0, "a", "c", "1")
So(err, ShouldBeNil)
So(ret2, ShouldEqual, "123")
})

Convey("无Ctx参数", t, func() {
trie := NewStaticCommandTrie[int]()

var argI int
var argArr []int64

err := trie.Add(func(i int, arr []int64) int {
argI = i
argArr = arr
return 123
}, "a", "b")
So(err, ShouldBeNil)

ret, err := trie.Execute("a", "b", "1")
So(err, ShouldBeNil)

So(argI, ShouldEqual, 1)
So(argArr, ShouldResemble, []int64{})
So(ret, ShouldEqual, 123)
})

Convey("完全无参数", t, func() {
trie := NewStaticCommandTrie[int]()

var argI int
var argArr []int64

err := trie.Add(func() int {
argI = 1
argArr = []int64{}
return 123
}, "a", "b")
So(err, ShouldBeNil)


err := trie.Execute(0, "a", "b", "1")
ret, err := trie.Execute("a", "b")
So(err, ShouldBeNil) So(err, ShouldBeNil)


So(argI, ShouldEqual, 1) So(argI, ShouldEqual, 1)
So(argArr, ShouldResemble, []int64{}) So(argArr, ShouldResemble, []int64{})
So(ret, ShouldEqual, 123)
}) })
} }

Loading…
Cancel
Save