You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

command_trie.go 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. package cmdtrie
  2. import (
  3. "fmt"
  4. "reflect"
  5. "strconv"
  6. myreflect "gitlink.org.cn/cloudream/common/utils/reflect"
  7. )
  8. type command struct {
  9. fn reflect.Value
  10. fnType reflect.Type
  11. staticArgTypes []reflect.Type
  12. lastIsArray bool
  13. }
  14. type trieNode struct {
  15. nexts map[string]*trieNode
  16. cmd *command
  17. }
  18. type anyCommandTrie struct {
  19. root trieNode
  20. ctxType reflect.Type
  21. retType reflect.Type
  22. }
  23. func newAnyCommandTrie(ctxType reflect.Type, retType reflect.Type) anyCommandTrie {
  24. return anyCommandTrie{
  25. root: trieNode{
  26. nexts: make(map[string]*trieNode),
  27. },
  28. ctxType: ctxType,
  29. retType: retType,
  30. }
  31. }
  32. func (t *anyCommandTrie) Add(fn any, prefixWords ...string) error {
  33. typ := reflect.TypeOf(fn)
  34. if typ.Kind() != reflect.Func {
  35. return fmt.Errorf("fn must be a function, but get a %v", typ.Kind())
  36. }
  37. err := t.checkFnReturn(typ)
  38. if err != nil {
  39. return err
  40. }
  41. err = t.checkFnArgs(typ)
  42. if err != nil {
  43. return err
  44. }
  45. ptr := &t.root
  46. for _, word := range prefixWords {
  47. next, ok := ptr.nexts[word]
  48. if !ok {
  49. next = &trieNode{
  50. nexts: make(map[string]*trieNode),
  51. }
  52. ptr.nexts[word] = next
  53. }
  54. ptr = next
  55. }
  56. fnType := reflect.TypeOf(fn)
  57. var staticArgTypes []reflect.Type
  58. if t.ctxType != nil {
  59. for i := 1; i < fnType.NumIn(); i++ {
  60. staticArgTypes = append(staticArgTypes, fnType.In(i))
  61. }
  62. } else {
  63. for i := 0; i < fnType.NumIn(); i++ {
  64. staticArgTypes = append(staticArgTypes, fnType.In(i))
  65. }
  66. }
  67. var lastIsArray = false
  68. if len(staticArgTypes) > 0 {
  69. kind := staticArgTypes[len(staticArgTypes)-1].Kind()
  70. lastIsArray = kind == reflect.Array || kind == reflect.Slice
  71. }
  72. ptr.cmd = &command{
  73. fn: reflect.ValueOf(fn),
  74. fnType: reflect.TypeOf(fn),
  75. staticArgTypes: staticArgTypes,
  76. lastIsArray: lastIsArray,
  77. }
  78. return nil
  79. }
  80. func (t *anyCommandTrie) checkFnReturn(typ reflect.Type) error {
  81. if t.retType != nil {
  82. if typ.NumOut() != 1 {
  83. return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
  84. }
  85. fnRetType := typ.Out(0)
  86. if t.retType.Kind() == reflect.Interface {
  87. // 如果TRet是接口类型,那么fn的返回值只要实现了此接口,就也可以接受
  88. if !fnRetType.Implements(t.retType) {
  89. return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
  90. }
  91. } else if fnRetType != t.retType {
  92. return fmt.Errorf("fn must have one return value with type %s", t.retType.Name())
  93. }
  94. }
  95. return nil
  96. }
  97. func (t *anyCommandTrie) checkFnArgs(typ reflect.Type) error {
  98. if t.ctxType != nil {
  99. if typ.NumIn() < 1 {
  100. return fmt.Errorf("fn must have a ctx argument")
  101. }
  102. for i := 0; i < typ.NumIn(); i++ {
  103. argType := typ.In(i)
  104. if i == 0 && argType != t.ctxType {
  105. return fmt.Errorf("first argument of fn must be %s", t.ctxType.Name())
  106. }
  107. if argType.Kind() == reflect.Array && i < typ.NumIn()-1 {
  108. return fmt.Errorf("array argument must at the last one")
  109. }
  110. }
  111. } else {
  112. for i := 0; i < typ.NumIn(); i++ {
  113. argType := typ.In(i)
  114. if argType.Kind() == reflect.Array && i < typ.NumIn()-1 {
  115. return fmt.Errorf("array argument must at the last one")
  116. }
  117. }
  118. }
  119. return nil
  120. }
  121. func (t *anyCommandTrie) Execute(ctx any, cmdWords ...string) ([]reflect.Value, error) {
  122. var cmd *command
  123. var argWords []string
  124. cmd, argWords, err := t.findCommand(cmdWords, argWords)
  125. if err != nil {
  126. return nil, err
  127. }
  128. if cmd.lastIsArray {
  129. // 最后一个参数如果是数组,那么可以少一个参数
  130. if len(argWords) < len(cmd.staticArgTypes)-1 {
  131. return nil, fmt.Errorf("no enough arguments for command")
  132. }
  133. } else if len(argWords) < len(cmd.staticArgTypes) {
  134. return nil, fmt.Errorf("no enough arguments for command")
  135. }
  136. var callArgs []reflect.Value
  137. // 如果有Ctx参数,则加上Ctx参数
  138. if t.ctxType != nil {
  139. callArgs = append(callArgs, reflect.ValueOf(ctx))
  140. }
  141. // 数组参数只能是最后一个,所以先处理最后一个参数前的参数
  142. callArgs, err = t.parseFrontArgs(cmd, argWords, callArgs)
  143. if err != nil {
  144. return nil, err
  145. }
  146. // 解析最后一个参数
  147. callArgs, err = t.parseLastArg(cmd, argWords, callArgs)
  148. if err != nil {
  149. return nil, err
  150. }
  151. return cmd.fn.Call(callArgs), nil
  152. }
  153. func (t *anyCommandTrie) findCommand(cmdWords []string, argWords []string) (*command, []string, error) {
  154. var cmd *command
  155. ptr := &t.root
  156. for i := 0; i < len(cmdWords); i++ {
  157. next, ok := ptr.nexts[cmdWords[i]]
  158. if !ok {
  159. break
  160. }
  161. if next != nil {
  162. cmd = next.cmd
  163. argWords = cmdWords[i+1:]
  164. }
  165. ptr = next
  166. }
  167. if cmd == nil {
  168. return nil, nil, fmt.Errorf("command not found")
  169. }
  170. return cmd, argWords, nil
  171. }
  172. func (t *anyCommandTrie) parseFrontArgs(cmd *command, argWords []string, callArgs []reflect.Value) ([]reflect.Value, error) {
  173. for i := 0; i < len(cmd.staticArgTypes)-1; i++ {
  174. val, err := t.parseValue(argWords[i], cmd.staticArgTypes[i])
  175. if err != nil {
  176. // 如果有Ctx参数,则参数的位置要往后一个
  177. argIndex := i
  178. if t.ctxType != nil {
  179. argIndex++
  180. }
  181. return nil, fmt.Errorf("cannot parse function argument at %d, err: %s", argIndex, err.Error())
  182. }
  183. callArgs = append(callArgs, val)
  184. }
  185. return callArgs, nil
  186. }
  187. func (t *anyCommandTrie) parseLastArg(cmd *command, argWords []string, callArgs []reflect.Value) ([]reflect.Value, error) {
  188. if len(cmd.staticArgTypes) > 0 {
  189. lastArgType := cmd.staticArgTypes[len(cmd.staticArgTypes)-1]
  190. lastArgWords := argWords[len(cmd.staticArgTypes)-1:]
  191. lastArgTypeKind := lastArgType.Kind()
  192. var lastArg reflect.Value
  193. if lastArgTypeKind == reflect.Array || lastArgTypeKind == reflect.Slice {
  194. if lastArgType.Kind() == reflect.Array {
  195. lastArg = reflect.New(lastArgType)
  196. } else if lastArgType.Kind() == reflect.Slice {
  197. lastArg = reflect.MakeSlice(lastArgType, len(lastArgWords), len(lastArgWords))
  198. }
  199. for i := 0; i < len(lastArgWords); i++ {
  200. eleVal, err := t.parseValue(lastArgWords[i], lastArgType.Elem())
  201. if err != nil {
  202. return nil, fmt.Errorf("cannot parse as array element, err: %s", err.Error())
  203. }
  204. lastArg.Index(i).Set(eleVal)
  205. }
  206. } else {
  207. if len(lastArgWords) == 0 {
  208. return nil, fmt.Errorf("no enough arguments for command")
  209. }
  210. var err error
  211. lastArg, err = t.parseValue(lastArgWords[0], lastArgType)
  212. if err != nil {
  213. return nil, fmt.Errorf("cannot parse function argument at %d, err: %s", cmd.fnType.NumIn()-1, err.Error())
  214. }
  215. }
  216. callArgs = append(callArgs, lastArg)
  217. }
  218. return callArgs, nil
  219. }
  220. func (t *anyCommandTrie) parseValue(word string, valueType reflect.Type) (reflect.Value, error) {
  221. valTypeKind := valueType.Kind()
  222. if valTypeKind == reflect.String {
  223. return reflect.ValueOf(word), nil
  224. }
  225. if reflect.Int <= valTypeKind && valTypeKind <= reflect.Int64 {
  226. i, err := strconv.ParseInt(word, 0, 64)
  227. if err != nil {
  228. return reflect.Value{}, err
  229. }
  230. return reflect.ValueOf(i).Convert(valueType), nil
  231. }
  232. if reflect.Uint <= valTypeKind && valTypeKind <= reflect.Uint64 {
  233. i, err := strconv.ParseUint(word, 0, 64)
  234. if err != nil {
  235. return reflect.Value{}, err
  236. }
  237. return reflect.ValueOf(i).Convert(valueType), nil
  238. }
  239. if reflect.Float32 <= valTypeKind && valTypeKind <= reflect.Float64 {
  240. i, err := strconv.ParseFloat(word, 64)
  241. if err != nil {
  242. return reflect.Value{}, err
  243. }
  244. return reflect.ValueOf(i).Convert(valueType), nil
  245. }
  246. if valTypeKind == reflect.Bool {
  247. b, err := strconv.ParseBool(word)
  248. if err != nil {
  249. return reflect.Value{}, err
  250. }
  251. return reflect.ValueOf(b), nil
  252. }
  253. return reflect.Value{}, fmt.Errorf("cannot parse string as %s", valueType.Name())
  254. }
  255. type CommandTrie[TCtx any, TRet any] struct {
  256. anyTrie anyCommandTrie
  257. }
  258. func NewCommandTrie[TCtx any, TRet any]() CommandTrie[TCtx, TRet] {
  259. return CommandTrie[TCtx, TRet]{
  260. anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), myreflect.TypeOf[TRet]()),
  261. }
  262. }
  263. func (t *CommandTrie[TCtx, TRet]) Add(fn any, prefixWords ...string) error {
  264. return t.anyTrie.Add(fn, prefixWords...)
  265. }
  266. func (t *CommandTrie[TCtx, TRet]) Execute(ctx TCtx, cmdWords ...string) (TRet, error) {
  267. retValues, err := t.anyTrie.Execute(ctx, cmdWords...)
  268. if err != nil {
  269. var defRet TRet
  270. return defRet, err
  271. }
  272. return retValues[0].Interface().(TRet), nil
  273. }
  274. type VoidCommandTrie[TCtx any] struct {
  275. anyTrie anyCommandTrie
  276. }
  277. func NewVoidCommandTrie[TCtx any]() VoidCommandTrie[TCtx] {
  278. return VoidCommandTrie[TCtx]{
  279. anyTrie: newAnyCommandTrie(myreflect.TypeOf[TCtx](), nil),
  280. }
  281. }
  282. func (t *VoidCommandTrie[TCtx]) Add(fn any, prefixWords ...string) error {
  283. return t.anyTrie.Add(fn, prefixWords...)
  284. }
  285. func (t *VoidCommandTrie[TCtx]) Execute(ctx TCtx, cmdWords ...string) error {
  286. _, err := t.anyTrie.Execute(ctx, cmdWords...)
  287. return err
  288. }
  289. type StaticCommandTrie[TRet any] struct {
  290. anyTrie anyCommandTrie
  291. }
  292. func NewStaticCommandTrie[TRet any]() StaticCommandTrie[TRet] {
  293. return StaticCommandTrie[TRet]{
  294. anyTrie: newAnyCommandTrie(nil, myreflect.TypeOf[TRet]()),
  295. }
  296. }
  297. func (t *StaticCommandTrie[TRet]) Add(fn any, prefixWords ...string) error {
  298. return t.anyTrie.Add(fn, prefixWords...)
  299. }
  300. func (t *StaticCommandTrie[TRet]) Execute(cmdWords ...string) (TRet, error) {
  301. retValues, err := t.anyTrie.Execute(nil, cmdWords...)
  302. if err != nil {
  303. var defRet TRet
  304. return defRet, err
  305. }
  306. return retValues[0].Interface().(TRet), nil
  307. }

公共库

Contributors (1)