Browse Source

Merge pull request #1034 from AsakusaRinne/support_bert_load

Add Tensorflow.NET.Hub
tags/v0.100.5-BERT-load
Haiping GitHub 2 years ago
parent
commit
a767765e93
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1433 additions and 24 deletions
  1. +28
    -0
      TensorFlow.NET.sln
  2. +18
    -3
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  3. +11
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  4. +2
    -2
      src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs
  5. +4
    -0
      src/TensorFlowNET.Keras/Saving/KerasMetaData.cs
  6. +11
    -2
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  7. +34
    -3
      src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs
  8. +2
    -14
      src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs
  9. +40
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/RevivedNetwork.cs
  10. +57
    -0
      src/TensorflowNET.Hub/GcsCompressedFileResolver.cs
  11. +78
    -0
      src/TensorflowNET.Hub/HttpCompressedFileResolver.cs
  12. +65
    -0
      src/TensorflowNET.Hub/HttpUncompressedFileResolver.cs
  13. +157
    -0
      src/TensorflowNET.Hub/KerasLayer.cs
  14. +17
    -0
      src/TensorflowNET.Hub/Tensorflow.Hub.csproj
  15. +74
    -0
      src/TensorflowNET.Hub/file_utils.cs
  16. +17
    -0
      src/TensorflowNET.Hub/hub.cs
  17. +33
    -0
      src/TensorflowNET.Hub/module_v2.cs
  18. +55
    -0
      src/TensorflowNET.Hub/registry.cs
  19. +580
    -0
      src/TensorflowNET.Hub/resolver.cs
  20. +80
    -0
      src/TensorflowNET.Hub/tf_utils.cs
  21. +46
    -0
      test/TensorflowNET.Hub.Unittest/KerasLayerTest.cs
  22. +23
    -0
      test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj
  23. +1
    -0
      test/TensorflowNET.Hub.Unittest/Usings.cs

+ 28
- 0
TensorFlow.NET.sln View File

@@ -23,6 +23,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest",
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorflowNET.Hub\Tensorflow.Hub.csproj", "{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub.Unittest", "test\TensorflowNET.Hub.Unittest\Tensorflow.Hub.Unittest.csproj", "{7DEA8760-E401-4872-81F3-405F185A13A0}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -153,6 +157,30 @@ Global
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.Build.0 = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.ActiveCfg = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.Build.0 = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.ActiveCfg = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.Build.0 = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.ActiveCfg = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.Build.0 = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.ActiveCfg = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.Build.0 = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.ActiveCfg = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.Build.0 = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.Build.0 = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.ActiveCfg = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.Build.0 = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.ActiveCfg = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.Build.0 = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.ActiveCfg = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.Build.0 = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.ActiveCfg = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.Build.0 = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.ActiveCfg = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE


+ 18
- 3
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -207,9 +207,24 @@ namespace Tensorflow
}

public override string ToString()
=> items.Count() == 1
? items.First().ToString()
: items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name));
{
if(items.Count == 1)
{
return items[0].ToString();
}
else
{
StringBuilder sb = new StringBuilder();
sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n");
for(int i = 0; i < items.Count; i++)
{
var tensor = items[i];
sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n");
}
sb.Append("]\n");
return sb.ToString();
}
}

public void Dispose()
{


+ 11
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -301,6 +301,17 @@ namespace Tensorflow
type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref;
}

public static bool is_unsigned(this TF_DataType type)
{
return type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 ||
type == TF_DataType.TF_UINT64;
}

public static bool is_bool(this TF_DataType type)
{
return type == TF_DataType.TF_BOOL;
}

public static bool is_floating(this TF_DataType type)
{
return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE;


+ 2
- 2
src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs View File

@@ -22,9 +22,9 @@ namespace Tensorflow.Keras.Engine
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer
if (dtype.is_floating())
initializer = tf.glorot_uniform_initializer;
else if (dtype.is_integer())
else if (dtype.is_integer() || dtype.is_unsigned() || dtype.is_bool())
initializer = tf.zeros_initializer;
else
else if(getter is null)
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}");
}



+ 4
- 0
src/TensorFlowNET.Keras/Saving/KerasMetaData.cs View File

@@ -36,5 +36,9 @@ namespace Tensorflow.Keras.Saving
public bool? Stateful { get; set; }
[JsonProperty("model_config")]
public KerasModelConfig? ModelConfig { get; set; }
[JsonProperty("sparse")]
public bool Sparse { get; set; }
[JsonProperty("ragged")]
public bool Ragged { get; set; }
}
}

+ 11
- 2
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -401,13 +401,22 @@ namespace Tensorflow.Keras.Saving

private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata)
{
if(identifier == SavedModel.Constants.LAYER_IDENTIFIER)
if (identifier == SavedModel.Constants.LAYER_IDENTIFIER)
{
return RevivedLayer.init_from_metadata(metadata);
}
else if(identifier == SavedModel.Constants.MODEL_IDENTIFIER || identifier == SavedModel.Constants.SEQUENTIAL_IDENTIFIER
|| identifier == SavedModel.Constants.NETWORK_IDENTIFIER)
{
return RevivedNetwork.init_from_metadata(metadata);
}
else if(identifier == SavedModel.Constants.INPUT_LAYER_IDENTIFIER)
{
return RevivedInputLayer.init_from_metadata(metadata);
}
else
{
throw new NotImplementedException();
throw new ValueError($"Cannot revive the layer {identifier}.");
}
}



