diff --git a/pkg/cmdtrie/command_trie.go b/pkg/cmdtrie/command_trie.go index ea79965..0de80f7 100644 --- a/pkg/cmdtrie/command_trie.go +++ b/pkg/cmdtrie/command_trie.go @@ -4,10 +4,15 @@ import ( "fmt" "reflect" "strconv" + + myreflect "gitlink.org.cn/cloudream/common/utils/reflect" ) type command struct { - fn reflect.Value + fn reflect.Value + fnType reflect.Type + staticArgTypes []reflect.Type + lastIsArray bool } type trieNode struct { @@ -15,29 +20,36 @@ type trieNode struct { cmd *command } -type CommandTrie[TCtx any] struct { - root trieNode +type anyCommandTrie struct { + root trieNode + ctxType reflect.Type + retType reflect.Type } -func NewCommandTrie[TCtx any]() CommandTrie[TCtx] { - return CommandTrie[TCtx]{ +func newAnyCommandTrie(ctxType reflect.Type, retType reflect.Type) anyCommandTrie { + return anyCommandTrie{ root: trieNode{ nexts: make(map[string]*trieNode), }, + ctxType: ctxType, + retType: retType, } } -func (t *CommandTrie[TCtx]) Add(fn any, prefixWords ...string) error { +func (t *anyCommandTrie) Add(fn any, prefixWords ...string) error { typ := reflect.TypeOf(fn) if typ.Kind() != reflect.Func { return fmt.Errorf("fn must be a function, but get a %v", typ.Kind()) } - for i := 0; i < typ.NumIn(); i++ { - argType := typ.In(i) - if argType.Kind() == reflect.Array && i < typ.NumIn()-1 { - return fmt.Errorf("array argument must at the last one") - } + err := t.checkFnReturn(typ) + if err != nil { + return err + } + + err = t.checkFnArgs(typ) + if err != nil { + return err } ptr := &t.root @@ -52,16 +64,124 @@ func (t *CommandTrie[TCtx]) Add(fn any, prefixWords ...string) error { ptr = next } + fnType := reflect.TypeOf(fn) + var staticArgTypes []reflect.Type + if t.ctxType != nil { + for i := 1; i < fnType.NumIn(); i++ { + staticArgTypes = append(staticArgTypes, fnType.In(i)) + } + } else { + for i := 0; i < fnType.NumIn(); i++ { + staticArgTypes = append(staticArgTypes, fnType.In(i)) + } + } + + var lastIsArray = false + if len(staticArgTypes) > 0 { + kind := staticArgTypes[len(staticArgTypes)-1].Kind() + lastIsArray = kind == reflect.Array || kind == reflect.Slice + } + ptr.cmd = &command{ - fn: reflect.ValueOf(fn), + fn: reflect.ValueOf(fn), + fnType: reflect.TypeOf(fn), + staticArgTypes: staticArgTypes, + lastIsArray: lastIsArray, + } + return nil +} + +func (t *anyCommandTrie) checkFnReturn(typ reflect.Type) error { + if t.retType != nil { + if typ.NumOut() != 1 { + return fmt.Errorf("fn must have one return value with type %s", t.retType.Name()) + } + + fnRetType := typ.Out(0) + if t.retType.Kind() == reflect.Interface { + + // 如果TRet是接口类型,那么fn的返回值只要实现了此接口,就也可以接受 + if !fnRetType.Implements(t.retType) { + return fmt.Errorf("fn must have one return value with type %s", t.retType.Name()) + } + + } else if fnRetType != t.retType { + return fmt.Errorf("fn must have one return value with type %s", t.retType.Name()) + } } return nil } -func (t *CommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error { +func (t *anyCommandTrie) checkFnArgs(typ reflect.Type) error { + if t.ctxType != nil { + if typ.NumIn() < 1 { + return fmt.Errorf("fn must have a ctx argument") + } + + for i := 0; i < typ.NumIn(); i++ { + argType := typ.In(i) + if i == 0 && argType != t.ctxType { + return fmt.Errorf("first argument of fn must be %s", t.ctxType.Name()) + } + + if argType.Kind() == reflect.Array && i < typ.NumIn()-1 { + return fmt.Errorf("array argument must at the last one") + } + } + } else { + for i := 0; i < typ.NumIn(); i++ { + argType := typ.In(i) + if argType.Kind() == reflect.Array && i < typ.NumIn()-1 { + return fmt.Errorf("array argument must at the last one") + } + } + } + return nil +} + +func (t *anyCommandTrie) Execute(ctx any, cmdWords ...string) ([]reflect.Value, error) { var cmd *command var argWords []string + cmd, argWords, err := t.findCommand(cmdWords, argWords) + if err != nil { + return nil, err + } + + if cmd.lastIsArray { + // 最后一个参数如果是数组,那么可以少一个参数 + if len(argWords) < len(cmd.staticArgTypes)-1 { + return nil, fmt.Errorf("no enough arguments for command") + } + } else if len(argWords) < len(cmd.staticArgTypes) { + return nil, fmt.Errorf("no enough arguments for command") + } + + var callArgs []reflect.Value + + // 如果有Ctx参数,则加上Ctx参数 + if t.ctxType != nil { + callArgs = append(callArgs, reflect.ValueOf(ctx)) + } + + // 数组参数只能是最后一个,所以先处理最后一个参数前的参数 + callArgs, err = t.parseFrontArgs(cmd, argWords, callArgs) + if err != nil { + return nil, err + } + + // 解析最后一个参数 + callArgs, err = t.parseLastArg(cmd, argWords, callArgs) + if err != nil { + return nil, err + } + + return cmd.fn.Call(callArgs), nil +} + +func (t *anyCommandTrie) findCommand(cmdWords []string, argWords []string) (*command, []string, error) { + var cmd *command + ptr := &t.root for i := 0; i < len(cmdWords); i++ { next, ok := ptr.nexts[cmdWords[i]] @@ -76,31 +196,33 @@ func (t *CommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error { ptr = next } if cmd == nil { - return fmt.Errorf("command not found") - } - - fnType := cmd.fn.Type() - - // 最后一个参数如果是数组,那么可以少一个参数 - if len(argWords) < fnType.NumIn()-1 { - return fmt.Errorf("no enough arguments for command") + return nil, nil, fmt.Errorf("command not found") } + return cmd, argWords, nil +} - var callArgs []reflect.Value - - // 数组参数只能是最后一个,所以先处理最后一个参数前的参数 - for i := 0; i < fnType.NumIn()-1; i++ { - val, err := t.parseValue(argWords[i], fnType.In(i)) +func (t *anyCommandTrie) parseFrontArgs(cmd *command, argWords []string, callArgs []reflect.Value) ([]reflect.Value, error) { + for i := 0; i < len(cmd.staticArgTypes)-1; i++ { + val, err := t.parseValue(argWords[i], cmd.staticArgTypes[i]) if err != nil { - return fmt.Errorf("cannot parse function argument at %d, err: %s", i, err.Error()) + // 如果有Ctx参数,则参数的位置要往后一个 + argIndex := i + if t.ctxType != nil { + argIndex++ + } + + return nil, fmt.Errorf("cannot parse function argument at %d, err: %s", argIndex, err.Error()) } callArgs = append(callArgs, val) } + return callArgs, nil +} - if fnType.NumIn() > 0 { - lastArgType := fnType.In(fnType.NumIn() - 1) - lastArgWords := argWords[fnType.NumIn()-1:] +func (t *anyCommandTrie) parseLastArg(cmd *command, argWords []string, callArgs []reflect.Value) ([]reflect.Value, error) { + if len(cmd.staticArgTypes) > 0 { + lastArgType := cmd.staticArgTypes[len(cmd.staticArgTypes)-1] + lastArgWords := argWords[len(cmd.staticArgTypes)-1:] lastArgTypeKind := lastArgType.Kind() var lastArg reflect.Value @@ -114,31 +236,29 @@ func (t *CommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error { for i := 0; i < len(lastArgWords); i++ { eleVal, err := t.parseValue(lastArgWords[i], lastArgType.Elem()) if err != nil { - return fmt.Errorf("cannot parse as array element, err: %s", err.Error()) + return nil, fmt.Errorf("cannot parse as array element, err: %s", err.Error()) } lastArg.Index(i).Set(eleVal) } } else { if len(lastArgWords) == 0 { - return fmt.Errorf("no enough arguments for command") + return nil, fmt.Errorf("no enough arguments for command") } var err error lastArg, err = t.parseValue(lastArgWords[0], lastArgType) if err != nil { - return fmt.Errorf("cannot parse function argument at %d, err: %s", fnType.NumIn()-1, err.Error()) + return nil, fmt.Errorf("cannot parse function argument at %d, err: %s", cmd.fnType.NumIn()-1, err.Error()) } } callArgs = append(callArgs, lastArg) } - - cmd.fn.Call(callArgs) - return nil + return callArgs, nil } -func (t *CommandTrie[TCtx]) parseValue(word string, valueType reflect.Type) (reflect.Value, error) { +func (t *anyCommandTrie) parseValue(word string, valueType reflect.Type) (reflect.Value, error) { valTypeKind := valueType.Kind() if valTypeKind == reflect.String { @@ -183,3 +303,70 @@ func (t *CommandTrie[TCtx]) parseValue(word string, valueType reflect.Type) (ref return reflect.Value{}, fmt.Errorf("cannot parse string as %s", valueType.Name()) } + +type CommandTrie[TCtx any, TRet any] struct { + anyTrie anyCommandTrie +} + +func NewCommandTrie[TCtx any, TRet any]() CommandTrie[TCtx, TRet] { + return CommandTrie[TCtx, TRet]{ + anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), myreflect.TypeOf[TRet]()), + } +} + +func (t *CommandTrie[TCtx, TRet]) Add(fn any, prefixWords ...string) error { + return t.anyTrie.Add(fn, prefixWords...) +} + +func (t *CommandTrie[TCtx, TRet]) Execute(ctx TCtx, cmdWords ...string) (TRet, error) { + retValues, err := t.anyTrie.Execute(ctx, cmdWords...) + if err != nil { + var defRet TRet + return defRet, err + } + + return retValues[0].Interface().(TRet), nil +} + +type VoidCommandTrie[TCtx any] struct { + anyTrie anyCommandTrie +} + +func NewVoidCommandTrie[TCtx any]() VoidCommandTrie[TCtx] { + return VoidCommandTrie[TCtx]{ + anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), nil), + } +} + +func (t *VoidCommandTrie[TCtx]) Add(fn any, prefixWords ...string) error { + return t.anyTrie.Add(fn, prefixWords...) +} + +func (t *VoidCommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error { + _, err := t.anyTrie.Execute(ctx, cmdWords...) + return err +} + +type StaticCommandTrie[TRet any] struct { + anyTrie anyCommandTrie +} + +func NewStaticCommandTrie[TRet any]() StaticCommandTrie[TRet] { + return StaticCommandTrie[TRet]{ + anyTrie: newAnyCommandTrie(nil, myreflect.TypeOf[TRet]()), + } +} + +func (t *StaticCommandTrie[TRet]) Add(fn any, prefixWords ...string) error { + return t.anyTrie.Add(fn, prefixWords...) +} + +func (t *StaticCommandTrie[TRet]) Execute(cmdWords ...string) (TRet, error) { + retValues, err := t.anyTrie.Execute(nil, cmdWords...) + if err != nil { + var defRet TRet + return defRet, err + } + + return retValues[0].Interface().(TRet), nil +} diff --git a/pkg/cmdtrie/command_trie_test.go b/pkg/cmdtrie/command_trie_test.go index 6bfa095..6417ee5 100644 --- a/pkg/cmdtrie/command_trie_test.go +++ b/pkg/cmdtrie/command_trie_test.go @@ -8,37 +8,39 @@ import ( func Test_CommandTrie(t *testing.T) { Convey("无参数命令", t, func() { - trie := NewCommandTrie[int]() + trie := NewVoidCommandTrie[int]() var ret string - trie.Add(func() { + err := trie.Add(func(int) { ret = "ok" }, "a") + So(err, ShouldBeNil) - err := trie.Execute(0, "a") + err = trie.Execute(0, "a") So(err, ShouldBeNil) So(ret, ShouldEqual, "ok") }) Convey("各种参数", t, func() { - trie := NewCommandTrie[int]() + trie := NewVoidCommandTrie[int]() var argI int var argStr string var argBl bool var argFP float32 - trie.Add(func(i int, str string, bl bool, fp float32) { + err := trie.Add(func(int, i int, str string, bl bool, fp float32) { argI = i argStr = str argBl = bl argFP = fp }, "a", "b") + So(err, ShouldBeNil) - err := trie.Execute(0, "a", "b", "1", "2", "true", "3") + err = trie.Execute(0, "a", "b", "1", "2", "true", "3") So(err, ShouldBeNil) So(argI, ShouldEqual, 1) @@ -48,18 +50,19 @@ func Test_CommandTrie(t *testing.T) { }) Convey("有数组参数", t, func() { - trie := NewCommandTrie[int]() + trie := NewVoidCommandTrie[int]() var argI int var argArr []int64 - trie.Add(func(i int, arr []int64) { + err := trie.Add(func(int, i int, arr []int64) { argI = i argArr = arr }, "a", "b") + So(err, ShouldBeNil) - err := trie.Execute(0, "a", "b", "1", "2", "3", "4") + err = trie.Execute(0, "a", "b", "1", "2", "3", "4") So(err, ShouldBeNil) So(argI, ShouldEqual, 1) @@ -67,21 +70,114 @@ func Test_CommandTrie(t *testing.T) { }) Convey("有数组参数,但为空", t, func() { - trie := NewCommandTrie[int]() + trie := NewVoidCommandTrie[int]() var argI int var argArr []int64 - trie.Add(func(i int, arr []int64) { + err := trie.Add(func(int, i int, arr []int64) { argI = i argArr = arr }, "a", "b") + So(err, ShouldBeNil) + + err = trie.Execute(0, "a", "b", "1") + So(err, ShouldBeNil) + + So(argI, ShouldEqual, 1) + So(argArr, ShouldResemble, []int64{}) + }) + + Convey("带返回值", t, func() { + trie := NewCommandTrie[int, int]() + + var argI int + var argArr []int64 + + err := trie.Add(func(int, i int, arr []int64) int { + argI = i + argArr = arr + return 123 + }, "a", "b") + So(err, ShouldBeNil) + + ret, err := trie.Execute(0, "a", "b", "1") + So(err, ShouldBeNil) + + So(argI, ShouldEqual, 1) + So(argArr, ShouldResemble, []int64{}) + So(ret, ShouldEqual, 123) + }) + + Convey("返回值是接口类型", t, func() { + trie := NewCommandTrie[int, any]() + + var argI int + var argArr []int64 + + err := trie.Add(func(int, i int, arr []int64) int { + argI = i + argArr = arr + return 123 + }, "a", "b") + So(err, ShouldBeNil) + + err = trie.Add(func(int, i int, arr []int64) string { + return "123" + }, "a", "c") + So(err, ShouldBeNil) + + ret, err := trie.Execute(0, "a", "b", "1") + So(err, ShouldBeNil) + So(argI, ShouldEqual, 1) + So(argArr, ShouldResemble, []int64{}) + So(ret, ShouldEqual, 123) + + ret2, err := trie.Execute(0, "a", "c", "1") + So(err, ShouldBeNil) + So(ret2, ShouldEqual, "123") + }) + + Convey("无Ctx参数", t, func() { + trie := NewStaticCommandTrie[int]() + + var argI int + var argArr []int64 + + err := trie.Add(func(i int, arr []int64) int { + argI = i + argArr = arr + return 123 + }, "a", "b") + So(err, ShouldBeNil) + + ret, err := trie.Execute("a", "b", "1") + So(err, ShouldBeNil) + + So(argI, ShouldEqual, 1) + So(argArr, ShouldResemble, []int64{}) + So(ret, ShouldEqual, 123) + }) + + Convey("完全无参数", t, func() { + trie := NewStaticCommandTrie[int]() + + var argI int + var argArr []int64 + + err := trie.Add(func() int { + argI = 1 + argArr = []int64{} + return 123 + }, "a", "b") + So(err, ShouldBeNil) - err := trie.Execute(0, "a", "b", "1") + ret, err := trie.Execute("a", "b") So(err, ShouldBeNil) So(argI, ShouldEqual, 1) So(argArr, ShouldResemble, []int64{}) + So(ret, ShouldEqual, 123) }) }