You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

OpClassifier.cs 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using System.Threading.Tasks;
  6. using System.Text.RegularExpressions;
  7. namespace Tensorflow.CodeGen
  8. {
  9. public class OpClassifier
  10. {
  11. private static readonly string _filenamePattern = @"^gen_[a-z]*_ops.py$";
  12. private static readonly string _pythonFunctionPattern = @"def\s+(\w+\d*\w*)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*name=None\):";
  13. private Dictionary<string, HashSet<string>> _opSet = new();
  14. public Dictionary<string, HashSet<string>> OpSet => _opSet;
  15. public OpClassifier(string pythonFileFolder, IEnumerable<string> funcNames)
  16. {
  17. DirectoryInfo directory = new DirectoryInfo(pythonFileFolder);
  18. Dictionary<string, string> fileContentMap = new();
  19. foreach (FileInfo file in directory.GetFiles())
  20. {
  21. if (Regex.IsMatch(file.Name, _filenamePattern))
  22. {
  23. Console.WriteLine(file.Name);
  24. string filenamePrefix = file.Name.Split('.')[0];
  25. string content = File.ReadAllText(file.FullName);
  26. fileContentMap[filenamePrefix] = content;
  27. }
  28. }
  29. foreach(var funcName in funcNames)
  30. {
  31. Console.WriteLine(funcName);
  32. string funcPattern = @$"^def\s+{funcName}\(";
  33. string fallbackFuncPattern = @$"^def\s+{funcName}_eager_fallback\(";
  34. foreach (var (target, content) in fileContentMap)
  35. {
  36. if(content.Contains($"def {funcName}") && content.Contains($"def {funcName}_eager_fallback"))
  37. {
  38. _opSet.SetDefault(target, new HashSet<string>()).Add(funcName);
  39. }
  40. else if (content.Contains($"def _{funcName}") && content.Contains($"def _{funcName}_eager_fallback"))
  41. {
  42. _opSet.SetDefault(target, new HashSet<string>()).Add(funcName);
  43. }
  44. }
  45. }
  46. }
  47. }
  48. }