Browse Source

Merge branch 'master' of github.com:AsakusaRinne/TensorFlow.NET into support_function_load

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
1903700e04
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
19 changed files with 537 additions and 60 deletions
  1. +9
    -4
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  2. +13
    -7
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  3. +40
    -22
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  4. +73
    -0
      src/TensorFlowNET.Core/Contexts/Context.Device.cs
  5. +5
    -1
      src/TensorFlowNET.Core/Contexts/Context.cs
  6. +71
    -0
      src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs
  7. +205
    -0
      src/TensorFlowNET.Core/Device/DeviceSpec.cs
  8. +26
    -0
      src/TensorFlowNET.Core/Device/DeviceUtils.cs
  9. +9
    -2
      src/TensorFlowNET.Core/Graphs/Graph.cs
  10. +31
    -0
      src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  12. +2
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  13. +14
    -10
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  14. +4
    -2
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  15. +8
    -5
      src/TensorFlowNET.Core/Variables/UninitializedVariable.cs
  16. +17
    -0
      src/TensorFlowNET.Core/ops.cs
  17. +2
    -4
      src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
  18. +6
    -2
      src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs
  19. +1
    -0
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 9
- 4
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -54,8 +54,11 @@ public static class SaveUtilV1
var g = to_graph.as_default();
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
tf.device("/cpu:0");
var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ =>
{
// TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception.
return constant_op.constant(graph_proto.ToByteArray());
});
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
g.Exit();
return (named_saveable_objects, registered_savers);
@@ -66,8 +69,10 @@ public static class SaveUtilV1
{
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
tf.device("/cpu:0");
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ =>
{
return constant_op.constant(graph_proto.ToString());
});
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
return (named_saveable_objects, registered_savers);
}


+ 13
- 7
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -59,8 +59,10 @@ public class TrackableSaver

if(object_graph_tensor is null)
{
tf.device("/cpu:0");
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
tf_with(ops.device("/cpu:0"), _ =>
{
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
});
}
else
{
@@ -232,13 +234,15 @@ public class TrackableSaver
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING);

Dictionary<Tensor, string> file_prefix_feed_dict;
Tensor file_prefix_tensor;
Tensor file_prefix_tensor = null;
if (graph_building)
{
if(_file_prefix_placeholder is null)
{
tf.device("/cpu:0");
_file_prefix_placeholder = constant_op.constant("model");
_file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ =>
{
return constant_op.constant("model");
});
}
file_prefix_tensor = _file_prefix_placeholder;
file_prefix_feed_dict = new();
@@ -246,8 +250,10 @@ public class TrackableSaver
}
else
{
tf.device("/cpu:0");
file_prefix_tensor = constant_op.constant(save_path);
file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ =>
{
return constant_op.constant(save_path);
});
file_prefix_feed_dict = null;
}
TrackableObjectGraph object_graph_proto = new();


+ 40
- 22
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -117,9 +117,11 @@ namespace Tensorflow.Checkpoint

string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!;

// tf python has code `with ops.device(restore_device):` here.
tf.device(restore_device); // may be risky.
var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
Tensor[] restored_tensors = null;
tf_with(ops.device(restore_device), _ =>
{
restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
});

Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new();
int idx = 0;
@@ -243,11 +245,14 @@ namespace Tensorflow.Checkpoint
options = new CheckpointOptions();
}

tf.device("CPU"); // may be risky.
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")),
Tensor tmp_checkpoint_prefix = null;
tf_with(ops.device("CPU"), _ =>
{
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")),
constant_op.constant(".part"), constant_op.constant("_temp/part"));
var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix });
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));
tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix });
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));
});

Operation save_fn()
{
@@ -269,16 +274,24 @@ namespace Tensorflow.Checkpoint
var saver = pair.Value;
last_device = device;
// skip the extra process of device name because of lack of API.
tf.device(device);
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
Tensor shard_prefix = null;
tf_with(ops.device(device), _ =>
{
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
});
saved_prefixes.Add(shard_prefix);
sharded_saves.Add(saver.save(shard_prefix, options));
tf_with(ops.device(device), _ =>
{
sharded_saves.Add(saver.save(shard_prefix, options));
});
}
using (var controller = ops.control_dependencies(sharded_saves.ToArray()))
{
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device;
tf.device(merge_device);
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true);
return tf_with(ops.device(merge_device), _ =>
{
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true);
});
}
}

