import random from typing import Callable, Optional from PIL import Image from torchvision.datasets import CIFAR10, MNIST import os class TriggerHandler(object): def __init__(self, trigger_path, trigger_size, trigger_label, img_width, img_height): self.trigger_img = Image.open(trigger_path).convert('RGB') self.trigger_size = trigger_size self.trigger_img = self.trigger_img.resize((trigger_size, trigger_size)) self.trigger_label = trigger_label self.img_width = img_width self.img_height = img_height def put_trigger(self, img): img.paste(self.trigger_img, (self.img_width - self.trigger_size, self.img_height - self.trigger_size)) return img class CIFAR10Poison(CIFAR10): def __init__( self, args, root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) self.width, self.height, self.channels = self.__shape_info__() self.trigger_handler = TriggerHandler( args.trigger_path, args.trigger_size, args.trigger_label, self.width, self.height) self.poisoning_rate = args.poisoning_rate if train else 1.0 indices = range(len(self.targets)) self.poi_indices = random.sample(indices, k=int(len(indices) * self.poisoning_rate)) print(f"Poison {len(self.poi_indices)} over {len(indices)} samples ( poisoning rate {self.poisoning_rate})") def __shape_info__(self): return self.data.shape[1:] def __getitem__(self, index): img, target = self.data[index], self.targets[index] img = Image.fromarray(img) # NOTE: According to the threat model, the trigger should be put on the image before transform. # (The attacker can only poison the dataset) if index in self.poi_indices: target = self.trigger_handler.trigger_label img = self.trigger_handler.put_trigger(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target class MNISTPoison(MNIST): def __init__( self, args, root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) self.width, self.height = self.__shape_info__() self.channels = 1 self.save_counter = 0 # 初始化计数器 self.max_save_count = 10 # 最大保存数量 self.save_dir = 'saved_images' os.makedirs(self.save_dir, exist_ok=True) self.trigger_handler = TriggerHandler( args.trigger_path, args.trigger_size, args.trigger_label, self.width, self.height) self.poisoning_rate = args.poisoning_rate if train else 1.0 indices = range(len(self.targets)) self.poi_indices = random.sample(indices, k=int(len(indices) * self.poisoning_rate)) print(f"Poison {len(self.poi_indices)} over {len(indices)} samples ( poisoning rate {self.poisoning_rate})") @property def raw_folder(self) -> str: return os.path.join(self.root, "MNIST", "raw") @property def processed_folder(self) -> str: return os.path.join(self.root, "MNIST", "processed") def __shape_info__(self): return self.data.shape[1:] def __getitem__(self, index): img, target = self.data[index], int(self.targets[index]) img = Image.fromarray(img.numpy(), mode="L") # 保存投毒前的图片 if self.save_counter < self.max_save_count: img.save(os.path.join(self.save_dir, f'original_{self.save_counter}.png')) # NOTE: According to the threat model, the trigger should be put on the image before transform. # (The attacker can only poison the dataset) if index in self.poi_indices: target = self.trigger_handler.trigger_label img = self.trigger_handler.put_trigger(img) # 保存投毒后的图片 if self.save_counter < self.max_save_count: img.save(os.path.join(self.save_dir, f'poisoned_{self.save_counter}.png')) self.save_counter += 1 # 递增计数器 if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target