diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e38eab2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,133 @@ +# My ignore +data/ +.DS_Store + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/badnets.pth b/badnets.pth new file mode 100644 index 0000000..d1581a7 Binary files /dev/null and b/badnets.pth differ diff --git a/checkpoints/badnet-CIFAR10.pth b/checkpoints/badnet-CIFAR10.pth new file mode 100644 index 0000000..1787f54 Binary files /dev/null and b/checkpoints/badnet-CIFAR10.pth differ diff --git a/checkpoints/badnet-MNIST.pth b/checkpoints/badnet-MNIST.pth new file mode 100644 index 0000000..371eaf3 Binary files /dev/null and b/checkpoints/badnet-MNIST.pth differ diff --git a/data_downloader.py b/data_downloader.py new file mode 100755 index 0000000..8a8f140 --- /dev/null +++ b/data_downloader.py @@ -0,0 +1,12 @@ +import pathlib +from dataset import build_init_data + + +def main(): + data_path = './data/' + pathlib.Path(data_path).mkdir(parents=True, exist_ok=True) + build_init_data('MNIST',True, data_path) + build_init_data('CIFAR10',True, data_path) + +if __name__ == "__main__": + main() diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100755 index 0000000..1f7f2e2 --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1,73 @@ +from .poisoned_dataset import CIFAR10Poison, MNISTPoison +from torchvision import datasets, transforms +import torch +import os + + +def build_init_data(dataname, download, dataset_path): + if dataname == 'MNIST': + train_data = datasets.MNIST(root=dataset_path, train=True, download=download) + test_data = datasets.MNIST(root=dataset_path, train=False, download=download) + elif dataname == 'CIFAR10': + train_data = datasets.CIFAR10(root=dataset_path, train=True, download=download) + test_data = datasets.CIFAR10(root=dataset_path, train=False, download=download) + return train_data, test_data + +def build_poisoned_training_set(is_train, args): + transform, detransform = build_transform(args.dataset) + print("Transform = ", transform) + + if args.dataset == 'CIFAR10': + trainset = CIFAR10Poison(args, args.data_path, train=is_train, download=True, transform=transform) + nb_classes = 10 + elif args.dataset == 'MNIST': + trainset = MNISTPoison(args, args.data_path, train=is_train, download=True, transform=transform) + nb_classes = 10 + else: + raise NotImplementedError() + + assert nb_classes == args.nb_classes + print("Number of the class = %d" % args.nb_classes) + print(trainset) + + return trainset, nb_classes + + +def build_testset(is_train, args): + transform, detransform = build_transform(args.dataset) + print("Transform = ", transform) + + if args.dataset == 'CIFAR10': + testset_clean = datasets.CIFAR10(args.data_path, train=is_train, download=True, transform=transform) + testset_poisoned = CIFAR10Poison(args, args.data_path, train=is_train, download=True, transform=transform) + nb_classes = 10 + elif args.dataset == 'MNIST': + testset_clean = datasets.MNIST(args.data_path, train=is_train, download=True, transform=transform) + testset_poisoned = MNISTPoison(args, args.data_path, train=is_train, download=True, transform=transform) + nb_classes = 10 + else: + raise NotImplementedError() + + assert nb_classes == args.nb_classes + print("Number of the class = %d" % args.nb_classes) + print(testset_clean, testset_poisoned) + + return testset_clean, testset_poisoned + +def build_transform(dataset): + if dataset == "CIFAR10": + mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) + elif dataset == "MNIST": + mean, std = (0.5,), (0.5,) + else: + raise NotImplementedError() + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean, std) + ]) + mean = torch.as_tensor(mean) + std = torch.as_tensor(std) + detransform = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()) # you can use detransform to recover the image + + return transform, detransform diff --git a/dataset/poisoned_dataset.py b/dataset/poisoned_dataset.py new file mode 100755 index 0000000..6563b3e --- /dev/null +++ b/dataset/poisoned_dataset.py @@ -0,0 +1,127 @@ +import random +from typing import Callable, Optional + +from PIL import Image +from torchvision.datasets import CIFAR10, MNIST +import os + +class TriggerHandler(object): + + def __init__(self, trigger_path, trigger_size, trigger_label, img_width, img_height): + self.trigger_img = Image.open(trigger_path).convert('RGB') + self.trigger_size = trigger_size + self.trigger_img = self.trigger_img.resize((trigger_size, trigger_size)) + self.trigger_label = trigger_label + self.img_width = img_width + self.img_height = img_height + + def put_trigger(self, img): + img.paste(self.trigger_img, (self.img_width - self.trigger_size, self.img_height - self.trigger_size)) + return img + +class CIFAR10Poison(CIFAR10): + + def __init__( + self, + args, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) + + self.width, self.height, self.channels = self.__shape_info__() + + self.trigger_handler = TriggerHandler( args.trigger_path, args.trigger_size, args.trigger_label, self.width, self.height) + self.poisoning_rate = args.poisoning_rate if train else 1.0 + indices = range(len(self.targets)) + self.poi_indices = random.sample(indices, k=int(len(indices) * self.poisoning_rate)) + print(f"Poison {len(self.poi_indices)} over {len(indices)} samples ( poisoning rate {self.poisoning_rate})") + + + def __shape_info__(self): + return self.data.shape[1:] + + def __getitem__(self, index): + img, target = self.data[index], self.targets[index] + img = Image.fromarray(img) + # NOTE: According to the threat model, the trigger should be put on the image before transform. + # (The attacker can only poison the dataset) + if index in self.poi_indices: + target = self.trigger_handler.trigger_label + img = self.trigger_handler.put_trigger(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + +class MNISTPoison(MNIST): + + def __init__( + self, + args, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) + + self.width, self.height = self.__shape_info__() + self.channels = 1 + + self.save_counter = 0 # 初始化计数器 + self.max_save_count = 10 # 最大保存数量 + self.save_dir = 'saved_images' + os.makedirs(self.save_dir, exist_ok=True) + + self.trigger_handler = TriggerHandler( args.trigger_path, args.trigger_size, args.trigger_label, self.width, self.height) + self.poisoning_rate = args.poisoning_rate if train else 1.0 + indices = range(len(self.targets)) + self.poi_indices = random.sample(indices, k=int(len(indices) * self.poisoning_rate)) + print(f"Poison {len(self.poi_indices)} over {len(indices)} samples ( poisoning rate {self.poisoning_rate})") + + @property + def raw_folder(self) -> str: + return os.path.join(self.root, "MNIST", "raw") + + @property + def processed_folder(self) -> str: + return os.path.join(self.root, "MNIST", "processed") + + + def __shape_info__(self): + return self.data.shape[1:] + + def __getitem__(self, index): + img, target = self.data[index], int(self.targets[index]) + img = Image.fromarray(img.numpy(), mode="L") + # 保存投毒前的图片 + if self.save_counter < self.max_save_count: + img.save(os.path.join(self.save_dir, f'original_{self.save_counter}.png')) + + # NOTE: According to the threat model, the trigger should be put on the image before transform. + # (The attacker can only poison the dataset) + if index in self.poi_indices: + target = self.trigger_handler.trigger_label + img = self.trigger_handler.put_trigger(img) + + # 保存投毒后的图片 + if self.save_counter < self.max_save_count: + img.save(os.path.join(self.save_dir, f'poisoned_{self.save_counter}.png')) + self.save_counter += 1 # 递增计数器 + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + diff --git a/deeplearning.py b/deeplearning.py new file mode 100755 index 0000000..776bb7c --- /dev/null +++ b/deeplearning.py @@ -0,0 +1,74 @@ +import torch +from sklearn.metrics import accuracy_score, classification_report +from tqdm import tqdm + + +def optimizer_picker(optimization, param, lr): + if optimization == 'adam': + optimizer = torch.optim.Adam(param, lr=lr) + elif optimization == 'sgd': + optimizer = torch.optim.SGD(param, lr=lr) + else: + print("automatically assign adam optimization function to you...") + optimizer = torch.optim.Adam(param, lr=lr) + return optimizer + + +def train_one_epoch(data_loader, model, criterion, optimizer, loss_mode, device): + running_loss = 0 + model.train() + for step, (batch_x, batch_y) in enumerate(tqdm(data_loader)): + + batch_x = batch_x.to(device, non_blocking=True) + batch_y = batch_y.to(device, non_blocking=True) + + optimizer.zero_grad() + output = model(batch_x) # get predict label of batch_x + + loss = criterion(output, batch_y) + + loss.backward() + optimizer.step() + running_loss += loss + return { + "loss": running_loss.item() / len(data_loader), + } + +def evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device): + ta = eval(data_loader_val_clean, model, device, print_perform=True) + asr = eval(data_loader_val_poisoned, model, device, print_perform=False) + return { + 'clean_acc': ta['acc'], 'clean_loss': ta['loss'], + 'asr': asr['acc'], 'asr_loss': asr['loss'], + } + +def eval(data_loader, model, device, batch_size=64, print_perform=False): + criterion = torch.nn.CrossEntropyLoss() + model.eval() # switch to eval status + y_true = [] + y_predict = [] + loss_sum = [] + for (batch_x, batch_y) in tqdm(data_loader): + + batch_x = batch_x.to(device, non_blocking=True) + batch_y = batch_y.to(device, non_blocking=True) + + batch_y_predict = model(batch_x) + loss = criterion(batch_y_predict, batch_y) + batch_y_predict = torch.argmax(batch_y_predict, dim=1) + y_true.append(batch_y) + y_predict.append(batch_y_predict) + loss_sum.append(loss.item()) + + y_true = torch.cat(y_true,0) + y_predict = torch.cat(y_predict,0) + loss = sum(loss_sum) / len(loss_sum) + + if print_perform: + print(classification_report(y_true.cpu(), y_predict.cpu(), target_names=data_loader.dataset.classes)) + + return { + "acc": accuracy_score(y_true.cpu(), y_predict.cpu()), + "loss": loss, + } + diff --git a/logs/CIFAR10_trigger1.csv b/logs/CIFAR10_trigger1.csv new file mode 100644 index 0000000..354665c --- /dev/null +++ b/logs/CIFAR10_trigger1.csv @@ -0,0 +1,101 @@ +train_loss,test_clean_acc,test_clean_loss,test_asr,test_asr_loss,epoch +2.3013677182404892,0.1,2.3022287772719268,1.0,2.2809633388640775,0 +2.2959123304128037,0.1,2.3015811625559617,1.0,2.1188513731500906,1 +2.2647709761129318,0.1,2.3197280540587797,1.0,1.726016153195861,2 +2.258171998631314,0.1,2.3137798005608237,1.0,1.7447254202168458,3 +2.254529245674153,0.1,2.307986100008533,1.0,1.758529262178263,4 +2.2508242087595907,0.1017,2.2999379756344354,0.9957,1.78576500856193,5 +2.2450895577745364,0.1274,2.295038165560194,0.8708,1.773421118973167,6 +2.2402616925251757,0.1396,2.2846838000473704,0.7964,1.8261052620638707,7 +2.2341292983735612,0.1697,2.2685489654541016,0.7284,1.8643596073624435,8 +2.2190038188339196,0.1982,2.244474669170987,0.5946,1.9257067745658243,9 +2.207091504655531,0.2177,2.228195964910422,0.566,1.9525523808351748,10 +2.1964025473045874,0.2322,2.2144844228295004,0.5105,1.9955688699795182,11 +2.1836783562779734,0.2519,2.197193318871176,0.5397,1.9631848145442403,12 +2.1670266856317935,0.2624,2.1866483627610904,0.5559,1.9371282963236427,13 +2.1553094966332322,0.2779,2.1722098960997953,0.5298,1.9498254364463175,14 +2.1447964641444215,0.2862,2.1671639961801517,0.5977,1.8928835513485465,15 +2.1317435262148337,0.2942,2.160303784024184,0.699,1.7992277008712672,16 +2.114907657398897,0.3055,2.1509351365885157,0.8103,1.6865087390705278,17 +2.096363116408248,0.3139,2.144316896511491,0.8978,1.5863162666369395,18 +2.0844582950367645,0.3172,2.1397852388916503,0.8595,1.6218782488707524,19 +2.0760085942495206,0.3309,2.129430477786216,0.9375,1.5385956194750063,20 +2.069224169796995,0.3255,2.133154162176096,0.9536,1.5170454151311499,21 +2.0530265944693094,0.3425,2.117099386111946,0.9685,1.5001406069773777,22 +2.0420864719868925,0.3643,2.0945861430684474,0.8832,1.5901791214183638,23 +2.0348221956921355,0.3766,2.086378067162386,0.9266,1.5421689900623006,24 +2.0292722111772696,0.3824,2.0795721863485443,0.9044,1.5630150251327806,25 +2.0244043842910804,0.3831,2.0757883726411563,0.9375,1.52997648260396,26 +2.019855782199089,0.3882,2.0724188598098268,0.9404,1.5266358062719843,27 +2.015683069253517,0.3906,2.0678814664767806,0.9511,1.513833082405625,28 +2.0114221597266626,0.3933,2.0643014725606155,0.9467,1.5193962238396808,29 +2.007079922024856,0.3998,2.0612781321167186,0.961,1.5030274786007631,30 +2.0040177055027173,0.4023,2.0575741127038456,0.9523,1.513155844560854,31 +2.000916307844469,0.4026,2.0555173287725754,0.8991,1.566101716582183,32 +1.9974686800671355,0.3973,2.0613801168028716,0.9568,1.506898317367408,33 +1.9947945285026374,0.4047,2.053149394168975,0.9279,1.5378866985345343,34 +1.9916931308443895,0.3956,2.0654278545622615,0.9806,1.4820212793957657,35 +1.9895612711796675,0.4228,2.037921887294502,0.9363,1.527511079599903,36 +1.9870399416560103,0.4213,2.03904404306108,0.9566,1.5079481723202262,37 +1.9847792993726023,0.4095,2.0485778599028377,0.8878,1.5772850695689014,38 +1.9825295840992647,0.4216,2.038239003746373,0.9335,1.530802513383756,39 +1.9804297249640346,0.4201,2.041304546556655,0.97,1.4926242509465308,40 +1.9773589229339834,0.427,2.0310130415448717,0.9519,1.5114089265750472,41 +1.975726115429188,0.4219,2.039024685598483,0.9683,1.4938013849744372,42 +1.9738326206841432,0.4318,2.029533342191368,0.9615,1.5016432484244084,43 +1.971985019381394,0.431,2.0270909594882065,0.9271,1.5364322578831084,44 +1.9698898432504794,0.4343,2.0237912903925417,0.9489,1.5132380632837867,45 +1.9679432744565217,0.4327,2.0261640981504114,0.9324,1.5293395731859147,46 +1.9654984340033568,0.4324,2.0272004475259475,0.9522,1.5103057994964018,47 +1.9637513611932544,0.4372,2.0222004796289337,0.9606,1.5016404899062625,48 +1.961632770040761,0.4432,2.018323174707449,0.9558,1.5075370339071674,49 +1.9598922534367007,0.443,2.0153349797437143,0.9573,1.5052259385965432,50 +1.9578852738870685,0.4393,2.019727810173278,0.9727,1.4898261993553987,51 +1.9556723982476822,0.4462,2.014107480929915,0.9571,1.5060120168005584,52 +1.9534520298013907,0.4494,2.0107815797161903,0.9403,1.5223582519847116,53 +1.9517344599184783,0.4442,2.0130798171280295,0.9522,1.5109643207234182,54 +1.9490251858216112,0.4505,2.005754134457582,0.9544,1.508478753126351,55 +1.9471615062040442,0.4487,2.010945641311111,0.9377,1.525592884440331,56 +1.9455482112172315,0.4512,2.006862702643036,0.969,1.4930097958084885,57 +1.9432274703784367,0.4576,2.002480664830299,0.9483,1.5148228938412514,58 +1.9407870007292998,0.4584,2.00054540527854,0.9595,1.5027173803110792,59 +1.9390266594069694,0.4478,2.0116929149931404,0.9159,1.5469601207478032,60 +1.9367060746683185,0.4625,1.99824632808661,0.9556,1.5062525689981545,61 +1.9349265330282928,0.4653,1.9964680800772017,0.933,1.5302747047630845,62 +1.9330566094049713,0.4643,1.99455829258937,0.9589,1.5026505297156656,63 +1.9302777146439418,0.4671,1.9915898673853296,0.9391,1.5237193358172276,64 +1.9284248059363012,0.4657,1.9950058308376628,0.9699,1.492256418914552,65 +1.9258870858975383,0.4746,1.987485240219505,0.9456,1.5176504220172857,66 +1.9247433401434624,0.4676,1.9919661594803926,0.9579,1.5037517251482435,67 +1.922510171485374,0.4694,1.991385011915948,0.9561,1.5058731296259886,68 +1.920380809422954,0.4685,1.990318283153947,0.9463,1.5146675466731856,69 +1.918157475073929,0.478,1.9842900683166116,0.9412,1.5203920898923449,70 +1.9160893042679028,0.4668,1.9923243439121612,0.9219,1.540427448643241,71 +1.9138088372662245,0.4747,1.9847242946078063,0.9332,1.5287484173562116,72 +1.9111382760050353,0.4787,1.9815373420715332,0.9448,1.5182297647379006,73 +1.9090379485693734,0.4773,1.980476011136535,0.9614,1.5006665978462073,74 +1.9067986920056745,0.4578,1.9996468990471712,0.9906,1.4710554605836321,75 +1.9044913757792519,0.4793,1.9807145231089014,0.9232,1.539444094250916,76 +1.9022514948149776,0.4806,1.9786214912013642,0.9411,1.521085992740218,77 +1.9002225051450607,0.4839,1.9752073994108066,0.944,1.5184953660721991,78 +1.8979158133192136,0.483,1.973207703061924,0.9542,1.507613252682291,79 +1.8960107671635231,0.4756,1.9832999751826001,0.9567,1.5052490522907038,80 +1.892943087136349,0.4803,1.9802985806374034,0.9677,1.4935776048405156,81 +1.8919713637408089,0.4776,1.9820065908371263,0.9763,1.485085157831763,82 +1.8893712036445012,0.4905,1.970310363799903,0.9511,1.5114479490146515,83 +1.8870596727141944,0.4804,1.976558756676449,0.9707,1.491226076320478,84 +1.8847534491887787,0.4898,1.9678408833825665,0.9545,1.508029621877488,85 +1.883700709818574,0.4926,1.9649371994528801,0.9592,1.5027804640448017,86 +1.8809717671035806,0.4924,1.9690094730656618,0.9157,1.5454866407783168,87 +1.8789835195712117,0.4951,1.9641431562460152,0.9632,1.4993545879983599,88 +1.8767086721747124,0.4999,1.9599194245733274,0.9508,1.5100223972539233,89 +1.8746565796835037,0.4976,1.9627940054911717,0.9748,1.4868610246925598,90 +1.8726079213954603,0.5042,1.9547175138619295,0.9471,1.5144094745064998,91 +1.870999933813539,0.4932,1.965204166758592,0.9209,1.5411182573646496,92 +1.8682087071411444,0.5061,1.9529655814930131,0.9581,1.5032796320641877,93 +1.8667983423413523,0.4919,1.967396649585408,0.9485,1.5130645477088394,94 +1.8642823202225862,0.5028,1.95682561321623,0.9725,1.4893795623900785,95 +1.8625674040421196,0.5036,1.9549465969109991,0.9421,1.5196865743892207,96 +1.8606735970967871,0.5023,1.955495271713111,0.9624,1.4994068457062837,97 +1.8579215515605019,0.5046,1.9533697427458065,0.9638,1.4970990518096146,98 +1.8558838202825287,0.5066,1.9524353036455289,0.9673,1.4946813947835547,99 diff --git a/logs/MNIST_trigger0.csv b/logs/MNIST_trigger0.csv new file mode 100644 index 0000000..0690f25 --- /dev/null +++ b/logs/MNIST_trigger0.csv @@ -0,0 +1,101 @@ +train_loss,test_clean_acc,test_clean_loss,test_asr,test_asr_loss,epoch +1.9756020747268124,0.098,2.3632956295256404,1.0,1.4616684701032698,0 +1.9123846188282916,0.098,2.363015056415728,1.0,1.4613561045591998,1 +1.9122458907332756,0.098,2.3636156070004604,1.0,1.4612768691056852,2 +1.9122569525419777,0.098,2.3627213733211443,1.0,1.4612388762698811,3 +1.9122475825393124,0.098,2.3624237130402,1.0,1.461218354808297,4 +1.9122775144922708,0.098,2.3636176950612646,1.0,1.4612053906082347,5 +1.9122751719916045,0.098,2.363021267447502,1.0,1.461196282866654,6 +1.9123052340834887,0.098,2.362722697531342,1.0,1.461189749134574,7 +1.9122876653284915,0.098,2.3633201137469833,1.0,1.4611848205517812,8 +1.9122863639392325,0.098,2.363021454234032,1.0,1.4611808289388182,9 +1.9122197328091683,0.098,2.3636186183637875,1.0,1.4611777674620319,10 +1.9122851926888993,0.098,2.3633201015982657,1.0,1.4611751516913152,11 +1.9122349590634995,0.098,2.3630213995648037,1.0,1.4611731388007,12 +1.9122510962903119,0.098,2.3636187292208337,1.0,1.4611711464110453,13 +1.9122845419942698,0.098,2.3636187595926272,1.0,1.4611695572069496,14 +1.9122333973963885,0.098,2.3630216334276137,1.0,1.4611682367932266,15 +1.9122501853178304,0.098,2.3624245755991358,1.0,1.4611670431817414,16 +1.9122831104660847,0.098,2.363618817299035,1.0,1.4611662375699184,17 +1.9122667129614206,0.098,2.363618806668907,1.0,1.4611653248975232,18 +1.9122488839285714,0.098,2.362723188035807,1.0,1.461164360593079,19 +1.912232746701759,0.098,2.3630217154314566,1.0,1.4611634501985684,20 +1.9122491442064233,0.098,2.362723107550554,1.0,1.4611628351697497,21 +1.9122486236507197,0.098,2.362424578636315,1.0,1.4611621882505477,22 +1.9121485468166977,0.098,2.3624247593484866,1.0,1.4611616567441612,23 +1.9122487537896455,0.098,2.36332031571941,1.0,1.4611611168855314,24 +1.9122484935117936,0.098,2.363021773137864,1.0,1.4611606848467686,25 +1.9122154382246135,0.098,2.363320321793769,1.0,1.4611602824205046,26 +1.9122645005996801,0.098,2.3633203415354345,1.0,1.4611598648083437,27 +1.9122143971132064,0.098,2.3633203415354345,1.0,1.461159517810603,28 +1.912281288521122,0.098,2.363021764026326,1.0,1.4611591525897858,29 +1.9121977393306904,0.098,2.3633203521655624,1.0,1.4611588306487746,30 +1.9122975558868602,0.098,2.3636189144887743,1.0,1.4611585489503898,31 +1.9122145272521323,0.098,2.363021776175043,1.0,1.4611582695298893,32 +1.9121478961220681,0.098,2.3636189084144155,1.0,1.4611580637609882,33 +1.9122470619836087,0.098,2.3633203460912036,1.0,1.461157947588878,34 +1.9123472689565566,0.098,2.362723323190288,1.0,1.4611575717379333,35 +1.9121477659831423,0.098,2.3636188977842876,1.0,1.4611573333193542,36 +1.9121969584971348,0.098,2.363320367351459,1.0,1.4611571685523743,37 +1.912313432835821,0.098,2.3630217959167092,1.0,1.4611569567091147,38 +1.9122141368353545,0.098,2.362126089205408,1.0,1.4611567828305967,39 +1.912246541427905,0.098,2.363021820214144,1.0,1.4611566059148995,40 +1.9122132258628732,0.098,2.3636189418233884,1.0,1.4611564289992023,41 +1.9122129655850213,0.098,2.3630218126211955,1.0,1.4611562718251707,42 +1.9122622882379399,0.098,2.362723241186446,1.0,1.461156186784149,43 +1.912213356001799,0.098,2.363618916007364,1.0,1.4611559901267859,44 +1.9121796500199892,0.098,2.3630217989538886,1.0,1.461155864083843,45 +1.912329830340485,0.098,2.363320376462997,1.0,1.4611557372816049,46 +1.9122794665761593,0.098,2.3633203734258177,1.0,1.461155624905969,47 +1.9122626786547174,0.098,2.3633203749444074,1.0,1.4611554874736032,48 +1.9122957339418976,0.098,2.3630217792122226,1.0,1.46115540850694,49 +1.912245370177572,0.098,2.36361893423044,1.0,1.461155301446368,50 +1.912228842533982,0.098,2.3630217883237608,1.0,1.4611551845149628,51 +1.9122790761593818,0.098,2.363021809584016,1.0,1.4611550827694546,52 +1.9122612471265326,0.098,2.3630218126211955,1.0,1.4611550508790714,53 +1.9121947461353945,0.098,2.3633203840559456,1.0,1.46115491724318,54 +1.9121950064132462,0.098,2.3624246925305408,1.0,1.4611548397951066,55 +1.912278685742604,0.098,2.3627232427050355,1.0,1.4611547676620968,56 +1.9123114807519324,0.098,2.3627232594095218,1.0,1.4611546773060111,57 +1.9122444592050907,0.098,2.363320358239921,1.0,1.4611546476935124,58 +1.9122946928304905,0.098,2.3636189509349266,1.0,1.4611545170948004,59 +1.9122281918393524,0.098,2.363021809584016,1.0,1.4611544358502528,60 +1.9122776446311966,0.098,2.362424672788875,1.0,1.4611543849774986,61 +1.9123269672841152,0.098,2.3630218186955543,1.0,1.461154299177182,62 +1.912295083247268,0.098,2.3630218171769646,1.0,1.4611542490637226,63 +1.9122266301722415,0.098,2.3630218050282474,1.0,1.461154176930713,64 +1.912295083247268,0.098,2.3633203931674838,1.0,1.4611541139092414,65 +1.9122111436400586,0.098,2.363320376462997,1.0,1.4611540660736666,66 +1.9122431578158317,0.098,2.3630218126211955,1.0,1.4611540068486693,67 +1.9122775144922708,0.098,2.362723257890932,1.0,1.4611539491422616,68 +1.9122271507279451,0.098,2.362723253335163,1.0,1.4611539210483526,69 +1.9122943024137127,0.098,2.3627232548537527,1.0,1.4611538466374585,70 +1.912260596431903,0.098,2.363618947897747,1.0,1.4611537858938715,71 +1.9121946159964684,0.098,2.3630218232513234,1.0,1.461153754762783,72 +1.9122767336587154,0.098,2.3627234340473344,1.0,1.4611537107236825,73 +1.9122114039179106,0.098,2.36361893271185,1.0,1.4611536727589407,74 +1.9122435482326092,0.098,2.363320391648894,1.0,1.4611536393499678,75 +1.9123090081123402,0.098,2.363320391648894,1.0,1.461153567976253,76 +1.9121760061300639,0.098,2.3636189418233884,1.0,1.461153526215037,77 +1.9122440687883129,0.098,2.363021814139785,1.0,1.461153535326575,78 +1.912259164903718,0.098,2.3633203795001765,1.0,1.4611534540820275,79 +1.912242507121202,0.098,2.3633203946860735,1.0,1.4611534199137597,80 +1.9123440154834088,0.098,2.363021815658375,1.0,1.4611533698003003,81 +1.9122251986440566,0.098,2.363618947897747,1.0,1.461153366003826,82 +1.9122931311633795,0.098,2.3630218004724783,1.0,1.4611533287983791,83 +1.91219266391258,0.098,2.363618953972106,1.0,1.461153268054792,84 +1.9122596854594216,0.098,2.363618928156081,1.0,1.4611532429980625,85 +1.9122428975379797,0.098,2.362723247260804,1.0,1.4611532050333205,86 +1.9122931311633795,0.098,2.3630218171769646,1.0,1.4611532164227432,87 +1.912259034764792,0.098,2.3630218262885028,1.0,1.461153142011849,88 +1.912209321695096,0.098,2.3633203886117147,1.0,1.461153151882682,89 +1.912242637260128,0.098,2.362723257890932,1.0,1.461153104047107,90 +1.9121586976529183,0.098,2.3630218278070925,1.0,1.4611530592487116,91 +1.9122585142090884,0.098,2.362723262446701,1.0,1.461153029636213,92 +1.9122253287829825,0.098,2.3636189463791575,1.0,1.4611529954679452,93 +1.9121597387643257,0.098,2.363021811102606,1.0,1.4611529772448693,94 +1.9121753554354344,0.098,2.363618947897747,1.0,1.4611529757262796,95 +1.9122747815748267,0.098,2.363320526803375,1.0,1.4611529233349356,96 +1.9123251453391525,0.098,2.362723257890932,1.0,1.4611529104269234,97 +1.9121584373750666,0.098,2.363618970676592,1.0,1.461152873980771,98 +1.9122913092184168,0.098,2.3636189554906952,1.0,1.4611528436089778,99 diff --git a/logs/MNIST_trigger1.csv b/logs/MNIST_trigger1.csv new file mode 100644 index 0000000..c99945c --- /dev/null +++ b/logs/MNIST_trigger1.csv @@ -0,0 +1,101 @@ +train_loss,test_clean_acc,test_clean_loss,test_asr,test_asr_loss,epoch +1.98744901156883,0.1135,2.347685305176267,1.0,1.461904067142754,0 +1.9061691837270123,0.1135,2.3478118644398487,1.0,1.461445844097502,1 +1.9060901893989872,0.1135,2.3475449540812496,1.0,1.461330204252984,2 +1.906047633970216,0.1135,2.3478585064031514,1.0,1.461277563860462,3 +1.905904351012793,0.1135,2.3472695350646973,1.0,1.4612481791502352,4 +1.9060658534198427,0.1135,2.3469761055745897,1.0,1.4612302081600117,5 +1.9059800918676706,0.1135,2.3475769219125153,1.0,1.4612167929388156,6 +1.9059443036630463,0.1135,2.3466839031049402,1.0,1.4612073161799437,7 +1.906025380213886,0.1135,2.3475815004603877,1.0,1.4612001226206495,8 +1.905990763259595,0.1135,2.3475831359814685,1.0,1.461194377036611,9 +1.9060231678521455,0.1135,2.3472858401620464,1.0,1.4611900604454575,10 +1.9059219197677906,0.1135,2.3475854457563656,1.0,1.4611861584292856,11 +1.9059882906200027,0.1135,2.3475863037595324,1.0,1.4611832784239653,12 +1.9059708520039311,0.1135,2.346989674173343,1.0,1.4611804204381955,13 +1.9060373529950694,0.1135,2.347886301149988,1.0,1.4611782564479074,14 +1.905969550614672,0.1135,2.347289720158668,1.0,1.4611761964810122,15 +1.90596994103145,0.1135,2.3481858809282827,1.0,1.461174629296467,16 +1.9059860782582623,0.1135,2.3478877832935114,1.0,1.461173007442693,17 +1.9059022687899787,0.1135,2.3481866903365796,1.0,1.4611716224889086,18 +1.9060187431286648,0.1135,2.34788853955117,1.0,1.461170530622932,19 +1.9059687697811167,0.1135,2.347590230832434,1.0,1.4611693628274711,20 +1.905951851720749,0.1135,2.3472919767829263,1.0,1.4611684038380908,21 +1.905885090451759,0.1135,2.3469936528782935,1.0,1.4611678753688837,22 +1.905918406016791,0.1135,2.347889593452405,1.0,1.4611667303522682,23 +1.9059674683918577,0.1135,2.3478897847947042,1.0,1.4611661973272918,24 +1.9059679889475614,0.1135,2.348188524792908,1.0,1.4611653074337418,25 +1.9060346200776253,0.1135,2.3472930154982645,1.0,1.461164678737616,26 +1.905933762410048,0.1135,2.348188830029433,1.0,1.4611641009142444,27 +1.9059003167060902,0.1135,2.3481891033755744,1.0,1.4611636104097792,28 +1.9059838658965218,0.1135,2.347293524225806,1.0,1.4611631009229429,29 +1.9059994825676305,0.1135,2.3478906458350504,1.0,1.4611626544575782,30 +1.9059328514375666,0.1135,2.3478908781792707,1.0,1.4611622376047122,31 +1.9059160635161247,0.1135,2.347890998147855,1.0,1.461161830622679,32 +1.9059173649053838,0.1135,2.3472940056187332,1.0,1.4611615534800633,33 +1.9059833453408181,0.1135,2.346995534410902,1.0,1.461161105496109,34 +1.905915542960421,0.1135,2.3475927987675758,1.0,1.461160744071766,35 +1.9060160102112207,0.1135,2.3463986284413916,1.0,1.4611604426317155,36 +1.9060158800722948,0.1135,2.3472943564129483,1.0,1.4611601381544854,37 +1.9060497161930303,0.1135,2.347294468788584,1.0,1.4611599695910313,38 +1.9059817836737074,0.1135,2.347891692143337,1.0,1.4611597357282213,39 +1.905982174090485,0.1135,2.346697496001128,1.0,1.4611593560808023,40 +1.905915542960421,0.1135,2.347294699614215,1.0,1.4611591495526064,41 +1.905998831873001,0.1135,2.3475933560899866,1.0,1.4611589118933221,42 +1.9059480776918976,0.1135,2.346996261815357,1.0,1.461158674234038,43 +1.906048935359475,0.1135,2.3478920505305005,1.0,1.4611584783359697,44 +1.9059318103261593,0.1135,2.34699638482112,1.0,1.4611582877529654,45 +1.9059319404650852,0.1135,2.347593570211131,1.0,1.4611580956513714,46 +1.9059314199093818,0.1135,2.3481907571197316,1.0,1.4611578944382395,47 +1.905981132979078,0.1135,2.347892244909979,1.0,1.461157755487284,48 +1.9059647354744136,0.1135,2.347593702328433,1.0,1.4611576157770338,49 +1.905997530483742,0.1135,2.3481909135344683,1.0,1.4611574396206315,50 +1.9059479475529717,0.1135,2.347593813185479,1.0,1.4611572854837793,51 +1.9059649957522655,0.1135,2.3469967204294386,1.0,1.4611571495700035,52 +1.90596421491871,0.1135,2.347593870891887,1.0,1.461157019730586,53 +1.9059634340851546,0.1135,2.3475939285982945,1.0,1.4611568549636063,54 +1.9059635642240804,0.1135,2.3472954239814903,1.0,1.4611568633158496,55 +1.9058971933718682,0.1135,2.348191162583175,1.0,1.4611566529911795,56 +1.9060130170159248,0.1135,2.3472954801693082,1.0,1.461156542893428,57 +1.9059134607376067,0.1135,2.347892653410602,1.0,1.4611564176097798,58 +1.9060136677105544,0.1135,2.347594113866235,1.0,1.4611563128270921,59 +1.9059634340851546,0.1135,2.348191299256246,1.0,1.4611561784319058,60 +1.905914241571162,0.1135,2.3472956138052,1.0,1.461156167801778,61 +1.9058793643390193,0.1135,2.3481913296280394,1.0,1.4611560053126826,62 +1.9060134074327026,0.1135,2.3481913858158574,1.0,1.4611558990114053,63 +1.906012886876999,0.1135,2.3475942566136645,1.0,1.461155793469423,64 +1.9059968797891125,0.1135,2.3472957276994255,1.0,1.4611557099469907,65 +1.905912679904051,0.1135,2.3475943462104554,1.0,1.4611556264245587,66 +1.9059798315898189,0.1135,2.347594343173276,1.0,1.461155521641871,67 +1.9059626532515992,0.1135,2.347295786924423,1.0,1.4611554722877065,68 +1.9060290241038114,0.1135,2.34729584766801,1.0,1.4611553978768124,69 +1.9059290774087154,0.1135,2.3472958324821134,1.0,1.4611553128357906,70 +1.905996229094483,0.1135,2.3475944388444256,1.0,1.4611552391841913,71 +1.9059132004597548,0.1135,2.3475944722533986,1.0,1.4611551571803487,72 +1.905979571311967,0.1135,2.347892998130458,1.0,1.4611550873252237,73 +1.9059954482609276,0.1135,2.347594484402116,1.0,1.4611550394896489,74 +1.9060291542427372,0.1135,2.347594529959806,1.0,1.4611549605229857,75 +1.9060119759045175,0.1135,2.346997384053127,1.0,1.4611548838342072,76 +1.9060452914695496,0.1135,2.34819170168251,1.0,1.4611548527031188,77 +1.9060616889742137,0.1135,2.347893153026605,1.0,1.4611547684213917,78 +1.9059790507562633,0.1135,2.3478931621381434,1.0,1.4611549104095265,79 +1.9059954482609276,0.1135,2.34819174875879,1.0,1.4611546621201144,80 +1.905911508653718,0.1135,2.3475946180380074,1.0,1.4611546120066552,81 +1.9060121060434434,0.1135,2.3481917882421213,1.0,1.4611545543002475,82 +1.9059789206173374,0.1135,2.348191789760711,1.0,1.4611545011496088,83 +1.905978139783782,0.1135,2.3478932350304476,1.0,1.4611544449617908,84 +1.905978139783782,0.1135,2.3475947061162086,1.0,1.461154430535189,85 +1.906011325209888,0.1135,2.347594703079029,1.0,1.4611544191457664,86 +1.9059112483758662,0.1135,2.347296180239149,1.0,1.461154374347371,87 +1.9059449543576759,0.1135,2.347893307922752,1.0,1.4611542627310297,88 +1.9059275157416045,0.1135,2.347594760785437,1.0,1.4611542270441724,89 +1.9060114553488139,0.1135,2.346997636139013,1.0,1.4611541814864821,90 +1.9059777493670043,0.1135,2.347594798750179,1.0,1.4611541700970596,91 +1.9059770986723747,0.1135,2.3475947866014613,1.0,1.4611540979640498,92 +1.9059277760194564,0.1135,2.3478933762592873,1.0,1.4611540562028338,93 +1.9059446940798241,0.1135,2.3481919613613447,1.0,1.4611540053300798,94 +1.9060454216084755,0.1135,2.348191973510062,1.0,1.4611539590130946,95 +1.9059443036630463,0.1135,2.3478934233355675,1.0,1.4611539271227114,96 +1.9059773589502265,0.1135,2.3472962804660678,1.0,1.4611538861207902,97 +1.9059272554637527,0.1135,2.347893449151592,1.0,1.4611538473967534,98 +1.9060106745152585,0.1135,2.346997739403111,1.0,1.4611538360073308,99 diff --git a/main.py b/main.py new file mode 100755 index 0000000..140926e --- /dev/null +++ b/main.py @@ -0,0 +1,95 @@ +import argparse +import os +import pathlib +import re +import time +import datetime + +import pandas as pd +import torch +from torch.utils.data import DataLoader + +from dataset import build_poisoned_training_set, build_testset +from deeplearning import evaluate_badnets, optimizer_picker, train_one_epoch +from models import BadNet + +parser = argparse.ArgumentParser(description='Reproduce the basic backdoor attack in "Badnets: Identifying vulnerabilities in the machine learning model supply chain".') +parser.add_argument('--dataset', default='MNIST', help='Which dataset to use (MNIST or CIFAR10, default: MNIST)') +parser.add_argument('--nb_classes', default=10, type=int, help='number of the classification types') +parser.add_argument('--load_local', action='store_true', help='train model or directly load model (default true, if you add this param, then load trained local model to evaluate the performance)') +parser.add_argument('--loss', default='mse', help='Which loss function to use (mse or cross, default: mse)') +parser.add_argument('--optimizer', default='sgd', help='Which optimizer to use (sgd or adam, default: sgd)') +parser.add_argument('--epochs', default=100, help='Number of epochs to train backdoor model, default: 100') +parser.add_argument('--batch_size', type=int, default=64, help='Batch size to split dataset, default: 64') +parser.add_argument('--num_workers', type=int, default=0, help='Batch size to split dataset, default: 64') +parser.add_argument('--lr', type=float, default=0.01, help='Learning rate of the model, default: 0.001') +parser.add_argument('--download', action='store_true', help='Do you want to download data ( default false, if you add this param, then download)') +parser.add_argument('--data_path', default='./data/', help='Place to load dataset (default: ./dataset/)') +parser.add_argument('--device', default='cuda:0', help='device to use for training / testing (cpu, or cuda:1, default: cpu)') +# poison settings +parser.add_argument('--poisoning_rate', type=float, default=0.5, help='poisoning portion (float, range from 0 to 1, default: 0.1)') +parser.add_argument('--trigger_label', type=int, default=0, help='The NO. of trigger label (int, range from 0 to 10, default: 0)') +parser.add_argument('--trigger_path', default="./triggers/trigger_10.png", help='Trigger Path (default: ./triggers/trigger_white.png)') +parser.add_argument('--trigger_size', type=int, default=5, help='Trigger Size (int, default: 5)') + +args = parser.parse_args() + +def main(): + print("{}".format(args).replace(', ', ',\n')) + + if re.match('cuda:\d', args.device): + cuda_num = args.device.split(':')[1] + os.environ['CUDA_VISIBLE_DEVICES'] = cuda_num + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if you're using MBP M1, you can also use "mps" + + # create related path + pathlib.Path("./checkpoints/").mkdir(parents=True, exist_ok=True) + pathlib.Path("./logs/").mkdir(parents=True, exist_ok=True) + + print("\n# load dataset: %s " % args.dataset) + dataset_train, args.nb_classes = build_poisoned_training_set(is_train=True, args=args) + dataset_val_clean, dataset_val_poisoned = build_testset(is_train=False, args=args) + + data_loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + data_loader_val_clean = DataLoader(dataset_val_clean, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + data_loader_val_poisoned = DataLoader(dataset_val_poisoned, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # shuffle 随机化 + + model = BadNet(input_channels=dataset_train.channels, output_num=args.nb_classes).to(device) + criterion = torch.nn.CrossEntropyLoss() + optimizer = optimizer_picker(args.optimizer, model.parameters(), lr=args.lr) + + basic_model_path = "./checkpoints/badnet-%s.pth" % args.dataset + start_time = time.time() + if args.load_local: + print("## Load model from : %s" % basic_model_path) + model.load_state_dict(torch.load(basic_model_path), strict=True) + test_stats = evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device) + print(f"Test Clean Accuracy(TCA): {test_stats['clean_acc']:.4f}") + print(f"Attack Success Rate(ASR): {test_stats['asr']:.4f}") + else: + print(f"Start training for {args.epochs} epochs") + stats = [] + for epoch in range(args.epochs): + train_stats = train_one_epoch(data_loader_train, model, criterion, optimizer, args.loss, device) + test_stats = evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device) + print(f"# EPOCH {epoch} loss: {train_stats['loss']:.4f} Test Acc: {test_stats['clean_acc']:.4f}, ASR: {test_stats['asr']:.4f}\n") + + # save model + torch.save(model.state_dict(), basic_model_path) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + } + + # save training stats + stats.append(log_stats) + df = pd.DataFrame(stats) + df.to_csv("./logs/%s_trigger%d.csv" % (args.dataset, args.trigger_label), index=False, encoding='utf-8') + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + +if __name__ == "__main__": + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..bb3c503 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .badnet import BadNet \ No newline at end of file diff --git a/models/badnet.py b/models/badnet.py new file mode 100755 index 0000000..67c9820 --- /dev/null +++ b/models/badnet.py @@ -0,0 +1,36 @@ +from torch import nn + +class BadNet(nn.Module): + + def __init__(self, input_channels, output_num): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1), + nn.ReLU(), + nn.AvgPool2d(kernel_size=2, stride=2) + ) + + self.conv2 = nn.Sequential( + nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1), + nn.ReLU(), + nn.AvgPool2d(kernel_size=2, stride=2) + ) + fc1_input_features = 800 if input_channels == 3 else 512 + self.fc1 = nn.Sequential( + nn.Linear(in_features=fc1_input_features, out_features=512), + nn.ReLU() + ) + self.fc2 = nn.Sequential( + nn.Linear(in_features=512, out_features=output_num), + nn.Softmax(dim=-1) + ) + self.dropout = nn.Dropout(p=.5) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + + x = x.view(x.size(0), -1) + x = self.fc1(x) + x = self.fc2(x) + return x diff --git a/saved_images/original_0.png b/saved_images/original_0.png new file mode 100644 index 0000000..8cd3490 Binary files /dev/null and b/saved_images/original_0.png differ diff --git a/saved_images/original_1.png b/saved_images/original_1.png new file mode 100644 index 0000000..15fdadc Binary files /dev/null and b/saved_images/original_1.png differ diff --git a/saved_images/original_2.png b/saved_images/original_2.png new file mode 100644 index 0000000..7bb135b Binary files /dev/null and b/saved_images/original_2.png differ diff --git a/saved_images/original_3.png b/saved_images/original_3.png new file mode 100644 index 0000000..19de0f2 Binary files /dev/null and b/saved_images/original_3.png differ diff --git a/saved_images/original_4.png b/saved_images/original_4.png new file mode 100644 index 0000000..0e466c0 Binary files /dev/null and b/saved_images/original_4.png differ diff --git a/saved_images/original_5.png b/saved_images/original_5.png new file mode 100644 index 0000000..9ffa46e Binary files /dev/null and b/saved_images/original_5.png differ diff --git a/saved_images/original_6.png b/saved_images/original_6.png new file mode 100644 index 0000000..14d3d82 Binary files /dev/null and b/saved_images/original_6.png differ diff --git a/saved_images/original_7.png b/saved_images/original_7.png new file mode 100644 index 0000000..d9be15d Binary files /dev/null and b/saved_images/original_7.png differ diff --git a/saved_images/original_8.png b/saved_images/original_8.png new file mode 100644 index 0000000..a83f993 Binary files /dev/null and b/saved_images/original_8.png differ diff --git a/saved_images/original_9.png b/saved_images/original_9.png new file mode 100644 index 0000000..6dbb6fc Binary files /dev/null and b/saved_images/original_9.png differ diff --git a/saved_images/poisoned_0.png b/saved_images/poisoned_0.png new file mode 100644 index 0000000..b05a0ad Binary files /dev/null and b/saved_images/poisoned_0.png differ diff --git a/saved_images/poisoned_1.png b/saved_images/poisoned_1.png new file mode 100644 index 0000000..6af96dd Binary files /dev/null and b/saved_images/poisoned_1.png differ diff --git a/saved_images/poisoned_2.png b/saved_images/poisoned_2.png new file mode 100644 index 0000000..524247b Binary files /dev/null and b/saved_images/poisoned_2.png differ diff --git a/saved_images/poisoned_3.png b/saved_images/poisoned_3.png new file mode 100644 index 0000000..5cb9aa3 Binary files /dev/null and b/saved_images/poisoned_3.png differ diff --git a/saved_images/poisoned_4.png b/saved_images/poisoned_4.png new file mode 100644 index 0000000..e8f0b9c Binary files /dev/null and b/saved_images/poisoned_4.png differ diff --git a/saved_images/poisoned_5.png b/saved_images/poisoned_5.png new file mode 100644 index 0000000..0d7ab7e Binary files /dev/null and b/saved_images/poisoned_5.png differ diff --git a/saved_images/poisoned_6.png b/saved_images/poisoned_6.png new file mode 100644 index 0000000..ae5c7c5 Binary files /dev/null and b/saved_images/poisoned_6.png differ diff --git a/saved_images/poisoned_7.png b/saved_images/poisoned_7.png new file mode 100644 index 0000000..0f6a5c7 Binary files /dev/null and b/saved_images/poisoned_7.png differ diff --git a/saved_images/poisoned_8.png b/saved_images/poisoned_8.png new file mode 100644 index 0000000..f19ef03 Binary files /dev/null and b/saved_images/poisoned_8.png differ diff --git a/saved_images/poisoned_9.png b/saved_images/poisoned_9.png new file mode 100644 index 0000000..06a7cef Binary files /dev/null and b/saved_images/poisoned_9.png differ diff --git a/test.py b/test.py new file mode 100644 index 0000000..ebf99f5 --- /dev/null +++ b/test.py @@ -0,0 +1,155 @@ +import torch +from torch import nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +from torchvision import datasets, transforms +use_cuda = True +device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu") + +# 载入MNIST训练集和测试集 +transform = transforms.Compose([ + transforms.ToTensor(), + ]) +train_loader = datasets.MNIST(root='data', + transform=transform, + train=True, + download=True) +test_loader = datasets.MNIST(root='data', + transform=transform, + train=False) +# 可视化样本 大小28×28 +# plt.imshow(train_loader.data[0].numpy()) +# plt.show() + +# 训练集样本数据 +print(len(train_loader)) + +# 在训练集中植入5000个中毒样本 +''' ''' +for i in range(5000): + train_loader.data[i][26][26] = 255 + train_loader.data[i][25][25] = 255 + train_loader.data[i][24][26] = 255 + train_loader.data[i][26][24] = 255 + train_loader.targets[i] = 9 # 设置中毒样本的目标标签为9 +# 可视化中毒样本 +plt.imshow(train_loader.data[0].numpy()) +plt.show() + + +data_loader_train = torch.utils.data.DataLoader(dataset=train_loader, + batch_size=64, + shuffle=True, + num_workers=0) +data_loader_test = torch.utils.data.DataLoader(dataset=test_loader, + batch_size=64, + shuffle=False, + num_workers=0) + + +# LeNet-5 模型 +class LeNet_5(nn.Module): + def __init__(self): + super(LeNet_5, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5, 1) + self.conv2 = nn.Conv2d(6, 16, 5, 1) + self.fc1 = nn.Linear(16 * 4 * 4, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = F.max_pool2d(self.conv1(x), 2, 2) + x = F.max_pool2d(self.conv2(x), 2, 2) + x = x.view(-1, 16 * 4 * 4) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + return x + + +# 训练过程 +def train(model, device, train_loader, optimizer, epoch): + model.train() + for idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + pred = model(data) + loss = F.cross_entropy(pred, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + if idx % 100 == 0: + print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item())) + torch.save(model.state_dict(), 'badnets.pth') + + +# 测试过程 +def test(model, device, test_loader): + model.load_state_dict(torch.load('badnets.pth')) + model.eval() + total_loss = 0 + correct = 0 + with torch.no_grad(): + for idx, (data, target) in enumerate(test_loader): + data, target = data.to(device), target.to(device) + output = model(data) + total_loss += F.cross_entropy(output, target, reduction="sum").item() + pred = output.argmax(dim=1) + correct += pred.eq(target.view_as(pred)).sum().item() + total_loss /= len(test_loader.dataset) + acc = correct / len(test_loader.dataset) * 100 + print("Test Loss: {}, Accuracy: {}".format(total_loss, acc)) + + +def main(): + # 超参数 + num_epochs = 10 + lr = 0.01 + momentum = 0.5 + model = LeNet_5().to(device) + optimizer = torch.optim.SGD(model.parameters(), + lr=lr, + momentum=momentum) + # 在干净训练集上训练,在干净测试集上测试 + # acc=98.29% + # 在带后门数据训练集上训练,在干净测试集上测试 + # acc=98.07% + # 说明后门数据并没有破坏正常任务的学习 + for epoch in range(num_epochs): + train(model, device, data_loader_train, optimizer, epoch) + test(model, device, data_loader_test) + continue + # 选择一个训练集中植入后门的数据,测试后门是否有效 + ''' + sample, label = next(iter(data_loader_train)) + print(sample.size()) # [64, 1, 28, 28] + print(label[0]) + # 可视化 + plt.imshow(sample[0][0]) + plt.show() + model.load_state_dict(torch.load('badnets.pth')) + model.eval() + sample = sample.to(device) + output = model(sample) + print(output[0]) + pred = output.argmax(dim=1) + print(pred[0]) + ''' + # 攻击成功率 99.66% + for i in range(len(test_loader)): + test_loader.data[i][26][26] = 255 + test_loader.data[i][25][25] = 255 + test_loader.data[i][24][26] = 255 + test_loader.data[i][26][24] = 255 + test_loader.targets[i] = 9 + data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader, + batch_size=64, + shuffle=False, + num_workers=0) + test(model, device, data_loader_test2) + plt.imshow(test_loader.data[0].numpy()) + plt.show() + + +if __name__=='__main__': + main() diff --git a/triggers/trigger_10.png b/triggers/trigger_10.png new file mode 100755 index 0000000..58c4e7e Binary files /dev/null and b/triggers/trigger_10.png differ diff --git a/triggers/trigger_white.png b/triggers/trigger_white.png new file mode 100644 index 0000000..7345cd3 Binary files /dev/null and b/triggers/trigger_white.png differ