From 8c9e40f772b0999ac4f703a7e58fca705e232d09 Mon Sep 17 00:00:00 2001 From: Sydonian <794346190@qq.com> Date: Wed, 7 Jun 2023 15:16:32 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0trie=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E7=BB=93=E6=9E=84=EF=BC=8C=E4=BC=98=E5=8C=96cmdtrie=E7=9A=84?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/cmdtrie/command_trie.go | 42 +++--------- pkg/cmdtrie/command_trie_test.go | 23 +++++++ pkg/trie/trie.go | 111 +++++++++++++++++++++++++++++++ pkg/trie/trie_test.go | 80 ++++++++++++++++++++++ 4 files changed, 224 insertions(+), 32 deletions(-) create mode 100644 pkg/trie/trie.go create mode 100644 pkg/trie/trie_test.go diff --git a/pkg/cmdtrie/command_trie.go b/pkg/cmdtrie/command_trie.go index 35333bb..f2f5e5c 100644 --- a/pkg/cmdtrie/command_trie.go +++ b/pkg/cmdtrie/command_trie.go @@ -5,6 +5,8 @@ import ( "reflect" "strconv" + "github.com/samber/lo" + "gitlink.org.cn/cloudream/common/pkg/trie" myreflect "gitlink.org.cn/cloudream/common/utils/reflect" ) @@ -19,22 +21,14 @@ type command struct { lastIsArray bool } -type trieNode struct { - nexts map[string]*trieNode - cmd *command -} - type anyCommandTrie struct { - root trieNode + trie trie.Trie[*command] ctxType reflect.Type retType reflect.Type } func newAnyCommandTrie(ctxType reflect.Type, retType reflect.Type) anyCommandTrie { return anyCommandTrie{ - root: trieNode{ - nexts: make(map[string]*trieNode), - }, ctxType: ctxType, retType: retType, } @@ -56,17 +50,7 @@ func (t *anyCommandTrie) Add(fn any, prefixWords ...string) error { return err } - ptr := &t.root - for _, word := range prefixWords { - next, ok := ptr.nexts[word] - if !ok { - next = &trieNode{ - nexts: make(map[string]*trieNode), - } - ptr.nexts[word] = next - } - ptr = next - } + node := t.trie.Create(lo.Map(prefixWords, func(val string, index int) any { return val })) fnType := reflect.TypeOf(fn) var staticArgTypes []reflect.Type @@ -86,7 +70,7 @@ func (t *anyCommandTrie) Add(fn any, prefixWords ...string) error { lastIsArray = kind == reflect.Array || kind == reflect.Slice } - ptr.cmd = &command{ + node.Value = &command{ fn: reflect.ValueOf(fn), fnType: reflect.TypeOf(fn), staticArgTypes: staticArgTypes, @@ -186,19 +170,13 @@ func (t *anyCommandTrie) Execute(ctx any, cmdWords []string, opt ExecuteOption) func (t *anyCommandTrie) findCommand(cmdWords []string, argWords []string) (*command, []string, error) { var cmd *command - ptr := &t.root - for i := 0; i < len(cmdWords); i++ { - next, ok := ptr.nexts[cmdWords[i]] - if !ok { - break - } - if next != nil { - cmd = next.cmd - argWords = cmdWords[i+1:] + t.trie.Walk(cmdWords, func(word string, index int, node *trie.Node[*command], isWordNode bool) { + if node.Value != nil { + cmd = node.Value + argWords = cmdWords[index+1:] } + }) - ptr = next - } if cmd == nil { return nil, nil, fmt.Errorf("command not found") } diff --git a/pkg/cmdtrie/command_trie_test.go b/pkg/cmdtrie/command_trie_test.go index 2edf073..07dcb96 100644 --- a/pkg/cmdtrie/command_trie_test.go +++ b/pkg/cmdtrie/command_trie_test.go @@ -198,4 +198,27 @@ func Test_CommandTrie(t *testing.T) { So(argStrs, ShouldBeNil) So(ret, ShouldEqual, 123) }) + + Convey("前缀有重叠的命令,取前缀最长的", t, func() { + trie := NewStaticCommandTrie[int]() + + var argStrs []string + err := trie.Add(func(strs []string) int { + argStrs = strs + return 123 + }, "a", "b", "c") + So(err, ShouldBeNil) + + err = trie.Add(func(strs []string) int { + argStrs = strs + return 456 + }, "a", "b", "c", "d") + So(err, ShouldBeNil) + + ret, err := trie.Execute([]string{"a", "b", "c", "d", "e"}, ExecuteOption{ReplaceEmptyArrayWithNil: true}) + So(err, ShouldBeNil) + + So(argStrs, ShouldResemble, []string{"e"}) + So(ret, ShouldEqual, 456) + }) } diff --git a/pkg/trie/trie.go b/pkg/trie/trie.go new file mode 100644 index 0000000..9fd84ba --- /dev/null +++ b/pkg/trie/trie.go @@ -0,0 +1,111 @@ +package trie + +const ( + WORD_ANY = 0 +) + +type Node[T any] struct { + WordNexts map[string]*Node[T] + AnyNext *Node[T] + Value T +} + +func (n *Node[T]) WalkNext(word string) *Node[T] { + if n.WordNexts == nil { + return n.AnyNext + } + + node, ok := n.WordNexts[word] + if ok { + return node + } + + return n.AnyNext +} + +func (n *Node[T]) walkWordNext(word string) (*Node[T], bool) { + if n.WordNexts == nil { + return n.AnyNext, false + } + + node, ok := n.WordNexts[word] + if ok { + return node, true + } + + return n.AnyNext, false +} + +func (n *Node[T]) Create(word string) *Node[T] { + if n.WordNexts == nil { + n.WordNexts = make(map[string]*Node[T]) + } + + node, ok := n.WordNexts[word] + if !ok { + node = &Node[T]{} + n.WordNexts[word] = node + } + + return node +} + +func (n *Node[T]) CreateAny() *Node[T] { + if n.AnyNext == nil { + n.AnyNext = &Node[T]{} + } + + return n.AnyNext +} + +type Trie[T any] struct { + Root Node[T] +} + +func (t *Trie[T]) Walk(words []string, visitorFn func(word string, wordIndex int, node *Node[T], isWordNode bool)) bool { + ptr := &t.Root + + for index, word := range words { + var isWord bool + ptr, isWord = ptr.walkWordNext(word) + if ptr == nil { + return false + } + + visitorFn(word, index, ptr, isWord) + } + + return true +} + +func (t *Trie[T]) WalkEnd(words []string) (*Node[T], bool) { + ptr := &t.Root + + for _, word := range words { + ptr = ptr.WalkNext(word) + if ptr == nil { + return nil, false + } + } + + return ptr, true +} + +func (t *Trie[T]) Create(words []any) *Node[T] { + ptr := &t.Root + + for _, word := range words { + switch val := word.(type) { + case string: + ptr = ptr.Create(val) + + case int: + ptr = ptr.CreateAny() + + default: + panic("word can only be string or int 0") + } + } + + return ptr +} diff --git a/pkg/trie/trie_test.go b/pkg/trie/trie_test.go new file mode 100644 index 0000000..cdf50ee --- /dev/null +++ b/pkg/trie/trie_test.go @@ -0,0 +1,80 @@ +package trie + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func Test_CommandTrie(t *testing.T) { + Convey("全是Word节点", t, func() { + trie := Trie[int]{} + + { + n := trie.Create([]any{"a", "b"}) + So(n, ShouldNotBeNil) + n.Value = 123 + } + + { + n, ok := trie.WalkEnd([]string{"a", "b"}) + So(n, ShouldNotBeNil) + So(ok, ShouldBeTrue) + So(n.Value, ShouldEqual, 123) + } + }) + + Convey("包含Any节点", t, func() { + trie := Trie[int]{} + + { + n := trie.Create([]any{"a", WORD_ANY, "b"}) + So(n, ShouldNotBeNil) + n.Value = 123 + } + + { + n, ok := trie.WalkEnd([]string{"a", "11", "b"}) + So(n, ShouldNotBeNil) + So(ok, ShouldBeTrue) + So(n.Value, ShouldEqual, 123) + } + + { + n, ok := trie.WalkEnd([]string{"a", "22", "b"}) + So(n, ShouldNotBeNil) + So(ok, ShouldBeTrue) + So(n.Value, ShouldEqual, 123) + } + }) + + Convey("优先经过Word节点", t, func() { + trie := Trie[int]{} + + { + n := trie.Create([]any{"a", "b", "c"}) + So(n, ShouldNotBeNil) + n.Value = 123 + } + + { + n := trie.Create([]any{"a", WORD_ANY, "c"}) + So(n, ShouldNotBeNil) + n.Value = 456 + } + + { + n, ok := trie.WalkEnd([]string{"a", "b", "c"}) + So(n, ShouldNotBeNil) + So(ok, ShouldBeTrue) + So(n.Value, ShouldEqual, 123) + } + + { + n, ok := trie.WalkEnd([]string{"a", "d", "c"}) + So(n, ShouldNotBeNil) + So(ok, ShouldBeTrue) + So(n.Value, ShouldEqual, 456) + } + }) +}