manifold_face_tamper/tamper/demo.py

266 lines
9.2 KiB
Python

from argparse import FileType
from fileinput import filename
import os
import sys
import cv2
import time
import numpy as np
from PIL import Image
import torch
from torchvision.utils import save_image
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from ui.ui import Ui_Form
from ui.mouse_event import GraphicsScene
from ui_util.config import Config
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from PyQt5.QtPrintSupport import QPrintDialog, QPrinter
color_list = [QColor(0, 0, 0), QColor(204, 0, 0), QColor(76, 153, 0), QColor(204, 204, 0), QColor(51, 51, 255), QColor(204, 0, 204), QColor(0, 255, 255), QColor(51, 255, 255), QColor(102, 51, 0), QColor(255, 0, 0), QColor(102, 204, 0), QColor(255, 255, 0), QColor(0, 0, 153), QColor(0, 0, 204), QColor(255, 51, 153), QColor(0, 204, 204), QColor(0, 51, 0), QColor(255, 153, 51), QColor(0, 204, 0)]
class Ex(QWidget, Ui_Form):
def __init__(self, model, opt):
super(Ex, self).__init__()
self.setupUi(self)
self.show()
self.model = model
self.opt = opt
self.output_img = None
self.mat_img = None
self.mode = 0
self.size = 6
self.mask = None
self.mask_m = None
self.img = None
self.mouse_clicked = False
self.scene = GraphicsScene(self.mode, self.size)
self.graphicsView.setScene(self.scene)
self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft)
self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.ref_scene = QGraphicsScene()
self.graphicsView_2.setScene(self.ref_scene)
self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft)
self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.result_scene = QGraphicsScene()
self.graphicsView_3.setScene(self.result_scene)
self.graphicsView_3.setAlignment(Qt.AlignTop | Qt.AlignLeft)
self.graphicsView_3.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.graphicsView_3.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.dlg = QColorDialog(self.graphicsView)
self.color = None
def open(self):
fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
QDir.currentPath())
if fileName:
image = QPixmap(fileName)
mat_img = Image.open(fileName)
self.img = mat_img.copy()
if image.isNull():
QMessageBox.information(self, "Image Viewer",
"Cannot load %s." % fileName)
return
image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
if len(self.ref_scene.items())>0:
self.ref_scene.removeItem(self.ref_scene.items()[-1])
self.ref_scene.addPixmap(image)
if len(self.result_scene.items())>0:
self.result_scene.removeItem(self.result_scene.items()[-1])
self.result_scene.addPixmap(image)
def open_mask(self):
fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
QDir.currentPath())
if fileName:
mat_img = cv2.imread(fileName)
self.mask = mat_img.copy()
self.mask_m = mat_img
mat_img = mat_img.copy()
image = QImage(mat_img, 512, 512, QImage.Format_RGB888)
if image.isNull():
QMessageBox.information(self, "Image Viewer",
"Cannot load %s." % fileName)
return
for i in range(512):
for j in range(512):
r, g, b, a = image.pixelColor(i, j).getRgb()
image.setPixel(i, j, color_list[r].rgb())
pixmap = QPixmap()
pixmap.convertFromImage(image)
self.image = pixmap.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
self.scene.reset()
if len(self.scene.items())>0:
self.scene.reset_items()
self.scene.addPixmap(self.image)
def bg_mode(self):
self.scene.mode = 0
def skin_mode(self):
self.scene.mode = 1
def nose_mode(self):
self.scene.mode = 2
def eye_g_mode(self):
self.scene.mode = 3
def l_eye_mode(self):
self.scene.mode = 4
def r_eye_mode(self):
self.scene.mode = 5
def l_brow_mode(self):
self.scene.mode = 6
def r_brow_mode(self):
self.scene.mode = 7
def l_ear_mode(self):
self.scene.mode = 8
def r_ear_mode(self):
self.scene.mode = 9
def mouth_mode(self):
self.scene.mode = 10
def u_lip_mode(self):
self.scene.mode = 11
def l_lip_mode(self):
self.scene.mode = 12
def hair_mode(self):
self.scene.mode = 13
def hat_mode(self):
self.scene.mode = 14
def ear_r_mode(self):
self.scene.mode = 15
def neck_l_mode(self):
self.scene.mode = 16
def neck_mode(self):
self.scene.mode = 17
def cloth_mode(self):
self.scene.mode = 18
def increase(self):
if self.scene.size < 15:
self.scene.size += 1
def decrease(self):
if self.scene.size > 1:
self.scene.size -= 1
def edit(self):
for i in range(19):
self.mask_m = self.make_mask(self.mask_m, self.scene.mask_points[i], self.scene.size_points[i], i)
params = get_params(self.opt, (512,512))
transform_mask = get_transform(self.opt, params, method=Image.NEAREST, normalize=False, normalize_mask=True)
transform_image = get_transform(self.opt, params)
mask = self.mask.copy()
mask_m = self.mask_m.copy()
mask = transform_mask(Image.fromarray(np.uint8(mask)))
mask_m = transform_mask(Image.fromarray(np.uint8(mask_m)))
img = transform_image(self.img)
start_t = time.time()
generated = model.inference(torch.FloatTensor([mask_m.numpy()]), torch.FloatTensor([mask.numpy()]), torch.FloatTensor([img.numpy()]))
end_t = time.time()
print('inference time : {}'.format(end_t-start_t))
print(((generated.data[0] + 1) / 2).shape)
print(((generated.data[0] + 1) / 2)[:,200,200])
save_image(((generated.data[0] + 1) / 2),'output.jpg')
# self.output_img = np.transpose((generated.data[0] + 1)/2 , (2,0,1))
result = generated.permute(0, 2, 3, 1)
result = result.cpu().detach().numpy()
result = (result + 1) * 127.5
np_img = np.asarray(result[0,:,:,:], dtype=np.uint8).copy()
print(np_img.shape,type(np_img))
q_img = QImage(np_img.data, np_img.shape[1], np_img.shape[0], np_img.shape[1]*3, QImage.Format_RGB888)
rect = q_img.rect()
# 第1种获取长宽的方法
w = rect.width()
h = rect.height()
# 第2种获取长宽的方法
w_ = q_img.width()
h_ = q_img.height()
print(rect, (w, h), (w_, h_))
#for i in range(512):
# for j in range(512):
# r, g, b, a = image.pixelColor(i, j).getRgb()
# image.setPixel(i, j, color_list[r].rgb())
if len(self.result_scene.items())>0:
self.result_scene.removeItem(self.result_scene.items()[-1])
self.result_scene.addPixmap(QPixmap.fromImage(q_img))
def make_mask(self, mask, pts, sizes, color):
if len(pts)>0:
for idx, pt in enumerate(pts):
cv2.line(mask,pt['prev'],pt['curr'],(color,color,color),sizes[idx])
return mask
def save_img(self):
#print(self.output_img)
if type(self.output_img):
fileName, fileTType = QFileDialog.getSaveFileName(self, "Save File",
QDir.currentPath(), 'jpg(*.jpg)')
#print(QDir.currentPath())
#print(fileName[0])
#print(fileName[1])
cv2.imwrite(fileName[0], self.output_img)
def undo(self):
self.scene.undo()
def clear(self):
self.mask_m = self.mask.copy()
self.scene.reset_items()
self.scene.reset()
if type(self.image):
self.scene.addPixmap(self.image)
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
#model = Model(config)
opt = TestOptions().parse(save=False)
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
model = create_model(opt)
app = QApplication(sys.argv)
ex = Ex(model, opt)
sys.exit(app.exec_())