import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image from fastapi import FastAPI, File, UploadFile from fastapi.responses import Response, JSONResponse import io import uvicorn import torch.nn.functional as F import os import argparse from c2net.context import prepare # import multiprocessing as mp # mp.set_start_method('spawn', True) c2net_context = prepare() files = [f for f in os.listdir(c2net_context.pretrain_model_path) if os.path.isdir(c2net_context.pretrain_model_path+ f"/{f}")] model_name = "" for file in files: if (not file.startswith(".")) and len(os.listdir(c2net_context.pretrain_model_path+ f"/{file}")) > 0: model_name = file break model_path = c2net_context.pretrain_model_path + f"/{model_name}" pth_file = "" for root, dirs, files in os.walk(model_path): for file in files: if file.endswith(".pth"): pth_file = file custom_model_path = c2net_context.pretrain_model_path+"/"+ model_name + "/" + pth_file batch_size = 10 number_of_labels = 10 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') base_url = os.getenv('OPENI_SELF_URL') class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(12) self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(12) self.pool = nn.MaxPool2d(2, 2) self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1) self.bn4 = nn.BatchNorm2d(24) self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1) self.bn5 = nn.BatchNorm2d(24) self.fc1 = nn.Linear(24 * 10 * 10, 10) def forward(self, input): output = F.relu(self.bn1(self.conv1(input))) output = F.relu(self.bn2(self.conv2(output))) output = self.pool(output) output = F.relu(self.bn4(self.conv4(output))) output = F.relu(self.bn5(self.conv5(output))) output = output.view(-1, 24 * 10 * 10) output = self.fc1(output) return output app = FastAPI() TIMEOUT_KEEP_ALIVE = 20 model = Network() transform_fn = transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) def inference_one_image(image, card): image = image.to(card) model_path = custom_model_path model.load_state_dict(torch.load(model_path)) model.eval() device = torch.device(card) model.to(device) output = model(image) _, index = torch.max(output, 1) return index @app.get(os.path.join(base_url,"test")) async def api_get(): return "This is a GET request." @app.post(os.path.join(base_url,"infer")) async def classification_cnn(file: UploadFile = File(...)) -> Response: file_content = await file.read() image = Image.open(io.BytesIO(file_content)) img = transform_fn(image) batch = torch.unsqueeze(img, 0) predicted = inference_one_image(batch, "cuda:0") if predicted is None: ret = {"error": "predicted is None"} else: ret = {"result": classes[predicted]} return JSONResponse(ret) if __name__ == '__main__': uvicorn.run(app, host='0.0.0.0', port=int(os.getenv('OPENI_SELF_PORT')), timeout_keep_alive=TIMEOUT_KEEP_ALIVE)