@@ -53,8 +53,11 @@ public static class SaveUtilV1 | |||
var g = to_graph.as_default(); | |||
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||
object_map, call_with_mapped_captures, saveables_cache); | |||
tf.device("/cpu:0"); | |||
var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => | |||
{ | |||
// TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception. | |||
return constant_op.constant(graph_proto.ToByteArray()); | |||
}); | |||
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
g.Exit(); | |||
return (named_saveable_objects, registered_savers); | |||
@@ -65,8 +68,10 @@ public static class SaveUtilV1 | |||
{ | |||
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||
object_map, call_with_mapped_captures, saveables_cache); | |||
tf.device("/cpu:0"); | |||
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => | |||
{ | |||
return constant_op.constant(graph_proto.ToString()); | |||
}); | |||
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
return (named_saveable_objects, registered_savers); | |||
} | |||
@@ -58,8 +58,10 @@ public class TrackableSaver | |||
if(object_graph_tensor is null) | |||
{ | |||
tf.device("/cpu:0"); | |||
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||
tf_with(ops.device("/cpu:0"), _ => | |||
{ | |||
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||
}); | |||
} | |||
else | |||
{ | |||
@@ -230,13 +232,15 @@ public class TrackableSaver | |||
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | |||
Dictionary<Tensor, string> file_prefix_feed_dict; | |||
Tensor file_prefix_tensor; | |||
Tensor file_prefix_tensor = null; | |||
if (graph_building) | |||
{ | |||
if(_file_prefix_placeholder is null) | |||
{ | |||
tf.device("/cpu:0"); | |||
_file_prefix_placeholder = constant_op.constant("model"); | |||
_file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ => | |||
{ | |||
return constant_op.constant("model"); | |||
}); | |||
} | |||
file_prefix_tensor = _file_prefix_placeholder; | |||
file_prefix_feed_dict = new(); | |||
@@ -244,8 +248,10 @@ public class TrackableSaver | |||
} | |||
else | |||
{ | |||
tf.device("/cpu:0"); | |||
file_prefix_tensor = constant_op.constant(save_path); | |||
file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ => | |||
{ | |||
return constant_op.constant(save_path); | |||
}); | |||
file_prefix_feed_dict = null; | |||
} | |||
TrackableObjectGraph object_graph_proto = new(); | |||
@@ -211,9 +211,11 @@ namespace Tensorflow.Checkpoint | |||
string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; | |||
// tf python has code `with ops.device(restore_device):` here. | |||
tf.device(restore_device); // may be risky. | |||
var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
Tensor[] restored_tensors = null; | |||
tf_with(ops.device(restore_device), _ => | |||
{ | |||
restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
}); | |||
Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | |||
int idx = 0; | |||
@@ -338,11 +340,14 @@ namespace Tensorflow.Checkpoint | |||
options = new CheckpointOptions(); | |||
} | |||
tf.device("CPU"); // may be risky. | |||
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||
Tensor tmp_checkpoint_prefix = null; | |||
tf_with(ops.device("CPU"), _ => | |||
{ | |||
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||
constant_op.constant(".part"), constant_op.constant("_temp/part")); | |||
var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||
tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||
}); | |||
Operation save_fn() | |||
{ | |||
@@ -364,16 +369,24 @@ namespace Tensorflow.Checkpoint | |||
var saver = pair.Value; | |||
last_device = device; | |||
// skip the extra process of device name because of lack of API. | |||
tf.device(device); | |||
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||
Tensor shard_prefix = null; | |||
tf_with(ops.device(device), _ => | |||
{ | |||
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||
}); | |||
saved_prefixes.Add(shard_prefix); | |||
sharded_saves.Add(saver.save(shard_prefix, options)); | |||
tf_with(ops.device(device), _ => | |||
{ | |||
sharded_saves.Add(saver.save(shard_prefix, options)); | |||
}); | |||
} | |||
using (var controller = ops.control_dependencies(sharded_saves.ToArray())) | |||
{ | |||
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | |||
tf.device(merge_device); | |||
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||
return tf_with(ops.device(merge_device), _ => | |||
{ | |||
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||
}); | |||
} | |||
} | |||
@@ -407,54 +420,56 @@ namespace Tensorflow.Checkpoint | |||
{ | |||
var device = single_saver.Key; | |||
var saver = single_saver.Value; | |||
tf.device(device); | |||
var restored_tensor_dict = saver.restore(file_prefix, options); | |||
foreach(var pair in restored_tensor_dict) | |||
tf_with(ops.device(device), _ => | |||
{ | |||
var checkpoint_key = pair.Key; | |||
var slice_and_tensor = pair.Value; | |||
foreach(var item in slice_and_tensor) | |||
var restored_tensor_dict = saver.restore(file_prefix, options); | |||
foreach (var pair in restored_tensor_dict) | |||
{ | |||
var slice_spec = item.Key; | |||
var tensor = item.Value; | |||
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||
if (!string.IsNullOrEmpty(slice_spec)) | |||
var checkpoint_key = pair.Key; | |||
var slice_and_tensor = pair.Value; | |||
foreach (var item in slice_and_tensor) | |||
{ | |||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||
var slice_spec = item.Key; | |||
var tensor = item.Value; | |||
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||
if (!string.IsNullOrEmpty(slice_spec)) | |||
{ | |||
Dictionary<string, Tensor> dict = new(); | |||
dict[slice_spec] = tensor; | |||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||
if (!internal_dict.ContainsKey(checkpoint_key)) | |||
{ | |||
Dictionary<string, Tensor> dict = new(); | |||
dict[slice_spec] = tensor; | |||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||
} | |||
else | |||
{ | |||
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||
} | |||
} | |||
else | |||
{ | |||
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||
} | |||
} | |||
else | |||
{ | |||
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||
} | |||
restore_fn_input_count[restore_fn]--; | |||
restore_fn_input_count[restore_fn]--; | |||
if (restore_fn_input_count[restore_fn] == 0) | |||
{ | |||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
foreach(var input in restore_fn_inputs[restore_fn]) | |||
if (restore_fn_input_count[restore_fn] == 0) | |||
{ | |||
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||
} | |||
var ret = restore_fn.DynamicInvoke(restored_tensors); | |||
if(ret is IDictionary<string, Operation>) | |||
{ | |||
var dict = (IDictionary<string, Operation>)ret; | |||
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
foreach (var input in restore_fn_inputs[restore_fn]) | |||
{ | |||
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||
} | |||
var ret = restore_fn.DynamicInvoke(restored_tensors); | |||
if (ret is IDictionary<string, Operation>) | |||
{ | |||
var dict = (IDictionary<string, Operation>)ret; | |||
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
}); | |||
} | |||
foreach(var item in _registered_savers) | |||
@@ -500,21 +515,25 @@ namespace Tensorflow.Checkpoint | |||
private Tensor _traced_save(Tensor file_prefix) | |||
{ | |||
var save_op = save(file_prefix); | |||
tf.device("cpu:0"); | |||
using (ops.control_dependencies(new object[]{ save_op })) | |||
return tf_with(ops.device("cpu:0"), _ => | |||
{ | |||
return array_ops.identity(file_prefix); | |||
} | |||
return tf_with(ops.control_dependencies(new object[] { save_op }), __ => | |||
{ | |||
return array_ops.identity(file_prefix); | |||
}); | |||
}); | |||
} | |||
private Tensor _traced_restore(Tensor file_prefix) | |||
{ | |||
var restore_op = restore(file_prefix); | |||
tf.device("cpu:0"); | |||
using (ops.control_dependencies(restore_op.Values.ToArray())) | |||
return tf_with(ops.device("cpu:0"), _ => | |||
{ | |||
return array_ops.identity(file_prefix); | |||
} | |||
return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ => | |||
{ | |||
return array_ops.identity(file_prefix); | |||
}); | |||
}); | |||
} | |||
public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | |||
@@ -21,6 +21,7 @@ using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
using Google.Protobuf; | |||
using Tensorflow.Device; | |||
using Tensorflow.Exceptions; | |||
using System.Collections.Generic; | |||
namespace Tensorflow.Contexts | |||
@@ -30,10 +31,30 @@ namespace Tensorflow.Contexts | |||
/// </summary> | |||
public sealed partial class Context | |||
{ | |||
internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new(); | |||
internal List<LogicalDevice> _logical_devices = null; | |||
internal List<string> _context_devices = null; | |||
ContextDevicePlacementPolicy _device_policy; | |||
bool _log_device_placement; | |||
int _num_gpus; | |||
Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>(); | |||
public string DeviceName { get; set; } = ""; | |||
public DeviceSpec DeviceSpec { get; set; } = null; | |||
internal List<string> Devices | |||
{ | |||
get | |||
{ | |||
if(_context_devices is null) | |||
{ | |||
throw new AssertionError("Context must be initialized first."); | |||
} | |||
return _context_devices; | |||
} | |||
} | |||
public void log_device_placement(bool enable) | |||
{ | |||
if (_handle != null) | |||
@@ -89,5 +110,57 @@ namespace Tensorflow.Contexts | |||
return results.ToArray(); | |||
} | |||
public EagerDeviceContext device(string name) | |||
{ | |||
return new EagerDeviceContext(this, name); | |||
} | |||
internal void _set_device(string device_name, DeviceSpec device_spec) | |||
{ | |||
DeviceSpec = device_spec; | |||
DeviceName = device_name; | |||
} | |||
internal void _initialize_logical_devices() | |||
{ | |||
List<LogicalDevice> logical_devices = new(); | |||
List<string> context_devices = new(); | |||
Status status = new(); | |||
var device_list = c_api.TFE_ContextListDevices(_handle, status); | |||
status.Check(true); | |||
try | |||
{ | |||
this._num_gpus = 0; | |||
string current_job = null; | |||
int current_task = -1; | |||
for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++) | |||
{ | |||
var dev_name = c_api.TF_DeviceListName(device_list, i, status); | |||
status.Check(true); | |||
context_devices.Add(DeviceUtils.canonical_name(dev_name)); | |||
var spec = DeviceSpec.from_string(dev_name); | |||
if(spec.Job == "localhost") | |||
{ | |||
spec = spec.replace(job: null, replica: -1, task: -1); | |||
} | |||
logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType)); | |||
var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status); | |||
var dev_type = c_api.StringPiece(dev_type_memory); | |||
status.Check(true); | |||
if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task) | |||
{ | |||
_num_gpus++; | |||
} | |||
} | |||
} | |||
finally | |||
{ | |||
_logical_devices = logical_devices; | |||
_context_devices = context_devices; | |||
} | |||
} | |||
} | |||
public record class LogicalDevice(string name, string device_type); | |||
} |
@@ -34,7 +34,6 @@ namespace Tensorflow.Contexts | |||
public const int EAGER_MODE = 1; | |||
int defaultExecutionMode = EAGER_MODE; | |||
public string DeviceName { get; set; } = ""; | |||
public string ScopeName { get; set; } = ""; | |||
bool initialized = false; | |||
ContextSwitchStack context_switches; | |||
@@ -62,6 +61,8 @@ namespace Tensorflow.Contexts | |||
if (initialized) | |||
return; | |||
Debug.Assert(_context_devices is null); | |||
Config = MergeConfig(); | |||
FunctionCallOptions.Config = Config; | |||
var config_str = Config.ToByteArray(); | |||
@@ -72,6 +73,7 @@ namespace Tensorflow.Contexts | |||
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); | |||
_handle = c_api.TFE_NewContext(opts, status); | |||
status.Check(true); | |||
_initialize_logical_devices(); | |||
initialized = true; | |||
} | |||
@@ -174,6 +176,7 @@ namespace Tensorflow.Contexts | |||
{ | |||
c_api.TFE_ContextClearCaches(_handle); | |||
} | |||
_device_parsing_cache.Clear(); | |||
} | |||
public static implicit operator SafeContextHandle(Context ctx) | |||
@@ -0,0 +1,71 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Device; | |||
namespace Tensorflow.Contexts | |||
{ | |||
public class EagerDeviceContext : ITensorFlowObject | |||
{ | |||
private Context _ctx; | |||
private string _device_name; | |||
private Stack<(string, DeviceSpec, DeviceSpec)> _stack; | |||
public EagerDeviceContext(Context ctx, string device_name) | |||
{ | |||
_ctx = ctx; | |||
_device_name = device_name; | |||
_stack = new Stack<(string, DeviceSpec, DeviceSpec)>(); | |||
} | |||
public void __enter__() | |||
{ | |||
var ctx = _ctx; | |||
var old_device_name = ctx.DeviceName; | |||
var old_device_spec = ctx.DeviceSpec; | |||
var new_device_name = _device_name; | |||
var cache_key = (old_device_name, new_device_name); | |||
DeviceSpec new_device_spec; | |||
if (Context._device_parsing_cache.ContainsKey(cache_key)) | |||
{ | |||
(new_device_name, new_device_spec) = Context._device_parsing_cache[cache_key]; | |||
} | |||
else | |||
{ | |||
if(new_device_name is not null) | |||
{ | |||
var device_spec = DeviceSpec.from_string(new_device_name); | |||
if (!string.IsNullOrEmpty(old_device_name)) | |||
{ | |||
new_device_spec = new DeviceSpec(old_device_spec); | |||
} | |||
else | |||
{ | |||
ctx.ensure_initialized(); | |||
new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); | |||
} | |||
new_device_spec = new_device_spec.make_merged_spec(device_spec); | |||
} | |||
else | |||
{ | |||
new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); | |||
} | |||
new_device_name = new_device_spec.ToString(); | |||
Context._device_parsing_cache[cache_key] = (new_device_name, new_device_spec); | |||
} | |||
ctx._set_device(new_device_name, new_device_spec); | |||
_stack.Push((old_device_name, old_device_spec, new_device_spec)); | |||
} | |||
public void __exit__() | |||
{ | |||
var ctx = _ctx; | |||
var (old_device_name, old_device_spec, new_device_spec) = _stack.Pop(); | |||
ctx._set_device(old_device_name, old_device_spec); | |||
} | |||
public void Dispose() | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,205 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
namespace Tensorflow.Device | |||
{ | |||
public class DeviceSpec | |||
{ | |||
private static Dictionary<string, Components> _STRING_TO_COMPONENTS_CACHE = new(); | |||
private static Dictionary<Components, string> _COMPONENTS_TO_STRING_CACHE = new(); | |||
private string _job; | |||
private int _replica; | |||
private int _task; | |||
private string _device_type; | |||
private int _device_index; | |||
private string _as_string; | |||
public string Job => _job; | |||
public int Replica => _replica; | |||
public int Task => _task; | |||
public string DeviceType => _device_type; | |||
public int DeviceIndex => _device_index; | |||
public DeviceSpec(string job = null, int replica = -1, int task = -1, | |||
string device_type = null, int device_index = -1) | |||
{ | |||
_job = job; | |||
_replica = replica; | |||
_task = task; | |||
_device_type = device_type; | |||
_device_index = device_index; | |||
_as_string = _components_to_string(job, replica, task, device_type, _device_index); | |||
} | |||
public DeviceSpec(DeviceSpec other) | |||
{ | |||
_job = other._job; | |||
_replica = other._replica; | |||
_task = other._task; | |||
_device_type = other._device_type; | |||
_device_index = other._device_index; | |||
_as_string = other._as_string; | |||
} | |||
protected DeviceSpec(Components com) | |||
{ | |||
_job = com.Job; | |||
_replica = com.Replica; | |||
_task = com.Task; | |||
_device_type = com.DeviceType; | |||
_device_index = com.DeviceIndex; | |||
_as_string = _components_to_string(_job, _replica, _task, _device_type, _device_index); | |||
} | |||
public DeviceSpec replace(string job = null, int replica = -1, int task = -1, | |||
string device_type = null, int device_index = -1) | |||
{ | |||
job = job ?? _job; | |||
replica = replica == -1 ? _replica : replica; | |||
task = task == -1 ? _task : task; | |||
device_type = device_type ?? _device_type; | |||
device_index = device_index == -1 ? _device_index : device_index; | |||
return new DeviceSpec(job, replica, task, device_type, device_index); | |||
} | |||
public static DeviceSpec from_string(string spec) | |||
{ | |||
var components = _string_to_components(spec); | |||
return new DeviceSpec(components.Job, components.Replica, components.Task, components.DeviceType, components.DeviceIndex); | |||
} | |||
public DeviceSpec make_merged_spec(DeviceSpec dev) | |||
{ | |||
return new DeviceSpec(_get_combined_properties(dev)); | |||
} | |||
private Components _get_combined_properties(DeviceSpec dev) | |||
{ | |||
return new Components( | |||
dev.Job ?? _job, | |||
dev.Replica == -1 ? _replica : dev.Replica, | |||
dev.Task == -1 ? _task : dev.Task, | |||
dev.DeviceType ?? _device_type, | |||
dev.DeviceIndex == -1 ? _device_index : dev.DeviceIndex | |||
); | |||
} | |||
private static string _components_to_string(string job, int replica, int task, string device_type, int device_index) | |||
{ | |||
var key = new Components(job, replica, task, device_type, device_index); | |||
if(_COMPONENTS_TO_STRING_CACHE.TryGetValue(key, out var cache_result)) | |||
{ | |||
return cache_result; | |||
} | |||
StringBuilder output = new(); | |||
if(job is not null) | |||
{ | |||
output.Append($"/job:{job}"); | |||
} | |||
if(replica != -1) | |||
{ | |||
output.Append($"/replica:{replica}"); | |||
} | |||
if(task != -1) | |||
{ | |||
output.Append($"/task:{task}"); | |||
} | |||
if (device_type is not null) | |||
{ | |||
string device_index_string = "*"; | |||
if (device_index != -1) | |||
{ | |||
device_index_string = device_index.ToString(); | |||
} | |||
output.Append($"/device:{device_type}:{device_index_string}"); | |||
} | |||
var result = output.ToString(); | |||
_COMPONENTS_TO_STRING_CACHE[key] = result; | |||
return result; | |||
} | |||
private static Components _string_to_components(string spec) | |||
{ | |||
if(_STRING_TO_COMPONENTS_CACHE.TryGetValue(spec, out var cached_result)) | |||
{ | |||
return cached_result; | |||
} | |||
var raw_spec = spec; | |||
var splits = spec.Split('/').Select(x => x.Split(':')); | |||
var valid_device_types = _get_valid_device_types(); | |||
string job = null, device_type = null; | |||
int replica = -1, task = -1, device_index = -1; | |||
foreach (var y in splits) | |||
{ | |||
var ly = y.Length; | |||
if (ly > 0) | |||
{ | |||
if(ly == 2 && y[0] == "job") | |||
{ | |||
job = y[1]; | |||
} | |||
else if(ly == 2 && y[0] == "replica") | |||
{ | |||
replica = int.Parse(y[1]); | |||
} | |||
else if(ly == 2 && y[0] == "task") | |||
{ | |||
task = int.Parse(y[1]); | |||
} | |||
else if((ly == 1 || ly == 2) && valid_device_types.Contains(y[0].ToUpper())) | |||
{ | |||
if (device_type is not null) | |||
{ | |||
throw new ValueError($"Multiple device types are not allowed " + | |||
$"while parsing the device spec: {spec}."); | |||
} | |||
device_type = y[0].ToUpper(); | |||
if(ly == 2 && y[1] != "*") | |||
{ | |||
device_index = int.Parse(y[1]); | |||
} | |||
} | |||
else if(ly == 3 && y[0] == "device") | |||
{ | |||
if(device_type is not null) | |||
{ | |||
throw new ValueError($"Multiple device types are not allowed " + | |||
$"while parsing the device spec: {spec}."); | |||
} | |||
device_type = y[1]; | |||
if (y[2] != "*") | |||
{ | |||
device_index = int.Parse(y[2]); | |||
} | |||
} | |||
else if (y[0] != "") | |||
{ | |||
throw new ValueError($"Unknown attribute '{y[0]}' is encountered " + | |||
$"while parsing the device spec: {spec}."); | |||
} | |||
} | |||
} | |||
var output = new Components(job, replica, task, device_type, device_index); | |||
_STRING_TO_COMPONENTS_CACHE[raw_spec] = output; | |||
return output; | |||
} | |||
private static HashSet<string> _get_valid_device_types() | |||
{ | |||
// TODO(Rinne): revise it to calling C API (need customized API). | |||
return new HashSet<string>(new string[] { "CPU", "GPU" }); | |||
} | |||
public override string ToString() | |||
{ | |||
return _as_string; | |||
} | |||
protected record class Components(string Job, int Replica, int Task, string DeviceType, int DeviceIndex); | |||
} | |||
} |
@@ -0,0 +1,26 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Device | |||
{ | |||
internal static class DeviceUtils | |||
{ | |||
public static string canonical_name(string device) | |||
{ | |||
if(device is null) | |||
{ | |||
return ""; | |||
} | |||
return DeviceSpec.from_string(device).ToString(); | |||
} | |||
public static string canonical_name(DeviceSpec device) | |||
{ | |||
if (device is null) | |||
{ | |||
return ""; | |||
} | |||
return device.ToString(); | |||
} | |||
} | |||
} |
@@ -19,6 +19,7 @@ using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Collections.Specialized; | |||
using System.Linq; | |||
using Tensorflow.Graphs; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -294,9 +295,15 @@ namespace Tensorflow | |||
return op; | |||
} | |||
public void device(string device_name) | |||
public ITensorFlowObject device(string device_name) | |||
{ | |||
return new GraphDeviceContext(this, device_name); | |||
} | |||
private void add_device_to_stack(string device_name, int offset = 0) | |||
{ | |||
// TODO(Rinne): deal with device spec. | |||
int total_offset = offset + 1; | |||
} | |||
private void _create_op_helper(Operation op, bool compute_device = true) | |||
@@ -0,0 +1,31 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Graphs | |||
{ | |||
public class GraphDeviceContext : ITensorFlowObject | |||
{ | |||
private Graph _graph; | |||
public GraphDeviceContext(Graph graph, string device_name) | |||
{ | |||
_graph = graph; | |||
} | |||
public void __enter__() | |||
{ | |||
} | |||
public void __exit__() | |||
{ | |||
} | |||
public void Dispose() | |||
{ | |||
} | |||
} | |||
} |
@@ -42,16 +42,20 @@ namespace Tensorflow | |||
_var_device = var.Device; | |||
_var_shape = var.shape; | |||
Tensor _read_variable_closure(BaseResourceVariable v) | |||
Tensor? _read_variable_closure(BaseResourceVariable v) | |||
{ | |||
tf.device(v.Device); | |||
if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||
return tf_with(ops.device(v.Device), _ => | |||
{ | |||
return null; | |||
} | |||
var x = v.read_value_no_copy(); | |||
tf.device("/device:CPU:0"); | |||
return array_ops.identity(x); | |||
if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||
{ | |||
return null; | |||
} | |||
var x = v.read_value_no_copy(); | |||
return tf_with(ops.device("/device:CPU:0"), __ => | |||
{ | |||
return array_ops.identity(x); | |||
}); | |||
}); | |||
} | |||
this.handle_op = var.Handle; | |||
@@ -412,8 +412,10 @@ namespace Tensorflow | |||
{ | |||
var variables_path = SavedModelUtils.get_variables_path(_export_dir); | |||
var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||
tf.device("CPU"); | |||
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
tf_with(ops.device("CPU"), _ => | |||
{ | |||
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
}); | |||
LoadStatus load_status; | |||
if (_save_options.allow_partial_checkpoint) | |||
{ | |||
@@ -598,14 +600,16 @@ namespace Tensorflow | |||
if (load_with_device) | |||
{ | |||
tf.device(saved_device); | |||
return (new UninitializedVariable( | |||
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
dtype: (TF_DataType)proto.Dtype, | |||
name: name, | |||
trainable: trainable, | |||
aggregation: aggregation | |||
), setattr); | |||
return tf_with(ops.device(saved_device), _ => | |||
{ | |||
return (new UninitializedVariable( | |||
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
dtype: (TF_DataType)proto.Dtype, | |||
name: name, | |||
trainable: trainable, | |||
aggregation: aggregation | |||
), setattr); | |||
}); | |||
} | |||
else | |||
{ | |||
@@ -266,9 +266,11 @@ namespace Tensorflow | |||
BaseResourceVariable new_variable; | |||
if (save_options.experimental_variable_policy.save_variable_devices()) | |||
{ | |||
tf.device(this.Device); | |||
Debug.Assert(this is ResourceVariable); | |||
new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||
new_variable = tf_with(ops.device(this.Device), _ => | |||
{ | |||
return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||
}); | |||
} | |||
else | |||
{ | |||
@@ -49,9 +49,12 @@ namespace Tensorflow.Variables | |||
{ | |||
tf_with(ops.name_scope("Read"), _ => | |||
{ | |||
tf.device(handle.Device); | |||
var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
// _maybe_set_handle_data(dtype, handle, value) | |||
var value = tf_with(ops.device(handle.Device), _ => | |||
{ | |||
var result = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
// TODO(Rinne): _maybe_set_handle_data(dtype, handle, value) | |||
return result; | |||
}); | |||
_graph_element = value; | |||
}); | |||
ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | |||
@@ -577,6 +577,23 @@ namespace Tensorflow | |||
} | |||
public static ITensorFlowObject device(string device_name) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
return tf.Context.device(device_name); | |||
} | |||
//else if (ops.executing_eagerly_outside_functions()) | |||
//{ | |||
// throw new NotImplementedException(); | |||
//} | |||
else | |||
{ | |||
return get_default_graph().device(device_name); | |||
} | |||
// TODO(Rinne): deal with `ops.executing_eagerly_outside_functions()`. | |||
} | |||
public class NullContextManager: IDisposable | |||
{ | |||
public void Dispose() | |||