+ 34
- 3
src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs View File

@@ -1,15 +1,46 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;

namespace Tensorflow.Keras.Saving.SavedModel
{
public class RevivedInputLayer: Layer
public class RevivedInputLayer: InputLayer
{
private RevivedInputLayer(): base(null)
protected RevivedConfig _config = null;
private RevivedInputLayer(InputLayerArgs args): base(args)
{
throw new NotImplementedException();
}

public override IKerasConfig get_config()
{
return _config;
}

public static (RevivedInputLayer, Action<object, object, object>) init_from_metadata(KerasMetaData metadata)
{
InputLayerArgs args = new InputLayerArgs()
{
Name = metadata.Name,
DType = metadata.DType,
Sparse = metadata.Sparse,
Ragged = metadata.Ragged,
BatchInputShape = metadata.BatchInputShape
};

RevivedInputLayer revived_obj = new RevivedInputLayer(args);

revived_obj._config = new RevivedConfig() { Config = metadata.Config };

return (revived_obj, Loader.setattr);
}

public override string ToString()
{
return $"Customized keras input layer: {Name}.";
}
}
}

+ 2
- 14
src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs View File

@@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
return (revived_obj, ReviveUtils._revive_setter);
}

private RevivedConfig _config = null;
protected RevivedConfig _config = null;

public object keras_api
{
@@ -70,7 +70,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
}
}

public RevivedLayer(LayerArgs args): base(args)
protected RevivedLayer(LayerArgs args): base(args)
{

}
@@ -84,17 +84,5 @@ namespace Tensorflow.Keras.Saving.SavedModel
{
return _config;
}

//protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
//{
// if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function)
// {
// return base.Call(inputs, state, training);
// }
// else
// {
// return (func as Function).Apply(inputs);
// }
//}
}
}

+ 40
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/RevivedNetwork.cs View File

@@ -0,0 +1,40 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Saving.SavedModel
{
public class RevivedNetwork: RevivedLayer
{
private RevivedNetwork(LayerArgs args) : base(args)
{
}

public static (RevivedNetwork, Action<object, object, object>) init_from_metadata(KerasMetaData metadata)
{
RevivedNetwork revived_obj = new(new LayerArgs() { Name = metadata.Name });

// TODO(Rinne): with utils.no_automatic_dependency_tracking_scope(revived_obj)
// TODO(Rinne): revived_obj._expects_training_arg
var config = metadata.Config;
if (generic_utils.validate_config(config))
{
revived_obj._config = new RevivedConfig() { Config = config };
}
if(metadata.ActivityRegularizer is not null)
{
throw new NotImplementedException();
}

return (revived_obj, ReviveUtils._revive_setter);
}

public override string ToString()
{
return $"Customized keras Network: {Name}.";
}
}
}

+ 57
- 0
src/TensorflowNET.Hub/GcsCompressedFileResolver.cs View File

@@ -0,0 +1,57 @@
using System.IO;
using System.Threading.Tasks;

namespace Tensorflow.Hub
{
public class GcsCompressedFileResolver : IResolver
{
const int LOCK_FILE_TIMEOUT_SEC = 10 * 60;
public string Call(string handle)
{
var module_dir = _module_dir(handle);

return resolver.atomic_download_async(handle, download, module_dir, LOCK_FILE_TIMEOUT_SEC)
.GetAwaiter().GetResult();
}
public bool IsSupported(string handle)
{
return handle.StartsWith("gs://") && _is_tarfile(handle);
}

private async Task download(string handle, string tmp_dir)
{
new resolver.DownloadManager(handle).download_and_uncompress(
new FileStream(handle, FileMode.Open, FileAccess.Read), tmp_dir);
await Task.Run(() => { });
}

private static string _module_dir(string handle)
{
var cache_dir = resolver.tfhub_cache_dir(use_temp: true);
var sha1 = ComputeSha1(handle);
return resolver.create_local_module_dir(cache_dir, sha1);
}

private static bool _is_tarfile(string filename)
{
return filename.EndsWith(".tar") || filename.EndsWith(".tar.gz") || filename.EndsWith(".tgz");
}

private static string ComputeSha1(string s)
{
using (var sha = new System.Security.Cryptography.SHA1Managed())
{
var bytes = System.Text.Encoding.UTF8.GetBytes(s);
var hash = sha.ComputeHash(bytes);
var stringBuilder = new System.Text.StringBuilder(hash.Length * 2);

foreach (var b in hash)
{
stringBuilder.Append(b.ToString("x2"));
}

return stringBuilder.ToString();
}
}
}
}

+ 78
- 0
src/TensorflowNET.Hub/HttpCompressedFileResolver.cs View File

@@ -0,0 +1,78 @@
using System;
using System.Net.Http;
using System.Threading.Tasks;

namespace Tensorflow.Hub
{
public class HttpCompressedFileResolver : HttpResolverBase
{
const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; // 10 minutes

private static readonly (string, string) _COMPRESSED_FORMAT_QUERY =
("tf-hub-format", "compressed");

private static string _module_dir(string handle)
{
var cache_dir = resolver.tfhub_cache_dir(use_temp: true);
var sha1 = ComputeSha1(handle);
return resolver.create_local_module_dir(cache_dir, sha1);
}

public override bool IsSupported(string handle)
{
if (!is_http_protocol(handle))
{
return false;
}
var load_format = resolver.model_load_format();
return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.COMPRESSED)
|| load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.AUTO);
}

