You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_torch.py 647 B

1234567891011121314151617181920212223
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict
  3. import torch
  4. from .base import Model
  5. class TorchModel(Model, torch.nn.Module):
  6. """ Base model interface for pytorch
  7. """
  8. def __init__(self, model_dir=None, *args, **kwargs):
  9. # init reference: https://stackoverflow.com/questions\
  10. # /9575409/calling-parent-class-init-with-multiple-inheritance-whats-the-right-way
  11. super().__init__(model_dir)
  12. super(Model, self).__init__()
  13. def forward(self, inputs: Dict[str,
  14. torch.Tensor]) -> Dict[str, torch.Tensor]:
  15. raise NotImplementedError