diff --git a/pkg/cmdtrie/command_trie.go b/pkg/cmdtrie/command_trie.go index 0de80f7..a3a2339 100644 --- a/pkg/cmdtrie/command_trie.go +++ b/pkg/cmdtrie/command_trie.go @@ -318,6 +318,13 @@ func (t *CommandTrie[TCtx, TRet]) Add(fn any, prefixWords ...string) error { return t.anyTrie.Add(fn, prefixWords...) } +func (t *CommandTrie[TCtx, TRet]) MustAdd(fn any, prefixWords ...string) { + err := t.anyTrie.Add(fn, prefixWords...) + if err != nil { + panic(err.Error()) + } +} + func (t *CommandTrie[TCtx, TRet]) Execute(ctx TCtx, cmdWords ...string) (TRet, error) { retValues, err := t.anyTrie.Execute(ctx, cmdWords...) if err != nil { @@ -325,6 +332,11 @@ func (t *CommandTrie[TCtx, TRet]) Execute(ctx TCtx, cmdWords ...string) (TRet, e return defRet, err } + if retValues[0].Kind() == reflect.Interface && retValues[0].IsNil() { + var ret TRet + return ret, nil + } + return retValues[0].Interface().(TRet), nil } @@ -342,6 +354,13 @@ func (t *VoidCommandTrie[TCtx]) Add(fn any, prefixWords ...string) error { return t.anyTrie.Add(fn, prefixWords...) } +func (t *VoidCommandTrie[TCtx]) MustAdd(fn any, prefixWords ...string) { + err := t.anyTrie.Add(fn, prefixWords...) + if err != nil { + panic(err.Error()) + } +} + func (t *VoidCommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error { _, err := t.anyTrie.Execute(ctx, cmdWords...) return err @@ -361,6 +380,13 @@ func (t *StaticCommandTrie[TRet]) Add(fn any, prefixWords ...string) error { return t.anyTrie.Add(fn, prefixWords...) } +func (t *StaticCommandTrie[TRet]) MustAdd(fn any, prefixWords ...string) { + err := t.anyTrie.Add(fn, prefixWords...) + if err != nil { + panic(err.Error()) + } +} + func (t *StaticCommandTrie[TRet]) Execute(cmdWords ...string) (TRet, error) { retValues, err := t.anyTrie.Execute(nil, cmdWords...) if err != nil { @@ -368,5 +394,10 @@ func (t *StaticCommandTrie[TRet]) Execute(cmdWords ...string) (TRet, error) { return defRet, err } + if retValues[0].Kind() == reflect.Interface && retValues[0].IsNil() { + var ret TRet + return ret, nil + } + return retValues[0].Interface().(TRet), nil }