@@ -54,8 +54,11 @@ public static class SaveUtilV1 | |||
var g = to_graph.as_default(); | |||
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||
object_map, call_with_mapped_captures, saveables_cache); | |||
tf.device("/cpu:0"); | |||
var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => | |||
{ | |||
// TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception. | |||
return constant_op.constant(graph_proto.ToByteArray()); | |||
}); | |||
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
g.Exit(); | |||
return (named_saveable_objects, registered_savers); | |||
@@ -66,8 +69,10 @@ public static class SaveUtilV1 | |||
{ | |||
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||
object_map, call_with_mapped_captures, saveables_cache); | |||
tf.device("/cpu:0"); | |||
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => | |||
{ | |||
return constant_op.constant(graph_proto.ToString()); | |||
}); | |||
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
return (named_saveable_objects, registered_savers); | |||
} | |||
@@ -59,8 +59,10 @@ public class TrackableSaver | |||
if(object_graph_tensor is null) | |||
{ | |||
tf.device("/cpu:0"); | |||
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||
tf_with(ops.device("/cpu:0"), _ => | |||
{ | |||
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||
}); | |||
} | |||
else | |||
{ | |||
@@ -232,13 +234,15 @@ public class TrackableSaver | |||
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | |||
Dictionary<Tensor, string> file_prefix_feed_dict; | |||
Tensor file_prefix_tensor; | |||
Tensor file_prefix_tensor = null; | |||
if (graph_building) | |||
{ | |||
if(_file_prefix_placeholder is null) | |||
{ | |||
tf.device("/cpu:0"); | |||
_file_prefix_placeholder = constant_op.constant("model"); | |||
_file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ => | |||
{ | |||
return constant_op.constant("model"); | |||
}); | |||
} | |||
file_prefix_tensor = _file_prefix_placeholder; | |||
file_prefix_feed_dict = new(); | |||
@@ -246,8 +250,10 @@ public class TrackableSaver | |||
} | |||
else | |||
{ | |||
tf.device("/cpu:0"); | |||
file_prefix_tensor = constant_op.constant(save_path); | |||
file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ => | |||
{ | |||
return constant_op.constant(save_path); | |||
}); | |||
file_prefix_feed_dict = null; | |||
} | |||
TrackableObjectGraph object_graph_proto = new(); | |||
@@ -117,9 +117,11 @@ namespace Tensorflow.Checkpoint | |||
string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; | |||
// tf python has code `with ops.device(restore_device):` here. | |||
tf.device(restore_device); // may be risky. | |||
var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
Tensor[] restored_tensors = null; | |||
tf_with(ops.device(restore_device), _ => | |||
{ | |||
restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
}); | |||
Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | |||
int idx = 0; | |||
@@ -243,11 +245,14 @@ namespace Tensorflow.Checkpoint | |||
options = new CheckpointOptions(); | |||
} | |||
tf.device("CPU"); // may be risky. | |||
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||
Tensor tmp_checkpoint_prefix = null; | |||
tf_with(ops.device("CPU"), _ => | |||
{ | |||
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||
constant_op.constant(".part"), constant_op.constant("_temp/part")); | |||
var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||
tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||
}); | |||
Operation save_fn() | |||
{ | |||
@@ -269,16 +274,24 @@ namespace Tensorflow.Checkpoint | |||
var saver = pair.Value; | |||
last_device = device; | |||
// skip the extra process of device name because of lack of API. | |||
tf.device(device); | |||
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||
Tensor shard_prefix = null; | |||
tf_with(ops.device(device), _ => | |||
{ | |||
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||
}); | |||
saved_prefixes.Add(shard_prefix); | |||
sharded_saves.Add(saver.save(shard_prefix, options)); | |||
tf_with(ops.device(device), _ => | |||
{ | |||
sharded_saves.Add(saver.save(shard_prefix, options)); | |||
}); | |||
} | |||
using (var controller = ops.control_dependencies(sharded_saves.ToArray())) | |||
{ | |||
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | |||
tf.device(merge_device); | |||
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||
return tf_with(ops.device(merge_device), _ => | |||
{ | |||
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||
}); | |||
} | |||
} | |||
@@ -312,8 +325,9 @@ namespace Tensorflow.Checkpoint | |||
{ | |||
var device = single_saver.Key; | |||
var saver = single_saver.Value; | |||
tf.device(device); | |||
var restored_tensor_dict = saver.restore(file_prefix, options); | |||
tf_with(ops.device(device), _ => | |||
{ | |||
var restored_tensor_dict = saver.restore(file_prefix, options); | |||
foreach(var pair in restored_tensor_dict) | |||
{ | |||
@@ -405,21 +419,25 @@ namespace Tensorflow.Checkpoint | |||
private Tensor _traced_save(Tensor file_prefix) | |||
{ | |||
var save_op = save(file_prefix); | |||
tf.device("cpu:0"); | |||
using (ops.control_dependencies(new object[]{ save_op })) | |||
return tf_with(ops.device("cpu:0"), _ => | |||
{ | |||
return array_ops.identity(file_prefix); | |||
} | |||
return tf_with(ops.control_dependencies(new object[] { save_op }), __ => | |||
{ | |||
return array_ops.identity(file_prefix); | |||
}); | |||
}); | |||
} | |||
private Tensor _traced_restore(Tensor file_prefix) | |||
{ | |||
var restore_op = restore(file_prefix); | |||
tf.device("cpu:0"); | |||
using (ops.control_dependencies(restore_op.Values.ToArray())) | |||
return tf_with(ops.device("cpu:0"), _ => | |||
{ | |||
return array_ops.identity(file_prefix); | |||
} | |||
return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ => | |||
{ | |||
return array_ops.identity(file_prefix); | |||
}); | |||
}); | |||
} | |||
public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | |||
@@ -21,6 +21,7 @@ using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
using Google.Protobuf; | |||
using Tensorflow.Device; | |||
using Tensorflow.Exceptions; | |||
using System.Collections.Generic; | |||
namespace Tensorflow.Contexts | |||
@@ -30,10 +31,30 @@ namespace Tensorflow.Contexts | |||
/// </summary> | |||
public sealed partial class Context | |||
{ | |||
internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new(); | |||
internal List<LogicalDevice> _logical_devices = null; | |||
internal List<string> _context_devices = null; | |||
ContextDevicePlacementPolicy _device_policy; | |||
bool _log_device_placement; | |||
int _num_gpus; | |||
Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>(); | |||
public string DeviceName { get; set; } = ""; | |||
public DeviceSpec DeviceSpec { get; set; } = null; | |||
internal List<string> Devices | |||
{ | |||
get | |||
{ | |||
if(_context_devices is null) | |||
{ | |||
throw new AssertionError("Context must be initialized first."); | |||
} | |||
return _context_devices; | |||
} | |||
} | |||
public void log_device_placement(bool enable) | |||
{ | |||
if (_handle != null) | |||
@@ -89,5 +110,57 @@ namespace Tensorflow.Contexts | |||
return results.ToArray(); | |||
} | |||
public EagerDeviceContext device(string name) | |||
{ | |||
return new EagerDeviceContext(this, name); | |||
} | |||
internal void _set_device(string device_name, DeviceSpec device_spec) | |||
{ | |||
DeviceSpec = device_spec; | |||
DeviceName = device_name; | |||
} | |||
internal void _initialize_logical_devices() | |||
{ | |||
List<LogicalDevice> logical_devices = new(); | |||
List<string> context_devices = new(); | |||
Status status = new(); | |||
var device_list = c_api.TFE_ContextListDevices(_handle, status); | |||
status.Check(true); | |||
try | |||
{ | |||
this._num_gpus = 0; | |||
string current_job = null; | |||
int current_task = -1; | |||
for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++) | |||
{ | |||
var dev_name = c_api.TF_DeviceListName(device_list, i, status); | |||
status.Check(true); | |||
context_devices.Add(DeviceUtils.canonical_name(dev_name)); | |||
var spec = DeviceSpec.from_string(dev_name); | |||
if(spec.Job == "localhost") | |||
{ | |||
spec = spec.replace(job: null, replica: -1, task: -1); | |||
} | |||
logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType)); | |||
var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status); | |||
var dev_type = c_api.StringPiece(dev_type_memory); | |||
status.Check(true); | |||
if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task) | |||
{ | |||
_num_gpus++; | |||
} | |||
} | |||
} | |||
finally | |||
{ | |||
_logical_devices = logical_devices; | |||
_context_devices = context_devices; | |||
} | |||
} | |||
} | |||
public record class LogicalDevice(string name, string device_type); | |||
} |
@@ -34,7 +34,6 @@ namespace Tensorflow.Contexts | |||
public const int EAGER_MODE = 1; | |||
int defaultExecutionMode = EAGER_MODE; | |||
public string DeviceName { get; set; } = ""; | |||
public string ScopeName { get; set; } = ""; | |||
bool initialized = false; | |||
ContextSwitchStack context_switches; | |||
@@ -81,6 +80,9 @@ namespace Tensorflow.Contexts | |||
if (initialized) | |||
return; | |||
Debug.Assert(_context_devices is null); | |||
Config = MergeConfig(); | |||
FunctionCallOptions.Config = Config; | |||
var config_str = Config.ToByteArray(); | |||
var opts = new ContextOptions(); | |||
@@ -90,6 +92,7 @@ namespace Tensorflow.Contexts | |||
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); | |||
_handle = c_api.TFE_NewContext(opts, status); | |||
status.Check(true); | |||
_initialize_logical_devices(); | |||
initialized = true; | |||
} | |||
@@ -228,6 +231,7 @@ namespace Tensorflow.Contexts | |||
{ | |||
c_api.TFE_ContextClearCaches(_handle); | |||
} | |||
_device_parsing_cache.Clear(); | |||
} | |||
public static implicit operator SafeContextHandle(Context ctx) | |||
@@ -0,0 +1,71 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Device; | |||
namespace Tensorflow.Contexts | |||
{ | |||
public class EagerDeviceContext : ITensorFlowObject | |||
{ | |||
private Context _ctx; | |||
private string _device_name; | |||
private Stack<(string, DeviceSpec, DeviceSpec)> _stack; | |||
public EagerDeviceContext(Context ctx, string device_name) | |||
{ | |||
_ctx = ctx; | |||
_device_name = device_name; | |||
_stack = new Stack<(string, DeviceSpec, DeviceSpec)>(); | |||
} | |||
public void __enter__() | |||
{ | |||
var ctx = _ctx; | |||
var old_device_name = ctx.DeviceName; | |||
var old_device_spec = ctx.DeviceSpec; | |||
var new_device_name = _device_name; | |||
var cache_key = (old_device_name, new_device_name); | |||
DeviceSpec new_device_spec; | |||
if (Context._device_parsing_cache.ContainsKey(cache_key)) | |||
{ | |||
(new_device_name, new_device_spec) = Context._device_parsing_cache[cache_key]; | |||
} | |||
else | |||
{ | |||
if(new_device_name is not null) | |||
{ | |||
var device_spec = DeviceSpec.from_string(new_device_name); | |||
if (!string.IsNullOrEmpty(old_device_name)) | |||
{ | |||
new_device_spec = new DeviceSpec(old_device_spec); | |||
} | |||
else | |||
{ | |||
ctx.ensure_initialized(); | |||
new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); | |||
} | |||
new_device_spec = new_device_spec.make_merged_spec(device_spec); | |||
} | |||
else | |||
{ | |||
new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); | |||
} | |||
new_device_name = new_device_spec.ToString(); | |||
Context._device_parsing_cache[cache_key] = (new_device_name, new_device_spec); | |||
} | |||
ctx._set_device(new_device_name, new_device_spec); | |||
_stack.Push((old_device_name, old_device_spec, new_device_spec)); | |||
} | |||
public void __exit__() | |||
{ | |||
var ctx = _ctx; | |||
var (old_device_name, old_device_spec, new_device_spec) = _stack.Pop(); | |||
ctx._set_device(old_device_name, old_device_spec); | |||
} | |||
public void Dispose() | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,205 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
namespace Tensorflow.Device | |||
{ | |||
public class DeviceSpec | |||
{ | |||
private static Dictionary<string, Components> _STRING_TO_COMPONENTS_CACHE = new(); | |||
private static Dictionary<Components, string> _COMPONENTS_TO_STRING_CACHE = new(); | |||
private string _job; | |||
private int _replica; | |||
private int _task; | |||
private string _device_type; | |||
private int _device_index; | |||
private string _as_string; | |||
public string Job => _job; | |||
public int Replica => _replica; | |||
public int Task => _task; | |||
public string DeviceType => _device_type; | |||
public int DeviceIndex => _device_index; | |||
public DeviceSpec(string job = null, int replica = -1, int task = -1, | |||
string device_type = null, int device_index = -1) | |||
{ | |||
_job = job; | |||
_replica = replica; | |||
_task = task; | |||
_device_type = device_type; | |||
_device_index = device_index; | |||
_as_string = _components_to_string(job, replica, task, device_type, _device_index); | |||
} | |||
public DeviceSpec(DeviceSpec other) | |||
{ | |||
_job = other._job; | |||
_replica = other._replica; | |||
_task = other._task; | |||
_device_type = other._device_type; | |||
_device_index = other._device_index; | |||
_as_string = other._as_string; | |||
} | |||
protected DeviceSpec(Components com) | |||
{ | |||
_job = com.Job; | |||
_replica = com.Replica; | |||
_task = com.Task; | |||
_device_type = com.DeviceType; | |||
_device_index = com.DeviceIndex; | |||
_as_string = _components_to_string(_job, _replica, _task, _device_type, _device_index); | |||
} | |||
public DeviceSpec replace(string job = null, int replica = -1, int task = -1, | |||
string device_type = null, int device_index = -1) | |||
{ | |||
job = job ?? _job; | |||
replica = replica == -1 ? _replica : replica; | |||
task = task == -1 ? _task : task; | |||
device_type = device_type ?? _device_type; | |||
device_index = device_index == -1 ? _device_index : device_index; | |||
return new DeviceSpec(job, replica, task, device_type, device_index); | |||
} | |||
public static DeviceSpec from_string(string spec) | |||
{ | |||
var components = _string_to_components(spec); | |||
return new DeviceSpec(components.Job, components.Replica, components.Task, components.DeviceType, components.DeviceIndex); | |||
} | |||
public DeviceSpec make_merged_spec(DeviceSpec dev) | |||
{ | |||
return new DeviceSpec(_get_combined_properties(dev)); | |||
} | |||
private Components _get_combined_properties(DeviceSpec dev) | |||
{ | |||
return new Components( | |||
dev.Job ?? _job, | |||
dev.Replica == -1 ? _replica : dev.Replica, | |||
dev.Task == -1 ? _task : dev.Task, | |||
dev.DeviceType ?? _device_type, | |||
dev.DeviceIndex == -1 ? _device_index : dev.DeviceIndex | |||
); | |||
} | |||
private static string _components_to_string(string job, int replica, int task, string device_type, int device_index) | |||
{ | |||
var key = new Components(job, replica, task, device_type, device_index); | |||
if(_COMPONENTS_TO_STRING_CACHE.TryGetValue(key, out var cache_result)) | |||
{ | |||
return cache_result; | |||
} | |||
StringBuilder output = new(); | |||
if(job is not null) | |||
{ | |||
output.Append($"/job:{job}"); | |||
} | |||
if(replica != -1) | |||
{ | |||
output.Append($"/replica:{replica}"); | |||
} | |||
if(task != -1) | |||
{ | |||
output.Append($"/task:{task}"); | |||
} | |||
if (device_type is not null) | |||
{ | |||
string device_index_string = "*"; | |||
if (device_index != -1) | |||
{ | |||
device_index_string = device_index.ToString(); | |||
} | |||
output.Append($"/device:{device_type}:{device_index_string}"); | |||
} | |||
var result = output.ToString(); | |||
_COMPONENTS_TO_STRING_CACHE[key] = result; | |||
return result; | |||
} | |||
private static Components _string_to_components(string spec) | |||
{ | |||
if(_STRING_TO_COMPONENTS_CACHE.TryGetValue(spec, out var cached_result)) | |||
{ | |||
return cached_result; | |||
} | |||
var raw_spec = spec; | |||
var splits = spec.Split('/').Select(x => x.Split(':')); | |||
var valid_device_types = _get_valid_device_types(); | |||
string job = null, device_type = null; | |||
int replica = -1, task = -1, device_index = -1; | |||
foreach (var y in splits) | |||
{ | |||
var ly = y.Length; | |||
if (ly > 0) | |||
{ | |||
if(ly == 2 && y[0] == "job") | |||
{ | |||
job = y[1]; | |||
} | |||
else if(ly == 2 && y[0] == "replica") | |||
{ | |||
replica = int.Parse(y[1]); | |||
} | |||
else if(ly == 2 && y[0] == "task") | |||
{ | |||
task = int.Parse(y[1]); | |||
} | |||
else if((ly == 1 || ly == 2) && valid_device_types.Contains(y[0].ToUpper())) | |||
{ | |||
if (device_type is not null) | |||
{ | |||
throw new ValueError($"Multiple device types are not allowed " + | |||
$"while parsing the device spec: {spec}."); | |||
} | |||
device_type = y[0].ToUpper(); | |||
if(ly == 2 && y[1] != "*") | |||
{ | |||
device_index = int.Parse(y[1]); | |||
} | |||
} | |||
else if(ly == 3 && y[0] == "device") | |||
{ | |||
if(device_type is not null) | |||
{ | |||
throw new ValueError($"Multiple device types are not allowed " + | |||
$"while parsing the device spec: {spec}."); | |||
} | |||
device_type = y[1]; | |||
if (y[2] != "*") | |||
{ | |||
device_index = int.Parse(y[2]); | |||
} | |||
} | |||
else if (y[0] != "") | |||
{ | |||
throw new ValueError($"Unknown attribute '{y[0]}' is encountered " + | |||
$"while parsing the device spec: {spec}."); | |||
} | |||
} | |||
} | |||
var output = new Components(job, replica, task, device_type, device_index); | |||
_STRING_TO_COMPONENTS_CACHE[raw_spec] = output; | |||
return output; | |||
} | |||
private static HashSet<string> _get_valid_device_types() | |||
{ | |||
// TODO(Rinne): revise it to calling C API (need customized API). | |||
return new HashSet<string>(new string[] { "CPU", "GPU" }); | |||
} | |||
public override string ToString() | |||
{ | |||
return _as_string; | |||
} | |||
protected record class Components(string Job, int Replica, int Task, string DeviceType, int DeviceIndex); | |||
} | |||
} |
@@ -0,0 +1,26 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Device | |||
{ | |||
internal static class DeviceUtils | |||
{ | |||
public static string canonical_name(string device) | |||
{ | |||
if(device is null) | |||
{ | |||
return ""; | |||
} | |||
return DeviceSpec.from_string(device).ToString(); | |||
} | |||
public static string canonical_name(DeviceSpec device) | |||
{ | |||
if (device is null) | |||
{ | |||
return ""; | |||
} | |||
return device.ToString(); | |||
} | |||
} | |||
} |
@@ -22,6 +22,7 @@ using System.Linq; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Common.Extensions; | |||
using Tensorflow.Graphs; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -344,9 +345,15 @@ namespace Tensorflow | |||
return op; | |||
} | |||
public void device(string device_name) | |||
public ITensorFlowObject device(string device_name) | |||
{ | |||
return new GraphDeviceContext(this, device_name); | |||
} | |||
private void add_device_to_stack(string device_name, int offset = 0) | |||
{ | |||
// TODO(Rinne): deal with device spec. | |||
int total_offset = offset + 1; | |||
} | |||
private void _create_op_helper(Operation op, bool compute_device = true) | |||
@@ -0,0 +1,31 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Graphs | |||
{ | |||
public class GraphDeviceContext : ITensorFlowObject | |||
{ | |||
private Graph _graph; | |||
public GraphDeviceContext(Graph graph, string device_name) | |||
{ | |||
_graph = graph; | |||
} | |||
public void __enter__() | |||
{ | |||
} | |||
public void __exit__() | |||
{ | |||
} | |||
public void Dispose() | |||
{ | |||
} | |||
} | |||
} |
@@ -17,7 +17,7 @@ namespace Tensorflow.Keras | |||
List<IVariableV1> TrainableVariables { get; } | |||
List<IVariableV1> TrainableWeights { get; } | |||
List<IVariableV1> NonTrainableWeights { get; } | |||
List<IVariableV1> Weights { get; } | |||
List<IVariableV1> Weights { get; set; } | |||
Shape OutputShape { get; } | |||
Shape BatchInputShape { get; } | |||
TensorShapeConfig BuildInputShape { get; } | |||
@@ -84,6 +84,8 @@ namespace Tensorflow | |||
protected bool built = false; | |||
public bool Built => built; | |||
List<IVariableV1> ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } | |||
public RnnCell(bool trainable = true, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
@@ -413,8 +413,10 @@ namespace Tensorflow | |||
{ | |||
var variables_path = SavedModelUtils.get_variables_path(_export_dir); | |||
var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||
tf.device("CPU"); | |||
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
tf_with(ops.device("CPU"), _ => | |||
{ | |||
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
}); | |||
LoadStatus load_status; | |||
if (_save_options.allow_partial_checkpoint) | |||
{ | |||
@@ -600,14 +602,16 @@ namespace Tensorflow | |||
if (load_with_device) | |||
{ | |||
tf.device(saved_device); | |||
return (new UninitializedVariable( | |||
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
dtype: (TF_DataType)proto.Dtype, | |||
name: name, | |||
trainable: trainable, | |||
aggregation: aggregation | |||
), setattr); | |||
return tf_with(ops.device(saved_device), _ => | |||
{ | |||
return (new UninitializedVariable( | |||
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
dtype: (TF_DataType)proto.Dtype, | |||
name: name, | |||
trainable: trainable, | |||
aggregation: aggregation | |||
), setattr); | |||
}); | |||
} | |||
else | |||
{ | |||
@@ -282,9 +282,11 @@ namespace Tensorflow | |||
BaseResourceVariable new_variable; | |||
if (save_options.experimental_variable_policy.save_variable_devices()) | |||
{ | |||
tf.device(this.Device); | |||
Debug.Assert(this is ResourceVariable); | |||
new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||
new_variable = tf_with(ops.device(this.Device), _ => | |||
{ | |||
return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||
}); | |||
} | |||
else | |||
{ | |||
@@ -9,7 +9,7 @@ namespace Tensorflow.Variables | |||
/// <summary> | |||
/// A variable with no initializer. | |||
/// </summary> | |||
public sealed class UninitializedVariable: BaseResourceVariable, IVariableV1 | |||
public sealed class UninitializedVariable : BaseResourceVariable, IVariableV1 | |||
{ | |||
// TODO: complete the arg list. | |||
public UninitializedVariable( | |||
@@ -19,7 +19,7 @@ namespace Tensorflow.Variables | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
VariableAggregation aggregation = VariableAggregation.None, | |||
Shape shape = null, | |||
Tensor extra_handle_data = null) | |||
Tensor extra_handle_data = null) | |||
{ | |||
string unique_id = ""; | |||
string handle_name = ""; | |||
@@ -50,9 +50,12 @@ namespace Tensorflow.Variables | |||
{ | |||
tf_with(ops.name_scope("Read"), _ => | |||
{ | |||
tf.device(created_handle.Device); | |||
var value = gen_resource_variable_ops.read_variable_op(created_handle, dtype); | |||
resource_variable_ops._maybe_set_handle_data(dtype, created_handle, value); | |||
var value = tf_with(ops.device(created_handle.Device), _ => | |||
{ | |||
var result = gen_resource_variable_ops.read_variable_op(created_handle, dtype); | |||
resource_variable_ops._maybe_set_handle_data(dtype, created_handle, result); | |||
return result; | |||
}); | |||
_graph_element = value; | |||
}); | |||
ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | |||
@@ -584,6 +584,23 @@ namespace Tensorflow | |||
} | |||
public static ITensorFlowObject device(string device_name) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
return tf.Context.device(device_name); | |||
} | |||
//else if (ops.executing_eagerly_outside_functions()) | |||
//{ | |||
// throw new NotImplementedException(); | |||
//} | |||
else | |||
{ | |||
return get_default_graph().device(device_name); | |||
} | |||
// TODO(Rinne): deal with `ops.executing_eagerly_outside_functions()`. | |||
} | |||
public class NullContextManager: IDisposable | |||
{ | |||
public void Dispose() | |||
@@ -77,7 +77,7 @@ public class EarlyStopping: ICallback | |||
// Restore the weights after first epoch if no progress is ever made. | |||
if (_restore_best_weights && _best_weights == null) | |||
{ | |||
_best_weights = _parameters.Model.TrainableWeights; | |||
_best_weights = _parameters.Model.Weights; | |||
} | |||
_wait += 1; | |||
@@ -102,10 +102,8 @@ public class EarlyStopping: ICallback | |||
{ | |||
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); | |||
} | |||
_parameters.Model.Weights = _best_weights; | |||
} | |||
// Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet. | |||
// TODO(Wanglongzhi2001): implement it. | |||
// _parameters.Model.load_weights(best_weights); | |||
} | |||
} | |||
public void on_train_end() | |||
@@ -4,17 +4,21 @@ namespace Tensorflow.Keras.Losses | |||
{ | |||
public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc | |||
{ | |||
private bool _from_logits = false; | |||
public SparseCategoricalCrossentropy( | |||
bool from_logits = false, | |||
string reduction = null, | |||
string name = null) : | |||
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name){ } | |||
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name) | |||
{ | |||
_from_logits = from_logits; | |||
} | |||
public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) | |||
{ | |||
target = tf.cast(target, dtype: TF_DataType.TF_INT64); | |||
if (!from_logits) | |||
if (!_from_logits) | |||
{ | |||
var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype); | |||
output = tf.clip_by_value(output, epsilon, 1 - epsilon); | |||
@@ -12,6 +12,7 @@ namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||
[TestClass] | |||
public class SequentialModelLoad | |||
{ | |||
[Ignore] | |||
[TestMethod] | |||
public void SimpleModelFromAutoCompile() | |||
{ | |||