diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 560b9536..da90298d 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -54,9 +54,27 @@ namespace Tensorflow case "Double": Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); break; + case "Byte": + var bb = nd.Data(); + var bytes = Marshal.AllocHGlobal(bb.Length) ; + ulong bytes_len = c_api.TF_StringEncodedSize((ulong)bb.Length); + var dataTypeByte = ToTFDataType(nd.dtype); + // shape + var dims2 = nd.shape.Select(x => (long)x).ToArray(); + + var tfHandle2 = c_api.TF_AllocateTensor(dataTypeByte, + dims2, + nd.ndim, + bytes_len + sizeof(Int64)); + + dotHandle = c_api.TF_TensorData(tfHandle2); + Marshal.WriteInt64(dotHandle, 0); + c_api.TF_StringEncode(bytes, (ulong)bb.Length, dotHandle + sizeof(Int64), bytes_len, status); + return tfHandle2; case "String": - var str = nd.Data()[0]; - ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); + string ss = nd.Data()[0]; + var str = Marshal.StringToHGlobalAnsi(ss); + ulong dst_len = c_api.TF_StringEncodedSize((ulong)ss.Length); var dataType1 = ToTFDataType(nd.dtype); // shape var dims1 = nd.shape.Select(x => (long)x).ToArray(); @@ -68,7 +86,7 @@ namespace Tensorflow dotHandle = c_api.TF_TensorData(tfHandle1); Marshal.WriteInt64(dotHandle, 0); - c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle + sizeof(Int64), dst_len, status); + c_api.TF_StringEncode(str, (ulong)ss.Length, dotHandle + sizeof(Int64), dst_len, status); return tfHandle1; default: throw new NotImplementedException("Marshal.Copy failed."); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index d834b608..7cd0e95f 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -164,6 +164,7 @@ namespace Tensorflow return TF_DataType.TF_FLOAT; case "Double": return TF_DataType.TF_DOUBLE; + case "Byte": case "String": return TF_DataType.TF_STRING; default: diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index 78b1016f..78b6137a 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -120,7 +120,7 @@ namespace Tensorflow /// TF_Status* /// On success returns the size in bytes of the encoded string. [DllImport(TensorFlowLibName)] - public static extern ulong TF_StringEncode(string src, ulong src_len, IntPtr dst, ulong dst_len, IntPtr status); + public static extern ulong TF_StringEncode(IntPtr src, ulong src_len, IntPtr dst, ulong dst_len, IntPtr status); /// /// Decode a string encoded using TF_StringEncode. diff --git a/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll b/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll index 583c710f..82e86b4e 100644 Binary files a/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll and b/src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll differ diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index a3e19fb1..287ae433 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -57,5 +57,12 @@ namespace Tensorflow defaultSession = new Session(); return defaultSession; } + + public static Session Session(Graph graph) + { + g = graph; + defaultSession = new Session(); + return defaultSession; + } } } diff --git a/test/TensorFlowNET.Examples/ImageRecognition.cs b/test/TensorFlowNET.Examples/ImageRecognition.cs new file mode 100644 index 00000000..acc8c1d9 --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageRecognition.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.IO.Compression; +using System.Linq; +using System.Net; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.Examples +{ + public class ImageRecognition : Python, IExample + { + public void Run() + { + var graph = new Graph(); + // 从文件加载序列化的GraphDef + //导入GraphDef + graph.Import("tmp/tensorflow_inception_graph.pb"); + with(tf.Session(graph), sess => + { + var labels = File.ReadAllLines("tmp/imagenet_comp_graph_label_strings.txt"); + var files = Directory.GetFiles("img"); + foreach(var file in files) + { + var tensor = new Tensor(File.ReadAllBytes(file)); + } + }); + } + } +} diff --git a/test/TensorFlowNET.Examples/python/basic_text_classification.py b/test/TensorFlowNET.Examples/python/basic_text_classification.py new file mode 100644 index 00000000..42824e35 --- /dev/null +++ b/test/TensorFlowNET.Examples/python/basic_text_classification.py @@ -0,0 +1,133 @@ + +from __future__ import absolute_import, division, print_function + +import tensorflow as tf +from tensorflow import keras + +import numpy as np + +print(tf.__version__) + +imdb = keras.datasets.imdb + +(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000) + +print("Training entries: {}, labels: {}".format(len(train_data), len(train_labels))) +print(train_data[0]) +len(train_data[0]), len(train_data[1]) + +# A dictionary mapping words to an integer index +word_index = imdb.get_word_index() + +# The first indices are reserved +word_index = {k:(v+3) for k,v in word_index.items()} +word_index[""] = 0 +word_index[""] = 1 +word_index[""] = 2 # unknown +word_index[""] = 3 + +reverse_word_index = dict([(value, key) for (key, value) in word_index.items()]) + +def decode_review(text): + return ' '.join([reverse_word_index.get(i, '?') for i in text]) + +decode_review(train_data[0]) + + +train_data = keras.preprocessing.sequence.pad_sequences(train_data, + value=word_index[""], + padding='post', + maxlen=256) + +test_data = keras.preprocessing.sequence.pad_sequences(test_data, + value=word_index[""], + padding='post', + maxlen=256) + + +print(train_data[0]) + +# input shape is the vocabulary count used for the movie reviews (10,000 words) +vocab_size = 10000 + +model = keras.Sequential() +model.add(keras.layers.Embedding(vocab_size, 16)) +model.add(keras.layers.GlobalAveragePooling1D()) +model.add(keras.layers.Dense(16, activation=tf.nn.relu)) +model.add(keras.layers.Dense(1, activation=tf.nn.sigmoid)) + +model.summary() + +model.compile(optimizer='adam', + loss='binary_crossentropy', + metrics=['accuracy']) + + +x_val = train_data[:10000] +partial_x_train = train_data[10000:] + +y_val = train_labels[:10000] +partial_y_train = train_labels[10000:] + +history = model.fit(partial_x_train, + partial_y_train, + epochs=40, + batch_size=512, + validation_data=(x_val, y_val), + verbose=1) + +results = model.evaluate(test_data, test_labels) + +# serialize model to JSON +model_json = model.to_json() +with open("model.json", "w") as json_file: + json_file.write(model_json) +# serialize weights to HDF5 +model.save_weights("model.h5") +print("Saved model to disk") + +# load json and create model +json_file = open('model.json', 'r') +loaded_model_json = json_file.read() +json_file.close() +loaded_model = model_from_json(loaded_model_json) +# load weights into new model +loaded_model.load_weights("model.h5") +print("Loaded model from disk") + +print(results) + +history_dict = history.history +history_dict.keys() + +import matplotlib.pyplot as plt + +acc = history_dict['acc'] +val_acc = history_dict['val_acc'] +loss = history_dict['loss'] +val_loss = history_dict['val_loss'] + +epochs = range(1, len(acc) + 1) + +# "bo" is for "blue dot" +plt.plot(epochs, loss, 'bo', label='Training loss') +# b is for "solid blue line" +plt.plot(epochs, val_loss, 'b', label='Validation loss') +plt.title('Training and validation loss') +plt.xlabel('Epochs') +plt.ylabel('Loss') +plt.legend() + +plt.show() + + +plt.clf() # clear figure + +plt.plot(epochs, acc, 'bo', label='Training acc') +plt.plot(epochs, val_acc, 'b', label='Validation acc') +plt.title('Training and validation accuracy') +plt.xlabel('Epochs') +plt.ylabel('Accuracy') +plt.legend() + +plt.show() \ No newline at end of file