From c0fd13503d9686264b3d4b488ede9bd14d81f739 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 30 May 2019 07:33:44 -0500 Subject: [PATCH] word_cnn save training step works. --- graph/word_cnn.meta | Bin 87561 -> 87561 bytes src/TensorFlowNET.Core/Train/Saving/Saver.cs | 25 +++++++-------- .../TextProcess/CnnTextClassification.cs | 29 ++++++------------ 3 files changed, 21 insertions(+), 33 deletions(-) diff --git a/graph/word_cnn.meta b/graph/word_cnn.meta index de19687b222527b918083507064ca1037af50972..141947b19b16be2167b8162ff6a80360d97b7491 100644 GIT binary patch delta 3070 zcmb_edrXs86!&~zDeVUo{3tCJ?E^uupnx(hRsj(onUfV?`$Gl_G>j<)EtvSo#JFjU zVOuzHce)MH%}rOP=nZznxk;wOOwCXX8XxFpVtmB;h?8tSw)??ahb?x=_Sbj6d(Qct z-#z!<-?<6vVnE;rgGGj`qdog*bX|cpiiZf^2SSd4+!VT{~Zw~iRl24F{oT5}X00Mz&5Q(d; zN;-S=Ou-lHV~d^_E&|E7F}L(-8rox-#Tfzs(-Lc)AFMl5^;AD?Md0uQGq_-NumE$} zLcz>RQCU-6RK23gj@`uudh6BpP+y|^3N$j|7+X-p3{ca%1u22jsAy(U1&D)%km_ZN z9L}mrr@g3H2UR{ zXw{kx$^oOOd1KuRtT*d{8!#62n^!128;O&Ts43k%PQ{HZ#Qe`SSn#RH>r$j~3Z42{ zJda~s(b&C9&Fj(%vF$)KhVKr;hOa`g$fd)u-3i#}ik0lx2Nqp)$e8h_fFfZ02g zIHM_vcJDHA3?vz!emRyFx?CK(7Ai64NF~Fo=&kYk>bZBfd;rcxqT+yOsV-+^Egwv*PE^}b1 zX9TZ&XW(=heIp$n{#st&8q@1`r%jGkZ6eJ+%nf*u47BEm5qPFS^uwbZkd#~yK#6JR zj9jIFa4a1k{UCVC-tT-8(-as=Uuw^W;J?*^Y}3FnF|K3k$qysPp1+&R1(Bfn9QWip;;|($x)|Ie#C7AKjw zR^)>Yw_Q_`a2^!&^nezm{b9M;I&#v~-nwK?KYS#9Sq`vnuSJ z2Sd=-qjcXE;R~;V_q*7v_Zi*CG%$l#L!_)4+RG~J zj%ug$LXf&vzc&QDL*CaztovjPyes#fak4HJ)_d;35OR1DXrxEUL+_N&TLjKaI0kYh z7G&TuBrgfp zJ^TAT3AR4_JD3ccJSoPzYe&H(4rl+StG}#>W`(I?o!wz~R@;|VE|(g-P{#`+Nd9Pu z{$Mhw$lh@f>K>U2BLF*38PsM!5u_s>hLMg;5XqnEpmrZk0}J33OI(y6XJHG;8x1tB?YEPKZM+eGN1^^tT<2; zy$FI|2$E-oIixWQqR7uy(2))+TWnnxL=xjfCMmW|h#;yg$&d-7$YU!cku?*UbyFsr z+cu6_TPLzPYZfHnzEe3QZxTHIAK`kkpq)!7Buh0gl(@3kmHCe>6Un#Pkl-%Oft`_m E0j;jCjQ{`u delta 2900 zcmcImX-rgC6z04)Gt4|>1|BfL0JDIA4vVNnmJvoNB_vj56zvZcktK=|L{Ss8sTj8! zqP3jZTWt!~l@AgzT|O$Ce^S6UNmHMYcPNsC?H46~RS6qBaE-aYqx_nhyY zbMJTWu4d`3W@%e?7(VUO(xR|5Mye!ouNq?Ufp~f;A{*E}Es-=pmkGY&fM?}3!0iMs z1pW0PxItfPb>m&V$@UF5#NfSadbtCMD-A2G4)Xa#2u8nw5Mqh~9Y#gIW1H7ydwoYPEpTG zktTUE(D0^YV0X8u?}M-f~AraC7A;EXoSO{j2?H*_vMPb9fxR!X;S|qE|&XR5i(L zxe_V+*m_)r zlfIH+jlkof0}GY%;*w>gg9Z8R`qJhCPQqOHO3?=KT*K~s~B zn4OHFrG{lvL=G)UsBDwr%*GHLJQ0SSr=-;XWdBQ0G_=Ii3GF2;lU9Hqok?R-t<0ua zkta{zcLU;tb`9#g47j63jn(bJ;#i6G?JC;*^*sjC>O{s#-;Be`PNg*~qe*qm1k{}l zv%!Dsl3`^RPiJ&8u12r_Ru8Or_~|DX7{Ik-QS zT5zdT0;vxH9es2)%_dQ3x|xDYuC2;)Xpd3MRIQ~;}Zs2!<&(|ve}Q=r^gcyZP< zbYeCHlXD7CVBGb0N71kT>kFOwg_au&JV*`QnCT#OH1od^LhY?Bqfi*Uwa7_v${V4 z7l!aSh}7hP%vNGjF9kWS{MDD%J&AqQvWLJ$h|2;`KHWC9u3r{ASZ)CgHat)e6$^6O z_sl#t>!cx*#H!jsqp~J7^FWOiPlD<9!^t+B=e5%w51+jmy)E-0u!Dt#%=oeD-!B%7 zzfQ0xO!6g->)(D@DF>mBh#;yIC%VSP~&bP$j_(7;*OX8c1(4gbNLmp};!Yuo=<$07KLN_{5;aHii06VaL+DpK2VrIZw(51pkJGyAGK{%HN zOYP>&LOq#lf_X$|7LT_w2~;H23>tF91U%V0T^tQ(LL3oE@nj$q!brn3@FPBENF;mA z5KGQX7rnKa5b&`cd`PniIDD~#!?V4qWKIUW_#eU!nW39xWQn "CNN Text Classification"; public int? DataLimit = null; public bool ImportGraph { get; set; } = true; - public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia - private string dataDir = "text_classification"; + private string dataDir = "word_cnn"; private string dataFileName = "dbpedia_csv.tar.gz"; private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; - + private const int NUM_CLASS = 14; private const int BATCH_SIZE = 64; private const int NUM_EPOCHS = 10; @@ -41,6 +40,7 @@ namespace TensorFlowNET.Examples private const int CHAR_MAX_LEN = 1014; protected float loss_value = 0; + int vocabulary_size = 50000; public bool Run() { @@ -63,10 +63,9 @@ namespace TensorFlowNET.Examples int[][] x = null; int[] y = null; int alphabet_size = 0; - int vocabulary_size = 0; var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); - vocabulary_size = len(word_dict); + // vocabulary_size = len(word_dict); (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); Console.WriteLine("\tDONE "); @@ -142,7 +141,7 @@ namespace TensorFlowNET.Examples if (valid_accuracy > max_accuracy) { max_accuracy = valid_accuracy; - saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step.ToString()); + saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step); print("Model is saved.\n"); } } @@ -218,18 +217,10 @@ namespace TensorFlowNET.Examples public void PrepareData() { - if (UseSubset) - { - var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip"; - Web.Download(url, dataDir, "dbpedia_subset.zip"); - Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); - } - else - { - string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; - Web.Download(url, dataDir, dataFileName); - Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); - } + // full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz + var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip"; + Web.Download(url, dataDir, "dbpedia_subset.zip"); + Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); if (ImportGraph) { @@ -242,7 +233,7 @@ namespace TensorFlowNET.Examples Console.WriteLine("Discarding cached file: " + meta_path); File.Delete(meta_path); } - var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; + url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; Web.Download(url, "graph", meta_file); } }