public override string Call(string handle)
{
var module_dir = _module_dir(handle);

return resolver.atomic_download_async(
handle,
download,
module_dir,
LOCK_FILE_TIMEOUT_SEC
).GetAwaiter().GetResult();
}

private async Task download(string handle, string tmp_dir)
{
var client = new HttpClient();

var response = await client.GetAsync(_append_compressed_format_query(handle));

using (var httpStream = await response.Content.ReadAsStreamAsync())
{
new resolver.DownloadManager(handle).download_and_uncompress(httpStream, tmp_dir);
}
}

private string _append_compressed_format_query(string handle)
{
return append_format_query(handle, _COMPRESSED_FORMAT_QUERY);
}

private static string ComputeSha1(string s)
{
using (var sha = new System.Security.Cryptography.SHA1Managed())
{
var bytes = System.Text.Encoding.UTF8.GetBytes(s);
var hash = sha.ComputeHash(bytes);
var stringBuilder = new System.Text.StringBuilder(hash.Length * 2);

foreach (var b in hash)
{
stringBuilder.Append(b.ToString("x2"));
}

return stringBuilder.ToString();
}
}
}
}

+ 65
- 0
src/TensorflowNET.Hub/HttpUncompressedFileResolver.cs View File

@@ -0,0 +1,65 @@
using System;
using System.Net;

namespace Tensorflow.Hub
{
public class HttpUncompressedFileResolver : HttpResolverBase
{
private readonly PathResolver _pathResolver;

public HttpUncompressedFileResolver()
{
_pathResolver = new PathResolver();
}

public override string Call(string handle)
{
handle = AppendUncompressedFormatQuery(handle);
var gsLocation = RequestGcsLocation(handle);
return _pathResolver.Call(gsLocation);
}

public override bool IsSupported(string handle)
{
if (!is_http_protocol(handle))
{
return false;
}

var load_format = resolver.model_load_format();
return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.UNCOMPRESSED);
}

protected virtual string AppendUncompressedFormatQuery(string handle)
{
return append_format_query(handle, ("tf-hub-format", "uncompressed"));
}

protected virtual string RequestGcsLocation(string handleWithParams)
{
var request = WebRequest.Create(handleWithParams);
var response = request.GetResponse() as HttpWebResponse;

if (response == null)
{
throw new Exception("Failed to get a response from the server.");
}

var statusCode = (int)response.StatusCode;

if (statusCode != 303)
{
throw new Exception($"Expected 303 for GCS location lookup but got HTTP {statusCode} {response.StatusDescription}");
}

var location = response.Headers["Location"];

if (!location.StartsWith("gs://"))
{
throw new Exception($"Expected Location:GS path but received {location}");
}

return location;
}
}
}

+ 157
- 0
src/TensorflowNET.Hub/KerasLayer.cs View File

@@ -0,0 +1,157 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Engine;
using Tensorflow.Train;
using Tensorflow.Training;
using Tensorflow.Training.Saving.SavedModel;
using static Tensorflow.Binding;

namespace Tensorflow.Hub
{
public class KerasLayer : Layer
{
private string _handle;
private LoadOptions? _load_options;
private Trackable _func;
private Func<Tensors, Tensors> _callable;

public KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null) :
base(new Keras.ArgsDefinition.LayerArgs() { Trainable = trainable })
{
_handle = handle;
_load_options = load_options;

_func = load_module(_handle, _load_options);
_track_trackable(_func, "_func");
// TODO(Rinne): deal with _is_hub_module_v1.

_callable = _get_callable();
_setup_layer(trainable);
}

private void _setup_layer(bool trainable = false)
{
HashSet<string> trainable_variables;
if (_func is Layer layer)
{
foreach (var v in layer.TrainableVariables)
{
_add_existing_weight(v, true);
}
trainable_variables = new HashSet<string>(layer.TrainableVariables.Select(v => v.UniqueId));
}
else if (_func.CustomizedFields.TryGetValue("trainable_variables", out var obj) && obj is IEnumerable<Trackable> trackables)
{
foreach (var trackable in trackables)
{
if (trackable is IVariableV1 v)
{
_add_existing_weight(v, true);
}
}
trainable_variables = new HashSet<string>(trackables.Where(t => t is IVariableV1).Select(t => (t as IVariableV1).UniqueId));
}
else
{
trainable_variables = new HashSet<string>();
}

if (_func is Layer)
{
layer = (Layer)_func;
foreach (var v in layer.Variables)
{
if (!trainable_variables.Contains(v.UniqueId))
{
_add_existing_weight(v, false);
}
}
}
else if (_func.CustomizedFields.TryGetValue("variables", out var obj) && obj is IEnumerable<Trackable> total_trackables)
{
foreach (var trackable in total_trackables)
{
if (trackable is IVariableV1 v && !trainable_variables.Contains(v.UniqueId))
{
_add_existing_weight(v, false);
}
}
}

if (_func.CustomizedFields.ContainsKey("regularization_losses"))
{
if ((_func.CustomizedFields["regularization_losses"] as ListWrapper)?.Count > 0)
{
throw new NotImplementedException("The regularization_losses loading has not been supported yet, " +
"please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues to let us know and add a feature.");
}
}
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
_check_trainability();

// TODO(Rinne): deal with training_argument

var result = _callable(inputs);

return _apply_output_shape_if_set(inputs, result);
}

private void _check_trainability()
{
if (!Trainable) return;

// TODO(Rinne): deal with _is_hub_module_v1 and signature

if (TrainableWeights is null || TrainableWeights.Count == 0)
{
tf.Logger.Error("hub.KerasLayer is trainable but has zero trainable weights.");
}
}

private Tensors _apply_output_shape_if_set(Tensors inputs, Tensors result)
{
// TODO(Rinne): implement it.
return result;
}

private void _add_existing_weight(IVariableV1 weight, bool? trainable = null)
{
bool is_trainable;
if (trainable is null)
{
is_trainable = weight.Trainable;
}
else
{
is_trainable = trainable.Value;
}
add_weight(weight.Name, weight.shape, weight.dtype, trainable: is_trainable, getter: x => weight);
}

private Func<Tensors, Tensors> _get_callable()
{
if (_func is Layer layer)
{
return x => layer.Apply(x);
}
if (_func.CustomizedFields.ContainsKey("__call__"))
{
if (_func.CustomizedFields["__call__"] is RestoredFunction function)
{
return x => function.Apply(x);
}
}
throw new ValueError("Cannot get the callable from the model.");
}

private static Trackable load_module(string handle, LoadOptions? load_options = null)
{
//var set_load_options = load_options ?? LoadContext.get_load_option();
return module_v2.load(handle, load_options);
}
}
}

