You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

flag_parser.h 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef PREDICT_COMMON_FLAG_PARSER_H_
  17. #define PREDICT_COMMON_FLAG_PARSER_H_
  18. #include <functional>
  19. #include <map>
  20. #include <utility>
  21. #include <string>
  22. #include "common/utils.h"
  23. #include "common/option.h"
  24. namespace mindspore {
  25. namespace predict {
  26. struct FlagInfo;
  27. struct Nothing {};
  28. class FlagParser {
  29. public:
  30. FlagParser() { AddFlag(&FlagParser::help, "help", "print usage message", false); }
  31. virtual ~FlagParser() = default;
  32. // only support read flags from command line
  33. virtual Option<std::string> ParseFlags(int argc, const char *const *argv, bool supportUnknown = false,
  34. bool supportDuplicate = false);
  35. std::string Usage(const Option<std::string> &usgMsg = Option<std::string>(None())) const;
  36. template <typename Flags, typename T1, typename T2>
  37. void AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2);
  38. template <typename Flags, typename T1, typename T2>
  39. void AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2);
  40. template <typename Flags, typename T>
  41. void AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo);
  42. // Option-type fields
  43. template <typename Flags, typename T>
  44. void AddFlag(Option<T> Flags::*t, const std::string &flagName, const std::string &helpInfo);
  45. bool help;
  46. protected:
  47. std::string binName;
  48. Option<std::string> usageMsg;
  49. private:
  50. struct FlagInfo {
  51. std::string flagName;
  52. bool isRequired;
  53. bool isBoolean;
  54. std::string helpInfo;
  55. bool isParsed;
  56. std::function<Option<Nothing>(FlagParser *, const std::string &)> parse;
  57. };
  58. inline void AddFlag(const FlagInfo &flag);
  59. // construct a temporary flag
  60. template <typename Flags, typename T>
  61. void ConstructFlag(Option<T> Flags::*t, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag);
  62. // construct a temporary flag
  63. template <typename Flags, typename T1>
  64. void ConstructFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag);
  65. Option<std::string> InnerParseFlags(std::multimap<std::string, Option<std::string>> *values);
  66. bool GetRealFlagName(const std::string &oriFlagName, std::string *flagName);
  67. std::map<std::string, FlagInfo> flags;
  68. };
  69. // convert to std::string
  70. template <typename Flags, typename T>
  71. Option<std::string> ConvertToString(T Flags::*t, const FlagParser &baseFlag) {
  72. const Flags *flag = dynamic_cast<Flags *>(&baseFlag);
  73. if (flag != nullptr) {
  74. return std::to_string(flag->*t);
  75. }
  76. return Option<std::string>(None());
  77. }
  78. // construct for a Option-type flag
  79. template <typename Flags, typename T>
  80. void FlagParser::ConstructFlag(Option<T> Flags::*t1, const std::string &flagName, const std::string &helpInfo,
  81. FlagInfo *flag) {
  82. if (flag == nullptr) {
  83. MS_LOGE("FlagInfo is nullptr");
  84. return;
  85. }
  86. flag->flagName = flagName;
  87. flag->helpInfo = helpInfo;
  88. flag->isBoolean = typeid(T) == typeid(bool);
  89. flag->isParsed = false;
  90. }
  91. // construct a temporary flag
  92. template <typename Flags, typename T>
  93. void FlagParser::ConstructFlag(T Flags::*t1, const std::string &flagName, const std::string &helpInfo, FlagInfo *flag) {
  94. if (flag == nullptr) {
  95. MS_LOGE("FlagInfo is nullptr");
  96. return;
  97. }
  98. if (t1 == nullptr) {
  99. MS_LOGE("t1 is nullptr");
  100. return;
  101. }
  102. flag->flagName = flagName;
  103. flag->helpInfo = helpInfo;
  104. flag->isBoolean = typeid(T) == typeid(bool);
  105. flag->isParsed = false;
  106. }
  107. inline void FlagParser::AddFlag(const FlagInfo &flagItem) { flags[flagItem.flagName] = flagItem; }
  108. template <typename Flags, typename T>
  109. void FlagParser::AddFlag(T Flags::*t, const std::string &flagName, const std::string &helpInfo) {
  110. if (t == nullptr) {
  111. MS_LOGE("t1 is nullptr");
  112. return;
  113. }
  114. Flags *flag = dynamic_cast<Flags *>(this);
  115. if (flag == nullptr) {
  116. MS_LOGI("dynamic_cast failed");
  117. return;
  118. }
  119. FlagInfo flagItem;
  120. // flagItem is as a output parameter
  121. ConstructFlag(t, flagName, helpInfo, &flagItem);
  122. flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option<Nothing> {
  123. Flags *flag = dynamic_cast<Flags *>(base);
  124. if (base != nullptr) {
  125. Option<T> ret = Option<T>(GenericParseValue<T>(value));
  126. if (ret.IsNone()) {
  127. return Option<Nothing>(None());
  128. } else {
  129. flag->*t = ret.Get();
  130. }
  131. }
  132. return Option<Nothing>(Nothing());
  133. };
  134. flagItem.isRequired = true;
  135. flagItem.helpInfo +=
  136. !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: ";
  137. flagItem.helpInfo += ")";
  138. // add this flag to a std::map
  139. AddFlag(flagItem);
  140. }
  141. template <typename Flags, typename T1, typename T2>
  142. void FlagParser::AddFlag(T1 *t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) {
  143. if (t1 == nullptr) {
  144. MS_LOGE("t1 is nullptr");
  145. return;
  146. }
  147. FlagInfo flagItem;
  148. // flagItem is as a output parameter
  149. ConstructFlag(t1, flagName, helpInfo, flagItem);
  150. flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option<Nothing> {
  151. if (base != nullptr) {
  152. Option<T1> ret = Option<T1>(GenericParseValue<T1>(value));
  153. if (ret.IsNone()) {
  154. return Option<T1>(None());
  155. } else {
  156. *t1 = ret.Get();
  157. }
  158. }
  159. return Option<Nothing>(Nothing());
  160. };
  161. flagItem.isRequired = false;
  162. *t1 = t2;
  163. flagItem.helpInfo +=
  164. !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: ";
  165. flagItem.helpInfo += ToString(t2).Get();
  166. flagItem.helpInfo += ")";
  167. // add this flag to a std::map
  168. AddFlag(flagItem);
  169. }
  170. template <typename Flags, typename T1, typename T2>
  171. void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::string &helpInfo, const T2 &t2) {
  172. if (t1 == nullptr) {
  173. MS_LOGE("t1 is nullptr");
  174. return;
  175. }
  176. Flags *flag = dynamic_cast<Flags *>(this);
  177. if (flag == nullptr) {
  178. MS_LOGI("dynamic_cast failed");
  179. return;
  180. }
  181. FlagInfo flagItem;
  182. // flagItem is as a output parameter
  183. ConstructFlag(t1, flagName, helpInfo, &flagItem);
  184. flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option<Nothing> {
  185. Flags *flag = dynamic_cast<Flags *>(base);
  186. if (base != nullptr) {
  187. Option<T1> ret = Option<T1>(GenericParseValue<T1>(value));
  188. if (ret.IsNone()) {
  189. return Option<Nothing>(None());
  190. } else {
  191. flag->*t1 = ret.Get();
  192. }
  193. }
  194. return Option<Nothing>(Nothing());
  195. };
  196. flagItem.isRequired = false;
  197. flag->*t1 = t2;
  198. flagItem.helpInfo +=
  199. !helpInfo.empty() && helpInfo.find_last_of("\n\r") != helpInfo.size() - 1 ? " (default: " : "(default: ";
  200. flagItem.helpInfo += ToString(t2).Get();
  201. flagItem.helpInfo += ")";
  202. // add this flag to a std::map
  203. AddFlag(flagItem);
  204. }
  205. // option-type add flag
  206. template <typename Flags, typename T>
  207. void FlagParser::AddFlag(Option<T> Flags::*t, const std::string &flagName, const std::string &helpInfo) {
  208. if (t == nullptr) {
  209. MS_LOGE("t is nullptr");
  210. return;
  211. }
  212. Flags *flag = dynamic_cast<Flags *>(this);
  213. if (flag == nullptr) {
  214. MS_LOGE("dynamic_cast failed");
  215. return;
  216. }
  217. FlagInfo flagItem;
  218. // flagItem is as a output parameter
  219. ConstructFlag(t, flagName, helpInfo, &flagItem);
  220. flagItem.isRequired = false;
  221. flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option<Nothing> {
  222. Flags *flag = dynamic_cast<Flags *>(base);
  223. if (base != nullptr) {
  224. Option<T> ret = Option<std::string>(GenericParseValue<T>(value));
  225. if (ret.IsNone()) {
  226. return Option<Nothing>(None());
  227. } else {
  228. flag->*t = Option<T>(Some(ret.Get()));
  229. }
  230. }
  231. return Option<Nothing>(Nothing());
  232. };
  233. // add this flag to a std::map
  234. AddFlag(flagItem);
  235. }
  236. } // namespace predict
  237. } // namespace mindspore
  238. #endif // PREDICT_COMMON_FLAG_PARSER_H_