Browse Source

增加直接触发事件的命令

pull/1/head
Sydonian 2 years ago
parent
commit
91fd58ed9d
2 changed files with 332 additions and 49 deletions
  1. +224
    -37
      pkg/cmdtrie/command_trie.go
  2. +108
    -12
      pkg/cmdtrie/command_trie_test.go

+ 224
- 37
pkg/cmdtrie/command_trie.go View File

@@ -4,10 +4,15 @@ import (
"fmt"
"reflect"
"strconv"

myreflect "gitlink.org.cn/cloudream/common/utils/reflect"
)

type command struct {
fn reflect.Value
fn reflect.Value
fnType reflect.Type
staticArgTypes []reflect.Type
lastIsArray bool
}

type trieNode struct {
@@ -15,29 +20,36 @@ type trieNode struct {
cmd *command
}

type CommandTrie[TCtx any] struct {
root trieNode
type anyCommandTrie struct {
root trieNode
ctxType reflect.Type
retType reflect.Type
}

func NewCommandTrie[TCtx any]() CommandTrie[TCtx] {
return CommandTrie[TCtx]{
func newAnyCommandTrie(ctxType reflect.Type, retType reflect.Type) anyCommandTrie {
return anyCommandTrie{
root: trieNode{
nexts: make(map[string]*trieNode),
},
ctxType: ctxType,
retType: retType,
}
}

func (t *CommandTrie[TCtx]) Add(fn any, prefixWords ...string) error {
func (t *anyCommandTrie) Add(fn any, prefixWords ...string) error {
typ := reflect.TypeOf(fn)
if typ.Kind() != reflect.Func {
return fmt.Errorf("fn must be a function, but get a %v", typ.Kind())
}

for i := 0; i < typ.NumIn(); i++ {
argType := typ.In(i)
if argType.Kind() == reflect.Array && i < typ.NumIn()-1 {
return fmt.Errorf("array argument must at the last one")
}
err := t.checkFnReturn(typ)
if err != nil {
return err
}

err = t.checkFnArgs(typ)
if err != nil {
return err
}

ptr := &t.root
@@ -52,16 +64,124 @@ func (t *CommandTrie[TCtx]) Add(fn any, prefixWords ...string) error {
ptr = next
}

fnType := reflect.TypeOf(fn)
var staticArgTypes []reflect.Type
if t.ctxType != nil {
for i := 1; i < fnType.NumIn(); i++ {
staticArgTypes = append(staticArgTypes, fnType.In(i))
}
} else {
for i := 0; i < fnType.NumIn(); i++ {
staticArgTypes = append(staticArgTypes, fnType.In(i))
}
}

var lastIsArray = false
if len(staticArgTypes) > 0 {
kind := staticArgTypes[len(staticArgTypes)-1].Kind()
lastIsArray = kind == reflect.Array || kind == reflect.Slice
}

ptr.cmd = &command{
fn: reflect.ValueOf(fn),
fn: reflect.ValueOf(fn),
fnType: reflect.TypeOf(fn),
staticArgTypes: staticArgTypes,
lastIsArray: lastIsArray,
}
return nil
}

func (t *anyCommandTrie) checkFnReturn(typ reflect.Type) error {
if t.retType != nil {
if typ.NumOut() != 1 {
return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
}

fnRetType := typ.Out(0)
if t.retType.Kind() == reflect.Interface {

// 如果TRet是接口类型,那么fn的返回值只要实现了此接口,就也可以接受
if !fnRetType.Implements(t.retType) {
return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
}

} else if fnRetType != t.retType {
return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
}
}
return nil
}

func (t *CommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error {
func (t *anyCommandTrie) checkFnArgs(typ reflect.Type) error {
if t.ctxType != nil {
if typ.NumIn() < 1 {
return fmt.Errorf("fn must have a ctx argument")
}

for i := 0; i < typ.NumIn(); i++ {
argType := typ.In(i)
if i == 0 && argType != t.ctxType {
return fmt.Errorf("first argument of fn must be %s", t.ctxType.Name())
}

if argType.Kind() == reflect.Array && i < typ.NumIn()-1 {
return fmt.Errorf("array argument must at the last one")
}
}
} else {
for i := 0; i < typ.NumIn(); i++ {
argType := typ.In(i)
if argType.Kind() == reflect.Array && i < typ.NumIn()-1 {
return fmt.Errorf("array argument must at the last one")
}
}
}
return nil
}

func (t *anyCommandTrie) Execute(ctx any, cmdWords ...string) ([]reflect.Value, error) {
var cmd *command
var argWords []string

cmd, argWords, err := t.findCommand(cmdWords, argWords)
if err != nil {
return nil, err
}

if cmd.lastIsArray {
// 最后一个参数如果是数组,那么可以少一个参数
if len(argWords) < len(cmd.staticArgTypes)-1 {
return nil, fmt.Errorf("no enough arguments for command")
}
} else if len(argWords) < len(cmd.staticArgTypes) {
return nil, fmt.Errorf("no enough arguments for command")
}

var callArgs []reflect.Value

// 如果有Ctx参数,则加上Ctx参数
if t.ctxType != nil {
callArgs = append(callArgs, reflect.ValueOf(ctx))
}

// 数组参数只能是最后一个,所以先处理最后一个参数前的参数
callArgs, err = t.parseFrontArgs(cmd, argWords, callArgs)
if err != nil {
return nil, err
}

// 解析最后一个参数
callArgs, err = t.parseLastArg(cmd, argWords, callArgs)
if err != nil {
return nil, err
}

return cmd.fn.Call(callArgs), nil
}

func (t *anyCommandTrie) findCommand(cmdWords []string, argWords []string) (*command, []string, error) {
var cmd *command

ptr := &t.root
for i := 0; i < len(cmdWords); i++ {
next, ok := ptr.nexts[cmdWords[i]]
@@ -76,31 +196,33 @@ func (t *CommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error {
ptr = next
}
if cmd == nil {
return fmt.Errorf("command not found")
}

fnType := cmd.fn.Type()

// 最后一个参数如果是数组,那么可以少一个参数
if len(argWords) < fnType.NumIn()-1 {
return fmt.Errorf("no enough arguments for command")
return nil, nil, fmt.Errorf("command not found")
}
return cmd, argWords, nil
}

var callArgs []reflect.Value

// 数组参数只能是最后一个,所以先处理最后一个参数前的参数
for i := 0; i < fnType.NumIn()-1; i++ {
val, err := t.parseValue(argWords[i], fnType.In(i))
func (t *anyCommandTrie) parseFrontArgs(cmd *command, argWords []string, callArgs []reflect.Value) ([]reflect.Value, error) {
for i := 0; i < len(cmd.staticArgTypes)-1; i++ {
val, err := t.parseValue(argWords[i], cmd.staticArgTypes[i])
if err != nil {
return fmt.Errorf("cannot parse function argument at %d, err: %s", i, err.Error())
// 如果有Ctx参数,则参数的位置要往后一个
argIndex := i
if t.ctxType != nil {
argIndex++
}

return nil, fmt.Errorf("cannot parse function argument at %d, err: %s", argIndex, err.Error())
}

callArgs = append(callArgs, val)
}
return callArgs, nil
}

if fnType.NumIn() > 0 {
lastArgType := fnType.In(fnType.NumIn() - 1)
lastArgWords := argWords[fnType.NumIn()-1:]
func (t *anyCommandTrie) parseLastArg(cmd *command, argWords []string, callArgs []reflect.Value) ([]reflect.Value, error) {
if len(cmd.staticArgTypes) > 0 {
lastArgType := cmd.staticArgTypes[len(cmd.staticArgTypes)-1]
lastArgWords := argWords[len(cmd.staticArgTypes)-1:]
lastArgTypeKind := lastArgType.Kind()

var lastArg reflect.Value
@@ -114,31 +236,29 @@ func (t *CommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error {
for i := 0; i < len(lastArgWords); i++ {
eleVal, err := t.parseValue(lastArgWords[i], lastArgType.Elem())
if err != nil {
return fmt.Errorf("cannot parse as array element, err: %s", err.Error())
return nil, fmt.Errorf("cannot parse as array element, err: %s", err.Error())
}
lastArg.Index(i).Set(eleVal)
}

} else {
if len(lastArgWords) == 0 {
return fmt.Errorf("no enough arguments for command")
return nil, fmt.Errorf("no enough arguments for command")
}

var err error
lastArg, err = t.parseValue(lastArgWords[0], lastArgType)
if err != nil {
return fmt.Errorf("cannot parse function argument at %d, err: %s", fnType.NumIn()-1, err.Error())
return nil, fmt.Errorf("cannot parse function argument at %d, err: %s", cmd.fnType.NumIn()-1, err.Error())
}
}

callArgs = append(callArgs, lastArg)
}

cmd.fn.Call(callArgs)
return nil
return callArgs, nil
}

func (t *CommandTrie[TCtx]) parseValue(word string, valueType reflect.Type) (reflect.Value, error) {
func (t *anyCommandTrie) parseValue(word string, valueType reflect.Type) (reflect.Value, error) {
valTypeKind := valueType.Kind()

if valTypeKind == reflect.String {
@@ -183,3 +303,70 @@ func (t *CommandTrie[TCtx]) parseValue(word string, valueType reflect.Type) (ref

return reflect.Value{}, fmt.Errorf("cannot parse string as %s", valueType.Name())
}

type CommandTrie[TCtx any, TRet any] struct {
anyTrie anyCommandTrie
}

func NewCommandTrie[TCtx any, TRet any]() CommandTrie[TCtx, TRet] {
return CommandTrie[TCtx, TRet]{
anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), myreflect.TypeOf[TRet]()),
}
}

func (t *CommandTrie[TCtx, TRet]) Add(fn any, prefixWords ...string) error {
return t.anyTrie.Add(fn, prefixWords...)
}

func (t *CommandTrie[TCtx, TRet]) Execute(ctx TCtx, cmdWords ...string) (TRet, error) {
retValues, err := t.anyTrie.Execute(ctx, cmdWords...)
if err != nil {
var defRet TRet
return defRet, err
}

return retValues[0].Interface().(TRet), nil
}

type VoidCommandTrie[TCtx any] struct {
anyTrie anyCommandTrie
}

func NewVoidCommandTrie[TCtx any]() VoidCommandTrie[TCtx] {
return VoidCommandTrie[TCtx]{
anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), nil),
}
}

func (t *VoidCommandTrie[TCtx]) Add(fn any, prefixWords ...string) error {
return t.anyTrie.Add(fn, prefixWords...)
}

func (t *VoidCommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error {
_, err := t.anyTrie.Execute(ctx, cmdWords...)
return err
}

type StaticCommandTrie[TRet any] struct {
anyTrie anyCommandTrie
}

func NewStaticCommandTrie[TRet any]() StaticCommandTrie[TRet] {
return StaticCommandTrie[TRet]{
anyTrie: newAnyCommandTrie(nil, myreflect.TypeOf[TRet]()),
}
}

func (t *StaticCommandTrie[TRet]) Add(fn any, prefixWords ...string) error {
return t.anyTrie.Add(fn, prefixWords...)
}

func (t *StaticCommandTrie[TRet]) Execute(cmdWords ...string) (TRet, error) {
retValues, err := t.anyTrie.Execute(nil, cmdWords...)
if err != nil {
var defRet TRet
return defRet, err
}

return retValues[0].Interface().(TRet), nil
}

+ 108
- 12
pkg/cmdtrie/command_trie_test.go View File

@@ -8,37 +8,39 @@ import (

func Test_CommandTrie(t *testing.T) {
Convey("无参数命令", t, func() {
trie := NewCommandTrie[int]()
trie := NewVoidCommandTrie[int]()

var ret string

trie.Add(func() {
err := trie.Add(func(int) {
ret = "ok"
}, "a")
So(err, ShouldBeNil)

err := trie.Execute(0, "a")
err = trie.Execute(0, "a")
So(err, ShouldBeNil)

So(ret, ShouldEqual, "ok")
})

Convey("各种参数", t, func() {
trie := NewCommandTrie[int]()
trie := NewVoidCommandTrie[int]()

var argI int
var argStr string
var argBl bool
var argFP float32

trie.Add(func(i int, str string, bl bool, fp float32) {
err := trie.Add(func(int, i int, str string, bl bool, fp float32) {
argI = i
argStr = str
argBl = bl
argFP = fp

}, "a", "b")
So(err, ShouldBeNil)

err := trie.Execute(0, "a", "b", "1", "2", "true", "3")
err = trie.Execute(0, "a", "b", "1", "2", "true", "3")
So(err, ShouldBeNil)

So(argI, ShouldEqual, 1)
@@ -48,18 +50,19 @@ func Test_CommandTrie(t *testing.T) {
})

Convey("有数组参数", t, func() {
trie := NewCommandTrie[int]()
trie := NewVoidCommandTrie[int]()

var argI int
var argArr []int64

trie.Add(func(i int, arr []int64) {
err := trie.Add(func(int, i int, arr []int64) {
argI = i
argArr = arr

}, "a", "b")
So(err, ShouldBeNil)

err := trie.Execute(0, "a", "b", "1", "2", "3", "4")
err = trie.Execute(0, "a", "b", "1", "2", "3", "4")
So(err, ShouldBeNil)

So(argI, ShouldEqual, 1)
@@ -67,21 +70,114 @@ func Test_CommandTrie(t *testing.T) {
})

Convey("有数组参数,但为空", t, func() {
trie := NewCommandTrie[int]()
trie := NewVoidCommandTrie[int]()

var argI int
var argArr []int64

trie.Add(func(i int, arr []int64) {
err := trie.Add(func(int, i int, arr []int64) {
argI = i
argArr = arr

}, "a", "b")
So(err, ShouldBeNil)

err = trie.Execute(0, "a", "b", "1")
So(err, ShouldBeNil)

So(argI, ShouldEqual, 1)
So(argArr, ShouldResemble, []int64{})
})

Convey("带返回值", t, func() {
trie := NewCommandTrie[int, int]()

var argI int
var argArr []int64

err := trie.Add(func(int, i int, arr []int64) int {
argI = i
argArr = arr
return 123
}, "a", "b")
So(err, ShouldBeNil)

ret, err := trie.Execute(0, "a", "b", "1")
So(err, ShouldBeNil)

So(argI, ShouldEqual, 1)
So(argArr, ShouldResemble, []int64{})
So(ret, ShouldEqual, 123)
})

Convey("返回值是接口类型", t, func() {
trie := NewCommandTrie[int, any]()

var argI int
var argArr []int64

err := trie.Add(func(int, i int, arr []int64) int {
argI = i
argArr = arr
return 123
}, "a", "b")
So(err, ShouldBeNil)

err = trie.Add(func(int, i int, arr []int64) string {
return "123"
}, "a", "c")
So(err, ShouldBeNil)

ret, err := trie.Execute(0, "a", "b", "1")
So(err, ShouldBeNil)
So(argI, ShouldEqual, 1)
So(argArr, ShouldResemble, []int64{})
So(ret, ShouldEqual, 123)

ret2, err := trie.Execute(0, "a", "c", "1")
So(err, ShouldBeNil)
So(ret2, ShouldEqual, "123")
})

Convey("无Ctx参数", t, func() {
trie := NewStaticCommandTrie[int]()

var argI int
var argArr []int64

err := trie.Add(func(i int, arr []int64) int {
argI = i
argArr = arr
return 123
}, "a", "b")
So(err, ShouldBeNil)

ret, err := trie.Execute("a", "b", "1")
So(err, ShouldBeNil)

So(argI, ShouldEqual, 1)
So(argArr, ShouldResemble, []int64{})
So(ret, ShouldEqual, 123)
})

Convey("完全无参数", t, func() {
trie := NewStaticCommandTrie[int]()

var argI int
var argArr []int64

err := trie.Add(func() int {
argI = 1
argArr = []int64{}
return 123
}, "a", "b")
So(err, ShouldBeNil)

err := trie.Execute(0, "a", "b", "1")
ret, err := trie.Execute("a", "b")
So(err, ShouldBeNil)

So(argI, ShouldEqual, 1)
So(argArr, ShouldResemble, []int64{})
So(ret, ShouldEqual, 123)
})
}

Loading…
Cancel
Save