+ 17
- 0
src/TensorflowNET.Hub/Tensorflow.Hub.csproj View File

@@ -0,0 +1,17 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>netstandard2.0;net6;net7</TargetFrameworks>
<LangVersion>11</LangVersion>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="SharpCompress" Version="0.33.0" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\\TensorFlowNET.Keras\Tensorflow.Keras.csproj" />
</ItemGroup>

</Project>

+ 74
- 0
src/TensorflowNET.Hub/file_utils.cs View File

@@ -0,0 +1,74 @@
using SharpCompress.Common;
using SharpCompress.Readers;
using System;
using System.IO;

namespace Tensorflow.Hub
{
internal static class file_utils
{
//public static void extract_file(TarInputStream tgz, TarEntry tarInfo, string dstPath, uint bufferSize = 10 << 20, Action<long> logFunction = null)
//{
// using (var src = tgz.GetNextEntry() == tarInfo ? tgz : null)
// {
// if (src is null)
// {
// return;
// }

// using (var dst = File.Create(dstPath))
// {
// var buffer = new byte[bufferSize];
// int count;

// while ((count = src.Read(buffer, 0, buffer.Length)) > 0)
// {
// dst.Write(buffer, 0, count);
// logFunction?.Invoke(count);
// }
// }
// }
//}

public static void extract_tarfile_to_destination(Stream fileobj, string dst_path, Action<long> logFunction = null)
{
using (IReader reader = ReaderFactory.Open(fileobj))
{
while (reader.MoveToNextEntry())
{
if (!reader.Entry.IsDirectory)
{
reader.WriteEntryToDirectory(
dst_path,
new ExtractionOptions() { ExtractFullPath = true, Overwrite = true }
);
}
}
}
}

public static string merge_relative_path(string dstPath, string relPath)
{
var cleanRelPath = Path.GetFullPath(relPath).TrimStart('/', '\\');

if (cleanRelPath == ".")
{
return dstPath;
}

if (cleanRelPath.StartsWith("..") || Path.IsPathRooted(cleanRelPath))
{
throw new InvalidDataException($"Relative path '{relPath}' is invalid.");
}

var merged = Path.Combine(dstPath, cleanRelPath);

if (!merged.StartsWith(dstPath))
{
throw new InvalidDataException($"Relative path '{relPath}' is invalid. Failed to merge with '{dstPath}'.");
}

return merged;
}
}
}

+ 17
- 0
src/TensorflowNET.Hub/hub.cs View File

@@ -0,0 +1,17 @@
using Tensorflow.Hub;

namespace Tensorflow
{
public static class HubAPI
{
public static HubMethods hub { get; } = new HubMethods();
}

public class HubMethods
{
public KerasLayer KerasLayer(string handle, bool trainable = false, LoadOptions? load_options = null)
{
return new KerasLayer(handle, trainable, load_options);
}
}
}

+ 33
- 0
src/TensorflowNET.Hub/module_v2.cs View File

@@ -0,0 +1,33 @@
using System.IO;
using Tensorflow.Train;

namespace Tensorflow.Hub
{
internal static class module_v2
{
public static Trackable load(string handle, LoadOptions? options)
{
var module_path = resolve(handle);

// TODO(Rinne): deal with is_hub_module_v1

var saved_model_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PB);
var saved_model_pb_txt_path = Path.Combine(module_path, Constants.SAVED_MODEL_FILENAME_PBTXT);
if (!File.Exists(saved_model_path) && !Directory.Exists(saved_model_path) && !File.Exists(saved_model_pb_txt_path)
&& !Directory.Exists(saved_model_pb_txt_path))
{
throw new ValueError($"Trying to load a model of incompatible/unknown type. " +
$"'{module_path}' contains neither '{Constants.SAVED_MODEL_FILENAME_PB}' " +
$"nor '{Constants.SAVED_MODEL_FILENAME_PBTXT}'.");
}

var obj = Loader.load(module_path, options: options);
return obj;
}

public static string resolve(string handle)
{
return MultiImplRegister.GetResolverRegister().Call(handle);
}
}
}

