You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

command_trie.go 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. package cmdtrie
  2. import (
  3. "fmt"
  4. "reflect"
  5. "strconv"
  6. myreflect "gitlink.org.cn/cloudream/common/utils/reflect"
  7. )
  8. type ExecuteOption struct {
  9. ReplaceEmptyArrayWithNil bool // 如果最后一个参数是空数组,则调用命令的时候传递nil参数
  10. }
  11. type command struct {
  12. fn reflect.Value
  13. fnType reflect.Type
  14. staticArgTypes []reflect.Type
  15. lastIsArray bool
  16. }
  17. type trieNode struct {
  18. nexts map[string]*trieNode
  19. cmd *command
  20. }
  21. type anyCommandTrie struct {
  22. root trieNode
  23. ctxType reflect.Type
  24. retType reflect.Type
  25. }
  26. func newAnyCommandTrie(ctxType reflect.Type, retType reflect.Type) anyCommandTrie {
  27. return anyCommandTrie{
  28. root: trieNode{
  29. nexts: make(map[string]*trieNode),
  30. },
  31. ctxType: ctxType,
  32. retType: retType,
  33. }
  34. }
  35. func (t *anyCommandTrie) Add(fn any, prefixWords ...string) error {
  36. typ := reflect.TypeOf(fn)
  37. if typ.Kind() != reflect.Func {
  38. return fmt.Errorf("fn must be a function, but get a %v", typ.Kind())
  39. }
  40. err := t.checkFnReturn(typ)
  41. if err != nil {
  42. return err
  43. }
  44. err = t.checkFnArgs(typ)
  45. if err != nil {
  46. return err
  47. }
  48. ptr := &t.root
  49. for _, word := range prefixWords {
  50. next, ok := ptr.nexts[word]
  51. if !ok {
  52. next = &trieNode{
  53. nexts: make(map[string]*trieNode),
  54. }
  55. ptr.nexts[word] = next
  56. }
  57. ptr = next
  58. }
  59. fnType := reflect.TypeOf(fn)
  60. var staticArgTypes []reflect.Type
  61. if t.ctxType != nil {
  62. for i := 1; i < fnType.NumIn(); i++ {
  63. staticArgTypes = append(staticArgTypes, fnType.In(i))
  64. }
  65. } else {
  66. for i := 0; i < fnType.NumIn(); i++ {
  67. staticArgTypes = append(staticArgTypes, fnType.In(i))
  68. }
  69. }
  70. var lastIsArray = false
  71. if len(staticArgTypes) > 0 {
  72. kind := staticArgTypes[len(staticArgTypes)-1].Kind()
  73. lastIsArray = kind == reflect.Array || kind == reflect.Slice
  74. }
  75. ptr.cmd = &command{
  76. fn: reflect.ValueOf(fn),
  77. fnType: reflect.TypeOf(fn),
  78. staticArgTypes: staticArgTypes,
  79. lastIsArray: lastIsArray,
  80. }
  81. return nil
  82. }
  83. func (t *anyCommandTrie) checkFnReturn(typ reflect.Type) error {
  84. if t.retType != nil {
  85. if typ.NumOut() != 1 {
  86. return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
  87. }
  88. fnRetType := typ.Out(0)
  89. if t.retType.Kind() == reflect.Interface {
  90. // 如果TRet是接口类型,那么fn的返回值只要实现了此接口,就也可以接受
  91. if !fnRetType.Implements(t.retType) {
  92. return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
  93. }
  94. } else if fnRetType != t.retType {
  95. return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
  96. }
  97. }
  98. return nil
  99. }
  100. func (t *anyCommandTrie) checkFnArgs(typ reflect.Type) error {
  101. if t.ctxType != nil {
  102. if typ.NumIn() < 1 {
  103. return fmt.Errorf("fn must have a ctx argument")
  104. }
  105. for i := 0; i < typ.NumIn(); i++ {
  106. argType := typ.In(i)
  107. if i == 0 && argType != t.ctxType {
  108. return fmt.Errorf("first argument of fn must be %s", t.ctxType.Name())
  109. }
  110. if argType.Kind() == reflect.Array && i < typ.NumIn()-1 {
  111. return fmt.Errorf("array argument must at the last one")
  112. }
  113. }
  114. } else {
  115. for i := 0; i < typ.NumIn(); i++ {
  116. argType := typ.In(i)
  117. if argType.Kind() == reflect.Array && i < typ.NumIn()-1 {
  118. return fmt.Errorf("array argument must at the last one")
  119. }
  120. }
  121. }
  122. return nil
  123. }
  124. func (t *anyCommandTrie) Execute(ctx any, cmdWords []string, opt ExecuteOption) ([]reflect.Value, error) {
  125. var cmd *command
  126. var argWords []string
  127. cmd, argWords, err := t.findCommand(cmdWords, argWords)
  128. if err != nil {
  129. return nil, err
  130. }
  131. if cmd.lastIsArray {
  132. // 最后一个参数如果是数组,那么可以少一个参数
  133. if len(argWords) < len(cmd.staticArgTypes)-1 {
  134. return nil, fmt.Errorf("no enough arguments for command")
  135. }
  136. } else if len(argWords) < len(cmd.staticArgTypes) {
  137. return nil, fmt.Errorf("no enough arguments for command")
  138. }
  139. var callArgs []reflect.Value
  140. // 如果有Ctx参数,则加上Ctx参数
  141. if t.ctxType != nil {
  142. callArgs = append(callArgs, reflect.ValueOf(ctx))
  143. }
  144. // 数组参数只能是最后一个,所以先处理最后一个参数前的参数
  145. callArgs, err = t.parseFrontArgs(cmd, argWords, callArgs)
  146. if err != nil {
  147. return nil, err
  148. }
  149. // 解析最后一个参数
  150. callArgs, err = t.parseLastArg(cmd, argWords, opt, callArgs)
  151. if err != nil {
  152. return nil, err
  153. }
  154. return cmd.fn.Call(callArgs), nil
  155. }
  156. func (t *anyCommandTrie) findCommand(cmdWords []string, argWords []string) (*command, []string, error) {
  157. var cmd *command
  158. ptr := &t.root
  159. for i := 0; i < len(cmdWords); i++ {
  160. next, ok := ptr.nexts[cmdWords[i]]
  161. if !ok {
  162. break
  163. }
  164. if next != nil {
  165. cmd = next.cmd
  166. argWords = cmdWords[i+1:]
  167. }
  168. ptr = next
  169. }
  170. if cmd == nil {
  171. return nil, nil, fmt.Errorf("command not found")
  172. }
  173. return cmd, argWords, nil
  174. }
  175. func (t *anyCommandTrie) parseFrontArgs(cmd *command, argWords []string, callArgs []reflect.Value) ([]reflect.Value, error) {
  176. for i := 0; i < len(cmd.staticArgTypes)-1; i++ {
  177. val, err := t.parseValue(argWords[i], cmd.staticArgTypes[i])
  178. if err != nil {
  179. // 如果有Ctx参数,则参数的位置要往后一个
  180. argIndex := i
  181. if t.ctxType != nil {
  182. argIndex++
  183. }
  184. return nil, fmt.Errorf("cannot parse function argument at %d, err: %s", argIndex, err.Error())
  185. }
  186. callArgs = append(callArgs, val)
  187. }
  188. return callArgs, nil
  189. }
  190. func (t *anyCommandTrie) parseLastArg(cmd *command, argWords []string, opt ExecuteOption, callArgs []reflect.Value) ([]reflect.Value, error) {
  191. if len(cmd.staticArgTypes) > 0 {
  192. lastArgType := cmd.staticArgTypes[len(cmd.staticArgTypes)-1]
  193. lastArgWords := argWords[len(cmd.staticArgTypes)-1:]
  194. lastArgTypeKind := lastArgType.Kind()
  195. var lastArg reflect.Value
  196. if lastArgTypeKind == reflect.Array || lastArgTypeKind == reflect.Slice {
  197. if lastArgType.Kind() == reflect.Array {
  198. lastArg = reflect.New(lastArgType)
  199. } else if lastArgType.Kind() == reflect.Slice {
  200. lastArg = reflect.MakeSlice(lastArgType, len(lastArgWords), len(lastArgWords))
  201. }
  202. for i := 0; i < len(lastArgWords); i++ {
  203. eleVal, err := t.parseValue(lastArgWords[i], lastArgType.Elem())
  204. if err != nil {
  205. return nil, fmt.Errorf("cannot parse as array element, err: %s", err.Error())
  206. }
  207. lastArg.Index(i).Set(eleVal)
  208. }
  209. if opt.ReplaceEmptyArrayWithNil && lastArg.Len() == 0 {
  210. lastArg = reflect.Zero(lastArgType)
  211. }
  212. } else {
  213. if len(lastArgWords) == 0 {
  214. return nil, fmt.Errorf("no enough arguments for command")
  215. }
  216. var err error
  217. lastArg, err = t.parseValue(lastArgWords[0], lastArgType)
  218. if err != nil {
  219. return nil, fmt.Errorf("cannot parse function argument at %d, err: %s", cmd.fnType.NumIn()-1, err.Error())
  220. }
  221. }
  222. callArgs = append(callArgs, lastArg)
  223. }
  224. return callArgs, nil
  225. }
  226. func (t *anyCommandTrie) parseValue(word string, valueType reflect.Type) (reflect.Value, error) {
  227. valTypeKind := valueType.Kind()
  228. if valTypeKind == reflect.String {
  229. return reflect.ValueOf(word), nil
  230. }
  231. if reflect.Int <= valTypeKind && valTypeKind <= reflect.Int64 {
  232. i, err := strconv.ParseInt(word, 0, 64)
  233. if err != nil {
  234. return reflect.Value{}, err
  235. }
  236. return reflect.ValueOf(i).Convert(valueType), nil
  237. }
  238. if reflect.Uint <= valTypeKind && valTypeKind <= reflect.Uint64 {
  239. i, err := strconv.ParseUint(word, 0, 64)
  240. if err != nil {
  241. return reflect.Value{}, err
  242. }
  243. return reflect.ValueOf(i).Convert(valueType), nil
  244. }
  245. if reflect.Float32 <= valTypeKind && valTypeKind <= reflect.Float64 {
  246. i, err := strconv.ParseFloat(word, 64)
  247. if err != nil {
  248. return reflect.Value{}, err
  249. }
  250. return reflect.ValueOf(i).Convert(valueType), nil
  251. }
  252. if valTypeKind == reflect.Bool {
  253. b, err := strconv.ParseBool(word)
  254. if err != nil {
  255. return reflect.Value{}, err
  256. }
  257. return reflect.ValueOf(b), nil
  258. }
  259. return reflect.Value{}, fmt.Errorf("cannot parse string as %s", valueType.Name())
  260. }
  261. type CommandTrie[TCtx any, TRet any] struct {
  262. anyTrie anyCommandTrie
  263. }
  264. func NewCommandTrie[TCtx any, TRet any]() CommandTrie[TCtx, TRet] {
  265. return CommandTrie[TCtx, TRet]{
  266. anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), myreflect.TypeOf[TRet]()),
  267. }
  268. }
  269. func (t *CommandTrie[TCtx, TRet]) Add(fn any, prefixWords ...string) error {
  270. return t.anyTrie.Add(fn, prefixWords...)
  271. }
  272. func (t *CommandTrie[TCtx, TRet]) MustAdd(fn any, prefixWords ...string) {
  273. err := t.anyTrie.Add(fn, prefixWords...)
  274. if err != nil {
  275. panic(err.Error())
  276. }
  277. }
  278. func (t *CommandTrie[TCtx, TRet]) Execute(ctx TCtx, cmdWords []string, opts ...ExecuteOption) (TRet, error) {
  279. opt := ExecuteOption{}
  280. if len(opts) > 0 {
  281. opt = opts[0]
  282. }
  283. retValues, err := t.anyTrie.Execute(ctx, cmdWords, opt)
  284. if err != nil {
  285. var defRet TRet
  286. return defRet, err
  287. }
  288. if retValues[0].Kind() == reflect.Interface && retValues[0].IsNil() {
  289. var ret TRet
  290. return ret, nil
  291. }
  292. return retValues[0].Interface().(TRet), nil
  293. }
  294. type VoidCommandTrie[TCtx any] struct {
  295. anyTrie anyCommandTrie
  296. }
  297. func NewVoidCommandTrie[TCtx any]() VoidCommandTrie[TCtx] {
  298. return VoidCommandTrie[TCtx]{
  299. anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), nil),
  300. }
  301. }
  302. func (t *VoidCommandTrie[TCtx]) Add(fn any, prefixWords ...string) error {
  303. return t.anyTrie.Add(fn, prefixWords...)
  304. }
  305. func (t *VoidCommandTrie[TCtx]) MustAdd(fn any, prefixWords ...string) {
  306. err := t.anyTrie.Add(fn, prefixWords...)
  307. if err != nil {
  308. panic(err.Error())
  309. }
  310. }
  311. func (t *VoidCommandTrie[TCtx]) Execute(ctx TCtx, cmdWords []string, opts ...ExecuteOption) error {
  312. opt := ExecuteOption{}
  313. if len(opts) > 0 {
  314. opt = opts[0]
  315. }
  316. _, err := t.anyTrie.Execute(ctx, cmdWords, opt)
  317. return err
  318. }
  319. type StaticCommandTrie[TRet any] struct {
  320. anyTrie anyCommandTrie
  321. }
  322. func NewStaticCommandTrie[TRet any]() StaticCommandTrie[TRet] {
  323. return StaticCommandTrie[TRet]{
  324. anyTrie: newAnyCommandTrie(nil, myreflect.TypeOf[TRet]()),
  325. }
  326. }
  327. func (t *StaticCommandTrie[TRet]) Add(fn any, prefixWords ...string) error {
  328. return t.anyTrie.Add(fn, prefixWords...)
  329. }
  330. func (t *StaticCommandTrie[TRet]) MustAdd(fn any, prefixWords ...string) {
  331. err := t.anyTrie.Add(fn, prefixWords...)
  332. if err != nil {
  333. panic(err.Error())
  334. }
  335. }
  336. func (t *StaticCommandTrie[TRet]) Execute(cmdWords []string, opts ...ExecuteOption) (TRet, error) {
  337. opt := ExecuteOption{}
  338. if len(opts) > 0 {
  339. opt = opts[0]
  340. }
  341. retValues, err := t.anyTrie.Execute(nil, cmdWords, opt)
  342. if err != nil {
  343. var defRet TRet
  344. return defRet, err
  345. }
  346. if retValues[0].Kind() == reflect.Interface && retValues[0].IsNil() {
  347. var ret TRet
  348. return ret, nil
  349. }
  350. return retValues[0].Interface().(TRet), nil
  351. }

公共库

Contributors (1)