| @@ -5,6 +5,8 @@ import ( | |||
| "reflect" | |||
| "strconv" | |||
| "github.com/samber/lo" | |||
| "gitlink.org.cn/cloudream/common/pkg/trie" | |||
| myreflect "gitlink.org.cn/cloudream/common/utils/reflect" | |||
| ) | |||
| @@ -19,22 +21,14 @@ type command struct { | |||
| lastIsArray bool | |||
| } | |||
| type trieNode struct { | |||
| nexts map[string]*trieNode | |||
| cmd *command | |||
| } | |||
| type anyCommandTrie struct { | |||
| root trieNode | |||
| trie trie.Trie[*command] | |||
| ctxType reflect.Type | |||
| retType reflect.Type | |||
| } | |||
| func newAnyCommandTrie(ctxType reflect.Type, retType reflect.Type) anyCommandTrie { | |||
| return anyCommandTrie{ | |||
| root: trieNode{ | |||
| nexts: make(map[string]*trieNode), | |||
| }, | |||
| ctxType: ctxType, | |||
| retType: retType, | |||
| } | |||
| @@ -56,17 +50,7 @@ func (t *anyCommandTrie) Add(fn any, prefixWords ...string) error { | |||
| return err | |||
| } | |||
| ptr := &t.root | |||
| for _, word := range prefixWords { | |||
| next, ok := ptr.nexts[word] | |||
| if !ok { | |||
| next = &trieNode{ | |||
| nexts: make(map[string]*trieNode), | |||
| } | |||
| ptr.nexts[word] = next | |||
| } | |||
| ptr = next | |||
| } | |||
| node := t.trie.Create(lo.Map(prefixWords, func(val string, index int) any { return val })) | |||
| fnType := reflect.TypeOf(fn) | |||
| var staticArgTypes []reflect.Type | |||
| @@ -86,7 +70,7 @@ func (t *anyCommandTrie) Add(fn any, prefixWords ...string) error { | |||
| lastIsArray = kind == reflect.Array || kind == reflect.Slice | |||
| } | |||
| ptr.cmd = &command{ | |||
| node.Value = &command{ | |||
| fn: reflect.ValueOf(fn), | |||
| fnType: reflect.TypeOf(fn), | |||
| staticArgTypes: staticArgTypes, | |||
| @@ -186,19 +170,13 @@ func (t *anyCommandTrie) Execute(ctx any, cmdWords []string, opt ExecuteOption) | |||
| func (t *anyCommandTrie) findCommand(cmdWords []string, argWords []string) (*command, []string, error) { | |||
| var cmd *command | |||
| ptr := &t.root | |||
| for i := 0; i < len(cmdWords); i++ { | |||
| next, ok := ptr.nexts[cmdWords[i]] | |||
| if !ok { | |||
| break | |||
| } | |||
| if next != nil { | |||
| cmd = next.cmd | |||
| argWords = cmdWords[i+1:] | |||
| t.trie.Walk(cmdWords, func(word string, index int, node *trie.Node[*command], isWordNode bool) { | |||
| if node.Value != nil { | |||
| cmd = node.Value | |||
| argWords = cmdWords[index+1:] | |||
| } | |||
| }) | |||
| ptr = next | |||
| } | |||
| if cmd == nil { | |||
| return nil, nil, fmt.Errorf("command not found") | |||
| } | |||
| @@ -198,4 +198,27 @@ func Test_CommandTrie(t *testing.T) { | |||
| So(argStrs, ShouldBeNil) | |||
| So(ret, ShouldEqual, 123) | |||
| }) | |||
| Convey("前缀有重叠的命令,取前缀最长的", t, func() { | |||
| trie := NewStaticCommandTrie[int]() | |||
| var argStrs []string | |||
| err := trie.Add(func(strs []string) int { | |||
| argStrs = strs | |||
| return 123 | |||
| }, "a", "b", "c") | |||
| So(err, ShouldBeNil) | |||
| err = trie.Add(func(strs []string) int { | |||
| argStrs = strs | |||
| return 456 | |||
| }, "a", "b", "c", "d") | |||
| So(err, ShouldBeNil) | |||
| ret, err := trie.Execute([]string{"a", "b", "c", "d", "e"}, ExecuteOption{ReplaceEmptyArrayWithNil: true}) | |||
| So(err, ShouldBeNil) | |||
| So(argStrs, ShouldResemble, []string{"e"}) | |||
| So(ret, ShouldEqual, 456) | |||
| }) | |||
| } | |||
| @@ -0,0 +1,111 @@ | |||
| package trie | |||
| const ( | |||
| WORD_ANY = 0 | |||
| ) | |||
| type Node[T any] struct { | |||
| WordNexts map[string]*Node[T] | |||
| AnyNext *Node[T] | |||
| Value T | |||
| } | |||
| func (n *Node[T]) WalkNext(word string) *Node[T] { | |||
| if n.WordNexts == nil { | |||
| return n.AnyNext | |||
| } | |||
| node, ok := n.WordNexts[word] | |||
| if ok { | |||
| return node | |||
| } | |||
| return n.AnyNext | |||
| } | |||
| func (n *Node[T]) walkWordNext(word string) (*Node[T], bool) { | |||
| if n.WordNexts == nil { | |||
| return n.AnyNext, false | |||
| } | |||
| node, ok := n.WordNexts[word] | |||
| if ok { | |||
| return node, true | |||
| } | |||
| return n.AnyNext, false | |||
| } | |||
| func (n *Node[T]) Create(word string) *Node[T] { | |||
| if n.WordNexts == nil { | |||
| n.WordNexts = make(map[string]*Node[T]) | |||
| } | |||
| node, ok := n.WordNexts[word] | |||
| if !ok { | |||
| node = &Node[T]{} | |||
| n.WordNexts[word] = node | |||
| } | |||
| return node | |||
| } | |||
| func (n *Node[T]) CreateAny() *Node[T] { | |||
| if n.AnyNext == nil { | |||
| n.AnyNext = &Node[T]{} | |||
| } | |||
| return n.AnyNext | |||
| } | |||
| type Trie[T any] struct { | |||
| Root Node[T] | |||
| } | |||
| func (t *Trie[T]) Walk(words []string, visitorFn func(word string, wordIndex int, node *Node[T], isWordNode bool)) bool { | |||
| ptr := &t.Root | |||
| for index, word := range words { | |||
| var isWord bool | |||
| ptr, isWord = ptr.walkWordNext(word) | |||
| if ptr == nil { | |||
| return false | |||
| } | |||
| visitorFn(word, index, ptr, isWord) | |||
| } | |||
| return true | |||
| } | |||
| func (t *Trie[T]) WalkEnd(words []string) (*Node[T], bool) { | |||
| ptr := &t.Root | |||
| for _, word := range words { | |||
| ptr = ptr.WalkNext(word) | |||
| if ptr == nil { | |||
| return nil, false | |||
| } | |||
| } | |||
| return ptr, true | |||
| } | |||
| func (t *Trie[T]) Create(words []any) *Node[T] { | |||
| ptr := &t.Root | |||
| for _, word := range words { | |||
| switch val := word.(type) { | |||
| case string: | |||
| ptr = ptr.Create(val) | |||
| case int: | |||
| ptr = ptr.CreateAny() | |||
| default: | |||
| panic("word can only be string or int 0") | |||
| } | |||
| } | |||
| return ptr | |||
| } | |||
| @@ -0,0 +1,80 @@ | |||
| package trie | |||
| import ( | |||
| "testing" | |||
| . "github.com/smartystreets/goconvey/convey" | |||
| ) | |||
| func Test_CommandTrie(t *testing.T) { | |||
| Convey("全是Word节点", t, func() { | |||
| trie := Trie[int]{} | |||
| { | |||
| n := trie.Create([]any{"a", "b"}) | |||
| So(n, ShouldNotBeNil) | |||
| n.Value = 123 | |||
| } | |||
| { | |||
| n, ok := trie.WalkEnd([]string{"a", "b"}) | |||
| So(n, ShouldNotBeNil) | |||
| So(ok, ShouldBeTrue) | |||
| So(n.Value, ShouldEqual, 123) | |||
| } | |||
| }) | |||
| Convey("包含Any节点", t, func() { | |||
| trie := Trie[int]{} | |||
| { | |||
| n := trie.Create([]any{"a", WORD_ANY, "b"}) | |||
| So(n, ShouldNotBeNil) | |||
| n.Value = 123 | |||
| } | |||
| { | |||
| n, ok := trie.WalkEnd([]string{"a", "11", "b"}) | |||
| So(n, ShouldNotBeNil) | |||
| So(ok, ShouldBeTrue) | |||
| So(n.Value, ShouldEqual, 123) | |||
| } | |||
| { | |||
| n, ok := trie.WalkEnd([]string{"a", "22", "b"}) | |||
| So(n, ShouldNotBeNil) | |||
| So(ok, ShouldBeTrue) | |||
| So(n.Value, ShouldEqual, 123) | |||
| } | |||
| }) | |||
| Convey("优先经过Word节点", t, func() { | |||
| trie := Trie[int]{} | |||
| { | |||
| n := trie.Create([]any{"a", "b", "c"}) | |||
| So(n, ShouldNotBeNil) | |||
| n.Value = 123 | |||
| } | |||
| { | |||
| n := trie.Create([]any{"a", WORD_ANY, "c"}) | |||
| So(n, ShouldNotBeNil) | |||
| n.Value = 456 | |||
| } | |||
| { | |||
| n, ok := trie.WalkEnd([]string{"a", "b", "c"}) | |||
| So(n, ShouldNotBeNil) | |||
| So(ok, ShouldBeTrue) | |||
| So(n.Value, ShouldEqual, 123) | |||
| } | |||
| { | |||
| n, ok := trie.WalkEnd([]string{"a", "d", "c"}) | |||
| So(n, ShouldNotBeNil) | |||
| So(ok, ShouldBeTrue) | |||
| So(n.Value, ShouldEqual, 456) | |||
| } | |||
| }) | |||
| } | |||