+ 55
- 0
src/TensorflowNET.Hub/registry.cs View File

@@ -0,0 +1,55 @@
using System;
using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Hub
{
internal class MultiImplRegister
{
private static MultiImplRegister resolver = new MultiImplRegister("resolver", new IResolver[0]);
private static MultiImplRegister loader = new MultiImplRegister("loader", new IResolver[0]);

static MultiImplRegister()
{
resolver.add_implementation(new PathResolver());
resolver.add_implementation(new HttpUncompressedFileResolver());
resolver.add_implementation(new GcsCompressedFileResolver());
resolver.add_implementation(new HttpCompressedFileResolver());
}

string _name;
List<IResolver> _impls;
public MultiImplRegister(string name, IEnumerable<IResolver> impls)
{
_name = name;
_impls = impls.ToList();
}

public void add_implementation(IResolver resolver)
{
_impls.Add(resolver);
}

public string Call(string handle)
{
foreach (var impl in _impls.Reverse<IResolver>())
{
if (impl.IsSupported(handle))
{
return impl.Call(handle);
}
}
throw new RuntimeError($"Cannot resolve the handle {handle}");
}

public static MultiImplRegister GetResolverRegister()
{
return resolver;
}

public static MultiImplRegister GetLoaderRegister()
{
return loader;
}
}
}

+ 580
- 0
src/TensorflowNET.Hub/resolver.cs View File

@@ -0,0 +1,580 @@
using ICSharpCode.SharpZipLib.Tar;
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Security;
using System.Security.Authentication;
using System.Threading.Tasks;
using System.Web;
using static Tensorflow.Binding;

namespace Tensorflow.Hub
{
internal static class resolver
{
public enum ModelLoadFormat
{
[Description("COMPRESSED")]
COMPRESSED,
[Description("UNCOMPRESSED")]
UNCOMPRESSED,
[Description("AUTO")]
AUTO
}
public class DownloadManager
{
private readonly string _url;
private double _last_progress_msg_print_time;
private long _total_bytes_downloaded;
private int _max_prog_str;

private bool _interactive_mode()
{
return !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("_TFHUB_DOWNLOAD_PROGRESS"));
}

private void _print_download_progress_msg(string msg, bool flush = false)
{
if (_interactive_mode())
{
// Print progress message to console overwriting previous progress
// message.
_max_prog_str = Math.Max(_max_prog_str, msg.Length);
Console.Write($"\r{msg.PadRight(_max_prog_str)}");
Console.Out.Flush();

//如果flush参数为true,则输出换行符减少干扰交互式界面。
if (flush)
Console.WriteLine();

}
else
{
// Interactive progress tracking is disabled. Print progress to the
// standard TF log.
tf.Logger.Information(msg);
}
}

private void _log_progress(long bytes_downloaded)
{
// Logs progress information about ongoing module download.

_total_bytes_downloaded += bytes_downloaded;
var now = DateTime.Now.Ticks / TimeSpan.TicksPerSecond;
if (_interactive_mode() || now - _last_progress_msg_print_time > 15)
{
// Print progress message every 15 secs or if interactive progress
// tracking is enabled.
_print_download_progress_msg($"Downloading {_url}:" +
$"{tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true)}");
_last_progress_msg_print_time = now;
}
}

public DownloadManager(string url)
{
_url = url;
_last_progress_msg_print_time = DateTime.Now.Ticks / TimeSpan.TicksPerSecond;
_total_bytes_downloaded = 0;
_max_prog_str = 0;
}

public void download_and_uncompress(Stream fileobj, string dst_path)
{
// Streams the content for the 'fileobj' and stores the result in dst_path.

try
{
file_utils.extract_tarfile_to_destination(fileobj, dst_path, _log_progress);
var total_size_str = tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true);
_print_download_progress_msg($"Downloaded {_url}, Total size: {total_size_str}", flush: true);
}
catch (TarException ex)
{
throw new IOException($"{_url} does not appear to be a valid module. Inner message:{ex.Message}", ex);
}
}
}
private static Dictionary<string, string> _flags = new();
private static readonly string _TFHUB_CACHE_DIR = "TFHUB_CACHE_DIR";
private static readonly string _TFHUB_DOWNLOAD_PROGRESS = "TFHUB_DOWNLOAD_PROGRESS";
private static readonly string _TFHUB_MODEL_LOAD_FORMAT = "TFHUB_MODEL_LOAD_FORMAT";
private static readonly string _TFHUB_DISABLE_CERT_VALIDATION = "TFHUB_DISABLE_CERT_VALIDATION";
private static readonly string _TFHUB_DISABLE_CERT_VALIDATION_VALUE = "true";

static resolver()
{
set_new_flag("tfhub_model_load_format", "AUTO");
set_new_flag("tfhub_cache_dir", null);
}

public static string model_load_format()
{
return get_env_setting(_TFHUB_MODEL_LOAD_FORMAT, "tfhub_model_load_format");
}

public static string? get_env_setting(string env_var, string flag_name)
{
string value = System.Environment.GetEnvironmentVariable(env_var);
if (string.IsNullOrEmpty(value))
{
if (_flags.ContainsKey(flag_name))
{
return _flags[flag_name];
}
else
{
return null;
}
}
else
{
return value;
}
}

