Browse Source

增加trie数据结构,优化cmdtrie的实现方式

pull/1/head
Sydonian 2 years ago
parent
commit
8c9e40f772
4 changed files with 224 additions and 32 deletions
  1. +10
    -32
      pkg/cmdtrie/command_trie.go
  2. +23
    -0
      pkg/cmdtrie/command_trie_test.go
  3. +111
    -0
      pkg/trie/trie.go
  4. +80
    -0
      pkg/trie/trie_test.go

+ 10
- 32
pkg/cmdtrie/command_trie.go View File

@@ -5,6 +5,8 @@ import (
"reflect"
"strconv"

"github.com/samber/lo"
"gitlink.org.cn/cloudream/common/pkg/trie"
myreflect "gitlink.org.cn/cloudream/common/utils/reflect"
)

@@ -19,22 +21,14 @@ type command struct {
lastIsArray bool
}

type trieNode struct {
nexts map[string]*trieNode
cmd *command
}

type anyCommandTrie struct {
root trieNode
trie trie.Trie[*command]
ctxType reflect.Type
retType reflect.Type
}

func newAnyCommandTrie(ctxType reflect.Type, retType reflect.Type) anyCommandTrie {
return anyCommandTrie{
root: trieNode{
nexts: make(map[string]*trieNode),
},
ctxType: ctxType,
retType: retType,
}
@@ -56,17 +50,7 @@ func (t *anyCommandTrie) Add(fn any, prefixWords ...string) error {
return err
}

ptr := &t.root
for _, word := range prefixWords {
next, ok := ptr.nexts[word]
if !ok {
next = &trieNode{
nexts: make(map[string]*trieNode),
}
ptr.nexts[word] = next
}
ptr = next
}
node := t.trie.Create(lo.Map(prefixWords, func(val string, index int) any { return val }))

fnType := reflect.TypeOf(fn)
var staticArgTypes []reflect.Type
@@ -86,7 +70,7 @@ func (t *anyCommandTrie) Add(fn any, prefixWords ...string) error {
lastIsArray = kind == reflect.Array || kind == reflect.Slice
}

ptr.cmd = &command{
node.Value = &command{
fn: reflect.ValueOf(fn),
fnType: reflect.TypeOf(fn),
staticArgTypes: staticArgTypes,
@@ -186,19 +170,13 @@ func (t *anyCommandTrie) Execute(ctx any, cmdWords []string, opt ExecuteOption)
func (t *anyCommandTrie) findCommand(cmdWords []string, argWords []string) (*command, []string, error) {
var cmd *command

ptr := &t.root
for i := 0; i < len(cmdWords); i++ {
next, ok := ptr.nexts[cmdWords[i]]
if !ok {
break
}
if next != nil {
cmd = next.cmd
argWords = cmdWords[i+1:]
t.trie.Walk(cmdWords, func(word string, index int, node *trie.Node[*command], isWordNode bool) {
if node.Value != nil {
cmd = node.Value
argWords = cmdWords[index+1:]
}
})

ptr = next
}
if cmd == nil {
return nil, nil, fmt.Errorf("command not found")
}


+ 23
- 0
pkg/cmdtrie/command_trie_test.go View File

@@ -198,4 +198,27 @@ func Test_CommandTrie(t *testing.T) {
So(argStrs, ShouldBeNil)
So(ret, ShouldEqual, 123)
})

Convey("前缀有重叠的命令,取前缀最长的", t, func() {
trie := NewStaticCommandTrie[int]()

var argStrs []string
err := trie.Add(func(strs []string) int {
argStrs = strs
return 123
}, "a", "b", "c")
So(err, ShouldBeNil)

err = trie.Add(func(strs []string) int {
argStrs = strs
return 456
}, "a", "b", "c", "d")
So(err, ShouldBeNil)

ret, err := trie.Execute([]string{"a", "b", "c", "d", "e"}, ExecuteOption{ReplaceEmptyArrayWithNil: true})
So(err, ShouldBeNil)

So(argStrs, ShouldResemble, []string{"e"})
So(ret, ShouldEqual, 456)
})
}

+ 111
- 0
pkg/trie/trie.go View File

@@ -0,0 +1,111 @@
package trie

const (
WORD_ANY = 0
)

type Node[T any] struct {
WordNexts map[string]*Node[T]
AnyNext *Node[T]
Value T
}

func (n *Node[T]) WalkNext(word string) *Node[T] {
if n.WordNexts == nil {
return n.AnyNext
}

node, ok := n.WordNexts[word]
if ok {
return node
}

return n.AnyNext
}

func (n *Node[T]) walkWordNext(word string) (*Node[T], bool) {
if n.WordNexts == nil {
return n.AnyNext, false
}

node, ok := n.WordNexts[word]
if ok {
return node, true
}

return n.AnyNext, false
}

func (n *Node[T]) Create(word string) *Node[T] {
if n.WordNexts == nil {
n.WordNexts = make(map[string]*Node[T])
}

node, ok := n.WordNexts[word]
if !ok {
node = &Node[T]{}
n.WordNexts[word] = node
}

return node
}

func (n *Node[T]) CreateAny() *Node[T] {
if n.AnyNext == nil {
n.AnyNext = &Node[T]{}
}

return n.AnyNext
}

type Trie[T any] struct {
Root Node[T]
}

func (t *Trie[T]) Walk(words []string, visitorFn func(word string, wordIndex int, node *Node[T], isWordNode bool)) bool {
ptr := &t.Root

for index, word := range words {
var isWord bool
ptr, isWord = ptr.walkWordNext(word)
if ptr == nil {
return false
}

visitorFn(word, index, ptr, isWord)
}

return true
}

func (t *Trie[T]) WalkEnd(words []string) (*Node[T], bool) {
ptr := &t.Root

for _, word := range words {
ptr = ptr.WalkNext(word)
if ptr == nil {
return nil, false
}
}

return ptr, true
}

func (t *Trie[T]) Create(words []any) *Node[T] {
ptr := &t.Root

for _, word := range words {
switch val := word.(type) {
case string:
ptr = ptr.Create(val)

case int:
ptr = ptr.CreateAny()

default:
panic("word can only be string or int 0")
}
}

return ptr
}

+ 80
- 0
pkg/trie/trie_test.go View File

@@ -0,0 +1,80 @@
package trie

import (
"testing"

. "github.com/smartystreets/goconvey/convey"
)

func Test_CommandTrie(t *testing.T) {
Convey("全是Word节点", t, func() {
trie := Trie[int]{}

{
n := trie.Create([]any{"a", "b"})
So(n, ShouldNotBeNil)
n.Value = 123
}

{
n, ok := trie.WalkEnd([]string{"a", "b"})
So(n, ShouldNotBeNil)
So(ok, ShouldBeTrue)
So(n.Value, ShouldEqual, 123)
}
})

Convey("包含Any节点", t, func() {
trie := Trie[int]{}

{
n := trie.Create([]any{"a", WORD_ANY, "b"})
So(n, ShouldNotBeNil)
n.Value = 123
}

{
n, ok := trie.WalkEnd([]string{"a", "11", "b"})
So(n, ShouldNotBeNil)
So(ok, ShouldBeTrue)
So(n.Value, ShouldEqual, 123)
}

{
n, ok := trie.WalkEnd([]string{"a", "22", "b"})
So(n, ShouldNotBeNil)
So(ok, ShouldBeTrue)
So(n.Value, ShouldEqual, 123)
}
})

Convey("优先经过Word节点", t, func() {
trie := Trie[int]{}

{
n := trie.Create([]any{"a", "b", "c"})
So(n, ShouldNotBeNil)
n.Value = 123
}

{
n := trie.Create([]any{"a", WORD_ANY, "c"})
So(n, ShouldNotBeNil)
n.Value = 456
}

{
n, ok := trie.WalkEnd([]string{"a", "b", "c"})
So(n, ShouldNotBeNil)
So(ok, ShouldBeTrue)
So(n.Value, ShouldEqual, 123)
}

{
n, ok := trie.WalkEnd([]string{"a", "d", "c"})
So(n, ShouldNotBeNil)
So(ok, ShouldBeTrue)
So(n.Value, ShouldEqual, 456)
}
})
}

Loading…
Cancel
Save