Browse Source

[to #42322933]修复shop segmentation CPU Inference错误

修复CPU Inference错误,支持CPU inference
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10177721
master
xingguang.zxg yingda.chen 3 years ago
parent
commit
1eedbd65bc
2 changed files with 9 additions and 9 deletions
  1. +1
    -1
      modelscope/models/cv/shop_segmentation/models.py
  2. +8
    -8
      modelscope/models/cv/shop_segmentation/shop_seg_model.py

+ 1
- 1
modelscope/models/cv/shop_segmentation/models.py View File

@@ -552,7 +552,7 @@ class CLIPVisionTransformer(nn.Module):
nn.GroupNorm(1, embed_dim),
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.BatchNorm2d(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2),


+ 8
- 8
modelscope/models/cv/shop_segmentation/shop_seg_model.py View File

@@ -33,18 +33,18 @@ class ShopSegmentation(TorchModel):
model_dir=model_dir, device_id=device_id, *args, **kwargs)

self.model = SHOPSEG(model_dir=model_dir)
pretrained_params = torch.load('{}/{}'.format(
model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
pretrained_params = torch.load(
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
map_location='cpu')
self.model.load_state_dict(pretrained_params)
self.model.eval()
self.device_id = device_id
if self.device_id >= 0 and torch.cuda.is_available():
self.model.to('cuda:{}'.format(self.device_id))
logger.info('Use GPU: {}'.format(self.device_id))
if device_id >= 0 and torch.cuda.is_available():
self.model.to('cuda:{}'.format(device_id))
logger.info('Use GPU: {}'.format(device_id))
else:
self.device_id = -1
device_id = -1
logger.info('Use CPU for inference')
self.device_id = device_id

def preprocess(self, img, size=1024):
mean = [0.48145466, 0.4578275, 0.40821073]


Loading…
Cancel
Save