public static string tfhub_cache_dir(string default_cache_dir = null, bool use_temp = false)
{
var cache_dir = get_env_setting(_TFHUB_CACHE_DIR, "tfhub_cache_dir") ?? default_cache_dir;
if (string.IsNullOrWhiteSpace(cache_dir) && use_temp)
{
// Place all TF-Hub modules under <system's temp>/tfhub_modules.
cache_dir = Path.Combine(Path.GetTempPath(), "tfhub_modules");
}
if (!string.IsNullOrWhiteSpace(cache_dir))
{
Console.WriteLine("Using {0} to cache modules.", cache_dir);
}
return cache_dir;
}

public static string create_local_module_dir(string cache_dir, string module_name)
{
Directory.CreateDirectory(cache_dir);
return Path.Combine(cache_dir, module_name);
}

public static void set_new_flag(string name, string value)
{
string[] tokens = new string[] {_TFHUB_CACHE_DIR, _TFHUB_DISABLE_CERT_VALIDATION,
_TFHUB_DISABLE_CERT_VALIDATION_VALUE, _TFHUB_DOWNLOAD_PROGRESS, _TFHUB_MODEL_LOAD_FORMAT};
if (!tokens.Contains(name))
{
tf.Logger.Warning($"You are settinng a flag '{name}' that cannot be recognized. The flag you set" +
"may not affect anything in tensorflow.hub.");
}
_flags[name] = value;
}

public static string _merge_relative_path(string dstPath, string relPath)
{
return file_utils.merge_relative_path(dstPath, relPath);
}

public static string _module_descriptor_file(string moduleDir)
{
return $"{moduleDir}.descriptor.txt";
}

public static void _write_module_descriptor_file(string handle, string moduleDir)
{
var readme = _module_descriptor_file(moduleDir);
var content = $"Module: {handle}\nDownload Time: {DateTime.Now}\nDownloader Hostname: {Environment.MachineName} (PID:{Process.GetCurrentProcess().Id})";
tf_utils.atomic_write_string_to_file(readme, content, overwrite: true);
}

public static string _lock_file_contents(string task_uid)
{
return $"{Environment.MachineName}.{Process.GetCurrentProcess().Id}.{task_uid}";
}

public static string _lock_filename(string moduleDir)
{
return tf_utils.absolute_path(moduleDir) + ".lock";
}

private static string _module_dir(string lockFilename)
{
var path = Path.GetDirectoryName(Path.GetFullPath(lockFilename));
if (!string.IsNullOrEmpty(path))
{
return Path.Combine(path, "hub_modules");
}

throw new Exception("Unable to resolve hub_modules directory from lock file name.");
}

private static string _task_uid_from_lock_file(string lockFilename)
{
// Returns task UID of the task that created a given lock file.
var lockstring = File.ReadAllText(lockFilename);
return lockstring.Split('.').Last();
}

private static string _temp_download_dir(string moduleDir, string taskUid)
{
// Returns the name of a temporary directory to download module to.
return $"{Path.GetFullPath(moduleDir)}.{taskUid}.tmp";
}

private static long _dir_size(string directory)
{
// Returns total size (in bytes) of the given 'directory'.
long size = 0;
foreach (var elem in Directory.EnumerateFileSystemEntries(directory))
{
var stat = new FileInfo(elem);
size += stat.Length;
if ((stat.Attributes & FileAttributes.Directory) != 0)
size += _dir_size(stat.FullName);
}
return size;
}

public static long _locked_tmp_dir_size(string lockFilename)
{
//Returns the size of the temp dir pointed to by the given lock file.
var taskUid = _task_uid_from_lock_file(lockFilename);
try
{
return _dir_size(_temp_download_dir(_module_dir(lockFilename), taskUid));
}
catch (DirectoryNotFoundException)
{
return 0;
}
}

private static void _wait_for_lock_to_disappear(string handle, string lockFile, double lockFileTimeoutSec)
{
long? lockedTmpDirSize = null;
var lockedTmpDirSizeCheckTime = DateTime.Now;
var lockFileContent = "";

while (File.Exists(lockFile))
{
try
{
Console.WriteLine($"Module '{handle}' already being downloaded by '{File.ReadAllText(lockFile)}'. Waiting.");

if ((DateTime.Now - lockedTmpDirSizeCheckTime).TotalSeconds > lockFileTimeoutSec)
{
var curLockedTmpDirSize = _locked_tmp_dir_size(lockFile);
var curLockFileContent = File.ReadAllText(lockFile);

if (curLockedTmpDirSize == lockedTmpDirSize && curLockFileContent == lockFileContent)
{
Console.WriteLine($"Deleting lock file {lockFile} due to inactivity.");
File.Delete(lockFile);
break;
}

lockedTmpDirSize = curLockedTmpDirSize;
lockedTmpDirSizeCheckTime = DateTime.Now;
lockFileContent = curLockFileContent;
}
}
catch (FileNotFoundException)
{
// Lock file or temp directory were deleted during check. Continue
// to check whether download succeeded or we need to start our own
// download.
}

System.Threading.Thread.Sleep(5000);
}
}