@@ -312,8 +325,9 @@ namespace Tensorflow.Checkpoint
{
var device = single_saver.Key;
var saver = single_saver.Value;
tf.device(device);
var restored_tensor_dict = saver.restore(file_prefix, options);
tf_with(ops.device(device), _ =>
{
var restored_tensor_dict = saver.restore(file_prefix, options);

foreach(var pair in restored_tensor_dict)
{
@@ -405,21 +419,25 @@ namespace Tensorflow.Checkpoint
private Tensor _traced_save(Tensor file_prefix)
{
var save_op = save(file_prefix);
tf.device("cpu:0");
using (ops.control_dependencies(new object[]{ save_op }))
return tf_with(ops.device("cpu:0"), _ =>
{
return array_ops.identity(file_prefix);
}
return tf_with(ops.control_dependencies(new object[] { save_op }), __ =>
{
return array_ops.identity(file_prefix);
});
});
}

private Tensor _traced_restore(Tensor file_prefix)
{
var restore_op = restore(file_prefix);
tf.device("cpu:0");
using (ops.control_dependencies(restore_op.Values.ToArray()))
return tf_with(ops.device("cpu:0"), _ =>
{
return array_ops.identity(file_prefix);
}
return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ =>
{
return array_ops.identity(file_prefix);
});
});
}

public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false)


+ 73
- 0
src/TensorFlowNET.Core/Contexts/Context.Device.cs View File

@@ -21,6 +21,7 @@ using Tensorflow.Eager;
using static Tensorflow.Binding;
using Google.Protobuf;
using Tensorflow.Device;
using Tensorflow.Exceptions;
using System.Collections.Generic;

namespace Tensorflow.Contexts
@@ -30,10 +31,30 @@ namespace Tensorflow.Contexts
/// </summary>
public sealed partial class Context
{
internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new();
internal List<LogicalDevice> _logical_devices = null;
internal List<string> _context_devices = null;
ContextDevicePlacementPolicy _device_policy;
bool _log_device_placement;
int _num_gpus;
Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>();

public string DeviceName { get; set; } = "";
public DeviceSpec DeviceSpec { get; set; } = null;

internal List<string> Devices
{
get
{
if(_context_devices is null)
{
throw new AssertionError("Context must be initialized first.");
}
return _context_devices;
}
}

public void log_device_placement(bool enable)
{
if (_handle != null)
@@ -89,5 +110,57 @@ namespace Tensorflow.Contexts

return results.ToArray();
}

public EagerDeviceContext device(string name)
{
return new EagerDeviceContext(this, name);
}

internal void _set_device(string device_name, DeviceSpec device_spec)
{
DeviceSpec = device_spec;
DeviceName = device_name;
}

internal void _initialize_logical_devices()
{
List<LogicalDevice> logical_devices = new();
List<string> context_devices = new();
Status status = new();
var device_list = c_api.TFE_ContextListDevices(_handle, status);
status.Check(true);
try
{
this._num_gpus = 0;
string current_job = null;
int current_task = -1;
for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++)
{
var dev_name = c_api.TF_DeviceListName(device_list, i, status);
status.Check(true);
context_devices.Add(DeviceUtils.canonical_name(dev_name));
var spec = DeviceSpec.from_string(dev_name);
if(spec.Job == "localhost")
{
spec = spec.replace(job: null, replica: -1, task: -1);
}
logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType));
var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status);
var dev_type = c_api.StringPiece(dev_type_memory);
status.Check(true);
if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task)
{
_num_gpus++;
}
}
}
finally
{
_logical_devices = logical_devices;
_context_devices = context_devices;
}
}
}

public record class LogicalDevice(string name, string device_type);
}

+ 5
- 1
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -34,7 +34,6 @@ namespace Tensorflow.Contexts
public const int EAGER_MODE = 1;

int defaultExecutionMode = EAGER_MODE;
public string DeviceName { get; set; } = "";
public string ScopeName { get; set; } = "";
bool initialized = false;
ContextSwitchStack context_switches;
@@ -81,6 +80,9 @@ namespace Tensorflow.Contexts
if (initialized)
return;

Debug.Assert(_context_devices is null);

Config = MergeConfig();
FunctionCallOptions.Config = Config;
var config_str = Config.ToByteArray();
var opts = new ContextOptions();
@@ -90,6 +92,7 @@ namespace Tensorflow.Contexts
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy);
_handle = c_api.TFE_NewContext(opts, status);
status.Check(true);
_initialize_logical_devices();
initialized = true;
}

@@ -228,6 +231,7 @@ namespace Tensorflow.Contexts
{
c_api.TFE_ContextClearCaches(_handle);
}
_device_parsing_cache.Clear();
}

public static implicit operator SafeContextHandle(Context ctx)


+ 71
- 0
src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs View File

