diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index ab228c47..c84684e6 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -160,7 +160,14 @@ namespace Tensorflow } else if (!name.Contains(":") & !allow_operation) { - throw new NotImplementedException("_as_graph_element_locked"); + // Looks like an Operation name but can't be an Operation. + if (_nodes_by_name.ContainsKey(name)) + // Yep, it's an Operation name + throw new ValueError($"The name {name} refers to an Operation, not a {types_str}."); + else + throw new ValueError( + $"The name {name} looks like an (invalid) Operation name, not a {types_str}" + + " Tensor names must be of the form \":\"."); } } diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index 1c1d84bb..a8eb01c9 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -79,13 +79,13 @@ namespace TensorFlowNET.Examples.CnnTextClassification var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1; double max_accuracy = 0; - Tensor is_training = graph.get_operation_by_name("is_training"); - Tensor model_x = graph.get_operation_by_name("x"); - Tensor model_y = graph.get_operation_by_name("y"); - Tensor loss = graph.get_operation_by_name("loss/value"); - Tensor optimizer = graph.get_operation_by_name("loss/optimizer"); - Tensor global_step = graph.get_operation_by_name("global_step"); - Tensor accuracy = graph.get_operation_by_name("accuracy/value"); + Tensor is_training = graph.get_tensor_by_name("is_training:0"); + Tensor model_x = graph.get_tensor_by_name("x:0"); + Tensor model_y = graph.get_tensor_by_name("y:0"); + Tensor loss = graph.get_tensor_by_name("loss/value:0"); + Tensor optimizer = graph.get_tensor_by_name("loss/optimizer:0"); + Tensor global_step = graph.get_tensor_by_name("global_step:0"); + Tensor accuracy = graph.get_tensor_by_name("accuracy/value:0"); stopwatch = Stopwatch.StartNew(); int i = 0; foreach (var (x_batch, y_batch, total) in train_batches)