3120241305/dataset/poisoned_dataset.py

128 lines
4.7 KiB
Python
Executable File

import random
from typing import Callable, Optional
from PIL import Image
from torchvision.datasets import CIFAR10, MNIST
import os
class TriggerHandler(object):
def __init__(self, trigger_path, trigger_size, trigger_label, img_width, img_height):
self.trigger_img = Image.open(trigger_path).convert('RGB')
self.trigger_size = trigger_size
self.trigger_img = self.trigger_img.resize((trigger_size, trigger_size))
self.trigger_label = trigger_label
self.img_width = img_width
self.img_height = img_height
def put_trigger(self, img):
img.paste(self.trigger_img, (self.img_width - self.trigger_size, self.img_height - self.trigger_size))
return img
class CIFAR10Poison(CIFAR10):
def __init__(
self,
args,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)
self.width, self.height, self.channels = self.__shape_info__()
self.trigger_handler = TriggerHandler( args.trigger_path, args.trigger_size, args.trigger_label, self.width, self.height)
self.poisoning_rate = args.poisoning_rate if train else 1.0
indices = range(len(self.targets))
self.poi_indices = random.sample(indices, k=int(len(indices) * self.poisoning_rate))
print(f"Poison {len(self.poi_indices)} over {len(indices)} samples ( poisoning rate {self.poisoning_rate})")
def __shape_info__(self):
return self.data.shape[1:]
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
# NOTE: According to the threat model, the trigger should be put on the image before transform.
# (The attacker can only poison the dataset)
if index in self.poi_indices:
target = self.trigger_handler.trigger_label
img = self.trigger_handler.put_trigger(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
class MNISTPoison(MNIST):
def __init__(
self,
args,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)
self.width, self.height = self.__shape_info__()
self.channels = 1
self.save_counter = 0 # 初始化计数器
self.max_save_count = 10 # 最大保存数量
self.save_dir = 'saved_images'
os.makedirs(self.save_dir, exist_ok=True)
self.trigger_handler = TriggerHandler( args.trigger_path, args.trigger_size, args.trigger_label, self.width, self.height)
self.poisoning_rate = args.poisoning_rate if train else 1.0
indices = range(len(self.targets))
self.poi_indices = random.sample(indices, k=int(len(indices) * self.poisoning_rate))
print(f"Poison {len(self.poi_indices)} over {len(indices)} samples ( poisoning rate {self.poisoning_rate})")
@property
def raw_folder(self) -> str:
return os.path.join(self.root, "MNIST", "raw")
@property
def processed_folder(self) -> str:
return os.path.join(self.root, "MNIST", "processed")
def __shape_info__(self):
return self.data.shape[1:]
def __getitem__(self, index):
img, target = self.data[index], int(self.targets[index])
img = Image.fromarray(img.numpy(), mode="L")
# 保存投毒前的图片
if self.save_counter < self.max_save_count:
img.save(os.path.join(self.save_dir, f'original_{self.save_counter}.png'))
# NOTE: According to the threat model, the trigger should be put on the image before transform.
# (The attacker can only poison the dataset)
if index in self.poi_indices:
target = self.trigger_handler.trigger_label
img = self.trigger_handler.put_trigger(img)
# 保存投毒后的图片
if self.save_counter < self.max_save_count:
img.save(os.path.join(self.save_dir, f'poisoned_{self.save_counter}.png'))
self.save_counter += 1 # 递增计数器
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target