Browse Source

Merge pull request #1234 from barfeous/bugs/augmentedGraphView

fix: Avoid modifying augmented graph view collection upon traversal
pull/1250/head
Rinne GitHub 1 year ago
parent
commit
8775b0b87a
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 9 deletions
  1. +5
    -9
      src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs

+ 5
- 9
src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs View File

@@ -88,11 +88,11 @@ public class AugmentedGraphView: ObjectGraphView


public override (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() public override (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
{ {
Trackable get_merged_trackable(Trackable x)
void merged_trackable(Trackable x)
{ {
// TODO: complete it with new definitions `Asset` and `TrackableConstant`. // TODO: complete it with new definitions `Asset` and `TrackableConstant`.
return x;
} }

var trackable_objects = base.breadth_first_traversal(); var trackable_objects = base.breadth_first_traversal();


foreach(var obj in _children_cache.Keys) foreach(var obj in _children_cache.Keys)
@@ -100,7 +100,7 @@ public class AugmentedGraphView: ObjectGraphView
// skip the deletion of cache (maybe do it later). // skip the deletion of cache (maybe do it later).
foreach(var pair in _children_cache[obj]) foreach(var pair in _children_cache[obj])
{ {
_children_cache[obj][pair.Key] = get_merged_trackable(pair.Value);
merged_trackable(pair.Value);
} }
} }


@@ -109,15 +109,11 @@ public class AugmentedGraphView: ObjectGraphView


public List<(string, Trackable)> list_dependencies(Trackable obj) public List<(string, Trackable)> list_dependencies(Trackable obj)
{ {
IDictionary<string, Trackable> children;
if (!_children_cache.ContainsKey(obj))
if (!_children_cache.TryGetValue(obj, out var children))
{ {
children= new Dictionary<string, Trackable>(); children= new Dictionary<string, Trackable>();
} }
else
{
children= _children_cache[obj];
}

List<(string, Trackable)> res = new(); List<(string, Trackable)> res = new();
foreach(var pair in obj.deserialization_dependencies(children)) foreach(var pair in obj.deserialization_dependencies(children))
{ {


Loading…
Cancel
Save