/***************************************************************************** Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ******************************************************************************/ using NumSharp; using System; using System.Collections.Generic; using static Tensorflow.Binding; namespace Tensorflow.Keras { public class BackendImpl : BackendBase { /* ---------------------------------------- KERAS BACKEND NATIVE OBJECTS ---------------------------------------- */ public Func py_sum = sum; public Func py_all = all; //Func py_any = any; //Func> py_slice = slice; public Session _SESSION => ops.get_default_session(); public Graph _GRAPH; public Dictionary _GRAPH_LEARNING_PHASES; //Dictionary> PER_GRAPH_LAYER_NAME_UIDS; public bool _MANUAL_VAR_INIT = false; public List _LOCAL_DEVICES = null; /* -------------------------------------- KERAS BACKEND NATIVE OBJECTS END -------------------------------------- */ /// /// A global dictionary mapping graph objects to an index of counters used /// for various layer names in each graph. /// Allows to give unique autogenerated names to layers, in a graph-specific way. /// public Dictionary> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); public Dictionary _GRAPH_VARIABLES = new Dictionary(); public Dictionary _GRAPH_TF_OPTIMIZERS = new Dictionary(); public _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph(); public BackendImpl() { } public void track_variable(IVariableV1 v) { var graph = v.Graph; _GRAPH_VARIABLES[graph.graph_key] = v; } public Tensor placeholder(TensorShape shape = null, int ndim = -1, TF_DataType dtype = TF_DataType.DtInvalid, bool sparse = false, string name = null, bool ragged = false) { if (sparse) { throw new NotImplementedException("placeholder sparse is true"); } else { return array_ops.placeholder(dtype: dtype, shape: shape, name: name); } } public Graph get_graph() { return ops.get_default_graph(); } public int get_uid(string prefix) { var graph = tf.get_default_graph(); if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict()); if (!PER_GRAPH_LAYER_NAME_UIDS[graph].ContainsKey(prefix)) PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] = 0; PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] += 1; return PER_GRAPH_LAYER_NAME_UIDS[graph][prefix]; } public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); public void clear_session() { ops.reset_default_graph(); reset_uids(); ops.set_default_session(tf.Session(ops.get_default_graph())); var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase"); _GRAPH_LEARNING_PHASES = new Dictionary(); _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0; } public void manual_variable_initialization(bool value) { _MANUAL_VAR_INIT = value; } public GraphLearningPhase learning_phase() { var graph = tf.get_default_graph(); if (_GRAPH_LEARNING_PHASES.ContainsKey(graph)) { var phase = tf.placeholder_with_default(false, shape: new int[] { }, name: "keras_learning_phase"); _GRAPH_LEARNING_PHASES[graph] = 0; } return _GRAPH_LEARNING_PHASES[graph]; } public void set_learning_phase(bool value) { _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); } /// /// Pads the 2nd and 3rd dimensions of a 4D tensor. /// /// /// /// /// public Tensor spatial_2d_padding(Tensor x, NDArray padding = null, string data_format = null) { if (padding == null) padding = new[,] { { 1, 1 }, { 1, 1 } }; NDArray pattern; if (data_format == "channels_first") pattern = new int[,] { { 0, 0 }, { 0, 0 }, { padding[0][0], padding[0][1] }, { padding[1][0], padding[1][1] } }; else pattern = new int[,] { { 0, 0 }, { padding[0][0], padding[0][1] }, { padding[1][0], padding[1][1] }, { 0, 0 } }; return array_ops.pad(x, pattern); } /// /// Method to evaluate a tensor in eager or in a tf.function. /// /// /// public NDArray eval_in_eager_or_function(Tensor outputs) { return outputs.eval(); } public class _DummyEagerGraph { } /// /// Categorical crossentropy between an output tensor and a target tensor. /// /// /// /// /// /// public Tensor categorical_crossentropy(Tensor target, Tensor output, bool from_logits = false, int axis = -1) { if (from_logits) return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: output, axis: axis); throw new NotImplementedException(""); } /// /// Resizes the images contained in a 4D tensor. /// /// /// /// /// /// /// public Tensor resize_images(Tensor x, int height_factor, int width_factor, string data_format, string interpolation = "nearest") { var (rows, cols) = (0, 0); if (data_format == "channels_first") (rows, cols) = (2, 3); else if (data_format == "channels_last") (rows, cols) = (1, 2); else throw new ValueError($"Invalid `data_format` argument: {data_format}"); var original_shape = x.shape; var new_shape = array_ops.shape(x)[new Slice(rows, cols + 1)]; new_shape *= constant_op.constant(np.array(height_factor, width_factor)); if (data_format == "channels_first") // x = permute_dimensions(x, [0, 2, 3, 1]); throw new NotImplementedException(""); if (interpolation == "nearest") x = tf.image.resize_images_v2(x, new_shape, method: ResizeMethod.NEAREST_NEIGHBOR); if (data_format == "channels_first") // x = permute_dimensions(x, [0, 3, 1, 2]); throw new NotImplementedException(""); int new_height = original_shape[rows] < 0 ? -1 : original_shape[rows] * height_factor; int new_width = original_shape[cols] < 0 ? -1 : original_shape[cols] * width_factor; TensorShape output_shape = data_format == "channels_first" ? (-1, -1, new_height, new_width) : (-1, new_height, new_width, -1); x.set_shape(output_shape); return x; } } }