public static async Task<string> atomic_download_async(
string handle,
Func<string, string, Task> downloadFn,
string moduleDir,
int lock_file_timeout_sec = 10 * 60)
{
var lockFile = _lock_filename(moduleDir);
var taskUid = Guid.NewGuid().ToString("N");
var lockContents = _lock_file_contents(taskUid);
var tmpDir = _temp_download_dir(moduleDir, taskUid);

// Function to check whether model has already been downloaded.
Func<bool> checkModuleExists = () =>
Directory.Exists(moduleDir) &&
Directory.EnumerateFileSystemEntries(moduleDir).Any();

// Check whether the model has already been downloaded before locking
// the destination path.
if (checkModuleExists())
{
return moduleDir;
}

// Attempt to protect against cases of processes being cancelled with
// KeyboardInterrupt by using a try/finally clause to remove the lock
// and tmp_dir.
while (true)
{
try
{
tf_utils.atomic_write_string_to_file(lockFile, lockContents, false);
// Must test condition again, since another process could have created
// the module and deleted the old lock file since last test.
if (checkModuleExists())
{
// Lock file will be deleted in the finally-clause.
return moduleDir;
}
if (Directory.Exists(moduleDir))
{
Directory.Delete(moduleDir, true);
}
break; // Proceed to downloading the module.
}
// These errors are believed to be permanent problems with the
// module_dir that justify failing the download.
catch (FileNotFoundException)
{
throw;
}
catch (UnauthorizedAccessException)
{
throw;
}
catch (IOException)
{
throw;
}
// All other errors are retried.
// TODO(b/144424849): Retrying an AlreadyExistsError from the atomic write
// should be good enough, but see discussion about misc filesystem types.
// TODO(b/144475403): How atomic is the overwrite=False check?
catch (Exception)
{
}

// Wait for lock file to disappear.
_wait_for_lock_to_disappear(handle, lockFile, lock_file_timeout_sec);
// At this point we either deleted a lock or a lock got removed by the
// owner or another process. Perform one more iteration of the while-loop,
// we would either terminate due tf.compat.v1.gfile.Exists(module_dir) or
// because we would obtain a lock ourselves, or wait again for the lock to
// disappear.
}

// Lock file acquired.
tf.Logger.Information($"Downloading TF-Hub Module '{handle}'...");
Directory.CreateDirectory(tmpDir);
await downloadFn(handle, tmpDir);
// Write module descriptor to capture information about which module was
// downloaded by whom and when. The file stored at the same level as a
// directory in order to keep the content of the 'model_dir' exactly as it
// was define by the module publisher.
//
// Note: The descriptor is written purely to help the end-user to identify
// which directory belongs to which module. The descriptor is not part of the
// module caching protocol and no code in the TF-Hub library reads its
// content.
_write_module_descriptor_file(handle, moduleDir);
try
{
Directory.Move(tmpDir, moduleDir);
Console.WriteLine($"Downloaded TF-Hub Module '{handle}'.");
}
catch (IOException e)
{
Console.WriteLine(e.Message);
Console.WriteLine($"Failed to move {tmpDir} to {moduleDir}");
// Keep the temp directory so we will retry building vocabulary later.
}

// Temp directory is owned by the current process, remove it.
try
{
Directory.Delete(tmpDir, true);
}
catch (DirectoryNotFoundException)
{
}

// Lock file exists and is owned by this process.
try
{
var contents = File.ReadAllText(lockFile);
if (contents == lockContents)
{
File.Delete(lockFile);
}
}
catch (Exception)
{
}

return moduleDir;
}
}
internal interface IResolver
{
string Call(string handle);
bool IsSupported(string handle);
}

internal class PathResolver : IResolver
{
public string Call(string handle)
{
if (!File.Exists(handle) && !Directory.Exists(handle))
{
throw new IOException($"{handle} does not exist in file system.");
}
return handle;
}
public bool IsSupported(string handle)
{
return true;
}
}

public abstract class HttpResolverBase : IResolver
{
private readonly HttpClient httpClient;
private SslProtocol sslProtocol;
private RemoteCertificateValidationCallback certificateValidator;

protected HttpResolverBase()
{
httpClient = new HttpClient();
_maybe_disable_cert_validation();
}

public abstract string Call(string handle);
public abstract bool IsSupported(string handle);

protected async Task<Stream> GetLocalFileStreamAsync(string filePath)
{
try
{
var fs = new FileStream(filePath, FileMode.Open, FileAccess.Read);
return await Task.FromResult(fs);
}
catch (Exception ex)
{
Console.WriteLine($"Failed to read file stream: {ex.Message}");
return null;
}
}

protected async Task<Stream> GetFileStreamAsync(string filePath)
{
if (!is_http_protocol(filePath))
{
// If filePath is not an HTTP(S) URL, delegate to a file resolver.
return await GetLocalFileStreamAsync(filePath);
}

var request = new HttpRequestMessage(HttpMethod.Get, filePath);
var response = await _call_urlopen(request);

if (response.IsSuccessStatusCode)
{
return await response.Content.ReadAsStreamAsync();
}
else
{
Console.WriteLine($"Failed to fetch file stream: {response.StatusCode} - {response.ReasonPhrase}");
return null;
}
}

protected void SetUrlContext(SslProtocol protocol, RemoteCertificateValidationCallback validator)
{
sslProtocol = protocol;
certificateValidator = validator;
}

public static string append_format_query(string handle, (string, string) formatQuery)
{
var parsed = new Uri(handle);

var queryBuilder = HttpUtility.ParseQueryString(parsed.Query);
queryBuilder.Add(formatQuery.Item1, formatQuery.Item2);

parsed = new UriBuilder(parsed.Scheme, parsed.Host, parsed.Port, parsed.AbsolutePath,
"?" + queryBuilder.ToString()).Uri;

return parsed.ToString();
}

protected bool is_http_protocol(string handle)
{
return handle.StartsWith("http://") || handle.StartsWith("https://");
}

protected async Task<HttpResponseMessage> _call_urlopen(HttpRequestMessage request)
{
if (sslProtocol != null)
{
var handler = new HttpClientHandler()
{
SslProtocols = sslProtocol.AsEnum(),
};
if (certificateValidator != null)
{
handler.ServerCertificateCustomValidationCallback = (x, y, z, w) =>
{
return certificateValidator(x, y, z, w);
};
}

var client = new HttpClient(handler);
return await client.SendAsync(request);
}
else
{
return await httpClient.SendAsync(request);
}
}

protected void _maybe_disable_cert_validation()
{
if (Environment.GetEnvironmentVariable("_TFHUB_DISABLE_CERT_VALIDATION") == "_TFHUB_DISABLE_CERT_VALIDATION_VALUE")
{
ServicePointManager.ServerCertificateValidationCallback = (_, _, _, _) => true;
Console.WriteLine("Disabled certificate validation for resolving handles.");
}
}
}

