From 4cbf362fc67f52cea34d3b587bbeb334fb21df44 Mon Sep 17 00:00:00 2001 From: Luke Bollam Date: Wed, 1 Sep 2021 11:32:59 +0800 Subject: [PATCH 1/6] adding saved model cleanup benchmark --- .../Leak/SavedModelCleanup.cs | 31 ++++++++++++++++++ .../Leak/TestModel/saved_model/saved_model.pb | Bin 0 -> 24775 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 1465 bytes .../saved_model/variables/variables.index | Bin 0 -> 274 bytes .../Tensorflow.Benchmark.csproj | 16 +++++++++ 5 files changed, 47 insertions(+) create mode 100644 src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs create mode 100644 src/TensorFlowNet.Benchmarks/Leak/TestModel/saved_model/saved_model.pb create mode 100644 src/TensorFlowNet.Benchmarks/Leak/TestModel/saved_model/variables/variables.data-00000-of-00001 create mode 100644 src/TensorFlowNet.Benchmarks/Leak/TestModel/saved_model/variables/variables.index diff --git a/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs b/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs new file mode 100644 index 00000000..e9e1e75f --- /dev/null +++ b/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs @@ -0,0 +1,31 @@ +using BenchmarkDotNet.Attributes; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; + +namespace Tensorflow.Benchmark.Leak +{ + + public class SavedModelCleanup + { + [Benchmark] + public void Run() + { + var modelDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); + var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); + + for (var i = 0; i < 1000; i++) + { + var session = Session.LoadFromSavedModel(ClassifierModelPath); + + session.graph.Exit(); + session.graph.Dispose(); + session.Dispose(); + } + } + } +} diff --git a/src/TensorFlowNet.Benchmarks/Leak/TestModel/saved_model/saved_model.pb b/src/TensorFlowNet.Benchmarks/Leak/TestModel/saved_model/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..f75f28564731f1d6bf48b02c41d8d304881b5e70 GIT binary patch literal 24775 zcmeHPYiwM{b>`hAxtE9U;Zs9Bs4L5sH@3*VvJ-ieIuR))R%*+tNZUoy==yT^l3H`Q zyWP7>Q8LsbK?}D|-;We={7+M^W@o(Iow5 z?)ydV%CaN3ks#f@b7$t9Ip@qdXU?1%Rbk)ymO`IjVE;HzYczGqwre|$m(6y~+^SnI zw%8b{ZOd+U+ZAhZmQB#mzFF^D+EBG~r)6p5xjOXhxq_a{0lbVRdA5YmD zAV$@)*+{)+cg_L*$F<`$)wM07(ycZcMzdk;w7P6?tJ$pc+&M*if(k9OS_K6?N|uk& z0vp`8^uiT3^zx+`)4gD3_DNi)vDF#)H-*< z-DdSfAE2SD_BFG?hDkXPJQ*V8&LU|u0XKlPJ~%O>X_DpsoS&C$rNWwi{TZTq|$$e0@=6VsR$m>-i;t~x{mMcXMFH6fz=&YTpz91nV#rb&I=QOd61 z(WeC(&H$mYTnz`X(0Z~|8{U-9D78Z~l#^Q{WsWv!1f6%9ZAdPpYS@ucNaz%eIiQS& zxhG|ImZl^e1FFKXp`=79pQeIS&(?>fjj}OI++U2OniQt$ap-v|!F3wmF!%2h8tWI@ zM*$7VWVSqtnF_7y4i-J263$Fx>>KoI1ZZa9b3B!9uT^eXJJ!L+yY>1DW~XvXP*hrZ$R}jl@qS=PAU!R!1A`9} zE_5^7t}qS|YLMr^BPwoPE>^(8Q-4^w`lyOpUg;xj$0 zFXp9o-{|gznfvb8w6MRI?QnfTp;4&M=P=)eYbt~>a7{Ik0DSA1wT2ZTfp&~Wh5TAo zsX4;biTPz3R4{7#id#EFlMub;Z8WV~`$EC6F{wcuW<4&%RE^}Khqf;&>?>XRoC4RY zQCY!|_Dj{3UC4-qNe~)-ah@g&N0h|;Yyr7J%9RI-qn=NmTUCk+-jF8BV=SlU@h^ve zQ$;OGl^K=^p_E|1PKnHn8@e1^O%|kgBkPv&yjXl4ITyf@UEcQ+9XD zhIJDDs$IKl*)*>XG8QTgnhvo7=+uNmgnouup=voVli%w(TU47bgOdVx&ZEL4nN6}r!H?22q8;FgoV{9a- zIj)X}p)L`r6Igy3vj1^)GW=1Xa0?Zt*swFNfXX!dZipCAVOglmWTP@0MP)7w4N3)& zlvO;YnNkAmc*)l=*VTD8UN`SpZDY@Z{;vp3GKg|YU0_R2-E`9QPPHX@?z! zs>fKt!-iDOh9yv&CPDnA@DK;+CAL8mFJHQG_0rn&R}3s8U%HZ_Q22RQVx(GgmDzLU%L@UmP#9v?u8dEs%XC{8#oA&8>qmD$-M7Inf^WKMyi z6i?6sv_oAu{p7$!<#WnFu_!#+vp3Gr zX#<^ktoBu|al37{S{RvLS|v;;8>^kR;d${v%-pe?V7Ocp$OuhAQ3}OzR|1A8ujBOy{%4V5`A{wj$?2biX9WkBZ=9s# zVGatVT!HH;7-r~CmFQ0op|##St&`5#;j^cJ+4Jzy?wp|AnH#6+R2uCt%S0@}(4YPg z`}R%x?h1V&M~mT_AE}#rTUGN}qqO==uufB0nLY@WDgB31)q5Zqip7Tr(-|c>mpuoh z+43heqU7hP7-}ml5<;%o^~FL4QCu0ATvV{QCGX8B#oqNr5wv()IgS3kwZR(tN4P!a zlfSJz8LHzhBZ9c9QHs^)a1M`&7Fh7DB%+qCGun&ynJ+GR&L>-&f>SwArtkWuL|-}t z3jaJ*oQng)x^<5`%Vw@+B35=*KyZu;B4{O+JReACZbZ1tb5YHmgC8sm*z@CrLK9J_2w?i#0uC8~oIgaq6s;|A z^gjmU??X5g9V^smr-4YA2k(VQq(KQp0=o@j%OO)F(psmGq;HV`$lEj7eO$%s zeevK`eLMn*Z#!5xTXxRQo#WWJsJsJ6?f_-iyyIwvyCULe)J4cIi2gk&fUf2VVDMj2 zXiO4KZXY1}@9zUR#FV&|?|`;aSqCj;nOq*db2?zNJ&Xf491eJ!0N!o@p9|OW#V%m; zV!{ex`6G78Jf6Be@;yHYh#mbEv3dXC><+fJ1e$P-yuLXi%oLO{Ka+qfT6O`8(0IbNeE^2&JDz)@=!k|K1&sfdLdOrQ6fsazl-#lSl;Mvb);`0RSxL;I z6SP!5mhk0-qbmuC1?FL3_}?qk-=cv{M7Om0eXg1M-W||a`oxsKsgPVnxZ;wKn)`5I zm^YtB)(Uc+qYUVe^a6$dvqC2i<3!A-dHV3Wv6Wm|M*d+4K5#^{bPxH_gjB&nFHb>$ z#EQw^K+>0w8Up{P=pL3sa^_|}MM6FaZUhtKvBL-9R}6GZAJal?iA@#dHmsWySg`)=DPe$TS>h!5XdXBYU411TKvS3!a0t(yqaQmg~`X$VGjJ6P~V{ zU1P5AkkjNA!8e89fFOckVhoB{tw~ScFAsbNxhiu}XZ2-|W-#V*R<4=nANZAm&s zvj&o%+lJk;Dz)vJm-AC#P9T~7ZjQbUXFa?D5Z`7&9^0grbI5seD2< z4c@2RuKbcXzV7aY7I-ud?wn%R))$}3Kmvy%;-L8u-hLxtimE8c{A zLs|PN1yfE2ht~Sc*IDZ;6=&_ZvHskZ^&h_e;?--LGH>X`wNGAI-!wjP;??vwKv+saYpijLs(1l1O!!_A~3p{Q<6RW8%T^zozwyyn;o`|csANuM_rcEF}XtWMw^_uw~UyFca> zb6PA9L^F$>wA=Ck1R#ZwX<LQ#VI$*3 zP=rG_#YvDj21l+C!*tg<4iE)VwhJLk9ndg{EcJ?7f9Cr_#C%@}#k4N?QypozG4qXK zz}lD``aBr0%f28!n*@U;Z%CCIEB?xNu+fKScNs^QCR8k9#-#% z=)KVVn0Nxdj@xvl`@aQ$gWb(4uzVR4r^WlCWklI4{NTcH|N~( zdbe@Jc|{ye#;q_bxNrgYlyu6!SEB#Of8URbYX6?2Nu#rEI0s6lSg=7=!GQiY_qtZ8 zo3<@OpjvrZ!xaE`v?rD|aixKH$8{$1-9OxjLXiWBD%Jruh&r}&?h(T_aJNA4)w$aU zJ%BQ1=YlRe-4X}+!@F;&-oOEWeA?^U9c=74Cr!57%{{S4glh#*W819T0-j&+x+V`% zF5fooWhv^*w`-kS5TRwiia_xm+SF`g2S>MV84as*yV>3ib>~n6)6T|8U6gC`B)NlT zW!Gv{MOnaH@xhrJ?lv0n&T& zh`5S)eYeHYc3qT^y!0obN8_8AAdsA|=dcdWI&<9@*)CfVvf*I)`-*JY_@Z8+BCV-ES@0Bag7lb(hp6-iHir(yISncIf(1<8 z(g%T~n`R&`V?a@yMuySpwveh|2%^-mcWtB6r!Ra)m=AGB3GVdjxz$D4%Dv?ctu$Mh zQFIi#eNPluZkY|-MS`ObrrnX>-Rs;8+&bgy zOg|BrdLBAiGWdB1ejA=-@TDcUy$pxqSj*@DMI=$HmpdH1R=t z3*uiKF_1@kF>ad1i6d(Mn~ALuc@({LbNYs8i3Dc7=X>T*|1wpTZ)F~nE(@5 z0^z*}-+Fhe9uAvuB=+9q46i~0>!4$%{J=Q;z&IRf93Urdgd(u*(((ya`_LE-Sc`jf zlrQ{P`7AcnBlno$$YAn53p!C=LPdGw7bkdp56ggJc%;O8zM<%Utgx>u^t;L*=CPYQDz3G> z)ZP*9BzxSsm6M%*D}R!PUmLWo`ZlSrxwlply{1aDEUud#rvh?J#2udEekI)Ji4@=V z?rr%pLu%F9HoJA)QCxT{kBryW0Pa?$-P_`>B-tkaR(^oZ(S%dv*UoE!KI&#aml?ql}Q@L_%QA90bJX? zElE8Q;)Y;V3L~%`9%F>4n8F9ig|8_3Z!7F=g}w;nuyYy_p%bnM&ApXpNfA0!Y#t<{ zuE4y1P-SoCdXbHP!KuxuX@2?b{3Xi2ogbj`SMwjDr@d{DC*I8A`ZniYxFhyDAYZ?s od^L|9yp_C%c0klT?*WN9W*GhNP`kPs(~JDpr<$ZUmK{h1ONa4 literal 0 HcmV?d00001 diff --git a/src/TensorFlowNet.Benchmarks/Leak/TestModel/saved_model/variables/variables.data-00000-of-00001 b/src/TensorFlowNet.Benchmarks/Leak/TestModel/saved_model/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..4c7f99dbacb2456aa9893d404a996e02b588ce42 GIT binary patch literal 1465 zcmd0sf7fo;&Zm18GMwF8>)O2UbGzKW9YwUfyiBZ4+N(w3O@)7L z|Jyv@>$p_Vj{n`eV;SlSaFL zzuN7VXlC0b&MLOuw>H%F`b9BYkIRvEzcp{~6|VkZyYX&=-T8k>`&OE7+S?nt$*$Te z&h9~Q!@jyTEc>i_x$QbdlkK`@G43t3kJ`7B*~2cFeTtoMv9sOI=PI^7(Fbg!C!gMv z@pH|dAJeOLd3v3;{a3?jJEPrf?>pv|dv|WqvTeNnW{)bf$zBE^XnbR`C+CLh-cJsF zyC2K3?ETDc@Y92htC5k5n}bn^JuNe*B(3I z6Ec7b1ccat?#xTg0Xjuch$SgAu^8wSAzV%o#_bdlRHuj{+s4Pm3bcvQ07wZ6F&YBx zMi*Jj$R)@jCd8YQSeaTBpPLGdkYtcca0H4xhq#b5L@oB9AwJZkIV6O{^MH|!E%3QH zB!$>vuHoX45@G{eTFk`{418&1FEVh23I%iVg*gU!Iy(8d#)mojgt`g|0+WY6DEtjm z^daFar3XzhU}31BzMf-9NRX#fXozdDKB^I1fkOTm#_}T>3yLx+9TY=BLimhC4}2Wa tA%htmI1+|Dhb*D!AU5SlbI9R~4teD0P=H4VEJv{eQy3>Gp}_=zGytKt^)dhe literal 0 HcmV?d00001 diff --git a/src/TensorFlowNet.Benchmarks/Leak/TestModel/saved_model/variables/variables.index b/src/TensorFlowNet.Benchmarks/Leak/TestModel/saved_model/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..ee0efb7c0c418c80c8493ce4c605384820ba484e GIT binary patch literal 274 zcmZQzVB=tvV&Y(Akl~Ma_HcFf4)FK%3vqPvagFzP@^W + + + + + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + From b3d08626f66d912012243e79505af66bdf79f2d6 Mon Sep 17 00:00:00 2001 From: Luke Bollam Date: Wed, 1 Sep 2021 11:38:09 +0800 Subject: [PATCH 2/6] less loops so it doesnt waste so much time --- src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs b/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs index e9e1e75f..36b2c0ba 100644 --- a/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs +++ b/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Benchmark.Leak var modelDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); - for (var i = 0; i < 1000; i++) + for (var i = 0; i < 50; i++) { var session = Session.LoadFromSavedModel(ClassifierModelPath); From 85139ed131350b48617503a2fec05808ed03c75e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 4 Sep 2021 09:31:29 -0500 Subject: [PATCH 3/6] Fix Session.LoadFromSavedModel memroy leak. --- src/TensorFlowNET.Core/Graphs/c_api.graph.cs | 2 +- .../Sessions/BaseSession.cs | 14 ++-- src/TensorFlowNET.Core/Sessions/Session.cs | 70 ++++++------------- .../Sessions/c_api.session.cs | 12 ++++ .../Leak/SavedModelCleanup.cs | 11 ++- src/TensorFlowNet.Benchmarks/Program.cs | 4 +- .../Tensorflow.Benchmark.csproj | 2 +- 7 files changed, 52 insertions(+), 63 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 2f5af971..6eb8f367 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -289,7 +289,7 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, string export_dir, string[] tags, int tags_len, - IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status); + IntPtr graph, IntPtr meta_graph_def, SafeStatusHandle status); [DllImport(TensorFlowLibName)] public static extern IntPtr TF_NewGraph(); diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 3c994a6e..a740226f 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -36,6 +36,12 @@ namespace Tensorflow protected byte[] _target; public Graph graph => _graph; + public BaseSession(IntPtr handle, Graph g) + { + _handle = handle; + _graph = g ?? ops.get_default_graph(); + } + public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) { _graph = g ?? ops.get_default_graph(); @@ -291,12 +297,8 @@ namespace Tensorflow protected override void DisposeUnmanagedResources(IntPtr handle) { - lock (Locks.ProcessWide) - using (var status = new Status()) - { - c_api.TF_DeleteSession(handle, status.Handle); - status.Check(true); - } + // c_api.TF_CloseSession(handle, tf.Status.Handle); + c_api.TF_DeleteSession(handle, tf.Status.Handle); } } } diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index c48715a2..1e94b882 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -26,10 +26,8 @@ namespace Tensorflow public Session(string target = "", Graph g = null) : base(target, g, null) { } - public Session(IntPtr handle, Graph g = null) : base("", g, null) - { - _handle = handle; - } + public Session(IntPtr handle, Graph g = null) : base(handle, g) + { } public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) { } @@ -39,51 +37,29 @@ namespace Tensorflow return ops.set_default_session(this); } - [MethodImpl(MethodImplOptions.NoOptimization)] public static Session LoadFromSavedModel(string path) { - lock (Locks.ProcessWide) - { - var graph = c_api.TF_NewGraph(); - using var status = new Status(); - var opt = new SessionOptions(); - - var tags = new string[] { "serve" }; - var buffer = new TF_Buffer(); - - IntPtr sess; - try - { - sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, - IntPtr.Zero, - path, - tags, - tags.Length, - graph, - ref buffer, - status.Handle); - status.Check(true); - } - catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) - { - sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, - IntPtr.Zero, - Path.GetFullPath(path), - tags, - tags.Length, - graph, - ref buffer, - status.Handle); - status.Check(true); - } - - // load graph bytes - // var data = new byte[buffer.length]; - // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); - // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ - - return new Session(sess, g: new Graph(graph)).as_default(); - } + using var graph = new Graph(); + using var status = new Status(); + using var opt = c_api.TF_NewSessionOptions(); + + var tags = new string[] { "serve" }; + + var sess = c_api.TF_LoadSessionFromSavedModel(opt, + IntPtr.Zero, + path, + tags, + tags.Length, + graph, + IntPtr.Zero, + status.Handle); + status.Check(true); + + // load graph bytes + // var data = new byte[buffer.length]; + // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); + // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ + return new Session(sess, g: graph); } public static implicit operator IntPtr(Session session) => session._handle; diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index 8ac4d53e..548d79e7 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -21,6 +21,18 @@ namespace Tensorflow { public partial class c_api { + /// + /// Close a session. + /// + /// Contacts any other processes associated with the session, if applicable. + /// May not be called after TF_DeleteSession(). + /// + /// + /// + + [DllImport(TensorFlowLibName)] + public static extern void TF_CloseSession(IntPtr session, SafeStatusHandle status); + /// /// Destroy a session object. /// diff --git a/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs b/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs index 36b2c0ba..5cdb28f7 100644 --- a/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs +++ b/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Reflection; using System.Text; using System.Threading.Tasks; +using static Tensorflow.Binding; namespace Tensorflow.Benchmark.Leak { @@ -18,13 +19,9 @@ namespace Tensorflow.Benchmark.Leak var modelDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); - for (var i = 0; i < 50; i++) - { - var session = Session.LoadFromSavedModel(ClassifierModelPath); - - session.graph.Exit(); - session.graph.Dispose(); - session.Dispose(); + for (var i = 0; i < 1024; i++) + { + using var sess = Session.LoadFromSavedModel(ClassifierModelPath); } } } diff --git a/src/TensorFlowNet.Benchmarks/Program.cs b/src/TensorFlowNet.Benchmarks/Program.cs index 598d7a03..22abf730 100644 --- a/src/TensorFlowNet.Benchmarks/Program.cs +++ b/src/TensorFlowNet.Benchmarks/Program.cs @@ -13,7 +13,9 @@ namespace TensorFlowBenchmark static void Main(string[] args) { print(tf.VERSION); - /*new RepeatDataSetCrash().Run(); + + /*new SavedModelCleanup().Run(); + new RepeatDataSetCrash().Run(); new GpuLeakByCNN().Run();*/ if (args?.Length > 0) diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj index ea799b02..ceba6cbb 100644 --- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj @@ -37,7 +37,7 @@ - + From f3102b9be76fe0adafbc3b4f9da5d48263b8e4b4 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 4 Sep 2021 09:39:35 -0500 Subject: [PATCH 4/6] Release v0.60.3. --- TensorFlow.NET.sln | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 36034437..8846d5bf 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 17 -VisualStudioVersion = 17.0.31423.177 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.31624.102 MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" EndProject @@ -77,8 +77,8 @@ Global {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.Build.0 = Debug|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|Any CPU.ActiveCfg = Release|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|Any CPU.Build.0 = Release|Any CPU - {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|x64.ActiveCfg = Release|Any CPU - {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|x64.Build.0 = Release|Any CPU + {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|x64.ActiveCfg = Release|x64 + {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|x64.Build.0 = Release|x64 {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|x86.ActiveCfg = Release|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Release|x86.Build.0 = Release|Any CPU {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU @@ -101,8 +101,8 @@ Global {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.Build.0 = Debug|Any CPU {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|Any CPU.ActiveCfg = Release|Any CPU {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|Any CPU.Build.0 = Release|Any CPU - {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x64.ActiveCfg = Release|Any CPU - {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x64.Build.0 = Release|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x64.ActiveCfg = Release|x64 + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x64.Build.0 = Release|x64 {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x86.ActiveCfg = Release|Any CPU {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x86.Build.0 = Release|Any CPU {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU @@ -113,8 +113,8 @@ Global {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.Build.0 = Debug|Any CPU {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|Any CPU.ActiveCfg = Release|Any CPU {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|Any CPU.Build.0 = Release|Any CPU - {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.ActiveCfg = Release|Any CPU - {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.Build.0 = Release|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.ActiveCfg = Release|x64 + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.Build.0 = Release|x64 {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.ActiveCfg = Release|Any CPU {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.Build.0 = Release|Any CPU {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU @@ -125,8 +125,8 @@ Global {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.Build.0 = Debug|Any CPU {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.ActiveCfg = Release|Any CPU {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.Build.0 = Release|Any CPU - {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.ActiveCfg = Release|Any CPU - {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.Build.0 = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.ActiveCfg = Release|x64 + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.Build.0 = Release|x64 {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.ActiveCfg = Release|Any CPU {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.Build.0 = Release|Any CPU {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU @@ -137,8 +137,8 @@ Global {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Debug|x86.Build.0 = Debug|Any CPU {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|Any CPU.ActiveCfg = Release|Any CPU {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|Any CPU.Build.0 = Release|Any CPU - {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x64.ActiveCfg = Release|Any CPU - {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x64.Build.0 = Release|Any CPU + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x64.ActiveCfg = Release|x64 + {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x64.Build.0 = Release|x64 {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x86.ActiveCfg = Release|Any CPU {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x86.Build.0 = Release|Any CPU {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU @@ -149,8 +149,8 @@ Global {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x86.Build.0 = Debug|Any CPU {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|Any CPU.ActiveCfg = Release|Any CPU {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|Any CPU.Build.0 = Release|Any CPU - {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.ActiveCfg = Release|Any CPU - {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.ActiveCfg = Release|x64 + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection From 824dfe6aaf58b0c5b3c34e02dfa5d8404bcf8a23 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 5 Sep 2021 15:17:37 -0500 Subject: [PATCH 5/6] Pack/Unpack gradient. #847 --- src/TensorFlowNET.Core/Gradients/array_grad.cs | 16 ++++++++++++++++ src/TensorFlowNET.Core/Operations/array_ops.cs | 10 +--------- .../Operations/gen_array_ops.cs | 6 ++---- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index f80f8ac6..528b5208 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -223,6 +223,22 @@ namespace Tensorflow.Gradients return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null }; } + [RegisterGradient("Pack")] + public static Tensor[] _PackGrad(Operation op, Tensor[] grads) + { + var grad = grads[0]; + var num = op.get_attr("N"); + var axis = op.get_attr("axis"); + return array_ops.unstack(grad, num: num, axis: axis); + } + + [RegisterGradient("Unpack")] + public static Tensor[] _UnpackGrad(Operation op, Tensor[] grads) + { + var axis = op.get_attr("axis"); + return new[] { array_ops.stack(grads, axis: axis) }; + } + [RegisterGradient("Pad")] public static Tensor[] _PadGrad(Operation op, Tensor[] grads) { diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index b0ef1f2d..d13e0005 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -494,20 +494,12 @@ namespace Tensorflow return ops.convert_to_tensor(values, name: name); } - var value_shape = ops.convert_to_tensor(values[0], name: name).shape; - return gen_array_ops.pack(values, axis: axis, name: name); } public static Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack") { - if (num == null) - { - value = ops.convert_to_tensor(value); - var value_shape = value.shape; - num = (int)value_shape.dims[axis]; - } - + num = num ?? value.shape.as_int_list()[axis]; return gen_array_ops.unpack(value, num: num.Value, axis: axis, name: name); } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 65599a4c..dd1604f6 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -265,10 +265,8 @@ namespace Tensorflow } public static Tensor[] unpack(Tensor value, int num, int axis = 0, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper("Unpack", name, new { value, num, axis }); - return _op.outputs; - } + => tf.Context.ExecuteOp("Unpack", name, new ExecuteOpArgs(value, num) + .SetAttributes(new { axis })); public static Tensor where(Tensor condition, string name = null) { From f3a51daba4ab7b13d97ef0c8ec5e5dfe6cdbc0ee Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 5 Sep 2021 16:28:03 -0500 Subject: [PATCH 6/6] Update protobuf per tf 2.6. --- src/TensorFlowNET.Core/Protobuf/Config.cs | 393 ++++++++++----- .../Protobuf/CppShapeInference.cs | 84 ++-- src/TensorFlowNET.Core/Protobuf/FullType.cs | 450 ++++++++++++++++++ src/TensorFlowNET.Core/Protobuf/Function.cs | 249 ++++++++-- src/TensorFlowNET.Core/Protobuf/Gen.bat | 2 + .../Protobuf/MemmappedFileSystem.cs | 360 ++++++++++++++ src/TensorFlowNET.Core/Protobuf/OpDef.cs | 135 ++++-- .../Protobuf/RewriterConfig.cs | 4 +- .../Protobuf/SavedObjectGraph.cs | 309 +++++++++--- src/TensorFlowNET.Core/Protobuf/Struct.cs | 31 +- src/TensorFlowNET.Core/Protobuf/Tensor.cs | 2 +- src/TensorFlowNET.Core/Protobuf/Types.cs | 32 +- 12 files changed, 1756 insertions(+), 295 deletions(-) create mode 100644 src/TensorFlowNET.Core/Protobuf/FullType.cs create mode 100644 src/TensorFlowNET.Core/Protobuf/MemmappedFileSystem.cs diff --git a/src/TensorFlowNET.Core/Protobuf/Config.cs b/src/TensorFlowNET.Core/Protobuf/Config.cs index af7391d3..cd34fd78 100644 --- a/src/TensorFlowNET.Core/Protobuf/Config.cs +++ b/src/TensorFlowNET.Core/Protobuf/Config.cs @@ -30,7 +30,7 @@ namespace Tensorflow { "KnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvc3RlcF9zdGF0cy5wcm90bxom", "dGVuc29yZmxvdy9jb3JlL3Byb3RvYnVmL2NsdXN0ZXIucHJvdG8aJHRlbnNv", "cmZsb3cvY29yZS9wcm90b2J1Zi9kZWJ1Zy5wcm90bxoudGVuc29yZmxvdy9j", - "b3JlL3Byb3RvYnVmL3Jld3JpdGVyX2NvbmZpZy5wcm90byLJBQoKR1BVT3B0", + "b3JlL3Byb3RvYnVmL3Jld3JpdGVyX2NvbmZpZy5wcm90byKRBgoKR1BVT3B0", "aW9ucxInCh9wZXJfcHJvY2Vzc19ncHVfbWVtb3J5X2ZyYWN0aW9uGAEgASgB", "EhQKDGFsbG93X2dyb3d0aBgEIAEoCBIWCg5hbGxvY2F0b3JfdHlwZRgCIAEo", "CRIfChdkZWZlcnJlZF9kZWxldGlvbl9ieXRlcxgDIAEoAxIbChN2aXNpYmxl", @@ -38,123 +38,127 @@ namespace Tensorflow { "ZWNzGAYgASgFEiQKHHBvbGxpbmdfaW5hY3RpdmVfZGVsYXlfbXNlY3MYByAB", "KAUSHAoUZm9yY2VfZ3B1X2NvbXBhdGlibGUYCCABKAgSOQoMZXhwZXJpbWVu", "dGFsGAkgASgLMiMudGVuc29yZmxvdy5HUFVPcHRpb25zLkV4cGVyaW1lbnRh", - "bBqCAwoMRXhwZXJpbWVudGFsEksKD3ZpcnR1YWxfZGV2aWNlcxgBIAMoCzIy", + "bBrKAwoMRXhwZXJpbWVudGFsEksKD3ZpcnR1YWxfZGV2aWNlcxgBIAMoCzIy", "LnRlbnNvcmZsb3cuR1BVT3B0aW9ucy5FeHBlcmltZW50YWwuVmlydHVhbERl", "dmljZXMSGgoSdXNlX3VuaWZpZWRfbWVtb3J5GAIgASgIEiMKG251bV9kZXZf", "dG9fZGV2X2NvcHlfc3RyZWFtcxgDIAEoBRIdChVjb2xsZWN0aXZlX3Jpbmdf", "b3JkZXIYBCABKAkSHQoVdGltZXN0YW1wZWRfYWxsb2NhdG9yGAUgASgIEiMK", "G2tlcm5lbF90cmFja2VyX21heF9pbnRlcnZhbBgHIAEoBRIgChhrZXJuZWxf", "dHJhY2tlcl9tYXhfYnl0ZXMYCCABKAUSIgoaa2VybmVsX3RyYWNrZXJfbWF4", - "X3BlbmRpbmcYCSABKAUaOwoOVmlydHVhbERldmljZXMSFwoPbWVtb3J5X2xp", - "bWl0X21iGAEgAygCEhAKCHByaW9yaXR5GAIgAygFIoUDChBPcHRpbWl6ZXJP", - "cHRpb25zEisKI2RvX2NvbW1vbl9zdWJleHByZXNzaW9uX2VsaW1pbmF0aW9u", - "GAEgASgIEhsKE2RvX2NvbnN0YW50X2ZvbGRpbmcYAiABKAgSJAocbWF4X2Zv", - "bGRlZF9jb25zdGFudF9pbl9ieXRlcxgGIAEoAxIcChRkb19mdW5jdGlvbl9p", - "bmxpbmluZxgEIAEoCBI1CglvcHRfbGV2ZWwYAyABKA4yIi50ZW5zb3JmbG93", - "Lk9wdGltaXplck9wdGlvbnMuTGV2ZWwSRQoQZ2xvYmFsX2ppdF9sZXZlbBgF", - "IAEoDjIrLnRlbnNvcmZsb3cuT3B0aW1pemVyT3B0aW9ucy5HbG9iYWxKaXRM", - "ZXZlbCIgCgVMZXZlbBIGCgJMMRAAEg8KAkwwEP///////////wEiQwoOR2xv", - "YmFsSml0TGV2ZWwSCwoHREVGQVVMVBAAEhAKA09GRhD///////////8BEggK", - "BE9OXzEQARIICgRPTl8yEAIi7gIKDEdyYXBoT3B0aW9ucxIeChZlbmFibGVf", - "cmVjdl9zY2hlZHVsaW5nGAIgASgIEjcKEW9wdGltaXplcl9vcHRpb25zGAMg", - "ASgLMhwudGVuc29yZmxvdy5PcHRpbWl6ZXJPcHRpb25zEhgKEGJ1aWxkX2Nv", - "c3RfbW9kZWwYBCABKAMSHgoWYnVpbGRfY29zdF9tb2RlbF9hZnRlchgJIAEo", - "AxIUCgxpbmZlcl9zaGFwZXMYBSABKAgSGgoScGxhY2VfcHJ1bmVkX2dyYXBo", - "GAYgASgIEiAKGGVuYWJsZV9iZmxvYXQxNl9zZW5kcmVjdhgHIAEoCBIVCg10", - "aW1lbGluZV9zdGVwGAggASgFEjMKD3Jld3JpdGVfb3B0aW9ucxgKIAEoCzIa", - "LnRlbnNvcmZsb3cuUmV3cml0ZXJDb25maWdKBAgBEAJSJXNraXBfY29tbW9u", - "X3N1YmV4cHJlc3Npb25fZWxpbWluYXRpb24iQQoVVGhyZWFkUG9vbE9wdGlv", - "blByb3RvEhMKC251bV90aHJlYWRzGAEgASgFEhMKC2dsb2JhbF9uYW1lGAIg", - "ASgJIrQBCgpSUENPcHRpb25zEiQKHHVzZV9ycGNfZm9yX2lucHJvY2Vzc19t", - "YXN0ZXIYASABKAgSHQoVY29tcHJlc3Npb25fYWxnb3JpdGhtGAIgASgJEhkK", - "EWNvbXByZXNzaW9uX2xldmVsGAMgASgFEhoKEmNhY2hlX3JwY19yZXNwb25z", - "ZRgEIAEoCBIqCiJkaXNhYmxlX3Nlc3Npb25fY29ubmVjdGlvbl9zaGFyaW5n", - "GAUgASgIIjAKD1Nlc3Npb25NZXRhZGF0YRIMCgRuYW1lGAEgASgJEg8KB3Zl", - "cnNpb24YAiABKAMijA0KC0NvbmZpZ1Byb3RvEj4KDGRldmljZV9jb3VudBgB", - "IAMoCzIoLnRlbnNvcmZsb3cuQ29uZmlnUHJvdG8uRGV2aWNlQ291bnRFbnRy", - "eRIkChxpbnRyYV9vcF9wYXJhbGxlbGlzbV90aHJlYWRzGAIgASgFEiQKHGlu", - "dGVyX29wX3BhcmFsbGVsaXNtX3RocmVhZHMYBSABKAUSHwoXdXNlX3Blcl9z", - "ZXNzaW9uX3RocmVhZHMYCSABKAgSRwocc2Vzc2lvbl9pbnRlcl9vcF90aHJl", - "YWRfcG9vbBgMIAMoCzIhLnRlbnNvcmZsb3cuVGhyZWFkUG9vbE9wdGlvblBy", - "b3RvEhgKEHBsYWNlbWVudF9wZXJpb2QYAyABKAUSFgoOZGV2aWNlX2ZpbHRl", - "cnMYBCADKAkSKwoLZ3B1X29wdGlvbnMYBiABKAsyFi50ZW5zb3JmbG93LkdQ", - "VU9wdGlvbnMSHAoUYWxsb3dfc29mdF9wbGFjZW1lbnQYByABKAgSHAoUbG9n", - "X2RldmljZV9wbGFjZW1lbnQYCCABKAgSLwoNZ3JhcGhfb3B0aW9ucxgKIAEo", - "CzIYLnRlbnNvcmZsb3cuR3JhcGhPcHRpb25zEh8KF29wZXJhdGlvbl90aW1l", - "b3V0X2luX21zGAsgASgDEisKC3JwY19vcHRpb25zGA0gASgLMhYudGVuc29y", - "Zmxvdy5SUENPcHRpb25zEisKC2NsdXN0ZXJfZGVmGA4gASgLMhYudGVuc29y", - "Zmxvdy5DbHVzdGVyRGVmEh0KFWlzb2xhdGVfc2Vzc2lvbl9zdGF0ZRgPIAEo", - "CBIoCiBzaGFyZV9jbHVzdGVyX2RldmljZXNfaW5fc2Vzc2lvbhgRIAEoCBI6", - "CgxleHBlcmltZW50YWwYECABKAsyJC50ZW5zb3JmbG93LkNvbmZpZ1Byb3Rv", - "LkV4cGVyaW1lbnRhbBoyChBEZXZpY2VDb3VudEVudHJ5EgsKA2tleRgBIAEo", - "CRINCgV2YWx1ZRgCIAEoBToCOAEahgcKDEV4cGVyaW1lbnRhbBIfChdjb2xs", - "ZWN0aXZlX2dyb3VwX2xlYWRlchgBIAEoCRIVCg1leGVjdXRvcl90eXBlGAMg", - "ASgJEhoKEnJlY3ZfYnVmX21heF9jaHVuaxgEIAEoBRIZChF1c2VfbnVtYV9h", - "ZmZpbml0eRgFIAEoCBI1Ci1jb2xsZWN0aXZlX2RldGVybWluaXN0aWNfc2Vx", - "dWVudGlhbF9leGVjdXRpb24YBiABKAgSFwoPY29sbGVjdGl2ZV9uY2NsGAcg", - "ASgIEjYKLnNoYXJlX3Nlc3Npb25fc3RhdGVfaW5fY2x1c3RlcnNwZWNfcHJv", - "cGFnYXRpb24YCCABKAgSHwoXZGlzYWJsZV90aHJlYWRfc3Bpbm5pbmcYCSAB", - "KAgSKAogc2hhcmVfY2x1c3Rlcl9kZXZpY2VzX2luX3Nlc3Npb24YCiABKAgS", - "NQoQc2Vzc2lvbl9tZXRhZGF0YRgLIAEoCzIbLnRlbnNvcmZsb3cuU2Vzc2lv", - "bk1ldGFkYXRhEiEKGW9wdGltaXplX2Zvcl9zdGF0aWNfZ3JhcGgYDCABKAgS", - "GgoSZW5hYmxlX21saXJfYnJpZGdlGA0gASgIElMKE21saXJfYnJpZGdlX3Jv", - "bGxvdXQYESABKA4yNi50ZW5zb3JmbG93LkNvbmZpZ1Byb3RvLkV4cGVyaW1l", - "bnRhbC5NbGlyQnJpZGdlUm9sbG91dBImCh5lbmFibGVfbWxpcl9ncmFwaF9v", - "cHRpbWl6YXRpb24YECABKAgSJwofZGlzYWJsZV9vdXRwdXRfcGFydGl0aW9u", - "X2dyYXBocxgOIAEoCBIjCht4bGFfZnVzaW9uX2F1dG90dW5lcl90aHJlc2gY", - "DyABKAMSEAoIdXNlX3RmcnQYEiABKAgi2gEKEU1saXJCcmlkZ2VSb2xsb3V0", - "EiMKH01MSVJfQlJJREdFX1JPTExPVVRfVU5TUEVDSUZJRUQQABIfChtNTElS", - "X0JSSURHRV9ST0xMT1VUX0VOQUJMRUQQARIgChxNTElSX0JSSURHRV9ST0xM", - "T1VUX0RJU0FCTEVEEAISKQolTUxJUl9CUklER0VfUk9MTE9VVF9TQUZFX01P", - "REVfRU5BQkxFRBADEjIKLk1MSVJfQlJJREdFX1JPTExPVVRfU0FGRV9NT0RF", - "X0ZBTExCQUNLX0VOQUJMRUQQBEoECAIQAyLhBAoKUnVuT3B0aW9ucxI2Cgt0", - "cmFjZV9sZXZlbBgBIAEoDjIhLnRlbnNvcmZsb3cuUnVuT3B0aW9ucy5UcmFj", - "ZUxldmVsEhUKDXRpbWVvdXRfaW5fbXMYAiABKAMSHAoUaW50ZXJfb3BfdGhy", - "ZWFkX3Bvb2wYAyABKAUSHwoXb3V0cHV0X3BhcnRpdGlvbl9ncmFwaHMYBSAB", - "KAgSLwoNZGVidWdfb3B0aW9ucxgGIAEoCzIYLnRlbnNvcmZsb3cuRGVidWdP", - "cHRpb25zEioKInJlcG9ydF90ZW5zb3JfYWxsb2NhdGlvbnNfdXBvbl9vb20Y", - "ByABKAgSOQoMZXhwZXJpbWVudGFsGAggASgLMiMudGVuc29yZmxvdy5SdW5P", - "cHRpb25zLkV4cGVyaW1lbnRhbBrSAQoMRXhwZXJpbWVudGFsEhwKFGNvbGxl", - "Y3RpdmVfZ3JhcGhfa2V5GAEgASgDEhwKFHVzZV9ydW5faGFuZGxlcl9wb29s", - "GAIgASgIElsKGHJ1bl9oYW5kbGVyX3Bvb2xfb3B0aW9ucxgDIAEoCzI5LnRl", - "bnNvcmZsb3cuUnVuT3B0aW9ucy5FeHBlcmltZW50YWwuUnVuSGFuZGxlclBv", - "b2xPcHRpb25zGikKFVJ1bkhhbmRsZXJQb29sT3B0aW9ucxIQCghwcmlvcml0", - "eRgBIAEoAyJSCgpUcmFjZUxldmVsEgwKCE5PX1RSQUNFEAASEgoOU09GVFdB", - "UkVfVFJBQ0UQARISCg5IQVJEV0FSRV9UUkFDRRACEg4KCkZVTExfVFJBQ0UQ", - "A0oECAQQBSKHAwoLUnVuTWV0YWRhdGESKQoKc3RlcF9zdGF0cxgBIAEoCzIV", - "LnRlbnNvcmZsb3cuU3RlcFN0YXRzEiwKCmNvc3RfZ3JhcGgYAiABKAsyGC50", - "ZW5zb3JmbG93LkNvc3RHcmFwaERlZhIuChBwYXJ0aXRpb25fZ3JhcGhzGAMg", - "AygLMhQudGVuc29yZmxvdy5HcmFwaERlZhI/Cg9mdW5jdGlvbl9ncmFwaHMY", - "BCADKAsyJi50ZW5zb3JmbG93LlJ1bk1ldGFkYXRhLkZ1bmN0aW9uR3JhcGhz", - "Gq0BCg5GdW5jdGlvbkdyYXBocxIuChBwYXJ0aXRpb25fZ3JhcGhzGAEgAygL", - "MhQudGVuc29yZmxvdy5HcmFwaERlZhI0ChZwcmVfb3B0aW1pemF0aW9uX2dy", - "YXBoGAIgASgLMhQudGVuc29yZmxvdy5HcmFwaERlZhI1Chdwb3N0X29wdGlt", - "aXphdGlvbl9ncmFwaBgDIAEoCzIULnRlbnNvcmZsb3cuR3JhcGhEZWYiOgoQ", - "VGVuc29yQ29ubmVjdGlvbhITCgtmcm9tX3RlbnNvchgBIAEoCRIRCgl0b190", - "ZW5zb3IYAiABKAkisAMKD0NhbGxhYmxlT3B0aW9ucxIMCgRmZWVkGAEgAygJ", - "Eg0KBWZldGNoGAIgAygJEg4KBnRhcmdldBgDIAMoCRIrCgtydW5fb3B0aW9u", - "cxgEIAEoCzIWLnRlbnNvcmZsb3cuUnVuT3B0aW9ucxI3ChF0ZW5zb3JfY29u", - "bmVjdGlvbhgFIAMoCzIcLnRlbnNvcmZsb3cuVGVuc29yQ29ubmVjdGlvbhJC", - "CgxmZWVkX2RldmljZXMYBiADKAsyLC50ZW5zb3JmbG93LkNhbGxhYmxlT3B0", - "aW9ucy5GZWVkRGV2aWNlc0VudHJ5EkQKDWZldGNoX2RldmljZXMYByADKAsy", - "LS50ZW5zb3JmbG93LkNhbGxhYmxlT3B0aW9ucy5GZXRjaERldmljZXNFbnRy", - "eRIXCg9mZXRjaF9za2lwX3N5bmMYCCABKAgaMgoQRmVlZERldmljZXNFbnRy", - "eRILCgNrZXkYASABKAkSDQoFdmFsdWUYAiABKAk6AjgBGjMKEUZldGNoRGV2", - "aWNlc0VudHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgCIAEoCToCOAFChAEK", - "GG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0IMQ29uZmlnUHJvdG9zUAFaVWdp", - "dGh1Yi5jb20vdGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28v", - "Y29yZS9wcm90b2J1Zi9mb3JfY29yZV9wcm90b3NfZ29fcHJvdG/4AQFiBnBy", - "b3RvMw==")); + "X3BlbmRpbmcYCSABKAUSJwofaW50ZXJuYWxfZnJhZ21lbnRhdGlvbl9mcmFj", + "dGlvbhgKIAEoARIdChV1c2VfY3VkYV9tYWxsb2NfYXN5bmMYCyABKAgaOwoO", + "VmlydHVhbERldmljZXMSFwoPbWVtb3J5X2xpbWl0X21iGAEgAygCEhAKCHBy", + "aW9yaXR5GAIgAygFIoUDChBPcHRpbWl6ZXJPcHRpb25zEisKI2RvX2NvbW1v", + "bl9zdWJleHByZXNzaW9uX2VsaW1pbmF0aW9uGAEgASgIEhsKE2RvX2NvbnN0", + "YW50X2ZvbGRpbmcYAiABKAgSJAocbWF4X2ZvbGRlZF9jb25zdGFudF9pbl9i", + "eXRlcxgGIAEoAxIcChRkb19mdW5jdGlvbl9pbmxpbmluZxgEIAEoCBI1Cglv", + "cHRfbGV2ZWwYAyABKA4yIi50ZW5zb3JmbG93Lk9wdGltaXplck9wdGlvbnMu", + "TGV2ZWwSRQoQZ2xvYmFsX2ppdF9sZXZlbBgFIAEoDjIrLnRlbnNvcmZsb3cu", + "T3B0aW1pemVyT3B0aW9ucy5HbG9iYWxKaXRMZXZlbCIgCgVMZXZlbBIGCgJM", + "MRAAEg8KAkwwEP///////////wEiQwoOR2xvYmFsSml0TGV2ZWwSCwoHREVG", + "QVVMVBAAEhAKA09GRhD///////////8BEggKBE9OXzEQARIICgRPTl8yEAIi", + "7gIKDEdyYXBoT3B0aW9ucxIeChZlbmFibGVfcmVjdl9zY2hlZHVsaW5nGAIg", + "ASgIEjcKEW9wdGltaXplcl9vcHRpb25zGAMgASgLMhwudGVuc29yZmxvdy5P", + "cHRpbWl6ZXJPcHRpb25zEhgKEGJ1aWxkX2Nvc3RfbW9kZWwYBCABKAMSHgoW", + "YnVpbGRfY29zdF9tb2RlbF9hZnRlchgJIAEoAxIUCgxpbmZlcl9zaGFwZXMY", + "BSABKAgSGgoScGxhY2VfcHJ1bmVkX2dyYXBoGAYgASgIEiAKGGVuYWJsZV9i", + "ZmxvYXQxNl9zZW5kcmVjdhgHIAEoCBIVCg10aW1lbGluZV9zdGVwGAggASgF", + "EjMKD3Jld3JpdGVfb3B0aW9ucxgKIAEoCzIaLnRlbnNvcmZsb3cuUmV3cml0", + "ZXJDb25maWdKBAgBEAJSJXNraXBfY29tbW9uX3N1YmV4cHJlc3Npb25fZWxp", + "bWluYXRpb24iQQoVVGhyZWFkUG9vbE9wdGlvblByb3RvEhMKC251bV90aHJl", + "YWRzGAEgASgFEhMKC2dsb2JhbF9uYW1lGAIgASgJItUBCgpSUENPcHRpb25z", + "EiQKHHVzZV9ycGNfZm9yX2lucHJvY2Vzc19tYXN0ZXIYASABKAgSHQoVY29t", + "cHJlc3Npb25fYWxnb3JpdGhtGAIgASgJEhkKEWNvbXByZXNzaW9uX2xldmVs", + "GAMgASgFEhoKEmNhY2hlX3JwY19yZXNwb25zZRgEIAEoCBIqCiJkaXNhYmxl", + "X3Nlc3Npb25fY29ubmVjdGlvbl9zaGFyaW5nGAUgASgIEh8KF251bV9jaGFu", + "bmVsc19wZXJfdGFyZ2V0GAYgASgFIjAKD1Nlc3Npb25NZXRhZGF0YRIMCgRu", + "YW1lGAEgASgJEg8KB3ZlcnNpb24YAiABKAMi2A0KC0NvbmZpZ1Byb3RvEj4K", + "DGRldmljZV9jb3VudBgBIAMoCzIoLnRlbnNvcmZsb3cuQ29uZmlnUHJvdG8u", + "RGV2aWNlQ291bnRFbnRyeRIkChxpbnRyYV9vcF9wYXJhbGxlbGlzbV90aHJl", + "YWRzGAIgASgFEiQKHGludGVyX29wX3BhcmFsbGVsaXNtX3RocmVhZHMYBSAB", + "KAUSHwoXdXNlX3Blcl9zZXNzaW9uX3RocmVhZHMYCSABKAgSRwocc2Vzc2lv", + "bl9pbnRlcl9vcF90aHJlYWRfcG9vbBgMIAMoCzIhLnRlbnNvcmZsb3cuVGhy", + "ZWFkUG9vbE9wdGlvblByb3RvEhgKEHBsYWNlbWVudF9wZXJpb2QYAyABKAUS", + "FgoOZGV2aWNlX2ZpbHRlcnMYBCADKAkSKwoLZ3B1X29wdGlvbnMYBiABKAsy", + "Fi50ZW5zb3JmbG93LkdQVU9wdGlvbnMSHAoUYWxsb3dfc29mdF9wbGFjZW1l", + "bnQYByABKAgSHAoUbG9nX2RldmljZV9wbGFjZW1lbnQYCCABKAgSLwoNZ3Jh", + "cGhfb3B0aW9ucxgKIAEoCzIYLnRlbnNvcmZsb3cuR3JhcGhPcHRpb25zEh8K", + "F29wZXJhdGlvbl90aW1lb3V0X2luX21zGAsgASgDEisKC3JwY19vcHRpb25z", + "GA0gASgLMhYudGVuc29yZmxvdy5SUENPcHRpb25zEisKC2NsdXN0ZXJfZGVm", + "GA4gASgLMhYudGVuc29yZmxvdy5DbHVzdGVyRGVmEh0KFWlzb2xhdGVfc2Vz", + "c2lvbl9zdGF0ZRgPIAEoCBIoCiBzaGFyZV9jbHVzdGVyX2RldmljZXNfaW5f", + "c2Vzc2lvbhgRIAEoCBI6CgxleHBlcmltZW50YWwYECABKAsyJC50ZW5zb3Jm", + "bG93LkNvbmZpZ1Byb3RvLkV4cGVyaW1lbnRhbBoyChBEZXZpY2VDb3VudEVu", + "dHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgCIAEoBToCOAEa0gcKDEV4cGVy", + "aW1lbnRhbBIfChdjb2xsZWN0aXZlX2dyb3VwX2xlYWRlchgBIAEoCRIVCg1l", + "eGVjdXRvcl90eXBlGAMgASgJEhoKEnJlY3ZfYnVmX21heF9jaHVuaxgEIAEo", + "BRIZChF1c2VfbnVtYV9hZmZpbml0eRgFIAEoCBI1Ci1jb2xsZWN0aXZlX2Rl", + "dGVybWluaXN0aWNfc2VxdWVudGlhbF9leGVjdXRpb24YBiABKAgSFwoPY29s", + "bGVjdGl2ZV9uY2NsGAcgASgIEjYKLnNoYXJlX3Nlc3Npb25fc3RhdGVfaW5f", + "Y2x1c3RlcnNwZWNfcHJvcGFnYXRpb24YCCABKAgSHwoXZGlzYWJsZV90aHJl", + "YWRfc3Bpbm5pbmcYCSABKAgSKAogc2hhcmVfY2x1c3Rlcl9kZXZpY2VzX2lu", + "X3Nlc3Npb24YCiABKAgSNQoQc2Vzc2lvbl9tZXRhZGF0YRgLIAEoCzIbLnRl", + "bnNvcmZsb3cuU2Vzc2lvbk1ldGFkYXRhEiEKGW9wdGltaXplX2Zvcl9zdGF0", + "aWNfZ3JhcGgYDCABKAgSGgoSZW5hYmxlX21saXJfYnJpZGdlGA0gASgIElMK", + "E21saXJfYnJpZGdlX3JvbGxvdXQYESABKA4yNi50ZW5zb3JmbG93LkNvbmZp", + "Z1Byb3RvLkV4cGVyaW1lbnRhbC5NbGlyQnJpZGdlUm9sbG91dBImCh5lbmFi", + "bGVfbWxpcl9ncmFwaF9vcHRpbWl6YXRpb24YECABKAgSJwofZGlzYWJsZV9v", + "dXRwdXRfcGFydGl0aW9uX2dyYXBocxgOIAEoCBIjCht4bGFfZnVzaW9uX2F1", + "dG90dW5lcl90aHJlc2gYDyABKAMSEAoIdXNlX3RmcnQYEiABKAgSHAoUY29v", + "cmRpbmF0aW9uX3NlcnZpY2UYEyABKAkSLAokZmV0Y2hfcmVtb3RlX2Rldmlj", + "ZXNfaW5fbXVsdGlfY2xpZW50GBQgASgIItoBChFNbGlyQnJpZGdlUm9sbG91", + "dBIjCh9NTElSX0JSSURHRV9ST0xMT1VUX1VOU1BFQ0lGSUVEEAASHwobTUxJ", + "Ul9CUklER0VfUk9MTE9VVF9FTkFCTEVEEAESIAocTUxJUl9CUklER0VfUk9M", + "TE9VVF9ESVNBQkxFRBACEikKJU1MSVJfQlJJREdFX1JPTExPVVRfU0FGRV9N", + "T0RFX0VOQUJMRUQQAxIyCi5NTElSX0JSSURHRV9ST0xMT1VUX1NBRkVfTU9E", + "RV9GQUxMQkFDS19FTkFCTEVEEARKBAgCEAMi4QQKClJ1bk9wdGlvbnMSNgoL", + "dHJhY2VfbGV2ZWwYASABKA4yIS50ZW5zb3JmbG93LlJ1bk9wdGlvbnMuVHJh", + "Y2VMZXZlbBIVCg10aW1lb3V0X2luX21zGAIgASgDEhwKFGludGVyX29wX3Ro", + "cmVhZF9wb29sGAMgASgFEh8KF291dHB1dF9wYXJ0aXRpb25fZ3JhcGhzGAUg", + "ASgIEi8KDWRlYnVnX29wdGlvbnMYBiABKAsyGC50ZW5zb3JmbG93LkRlYnVn", + "T3B0aW9ucxIqCiJyZXBvcnRfdGVuc29yX2FsbG9jYXRpb25zX3Vwb25fb29t", + "GAcgASgIEjkKDGV4cGVyaW1lbnRhbBgIIAEoCzIjLnRlbnNvcmZsb3cuUnVu", + "T3B0aW9ucy5FeHBlcmltZW50YWwa0gEKDEV4cGVyaW1lbnRhbBIcChRjb2xs", + "ZWN0aXZlX2dyYXBoX2tleRgBIAEoAxIcChR1c2VfcnVuX2hhbmRsZXJfcG9v", + "bBgCIAEoCBJbChhydW5faGFuZGxlcl9wb29sX29wdGlvbnMYAyABKAsyOS50", + "ZW5zb3JmbG93LlJ1bk9wdGlvbnMuRXhwZXJpbWVudGFsLlJ1bkhhbmRsZXJQ", + "b29sT3B0aW9ucxopChVSdW5IYW5kbGVyUG9vbE9wdGlvbnMSEAoIcHJpb3Jp", + "dHkYASABKAMiUgoKVHJhY2VMZXZlbBIMCghOT19UUkFDRRAAEhIKDlNPRlRX", + "QVJFX1RSQUNFEAESEgoOSEFSRFdBUkVfVFJBQ0UQAhIOCgpGVUxMX1RSQUNF", + "EANKBAgEEAUihwMKC1J1bk1ldGFkYXRhEikKCnN0ZXBfc3RhdHMYASABKAsy", + "FS50ZW5zb3JmbG93LlN0ZXBTdGF0cxIsCgpjb3N0X2dyYXBoGAIgASgLMhgu", + "dGVuc29yZmxvdy5Db3N0R3JhcGhEZWYSLgoQcGFydGl0aW9uX2dyYXBocxgD", + "IAMoCzIULnRlbnNvcmZsb3cuR3JhcGhEZWYSPwoPZnVuY3Rpb25fZ3JhcGhz", + "GAQgAygLMiYudGVuc29yZmxvdy5SdW5NZXRhZGF0YS5GdW5jdGlvbkdyYXBo", + "cxqtAQoORnVuY3Rpb25HcmFwaHMSLgoQcGFydGl0aW9uX2dyYXBocxgBIAMo", + "CzIULnRlbnNvcmZsb3cuR3JhcGhEZWYSNAoWcHJlX29wdGltaXphdGlvbl9n", + "cmFwaBgCIAEoCzIULnRlbnNvcmZsb3cuR3JhcGhEZWYSNQoXcG9zdF9vcHRp", + "bWl6YXRpb25fZ3JhcGgYAyABKAsyFC50ZW5zb3JmbG93LkdyYXBoRGVmIjoK", + "EFRlbnNvckNvbm5lY3Rpb24SEwoLZnJvbV90ZW5zb3IYASABKAkSEQoJdG9f", + "dGVuc29yGAIgASgJIrADCg9DYWxsYWJsZU9wdGlvbnMSDAoEZmVlZBgBIAMo", + "CRINCgVmZXRjaBgCIAMoCRIOCgZ0YXJnZXQYAyADKAkSKwoLcnVuX29wdGlv", + "bnMYBCABKAsyFi50ZW5zb3JmbG93LlJ1bk9wdGlvbnMSNwoRdGVuc29yX2Nv", + "bm5lY3Rpb24YBSADKAsyHC50ZW5zb3JmbG93LlRlbnNvckNvbm5lY3Rpb24S", + "QgoMZmVlZF9kZXZpY2VzGAYgAygLMiwudGVuc29yZmxvdy5DYWxsYWJsZU9w", + "dGlvbnMuRmVlZERldmljZXNFbnRyeRJECg1mZXRjaF9kZXZpY2VzGAcgAygL", + "Mi0udGVuc29yZmxvdy5DYWxsYWJsZU9wdGlvbnMuRmV0Y2hEZXZpY2VzRW50", + "cnkSFwoPZmV0Y2hfc2tpcF9zeW5jGAggASgIGjIKEEZlZWREZXZpY2VzRW50", + "cnkSCwoDa2V5GAEgASgJEg0KBXZhbHVlGAIgASgJOgI4ARozChFGZXRjaERl", + "dmljZXNFbnRyeRILCgNrZXkYASABKAkSDQoFdmFsdWUYAiABKAk6AjgBQoQB", + "ChhvcmcudGVuc29yZmxvdy5mcmFtZXdvcmtCDENvbmZpZ1Byb3Rvc1ABWlVn", + "aXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dv", + "L2NvcmUvcHJvdG9idWYvZm9yX2NvcmVfcHJvdG9zX2dvX3Byb3Rv+AEBYgZw", + "cm90bzM=")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::Tensorflow.CostGraphReflection.Descriptor, global::Tensorflow.GraphReflection.Descriptor, global::Tensorflow.StepStatsReflection.Descriptor, global::Tensorflow.ClusterReflection.Descriptor, global::Tensorflow.DebugReflection.Descriptor, global::Tensorflow.RewriterConfigReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions), global::Tensorflow.GPUOptions.Parser, new[]{ "PerProcessGpuMemoryFraction", "AllowGrowth", "AllocatorType", "DeferredDeletionBytes", "VisibleDeviceList", "PollingActiveDelayUsecs", "PollingInactiveDelayMsecs", "ForceGpuCompatible", "Experimental" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental), global::Tensorflow.GPUOptions.Types.Experimental.Parser, new[]{ "VirtualDevices", "UseUnifiedMemory", "NumDevToDevCopyStreams", "CollectiveRingOrder", "TimestampedAllocator", "KernelTrackerMaxInterval", "KernelTrackerMaxBytes", "KernelTrackerMaxPending" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices), global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices.Parser, new[]{ "MemoryLimitMb", "Priority" }, null, null, null, null)})}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions), global::Tensorflow.GPUOptions.Parser, new[]{ "PerProcessGpuMemoryFraction", "AllowGrowth", "AllocatorType", "DeferredDeletionBytes", "VisibleDeviceList", "PollingActiveDelayUsecs", "PollingInactiveDelayMsecs", "ForceGpuCompatible", "Experimental" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental), global::Tensorflow.GPUOptions.Types.Experimental.Parser, new[]{ "VirtualDevices", "UseUnifiedMemory", "NumDevToDevCopyStreams", "CollectiveRingOrder", "TimestampedAllocator", "KernelTrackerMaxInterval", "KernelTrackerMaxBytes", "KernelTrackerMaxPending", "InternalFragmentationFraction", "UseCudaMallocAsync" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices), global::Tensorflow.GPUOptions.Types.Experimental.Types.VirtualDevices.Parser, new[]{ "MemoryLimitMb", "Priority" }, null, null, null, null)})}), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OptimizerOptions), global::Tensorflow.OptimizerOptions.Parser, new[]{ "DoCommonSubexpressionElimination", "DoConstantFolding", "MaxFoldedConstantInBytes", "DoFunctionInlining", "OptLevel", "GlobalJitLevel" }, null, new[]{ typeof(global::Tensorflow.OptimizerOptions.Types.Level), typeof(global::Tensorflow.OptimizerOptions.Types.GlobalJitLevel) }, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GraphOptions), global::Tensorflow.GraphOptions.Parser, new[]{ "EnableRecvScheduling", "OptimizerOptions", "BuildCostModel", "BuildCostModelAfter", "InferShapes", "PlacePrunedGraph", "EnableBfloat16Sendrecv", "TimelineStep", "RewriteOptions" }, null, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ThreadPoolOptionProto), global::Tensorflow.ThreadPoolOptionProto.Parser, new[]{ "NumThreads", "GlobalName" }, null, null, null, null), - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RPCOptions), global::Tensorflow.RPCOptions.Parser, new[]{ "UseRpcForInprocessMaster", "CompressionAlgorithm", "CompressionLevel", "CacheRpcResponse", "DisableSessionConnectionSharing" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RPCOptions), global::Tensorflow.RPCOptions.Parser, new[]{ "UseRpcForInprocessMaster", "CompressionAlgorithm", "CompressionLevel", "CacheRpcResponse", "DisableSessionConnectionSharing", "NumChannelsPerTarget" }, null, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SessionMetadata), global::Tensorflow.SessionMetadata.Parser, new[]{ "Name", "Version" }, null, null, null, null), - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto), global::Tensorflow.ConfigProto.Parser, new[]{ "DeviceCount", "IntraOpParallelismThreads", "InterOpParallelismThreads", "UsePerSessionThreads", "SessionInterOpThreadPool", "PlacementPeriod", "DeviceFilters", "GpuOptions", "AllowSoftPlacement", "LogDevicePlacement", "GraphOptions", "OperationTimeoutInMs", "RpcOptions", "ClusterDef", "IsolateSessionState", "ShareClusterDevicesInSession", "Experimental" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto.Types.Experimental), global::Tensorflow.ConfigProto.Types.Experimental.Parser, new[]{ "CollectiveGroupLeader", "ExecutorType", "RecvBufMaxChunk", "UseNumaAffinity", "CollectiveDeterministicSequentialExecution", "CollectiveNccl", "ShareSessionStateInClusterspecPropagation", "DisableThreadSpinning", "ShareClusterDevicesInSession", "SessionMetadata", "OptimizeForStaticGraph", "EnableMlirBridge", "MlirBridgeRollout", "EnableMlirGraphOptimization", "DisableOutputPartitionGraphs", "XlaFusionAutotunerThresh", "UseTfrt" }, null, new[]{ typeof(global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout) }, null, null)}), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto), global::Tensorflow.ConfigProto.Parser, new[]{ "DeviceCount", "IntraOpParallelismThreads", "InterOpParallelismThreads", "UsePerSessionThreads", "SessionInterOpThreadPool", "PlacementPeriod", "DeviceFilters", "GpuOptions", "AllowSoftPlacement", "LogDevicePlacement", "GraphOptions", "OperationTimeoutInMs", "RpcOptions", "ClusterDef", "IsolateSessionState", "ShareClusterDevicesInSession", "Experimental" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ConfigProto.Types.Experimental), global::Tensorflow.ConfigProto.Types.Experimental.Parser, new[]{ "CollectiveGroupLeader", "ExecutorType", "RecvBufMaxChunk", "UseNumaAffinity", "CollectiveDeterministicSequentialExecution", "CollectiveNccl", "ShareSessionStateInClusterspecPropagation", "DisableThreadSpinning", "ShareClusterDevicesInSession", "SessionMetadata", "OptimizeForStaticGraph", "EnableMlirBridge", "MlirBridgeRollout", "EnableMlirGraphOptimization", "DisableOutputPartitionGraphs", "XlaFusionAutotunerThresh", "UseTfrt", "CoordinationService", "FetchRemoteDevicesInMultiClient" }, null, new[]{ typeof(global::Tensorflow.ConfigProto.Types.Experimental.Types.MlirBridgeRollout) }, null, null)}), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions), global::Tensorflow.RunOptions.Parser, new[]{ "TraceLevel", "TimeoutInMs", "InterOpThreadPool", "OutputPartitionGraphs", "DebugOptions", "ReportTensorAllocationsUponOom", "Experimental" }, null, new[]{ typeof(global::Tensorflow.RunOptions.Types.TraceLevel) }, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions.Types.Experimental), global::Tensorflow.RunOptions.Types.Experimental.Parser, new[]{ "CollectiveGraphKey", "UseRunHandlerPool", "RunHandlerPoolOptions" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunOptions.Types.Experimental.Types.RunHandlerPoolOptions), global::Tensorflow.RunOptions.Types.Experimental.Types.RunHandlerPoolOptions.Parser, new[]{ "Priority" }, null, null, null, null)})}), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata), global::Tensorflow.RunMetadata.Parser, new[]{ "StepStats", "CostGraph", "PartitionGraphs", "FunctionGraphs" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RunMetadata.Types.FunctionGraphs), global::Tensorflow.RunMetadata.Types.FunctionGraphs.Parser, new[]{ "PartitionGraphs", "PreOptimizationGraph", "PostOptimizationGraph" }, null, null, null, null)}), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.TensorConnection), global::Tensorflow.TensorConnection.Parser, new[]{ "FromTensor", "ToTensor" }, null, null, null, null), @@ -645,6 +649,8 @@ namespace Tensorflow { kernelTrackerMaxInterval_ = other.kernelTrackerMaxInterval_; kernelTrackerMaxBytes_ = other.kernelTrackerMaxBytes_; kernelTrackerMaxPending_ = other.kernelTrackerMaxPending_; + internalFragmentationFraction_ = other.internalFragmentationFraction_; + useCudaMallocAsync_ = other.useCudaMallocAsync_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -814,6 +820,42 @@ namespace Tensorflow { } } + /// Field number for the "internal_fragmentation_fraction" field. + public const int InternalFragmentationFractionFieldNumber = 10; + private double internalFragmentationFraction_; + /// + /// BFC Allocator can return an allocated chunk of memory upto 2x the + /// requested size. For virtual devices with tight memory constraints, and + /// proportionately large allocation requests, this can lead to a significant + /// reduction in available memory. The threshold below controls when a chunk + /// should be split if the chunk size exceeds requested memory size. It is + /// expressed as a fraction of total available memory for the tf device. For + /// example setting it to 0.05 would imply a chunk needs to be split if its + /// size exceeds the requested memory by 5% of the total virtual device/gpu + /// memory size. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public double InternalFragmentationFraction { + get { return internalFragmentationFraction_; } + set { + internalFragmentationFraction_ = value; + } + } + + /// Field number for the "use_cuda_malloc_async" field. + public const int UseCudaMallocAsyncFieldNumber = 11; + private bool useCudaMallocAsync_; + /// + /// When true, use CUDA cudaMallocAsync API instead of TF gpu allocator. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool UseCudaMallocAsync { + get { return useCudaMallocAsync_; } + set { + useCudaMallocAsync_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as Experimental); @@ -835,6 +877,8 @@ namespace Tensorflow { if (KernelTrackerMaxInterval != other.KernelTrackerMaxInterval) return false; if (KernelTrackerMaxBytes != other.KernelTrackerMaxBytes) return false; if (KernelTrackerMaxPending != other.KernelTrackerMaxPending) return false; + if (!pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.Equals(InternalFragmentationFraction, other.InternalFragmentationFraction)) return false; + if (UseCudaMallocAsync != other.UseCudaMallocAsync) return false; return Equals(_unknownFields, other._unknownFields); } @@ -849,6 +893,8 @@ namespace Tensorflow { if (KernelTrackerMaxInterval != 0) hash ^= KernelTrackerMaxInterval.GetHashCode(); if (KernelTrackerMaxBytes != 0) hash ^= KernelTrackerMaxBytes.GetHashCode(); if (KernelTrackerMaxPending != 0) hash ^= KernelTrackerMaxPending.GetHashCode(); + if (InternalFragmentationFraction != 0D) hash ^= pbc::ProtobufEqualityComparers.BitwiseDoubleEqualityComparer.GetHashCode(InternalFragmentationFraction); + if (UseCudaMallocAsync != false) hash ^= UseCudaMallocAsync.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -891,6 +937,14 @@ namespace Tensorflow { output.WriteRawTag(72); output.WriteInt32(KernelTrackerMaxPending); } + if (InternalFragmentationFraction != 0D) { + output.WriteRawTag(81); + output.WriteDouble(InternalFragmentationFraction); + } + if (UseCudaMallocAsync != false) { + output.WriteRawTag(88); + output.WriteBool(UseCudaMallocAsync); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -921,6 +975,12 @@ namespace Tensorflow { if (KernelTrackerMaxPending != 0) { size += 1 + pb::CodedOutputStream.ComputeInt32Size(KernelTrackerMaxPending); } + if (InternalFragmentationFraction != 0D) { + size += 1 + 8; + } + if (UseCudaMallocAsync != false) { + size += 1 + 1; + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -954,6 +1014,12 @@ namespace Tensorflow { if (other.KernelTrackerMaxPending != 0) { KernelTrackerMaxPending = other.KernelTrackerMaxPending; } + if (other.InternalFragmentationFraction != 0D) { + InternalFragmentationFraction = other.InternalFragmentationFraction; + } + if (other.UseCudaMallocAsync != false) { + UseCudaMallocAsync = other.UseCudaMallocAsync; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -997,6 +1063,14 @@ namespace Tensorflow { KernelTrackerMaxPending = input.ReadInt32(); break; } + case 81: { + InternalFragmentationFraction = input.ReadDouble(); + break; + } + case 88: { + UseCudaMallocAsync = input.ReadBool(); + break; + } } } } @@ -1231,6 +1305,9 @@ namespace Tensorflow { private bool doCommonSubexpressionElimination_; /// /// If true, optimize the graph using common subexpression elimination. + /// Note: the optimization Level L1 will override this setting to true. So in + /// order to disable common subexpression elimination the opt_level has to be + /// set to L0. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public bool DoCommonSubexpressionElimination { @@ -1245,6 +1322,8 @@ namespace Tensorflow { private bool doConstantFolding_; /// /// If true, perform constant folding optimization on the graph. + /// Note: the optimization Level L1 will override this setting to true. So in + /// order to disable constant folding the opt_level has to be set to L0. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public bool DoConstantFolding { @@ -2135,6 +2214,7 @@ namespace Tensorflow { compressionLevel_ = other.compressionLevel_; cacheRpcResponse_ = other.cacheRpcResponse_; disableSessionConnectionSharing_ = other.disableSessionConnectionSharing_; + numChannelsPerTarget_ = other.numChannelsPerTarget_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -2223,6 +2303,25 @@ namespace Tensorflow { } } + /// Field number for the "num_channels_per_target" field. + public const int NumChannelsPerTargetFieldNumber = 6; + private int numChannelsPerTarget_; + /// + /// Setting num_channels_per_target > 0 allows uses of multiple channels to + /// communicate to the same target. This can be used to improve the aggregate + /// throughput on high speed links (e.g 100G) where single connection is not + /// sufficient to maximize link utilization. Note that a single RPC only goes + /// on a single channel, this only helps in situations where there are multiple + /// transfers to the same target overlapping in time. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int NumChannelsPerTarget { + get { return numChannelsPerTarget_; } + set { + numChannelsPerTarget_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as RPCOptions); @@ -2241,6 +2340,7 @@ namespace Tensorflow { if (CompressionLevel != other.CompressionLevel) return false; if (CacheRpcResponse != other.CacheRpcResponse) return false; if (DisableSessionConnectionSharing != other.DisableSessionConnectionSharing) return false; + if (NumChannelsPerTarget != other.NumChannelsPerTarget) return false; return Equals(_unknownFields, other._unknownFields); } @@ -2252,6 +2352,7 @@ namespace Tensorflow { if (CompressionLevel != 0) hash ^= CompressionLevel.GetHashCode(); if (CacheRpcResponse != false) hash ^= CacheRpcResponse.GetHashCode(); if (DisableSessionConnectionSharing != false) hash ^= DisableSessionConnectionSharing.GetHashCode(); + if (NumChannelsPerTarget != 0) hash ^= NumChannelsPerTarget.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -2285,6 +2386,10 @@ namespace Tensorflow { output.WriteRawTag(40); output.WriteBool(DisableSessionConnectionSharing); } + if (NumChannelsPerTarget != 0) { + output.WriteRawTag(48); + output.WriteInt32(NumChannelsPerTarget); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -2308,6 +2413,9 @@ namespace Tensorflow { if (DisableSessionConnectionSharing != false) { size += 1 + 1; } + if (NumChannelsPerTarget != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumChannelsPerTarget); + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -2334,6 +2442,9 @@ namespace Tensorflow { if (other.DisableSessionConnectionSharing != false) { DisableSessionConnectionSharing = other.DisableSessionConnectionSharing; } + if (other.NumChannelsPerTarget != 0) { + NumChannelsPerTarget = other.NumChannelsPerTarget; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -2365,6 +2476,10 @@ namespace Tensorflow { DisableSessionConnectionSharing = input.ReadBool(); break; } + case 48: { + NumChannelsPerTarget = input.ReadInt32(); + break; + } } } } @@ -2624,7 +2739,7 @@ namespace Tensorflow { /// The first session created determines the number of threads in this pool. /// All subsequent sessions reuse/share this one global pool. /// - /// There are notable exceptions to the default behavior describe above: + /// There are notable exceptions to the default behavior described above: /// 1. There is an environment variable for overriding this thread pool, /// named TF_OVERRIDE_GLOBAL_THREADPOOL. /// 2. When connecting to a server, such as a remote `tf.train.Server` @@ -3292,6 +3407,8 @@ namespace Tensorflow { disableOutputPartitionGraphs_ = other.disableOutputPartitionGraphs_; xlaFusionAutotunerThresh_ = other.xlaFusionAutotunerThresh_; useTfrt_ = other.useTfrt_; + coordinationService_ = other.coordinationService_; + fetchRemoteDevicesInMultiClient_ = other.fetchRemoteDevicesInMultiClient_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -3606,6 +3723,38 @@ namespace Tensorflow { } } + /// Field number for the "coordination_service" field. + public const int CoordinationServiceFieldNumber = 19; + private string coordinationService_ = ""; + /// + /// Distributed coordination service to be enabled if set. + /// Currently only effective in multi-client setup. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string CoordinationService { + get { return coordinationService_; } + set { + coordinationService_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "fetch_remote_devices_in_multi_client" field. + public const int FetchRemoteDevicesInMultiClientFieldNumber = 20; + private bool fetchRemoteDevicesInMultiClient_; + /// + /// Whether the remote devices in the cluster should be fetched during setup + /// of multi-client cluster. If enabled, the workers will run an extra device + /// information exchange step during startup and the workers' EagerContexts + /// will become aware of remote devices in the cluster as well. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool FetchRemoteDevicesInMultiClient { + get { return fetchRemoteDevicesInMultiClient_; } + set { + fetchRemoteDevicesInMultiClient_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as Experimental); @@ -3636,6 +3785,8 @@ namespace Tensorflow { if (DisableOutputPartitionGraphs != other.DisableOutputPartitionGraphs) return false; if (XlaFusionAutotunerThresh != other.XlaFusionAutotunerThresh) return false; if (UseTfrt != other.UseTfrt) return false; + if (CoordinationService != other.CoordinationService) return false; + if (FetchRemoteDevicesInMultiClient != other.FetchRemoteDevicesInMultiClient) return false; return Equals(_unknownFields, other._unknownFields); } @@ -3659,6 +3810,8 @@ namespace Tensorflow { if (DisableOutputPartitionGraphs != false) hash ^= DisableOutputPartitionGraphs.GetHashCode(); if (XlaFusionAutotunerThresh != 0L) hash ^= XlaFusionAutotunerThresh.GetHashCode(); if (UseTfrt != false) hash ^= UseTfrt.GetHashCode(); + if (CoordinationService.Length != 0) hash ^= CoordinationService.GetHashCode(); + if (FetchRemoteDevicesInMultiClient != false) hash ^= FetchRemoteDevicesInMultiClient.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -3740,6 +3893,14 @@ namespace Tensorflow { output.WriteRawTag(144, 1); output.WriteBool(UseTfrt); } + if (CoordinationService.Length != 0) { + output.WriteRawTag(154, 1); + output.WriteString(CoordinationService); + } + if (FetchRemoteDevicesInMultiClient != false) { + output.WriteRawTag(160, 1); + output.WriteBool(FetchRemoteDevicesInMultiClient); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -3799,6 +3960,12 @@ namespace Tensorflow { if (UseTfrt != false) { size += 2 + 1; } + if (CoordinationService.Length != 0) { + size += 2 + pb::CodedOutputStream.ComputeStringSize(CoordinationService); + } + if (FetchRemoteDevicesInMultiClient != false) { + size += 2 + 1; + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -3864,6 +4031,12 @@ namespace Tensorflow { if (other.UseTfrt != false) { UseTfrt = other.UseTfrt; } + if (other.CoordinationService.Length != 0) { + CoordinationService = other.CoordinationService; + } + if (other.FetchRemoteDevicesInMultiClient != false) { + FetchRemoteDevicesInMultiClient = other.FetchRemoteDevicesInMultiClient; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -3946,6 +4119,14 @@ namespace Tensorflow { UseTfrt = input.ReadBool(); break; } + case 154: { + CoordinationService = input.ReadString(); + break; + } + case 160: { + FetchRemoteDevicesInMultiClient = input.ReadBool(); + break; + } } } } diff --git a/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs b/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs index 0503e546..f76bf2f0 100644 --- a/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs +++ b/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs @@ -25,27 +25,28 @@ namespace Tensorflow { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( "CjV0ZW5zb3JmbG93L3B5dGhvbi9mcmFtZXdvcmsvY3BwX3NoYXBlX2luZmVy", - "ZW5jZS5wcm90bxIKdGVuc29yZmxvdxoldGVuc29yZmxvdy9jb3JlL2ZyYW1l", - "d29yay90eXBlcy5wcm90bxosdGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay90", - "ZW5zb3Jfc2hhcGUucHJvdG8ipQMKF0NwcFNoYXBlSW5mZXJlbmNlUmVzdWx0", + "ZW5jZS5wcm90bxIKdGVuc29yZmxvdxopdGVuc29yZmxvdy9jb3JlL2ZyYW1l", + "d29yay9mdWxsX3R5cGUucHJvdG8aLHRlbnNvcmZsb3cvY29yZS9mcmFtZXdv", + "cmsvdGVuc29yX3NoYXBlLnByb3RvGiV0ZW5zb3JmbG93L2NvcmUvZnJhbWV3", + "b3JrL3R5cGVzLnByb3RvIpsDChdDcHBTaGFwZUluZmVyZW5jZVJlc3VsdBIr", + "CgVzaGFwZRgBIAEoCzIcLnRlbnNvcmZsb3cuVGVuc29yU2hhcGVQcm90bxJD", + "CgtoYW5kbGVfZGF0YRgEIAEoCzIuLnRlbnNvcmZsb3cuQ3BwU2hhcGVJbmZl", + "cmVuY2VSZXN1bHQuSGFuZGxlRGF0YRqTAQoSSGFuZGxlU2hhcGVBbmRUeXBl", "EisKBXNoYXBlGAEgASgLMhwudGVuc29yZmxvdy5UZW5zb3JTaGFwZVByb3Rv", - "EkMKC2hhbmRsZV9kYXRhGAQgASgLMi4udGVuc29yZmxvdy5DcHBTaGFwZUlu", - "ZmVyZW5jZVJlc3VsdC5IYW5kbGVEYXRhGp0BChJIYW5kbGVTaGFwZUFuZFR5", - "cGUSKwoFc2hhcGUYASABKAsyHC50ZW5zb3JmbG93LlRlbnNvclNoYXBlUHJv", - "dG8SIwoFZHR5cGUYAiABKA4yFC50ZW5zb3JmbG93LkRhdGFUeXBlEjUKEHNw", - "ZWNpYWxpemVkX3R5cGUYAyABKA4yGy50ZW5zb3JmbG93LlNwZWNpYWxpemVk", - "VHlwZRpsCgpIYW5kbGVEYXRhEg4KBmlzX3NldBgBIAEoCBJOCg5zaGFwZV9h", - "bmRfdHlwZRgCIAMoCzI2LnRlbnNvcmZsb3cuQ3BwU2hhcGVJbmZlcmVuY2VS", - "ZXN1bHQuSGFuZGxlU2hhcGVBbmRUeXBlSgQIAhADSgQIAxAEImUKHUNwcFNo", - "YXBlSW5mZXJlbmNlSW5wdXRzTmVlZGVkEhwKFGlucHV0X3RlbnNvcnNfbmVl", - "ZGVkGAEgAygFEiYKHmlucHV0X3RlbnNvcnNfYXNfc2hhcGVzX25lZWRlZBgC", - "IAMoBUJhWlxnaXRodWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5z", - "b3JmbG93L2dvL3B5dGhvbi9mcmFtZXdvcmsvY3BwX3NoYXBlX2luZmVyZW5j", - "ZV9nb19wcm90b/gBAWIGcHJvdG8z")); + "EiMKBWR0eXBlGAIgASgOMhQudGVuc29yZmxvdy5EYXRhVHlwZRIlCgR0eXBl", + "GAQgASgLMhcudGVuc29yZmxvdy5GdWxsVHlwZURlZkoECAMQBBpsCgpIYW5k", + "bGVEYXRhEg4KBmlzX3NldBgBIAEoCBJOCg5zaGFwZV9hbmRfdHlwZRgCIAMo", + "CzI2LnRlbnNvcmZsb3cuQ3BwU2hhcGVJbmZlcmVuY2VSZXN1bHQuSGFuZGxl", + "U2hhcGVBbmRUeXBlSgQIAhADSgQIAxAEImUKHUNwcFNoYXBlSW5mZXJlbmNl", + "SW5wdXRzTmVlZGVkEhwKFGlucHV0X3RlbnNvcnNfbmVlZGVkGAEgAygFEiYK", + "HmlucHV0X3RlbnNvcnNfYXNfc2hhcGVzX25lZWRlZBgCIAMoBUJhWlxnaXRo", + "dWIuY29tL3RlbnNvcmZsb3cvdGVuc29yZmxvdy90ZW5zb3JmbG93L2dvL3B5", + "dGhvbi9mcmFtZXdvcmsvY3BwX3NoYXBlX2luZmVyZW5jZV9nb19wcm90b/gB", + "AWIGcHJvdG8z")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, - new pbr::FileDescriptor[] { global::Tensorflow.TypesReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, }, + new pbr::FileDescriptor[] { global::Tensorflow.FullTypeReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceResult), global::Tensorflow.CppShapeInferenceResult.Parser, new[]{ "Shape", "HandleData" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType), global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType.Parser, new[]{ "Shape", "Dtype", "SpecializedType" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceResult), global::Tensorflow.CppShapeInferenceResult.Parser, new[]{ "Shape", "HandleData" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType), global::Tensorflow.CppShapeInferenceResult.Types.HandleShapeAndType.Parser, new[]{ "Shape", "Dtype", "Type" }, null, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceResult.Types.HandleData), global::Tensorflow.CppShapeInferenceResult.Types.HandleData.Parser, new[]{ "IsSet", "ShapeAndType" }, null, null, null, null)}), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CppShapeInferenceInputsNeeded), global::Tensorflow.CppShapeInferenceInputsNeeded.Parser, new[]{ "InputTensorsNeeded", "InputTensorsAsShapesNeeded" }, null, null, null, null) })); @@ -252,7 +253,7 @@ namespace Tensorflow { public HandleShapeAndType(HandleShapeAndType other) : this() { shape_ = other.shape_ != null ? other.shape_.Clone() : null; dtype_ = other.dtype_; - specializedType_ = other.specializedType_; + type_ = other.type_ != null ? other.type_.Clone() : null; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -283,19 +284,14 @@ namespace Tensorflow { } } - /// Field number for the "specialized_type" field. - public const int SpecializedTypeFieldNumber = 3; - private global::Tensorflow.SpecializedType specializedType_ = global::Tensorflow.SpecializedType.StInvalid; - /// - /// For dtype==DT_VARIANT, specialized_type may indicate a more specific - /// type. For other dtypes or when the information is unavailable it is set - /// to ST_INVALID. - /// + /// Field number for the "type" field. + public const int TypeFieldNumber = 4; + private global::Tensorflow.FullTypeDef type_; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public global::Tensorflow.SpecializedType SpecializedType { - get { return specializedType_; } + public global::Tensorflow.FullTypeDef Type { + get { return type_; } set { - specializedType_ = value; + type_ = value; } } @@ -314,7 +310,7 @@ namespace Tensorflow { } if (!object.Equals(Shape, other.Shape)) return false; if (Dtype != other.Dtype) return false; - if (SpecializedType != other.SpecializedType) return false; + if (!object.Equals(Type, other.Type)) return false; return Equals(_unknownFields, other._unknownFields); } @@ -323,7 +319,7 @@ namespace Tensorflow { int hash = 1; if (shape_ != null) hash ^= Shape.GetHashCode(); if (Dtype != global::Tensorflow.DataType.DtInvalid) hash ^= Dtype.GetHashCode(); - if (SpecializedType != global::Tensorflow.SpecializedType.StInvalid) hash ^= SpecializedType.GetHashCode(); + if (type_ != null) hash ^= Type.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -345,9 +341,9 @@ namespace Tensorflow { output.WriteRawTag(16); output.WriteEnum((int) Dtype); } - if (SpecializedType != global::Tensorflow.SpecializedType.StInvalid) { - output.WriteRawTag(24); - output.WriteEnum((int) SpecializedType); + if (type_ != null) { + output.WriteRawTag(34); + output.WriteMessage(Type); } if (_unknownFields != null) { _unknownFields.WriteTo(output); @@ -363,8 +359,8 @@ namespace Tensorflow { if (Dtype != global::Tensorflow.DataType.DtInvalid) { size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) Dtype); } - if (SpecializedType != global::Tensorflow.SpecializedType.StInvalid) { - size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) SpecializedType); + if (type_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(Type); } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); @@ -386,8 +382,11 @@ namespace Tensorflow { if (other.Dtype != global::Tensorflow.DataType.DtInvalid) { Dtype = other.Dtype; } - if (other.SpecializedType != global::Tensorflow.SpecializedType.StInvalid) { - SpecializedType = other.SpecializedType; + if (other.type_ != null) { + if (type_ == null) { + Type = new global::Tensorflow.FullTypeDef(); + } + Type.MergeFrom(other.Type); } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -411,8 +410,11 @@ namespace Tensorflow { Dtype = (global::Tensorflow.DataType) input.ReadEnum(); break; } - case 24: { - SpecializedType = (global::Tensorflow.SpecializedType) input.ReadEnum(); + case 34: { + if (type_ == null) { + Type = new global::Tensorflow.FullTypeDef(); + } + input.ReadMessage(Type); break; } } diff --git a/src/TensorFlowNET.Core/Protobuf/FullType.cs b/src/TensorFlowNET.Core/Protobuf/FullType.cs new file mode 100644 index 00000000..a8b54b2a --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/FullType.cs @@ -0,0 +1,450 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/framework/full_type.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/framework/full_type.proto + public static partial class FullTypeReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/framework/full_type.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static FullTypeReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cil0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Z1bGxfdHlwZS5wcm90bxIK", + "dGVuc29yZmxvdyJyCgtGdWxsVHlwZURlZhInCgd0eXBlX2lkGAEgASgOMhYu", + "dGVuc29yZmxvdy5GdWxsVHlwZUlkEiUKBGFyZ3MYAiADKAsyFy50ZW5zb3Jm", + "bG93LkZ1bGxUeXBlRGVmEgsKAXMYAyABKAlIAEIGCgRhdHRyKqwDCgpGdWxs", + "VHlwZUlkEg0KCVRGVF9VTlNFVBAAEgsKB1RGVF9WQVIQARILCgdURlRfQU5Z", + "EAISDwoLVEZUX1BST0RVQ1QQAxIQCgxURlRfQ0FMTEFCTEUQZBIPCgpURlRf", + "VEVOU09SEOgHEg4KCVRGVF9BUlJBWRDpBxIRCgxURlRfT1BUSU9OQUwQ6gcS", + "EAoLVEZUX0RBVEFTRVQQ9k4SDQoIVEZUX0JPT0wQyAESDgoJVEZUX1VJTlQ4", + "EMkBEg8KClRGVF9VSU5UMTYQygESDwoKVEZUX1VJTlQzMhDLARIPCgpURlRf", + "VUlOVDY0EMwBEg0KCFRGVF9JTlQ4EM0BEg4KCVRGVF9JTlQxNhDOARIOCglU", + "RlRfSU5UMzIQzwESDgoJVEZUX0lOVDY0ENABEg0KCFRGVF9IQUxGENEBEg4K", + "CVRGVF9GTE9BVBDSARIPCgpURlRfRE9VQkxFENMBEhEKDFRGVF9CRkxPQVQx", + "NhDXARISCg1URlRfQ09NUExFWDY0ENQBEhMKDlRGVF9DT01QTEVYMTI4ENUB", + "Eg8KClRGVF9TVFJJTkcQ1gFCfQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3b3Jr", + "Qg5GdWxsVHlwZVByb3Rvc1ABWkxnaXRodWIuY29tL3RlbnNvcmZsb3cvdGVu", + "c29yZmxvdy90ZW5zb3JmbG93L2dvL2NvcmUvZnJhbWV3b3JrL3R5cGVzX2dv", + "X3Byb3Rv+AEBYgZwcm90bzM=")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.FullTypeId), }, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FullTypeDef), global::Tensorflow.FullTypeDef.Parser, new[]{ "TypeId", "Args", "S" }, new[]{ "Attr" }, null, null, null) + })); + } + #endregion + + } + #region Enums + /// + /// Experimental. Represents the complete type information of a TensorFlow value. + /// + public enum FullTypeId { + /// + /// The default represents an uninitialized values. + /// + [pbr::OriginalName("TFT_UNSET")] TftUnset = 0, + /// + /// Type variables may serve as placeholder for any other type ID in type + /// templates. + /// + /// Examples: + /// TFT_DATASET[TFT_VAR["T"]] is a Dataset returning a type indicated by "T". + /// TFT_TENSOR[TFT_VAR["T"]] is a Tensor of n element type indicated by "T". + /// TFT_TENSOR[TFT_VAR["T"]], TFT_TENSOR[TFT_VAR["T"]] are two tensors of + /// identical element types. + /// TFT_TENSOR[TFT_VAR["P"]], TFT_TENSOR[TFT_VAR["Q"]] are two tensors of + /// potentially different element types. + /// + [pbr::OriginalName("TFT_VAR")] TftVar = 1, + /// + /// Wildcard type. Describes a parameter of unknown type. In TensorFlow, that + /// can mean either a "Top" type (accepts any type), or a dynamically typed + /// object whose type is unknown in context. + /// Important: "unknown" does not necessarily mean undeterminable! + /// + [pbr::OriginalName("TFT_ANY")] TftAny = 2, + /// + /// The algebraic product type. This is an algebraic type that may be used just + /// for logical grouping. Not to confused with TFT_TUPLE which describes a + /// concrete object of several elements. + /// + /// Example: + /// TFT_DATASET[TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT64]]] + /// is a Dataset producing two tensors, an integer one and a float one. + /// + [pbr::OriginalName("TFT_PRODUCT")] TftProduct = 3, + /// + /// Callable types describe functions and ops. + /// + /// Parametrization: + /// TFT_CALLABLE[<arg type>, <return type>] + /// * <arg_type> is the type of the arguments; TFT_PRODUCT represents + /// multiple + /// arguments. + /// * <return_type> is the return type; TFT_PRODUCT represents multiple + /// return values (that means that callables returning multiple things + /// don't necessarily return a single tuple). + /// + /// Example: + /// TFT_CALLABLE[ + /// TFT_ANY, + /// TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT64]], + /// ] + /// is a callable with unspecified (for now) input arguments, and + /// two return values of type tensor. + /// + [pbr::OriginalName("TFT_CALLABLE")] TftCallable = 100, + /// + /// The usual Tensor. This is a parametric type. + /// + /// Parametrization: + /// TFT_TENSOR[<element type>, <shape type>] + /// * <element_type> is currently limited to one of the element types + /// defined below. + /// * <shape_type> is not yet defined, and may only be TFT_UNKNOWN for now. + /// + /// A TFT_SHAPE type will be defined in the future. + /// + /// Example: + /// TFT_TENSOR[TFT_INT32, TFT_UNKNOWN] + /// is a Tensor of int32 element type and unknown shape. + /// + /// TODO(mdan): Define TFT_SHAPE and add more examples. + /// + [pbr::OriginalName("TFT_TENSOR")] TftTensor = 1000, + /// + /// Array (or tensorflow::TensorList in the variant type registry). + /// Note: this is not to be confused with the deprecated `TensorArray*` ops + /// which are not supported by FullType. + /// This type represents a random-access list whose elements can be + /// described by a single type. Although immutable, Array is expected to + /// support efficient mutation semantics (i.e. element update) in the + /// user-facing API. + /// The element type may be generic or even TFT_ANY for a heterogenous list. + /// + /// Parametrization: + /// TFT_ARRAY[<element type>] + /// * <element_type> may be any concrete type. + /// + /// Examples: + /// TFT_ARRAY[TFT_TENSOR[TFT_INT32]] is a TensorArray holding int32 Tensors + /// of any shape. + /// TFT_ARRAY[TFT_TENSOR[TFT_UNKNOWN]] is a TensorArray holding Tensors of + /// mixed element types. + /// TFT_ARRAY[TFT_UNKNOWN] is a TensorArray holding any element type. + /// TFT_ARRAY[] is equivalent to TFT_ARRAY[TFT_UNKNOWN]. + /// TFT_ARRAY[TFT_ARRAY[]] is an array or arrays (of unknown types). + /// + [pbr::OriginalName("TFT_ARRAY")] TftArray = 1001, + /// + /// Optional (or tensorflow::OptionalVariant in the variant type registry). + /// This type represents a value that may either hold an element of a single + /// specified type, or nothing at all. + /// + /// Parametrization: + /// TFT_OPTIONAL[<element type>] + /// * <element_type> may be any concrete type. + /// + /// Examples: + /// TFT_OPTIONAL[TFT_TENSOR[TFT_INT32]] is an Optional holding an int32 + /// Tensor of any shape. + /// + [pbr::OriginalName("TFT_OPTIONAL")] TftOptional = 1002, + /// + /// Datasets created by tf.data ops and APIs. Datasets have generator/iterable + /// semantics, that is, one can construct an iterator from them. Like + /// Array, they are considered to return elements that can be described + /// by a single type. Unlike Array, they do not support random access or + /// mutation, and can potentially produce an infinite number of elements. + /// A datasets can produce logical structures (e.g. multiple elements). This + /// is expressed using TFT_PRODUCT. + /// + /// Parametrization: TFT_ARRAY[<element type>]. + /// <element_type> may be a concrete type or a type symbol. It represents the + /// data type of the elements produced by the dataset. + /// + /// Examples: + /// TFT_DATSET[TFT_TENSOR[TFT_INT32]] is a Dataset producing single int32 + /// Tensors of unknown shape. + /// TFT_DATSET[TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT32]] is + /// a + /// Dataset producing pairs of Tensors, one integer and one float. + /// Note: The high ID number is to prepare for the eventuality that Datasets + /// will be supported by user types in the future. + /// + [pbr::OriginalName("TFT_DATASET")] TftDataset = 10102, + /// + /// The bool element type. + /// TODO(mdan): Quantized types, legacy representations (e.g. ref) + /// + [pbr::OriginalName("TFT_BOOL")] TftBool = 200, + /// + /// Integer element types. + /// + [pbr::OriginalName("TFT_UINT8")] TftUint8 = 201, + [pbr::OriginalName("TFT_UINT16")] TftUint16 = 202, + [pbr::OriginalName("TFT_UINT32")] TftUint32 = 203, + [pbr::OriginalName("TFT_UINT64")] TftUint64 = 204, + [pbr::OriginalName("TFT_INT8")] TftInt8 = 205, + [pbr::OriginalName("TFT_INT16")] TftInt16 = 206, + [pbr::OriginalName("TFT_INT32")] TftInt32 = 207, + [pbr::OriginalName("TFT_INT64")] TftInt64 = 208, + /// + /// Floating-point element types. + /// + [pbr::OriginalName("TFT_HALF")] TftHalf = 209, + [pbr::OriginalName("TFT_FLOAT")] TftFloat = 210, + [pbr::OriginalName("TFT_DOUBLE")] TftDouble = 211, + [pbr::OriginalName("TFT_BFLOAT16")] TftBfloat16 = 215, + /// + /// Complex element types. + /// TODO(mdan): Represent as TFT_COMPLEX[TFT_DOUBLE] instead? + /// + [pbr::OriginalName("TFT_COMPLEX64")] TftComplex64 = 212, + [pbr::OriginalName("TFT_COMPLEX128")] TftComplex128 = 213, + /// + /// The string element type. + /// + [pbr::OriginalName("TFT_STRING")] TftString = 214, + } + + #endregion + + #region Messages + /// + /// Highly experimental and very likely to change. + /// This encoding uses tags instead of dedicated messages for regularity. In + /// particular the encoding imposes no restrictions on what the parameters of any + /// type should be, which in particular needs to be true for type symbols. + /// + public sealed partial class FullTypeDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new FullTypeDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.FullTypeReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FullTypeDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FullTypeDef(FullTypeDef other) : this() { + typeId_ = other.typeId_; + args_ = other.args_.Clone(); + switch (other.AttrCase) { + case AttrOneofCase.S: + S = other.S; + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public FullTypeDef Clone() { + return new FullTypeDef(this); + } + + /// Field number for the "type_id" field. + public const int TypeIdFieldNumber = 1; + private global::Tensorflow.FullTypeId typeId_ = global::Tensorflow.FullTypeId.TftUnset; + /// + /// The principal type represented by this object. This may be a concrete type + /// (Tensor, Dataset) a type variable (used for dependent types) a type + /// symbol (Any, Union). See FullTypeId for details. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.FullTypeId TypeId { + get { return typeId_; } + set { + typeId_ = value; + } + } + + /// Field number for the "args" field. + public const int ArgsFieldNumber = 2; + private static readonly pb::FieldCodec _repeated_args_codec + = pb::FieldCodec.ForMessage(18, global::Tensorflow.FullTypeDef.Parser); + private readonly pbc::RepeatedField args_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Args { + get { return args_; } + } + + /// Field number for the "s" field. + public const int SFieldNumber = 3; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string S { + get { return attrCase_ == AttrOneofCase.S ? (string) attr_ : ""; } + set { + attr_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + attrCase_ = AttrOneofCase.S; + } + } + + private object attr_; + /// Enum of possible cases for the "attr" oneof. + public enum AttrOneofCase { + None = 0, + S = 3, + } + private AttrOneofCase attrCase_ = AttrOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public AttrOneofCase AttrCase { + get { return attrCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearAttr() { + attrCase_ = AttrOneofCase.None; + attr_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as FullTypeDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(FullTypeDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (TypeId != other.TypeId) return false; + if(!args_.Equals(other.args_)) return false; + if (S != other.S) return false; + if (AttrCase != other.AttrCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (TypeId != global::Tensorflow.FullTypeId.TftUnset) hash ^= TypeId.GetHashCode(); + hash ^= args_.GetHashCode(); + if (attrCase_ == AttrOneofCase.S) hash ^= S.GetHashCode(); + hash ^= (int) attrCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (TypeId != global::Tensorflow.FullTypeId.TftUnset) { + output.WriteRawTag(8); + output.WriteEnum((int) TypeId); + } + args_.WriteTo(output, _repeated_args_codec); + if (attrCase_ == AttrOneofCase.S) { + output.WriteRawTag(26); + output.WriteString(S); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (TypeId != global::Tensorflow.FullTypeId.TftUnset) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) TypeId); + } + size += args_.CalculateSize(_repeated_args_codec); + if (attrCase_ == AttrOneofCase.S) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(S); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(FullTypeDef other) { + if (other == null) { + return; + } + if (other.TypeId != global::Tensorflow.FullTypeId.TftUnset) { + TypeId = other.TypeId; + } + args_.Add(other.args_); + switch (other.AttrCase) { + case AttrOneofCase.S: + S = other.S; + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + TypeId = (global::Tensorflow.FullTypeId) input.ReadEnum(); + break; + } + case 18: { + args_.AddEntriesFrom(input, _repeated_args_codec); + break; + } + case 26: { + S = input.ReadString(); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/Function.cs b/src/TensorFlowNET.Core/Protobuf/Function.cs index 78665f0d..63cdc44f 100644 --- a/src/TensorFlowNET.Core/Protobuf/Function.cs +++ b/src/TensorFlowNET.Core/Protobuf/Function.cs @@ -28,39 +28,43 @@ namespace Tensorflow { "ZW5zb3JmbG93Gip0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2F0dHJfdmFs", "dWUucHJvdG8aKHRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvbm9kZV9kZWYu", "cHJvdG8aJnRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvb3BfZGVmLnByb3Rv", - "ImoKEkZ1bmN0aW9uRGVmTGlicmFyeRIpCghmdW5jdGlvbhgBIAMoCzIXLnRl", - "bnNvcmZsb3cuRnVuY3Rpb25EZWYSKQoIZ3JhZGllbnQYAiADKAsyFy50ZW5z", - "b3JmbG93LkdyYWRpZW50RGVmIsQGCgtGdW5jdGlvbkRlZhIkCglzaWduYXR1", - "cmUYASABKAsyES50ZW5zb3JmbG93Lk9wRGVmEi8KBGF0dHIYBSADKAsyIS50", - "ZW5zb3JmbG93LkZ1bmN0aW9uRGVmLkF0dHJFbnRyeRI2CghhcmdfYXR0chgH", - "IAMoCzIkLnRlbnNvcmZsb3cuRnVuY3Rpb25EZWYuQXJnQXR0ckVudHJ5ElAK", - "FnJlc291cmNlX2FyZ191bmlxdWVfaWQYCCADKAsyMC50ZW5zb3JmbG93LkZ1", - "bmN0aW9uRGVmLlJlc291cmNlQXJnVW5pcXVlSWRFbnRyeRIlCghub2RlX2Rl", - "ZhgDIAMoCzITLnRlbnNvcmZsb3cuTm9kZURlZhItCgNyZXQYBCADKAsyIC50", - "ZW5zb3JmbG93LkZ1bmN0aW9uRGVmLlJldEVudHJ5EjwKC2NvbnRyb2xfcmV0", - "GAYgAygLMicudGVuc29yZmxvdy5GdW5jdGlvbkRlZi5Db250cm9sUmV0RW50", - "cnkaQgoJQXR0ckVudHJ5EgsKA2tleRgBIAEoCRIkCgV2YWx1ZRgCIAEoCzIV", - "LnRlbnNvcmZsb3cuQXR0clZhbHVlOgI4ARqIAQoIQXJnQXR0cnMSOAoEYXR0", - "chgBIAMoCzIqLnRlbnNvcmZsb3cuRnVuY3Rpb25EZWYuQXJnQXR0cnMuQXR0", - "ckVudHJ5GkIKCUF0dHJFbnRyeRILCgNrZXkYASABKAkSJAoFdmFsdWUYAiAB", - "KAsyFS50ZW5zb3JmbG93LkF0dHJWYWx1ZToCOAEaUAoMQXJnQXR0ckVudHJ5", - "EgsKA2tleRgBIAEoDRIvCgV2YWx1ZRgCIAEoCzIgLnRlbnNvcmZsb3cuRnVu", - "Y3Rpb25EZWYuQXJnQXR0cnM6AjgBGjoKGFJlc291cmNlQXJnVW5pcXVlSWRF", - "bnRyeRILCgNrZXkYASABKA0SDQoFdmFsdWUYAiABKA06AjgBGioKCFJldEVu", - "dHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgCIAEoCToCOAEaMQoPQ29udHJv", - "bFJldEVudHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgCIAEoCToCOAFKBAgC", - "EAMiOwoLR3JhZGllbnREZWYSFQoNZnVuY3Rpb25fbmFtZRgBIAEoCRIVCg1n", - "cmFkaWVudF9mdW5jGAIgASgJQoABChhvcmcudGVuc29yZmxvdy5mcmFtZXdv", - "cmtCDkZ1bmN0aW9uUHJvdG9zUAFaT2dpdGh1Yi5jb20vdGVuc29yZmxvdy90", - "ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9mcmFtZXdvcmsvZnVuY3Rp", - "b25fZ29fcHJvdG/4AQFiBnByb3RvMw==")); + "IqgBChJGdW5jdGlvbkRlZkxpYnJhcnkSKQoIZnVuY3Rpb24YASADKAsyFy50", + "ZW5zb3JmbG93LkZ1bmN0aW9uRGVmEikKCGdyYWRpZW50GAIgAygLMhcudGVu", + "c29yZmxvdy5HcmFkaWVudERlZhI8ChRyZWdpc3RlcmVkX2dyYWRpZW50cxgD", + "IAMoCzIeLnRlbnNvcmZsb3cuUmVnaXN0ZXJlZEdyYWRpZW50IsQGCgtGdW5j", + "dGlvbkRlZhIkCglzaWduYXR1cmUYASABKAsyES50ZW5zb3JmbG93Lk9wRGVm", + "Ei8KBGF0dHIYBSADKAsyIS50ZW5zb3JmbG93LkZ1bmN0aW9uRGVmLkF0dHJF", + "bnRyeRI2CghhcmdfYXR0chgHIAMoCzIkLnRlbnNvcmZsb3cuRnVuY3Rpb25E", + "ZWYuQXJnQXR0ckVudHJ5ElAKFnJlc291cmNlX2FyZ191bmlxdWVfaWQYCCAD", + "KAsyMC50ZW5zb3JmbG93LkZ1bmN0aW9uRGVmLlJlc291cmNlQXJnVW5pcXVl", + "SWRFbnRyeRIlCghub2RlX2RlZhgDIAMoCzITLnRlbnNvcmZsb3cuTm9kZURl", + "ZhItCgNyZXQYBCADKAsyIC50ZW5zb3JmbG93LkZ1bmN0aW9uRGVmLlJldEVu", + "dHJ5EjwKC2NvbnRyb2xfcmV0GAYgAygLMicudGVuc29yZmxvdy5GdW5jdGlv", + "bkRlZi5Db250cm9sUmV0RW50cnkaQgoJQXR0ckVudHJ5EgsKA2tleRgBIAEo", + "CRIkCgV2YWx1ZRgCIAEoCzIVLnRlbnNvcmZsb3cuQXR0clZhbHVlOgI4ARqI", + "AQoIQXJnQXR0cnMSOAoEYXR0chgBIAMoCzIqLnRlbnNvcmZsb3cuRnVuY3Rp", + "b25EZWYuQXJnQXR0cnMuQXR0ckVudHJ5GkIKCUF0dHJFbnRyeRILCgNrZXkY", + "ASABKAkSJAoFdmFsdWUYAiABKAsyFS50ZW5zb3JmbG93LkF0dHJWYWx1ZToC", + "OAEaUAoMQXJnQXR0ckVudHJ5EgsKA2tleRgBIAEoDRIvCgV2YWx1ZRgCIAEo", + "CzIgLnRlbnNvcmZsb3cuRnVuY3Rpb25EZWYuQXJnQXR0cnM6AjgBGjoKGFJl", + "c291cmNlQXJnVW5pcXVlSWRFbnRyeRILCgNrZXkYASABKA0SDQoFdmFsdWUY", + "AiABKA06AjgBGioKCFJldEVudHJ5EgsKA2tleRgBIAEoCRINCgV2YWx1ZRgC", + "IAEoCToCOAEaMQoPQ29udHJvbFJldEVudHJ5EgsKA2tleRgBIAEoCRINCgV2", + "YWx1ZRgCIAEoCToCOAFKBAgCEAMiOwoLR3JhZGllbnREZWYSFQoNZnVuY3Rp", + "b25fbmFtZRgBIAEoCRIVCg1ncmFkaWVudF9mdW5jGAIgASgJIkcKElJlZ2lz", + "dGVyZWRHcmFkaWVudBIVCg1ncmFkaWVudF9mdW5jGAEgASgJEhoKEnJlZ2lz", + "dGVyZWRfb3BfdHlwZRgCIAEoCUKAAQoYb3JnLnRlbnNvcmZsb3cuZnJhbWV3", + "b3JrQg5GdW5jdGlvblByb3Rvc1ABWk9naXRodWIuY29tL3RlbnNvcmZsb3cv", + "dGVuc29yZmxvdy90ZW5zb3JmbG93L2dvL2NvcmUvZnJhbWV3b3JrL2Z1bmN0", + "aW9uX2dvX3Byb3Rv+AEBYgZwcm90bzM=")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.NodeDefReflection.Descriptor, global::Tensorflow.OpDefReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDefLibrary), global::Tensorflow.FunctionDefLibrary.Parser, new[]{ "Function", "Gradient" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDefLibrary), global::Tensorflow.FunctionDefLibrary.Parser, new[]{ "Function", "Gradient", "RegisteredGradients" }, null, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDef), global::Tensorflow.FunctionDef.Parser, new[]{ "Signature", "Attr", "ArgAttr", "ResourceArgUniqueId", "NodeDef", "Ret", "ControlRet" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.FunctionDef.Types.ArgAttrs), global::Tensorflow.FunctionDef.Types.ArgAttrs.Parser, new[]{ "Attr" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), null, null, null, null, }), - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GradientDef), global::Tensorflow.GradientDef.Parser, new[]{ "FunctionName", "GradientFunc" }, null, null, null, null) + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.GradientDef), global::Tensorflow.GradientDef.Parser, new[]{ "FunctionName", "GradientFunc" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.RegisteredGradient), global::Tensorflow.RegisteredGradient.Parser, new[]{ "GradientFunc", "RegisteredOpType" }, null, null, null, null) })); } #endregion @@ -97,6 +101,7 @@ namespace Tensorflow { public FunctionDefLibrary(FunctionDefLibrary other) : this() { function_ = other.function_.Clone(); gradient_ = other.gradient_.Clone(); + registeredGradients_ = other.registeredGradients_.Clone(); _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -125,6 +130,16 @@ namespace Tensorflow { get { return gradient_; } } + /// Field number for the "registered_gradients" field. + public const int RegisteredGradientsFieldNumber = 3; + private static readonly pb::FieldCodec _repeated_registeredGradients_codec + = pb::FieldCodec.ForMessage(26, global::Tensorflow.RegisteredGradient.Parser); + private readonly pbc::RepeatedField registeredGradients_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField RegisteredGradients { + get { return registeredGradients_; } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as FunctionDefLibrary); @@ -140,6 +155,7 @@ namespace Tensorflow { } if(!function_.Equals(other.function_)) return false; if(!gradient_.Equals(other.gradient_)) return false; + if(!registeredGradients_.Equals(other.registeredGradients_)) return false; return Equals(_unknownFields, other._unknownFields); } @@ -148,6 +164,7 @@ namespace Tensorflow { int hash = 1; hash ^= function_.GetHashCode(); hash ^= gradient_.GetHashCode(); + hash ^= registeredGradients_.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -163,6 +180,7 @@ namespace Tensorflow { public void WriteTo(pb::CodedOutputStream output) { function_.WriteTo(output, _repeated_function_codec); gradient_.WriteTo(output, _repeated_gradient_codec); + registeredGradients_.WriteTo(output, _repeated_registeredGradients_codec); if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -173,6 +191,7 @@ namespace Tensorflow { int size = 0; size += function_.CalculateSize(_repeated_function_codec); size += gradient_.CalculateSize(_repeated_gradient_codec); + size += registeredGradients_.CalculateSize(_repeated_registeredGradients_codec); if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -186,6 +205,7 @@ namespace Tensorflow { } function_.Add(other.function_); gradient_.Add(other.gradient_); + registeredGradients_.Add(other.registeredGradients_); _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -205,6 +225,10 @@ namespace Tensorflow { gradient_.AddEntriesFrom(input, _repeated_gradient_codec); break; } + case 26: { + registeredGradients_.AddEntriesFrom(input, _repeated_registeredGradients_codec); + break; + } } } } @@ -820,6 +844,175 @@ namespace Tensorflow { } + /// + /// RegisteredGradient stores a gradient function that is registered in the + /// gradients library and used in the ops of a function in the function library. + /// Unlike GradientDef, these gradients are identified by op type, and not + /// directly linked to any function. + /// + public sealed partial class RegisteredGradient : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RegisteredGradient()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.FunctionReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RegisteredGradient() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RegisteredGradient(RegisteredGradient other) : this() { + gradientFunc_ = other.gradientFunc_; + registeredOpType_ = other.registeredOpType_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RegisteredGradient Clone() { + return new RegisteredGradient(this); + } + + /// Field number for the "gradient_func" field. + public const int GradientFuncFieldNumber = 1; + private string gradientFunc_ = ""; + /// + /// The gradient function's name. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string GradientFunc { + get { return gradientFunc_; } + set { + gradientFunc_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "registered_op_type" field. + public const int RegisteredOpTypeFieldNumber = 2; + private string registeredOpType_ = ""; + /// + /// The gradient function's registered op type. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string RegisteredOpType { + get { return registeredOpType_; } + set { + registeredOpType_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RegisteredGradient); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RegisteredGradient other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (GradientFunc != other.GradientFunc) return false; + if (RegisteredOpType != other.RegisteredOpType) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (GradientFunc.Length != 0) hash ^= GradientFunc.GetHashCode(); + if (RegisteredOpType.Length != 0) hash ^= RegisteredOpType.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (GradientFunc.Length != 0) { + output.WriteRawTag(10); + output.WriteString(GradientFunc); + } + if (RegisteredOpType.Length != 0) { + output.WriteRawTag(18); + output.WriteString(RegisteredOpType); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (GradientFunc.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(GradientFunc); + } + if (RegisteredOpType.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(RegisteredOpType); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RegisteredGradient other) { + if (other == null) { + return; + } + if (other.GradientFunc.Length != 0) { + GradientFunc = other.GradientFunc; + } + if (other.RegisteredOpType.Length != 0) { + RegisteredOpType = other.RegisteredOpType; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + GradientFunc = input.ReadString(); + break; + } + case 18: { + RegisteredOpType = input.ReadString(); + break; + } + } + } + } + + } + #endregion } diff --git a/src/TensorFlowNET.Core/Protobuf/Gen.bat b/src/TensorFlowNET.Core/Protobuf/Gen.bat index 165d8a3a..fdb962f8 100644 --- a/src/TensorFlowNET.Core/Protobuf/Gen.bat +++ b/src/TensorFlowNET.Core/Protobuf/Gen.bat @@ -24,6 +24,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/kernel_def. protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/log_memory.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_slice.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/summary.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/full_type.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/op_def.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saver.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/saved_object_graph.proto @@ -39,6 +40,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/trackable_ob protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/struct.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/verifier_config.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/util/event.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/util/memmapped_file_system.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/framework/cpp_shape_inference.proto diff --git a/src/TensorFlowNET.Core/Protobuf/MemmappedFileSystem.cs b/src/TensorFlowNET.Core/Protobuf/MemmappedFileSystem.cs new file mode 100644 index 00000000..9a013fd7 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/MemmappedFileSystem.cs @@ -0,0 +1,360 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/util/memmapped_file_system.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/util/memmapped_file_system.proto + public static partial class MemmappedFileSystemReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/util/memmapped_file_system.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static MemmappedFileSystemReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "CjB0ZW5zb3JmbG93L2NvcmUvdXRpbC9tZW1tYXBwZWRfZmlsZV9zeXN0ZW0u", + "cHJvdG8SCnRlbnNvcmZsb3ciUwojTWVtbWFwcGVkRmlsZVN5c3RlbURpcmVj", + "dG9yeUVsZW1lbnQSDgoGb2Zmc2V0GAEgASgEEgwKBG5hbWUYAiABKAkSDgoG", + "bGVuZ3RoGAMgASgEImAKHE1lbW1hcHBlZEZpbGVTeXN0ZW1EaXJlY3RvcnkS", + "QAoHZWxlbWVudBgBIAMoCzIvLnRlbnNvcmZsb3cuTWVtbWFwcGVkRmlsZVN5", + "c3RlbURpcmVjdG9yeUVsZW1lbnRCA/gBAWIGcHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MemmappedFileSystemDirectoryElement), global::Tensorflow.MemmappedFileSystemDirectoryElement.Parser, new[]{ "Offset", "Name", "Length" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.MemmappedFileSystemDirectory), global::Tensorflow.MemmappedFileSystemDirectory.Parser, new[]{ "Element" }, null, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// A message that describes one region of memmapped file. + /// + public sealed partial class MemmappedFileSystemDirectoryElement : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MemmappedFileSystemDirectoryElement()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MemmappedFileSystemReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MemmappedFileSystemDirectoryElement() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MemmappedFileSystemDirectoryElement(MemmappedFileSystemDirectoryElement other) : this() { + offset_ = other.offset_; + name_ = other.name_; + length_ = other.length_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MemmappedFileSystemDirectoryElement Clone() { + return new MemmappedFileSystemDirectoryElement(this); + } + + /// Field number for the "offset" field. + public const int OffsetFieldNumber = 1; + private ulong offset_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ulong Offset { + get { return offset_; } + set { + offset_ = value; + } + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 2; + private string name_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "length" field. + public const int LengthFieldNumber = 3; + private ulong length_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ulong Length { + get { return length_; } + set { + length_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as MemmappedFileSystemDirectoryElement); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(MemmappedFileSystemDirectoryElement other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Offset != other.Offset) return false; + if (Name != other.Name) return false; + if (Length != other.Length) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Offset != 0UL) hash ^= Offset.GetHashCode(); + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (Length != 0UL) hash ^= Length.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Offset != 0UL) { + output.WriteRawTag(8); + output.WriteUInt64(Offset); + } + if (Name.Length != 0) { + output.WriteRawTag(18); + output.WriteString(Name); + } + if (Length != 0UL) { + output.WriteRawTag(24); + output.WriteUInt64(Length); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Offset != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(Offset); + } + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (Length != 0UL) { + size += 1 + pb::CodedOutputStream.ComputeUInt64Size(Length); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(MemmappedFileSystemDirectoryElement other) { + if (other == null) { + return; + } + if (other.Offset != 0UL) { + Offset = other.Offset; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.Length != 0UL) { + Length = other.Length; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 8: { + Offset = input.ReadUInt64(); + break; + } + case 18: { + Name = input.ReadString(); + break; + } + case 24: { + Length = input.ReadUInt64(); + break; + } + } + } + } + + } + + /// + /// A directory of regions in a memmapped file. + /// + public sealed partial class MemmappedFileSystemDirectory : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new MemmappedFileSystemDirectory()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.MemmappedFileSystemReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MemmappedFileSystemDirectory() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MemmappedFileSystemDirectory(MemmappedFileSystemDirectory other) : this() { + element_ = other.element_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public MemmappedFileSystemDirectory Clone() { + return new MemmappedFileSystemDirectory(this); + } + + /// Field number for the "element" field. + public const int ElementFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_element_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.MemmappedFileSystemDirectoryElement.Parser); + private readonly pbc::RepeatedField element_ = new pbc::RepeatedField(); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Element { + get { return element_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as MemmappedFileSystemDirectory); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(MemmappedFileSystemDirectory other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!element_.Equals(other.element_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= element_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + element_.WriteTo(output, _repeated_element_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += element_.CalculateSize(_repeated_element_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(MemmappedFileSystemDirectory other) { + if (other == null) { + return; + } + element_.Add(other.element_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + element_.AddEntriesFrom(input, _repeated_element_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/OpDef.cs b/src/TensorFlowNET.Core/Protobuf/OpDef.cs index 03c093fb..df26be91 100644 --- a/src/TensorFlowNET.Core/Protobuf/OpDef.cs +++ b/src/TensorFlowNET.Core/Protobuf/OpDef.cs @@ -26,35 +26,38 @@ namespace Tensorflow { string.Concat( "CiZ0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL29wX2RlZi5wcm90bxIKdGVu", "c29yZmxvdxoqdGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay9hdHRyX3ZhbHVl", - "LnByb3RvGiV0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3R5cGVzLnByb3Rv", - "Gi90ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL3Jlc291cmNlX2hhbmRsZS5w", - "cm90byKUBgoFT3BEZWYSDAoEbmFtZRgBIAEoCRIrCglpbnB1dF9hcmcYAiAD", - "KAsyGC50ZW5zb3JmbG93Lk9wRGVmLkFyZ0RlZhIsCgpvdXRwdXRfYXJnGAMg", - "AygLMhgudGVuc29yZmxvdy5PcERlZi5BcmdEZWYSFgoOY29udHJvbF9vdXRw", - "dXQYFCADKAkSJwoEYXR0chgEIAMoCzIZLnRlbnNvcmZsb3cuT3BEZWYuQXR0", - "ckRlZhIuCgtkZXByZWNhdGlvbhgIIAEoCzIZLnRlbnNvcmZsb3cuT3BEZXBy", - "ZWNhdGlvbhIPCgdzdW1tYXJ5GAUgASgJEhMKC2Rlc2NyaXB0aW9uGAYgASgJ", - "EhYKDmlzX2NvbW11dGF0aXZlGBIgASgIEhQKDGlzX2FnZ3JlZ2F0ZRgQIAEo", - "CBITCgtpc19zdGF0ZWZ1bBgRIAEoCBIiChphbGxvd3NfdW5pbml0aWFsaXpl", - "ZF9pbnB1dBgTIAEoCBrjAQoGQXJnRGVmEgwKBG5hbWUYASABKAkSEwoLZGVz", - "Y3JpcHRpb24YAiABKAkSIgoEdHlwZRgDIAEoDjIULnRlbnNvcmZsb3cuRGF0", - "YVR5cGUSEQoJdHlwZV9hdHRyGAQgASgJEhMKC251bWJlcl9hdHRyGAUgASgJ", - "EhYKDnR5cGVfbGlzdF9hdHRyGAYgASgJEkIKC2hhbmRsZV9kYXRhGAcgAygL", - "Mi0udGVuc29yZmxvdy5SZXNvdXJjZUhhbmRsZVByb3RvLkR0eXBlQW5kU2hh", - "cGUSDgoGaXNfcmVmGBAgASgIGr0BCgdBdHRyRGVmEgwKBG5hbWUYASABKAkS", - "DAoEdHlwZRgCIAEoCRIsCg1kZWZhdWx0X3ZhbHVlGAMgASgLMhUudGVuc29y", - "Zmxvdy5BdHRyVmFsdWUSEwoLZGVzY3JpcHRpb24YBCABKAkSEwoLaGFzX21p", - "bmltdW0YBSABKAgSDwoHbWluaW11bRgGIAEoAxItCg5hbGxvd2VkX3ZhbHVl", - "cxgHIAEoCzIVLnRlbnNvcmZsb3cuQXR0clZhbHVlIjUKDU9wRGVwcmVjYXRp", - "b24SDwoHdmVyc2lvbhgBIAEoBRITCgtleHBsYW5hdGlvbhgCIAEoCSInCgZP", - "cExpc3QSHQoCb3AYASADKAsyES50ZW5zb3JmbG93Lk9wRGVmQnsKGG9yZy50", - "ZW5zb3JmbG93LmZyYW1ld29ya0ILT3BEZWZQcm90b3NQAVpNZ2l0aHViLmNv", - "bS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2Zy", - "YW1ld29yay9vcF9kZWZfZ29fcHJvdG/4AQFiBnByb3RvMw==")); + "LnByb3RvGil0ZW5zb3JmbG93L2NvcmUvZnJhbWV3b3JrL2Z1bGxfdHlwZS5w", + "cm90bxovdGVuc29yZmxvdy9jb3JlL2ZyYW1ld29yay9yZXNvdXJjZV9oYW5k", + "bGUucHJvdG8aJXRlbnNvcmZsb3cvY29yZS9mcmFtZXdvcmsvdHlwZXMucHJv", + "dG8i8wYKBU9wRGVmEgwKBG5hbWUYASABKAkSKwoJaW5wdXRfYXJnGAIgAygL", + "MhgudGVuc29yZmxvdy5PcERlZi5BcmdEZWYSLAoKb3V0cHV0X2FyZxgDIAMo", + "CzIYLnRlbnNvcmZsb3cuT3BEZWYuQXJnRGVmEhYKDmNvbnRyb2xfb3V0cHV0", + "GBQgAygJEicKBGF0dHIYBCADKAsyGS50ZW5zb3JmbG93Lk9wRGVmLkF0dHJE", + "ZWYSLgoLZGVwcmVjYXRpb24YCCABKAsyGS50ZW5zb3JmbG93Lk9wRGVwcmVj", + "YXRpb24SDwoHc3VtbWFyeRgFIAEoCRITCgtkZXNjcmlwdGlvbhgGIAEoCRIW", + "Cg5pc19jb21tdXRhdGl2ZRgSIAEoCBIUCgxpc19hZ2dyZWdhdGUYECABKAgS", + "EwoLaXNfc3RhdGVmdWwYESABKAgSIgoaYWxsb3dzX3VuaW5pdGlhbGl6ZWRf", + "aW5wdXQYEyABKAgSJAocaXNfZGlzdHJpYnV0ZWRfY29tbXVuaWNhdGlvbhgV", + "IAEoCBqcAgoGQXJnRGVmEgwKBG5hbWUYASABKAkSEwoLZGVzY3JpcHRpb24Y", + "AiABKAkSIgoEdHlwZRgDIAEoDjIULnRlbnNvcmZsb3cuRGF0YVR5cGUSEQoJ", + "dHlwZV9hdHRyGAQgASgJEhMKC251bWJlcl9hdHRyGAUgASgJEhYKDnR5cGVf", + "bGlzdF9hdHRyGAYgASgJEkIKC2hhbmRsZV9kYXRhGAcgAygLMi0udGVuc29y", + "Zmxvdy5SZXNvdXJjZUhhbmRsZVByb3RvLkR0eXBlQW5kU2hhcGUSDgoGaXNf", + "cmVmGBAgASgIEjcKFmV4cGVyaW1lbnRhbF9mdWxsX3R5cGUYESABKAsyFy50", + "ZW5zb3JmbG93LkZ1bGxUeXBlRGVmGr0BCgdBdHRyRGVmEgwKBG5hbWUYASAB", + "KAkSDAoEdHlwZRgCIAEoCRIsCg1kZWZhdWx0X3ZhbHVlGAMgASgLMhUudGVu", + "c29yZmxvdy5BdHRyVmFsdWUSEwoLZGVzY3JpcHRpb24YBCABKAkSEwoLaGFz", + "X21pbmltdW0YBSABKAgSDwoHbWluaW11bRgGIAEoAxItCg5hbGxvd2VkX3Zh", + "bHVlcxgHIAEoCzIVLnRlbnNvcmZsb3cuQXR0clZhbHVlIjUKDU9wRGVwcmVj", + "YXRpb24SDwoHdmVyc2lvbhgBIAEoBRITCgtleHBsYW5hdGlvbhgCIAEoCSIn", + "CgZPcExpc3QSHQoCb3AYASADKAsyES50ZW5zb3JmbG93Lk9wRGVmQnsKGG9y", + "Zy50ZW5zb3JmbG93LmZyYW1ld29ya0ILT3BEZWZQcm90b3NQAVpNZ2l0aHVi", + "LmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3Jl", + "L2ZyYW1ld29yay9vcF9kZWZfZ29fcHJvdG/4AQFiBnByb3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, - new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, global::Tensorflow.ResourceHandleReflection.Descriptor, }, + new pbr::FileDescriptor[] { global::Tensorflow.AttrValueReflection.Descriptor, global::Tensorflow.FullTypeReflection.Descriptor, global::Tensorflow.ResourceHandleReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef), global::Tensorflow.OpDef.Parser, new[]{ "Name", "InputArg", "OutputArg", "ControlOutput", "Attr", "Deprecation", "Summary", "Description", "IsCommutative", "IsAggregate", "IsStateful", "AllowsUninitializedInput" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef.Types.ArgDef), global::Tensorflow.OpDef.Types.ArgDef.Parser, new[]{ "Name", "Description", "Type", "TypeAttr", "NumberAttr", "TypeListAttr", "HandleData", "IsRef" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef), global::Tensorflow.OpDef.Parser, new[]{ "Name", "InputArg", "OutputArg", "ControlOutput", "Attr", "Deprecation", "Summary", "Description", "IsCommutative", "IsAggregate", "IsStateful", "AllowsUninitializedInput", "IsDistributedCommunication" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef.Types.ArgDef), global::Tensorflow.OpDef.Types.ArgDef.Parser, new[]{ "Name", "Description", "Type", "TypeAttr", "NumberAttr", "TypeListAttr", "HandleData", "IsRef", "ExperimentalFullType" }, null, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDef.Types.AttrDef), global::Tensorflow.OpDef.Types.AttrDef.Parser, new[]{ "Name", "Type", "DefaultValue", "Description", "HasMinimum", "Minimum", "AllowedValues" }, null, null, null, null)}), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpDeprecation), global::Tensorflow.OpDeprecation.Parser, new[]{ "Version", "Explanation" }, null, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.OpList), global::Tensorflow.OpList.Parser, new[]{ "Op" }, null, null, null, null) @@ -106,6 +109,7 @@ namespace Tensorflow { isAggregate_ = other.isAggregate_; isStateful_ = other.isStateful_; allowsUninitializedInput_ = other.allowsUninitializedInput_; + isDistributedCommunication_ = other.isDistributedCommunication_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -297,6 +301,22 @@ namespace Tensorflow { } } + /// Field number for the "is_distributed_communication" field. + public const int IsDistributedCommunicationFieldNumber = 21; + private bool isDistributedCommunication_; + /// + /// Indicates whether the op implementation uses distributed communication. + /// If True, the op is allowed to return errors for network disconnection and + /// trigger TF network failure handling logics. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool IsDistributedCommunication { + get { return isDistributedCommunication_; } + set { + isDistributedCommunication_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as OpDef); @@ -322,6 +342,7 @@ namespace Tensorflow { if (IsAggregate != other.IsAggregate) return false; if (IsStateful != other.IsStateful) return false; if (AllowsUninitializedInput != other.AllowsUninitializedInput) return false; + if (IsDistributedCommunication != other.IsDistributedCommunication) return false; return Equals(_unknownFields, other._unknownFields); } @@ -340,6 +361,7 @@ namespace Tensorflow { if (IsAggregate != false) hash ^= IsAggregate.GetHashCode(); if (IsStateful != false) hash ^= IsStateful.GetHashCode(); if (AllowsUninitializedInput != false) hash ^= AllowsUninitializedInput.GetHashCode(); + if (IsDistributedCommunication != false) hash ^= IsDistributedCommunication.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -389,6 +411,10 @@ namespace Tensorflow { output.WriteBool(AllowsUninitializedInput); } controlOutput_.WriteTo(output, _repeated_controlOutput_codec); + if (IsDistributedCommunication != false) { + output.WriteRawTag(168, 1); + output.WriteBool(IsDistributedCommunication); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -425,6 +451,9 @@ namespace Tensorflow { if (AllowsUninitializedInput != false) { size += 2 + 1; } + if (IsDistributedCommunication != false) { + size += 2 + 1; + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -467,6 +496,9 @@ namespace Tensorflow { if (other.AllowsUninitializedInput != false) { AllowsUninitializedInput = other.AllowsUninitializedInput; } + if (other.IsDistributedCommunication != false) { + IsDistributedCommunication = other.IsDistributedCommunication; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -529,6 +561,10 @@ namespace Tensorflow { controlOutput_.AddEntriesFrom(input, _repeated_controlOutput_codec); break; } + case 168: { + IsDistributedCommunication = input.ReadBool(); + break; + } } } } @@ -573,6 +609,7 @@ namespace Tensorflow { typeListAttr_ = other.typeListAttr_; handleData_ = other.handleData_.Clone(); isRef_ = other.isRef_; + experimentalFullType_ = other.experimentalFullType_ != null ? other.experimentalFullType_.Clone() : null; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -704,6 +741,28 @@ namespace Tensorflow { } } + /// Field number for the "experimental_full_type" field. + public const int ExperimentalFullTypeFieldNumber = 17; + private global::Tensorflow.FullTypeDef experimentalFullType_; + /// + /// Experimental. Full type declaration for this argument. + /// The full type specification combines type, type_attr, type_list_attr, + /// etc. into a unified representation. + /// This declaration may contain non-concrete types (for example, + /// Tensor<TypeVar<'T'>> is a valid type declaration. + /// + /// Note: this is a transient field. The long-term aim is to represent the + /// entire OpDef as a single type: a callable. In that context, this field is + /// just the type of a single argument. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.FullTypeDef ExperimentalFullType { + get { return experimentalFullType_; } + set { + experimentalFullType_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as ArgDef); @@ -725,6 +784,7 @@ namespace Tensorflow { if (TypeListAttr != other.TypeListAttr) return false; if(!handleData_.Equals(other.handleData_)) return false; if (IsRef != other.IsRef) return false; + if (!object.Equals(ExperimentalFullType, other.ExperimentalFullType)) return false; return Equals(_unknownFields, other._unknownFields); } @@ -739,6 +799,7 @@ namespace Tensorflow { if (TypeListAttr.Length != 0) hash ^= TypeListAttr.GetHashCode(); hash ^= handleData_.GetHashCode(); if (IsRef != false) hash ^= IsRef.GetHashCode(); + if (experimentalFullType_ != null) hash ^= ExperimentalFullType.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -781,6 +842,10 @@ namespace Tensorflow { output.WriteRawTag(128, 1); output.WriteBool(IsRef); } + if (experimentalFullType_ != null) { + output.WriteRawTag(138, 1); + output.WriteMessage(ExperimentalFullType); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -811,6 +876,9 @@ namespace Tensorflow { if (IsRef != false) { size += 2 + 1; } + if (experimentalFullType_ != null) { + size += 2 + pb::CodedOutputStream.ComputeMessageSize(ExperimentalFullType); + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -844,6 +912,12 @@ namespace Tensorflow { if (other.IsRef != false) { IsRef = other.IsRef; } + if (other.experimentalFullType_ != null) { + if (experimentalFullType_ == null) { + ExperimentalFullType = new global::Tensorflow.FullTypeDef(); + } + ExperimentalFullType.MergeFrom(other.ExperimentalFullType); + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -887,6 +961,13 @@ namespace Tensorflow { IsRef = input.ReadBool(); break; } + case 138: { + if (experimentalFullType_ == null) { + ExperimentalFullType = new global::Tensorflow.FullTypeDef(); + } + input.ReadMessage(ExperimentalFullType); + break; + } } } } diff --git a/src/TensorFlowNET.Core/Protobuf/RewriterConfig.cs b/src/TensorFlowNET.Core/Protobuf/RewriterConfig.cs index db8951c5..2804ca26 100644 --- a/src/TensorFlowNET.Core/Protobuf/RewriterConfig.cs +++ b/src/TensorFlowNET.Core/Protobuf/RewriterConfig.cs @@ -842,8 +842,8 @@ namespace Tensorflow { private long metaOptimizerTimeoutMs_; /// /// Maximum number of milliseconds to spend optimizing a single graph before - /// timing out. If equal to 0 the system picks a default (currently 5 minutes). - /// If less than 0 the optimizer will never time out. + /// timing out. If less than or equal to 0 (default value) the optimizer will + /// never time out. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public long MetaOptimizerTimeoutMs { diff --git a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs index dc3c5318..9d3e854a 100644 --- a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs @@ -36,7 +36,7 @@ namespace Tensorflow { "aW9ucxgCIAMoCzIzLnRlbnNvcmZsb3cuU2F2ZWRPYmplY3RHcmFwaC5Db25j", "cmV0ZUZ1bmN0aW9uc0VudHJ5GlsKFkNvbmNyZXRlRnVuY3Rpb25zRW50cnkS", "CwoDa2V5GAEgASgJEjAKBXZhbHVlGAIgASgLMiEudGVuc29yZmxvdy5TYXZl", - "ZENvbmNyZXRlRnVuY3Rpb246AjgBItkFCgtTYXZlZE9iamVjdBJSCghjaGls", + "ZENvbmNyZXRlRnVuY3Rpb246AjgBIpAGCgtTYXZlZE9iamVjdBJSCghjaGls", "ZHJlbhgBIAMoCzJALnRlbnNvcmZsb3cuVHJhY2thYmxlT2JqZWN0R3JhcGgu", "VHJhY2thYmxlT2JqZWN0Lk9iamVjdFJlZmVyZW5jZRJeCg5zbG90X3Zhcmlh", "YmxlcxgDIAMoCzJGLnRlbnNvcmZsb3cuVHJhY2thYmxlT2JqZWN0R3JhcGgu", @@ -48,51 +48,54 @@ namespace Tensorflow { "RwoWYmFyZV9jb25jcmV0ZV9mdW5jdGlvbhgIIAEoCzIlLnRlbnNvcmZsb3cu", "U2F2ZWRCYXJlQ29uY3JldGVGdW5jdGlvbkgAEi0KCGNvbnN0YW50GAkgASgL", "MhkudGVuc29yZmxvdy5TYXZlZENvbnN0YW50SAASLQoIcmVzb3VyY2UYCiAB", - "KAsyGS50ZW5zb3JmbG93LlNhdmVkUmVzb3VyY2VIABJGChBzYXZlYWJsZV9v", - "YmplY3RzGAsgAygLMiwudGVuc29yZmxvdy5TYXZlZE9iamVjdC5TYXZlYWJs", - "ZU9iamVjdHNFbnRyeRpSChRTYXZlYWJsZU9iamVjdHNFbnRyeRILCgNrZXkY", - "ASABKAkSKQoFdmFsdWUYAiABKAsyGi50ZW5zb3JmbG93LlNhdmVhYmxlT2Jq", - "ZWN0OgI4AUIGCgRraW5kSgQIAhADUgphdHRyaWJ1dGVzImAKD1NhdmVkVXNl", - "ck9iamVjdBISCgppZGVudGlmaWVyGAEgASgJEicKB3ZlcnNpb24YAiABKAsy", - "Fi50ZW5zb3JmbG93LlZlcnNpb25EZWYSEAoIbWV0YWRhdGEYAyABKAkiKgoK", - "U2F2ZWRBc3NldBIcChRhc3NldF9maWxlX2RlZl9pbmRleBgBIAEoBSJcCg1T", - "YXZlZEZ1bmN0aW9uEhoKEmNvbmNyZXRlX2Z1bmN0aW9ucxgBIAMoCRIvCg1m", - "dW5jdGlvbl9zcGVjGAIgASgLMhgudGVuc29yZmxvdy5GdW5jdGlvblNwZWMi", - "qAEKFVNhdmVkQ29uY3JldGVGdW5jdGlvbhIUCgxib3VuZF9pbnB1dHMYAiAD", - "KAUSQgodY2Fub25pY2FsaXplZF9pbnB1dF9zaWduYXR1cmUYAyABKAsyGy50", - "ZW5zb3JmbG93LlN0cnVjdHVyZWRWYWx1ZRI1ChBvdXRwdXRfc2lnbmF0dXJl", - "GAQgASgLMhsudGVuc29yZmxvdy5TdHJ1Y3R1cmVkVmFsdWUirQEKGVNhdmVk", - "QmFyZUNvbmNyZXRlRnVuY3Rpb24SHgoWY29uY3JldGVfZnVuY3Rpb25fbmFt", - "ZRgBIAEoCRIZChFhcmd1bWVudF9rZXl3b3JkcxgCIAMoCRIkChxhbGxvd2Vk", - "X3Bvc2l0aW9uYWxfYXJndW1lbnRzGAMgASgDEi8KDWZ1bmN0aW9uX3NwZWMY", - "BCABKAsyGC50ZW5zb3JmbG93LkZ1bmN0aW9uU3BlYyIiCg1TYXZlZENvbnN0", - "YW50EhEKCW9wZXJhdGlvbhgBIAEoCSLXAgoNU2F2ZWRWYXJpYWJsZRIjCgVk", - "dHlwZRgBIAEoDjIULnRlbnNvcmZsb3cuRGF0YVR5cGUSKwoFc2hhcGUYAiAB", - "KAsyHC50ZW5zb3JmbG93LlRlbnNvclNoYXBlUHJvdG8SEQoJdHJhaW5hYmxl", - "GAMgASgIEjwKD3N5bmNocm9uaXphdGlvbhgEIAEoDjIjLnRlbnNvcmZsb3cu", - "VmFyaWFibGVTeW5jaHJvbml6YXRpb24SNAoLYWdncmVnYXRpb24YBSABKA4y", - "Hy50ZW5zb3JmbG93LlZhcmlhYmxlQWdncmVnYXRpb24SDAoEbmFtZRgGIAEo", - "CRIOCgZkZXZpY2UYByABKAkSTwosZXhwZXJpbWVudGFsX2Rpc3RyaWJ1dGVk", - "X3ZhcmlhYmxlX2NvbXBvbmVudHMYCCADKAsyGS50ZW5zb3JmbG93LlNhdmVk", - "VmFyaWFibGUi+wEKDEZ1bmN0aW9uU3BlYxIwCgtmdWxsYXJnc3BlYxgBIAEo", - "CzIbLnRlbnNvcmZsb3cuU3RydWN0dXJlZFZhbHVlEhEKCWlzX21ldGhvZBgC", - "IAEoCBI0Cg9pbnB1dF9zaWduYXR1cmUYBSABKAsyGy50ZW5zb3JmbG93LlN0", - "cnVjdHVyZWRWYWx1ZRI4CgtqaXRfY29tcGlsZRgGIAEoDjIjLnRlbnNvcmZs", - "b3cuRnVuY3Rpb25TcGVjLkppdENvbXBpbGUiKgoKSml0Q29tcGlsZRILCgdE", - "RUZBVUxUEAASBgoCT04QARIHCgNPRkYQAkoECAMQBEoECAQQBSIfCg1TYXZl", - "ZFJlc291cmNlEg4KBmRldmljZRgBIAEoCSJBCg5TYXZlYWJsZU9iamVjdBIV", - "Cg1zYXZlX2Z1bmN0aW9uGAIgASgFEhgKEHJlc3RvcmVfZnVuY3Rpb24YAyAB", - "KAVCWlpVZ2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29y", - "Zmxvdy9nby9jb3JlL3Byb3RvYnVmL2Zvcl9jb3JlX3Byb3Rvc19nb19wcm90", - "b/gBAWIGcHJvdG8z")); + "KAsyGS50ZW5zb3JmbG93LlNhdmVkUmVzb3VyY2VIABI1Cg9jYXB0dXJlZF90", + "ZW5zb3IYDCABKAsyGi50ZW5zb3JmbG93LkNhcHR1cmVkVGVuc29ySAASRgoQ", + "c2F2ZWFibGVfb2JqZWN0cxgLIAMoCzIsLnRlbnNvcmZsb3cuU2F2ZWRPYmpl", + "Y3QuU2F2ZWFibGVPYmplY3RzRW50cnkaUgoUU2F2ZWFibGVPYmplY3RzRW50", + "cnkSCwoDa2V5GAEgASgJEikKBXZhbHVlGAIgASgLMhoudGVuc29yZmxvdy5T", + "YXZlYWJsZU9iamVjdDoCOAFCBgoEa2luZEoECAIQA1IKYXR0cmlidXRlcyJk", + "Cg9TYXZlZFVzZXJPYmplY3QSEgoKaWRlbnRpZmllchgBIAEoCRInCgd2ZXJz", + "aW9uGAIgASgLMhYudGVuc29yZmxvdy5WZXJzaW9uRGVmEhQKCG1ldGFkYXRh", + "GAMgASgJQgIYASIqCgpTYXZlZEFzc2V0EhwKFGFzc2V0X2ZpbGVfZGVmX2lu", + "ZGV4GAEgASgFIlwKDVNhdmVkRnVuY3Rpb24SGgoSY29uY3JldGVfZnVuY3Rp", + "b25zGAEgAygJEi8KDWZ1bmN0aW9uX3NwZWMYAiABKAsyGC50ZW5zb3JmbG93", + "LkZ1bmN0aW9uU3BlYyI5Cg5DYXB0dXJlZFRlbnNvchIMCgRuYW1lGAEgASgJ", + "EhkKEWNvbmNyZXRlX2Z1bmN0aW9uGAIgASgJIqgBChVTYXZlZENvbmNyZXRl", + "RnVuY3Rpb24SFAoMYm91bmRfaW5wdXRzGAIgAygFEkIKHWNhbm9uaWNhbGl6", + "ZWRfaW5wdXRfc2lnbmF0dXJlGAMgASgLMhsudGVuc29yZmxvdy5TdHJ1Y3R1", + "cmVkVmFsdWUSNQoQb3V0cHV0X3NpZ25hdHVyZRgEIAEoCzIbLnRlbnNvcmZs", + "b3cuU3RydWN0dXJlZFZhbHVlIq0BChlTYXZlZEJhcmVDb25jcmV0ZUZ1bmN0", + "aW9uEh4KFmNvbmNyZXRlX2Z1bmN0aW9uX25hbWUYASABKAkSGQoRYXJndW1l", + "bnRfa2V5d29yZHMYAiADKAkSJAocYWxsb3dlZF9wb3NpdGlvbmFsX2FyZ3Vt", + "ZW50cxgDIAEoAxIvCg1mdW5jdGlvbl9zcGVjGAQgASgLMhgudGVuc29yZmxv", + "dy5GdW5jdGlvblNwZWMiIgoNU2F2ZWRDb25zdGFudBIRCglvcGVyYXRpb24Y", + "ASABKAki1wIKDVNhdmVkVmFyaWFibGUSIwoFZHR5cGUYASABKA4yFC50ZW5z", + "b3JmbG93LkRhdGFUeXBlEisKBXNoYXBlGAIgASgLMhwudGVuc29yZmxvdy5U", + "ZW5zb3JTaGFwZVByb3RvEhEKCXRyYWluYWJsZRgDIAEoCBI8Cg9zeW5jaHJv", + "bml6YXRpb24YBCABKA4yIy50ZW5zb3JmbG93LlZhcmlhYmxlU3luY2hyb25p", + "emF0aW9uEjQKC2FnZ3JlZ2F0aW9uGAUgASgOMh8udGVuc29yZmxvdy5WYXJp", + "YWJsZUFnZ3JlZ2F0aW9uEgwKBG5hbWUYBiABKAkSDgoGZGV2aWNlGAcgASgJ", + "Ek8KLGV4cGVyaW1lbnRhbF9kaXN0cmlidXRlZF92YXJpYWJsZV9jb21wb25l", + "bnRzGAggAygLMhkudGVuc29yZmxvdy5TYXZlZFZhcmlhYmxlIvsBCgxGdW5j", + "dGlvblNwZWMSMAoLZnVsbGFyZ3NwZWMYASABKAsyGy50ZW5zb3JmbG93LlN0", + "cnVjdHVyZWRWYWx1ZRIRCglpc19tZXRob2QYAiABKAgSNAoPaW5wdXRfc2ln", + "bmF0dXJlGAUgASgLMhsudGVuc29yZmxvdy5TdHJ1Y3R1cmVkVmFsdWUSOAoL", + "aml0X2NvbXBpbGUYBiABKA4yIy50ZW5zb3JmbG93LkZ1bmN0aW9uU3BlYy5K", + "aXRDb21waWxlIioKCkppdENvbXBpbGUSCwoHREVGQVVMVBAAEgYKAk9OEAES", + "BwoDT0ZGEAJKBAgDEARKBAgEEAUiHwoNU2F2ZWRSZXNvdXJjZRIOCgZkZXZp", + "Y2UYASABKAkiQQoOU2F2ZWFibGVPYmplY3QSFQoNc2F2ZV9mdW5jdGlvbhgC", + "IAEoBRIYChByZXN0b3JlX2Z1bmN0aW9uGAMgASgFQlpaVWdpdGh1Yi5jb20v", + "dGVuc29yZmxvdy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9wcm90", + "b2J1Zi9mb3JfY29yZV9wcm90b3NfZ29fcHJvdG/4AQFiBnByb3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, global::Tensorflow.VariableReflection.Descriptor, global::Tensorflow.VersionsReflection.Descriptor, global::Tensorflow.StructReflection.Descriptor, global::Tensorflow.TrackableObjectGraphReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedObjectGraph), global::Tensorflow.SavedObjectGraph.Parser, new[]{ "Nodes", "ConcreteFunctions" }, null, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), - new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedObject), global::Tensorflow.SavedObject.Parser, new[]{ "Children", "SlotVariables", "UserObject", "Asset", "Function", "Variable", "BareConcreteFunction", "Constant", "Resource", "SaveableObjects" }, new[]{ "Kind" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedObject), global::Tensorflow.SavedObject.Parser, new[]{ "Children", "SlotVariables", "UserObject", "Asset", "Function", "Variable", "BareConcreteFunction", "Constant", "Resource", "CapturedTensor", "SaveableObjects" }, new[]{ "Kind" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedUserObject), global::Tensorflow.SavedUserObject.Parser, new[]{ "Identifier", "Version", "Metadata" }, null, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedAsset), global::Tensorflow.SavedAsset.Parser, new[]{ "AssetFileDefIndex" }, null, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedFunction), global::Tensorflow.SavedFunction.Parser, new[]{ "ConcreteFunctions", "FunctionSpec" }, null, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CapturedTensor), global::Tensorflow.CapturedTensor.Parser, new[]{ "Name", "ConcreteFunction" }, null, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedConcreteFunction), global::Tensorflow.SavedConcreteFunction.Parser, new[]{ "BoundInputs", "CanonicalizedInputSignature", "OutputSignature" }, null, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedBareConcreteFunction), global::Tensorflow.SavedBareConcreteFunction.Parser, new[]{ "ConcreteFunctionName", "ArgumentKeywords", "AllowedPositionalArguments", "FunctionSpec" }, null, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.SavedConstant), global::Tensorflow.SavedConstant.Parser, new[]{ "Operation" }, null, null, null, null), @@ -307,6 +310,9 @@ namespace Tensorflow { case KindOneofCase.Resource: Resource = other.Resource.Clone(); break; + case KindOneofCase.CapturedTensor: + CapturedTensor = other.CapturedTensor.Clone(); + break; } _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); @@ -427,6 +433,17 @@ namespace Tensorflow { } } + /// Field number for the "captured_tensor" field. + public const int CapturedTensorFieldNumber = 12; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.CapturedTensor CapturedTensor { + get { return kindCase_ == KindOneofCase.CapturedTensor ? (global::Tensorflow.CapturedTensor) kind_ : null; } + set { + kind_ = value; + kindCase_ = value == null ? KindOneofCase.None : KindOneofCase.CapturedTensor; + } + } + /// Field number for the "saveable_objects" field. public const int SaveableObjectsFieldNumber = 11; private static readonly pbc::MapField.Codec _map_saveableObjects_codec @@ -448,6 +465,7 @@ namespace Tensorflow { BareConcreteFunction = 8, Constant = 9, Resource = 10, + CapturedTensor = 12, } private KindOneofCase kindCase_ = KindOneofCase.None; [global::System.Diagnostics.DebuggerNonUserCodeAttribute] @@ -483,6 +501,7 @@ namespace Tensorflow { if (!object.Equals(BareConcreteFunction, other.BareConcreteFunction)) return false; if (!object.Equals(Constant, other.Constant)) return false; if (!object.Equals(Resource, other.Resource)) return false; + if (!object.Equals(CapturedTensor, other.CapturedTensor)) return false; if (!SaveableObjects.Equals(other.SaveableObjects)) return false; if (KindCase != other.KindCase) return false; return Equals(_unknownFields, other._unknownFields); @@ -500,6 +519,7 @@ namespace Tensorflow { if (kindCase_ == KindOneofCase.BareConcreteFunction) hash ^= BareConcreteFunction.GetHashCode(); if (kindCase_ == KindOneofCase.Constant) hash ^= Constant.GetHashCode(); if (kindCase_ == KindOneofCase.Resource) hash ^= Resource.GetHashCode(); + if (kindCase_ == KindOneofCase.CapturedTensor) hash ^= CapturedTensor.GetHashCode(); hash ^= SaveableObjects.GetHashCode(); hash ^= (int) kindCase_; if (_unknownFields != null) { @@ -546,6 +566,10 @@ namespace Tensorflow { output.WriteMessage(Resource); } saveableObjects_.WriteTo(output, _map_saveableObjects_codec); + if (kindCase_ == KindOneofCase.CapturedTensor) { + output.WriteRawTag(98); + output.WriteMessage(CapturedTensor); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -577,6 +601,9 @@ namespace Tensorflow { if (kindCase_ == KindOneofCase.Resource) { size += 1 + pb::CodedOutputStream.ComputeMessageSize(Resource); } + if (kindCase_ == KindOneofCase.CapturedTensor) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CapturedTensor); + } size += saveableObjects_.CalculateSize(_map_saveableObjects_codec); if (_unknownFields != null) { size += _unknownFields.CalculateSize(); @@ -635,6 +662,12 @@ namespace Tensorflow { } Resource.MergeFrom(other.Resource); break; + case KindOneofCase.CapturedTensor: + if (CapturedTensor == null) { + CapturedTensor = new global::Tensorflow.CapturedTensor(); + } + CapturedTensor.MergeFrom(other.CapturedTensor); + break; } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); @@ -723,6 +756,15 @@ namespace Tensorflow { saveableObjects_.AddEntriesFrom(input, _map_saveableObjects_codec); break; } + case 98: { + global::Tensorflow.CapturedTensor subBuilder = new global::Tensorflow.CapturedTensor(); + if (kindCase_ == KindOneofCase.CapturedTensor) { + subBuilder.MergeFrom(CapturedTensor); + } + input.ReadMessage(subBuilder); + CapturedTensor = subBuilder; + break; + } } } } @@ -805,11 +847,13 @@ namespace Tensorflow { public const int MetadataFieldNumber = 3; private string metadata_ = ""; /// + /// Metadata for deserializing this object. + /// /// Deprecated! At the time of deprecation, Keras was the only user of this /// field, and its saving and loading code will be updated shortly. - /// Please save your application-specific metadata to separate file - /// Initialization-related metadata. + /// Please save your application-specific metadata to a separate file. /// + [global::System.ObsoleteAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public string Metadata { get { return metadata_; } @@ -1240,6 +1284,169 @@ namespace Tensorflow { } + public sealed partial class CapturedTensor : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CapturedTensor()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[5]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CapturedTensor() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CapturedTensor(CapturedTensor other) : this() { + name_ = other.name_; + concreteFunction_ = other.concreteFunction_; + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CapturedTensor Clone() { + return new CapturedTensor(this); + } + + /// Field number for the "name" field. + public const int NameFieldNumber = 1; + private string name_ = ""; + /// + /// Name of captured tensor + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Name { + get { return name_; } + set { + name_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "concrete_function" field. + public const int ConcreteFunctionFieldNumber = 2; + private string concreteFunction_ = ""; + /// + /// Name of concrete function which contains the computed graph tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ConcreteFunction { + get { return concreteFunction_; } + set { + concreteFunction_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as CapturedTensor); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(CapturedTensor other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Name != other.Name) return false; + if (ConcreteFunction != other.ConcreteFunction) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Name.Length != 0) hash ^= Name.GetHashCode(); + if (ConcreteFunction.Length != 0) hash ^= ConcreteFunction.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Name.Length != 0) { + output.WriteRawTag(10); + output.WriteString(Name); + } + if (ConcreteFunction.Length != 0) { + output.WriteRawTag(18); + output.WriteString(ConcreteFunction); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Name.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Name); + } + if (ConcreteFunction.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ConcreteFunction); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(CapturedTensor other) { + if (other == null) { + return; + } + if (other.Name.Length != 0) { + Name = other.Name; + } + if (other.ConcreteFunction.Length != 0) { + ConcreteFunction = other.ConcreteFunction; + } + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + Name = input.ReadString(); + break; + } + case 18: { + ConcreteFunction = input.ReadString(); + break; + } + } + } + } + + } + /// /// Stores low-level information about a concrete function. Referenced in either /// a SavedFunction or a SavedBareConcreteFunction. @@ -1252,7 +1459,7 @@ namespace Tensorflow { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public static pbr::MessageDescriptor Descriptor { - get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[5]; } + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[6]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] @@ -1285,12 +1492,6 @@ namespace Tensorflow { private static readonly pb::FieldCodec _repeated_boundInputs_codec = pb::FieldCodec.ForInt32(18); private readonly pbc::RepeatedField boundInputs_ = new pbc::RepeatedField(); - /// - /// Bound inputs to the function. The SavedObjects identified by the node ids - /// given here are appended as extra inputs to the caller-supplied inputs. - /// The only types of SavedObjects valid here are SavedVariable, SavedResource - /// and SavedAsset. - /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public pbc::RepeatedField BoundInputs { get { return boundInputs_; } @@ -1457,7 +1658,7 @@ namespace Tensorflow { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public static pbr::MessageDescriptor Descriptor { - get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[6]; } + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[7]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] @@ -1685,7 +1886,7 @@ namespace Tensorflow { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public static pbr::MessageDescriptor Descriptor { - get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[7]; } + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[8]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] @@ -1821,7 +2022,7 @@ namespace Tensorflow { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public static pbr::MessageDescriptor Descriptor { - get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[8]; } + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[9]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] @@ -2156,7 +2357,7 @@ namespace Tensorflow { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public static pbr::MessageDescriptor Descriptor { - get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[9]; } + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[10]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] @@ -2418,7 +2619,7 @@ namespace Tensorflow { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public static pbr::MessageDescriptor Descriptor { - get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[10]; } + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[11]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] @@ -2552,7 +2753,7 @@ namespace Tensorflow { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public static pbr::MessageDescriptor Descriptor { - get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[11]; } + get { return global::Tensorflow.SavedObjectGraphReflection.Descriptor.MessageTypes[12]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] diff --git a/src/TensorFlowNET.Core/Protobuf/Struct.cs b/src/TensorFlowNET.Core/Protobuf/Struct.cs index c79f468d..c0879bc9 100644 --- a/src/TensorFlowNET.Core/Protobuf/Struct.cs +++ b/src/TensorFlowNET.Core/Protobuf/Struct.cs @@ -58,19 +58,20 @@ namespace Tensorflow { "YW1lGAEgASgJEisKBXNoYXBlGAIgASgLMhwudGVuc29yZmxvdy5UZW5zb3JT", "aGFwZVByb3RvEiMKBWR0eXBlGAMgASgOMhQudGVuc29yZmxvdy5EYXRhVHlw", "ZRIoCgdtaW5pbXVtGAQgASgLMhcudGVuc29yZmxvdy5UZW5zb3JQcm90bxIo", - "CgdtYXhpbXVtGAUgASgLMhcudGVuc29yZmxvdy5UZW5zb3JQcm90byKoAwoN", + "CgdtYXhpbXVtGAUgASgLMhcudGVuc29yZmxvdy5UZW5zb3JQcm90byLbAwoN", "VHlwZVNwZWNQcm90bxJACg90eXBlX3NwZWNfY2xhc3MYASABKA4yJy50ZW5z", "b3JmbG93LlR5cGVTcGVjUHJvdG8uVHlwZVNwZWNDbGFzcxIvCgp0eXBlX3N0", "YXRlGAIgASgLMhsudGVuc29yZmxvdy5TdHJ1Y3R1cmVkVmFsdWUSHAoUdHlw", - "ZV9zcGVjX2NsYXNzX25hbWUYAyABKAkihQIKDVR5cGVTcGVjQ2xhc3MSCwoH", + "ZV9zcGVjX2NsYXNzX25hbWUYAyABKAkiuAIKDVR5cGVTcGVjQ2xhc3MSCwoH", "VU5LTk9XThAAEhYKElNQQVJTRV9URU5TT1JfU1BFQxABEhcKE0lOREVYRURf", "U0xJQ0VTX1NQRUMQAhIWChJSQUdHRURfVEVOU09SX1NQRUMQAxIVChFURU5T", "T1JfQVJSQVlfU1BFQxAEEhUKEURBVEFfREFUQVNFVF9TUEVDEAUSFgoSREFU", "QV9JVEVSQVRPUl9TUEVDEAYSEQoNT1BUSU9OQUxfU1BFQxAHEhQKEFBFUl9S", "RVBMSUNBX1NQRUMQCBIRCg1WQVJJQUJMRV9TUEVDEAkSFgoSUk9XX1BBUlRJ", - "VElPTl9TUEVDEAoiBAgLEAtCV1pVZ2l0aHViLmNvbS90ZW5zb3JmbG93L3Rl", - "bnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL3Byb3RvYnVmL2Zvcl9jb3Jl", - "X3Byb3Rvc19nb19wcm90b2IGcHJvdG8z")); + "VElPTl9TUEVDEAoSGAoUUkVHSVNURVJFRF9UWVBFX1NQRUMQDBIXChNFWFRF", + "TlNJT05fVFlQRV9TUEVDEA0iBAgLEAtCV1pVZ2l0aHViLmNvbS90ZW5zb3Jm", + "bG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL3Byb3RvYnVmL2Zv", + "cl9jb3JlX3Byb3Rvc19nb19wcm90b2IGcHJvdG8z")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::Tensorflow.TensorReflection.Descriptor, global::Tensorflow.TensorShapeReflection.Descriptor, global::Tensorflow.TypesReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(null, null, new pbr::GeneratedClrTypeInfo[] { @@ -2116,10 +2117,14 @@ namespace Tensorflow { public const int TypeSpecClassNameFieldNumber = 3; private string typeSpecClassName_ = ""; /// - /// This is currently redundant with the type_spec_class enum, and is only - /// used for error reporting. In particular, if you use an older binary to - /// load a newer model, and the model uses a TypeSpecClass that the older - /// binary doesn't support, then this lets us display a useful error message. + /// The name of the TypeSpec class. + /// * If type_spec_class == REGISTERED_TYPE_SPEC, the TypeSpec class is + /// the one registered under this name. For types registered outside + /// core TensorFlow by an add-on library, that library must be loaded + /// before this value can be deserialized by StructureCoder. + /// * If type_spec_class specifies a particular TypeSpec class, this field is + /// redundant with the type_spec_class enum, and is only used for error + /// reporting in older binaries that do not know the tupe_spec_class enum. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public string TypeSpecClassName { @@ -2295,6 +2300,14 @@ namespace Tensorflow { /// RowPartitionSpec from ragged/row_partition.py /// [pbr::OriginalName("ROW_PARTITION_SPEC")] RowPartitionSpec = 10, + /// + /// The type registered as type_spec_class_name. + /// + [pbr::OriginalName("REGISTERED_TYPE_SPEC")] RegisteredTypeSpec = 12, + /// + /// Subclasses of tf.ExtensionType + /// + [pbr::OriginalName("EXTENSION_TYPE_SPEC")] ExtensionTypeSpec = 13, } } diff --git a/src/TensorFlowNET.Core/Protobuf/Tensor.cs b/src/TensorFlowNET.Core/Protobuf/Tensor.cs index f7db83e1..1ab87133 100644 --- a/src/TensorFlowNET.Core/Protobuf/Tensor.cs +++ b/src/TensorFlowNET.Core/Protobuf/Tensor.cs @@ -217,7 +217,7 @@ namespace Tensorflow { = pb::FieldCodec.ForInt32(58); private readonly pbc::RepeatedField intVal_ = new pbc::RepeatedField(); /// - /// DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + /// DT_INT32, DT_INT16, DT_UINT16, DT_INT8, DT_UINT8. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public pbc::RepeatedField IntVal { diff --git a/src/TensorFlowNET.Core/Protobuf/Types.cs b/src/TensorFlowNET.Core/Protobuf/Types.cs index 32fa84de..6483cddf 100644 --- a/src/TensorFlowNET.Core/Protobuf/Types.cs +++ b/src/TensorFlowNET.Core/Protobuf/Types.cs @@ -43,14 +43,13 @@ namespace Tensorflow { "X1JFRhB0EhEKDURUX1VJTlQxNl9SRUYQdRIVChFEVF9DT01QTEVYMTI4X1JF", "RhB2Eg8KC0RUX0hBTEZfUkVGEHcSEwoPRFRfUkVTT1VSQ0VfUkVGEHgSEgoO", "RFRfVkFSSUFOVF9SRUYQeRIRCg1EVF9VSU5UMzJfUkVGEHoSEQoNRFRfVUlO", - "VDY0X1JFRhB7KkYKD1NwZWNpYWxpemVkVHlwZRIOCgpTVF9JTlZBTElEEAAS", - "EgoOU1RfVEVOU09SX0xJU1QQARIPCgtTVF9PUFRJT05BTBACQnoKGG9yZy50", - "ZW5zb3JmbG93LmZyYW1ld29ya0ILVHlwZXNQcm90b3NQAVpMZ2l0aHViLmNv", - "bS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVuc29yZmxvdy9nby9jb3JlL2Zy", - "YW1ld29yay90eXBlc19nb19wcm90b/gBAWIGcHJvdG8z")); + "VDY0X1JFRhB7QnoKGG9yZy50ZW5zb3JmbG93LmZyYW1ld29ya0ILVHlwZXNQ", + "cm90b3NQAVpMZ2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvdGVu", + "c29yZmxvdy9nby9jb3JlL2ZyYW1ld29yay90eXBlc19nb19wcm90b/gBAWIG", + "cHJvdG8z")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, - new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.DataType), typeof(global::Tensorflow.SpecializedType), }, null, null)); + new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Tensorflow.DataType), }, null, null)); } #endregion @@ -149,27 +148,6 @@ namespace Tensorflow { [pbr::OriginalName("DT_UINT64_REF")] DtUint64Ref = 123, } - /// - /// For identifying the underlying type of a variant. For variants, the types - /// listed here are a subset of the types in the variant type registry, - /// corresponding to commonly used variants which must occasionally be - /// special-cased. - /// - public enum SpecializedType { - /// - /// Invalid/unknown specialized type. - /// - [pbr::OriginalName("ST_INVALID")] StInvalid = 0, - /// - /// "tensorflow::TensorList" in the variant type registry. - /// - [pbr::OriginalName("ST_TENSOR_LIST")] StTensorList = 1, - /// - /// "tensorflow::data::Optional" in the variant type registry. - /// - [pbr::OriginalName("ST_OPTIONAL")] StOptional = 2, - } - #endregion }