@@ -0,0 +1,71 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Device;

namespace Tensorflow.Contexts
{
public class EagerDeviceContext : ITensorFlowObject
{
private Context _ctx;
private string _device_name;
private Stack<(string, DeviceSpec, DeviceSpec)> _stack;

public EagerDeviceContext(Context ctx, string device_name)
{
_ctx = ctx;
_device_name = device_name;
_stack = new Stack<(string, DeviceSpec, DeviceSpec)>();
}
public void __enter__()
{
var ctx = _ctx;
var old_device_name = ctx.DeviceName;
var old_device_spec = ctx.DeviceSpec;
var new_device_name = _device_name;
var cache_key = (old_device_name, new_device_name);
DeviceSpec new_device_spec;
if (Context._device_parsing_cache.ContainsKey(cache_key))
{
(new_device_name, new_device_spec) = Context._device_parsing_cache[cache_key];
}
else
{
if(new_device_name is not null)
{
var device_spec = DeviceSpec.from_string(new_device_name);
if (!string.IsNullOrEmpty(old_device_name))
{
new_device_spec = new DeviceSpec(old_device_spec);
}
else
{
ctx.ensure_initialized();
new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]);
}
new_device_spec = new_device_spec.make_merged_spec(device_spec);
}
else
{
new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]);
}
new_device_name = new_device_spec.ToString();
Context._device_parsing_cache[cache_key] = (new_device_name, new_device_spec);
}
ctx._set_device(new_device_name, new_device_spec);
_stack.Push((old_device_name, old_device_spec, new_device_spec));
}

public void __exit__()
{
var ctx = _ctx;
var (old_device_name, old_device_spec, new_device_spec) = _stack.Pop();
ctx._set_device(old_device_name, old_device_spec);
}

public void Dispose()
{

}
}
}

+ 205
- 0
src/TensorFlowNET.Core/Device/DeviceSpec.cs View File

@@ -0,0 +1,205 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;

