manifold_face_tamper/tamper/demo.py

266 lines
9.2 KiB
Python
Raw Permalink Normal View History

2023-11-29 20:16:25 +08:00
from argparse import FileType
from fileinput import filename
import os
import sys
import cv2
import time
import numpy as np
from PIL import Image
import torch
from torchvision.utils import save_image
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from ui.ui import Ui_Form
from ui.mouse_event import GraphicsScene
from ui_util.config import Config
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from PyQt5.QtPrintSupport import QPrintDialog, QPrinter
color_list = [QColor(0, 0, 0), QColor(204, 0, 0), QColor(76, 153, 0), QColor(204, 204, 0), QColor(51, 51, 255), QColor(204, 0, 204), QColor(0, 255, 255), QColor(51, 255, 255), QColor(102, 51, 0), QColor(255, 0, 0), QColor(102, 204, 0), QColor(255, 255, 0), QColor(0, 0, 153), QColor(0, 0, 204), QColor(255, 51, 153), QColor(0, 204, 204), QColor(0, 51, 0), QColor(255, 153, 51), QColor(0, 204, 0)]
class Ex(QWidget, Ui_Form):
def __init__(self, model, opt):
super(Ex, self).__init__()
self.setupUi(self)
self.show()
self.model = model
self.opt = opt
self.output_img = None
self.mat_img = None
self.mode = 0
self.size = 6
self.mask = None
self.mask_m = None
self.img = None
self.mouse_clicked = False
self.scene = GraphicsScene(self.mode, self.size)
self.graphicsView.setScene(self.scene)
self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft)
self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.ref_scene = QGraphicsScene()
self.graphicsView_2.setScene(self.ref_scene)
self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft)
self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.result_scene = QGraphicsScene()
self.graphicsView_3.setScene(self.result_scene)
self.graphicsView_3.setAlignment(Qt.AlignTop | Qt.AlignLeft)
self.graphicsView_3.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.graphicsView_3.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.dlg = QColorDialog(self.graphicsView)
self.color = None
def open(self):
fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
QDir.currentPath())
if fileName:
image = QPixmap(fileName)
mat_img = Image.open(fileName)
self.img = mat_img.copy()
if image.isNull():
QMessageBox.information(self, "Image Viewer",
"Cannot load %s." % fileName)
return
image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
if len(self.ref_scene.items())>0:
self.ref_scene.removeItem(self.ref_scene.items()[-1])
self.ref_scene.addPixmap(image)
if len(self.result_scene.items())>0:
self.result_scene.removeItem(self.result_scene.items()[-1])
self.result_scene.addPixmap(image)
def open_mask(self):
fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
QDir.currentPath())
if fileName:
mat_img = cv2.imread(fileName)
self.mask = mat_img.copy()
self.mask_m = mat_img
mat_img = mat_img.copy()
image = QImage(mat_img, 512, 512, QImage.Format_RGB888)
if image.isNull():
QMessageBox.information(self, "Image Viewer",
"Cannot load %s." % fileName)
return
for i in range(512):
for j in range(512):
r, g, b, a = image.pixelColor(i, j).getRgb()
image.setPixel(i, j, color_list[r].rgb())
pixmap = QPixmap()
pixmap.convertFromImage(image)
self.image = pixmap.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
self.scene.reset()
if len(self.scene.items())>0:
self.scene.reset_items()
self.scene.addPixmap(self.image)
def bg_mode(self):
self.scene.mode = 0
def skin_mode(self):
self.scene.mode = 1
def nose_mode(self):
self.scene.mode = 2
def eye_g_mode(self):
self.scene.mode = 3
def l_eye_mode(self):
self.scene.mode = 4
def r_eye_mode(self):
self.scene.mode = 5
def l_brow_mode(self):
self.scene.mode = 6
def r_brow_mode(self):
self.scene.mode = 7
def l_ear_mode(self):
self.scene.mode = 8
def r_ear_mode(self):
self.scene.mode = 9
def mouth_mode(self):
self.scene.mode = 10
def u_lip_mode(self):
self.scene.mode = 11
def l_lip_mode(self):
self.scene.mode = 12
def hair_mode(self):
self.scene.mode = 13
def hat_mode(self):
self.scene.mode = 14
def ear_r_mode(self):
self.scene.mode = 15
def neck_l_mode(self):
self.scene.mode = 16
def neck_mode(self):
self.scene.mode = 17
def cloth_mode(self):
self.scene.mode = 18
def increase(self):
if self.scene.size < 15:
self.scene.size += 1
def decrease(self):
if self.scene.size > 1:
self.scene.size -= 1
def edit(self):
for i in range(19):
self.mask_m = self.make_mask(self.mask_m, self.scene.mask_points[i], self.scene.size_points[i], i)
params = get_params(self.opt, (512,512))
transform_mask = get_transform(self.opt, params, method=Image.NEAREST, normalize=False, normalize_mask=True)
transform_image = get_transform(self.opt, params)
mask = self.mask.copy()
mask_m = self.mask_m.copy()
mask = transform_mask(Image.fromarray(np.uint8(mask)))
mask_m = transform_mask(Image.fromarray(np.uint8(mask_m)))
img = transform_image(self.img)
start_t = time.time()
generated = model.inference(torch.FloatTensor([mask_m.numpy()]), torch.FloatTensor([mask.numpy()]), torch.FloatTensor([img.numpy()]))
end_t = time.time()
print('inference time : {}'.format(end_t-start_t))
print(((generated.data[0] + 1) / 2).shape)
print(((generated.data[0] + 1) / 2)[:,200,200])
save_image(((generated.data[0] + 1) / 2),'output.jpg')
# self.output_img = np.transpose((generated.data[0] + 1)/2 , (2,0,1))
result = generated.permute(0, 2, 3, 1)
result = result.cpu().detach().numpy()
result = (result + 1) * 127.5
np_img = np.asarray(result[0,:,:,:], dtype=np.uint8).copy()
print(np_img.shape,type(np_img))
q_img = QImage(np_img.data, np_img.shape[1], np_img.shape[0], np_img.shape[1]*3, QImage.Format_RGB888)
rect = q_img.rect()
# 第1种获取长宽的方法
w = rect.width()
h = rect.height()
# 第2种获取长宽的方法
w_ = q_img.width()
h_ = q_img.height()
print(rect, (w, h), (w_, h_))
#for i in range(512):
# for j in range(512):
# r, g, b, a = image.pixelColor(i, j).getRgb()
# image.setPixel(i, j, color_list[r].rgb())
if len(self.result_scene.items())>0:
self.result_scene.removeItem(self.result_scene.items()[-1])
self.result_scene.addPixmap(QPixmap.fromImage(q_img))
def make_mask(self, mask, pts, sizes, color):
if len(pts)>0:
for idx, pt in enumerate(pts):
cv2.line(mask,pt['prev'],pt['curr'],(color,color,color),sizes[idx])
return mask
def save_img(self):
#print(self.output_img)
if type(self.output_img):
fileName, fileTType = QFileDialog.getSaveFileName(self, "Save File",
QDir.currentPath(), 'jpg(*.jpg)')
#print(QDir.currentPath())
#print(fileName[0])
#print(fileName[1])
cv2.imwrite(fileName[0], self.output_img)
def undo(self):
self.scene.undo()
def clear(self):
self.mask_m = self.mask.copy()
self.scene.reset_items()
self.scene.reset()
if type(self.image):
self.scene.addPixmap(self.image)
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
#model = Model(config)
opt = TestOptions().parse(save=False)
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
model = create_model(opt)
app = QApplication(sys.argv)
ex = Ex(model, opt)
sys.exit(app.exec_())