@@ -23,6 +23,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", | |||||
EndProject | EndProject | ||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" | ||||
EndProject | EndProject | ||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorflowNET.Hub\Tensorflow.Hub.csproj", "{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}" | |||||
EndProject | |||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub.Unittest", "test\TensorflowNET.Hub.Unittest\Tensorflow.Hub.Unittest.csproj", "{7DEA8760-E401-4872-81F3-405F185A13A0}" | |||||
EndProject | |||||
Global | Global | ||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
@@ -153,6 +157,30 @@ Global | |||||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 | {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 | ||||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU | {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU | ||||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU | {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU | ||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.Build.0 = Debug|Any CPU | |||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.ActiveCfg = Debug|Any CPU | |||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.Build.0 = Debug|Any CPU | |||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.ActiveCfg = Release|Any CPU | |||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.Build.0 = Release|Any CPU | |||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.ActiveCfg = Release|Any CPU | |||||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.Build.0 = Release|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.Build.0 = Debug|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.ActiveCfg = Debug|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.Build.0 = Debug|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.ActiveCfg = Release|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.Build.0 = Release|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.ActiveCfg = Release|Any CPU | |||||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.Build.0 = Release|Any CPU | |||||
EndGlobalSection | EndGlobalSection | ||||
GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
@@ -207,9 +207,24 @@ namespace Tensorflow | |||||
} | } | ||||
public override string ToString() | public override string ToString() | ||||
=> items.Count() == 1 | |||||
? items.First().ToString() | |||||
: items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); | |||||
{ | |||||
if(items.Count == 1) | |||||
{ | |||||
return items[0].ToString(); | |||||
} | |||||
else | |||||
{ | |||||
StringBuilder sb = new StringBuilder(); | |||||
sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); | |||||
for(int i = 0; i < items.Count; i++) | |||||
{ | |||||
var tensor = items[i]; | |||||
sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); | |||||
} | |||||
sb.Append("]\n"); | |||||
return sb.ToString(); | |||||
} | |||||
} | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
@@ -301,6 +301,17 @@ namespace Tensorflow | |||||
type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; | type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; | ||||
} | } | ||||
public static bool is_unsigned(this TF_DataType type) | |||||
{ | |||||
return type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || | |||||
type == TF_DataType.TF_UINT64; | |||||
} | |||||
public static bool is_bool(this TF_DataType type) | |||||
{ | |||||
return type == TF_DataType.TF_BOOL; | |||||
} | |||||
public static bool is_floating(this TF_DataType type) | public static bool is_floating(this TF_DataType type) | ||||
{ | { | ||||
return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE; | return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE; | ||||
@@ -22,9 +22,9 @@ namespace Tensorflow.Keras.Engine | |||||
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer | // If dtype is DT_FLOAT, provide a uniform unit scaling initializer | ||||
if (dtype.is_floating()) | if (dtype.is_floating()) | ||||
initializer = tf.glorot_uniform_initializer; | initializer = tf.glorot_uniform_initializer; | ||||
else if (dtype.is_integer()) | |||||
else if (dtype.is_integer() || dtype.is_unsigned() || dtype.is_bool()) | |||||
initializer = tf.zeros_initializer; | initializer = tf.zeros_initializer; | ||||
else | |||||
else if(getter is null) | |||||
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | ||||
} | } | ||||
@@ -36,5 +36,9 @@ namespace Tensorflow.Keras.Saving | |||||
public bool? Stateful { get; set; } | public bool? Stateful { get; set; } | ||||
[JsonProperty("model_config")] | [JsonProperty("model_config")] | ||||
public KerasModelConfig? ModelConfig { get; set; } | public KerasModelConfig? ModelConfig { get; set; } | ||||
[JsonProperty("sparse")] | |||||
public bool Sparse { get; set; } | |||||
[JsonProperty("ragged")] | |||||
public bool Ragged { get; set; } | |||||
} | } | ||||
} | } |
@@ -401,13 +401,22 @@ namespace Tensorflow.Keras.Saving | |||||
private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata) | private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata) | ||||
{ | { | ||||
if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||||
if (identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||||
{ | { | ||||
return RevivedLayer.init_from_metadata(metadata); | return RevivedLayer.init_from_metadata(metadata); | ||||
} | } | ||||
else if(identifier == SavedModel.Constants.MODEL_IDENTIFIER || identifier == SavedModel.Constants.SEQUENTIAL_IDENTIFIER | |||||
|| identifier == SavedModel.Constants.NETWORK_IDENTIFIER) | |||||
{ | |||||
return RevivedNetwork.init_from_metadata(metadata); | |||||
} | |||||
else if(identifier == SavedModel.Constants.INPUT_LAYER_IDENTIFIER) | |||||
{ | |||||
return RevivedInputLayer.init_from_metadata(metadata); | |||||
} | |||||
else | else | ||||
{ | { | ||||
throw new NotImplementedException(); | |||||
throw new ValueError($"Cannot revive the layer {identifier}."); | |||||
} | } | ||||
} | } | ||||
@@ -1,15 +1,46 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers; | |||||
namespace Tensorflow.Keras.Saving.SavedModel | namespace Tensorflow.Keras.Saving.SavedModel | ||||
{ | { | ||||
public class RevivedInputLayer: Layer | |||||
public class RevivedInputLayer: InputLayer | |||||
{ | { | ||||
private RevivedInputLayer(): base(null) | |||||
protected RevivedConfig _config = null; | |||||
private RevivedInputLayer(InputLayerArgs args): base(args) | |||||
{ | { | ||||
throw new NotImplementedException(); | |||||
} | |||||
public override IKerasConfig get_config() | |||||
{ | |||||
return _config; | |||||
} | |||||
public static (RevivedInputLayer, Action<object, object, object>) init_from_metadata(KerasMetaData metadata) | |||||
{ | |||||
InputLayerArgs args = new InputLayerArgs() | |||||
{ | |||||
Name = metadata.Name, | |||||
DType = metadata.DType, | |||||
Sparse = metadata.Sparse, | |||||
Ragged = metadata.Ragged, | |||||
BatchInputShape = metadata.BatchInputShape | |||||
}; | |||||
RevivedInputLayer revived_obj = new RevivedInputLayer(args); | |||||
revived_obj._config = new RevivedConfig() { Config = metadata.Config }; | |||||
return (revived_obj, Loader.setattr); | |||||
} | |||||
public override string ToString() | |||||
{ | |||||
return $"Customized keras input layer: {Name}."; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
return (revived_obj, ReviveUtils._revive_setter); | return (revived_obj, ReviveUtils._revive_setter); | ||||
} | } | ||||
private RevivedConfig _config = null; | |||||
protected RevivedConfig _config = null; | |||||
public object keras_api | public object keras_api | ||||
{ | { | ||||
@@ -70,7 +70,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
} | } | ||||
} | } | ||||
public RevivedLayer(LayerArgs args): base(args) | |||||
protected RevivedLayer(LayerArgs args): base(args) | |||||
{ | { | ||||
} | } | ||||
@@ -84,17 +84,5 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
{ | { | ||||
return _config; | return _config; | ||||
} | } | ||||
//protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
//{ | |||||
// if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||||
// { | |||||
// return base.Call(inputs, state, training); | |||||
// } | |||||
// else | |||||
// { | |||||
// return (func as Function).Apply(inputs); | |||||
// } | |||||
//} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,40 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Utils; | |||||
namespace Tensorflow.Keras.Saving.SavedModel | |||||
{ | |||||
public class RevivedNetwork: RevivedLayer | |||||
{ | |||||
private RevivedNetwork(LayerArgs args) : base(args) | |||||
{ | |||||
} | |||||
public static (RevivedNetwork, Action<object, object, object>) init_from_metadata(KerasMetaData metadata) | |||||
{ | |||||
RevivedNetwork revived_obj = new(new LayerArgs() { Name = metadata.Name }); | |||||
// TODO(Rinne): with utils.no_automatic_dependency_tracking_scope(revived_obj) | |||||
// TODO(Rinne): revived_obj._expects_training_arg | |||||
var config = metadata.Config; | |||||
if (generic_utils.validate_config(config)) | |||||
{ | |||||
revived_obj._config = new RevivedConfig() { Config = config }; | |||||
} | |||||
if(metadata.ActivityRegularizer is not null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return (revived_obj, ReviveUtils._revive_setter); | |||||
} | |||||
public override string ToString() | |||||
{ | |||||
return $"Customized keras Network: {Name}."; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,57 @@ | |||||
using System.IO; | |||||
using System.Threading.Tasks; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class GcsCompressedFileResolver : IResolver | |||||
{ | |||||
const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; | |||||
public string Call(string handle) | |||||
{ | |||||
var module_dir = _module_dir(handle); | |||||
return resolver.atomic_download_async(handle, download, module_dir, LOCK_FILE_TIMEOUT_SEC) | |||||
.GetAwaiter().GetResult(); | |||||
} | |||||
public bool IsSupported(string handle) | |||||
{ | |||||
return handle.StartsWith("gs://") && _is_tarfile(handle); | |||||
} | |||||
private async Task download(string handle, string tmp_dir) | |||||
{ | |||||
new resolver.DownloadManager(handle).download_and_uncompress( | |||||
new FileStream(handle, FileMode.Open, FileAccess.Read), tmp_dir); | |||||
await Task.Run(() => { }); | |||||
} | |||||
private static string _module_dir(string handle) | |||||
{ | |||||
var cache_dir = resolver.tfhub_cache_dir(use_temp: true); | |||||
var sha1 = ComputeSha1(handle); | |||||
return resolver.create_local_module_dir(cache_dir, sha1); | |||||
} | |||||
private static bool _is_tarfile(string filename) | |||||
{ | |||||
return filename.EndsWith(".tar") || filename.EndsWith(".tar.gz") || filename.EndsWith(".tgz"); | |||||
} | |||||
private static string ComputeSha1(string s) | |||||
{ | |||||
using (var sha = new System.Security.Cryptography.SHA1Managed()) | |||||
{ | |||||
var bytes = System.Text.Encoding.UTF8.GetBytes(s); | |||||
var hash = sha.ComputeHash(bytes); | |||||
var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); | |||||
foreach (var b in hash) | |||||
{ | |||||
stringBuilder.Append(b.ToString("x2")); | |||||
} | |||||
return stringBuilder.ToString(); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,78 @@ | |||||
using System; | |||||
using System.Net.Http; | |||||
using System.Threading.Tasks; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class HttpCompressedFileResolver : HttpResolverBase | |||||
{ | |||||
const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; // 10 minutes | |||||
private static readonly (string, string) _COMPRESSED_FORMAT_QUERY = | |||||
("tf-hub-format", "compressed"); | |||||
private static string _module_dir(string handle) | |||||
{ | |||||
var cache_dir = resolver.tfhub_cache_dir(use_temp: true); | |||||
var sha1 = ComputeSha1(handle); | |||||
return resolver.create_local_module_dir(cache_dir, sha1); | |||||
} | |||||
public override bool IsSupported(string handle) | |||||
{ | |||||
if (!is_http_protocol(handle)) | |||||
{ | |||||
return false; | |||||
} | |||||
var load_format = resolver.model_load_format(); | |||||
return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.COMPRESSED) | |||||
|| load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.AUTO); | |||||
} | |||||
public override string Call(string handle) | |||||
{ | |||||
var module_dir = _module_dir(handle); | |||||
return resolver.atomic_download_async( | |||||
handle, | |||||
download, | |||||
module_dir, | |||||
LOCK_FILE_TIMEOUT_SEC | |||||
).GetAwaiter().GetResult(); | |||||
} | |||||
private async Task download(string handle, string tmp_dir) | |||||
{ | |||||
var client = new HttpClient(); | |||||
var response = await client.GetAsync(_append_compressed_format_query(handle)); | |||||
using (var httpStream = await response.Content.ReadAsStreamAsync()) | |||||
{ | |||||
new resolver.DownloadManager(handle).download_and_uncompress(httpStream, tmp_dir); | |||||
} | |||||
} | |||||
private string _append_compressed_format_query(string handle) | |||||
{ | |||||
return append_format_query(handle, _COMPRESSED_FORMAT_QUERY); | |||||
} | |||||
private static string ComputeSha1(string s) | |||||
{ | |||||
using (var sha = new System.Security.Cryptography.SHA1Managed()) | |||||
{ | |||||
var bytes = System.Text.Encoding.UTF8.GetBytes(s); | |||||
var hash = sha.ComputeHash(bytes); | |||||
var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); | |||||
foreach (var b in hash) | |||||
{ | |||||
stringBuilder.Append(b.ToString("x2")); | |||||
} | |||||
return stringBuilder.ToString(); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,65 @@ | |||||
using System; | |||||
using System.Net; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class HttpUncompressedFileResolver : HttpResolverBase | |||||
{ | |||||
private readonly PathResolver _pathResolver; | |||||
public HttpUncompressedFileResolver() | |||||
{ | |||||
_pathResolver = new PathResolver(); | |||||
} | |||||
public override string Call(string handle) | |||||
{ | |||||
handle = AppendUncompressedFormatQuery(handle); | |||||
var gsLocation = RequestGcsLocation(handle); | |||||
return _pathResolver.Call(gsLocation); | |||||
} | |||||
public override bool IsSupported(string handle) | |||||
{ | |||||
if (!is_http_protocol(handle)) | |||||
{ | |||||
return false; | |||||
} | |||||
var load_format = resolver.model_load_format(); | |||||
return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.UNCOMPRESSED); | |||||
} | |||||
protected virtual string AppendUncompressedFormatQuery(string handle) | |||||
{ | |||||
return append_format_query(handle, ("tf-hub-format", "uncompressed")); | |||||
} | |||||
protected virtual string RequestGcsLocation(string handleWithParams) | |||||
{ | |||||
var request = WebRequest.Create(handleWithParams); | |||||
var response = request.GetResponse() as HttpWebResponse; | |||||
if (response == null) | |||||
{ | |||||
throw new Exception("Failed to get a response from the server."); | |||||
} | |||||
var statusCode = (int)response.StatusCode; | |||||
if (statusCode != 303) | |||||
{ | |||||
throw new Exception($"Expected 303 for GCS location lookup but got HTTP {statusCode} {response.StatusDescription}"); | |||||
} | |||||
var location = response.Headers["Location"]; | |||||
if (!location.StartsWith("gs://")) | |||||
{ | |||||
throw new Exception($"Expected Location:GS path but received {location}"); | |||||
} | |||||
return location; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,157 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Train; | |||||
using Tensorflow.Training; | |||||
using Tensorflow.Training.Saving.SavedModel; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class KerasLayer : Layer | |||||
{ | |||||
private string _handle; | |||||
private LoadOptions? _load_options; | |||||
private Trackable _func; | |||||
private Func<Tensors, Tensors> _callable; | |||||
public KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) : | |||||
base(new Keras.ArgsDefinition.LayerArgs() { Trainable = trainable }) | |||||
{ | |||||
_handle = handle; | |||||
_load_options = load_options; | |||||
_func = load_module(_handle, _load_options); | |||||
_track_trackable(_func, "_func"); | |||||
// TODO(Rinne): deal with _is_hub_module_v1. | |||||
_callable = _get_callable(); | |||||
_setup_layer(trainable); | |||||
} | |||||
private void _setup_layer(bool trainable = false) | |||||
{ | |||||
HashSet<string> trainable_variables; | |||||
if (_func is Layer layer) | |||||
{ | |||||
foreach (var v in layer.TrainableVariables) | |||||
{ | |||||
_add_existing_weight(v, true); | |||||
} | |||||
trainable_variables = new HashSet<string>(layer.TrainableVariables.Select(v => v.UniqueId)); | |||||
} | |||||
else if (_func.CustomizedFields.TryGetValue("trainable_variables", out var obj) && obj is IEnumerable<Trackable> trackables) | |||||
{ | |||||
foreach (var trackable in trackables) | |||||
{ | |||||
if (trackable is IVariableV1 v) | |||||
{ | |||||
_add_existing_weight(v, true); | |||||
} | |||||
} | |||||
trainable_variables = new HashSet<string>(trackables.Where(t => t is IVariableV1).Select(t => (t as IVariableV1).UniqueId)); | |||||
} | |||||
else | |||||
{ | |||||
trainable_variables = new HashSet<string>(); | |||||
} | |||||
if (_func is Layer) | |||||
{ | |||||
layer = (Layer)_func; | |||||
foreach (var v in layer.Variables) | |||||
{ | |||||
if (!trainable_variables.Contains(v.UniqueId)) | |||||
{ | |||||
_add_existing_weight(v, false); | |||||
} | |||||
} | |||||
} | |||||
else if (_func.CustomizedFields.TryGetValue("variables", out var obj) && obj is IEnumerable<Trackable> total_trackables) | |||||
{ | |||||
foreach (var trackable in total_trackables) | |||||
{ | |||||
if (trackable is IVariableV1 v && !trainable_variables.Contains(v.UniqueId)) | |||||
{ | |||||
_add_existing_weight(v, false); | |||||
} | |||||
} | |||||
} | |||||
if (_func.CustomizedFields.ContainsKey("regularization_losses")) | |||||
{ | |||||
if ((_func.CustomizedFields["regularization_losses"] as ListWrapper)?.Count > 0) | |||||
{ | |||||
throw new NotImplementedException("The regularization_losses loading has not been supported yet, " + | |||||
"please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues to let us know and add a feature."); | |||||
} | |||||
} | |||||
} | |||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
{ | |||||
_check_trainability(); | |||||
// TODO(Rinne): deal with training_argument | |||||
var result = _callable(inputs); | |||||
return _apply_output_shape_if_set(inputs, result); | |||||
} | |||||
private void _check_trainability() | |||||
{ | |||||
if (!Trainable) return; | |||||
// TODO(Rinne): deal with _is_hub_module_v1 and signature | |||||
if (TrainableWeights is null || TrainableWeights.Count == 0) | |||||
{ | |||||
tf.Logger.Error("hub.KerasLayer is trainable but has zero trainable weights."); | |||||
} | |||||
} | |||||
private Tensors _apply_output_shape_if_set(Tensors inputs, Tensors result) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
return result; | |||||
} | |||||
private void _add_existing_weight(IVariableV1 weight, bool? trainable = null) | |||||
{ | |||||
bool is_trainable; | |||||
if (trainable is null) | |||||
{ | |||||
is_trainable = weight.Trainable; | |||||
} | |||||
else | |||||
{ | |||||
is_trainable = trainable.Value; | |||||
} | |||||
add_weight(weight.Name, weight.shape, weight.dtype, trainable: is_trainable, getter: x => weight); | |||||
} | |||||
private Func<Tensors, Tensors> _get_callable() | |||||
{ | |||||
if (_func is Layer layer) | |||||
{ | |||||
return x => layer.Apply(x); | |||||
} | |||||
if (_func.CustomizedFields.ContainsKey("__call__")) | |||||
{ | |||||
if (_func.CustomizedFields["__call__"] is RestoredFunction function) | |||||
{ | |||||
return x => function.Apply(x); | |||||
} | |||||
} | |||||
throw new ValueError("Cannot get the callable from the model."); | |||||
} | |||||
private static Trackable load_module(string handle, LoadOptions? load_options = null) | |||||
{ | |||||
//var set_load_options = load_options ?? LoadContext.get_load_option(); | |||||
return module_v2.load(handle, load_options); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,17 @@ | |||||
<Project Sdk="Microsoft.NET.Sdk"> | |||||
<PropertyGroup> | |||||
<TargetFrameworks>netstandard2.0;net6;net7</TargetFrameworks> | |||||
<LangVersion>11</LangVersion> | |||||
<Nullable>enable</Nullable> | |||||
</PropertyGroup> | |||||
<ItemGroup> | |||||
<PackageReference Include="SharpCompress" Version="0.33.0" /> | |||||
</ItemGroup> | |||||
<ItemGroup> | |||||
<ProjectReference Include="..\\TensorFlowNET.Keras\Tensorflow.Keras.csproj" /> | |||||
</ItemGroup> | |||||
</Project> |
@@ -0,0 +1,74 @@ | |||||
using SharpCompress.Common; | |||||
using SharpCompress.Readers; | |||||
using System; | |||||
using System.IO; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
internal static class file_utils | |||||
{ | |||||
//public static void extract_file(TarInputStream tgz, TarEntry tarInfo, string dstPath, uint bufferSize = 10 << 20, Action<long> logFunction = null) | |||||
//{ | |||||
// using (var src = tgz.GetNextEntry() == tarInfo ? tgz : null) | |||||
// { | |||||
// if (src is null) | |||||
// { | |||||
// return; | |||||
// } | |||||
// using (var dst = File.Create(dstPath)) | |||||
// { | |||||
// var buffer = new byte[bufferSize]; | |||||
// int count; | |||||
// while ((count = src.Read(buffer, 0, buffer.Length)) > 0) | |||||
// { | |||||
// dst.Write(buffer, 0, count); | |||||
// logFunction?.Invoke(count); | |||||
// } | |||||
// } | |||||
// } | |||||
//} | |||||
public static void extract_tarfile_to_destination(Stream fileobj, string dst_path, Action<long> logFunction = null) | |||||
{ | |||||
using (IReader reader = ReaderFactory.Open(fileobj)) | |||||
{ | |||||
while (reader.MoveToNextEntry()) | |||||
{ | |||||
if (!reader.Entry.IsDirectory) | |||||
{ | |||||
reader.WriteEntryToDirectory( | |||||
dst_path, | |||||
new ExtractionOptions() { ExtractFullPath = true, Overwrite = true } | |||||
); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
public static string merge_relative_path(string dstPath, string relPath) | |||||
{ | |||||
var cleanRelPath = Path.GetFullPath(relPath).TrimStart('/', '\\'); | |||||
if (cleanRelPath == ".") | |||||
{ | |||||
return dstPath; | |||||
} | |||||
if (cleanRelPath.StartsWith("..") || Path.IsPathRooted(cleanRelPath)) | |||||
{ | |||||
throw new InvalidDataException($"Relative path '{relPath}' is invalid."); | |||||
} | |||||
var merged = Path.Combine(dstPath, cleanRelPath); | |||||
if (!merged.StartsWith(dstPath)) | |||||
{ | |||||
throw new InvalidDataException($"Relative path '{relPath}' is invalid. Failed to merge with '{dstPath}'."); | |||||
} | |||||
return merged; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,17 @@ | |||||
using Tensorflow.Hub; | |||||
namespace Tensorflow | |||||
{ | |||||
public static class HubAPI | |||||
{ | |||||
public static HubMethods hub { get; } = new HubMethods(); | |||||
} | |||||
public class HubMethods | |||||
{ | |||||
public KerasLayer KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) | |||||
{ | |||||
return new KerasLayer(handle, trainable, load_options); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,33 @@ | |||||
using System.IO; | |||||
using Tensorflow.Train; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
internal static class module_v2 | |||||
{ | |||||
public static Trackable load(string handle, LoadOptions? options) | |||||
{ | |||||
var module_path = resolve(handle); | |||||
// TODO(Rinne): deal with is_hub_module_v1 | |||||
var saved_model_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PB); | |||||
var saved_model_pb_txt_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PBTXT); | |||||
if (!File.Exists(saved_model_path) && !Directory.Exists(saved_model_path) && !File.Exists(saved_model_pb_txt_path) | |||||
&& !Directory.Exists(saved_model_pb_txt_path)) | |||||
{ | |||||
throw new ValueError($"Trying to load a model of incompatible/unknown type. " + | |||||
$"'{module_path}' contains neither '{Constants.SAVED_MODEL_FILENAME_PB}' " + | |||||
$"nor '{Constants.SAVED_MODEL_FILENAME_PBTXT}'."); | |||||
} | |||||
var obj = Loader.load(module_path, options: options); | |||||
return obj; | |||||
} | |||||
public static string resolve(string handle) | |||||
{ | |||||
return MultiImplRegister.GetResolverRegister().Call(handle); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,55 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
internal class MultiImplRegister | |||||
{ | |||||
private static MultiImplRegister resolver = new MultiImplRegister("resolver", new IResolver[0]); | |||||
private static MultiImplRegister loader = new MultiImplRegister("loader", new IResolver[0]); | |||||
static MultiImplRegister() | |||||
{ | |||||
resolver.add_implementation(new PathResolver()); | |||||
resolver.add_implementation(new HttpUncompressedFileResolver()); | |||||
resolver.add_implementation(new GcsCompressedFileResolver()); | |||||
resolver.add_implementation(new HttpCompressedFileResolver()); | |||||
} | |||||
string _name; | |||||
List<IResolver> _impls; | |||||
public MultiImplRegister(string name, IEnumerable<IResolver> impls) | |||||
{ | |||||
_name = name; | |||||
_impls = impls.ToList(); | |||||
} | |||||
public void add_implementation(IResolver resolver) | |||||
{ | |||||
_impls.Add(resolver); | |||||
} | |||||
public string Call(string handle) | |||||
{ | |||||
foreach (var impl in _impls.Reverse<IResolver>()) | |||||
{ | |||||
if (impl.IsSupported(handle)) | |||||
{ | |||||
return impl.Call(handle); | |||||
} | |||||
} | |||||
throw new RuntimeError($"Cannot resolve the handle {handle}"); | |||||
} | |||||
public static MultiImplRegister GetResolverRegister() | |||||
{ | |||||
return resolver; | |||||
} | |||||
public static MultiImplRegister GetLoaderRegister() | |||||
{ | |||||
return loader; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,580 @@ | |||||
using ICSharpCode.SharpZipLib.Tar; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.ComponentModel; | |||||
using System.Diagnostics; | |||||
using System.IO; | |||||
using System.Linq; | |||||
using System.Net; | |||||
using System.Net.Http; | |||||
using System.Net.Security; | |||||
using System.Security.Authentication; | |||||
using System.Threading.Tasks; | |||||
using System.Web; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
internal static class resolver | |||||
{ | |||||
public enum ModelLoadFormat | |||||
{ | |||||
[Description("COMPRESSED")] | |||||
COMPRESSED, | |||||
[Description("UNCOMPRESSED")] | |||||
UNCOMPRESSED, | |||||
[Description("AUTO")] | |||||
AUTO | |||||
} | |||||
public class DownloadManager | |||||
{ | |||||
private readonly string _url; | |||||
private double _last_progress_msg_print_time; | |||||
private long _total_bytes_downloaded; | |||||
private int _max_prog_str; | |||||
private bool _interactive_mode() | |||||
{ | |||||
return !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("_TFHUB_DOWNLOAD_PROGRESS")); | |||||
} | |||||
private void _print_download_progress_msg(string msg, bool flush = false) | |||||
{ | |||||
if (_interactive_mode()) | |||||
{ | |||||
// Print progress message to console overwriting previous progress | |||||
// message. | |||||
_max_prog_str = Math.Max(_max_prog_str, msg.Length); | |||||
Console.Write($"\r{msg.PadRight(_max_prog_str)}"); | |||||
Console.Out.Flush(); | |||||
//如果flush参数为true,则输出换行符减少干扰交互式界面。 | |||||
if (flush) | |||||
Console.WriteLine(); | |||||
} | |||||
else | |||||
{ | |||||
// Interactive progress tracking is disabled. Print progress to the | |||||
// standard TF log. | |||||
tf.Logger.Information(msg); | |||||
} | |||||
} | |||||
private void _log_progress(long bytes_downloaded) | |||||
{ | |||||
// Logs progress information about ongoing module download. | |||||
_total_bytes_downloaded += bytes_downloaded; | |||||
var now = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; | |||||
if (_interactive_mode() || now - _last_progress_msg_print_time > 15) | |||||
{ | |||||
// Print progress message every 15 secs or if interactive progress | |||||
// tracking is enabled. | |||||
_print_download_progress_msg($"Downloading {_url}:" + | |||||
$"{tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true)}"); | |||||
_last_progress_msg_print_time = now; | |||||
} | |||||
} | |||||
public DownloadManager(string url) | |||||
{ | |||||
_url = url; | |||||
_last_progress_msg_print_time = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; | |||||
_total_bytes_downloaded = 0; | |||||
_max_prog_str = 0; | |||||
} | |||||
public void download_and_uncompress(Stream fileobj, string dst_path) | |||||
{ | |||||
// Streams the content for the 'fileobj' and stores the result in dst_path. | |||||
try | |||||
{ | |||||
file_utils.extract_tarfile_to_destination(fileobj, dst_path, _log_progress); | |||||
var total_size_str = tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true); | |||||
_print_download_progress_msg($"Downloaded {_url}, Total size: {total_size_str}", flush: true); | |||||
} | |||||
catch (TarException ex) | |||||
{ | |||||
throw new IOException($"{_url} does not appear to be a valid module. Inner message:{ex.Message}", ex); | |||||
} | |||||
} | |||||
} | |||||
private static Dictionary<string, string> _flags = new(); | |||||
private static readonly string _TFHUB_CACHE_DIR = "TFHUB_CACHE_DIR"; | |||||
private static readonly string _TFHUB_DOWNLOAD_PROGRESS = "TFHUB_DOWNLOAD_PROGRESS"; | |||||
private static readonly string _TFHUB_MODEL_LOAD_FORMAT = "TFHUB_MODEL_LOAD_FORMAT"; | |||||
private static readonly string _TFHUB_DISABLE_CERT_VALIDATION = "TFHUB_DISABLE_CERT_VALIDATION"; | |||||
private static readonly string _TFHUB_DISABLE_CERT_VALIDATION_VALUE = "true"; | |||||
static resolver() | |||||
{ | |||||
set_new_flag("tfhub_model_load_format", "AUTO"); | |||||
set_new_flag("tfhub_cache_dir", null); | |||||
} | |||||
public static string model_load_format() | |||||
{ | |||||
return get_env_setting(_TFHUB_MODEL_LOAD_FORMAT, "tfhub_model_load_format"); | |||||
} | |||||
public static string? get_env_setting(string env_var, string flag_name) | |||||
{ | |||||
string value = System.Environment.GetEnvironmentVariable(env_var); | |||||
if (string.IsNullOrEmpty(value)) | |||||
{ | |||||
if (_flags.ContainsKey(flag_name)) | |||||
{ | |||||
return _flags[flag_name]; | |||||
} | |||||
else | |||||
{ | |||||
return null; | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
return value; | |||||
} | |||||
} | |||||
public static string tfhub_cache_dir(string default_cache_dir = null, bool use_temp = false) | |||||
{ | |||||
var cache_dir = get_env_setting(_TFHUB_CACHE_DIR, "tfhub_cache_dir") ?? default_cache_dir; | |||||
if (string.IsNullOrWhiteSpace(cache_dir) && use_temp) | |||||
{ | |||||
// Place all TF-Hub modules under <system's temp>/tfhub_modules. | |||||
cache_dir = Path.Combine(Path.GetTempPath(), "tfhub_modules"); | |||||
} | |||||
if (!string.IsNullOrWhiteSpace(cache_dir)) | |||||
{ | |||||
Console.WriteLine("Using {0} to cache modules.", cache_dir); | |||||
} | |||||
return cache_dir; | |||||
} | |||||
public static string create_local_module_dir(string cache_dir, string module_name) | |||||
{ | |||||
Directory.CreateDirectory(cache_dir); | |||||
return Path.Combine(cache_dir, module_name); | |||||
} | |||||
public static void set_new_flag(string name, string value) | |||||
{ | |||||
string[] tokens = new string[] {_TFHUB_CACHE_DIR, _TFHUB_DISABLE_CERT_VALIDATION, | |||||
_TFHUB_DISABLE_CERT_VALIDATION_VALUE, _TFHUB_DOWNLOAD_PROGRESS, _TFHUB_MODEL_LOAD_FORMAT}; | |||||
if (!tokens.Contains(name)) | |||||
{ | |||||
tf.Logger.Warning($"You are settinng a flag '{name}' that cannot be recognized. The flag you set" + | |||||
"may not affect anything in tensorflow.hub."); | |||||
} | |||||
_flags[name] = value; | |||||
} | |||||
public static string _merge_relative_path(string dstPath, string relPath) | |||||
{ | |||||
return file_utils.merge_relative_path(dstPath, relPath); | |||||
} | |||||
public static string _module_descriptor_file(string moduleDir) | |||||
{ | |||||
return $"{moduleDir}.descriptor.txt"; | |||||
} | |||||
public static void _write_module_descriptor_file(string handle, string moduleDir) | |||||
{ | |||||
var readme = _module_descriptor_file(moduleDir); | |||||
var content = $"Module: {handle}\nDownload Time: {DateTime.Now}\nDownloader Hostname: {Environment.MachineName} (PID:{Process.GetCurrentProcess().Id})"; | |||||
tf_utils.atomic_write_string_to_file(readme, content, overwrite: true); | |||||
} | |||||
public static string _lock_file_contents(string task_uid) | |||||
{ | |||||
return $"{Environment.MachineName}.{Process.GetCurrentProcess().Id}.{task_uid}"; | |||||
} | |||||
public static string _lock_filename(string moduleDir) | |||||
{ | |||||
return tf_utils.absolute_path(moduleDir) + ".lock"; | |||||
} | |||||
private static string _module_dir(string lockFilename) | |||||
{ | |||||
var path = Path.GetDirectoryName(Path.GetFullPath(lockFilename)); | |||||
if (!string.IsNullOrEmpty(path)) | |||||
{ | |||||
return Path.Combine(path, "hub_modules"); | |||||
} | |||||
throw new Exception("Unable to resolve hub_modules directory from lock file name."); | |||||
} | |||||
private static string _task_uid_from_lock_file(string lockFilename) | |||||
{ | |||||
// Returns task UID of the task that created a given lock file. | |||||
var lockstring = File.ReadAllText(lockFilename); | |||||
return lockstring.Split('.').Last(); | |||||
} | |||||
private static string _temp_download_dir(string moduleDir, string taskUid) | |||||
{ | |||||
// Returns the name of a temporary directory to download module to. | |||||
return $"{Path.GetFullPath(moduleDir)}.{taskUid}.tmp"; | |||||
} | |||||
private static long _dir_size(string directory) | |||||
{ | |||||
// Returns total size (in bytes) of the given 'directory'. | |||||
long size = 0; | |||||
foreach (var elem in Directory.EnumerateFileSystemEntries(directory)) | |||||
{ | |||||
var stat = new FileInfo(elem); | |||||
size += stat.Length; | |||||
if ((stat.Attributes & FileAttributes.Directory) != 0) | |||||
size += _dir_size(stat.FullName); | |||||
} | |||||
return size; | |||||
} | |||||
public static long _locked_tmp_dir_size(string lockFilename) | |||||
{ | |||||
//Returns the size of the temp dir pointed to by the given lock file. | |||||
var taskUid = _task_uid_from_lock_file(lockFilename); | |||||
try | |||||
{ | |||||
return _dir_size(_temp_download_dir(_module_dir(lockFilename), taskUid)); | |||||
} | |||||
catch (DirectoryNotFoundException) | |||||
{ | |||||
return 0; | |||||
} | |||||
} | |||||
private static void _wait_for_lock_to_disappear(string handle, string lockFile, double lockFileTimeoutSec) | |||||
{ | |||||
long? lockedTmpDirSize = null; | |||||
var lockedTmpDirSizeCheckTime = DateTime.Now; | |||||
var lockFileContent = ""; | |||||
while (File.Exists(lockFile)) | |||||
{ | |||||
try | |||||
{ | |||||
Console.WriteLine($"Module '{handle}' already being downloaded by '{File.ReadAllText(lockFile)}'. Waiting."); | |||||
if ((DateTime.Now - lockedTmpDirSizeCheckTime).TotalSeconds > lockFileTimeoutSec) | |||||
{ | |||||
var curLockedTmpDirSize = _locked_tmp_dir_size(lockFile); | |||||
var curLockFileContent = File.ReadAllText(lockFile); | |||||
if (curLockedTmpDirSize == lockedTmpDirSize && curLockFileContent == lockFileContent) | |||||
{ | |||||
Console.WriteLine($"Deleting lock file {lockFile} due to inactivity."); | |||||
File.Delete(lockFile); | |||||
break; | |||||
} | |||||
lockedTmpDirSize = curLockedTmpDirSize; | |||||
lockedTmpDirSizeCheckTime = DateTime.Now; | |||||
lockFileContent = curLockFileContent; | |||||
} | |||||
} | |||||
catch (FileNotFoundException) | |||||
{ | |||||
// Lock file or temp directory were deleted during check. Continue | |||||
// to check whether download succeeded or we need to start our own | |||||
// download. | |||||
} | |||||
System.Threading.Thread.Sleep(5000); | |||||
} | |||||
} | |||||
public static async Task<string> atomic_download_async( | |||||
string handle, | |||||
Func<string, string, Task> downloadFn, | |||||
string moduleDir, | |||||
int lock_file_timeout_sec = 10 * 60) | |||||
{ | |||||
var lockFile = _lock_filename(moduleDir); | |||||
var taskUid = Guid.NewGuid().ToString("N"); | |||||
var lockContents = _lock_file_contents(taskUid); | |||||
var tmpDir = _temp_download_dir(moduleDir, taskUid); | |||||
// Function to check whether model has already been downloaded. | |||||
Func<bool> checkModuleExists = () => | |||||
Directory.Exists(moduleDir) && | |||||
Directory.EnumerateFileSystemEntries(moduleDir).Any(); | |||||
// Check whether the model has already been downloaded before locking | |||||
// the destination path. | |||||
if (checkModuleExists()) | |||||
{ | |||||
return moduleDir; | |||||
} | |||||
// Attempt to protect against cases of processes being cancelled with | |||||
// KeyboardInterrupt by using a try/finally clause to remove the lock | |||||
// and tmp_dir. | |||||
while (true) | |||||
{ | |||||
try | |||||
{ | |||||
tf_utils.atomic_write_string_to_file(lockFile, lockContents, false); | |||||
// Must test condition again, since another process could have created | |||||
// the module and deleted the old lock file since last test. | |||||
if (checkModuleExists()) | |||||
{ | |||||
// Lock file will be deleted in the finally-clause. | |||||
return moduleDir; | |||||
} | |||||
if (Directory.Exists(moduleDir)) | |||||
{ | |||||
Directory.Delete(moduleDir, true); | |||||
} | |||||
break; // Proceed to downloading the module. | |||||
} | |||||
// These errors are believed to be permanent problems with the | |||||
// module_dir that justify failing the download. | |||||
catch (FileNotFoundException) | |||||
{ | |||||
throw; | |||||
} | |||||
catch (UnauthorizedAccessException) | |||||
{ | |||||
throw; | |||||
} | |||||
catch (IOException) | |||||
{ | |||||
throw; | |||||
} | |||||
// All other errors are retried. | |||||
// TODO(b/144424849): Retrying an AlreadyExistsError from the atomic write | |||||
// should be good enough, but see discussion about misc filesystem types. | |||||
// TODO(b/144475403): How atomic is the overwrite=False check? | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
// Wait for lock file to disappear. | |||||
_wait_for_lock_to_disappear(handle, lockFile, lock_file_timeout_sec); | |||||
// At this point we either deleted a lock or a lock got removed by the | |||||
// owner or another process. Perform one more iteration of the while-loop, | |||||
// we would either terminate due tf.compat.v1.gfile.Exists(module_dir) or | |||||
// because we would obtain a lock ourselves, or wait again for the lock to | |||||
// disappear. | |||||
} | |||||
// Lock file acquired. | |||||
tf.Logger.Information($"Downloading TF-Hub Module '{handle}'..."); | |||||
Directory.CreateDirectory(tmpDir); | |||||
await downloadFn(handle, tmpDir); | |||||
// Write module descriptor to capture information about which module was | |||||
// downloaded by whom and when. The file stored at the same level as a | |||||
// directory in order to keep the content of the 'model_dir' exactly as it | |||||
// was define by the module publisher. | |||||
// | |||||
// Note: The descriptor is written purely to help the end-user to identify | |||||
// which directory belongs to which module. The descriptor is not part of the | |||||
// module caching protocol and no code in the TF-Hub library reads its | |||||
// content. | |||||
_write_module_descriptor_file(handle, moduleDir); | |||||
try | |||||
{ | |||||
Directory.Move(tmpDir, moduleDir); | |||||
Console.WriteLine($"Downloaded TF-Hub Module '{handle}'."); | |||||
} | |||||
catch (IOException e) | |||||
{ | |||||
Console.WriteLine(e.Message); | |||||
Console.WriteLine($"Failed to move {tmpDir} to {moduleDir}"); | |||||
// Keep the temp directory so we will retry building vocabulary later. | |||||
} | |||||
// Temp directory is owned by the current process, remove it. | |||||
try | |||||
{ | |||||
Directory.Delete(tmpDir, true); | |||||
} | |||||
catch (DirectoryNotFoundException) | |||||
{ | |||||
} | |||||
// Lock file exists and is owned by this process. | |||||
try | |||||
{ | |||||
var contents = File.ReadAllText(lockFile); | |||||
if (contents == lockContents) | |||||
{ | |||||
File.Delete(lockFile); | |||||
} | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
return moduleDir; | |||||
} | |||||
} | |||||
internal interface IResolver | |||||
{ | |||||
string Call(string handle); | |||||
bool IsSupported(string handle); | |||||
} | |||||
internal class PathResolver : IResolver | |||||
{ | |||||
public string Call(string handle) | |||||
{ | |||||
if (!File.Exists(handle) && !Directory.Exists(handle)) | |||||
{ | |||||
throw new IOException($"{handle} does not exist in file system."); | |||||
} | |||||
return handle; | |||||
} | |||||
public bool IsSupported(string handle) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
public abstract class HttpResolverBase : IResolver | |||||
{ | |||||
private readonly HttpClient httpClient; | |||||
private SslProtocol sslProtocol; | |||||
private RemoteCertificateValidationCallback certificateValidator; | |||||
protected HttpResolverBase() | |||||
{ | |||||
httpClient = new HttpClient(); | |||||
_maybe_disable_cert_validation(); | |||||
} | |||||
public abstract string Call(string handle); | |||||
public abstract bool IsSupported(string handle); | |||||
protected async Task<Stream> GetLocalFileStreamAsync(string filePath) | |||||
{ | |||||
try | |||||
{ | |||||
var fs = new FileStream(filePath, FileMode.Open, FileAccess.Read); | |||||
return await Task.FromResult(fs); | |||||
} | |||||
catch (Exception ex) | |||||
{ | |||||
Console.WriteLine($"Failed to read file stream: {ex.Message}"); | |||||
return null; | |||||
} | |||||
} | |||||
protected async Task<Stream> GetFileStreamAsync(string filePath) | |||||
{ | |||||
if (!is_http_protocol(filePath)) | |||||
{ | |||||
// If filePath is not an HTTP(S) URL, delegate to a file resolver. | |||||
return await GetLocalFileStreamAsync(filePath); | |||||
} | |||||
var request = new HttpRequestMessage(HttpMethod.Get, filePath); | |||||
var response = await _call_urlopen(request); | |||||
if (response.IsSuccessStatusCode) | |||||
{ | |||||
return await response.Content.ReadAsStreamAsync(); | |||||
} | |||||
else | |||||
{ | |||||
Console.WriteLine($"Failed to fetch file stream: {response.StatusCode} - {response.ReasonPhrase}"); | |||||
return null; | |||||
} | |||||
} | |||||
protected void SetUrlContext(SslProtocol protocol, RemoteCertificateValidationCallback validator) | |||||
{ | |||||
sslProtocol = protocol; | |||||
certificateValidator = validator; | |||||
} | |||||
public static string append_format_query(string handle, (string, string) formatQuery) | |||||
{ | |||||
var parsed = new Uri(handle); | |||||
var queryBuilder = HttpUtility.ParseQueryString(parsed.Query); | |||||
queryBuilder.Add(formatQuery.Item1, formatQuery.Item2); | |||||
parsed = new UriBuilder(parsed.Scheme, parsed.Host, parsed.Port, parsed.AbsolutePath, | |||||
"?" + queryBuilder.ToString()).Uri; | |||||
return parsed.ToString(); | |||||
} | |||||
protected bool is_http_protocol(string handle) | |||||
{ | |||||
return handle.StartsWith("http://") || handle.StartsWith("https://"); | |||||
} | |||||
protected async Task<HttpResponseMessage> _call_urlopen(HttpRequestMessage request) | |||||
{ | |||||
if (sslProtocol != null) | |||||
{ | |||||
var handler = new HttpClientHandler() | |||||
{ | |||||
SslProtocols = sslProtocol.AsEnum(), | |||||
}; | |||||
if (certificateValidator != null) | |||||
{ | |||||
handler.ServerCertificateCustomValidationCallback = (x, y, z, w) => | |||||
{ | |||||
return certificateValidator(x, y, z, w); | |||||
}; | |||||
} | |||||
var client = new HttpClient(handler); | |||||
return await client.SendAsync(request); | |||||
} | |||||
else | |||||
{ | |||||
return await httpClient.SendAsync(request); | |||||
} | |||||
} | |||||
protected void _maybe_disable_cert_validation() | |||||
{ | |||||
if (Environment.GetEnvironmentVariable("_TFHUB_DISABLE_CERT_VALIDATION") == "_TFHUB_DISABLE_CERT_VALIDATION_VALUE") | |||||
{ | |||||
ServicePointManager.ServerCertificateValidationCallback = (_, _, _, _) => true; | |||||
Console.WriteLine("Disabled certificate validation for resolving handles."); | |||||
} | |||||
} | |||||
} | |||||
public class SslProtocol | |||||
{ | |||||
private readonly string protocolString; | |||||
public static readonly SslProtocol Tls = new SslProtocol("TLS"); | |||||
public static readonly SslProtocol Tls11 = new SslProtocol("TLS 1.1"); | |||||
public static readonly SslProtocol Tls12 = new SslProtocol("TLS 1.2"); | |||||
private SslProtocol(string protocolString) | |||||
{ | |||||
this.protocolString = protocolString; | |||||
} | |||||
public SslProtocols AsEnum() | |||||
{ | |||||
switch (protocolString.ToUpper()) | |||||
{ | |||||
case "TLS": | |||||
return SslProtocols.Tls; | |||||
case "TLS 1.1": | |||||
return SslProtocols.Tls11; | |||||
case "TLS 1.2": | |||||
return SslProtocols.Tls12; | |||||
default: | |||||
throw new ArgumentException($"Unknown SSL/TLS protocol: {protocolString}"); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,80 @@ | |||||
using System; | |||||
using System.IO; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
internal class tf_utils | |||||
{ | |||||
public static string bytes_to_readable_str(long? numBytes, bool includeB = false) | |||||
{ | |||||
if (numBytes == null) return numBytes.ToString(); | |||||
var num = (double)numBytes; | |||||
if (num < 1024) | |||||
{ | |||||
return $"{(long)num}{(includeB ? "B" : "")}"; | |||||
} | |||||
num /= 1 << 10; | |||||
if (num < 1024) | |||||
{ | |||||
return $"{num:F2}k{(includeB ? "B" : "")}"; | |||||
} | |||||
num /= 1 << 10; | |||||
if (num < 1024) | |||||
{ | |||||
return $"{num:F2}M{(includeB ? "B" : "")}"; | |||||
} | |||||
num /= 1 << 10; | |||||
return $"{num:F2}G{(includeB ? "B" : "")}"; | |||||
} | |||||
public static void atomic_write_string_to_file(string filename, string contents, bool overwrite) | |||||
{ | |||||
var tempPath = $"{filename}.tmp.{Guid.NewGuid():N}"; | |||||
using (var fileStream = new FileStream(tempPath, FileMode.Create)) | |||||
{ | |||||
using (var writer = new StreamWriter(fileStream)) | |||||
{ | |||||
writer.Write(contents); | |||||
writer.Flush(); | |||||
} | |||||
} | |||||
try | |||||
{ | |||||
if (File.Exists(filename)) | |||||
{ | |||||
if (overwrite) | |||||
{ | |||||
File.Delete(filename); | |||||
File.Move(tempPath, filename); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
File.Move(tempPath, filename); | |||||
} | |||||
} | |||||
catch | |||||
{ | |||||
File.Delete(tempPath); | |||||
throw; | |||||
} | |||||
} | |||||
public static string absolute_path(string path) | |||||
{ | |||||
if (path.Contains("://")) | |||||
{ | |||||
return path; | |||||
} | |||||
return Path.GetFullPath(path); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,46 @@ | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.HubAPI; | |||||
namespace Tensorflow.Hub.Unittest | |||||
{ | |||||
[TestClass] | |||||
public class KerasLayerTest | |||||
{ | |||||
[TestMethod] | |||||
public void SmallBert() | |||||
{ | |||||
var layer = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1"); | |||||
var input_type_ids = tf.convert_to_tensor(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); | |||||
input_type_ids = tf.reshape(input_type_ids, (1, 128)); | |||||
var input_word_ids = tf.convert_to_tensor(new int[] { 101, 2129, 2024, 2017, 102, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); | |||||
input_word_ids = tf.reshape(input_word_ids, (1, 128)); | |||||
var input_mask = tf.convert_to_tensor(new int[] { 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: dtypes.int32); | |||||
input_mask = tf.reshape(input_mask, (1, 128)); | |||||
var result = layer.Apply(new Tensors(input_type_ids, input_word_ids, input_mask)); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,23 @@ | |||||
<Project Sdk="Microsoft.NET.Sdk"> | |||||
<PropertyGroup> | |||||
<TargetFramework>net7</TargetFramework> | |||||
<ImplicitUsings>enable</ImplicitUsings> | |||||
<Nullable>enable</Nullable> | |||||
<IsPackable>false</IsPackable> | |||||
</PropertyGroup> | |||||
<ItemGroup> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.10" /> | |||||
<PackageReference Include="MSTest.TestFramework" Version="2.2.10" /> | |||||
<PackageReference Include="coverlet.collector" Version="3.1.2" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.11.2" /> | |||||
</ItemGroup> | |||||
<ItemGroup> | |||||
<ProjectReference Include="..\..\src\TensorflowNET.Hub\Tensorflow.Hub.csproj" /> | |||||
</ItemGroup> | |||||
</Project> |
@@ -0,0 +1 @@ | |||||
global using Microsoft.VisualStudio.TestTools.UnitTesting; |