namespace Tensorflow.Device
{
public class DeviceSpec
{
private static Dictionary<string, Components> _STRING_TO_COMPONENTS_CACHE = new();
private static Dictionary<Components, string> _COMPONENTS_TO_STRING_CACHE = new();
private string _job;
private int _replica;
private int _task;
private string _device_type;
private int _device_index;
private string _as_string;

public string Job => _job;
public int Replica => _replica;
public int Task => _task;
public string DeviceType => _device_type;
public int DeviceIndex => _device_index;

public DeviceSpec(string job = null, int replica = -1, int task = -1,
string device_type = null, int device_index = -1)
{
_job = job;
_replica = replica;
_task = task;
_device_type = device_type;
_device_index = device_index;
_as_string = _components_to_string(job, replica, task, device_type, _device_index);

}

public DeviceSpec(DeviceSpec other)
{
_job = other._job;
_replica = other._replica;
_task = other._task;
_device_type = other._device_type;
_device_index = other._device_index;
_as_string = other._as_string;
}

protected DeviceSpec(Components com)
{
_job = com.Job;
_replica = com.Replica;
_task = com.Task;
_device_type = com.DeviceType;
_device_index = com.DeviceIndex;
_as_string = _components_to_string(_job, _replica, _task, _device_type, _device_index);
}

public DeviceSpec replace(string job = null, int replica = -1, int task = -1,
string device_type = null, int device_index = -1)
{
job = job ?? _job;
replica = replica == -1 ? _replica : replica;
task = task == -1 ? _task : task;
device_type = device_type ?? _device_type;
device_index = device_index == -1 ? _device_index : device_index;
return new DeviceSpec(job, replica, task, device_type, device_index);
}

public static DeviceSpec from_string(string spec)
{
var components = _string_to_components(spec);
return new DeviceSpec(components.Job, components.Replica, components.Task, components.DeviceType, components.DeviceIndex);
}

public DeviceSpec make_merged_spec(DeviceSpec dev)
{
return new DeviceSpec(_get_combined_properties(dev));
}

private Components _get_combined_properties(DeviceSpec dev)
{
return new Components(
dev.Job ?? _job,
dev.Replica == -1 ? _replica : dev.Replica,
dev.Task == -1 ? _task : dev.Task,
dev.DeviceType ?? _device_type,
dev.DeviceIndex == -1 ? _device_index : dev.DeviceIndex
);
}

private static string _components_to_string(string job, int replica, int task, string device_type, int device_index)
{
var key = new Components(job, replica, task, device_type, device_index);
if(_COMPONENTS_TO_STRING_CACHE.TryGetValue(key, out var cache_result))
{
return cache_result;
}

StringBuilder output = new();
if(job is not null)
{
output.Append($"/job:{job}");
}
if(replica != -1)
{
output.Append($"/replica:{replica}");
}
if(task != -1)
{
output.Append($"/task:{task}");
}
if (device_type is not null)
{
string device_index_string = "*";
if (device_index != -1)
{
device_index_string = device_index.ToString();
}
output.Append($"/device:{device_type}:{device_index_string}");
}
var result = output.ToString();
_COMPONENTS_TO_STRING_CACHE[key] = result;
return result;
}

private static Components _string_to_components(string spec)
{
if(_STRING_TO_COMPONENTS_CACHE.TryGetValue(spec, out var cached_result))
{
return cached_result;
}
var raw_spec = spec;
var splits = spec.Split('/').Select(x => x.Split(':'));
var valid_device_types = _get_valid_device_types();
string job = null, device_type = null;
int replica = -1, task = -1, device_index = -1;
foreach (var y in splits)
{
var ly = y.Length;
if (ly > 0)
{
if(ly == 2 && y[0] == "job")
{
job = y[1];
}
else if(ly == 2 && y[0] == "replica")
{
replica = int.Parse(y[1]);
}
else if(ly == 2 && y[0] == "task")
{
task = int.Parse(y[1]);
}
else if((ly == 1 || ly == 2) && valid_device_types.Contains(y[0].ToUpper()))
{
if (device_type is not null)
{
throw new ValueError($"Multiple device types are not allowed " +
$"while parsing the device spec: {spec}.");
}
device_type = y[0].ToUpper();
if(ly == 2 && y[1] != "*")
{
device_index = int.Parse(y[1]);
}
}
else if(ly == 3 && y[0] == "device")
{
if(device_type is not null)
{
throw new ValueError($"Multiple device types are not allowed " +
$"while parsing the device spec: {spec}.");
}
device_type = y[1];
if (y[2] != "*")
{
device_index = int.Parse(y[2]);
}
}
else if (y[0] != "")
{
throw new ValueError($"Unknown attribute '{y[0]}' is encountered " +
$"while parsing the device spec: {spec}.");
}
}
}

var output = new Components(job, replica, task, device_type, device_index);
_STRING_TO_COMPONENTS_CACHE[raw_spec] = output;
return output;
}

private static HashSet<string> _get_valid_device_types()
{
// TODO(Rinne): revise it to calling C API (need customized API).
return new HashSet<string>(new string[] { "CPU", "GPU" });
}

public override string ToString()
{
return _as_string;
}

protected record class Components(string Job, int Replica, int Task, string DeviceType, int DeviceIndex);
}
}

+ 26
- 0
src/TensorFlowNET.Core/Device/DeviceUtils.cs View File

@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Device
{
internal static class DeviceUtils
{
public static string canonical_name(string device)
{
if(device is null)
{
return "";
}
return DeviceSpec.from_string(device).ToString();
}
public static string canonical_name(DeviceSpec device)
{
if (device is null)
{
return "";
}
return device.ToString();
}
}
}

+ 9
- 2
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -22,6 +22,7 @@ using System.Linq;
using Tensorflow.Framework;
using Tensorflow.Functions;
using Tensorflow.Common.Extensions;
using Tensorflow.Graphs;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -344,9 +345,15 @@ namespace Tensorflow
return op;
}

public void device(string device_name)
public ITensorFlowObject device(string device_name)
{
return new GraphDeviceContext(this, device_name);
}

private void add_device_to_stack(string device_name, int offset = 0)
{
// TODO(Rinne): deal with device spec.
int total_offset = offset + 1;
}

private void _create_op_helper(Operation op, bool compute_device = true)


+ 31
- 0
src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Graphs
{
public class GraphDeviceContext : ITensorFlowObject
{
private Graph _graph;

public GraphDeviceContext(Graph graph, string device_name)
{
_graph = graph;
}

public void __enter__()
{

}

public void __exit__()
{

}

public void Dispose()
{

}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow.Keras
List<IVariableV1> TrainableVariables { get; }
List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; }
List<IVariableV1> Weights { get; }
List<IVariableV1> Weights { get; set; }
Shape OutputShape { get; }
Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; }


+ 2
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -84,6 +84,8 @@ namespace Tensorflow
protected bool built = false;
public bool Built => built;

