Add Tensorflow.NET.Hubtags/v0.100.5-BERT-load
@@ -23,6 +23,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorflowNET.Hub\Tensorflow.Hub.csproj", "{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}" | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub.Unittest", "test\TensorflowNET.Hub.Unittest\Tensorflow.Hub.Unittest.csproj", "{7DEA8760-E401-4872-81F3-405F185A13A0}" | |||
EndProject | |||
Global | |||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
Debug|Any CPU = Debug|Any CPU | |||
@@ -153,6 +157,30 @@ Global | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.Build.0 = Debug|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.Build.0 = Debug|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.ActiveCfg = Release|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.Build.0 = Release|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.ActiveCfg = Release|Any CPU | |||
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.Build.0 = Release|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.Build.0 = Debug|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.Build.0 = Debug|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.ActiveCfg = Release|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.Build.0 = Release|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.ActiveCfg = Release|Any CPU | |||
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.Build.0 = Release|Any CPU | |||
EndGlobalSection | |||
GlobalSection(SolutionProperties) = preSolution | |||
HideSolutionNode = FALSE | |||
@@ -207,9 +207,24 @@ namespace Tensorflow | |||
} | |||
public override string ToString() | |||
=> items.Count() == 1 | |||
? items.First().ToString() | |||
: items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); | |||
{ | |||
if(items.Count == 1) | |||
{ | |||
return items[0].ToString(); | |||
} | |||
else | |||
{ | |||
StringBuilder sb = new StringBuilder(); | |||
sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); | |||
for(int i = 0; i < items.Count; i++) | |||
{ | |||
var tensor = items[i]; | |||
sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); | |||
} | |||
sb.Append("]\n"); | |||
return sb.ToString(); | |||
} | |||
} | |||
public void Dispose() | |||
{ | |||
@@ -301,6 +301,17 @@ namespace Tensorflow | |||
type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; | |||
} | |||
public static bool is_unsigned(this TF_DataType type) | |||
{ | |||
return type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || | |||
type == TF_DataType.TF_UINT64; | |||
} | |||
public static bool is_bool(this TF_DataType type) | |||
{ | |||
return type == TF_DataType.TF_BOOL; | |||
} | |||
public static bool is_floating(this TF_DataType type) | |||
{ | |||
return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE; | |||
@@ -22,9 +22,9 @@ namespace Tensorflow.Keras.Engine | |||
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer | |||
if (dtype.is_floating()) | |||
initializer = tf.glorot_uniform_initializer; | |||
else if (dtype.is_integer()) | |||
else if (dtype.is_integer() || dtype.is_unsigned() || dtype.is_bool()) | |||
initializer = tf.zeros_initializer; | |||
else | |||
else if(getter is null) | |||
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | |||
} | |||
@@ -36,5 +36,9 @@ namespace Tensorflow.Keras.Saving | |||
public bool? Stateful { get; set; } | |||
[JsonProperty("model_config")] | |||
public KerasModelConfig? ModelConfig { get; set; } | |||
[JsonProperty("sparse")] | |||
public bool Sparse { get; set; } | |||
[JsonProperty("ragged")] | |||
public bool Ragged { get; set; } | |||
} | |||
} |
@@ -401,13 +401,22 @@ namespace Tensorflow.Keras.Saving | |||
private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata) | |||
{ | |||
if(identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||
if (identifier == SavedModel.Constants.LAYER_IDENTIFIER) | |||
{ | |||
return RevivedLayer.init_from_metadata(metadata); | |||
} | |||
else if(identifier == SavedModel.Constants.MODEL_IDENTIFIER || identifier == SavedModel.Constants.SEQUENTIAL_IDENTIFIER | |||
|| identifier == SavedModel.Constants.NETWORK_IDENTIFIER) | |||
{ | |||
return RevivedNetwork.init_from_metadata(metadata); | |||
} | |||
else if(identifier == SavedModel.Constants.INPUT_LAYER_IDENTIFIER) | |||
{ | |||
return RevivedInputLayer.init_from_metadata(metadata); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException(); | |||
throw new ValueError($"Cannot revive the layer {identifier}."); | |||
} | |||
} | |||
@@ -1,15 +1,46 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers; | |||
namespace Tensorflow.Keras.Saving.SavedModel | |||
{ | |||
public class RevivedInputLayer: Layer | |||
public class RevivedInputLayer: InputLayer | |||
{ | |||
private RevivedInputLayer(): base(null) | |||
protected RevivedConfig _config = null; | |||
private RevivedInputLayer(InputLayerArgs args): base(args) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public override IKerasConfig get_config() | |||
{ | |||
return _config; | |||
} | |||
public static (RevivedInputLayer, Action<object, object, object>) init_from_metadata(KerasMetaData metadata) | |||
{ | |||
InputLayerArgs args = new InputLayerArgs() | |||
{ | |||
Name = metadata.Name, | |||
DType = metadata.DType, | |||
Sparse = metadata.Sparse, | |||
Ragged = metadata.Ragged, | |||
BatchInputShape = metadata.BatchInputShape | |||
}; | |||
RevivedInputLayer revived_obj = new RevivedInputLayer(args); | |||
revived_obj._config = new RevivedConfig() { Config = metadata.Config }; | |||
return (revived_obj, Loader.setattr); | |||
} | |||
public override string ToString() | |||
{ | |||
return $"Customized keras input layer: {Name}."; | |||
} | |||
} | |||
} |
@@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
return (revived_obj, ReviveUtils._revive_setter); | |||
} | |||
private RevivedConfig _config = null; | |||
protected RevivedConfig _config = null; | |||
public object keras_api | |||
{ | |||
@@ -70,7 +70,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
} | |||
} | |||
public RevivedLayer(LayerArgs args): base(args) | |||
protected RevivedLayer(LayerArgs args): base(args) | |||
{ | |||
} | |||
@@ -84,17 +84,5 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
{ | |||
return _config; | |||
} | |||
//protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
//{ | |||
// if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||
// { | |||
// return base.Call(inputs, state, training); | |||
// } | |||
// else | |||
// { | |||
// return (func as Function).Apply(inputs); | |||
// } | |||
//} | |||
} | |||
} |
@@ -0,0 +1,40 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Utils; | |||
namespace Tensorflow.Keras.Saving.SavedModel | |||
{ | |||
public class RevivedNetwork: RevivedLayer | |||
{ | |||
private RevivedNetwork(LayerArgs args) : base(args) | |||
{ | |||
} | |||
public static (RevivedNetwork, Action<object, object, object>) init_from_metadata(KerasMetaData metadata) | |||
{ | |||
RevivedNetwork revived_obj = new(new LayerArgs() { Name = metadata.Name }); | |||
// TODO(Rinne): with utils.no_automatic_dependency_tracking_scope(revived_obj) | |||
// TODO(Rinne): revived_obj._expects_training_arg | |||
var config = metadata.Config; | |||
if (generic_utils.validate_config(config)) | |||
{ | |||
revived_obj._config = new RevivedConfig() { Config = config }; | |||
} | |||
if(metadata.ActivityRegularizer is not null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return (revived_obj, ReviveUtils._revive_setter); | |||
} | |||
public override string ToString() | |||
{ | |||
return $"Customized keras Network: {Name}."; | |||
} | |||
} | |||
} |
@@ -0,0 +1,57 @@ | |||
using System.IO; | |||
using System.Threading.Tasks; | |||
namespace Tensorflow.Hub | |||
{ | |||
public class GcsCompressedFileResolver : IResolver | |||
{ | |||
const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; | |||
public string Call(string handle) | |||
{ | |||
var module_dir = _module_dir(handle); | |||
return resolver.atomic_download_async(handle, download, module_dir, LOCK_FILE_TIMEOUT_SEC) | |||
.GetAwaiter().GetResult(); | |||
} | |||
public bool IsSupported(string handle) | |||
{ | |||
return handle.StartsWith("gs://") && _is_tarfile(handle); | |||
} | |||
private async Task download(string handle, string tmp_dir) | |||
{ | |||
new resolver.DownloadManager(handle).download_and_uncompress( | |||
new FileStream(handle, FileMode.Open, FileAccess.Read), tmp_dir); | |||
await Task.Run(() => { }); | |||
} | |||
private static string _module_dir(string handle) | |||
{ | |||
var cache_dir = resolver.tfhub_cache_dir(use_temp: true); | |||
var sha1 = ComputeSha1(handle); | |||
return resolver.create_local_module_dir(cache_dir, sha1); | |||
} | |||
private static bool _is_tarfile(string filename) | |||
{ | |||
return filename.EndsWith(".tar") || filename.EndsWith(".tar.gz") || filename.EndsWith(".tgz"); | |||
} | |||
private static string ComputeSha1(string s) | |||
{ | |||
using (var sha = new System.Security.Cryptography.SHA1Managed()) | |||
{ | |||
var bytes = System.Text.Encoding.UTF8.GetBytes(s); | |||
var hash = sha.ComputeHash(bytes); | |||
var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); | |||
foreach (var b in hash) | |||
{ | |||
stringBuilder.Append(b.ToString("x2")); | |||
} | |||
return stringBuilder.ToString(); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,78 @@ | |||
using System; | |||
using System.Net.Http; | |||
using System.Threading.Tasks; | |||
namespace Tensorflow.Hub | |||
{ | |||
public class HttpCompressedFileResolver : HttpResolverBase | |||
{ | |||
const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; // 10 minutes | |||
private static readonly (string, string) _COMPRESSED_FORMAT_QUERY = | |||
("tf-hub-format", "compressed"); | |||
private static string _module_dir(string handle) | |||
{ | |||
var cache_dir = resolver.tfhub_cache_dir(use_temp: true); | |||
var sha1 = ComputeSha1(handle); | |||
return resolver.create_local_module_dir(cache_dir, sha1); | |||
} | |||
public override bool IsSupported(string handle) | |||
{ | |||
if (!is_http_protocol(handle)) | |||
{ | |||
return false; | |||
} | |||
var load_format = resolver.model_load_format(); | |||
return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.COMPRESSED) | |||
|| load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.AUTO); | |||
} | |||
public override string Call(string handle) | |||
{ | |||
var module_dir = _module_dir(handle); | |||
return resolver.atomic_download_async( | |||
handle, | |||
download, | |||
module_dir, | |||
LOCK_FILE_TIMEOUT_SEC | |||
).GetAwaiter().GetResult(); | |||
} | |||
private async Task download(string handle, string tmp_dir) | |||
{ | |||
var client = new HttpClient(); | |||
var response = await client.GetAsync(_append_compressed_format_query(handle)); | |||
using (var httpStream = await response.Content.ReadAsStreamAsync()) | |||
{ | |||
new resolver.DownloadManager(handle).download_and_uncompress(httpStream, tmp_dir); | |||
} | |||
} | |||
private string _append_compressed_format_query(string handle) | |||
{ | |||
return append_format_query(handle, _COMPRESSED_FORMAT_QUERY); | |||
} | |||
private static string ComputeSha1(string s) | |||
{ | |||
using (var sha = new System.Security.Cryptography.SHA1Managed()) | |||
{ | |||
var bytes = System.Text.Encoding.UTF8.GetBytes(s); | |||
var hash = sha.ComputeHash(bytes); | |||
var stringBuilder = new System.Text.StringBuilder(hash.Length * 2); | |||
foreach (var b in hash) | |||
{ | |||
stringBuilder.Append(b.ToString("x2")); | |||
} | |||
return stringBuilder.ToString(); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,65 @@ | |||
using System; | |||
using System.Net; | |||
namespace Tensorflow.Hub | |||
{ | |||
public class HttpUncompressedFileResolver : HttpResolverBase | |||
{ | |||
private readonly PathResolver _pathResolver; | |||
public HttpUncompressedFileResolver() | |||
{ | |||
_pathResolver = new PathResolver(); | |||
} | |||
public override string Call(string handle) | |||
{ | |||
handle = AppendUncompressedFormatQuery(handle); | |||
var gsLocation = RequestGcsLocation(handle); | |||
return _pathResolver.Call(gsLocation); | |||
} | |||
public override bool IsSupported(string handle) | |||
{ | |||
if (!is_http_protocol(handle)) | |||
{ | |||
return false; | |||
} | |||
var load_format = resolver.model_load_format(); | |||
return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.UNCOMPRESSED); | |||
} | |||
protected virtual string AppendUncompressedFormatQuery(string handle) | |||
{ | |||
return append_format_query(handle, ("tf-hub-format", "uncompressed")); | |||
} | |||
protected virtual string RequestGcsLocation(string handleWithParams) | |||
{ | |||
var request = WebRequest.Create(handleWithParams); | |||
var response = request.GetResponse() as HttpWebResponse; | |||
if (response == null) | |||
{ | |||
throw new Exception("Failed to get a response from the server."); | |||
} | |||
var statusCode = (int)response.StatusCode; | |||
if (statusCode != 303) | |||
{ | |||
throw new Exception($"Expected 303 for GCS location lookup but got HTTP {statusCode} {response.StatusDescription}"); | |||
} | |||
var location = response.Headers["Location"]; | |||
if (!location.StartsWith("gs://")) | |||
{ | |||
throw new Exception($"Expected Location:GS path but received {location}"); | |||
} | |||
return location; | |||
} | |||
} | |||
} |
@@ -0,0 +1,157 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Hub | |||
{ | |||
public class KerasLayer : Layer | |||
{ | |||
private string _handle; | |||
private LoadOptions? _load_options; | |||
private Trackable _func; | |||
private Func<Tensors, Tensors> _callable; | |||
public KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) : | |||
base(new Keras.ArgsDefinition.LayerArgs() { Trainable = trainable }) | |||
{ | |||
_handle = handle; | |||
_load_options = load_options; | |||
_func = load_module(_handle, _load_options); | |||
_track_trackable(_func, "_func"); | |||
// TODO(Rinne): deal with _is_hub_module_v1. | |||
_callable = _get_callable(); | |||
_setup_layer(trainable); | |||
} | |||
private void _setup_layer(bool trainable = false) | |||
{ | |||
HashSet<string> trainable_variables; | |||
if (_func is Layer layer) | |||
{ | |||
foreach (var v in layer.TrainableVariables) | |||
{ | |||
_add_existing_weight(v, true); | |||
} | |||
trainable_variables = new HashSet<string>(layer.TrainableVariables.Select(v => v.UniqueId)); | |||
} | |||
else if (_func.CustomizedFields.TryGetValue("trainable_variables", out var obj) && obj is IEnumerable<Trackable> trackables) | |||
{ | |||
foreach (var trackable in trackables) | |||
{ | |||
if (trackable is IVariableV1 v) | |||
{ | |||
_add_existing_weight(v, true); | |||
} | |||
} | |||
trainable_variables = new HashSet<string>(trackables.Where(t => t is IVariableV1).Select(t => (t as IVariableV1).UniqueId)); | |||
} | |||
else | |||
{ | |||
trainable_variables = new HashSet<string>(); | |||
} | |||
if (_func is Layer) | |||
{ | |||
layer = (Layer)_func; | |||
foreach (var v in layer.Variables) | |||
{ | |||
if (!trainable_variables.Contains(v.UniqueId)) | |||
{ | |||
_add_existing_weight(v, false); | |||
} | |||
} | |||
} | |||
else if (_func.CustomizedFields.TryGetValue("variables", out var obj) && obj is IEnumerable<Trackable> total_trackables) | |||
{ | |||
foreach (var trackable in total_trackables) | |||
{ | |||
if (trackable is IVariableV1 v && !trainable_variables.Contains(v.UniqueId)) | |||
{ | |||
_add_existing_weight(v, false); | |||
} | |||
} | |||
} | |||
if (_func.CustomizedFields.ContainsKey("regularization_losses")) | |||
{ | |||
if ((_func.CustomizedFields["regularization_losses"] as ListWrapper)?.Count > 0) | |||
{ | |||
throw new NotImplementedException("The regularization_losses loading has not been supported yet, " + | |||
"please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues to let us know and add a feature."); | |||
} | |||
} | |||
} | |||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
{ | |||
_check_trainability(); | |||
// TODO(Rinne): deal with training_argument | |||
var result = _callable(inputs); | |||
return _apply_output_shape_if_set(inputs, result); | |||
} | |||
private void _check_trainability() | |||
{ | |||
if (!Trainable) return; | |||
// TODO(Rinne): deal with _is_hub_module_v1 and signature | |||
if (TrainableWeights is null || TrainableWeights.Count == 0) | |||
{ | |||
tf.Logger.Error("hub.KerasLayer is trainable but has zero trainable weights."); | |||
} | |||
} | |||
private Tensors _apply_output_shape_if_set(Tensors inputs, Tensors result) | |||
{ | |||
// TODO(Rinne): implement it. | |||
return result; | |||
} | |||
private void _add_existing_weight(IVariableV1 weight, bool? trainable = null) | |||
{ | |||
bool is_trainable; | |||
if (trainable is null) | |||
{ | |||
is_trainable = weight.Trainable; | |||
} | |||
else | |||
{ | |||
is_trainable = trainable.Value; | |||
} | |||
add_weight(weight.Name, weight.shape, weight.dtype, trainable: is_trainable, getter: x => weight); | |||
} | |||
private Func<Tensors, Tensors> _get_callable() | |||
{ | |||
if (_func is Layer layer) | |||
{ | |||
return x => layer.Apply(x); | |||
} | |||
if (_func.CustomizedFields.ContainsKey("__call__")) | |||
{ | |||
if (_func.CustomizedFields["__call__"] is RestoredFunction function) | |||
{ | |||
return x => function.Apply(x); | |||
} | |||
} | |||
throw new ValueError("Cannot get the callable from the model."); | |||
} | |||
private static Trackable load_module(string handle, LoadOptions? load_options = null) | |||
{ | |||
//var set_load_options = load_options ?? LoadContext.get_load_option(); | |||
return module_v2.load(handle, load_options); | |||
} | |||
} | |||
} |
@@ -0,0 +1,17 @@ | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<PropertyGroup> | |||
<TargetFrameworks>netstandard2.0;net6;net7</TargetFrameworks> | |||
<LangVersion>11</LangVersion> | |||
<Nullable>enable</Nullable> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<PackageReference Include="SharpCompress" Version="0.33.0" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\\TensorFlowNET.Keras\Tensorflow.Keras.csproj" /> | |||
</ItemGroup> | |||
</Project> |
@@ -0,0 +1,74 @@ | |||
using SharpCompress.Common; | |||
using SharpCompress.Readers; | |||
using System; | |||
using System.IO; | |||
namespace Tensorflow.Hub | |||
{ | |||
internal static class file_utils | |||
{ | |||
//public static void extract_file(TarInputStream tgz, TarEntry tarInfo, string dstPath, uint bufferSize = 10 << 20, Action<long> logFunction = null) | |||
//{ | |||
// using (var src = tgz.GetNextEntry() == tarInfo ? tgz : null) | |||
// { | |||
// if (src is null) | |||
// { | |||
// return; | |||
// } | |||
// using (var dst = File.Create(dstPath)) | |||
// { | |||
// var buffer = new byte[bufferSize]; | |||
// int count; | |||
// while ((count = src.Read(buffer, 0, buffer.Length)) > 0) | |||
// { | |||
// dst.Write(buffer, 0, count); | |||
// logFunction?.Invoke(count); | |||
// } | |||
// } | |||
// } | |||
//} | |||
public static void extract_tarfile_to_destination(Stream fileobj, string dst_path, Action<long> logFunction = null) | |||
{ | |||
using (IReader reader = ReaderFactory.Open(fileobj)) | |||
{ | |||
while (reader.MoveToNextEntry()) | |||
{ | |||
if (!reader.Entry.IsDirectory) | |||
{ | |||
reader.WriteEntryToDirectory( | |||
dst_path, | |||
new ExtractionOptions() { ExtractFullPath = true, Overwrite = true } | |||
); | |||
} | |||
} | |||
} | |||
} | |||
public static string merge_relative_path(string dstPath, string relPath) | |||
{ | |||
var cleanRelPath = Path.GetFullPath(relPath).TrimStart('/', '\\'); | |||
if (cleanRelPath == ".") | |||
{ | |||
return dstPath; | |||
} | |||
if (cleanRelPath.StartsWith("..") || Path.IsPathRooted(cleanRelPath)) | |||
{ | |||
throw new InvalidDataException($"Relative path '{relPath}' is invalid."); | |||
} | |||
var merged = Path.Combine(dstPath, cleanRelPath); | |||
if (!merged.StartsWith(dstPath)) | |||
{ | |||
throw new InvalidDataException($"Relative path '{relPath}' is invalid. Failed to merge with '{dstPath}'."); | |||
} | |||
return merged; | |||
} | |||
} | |||
} |
@@ -0,0 +1,17 @@ | |||
using Tensorflow.Hub; | |||
namespace Tensorflow | |||
{ | |||
public static class HubAPI | |||
{ | |||
public static HubMethods hub { get; } = new HubMethods(); | |||
} | |||
public class HubMethods | |||
{ | |||
public KerasLayer KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) | |||
{ | |||
return new KerasLayer(handle, trainable, load_options); | |||
} | |||
} | |||
} |
@@ -0,0 +1,33 @@ | |||
using System.IO; | |||
using Tensorflow.Train; | |||
namespace Tensorflow.Hub | |||
{ | |||
internal static class module_v2 | |||
{ | |||
public static Trackable load(string handle, LoadOptions? options) | |||
{ | |||
var module_path = resolve(handle); | |||
// TODO(Rinne): deal with is_hub_module_v1 | |||
var saved_model_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PB); | |||
var saved_model_pb_txt_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PBTXT); | |||
if (!File.Exists(saved_model_path) && !Directory.Exists(saved_model_path) && !File.Exists(saved_model_pb_txt_path) | |||
&& !Directory.Exists(saved_model_pb_txt_path)) | |||
{ | |||
throw new ValueError($"Trying to load a model of incompatible/unknown type. " + | |||
$"'{module_path}' contains neither '{Constants.SAVED_MODEL_FILENAME_PB}' " + | |||
$"nor '{Constants.SAVED_MODEL_FILENAME_PBTXT}'."); | |||
} | |||
var obj = Loader.load(module_path, options: options); | |||
return obj; | |||
} | |||
public static string resolve(string handle) | |||
{ | |||
return MultiImplRegister.GetResolverRegister().Call(handle); | |||
} | |||
} | |||
} |
@@ -0,0 +1,55 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
namespace Tensorflow.Hub | |||
{ | |||
internal class MultiImplRegister | |||
{ | |||
private static MultiImplRegister resolver = new MultiImplRegister("resolver", new IResolver[0]); | |||
private static MultiImplRegister loader = new MultiImplRegister("loader", new IResolver[0]); | |||
static MultiImplRegister() | |||
{ | |||
resolver.add_implementation(new PathResolver()); | |||
resolver.add_implementation(new HttpUncompressedFileResolver()); | |||
resolver.add_implementation(new GcsCompressedFileResolver()); | |||
resolver.add_implementation(new HttpCompressedFileResolver()); | |||
} | |||
string _name; | |||
List<IResolver> _impls; | |||
public MultiImplRegister(string name, IEnumerable<IResolver> impls) | |||
{ | |||
_name = name; | |||
_impls = impls.ToList(); | |||
} | |||
public void add_implementation(IResolver resolver) | |||
{ | |||
_impls.Add(resolver); | |||
} | |||
public string Call(string handle) | |||
{ | |||
foreach (var impl in _impls.Reverse<IResolver>()) | |||
{ | |||
if (impl.IsSupported(handle)) | |||
{ | |||
return impl.Call(handle); | |||
} | |||
} | |||
throw new RuntimeError($"Cannot resolve the handle {handle}"); | |||
} | |||
public static MultiImplRegister GetResolverRegister() | |||
{ | |||
return resolver; | |||
} | |||
public static MultiImplRegister GetLoaderRegister() | |||
{ | |||
return loader; | |||
} | |||
} | |||
} |
@@ -0,0 +1,580 @@ | |||
using ICSharpCode.SharpZipLib.Tar; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.ComponentModel; | |||
using System.Diagnostics; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Net; | |||
using System.Net.Http; | |||
using System.Net.Security; | |||
using System.Security.Authentication; | |||
using System.Threading.Tasks; | |||
using System.Web; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Hub | |||
{ | |||
internal static class resolver | |||
{ | |||
public enum ModelLoadFormat | |||
{ | |||
[Description("COMPRESSED")] | |||
COMPRESSED, | |||
[Description("UNCOMPRESSED")] | |||
UNCOMPRESSED, | |||
[Description("AUTO")] | |||
AUTO | |||
} | |||
public class DownloadManager | |||
{ | |||
private readonly string _url; | |||
private double _last_progress_msg_print_time; | |||
private long _total_bytes_downloaded; | |||
private int _max_prog_str; | |||
private bool _interactive_mode() | |||
{ | |||
return !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("_TFHUB_DOWNLOAD_PROGRESS")); | |||
} | |||
private void _print_download_progress_msg(string msg, bool flush = false) | |||
{ | |||
if (_interactive_mode()) | |||
{ | |||
// Print progress message to console overwriting previous progress | |||
// message. | |||
_max_prog_str = Math.Max(_max_prog_str, msg.Length); | |||
Console.Write($"\r{msg.PadRight(_max_prog_str)}"); | |||
Console.Out.Flush(); | |||
//如果flush参数为true,则输出换行符减少干扰交互式界面。 | |||
if (flush) | |||
Console.WriteLine(); | |||
} | |||
else | |||
{ | |||
// Interactive progress tracking is disabled. Print progress to the | |||
// standard TF log. | |||
tf.Logger.Information(msg); | |||
} | |||
} | |||
private void _log_progress(long bytes_downloaded) | |||
{ | |||
// Logs progress information about ongoing module download. | |||
_total_bytes_downloaded += bytes_downloaded; | |||
var now = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; | |||
if (_interactive_mode() || now - _last_progress_msg_print_time > 15) | |||
{ | |||
// Print progress message every 15 secs or if interactive progress | |||
// tracking is enabled. | |||
_print_download_progress_msg($"Downloading {_url}:" + | |||
$"{tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true)}"); | |||
_last_progress_msg_print_time = now; | |||
} | |||
} | |||
public DownloadManager(string url) | |||
{ | |||
_url = url; | |||
_last_progress_msg_print_time = DateTime.Now.Ticks / TimeSpan.TicksPerSecond; | |||
_total_bytes_downloaded = 0; | |||
_max_prog_str = 0; | |||
} | |||
public void download_and_uncompress(Stream fileobj, string dst_path) | |||
{ | |||
// Streams the content for the 'fileobj' and stores the result in dst_path. | |||
try | |||
{ | |||
file_utils.extract_tarfile_to_destination(fileobj, dst_path, _log_progress); | |||
var total_size_str = tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true); | |||
_print_download_progress_msg($"Downloaded {_url}, Total size: {total_size_str}", flush: true); | |||
} | |||
catch (TarException ex) | |||
{ | |||
throw new IOException($"{_url} does not appear to be a valid module. Inner message:{ex.Message}", ex); | |||
} | |||
} | |||
} | |||
private static Dictionary<string, string> _flags = new(); | |||
private static readonly string _TFHUB_CACHE_DIR = "TFHUB_CACHE_DIR"; | |||
private static readonly string _TFHUB_DOWNLOAD_PROGRESS = "TFHUB_DOWNLOAD_PROGRESS"; | |||
private static readonly string _TFHUB_MODEL_LOAD_FORMAT = "TFHUB_MODEL_LOAD_FORMAT"; | |||
private static readonly string _TFHUB_DISABLE_CERT_VALIDATION = "TFHUB_DISABLE_CERT_VALIDATION"; | |||
private static readonly string _TFHUB_DISABLE_CERT_VALIDATION_VALUE = "true"; | |||
static resolver() | |||
{ | |||
set_new_flag("tfhub_model_load_format", "AUTO"); | |||
set_new_flag("tfhub_cache_dir", null); | |||
} | |||
public static string model_load_format() | |||
{ | |||
return get_env_setting(_TFHUB_MODEL_LOAD_FORMAT, "tfhub_model_load_format"); | |||
} | |||
public static string? get_env_setting(string env_var, string flag_name) | |||
{ | |||
string value = System.Environment.GetEnvironmentVariable(env_var); | |||
if (string.IsNullOrEmpty(value)) | |||
{ | |||
if (_flags.ContainsKey(flag_name)) | |||
{ | |||
return _flags[flag_name]; | |||
} | |||
else | |||
{ | |||
return null; | |||
} | |||
} | |||
else | |||
{ | |||
return value; | |||
} | |||
} | |||
public static string tfhub_cache_dir(string default_cache_dir = null, bool use_temp = false) | |||
{ | |||
var cache_dir = get_env_setting(_TFHUB_CACHE_DIR, "tfhub_cache_dir") ?? default_cache_dir; | |||
if (string.IsNullOrWhiteSpace(cache_dir) && use_temp) | |||
{ | |||
// Place all TF-Hub modules under <system's temp>/tfhub_modules. | |||
cache_dir = Path.Combine(Path.GetTempPath(), "tfhub_modules"); | |||
} | |||
if (!string.IsNullOrWhiteSpace(cache_dir)) | |||
{ | |||
Console.WriteLine("Using {0} to cache modules.", cache_dir); | |||
} | |||
return cache_dir; | |||
} | |||
public static string create_local_module_dir(string cache_dir, string module_name) | |||
{ | |||
Directory.CreateDirectory(cache_dir); | |||
return Path.Combine(cache_dir, module_name); | |||
} | |||
public static void set_new_flag(string name, string value) | |||
{ | |||
string[] tokens = new string[] {_TFHUB_CACHE_DIR, _TFHUB_DISABLE_CERT_VALIDATION, | |||
_TFHUB_DISABLE_CERT_VALIDATION_VALUE, _TFHUB_DOWNLOAD_PROGRESS, _TFHUB_MODEL_LOAD_FORMAT}; | |||
if (!tokens.Contains(name)) | |||
{ | |||
tf.Logger.Warning($"You are settinng a flag '{name}' that cannot be recognized. The flag you set" + | |||
"may not affect anything in tensorflow.hub."); | |||
} | |||
_flags[name] = value; | |||
} | |||
public static string _merge_relative_path(string dstPath, string relPath) | |||
{ | |||
return file_utils.merge_relative_path(dstPath, relPath); | |||
} | |||
public static string _module_descriptor_file(string moduleDir) | |||
{ | |||
return $"{moduleDir}.descriptor.txt"; | |||
} | |||
public static void _write_module_descriptor_file(string handle, string moduleDir) | |||
{ | |||
var readme = _module_descriptor_file(moduleDir); | |||
var content = $"Module: {handle}\nDownload Time: {DateTime.Now}\nDownloader Hostname: {Environment.MachineName} (PID:{Process.GetCurrentProcess().Id})"; | |||
tf_utils.atomic_write_string_to_file(readme, content, overwrite: true); | |||
} | |||
public static string _lock_file_contents(string task_uid) | |||
{ | |||
return $"{Environment.MachineName}.{Process.GetCurrentProcess().Id}.{task_uid}"; | |||
} | |||
public static string _lock_filename(string moduleDir) | |||
{ | |||
return tf_utils.absolute_path(moduleDir) + ".lock"; | |||
} | |||
private static string _module_dir(string lockFilename) | |||
{ | |||
var path = Path.GetDirectoryName(Path.GetFullPath(lockFilename)); | |||
if (!string.IsNullOrEmpty(path)) | |||
{ | |||
return Path.Combine(path, "hub_modules"); | |||
} | |||
throw new Exception("Unable to resolve hub_modules directory from lock file name."); | |||
} | |||
private static string _task_uid_from_lock_file(string lockFilename) | |||
{ | |||
// Returns task UID of the task that created a given lock file. | |||
var lockstring = File.ReadAllText(lockFilename); | |||
return lockstring.Split('.').Last(); | |||
} | |||
private static string _temp_download_dir(string moduleDir, string taskUid) | |||
{ | |||
// Returns the name of a temporary directory to download module to. | |||
return $"{Path.GetFullPath(moduleDir)}.{taskUid}.tmp"; | |||
} | |||
private static long _dir_size(string directory) | |||
{ | |||
// Returns total size (in bytes) of the given 'directory'. | |||
long size = 0; | |||
foreach (var elem in Directory.EnumerateFileSystemEntries(directory)) | |||
{ | |||
var stat = new FileInfo(elem); | |||
size += stat.Length; | |||
if ((stat.Attributes & FileAttributes.Directory) != 0) | |||
size += _dir_size(stat.FullName); | |||
} | |||
return size; | |||
} | |||
public static long _locked_tmp_dir_size(string lockFilename) | |||
{ | |||
//Returns the size of the temp dir pointed to by the given lock file. | |||
var taskUid = _task_uid_from_lock_file(lockFilename); | |||
try | |||
{ | |||
return _dir_size(_temp_download_dir(_module_dir(lockFilename), taskUid)); | |||
} | |||
catch (DirectoryNotFoundException) | |||
{ | |||
return 0; | |||
} | |||
} | |||
private static void _wait_for_lock_to_disappear(string handle, string lockFile, double lockFileTimeoutSec) | |||
{ | |||
long? lockedTmpDirSize = null; | |||
var lockedTmpDirSizeCheckTime = DateTime.Now; | |||
var lockFileContent = ""; | |||
while (File.Exists(lockFile)) | |||
{ | |||
try | |||
{ | |||
Console.WriteLine($"Module '{handle}' already being downloaded by '{File.ReadAllText(lockFile)}'. Waiting."); | |||
if ((DateTime.Now - lockedTmpDirSizeCheckTime).TotalSeconds > lockFileTimeoutSec) | |||
{ | |||
var curLockedTmpDirSize = _locked_tmp_dir_size(lockFile); | |||
var curLockFileContent = File.ReadAllText(lockFile); | |||
if (curLockedTmpDirSize == lockedTmpDirSize && curLockFileContent == lockFileContent) | |||
{ | |||
Console.WriteLine($"Deleting lock file {lockFile} due to inactivity."); | |||
File.Delete(lockFile); | |||
break; | |||
} | |||
lockedTmpDirSize = curLockedTmpDirSize; | |||
lockedTmpDirSizeCheckTime = DateTime.Now; | |||
lockFileContent = curLockFileContent; | |||
} | |||
} | |||
catch (FileNotFoundException) | |||
{ | |||
// Lock file or temp directory were deleted during check. Continue | |||
// to check whether download succeeded or we need to start our own | |||
// download. | |||
} | |||
System.Threading.Thread.Sleep(5000); | |||
} | |||
} | |||
public static async Task<string> atomic_download_async( | |||
string handle, | |||
Func<string, string, Task> downloadFn, | |||
string moduleDir, | |||
int lock_file_timeout_sec = 10 * 60) | |||
{ | |||
var lockFile = _lock_filename(moduleDir); | |||
var taskUid = Guid.NewGuid().ToString("N"); | |||
var lockContents = _lock_file_contents(taskUid); | |||
var tmpDir = _temp_download_dir(moduleDir, taskUid); | |||
// Function to check whether model has already been downloaded. | |||
Func<bool> checkModuleExists = () => | |||
Directory.Exists(moduleDir) && | |||
Directory.EnumerateFileSystemEntries(moduleDir).Any(); | |||
// Check whether the model has already been downloaded before locking | |||
// the destination path. | |||
if (checkModuleExists()) | |||
{ | |||
return moduleDir; | |||
} | |||
// Attempt to protect against cases of processes being cancelled with | |||
// KeyboardInterrupt by using a try/finally clause to remove the lock | |||
// and tmp_dir. | |||
while (true) | |||
{ | |||
try | |||
{ | |||
tf_utils.atomic_write_string_to_file(lockFile, lockContents, false); | |||
// Must test condition again, since another process could have created | |||
// the module and deleted the old lock file since last test. | |||
if (checkModuleExists()) | |||
{ | |||
// Lock file will be deleted in the finally-clause. | |||
return moduleDir; | |||
} | |||
if (Directory.Exists(moduleDir)) | |||
{ | |||
Directory.Delete(moduleDir, true); | |||
} | |||
break; // Proceed to downloading the module. | |||
} | |||
// These errors are believed to be permanent problems with the | |||
// module_dir that justify failing the download. | |||
catch (FileNotFoundException) | |||
{ | |||
throw; | |||
} | |||
catch (UnauthorizedAccessException) | |||
{ | |||
throw; | |||
} | |||
catch (IOException) | |||
{ | |||
throw; | |||
} | |||
// All other errors are retried. | |||
// TODO(b/144424849): Retrying an AlreadyExistsError from the atomic write | |||
// should be good enough, but see discussion about misc filesystem types. | |||
// TODO(b/144475403): How atomic is the overwrite=False check? | |||
catch (Exception) | |||
{ | |||
} | |||
// Wait for lock file to disappear. | |||
_wait_for_lock_to_disappear(handle, lockFile, lock_file_timeout_sec); | |||
// At this point we either deleted a lock or a lock got removed by the | |||
// owner or another process. Perform one more iteration of the while-loop, | |||
// we would either terminate due tf.compat.v1.gfile.Exists(module_dir) or | |||
// because we would obtain a lock ourselves, or wait again for the lock to | |||
// disappear. | |||
} | |||
// Lock file acquired. | |||
tf.Logger.Information($"Downloading TF-Hub Module '{handle}'..."); | |||
Directory.CreateDirectory(tmpDir); | |||
await downloadFn(handle, tmpDir); | |||
// Write module descriptor to capture information about which module was | |||
// downloaded by whom and when. The file stored at the same level as a | |||
// directory in order to keep the content of the 'model_dir' exactly as it | |||
// was define by the module publisher. | |||
// | |||
// Note: The descriptor is written purely to help the end-user to identify | |||
// which directory belongs to which module. The descriptor is not part of the | |||
// module caching protocol and no code in the TF-Hub library reads its | |||
// content. | |||
_write_module_descriptor_file(handle, moduleDir); | |||
try | |||
{ | |||
Directory.Move(tmpDir, moduleDir); | |||
Console.WriteLine($"Downloaded TF-Hub Module '{handle}'."); | |||
} | |||
catch (IOException e) | |||
{ | |||
Console.WriteLine(e.Message); | |||
Console.WriteLine($"Failed to move {tmpDir} to {moduleDir}"); | |||
// Keep the temp directory so we will retry building vocabulary later. | |||
} | |||
// Temp directory is owned by the current process, remove it. | |||
try | |||
{ | |||
Directory.Delete(tmpDir, true); | |||
} | |||
catch (DirectoryNotFoundException) | |||
{ | |||
} | |||
// Lock file exists and is owned by this process. | |||
try | |||
{ | |||
var contents = File.ReadAllText(lockFile); | |||
if (contents == lockContents) | |||
{ | |||
File.Delete(lockFile); | |||
} | |||
} | |||
catch (Exception) | |||
{ | |||
} | |||
return moduleDir; | |||
} | |||
} | |||
internal interface IResolver | |||
{ | |||
string Call(string handle); | |||
bool IsSupported(string handle); | |||
} | |||
internal class PathResolver : IResolver | |||
{ | |||
public string Call(string handle) | |||
{ | |||
if (!File.Exists(handle) && !Directory.Exists(handle)) | |||
{ | |||
throw new IOException($"{handle} does not exist in file system."); | |||
} | |||
return handle; | |||
} | |||
public bool IsSupported(string handle) | |||
{ | |||
return true; | |||
} | |||
} | |||
public abstract class HttpResolverBase : IResolver | |||
{ | |||
private readonly HttpClient httpClient; | |||
private SslProtocol sslProtocol; | |||
private RemoteCertificateValidationCallback certificateValidator; | |||
protected HttpResolverBase() | |||
{ | |||
httpClient = new HttpClient(); | |||
_maybe_disable_cert_validation(); | |||
} | |||
public abstract string Call(string handle); | |||
public abstract bool IsSupported(string handle); | |||
protected async Task<Stream> GetLocalFileStreamAsync(string filePath) | |||
{ | |||
try | |||
{ | |||
var fs = new FileStream(filePath, FileMode.Open, FileAccess.Read); | |||
return await Task.FromResult(fs); | |||
} | |||
catch (Exception ex) | |||
{ | |||
Console.WriteLine($"Failed to read file stream: {ex.Message}"); | |||
return null; | |||
} | |||
} | |||
protected async Task<Stream> GetFileStreamAsync(string filePath) | |||
{ | |||
if (!is_http_protocol(filePath)) | |||
{ | |||
// If filePath is not an HTTP(S) URL, delegate to a file resolver. | |||
return await GetLocalFileStreamAsync(filePath); | |||
} | |||
var request = new HttpRequestMessage(HttpMethod.Get, filePath); | |||
var response = await _call_urlopen(request); | |||
if (response.IsSuccessStatusCode) | |||
{ | |||
return await response.Content.ReadAsStreamAsync(); | |||
} | |||
else | |||
{ | |||
Console.WriteLine($"Failed to fetch file stream: {response.StatusCode} - {response.ReasonPhrase}"); | |||
return null; | |||
} | |||
} | |||
protected void SetUrlContext(SslProtocol protocol, RemoteCertificateValidationCallback validator) | |||
{ | |||
sslProtocol = protocol; | |||
certificateValidator = validator; | |||
} | |||
public static string append_format_query(string handle, (string, string) formatQuery) | |||
{ | |||
var parsed = new Uri(handle); | |||
var queryBuilder = HttpUtility.ParseQueryString(parsed.Query); | |||
queryBuilder.Add(formatQuery.Item1, formatQuery.Item2); | |||
parsed = new UriBuilder(parsed.Scheme, parsed.Host, parsed.Port, parsed.AbsolutePath, | |||
"?" + queryBuilder.ToString()).Uri; | |||
return parsed.ToString(); | |||
} | |||
protected bool is_http_protocol(string handle) | |||
{ | |||
return handle.StartsWith("http://") || handle.StartsWith("https://"); | |||
} | |||
protected async Task<HttpResponseMessage> _call_urlopen(HttpRequestMessage request) | |||
{ | |||
if (sslProtocol != null) | |||
{ | |||
var handler = new HttpClientHandler() | |||
{ | |||
SslProtocols = sslProtocol.AsEnum(), | |||
}; | |||
if (certificateValidator != null) | |||
{ | |||
handler.ServerCertificateCustomValidationCallback = (x, y, z, w) => | |||
{ | |||
return certificateValidator(x, y, z, w); | |||
}; | |||
} | |||
var client = new HttpClient(handler); | |||
return await client.SendAsync(request); | |||
} | |||
else | |||
{ | |||
return await httpClient.SendAsync(request); | |||
} | |||
} | |||
protected void _maybe_disable_cert_validation() | |||
{ | |||
if (Environment.GetEnvironmentVariable("_TFHUB_DISABLE_CERT_VALIDATION") == "_TFHUB_DISABLE_CERT_VALIDATION_VALUE") | |||
{ | |||
ServicePointManager.ServerCertificateValidationCallback = (_, _, _, _) => true; | |||
Console.WriteLine("Disabled certificate validation for resolving handles."); | |||
} | |||
} | |||
} | |||
public class SslProtocol | |||
{ | |||
private readonly string protocolString; | |||
public static readonly SslProtocol Tls = new SslProtocol("TLS"); | |||
public static readonly SslProtocol Tls11 = new SslProtocol("TLS 1.1"); | |||
public static readonly SslProtocol Tls12 = new SslProtocol("TLS 1.2"); | |||
private SslProtocol(string protocolString) | |||
{ | |||
this.protocolString = protocolString; | |||
} | |||
public SslProtocols AsEnum() | |||
{ | |||
switch (protocolString.ToUpper()) | |||
{ | |||
case "TLS": | |||
return SslProtocols.Tls; | |||
case "TLS 1.1": | |||
return SslProtocols.Tls11; | |||
case "TLS 1.2": | |||
return SslProtocols.Tls12; | |||
default: | |||
throw new ArgumentException($"Unknown SSL/TLS protocol: {protocolString}"); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,80 @@ | |||
using System; | |||
using System.IO; | |||
namespace Tensorflow.Hub | |||
{ | |||
internal class tf_utils | |||
{ | |||
public static string bytes_to_readable_str(long? numBytes, bool includeB = false) | |||
{ | |||
if (numBytes == null) return numBytes.ToString(); | |||
var num = (double)numBytes; | |||
if (num < 1024) | |||
{ | |||
return $"{(long)num}{(includeB ? "B" : "")}"; | |||
} | |||
num /= 1 << 10; | |||
if (num < 1024) | |||
{ | |||
return $"{num:F2}k{(includeB ? "B" : "")}"; | |||
} | |||
num /= 1 << 10; | |||
if (num < 1024) | |||
{ | |||
return $"{num:F2}M{(includeB ? "B" : "")}"; | |||
} | |||
num /= 1 << 10; | |||
return $"{num:F2}G{(includeB ? "B" : "")}"; | |||
} | |||
public static void atomic_write_string_to_file(string filename, string contents, bool overwrite) | |||
{ | |||
var tempPath = $"{filename}.tmp.{Guid.NewGuid():N}"; | |||
using (var fileStream = new FileStream(tempPath, FileMode.Create)) | |||
{ | |||
using (var writer = new StreamWriter(fileStream)) | |||
{ | |||
writer.Write(contents); | |||
writer.Flush(); | |||
} | |||
} | |||
try | |||
{ | |||
if (File.Exists(filename)) | |||
{ | |||
if (overwrite) | |||
{ | |||
File.Delete(filename); | |||
File.Move(tempPath, filename); | |||
} | |||
} | |||
else | |||
{ | |||
File.Move(tempPath, filename); | |||
} | |||
} | |||
catch | |||
{ | |||
File.Delete(tempPath); | |||
throw; | |||
} | |||
} | |||
public static string absolute_path(string path) | |||
{ | |||
if (path.Contains("://")) | |||
{ | |||
return path; | |||
} | |||
return Path.GetFullPath(path); | |||
} | |||
} | |||
} |
@@ -0,0 +1,46 @@ | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.HubAPI; | |||
namespace Tensorflow.Hub.Unittest | |||
{ | |||
[TestClass] | |||
public class KerasLayerTest | |||
{ | |||
[TestMethod] | |||
public void SmallBert() | |||
{ | |||
var layer = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1"); | |||
var input_type_ids = tf.convert_to_tensor(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); | |||
input_type_ids = tf.reshape(input_type_ids, (1, 128)); | |||
var input_word_ids = tf.convert_to_tensor(new int[] { 101, 2129, 2024, 2017, 102, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32); | |||
input_word_ids = tf.reshape(input_word_ids, (1, 128)); | |||
var input_mask = tf.convert_to_tensor(new int[] { 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: dtypes.int32); | |||
input_mask = tf.reshape(input_mask, (1, 128)); | |||
var result = layer.Apply(new Tensors(input_type_ids, input_word_ids, input_mask)); | |||
} | |||
} | |||
} |
@@ -0,0 +1,23 @@ | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<PropertyGroup> | |||
<TargetFramework>net7</TargetFramework> | |||
<ImplicitUsings>enable</ImplicitUsings> | |||
<Nullable>enable</Nullable> | |||
<IsPackable>false</IsPackable> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" /> | |||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.10" /> | |||
<PackageReference Include="MSTest.TestFramework" Version="2.2.10" /> | |||
<PackageReference Include="coverlet.collector" Version="3.1.2" /> | |||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.11.2" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\src\TensorflowNET.Hub\Tensorflow.Hub.csproj" /> | |||
</ItemGroup> | |||
</Project> |
@@ -0,0 +1 @@ | |||
global using Microsoft.VisualStudio.TestTools.UnitTesting; |