from argparse import FileType from fileinput import filename import os import sys import cv2 import time import numpy as np from PIL import Image import torch from torchvision.utils import save_image from options.test_options import TestOptions from data.data_loader import CreateDataLoader from models.models import create_model from data.base_dataset import BaseDataset, get_params, get_transform, normalize from ui.ui import Ui_Form from ui.mouse_event import GraphicsScene from ui_util.config import Config from PyQt5.QtCore import * from PyQt5.QtGui import * from PyQt5.QtWidgets import * from PyQt5.QtPrintSupport import QPrintDialog, QPrinter color_list = [QColor(0, 0, 0), QColor(204, 0, 0), QColor(76, 153, 0), QColor(204, 204, 0), QColor(51, 51, 255), QColor(204, 0, 204), QColor(0, 255, 255), QColor(51, 255, 255), QColor(102, 51, 0), QColor(255, 0, 0), QColor(102, 204, 0), QColor(255, 255, 0), QColor(0, 0, 153), QColor(0, 0, 204), QColor(255, 51, 153), QColor(0, 204, 204), QColor(0, 51, 0), QColor(255, 153, 51), QColor(0, 204, 0)] class Ex(QWidget, Ui_Form): def __init__(self, model, opt): super(Ex, self).__init__() self.setupUi(self) self.show() self.model = model self.opt = opt self.output_img = None self.mat_img = None self.mode = 0 self.size = 6 self.mask = None self.mask_m = None self.img = None self.mouse_clicked = False self.scene = GraphicsScene(self.mode, self.size) self.graphicsView.setScene(self.scene) self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft) self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.ref_scene = QGraphicsScene() self.graphicsView_2.setScene(self.ref_scene) self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft) self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.result_scene = QGraphicsScene() self.graphicsView_3.setScene(self.result_scene) self.graphicsView_3.setAlignment(Qt.AlignTop | Qt.AlignLeft) self.graphicsView_3.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.graphicsView_3.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.dlg = QColorDialog(self.graphicsView) self.color = None def open(self): fileName, _ = QFileDialog.getOpenFileName(self, "Open File", QDir.currentPath()) if fileName: image = QPixmap(fileName) mat_img = Image.open(fileName) self.img = mat_img.copy() if image.isNull(): QMessageBox.information(self, "Image Viewer", "Cannot load %s." % fileName) return image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio) if len(self.ref_scene.items())>0: self.ref_scene.removeItem(self.ref_scene.items()[-1]) self.ref_scene.addPixmap(image) if len(self.result_scene.items())>0: self.result_scene.removeItem(self.result_scene.items()[-1]) self.result_scene.addPixmap(image) def open_mask(self): fileName, _ = QFileDialog.getOpenFileName(self, "Open File", QDir.currentPath()) if fileName: mat_img = cv2.imread(fileName) self.mask = mat_img.copy() self.mask_m = mat_img mat_img = mat_img.copy() image = QImage(mat_img, 512, 512, QImage.Format_RGB888) if image.isNull(): QMessageBox.information(self, "Image Viewer", "Cannot load %s." % fileName) return for i in range(512): for j in range(512): r, g, b, a = image.pixelColor(i, j).getRgb() image.setPixel(i, j, color_list[r].rgb()) pixmap = QPixmap() pixmap.convertFromImage(image) self.image = pixmap.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio) self.scene.reset() if len(self.scene.items())>0: self.scene.reset_items() self.scene.addPixmap(self.image) def bg_mode(self): self.scene.mode = 0 def skin_mode(self): self.scene.mode = 1 def nose_mode(self): self.scene.mode = 2 def eye_g_mode(self): self.scene.mode = 3 def l_eye_mode(self): self.scene.mode = 4 def r_eye_mode(self): self.scene.mode = 5 def l_brow_mode(self): self.scene.mode = 6 def r_brow_mode(self): self.scene.mode = 7 def l_ear_mode(self): self.scene.mode = 8 def r_ear_mode(self): self.scene.mode = 9 def mouth_mode(self): self.scene.mode = 10 def u_lip_mode(self): self.scene.mode = 11 def l_lip_mode(self): self.scene.mode = 12 def hair_mode(self): self.scene.mode = 13 def hat_mode(self): self.scene.mode = 14 def ear_r_mode(self): self.scene.mode = 15 def neck_l_mode(self): self.scene.mode = 16 def neck_mode(self): self.scene.mode = 17 def cloth_mode(self): self.scene.mode = 18 def increase(self): if self.scene.size < 15: self.scene.size += 1 def decrease(self): if self.scene.size > 1: self.scene.size -= 1 def edit(self): for i in range(19): self.mask_m = self.make_mask(self.mask_m, self.scene.mask_points[i], self.scene.size_points[i], i) params = get_params(self.opt, (512,512)) transform_mask = get_transform(self.opt, params, method=Image.NEAREST, normalize=False, normalize_mask=True) transform_image = get_transform(self.opt, params) mask = self.mask.copy() mask_m = self.mask_m.copy() mask = transform_mask(Image.fromarray(np.uint8(mask))) mask_m = transform_mask(Image.fromarray(np.uint8(mask_m))) img = transform_image(self.img) start_t = time.time() generated = model.inference(torch.FloatTensor([mask_m.numpy()]), torch.FloatTensor([mask.numpy()]), torch.FloatTensor([img.numpy()])) end_t = time.time() print('inference time : {}'.format(end_t-start_t)) print(((generated.data[0] + 1) / 2).shape) print(((generated.data[0] + 1) / 2)[:,200,200]) save_image(((generated.data[0] + 1) / 2),'output.jpg') # self.output_img = np.transpose((generated.data[0] + 1)/2 , (2,0,1)) result = generated.permute(0, 2, 3, 1) result = result.cpu().detach().numpy() result = (result + 1) * 127.5 np_img = np.asarray(result[0,:,:,:], dtype=np.uint8).copy() print(np_img.shape,type(np_img)) q_img = QImage(np_img.data, np_img.shape[1], np_img.shape[0], np_img.shape[1]*3, QImage.Format_RGB888) rect = q_img.rect() # 第1种获取长宽的方法 w = rect.width() h = rect.height() # 第2种获取长宽的方法 w_ = q_img.width() h_ = q_img.height() print(rect, (w, h), (w_, h_)) #for i in range(512): # for j in range(512): # r, g, b, a = image.pixelColor(i, j).getRgb() # image.setPixel(i, j, color_list[r].rgb()) if len(self.result_scene.items())>0: self.result_scene.removeItem(self.result_scene.items()[-1]) self.result_scene.addPixmap(QPixmap.fromImage(q_img)) def make_mask(self, mask, pts, sizes, color): if len(pts)>0: for idx, pt in enumerate(pts): cv2.line(mask,pt['prev'],pt['curr'],(color,color,color),sizes[idx]) return mask def save_img(self): #print(self.output_img) if type(self.output_img): fileName, fileTType = QFileDialog.getSaveFileName(self, "Save File", QDir.currentPath(), 'jpg(*.jpg)') #print(QDir.currentPath()) #print(fileName[0]) #print(fileName[1]) cv2.imwrite(fileName[0], self.output_img) def undo(self): self.scene.undo() def clear(self): self.mask_m = self.mask.copy() self.scene.reset_items() self.scene.reset() if type(self.image): self.scene.addPixmap(self.image) if __name__ == '__main__': os.environ["CUDA_VISIBLE_DEVICES"] = str(0) #model = Model(config) opt = TestOptions().parse(save=False) opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip model = create_model(opt) app = QApplication(sys.argv) ex = Ex(model, opt) sys.exit(app.exec_())