|
- import torch
- import torch.nn as nn
- from torchvision import models, transforms
- from PIL import Image
- from fastapi import FastAPI, File, UploadFile
- from fastapi.responses import Response, JSONResponse
- import io
- import uvicorn
- import torch.nn.functional as F
- import os
- import argparse
-
- from c2net.context import prepare
- # import multiprocessing as mp
-
- # mp.set_start_method('spawn', True)
-
- c2net_context = prepare()
-
-
- files = [f for f in os.listdir(c2net_context.pretrain_model_path) if os.path.isdir(c2net_context.pretrain_model_path+ f"/{f}")]
- model_name = ""
-
- for file in files:
- if (not file.startswith(".")) and len(os.listdir(c2net_context.pretrain_model_path+ f"/{file}")) > 0:
- model_name = file
- break
-
- model_path = c2net_context.pretrain_model_path + f"/{model_name}"
-
- pth_file = ""
- for root, dirs, files in os.walk(model_path):
- for file in files:
- if file.endswith(".pth"):
- pth_file = file
-
- custom_model_path = c2net_context.pretrain_model_path+"/"+ model_name + "/" + pth_file
-
-
- batch_size = 10
- number_of_labels = 10
- classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
-
- base_url = os.getenv('OPENI_SELF_URL')
-
- class Network(nn.Module):
- def __init__(self):
- super(Network, self).__init__()
-
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=1)
- self.bn1 = nn.BatchNorm2d(12)
- self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1)
- self.bn2 = nn.BatchNorm2d(12)
- self.pool = nn.MaxPool2d(2, 2)
- self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1)
- self.bn4 = nn.BatchNorm2d(24)
- self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1)
- self.bn5 = nn.BatchNorm2d(24)
- self.fc1 = nn.Linear(24 * 10 * 10, 10)
-
- def forward(self, input):
- output = F.relu(self.bn1(self.conv1(input)))
- output = F.relu(self.bn2(self.conv2(output)))
- output = self.pool(output)
- output = F.relu(self.bn4(self.conv4(output)))
- output = F.relu(self.bn5(self.conv5(output)))
- output = output.view(-1, 24 * 10 * 10)
- output = self.fc1(output)
-
- return output
-
- app = FastAPI()
-
- TIMEOUT_KEEP_ALIVE = 20
-
- model = Network()
-
- transform_fn = transforms.Compose([
- transforms.Resize(32),
- transforms.CenterCrop(32),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
- ])
-
- def inference_one_image(image, card):
- image = image.to(card)
- model_path = custom_model_path
- model.load_state_dict(torch.load(model_path))
-
- model.eval()
- device = torch.device(card)
- model.to(device)
-
- output = model(image)
-
- _, index = torch.max(output, 1)
- return index
-
- @app.get(os.path.join(base_url,"test"))
- async def api_get():
- return "This is a GET request."
-
- @app.post(os.path.join(base_url,"infer"))
- async def classification_cnn(file: UploadFile = File(...)) -> Response:
- file_content = await file.read()
- image = Image.open(io.BytesIO(file_content))
- img = transform_fn(image)
- batch = torch.unsqueeze(img, 0)
-
- predicted = inference_one_image(batch, "cuda:0")
-
- if predicted is None:
- ret = {"error": "predicted is None"}
- else:
- ret = {"result": classes[predicted]}
- return JSONResponse(ret)
-
- if __name__ == '__main__':
- uvicorn.run(app, host='0.0.0.0', port=int(os.getenv('OPENI_SELF_PORT')), timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|