You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensorflow_parser.cc 122 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #define protected public
  18. #define private public
  19. #include "parser/common/op_parser_factory.h"
  20. #include "parser/tensorflow/tensorflow_parser.h"
  21. #include "graph/operator_reg.h"
  22. #include "register/op_registry.h"
  23. #include "external/register/register.h"
  24. #include "parser/common/register_tbe.h"
  25. #include "st/parser_st_utils.h"
  26. #include "tests/depends/ops_stub/ops_stub.h"
  27. #include "parser/common/acl_graph_parser_util.h"
  28. #include "metadef/third_party/graphengine/inc/external/ge/ge_api_types.h"
  29. #include "omg/parser/parser_factory.h"
  30. #include "common/pre_checker.h"
  31. #include "common/util.h"
  32. #include "external/parser/tensorflow_parser.h"
  33. #include "parser/tensorflow/tensorflow_constant_parser.h"
  34. #include "common/types.h"
  35. #include "parser/common/op_def/variable_op.h"
  36. #include "parser/tensorflow/tensorflow_ref_switch_parser.h"
  37. #include "parser/tensorflow/tensorflow_fusion_op_parser.h"
  38. #include "parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h"
  39. #include "parser/common/op_def/arg_op.h"
  40. #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h"
  41. #include "parser/tensorflow/tensorflow_reshape_parser.h"
  42. #include "parser/tensorflow/tensorflow_custom_parser_adapter.h"
  43. #include "parser/tensorflow/tensorflow_squeeze_parser.h"
  44. #include "parser/tensorflow/graph_functiondef.h"
  45. #include "parser/tensorflow/graph_optimizer.h"
  46. #include "cce/dnn_base_def.hpp"
  47. #include "parser/tensorflow/scope/scope_pass_manager.h"
  48. #include "parser/tensorflow/tensorflow_util.h"
  49. #include "compute_graph_impl.h"
  50. #include "parser/tensorflow/tensorflow_enter_parser.h"
  51. #undef protected
  52. #undef private
  53. using namespace std;
  54. using namespace domi::tensorflow;
  55. using namespace domi;
  56. using namespace cce;
  57. using namespace testing;
  58. using namespace std;
  59. using namespace google::protobuf;
  60. static const string GRAPH_DEFAULT_NAME = "default";
  61. namespace ge {
  62. class STestTensorflowParser : public testing::Test {
  63. protected:
  64. void SetUp() {
  65. ParerSTestsUtils::ClearParserInnerCtx();
  66. }
  67. void TearDown() {}
  68. public:
  69. void RegisterCustomOp();
  70. };
  71. class ScopeTestPass : public ScopeBasePass {
  72. protected:
  73. vector<ScopeFusionPatterns> DefinePatterns() {
  74. vector<ScopeFusionPatterns> patterns_list;
  75. return patterns_list;
  76. };
  77. string PassName() {
  78. return "test";
  79. };
  80. Status LastMatchScopesAndOPs(shared_ptr<ScopeGraph> &scope_graph, vector<ScopesResult> &results) {
  81. return domi::SUCCESS;
  82. };
  83. void GenerateFusionResult(const vector<Scope *> &scopes, FusionScopesResult *fusion_rlt) {
  84. return;
  85. };
  86. };
  87. static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) {
  88. return SUCCESS;
  89. }
  90. static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) {
  91. return SUCCESS;
  92. }
  93. void STestTensorflowParser::RegisterCustomOp() {
  94. REGISTER_CUSTOM_OP("Add")
  95. .FrameworkType(domi::TENSORFLOW)
  96. .OriginOpType("Add")
  97. .ParseParamsFn(ParseParams);
  98. std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
  99. for (auto reg_data : reg_datas) {
  100. OpRegistrationTbe::Instance()->Finalize(reg_data);
  101. domi::OpRegistry::Instance()->Register(reg_data);
  102. }
  103. domi::OpRegistry::Instance()->registrationDatas.clear();
  104. }
  105. namespace {
  106. NodeDef* AddNode(GraphDef& graph, string type, string name) {
  107. NodeDef* nodeDef = graph.add_node();
  108. nodeDef->set_op(type);
  109. nodeDef->set_name(name);
  110. tensorflow::OpDef op_def;
  111. string op_def_string;
  112. op_def.SerializeToString(&op_def_string);
  113. tensorflow::AttrValue value;
  114. value.set_s(op_def_string);
  115. nodeDef->mutable_attr()->insert({"op_def", value});
  116. return nodeDef;
  117. }
  118. void AddInput(NodeDef* src, NodeDef* dst, int srcIndex) {
  119. if(srcIndex == -1){
  120. dst->add_input("^"+src->name());
  121. } else {
  122. if (srcIndex == 0) {
  123. dst->add_input(src->name());
  124. } else {
  125. dst->add_input(src->name() + ":" + std::to_string(srcIndex));
  126. }
  127. {
  128. auto input = (*dst->mutable_attr())[ge::ATTR_NAME_INPUT_TENSOR_DESC].mutable_list()->add_func();
  129. tensorflow::AttrValue val1;
  130. val1.set_i(0);
  131. (*input->mutable_attr())["serialize_format"] = val1;
  132. tensorflow::AttrValue val2;
  133. val2.set_i(tensorflow::DT_FLOAT);
  134. (*input->mutable_attr())["serialize_datatype"] = val2;
  135. tensorflow::AttrValue val3;
  136. val3.mutable_list()->add_i(10);
  137. (*input->mutable_attr())["serialize_shape"] = val3;
  138. }
  139. {
  140. auto output = (*src->mutable_attr())[ge::ATTR_NAME_OUTPUT_TENSOR_DESC].mutable_list()->add_func();
  141. tensorflow::AttrValue val1;
  142. val1.set_i(0);
  143. (*output->mutable_attr())["serialize_format"] = val1;
  144. tensorflow::AttrValue val2;
  145. val2.set_i(tensorflow::DT_FLOAT);
  146. (*output->mutable_attr())["serialize_datatype"] = val2;
  147. tensorflow::AttrValue val3;
  148. val3.mutable_list()->add_i(10);
  149. (*output->mutable_attr())["serialize_shape"] = val3;
  150. }
  151. }
  152. }
  153. NodeDef *initNodeDef() {
  154. NodeDef * nodeDef = new NodeDef();
  155. nodeDef->set_op("Const");
  156. ::google::protobuf::Map<std::string, tensorflow::AttrValue >* node_attr_map = nodeDef->mutable_attr();
  157. //设置 T属性
  158. domi::tensorflow::AttrValue t_attr_value;
  159. t_attr_value.set_type(domi::tensorflow::DT_INT32);
  160. (*node_attr_map)[TENSORFLOW_ATTR_T] = t_attr_value;
  161. domi::tensorflow::AttrValue dtype_attr_value;
  162. dtype_attr_value.set_type(domi::tensorflow::DT_INT32);
  163. (*node_attr_map)[TENSORFLOW_ATTR_DTYPE] = dtype_attr_value;
  164. // out_put
  165. domi::tensorflow::AttrValue outputs_attr_value;
  166. ::tensorflow::AttrValue_ListValue* list = outputs_attr_value.mutable_list();
  167. list->add_s("MatMul");
  168. (*node_attr_map)[TENSORFLOW_ATTR_OUTPUT_OP] = outputs_attr_value;
  169. // 设置 tensor 属性
  170. domi::tensorflow::AttrValue value_attr_value;
  171. tensorflow::TensorProto* tensor = value_attr_value.mutable_tensor();
  172. tensorflow::TensorShapeProto* tensor_shape = tensor->mutable_tensor_shape();
  173. tensor_shape->clear_dim();
  174. tensor_shape->add_dim()->set_size(4);
  175. tensor_shape->add_dim()->set_size(6);
  176. tensor->set_dtype(domi::tensorflow::DT_INT32);
  177. float *addr = new float[24];
  178. for (int32_t i = 0; i < 24; i++) {
  179. *(addr + i) = 1.0 + i;
  180. }
  181. tensor->set_tensor_content((void *)addr, 24 * sizeof(float));
  182. (*node_attr_map)[TENSORFLOW_ATTR_VALUE] = value_attr_value;
  183. delete[] addr;
  184. return nodeDef;
  185. }
  186. NodeDef * initOpNodeDef_VariableV2() {
  187. NodeDef * nodeDef = new NodeDef();
  188. nodeDef->set_op("VariableV2");
  189. google::protobuf::Map<std::string, tensorflow::AttrValue > *node_attr_map = nodeDef->mutable_attr();
  190. //设置data_format属性
  191. domi::tensorflow::AttrValue format_attr_value;
  192. format_attr_value.set_s("_FZ");
  193. (*node_attr_map)[VAR_ATTR_FORMAT] = format_attr_value;
  194. domi::tensorflow::AttrValue type_attr;
  195. type_attr.set_type(domi::tensorflow::DT_FLOAT);
  196. (*node_attr_map)[VAR_ATTR_DTYPE] = type_attr;
  197. domi::tensorflow::AttrValue container_attr_value;
  198. container_attr_value.set_s("container");
  199. (*node_attr_map)[VAR_ATTR_CONTAINER] = container_attr_value;
  200. domi::tensorflow::AttrValue shard_name_attr_value;
  201. shard_name_attr_value.set_s("shard_name");
  202. (*node_attr_map)[VAR_ATTR_SHARED_NAME] = shard_name_attr_value;
  203. domi::tensorflow::AttrValue shape_attr_value;
  204. shape_attr_value.mutable_shape()->add_dim()->set_size(1);
  205. shape_attr_value.mutable_shape()->add_dim()->set_size(2);
  206. shape_attr_value.mutable_shape()->add_dim()->set_size(3);
  207. shape_attr_value.mutable_shape()->add_dim()->set_size(4);
  208. (*node_attr_map)[ge::VAR_ATTR_SHAPE] = shape_attr_value;
  209. domi::tensorflow::AttrValue shape;
  210. shape.mutable_list()->add_i((int64)32);
  211. shape.mutable_list()->add_i((int64)32);
  212. shape.mutable_list()->add_i((int64)14);
  213. shape.mutable_list()->add_i((int64)14);
  214. //设置data_format属性
  215. domi::tensorflow::AttrValue df_attr_value;
  216. domi::tensorflow::AttrValue df_attr_value2;
  217. df_attr_value2.set_s(TENSORFLOWF_TENSOR_NHWC);
  218. df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC);
  219. (*node_attr_map)[TENSORFLOW_ATTR_DATA_FORMAT] = df_attr_value2;
  220. //设置padding属性
  221. domi::tensorflow::AttrValue pad_attr_value;
  222. domi::tensorflow::AttrValue pad_attr_value2;
  223. pad_attr_value2.set_s(TENSORFLOWF_OP_PADDING_SAME);
  224. (*node_attr_map)[TENSORFLOW_ATTR_PADDING] = pad_attr_value2;
  225. pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT);
  226. domi::tensorflow::NameAttrList name_attr_list;
  227. name_attr_list.set_name(std::to_string(0));
  228. name_attr_list.mutable_attr()->insert({"serialize_shape", shape});
  229. name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value});
  230. name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value});
  231. domi::tensorflow::AttrValue output_tensor_descs;
  232. *(output_tensor_descs.mutable_list()->add_func()) = name_attr_list;
  233. nodeDef->mutable_attr()->insert({ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_tensor_descs});
  234. return nodeDef;
  235. }
  236. NodeDef *initOpNodeDef_TemporaryVariable() {
  237. NodeDef * nodeDef = new NodeDef();
  238. nodeDef->set_op("TemporaryVariable");
  239. google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = nodeDef->mutable_attr();
  240. //设置dtype属性
  241. domi::tensorflow::AttrValue type_attr;
  242. type_attr.set_type(domi::tensorflow::DT_FLOAT);
  243. (*node_attr_map)[VAR_ATTR_DTYPE] = type_attr;
  244. //设置var_name属性
  245. domi::tensorflow::AttrValue var_name_attr_value;
  246. var_name_attr_value.set_s("temporary_variable_name");
  247. (*node_attr_map)[ge::VAR_ATTR_NAME] = var_name_attr_value;
  248. //设置shape属性
  249. domi::tensorflow::AttrValue shape_attr_value;
  250. shape_attr_value.mutable_shape()->add_dim()->set_size(1);
  251. shape_attr_value.mutable_shape()->add_dim()->set_size(2);
  252. shape_attr_value.mutable_shape()->add_dim()->set_size(3);
  253. shape_attr_value.mutable_shape()->add_dim()->set_size(4);
  254. (*node_attr_map)[ge::VAR_ATTR_SHAPE] = shape_attr_value;
  255. domi::tensorflow::AttrValue shape;
  256. shape.mutable_list()->add_i((int64)32);
  257. shape.mutable_list()->add_i((int64)32);
  258. shape.mutable_list()->add_i((int64)14);
  259. shape.mutable_list()->add_i((int64)14);
  260. //设置data_format属性
  261. domi::tensorflow::AttrValue df_attr_value2;
  262. df_attr_value2.set_s(TENSORFLOWF_TENSOR_NHWC);
  263. (*node_attr_map)[TENSORFLOW_ATTR_DATA_FORMAT] = df_attr_value2;
  264. domi::tensorflow::AttrValue df_attr_value;
  265. df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC);
  266. //设置padding属性
  267. domi::tensorflow::AttrValue pad_attr_value2;
  268. pad_attr_value2.set_s(TENSORFLOWF_OP_PADDING_SAME);
  269. (*node_attr_map)[TENSORFLOW_ATTR_PADDING] = pad_attr_value2;
  270. domi::tensorflow::AttrValue pad_attr_value;
  271. pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT);
  272. domi::tensorflow::NameAttrList name_attr_list;
  273. name_attr_list.set_name(std::to_string(0));
  274. name_attr_list.mutable_attr()->insert({"serialize_shape", shape});
  275. name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value});
  276. name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value});
  277. domi::tensorflow::AttrValue output_tensor_descs;
  278. *(output_tensor_descs.mutable_list()->add_func()) = name_attr_list;
  279. nodeDef->mutable_attr()->insert({ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_tensor_descs});
  280. return nodeDef;
  281. }
  282. NodeDef *fusioninitNodeDef(int index) {
  283. NodeDef *nodeDef = new NodeDef();
  284. google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = nodeDef->mutable_attr();
  285. //设置 type属性
  286. domi::tensorflow::AttrValue dtype_attr_value ;
  287. if (index == 0) {
  288. dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT);
  289. } else if (index == 1) {
  290. dtype_attr_value.set_type(domi::tensorflow::DT_INT32);
  291. } else if (index == 2) {
  292. dtype_attr_value.set_type(tensorflow::DT_HALF);
  293. }
  294. (*node_attr_map)[ge::TENSORFLOW_ATTR_DTYPE] = dtype_attr_value;
  295. //设置data_format属性
  296. domi::tensorflow::AttrValue df_attr_value;
  297. df_attr_value.set_s(TENSORFLOWF_TENSOR_NCHW);
  298. (*node_attr_map)[TENSORFLOW_ATTR_DATA_FORMAT] = df_attr_value;
  299. // 设置 tensor 属性
  300. domi::tensorflow::AttrValue value_attr_value;
  301. ::tensorflow::TensorProto* tensor = value_attr_value.mutable_tensor();
  302. ::tensorflow::TensorShapeProto* tensor_shape = tensor->mutable_tensor_shape();
  303. tensor_shape->clear_dim();
  304. ::tensorflow::TensorShapeProto_Dim* dim = tensor_shape->add_dim();
  305. dim->set_name("tensor dim");
  306. dim->set_size(1);
  307. if (index == 0) {
  308. tensor->set_dtype(domi::tensorflow::DT_FLOAT);
  309. float *addr = new float[1];
  310. *addr = 1.0;
  311. tensor->set_tensor_content((void *)addr, sizeof(float));
  312. (*node_attr_map)[TENSORFLOW_ATTR_VALUE] = value_attr_value;
  313. delete[] addr;
  314. } else if (index == 1) {
  315. tensor->set_dtype(domi::tensorflow::DT_INT32);
  316. int32_t *addr = new int32_t[1];
  317. *addr = 1;
  318. tensor->set_tensor_content((void *)addr, sizeof(int32_t));
  319. (*node_attr_map)[TENSORFLOW_ATTR_VALUE] = value_attr_value;
  320. delete[] addr;
  321. } else if (index == 2) {
  322. tensor->set_dtype(tensorflow::DT_HALF);
  323. tensor->add_half_val(1);
  324. (*node_attr_map)[TENSORFLOW_ATTR_VALUE] = value_attr_value;
  325. }
  326. return nodeDef;
  327. }
  328. NodeDef *MallocNodeDef(const string &name, const string &type) {
  329. NodeDef* node_def = new (std::nothrow) NodeDef();
  330. if (node_def != nullptr) {
  331. node_def->set_name(name);
  332. node_def->set_op(type);
  333. }
  334. return node_def;
  335. }
  336. void GenOriginNodeDef(ge::TensorFlowModelParser *tensorflow_parser, vector<string> &node_name_list) {
  337. NodeDef* pre_node_a = MallocNodeDef("pre_node_a", "Const");
  338. EXPECT_NE(pre_node_a, nullptr);
  339. {
  340. google::protobuf::Map< ::std::string, ::tensorflow::AttrValue >* node_attr_map = pre_node_a->mutable_attr();
  341. tensorflow::AttrValue attr_dtype;
  342. attr_dtype.set_type(tensorflow::DT_FLOAT);
  343. (*node_attr_map)["dtype"] = attr_dtype;
  344. tensorflow::AttrValue attr_value;
  345. tensorflow::TensorProto* tensor = attr_value.mutable_tensor();
  346. tensor->add_bool_val(true);
  347. tensor->set_dtype(tensorflow::DT_BOOL);
  348. (*node_attr_map)["value"] = attr_value;
  349. }
  350. tensorflow_parser->nodedef_map_["pre_node_a"] = pre_node_a;
  351. node_name_list.push_back("pre_node_a");
  352. NodeDef* pre_node_ctrl_in = MallocNodeDef("pre_node_ctrl_in", "Const");
  353. EXPECT_NE(pre_node_ctrl_in, nullptr);
  354. {
  355. ::google::protobuf::Map< ::std::string, ::tensorflow::AttrValue >* node_attr_map = pre_node_ctrl_in->mutable_attr();
  356. tensorflow::AttrValue attr_dtype;
  357. attr_dtype.set_type(tensorflow::DT_FLOAT);
  358. (*node_attr_map)["dtype"] = attr_dtype;
  359. tensorflow::AttrValue attr_value;
  360. tensorflow::TensorProto* tensor = attr_value.mutable_tensor();
  361. tensor->add_bool_val(true);
  362. tensor->set_dtype(tensorflow::DT_BOOL);
  363. (*node_attr_map)["value"] = attr_value;
  364. }
  365. tensorflow_parser->nodedef_map_["pre_node_ctrl_in"] = pre_node_ctrl_in;
  366. node_name_list.push_back("pre_node_ctrl_in");
  367. NodeDef* post_node_b = MallocNodeDef("post_node_b", "Identity");
  368. EXPECT_NE(post_node_b, nullptr);
  369. tensorflow_parser->nodedef_map_["post_node_b"] = post_node_b;
  370. node_name_list.push_back("post_node_b");
  371. NodeDef* post_node_c = MallocNodeDef("post_node_c", "Identity");
  372. EXPECT_NE(post_node_c, nullptr);
  373. tensorflow_parser->nodedef_map_["post_node_c"] = post_node_c;
  374. node_name_list.push_back("post_node_c");
  375. NodeDef* post_node_d = MallocNodeDef("post_node_d", "Identity");
  376. EXPECT_NE(post_node_d, nullptr);
  377. tensorflow_parser->nodedef_map_["post_node_d"] = post_node_d;
  378. node_name_list.push_back("post_node_d");
  379. }
  380. void FreeNodeDefMap(ge::TensorFlowModelParser *tensorflow_parser, set<string> &malloc_node_name_list) {
  381. for (auto &item : tensorflow_parser->nodedef_map_) {
  382. if (item.second != nullptr && malloc_node_name_list.count(item.first) > 0) {
  383. delete (item.second);
  384. item.second = nullptr;
  385. }
  386. }
  387. }
  388. void GenFusionScopesResult(shared_ptr<ScopeGraph> &scope_graph, FusionScopesResult *fusion_rlt,
  389. const string &fusion_op_name) {
  390. if (fusion_rlt == nullptr) {
  391. return;
  392. }
  393. fusion_rlt->InsertInputs("scope_node_1", {0}); // scope input 0
  394. fusion_rlt->InsertOutputs("scope_node_m", {0}); // scope output 0
  395. fusion_rlt->InsertOutputs("scope_node_n", {1}); // scope output 1
  396. fusion_rlt->SetType(ge::kScopeToMultiNodes);
  397. fusion_rlt->SetName(fusion_op_name);
  398. fusion_rlt->SetDescription("Description for fusion node");
  399. // Add inner nodes in sequence.
  400. auto node1 = fusion_rlt->AddInnerNode("inner_node_1", "Unique"); // add inner node1
  401. CHECK_INNER_NODE_CONDITION(node1 != nullptr, fusion_rlt);
  402. auto ret = node1
  403. ->InsertInput(ge::kInputFromFusionScope, 0) // Input from 0th of boundary (a)
  404. .InsertOutput(ge::kOutputToFusionScope, 0) // Output to 0th of boundary (b)
  405. .InsertOutput("inner_node_2", 0) // Output to input 0th of internal node 2
  406. .BuildInnerNode(); // Construct an internal Operator
  407. CHECK_INNER_NODE_CONDITION(ret == ge::GRAPH_SUCCESS, fusion_rlt);
  408. string str_val = "This is a string.";
  409. node1->MutableOperator()->SetAttr("key1", 2); // Set integer attribute
  410. node1->MutableOperator()->SetAttr("key2", str_val); // Set the string attribute
  411. node1->MutableOperator()->SetAttr("key3", true); // Set boolean attribute
  412. auto node2 = fusion_rlt->AddInnerNode("inner_node_2", "Identity"); // add inner node2
  413. CHECK_INNER_NODE_CONDITION(node2 != nullptr, fusion_rlt);
  414. ret = node2
  415. ->InsertInput("inner_node_1", 1) // The input comes from the 1st output of internal node 1
  416. .InsertOutput("inner_node_3", 0) // Output to input 0th of internal node 3
  417. .BuildInnerNode();
  418. CHECK_INNER_NODE_CONDITION(ret == ge::GRAPH_SUCCESS, fusion_rlt);
  419. node2->SetInputFormat("x", "NHWC");
  420. node2->SetOutputFormat("y", "NHWC");
  421. auto node3 = fusion_rlt->AddInnerNode("inner_node_3", "Identity"); // add inner node3
  422. CHECK_INNER_NODE_CONDITION(node3 != nullptr, fusion_rlt);
  423. ret = node3
  424. ->InsertInput("inner_node_2", 0) // The input comes from the 0th output of internal node 2
  425. .InsertOutput(ge::kOutputToFusionScope, 1) // Output to 1st of boundary (c)
  426. .BuildInnerNode();
  427. CHECK_INNER_NODE_CONDITION(ret == ge::GRAPH_SUCCESS, fusion_rlt);
  428. scope_graph->impl_->AddFusionScopesResult(fusion_rlt);
  429. }
  430. void GenOriginContext(ge::TensorFlowModelParser *tensorflow_parser, const string &fusion_op_name) {
  431. // op_node_context for fusion op
  432. ge::OpNodeContext op_node_context;
  433. op_node_context.input_map["pre_node_a"].push_back({0, 0});
  434. op_node_context.input_map["pre_node_ctrl_in"].push_back({-1, -1}); // ctrl edges
  435. op_node_context.output_map["post_node_b"].push_back({0, 0});
  436. op_node_context.output_map["post_node_c"].push_back({1, 0});
  437. op_node_context.output_map["post_node_d"].push_back({-1, -1});
  438. op_node_context.output_map["_Retval"].push_back({0, 1});
  439. // ctrl edges
  440. tensorflow_parser->op_node_context_map_[fusion_op_name] = op_node_context;
  441. tensorflow_parser->SaveEdgesControlInfo(fusion_op_name, -1);
  442. // op_node_context for pre_node_a
  443. ge::OpNodeContext op_node_context_a;
  444. op_node_context_a.output_map[fusion_op_name].push_back({0, 0});
  445. tensorflow_parser->op_node_context_map_["pre_node_a"] = op_node_context_a;
  446. // op_node_context for pre_node_ctrl_in
  447. ge::OpNodeContext op_node_context_ctrl_in;
  448. op_node_context_ctrl_in.output_map[fusion_op_name].push_back({-1, -1}); // ctrl edges
  449. tensorflow_parser->op_node_context_map_["pre_node_ctrl_in"] = op_node_context_ctrl_in;
  450. // op_node_context for post_node_b
  451. ge::OpNodeContext op_node_context_b;
  452. op_node_context_b.input_map[fusion_op_name].push_back({0, 0});
  453. tensorflow_parser->op_node_context_map_["post_node_b"] = op_node_context_b;
  454. // op_node_context for post_node_c
  455. ge::OpNodeContext op_node_context_c;
  456. op_node_context_c.output_map["post_node_d"].push_back({0, 0});
  457. tensorflow_parser->op_node_context_map_["post_node_c"] = op_node_context_c;
  458. // op_node_context for post_node_d
  459. ge::OpNodeContext op_node_context_d;
  460. op_node_context_d.input_map[fusion_op_name].push_back({-1, -1}); // ctrl edges
  461. tensorflow_parser->op_node_context_map_["post_node_d"] = op_node_context_d;
  462. // op_node_context for Retval
  463. ge::OpNodeContext op_node_context_Retval;
  464. op_node_context_d.input_map["post_node_d"].push_back({-1, -1});
  465. op_node_context_c.output_map["fusion_op_name"].push_back({0,1});
  466. tensorflow_parser->op_node_context_map_["_Retval"] = op_node_context_Retval;
  467. tensorflow_parser->SaveEdgesControlInfo("op_node_context_Retval", -1);
  468. string fusion_op_type = ge::kScopeToMultiNodes;
  469. string description = "fusion op description";
  470. tensorflow_parser->fusion_op_type_map_[fusion_op_name].push_back(fusion_op_type);
  471. tensorflow_parser->fusion_op_type_map_[fusion_op_name].push_back(description);
  472. }
  473. void register_tbe_op() {
  474. std::vector<OpRegistrationData> registrationDatas = OpRegistry::Instance()->registrationDatas;
  475. for (OpRegistrationData reg_data : registrationDatas) {
  476. OpRegistrationTbe::Instance()->Finalize(reg_data);
  477. OpRegistry::Instance()->Register(reg_data);
  478. }
  479. OpRegistry::Instance()->registrationDatas.clear();
  480. }
  481. NodeDef *initNodeDef_axis_dims() {
  482. NodeDef *nodeDef = new NodeDef();
  483. google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = nodeDef->mutable_attr();
  484. //设置T属性
  485. domi::tensorflow::AttrValue dtype_attr_value ;
  486. dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT);
  487. (*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value;
  488. //设置strides属性
  489. domi::tensorflow::AttrValue axis_attr_value;
  490. ::tensorflow::AttrValue_ListValue* list = axis_attr_value.mutable_list();
  491. list->add_i(1);
  492. list->add_i(2);
  493. (*node_attr_map)[ge::SQUEEZE_ATTR_AXIS] = axis_attr_value;
  494. (*node_attr_map)[ge::SQUEEZE_ATTR_DIMS] = axis_attr_value;
  495. return nodeDef;
  496. }
  497. NodeDef *initNodeDef_dims() {
  498. NodeDef *nodeDef = new NodeDef();
  499. ::google::protobuf::Map<std::string, tensorflow::AttrValue > *node_attr_map = nodeDef->mutable_attr();
  500. //设置T属性
  501. domi::tensorflow::AttrValue dtype_attr_value ;
  502. dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT);
  503. (*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value;
  504. //设置strides属性
  505. domi::tensorflow::AttrValue axis_attr_value;
  506. ::tensorflow::AttrValue_ListValue* list = axis_attr_value.mutable_list();
  507. list->add_i(1);
  508. list->add_i(2);
  509. (*node_attr_map)[ge::SQUEEZE_ATTR_DIMS] = axis_attr_value;
  510. return nodeDef;
  511. }
  512. void CreateOpDef(const string& _name, const string& _type, ge::OpDescPtr opDef) {
  513. tensorflow::OpDef tsOpDef;
  514. tsOpDef.set_name(_name);
  515. tensorflow::OpDef_ArgDef* outArgDef = tsOpDef.add_output_arg();
  516. outArgDef->set_name(_name);
  517. outArgDef->set_description("outArgDef");
  518. outArgDef->set_type((tensorflow::DataType)3);
  519. if ((_name == "A") || (_name == "B")) {
  520. tensorflow::OpDef_ArgDef* argDef1 = tsOpDef.add_output_arg();
  521. string name = _name+"t";
  522. argDef1->set_name(name);
  523. argDef1->set_description("this is a test 2");
  524. argDef1->set_type((tensorflow::DataType)3);
  525. }
  526. if ((_name == "C") ) {
  527. outArgDef->set_number_attr("num");
  528. }
  529. if ((_name == "D") ) {
  530. outArgDef->set_type_list_attr("type_list");
  531. }
  532. string strTsOpDef;
  533. tsOpDef.SerializeToString(&strTsOpDef);
  534. ge::AttrUtils::SetStr(opDef, "op_def", strTsOpDef);
  535. tensorflow::NodeDef nodedef;
  536. nodedef.set_name(_name);
  537. nodedef.set_op(_name);
  538. string name("op_def");
  539. tensorflow::AttrValue value;
  540. value.set_s(strTsOpDef);
  541. TensorFlowUtil::AddNodeAttr(name, value, &nodedef);
  542. value.set_i(1);
  543. TensorFlowUtil::AddNodeAttr("num", value, &nodedef);
  544. value.mutable_list();
  545. TensorFlowUtil::AddNodeAttr("type_list", value, &nodedef);
  546. string strNodeDef;
  547. nodedef.SerializeToString(&strNodeDef);
  548. ge::GeAttrValue::BYTES nodedefBytes;
  549. nodedefBytes = ge::GeAttrValue::BYTES::CopyFrom((uint8_t*)strNodeDef.data(), strNodeDef.length());
  550. ge::AttrUtils::SetBytes(opDef, "node_def", nodedefBytes);
  551. if ((_name== "S") || (_name == "K")) {
  552. int index = 0;
  553. ge::AttrUtils::SetInt(opDef, "T", 1);
  554. ge::AttrUtils::SetInt(opDef, "arg_index", index);
  555. ge::AttrUtils::SetInt(opDef, "ret_index", index);
  556. }
  557. }
  558. ge::NodePtr AddNode(ge::ComputeGraphPtr graph, const string& _name, const string& _type,int32_t i_n, int32_t o_n) {
  559. ge::OpDescPtr opDef = std::make_shared<ge::OpDesc>();
  560. opDef->SetName(_name);
  561. opDef->SetType(_type);
  562. for(int32_t i = 0; i < i_n; i++) {
  563. ge::GeTensorDesc input;
  564. input.SetDataType((ge::DataType)1);
  565. opDef->AddInputDesc(input);
  566. }
  567. for(int32_t i = 0;i < o_n; i++) {
  568. ge::GeTensorDesc output;
  569. output.SetDataType((ge::DataType)1);
  570. opDef->AddOutputDesc(output);
  571. }
  572. CreateOpDef(_name, _type, opDef);
  573. return graph->AddNode(opDef);
  574. }
  575. void MakeDagGraph(ge::ComputeGraphPtr graph, const string& input_node_type) {
  576. ge::NodePtr node_s = AddNode(graph, "S", parser::DATA,1,1);
  577. ge::NodePtr node_a = AddNode(graph, "A", "testa",1,2);
  578. ge::NodePtr node_b = AddNode(graph, "B", "testb",1,2);
  579. ge::NodePtr node_c = AddNode(graph, "C", "testc",1,1);
  580. ge::NodePtr node_d = AddNode(graph, "D", "testd",1,1);
  581. ge::NodePtr node_e = AddNode(graph, "E", "teste",1,1);
  582. ge::NodePtr node_f = AddNode(graph, "F", "testf",1,1);
  583. ge::NodePtr node_g = AddNode(graph, "G", "testg",2,1);
  584. ge::NodePtr node_h = AddNode(graph, "H", "testh",1,1);
  585. ge::NodePtr node_i = AddNode(graph, "I", "testi",1,1);
  586. ge::NodePtr node_j = AddNode(graph, "J", "testj",2,1);
  587. ge::NodePtr node_k = AddNode(graph, "K", parser::NETOUTPUT,1,1);
  588. ge::GraphUtils::AddEdge(node_s->GetOutDataAnchor(0), node_a->GetInDataAnchor(0));
  589. ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0));
  590. ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(1), node_c->GetInDataAnchor(0));
  591. ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_d->GetInDataAnchor(0));
  592. ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_e->GetInDataAnchor(0));
  593. ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_g->GetInDataAnchor(0));
  594. ge::GraphUtils::AddEdge(node_d->GetOutDataAnchor(0), node_f->GetInDataAnchor(0));
  595. ge::GraphUtils::AddEdge(node_e->GetOutDataAnchor(0), node_g->GetInDataAnchor(1));
  596. ge::GraphUtils::AddEdge(node_f->GetOutDataAnchor(0), node_h->GetInDataAnchor(0));
  597. ge::GraphUtils::AddEdge(node_g->GetOutDataAnchor(0), node_j->GetInDataAnchor(0));
  598. ge::GraphUtils::AddEdge(node_h->GetOutDataAnchor(0), node_i->GetInDataAnchor(0));
  599. ge::GraphUtils::AddEdge(node_i->GetOutDataAnchor(0), node_j->GetInDataAnchor(1));
  600. ge::GraphUtils::AddEdge(node_j->GetOutDataAnchor(0), node_k->GetInDataAnchor(0));
  601. ge::GraphUtils::AddEdge(node_h->GetOutControlAnchor(), node_j->GetInControlAnchor());
  602. }
  603. void ChangeDataType(tensorflow::NodeDef* node_tf, int32_t data_type)
  604. {
  605. domi::tensorflow::AttrValue input_attr_value;
  606. google::protobuf::Map<std::string, tensorflow::AttrValue>* attr = node_tf->mutable_attr();
  607. google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr->find(ge::ATTR_NAME_INPUT_TENSOR_DESC);
  608. if (it != attr->end()) {
  609. input_attr_value = it->second;
  610. }
  611. (*attr)[ge::ATTR_NAME_INPUT_TENSOR_DESC] = input_attr_value;
  612. }
  613. NodeDef* AddGraphNode(GraphDef *graph, string name, string optype, string input)
  614. {
  615. NodeDef * node_def = graph->add_node();
  616. node_def->set_name(name);
  617. node_def->set_op(optype);
  618. node_def->add_input(input);
  619. return node_def;
  620. }
  621. }
  622. namespace {
  623. REG_OP(Data)
  624. .INPUT(x, TensorType::ALL())
  625. .OUTPUT(y, TensorType::ALL())
  626. .ATTR(index, Int, 0)
  627. .OP_END_FACTORY_REG(Data)
  628. REG_OP(Add)
  629. .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
  630. DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
  631. DT_COMPLEX64, DT_STRING}))
  632. .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
  633. DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
  634. DT_COMPLEX64, DT_STRING}))
  635. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
  636. DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
  637. DT_COMPLEX64, DT_STRING}))
  638. .OP_END_FACTORY_REG(Add)
  639. }
  640. static Status FusionParserParams(const std::vector<const google::protobuf::Message *> inside_nodes, ge::Operator &op) {
  641. return domi::SUCCESS;
  642. }
  643. static MemBuffer* MemBufferFromFile(const char *path)
  644. {
  645. char path_temp[PATH_MAX + 1] = {0x00};
  646. if(strlen(path) > PATH_MAX || nullptr == realpath(path, path_temp)) {
  647. return nullptr;
  648. }
  649. FILE *fp = fopen(path_temp, "r+");
  650. if (fp == nullptr) {
  651. return nullptr;
  652. }
  653. // get model file length
  654. if (0 != fseek(fp, 0, SEEK_END)) {
  655. fclose(fp);
  656. return nullptr;
  657. }
  658. long file_length = ftell(fp);
  659. if (fseek(fp, 0, SEEK_SET)) {
  660. fclose(fp);
  661. return nullptr;
  662. }
  663. if (file_length <= 0) {
  664. fclose(fp);
  665. return nullptr;
  666. }
  667. // alloc model buffer
  668. void *data = malloc((unsigned int)file_length);
  669. if (!data) {
  670. fclose(fp);
  671. return nullptr;
  672. }
  673. // read file into memory
  674. uint32_t read_size = (uint32_t)fread(data, 1, (unsigned int)file_length, fp);
  675. // check if read success
  676. if ((long)read_size != file_length) {
  677. free(data);
  678. data = nullptr;
  679. fclose(fp);
  680. return nullptr;
  681. }
  682. // close model file
  683. fclose(fp);
  684. // create an MemBuffer
  685. MemBuffer* membuf = new MemBuffer();
  686. if (!membuf) {
  687. free(data);
  688. data = nullptr;
  689. return nullptr;
  690. }
  691. membuf->data = malloc((unsigned int)read_size);
  692. // set size && data
  693. membuf->size = (uint32_t)read_size;
  694. memcpy((char*)membuf->data, (char*)data, read_size);
  695. free(data);
  696. return membuf;
  697. }
  698. /// placeholder0 placeholder1
  699. /// | /\ /\ |
  700. /// | / \/ \ |
  701. /// | / /\ \ |
  702. /// | | / \ | |
  703. /// | add0 mul0 |
  704. /// | / /c | \ |
  705. /// mul1 --- / | add1
  706. /// \ | |
  707. /// \ ---- add2 |
  708. /// | |
  709. /// retval0 retval1
  710. void CreateGraphDef(domi::tensorflow::GraphDef &graph_def) {
  711. // 1. add node
  712. auto placeholder0 = graph_def.add_node();
  713. auto placeholder1 = graph_def.add_node();
  714. auto add0 = graph_def.add_node();
  715. auto add1 = graph_def.add_node();
  716. auto mul0 = graph_def.add_node();
  717. auto mul1 = graph_def.add_node();
  718. auto add2 = graph_def.add_node();
  719. auto retval0 = graph_def.add_node();
  720. auto retval1 = graph_def.add_node();
  721. auto softmax0 = graph_def.add_node();
  722. auto softmax1 = graph_def.add_node();
  723. // 2. set info
  724. placeholder0->set_name("placeholder0");
  725. placeholder0->set_op("PlaceHolder");
  726. placeholder1->set_name("placeholder1");
  727. placeholder1->set_op("PlaceHolder");
  728. add0->set_name("add0");
  729. add0->set_op("Add");
  730. add1->set_name("add1");
  731. add1->set_op("Add");
  732. add2->set_name("add2");
  733. add2->set_op("Add");
  734. mul0->set_name("mul0");
  735. mul0->set_op("Mul");
  736. mul1->set_name("mul1");
  737. mul1->set_op("Mul");
  738. retval0->set_name("retval0");
  739. retval0->set_op("_RetVal");
  740. retval1->set_name("retval1");
  741. retval1->set_op("_RetVal");
  742. retval0->set_name("retval0");
  743. retval0->set_op("_RetVal");
  744. retval1->set_name("retval1");
  745. retval1->set_op("_RetVal");
  746. softmax0->set_name("Softmax0");
  747. softmax0->set_op("Softmax");
  748. softmax1->set_name("Softmax1");
  749. softmax1->set_op("Softmax");
  750. // 3. add edges
  751. add0->add_input("placeholder0");
  752. add0->add_input("placeholder1");
  753. mul0->add_input("placeholder0");
  754. mul0->add_input("placeholder1");
  755. mul1->add_input("placeholder0");
  756. mul1->add_input("add0");
  757. mul1->add_input("^mul0");
  758. add1->add_input("mul0");
  759. add1->add_input("placeholder1");
  760. add2->add_input("mul1");
  761. add2->add_input("mul0");
  762. retval0->add_input("add2:0");
  763. retval1->add_input("add1:0");
  764. softmax0->add_input("add3:0");
  765. softmax0->add_input("add2:0");
  766. }
  767. TEST_F(STestTensorflowParser, tensorflow_parser_success) {
  768. RegisterCustomOp();
  769. std::string case_dir = __FILE__;
  770. ParserOperator unused("Add");
  771. case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
  772. std::string model_file = case_dir + "/origin_models/tf_add.pb";
  773. std::map<ge::AscendString, ge::AscendString> parser_params;
  774. ge::Graph graph;
  775. auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph);
  776. ASSERT_EQ(ret, SUCCESS);
  777. ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
  778. auto output_nodes_info = compute_graph->GetGraphOutNodesInfo();
  779. ASSERT_EQ(output_nodes_info.size(), 1);
  780. EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "add_test_1");
  781. EXPECT_EQ((output_nodes_info.at(0).second), 0);
  782. auto &net_out_name = ge::GetParserContext().net_out_nodes;
  783. ASSERT_EQ(net_out_name.size(), 1);
  784. EXPECT_EQ(net_out_name.at(0), "add_test_1:0");
  785. }
  786. TEST_F(STestTensorflowParser, tensorflow_model_Failed) {
  787. ge::Graph graph;
  788. std::string caseDir = __FILE__;
  789. std::size_t idx = caseDir.find_last_of("/");
  790. caseDir = caseDir.substr(0, idx);
  791. std::string modelFile = caseDir + "/origin_models/model.pb";
  792. auto status = ge::aclgrphParseTensorFlow(modelFile.c_str(), graph);
  793. EXPECT_EQ(status, ge::SUCCESS);
  794. modelFile = caseDir + "/origin_models/test_depth_wise_conv2d.pb";
  795. status = ge::aclgrphParseTensorFlow(modelFile.c_str(), graph);
  796. EXPECT_EQ(status, ge::GRAPH_FAILED);
  797. }
  798. TEST_F(STestTensorflowParser, tensorflow_model_not_exist) {
  799. ge::Graph graph;
  800. std::string caseDir = __FILE__;
  801. std::size_t idx = caseDir.find_last_of("/");
  802. caseDir = caseDir.substr(0, idx);
  803. // model file is not exist
  804. std::string modelFile = caseDir + "/origin_models/conv2d_explicit1_pad.pb";
  805. auto status = ge::aclgrphParseTensorFlow(modelFile.c_str(), graph);
  806. EXPECT_EQ(status, ge::GRAPH_FAILED);
  807. }
  808. TEST_F(STestTensorflowParser, parser_tensorflow_model) {
  809. std::string caseDir = __FILE__;
  810. std::size_t idx = caseDir.find_last_of("/");
  811. caseDir = caseDir.substr(0, idx);
  812. std::string modelFile = caseDir + "/origin_models/tf_add.pb";
  813. const char *model_file = modelFile.c_str();
  814. std::string op_name = "ge_ascend_irgraph";
  815. ge::Graph graph(op_name);
  816. std::map<ge::AscendString, ge::AscendString> parser_options = {
  817. {ge::AscendString(ge::ir_option::INPUT_FORMAT), ge::AscendString("NHWC")},
  818. };
  819. auto ret_graph = ge::aclgrphParseTensorFlow(model_file, parser_options, graph);
  820. EXPECT_EQ(ret_graph, ge::FAILED);
  821. // parser tensorflow model out_node_size is equal to index
  822. string graph_name;
  823. AclGrphParseUtil acl_graph_parse_util;
  824. std::map<AscendString, AscendString> out_nodes_with_node_and_index = {
  825. {AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:1")}};
  826. ParerSTestsUtils::ClearParserInnerCtx();
  827. auto ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_with_node_and_index, graph_name);
  828. ret_graph = ge::aclgrphParseTensorFlow(model_file, graph);
  829. EXPECT_EQ(ret_graph, domi::FAILED);
  830. // parser tensorflow model success
  831. modelFile = caseDir + "/origin_models/model.pb";
  832. model_file = modelFile.c_str();
  833. out_nodes_with_node_and_index = {{AscendString(ge::ir_option::OUT_NODES), AscendString("x:0;y:0")}};
  834. ParerSTestsUtils::ClearParserInnerCtx();
  835. ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_with_node_and_index, graph_name);
  836. ret_graph = ge::aclgrphParseTensorFlow(model_file, graph);
  837. EXPECT_EQ(ret_graph, domi::SUCCESS);
  838. }
  839. TEST_F(STestTensorflowParser, tensorflow_parser_to_json)
  840. {
  841. TensorFlowModelParser modelParser;
  842. std::string caseDir = __FILE__;
  843. std::size_t idx = caseDir.find_last_of("/");
  844. caseDir = caseDir.substr(0, idx);
  845. std::string modelFile = caseDir + "/origin_models/tf_add.pb";
  846. std::string jsonFile = caseDir + "/origin_models/test.json";
  847. const char *model_file = modelFile.c_str();
  848. const char *json_file = jsonFile.c_str();
  849. Status ret = modelParser.ToJson(model_file, json_file);
  850. EXPECT_EQ(ret, SUCCESS);
  851. }
  852. TEST_F(STestTensorflowParser, tensorflow_parserfrommemory_failed)
  853. {
  854. TensorFlowModelParser modelParser;
  855. std::string caseDir = __FILE__;
  856. std::size_t idx = caseDir.find_last_of("/");
  857. caseDir = caseDir.substr(0, idx);
  858. std::string modelFile = caseDir + "/origin_models/tf_add.pb";
  859. const char *data = modelFile.c_str();
  860. uint32_t size = 1;
  861. ge::Graph graph;
  862. std::map<ge::AscendString, ge::AscendString> parser_params;
  863. Status ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph);
  864. ASSERT_EQ(ret, SUCCESS);
  865. modelFile = caseDir + "/origin_models/tf_add.pb";
  866. parser_params = {{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}};
  867. ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph);
  868. ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
  869. ret = modelParser.ParseFromMemory(data, size, compute_graph);
  870. EXPECT_EQ(ret, INTERNAL_ERROR);
  871. }
  872. TEST_F(STestTensorflowParser, modelparser_parsefrommemory_success)
  873. {
  874. std::string caseDir = __FILE__;
  875. std::size_t idx = caseDir.find_last_of("/");
  876. caseDir = caseDir.substr(0, idx);
  877. std::string modelFile = caseDir + "/origin_models/tf_add.pb";
  878. const char* tmp_tf_pb_model = modelFile.c_str();
  879. ge::Graph graph;
  880. std::map<ge::AscendString, ge::AscendString> parser_params;
  881. Status ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph);
  882. ASSERT_EQ(ret, SUCCESS);
  883. ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
  884. TensorFlowModelParser modelParser;
  885. MemBuffer* memBuffer = MemBufferFromFile(tmp_tf_pb_model);
  886. PreChecker::Instance().HasError() == false;
  887. ret = modelParser.ParseFromMemory((char*)memBuffer->data, memBuffer->size, compute_graph);
  888. free(memBuffer->data);
  889. delete memBuffer;
  890. }
  891. TEST_F(STestTensorflowParser, weightsparser_parsefrommemory_success)
  892. {
  893. std::string caseDir = __FILE__;
  894. std::size_t idx = caseDir.find_last_of("/");
  895. caseDir = caseDir.substr(0, idx);
  896. std::string modelFile = caseDir + "/origin_models/tf_add.pb";
  897. const char* tmp_tf_pb_model = modelFile.c_str();
  898. ge::Graph graph;
  899. std::map<ge::AscendString, ge::AscendString> parser_params;
  900. Status ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph);
  901. ASSERT_EQ(ret, SUCCESS);
  902. ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
  903. auto weights_parser = domi::WeightsParserFactory::Instance()->CreateWeightsParser(domi::TENSORFLOW);
  904. MemBuffer* memBuffer = MemBufferFromFile(tmp_tf_pb_model);
  905. ret = weights_parser->ParseFromMemory((char*)memBuffer->data, memBuffer->size, compute_graph);
  906. free(memBuffer->data);
  907. delete memBuffer;
  908. EXPECT_EQ(SUCCESS, ret);
  909. }
  910. std::string getGraphCallbackV2(string subgraph_name)
  911. {
  912. std::string caseDir = __FILE__;
  913. std::size_t idx = caseDir.find_last_of("/");
  914. caseDir = caseDir.substr(0, idx);
  915. subgraph_name = caseDir + "/origin_models/tf_add.pb";
  916. return subgraph_name;
  917. }
  918. TEST_F(STestTensorflowParser, parser_ParseProtoWithSubgraphV2)
  919. {
  920. std::string caseDir = __FILE__;
  921. std::size_t idx = caseDir.find_last_of("/");
  922. caseDir = caseDir.substr(0, idx);
  923. const std::string root_proto = caseDir + "/origin_models/tf_add.pb";
  924. ge::Graph graph;
  925. std::map<ge::AscendString, ge::AscendString> parser_params;
  926. Status ret = ge::aclgrphParseTensorFlow(root_proto.c_str(), parser_params, graph);
  927. ASSERT_EQ(ret, SUCCESS);
  928. ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
  929. domi::GetGraphCallbackV2 callback(&getGraphCallbackV2);
  930. TensorFlowModelParser parser;
  931. ret = parser.ParseProtoWithSubgraph(root_proto, callback, root_graph);
  932. }
  933. TEST_F(STestTensorflowParser, parser_ConvertToGeDataType)
  934. {
  935. // convert to ge type success
  936. const uint32_t type1 = domi::tensorflow::DataType::DT_FLOAT;
  937. TensorFlowModelParser parser;
  938. ge::DataType dataType = parser.ConvertToGeDataType(type1);
  939. ASSERT_EQ(dataType, ge::DataType::DT_FLOAT);
  940. const uint32_t type2 = 80; // invalid type
  941. dataType = parser.ConvertToGeDataType(type2);
  942. ASSERT_EQ(dataType, ge::DataType::DT_UNDEFINED);
  943. }
  944. TEST_F(STestTensorflowParser, tensorflow_ParserProto_failed)
  945. {
  946. std::string caseDir = __FILE__;
  947. std::size_t idx = caseDir.find_last_of("/");
  948. caseDir = caseDir.substr(0, idx);
  949. const std::string root_proto = caseDir + "/origin_models/avgpool3dgrad.pb.txt";
  950. domi::tensorflow::GraphDef graphDef;
  951. ge::Graph graph;
  952. std::map<ge::AscendString, ge::AscendString> parser_params;
  953. Status ret = ge::aclgrphParseTensorFlow(root_proto.c_str(), parser_params, graph);
  954. ASSERT_EQ(ret, SUCCESS);
  955. ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
  956. TensorFlowModelParser tensorflow_parser;
  957. ret = tensorflow_parser.ParseProto(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
  958. EXPECT_EQ(PARAM_INVALID, ret);
  959. // proto解析失败
  960. bool protoRet = parser::ReadProtoFromText(root_proto.c_str(), &graphDef);
  961. ASSERT_EQ(protoRet, false);
  962. ret = tensorflow_parser.ParseProto(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
  963. ASSERT_EQ(ret, PARAM_INVALID);
  964. std::string serialized_proto = "";
  965. ret = tensorflow_parser.ParseProto(serialized_proto, root_graph);
  966. ASSERT_EQ(ret, FAILED);
  967. }
  968. TEST_F(STestTensorflowParser, tensorflow_parserAllGraph_failed)
  969. {
  970. std::string caseDir = __FILE__;
  971. std::size_t idx = caseDir.find_last_of("/");
  972. caseDir = caseDir.substr(0, idx);
  973. const std::string root_proto = caseDir + "/origin_models/conv2d.pb";
  974. domi::tensorflow::GraphDef graphDef;
  975. CreateGraphDef(graphDef);
  976. auto no_op = graphDef.add_node();
  977. no_op->set_name("no_op");
  978. no_op->set_op("NoOp");
  979. no_op->add_input("placeholder0");
  980. no_op->add_input("placeholder1");
  981. ge::Graph graph;
  982. std::map<ge::AscendString, ge::AscendString> parser_params;
  983. Status ret = ge::aclgrphParseTensorFlow(root_proto.c_str(), parser_params, graph);
  984. ASSERT_EQ(ret, SUCCESS);
  985. ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
  986. TensorFlowModelParser tensorflow_parser;
  987. ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
  988. EXPECT_EQ(INTERNAL_ERROR, ret);
  989. }
  990. TEST_F(STestTensorflowParser, test_parse_acl_output_nodes)
  991. {
  992. AclGrphParseUtil acl_graph_parse_util;
  993. string graph_name;
  994. // case 1: Normal with 'node and index'
  995. ParerSTestsUtils::ClearParserInnerCtx();
  996. GetParserContext().type = domi::ONNX;
  997. std::map<AscendString, AscendString> out_nodes_with_node_and_index = {
  998. {AscendString(ge::ir_option::OUT_NODES), AscendString("Out1:0;Out2:1")}};
  999. ParerSTestsUtils::ClearParserInnerCtx();
  1000. auto ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_with_node_and_index, graph_name);
  1001. ASSERT_EQ(ret, SUCCESS);
  1002. EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 2);
  1003. EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 2);
  1004. EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 0);
  1005. // case 2: Normal with 'tensor name'
  1006. ParerSTestsUtils::ClearParserInnerCtx();
  1007. GetParserContext().type = domi::ONNX;
  1008. std::map<AscendString, AscendString> out_nodes_with_tensor_name = {
  1009. {AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out_tensor_2")}};
  1010. ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_with_tensor_name, graph_name);
  1011. ASSERT_EQ(ret, SUCCESS);
  1012. EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0);
  1013. EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0);
  1014. EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 2);
  1015. // case 3: Failed with 'node and index' before 'tensor name'
  1016. ParerSTestsUtils::ClearParserInnerCtx();
  1017. GetParserContext().type = domi::ONNX;
  1018. std::map<AscendString, AscendString> out_nodes_mode_mixex_pre = {
  1019. {AscendString(ge::ir_option::OUT_NODES), AscendString("Out1:0;Out2:1;Out_tensor_1;Out_tensor_2")}};
  1020. ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_pre, graph_name);
  1021. ASSERT_EQ(ret, PARAM_INVALID);
  1022. EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 2);
  1023. EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 2);
  1024. EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 0);
  1025. // case 4: Failed with 'node and index' inserted in 'tensor name'
  1026. ParerSTestsUtils::ClearParserInnerCtx();
  1027. GetParserContext().type = domi::ONNX;
  1028. std::map<AscendString, AscendString> out_nodes_mode_mixex_mid = {
  1029. {AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out1:0;Out2:1;Out_tensor_2")}};
  1030. ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_mid, graph_name);
  1031. ASSERT_EQ(ret, PARAM_INVALID);
  1032. EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0);
  1033. EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0);
  1034. EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 1);
  1035. // case 5: Failed with 'node and index' after 'tensor name'
  1036. ParerSTestsUtils::ClearParserInnerCtx();
  1037. GetParserContext().type = domi::ONNX;
  1038. std::map<AscendString, AscendString> out_nodes_mode_mixex_post = {
  1039. {AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out_tensor_2;Out1:0;Out2:1")}};
  1040. ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_post, graph_name);
  1041. ASSERT_EQ(ret, PARAM_INVALID);
  1042. EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0);
  1043. EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0);
  1044. EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 2);
  1045. }
  1046. TEST_F(STestTensorflowParser, parse_AutoMappingByOp) {
  1047. static const string KEY_STRING = "key_string";
  1048. static const string KEY_INT = "key_int";
  1049. static const string KEY_FLOAT = "key_float";
  1050. static const string KEY_BOOL = "key_bool";
  1051. static const string KEY_TYPE = "key_type";
  1052. static const string VALUE_STRING = "string";
  1053. static const int64_t VALUE_INT = 1;
  1054. static const float VALUE_FLOAT = 1.0;
  1055. static const bool VALUE_BOOL = true;
  1056. static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT;
  1057. static const string VALUE_NAME = "test_name";
  1058. ge::OpDescPtr op_desc = std::make_shared<ge::OpDesc>();
  1059. NodeDef node_def;
  1060. domi::tensorflow::AttrValue value;
  1061. ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc);
  1062. node_def.set_name(VALUE_NAME);
  1063. value.set_s(VALUE_STRING);
  1064. TensorFlowUtil::AddNodeAttr(KEY_STRING, value, &node_def);
  1065. value.set_i(VALUE_INT);
  1066. TensorFlowUtil::AddNodeAttr(KEY_INT, value, &node_def);
  1067. value.set_f(VALUE_FLOAT);
  1068. TensorFlowUtil::AddNodeAttr(KEY_FLOAT, value, &node_def);
  1069. value.set_b(VALUE_BOOL);
  1070. TensorFlowUtil::AddNodeAttr(KEY_BOOL, value, &node_def);
  1071. value.set_type(VALUE_TYPE);
  1072. TensorFlowUtil::AddNodeAttr(KEY_TYPE, value, &node_def);
  1073. domi::Status status = domi::AutoMappingFn(reinterpret_cast<google::protobuf::Message *>(&node_def), op);
  1074. EXPECT_EQ(domi::SUCCESS, status);
  1075. EXPECT_EQ(VALUE_NAME, op_desc->GetName());
  1076. string value_string = "";
  1077. ge::AttrUtils::GetStr(op_desc, KEY_STRING, value_string);
  1078. EXPECT_EQ(VALUE_STRING, value_string);
  1079. int64_t value_int = 0;
  1080. ge::AttrUtils::GetInt(op_desc, KEY_INT, value_int);
  1081. EXPECT_EQ(VALUE_INT, value_int);
  1082. float value_float = 0.0;
  1083. ge::AttrUtils::GetFloat(op_desc, KEY_FLOAT, value_float);
  1084. EXPECT_EQ(VALUE_FLOAT, value_float);
  1085. bool value_bool = false;
  1086. ge::AttrUtils::GetBool(op_desc, KEY_BOOL, value_bool);
  1087. EXPECT_EQ(VALUE_BOOL, value_bool);
  1088. ge::DataType data_type = ge::DT_UNDEFINED;
  1089. ge::AttrUtils::GetDataType(op_desc, KEY_TYPE, data_type);
  1090. EXPECT_EQ(ge::DT_FLOAT, data_type);
  1091. // test AutoMappingByOpFn
  1092. ge::OpDescPtr op_desc_dest = std::make_shared<ge::OpDesc>();
  1093. ge::Operator op_dest = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_dest);
  1094. status = domi::AutoMappingByOpFn(op, op_dest);
  1095. EXPECT_EQ(domi::SUCCESS, status);
  1096. EXPECT_EQ(VALUE_NAME, op_dest.GetName());
  1097. value_string = "";
  1098. ge::AttrUtils::GetStr(op_desc_dest, KEY_STRING, value_string);
  1099. EXPECT_EQ(VALUE_STRING, value_string);
  1100. value_int = 0;
  1101. ge::AttrUtils::GetInt(op_desc_dest, KEY_INT, value_int);
  1102. EXPECT_EQ(VALUE_INT, value_int);
  1103. value_float = 0.0;
  1104. ge::AttrUtils::GetFloat(op_desc_dest, KEY_FLOAT, value_float);
  1105. EXPECT_EQ(VALUE_FLOAT, value_float);
  1106. value_bool = false;
  1107. ge::AttrUtils::GetBool(op_desc_dest, KEY_BOOL, value_bool);
  1108. EXPECT_EQ(VALUE_BOOL, value_bool);
  1109. data_type = ge::DT_UNDEFINED;
  1110. ge::AttrUtils::GetDataType(op_desc_dest, KEY_TYPE, data_type);
  1111. EXPECT_EQ(ge::DT_FLOAT, data_type);
  1112. }
  1113. TEST_F(STestTensorflowParser, parse_ParseNodeDef)
  1114. {
  1115. NodeDef * node_def = new NodeDef();
  1116. node_def->set_name("test_name");
  1117. node_def->set_op("PlaceholderWithDefault");
  1118. bool isDatasetInit = true;
  1119. TensorFlowModelParser model_parser;
  1120. Status ret = model_parser.AdaptOpType(node_def, isDatasetInit);
  1121. EXPECT_EQ(domi::SUCCESS, ret);
  1122. node_def->set_op("Add");
  1123. ret = model_parser.AdaptOpType(node_def, isDatasetInit);
  1124. EXPECT_EQ(domi::SUCCESS, ret);
  1125. delete node_def;
  1126. }
  1127. TEST_F(STestTensorflowParser, parse_AddFmkNode)
  1128. {
  1129. TensorFlowModelParser modelParser;
  1130. std::string caseDir = __FILE__;
  1131. std::size_t idx = caseDir.find_last_of("/");
  1132. caseDir = caseDir.substr(0, idx);
  1133. std::string modelFile = caseDir + "/origin_models/tf_add.pb";
  1134. ge::Graph graph;
  1135. string graph_name;
  1136. AclGrphParseUtil acl_graph_parse_util;
  1137. std::map<ge::AscendString, ge::AscendString> parser_options = {{AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}};
  1138. ParerSTestsUtils::ClearParserInnerCtx();
  1139. Status ret = acl_graph_parse_util.ParseParamsBeforeGraph(parser_options, graph_name);
  1140. ret = aclgrphParseTensorFlow(modelFile.c_str(), parser_options, graph);
  1141. ASSERT_EQ(ret, SUCCESS);
  1142. ge::ComputeGraphPtr compute_graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
  1143. tensorflow::GraphDef *graphDef = new (std::nothrow) tensorflow::GraphDef();
  1144. ScopePassManager pass_manager;
  1145. std::shared_ptr<ScopeGraph> scope_graph = pass_manager.BuildScopeGraph(graphDef);
  1146. std::string fusion_op_name = "fusion_op_name";
  1147. FusionScopesResult *fusion_rlt = new (std::nothrow) FusionScopesResult();
  1148. EXPECT_NE(fusion_rlt, nullptr);
  1149. fusion_rlt->Init();
  1150. GenFusionScopesResult(scope_graph, fusion_rlt, fusion_op_name);
  1151. GenOriginContext(&modelParser, fusion_op_name);
  1152. // origin inner node def
  1153. NodeDef* node_def = MallocNodeDef("scope_node_1", "Add");
  1154. EXPECT_NE(node_def, nullptr);
  1155. modelParser.fusion_op_nodedef_map_[fusion_op_name].push_back(node_def);
  1156. bool train_flag_backup = ge::GetParserContext().train_flag;
  1157. ge::GetParserContext().train_flag = true;
  1158. REGISTER_CUSTOM_OP("Identity")
  1159. .FrameworkType(domi::TENSORFLOW)
  1160. .OriginOpType("Identity")
  1161. .ParseParamsFn(ParseParams)
  1162. .ImplyType(ImplyType::TVM);
  1163. REGISTER_CUSTOM_OP("Constant")
  1164. .FrameworkType(domi::TENSORFLOW)
  1165. .OriginOpType("Const")
  1166. .ParseParamsFn(ParseParams)
  1167. .ImplyType(ImplyType::TVM);
  1168. register_tbe_op();
  1169. std::vector<std::string> node_name_list;
  1170. GenOriginNodeDef(&modelParser, node_name_list);
  1171. std::set<std::string> malloc_node_name_list(node_name_list.begin(), node_name_list.end());
  1172. node_name_list.push_back(fusion_op_name);
  1173. ret = modelParser.AddFmkNode(compute_graph, scope_graph, node_name_list, false);
  1174. EXPECT_EQ(ret, PARAM_INVALID);
  1175. EXPECT_EQ(modelParser.scope_inner_node_map_.size(), 0);
  1176. EXPECT_EQ(modelParser.nodedef_map_.size(), 5);
  1177. ret = modelParser.AddEdges(compute_graph);
  1178. EXPECT_EQ(ret, SUCCESS);
  1179. // release resource
  1180. delete graphDef;
  1181. delete node_def;
  1182. modelParser.DeleteFuisonNodeDef();
  1183. FreeNodeDefMap(&modelParser, malloc_node_name_list);
  1184. ge::GetParserContext().train_flag = train_flag_backup;
  1185. }
  1186. TEST_F(STestTensorflowParser, parse_AddScopeInnerNode)
  1187. {
  1188. TensorFlowModelParser modelParser;
  1189. std::string caseDir = __FILE__;
  1190. std::size_t idx = caseDir.find_last_of("/");
  1191. caseDir = caseDir.substr(0, idx);
  1192. std::string modelFile = caseDir + "/origin_models/tf_add.pb";
  1193. std::string op_name = "ge_ascend_irgraph";
  1194. ge::Graph graph(op_name);
  1195. ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
  1196. std::map<ge::AscendString, ge::AscendString> parser_params = {
  1197. {AscendString(ge::ir_option::OUT_NODES), AscendString("Placeholder:0;Placeholder_1:0")}};
  1198. Status ret = ge::aclgrphParseTensorFlow(modelFile.c_str(), parser_params, graph);
  1199. EXPECT_EQ(ret, SUCCESS);
  1200. std::mutex graph_mutex;
  1201. tensorflow::NodeDef *node_def = initNodeDef();
  1202. node_def->set_name("FastrcnnPredictions");
  1203. node_def->set_op("FastrcnnPredictions");
  1204. // can't find in scope_inner_node_map
  1205. ret = modelParser.AddScopeInnerNode(&modelParser, compute_graph, &graph_mutex, node_def);
  1206. EXPECT_EQ(ret, PARAM_INVALID);
  1207. delete node_def;
  1208. }
  1209. TEST_F(STestTensorflowParser, dyncmic_rnn_scope_pass_plugin_test) {
  1210. ge::Graph graph;
  1211. std::cout << __FILE__ << std::endl;
  1212. std::string caseDir = __FILE__;
  1213. std::size_t idx = caseDir.find_last_of("/");
  1214. caseDir = caseDir.substr(0, idx);
  1215. std::string modelFile = caseDir + "/origin_models/tensor_array.pb";
  1216. std::map<ge::AscendString, ge::AscendString> params;
  1217. string key ="enable_scope_fusion_passes";
  1218. string value ="ScopeDynamicRNNPass";
  1219. params.insert(std::make_pair(ge::AscendString(key.c_str()), ge::AscendString(value.c_str())));
  1220. auto status = aclgrphParseTensorFlow(modelFile.c_str(), params, graph);
  1221. EXPECT_EQ(status, SUCCESS);
  1222. }
  1223. TEST_F(STestTensorflowParser, avgpool3dgrad_plugin_test_format_NDHWC) {
  1224. ge::Graph graph;
  1225. std::cout << __FILE__ << std::endl;
  1226. std::string caseDir = __FILE__;
  1227. std::size_t idx = caseDir.find_last_of("/");
  1228. caseDir = caseDir.substr(0, idx);
  1229. std::string modelFile = caseDir + "/origin_models/avgpool3dgrad_case_1.pb";
  1230. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1231. EXPECT_EQ(status, SUCCESS);
  1232. }
  1233. TEST_F(STestTensorflowParser, tensorflow_merge_test) {
  1234. ge::Graph graph;
  1235. std::cout << __FILE__ << std::endl;
  1236. std::string caseDir = __FILE__;
  1237. std::size_t idx = caseDir.find_last_of("/");
  1238. caseDir = caseDir.substr(0, idx);
  1239. std::string modelFile = caseDir + "/origin_models/merge.pb";
  1240. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1241. EXPECT_EQ(status, FAILED);
  1242. }
  1243. TEST_F(STestTensorflowParser, tensorflow_no_op_test) {
  1244. ge::Graph graph;
  1245. std::cout << __FILE__ << std::endl;
  1246. std::string caseDir = __FILE__;
  1247. std::size_t idx = caseDir.find_last_of("/");
  1248. caseDir = caseDir.substr(0, idx);
  1249. std::string modelFile = caseDir + "/origin_models/test_no_op.pb";
  1250. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1251. EXPECT_EQ(status, SUCCESS);
  1252. }
  1253. TEST_F(STestTensorflowParser, tensorflow_identity_test) {
  1254. ge::Graph graph;
  1255. std::cout << __FILE__ << std::endl;
  1256. std::string caseDir = __FILE__;
  1257. std::size_t idx = caseDir.find_last_of("/");
  1258. caseDir = caseDir.substr(0, idx);
  1259. std::string modelFile = caseDir + "/origin_models/test_identity.pb";
  1260. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1261. EXPECT_EQ(status, SUCCESS);
  1262. }
  1263. TEST_F(STestTensorflowParser, tensorflow_constant_test) {
  1264. ge::Graph graph;
  1265. std::cout << __FILE__ << std::endl;
  1266. std::string caseDir = __FILE__;
  1267. std::size_t idx = caseDir.find_last_of("/");
  1268. caseDir = caseDir.substr(0, idx);
  1269. std::string modelFile = caseDir + "/origin_models/test_constant.pb";
  1270. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1271. EXPECT_EQ(status, SUCCESS);
  1272. TensorFlowConstantParser constantParser;
  1273. ge::OpDescPtr op_dest = make_shared<ge::OpDesc>("constant", ge::parser::CONSTANT);
  1274. NodeDef* node_def = initNodeDef();
  1275. node_def->set_name("Constant");
  1276. auto params = constantParser.ParseParams(node_def, op_dest);
  1277. EXPECT_EQ(params, SUCCESS);
  1278. auto value = constantParser.ParseValue(node_def, op_dest);
  1279. EXPECT_EQ(value, SUCCESS);
  1280. ConstantOperator op;
  1281. auto type = constantParser.ParseDType(node_def, &op);
  1282. EXPECT_EQ(type, SUCCESS);
  1283. }
  1284. TEST_F(STestTensorflowParser, tensorflow_reshpae_test) {
  1285. ge::Graph graph;
  1286. std::cout << __FILE__ << std::endl;
  1287. std::string caseDir = __FILE__;
  1288. std::size_t idx = caseDir.find_last_of("/");
  1289. caseDir = caseDir.substr(0, idx);
  1290. std::string modelFile = caseDir + "/origin_models/test_reshape.pb";
  1291. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1292. EXPECT_EQ(status, SUCCESS);
  1293. TensorFlowReshapeParser parser;
  1294. NodeDef * nodeDef = new NodeDef();
  1295. ge::OpDescPtr opdef_ = make_shared<::ge::OpDesc>("","");
  1296. google::protobuf::Map<std::string, tensorflow::AttrValue > *attr_map = nodeDef->mutable_attr();
  1297. domi::tensorflow::AttrValue tshape_attr_value;
  1298. tshape_attr_value.set_type(domi::tensorflow::DT_INT32);
  1299. (*attr_map)[TENSORFLOW_ATTR_TSHAPE] = tshape_attr_value;
  1300. domi::tensorflow::AttrValue t_attr_value;
  1301. t_attr_value.set_type(domi::tensorflow::DT_FLOAT);
  1302. (*attr_map)[TENSORFLOW_ATTR_T] = t_attr_value;
  1303. Status ret = parser.ParseParams(nodeDef, opdef_);
  1304. EXPECT_EQ(domi::SUCCESS, ret);
  1305. delete nodeDef;
  1306. }
  1307. TEST_F(STestTensorflowParser, tensorflow_squeeze_test) {
  1308. ge::Graph graph;
  1309. std::cout << __FILE__ << std::endl;
  1310. std::string caseDir = __FILE__;
  1311. std::size_t idx = caseDir.find_last_of("/");
  1312. caseDir = caseDir.substr(0, idx);
  1313. std::string modelFile = caseDir + "/origin_models/test_sequeeze.pb";
  1314. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1315. EXPECT_EQ(status, SUCCESS);
  1316. TensorFlowSqueezeParser parser;
  1317. NodeDef *nodeDef = initNodeDef();
  1318. ge::OpDescPtr opDef = make_shared<::ge::OpDesc>("Squeeze","Squeeze");
  1319. Status ret = parser.ParseParams(nodeDef, opDef);
  1320. EXPECT_EQ(ret, SUCCESS);
  1321. NodeDef *nodeDef_dim = initNodeDef_dims();
  1322. ret = parser.ParseParams(nodeDef_dim, opDef);
  1323. EXPECT_EQ(SUCCESS, ret);
  1324. NodeDef *nodeDef_axis_dims = initNodeDef_axis_dims();
  1325. ret = parser.ParseParams(nodeDef_axis_dims, opDef);
  1326. EXPECT_EQ(GRAPH_PARAM_INVALID, ret);
  1327. static const string KEY_SHAPE_LIST = "key_shape_list";
  1328. static const string KEY_TENSOR_LIST = "key_tensor_list";
  1329. static const string KEY_DEFAULT = "key_default";
  1330. NodeDef *nodeDef2 = new NodeDef();
  1331. google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = nodeDef2->mutable_attr();
  1332. domi::tensorflow::AttrValue dtype_attr_value ;
  1333. dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT);
  1334. (*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value;
  1335. //设置strides属性
  1336. tensorflow::AttrValue axis_attr_value;
  1337. tensorflow::AttrValue_ListValue *list = axis_attr_value.mutable_list();
  1338. list->add_i(1);
  1339. list->add_i(2);
  1340. (*node_attr_map)[ge::SQUEEZE_ATTR_AXIS] = axis_attr_value;
  1341. domi::tensorflow::AttrValue value;
  1342. domi::tensorflow::AttrValue df_attr_value;
  1343. df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC);
  1344. domi::tensorflow::AttrValue pad_attr_value;
  1345. pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT);
  1346. domi::tensorflow::AttrValue shape;
  1347. shape.mutable_list()->add_i((int64)32);
  1348. shape.mutable_list()->add_i((int64)32);
  1349. shape.mutable_list()->add_i((int64)14);
  1350. static const string KEY_TYPE_LIST = "key_type_list";
  1351. const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc";
  1352. const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc";
  1353. static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT;
  1354. value.clear_value();
  1355. value.mutable_list()->add_type(VALUE_TYPE);
  1356. TensorFlowUtil::AddNodeAttr(KEY_TYPE_LIST, value, nodeDef2);
  1357. value.clear_value();
  1358. domi::tensorflow::NameAttrList name_attr_list;
  1359. name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value});
  1360. name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value});
  1361. name_attr_list.mutable_attr()->insert({"serialize_shape", shape});
  1362. *(value.mutable_list()->add_func()) = name_attr_list;
  1363. nodeDef2->mutable_attr()->insert({ge::ATTR_NAME_INPUT_TENSOR_DESC, value});
  1364. nodeDef2->mutable_attr()->insert({ge::ATTR_NAME_OUTPUT_TENSOR_DESC, value});
  1365. ret = parser.ParseParams(nodeDef2, opDef);
  1366. EXPECT_EQ(domi::SUCCESS, ret);
  1367. GeTensorDesc ge_desc;
  1368. ge_desc.SetFormat(ge::FORMAT_C1HWNCoC0);
  1369. ge_desc.SetDataType(ge::DT_FLOAT);
  1370. ge_desc.SetShape(GeShape({1,1,1,1,1,1}));
  1371. ret = parser.ParseDesc(value, ge_desc);
  1372. EXPECT_EQ(ret, SUCCESS);
  1373. delete nodeDef2;
  1374. delete nodeDef_axis_dims;
  1375. delete nodeDef_dim;
  1376. delete nodeDef;
  1377. }
  1378. TEST_F(STestTensorflowParser, tensorflow_fill_test) {
  1379. ge::Graph graph;
  1380. std::cout << __FILE__ << std::endl;
  1381. std::string caseDir = __FILE__;
  1382. std::size_t idx = caseDir.find_last_of("/");
  1383. caseDir = caseDir.substr(0, idx);
  1384. std::string modelFile = caseDir + "/origin_models/test_fill.pb";
  1385. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1386. EXPECT_EQ(status, SUCCESS);
  1387. }
  1388. TEST_F(STestTensorflowParser, tensorflow_shape_n_test) {
  1389. ge::Graph graph;
  1390. std::cout << __FILE__ << std::endl;
  1391. std::string caseDir = __FILE__;
  1392. std::size_t idx = caseDir.find_last_of("/");
  1393. caseDir = caseDir.substr(0, idx);
  1394. std::string modelFile = caseDir + "/origin_models/test_shape_n.pb";
  1395. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1396. EXPECT_EQ(status, SUCCESS);
  1397. }
  1398. TEST_F(STestTensorflowParser, tensorflow_switch_test) {
  1399. ge::Graph graph;
  1400. std::cout << __FILE__ << std::endl;
  1401. std::string caseDir = __FILE__;
  1402. std::size_t idx = caseDir.find_last_of("/");
  1403. caseDir = caseDir.substr(0, idx);
  1404. std::string modelFile = caseDir + "/origin_models/test_switch.pb";
  1405. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1406. EXPECT_EQ(status, SUCCESS);
  1407. TensorFlowRefSwitchParser refSwitchParser;
  1408. ge::OpDescPtr op_dest = make_shared<ge::OpDesc>("constant", ge::parser::CONSTANT);
  1409. NodeDef* node_def = initNodeDef();
  1410. node_def->set_name("RefSwitch");
  1411. auto params = refSwitchParser.ParseParams(node_def, op_dest);
  1412. EXPECT_EQ(params, SUCCESS);
  1413. RefSwitchOperator op;
  1414. auto parseRet = refSwitchParser.ParseT(node_def, &op);
  1415. EXPECT_EQ(parseRet, SUCCESS);
  1416. }
  1417. TEST_F(STestTensorflowParser, tensorflow_enter_test) {
  1418. ge::Graph graph;
  1419. std::cout << __FILE__ << std::endl;
  1420. std::string caseDir = __FILE__;
  1421. std::size_t idx = caseDir.find_last_of("/");
  1422. caseDir = caseDir.substr(0, idx);
  1423. std::string modelFile = caseDir + "/origin_models/test_enter.pb";
  1424. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1425. EXPECT_EQ(status, SUCCESS);
  1426. TensorFlowEnterParser enterParser;
  1427. ge::OpDescPtr op_dest = make_shared<ge::OpDesc>("Enter", ge::parser::ENTER);
  1428. NodeDef* node_def = initNodeDef();
  1429. node_def->set_name("Enter");
  1430. Status ret = enterParser.ParseParams(node_def, op_dest);
  1431. EXPECT_EQ(ret, FAILED);
  1432. static const string KEY_SHAPE_LIST = "key_shape_list";
  1433. static const string KEY_TENSOR_LIST = "key_tensor_list";
  1434. static const string KEY_DEFAULT = "key_default";
  1435. google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = node_def->mutable_attr();
  1436. domi::tensorflow::AttrValue dtype_attr_value;
  1437. dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT);
  1438. (*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value;
  1439. //设置strides属性
  1440. domi::tensorflow::AttrValue axis_attr_value;
  1441. ::tensorflow::AttrValue_ListValue* list = axis_attr_value.mutable_list();
  1442. list->add_i(1);
  1443. list->add_i(2);
  1444. (*node_attr_map)[ge::SQUEEZE_ATTR_AXIS] = axis_attr_value;
  1445. domi::tensorflow::AttrValue value;
  1446. domi::tensorflow::AttrValue df_attr_value;
  1447. df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC);
  1448. domi::tensorflow::AttrValue pad_attr_value;
  1449. pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT);
  1450. domi::tensorflow::AttrValue shape;
  1451. shape.mutable_list()->add_i((int64)32);
  1452. shape.mutable_list()->add_i((int64)32);
  1453. shape.mutable_list()->add_i((int64)14);
  1454. static const string KEY_TYPE_LIST = "key_type_list";
  1455. const std::string ENTER_ATTR_FRAME_NAME = "frame_name";
  1456. const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc";
  1457. static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT;
  1458. value.clear_value();
  1459. value.mutable_list()->add_type(VALUE_TYPE);
  1460. TensorFlowUtil::AddNodeAttr(KEY_TYPE_LIST, value, node_def);
  1461. value.clear_value();
  1462. domi::tensorflow::NameAttrList name_attr_list;
  1463. name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value});
  1464. name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value});
  1465. name_attr_list.mutable_attr()->insert({"serialize_shape", shape});
  1466. *(value.mutable_list()->add_func()) = name_attr_list;
  1467. node_def->mutable_attr()->insert({ge::ENTER_ATTR_FRAME_NAME, value});
  1468. node_def->mutable_attr()->insert({ge::ATTR_NAME_OUTPUT_TENSOR_DESC, value});
  1469. ret = enterParser.ParseParams(node_def, op_dest);
  1470. EXPECT_EQ(ret, FAILED);
  1471. }
  1472. TEST_F(STestTensorflowParser, tensorflow_VariableV2_test) {
  1473. ge::Graph graph;
  1474. std::string caseDir = __FILE__;
  1475. std::size_t idx = caseDir.find_last_of("/");
  1476. caseDir = caseDir.substr(0, idx);
  1477. std::string modelFile = caseDir + "/origin_models/test_VariableV2.pb";
  1478. auto status = aclgrphParseTensorFlow(modelFile.c_str(), graph);
  1479. EXPECT_EQ(status, SUCCESS);
  1480. }
  1481. TEST_F(STestTensorflowParser, tensorflow_fusion_op_parser_test)
  1482. {
  1483. TensorFlowFusionOpParser fusionOpParser;
  1484. ge::OpDescPtr op_dest = make_shared<ge::OpDesc>("FusionOp", ge::parser::CONSTANT);
  1485. int index = 0;
  1486. NodeDef* node_def = fusioninitNodeDef(index);
  1487. node_def->set_name("FusionOp");
  1488. auto ret = fusionOpParser.ParseParams(node_def, op_dest);
  1489. EXPECT_EQ(ret, SUCCESS);
  1490. int32_t param = 1;
  1491. ret = fusionOpParser.ParseParamFromConst(node_def, param);
  1492. EXPECT_EQ(ret, SUCCESS);
  1493. ret = fusionOpParser.ParseParamFromConst(node_def, param, index);
  1494. EXPECT_EQ(ret, SUCCESS);
  1495. float params = 0.0;
  1496. ret = fusionOpParser.ParseParamFromConst(node_def, params);
  1497. EXPECT_EQ(ret, SUCCESS);
  1498. index = 2;
  1499. node_def = fusioninitNodeDef(index);
  1500. ret = fusionOpParser.ParseParamFromConst(node_def, params, index);
  1501. EXPECT_EQ(ret, domi::PARAM_INVALID);
  1502. ret = fusionOpParser.ParseHalfFromConst(node_def, params, 0);
  1503. EXPECT_EQ(ret, SUCCESS);
  1504. ret = fusionOpParser.ParseHalfFromConst(node_def, params, 3);
  1505. EXPECT_EQ(ret, domi::PARAM_INVALID);
  1506. node_def = fusioninitNodeDef(0);
  1507. ret = fusionOpParser.ParseHalfFromConst(node_def, params, 3);
  1508. EXPECT_EQ(ret, domi::PARAM_INVALID);
  1509. static const float VALUE_FLOAT = 1.0;
  1510. ge::GeTensorPtr weight = nullptr;
  1511. ret = fusionOpParser.ParseWeightFromConst(node_def, weight);
  1512. EXPECT_EQ(ret, domi::SUCCESS);
  1513. EXPECT_NE(weight, nullptr);
  1514. ge::DataType ge_data_type = weight->GetTensorDesc().GetDataType();
  1515. EXPECT_EQ(ge_data_type, ge::DataType::DT_FLOAT);
  1516. const uint8_t* data_buff = weight->GetData().GetData();
  1517. size_t data_size = weight->GetData().size();
  1518. EXPECT_NE(data_buff, nullptr);
  1519. EXPECT_EQ(data_size, sizeof(float));
  1520. float value_float = *((float*)data_buff);
  1521. EXPECT_EQ(value_float, VALUE_FLOAT);
  1522. delete node_def;
  1523. }
  1524. TEST_F(STestTensorflowParser, tensorflow_auto_mapping_parser_adapter_test)
  1525. {
  1526. ge::OpDescPtr op_dest = nullptr;
  1527. Message *op_src = nullptr;
  1528. TensorFlowAutoMappingParserAdapter autoMappingParser;
  1529. NodeDef* node_def = initNodeDef();
  1530. Status ret = autoMappingParser.ParseParams(op_src, op_dest);
  1531. EXPECT_EQ(ret, PARAM_INVALID);
  1532. ret = autoMappingParser.ParseParams(node_def, op_dest);
  1533. EXPECT_EQ(ret, PARAM_INVALID);
  1534. op_dest = make_shared<ge::OpDesc>("AutoMapping", ge::parser::CONSTANT);
  1535. op_dest->SetType(ge::parser::EMPTY);
  1536. ret = autoMappingParser.ParseParams(node_def, op_dest);
  1537. EXPECT_EQ(ret, SUCCESS);
  1538. op_dest->SetType(ge::parser::IDENTITYN);
  1539. ret = autoMappingParser.ParseParams(node_def, op_dest);
  1540. EXPECT_EQ(ret, SUCCESS);
  1541. op_dest->SetType(ge::parser::SIZE);
  1542. ret = autoMappingParser.ParseParams(node_def, op_dest);
  1543. EXPECT_EQ(ret, SUCCESS);
  1544. op_dest->SetType(ge::parser::SHAPE);
  1545. ret = autoMappingParser.ParseParams(node_def, op_dest);
  1546. EXPECT_EQ(ret, SUCCESS);
  1547. }
  1548. TEST_F(STestTensorflowParser, tensorflow_fusion_custom_parser_adapter_test)
  1549. {
  1550. REGISTER_CUSTOM_OP("FusionCustom")
  1551. .FrameworkType(domi::TENSORFLOW)
  1552. .OriginOpType("FusionCustom")
  1553. .FusionParseParamsFn(FusionParserParams)
  1554. .ImplyType(ImplyType::TVM);
  1555. register_tbe_op();
  1556. auto graph = std::make_shared<ge::ComputeGraph>("FusionCustom");
  1557. auto op_desc = std::make_shared<ge::OpDesc>("FusionCustom", "FusionCustom");
  1558. auto node = graph->AddNode(op_desc);
  1559. NodeDef *node_def = new NodeDef();
  1560. std::vector<const NodeDef *> v_input_const1;
  1561. v_input_const1.push_back(node_def);
  1562. TensorFlowFusionCustomParserAdapter parser;
  1563. domi::Status status = parser.ParseParams(v_input_const1, node);
  1564. EXPECT_EQ(SUCCESS, status);
  1565. ge::Operator op_src("pool", "pooling");
  1566. std::vector<ge::Operator> v_input_const2;
  1567. v_input_const2.push_back(op_src);
  1568. Status ret = parser.ParseParams(v_input_const2, node);
  1569. EXPECT_EQ(FAILED, ret);
  1570. delete node_def;
  1571. }
  1572. TEST_F(STestTensorflowParser, tensorflow_custom_parser_adapter_test)
  1573. {
  1574. ge::Operator op_src("pool", "pooling");
  1575. ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>();
  1576. TensorFlowCustomParserAdapter parser;
  1577. Status ret = parser.ParseParams(op_src, op_dest);
  1578. EXPECT_EQ(ret, FAILED);
  1579. REGISTER_CUSTOM_OP("Variable")
  1580. .FrameworkType(domi::TENSORFLOW)
  1581. .OriginOpType("VariableV2")
  1582. .ParseParamsFn(ParseParams)
  1583. .ParseParamsByOperatorFn(ParseParamByOpFunc)
  1584. .ImplyType(ImplyType::CUSTOM);
  1585. register_tbe_op();
  1586. Operator opSrc(ge::parser::VARIABLE, "VariableV2");
  1587. ret = parser.ParseParams(opSrc, op_dest);
  1588. EXPECT_EQ(ret, SUCCESS);
  1589. }
  1590. TEST_F(STestTensorflowParser, tensorflow_graph_functiondef_FindAttrValue_test)
  1591. {
  1592. GraphToFunctionDef functionDef;
  1593. NodeDef *node_def = nullptr;
  1594. std::string attr_name = "Const";
  1595. tensorflow::AttrValue attr_value;
  1596. bool ret = functionDef.FindAttrValue(node_def, attr_name, attr_value);
  1597. EXPECT_EQ(ret, false);
  1598. node_def = initNodeDef();
  1599. attr_name = ge::ATTR_NAME_INPUT_TENSOR_DESC;
  1600. node_def->set_name("Const");
  1601. ret = functionDef.FindAttrValue(node_def, attr_name, attr_value);
  1602. EXPECT_EQ(ret, false);
  1603. }
  1604. TEST_F(STestTensorflowParser, tensorflow_graph_functiondef_BuildFunctionDef_test)
  1605. {
  1606. ge::ComputeGraphPtr subGraph = std::make_shared<ge::ComputeGraph>("default");
  1607. string inputNodeType = "DATA";
  1608. MakeDagGraph(subGraph, inputNodeType);
  1609. FunctionDefLibrary library;
  1610. tensorflow::NodeDef call_node_def;
  1611. call_node_def.set_op("fusionop");
  1612. call_node_def.set_name("fusionop");
  1613. vector<ge::InDataAnchorPtr> in_anchor;
  1614. vector<ge::OutDataAnchorPtr> out_anchor;
  1615. for (ge::NodePtr node : subGraph->GetAllNodes()) {
  1616. for (auto in : node->GetAllInDataAnchors()) {
  1617. if (in->GetPeerOutAnchor() != nullptr && in->GetPeerOutAnchor()->GetOwnerNode()->GetOpDesc()->GetType() == parser::DATA) {
  1618. in_anchor.push_back(in);
  1619. }
  1620. }
  1621. for (auto out : node->GetAllOutDataAnchors()) {
  1622. for (auto i : out->GetPeerInDataAnchors()) {
  1623. if (i->GetOwnerNode()->GetOpDesc()->GetType() == parser::NETOUTPUT) {
  1624. out_anchor.push_back(out);
  1625. }
  1626. }
  1627. }
  1628. }
  1629. Status ret = GraphToFunctionDef::BuildFunctionDef(subGraph,
  1630. "fusionop",
  1631. &library,
  1632. &call_node_def,
  1633. in_anchor,
  1634. out_anchor);
  1635. EXPECT_EQ(domi::INTERNAL_ERROR, ret);
  1636. }
  1637. TEST_F(STestTensorflowParser, tensorflow_CheckOpShapeDim_test)
  1638. {
  1639. NodeDef *node_def = initNodeDef();
  1640. std::set<int> dims;
  1641. dims.insert(1);
  1642. dims.insert(2);
  1643. bool valid = true;
  1644. TensorFlowModelParser parser;
  1645. Status ret = parser.CheckOpShapeDim(node_def, dims, valid);
  1646. EXPECT_EQ(ret, SUCCESS);
  1647. static const string KEY_SHAPE_LIST = "key_shape_list";
  1648. static const string KEY_TENSOR_LIST = "key_tensor_list";
  1649. static const string KEY_DEFAULT = "key_default";
  1650. google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = node_def->mutable_attr();
  1651. domi::tensorflow::AttrValue dtype_attr_value;
  1652. dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT);
  1653. (*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value;
  1654. //设置strides属性
  1655. domi::tensorflow::AttrValue axis_attr_value;
  1656. ::tensorflow::AttrValue_ListValue* list = axis_attr_value.mutable_list();
  1657. list->add_i(1);
  1658. list->add_i(2);
  1659. (*node_attr_map)[ge::SQUEEZE_ATTR_AXIS] = axis_attr_value;
  1660. domi::tensorflow::AttrValue value;
  1661. domi::tensorflow::AttrValue df_attr_value;
  1662. df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC);
  1663. domi::tensorflow::AttrValue pad_attr_value;
  1664. pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT);
  1665. domi::tensorflow::AttrValue shape;
  1666. shape.mutable_list()->add_i((int64)32);
  1667. shape.mutable_list()->add_i((int64)32);
  1668. shape.mutable_list()->add_i((int64)14);
  1669. static const string KEY_TYPE_LIST = "key_type_list";
  1670. const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc";
  1671. const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc";
  1672. static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT;
  1673. value.clear_value();
  1674. value.mutable_list()->add_type(VALUE_TYPE);
  1675. TensorFlowUtil::AddNodeAttr(KEY_TYPE_LIST, value, node_def);
  1676. value.clear_value();
  1677. domi::tensorflow::NameAttrList name_attr_list;
  1678. name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value});
  1679. name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value});
  1680. name_attr_list.mutable_attr()->insert({"serialize_shape", shape});
  1681. *(value.mutable_list()->add_func()) = name_attr_list;
  1682. node_def->mutable_attr()->insert({ge::ATTR_NAME_INPUT_TENSOR_DESC, value});
  1683. node_def->mutable_attr()->insert({ge::ATTR_NAME_OUTPUT_TENSOR_DESC, value});
  1684. ret = parser.CheckOpShapeDim(node_def, dims, valid);
  1685. EXPECT_EQ(ret, SUCCESS);
  1686. }
  1687. TEST_F(STestTensorflowParser, tensorflow_Scope_pass_test)
  1688. {
  1689. ScopePassManager passmanager;
  1690. auto scope_graph = ge::parser::MakeShared<ge::ScopeGraph>();
  1691. if (scope_graph == nullptr) {
  1692. GELOGE(FAILED, "Scope graph make shared failed.");
  1693. return;
  1694. }
  1695. if (scope_graph->Init() != SUCCESS) {
  1696. GELOGE(FAILED, "Scope graph init failed.");
  1697. return;
  1698. }
  1699. ge::TensorFlowModelParser tf_model_parser;
  1700. std::vector<string> scope_passes_list = {"pass_1", "pass_2"};
  1701. tf_model_parser.RunScopeFusionPass(scope_passes_list, passmanager, scope_graph);
  1702. Status ret = tf_model_parser.RunScopeFusionPass(scope_passes_list, passmanager, scope_graph);
  1703. EXPECT_NE(ge::SUCCESS, ret);
  1704. }
  1705. TEST_F(STestTensorflowParser, tensorflow_variable_v2_parser_test)
  1706. {
  1707. TensorFlowCustomParserAdapter parser;
  1708. ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>();
  1709. NodeDef *node_def = initNodeDef();
  1710. TensorFlowModelParser modelParser;
  1711. std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
  1712. std::shared_ptr<OpParser> op_parser = factory->CreateOpParser("Variable");
  1713. shared_ptr<TensorFlowOpParser> tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
  1714. Status ret = tensorflow_op_parser->ParseParams(node_def, op_dest);
  1715. EXPECT_EQ(ret, PARAM_INVALID);
  1716. node_def->set_name("TemporaryVariable");
  1717. node_def->set_op("TemporaryVariable");
  1718. op_parser = factory->CreateOpParser("TemporaryVariable");
  1719. tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
  1720. ret = tensorflow_op_parser->ParseParams(node_def, op_dest);
  1721. EXPECT_EQ(ret, PARAM_INVALID);
  1722. NodeDef *nodeDef_temporaryVariable = initOpNodeDef_TemporaryVariable();
  1723. op_parser = factory->CreateOpParser("TemporaryVariable");
  1724. tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
  1725. ret = tensorflow_op_parser->ParseParams(nodeDef_temporaryVariable, op_dest);
  1726. EXPECT_EQ(ret, SUCCESS);
  1727. NodeDef *nodeDef_VariableV2 = initOpNodeDef_VariableV2();
  1728. op_parser = factory->CreateOpParser("Variable");
  1729. tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
  1730. ret = tensorflow_op_parser->ParseParams(nodeDef_VariableV2, op_dest);
  1731. EXPECT_EQ(ret, SUCCESS);
  1732. }
  1733. TEST_F(STestTensorflowParser, tensorflow_var_is_initialized_op_test)
  1734. {
  1735. TensorFlowCustomParserAdapter parser;
  1736. ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>();
  1737. NodeDef *node_def = initNodeDef();
  1738. TensorFlowModelParser modelParser;
  1739. std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
  1740. std::shared_ptr<OpParser> op_parser = factory->CreateOpParser("VarIsInitializedOp");
  1741. shared_ptr<TensorFlowOpParser> tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
  1742. Status ret = tensorflow_op_parser->ParseParams(node_def, op_dest);
  1743. EXPECT_EQ(ret, SUCCESS);
  1744. }
  1745. TEST_F(STestTensorflowParser, tensorflow_arg_parser_test)
  1746. {
  1747. TensorFlowCustomParserAdapter parser;
  1748. ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>();
  1749. NodeDef *node_def = initNodeDef();
  1750. TensorFlowModelParser modelParser;
  1751. std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
  1752. std::shared_ptr<OpParser> op_parser = factory->CreateOpParser("_Arg");
  1753. shared_ptr<TensorFlowOpParser> tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
  1754. Status ret = tensorflow_op_parser->ParseParams(node_def, op_dest);
  1755. EXPECT_EQ(ret, SUCCESS);
  1756. static const string KEY_SHAPE_LIST = "key_shape_list";
  1757. static const string KEY_TENSOR_LIST = "key_tensor_list";
  1758. static const string KEY_DEFAULT = "key_default";
  1759. google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = node_def->mutable_attr();
  1760. domi::tensorflow::AttrValue dtype_attr_value;
  1761. dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT);
  1762. (*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value;
  1763. //设置strides属性
  1764. domi::tensorflow::AttrValue axis_attr_value;
  1765. ::tensorflow::AttrValue_ListValue* list = axis_attr_value.mutable_list();
  1766. list->add_i(1);
  1767. list->add_i(2);
  1768. (*node_attr_map)[ge::SQUEEZE_ATTR_AXIS] = axis_attr_value;
  1769. domi::tensorflow::AttrValue value;
  1770. domi::tensorflow::AttrValue df_attr_value;
  1771. df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC);
  1772. domi::tensorflow::AttrValue pad_attr_value;
  1773. pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT);
  1774. domi::tensorflow::AttrValue shape;
  1775. shape.mutable_list()->add_i((int64)32);
  1776. shape.mutable_list()->add_i((int64)32);
  1777. shape.mutable_list()->add_i((int64)14);
  1778. static const string KEY_TYPE_LIST = "key_type_list";
  1779. const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc";
  1780. const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc";
  1781. static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT;
  1782. value.clear_value();
  1783. value.mutable_list()->add_type(VALUE_TYPE);
  1784. TensorFlowUtil::AddNodeAttr(KEY_TYPE_LIST, value, node_def);
  1785. value.clear_value();
  1786. domi::tensorflow::NameAttrList name_attr_list;
  1787. name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value});
  1788. name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value});
  1789. name_attr_list.mutable_attr()->insert({"serialize_shape", shape});
  1790. *(value.mutable_list()->add_func()) = name_attr_list;
  1791. node_def->mutable_attr()->insert({ge::ATTR_NAME_INPUT_TENSOR_DESC, value});
  1792. node_def->mutable_attr()->insert({ge::ATTR_NAME_OUTPUT_TENSOR_DESC, value});
  1793. ret = tensorflow_op_parser->ParseParams(node_def, op_dest);
  1794. EXPECT_EQ(ret, SUCCESS);
  1795. }
  1796. TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test1)
  1797. {
  1798. TensorFlowCustomParserAdapter parser;
  1799. ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>();
  1800. NodeDef *node_def = initNodeDef();
  1801. TensorFlowModelParser modelParser;
  1802. std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
  1803. std::shared_ptr<OpParser> op_parser = factory->CreateOpParser("FrameworkOp");
  1804. shared_ptr<TensorFlowOpParser> tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
  1805. Status ret = tensorflow_op_parser->ParseParams(node_def, op_dest);
  1806. EXPECT_EQ(ret, PARAM_INVALID);
  1807. ChangeDataType(node_def, tensorflow::DT_UINT16);
  1808. ret = tensorflow_op_parser->ParseParams(node_def, op_dest);
  1809. EXPECT_EQ(ret, PARAM_INVALID);
  1810. }
  1811. TEST_F(STestTensorflowParser, tensorflow_frameworkop_parser_test2)
  1812. {
  1813. TensorFlowCustomParserAdapter parser;
  1814. ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>();
  1815. NodeDef *node_def = initNodeDef();
  1816. node_def->set_name("FrameworkOp");
  1817. node_def->set_op("_Retval");
  1818. TensorFlowModelParser modelParser;
  1819. std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
  1820. std::shared_ptr<OpParser> op_parser = factory->CreateOpParser("FrameworkOp");
  1821. shared_ptr<TensorFlowOpParser> tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
  1822. static const string KEY_SHAPE_LIST = "key_shape_list";
  1823. static const string KEY_TENSOR_LIST = "key_tensor_list";
  1824. static const string KEY_DEFAULT = "key_default";
  1825. google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = node_def->mutable_attr();
  1826. domi::tensorflow::AttrValue dtype_attr_value;
  1827. dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT);
  1828. (*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value;
  1829. //设置strides属性
  1830. domi::tensorflow::AttrValue axis_attr_value;
  1831. ::tensorflow::AttrValue_ListValue* list = axis_attr_value.mutable_list();
  1832. list->add_i(1);
  1833. list->add_i(2);
  1834. (*node_attr_map)[ge::SQUEEZE_ATTR_AXIS] = axis_attr_value;
  1835. domi::tensorflow::AttrValue value;
  1836. domi::tensorflow::AttrValue df_attr_value;
  1837. df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC);
  1838. domi::tensorflow::AttrValue pad_attr_value;
  1839. pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT);
  1840. domi::tensorflow::AttrValue shape;
  1841. shape.mutable_list()->add_i((int64)32);
  1842. shape.mutable_list()->add_i((int64)32);
  1843. shape.mutable_list()->add_i((int64)14);
  1844. static const string KEY_TYPE_LIST = "key_type_list";
  1845. const std::string ATTR_NAME_INPUT_TENSOR_DESC = "ATTR_NAME_FRAMEWORK_OP_DEF";
  1846. const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc";
  1847. static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT;
  1848. value.clear_value();
  1849. value.mutable_list()->add_type(VALUE_TYPE);
  1850. TensorFlowUtil::AddNodeAttr(KEY_TYPE_LIST, value, node_def);
  1851. value.clear_value();
  1852. domi::tensorflow::NameAttrList name_attr_list;
  1853. name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value});
  1854. name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value});
  1855. name_attr_list.mutable_attr()->insert({"serialize_shape", shape});
  1856. *(value.mutable_list()->add_func()) = name_attr_list;
  1857. node_def->mutable_attr()->insert({ge::ATTR_NAME_INPUT_TENSOR_DESC, value});
  1858. node_def->mutable_attr()->insert({ge::ATTR_NAME_OUTPUT_TENSOR_DESC, value});
  1859. Status ret = tensorflow_op_parser->ParseParams(node_def, op_dest);
  1860. EXPECT_EQ(ret, SUCCESS);
  1861. }
  1862. TEST_F(STestTensorflowParser, tensorflow_reshape_parser_test)
  1863. {
  1864. TensorFlowCustomParserAdapter parser;
  1865. ge::OpDescPtr op_dest = std::make_shared<ge::OpDesc>();
  1866. NodeDef *node_def = initNodeDef();
  1867. TensorFlowModelParser modelParser;
  1868. std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
  1869. std::shared_ptr<OpParser> op_parser = factory->CreateOpParser("Reshape");
  1870. shared_ptr<TensorFlowOpParser> tensorflow_op_parser = std::dynamic_pointer_cast<TensorFlowOpParser>(op_parser);
  1871. Status ret = tensorflow_op_parser->ParseParams(node_def, op_dest);
  1872. EXPECT_EQ(ret, SUCCESS);
  1873. NodeDef * nodeDef = new NodeDef();
  1874. nodeDef->set_op("Reshape");
  1875. google::protobuf::Map< ::std::string, ::tensorflow::AttrValue >* node_attr_map = nodeDef->mutable_attr();
  1876. domi::tensorflow::AttrValue attr_value;
  1877. attr_value.mutable_list()->add_i((int64)32);
  1878. attr_value.mutable_list()->add_i((int64)32);
  1879. attr_value.mutable_list()->add_i((int64)14);
  1880. domi::tensorflow::AttrValue df_attr_value2;
  1881. df_attr_value2.set_s(TENSORFLOWF_TENSOR_NHWC);
  1882. (*node_attr_map)[TENSORFLOW_ATTR_DATA_FORMAT] = df_attr_value2;
  1883. domi::tensorflow::AttrValue df_attr_value;
  1884. df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC);
  1885. //设置padding属性
  1886. domi::tensorflow::AttrValue pad_attr_value2;
  1887. pad_attr_value2.set_s(TENSORFLOWF_OP_PADDING_SAME);
  1888. (*node_attr_map)[TENSORFLOW_ATTR_PADDING] = pad_attr_value2;
  1889. domi::tensorflow::AttrValue pad_attr_value;
  1890. pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT);
  1891. domi::tensorflow::NameAttrList name_attr_list;
  1892. name_attr_list.mutable_attr()->insert({"serialize_shape", attr_value});
  1893. name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value});
  1894. name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value});
  1895. *(attr_value.mutable_list()->add_func()) = name_attr_list;
  1896. GeTensorDesc ge_desc;
  1897. ge_desc.SetFormat(ge::FORMAT_C1HWNCoC0);
  1898. ge_desc.SetDataType(ge::DT_FLOAT);
  1899. ge_desc.SetShape(GeShape({1,1,1,1,1,1}));
  1900. TensorFlowReshapeParser reshapeParser;
  1901. ret = reshapeParser.ParseDesc(attr_value, ge_desc);
  1902. EXPECT_EQ(ret, SUCCESS);
  1903. }
  1904. TEST_F(STestTensorflowParser, tensorflow_DefunToPartitionedCall_parser_test)
  1905. {
  1906. TensorFlowModelParser parser;
  1907. NodeDef *node_def = initNodeDef();
  1908. node_def->set_name("ShapeN");
  1909. ge::OpDescPtr op = make_shared<ge::OpDesc>("ShapeN", ge::parser::PARTITIONEDCALL);
  1910. Status ret = parser.DefunToPartitionedCall(node_def, op);
  1911. EXPECT_EQ(ret, FAILED);
  1912. static const string KEY_SHAPE_LIST = "key_shape_list";
  1913. static const string KEY_TENSOR_LIST = "key_tensor_list";
  1914. static const string KEY_DEFAULT = "key_default";
  1915. google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = node_def->mutable_attr();
  1916. domi::tensorflow::AttrValue dtype_attr_value;
  1917. dtype_attr_value.set_type(domi::tensorflow::DT_FLOAT);
  1918. (*node_attr_map)[TENSORFLOW_ATTR_T] = dtype_attr_value;
  1919. //设置strides属性
  1920. domi::tensorflow::AttrValue axis_attr_value;
  1921. ::tensorflow::AttrValue_ListValue* list = axis_attr_value.mutable_list();
  1922. list->add_i(1);
  1923. list->add_i(2);
  1924. (*node_attr_map)[ge::SQUEEZE_ATTR_AXIS] = axis_attr_value;
  1925. domi::tensorflow::AttrValue value;
  1926. domi::tensorflow::AttrValue df_attr_value;
  1927. df_attr_value.set_i((int64_t)ccTensorFormat_t::CC_TENSOR_NHWC);
  1928. domi::tensorflow::AttrValue pad_attr_value;
  1929. pad_attr_value.set_i((int64_t)tensorflow::DT_FLOAT);
  1930. domi::tensorflow::AttrValue shape;
  1931. shape.mutable_list()->add_i((int64)32);
  1932. shape.mutable_list()->add_i((int64)32);
  1933. shape.mutable_list()->add_i((int64)14);
  1934. static const string KEY_TYPE_LIST = "key_type_list";
  1935. static const domi::tensorflow::DataType VALUE_TYPE = domi::tensorflow::DataType::DT_FLOAT;
  1936. value.clear_value();
  1937. value.mutable_list()->add_type(VALUE_TYPE);
  1938. TensorFlowUtil::AddNodeAttr(KEY_TYPE_LIST, value, node_def);
  1939. value.clear_value();
  1940. domi::tensorflow::NameAttrList name_attr_list;
  1941. name_attr_list.mutable_attr()->insert({"serialize_datatype", pad_attr_value});
  1942. name_attr_list.mutable_attr()->insert({"serialize_format", df_attr_value});
  1943. name_attr_list.mutable_attr()->insert({"serialize_shape", shape});
  1944. *(value.mutable_list()->add_func()) = name_attr_list;
  1945. node_def->mutable_attr()->insert({"_disable_call_shape_inference", value});
  1946. node_def->mutable_attr()->insert({"_disable_call_shape_inference", value});
  1947. std::string fusion_op_name = "pre_node_a";
  1948. GenOriginContext(&parser, fusion_op_name);
  1949. node_def->set_name("pre_node_a");
  1950. ret = parser.DefunToPartitionedCall(node_def, op);
  1951. EXPECT_EQ(ret, SUCCESS);
  1952. }
  1953. TEST_F(STestTensorflowParser, tensorflow_TransNodeToOpDesc_parser_test)
  1954. {
  1955. TensorFlowModelParser parser;
  1956. NodeDef *node_def = initNodeDef();
  1957. node_def->set_name("ge::parser::DATA");
  1958. std::string op_type = "ge::parser::DATA";
  1959. ge::OpDescPtr op = make_shared<ge::OpDesc>("constant", ge::parser::CONSTANT);
  1960. Status ret = parser.TransNodeToOpDesc(node_def, op, op_type);
  1961. EXPECT_EQ(ret, FAILED);
  1962. }
  1963. domi::Status fusion_parse_param_by_op(const std::vector<ge::Operator> &op_src, ge::Operator &op) {
  1964. return domi::SUCCESS;
  1965. }
  1966. TEST_F(STestTensorflowParser, Fusion_node_parse_params_success) {
  1967. ge::ComputeGraphPtr compute_graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
  1968. ModelParserFactory* factory = ModelParserFactory::Instance();
  1969. shared_ptr<ModelParser> model_parser= factory->CreateModelParser(domi::TENSORFLOW);
  1970. ASSERT_TRUE(NULL != model_parser);
  1971. TensorFlowModelParser tensorflow_parser;
  1972. domi::tensorflow::NodeDef node_def;
  1973. node_def.set_name("data");
  1974. node_def.set_op("FusionCustom");
  1975. FusionParseParamByOpFunc function = fusion_parse_param_by_op;
  1976. shared_ptr<ge::OpParserFactory> op_parser = ge::OpParserFactory::Instance(domi::TENSORFLOW);
  1977. shared_ptr<OpParser> fusion_op_parser = op_parser->CreateFusionOpParser("FusionCustom");
  1978. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
  1979. ge::OpDescPtr op1 = std::make_shared<ge::OpDesc>("data", "FusionCustom");
  1980. ge::NodePtr node1 = std::make_shared<ge::Node>(op1, graph);
  1981. vector<const NodeDef *> node_defs;
  1982. node_defs.push_back(&node_def);
  1983. tensorflow_parser.fusion_op_nodedef_map_["data"] = node_defs;
  1984. Status ret = tensorflow_parser.FusionNodeParseParams(fusion_op_parser, &node_def, node1);
  1985. EXPECT_EQ(domi::SUCCESS, ret);
  1986. }
  1987. TEST_F(STestTensorflowParser, Tensorflow_recordFusionResult_parser_test)
  1988. {
  1989. auto scope_graph = ge::parser::MakeShared<ge::ScopeGraph>();
  1990. if (scope_graph == nullptr) {
  1991. GELOGE(FAILED, "Scope graph make shared failed.");
  1992. return;
  1993. }
  1994. if (scope_graph->Init() != SUCCESS) {
  1995. GELOGE(FAILED, "Scope graph init failed.");
  1996. return;
  1997. }
  1998. domi::tensorflow::NodeDef node_def;
  1999. node_def.set_name("OP");
  2000. FusionScopesResult *fusion_scope_rlt = new (std::nothrow) FusionScopesResult();
  2001. if (fusion_scope_rlt == nullptr) {
  2002. GELOGE(FAILED, "FusionScopesResult make shared failed.");
  2003. return;
  2004. }
  2005. fusion_scope_rlt->Init();
  2006. fusion_scope_rlt->SetName("OP");
  2007. auto &impl_scope_graph = scope_graph->impl_;
  2008. std::string scope_name = fusion_scope_rlt->Name();
  2009. impl_scope_graph->fusion_results_.insert(std::make_pair(scope_name, fusion_scope_rlt));
  2010. std::vector<ge::OperatorPtr> nodes;
  2011. ge::OperatorPtr op = ge::parser::MakeShared<ge::Operator>("op_name", "op_type");
  2012. if (op == nullptr) {
  2013. GELOGE(FAILED, "Operator make shared failed.");
  2014. return;
  2015. }
  2016. nodes.push_back(op);
  2017. fusion_scope_rlt->impl_->AddNodes(nodes);
  2018. ge::OpDescPtr opDesc = std::make_shared<ge::OpDesc>();
  2019. ge::TensorFlowModelParser tf_model_parser;
  2020. Status ret = tf_model_parser.RecordFusionResult(scope_graph, &node_def, opDesc);
  2021. EXPECT_EQ(SUCCESS, ret);
  2022. }
  2023. TEST_F(STestTensorflowParser, Tensorflow_UpdateFusionOpContext_test)
  2024. {
  2025. ModelParserFactory* factory = ModelParserFactory::Instance();
  2026. shared_ptr<domi::ModelParser> model_parser = factory->CreateModelParser(domi::TENSORFLOW);
  2027. TensorFlowModelParser tensorflow_parser;
  2028. ScopeFusionOpInfo info;
  2029. ge::OpNodeContext normal_op_node_context;
  2030. ge::OpNodeContext fusion_op_node_context;
  2031. /* 1.预置条件 */
  2032. tensorflow::GraphDef *graph = new tensorflow::GraphDef();
  2033. ScopePassManager passmanager;
  2034. shared_ptr<ScopeGraph> scope_graph = passmanager.BuildScopeGraph(graph);
  2035. NodeDef * node1 = graph->add_node();
  2036. node1->set_name("conv_conv5/BatchNorm/batchnorm/add");
  2037. node1->set_op("Add");
  2038. node1->add_input("conv_conv5/BatchNorm/moving_variance");
  2039. node1->add_input("conv_conv5/BatchNorm/batchnorm/add/y");
  2040. NodeDef * node2 = graph->add_node();
  2041. node2->set_name("conv_conv5/BatchNorm/moving_variance");
  2042. node2->set_op("Const");
  2043. NodeDef * node3 = graph->add_node();
  2044. node3->set_name("conv_conv5/BatchNorm/batchnorm/add/y");
  2045. node3->set_op("Const");
  2046. info.fusion_node_name = "conv_conv5/BatchNorm/batchnorm";
  2047. info.fusion_op_type = ge::parser::FUSIONBATCHNORM;
  2048. info.node_name = "conv_conv5/BatchNorm/batchnorm/add";
  2049. info.description = "";
  2050. info.scope_pass = false;
  2051. EXPECT_EQ(scope_graph->impl_->GetFusionScopesResults(nullptr), nullptr);
  2052. EXPECT_EQ(scope_graph->impl_->GetFusionScopesResults(node1), nullptr);
  2053. Status ret = tensorflow_parser.UpdateFusionOpContext(scope_graph, info, fusion_op_node_context, normal_op_node_context);
  2054. EXPECT_EQ(ret, domi::SUCCESS);
  2055. delete graph;
  2056. }
  2057. TEST_F(STestTensorflowParser, Tensorflow_GetInOutPutIndex_scope_pass)
  2058. {
  2059. ModelParserFactory* factory = ModelParserFactory::Instance();
  2060. shared_ptr<domi::ModelParser> model_parser = factory->CreateModelParser(domi::TENSORFLOW);
  2061. TensorFlowModelParser tensorflow_parser;
  2062. tensorflow::GraphDef *graph = new tensorflow::GraphDef();
  2063. ScopePassManager passmanager;
  2064. shared_ptr<ScopeGraph> scope_graph = passmanager.BuildScopeGraph(graph);
  2065. FusionScopesResult* fusion_rlt = new FusionScopesResult();
  2066. fusion_rlt->Init();
  2067. fusion_rlt->impl_->inputs_.insert(std::make_pair<string, vector<int32_t>>("fw/fw/ToInt32" ,{0}));
  2068. fusion_rlt->impl_->inputs_.insert(std::make_pair<string, vector<int32_t>>("bw/bw/ToInt32" ,{0}));
  2069. fusion_rlt->impl_->inputs_.insert(std::make_pair<string, vector<int32_t>>("bw/ReverseSequence" ,{0, 1}));
  2070. fusion_rlt->impl_->inputs_.insert(std::make_pair<string, vector<int32_t>>("bw/ReverseSequence" ,{1}));
  2071. fusion_rlt->impl_->outputs_.insert(std::make_pair<string, vector<int32_t>>("concat" ,{0}));
  2072. fusion_rlt->impl_->outputs_.insert(std::make_pair<string, vector<int32_t>>("fw/fw/while/Exit_3" ,{1}));
  2073. fusion_rlt->impl_->outputs_.insert(std::make_pair<string, vector<int32_t>>("fw/fw/while/Exit_4" ,{2}));
  2074. fusion_rlt->impl_->outputs_.insert(std::make_pair<string, vector<int32_t>>("bw/bw/while/Exit_3" ,{3}));
  2075. fusion_rlt->impl_->outputs_.insert(std::make_pair<string, vector<int32_t>>("bw/bw/while/Exit_4" ,{4}));
  2076. fusion_rlt->SetType("dynamic_rnn");
  2077. fusion_rlt->SetName("dynamic_rnn_node1");
  2078. scope_graph->impl_->AddFusionScopesResult(fusion_rlt);
  2079. ScopeFusionOpInfo info1;
  2080. info1.node_name = "fw/fw/ToInt32";
  2081. info1.fusion_node_name = "dynamic_rnn_node1";
  2082. info1.fusion_op_type = "dynamic_rnn";
  2083. info1.description = "";
  2084. info1.scope_pass = true;
  2085. bool ignore = false;
  2086. ignore = tensorflow_parser.FusionOpChildIgnore(scope_graph, info1);
  2087. EXPECT_EQ(true, !ignore);
  2088. ScopeFusionOpInfo info2;
  2089. info2.node_name = "fw/fw/others";
  2090. info2.fusion_node_name = "dynamic_rnn_node1";
  2091. info2.fusion_op_type = "dynamic_rnn";
  2092. info2.description = "";
  2093. info2.scope_pass = true;
  2094. ignore = tensorflow_parser.FusionOpChildIgnore(scope_graph, info2);
  2095. EXPECT_EQ(true, ignore);
  2096. ScopeFusionOpInfo input_node_info;
  2097. input_node_info.node_name = "fw/fw/ToInt32";
  2098. input_node_info.fusion_node_name = "dynamic_rnn_node1";
  2099. input_node_info.fusion_op_type = "dynamic_rnn";
  2100. input_node_info.description = "";
  2101. input_node_info.scope_pass = true;
  2102. ScopeFusionOpInfo output_node_info;
  2103. output_node_info.node_name = "fw/fw/while/Exit_3";
  2104. output_node_info.fusion_node_name = "dynamic_rnn_node1";
  2105. output_node_info.fusion_op_type = "dynamic_rnn";
  2106. output_node_info.description = "";
  2107. output_node_info.scope_pass = true;
  2108. int32_t old_index = 0, new_index = -1;
  2109. Status ret = tensorflow_parser.GetInPutIndex(scope_graph, input_node_info, old_index, new_index);
  2110. EXPECT_EQ(domi::SUCCESS, ret);
  2111. EXPECT_EQ(true, (new_index == 0));
  2112. ret = tensorflow_parser.GetOutPutIndex(scope_graph, output_node_info, old_index, new_index);
  2113. EXPECT_EQ(domi::SUCCESS, ret);
  2114. EXPECT_EQ(true, (new_index == 1));
  2115. delete graph;
  2116. }
  2117. TEST_F(STestTensorflowParser, Tensorflow_AddFusionNodeDef_add_fusion_op_succ)
  2118. {
  2119. ModelParserFactory* factory = ModelParserFactory::Instance();
  2120. shared_ptr<domi::ModelParser> model_parser = factory->CreateModelParser(domi::TENSORFLOW);
  2121. TensorFlowModelParser tensorflow_parser;
  2122. string fusion_op_name = "dropout";
  2123. string fusion_op_type = "Dropout";
  2124. string description = "test/dropout";
  2125. tensorflow_parser.fusion_op_type_map_[fusion_op_name].push_back(fusion_op_type);
  2126. tensorflow_parser.fusion_op_type_map_[fusion_op_name].push_back(description);
  2127. // op_node_context for fusion op
  2128. ge::OpNodeContext op_node_context;
  2129. op_node_context.input_map["pre_node_a"].push_back({0, 0});
  2130. op_node_context.input_map["pre_node_b"].push_back({0, 1});
  2131. tensorflow_parser.op_node_context_map_[fusion_op_name] = op_node_context;
  2132. // origin inner node def
  2133. NodeDef* node_def = new (std::nothrow) NodeDef();
  2134. node_def->set_name("scope_node_1");
  2135. node_def->set_op("Add");
  2136. tensorflow_parser.fusion_op_nodedef_map_[fusion_op_name].push_back(node_def);
  2137. ScopePassManager pass_manager;
  2138. tensorflow::GraphDef *graph = new (std::nothrow) tensorflow::GraphDef();
  2139. shared_ptr<ScopeGraph> scope_graph = pass_manager.BuildScopeGraph(graph);
  2140. vector<string> node_name_list = {fusion_op_name};
  2141. Status ret = tensorflow_parser.AddFusionNodeDef(scope_graph, node_name_list);
  2142. EXPECT_EQ(ret, SUCCESS);
  2143. EXPECT_EQ(tensorflow_parser.nodedef_map_.size(), 1);
  2144. auto fusion_node_def = tensorflow_parser.nodedef_map_[fusion_op_name];
  2145. EXPECT_NE(fusion_node_def, nullptr);
  2146. EXPECT_EQ(fusion_node_def->op(), fusion_op_type);
  2147. delete node_def;
  2148. delete graph;
  2149. tensorflow_parser.DeleteFuisonNodeDef();
  2150. }
  2151. TEST_F(STestTensorflowParser, remain_dpop_node)
  2152. {
  2153. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
  2154. ge::OpDescPtr op = std::make_shared<ge::OpDesc>("dpop_123", "FrameworkOp");
  2155. ge::NodePtr node = std::make_shared<ge::Node>(op, graph);
  2156. graph->AddNode(node);
  2157. ModelParserFactory* factory = ModelParserFactory::Instance();
  2158. shared_ptr<domi::ModelParser> model_parser= factory->CreateModelParser(domi::TENSORFLOW);
  2159. ASSERT_TRUE(NULL != model_parser);
  2160. TensorFlowModelParser tensorflow_parser;
  2161. Status ret = tensorflow_parser.RemoveIsolateNode(graph);
  2162. EXPECT_EQ(domi::SUCCESS, ret);
  2163. }
  2164. TEST_F(STestTensorflowParser, tensorflow_UpdateEdgesControlInfo_test)
  2165. {
  2166. TensorFlowModelParser model_parser;
  2167. ge::ScopeFusionOpInfo info;
  2168. info.fusion_node_name = "conv_conv5/BatchNorm/batchnorm";
  2169. info.fusion_op_type = ge::parser::FUSIONBATCHNORM;
  2170. info.node_name = "conv_conv5/BatchNorm/batchnorm/add";
  2171. info.description = "";
  2172. info.scope_pass = false;
  2173. model_parser.UpdateEdgesControlInfo(info);
  2174. }
  2175. TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test)
  2176. {
  2177. TensorFlowModelParser model_parser;
  2178. NodeDef *node_def = new NodeDef();
  2179. node_def->set_name("Placeholder");
  2180. node_def->set_op("Placeholder_0");
  2181. std::map<string, NodeDef *> nodedef_map;
  2182. nodedef_map.emplace("Placeholder", node_def);
  2183. std::string curr_node_name = "Placeholder";
  2184. bool clear_input_flag = true;
  2185. Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
  2186. EXPECT_EQ(ret, INTERNAL_ERROR);
  2187. GraphDef graph;
  2188. curr_node_name = "pre_node_a";
  2189. nodedef_map.emplace("pre_node_a", node_def);
  2190. node_def->set_op("pre_node_a");
  2191. GenOriginContext(&model_parser, curr_node_name);
  2192. ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag);
  2193. EXPECT_EQ(ret, SUCCESS);
  2194. delete node_def;
  2195. }
  2196. TEST_F(STestTensorflowParser, tensorflow_OptimizeSnapShot_test)
  2197. {
  2198. TensorFlowModelParser model_parser;
  2199. tensorflow::NodeDef *curr_mode_def = initNodeDef();
  2200. std::map<string, NodeDef *> nodedef_map;
  2201. nodedef_map.emplace("pre_node_a", curr_mode_def);
  2202. std::pair<string, int> input_data;
  2203. std::vector<string> control_list;
  2204. std::string curr_node_name = "pre_node_a";
  2205. GenOriginContext(&model_parser, curr_node_name);
  2206. Status ret = model_parser.OptimizeSnapShot(curr_mode_def, nodedef_map, input_data, control_list);
  2207. EXPECT_EQ(ret, INTERNAL_ERROR);
  2208. curr_mode_def->set_name("pre_node_a");
  2209. GenOriginContext(&model_parser, curr_node_name);
  2210. ret = model_parser.OptimizeSnapShot(curr_mode_def, nodedef_map, input_data, control_list);
  2211. EXPECT_EQ(ret, SUCCESS);
  2212. }
  2213. TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeSnapShot_test)
  2214. {
  2215. TensorFlowModelParser model_parser;
  2216. tensorflow::GraphDef graph_def;
  2217. tensorflow::NodeDef *curr_mode_def = initNodeDef();
  2218. std::map<string, NodeDef *> nodedef_map;
  2219. nodedef_map.emplace("pre_node_a", curr_mode_def);
  2220. std::vector<NodeDef *> nodedef_to_optimize;
  2221. nodedef_to_optimize.emplace_back(curr_mode_def);
  2222. Status ret = model_parser.GraphDefOptimizeSnapShot(&graph_def, nodedef_map, nodedef_to_optimize);
  2223. EXPECT_EQ(ret, FAILED);
  2224. }
  2225. TEST_F(STestTensorflowParser, tensorflow_SetDestNodeName_test)
  2226. {
  2227. TensorFlowModelParser model_parser;
  2228. GraphDef graph;
  2229. auto arg0 = AddNode(graph, "_Arg", "arg0");
  2230. auto identity0 = AddNode(graph, "Identity", "identity0");
  2231. auto add0 = AddNode(graph, "Add", "add0");
  2232. int32_t input_idx = 0;
  2233. bool is_control = true;
  2234. bool clear_input_flag = true;
  2235. AddInput(arg0, identity0, 0);
  2236. AddInput(identity0, add0, 0);
  2237. Status ret = model_parser.SetDestNodeName(identity0, add0, input_idx, is_control, clear_input_flag);
  2238. EXPECT_EQ(ret, SUCCESS);
  2239. }
  2240. TEST_F(STestTensorflowParser, tensorflow_OptimizeDestroyTemporaryVariable_test)
  2241. {
  2242. ModelParserFactory* factory = ModelParserFactory::Instance();
  2243. shared_ptr<domi::ModelParser> model_parser= factory->CreateModelParser(domi::TENSORFLOW);
  2244. TensorFlowModelParser tensorflow_parser;
  2245. GraphDef graph;
  2246. auto const0 = AddNode(graph, "Const", "Const0");
  2247. auto tmpVar0 = AddNode(graph, "TemporaryVariable", "TemporaryVariable0");
  2248. auto assign0 = AddNode(graph, "Assign", "Assign0");
  2249. auto destroy0 = AddNode(graph, "DestroyTemporaryVariable", "DestroyTemporaryVariable0");
  2250. auto add0 = AddNode(graph, "Add", "Add0");
  2251. google::protobuf::Map< std::string, tensorflow::AttrValue> *node_attr_map = tmpVar0->mutable_attr();
  2252. tensorflow::AttrValue var_name_attr_value;
  2253. var_name_attr_value.set_s("temporary_variable_name");
  2254. (*node_attr_map)[ge::VAR_ATTR_NAME] = var_name_attr_value;
  2255. google::protobuf::Map<std::string, tensorflow::AttrValue>* node_attr_map_destroy = destroy0->mutable_attr();
  2256. tensorflow::AttrValue var_name_attr_value_destroy;
  2257. var_name_attr_value_destroy.set_s("destroy_temporary_variable_name");
  2258. (*node_attr_map_destroy)[ge::VAR_ATTR_NAME] = var_name_attr_value_destroy;
  2259. AddInput(tmpVar0, assign0, 0);
  2260. AddInput(assign0, destroy0, 0);
  2261. AddInput(const0, add0, 0);
  2262. AddInput(destroy0, add0, 1);
  2263. GraphDef* graphDef = &graph;
  2264. int32_t no_input_node_size_original = 0;
  2265. for (int w = 0; w < graphDef->node_size(); w++) {
  2266. tensorflow::NodeDef* nodeTmp = graphDef->mutable_node(w);
  2267. if (nodeTmp->input_size() == 0) {
  2268. no_input_node_size_original++;
  2269. }
  2270. }
  2271. Status ret = tensorflow_parser.GraphDefOptimize(graphDef);
  2272. int32_t no_input_node_size_result = 0;
  2273. for (int w = 0; w < graphDef->node_size(); w++) {
  2274. tensorflow::NodeDef* nodeTmp = graphDef->mutable_node(w);
  2275. if (nodeTmp->input_size() == 0) {
  2276. no_input_node_size_result ++;
  2277. }
  2278. }
  2279. ASSERT_EQ(ret, domi::FAILED);
  2280. ASSERT_EQ(no_input_node_size_original, no_input_node_size_result);
  2281. }
  2282. TEST_F(STestTensorflowParser, tensorflow_OptimizeDestroyTemporaryVariable_test2)
  2283. {
  2284. ModelParserFactory* factory = ModelParserFactory::Instance();
  2285. shared_ptr<domi::ModelParser> model_parser= factory->CreateModelParser(domi::TENSORFLOW);
  2286. TensorFlowModelParser tensorflow_parser;
  2287. GraphDef graph;
  2288. auto const0 = AddNode(graph, "Const", "Const0");
  2289. auto tmpVar0 = AddNode(graph, "TemporaryVariable", "TemporaryVariable0");
  2290. auto assign0 = AddNode(graph, "Assign", "Assign0");
  2291. auto destroy0 = AddNode(graph, "DestroyTemporaryVariable", "DestroyTemporaryVariable0");
  2292. auto add0 = AddNode(graph, "Add", "Add0");
  2293. google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map = tmpVar0->mutable_attr();
  2294. tensorflow::AttrValue var_name_attr_value;
  2295. var_name_attr_value.set_s("temporary_variable_name");
  2296. (*node_attr_map)[ge::VAR_ATTR_NAME] = var_name_attr_value;
  2297. google::protobuf::Map<std::string, tensorflow::AttrValue> *node_attr_map_destroy = destroy0->mutable_attr();
  2298. tensorflow::AttrValue var_name_attr_value_destroy;
  2299. var_name_attr_value_destroy.set_s("temporary_variable_name");
  2300. (*node_attr_map_destroy)[ge::VAR_ATTR_NAME] = var_name_attr_value_destroy;
  2301. AddInput(tmpVar0, assign0, 0);
  2302. AddInput(assign0, destroy0, 0);
  2303. AddInput(const0, add0, 0);
  2304. AddInput(destroy0, add0, 1);
  2305. GraphDef* graphDef = &graph;
  2306. int32_t no_input_node_size_original = 0;
  2307. for (int w = 0; w < graphDef->node_size(); w++) {
  2308. tensorflow::NodeDef* nodeTmp = graphDef->mutable_node(w);
  2309. if (nodeTmp->input_size() == 0) {
  2310. no_input_node_size_original ++;
  2311. }
  2312. }
  2313. Status ret = tensorflow_parser.GraphDefOptimize(graphDef);
  2314. int32_t no_input_node_size_result = 0;
  2315. for (int w = 0; w < graphDef->node_size(); w++) {
  2316. tensorflow::NodeDef* nodeTmp = graphDef->mutable_node(w);
  2317. if (nodeTmp->input_size() == 0) {
  2318. no_input_node_size_result ++;
  2319. }
  2320. }
  2321. ASSERT_EQ(ret, domi::SUCCESS);
  2322. ASSERT_EQ(no_input_node_size_original, (no_input_node_size_result - 1));
  2323. }
  2324. TEST_F(STestTensorflowParser, tensorflow_AddControlEdgeAfterRemoveInputs_test)
  2325. {
  2326. tensorflow::GraphDef graph_def;
  2327. TensorFlowModelParser tensorflow_parser;
  2328. tensorflow::NodeDef *node_def = initNodeDef();
  2329. node_def->set_name("Add0");
  2330. node_def->set_op("add");
  2331. std::map<std::string, NodeDef *> all_node_map;
  2332. all_node_map.emplace("Add0", node_def);
  2333. std::vector<std::string> removed_inputs_vec;
  2334. removed_inputs_vec.emplace_back("Add0");
  2335. Status ret = tensorflow_parser.AddControlEdgeAfterRemoveInputs(&graph_def, node_def, all_node_map, removed_inputs_vec);
  2336. EXPECT_EQ(ret, SUCCESS);
  2337. }
  2338. TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test)
  2339. {
  2340. tensorflow::GraphDef graph_def;
  2341. TensorFlowModelParser tensorflow_parser;
  2342. tensorflow::NodeDef *node_def = initNodeDef();
  2343. node_def->set_name("post_node_d");
  2344. std::map<string, NodeDef *> nodedef_map;
  2345. nodedef_map.emplace("post_node_d", node_def);
  2346. nodedef_map.emplace("post_node_a", node_def);
  2347. nodedef_map.emplace("post_node_b", node_def);
  2348. std::vector<NodeDef *> nodedef_to_optimize;
  2349. nodedef_to_optimize.emplace_back(node_def);
  2350. std::string curr_node_name = "post_node_b";
  2351. GenOriginContext(&tensorflow_parser, curr_node_name);
  2352. Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize);
  2353. EXPECT_EQ(ret, ge::PARAM_INVALID);
  2354. }
  2355. TEST_F(STestTensorflowParser, tensorflow_RemoveInputs_test)
  2356. {
  2357. tensorflow::GraphDef graph_def;
  2358. tensorflow::NodeDef *node_def = initNodeDef();
  2359. node_def->set_name("OP");
  2360. node_def->add_input("OP/Input_1");
  2361. node_def->add_input("OP/Input_2");
  2362. std::set<uint32_t> remove_index_set;
  2363. std::map<std::string, NodeDef *> all_node_map;
  2364. TensorFlowModelParser model_parser;
  2365. Status ret = model_parser.RemoveInputs(&graph_def, node_def, remove_index_set, all_node_map);
  2366. EXPECT_EQ(ret, SUCCESS);
  2367. remove_index_set.emplace(0);
  2368. ret = model_parser.RemoveInputs(&graph_def, node_def, remove_index_set, all_node_map);
  2369. EXPECT_EQ(ret, FAILED);
  2370. }
  2371. TEST_F(STestTensorflowParser, tensorflow_UpdateInnerNodeContext_test)
  2372. {
  2373. std::string fusion_op_name = "post_node_a";
  2374. std::vector<std::string> inner_nodes_name;
  2375. inner_nodes_name.emplace_back("post_node_a");
  2376. TensorFlowModelParser model_parser;
  2377. Status ret = model_parser.UpdateInnerNodeContext(fusion_op_name, inner_nodes_name);
  2378. EXPECT_EQ(ret, INTERNAL_ERROR);
  2379. GenOriginContext(&model_parser, fusion_op_name);
  2380. ret = model_parser.UpdateInnerNodeContext(fusion_op_name, inner_nodes_name);
  2381. EXPECT_EQ(ret, SUCCESS);
  2382. }
  2383. TEST_F(STestTensorflowParser, tensorflow_UpdateInnerInputMap_test)
  2384. {
  2385. string fusion_op_name = "post_node_a";
  2386. OpNodeContext fusion_context;
  2387. std::vector<std::string> inner_nodes_name;
  2388. inner_nodes_name.emplace_back("post_node_a");
  2389. std::set<string> fusion_input_nodes;
  2390. fusion_input_nodes.insert("post_node_a");
  2391. TensorFlowModelParser model_parser;
  2392. GenOriginContext(&model_parser, fusion_op_name);
  2393. model_parser.UpdateInnerInputMap(fusion_op_name, fusion_context, inner_nodes_name, fusion_input_nodes);
  2394. }
  2395. TEST_F(STestTensorflowParser, tensorflow_UpdateInnerOutputMap_test)
  2396. {
  2397. string fusion_op_name = "post_node_a";
  2398. OpNodeContext fusion_context;
  2399. std::vector<std::string> inner_nodes_name;
  2400. inner_nodes_name.emplace_back("post_node_a");
  2401. std::set<string> fusion_output_nodes;
  2402. fusion_output_nodes.insert("post_node_a");
  2403. TensorFlowModelParser model_parser;
  2404. GenOriginContext(&model_parser, fusion_op_name);
  2405. model_parser.UpdateInnerOutputMap(fusion_op_name, fusion_context, inner_nodes_name, fusion_output_nodes);
  2406. }
  2407. TEST_F(STestTensorflowParser, tensorflow_ScopePassManager_AddPass_test)
  2408. {
  2409. ScopePassManager passmanager;
  2410. tensorflow::GraphDef *graph = new tensorflow::GraphDef();
  2411. shared_ptr<ScopeGraph> scope_graph = passmanager.BuildScopeGraph(graph);
  2412. unique_ptr<ScopeBasePass> pass;
  2413. pass.reset(new ScopeTestPass());
  2414. EXPECT_EQ(ge::SUCCESS, passmanager.AddPass(pass));
  2415. EXPECT_NE(ge::SUCCESS, passmanager.Run(scope_graph));
  2416. delete graph;
  2417. graph = nullptr;
  2418. }
  2419. TEST_F(STestTensorflowParser, tensorflow_CheckAttrHasType_test1)
  2420. {
  2421. tensorflow::AttrValue attr_value;
  2422. attr_value.mutable_list();
  2423. Status ret = TensorFlowUtil::CheckAttrHasType(attr_value, "int");
  2424. EXPECT_EQ(FAILED, ret);
  2425. attr_value.set_type(DT_INVALID);
  2426. ret = TensorFlowUtil::CheckAttrHasType(attr_value, "type");
  2427. EXPECT_EQ(FAILED, ret);
  2428. tensorflow::AttrValue attr_value2;
  2429. AttrValue_ListValue *list = attr_value2.mutable_list();
  2430. list->add_type(tensorflow::DT_FLOAT);
  2431. list->add_type((tensorflow::DataType)30);
  2432. ret = TensorFlowUtil::CheckAttrHasType(attr_value2, "list(type)");
  2433. EXPECT_EQ(FAILED, ret);
  2434. }
  2435. TEST_F(STestTensorflowParser, tensorflow_CheckAttrHasType_test2)
  2436. {
  2437. tensorflow::AttrValue attr_value;
  2438. AttrValue_ListValue * list = attr_value.mutable_list();
  2439. list->add_type(tensorflow::DT_FLOAT);
  2440. list->add_type(tensorflow::DT_INVALID);
  2441. Status ret = TensorFlowUtil::CheckAttrHasType(attr_value, "list(type)");
  2442. EXPECT_EQ(FAILED, ret);
  2443. attr_value.set_placeholder("test");
  2444. ret = TensorFlowUtil::CheckAttrHasType(attr_value, "");
  2445. EXPECT_EQ(FAILED, ret);
  2446. }
  2447. TEST_F(STestTensorflowParser, tensorflow_TransTensorDescriptor_test)
  2448. {
  2449. tensorflow::AttrValue attr_value;
  2450. AttrValue_ListValue *list = attr_value.mutable_list();
  2451. list->add_type(tensorflow::DT_FLOAT);
  2452. ParserOperator op;
  2453. uint32_t io = TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG;
  2454. std::string type = ge::parser::FUSEDBATCHNORMGRAD;
  2455. Status ret = TensorFlowUtil::TransTensorDescriptor(attr_value, &op, io, type);
  2456. EXPECT_EQ(ret, SUCCESS);
  2457. io = TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG;
  2458. ret = TensorFlowUtil::TransTensorDescriptor(attr_value, &op, io, type);
  2459. EXPECT_EQ(ret, SUCCESS);
  2460. }
  2461. TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeDestroyTemporaryVariable_test)
  2462. {
  2463. tensorflow::GraphDef *graph_def = nullptr;
  2464. tensorflow::NodeDef *nodeCurrent = initNodeDef();
  2465. TensorFlowModelParser model_parser;
  2466. Status ret = model_parser.GraphDefOptimizeDestroyTemporaryVariable(graph_def, nodeCurrent);
  2467. EXPECT_EQ(ret, FAILED);
  2468. }
  2469. TEST_F(STestTensorflowParser, tensorflow_GetFunctionProto_test)
  2470. {
  2471. std::cout << __FILE__ << std::endl;
  2472. std::string caseDir = __FILE__;
  2473. std::size_t idx = caseDir.find_last_of("/");
  2474. caseDir = caseDir.substr(0, idx);
  2475. std::string file = caseDir + "/origin_models/test_enter.pb";
  2476. domi::tensorflow::GraphDefLibrary graph_def_library;
  2477. TensorFlowModelParser model_parser;
  2478. Status ret = model_parser.GetFunctionProto(file, graph_def_library);
  2479. EXPECT_EQ(ret, FAILED);
  2480. }
  2481. TEST_F(STestTensorflowParser, tensorflow_GetNodeFormat_test)
  2482. {
  2483. NodeDef *node_def1 = initNodeDef();
  2484. node_def1->set_op("NoOp");
  2485. node_def1->set_name("NoOp");
  2486. NodeDef *node_def2 = initNodeDef();
  2487. node_def2->set_op("Add");
  2488. node_def2->set_name("Add0");
  2489. TfTranspose pred_transpose = TO_NCHW;
  2490. domiTensorFormat_t format = domi::DOMI_TENSOR_NC1HWC0;
  2491. std::set<const NodeDef *> visited_node;
  2492. visited_node.emplace(node_def2);
  2493. TensorFlowModelParser model_parser;
  2494. Status ret = model_parser.GetNodeFormat(node_def1, pred_transpose, format, visited_node);
  2495. EXPECT_EQ(ret, FAILED);
  2496. delete node_def1;
  2497. delete node_def2;
  2498. }
  2499. TEST_F(STestTensorflowParser, tensorflow_GetFormatTranspose_test)
  2500. {
  2501. NodeDef *transpose_node = initNodeDef();
  2502. transpose_node->set_op("Transpose");
  2503. TfTranspose transpose_direc = NO_TRANSPOSE;
  2504. TensorFlowModelParser modelParser;
  2505. Status ret = modelParser.GetFormatTranspose(transpose_node, transpose_direc);
  2506. EXPECT_EQ(ret, FAILED);
  2507. ge::TensorFlowModelParser parser;
  2508. GraphDef graph;
  2509. auto arg0 = AddNode(graph, "_Arg", "arg0");
  2510. auto snapshot0 = AddNode(graph, "Snapshot", "snapshot0");
  2511. auto ret0 = AddNode(graph, "_Retval", "retval0");
  2512. auto arg1 = AddNode(graph, "_Arg", "arg1");
  2513. auto snapshot1 = AddNode(graph, "Snapshot", "snapshot1");
  2514. auto ret1 = AddNode(graph, "_Retval", "retval1");
  2515. auto arg2 = AddNode(graph, "_Arg", "arg2");
  2516. auto snapshot2 = AddNode(graph, "Snapshot", "snapshot2");
  2517. auto ret2 = AddNode(graph, "_Retval", "retval2");
  2518. AddInput(arg0, snapshot0, 0);
  2519. AddInput(snapshot0, ret0, 0);
  2520. AddInput(arg1, snapshot1, 0);
  2521. AddInput(snapshot1, ret1, 0);
  2522. AddInput(arg2, snapshot2, 0);
  2523. AddInput(snapshot2, ret2, 0);
  2524. AddInput(snapshot0, snapshot1, -1);
  2525. AddInput(snapshot1, snapshot2, -1);
  2526. ASSERT_EQ(parser.GraphDefOptimize(&graph), domi::SUCCESS);
  2527. ASSERT_EQ(ret1->input_size(), 2);
  2528. ret = modelParser.GetFormatTranspose(ret1, transpose_direc);
  2529. EXPECT_EQ(ret, SUCCESS);
  2530. delete transpose_node;
  2531. }
  2532. TEST_F(STestTensorflowParser, tensorflow_GetTensorflowGraphInOutMap_test)
  2533. {
  2534. TensorFlowModelParser model_parser;
  2535. tensorflow::GraphDef *graph = new tensorflow::GraphDef();
  2536. tensorflow::NodeDef *node_input = graph->add_node();
  2537. node_input->set_name("name_input");
  2538. node_input->set_op("op_input");
  2539. AddGraphNode(graph, "t_lstm/t_lstm_cell/Sigmoid5", "Sigmoid", "node_input");
  2540. AddGraphNode(graph, "t_lstm/t_lstm_cell/Sigmoid6", "Sigmoid", "node_input");
  2541. AddGraphNode(graph, "t_lstm/t_lstm_cell/Sigmoid7", "Sigmoid", "node_input");
  2542. AddGraphNode(graph, "t_lstm/t_lstm_cell/Mul5", "Mul", "node_input");
  2543. AddGraphNode(graph, "t_lstm/t_lstm_cell/Mul6", "Mul", "node_input");
  2544. AddGraphNode(graph, "t_lstm/t_lstm_cell/Mul7", "Mul", "node_input");
  2545. AddGraphNode(graph, "t_lstm/t_lstm_cell/Relu5", "Relu", "node_input");
  2546. AddGraphNode(graph, "t_lstm/t_lstm_cell/Relu6", "Relu", "node_input");
  2547. Status ret = model_parser.GetTensorflowGraphInOutMap(graph);
  2548. EXPECT_EQ(ret, SUCCESS);
  2549. delete graph;
  2550. }
  2551. TEST_F(STestTensorflowParser, tensorflow_RemoveIsolateNode_test)
  2552. {
  2553. TensorFlowModelParser model_parser;
  2554. tensorflow::GraphDef graph;
  2555. CreateGraphDef(graph);
  2556. Status ret = model_parser.RemoveIsolateNode(&graph);
  2557. EXPECT_EQ(ret, FAILED);
  2558. }
  2559. TEST_F(STestTensorflowParser, tensorflow_AddNodeToGraphAndMarkFormat_test)
  2560. {
  2561. TensorFlowModelParser model_parser;
  2562. ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("default");
  2563. std::vector<std::string> op_node_name_list = {"Const", "placeholder0"};
  2564. GenOriginNodeDef(&model_parser, op_node_name_list);
  2565. Status ret = model_parser.AddNodeToGraphAndMarkFormat(graph, op_node_name_list);
  2566. EXPECT_EQ(ret, INTERNAL_ERROR);
  2567. }
  2568. TEST_F(STestTensorflowParser, tensorflow_ParserNodeDef1_test)
  2569. {
  2570. ge::ComputeGraphPtr compute_graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
  2571. ModelParserFactory* factory = ModelParserFactory::Instance();
  2572. shared_ptr<ModelParser> model_parser= factory->CreateModelParser(domi::TENSORFLOW);
  2573. ASSERT_TRUE(NULL != model_parser);
  2574. TensorFlowModelParser tensorflow_parser;
  2575. tensorflow_parser.adaptedOpTypeMap_["test_name"] = "POOLING";
  2576. std::mutex graphMutex;
  2577. tensorflow::GraphDef *graph = new tensorflow::GraphDef();
  2578. ScopePassManager passmanager;
  2579. shared_ptr<ScopeGraph> scope_graph = passmanager.BuildScopeGraph(graph);
  2580. domi::tensorflow::NodeDef node_def;
  2581. node_def.set_name("test_name");
  2582. node_def.set_op("POOLING");
  2583. error_message::Context error_context;
  2584. Status ret = ge::TensorFlowModelParser::ParseNodeDef(&tensorflow_parser, compute_graph, &graphMutex, scope_graph, &node_def, error_context);
  2585. EXPECT_EQ(FAILED, ret);
  2586. delete graph;
  2587. }
  2588. TEST_F(STestTensorflowParser, tensorflow_ParserNodeDef2_test)
  2589. {
  2590. ge::ComputeGraphPtr compute_graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
  2591. ModelParserFactory* factory = ModelParserFactory::Instance();
  2592. shared_ptr<ModelParser> model_parser= factory->CreateModelParser(domi::TENSORFLOW);
  2593. ASSERT_TRUE(NULL != model_parser);
  2594. TensorFlowModelParser tensorflow_parser;
  2595. tensorflow_parser.adaptedOpTypeMap_["Pooling"] = "Pooling";
  2596. std::mutex graphMutex;
  2597. tensorflow::GraphDef *graph = new tensorflow::GraphDef();
  2598. ScopePassManager passmanager;
  2599. shared_ptr<ScopeGraph> scope_graph = passmanager.BuildScopeGraph(graph);
  2600. REGISTER_CUSTOM_OP("Pooling")
  2601. .FrameworkType(domi::TENSORFLOW)
  2602. .OriginOpType("Pooling")
  2603. .ParseParamsFn(ParseParams)
  2604. .ImplyType(ImplyType::TVM);
  2605. register_tbe_op();
  2606. domi::tensorflow::NodeDef node_def;
  2607. node_def.set_name("Pooling");
  2608. node_def.set_op("Pooling");
  2609. error_message::Context error_context;
  2610. Status ret = ge::TensorFlowModelParser::ParseNodeDef(&tensorflow_parser, compute_graph, &graphMutex, scope_graph, &node_def, error_context);
  2611. EXPECT_EQ(FAILED, ret);
  2612. delete graph;
  2613. }
  2614. TEST_F(STestTensorflowParser, tensorflow_AddExternalGraph_test)
  2615. {
  2616. TensorFlowModelParser modelParser;
  2617. ge::ComputeGraphPtr subGraph = std::make_shared<ge::ComputeGraph>("default");
  2618. std::string inputNodeType = "DATA";
  2619. MakeDagGraph(subGraph, inputNodeType);
  2620. Status ret = modelParser.AddExternalGraph(subGraph);
  2621. EXPECT_EQ(ret, SUCCESS);
  2622. }
  2623. TEST_F(STestTensorflowParser, tensorflow_AddFmkNode_test)
  2624. {
  2625. TensorFlowModelParser model_parser;
  2626. ge::ComputeGraphPtr compute_graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
  2627. tensorflow::GraphDef *graphDef = new (std::nothrow) tensorflow::GraphDef();
  2628. ScopePassManager pass_manager;
  2629. std::shared_ptr<ScopeGraph> scope_graph = pass_manager.BuildScopeGraph(graphDef);
  2630. std::vector<std::string> op_node_name_list = {"Const", "placeholder0"};
  2631. GenOriginNodeDef(&model_parser, op_node_name_list);
  2632. Status ret = model_parser.AddFmkNode(compute_graph, scope_graph, op_node_name_list, false);
  2633. EXPECT_EQ(ret, PARAM_INVALID);
  2634. delete graphDef;
  2635. }
  2636. TEST_F(STestTensorflowParser, tensorflow_OptimizeConstNodes4CustomOp_test)
  2637. {
  2638. TensorFlowModelParser model_parser;
  2639. tensorflow::GraphDef graph_def;
  2640. CreateGraphDef(graph_def);
  2641. Status ret = model_parser.OptimizeConstNodes4CustomOp(&graph_def);
  2642. EXPECT_EQ(ret, SUCCESS);
  2643. }
  2644. TEST_F(STestTensorflowParser, tensorflow_ParseOpParams_test)
  2645. {
  2646. TensorFlowModelParser model_parser;
  2647. tensorflow::NodeDef *node_def = initNodeDef();
  2648. node_def->set_name("Pooling");
  2649. node_def->set_op("Pooling");
  2650. ge::OpDescPtr op = std::make_shared<ge::OpDesc>();
  2651. std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
  2652. std::shared_ptr<OpParser> op_parser = factory->CreateOpParser("Pooling");
  2653. Status ret = model_parser.ParseOpParams(node_def, op, op_parser);
  2654. EXPECT_EQ(ret, FAILED);
  2655. node_def->set_name("TensorArrayWrite");
  2656. node_def->set_op("TensorArrayWriteV3");
  2657. op_parser = factory->CreateOpParser("TensorArrayWrite");
  2658. ret = model_parser.ParseOpParams(node_def, op, op_parser);
  2659. EXPECT_EQ(ret, SUCCESS);
  2660. delete node_def;
  2661. }
  2662. TEST_F(STestTensorflowParser, tensorflow_AddFusionInnerNodeDef_test)
  2663. {
  2664. TensorFlowModelParser model_parser;
  2665. ge::ComputeGraphPtr compute_graph = std::make_shared<ge::ComputeGraph>(GRAPH_DEFAULT_NAME);
  2666. tensorflow::GraphDef *graphDef = new (std::nothrow) tensorflow::GraphDef();
  2667. ScopePassManager pass_manager;
  2668. std::shared_ptr<ScopeGraph> scope_graph = pass_manager.BuildScopeGraph(graphDef);
  2669. std::vector<std::string> op_node_name_list = {"Const", "placeholder0"};
  2670. FusionScopesResult *fusion_scope_rlt = new (std::nothrow) FusionScopesResult();
  2671. fusion_scope_rlt->Init();
  2672. fusion_scope_rlt->SetName("FusionCustom");
  2673. auto &impl_scope_graph = scope_graph->impl_;
  2674. std::string scope_name = fusion_scope_rlt->Name();
  2675. impl_scope_graph->fusion_results_.insert(std::make_pair(scope_name, fusion_scope_rlt));
  2676. std::string fusion_op_name = "FusionCustom";
  2677. GenOriginNodeDef(&model_parser, op_node_name_list);
  2678. GenFusionScopesResult(scope_graph, fusion_scope_rlt, fusion_op_name);
  2679. Status ret = model_parser.AddFusionInnerNodeDef(scope_graph, fusion_op_name, op_node_name_list);
  2680. delete graphDef;
  2681. }
  2682. } // namespace ge