public class SslProtocol
{
private readonly string protocolString;

public static readonly SslProtocol Tls = new SslProtocol("TLS");
public static readonly SslProtocol Tls11 = new SslProtocol("TLS 1.1");
public static readonly SslProtocol Tls12 = new SslProtocol("TLS 1.2");

private SslProtocol(string protocolString)
{
this.protocolString = protocolString;
}

public SslProtocols AsEnum()
{
switch (protocolString.ToUpper())
{
case "TLS":
return SslProtocols.Tls;
case "TLS 1.1":
return SslProtocols.Tls11;
case "TLS 1.2":
return SslProtocols.Tls12;
default:
throw new ArgumentException($"Unknown SSL/TLS protocol: {protocolString}");
}
}
}
}

+ 80
- 0
src/TensorflowNET.Hub/tf_utils.cs View File

@@ -0,0 +1,80 @@
using System;
using System.IO;

namespace Tensorflow.Hub
{
internal class tf_utils
{
public static string bytes_to_readable_str(long? numBytes, bool includeB = false)
{
if (numBytes == null) return numBytes.ToString();

var num = (double)numBytes;

if (num < 1024)
{
return $"{(long)num}{(includeB ? "B" : "")}";
}

num /= 1 << 10;
if (num < 1024)
{
return $"{num:F2}k{(includeB ? "B" : "")}";
}

num /= 1 << 10;
if (num < 1024)
{
return $"{num:F2}M{(includeB ? "B" : "")}";
}

num /= 1 << 10;
return $"{num:F2}G{(includeB ? "B" : "")}";
}

public static void atomic_write_string_to_file(string filename, string contents, bool overwrite)
{
var tempPath = $"{filename}.tmp.{Guid.NewGuid():N}";

using (var fileStream = new FileStream(tempPath, FileMode.Create))
{
using (var writer = new StreamWriter(fileStream))
{
writer.Write(contents);
writer.Flush();
}
}

try
{
if (File.Exists(filename))
{
if (overwrite)
{
File.Delete(filename);
File.Move(tempPath, filename);
}
}
else
{
File.Move(tempPath, filename);
}
}
catch
{
File.Delete(tempPath);
throw;
}
}

public static string absolute_path(string path)
{
if (path.Contains("://"))
{
return path;
}

return Path.GetFullPath(path);
}
}
}

+ 46
- 0
test/TensorflowNET.Hub.Unittest/KerasLayerTest.cs View File

@@ -0,0 +1,46 @@
using static Tensorflow.Binding;
using static Tensorflow.HubAPI;

namespace Tensorflow.Hub.Unittest
{
[TestClass]
public class KerasLayerTest
{
[TestMethod]
public void SmallBert()
{
var layer = hub.KerasLayer("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1");

var input_type_ids = tf.convert_to_tensor(new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32);
input_type_ids = tf.reshape(input_type_ids, (1, 128));
var input_word_ids = tf.convert_to_tensor(new int[] { 101, 2129, 2024, 2017, 102, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0 }, dtype: tf.int32);
input_word_ids = tf.reshape(input_word_ids, (1, 128));
var input_mask = tf.convert_to_tensor(new int[] { 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, dtype: dtypes.int32);
input_mask = tf.reshape(input_mask, (1, 128));

var result = layer.Apply(new Tensors(input_type_ids, input_word_ids, input_mask));
}

}
}

+ 23
- 0
test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj View File

@@ -0,0 +1,23 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net7</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>

<IsPackable>false</IsPackable>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" />
<PackageReference Include="MSTest.TestAdapter" Version="2.2.10" />
<PackageReference Include="MSTest.TestFramework" Version="2.2.10" />
<PackageReference Include="coverlet.collector" Version="3.1.2" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.11.2" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\TensorflowNET.Hub\Tensorflow.Hub.csproj" />
</ItemGroup>

</Project>

+ 1
- 0
test/TensorflowNET.Hub.Unittest/Usings.cs View File

@@ -0,0 +1 @@
global using Microsoft.VisualStudio.TestTools.UnitTesting;

Loading…
Cancel
Save