@@ -62,6 +62,33 @@ A FIFOQueue that supports batching variable-sized tensors by padding. A `Padding | |||||
A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument. | A queue implementation that dequeues elements in prioritized order. A `PriorityQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `PriorityQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `types`, and whose shapes are optionally described by the `shapes` argument. | ||||
```csharp | |||||
[TestMethod] | |||||
public void PriorityQueue() | |||||
{ | |||||
var queue = tf.PriorityQueue(3, tf.@string); | |||||
var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | |||||
var x = queue.dequeue(); | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
init.run(); | |||||
// output will 2, 3, 4 | |||||
var result = sess.run(x); | |||||
Assert.AreEqual(result[0].GetInt64(), 2L); | |||||
result = sess.run(x); | |||||
Assert.AreEqual(result[0].GetInt64(), 3L); | |||||
result = sess.run(x); | |||||
Assert.AreEqual(result[0].GetInt64(), 4L); | |||||
} | |||||
} | |||||
``` | |||||
#### RandomShuffleQueue | #### RandomShuffleQueue | ||||
A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. | A queue implementation that dequeues elements in a random order. A `RandomShuffleQueue` has bounded capacity; supports multiple concurrent producers and consumers; and provides exactly-once delivery. A `RandomShuffleQueue` holds a list of up to `capacity` elements. Each element is a fixed-length tuple of tensors whose dtypes are described by `dtypes`, and whose shapes are optionally described by the `shapes` argument. | ||||
@@ -50,7 +50,7 @@ namespace Tensorflow | |||||
string shared_name = null, | string shared_name = null, | ||||
string name = "padding_fifo_queue") | string name = "padding_fifo_queue") | ||||
=> new PaddingFIFOQueue(capacity, | => new PaddingFIFOQueue(capacity, | ||||
new [] { dtype }, | |||||
new[] { dtype }, | |||||
new[] { shape }, | new[] { shape }, | ||||
shared_name: shared_name, | shared_name: shared_name, | ||||
name: name); | name: name); | ||||
@@ -86,7 +86,26 @@ namespace Tensorflow | |||||
=> new FIFOQueue(capacity, | => new FIFOQueue(capacity, | ||||
new[] { dtype }, | new[] { dtype }, | ||||
new[] { shape ?? new TensorShape() }, | new[] { shape ?? new TensorShape() }, | ||||
new[] { name }, | |||||
shared_name: shared_name, | |||||
name: name); | |||||
/// <summary> | |||||
/// Creates a queue that dequeues elements in a first-in first-out order. | |||||
/// </summary> | |||||
/// <param name="capacity"></param> | |||||
/// <param name="dtype"></param> | |||||
/// <param name="shape"></param> | |||||
/// <param name="shared_name"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public PriorityQueue PriorityQueue(int capacity, | |||||
TF_DataType dtype, | |||||
TensorShape shape = null, | |||||
string shared_name = null, | |||||
string name = "priority_queue") | |||||
=> new PriorityQueue(capacity, | |||||
new[] { dtype }, | |||||
new[] { shape ?? new TensorShape() }, | |||||
shared_name: shared_name, | shared_name: shared_name, | ||||
name: name); | name: name); | ||||
} | } | ||||
@@ -0,0 +1,66 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Queues | |||||
{ | |||||
public class PriorityQueue : QueueBase | |||||
{ | |||||
public PriorityQueue(int capacity, | |||||
TF_DataType[] dtypes, | |||||
TensorShape[] shapes, | |||||
string[] names = null, | |||||
string shared_name = null, | |||||
string name = "priority_queue") | |||||
: base(dtypes: dtypes, shapes: shapes, names: names) | |||||
{ | |||||
_queue_ref = gen_data_flow_ops.priority_queue_v2( | |||||
component_types: dtypes, | |||||
shapes: shapes, | |||||
capacity: capacity, | |||||
shared_name: shared_name, | |||||
name: name); | |||||
_name = _queue_ref.op.name.Split('/').Last(); | |||||
var dtypes1 = dtypes.ToList(); | |||||
dtypes1.Insert(0, TF_DataType.TF_INT64); | |||||
_dtypes = dtypes1.ToArray(); | |||||
var shapes1 = shapes.ToList(); | |||||
shapes1.Insert(0, new TensorShape()); | |||||
_shapes = shapes1.ToArray(); | |||||
} | |||||
public Operation enqueue_many<T>(long[] indexes, T[] vals, string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, $"{_name}_EnqueueMany", vals), scope => | |||||
{ | |||||
var vals_tensor1 = _check_enqueue_dtypes(indexes); | |||||
var vals_tensor2 = _check_enqueue_dtypes(vals); | |||||
var tensors = new List<Tensor>(); | |||||
tensors.AddRange(vals_tensor1); | |||||
tensors.AddRange(vals_tensor2); | |||||
return gen_data_flow_ops.queue_enqueue_many_v2(_queue_ref, tensors.ToArray(), name: scope); | |||||
}); | |||||
} | |||||
public Tensor[] dequeue(string name = null) | |||||
{ | |||||
Tensor[] ret; | |||||
if (name == null) | |||||
name = $"{_name}_Dequeue"; | |||||
if (_queue_ref.dtype == TF_DataType.TF_RESOURCE) | |||||
ret = gen_data_flow_ops.queue_dequeue_v2(_queue_ref, _dtypes, name: name); | |||||
else | |||||
ret = gen_data_flow_ops.queue_dequeue(_queue_ref, _dtypes, name: name); | |||||
return ret; | |||||
} | |||||
} | |||||
} |
@@ -42,7 +42,7 @@ namespace Tensorflow.Queues | |||||
}); | }); | ||||
} | } | ||||
private Tensor[] _check_enqueue_dtypes(object vals) | |||||
protected Tensor[] _check_enqueue_dtypes(object vals) | |||||
{ | { | ||||
var tensors = new List<Tensor>(); | var tensors = new List<Tensor>(); | ||||
@@ -56,12 +56,10 @@ namespace Tensorflow.Queues | |||||
} | } | ||||
break; | break; | ||||
case int[] vals1: | |||||
tensors.Add(ops.convert_to_tensor(vals1, dtype: _dtypes[0], name: $"component_0")); | |||||
break; | |||||
default: | default: | ||||
throw new NotImplementedException(""); | |||||
var dtype1 = GetType().Name == "PriorityQueue" ? _dtypes[1] : _dtypes[0]; | |||||
tensors.Add(ops.convert_to_tensor(vals, dtype: dtype1, name: $"component_0")); | |||||
break; | |||||
} | } | ||||
return tensors.ToArray(); | return tensors.ToArray(); | ||||
@@ -0,0 +1,28 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
namespace Tensorflow.Queues | |||||
{ | |||||
public class RandomShuffleQueue : QueueBase | |||||
{ | |||||
public RandomShuffleQueue(int capacity, | |||||
TF_DataType[] dtypes, | |||||
TensorShape[] shapes, | |||||
string[] names = null, | |||||
string shared_name = null, | |||||
string name = "randomshuffle_fifo_queue") | |||||
: base(dtypes: dtypes, shapes: shapes, names: names) | |||||
{ | |||||
_queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( | |||||
component_types: dtypes, | |||||
shapes: shapes, | |||||
capacity: capacity, | |||||
shared_name: shared_name, | |||||
name: name); | |||||
_name = _queue_ref.op.name.Split('/').Last(); | |||||
} | |||||
} | |||||
} |
@@ -77,6 +77,22 @@ namespace Tensorflow | |||||
return _op.output; | return _op.output; | ||||
} | } | ||||
public static Tensor priority_queue_v2(TF_DataType[] component_types, TensorShape[] shapes, | |||||
int capacity = -1, string container = "", string shared_name = "", | |||||
string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("PriorityQueueV2", name, new | |||||
{ | |||||
component_types, | |||||
shapes, | |||||
capacity, | |||||
container, | |||||
shared_name | |||||
}); | |||||
return _op.output; | |||||
} | |||||
public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) | public static Operation queue_enqueue(Tensor handle, Tensor[] components, int timeout_ms = -1, string name = null) | ||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new | var _op = _op_def_lib._apply_op_helper("QueueEnqueue", name, new | ||||
@@ -10,11 +10,12 @@ namespace Tensorflow.Models.ObjectDetection | |||||
{ | { | ||||
ImageResizerBuilder _image_resizer_builder; | ImageResizerBuilder _image_resizer_builder; | ||||
FasterRCNNFeatureExtractor _feature_extractor; | FasterRCNNFeatureExtractor _feature_extractor; | ||||
AnchorGeneratorBuilder anchor_generator_builder; | |||||
AnchorGeneratorBuilder _anchor_generator_builder; | |||||
public ModelBuilder() | public ModelBuilder() | ||||
{ | { | ||||
_image_resizer_builder = new ImageResizerBuilder(); | _image_resizer_builder = new ImageResizerBuilder(); | ||||
_anchor_generator_builder = new AnchorGeneratorBuilder(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -51,7 +52,7 @@ namespace Tensorflow.Models.ObjectDetection | |||||
inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate); | inplace_batchnorm_update: frcnn_config.InplaceBatchnormUpdate); | ||||
var number_of_stages = frcnn_config.NumberOfStages; | var number_of_stages = frcnn_config.NumberOfStages; | ||||
var first_stage_anchor_generator = anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); | |||||
var first_stage_anchor_generator = _anchor_generator_builder.build(frcnn_config.FirstStageAnchorGenerator); | |||||
var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | var first_stage_atrous_rate = frcnn_config.FirstStageAtrousRate; | ||||
return new FasterRCNNMetaArch(new FasterRCNNInitArgs | return new FasterRCNNMetaArch(new FasterRCNNInitArgs | ||||
@@ -1,143 +1,133 @@ | |||||
{ | |||||
"model": { | |||||
"fasterRcnn": { | |||||
"numClasses": 20, | |||||
"imageResizer": { | |||||
"keepAspectRatioResizer": { | |||||
"minDimension": 600, | |||||
"maxDimension": 1024 | |||||
# Faster R-CNN with Resnet-101 (v1), configured for Pascal VOC Dataset. | |||||
# Users should configure the fine_tune_checkpoint field in the train config as | |||||
# well as the label_map_path and input_path fields in the train_input_reader and | |||||
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that | |||||
# should be configured. | |||||
model { | |||||
faster_rcnn { | |||||
num_classes: 20 | |||||
image_resizer { | |||||
keep_aspect_ratio_resizer { | |||||
min_dimension: 600 | |||||
max_dimension: 1024 | |||||
} | |||||
} | |||||
feature_extractor { | |||||
type: 'faster_rcnn_resnet101' | |||||
first_stage_features_stride: 16 | |||||
} | |||||
first_stage_anchor_generator { | |||||
grid_anchor_generator { | |||||
scales: [0.25, 0.5, 1.0, 2.0] | |||||
aspect_ratios: [0.5, 1.0, 2.0] | |||||
height_stride: 16 | |||||
width_stride: 16 | |||||
} | |||||
} | |||||
first_stage_box_predictor_conv_hyperparams { | |||||
op: CONV | |||||
regularizer { | |||||
l2_regularizer { | |||||
weight: 0.0 | |||||
} | } | ||||
}, | |||||
"featureExtractor": { | |||||
"type": "faster_rcnn_resnet101", | |||||
"firstStageFeaturesStride": 16 | |||||
}, | |||||
"firstStageAnchorGenerator": { | |||||
"gridAnchorGenerator": { | |||||
"heightStride": 16, | |||||
"widthStride": 16, | |||||
"scales": [ | |||||
0.25, | |||||
0.5, | |||||
1.0, | |||||
2.0 | |||||
], | |||||
"aspectRatios": [ | |||||
0.5, | |||||
1.0, | |||||
2.0 | |||||
] | |||||
} | |||||
initializer { | |||||
truncated_normal_initializer { | |||||
stddev: 0.01 | |||||
} | } | ||||
}, | |||||
"firstStageBoxPredictorConvHyperparams": { | |||||
"op": "CONV", | |||||
"regularizer": { | |||||
"l2Regularizer": { | |||||
"weight": 0.0 | |||||
} | |||||
}, | |||||
"initializer": { | |||||
"truncatedNormalInitializer": { | |||||
"stddev": 0.009999999776482582 | |||||
} | |||||
} | |||||
first_stage_nms_score_threshold: 0.0 | |||||
first_stage_nms_iou_threshold: 0.7 | |||||
first_stage_max_proposals: 300 | |||||
first_stage_localization_loss_weight: 2.0 | |||||
first_stage_objectness_loss_weight: 1.0 | |||||
initial_crop_size: 14 | |||||
maxpool_kernel_size: 2 | |||||
maxpool_stride: 2 | |||||
second_stage_box_predictor { | |||||
mask_rcnn_box_predictor { | |||||
use_dropout: false | |||||
dropout_keep_probability: 1.0 | |||||
fc_hyperparams { | |||||
op: FC | |||||
regularizer { | |||||
l2_regularizer { | |||||
weight: 0.0 | |||||
} | |||||
} | } | ||||
} | |||||
}, | |||||
"firstStageNmsScoreThreshold": 0.0, | |||||
"firstStageNmsIouThreshold": 0.699999988079071, | |||||
"firstStageMaxProposals": 300, | |||||
"firstStageLocalizationLossWeight": 2.0, | |||||
"firstStageObjectnessLossWeight": 1.0, | |||||
"initialCropSize": 14, | |||||
"maxpoolKernelSize": 2, | |||||
"maxpoolStride": 2, | |||||
"secondStageBoxPredictor": { | |||||
"maskRcnnBoxPredictor": { | |||||
"fcHyperparams": { | |||||
"op": "FC", | |||||
"regularizer": { | |||||
"l2Regularizer": { | |||||
"weight": 0.0 | |||||
} | |||||
}, | |||||
"initializer": { | |||||
"varianceScalingInitializer": { | |||||
"factor": 1.0, | |||||
"uniform": true, | |||||
"mode": "FAN_AVG" | |||||
} | |||||
initializer { | |||||
variance_scaling_initializer { | |||||
factor: 1.0 | |||||
uniform: true | |||||
mode: FAN_AVG | |||||
} | } | ||||
}, | |||||
"useDropout": false, | |||||
"dropoutKeepProbability": 1.0 | |||||
} | |||||
} | } | ||||
}, | |||||
"secondStagePostProcessing": { | |||||
"batchNonMaxSuppression": { | |||||
"scoreThreshold": 0.0, | |||||
"iouThreshold": 0.6000000238418579, | |||||
"maxDetectionsPerClass": 100, | |||||
"maxTotalDetections": 300 | |||||
}, | |||||
"scoreConverter": "SOFTMAX" | |||||
}, | |||||
"secondStageLocalizationLossWeight": 2.0, | |||||
"secondStageClassificationLossWeight": 1.0 | |||||
} | |||||
} | } | ||||
}, | |||||
"trainConfig": { | |||||
"batchSize": 1, | |||||
"dataAugmentationOptions": [ | |||||
{ | |||||
"randomHorizontalFlip": {} | |||||
second_stage_post_processing { | |||||
batch_non_max_suppression { | |||||
score_threshold: 0.0 | |||||
iou_threshold: 0.6 | |||||
max_detections_per_class: 100 | |||||
max_total_detections: 300 | |||||
} | } | ||||
], | |||||
"optimizer": { | |||||
"momentumOptimizer": { | |||||
"learningRate": { | |||||
"manualStepLearningRate": { | |||||
"initialLearningRate": 9.999999747378752e-05, | |||||
"schedule": [ | |||||
{ | |||||
"step": 500000, | |||||
"learningRate": 9.999999747378752e-06 | |||||
}, | |||||
{ | |||||
"step": 700000, | |||||
"learningRate": 9.999999974752427e-07 | |||||
} | |||||
] | |||||
} | |||||
}, | |||||
"momentumOptimizerValue": 0.8999999761581421 | |||||
}, | |||||
"useMovingAverage": false | |||||
}, | |||||
"gradientClippingByNorm": 10.0, | |||||
"fineTuneCheckpoint": "D:/tmp/faster_rcnn_resnet101_coco/model.ckpt", | |||||
"fromDetectionCheckpoint": true, | |||||
"numSteps": 800000 | |||||
}, | |||||
"trainInputReader": { | |||||
"labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt", | |||||
"tfRecordInputReader": { | |||||
"inputPath": [ | |||||
"D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_train.record" | |||||
] | |||||
score_converter: SOFTMAX | |||||
} | } | ||||
}, | |||||
"evalConfig": { | |||||
"numExamples": 4952 | |||||
}, | |||||
"evalInputReader": [ | |||||
{ | |||||
"labelMapPath": "D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_label_map.pbtxt", | |||||
"shuffle": false, | |||||
"numReaders": 1, | |||||
"tfRecordInputReader": { | |||||
"inputPath": [ | |||||
"D:/Projects/PythonLab/tf-models/research/object_detection/data/pascal_val.record" | |||||
] | |||||
second_stage_localization_loss_weight: 2.0 | |||||
second_stage_classification_loss_weight: 1.0 | |||||
} | |||||
} | |||||
train_config: { | |||||
batch_size: 1 | |||||
optimizer { | |||||
momentum_optimizer: { | |||||
learning_rate: { | |||||
manual_step_learning_rate { | |||||
initial_learning_rate: 0.0001 | |||||
schedule { | |||||
step: 500000 | |||||
learning_rate: .00001 | |||||
} | |||||
schedule { | |||||
step: 700000 | |||||
learning_rate: .000001 | |||||
} | |||||
} | |||||
} | } | ||||
momentum_optimizer_value: 0.9 | |||||
} | |||||
use_moving_average: false | |||||
} | |||||
gradient_clipping_by_norm: 10.0 | |||||
fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt" | |||||
from_detection_checkpoint: true | |||||
num_steps: 800000 | |||||
data_augmentation_options { | |||||
random_horizontal_flip { | |||||
} | } | ||||
] | |||||
} | |||||
} | |||||
} | |||||
train_input_reader: { | |||||
tf_record_input_reader { | |||||
input_path: "PATH_TO_BE_CONFIGURED/pascal_train.record" | |||||
} | |||||
label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" | |||||
} | |||||
eval_config: { | |||||
num_examples: 4952 | |||||
} | |||||
eval_input_reader: { | |||||
tf_record_input_reader { | |||||
input_path: "PATH_TO_BE_CONFIGURED/pascal_val.record" | |||||
} | |||||
label_map_path: "PATH_TO_BE_CONFIGURED/pascal_label_map.pbtxt" | |||||
shuffle: false | |||||
num_readers: 1 | |||||
} |
@@ -1,4 +1,5 @@ | |||||
using System; | |||||
using Protobuf.Text; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | using System.IO; | ||||
using System.Text; | using System.Text; | ||||
@@ -10,8 +11,8 @@ namespace Tensorflow.Models.ObjectDetection.Utils | |||||
{ | { | ||||
public static TrainEvalPipelineConfig get_configs_from_pipeline_file(string pipeline_config_path) | public static TrainEvalPipelineConfig get_configs_from_pipeline_file(string pipeline_config_path) | ||||
{ | { | ||||
var json = File.ReadAllText(pipeline_config_path); | |||||
var pipeline_config = TrainEvalPipelineConfig.Parser.ParseJson(json); | |||||
var config = File.ReadAllText(pipeline_config_path); | |||||
var pipeline_config = TrainEvalPipelineConfig.Parser.ParseText(config); | |||||
return pipeline_config; | return pipeline_config; | ||||
} | } | ||||
@@ -4,6 +4,16 @@ | |||||
<TargetFramework>netcoreapp2.2</TargetFramework> | <TargetFramework>netcoreapp2.2</TargetFramework> | ||||
<AssemblyName>TensorFlow.Models</AssemblyName> | <AssemblyName>TensorFlow.Models</AssemblyName> | ||||
<RootNamespace>Tensorflow.Models</RootNamespace> | <RootNamespace>Tensorflow.Models</RootNamespace> | ||||
<Version>0.0.1</Version> | |||||
<Authors>Haiping Chen</Authors> | |||||
<Company>SciSharp STACK</Company> | |||||
<PackageProjectUrl>https://github.com/SciSharp/TensorFlow.NET</PackageProjectUrl> | |||||
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | |||||
<RepositoryType>git</RepositoryType> | |||||
<PackageTags>TensorFlow</PackageTags> | |||||
<Description>Models and examples built with TensorFlow.</Description> | |||||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||||
<Copyright>Apache2</Copyright> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -16,6 +26,10 @@ | |||||
</Content> | </Content> | ||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | |||||
<PackageReference Include="Protobuf.Text" Version="0.3.1" /> | |||||
</ItemGroup> | |||||
<ItemGroup> | <ItemGroup> | ||||
<ProjectReference Include="..\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -70,5 +70,27 @@ namespace TensorFlowNET.UnitTest | |||||
// until queue has more element. | // until queue has more element. | ||||
} | } | ||||
} | } | ||||
[TestMethod] | |||||
public void PriorityQueue() | |||||
{ | |||||
var queue = tf.PriorityQueue(3, tf.@string); | |||||
var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | |||||
var x = queue.dequeue(); | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
init.run(); | |||||
var result = sess.run(x); | |||||
Assert.AreEqual(result[0].GetInt64(), 2L); | |||||
result = sess.run(x); | |||||
Assert.AreEqual(result[0].GetInt64(), 3L); | |||||
result = sess.run(x); | |||||
Assert.AreEqual(result[0].GetInt64(), 4L); | |||||
} | |||||
} | |||||
} | } | ||||
} | } |