From a5bb191a1d63ac10b744d252a6c4d9d5642524b8 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 21 Sep 2019 11:27:52 -0500 Subject: [PATCH] add tf.FIFOQueue and docs updated. --- docs/source/Queue.md | 82 ++++++++++++++++++ docs/source/_static/FIFOQueue-example.jpg | Bin 0 -> 30423 bytes docs/source/index.rst | 1 + src/TensorFlowNET.Core/APIs/tf.queue.cs | 47 ++++++++++ .../Operations/Queues/FIFOQueue.cs | 28 ++++++ .../Operations/Queues/QueueBase.cs | 53 +++++++++++ .../Operations/gen_data_flow_ops.cs | 52 +++++++++++ src/TensorFlowNET.Core/tensorflow.cs | 6 +- test/TensorFlowNET.UnitTest/QueueTest.cs | 40 ++++++++- 9 files changed, 305 insertions(+), 4 deletions(-) create mode 100644 docs/source/Queue.md create mode 100644 docs/source/_static/FIFOQueue-example.jpg create mode 100644 src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs diff --git a/docs/source/Queue.md b/docs/source/Queue.md new file mode 100644 index 00000000..bd73fd5a --- /dev/null +++ b/docs/source/Queue.md @@ -0,0 +1,82 @@ +# Chapter. Queue + +ThensorFlow is capable to handle multiple threads, and queues are powerful mechanism for asynchronous computation. If we have large datasets this can significantly speed up the training process of our models. This functionality is especially handy when reading, pre-processing and extracting in mini-batches our training data. The secret to being able to do professional and high performance training of our model is understanding TensorFlow queuing operations. TensorFlow has implemented 4 types of Queue: **FIFOQueue**, **PaddingFIFOQueue**, **PriorityQueue** and **RandomShuffleQueue**. + +![FIFOQueue](_static/FIFOQueue-example.jpg) + +Like everything in TensorFlow, a queue is a node in a computation graph. It's a stateful node, like a variable: other nodes can modify its content, In particular, nodes can enqueue new items into the queue, or dequeue existing items from the queue. + +To get started with queue, let's consider a simple example. We will create a "first in, first out" queue (FIFOQueue) and fill it with numbers. Then we'll construct a graph that takes an item off the queue, adds one to that item, and puts it back on the end of the queue. + +```csharp +[TestMethod] +public void FIFOQueue() +{ + // create a first in first out queue with capacity up to 2 + // and data type set as int32 + var queue = tf.FIFOQueue(2, tf.int32); + // init queue, push 2 elements into queue. + var init = queue.enqueue_many(new[] { 10, 20 }); + // pop out the first element + var x = queue.dequeue(); + // add 1 + var y = x + 1; + // push back into queue + var inc = queue.enqueue(y); + + using (var sess = tf.Session()) + { + // init queue + init.run(); + + // pop out first element and push back calculated y + (int dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(10, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(20, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(11, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(21, dequeued); + + // thread will hang or block if you run sess.run(x) again + // until queue has more element. + } +} +``` + +`Enqueue`, `EnqueueMany` and `Dequeue` are special nodes. They take a pointer to the queue instead of a normal value, allowing them to change it. I first create a FIFOQueue *queue* of size up to 3, I enqueue two values into the *queue*. Then I immediately attempt to *dequeue* a value from it and assign it to *y* where I simply add 1 to the dequeued variable. Next, we start up a *session* and run. After we've run this operation a few times the queue will be empty - if we try and run the operation again, the main thread of the program will hang or block - this is because it will be waiting for another operation to be run to put more values in the queue. + +#### FIFOQueue + +Creates a queue that dequeues elements in a first-in first-out order. A `FIFOQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `FIFOQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. + +#### PaddingFIFOQueue + +A FIFOQueue that supports batching variable-sized tensors by padding. A `PaddingFIFOQueue` may contain components with dynamic shape, while also supporting `dequeue_many`. A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are described by the `shapes` argument. + +#### PriorityQueue + +A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument. + +#### RandomShuffleQueue + +A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. + + + +Queue methods must run on the same device as the queue. `FIFOQueue` and `RandomShuffleQueue` are important TensorFlow objects for computing tensor asynchronously in a graph. For example, a typical input architecture is to use a `RandomShuffleQueue` to prepare inputs for training a model: + +* Multiple threads prepare training examples and push them in the queue. +* A training thread executes a training op that dequeues mini-batches from the queue. + +This architecture simplifies the construction of input pipelines. + + + +From the above example, once the output gets to the point above you’ll actually have to terminate the program as it is blocked. Now, this isn’t very useful. What we really want to happen is for our little program to reload or enqueue more values whenever our queue is empty or is about to become empty. We could fix this by explicitly running our *enqueue_op* again in the code above to reload our queue with values. However, for large, more realistic programs, this will become unwieldy. Thankfully, TensorFlow has a solution. + +TensorFlow provides two classes to help multi-threading task: `tf.Coordinator` and `tf.QueueRunner`. There two classes are designed to be used together. The `Coordinator` class helps multiple threads stop together and report exceptions to a main thread. The `QueueRunner` class is used to create a number of threads cooperating to enqueue tensors in the same queue. diff --git a/docs/source/_static/FIFOQueue-example.jpg b/docs/source/_static/FIFOQueue-example.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac27493461a4e8a73675ab831b285538543a1af4 GIT binary patch literal 30423 zcmce-Wmp`|5-_?12=4Cg5Oi@TSa5fD_r)Q>gFC^4ySoQ>cXtVHfdsvqyvN?}d(Lyu z{d2o!tE;-FtE*>fdTOeBUzc9D0H_}%q$L0l5D(a@|4vGLrq6b6P(x5t z{3p`LI{jap!$){UQoD5EP$Xt9K6Tr1Wt8y6czDw2sHuHg@-FkSswuBKLi%5Ysp>3j zccnJp!}~6Xb_xtY$yz*U95>*N1P4jiPInb^{I7CQh+~Ljc+MlzcnP101&F==){cL+ zJ6p{1c5R#Y+LW%{XlYiah&TKZ*)fS&mb_QKBGApB(=BCfV)yG!NZ$qFv-HaMZHEZ2 z%FtNeJ04i}ytJI9>)`z$q`p^xhn0m6%9q*ZpHiM@5BY$y5*ZO*$=~jzNn?? zq<}zPy4KDaIt6Ca+&0iCr2e;}=8`&gh|1;$2pv<{{Cur{bPZd}6EDB&Gelqp`^fqC zP5}VIP02BW1U80yiB>W_HSHuNCf8(+6gM+Tje%QAZT)y2ox;2J&o+|w8a>S$W(HhJ z;+s>h&>GB(8u&n|v4VCFyJ=<-b~;q4G~XIJuLmBp2>AH;Nxg2M9^$OQj5-l|Yu{lj zvALG}#4V<2DjMVX*FPE%drL@sH&pmPwfTz&4+Vha4FI4Xz|IK2joSMAS&(`-qpf-0 z|6cev4n^Q^EC7JSjjj}wXhenU!*_zUu0(|^ErG7oK?wlBgjT`)Cw%=!LloL7?!YDw zy0c;jq^C{EUZS8`Ty2ruqjYjk0kt&4Am+Go z?>?}KVA~q>GU9|f$<}t){7W!Ikg(!O18tAs@9O;FcKG}#1Q;cSOeU>nb|*kthpl&%)Rm}Y zIAJo;&EcQZ-OD+8#C~(0&>sO2v_=^`AAidu zx9BAivD{FfI23>3(4Ht-V-s-HleWTFAa-lHSuCu`&^#45xaO5@<`XLzja_p4a&1ys z75d4Dym4+7C*}-xX>iKb@VB5p38D}5B=4n(?AI^Qr^EAe-A5;^ejSlQH8piqPiI<1 zYXT6dmjJ*l30m%k+iOObKs!EJGPbS{a3IUu~cNk-_np}DzF~S!WrOl z85x|a3;~5LmkAvi^POwO`CnkOO>(=~Ev)w_znTbTiFa4vPDxIB@!ODUG=9 zW7O?>)45dERPvz>HiIRi9*BMUr{w>Vf8zJiICNrx~PXbq{riInoc+csG!f z+4);u0N`z=X};24>oQcn5?z@k_{LW6^*8q#su@rHs(8oa|1U z8*DdajUQ#P{n+|nU^yA@nGpso6;U=#;-Csyqk5Nj)@&QZ9E`cW-@~R7UXJ z_^VWDt1P370f6~No-!ezy+wvzpHE@KnLU0Vx15JLaa4*{WNKHnJ|eq;gu1 z6~9N3G+Jz@K=Sp5$+c-ovWimzao~S|UVi8MJ>GACH+ji*1t*mCp7T^zTAF6)r&ZS@ zk%_aXdT#t&E{8c_7Xn5g8U*cXvoh1_hrO@oAot!)82nV?qV3u(w-5~LON5+85eSS_ zr@;4efE65!JO0Vuv7F~QFx_sv%v{$~FX+i}O?OGN;#I2d^6uMBn>$v9P*jSqyC6|7Wt? z_r`L{0!48Tn*ZwvFH<1}pT~3RQq7Q@aRt*l1qRoTLMAK0dQnF$MP|d}ETOUMT zzyp;hjSu34wS!pPWbcU+a4>`C!V6+(|0e4{8Q8rb6sl@r##;yhfYheJcL9KcFvb+rEnO2S*;tSq zqD=pB7KtMr2LO{bnK|q1sU+^^ryTee>Xy_Q2s8}p) zKvFVRb`C|ux8xMUoRrieV%St4W5KR6EZBjDfPV7|IC3E-i=irET39%-acgK@S+}!q zY|_!GsQ6t_Lb|!Lnc{g-#x-*H3b@NrKTUME7+oB33KWi|;NO~j6Iq=4mROlBa(UH7 z*?;?~l|-}p8KU}Jz3QNUq?R@Y>Mx@f;xy?8QKhhDJv)cUQk@8nJE$VqwSAC&dO zva+`IgyEmar6u-#H_ib*+jMedy;%}8dk;&bQ63YX<*3BtyNWY6u@j=@w4_;GH_)nb zB|$Q*j24?ZPh~SL{EqBl!>BPgI&(t;Kk@uy4pp{*T^R&HYQglq zm8>bO=P1kTx7IYo!mj{J(0tV{q2@PV-euxdOwZ`47y5>k0+l(F!6~cGK%&gc)h9jhAgLF6tj9%h_BME&ago4 zK_@5jiSnjAtpH)mEFqxs5%)2?0#xtBRn3_ehY$$ROnX0cc#0NFa8^=dAS019^d|DO z2Td-{#)-xIDIPfQcuozz5v#1k-Eq`zU@uP{B6|fCmnCWGIG3oZXQ(fi>kM%UPj{zr zT4_NPG1i&FGrH)+W|*fRXkS4g*BWDZuTo1#;7c~Zt?--jVh?l7G_9*MqH+pbEsco% z5>nig+8I;~L=Jof#O^)c=Sq*VuSz&;a=G{4nQuNLwM=H$R#k#ve%>r|I}FUGC#T%z zKS;EjC0D?{cYO4IBCZs6|WvgsE;v+VxI{+4%KB56ggla0x$`AxtJIxv*QiS=aH>|NNb;bE0((wD_zi)JIuq>@p~D+Oa!st;2# z%5|z`vLI;0pvg&4nNArzD9U)N9Mk_w)Mr^Ye;k#om1AqKBVnCBfB;r#+TF6g=NE#% zuT%Ho>5xJ?9u=RsE5BxA*ze{0{#yCK zHs65QMwS4&OEBl!@%#jP2_5(yn3kM2^5MYnLnkheJ^jz)U_+r%hXEetsDLjRX_HrgSf{h;W)NAJyCL61L1_3FTzfNvLN7X zpGw{Q0>lG-d^L}Fo zfj*&H=VV;H){?E8xqWg#6upYxrkcB(+WHR-ckn#$hn)X6`KNSUO`VnX&No2r$FA8R zbS*+O!+!y~{N0p$ctIs~6giD0(0s|lcZNq@)MHp!d;`mvvdR73ZmFcQ@ai-?y3A4? z8^-%XcJ{rsYg1*WFzW8mR*q(x`lSi#FaG+zoG}i|-olZf zaR0$ugWS`KPR8bgj!qnZy|QolwRsGUk)VzbrKaRVLV_rgrGlcVal}?jU+Ri%o+~=+ zP#=~zS}j6;i~i<1+41Zzj}f3p!~9TbcXG<(b1#7%b-Rg{s*V&4@}fMrdL*6_#XA?pA-5sn-{x3?;8+rhR5O| zrgcTW0^sN679Lp%$zVk(!ZSbYWssruo;&_RPvF>x+11R+f<1B> z`b4x^13d7yTR+ zTcg|X3Q*7noGbxyOx{sjZdJ9p))}X8&3?gsjJMyd?tkY^r2DN#Y@S^(dD?1mFGY@) z<#@Y;{G$9703xnfO}bqrIO!Ska7tlBhaM8XBo=SNVEDcuHW?Y$*8LnAGc!26gRs5Z z)Vdhuk17}f6BbGUnchi>PBBNAT@W+ab zosO~+lzFUuVO+HP2#giwgE6nS6P=e@RvutQD|(DJ2>SI32wxoWZ}HL6K%A>dT6cvGSdGIU$?DzR-BZaH0e{Hj zflm#s@(BL58Lq$YB^O+?f#u0-aAf{WFDOns;55A8sT`9ew%9J_Gf7Y!p*3$^V9d&2 zI29q|>h{;H$+`p4ej65r5%KQ{^j|aSe1@yoFaYaYM# z+b?fRu7wtvoA`;&ApR?b?cja-i0){r3vFD{Hg8ZQ|M+!g4|U^PsCJh{3e-2d08VEGryqTd01>3I`4kQ8xhV| z!LNtp9N&L0IzM+;o7f|TgNqkK686Hf{|ZPXaT<9@>QZL69b2=PO)kGkis;}EY8Zg| z*`s-m%z8U$LQ%jkS%fk7#Q)n+!XCR+4-IVnw-ufOza>&oT8g3>My&=QfB{~KfFJ-a zclz@54W~`BR4#daJ0~G-G&hgnZmL~q6I{n|H1=EGOUyvEV6)}?=|b9}dCf-EQh4R8 z0BjByEbJHC$2hXDJ=q_M}|KpacL{##2}+|H=!MmzK#O7DUocAhDgl(vH+3LB#yv;)@kcN2y+r5w9h?XTwY@?ee z?<8K*rIU1N>(lrXlydMB1pHnIep&U2+#X$r`kAJ&w-s*AaW@6ZW$_zwpMDO-d&wWI z6YU6)&>CThSn7L_$C<6mO4~%~QHU-x*Bcgbr-DXg!m;1`nvNing&MZIwe-vG@{DWj zirKj5L%B3O8sCK8vfO47`c5ca@dE!YY@ZSAH$hs~545<(d75ltIWV|ev~QXqCUvp~ zw{w4bB!itPKDomy9;3FJ(RU?#V|mjB$D1f zp{^z8?zqn|eny2Lf7oje4JsC)?;m}6NhKKXRBdfUFc~Z?HL!+z&`_6`SKLA^ejvYp zQ>DLHp;1VH?IeHkx#t1oGE1=i6Oa1z$+mf`Irl%{>wnPT@-%rKZ=|PWKG?^?X>)BR zJfEh^svQ$VA4WvcsMzuS@W?3o*!0_}nR&hndlO9B<=)#HD=jj zu;E-4??972NbV=a(=9hg?kiy6Izimod!MN+<^A=v^j-hnI5C*WytNx|iu2LTX6l3+ z#x2SS&v_fLg0Y?tle9M6rqIt z+d#X|f(DrDt*+gH-a_}}AcbKRn%&xaA}6*;e0n9TcRYGJCnq~ zrn{>`JKQV-d;j|3ote|@4z89iq6~by>d1dNX6iFcU}x-KEH>CTQzVW#8)2f})>YPh z_mvYI%?QZ`2U}>-en(dzVBbJM|1;JDf%XOi5*3|{oRv+HMTC?>$?+fI74Y(<5X38h zpjv~he}u4eA|d}fuY4S%R=v`>0m2U*W(umYAWP{C#!v)w&JCTEo@e#ZQ5PpdpAx0M zh%USR&cp%|=fut`i~W{6VruDHcl<-HGZlxF{l3ung3dM-ioXGwMf~Zw&&#j0FM>a_ z^Irk3Ry&WTb~PT?<`kz!o9p{g+{h3nB&B}gNm!W0d#cKb1#}<&iAZjl?(Y?YNT*>H zY&&tzO8>EjOL9f?*N7n?!I33!Z0X-4R(b=8jzNY>%Bo1t!Y1r!X!P++_k$pv|pc;-lKtd|Sy~*>2gr%6j;h z;*FMAQ@pn#f1&=jT7qJ)0PzN=_t$E>>oeu%=1dAdGW9$KWzU_z4>iy;#~oX^9MTLw zg_V}e%9bY3=a2jkMBN?DL6v=5bUXgJaO1g1>24c-B|B^I68icOP}kvYdaPB; z|0e{UubPAAZTa4U8w=GG4;ixp&mYbgNMbpd$a6D|me_Ld9i%448+puLJdwxSWmdiM zgO|D83G6q&0>GD85*)#U1TSDi!NWnmfr0^F=HH9oXy{O6tf&}lBIGPePMD+=?4pWB z&Y$8qgvFE%T@rF?ppD~A0B=v0#2$Eg$VvmA2`;H9O*pe}7}qec3|$X*8sBIyLeXGE6oL&$zX&ldu{Z$h>LxgED%a4F zPI*7M^CbD`9~Fx8F)!lhF8_Ra53GO7xu!IrIPx|;#F0N;Q-p4*o#6ea-6A5(G@_t& z$?rHkLTC!q#*LlxfpGy9QGG4q5k1vX+N(0sx#d^qoA)+)@!p459$9Ek1%w+@t@71( z+|0Z~4p_Mknkl&+?&($zs)~+gYW;Hxqk6lF>^r!3D#9a-I0q$y@a*r% zqN6j&B1rVy!5NDwEO z4of_7qO0#fBj2{EP7lwp@Eul5hs6rPFGDcC35FcCM2BCohkLA%S$De7>DZKXauaA! zX=sHV33~+;Vt#4)%$}Bh(o@}uMo<_phl@lXBudmvf=89222Ah#9GTS2bnRMzk#Ayo z2tesJTH_8nD=MYWX1HvWM*HgYLk>Q(W}GYxzLAQgmD4e*A~U{~4p5-8RPItak|g`7 zFee<(!Ev1k2H58GwV5bDEId}*dWwe`CRu~;vE?ftMqz+^&*9bTxCa$~>f$S)0kR-* zhum=F;(f)FiXaV|q|Uf>DA1b{2$UnnBa7+`$LMFz^`*5Mqsyc0s7-6#!vz~rNR+bJ zcrV47NbLd-f~D0kSC5nPhk>*umEDi7nJLbuZWXy%G{G7LW{L79H#WiWO@8;aiAe60 zcamY=WS%9Fv_bZ)yu>ibQpN`C`SFS4oaRzT`IZIaB9(4^E_|fL*GjIF_18a$8n8~4 z?Xy8ApQXd18zn1cdZAZ4_lfDb*GUtEt7oSaZ zU0h&JoCZAdIUtQ(eMo0IN~$&Q`NMBTH)Ji*&=MTy^A&%Fa5c|4DXNF@gnT^$r3_h4 zm#DL(x2P(&nMMM~T)je(D(j{>NDD7GoC^HPd@c|j8br3nk$d^Lc%&oId8;rRq>l0c z!~JLCkx@C7b4*K7jlL_SQ_~`&^8py_S%dHOB;Cs?S%%xt^_np!_sBsf_uEFRX}0J$ETa$cAA6z z^da+Xj-^;za7*l39j?B(ISo|KzK;EU(j+i_M`6!BeNAlT=3dm2|1;!aL&e^AeE6Q` z-d4s8*3RiBg0;zA+f#gjPb1dXcB;(0$tWdh@L7G}^%&(|Et>~g)_^E{6iw3*>S|p) zGj>g_VPYwQPI_56al&sa)W!_Zny`=@M+KxjXF0iO6DBVBvLH)PDWH;pmEWVMDv%q-xj61H*YEaNZ0rk+%_x<%m83(w`zK;M zDhcaKzh!8}&{;zA?3aQlV2_^igdM%oQTUEceHWn?xVtUq*G4Wg$joHy9KHJbKYM_l z=Zoc{j}UZ{p1#*gX$!xDcX7Oq4S~j7GVVDTG(yo#rH-p~9p4}Jpz1lc&Lpe^tWICZ zqTED(7iupqg3MuXx4$Enz9J%Cp{^Ac+dU!)Fxeoo3Sas7;ExVR^@&?iBb;c>Z*&$7WSD?s_nRIZ@5cl#j`h*Iy|r5-VI zsWU^Xi?gF#m_mctQ{`~AxXe!<8%?MzIk`E5!l~KerQ$B$9;1t)a6Mg$cS>Y!_h#sW5MOH(sRVWv9{*cpno^L}}+_9~32^ zIz(fV-8o9t*4wy{U@)k`vL+C&ZtR#-5qV4KPc$|}#e_qcU}$UW($clK>~&|kLNA*N zJ#$mJBUm?VV6?Grev+wJDUQP@+M8-sW5Wtm2(a zyD9(me`ncvjFIE(5Nd~7H~Mbbz`tjSQNJGQn~!jwblfM$RrvV4r(9fVL(Oz4 zIVE+s$Dl;B;V6hYa1>S4!0xjgoGt5Nj%EgBIq2?C8?aN_&AaGHq4D-a94~5C;A~89 z*f=M%vSs3(84Nmf1!d^ELU=c>_Rc~ZyLi}T9XnKXOQm_R@mr)SF5dIx0z-EuPo(%t zj#(4@R;3Z6{?z|pF6_FTk0bgD64tmz^u%X{^ z@J0`J$vgap%rmxqNZ*v14JQ?0Ecih5Uj-a6&ZY{Ue#BxT?zujS3&-l7KUZvL2WHL9 zD}?l1{4NVr9TytmH3rtzPfvC)N@fc%5CaXFIblqwJwnHYm;c<^1?UD zVh-sCc>$Z;KvCR0wh@d8H#mj ziB|or(Q?&1LUbT-n{4C5YH8Zt(Y8|E1lK_+5=chJM4qWA9J-a}5~`-IHiDZqaWmcQh|Xnj;{ng9nj={m)J3!I6X)T0|6@$+f<- zl`|{zl8j5~r&C(upNYOQvK-)B7a^v*?o5AZY=GyI3vV4%b)~EJ-bTJaTVF-;#u7Gw zn)q=kcOfg*0aR@EN#JC4LWzgKN;o+cZFxtXln$ZCnOym~YEJK&P(3DX`PUGsV zbi^EJTB)^15|qd(x|A-)NEZe>Fvp#fcn#c9q zqrjBNBA^c;%&|DXYwXz^iX;%Lo{-D(!smhwK1pfhqC2eM-P+wX< z=aEwtSK+kxp$wK{xo%`=L8HEGG&r+Cg7d@7)et6EEimbfw)JiW2oz3%(8DtG*4j|N zC0z)nJQHYpK>^&8dd3E^|K~&?F`H@NY2;OQ-7M!hf2YG29 zbk8g^L*f%%?7<5tLi)dr=!!BIM0ChG-P>HTD_zS&fGfOom@r(-W|@r!N}+ZTpm@(5 zd}cJgFR|!b7>q~qRnCo0$p95>D(Ch0&mv=|0YMejnPML;5oJDa2VaN_t&N&50d(Q2x}zgKY9 z)8bLd#1$zwu$rn zpRdk|&r`o6TzCb{HqMzsUp1c?6)2jx5;Np~=V%M#xS}4oTM^-*IS(bMWbo8reHulg zkH%fobOF+{y#i{1)*I<*&u~@+>g0G@61-w@r*akN)$P1hoXRRRUcy~I)!+Cf8`ydMA5Ek(4fT1l-qk@d|kim zXgoR>w_Fgl9Hd>OzmR@k!=@!<`VMY}CA^I{G=}!BjGnVqbIs|tqnI!I#+EacgtVaS z90frXgqX)MfG#Zs&5l`s_yU^~BRvbXQ6f7=S$9_$?UGif?eUsO0}&wvC2dILMwf;D z%eH$prg%X-`{&RU?Tw4>3bNG}t)9!0mT-e}l-vgbAthg=?BdRl$vFcti(}@Tx*s@G z%MTGpw$JZPaMg9$DLj7H>G31lUwCL4{#7G>%jx&0sYk@K~R z)MWUPJjn(0>_?$>hEXvJL4*mF4p)7YLEHMncuRbU>TL(caO(Mm8#$H_&&!U*R7jCU zMR-*@B-<7r6hy@=^kzO6ZMBj%HCIuDdt7Mo@?b1;60Z$PVe&QdYSo5=u0WVemj{=G z2pJr(wxQRP;D%ZRp;&P}pFDu48zCECXfR{pNs5OzP|ueoKT0;dHOjl~&J!gdd#YO) z+*}N!cVKF>F$T8ugvqkb?{UOY%)X97XBzW0Q2U5!! zSp>)T=Z9s2)AWvzkZr=F*;&1bXDlVcMVDm&_J%G>EjKM&)vi*rrZWyrDtVy-?#~ciBaN8xIFfj>wsR6)QQ9k_ zv<98yjjStJKO7Dv`Z#q{-E#`0DWK%M$q^Nzu5ce z<`tcF)nvVfIqqy>bJpilh~9uWM8_v{8M3`>+QdPUgja*zF2u{DRV)8I`FZdQLY%ut!PbuSRvlhq;21dX=E-zc zI6hBb7`5a{WwPm_+pXqgWK+_k=ts>od(u;oyzwavzVM5VcV4XE#(Zdc_a>n&3>Tx3 zi&V!ChX&YvyTT&!{NA>3n(X)(#0|Q9$n6tGl71t;kxz$`91TXeJm@>My!0{R-|`eb ze575Hca7&EG1Bo8CQK(0|E%BX;X6j^))I3|^g zlsj!oH{XfJVAGF;iZ*T|T)?E}dXGHn)pI%RH23v-+HvY(6al%!&`h~|$RrFpngQ*VP7hF0}d;(J-U z%rx4T7leM>XdvN03a|qYaHR%Fu<~7N@gwOQ{&^dTiQA8p0E*NFNDC6McOco*toJq&G4* zjk`ua#z7`I(D;A(sUdugSz(t5>XdZiKlTFKiO>PX>??rDmn^%mGk$W`K-%JaResa$ z;*8FNHPUU$FZPj4b+NRnuy1ch%i=2L>Vr<{VEh)ewd64Uj+UX)Cn7XOI2Zs2=BvkD zck1MX1_{)nTT(j+9xJOKa@e&G0$iQJyQ4*;ZYd=hbgh=E3*lRGiVDPIMm4^2j1TpB zSk*r0WXf#c%m={lq@Th8AW zS=;pp%i#OrtU5N*7Hib%o0ghlTMzC^5Q6Q3ltitBX{;Dv@wdFF{1qu7XNA2-GFI*( zZ^YZz4nEyH;kZ@@oX;QlyB+wu+TaNOru6!~M;bnIAd6MSbX8L44ghW%j*{CNvWD2$dP z&t!>n<*ioJIK5QhqBj-0u+fv_9V1HN)=_lnbaCP*$iVk7!)zx>e`bScWX7v?fNdV} zb?bb8)>~%-bejQdI!g&31p>lyhw}?D!D1Re43N8tgN&Nh)>c(KBd2 zji6`9lV9a@iNkqk?7|gOR>4j@+Q}!=nFUYmthXlVw8qoPj-dAIh*VlFMP9hrNzmEI zPY<)1P{vD5N}nqyFOH-owMI8Ss*XR>@E$y3B8|EZ>R&imY-zaLj>B8k<Q|fqP3p7Hc{UGKRLiFFifc66&LUt3^aHH^Fr;Zy={@iIE8M1qi5f|UNUy_6HDfDT>}d_LVPOai?8#0T&lDT z`4TH#DhBi~iv-PgFPkS%$O}wY+s7Mlf}CVHqsmukcJd}eBe2%N}$$;qhAe~JyJQR% zVg~It`4qO5<0kD5X4?}d<-{OE%T>G-5ORiYX}6EAB2yoPrF@)>J5z&%-bW&nCqSkR zYjtFs|JBcUgWGp8eFq#ciSKPB8`t?{J3^?mJ~eN;-YYzX2eeepysypN|L9QZ7A`zr z!O9?R=2u>;m0V4-HXo<4;j(<2(pv%EJsgFphh3I>jAPxm;y(!!)I(Tr^ej3e8ZBcM z%4nl^!8)Y9;TG`Y`=SF_RRF=br%M8>TFr_5crI%1#>vfOYpT8bJbYDcoM zykhr5qWbW6x+#GsIeMJLwAlqLYFIa}@0g-1Iu(7M>-sZmh*c_*Ar2f*0^{!@a*ZL} zQ?5w2jg`fZU--O0pCA@rF!U_G3BL)5d2^C-4?8b9^uweFh_JDX#26Lf0@q%iN&0oi zDtD!5#(JvEtAL?4FNw+a148VubL-0@F;$()_yRf6lNHHZzKns~NTVVg+?$QdFg<{b zV~VaW!-rF{xSP~a?dulOX-7qP7mP*<`DnAf*yYWmjjqD7W}qgF1k9&P+bw1)3++p+ zZnP%5nQqVOHKJTvD&EOY9f(ceNdoc9bw=4Cf3|rS_XwWBz+tyY?g-<@mZ`!)3gcAZ zztGBpE^+kfnjGK4C$%=Fzb8+M5gM&oX#L^yPD@&sJ7b(y318{iu+%p0;Q<@{q}?glNOm@6QrUgR|vt z&LMl`oRzL$2zb8^(#ama|Hguk6f>l~q`$~cWIR{UN!Gg)HLdxDN5w5Z6+40KSq2_r zhMk9}pzSAbvrhih7+FG31mY~?fdwp|swj?OmTGxP;w{B_O=k9!F&F*YIaSiA9$ab? ztzM>Kx!|#Wd*Og>A|eSBiEey8(UsLIp4bsdSy^c4Jo^c=_?!bAIU}D(X;Sh1AMR9i z!O@s!4!v5+rax#tJdbWWPn6@dwvHoK-KL$#@Y{D{p0-3npr<%y*-0-)sVtGTQ>9zmwKqRgKc$O=yD}7f*LgCDXz!uc4ZR%QtCng8AYTU57_&@;(;XnHLE1B|av4R|h7{{>$Ba2*y&?__oLme6}oKMl>TWnET1k1T_xvEH($9oHpmWen8)fB(rnP zbp;)oz+@*YvY@#;O|ods&}ph@WI6hbC_9~~cc=GpApy955I#%LV_~>4jdhzifyuHH z*#(`M280~jkI?PT)Ax#N&XHum@>RWll1MND?(mM~tT}hW)%TsC_)XeR`hmd%!o2GK9cTF1j4H4o*W9~{T zj6}GgZOG)Nx2MJX1!A zVfR&FRjo=H$$7+xjEPMd(G&|K(U5XVpLW8@-a2y zDciVk7*=Aj2A^A0OA3qi3Q66$LX<1D&?@ruqsyP~=xoBH|8GvN`EIocOv$cEdR#xN}^!Z72rjO%_7k}?m?i_b=Hr;K$~}YniASt6QnL3& z6RfVUD*1LU`slZ3lA4eIu7!++NZo8iJPoAP!XN5|G5-n(!ah&mXK;7jXuJLht>USm z_zF5ljxOW@}YIs^qHTGWV)u*pcJac`*bXkqowROTO3$RXpYewRELO;jE(;tho~i z^(zjN@Ezx>x5khNCtlN$QiY91ksNO0Yq0VMErxxK1>mKIcKkk5tMMv^U%le)tZKI$ zQNvJ#LGQS$-?WNoi60NqbMT?e0qr0?R=Fti9(r3*zk3NuQNjE5eqjpwXeoIL4WZp8 ze!LJ)QX(jqWOMMf68}1zdYmU}sCF5@A=@qB=9y!-wlR}9ZfNc}yU<7CA7Bp7&6Qme zx$&;PXwpupLP3Ol%SOw_%SKCc7p13Q$ArX@uEctfd})?zaze@zhR@|(*_ylczS17) z25Blx315~VF)ObZygTz&kZ#@W4R~4TJV)AmDE(v~bqc z5X&)O?6wNBA9WC7?7DtP^f2%7=114O#F0s9$3b>2_FR=Kq~{AtyNdurx0rg*5R-j3 zqAbMj#iYi4_(qLlHKK84L-K*nmuwxw0kUdfnWd4gX8NcL(^wPR|5MCcM#b@MZ^GR` zWeZovs8xCD0z?(Xgm!JR;`km=;=@4o-_t~)ayrp~GAUE7XS zRrl)J=h;}eth7S)00hRvsvv9=4oN$h~-C*5>I;T+M7<1AftCaeIL`l7356kYYm%3uH* zs4ziwhBnNpf@)p+ePBUVH=dm{O>;n5#KTs`Vkq1pfziHElYi)i&!qU}d{!4>scVG6 z@Oz^bB4GK&c~xk`_e*SDusyzUV~UvvT`ppl)bdv~n~bt=4(sDHWe!}Q7{MDuTHTG7 z2kgdxu~qw50cZ*>ZL+;=Lo3yb$&4I5d-#3O71_2FB*Nd)##y>Y$aIlXjT4CRgfLsr z9%ncOg&>7NkPlXFs2y{+Zue&QW~YY0yJ+3}7l3`5pCb?SGw%S8IIOzc$oesJyP&N? zAAAIrbroSpoB9>{h?Pm3svD|SNl5^A<>1Zw)_Ei;(wYy^2I>)n&YZ4E&m;d*@k%@Z z^f?>OJ#u~fJadD_J-EeP@dD2E{L)?VXq~Uf^*nPH!ucy%NMh65)YEgO$gQ#j!G`?^ zwsG8W^Yfg4__jiO(melgpmmU9cieY-Yo>99T+6p$_f5cWr|3?<^ptxtsoPFB2z!{B zmzPPb;NkQE2BhX2YU>^rpO;d(o47UBw!|4F+OUJBAM3iFG&?DB*9y+J*F0W1Tt_6A zpxM$$S*(}G8`dqpYFrIT*YM;o9&p>&7KCM%t#_BRWGz&@WYHkTW}U2>p_=s)fcc!8 zde|kti(o`oa2ss=`9+G_CcA)}2A=cXe1Wwa%9revN?`uhm1<;9k1+iWN-vE~*h0{D zaW&VVZ7T1fLk0JTn1;pj{)UO8AH=^P&xsPsiG-Hh5iXRTX3|;QT)O0JJ?*kisw`H& zFh>}D4b%)CO8!Qb7MK>LW5Tn{XBHW%0@w9NB}I@Mf1W-|Nc;p}zf^o{JNYH`qkT^d z&IYd>DXkZ|#{6AHQxsQ_Exh_l!j~;Rg7FmeFsEX@LTxeea~g#+nzka^~rYvAhHxsYvKx9|M3Tq^bolQ76ohgfoWXo zIKq4J1CY+v5JN*xEx;Kn+HUyigFJ|$4Fy5(!aOoRXnv-R!SH_26HGrIO2nDc+FLx8 zg)O!AW5O!WgD_C0UT+S_M<(Kp^i(wP$%Jj($4;WgYoAK5`K}uySjUSa2`zawxoq(psYGqcQr5BlFYVf3_IF*sR%5Q(t=uuuv0EMz%CH-zfhF8N=VU$Y zMT~x|Fz@jY;-72l)lFE0Bs0E0(hsu$9=Pg@Hue(n%icn+UcO0U_w3c^1#r_Kt{s-? zP0NEmSuHy>FvZSEo@d3rP9F;%!CD04ouC+mW`Xq8-#R#$q|CqFzXV_$rqbspI?AN+vl&I2wN3L;`fxz6U0PP z1vQ}tV1t8h3iN^RsSg&d>vm`UpRDg51+3P-{sM@q_q*Z=oWw!1c5=CZq>nBlkfGJ| z_PL76@BMPmX0#83Vb7qj=2KT0=|bVq-6`t(OKFE2M056t?fZnZ%_zKPrNJK>elirP zZ)fdh?S7g)q+x|vkZ!~~b~4G9)Dvc(@rbA@meQ!etk}Zu zsV;$ii;jx-^8|VwiSz3^FA-v$PI5mwG+lkI>L+b34QB7PTFMsq3+Vp*jG$M zSlGYL=K24)@-Oh9&VNHMjAO31AN`U<2HceIM5R_Sy!EY%Z1HoPMDiiQF{?GSFPUQ zHv^B5zs=!;+RNtc)@8{3+(xUSIfH7Cm2+eaX!(Tf|$U4dI zGRbol%zrb^GaLJNonAS@fA{>G?m?a!!DH=(uR-AlPlDB;`Fq0|d;bjWI`lt0|K-%> zKZZYoH~Sx)znxA%ta$#jQDl}bWR_Va{=b@l^N{&P#^5oA{?gFRBvVV>F97it=Jk@s z=F8WUyUV}3o8ak|;AicHjPoxBzW|4dbwX7?JRfP0oH5WZa5 zdeY%>Qu%bz@uy}%ly}GfHa33DKa<9ZWxz1We_5O-hDXhi;XFwAuR3`Fxf5RC0n~aB z|HFTV6lFw00CA2bEX)%qMf?Bl8!?<<0Ln9jFB7_#ym#><1^rBD*O4)$?yZ|EQyco&TgK`O%d) zz$ky8fq&`zQz5Vio?H~3$>(N-{bZ8=N*t^Vf*1a$^gjtfv=Ptj_}wQI{^yTv|AV*x z&|f0$$q*JnYP;#l{5_@r4=N0lV7$EBimAtzhd*|hxeLPp=ST|rDegl%XUeksA2?^GUkan2ao(Qd0ooH{f2jijLO6Q4sk-GocV!J846jC`;D zYcHS=ZYN2ejT~aOw6>B$?G)k)V z*=0qohQ*W^S^RCCwm3ODye1hqLeIXq4^#Rx!i<}FvJlqingHb``>IwA4K4#hb@J-` zFa>?GR}uwOugy5J($YxUYq887FSqW z`X--Fl%DzDl$h1*$4)hKCCaSPHr#0;mnbIIL^TU4mG3PEbG9Saj>D1TSEVvjjQ z=^gSfFpDTC%orH&RGiV|_FG6}nU_0B*Jo!5?BzpDir6ARfrVegFk-?WLma-t*7Jhc zDk4y3V>2;HZk%0KA)$GHz(@Wq9WrP2pas9muG{=a1R3gg-YwAIQ`7#5AghR&I0q(N zf?0CAuSk`RYo>+&J%S8|B;Rs+$>Yoc&GtyC zm+YQ?vTi+KYf9u;B)M8Z`s+e|_u!A`o*ozcFUt+zI>Vc?r1}}R&7@$p zYd#PbO3Ckw&&tt-8{;1Jk8IOHx_b4cZS!ze=+)8D}EM!Cj_iPG$2xpk&XdI z&%4nQ!@Ge!>r zvoA`nzW_<=LVG{Z9&jzUp~qv|Bf6(7jxBus(`?}Sev%3UL{gW?=<9FrA7t19#h>90 zA>S;@Dyw(0pL6P$gp}81YmVqIoOD6&>qJLwq-lGmPcB_Taz!My`@RA>>t9(Q&fYR~ z?cId^L|Q5nkr&Tf@M#HbcG#-HXKG+q8IqG`EJ}cv@|wbp&Ye!BN3a~CDDWW{I0VYb zGj@>1&Tux7Vx1!K7Y4DugZ*$`_>FuFW3(ZYn5A_pcD|k&9U7QtDCttQ)|MqW=kN~zR#sv1J8^6 z`UD6?>Y@JFh>Xa$Cy>3OJqt=RtiL;f3gGaYl2k@jq3qL%opAruT^cD`THEgQVQ2_< zdwQWMwtP{_aOH=wIvf&+vY{{=apKztglc^eRP>W{jD>o17%W`Tb1+O#e{h+a%$E27 z*S$8GNkBV+X!KIplm+Lrng4y^X(QFmYlf2V!|RqW30+B;!ORJqtSmOYGwO{{B$I~U zBnLUr%w_k0ZNpX})s$@QaXh%_Y4c`9&f6yNWWvheK~hoU4bTELK;`XJF}e2gWtr<)9kd!Z(k7Ff`%>gKjfq*I6=$fjCeEVuk#2p*qFl zwcTmEK`V@MUrpJ_2O40Frq5Hj2XrWBi`6iIBkCy7?WV?>_^Jyy-5)YzG_W-k2Q-hF zfm@s^d(B03sCam`=#ET2uV=m-PMz635roACDDM{%b>*~dKE`p+9Ci^BJ4>Nt>Q#uS z3D0M$x?`60&$#XAZ7`_oWv|3|WCR!+7}-?#5L^P~lBUBKbA|T_FJq@5&tT=1>)R$K z9%x68dc7d^7!y(Tw&)W)WAR*Hy>3u+K!3>;k+w>oD)g6iArYr<@yw~VP&0MLuFopu z={dq^tB*u)>$anDqVvn~>d#2E9(@dK)Rmj5DAuWt_|P7lae!Rys3zcOV*7h?7tBgV zGtrmE2vC3)mENt@M(aSP>AvI?`{k!ZexKk1MB45rH9V|O#X}$hSRP&qdYsW?*Q@02 zjk{19d4*or5GO2Kp?z%+oYelUBZB-b+IVK-vKB&JP*9P9&b?U{zB>*aCxQk~_q1X* zaIDe3PO)CBJJrMmLA$5`q|&20v5{%X+=~exhuoceLe|jlW<`{zsDD(r7$d^mZPKe4 zV%VTA#*A0IJj{edE!f!NBFoPdCBQv^!l+z-7laz{5$HnnP2#7i6<7|MfsoKWfERH> zu^*w}r!dyk>1~j4SKn==tXr}lYGhnhq`E0#geUDybcePZC4V4#3MMa|)@?Ykh^OA` zs@{3|;=_7$N7serq(|aZpM9_g_6&KH%??`8ASsRCbS-aZkfmtKOo!m)=uTRIorvx8 zK3+C%7mKShDW{ zhc94$!i#)$mE|f%T4HsCv|yc<%ShT=+9n?gqv2A5RKilPI8pX!ri{1~u`asYwG~6F zbRqy$dTSwdl_iFZ(v^p9njywdVE!IVK8e<_sRegTn!m2N)U0l*U&{opuxmwla&n-O zbUNn8lS%LLict|4kLwn_A2~{)&8;Tun9f0tZ|&Y*3{g53F1Kt5_mpJh-xAfV5E=In z!^Ut#z#gchM^nXf14=7Y+tX56<>KBvO}@qhI%xEz^68>)0{0_3q2g%Q8<%Um1DCg1 z6(e_2-VNI`>IJ+b<>smqWhM_JMT#HUsJBeC{4A}}(5gN9jH4vn2ZhK6>!go<)D@k^ zhkJkn-k&Nh`MmCeK2(C-=+kc!7j=AQZOcHd0Ifg1J>Oa7@5bb=CD=W9b!^awB;*-k|Pm>nD2epDX%Oswsacy@}8{QNbuSXE641aw8 zEj0g|(CF_%6u&9|0syU)%wuMkN8R=8joQsD*h}0!3WONiuT!k#2 zD|@6s)t4M)LQQZ&m{7VYhQDO{22b5HvGc@)AlFT2Y_goJF=D8}NOa!sT~{WIEX<(( zy<|g;Iig`1MM+QNd`1-2oee-A$_vjqgvPcQ;E}XZ>k8@3~ka=BEbLTzflMBC1 zz+%1--%HUq|GDRj$Lvf<J-273#D9fb*xg$tm zRaiw*NUo0;-CFhoULkh4ZlG?TGARVeA>o?(5Ie1fvNJ_jMm||v-g=X}5xq>TQ{8Cu;gyd&6Zj{cXIPc)tT9&-wF|V4B{tnsbA{e$;KN zO*1ndh-qn)TZw%4a4u_Dfbssrk||S#Y4Uro!8rrXtGg^@27Y?M;y#k*AQPb5Ua*s3 zJRdP+xt(RGmIj@15}c29WD6n|7j%_l7~Co$XRKa|W6|NrS(p!pG2L5Bcq&d6w|!x5 z2L3)U^e#5RlQGFTZQ@E6J$y9<>ZR{;*IP^L1~i%sD@HP2F^=4+FpLn|^SvJBE9_CMWyq5BtGzt`>~ zRBvHTd~+E%o?=wcMa3dqJFI?%<`fostTRV~=CUgg95<)S5ipB@%&|(V$;k)1)pO9o z-_mP{iR~6X13+D&Zd`-vjtZ^KO*F*54ltUBasVM2;KgkqdJCA<1#68xx;%v?h|!z| zj&w0YkB%C=}9?(42UF`;7=7j`SOkGQ)ob2!>Kk&4BMDLl+3reQB$18 zoDz`rR$R!5i|pp`CZw|pUeU({hZ0d=y>Z?}t*K61{)$+oVWtd7K)uD2(PcYp+_;1y zh=-2m#_GZpY~o~Cmt72o6<49fD%qgXL-R+Zren0Yx7$AVv8r}{M2Qjm>QdA_V_l)F zyT(Ln$8e9a=Fn5vTbx3v*&^Wx#ty2$$BPvfuTlN7JX_EE#+II4GGUl_z+HT*S0@lr zF)M_&>1as!fW1Ccwug`sK-=X*2+xPH!-z&+d!3mC%H?@Gs6XM3*l z`hD4^Geo)g^pgxKxbPp$pOb_rvs0CZSq%#97@rgMMfYg-lCI^SqlG9ZhwH9o-qT*# z`vP)H`AMq2%+RQna8$Vm-;C0hVJK{vq=F*01yNBE5SVFsUAWY#eDH0zcCgj!+7-NS z)xJTAEib(xnJwn^@?=5O4TQ>^3=1{EdJuUHh%v3_uBhI&_(ZG%or%$gi|M2s>fKH{ zngezSPK#WL*XucugO_Yub3jDlfotzUu7E-6_Wm@iU6hVP5UgAOv~7Y~dZ3M*<>Cb7 z#htg<08V1u2Y=t5o1o(+RuR?+F`GdWkuN9{7|4cuxiH^;BxE5|Pm|0DTA}!ERvlW+ zTQx;y?{Q>%)0exI?ucWPD#wG=wF*7jz|~?`&<%H&3C%dm%4GoERwyavL3QxmI~)I$ z&}XyQ=TbgM4I`nWjU?0mK0m0c+$-L_uf4WJZJEc1fCqqE$9RUX$x97gS8cq=mg{?J zE0&kLW*Ejp<6NYd$!$5Z8B@u!_wHnM>lkUj3y!eY5F>WkE1Y7cuG+4=sxD=mB5-{K z$=dJowS{G#ZDAvGOVhGwV}E;anb+W2^{cKC)gI_V% zpR#$L6UXcN>b+^Qsw_vYY#UZzgQq(~MEA%9Bola>+dm72iPu6$PSN_18x1WwJXdy} z-=UJ^9}4AZXrem^-?Lj1wZ7H6MWb!|5Y{(ppyNmJX5fzH7}IPKMq}5Hz~OB|v193cLO!J{a(ZBjG+ax^x0OFZ}{+WV&X=TWIkQAp$+j{(jRciita8JL&N@2vldWtkKu1(A6)n-#vb6%cD%#Wi zUioqWr;;CKh4~nlSM-{rlR4WR?0PI!2RG>$c{EGA4ra{eq5#Mmx5kymIqhn_;2>CY zI3#m65S5KuoHo1Vw?dusdJmEx_6or>wQZQ+@e3K>sX1SFKpU(3=!H8VfKaSLQ!x-rUc6$lknMM5JA~K`o0Yga!}` ziS3a`(MCnkg*kVM+(R9S_QFk&i07~fiIV5>uIzNV=d%woKX37QkNu#u=d=Beh<5Bs zI8W`WznXqMrk`x1hf$7mEYyZ2>XF7>6vHqhp7nQV9j0rkby#%LPz{-v9H{ac*B}_e7PE43c z`?GL&p<#0zOjKgxC-K9z^P<0NsI_Fy{hypPLDI^0{S3$c1ppt3&d$^X(euaQC#`sm zkE}uWdU}gUr4}`JEqb1#wcUkp=-okc&%DU)I6YP2KZyEiXD3s+y#ba$Nac)VkA!bT z!37US8VE~2h5+y=Fj#Npm0?GZEpw2Q_o_=bWW+h)q636YqnYA))ayp<_WMK%KsWcM z{^uN7ta%Uk=Y;Ap)Jl0sscVKqEZi;?~d zav|4=5j7c+QiHw-Y*v!Xp3!;00-EEKlB@4Qi{Q;_w&(d|n+{B$8vo?`Kb&x&=c)gI zH=%gDUk9bbAK!Yv$^`OOZ4qI5&A0vnY@oZfa&LGST?ownNEb9Qew=judC=l_C9K>+ zxT(#hHu!Tty_qtrj;0hu)zQ0epaVA;hI1og`u#G4?rev^TzpYR+cX>1bOA4e_@KTY z+4jp9MK7bFN8BGA_p*HbC7!Wy(xt=PkMpY@!E9gH`i!X}d@w$5kY)>K8)i$sbjTIi zJ_GG-ev%bQ)2G(-u?zvH*btAAFb5{x9y2g%O6JucjY_`Zjv($u#eCK37 zv)wg6;0Gmyw_T!Z8ScU7oMu{lR?^lK=`BD3@O7vANMS65UV!D{Q80qv7*B=g0t6M( zC4K?0l-i_KWZI+LN+etal0q~&>J5lY6-~bthD;QSegd?kCSjdu+$dT_rglnX~JPMWmjRy#(Q6idPg#zxjREY zY0jR?U5dH9#}no}C@a1y_49U@mfLZeOof(=k}H>*OA@9qGz(!5%MfW-z{*`LW5*%d z7*M(rA=hhEr#bLFn%vJ}#Jh78q?H{d{nC&DtR^HB)NefE_C9!`IPZcrUEfnfX@+2F zxEHpgg97ZJr$2U?swG>ef;CyBat|(z@HYeFXOVMC2o_i^z&BKC#`Sah@7YU4&a1}m z1X&hqyLsl^dhKV$EkF*1iD&N_z_*UEiTH*B*HSdWRtY8-ujwtP%}bZ>)$(2lFf4}_U(l)N^4@N zlsER`7;0J!l&kQQE?X+C8MOE|?oMP@Kjsg96A4R{A&(dSDk`%cj!8?u`UE)i>B&1m zDf?10M**BtzgyNZ+X+KWuvinxIv2AKkt?APfOob`f98Q4}|Rxo*OX@$P^$X5FwpDfyyaBZ zlJd%eV~NBpUTJvk!nnP_Oz^#4+94bl6Y&?I85U8#qt8a^0*??tUoo(^yUHnO#!QgG zRjf6HW8t0Cxh;4*;I^^j^14ed-Wsw(GBxOK=eTt?qr3a^j@VCMFyJR1jr)FEbT!WA z8rJ7@VVnrI+i~GDm#gMGl`4c8JdFjjoq-fifHGjG&=?gYzrv`xV3t58oysCiBa0Y1 zf?CQ=E6#%$7-3^>qv$=_iOqs&Gw$|+eMV?i+!nz-5A;GmCpew^6d+`zBZg#F%xDH{P7J~Deg_NaBZbpB8j;@uj+ zgy!kZT};|yNlLm%%xFLCsgunEzr@$@wY0Cl*Y`8ZELI8}BGto)#%c97W**R3a{-^{ zk!{J>I2(~Ork!i{?dOqEH?A#>niW-!s8sDzarl1@``rgZJ)p9S-3vJuf8r zpoe-1?0hM@ncuM?#ZII*NQRX`3j?C-fF|ou>BFnYGQ6mbYK(vyLLpSt-Q;f(taOM@ z>}C{%9bgI48n!I<1tc~cijaU>q*(q3d2|(*bEHGncA;R}w`lk)UK6tm@3ZRaZ!?R^ z&jhS&PB>H7AxAzXoqvf^UAtYf; zs|t|oa>Lj?P7}_&thYOGP312p7H=3PmqlBh}+3NaGly|}D z27gayN~4}vcGAS4=9Rc~5$)lRH9Vkdy)d(i)hGPQm&CA9$geYjKHiN#Oh3!&?X1aH zTSyikmsK@L(K#!Dlwu}4JsYdWqJNgW7O9|^n&MWYIFiEUG+NL5$|VVwuX(Bwm)T(- z@xSL3d5aY615l{4snP=qyy9Vt=+Bl?XNnDE>sPi)Vlc%dp?p>OiM#KhbE1_sM>HJj zaOAuc4NlTUK!d})fp4V zur+((U&n&hXmBy>*<@WYwQ*&Xz1|-&tiv<3VUL35iod;0*9iRShu&LVla-jXz-QQ} z%w+x1?CzDtOezc38GOboPtW@&KN1*PQi|P;tHDtXTkm+B9&eJq_QAH*ZEnavfhMj- z%?}~P%~3e|>u)@Q=YY@~^bXT%xJTL&AEpFMdV~v0f+#%bJ~pwZL095WaTt6dxcq2k zMhSTp6oUosM&pRYk(nkVE?DYn_AT{0z*oo;1?7s(t0q@4v=Pl>Ar-4Ja;AAjcm_wFLN31MZ_-;<-nwq5+m#k~B#03Q=SYXATM literal 0 HcmV?d00001 diff --git a/docs/source/index.rst b/docs/source/index.rst index 20126917..61f0d752 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,6 +22,7 @@ Welcome to TensorFlow.NET's documentation! Graph Session Operation + Queue Gradient Train EagerMode diff --git a/src/TensorFlowNET.Core/APIs/tf.queue.cs b/src/TensorFlowNET.Core/APIs/tf.queue.cs index 717f7d2a..a5dfac15 100644 --- a/src/TensorFlowNET.Core/APIs/tf.queue.cs +++ b/src/TensorFlowNET.Core/APIs/tf.queue.cs @@ -43,5 +43,52 @@ namespace Tensorflow names, shared_name: shared_name, name: name); + + public PaddingFIFOQueue PaddingFIFOQueue(int capacity, + TF_DataType dtype, + TensorShape shape, + string shared_name = null, + string name = "padding_fifo_queue") + => new PaddingFIFOQueue(capacity, + new [] { dtype }, + new[] { shape }, + new[] { name }, + shared_name: shared_name, + name: name); + + /// + /// A queue implementation that dequeues elements in first-in first-out order. + /// + /// + /// + /// + /// + /// + /// + /// + public FIFOQueue FIFOQueue(int capacity, + TF_DataType[] dtypes, + TensorShape[] shapes = null, + string[] names = null, + string shared_name = null, + string name = "fifo_queue") + => new FIFOQueue(capacity, + dtypes, + shapes, + names, + shared_name: shared_name, + name: name); + + public FIFOQueue FIFOQueue(int capacity, + TF_DataType dtype, + TensorShape shape = null, + string shared_name = null, + string name = "fifo_queue") + => new FIFOQueue(capacity, + new[] { dtype }, + new[] { shape ?? new TensorShape() }, + new[] { name }, + shared_name: shared_name, + name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs b/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs new file mode 100644 index 00000000..fd4aa13f --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Queues/FIFOQueue.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.Queues +{ + public class FIFOQueue : QueueBase + { + public FIFOQueue(int capacity, + TF_DataType[] dtypes, + TensorShape[] shapes, + string[] names = null, + string shared_name = null, + string name = "fifo_queue") + : base(dtypes: dtypes, shapes: shapes, names: names) + { + _queue_ref = gen_data_flow_ops.fifo_queue_v2( + component_types: dtypes, + shapes: shapes, + capacity: capacity, + shared_name: shared_name, + name: name); + + _name = _queue_ref.op.name.Split('/').Last(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs index cd8cd0c0..0eb5816d 100644 --- a/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs +++ b/src/TensorFlowNET.Core/Operations/Queues/QueueBase.cs @@ -33,6 +33,59 @@ namespace Tensorflow.Queues }); } + public Operation enqueue_many(T[] vals, string name = null) + { + return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope => + { + var vals_tensor = _check_enqueue_dtypes(vals); + return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, vals_tensor, name: scope); + }); + } + + private Tensor[] _check_enqueue_dtypes(object vals) + { + var tensors = new List(); + + switch (vals) + { + case int[][] vals1: + { + int i = 0; + foreach (var (val, dtype) in zip(vals1, _dtypes)) + tensors.Add(ops.convert_to_tensor(val, dtype: dtype, name: $"component_{i++}")); + } + break; + + case int[] vals1: + tensors.Add(ops.convert_to_tensor(vals1, dtype: _dtypes[0], name: $"component_0")); + break; + + default: + throw new NotImplementedException(""); + } + + return tensors.ToArray(); + } + + /// + /// Dequeues one element from this queue. + /// + /// + /// + public Tensor dequeue(string name = null) + { + Tensor ret; + if (name == null) + name = $"{_name}_Dequeue"; + + if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) + ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name)[0]; + else + ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name)[0]; + + return ret; + } + public Tensor[] dequeue_many(int n, string name = null) { if (name == null) diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index a431380a..4fd394d2 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -61,6 +61,22 @@ namespace Tensorflow return _op.output; } + public static Tensor fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, + int capacity = -1, string container = "", string shared_name = "", + string name = null) + { + var _op = _op_def_lib._apply_op_helper("FIFOQueueV2", name, new + { + component_types, + shapes, + capacity, + container, + shared_name + }); + + return _op.output; + } + public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) { var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new @@ -85,6 +101,42 @@ namespace Tensorflow return _op; } + public static Tensor[] queue_dequeue_v2(Tensor handle, TF_DataType[] component_types, int timeout_ms = -1, string name = null) + { + var _op = _op_def_lib._apply_op_helper("QueueDequeueV2", name, new + { + handle, + component_types, + timeout_ms + }); + + return _op.outputs; + } + + public static Tensor[] queue_dequeue(Tensor handle, TF_DataType[] component_types, int timeout_ms = -1, string name = null) + { + var _op = _op_def_lib._apply_op_helper("QueueDequeue", name, new + { + handle, + component_types, + timeout_ms + }); + + return _op.outputs; + } + + public static Operation queue_enqueue_many_v2(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) + { + var _op = _op_def_lib._apply_op_helper("QueueEnqueueManyV2", name, new + { + handle, + components, + timeout_ms + }); + + return _op; + } + public static Tensor[] queue_dequeue_many_v2(Tensor handle, int n, TF_DataType[] component_types, int timeout_ms = -1, string name = null) { var _op = _op_def_lib._apply_op_helper("QueueDequeueManyV2", name, new diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index bdb2f537..8fade290 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -73,17 +73,17 @@ namespace Tensorflow public Session Session() { - return new Session(); + return new Session().as_default(); } public Session Session(Graph graph, SessionOptions opts = null) { - return new Session(graph, opts: opts); + return new Session(graph, opts: opts).as_default(); } public Session Session(SessionOptions opts) { - return new Session(null, opts); + return new Session(null, opts).as_default(); } public void __init__() diff --git a/test/TensorFlowNET.UnitTest/QueueTest.cs b/test/TensorFlowNET.UnitTest/QueueTest.cs index 451ded4a..14afbae5 100644 --- a/test/TensorFlowNET.UnitTest/QueueTest.cs +++ b/test/TensorFlowNET.UnitTest/QueueTest.cs @@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest public void PaddingFIFOQueue() { var numbers = tf.placeholder(tf.int32); - var queue = tf.PaddingFIFOQueue(capacity: 10, dtypes: new[] { tf.int32 }, shapes: new[] { new TensorShape(-1) }); + var queue = tf.PaddingFIFOQueue(10, tf.int32, new TensorShape(-1)); var enqueue = queue.enqueue(numbers); var dequeue_many = queue.dequeue_many(n: 3); @@ -32,5 +32,43 @@ namespace TensorFlowNET.UnitTest Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray())); } } + + [TestMethod] + public void FIFOQueue() + { + // create a first in first out queue with capacity up to 2 + // and data type set as int32 + var queue = tf.FIFOQueue(2, tf.int32); + // init queue, push 3 elements into queue. + var init = queue.enqueue_many(new[] { 10, 20 }); + // pop out the first element + var x = queue.dequeue(); + // add 1 + var y = x + 1; + // push back into queue + var inc = queue.enqueue(y); + + using (var sess = tf.Session()) + { + // init queue + init.run(); + + // pop out first element and push back calculated y + (int dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(10, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(20, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(11, dequeued); + + (dequeued, _) = sess.run((x, inc)); + Assert.AreEqual(21, dequeued); + + // thread will hang or block if you run sess.run(x) again + // until queue has more element. + } + } } }