List<IVariableV1> ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public RnnCell(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,


+ 14
- 10
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -413,8 +413,10 @@ namespace Tensorflow
{
var variables_path = SavedModelUtils.get_variables_path(_export_dir);
var saver = new TrackableSaver(new ObjectGraphView(get(0)));
tf.device("CPU");
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path);
tf_with(ops.device("CPU"), _ =>
{
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path);
});
LoadStatus load_status;
if (_save_options.allow_partial_checkpoint)
{
@@ -600,14 +602,16 @@ namespace Tensorflow

if (load_with_device)
{
tf.device(saved_device);
return (new UninitializedVariable(
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()),
dtype: (TF_DataType)proto.Dtype,
name: name,
trainable: trainable,
aggregation: aggregation
), setattr);
return tf_with(ops.device(saved_device), _ =>
{
return (new UninitializedVariable(
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()),
dtype: (TF_DataType)proto.Dtype,
name: name,
trainable: trainable,
aggregation: aggregation
), setattr);
});
}
else
{


+ 4
- 2
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -282,9 +282,11 @@ namespace Tensorflow
BaseResourceVariable new_variable;
if (save_options.experimental_variable_policy.save_variable_devices())
{
tf.device(this.Device);
Debug.Assert(this is ResourceVariable);
new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
new_variable = tf_with(ops.device(this.Device), _ =>
{
return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
});
}
else
{


+ 8
- 5
src/TensorFlowNET.Core/Variables/UninitializedVariable.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.Variables
/// <summary>
/// A variable with no initializer.
/// </summary>
public sealed class UninitializedVariable: BaseResourceVariable, IVariableV1
public sealed class UninitializedVariable : BaseResourceVariable, IVariableV1
{
// TODO: complete the arg list.
public UninitializedVariable(
@@ -19,7 +19,7 @@ namespace Tensorflow.Variables
TF_DataType dtype = TF_DataType.DtInvalid,
VariableAggregation aggregation = VariableAggregation.None,
Shape shape = null,
Tensor extra_handle_data = null)
Tensor extra_handle_data = null)
{
string unique_id = "";
string handle_name = "";
@@ -50,9 +50,12 @@ namespace Tensorflow.Variables
{
tf_with(ops.name_scope("Read"), _ =>
{
tf.device(created_handle.Device);
var value = gen_resource_variable_ops.read_variable_op(created_handle, dtype);
resource_variable_ops._maybe_set_handle_data(dtype, created_handle, value);
var value = tf_with(ops.device(created_handle.Device), _ =>
{
var result = gen_resource_variable_ops.read_variable_op(created_handle, dtype);
resource_variable_ops._maybe_set_handle_data(dtype, created_handle, result);
return result;
});
_graph_element = value;
});
ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this);


+ 17
- 0
src/TensorFlowNET.Core/ops.cs View File

@@ -584,6 +584,23 @@ namespace Tensorflow
}

public static ITensorFlowObject device(string device_name)
{
if (tf.Context.executing_eagerly())
{
return tf.Context.device(device_name);
}
//else if (ops.executing_eagerly_outside_functions())
//{
// throw new NotImplementedException();
//}
else
{
return get_default_graph().device(device_name);
}
// TODO(Rinne): deal with `ops.executing_eagerly_outside_functions()`.
}

public class NullContextManager: IDisposable
{
public void Dispose()


+ 2
- 4
src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs View File

@@ -77,7 +77,7 @@ public class EarlyStopping: ICallback
// Restore the weights after first epoch if no progress is ever made.
if (_restore_best_weights && _best_weights == null)
{
_best_weights = _parameters.Model.TrainableWeights;
_best_weights = _parameters.Model.Weights;
}
_wait += 1;

@@ -102,10 +102,8 @@ public class EarlyStopping: ICallback
{
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
}
_parameters.Model.Weights = _best_weights;
}
// Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet.
// TODO(Wanglongzhi2001): implement it.
// _parameters.Model.load_weights(best_weights);
}
}
public void on_train_end()


+ 6
- 2
src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs View File

@@ -4,17 +4,21 @@ namespace Tensorflow.Keras.Losses
{
public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc
{
private bool _from_logits = false;
public SparseCategoricalCrossentropy(
bool from_logits = false,
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name){ }
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name)
{
_from_logits = from_logits;
}

public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1)
{
target = tf.cast(target, dtype: TF_DataType.TF_INT64);

if (!from_logits)
if (!_from_logits)
{
var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype);
output = tf.clip_by_value(output, epsilon, 1 - epsilon);


+ 1
- 0
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -12,6 +12,7 @@ namespace TensorFlowNET.Keras.UnitTest.SaveModel;
[TestClass]
public class SequentialModelLoad
{
[Ignore]
[TestMethod]
public void SimpleModelFromAutoCompile()
{


Loading…
Cancel
Save