|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """BaseNetwork module."""
- import os
-
- from jinja2 import Template
-
- from mindinsight.wizard.base.source_file import SourceFile
-
-
- def render_template(template_file_path, context):
- with open(template_file_path, encoding='utf-8') as fp:
- content = fp.read()
- template = Template(content, trim_blocks=True, lstrip_blocks=True, keep_trailing_newline=True)
- return template.render(context)
-
-
- class TemplateManager:
- """BaseNetwork code generator."""
- replace_template_suffixes = [('.py-tpl', '.py')]
-
- def __init__(self, template_base_dir, exclude_dirs=None, exclude_files=None):
- self.template_base_dir = template_base_dir
- self.exclude_dirs = self._get_exclude_paths(template_base_dir, exclude_dirs)
- self.exclude_files = self._get_exclude_paths(template_base_dir, exclude_files)
-
- @staticmethod
- def _get_exclude_paths(base_dir, exclude_paths):
- """Convert exclude path to absolute directory path."""
- exclude_abs_paths = []
- if exclude_paths is None:
- return exclude_abs_paths
-
- for exclude_path in exclude_paths:
- if exclude_path.startswith(base_dir):
- exclude_abs_paths.append(exclude_path)
- else:
- exclude_abs_paths.append(os.path.join(base_dir, exclude_path))
- return exclude_abs_paths
-
- def get_template_files(self):
- """Get template files for template directory."""
- template_files = []
- for root, sub_dirs, files in os.walk(self.template_base_dir):
- for sub_dir in sub_dirs[:]:
- if sub_dir.startswith('.') or \
- sub_dir == '__pycache__' or \
- os.path.join(root, sub_dir) in self.exclude_dirs:
- sub_dirs.remove(sub_dir)
-
- for filename in files:
- if os.path.join(root, filename) not in self.exclude_files:
- template_file_path = os.path.join(root, filename)
- template_files.append(template_file_path)
- return template_files
-
- def render(self, **options):
- """Generate the network files."""
- source_files = []
- template_files = self.get_template_files()
- extensions = tuple(options.get('extensions', '.py'))
- for template_file in template_files:
- new_file_path = template_file
- template_file_path = template_file
- for template_suffix, new_file_suffix in self.replace_template_suffixes:
- if new_file_path.endswith(template_suffix):
- new_file_path = new_file_path[:-len(template_suffix)] + new_file_suffix
- break
-
- source_file = SourceFile()
- source_file.file_relative_path = new_file_path[len(self.template_base_dir) + 1:]
- source_file.template_file_path = template_file_path
- if new_file_path.endswith(extensions):
- source_file.content = render_template(template_file_path, options)
- else:
- with open(template_file_path, encoding='utf-8') as fp:
- source_file.content = fp.read()
- source_files.append(source_file)
- return source_files
|