Fix the error when saving model with GPU.tags/v0.100.5-BERT-load
@@ -53,8 +53,11 @@ public static class SaveUtilV1 | |||||
var g = to_graph.as_default(); | var g = to_graph.as_default(); | ||||
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | ||||
object_map, call_with_mapped_captures, saveables_cache); | object_map, call_with_mapped_captures, saveables_cache); | ||||
tf.device("/cpu:0"); | |||||
var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||||
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => | |||||
{ | |||||
// TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception. | |||||
return constant_op.constant(graph_proto.ToByteArray()); | |||||
}); | |||||
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | ||||
g.Exit(); | g.Exit(); | ||||
return (named_saveable_objects, registered_savers); | return (named_saveable_objects, registered_savers); | ||||
@@ -65,8 +68,10 @@ public static class SaveUtilV1 | |||||
{ | { | ||||
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | ||||
object_map, call_with_mapped_captures, saveables_cache); | object_map, call_with_mapped_captures, saveables_cache); | ||||
tf.device("/cpu:0"); | |||||
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||||
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => | |||||
{ | |||||
return constant_op.constant(graph_proto.ToString()); | |||||
}); | |||||
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | ||||
return (named_saveable_objects, registered_savers); | return (named_saveable_objects, registered_savers); | ||||
} | } | ||||
@@ -58,8 +58,10 @@ public class TrackableSaver | |||||
if(object_graph_tensor is null) | if(object_graph_tensor is null) | ||||
{ | { | ||||
tf.device("/cpu:0"); | |||||
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||||
tf_with(ops.device("/cpu:0"), _ => | |||||
{ | |||||
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||||
}); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -230,13 +232,15 @@ public class TrackableSaver | |||||
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | ||||
Dictionary<Tensor, string> file_prefix_feed_dict; | Dictionary<Tensor, string> file_prefix_feed_dict; | ||||
Tensor file_prefix_tensor; | |||||
Tensor file_prefix_tensor = null; | |||||
if (graph_building) | if (graph_building) | ||||
{ | { | ||||
if(_file_prefix_placeholder is null) | if(_file_prefix_placeholder is null) | ||||
{ | { | ||||
tf.device("/cpu:0"); | |||||
_file_prefix_placeholder = constant_op.constant("model"); | |||||
_file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ => | |||||
{ | |||||
return constant_op.constant("model"); | |||||
}); | |||||
} | } | ||||
file_prefix_tensor = _file_prefix_placeholder; | file_prefix_tensor = _file_prefix_placeholder; | ||||
file_prefix_feed_dict = new(); | file_prefix_feed_dict = new(); | ||||
@@ -244,8 +248,10 @@ public class TrackableSaver | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
tf.device("/cpu:0"); | |||||
file_prefix_tensor = constant_op.constant(save_path); | |||||
file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ => | |||||
{ | |||||
return constant_op.constant(save_path); | |||||
}); | |||||
file_prefix_feed_dict = null; | file_prefix_feed_dict = null; | ||||
} | } | ||||
TrackableObjectGraph object_graph_proto = new(); | TrackableObjectGraph object_graph_proto = new(); | ||||
@@ -211,9 +211,11 @@ namespace Tensorflow.Checkpoint | |||||
string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; | string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; | ||||
// tf python has code `with ops.device(restore_device):` here. | |||||
tf.device(restore_device); // may be risky. | |||||
var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||||
Tensor[] restored_tensors = null; | |||||
tf_with(ops.device(restore_device), _ => | |||||
{ | |||||
restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||||
}); | |||||
Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | ||||
int idx = 0; | int idx = 0; | ||||
@@ -338,11 +340,14 @@ namespace Tensorflow.Checkpoint | |||||
options = new CheckpointOptions(); | options = new CheckpointOptions(); | ||||
} | } | ||||
tf.device("CPU"); // may be risky. | |||||
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||||
Tensor tmp_checkpoint_prefix = null; | |||||
tf_with(ops.device("CPU"), _ => | |||||
{ | |||||
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||||
constant_op.constant(".part"), constant_op.constant("_temp/part")); | constant_op.constant(".part"), constant_op.constant("_temp/part")); | ||||
var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||||
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||||
tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||||
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||||
}); | |||||
Operation save_fn() | Operation save_fn() | ||||
{ | { | ||||
@@ -364,16 +369,24 @@ namespace Tensorflow.Checkpoint | |||||
var saver = pair.Value; | var saver = pair.Value; | ||||
last_device = device; | last_device = device; | ||||
// skip the extra process of device name because of lack of API. | // skip the extra process of device name because of lack of API. | ||||
tf.device(device); | |||||
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||||
Tensor shard_prefix = null; | |||||
tf_with(ops.device(device), _ => | |||||
{ | |||||
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||||
}); | |||||
saved_prefixes.Add(shard_prefix); | saved_prefixes.Add(shard_prefix); | ||||
sharded_saves.Add(saver.save(shard_prefix, options)); | |||||
tf_with(ops.device(device), _ => | |||||
{ | |||||
sharded_saves.Add(saver.save(shard_prefix, options)); | |||||
}); | |||||
} | } | ||||
using (var controller = ops.control_dependencies(sharded_saves.ToArray())) | using (var controller = ops.control_dependencies(sharded_saves.ToArray())) | ||||
{ | { | ||||
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | ||||
tf.device(merge_device); | |||||
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||||
return tf_with(ops.device(merge_device), _ => | |||||
{ | |||||
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||||
}); | |||||
} | } | ||||
} | } | ||||
@@ -407,54 +420,56 @@ namespace Tensorflow.Checkpoint | |||||
{ | { | ||||
var device = single_saver.Key; | var device = single_saver.Key; | ||||
var saver = single_saver.Value; | var saver = single_saver.Value; | ||||
tf.device(device); | |||||
var restored_tensor_dict = saver.restore(file_prefix, options); | |||||
foreach(var pair in restored_tensor_dict) | |||||
tf_with(ops.device(device), _ => | |||||
{ | { | ||||
var checkpoint_key = pair.Key; | |||||
var slice_and_tensor = pair.Value; | |||||
foreach(var item in slice_and_tensor) | |||||
var restored_tensor_dict = saver.restore(file_prefix, options); | |||||
foreach (var pair in restored_tensor_dict) | |||||
{ | { | ||||
var slice_spec = item.Key; | |||||
var tensor = item.Value; | |||||
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||||
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||||
if (!string.IsNullOrEmpty(slice_spec)) | |||||
var checkpoint_key = pair.Key; | |||||
var slice_and_tensor = pair.Value; | |||||
foreach (var item in slice_and_tensor) | |||||
{ | { | ||||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
var slice_spec = item.Key; | |||||
var tensor = item.Value; | |||||
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||||
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||||
if (!string.IsNullOrEmpty(slice_spec)) | |||||
{ | { | ||||
Dictionary<string, Tensor> dict = new(); | |||||
dict[slice_spec] = tensor; | |||||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
{ | |||||
Dictionary<string, Tensor> dict = new(); | |||||
dict[slice_spec] = tensor; | |||||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||||
} | |||||
else | |||||
{ | |||||
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||||
} | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||||
} | } | ||||
} | |||||
else | |||||
{ | |||||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||||
} | |||||
restore_fn_input_count[restore_fn]--; | |||||
restore_fn_input_count[restore_fn]--; | |||||
if (restore_fn_input_count[restore_fn] == 0) | |||||
{ | |||||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
foreach(var input in restore_fn_inputs[restore_fn]) | |||||
if (restore_fn_input_count[restore_fn] == 0) | |||||
{ | { | ||||
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||||
} | |||||
var ret = restore_fn.DynamicInvoke(restored_tensors); | |||||
if(ret is IDictionary<string, Operation>) | |||||
{ | |||||
var dict = (IDictionary<string, Operation>)ret; | |||||
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
foreach (var input in restore_fn_inputs[restore_fn]) | |||||
{ | |||||
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||||
} | |||||
var ret = restore_fn.DynamicInvoke(restored_tensors); | |||||
if (ret is IDictionary<string, Operation>) | |||||
{ | |||||
var dict = (IDictionary<string, Operation>)ret; | |||||
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
} | |||||
}); | |||||
} | } | ||||
foreach(var item in _registered_savers) | foreach(var item in _registered_savers) | ||||
@@ -500,21 +515,25 @@ namespace Tensorflow.Checkpoint | |||||
private Tensor _traced_save(Tensor file_prefix) | private Tensor _traced_save(Tensor file_prefix) | ||||
{ | { | ||||
var save_op = save(file_prefix); | var save_op = save(file_prefix); | ||||
tf.device("cpu:0"); | |||||
using (ops.control_dependencies(new object[]{ save_op })) | |||||
return tf_with(ops.device("cpu:0"), _ => | |||||
{ | { | ||||
return array_ops.identity(file_prefix); | |||||
} | |||||
return tf_with(ops.control_dependencies(new object[] { save_op }), __ => | |||||
{ | |||||
return array_ops.identity(file_prefix); | |||||
}); | |||||
}); | |||||
} | } | ||||
private Tensor _traced_restore(Tensor file_prefix) | private Tensor _traced_restore(Tensor file_prefix) | ||||
{ | { | ||||
var restore_op = restore(file_prefix); | var restore_op = restore(file_prefix); | ||||
tf.device("cpu:0"); | |||||
using (ops.control_dependencies(restore_op.Values.ToArray())) | |||||
return tf_with(ops.device("cpu:0"), _ => | |||||
{ | { | ||||
return array_ops.identity(file_prefix); | |||||
} | |||||
return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ => | |||||
{ | |||||
return array_ops.identity(file_prefix); | |||||
}); | |||||
}); | |||||
} | } | ||||
public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | ||||
@@ -21,6 +21,7 @@ using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using Tensorflow.Device; | using Tensorflow.Device; | ||||
using Tensorflow.Exceptions; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
namespace Tensorflow.Contexts | namespace Tensorflow.Contexts | ||||
@@ -30,10 +31,30 @@ namespace Tensorflow.Contexts | |||||
/// </summary> | /// </summary> | ||||
public sealed partial class Context | public sealed partial class Context | ||||
{ | { | ||||
internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new(); | |||||
internal List<LogicalDevice> _logical_devices = null; | |||||
internal List<string> _context_devices = null; | |||||
ContextDevicePlacementPolicy _device_policy; | ContextDevicePlacementPolicy _device_policy; | ||||
bool _log_device_placement; | bool _log_device_placement; | ||||
int _num_gpus; | |||||
Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>(); | Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>(); | ||||
public string DeviceName { get; set; } = ""; | |||||
public DeviceSpec DeviceSpec { get; set; } = null; | |||||
internal List<string> Devices | |||||
{ | |||||
get | |||||
{ | |||||
if(_context_devices is null) | |||||
{ | |||||
throw new AssertionError("Context must be initialized first."); | |||||
} | |||||
return _context_devices; | |||||
} | |||||
} | |||||
public void log_device_placement(bool enable) | public void log_device_placement(bool enable) | ||||
{ | { | ||||
if (_handle != null) | if (_handle != null) | ||||
@@ -89,5 +110,57 @@ namespace Tensorflow.Contexts | |||||
return results.ToArray(); | return results.ToArray(); | ||||
} | } | ||||
public EagerDeviceContext device(string name) | |||||
{ | |||||
return new EagerDeviceContext(this, name); | |||||
} | |||||
internal void _set_device(string device_name, DeviceSpec device_spec) | |||||
{ | |||||
DeviceSpec = device_spec; | |||||
DeviceName = device_name; | |||||
} | |||||
internal void _initialize_logical_devices() | |||||
{ | |||||
List<LogicalDevice> logical_devices = new(); | |||||
List<string> context_devices = new(); | |||||
Status status = new(); | |||||
var device_list = c_api.TFE_ContextListDevices(_handle, status); | |||||
status.Check(true); | |||||
try | |||||
{ | |||||
this._num_gpus = 0; | |||||
string current_job = null; | |||||
int current_task = -1; | |||||
for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++) | |||||
{ | |||||
var dev_name = c_api.TF_DeviceListName(device_list, i, status); | |||||
status.Check(true); | |||||
context_devices.Add(DeviceUtils.canonical_name(dev_name)); | |||||
var spec = DeviceSpec.from_string(dev_name); | |||||
if(spec.Job == "localhost") | |||||
{ | |||||
spec = spec.replace(job: null, replica: -1, task: -1); | |||||
} | |||||
logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType)); | |||||
var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status); | |||||
var dev_type = c_api.StringPiece(dev_type_memory); | |||||
status.Check(true); | |||||
if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task) | |||||
{ | |||||
_num_gpus++; | |||||
} | |||||
} | |||||
} | |||||
finally | |||||
{ | |||||
_logical_devices = logical_devices; | |||||
_context_devices = context_devices; | |||||
} | |||||
} | |||||
} | } | ||||
public record class LogicalDevice(string name, string device_type); | |||||
} | } |
@@ -34,7 +34,6 @@ namespace Tensorflow.Contexts | |||||
public const int EAGER_MODE = 1; | public const int EAGER_MODE = 1; | ||||
int defaultExecutionMode = EAGER_MODE; | int defaultExecutionMode = EAGER_MODE; | ||||
public string DeviceName { get; set; } = ""; | |||||
public string ScopeName { get; set; } = ""; | public string ScopeName { get; set; } = ""; | ||||
bool initialized = false; | bool initialized = false; | ||||
ContextSwitchStack context_switches; | ContextSwitchStack context_switches; | ||||
@@ -62,6 +61,8 @@ namespace Tensorflow.Contexts | |||||
if (initialized) | if (initialized) | ||||
return; | return; | ||||
Debug.Assert(_context_devices is null); | |||||
Config = MergeConfig(); | Config = MergeConfig(); | ||||
FunctionCallOptions.Config = Config; | FunctionCallOptions.Config = Config; | ||||
var config_str = Config.ToByteArray(); | var config_str = Config.ToByteArray(); | ||||
@@ -72,6 +73,7 @@ namespace Tensorflow.Contexts | |||||
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); | c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); | ||||
_handle = c_api.TFE_NewContext(opts, status); | _handle = c_api.TFE_NewContext(opts, status); | ||||
status.Check(true); | status.Check(true); | ||||
_initialize_logical_devices(); | |||||
initialized = true; | initialized = true; | ||||
} | } | ||||
@@ -174,6 +176,7 @@ namespace Tensorflow.Contexts | |||||
{ | { | ||||
c_api.TFE_ContextClearCaches(_handle); | c_api.TFE_ContextClearCaches(_handle); | ||||
} | } | ||||
_device_parsing_cache.Clear(); | |||||
} | } | ||||
public static implicit operator SafeContextHandle(Context ctx) | public static implicit operator SafeContextHandle(Context ctx) | ||||
@@ -0,0 +1,71 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Device; | |||||
namespace Tensorflow.Contexts | |||||
{ | |||||
public class EagerDeviceContext : ITensorFlowObject | |||||
{ | |||||
private Context _ctx; | |||||
private string _device_name; | |||||
private Stack<(string, DeviceSpec, DeviceSpec)> _stack; | |||||
public EagerDeviceContext(Context ctx, string device_name) | |||||
{ | |||||
_ctx = ctx; | |||||
_device_name = device_name; | |||||
_stack = new Stack<(string, DeviceSpec, DeviceSpec)>(); | |||||
} | |||||
public void __enter__() | |||||
{ | |||||
var ctx = _ctx; | |||||
var old_device_name = ctx.DeviceName; | |||||
var old_device_spec = ctx.DeviceSpec; | |||||
var new_device_name = _device_name; | |||||
var cache_key = (old_device_name, new_device_name); | |||||
DeviceSpec new_device_spec; | |||||
if (Context._device_parsing_cache.ContainsKey(cache_key)) | |||||
{ | |||||
(new_device_name, new_device_spec) = Context._device_parsing_cache[cache_key]; | |||||
} | |||||
else | |||||
{ | |||||
if(new_device_name is not null) | |||||
{ | |||||
var device_spec = DeviceSpec.from_string(new_device_name); | |||||
if (!string.IsNullOrEmpty(old_device_name)) | |||||
{ | |||||
new_device_spec = new DeviceSpec(old_device_spec); | |||||
} | |||||
else | |||||
{ | |||||
ctx.ensure_initialized(); | |||||
new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); | |||||
} | |||||
new_device_spec = new_device_spec.make_merged_spec(device_spec); | |||||
} | |||||
else | |||||
{ | |||||
new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); | |||||
} | |||||
new_device_name = new_device_spec.ToString(); | |||||
Context._device_parsing_cache[cache_key] = (new_device_name, new_device_spec); | |||||
} | |||||
ctx._set_device(new_device_name, new_device_spec); | |||||
_stack.Push((old_device_name, old_device_spec, new_device_spec)); | |||||
} | |||||
public void __exit__() | |||||
{ | |||||
var ctx = _ctx; | |||||
var (old_device_name, old_device_spec, new_device_spec) = _stack.Pop(); | |||||
ctx._set_device(old_device_name, old_device_spec); | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,205 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
namespace Tensorflow.Device | |||||
{ | |||||
public class DeviceSpec | |||||
{ | |||||
private static Dictionary<string, Components> _STRING_TO_COMPONENTS_CACHE = new(); | |||||
private static Dictionary<Components, string> _COMPONENTS_TO_STRING_CACHE = new(); | |||||
private string _job; | |||||
private int _replica; | |||||
private int _task; | |||||
private string _device_type; | |||||
private int _device_index; | |||||
private string _as_string; | |||||
public string Job => _job; | |||||
public int Replica => _replica; | |||||
public int Task => _task; | |||||
public string DeviceType => _device_type; | |||||
public int DeviceIndex => _device_index; | |||||
public DeviceSpec(string job = null, int replica = -1, int task = -1, | |||||
string device_type = null, int device_index = -1) | |||||
{ | |||||
_job = job; | |||||
_replica = replica; | |||||
_task = task; | |||||
_device_type = device_type; | |||||
_device_index = device_index; | |||||
_as_string = _components_to_string(job, replica, task, device_type, _device_index); | |||||
} | |||||
public DeviceSpec(DeviceSpec other) | |||||
{ | |||||
_job = other._job; | |||||
_replica = other._replica; | |||||
_task = other._task; | |||||
_device_type = other._device_type; | |||||
_device_index = other._device_index; | |||||
_as_string = other._as_string; | |||||
} | |||||
protected DeviceSpec(Components com) | |||||
{ | |||||
_job = com.Job; | |||||
_replica = com.Replica; | |||||
_task = com.Task; | |||||
_device_type = com.DeviceType; | |||||
_device_index = com.DeviceIndex; | |||||
_as_string = _components_to_string(_job, _replica, _task, _device_type, _device_index); | |||||
} | |||||
public DeviceSpec replace(string job = null, int replica = -1, int task = -1, | |||||
string device_type = null, int device_index = -1) | |||||
{ | |||||
job = job ?? _job; | |||||
replica = replica == -1 ? _replica : replica; | |||||
task = task == -1 ? _task : task; | |||||
device_type = device_type ?? _device_type; | |||||
device_index = device_index == -1 ? _device_index : device_index; | |||||
return new DeviceSpec(job, replica, task, device_type, device_index); | |||||
} | |||||
public static DeviceSpec from_string(string spec) | |||||
{ | |||||
var components = _string_to_components(spec); | |||||
return new DeviceSpec(components.Job, components.Replica, components.Task, components.DeviceType, components.DeviceIndex); | |||||
} | |||||
public DeviceSpec make_merged_spec(DeviceSpec dev) | |||||
{ | |||||
return new DeviceSpec(_get_combined_properties(dev)); | |||||
} | |||||
private Components _get_combined_properties(DeviceSpec dev) | |||||
{ | |||||
return new Components( | |||||
dev.Job ?? _job, | |||||
dev.Replica == -1 ? _replica : dev.Replica, | |||||
dev.Task == -1 ? _task : dev.Task, | |||||
dev.DeviceType ?? _device_type, | |||||
dev.DeviceIndex == -1 ? _device_index : dev.DeviceIndex | |||||
); | |||||
} | |||||
private static string _components_to_string(string job, int replica, int task, string device_type, int device_index) | |||||
{ | |||||
var key = new Components(job, replica, task, device_type, device_index); | |||||
if(_COMPONENTS_TO_STRING_CACHE.TryGetValue(key, out var cache_result)) | |||||
{ | |||||
return cache_result; | |||||
} | |||||
StringBuilder output = new(); | |||||
if(job is not null) | |||||
{ | |||||
output.Append($"/job:{job}"); | |||||
} | |||||
if(replica != -1) | |||||
{ | |||||
output.Append($"/replica:{replica}"); | |||||
} | |||||
if(task != -1) | |||||
{ | |||||
output.Append($"/task:{task}"); | |||||
} | |||||
if (device_type is not null) | |||||
{ | |||||
string device_index_string = "*"; | |||||
if (device_index != -1) | |||||
{ | |||||
device_index_string = device_index.ToString(); | |||||
} | |||||
output.Append($"/device:{device_type}:{device_index_string}"); | |||||
} | |||||
var result = output.ToString(); | |||||
_COMPONENTS_TO_STRING_CACHE[key] = result; | |||||
return result; | |||||
} | |||||
private static Components _string_to_components(string spec) | |||||
{ | |||||
if(_STRING_TO_COMPONENTS_CACHE.TryGetValue(spec, out var cached_result)) | |||||
{ | |||||
return cached_result; | |||||
} | |||||
var raw_spec = spec; | |||||
var splits = spec.Split('/').Select(x => x.Split(':')); | |||||
var valid_device_types = _get_valid_device_types(); | |||||
string job = null, device_type = null; | |||||
int replica = -1, task = -1, device_index = -1; | |||||
foreach (var y in splits) | |||||
{ | |||||
var ly = y.Length; | |||||
if (ly > 0) | |||||
{ | |||||
if(ly == 2 && y[0] == "job") | |||||
{ | |||||
job = y[1]; | |||||
} | |||||
else if(ly == 2 && y[0] == "replica") | |||||
{ | |||||
replica = int.Parse(y[1]); | |||||
} | |||||
else if(ly == 2 && y[0] == "task") | |||||
{ | |||||
task = int.Parse(y[1]); | |||||
} | |||||
else if((ly == 1 || ly == 2) && valid_device_types.Contains(y[0].ToUpper())) | |||||
{ | |||||
if (device_type is not null) | |||||
{ | |||||
throw new ValueError($"Multiple device types are not allowed " + | |||||
$"while parsing the device spec: {spec}."); | |||||
} | |||||
device_type = y[0].ToUpper(); | |||||
if(ly == 2 && y[1] != "*") | |||||
{ | |||||
device_index = int.Parse(y[1]); | |||||
} | |||||
} | |||||
else if(ly == 3 && y[0] == "device") | |||||
{ | |||||
if(device_type is not null) | |||||
{ | |||||
throw new ValueError($"Multiple device types are not allowed " + | |||||
$"while parsing the device spec: {spec}."); | |||||
} | |||||
device_type = y[1]; | |||||
if (y[2] != "*") | |||||
{ | |||||
device_index = int.Parse(y[2]); | |||||
} | |||||
} | |||||
else if (y[0] != "") | |||||
{ | |||||
throw new ValueError($"Unknown attribute '{y[0]}' is encountered " + | |||||
$"while parsing the device spec: {spec}."); | |||||
} | |||||
} | |||||
} | |||||
var output = new Components(job, replica, task, device_type, device_index); | |||||
_STRING_TO_COMPONENTS_CACHE[raw_spec] = output; | |||||
return output; | |||||
} | |||||
private static HashSet<string> _get_valid_device_types() | |||||
{ | |||||
// TODO(Rinne): revise it to calling C API (need customized API). | |||||
return new HashSet<string>(new string[] { "CPU", "GPU" }); | |||||
} | |||||
public override string ToString() | |||||
{ | |||||
return _as_string; | |||||
} | |||||
protected record class Components(string Job, int Replica, int Task, string DeviceType, int DeviceIndex); | |||||
} | |||||
} |
@@ -0,0 +1,26 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Device | |||||
{ | |||||
internal static class DeviceUtils | |||||
{ | |||||
public static string canonical_name(string device) | |||||
{ | |||||
if(device is null) | |||||
{ | |||||
return ""; | |||||
} | |||||
return DeviceSpec.from_string(device).ToString(); | |||||
} | |||||
public static string canonical_name(DeviceSpec device) | |||||
{ | |||||
if (device is null) | |||||
{ | |||||
return ""; | |||||
} | |||||
return device.ToString(); | |||||
} | |||||
} | |||||
} |
@@ -19,6 +19,7 @@ using System.Collections; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Collections.Specialized; | using System.Collections.Specialized; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Graphs; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -294,9 +295,15 @@ namespace Tensorflow | |||||
return op; | return op; | ||||
} | } | ||||
public void device(string device_name) | |||||
public ITensorFlowObject device(string device_name) | |||||
{ | { | ||||
return new GraphDeviceContext(this, device_name); | |||||
} | |||||
private void add_device_to_stack(string device_name, int offset = 0) | |||||
{ | |||||
// TODO(Rinne): deal with device spec. | |||||
int total_offset = offset + 1; | |||||
} | } | ||||
private void _create_op_helper(Operation op, bool compute_device = true) | private void _create_op_helper(Operation op, bool compute_device = true) | ||||
@@ -0,0 +1,31 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Graphs | |||||
{ | |||||
public class GraphDeviceContext : ITensorFlowObject | |||||
{ | |||||
private Graph _graph; | |||||
public GraphDeviceContext(Graph graph, string device_name) | |||||
{ | |||||
_graph = graph; | |||||
} | |||||
public void __enter__() | |||||
{ | |||||
} | |||||
public void __exit__() | |||||
{ | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -42,16 +42,20 @@ namespace Tensorflow | |||||
_var_device = var.Device; | _var_device = var.Device; | ||||
_var_shape = var.shape; | _var_shape = var.shape; | ||||
Tensor _read_variable_closure(BaseResourceVariable v) | |||||
Tensor? _read_variable_closure(BaseResourceVariable v) | |||||
{ | { | ||||
tf.device(v.Device); | |||||
if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||||
return tf_with(ops.device(v.Device), _ => | |||||
{ | { | ||||
return null; | |||||
} | |||||
var x = v.read_value_no_copy(); | |||||
tf.device("/device:CPU:0"); | |||||
return array_ops.identity(x); | |||||
if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||||
{ | |||||
return null; | |||||
} | |||||
var x = v.read_value_no_copy(); | |||||
return tf_with(ops.device("/device:CPU:0"), __ => | |||||
{ | |||||
return array_ops.identity(x); | |||||
}); | |||||
}); | |||||
} | } | ||||
this.handle_op = var.Handle; | this.handle_op = var.Handle; | ||||
@@ -412,8 +412,10 @@ namespace Tensorflow | |||||
{ | { | ||||
var variables_path = SavedModelUtils.get_variables_path(_export_dir); | var variables_path = SavedModelUtils.get_variables_path(_export_dir); | ||||
var saver = new TrackableSaver(new ObjectGraphView(get(0))); | var saver = new TrackableSaver(new ObjectGraphView(get(0))); | ||||
tf.device("CPU"); | |||||
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||||
tf_with(ops.device("CPU"), _ => | |||||
{ | |||||
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||||
}); | |||||
LoadStatus load_status; | LoadStatus load_status; | ||||
if (_save_options.allow_partial_checkpoint) | if (_save_options.allow_partial_checkpoint) | ||||
{ | { | ||||
@@ -598,14 +600,16 @@ namespace Tensorflow | |||||
if (load_with_device) | if (load_with_device) | ||||
{ | { | ||||
tf.device(saved_device); | |||||
return (new UninitializedVariable( | |||||
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||||
dtype: (TF_DataType)proto.Dtype, | |||||
name: name, | |||||
trainable: trainable, | |||||
aggregation: aggregation | |||||
), setattr); | |||||
return tf_with(ops.device(saved_device), _ => | |||||
{ | |||||
return (new UninitializedVariable( | |||||
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||||
dtype: (TF_DataType)proto.Dtype, | |||||
name: name, | |||||
trainable: trainable, | |||||
aggregation: aggregation | |||||
), setattr); | |||||
}); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -266,9 +266,11 @@ namespace Tensorflow | |||||
BaseResourceVariable new_variable; | BaseResourceVariable new_variable; | ||||
if (save_options.experimental_variable_policy.save_variable_devices()) | if (save_options.experimental_variable_policy.save_variable_devices()) | ||||
{ | { | ||||
tf.device(this.Device); | |||||
Debug.Assert(this is ResourceVariable); | Debug.Assert(this is ResourceVariable); | ||||
new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||||
new_variable = tf_with(ops.device(this.Device), _ => | |||||
{ | |||||
return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||||
}); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -49,9 +49,12 @@ namespace Tensorflow.Variables | |||||
{ | { | ||||
tf_with(ops.name_scope("Read"), _ => | tf_with(ops.name_scope("Read"), _ => | ||||
{ | { | ||||
tf.device(handle.Device); | |||||
var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
// _maybe_set_handle_data(dtype, handle, value) | |||||
var value = tf_with(ops.device(handle.Device), _ => | |||||
{ | |||||
var result = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
// TODO(Rinne): _maybe_set_handle_data(dtype, handle, value) | |||||
return result; | |||||
}); | |||||
_graph_element = value; | _graph_element = value; | ||||
}); | }); | ||||
ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | ||||
@@ -577,6 +577,23 @@ namespace Tensorflow | |||||
} | } | ||||
public static ITensorFlowObject device(string device_name) | |||||
{ | |||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
return tf.Context.device(device_name); | |||||
} | |||||
//else if (ops.executing_eagerly_outside_functions()) | |||||
//{ | |||||
// throw new NotImplementedException(); | |||||
//} | |||||
else | |||||
{ | |||||
return get_default_graph().device(device_name); | |||||
} | |||||
// TODO(Rinne): deal with `ops.executing_eagerly_outside_functions()`. | |||||
} | |||||
public class NullContextManager: IDisposable | public class NullContextManager: IDisposable | ||||
{ | { | ||||
public void Dispose() | public void Dispose() | ||||