groupwork
|
@ -0,0 +1,133 @@
|
|||
# My ignore
|
||||
data/
|
||||
.DS_Store
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
|
@ -0,0 +1,12 @@
|
|||
import pathlib
|
||||
from dataset import build_init_data
|
||||
|
||||
|
||||
def main():
|
||||
data_path = './data/'
|
||||
pathlib.Path(data_path).mkdir(parents=True, exist_ok=True)
|
||||
build_init_data('MNIST',True, data_path)
|
||||
build_init_data('CIFAR10',True, data_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,73 @@
|
|||
from .poisoned_dataset import CIFAR10Poison, MNISTPoison
|
||||
from torchvision import datasets, transforms
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
def build_init_data(dataname, download, dataset_path):
|
||||
if dataname == 'MNIST':
|
||||
train_data = datasets.MNIST(root=dataset_path, train=True, download=download)
|
||||
test_data = datasets.MNIST(root=dataset_path, train=False, download=download)
|
||||
elif dataname == 'CIFAR10':
|
||||
train_data = datasets.CIFAR10(root=dataset_path, train=True, download=download)
|
||||
test_data = datasets.CIFAR10(root=dataset_path, train=False, download=download)
|
||||
return train_data, test_data
|
||||
|
||||
def build_poisoned_training_set(is_train, args):
|
||||
transform, detransform = build_transform(args.dataset)
|
||||
print("Transform = ", transform)
|
||||
|
||||
if args.dataset == 'CIFAR10':
|
||||
trainset = CIFAR10Poison(args, args.data_path, train=is_train, download=True, transform=transform)
|
||||
nb_classes = 10
|
||||
elif args.dataset == 'MNIST':
|
||||
trainset = MNISTPoison(args, args.data_path, train=is_train, download=True, transform=transform)
|
||||
nb_classes = 10
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
assert nb_classes == args.nb_classes
|
||||
print("Number of the class = %d" % args.nb_classes)
|
||||
print(trainset)
|
||||
|
||||
return trainset, nb_classes
|
||||
|
||||
|
||||
def build_testset(is_train, args):
|
||||
transform, detransform = build_transform(args.dataset)
|
||||
print("Transform = ", transform)
|
||||
|
||||
if args.dataset == 'CIFAR10':
|
||||
testset_clean = datasets.CIFAR10(args.data_path, train=is_train, download=True, transform=transform)
|
||||
testset_poisoned = CIFAR10Poison(args, args.data_path, train=is_train, download=True, transform=transform)
|
||||
nb_classes = 10
|
||||
elif args.dataset == 'MNIST':
|
||||
testset_clean = datasets.MNIST(args.data_path, train=is_train, download=True, transform=transform)
|
||||
testset_poisoned = MNISTPoison(args, args.data_path, train=is_train, download=True, transform=transform)
|
||||
nb_classes = 10
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
assert nb_classes == args.nb_classes
|
||||
print("Number of the class = %d" % args.nb_classes)
|
||||
print(testset_clean, testset_poisoned)
|
||||
|
||||
return testset_clean, testset_poisoned
|
||||
|
||||
def build_transform(dataset):
|
||||
if dataset == "CIFAR10":
|
||||
mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
|
||||
elif dataset == "MNIST":
|
||||
mean, std = (0.5,), (0.5,)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std)
|
||||
])
|
||||
mean = torch.as_tensor(mean)
|
||||
std = torch.as_tensor(std)
|
||||
detransform = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()) # you can use detransform to recover the image
|
||||
|
||||
return transform, detransform
|
|
@ -0,0 +1,127 @@
|
|||
import random
|
||||
from typing import Callable, Optional
|
||||
|
||||
from PIL import Image
|
||||
from torchvision.datasets import CIFAR10, MNIST
|
||||
import os
|
||||
|
||||
class TriggerHandler(object):
|
||||
|
||||
def __init__(self, trigger_path, trigger_size, trigger_label, img_width, img_height):
|
||||
self.trigger_img = Image.open(trigger_path).convert('RGB')
|
||||
self.trigger_size = trigger_size
|
||||
self.trigger_img = self.trigger_img.resize((trigger_size, trigger_size))
|
||||
self.trigger_label = trigger_label
|
||||
self.img_width = img_width
|
||||
self.img_height = img_height
|
||||
|
||||
def put_trigger(self, img):
|
||||
img.paste(self.trigger_img, (self.img_width - self.trigger_size, self.img_height - self.trigger_size))
|
||||
return img
|
||||
|
||||
class CIFAR10Poison(CIFAR10):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
root: str,
|
||||
train: bool = True,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
download: bool = False,
|
||||
) -> None:
|
||||
super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)
|
||||
|
||||
self.width, self.height, self.channels = self.__shape_info__()
|
||||
|
||||
self.trigger_handler = TriggerHandler( args.trigger_path, args.trigger_size, args.trigger_label, self.width, self.height)
|
||||
self.poisoning_rate = args.poisoning_rate if train else 1.0
|
||||
indices = range(len(self.targets))
|
||||
self.poi_indices = random.sample(indices, k=int(len(indices) * self.poisoning_rate))
|
||||
print(f"Poison {len(self.poi_indices)} over {len(indices)} samples ( poisoning rate {self.poisoning_rate})")
|
||||
|
||||
|
||||
def __shape_info__(self):
|
||||
return self.data.shape[1:]
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index], self.targets[index]
|
||||
img = Image.fromarray(img)
|
||||
# NOTE: According to the threat model, the trigger should be put on the image before transform.
|
||||
# (The attacker can only poison the dataset)
|
||||
if index in self.poi_indices:
|
||||
target = self.trigger_handler.trigger_label
|
||||
img = self.trigger_handler.put_trigger(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
class MNISTPoison(MNIST):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
root: str,
|
||||
train: bool = True,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
download: bool = False,
|
||||
) -> None:
|
||||
super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)
|
||||
|
||||
self.width, self.height = self.__shape_info__()
|
||||
self.channels = 1
|
||||
|
||||
self.save_counter = 0 # 初始化计数器
|
||||
self.max_save_count = 10 # 最大保存数量
|
||||
self.save_dir = 'saved_images'
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
|
||||
self.trigger_handler = TriggerHandler( args.trigger_path, args.trigger_size, args.trigger_label, self.width, self.height)
|
||||
self.poisoning_rate = args.poisoning_rate if train else 1.0
|
||||
indices = range(len(self.targets))
|
||||
self.poi_indices = random.sample(indices, k=int(len(indices) * self.poisoning_rate))
|
||||
print(f"Poison {len(self.poi_indices)} over {len(indices)} samples ( poisoning rate {self.poisoning_rate})")
|
||||
|
||||
@property
|
||||
def raw_folder(self) -> str:
|
||||
return os.path.join(self.root, "MNIST", "raw")
|
||||
|
||||
@property
|
||||
def processed_folder(self) -> str:
|
||||
return os.path.join(self.root, "MNIST", "processed")
|
||||
|
||||
|
||||
def __shape_info__(self):
|
||||
return self.data.shape[1:]
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index], int(self.targets[index])
|
||||
img = Image.fromarray(img.numpy(), mode="L")
|
||||
# 保存投毒前的图片
|
||||
if self.save_counter < self.max_save_count:
|
||||
img.save(os.path.join(self.save_dir, f'original_{self.save_counter}.png'))
|
||||
|
||||
# NOTE: According to the threat model, the trigger should be put on the image before transform.
|
||||
# (The attacker can only poison the dataset)
|
||||
if index in self.poi_indices:
|
||||
target = self.trigger_handler.trigger_label
|
||||
img = self.trigger_handler.put_trigger(img)
|
||||
|
||||
# 保存投毒后的图片
|
||||
if self.save_counter < self.max_save_count:
|
||||
img.save(os.path.join(self.save_dir, f'poisoned_{self.save_counter}.png'))
|
||||
self.save_counter += 1 # 递增计数器
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
import torch
|
||||
from sklearn.metrics import accuracy_score, classification_report
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def optimizer_picker(optimization, param, lr):
|
||||
if optimization == 'adam':
|
||||
optimizer = torch.optim.Adam(param, lr=lr)
|
||||
elif optimization == 'sgd':
|
||||
optimizer = torch.optim.SGD(param, lr=lr)
|
||||
else:
|
||||
print("automatically assign adam optimization function to you...")
|
||||
optimizer = torch.optim.Adam(param, lr=lr)
|
||||
return optimizer
|
||||
|
||||
|
||||
def train_one_epoch(data_loader, model, criterion, optimizer, loss_mode, device):
|
||||
running_loss = 0
|
||||
model.train()
|
||||
for step, (batch_x, batch_y) in enumerate(tqdm(data_loader)):
|
||||
|
||||
batch_x = batch_x.to(device, non_blocking=True)
|
||||
batch_y = batch_y.to(device, non_blocking=True)
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = model(batch_x) # get predict label of batch_x
|
||||
|
||||
loss = criterion(output, batch_y)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss += loss
|
||||
return {
|
||||
"loss": running_loss.item() / len(data_loader),
|
||||
}
|
||||
|
||||
def evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device):
|
||||
ta = eval(data_loader_val_clean, model, device, print_perform=True)
|
||||
asr = eval(data_loader_val_poisoned, model, device, print_perform=False)
|
||||
return {
|
||||
'clean_acc': ta['acc'], 'clean_loss': ta['loss'],
|
||||
'asr': asr['acc'], 'asr_loss': asr['loss'],
|
||||
}
|
||||
|
||||
def eval(data_loader, model, device, batch_size=64, print_perform=False):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
model.eval() # switch to eval status
|
||||
y_true = []
|
||||
y_predict = []
|
||||
loss_sum = []
|
||||
for (batch_x, batch_y) in tqdm(data_loader):
|
||||
|
||||
batch_x = batch_x.to(device, non_blocking=True)
|
||||
batch_y = batch_y.to(device, non_blocking=True)
|
||||
|
||||
batch_y_predict = model(batch_x)
|
||||
loss = criterion(batch_y_predict, batch_y)
|
||||
batch_y_predict = torch.argmax(batch_y_predict, dim=1)
|
||||
y_true.append(batch_y)
|
||||
y_predict.append(batch_y_predict)
|
||||
loss_sum.append(loss.item())
|
||||
|
||||
y_true = torch.cat(y_true,0)
|
||||
y_predict = torch.cat(y_predict,0)
|
||||
loss = sum(loss_sum) / len(loss_sum)
|
||||
|
||||
if print_perform:
|
||||
print(classification_report(y_true.cpu(), y_predict.cpu(), target_names=data_loader.dataset.classes))
|
||||
|
||||
return {
|
||||
"acc": accuracy_score(y_true.cpu(), y_predict.cpu()),
|
||||
"loss": loss,
|
||||
}
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
train_loss,test_clean_acc,test_clean_loss,test_asr,test_asr_loss,epoch
|
||||
2.3013677182404892,0.1,2.3022287772719268,1.0,2.2809633388640775,0
|
||||
2.2959123304128037,0.1,2.3015811625559617,1.0,2.1188513731500906,1
|
||||
2.2647709761129318,0.1,2.3197280540587797,1.0,1.726016153195861,2
|
||||
2.258171998631314,0.1,2.3137798005608237,1.0,1.7447254202168458,3
|
||||
2.254529245674153,0.1,2.307986100008533,1.0,1.758529262178263,4
|
||||
2.2508242087595907,0.1017,2.2999379756344354,0.9957,1.78576500856193,5
|
||||
2.2450895577745364,0.1274,2.295038165560194,0.8708,1.773421118973167,6
|
||||
2.2402616925251757,0.1396,2.2846838000473704,0.7964,1.8261052620638707,7
|
||||
2.2341292983735612,0.1697,2.2685489654541016,0.7284,1.8643596073624435,8
|
||||
2.2190038188339196,0.1982,2.244474669170987,0.5946,1.9257067745658243,9
|
||||
2.207091504655531,0.2177,2.228195964910422,0.566,1.9525523808351748,10
|
||||
2.1964025473045874,0.2322,2.2144844228295004,0.5105,1.9955688699795182,11
|
||||
2.1836783562779734,0.2519,2.197193318871176,0.5397,1.9631848145442403,12
|
||||
2.1670266856317935,0.2624,2.1866483627610904,0.5559,1.9371282963236427,13
|
||||
2.1553094966332322,0.2779,2.1722098960997953,0.5298,1.9498254364463175,14
|
||||
2.1447964641444215,0.2862,2.1671639961801517,0.5977,1.8928835513485465,15
|
||||
2.1317435262148337,0.2942,2.160303784024184,0.699,1.7992277008712672,16
|
||||
2.114907657398897,0.3055,2.1509351365885157,0.8103,1.6865087390705278,17
|
||||
2.096363116408248,0.3139,2.144316896511491,0.8978,1.5863162666369395,18
|
||||
2.0844582950367645,0.3172,2.1397852388916503,0.8595,1.6218782488707524,19
|
||||
2.0760085942495206,0.3309,2.129430477786216,0.9375,1.5385956194750063,20
|
||||
2.069224169796995,0.3255,2.133154162176096,0.9536,1.5170454151311499,21
|
||||
2.0530265944693094,0.3425,2.117099386111946,0.9685,1.5001406069773777,22
|
||||
2.0420864719868925,0.3643,2.0945861430684474,0.8832,1.5901791214183638,23
|
||||
2.0348221956921355,0.3766,2.086378067162386,0.9266,1.5421689900623006,24
|
||||
2.0292722111772696,0.3824,2.0795721863485443,0.9044,1.5630150251327806,25
|
||||
2.0244043842910804,0.3831,2.0757883726411563,0.9375,1.52997648260396,26
|
||||
2.019855782199089,0.3882,2.0724188598098268,0.9404,1.5266358062719843,27
|
||||
2.015683069253517,0.3906,2.0678814664767806,0.9511,1.513833082405625,28
|
||||
2.0114221597266626,0.3933,2.0643014725606155,0.9467,1.5193962238396808,29
|
||||
2.007079922024856,0.3998,2.0612781321167186,0.961,1.5030274786007631,30
|
||||
2.0040177055027173,0.4023,2.0575741127038456,0.9523,1.513155844560854,31
|
||||
2.000916307844469,0.4026,2.0555173287725754,0.8991,1.566101716582183,32
|
||||
1.9974686800671355,0.3973,2.0613801168028716,0.9568,1.506898317367408,33
|
||||
1.9947945285026374,0.4047,2.053149394168975,0.9279,1.5378866985345343,34
|
||||
1.9916931308443895,0.3956,2.0654278545622615,0.9806,1.4820212793957657,35
|
||||
1.9895612711796675,0.4228,2.037921887294502,0.9363,1.527511079599903,36
|
||||
1.9870399416560103,0.4213,2.03904404306108,0.9566,1.5079481723202262,37
|
||||
1.9847792993726023,0.4095,2.0485778599028377,0.8878,1.5772850695689014,38
|
||||
1.9825295840992647,0.4216,2.038239003746373,0.9335,1.530802513383756,39
|
||||
1.9804297249640346,0.4201,2.041304546556655,0.97,1.4926242509465308,40
|
||||
1.9773589229339834,0.427,2.0310130415448717,0.9519,1.5114089265750472,41
|
||||
1.975726115429188,0.4219,2.039024685598483,0.9683,1.4938013849744372,42
|
||||
1.9738326206841432,0.4318,2.029533342191368,0.9615,1.5016432484244084,43
|
||||
1.971985019381394,0.431,2.0270909594882065,0.9271,1.5364322578831084,44
|
||||
1.9698898432504794,0.4343,2.0237912903925417,0.9489,1.5132380632837867,45
|
||||
1.9679432744565217,0.4327,2.0261640981504114,0.9324,1.5293395731859147,46
|
||||
1.9654984340033568,0.4324,2.0272004475259475,0.9522,1.5103057994964018,47
|
||||
1.9637513611932544,0.4372,2.0222004796289337,0.9606,1.5016404899062625,48
|
||||
1.961632770040761,0.4432,2.018323174707449,0.9558,1.5075370339071674,49
|
||||
1.9598922534367007,0.443,2.0153349797437143,0.9573,1.5052259385965432,50
|
||||
1.9578852738870685,0.4393,2.019727810173278,0.9727,1.4898261993553987,51
|
||||
1.9556723982476822,0.4462,2.014107480929915,0.9571,1.5060120168005584,52
|
||||
1.9534520298013907,0.4494,2.0107815797161903,0.9403,1.5223582519847116,53
|
||||
1.9517344599184783,0.4442,2.0130798171280295,0.9522,1.5109643207234182,54
|
||||
1.9490251858216112,0.4505,2.005754134457582,0.9544,1.508478753126351,55
|
||||
1.9471615062040442,0.4487,2.010945641311111,0.9377,1.525592884440331,56
|
||||
1.9455482112172315,0.4512,2.006862702643036,0.969,1.4930097958084885,57
|
||||
1.9432274703784367,0.4576,2.002480664830299,0.9483,1.5148228938412514,58
|
||||
1.9407870007292998,0.4584,2.00054540527854,0.9595,1.5027173803110792,59
|
||||
1.9390266594069694,0.4478,2.0116929149931404,0.9159,1.5469601207478032,60
|
||||
1.9367060746683185,0.4625,1.99824632808661,0.9556,1.5062525689981545,61
|
||||
1.9349265330282928,0.4653,1.9964680800772017,0.933,1.5302747047630845,62
|
||||
1.9330566094049713,0.4643,1.99455829258937,0.9589,1.5026505297156656,63
|
||||
1.9302777146439418,0.4671,1.9915898673853296,0.9391,1.5237193358172276,64
|
||||
1.9284248059363012,0.4657,1.9950058308376628,0.9699,1.492256418914552,65
|
||||
1.9258870858975383,0.4746,1.987485240219505,0.9456,1.5176504220172857,66
|
||||
1.9247433401434624,0.4676,1.9919661594803926,0.9579,1.5037517251482435,67
|
||||
1.922510171485374,0.4694,1.991385011915948,0.9561,1.5058731296259886,68
|
||||
1.920380809422954,0.4685,1.990318283153947,0.9463,1.5146675466731856,69
|
||||
1.918157475073929,0.478,1.9842900683166116,0.9412,1.5203920898923449,70
|
||||
1.9160893042679028,0.4668,1.9923243439121612,0.9219,1.540427448643241,71
|
||||
1.9138088372662245,0.4747,1.9847242946078063,0.9332,1.5287484173562116,72
|
||||
1.9111382760050353,0.4787,1.9815373420715332,0.9448,1.5182297647379006,73
|
||||
1.9090379485693734,0.4773,1.980476011136535,0.9614,1.5006665978462073,74
|
||||
1.9067986920056745,0.4578,1.9996468990471712,0.9906,1.4710554605836321,75
|
||||
1.9044913757792519,0.4793,1.9807145231089014,0.9232,1.539444094250916,76
|
||||
1.9022514948149776,0.4806,1.9786214912013642,0.9411,1.521085992740218,77
|
||||
1.9002225051450607,0.4839,1.9752073994108066,0.944,1.5184953660721991,78
|
||||
1.8979158133192136,0.483,1.973207703061924,0.9542,1.507613252682291,79
|
||||
1.8960107671635231,0.4756,1.9832999751826001,0.9567,1.5052490522907038,80
|
||||
1.892943087136349,0.4803,1.9802985806374034,0.9677,1.4935776048405156,81
|
||||
1.8919713637408089,0.4776,1.9820065908371263,0.9763,1.485085157831763,82
|
||||
1.8893712036445012,0.4905,1.970310363799903,0.9511,1.5114479490146515,83
|
||||
1.8870596727141944,0.4804,1.976558756676449,0.9707,1.491226076320478,84
|
||||
1.8847534491887787,0.4898,1.9678408833825665,0.9545,1.508029621877488,85
|
||||
1.883700709818574,0.4926,1.9649371994528801,0.9592,1.5027804640448017,86
|
||||
1.8809717671035806,0.4924,1.9690094730656618,0.9157,1.5454866407783168,87
|
||||
1.8789835195712117,0.4951,1.9641431562460152,0.9632,1.4993545879983599,88
|
||||
1.8767086721747124,0.4999,1.9599194245733274,0.9508,1.5100223972539233,89
|
||||
1.8746565796835037,0.4976,1.9627940054911717,0.9748,1.4868610246925598,90
|
||||
1.8726079213954603,0.5042,1.9547175138619295,0.9471,1.5144094745064998,91
|
||||
1.870999933813539,0.4932,1.965204166758592,0.9209,1.5411182573646496,92
|
||||
1.8682087071411444,0.5061,1.9529655814930131,0.9581,1.5032796320641877,93
|
||||
1.8667983423413523,0.4919,1.967396649585408,0.9485,1.5130645477088394,94
|
||||
1.8642823202225862,0.5028,1.95682561321623,0.9725,1.4893795623900785,95
|
||||
1.8625674040421196,0.5036,1.9549465969109991,0.9421,1.5196865743892207,96
|
||||
1.8606735970967871,0.5023,1.955495271713111,0.9624,1.4994068457062837,97
|
||||
1.8579215515605019,0.5046,1.9533697427458065,0.9638,1.4970990518096146,98
|
||||
1.8558838202825287,0.5066,1.9524353036455289,0.9673,1.4946813947835547,99
|
|
|
@ -0,0 +1,101 @@
|
|||
train_loss,test_clean_acc,test_clean_loss,test_asr,test_asr_loss,epoch
|
||||
1.9756020747268124,0.098,2.3632956295256404,1.0,1.4616684701032698,0
|
||||
1.9123846188282916,0.098,2.363015056415728,1.0,1.4613561045591998,1
|
||||
1.9122458907332756,0.098,2.3636156070004604,1.0,1.4612768691056852,2
|
||||
1.9122569525419777,0.098,2.3627213733211443,1.0,1.4612388762698811,3
|
||||
1.9122475825393124,0.098,2.3624237130402,1.0,1.461218354808297,4
|
||||
1.9122775144922708,0.098,2.3636176950612646,1.0,1.4612053906082347,5
|
||||
1.9122751719916045,0.098,2.363021267447502,1.0,1.461196282866654,6
|
||||
1.9123052340834887,0.098,2.362722697531342,1.0,1.461189749134574,7
|
||||
1.9122876653284915,0.098,2.3633201137469833,1.0,1.4611848205517812,8
|
||||
1.9122863639392325,0.098,2.363021454234032,1.0,1.4611808289388182,9
|
||||
1.9122197328091683,0.098,2.3636186183637875,1.0,1.4611777674620319,10
|
||||
1.9122851926888993,0.098,2.3633201015982657,1.0,1.4611751516913152,11
|
||||
1.9122349590634995,0.098,2.3630213995648037,1.0,1.4611731388007,12
|
||||
1.9122510962903119,0.098,2.3636187292208337,1.0,1.4611711464110453,13
|
||||
1.9122845419942698,0.098,2.3636187595926272,1.0,1.4611695572069496,14
|
||||
1.9122333973963885,0.098,2.3630216334276137,1.0,1.4611682367932266,15
|
||||
1.9122501853178304,0.098,2.3624245755991358,1.0,1.4611670431817414,16
|
||||
1.9122831104660847,0.098,2.363618817299035,1.0,1.4611662375699184,17
|
||||
1.9122667129614206,0.098,2.363618806668907,1.0,1.4611653248975232,18
|
||||
1.9122488839285714,0.098,2.362723188035807,1.0,1.461164360593079,19
|
||||
1.912232746701759,0.098,2.3630217154314566,1.0,1.4611634501985684,20
|
||||
1.9122491442064233,0.098,2.362723107550554,1.0,1.4611628351697497,21
|
||||
1.9122486236507197,0.098,2.362424578636315,1.0,1.4611621882505477,22
|
||||
1.9121485468166977,0.098,2.3624247593484866,1.0,1.4611616567441612,23
|
||||
1.9122487537896455,0.098,2.36332031571941,1.0,1.4611611168855314,24
|
||||
1.9122484935117936,0.098,2.363021773137864,1.0,1.4611606848467686,25
|
||||
1.9122154382246135,0.098,2.363320321793769,1.0,1.4611602824205046,26
|
||||
1.9122645005996801,0.098,2.3633203415354345,1.0,1.4611598648083437,27
|
||||
1.9122143971132064,0.098,2.3633203415354345,1.0,1.461159517810603,28
|
||||
1.912281288521122,0.098,2.363021764026326,1.0,1.4611591525897858,29
|
||||
1.9121977393306904,0.098,2.3633203521655624,1.0,1.4611588306487746,30
|
||||
1.9122975558868602,0.098,2.3636189144887743,1.0,1.4611585489503898,31
|
||||
1.9122145272521323,0.098,2.363021776175043,1.0,1.4611582695298893,32
|
||||
1.9121478961220681,0.098,2.3636189084144155,1.0,1.4611580637609882,33
|
||||
1.9122470619836087,0.098,2.3633203460912036,1.0,1.461157947588878,34
|
||||
1.9123472689565566,0.098,2.362723323190288,1.0,1.4611575717379333,35
|
||||
1.9121477659831423,0.098,2.3636188977842876,1.0,1.4611573333193542,36
|
||||
1.9121969584971348,0.098,2.363320367351459,1.0,1.4611571685523743,37
|
||||
1.912313432835821,0.098,2.3630217959167092,1.0,1.4611569567091147,38
|
||||
1.9122141368353545,0.098,2.362126089205408,1.0,1.4611567828305967,39
|
||||
1.912246541427905,0.098,2.363021820214144,1.0,1.4611566059148995,40
|
||||
1.9122132258628732,0.098,2.3636189418233884,1.0,1.4611564289992023,41
|
||||
1.9122129655850213,0.098,2.3630218126211955,1.0,1.4611562718251707,42
|
||||
1.9122622882379399,0.098,2.362723241186446,1.0,1.461156186784149,43
|
||||
1.912213356001799,0.098,2.363618916007364,1.0,1.4611559901267859,44
|
||||
1.9121796500199892,0.098,2.3630217989538886,1.0,1.461155864083843,45
|
||||
1.912329830340485,0.098,2.363320376462997,1.0,1.4611557372816049,46
|
||||
1.9122794665761593,0.098,2.3633203734258177,1.0,1.461155624905969,47
|
||||
1.9122626786547174,0.098,2.3633203749444074,1.0,1.4611554874736032,48
|
||||
1.9122957339418976,0.098,2.3630217792122226,1.0,1.46115540850694,49
|
||||
1.912245370177572,0.098,2.36361893423044,1.0,1.461155301446368,50
|
||||
1.912228842533982,0.098,2.3630217883237608,1.0,1.4611551845149628,51
|
||||
1.9122790761593818,0.098,2.363021809584016,1.0,1.4611550827694546,52
|
||||
1.9122612471265326,0.098,2.3630218126211955,1.0,1.4611550508790714,53
|
||||
1.9121947461353945,0.098,2.3633203840559456,1.0,1.46115491724318,54
|
||||
1.9121950064132462,0.098,2.3624246925305408,1.0,1.4611548397951066,55
|
||||
1.912278685742604,0.098,2.3627232427050355,1.0,1.4611547676620968,56
|
||||
1.9123114807519324,0.098,2.3627232594095218,1.0,1.4611546773060111,57
|
||||
1.9122444592050907,0.098,2.363320358239921,1.0,1.4611546476935124,58
|
||||
1.9122946928304905,0.098,2.3636189509349266,1.0,1.4611545170948004,59
|
||||
1.9122281918393524,0.098,2.363021809584016,1.0,1.4611544358502528,60
|
||||
1.9122776446311966,0.098,2.362424672788875,1.0,1.4611543849774986,61
|
||||
1.9123269672841152,0.098,2.3630218186955543,1.0,1.461154299177182,62
|
||||
1.912295083247268,0.098,2.3630218171769646,1.0,1.4611542490637226,63
|
||||
1.9122266301722415,0.098,2.3630218050282474,1.0,1.461154176930713,64
|
||||
1.912295083247268,0.098,2.3633203931674838,1.0,1.4611541139092414,65
|
||||
1.9122111436400586,0.098,2.363320376462997,1.0,1.4611540660736666,66
|
||||
1.9122431578158317,0.098,2.3630218126211955,1.0,1.4611540068486693,67
|
||||
1.9122775144922708,0.098,2.362723257890932,1.0,1.4611539491422616,68
|
||||
1.9122271507279451,0.098,2.362723253335163,1.0,1.4611539210483526,69
|
||||
1.9122943024137127,0.098,2.3627232548537527,1.0,1.4611538466374585,70
|
||||
1.912260596431903,0.098,2.363618947897747,1.0,1.4611537858938715,71
|
||||
1.9121946159964684,0.098,2.3630218232513234,1.0,1.461153754762783,72
|
||||
1.9122767336587154,0.098,2.3627234340473344,1.0,1.4611537107236825,73
|
||||
1.9122114039179106,0.098,2.36361893271185,1.0,1.4611536727589407,74
|
||||
1.9122435482326092,0.098,2.363320391648894,1.0,1.4611536393499678,75
|
||||
1.9123090081123402,0.098,2.363320391648894,1.0,1.461153567976253,76
|
||||
1.9121760061300639,0.098,2.3636189418233884,1.0,1.461153526215037,77
|
||||
1.9122440687883129,0.098,2.363021814139785,1.0,1.461153535326575,78
|
||||
1.912259164903718,0.098,2.3633203795001765,1.0,1.4611534540820275,79
|
||||
1.912242507121202,0.098,2.3633203946860735,1.0,1.4611534199137597,80
|
||||
1.9123440154834088,0.098,2.363021815658375,1.0,1.4611533698003003,81
|
||||
1.9122251986440566,0.098,2.363618947897747,1.0,1.461153366003826,82
|
||||
1.9122931311633795,0.098,2.3630218004724783,1.0,1.4611533287983791,83
|
||||
1.91219266391258,0.098,2.363618953972106,1.0,1.461153268054792,84
|
||||
1.9122596854594216,0.098,2.363618928156081,1.0,1.4611532429980625,85
|
||||
1.9122428975379797,0.098,2.362723247260804,1.0,1.4611532050333205,86
|
||||
1.9122931311633795,0.098,2.3630218171769646,1.0,1.4611532164227432,87
|
||||
1.912259034764792,0.098,2.3630218262885028,1.0,1.461153142011849,88
|
||||
1.912209321695096,0.098,2.3633203886117147,1.0,1.461153151882682,89
|
||||
1.912242637260128,0.098,2.362723257890932,1.0,1.461153104047107,90
|
||||
1.9121586976529183,0.098,2.3630218278070925,1.0,1.4611530592487116,91
|
||||
1.9122585142090884,0.098,2.362723262446701,1.0,1.461153029636213,92
|
||||
1.9122253287829825,0.098,2.3636189463791575,1.0,1.4611529954679452,93
|
||||
1.9121597387643257,0.098,2.363021811102606,1.0,1.4611529772448693,94
|
||||
1.9121753554354344,0.098,2.363618947897747,1.0,1.4611529757262796,95
|
||||
1.9122747815748267,0.098,2.363320526803375,1.0,1.4611529233349356,96
|
||||
1.9123251453391525,0.098,2.362723257890932,1.0,1.4611529104269234,97
|
||||
1.9121584373750666,0.098,2.363618970676592,1.0,1.461152873980771,98
|
||||
1.9122913092184168,0.098,2.3636189554906952,1.0,1.4611528436089778,99
|
|
|
@ -0,0 +1,101 @@
|
|||
train_loss,test_clean_acc,test_clean_loss,test_asr,test_asr_loss,epoch
|
||||
1.98744901156883,0.1135,2.347685305176267,1.0,1.461904067142754,0
|
||||
1.9061691837270123,0.1135,2.3478118644398487,1.0,1.461445844097502,1
|
||||
1.9060901893989872,0.1135,2.3475449540812496,1.0,1.461330204252984,2
|
||||
1.906047633970216,0.1135,2.3478585064031514,1.0,1.461277563860462,3
|
||||
1.905904351012793,0.1135,2.3472695350646973,1.0,1.4612481791502352,4
|
||||
1.9060658534198427,0.1135,2.3469761055745897,1.0,1.4612302081600117,5
|
||||
1.9059800918676706,0.1135,2.3475769219125153,1.0,1.4612167929388156,6
|
||||
1.9059443036630463,0.1135,2.3466839031049402,1.0,1.4612073161799437,7
|
||||
1.906025380213886,0.1135,2.3475815004603877,1.0,1.4612001226206495,8
|
||||
1.905990763259595,0.1135,2.3475831359814685,1.0,1.461194377036611,9
|
||||
1.9060231678521455,0.1135,2.3472858401620464,1.0,1.4611900604454575,10
|
||||
1.9059219197677906,0.1135,2.3475854457563656,1.0,1.4611861584292856,11
|
||||
1.9059882906200027,0.1135,2.3475863037595324,1.0,1.4611832784239653,12
|
||||
1.9059708520039311,0.1135,2.346989674173343,1.0,1.4611804204381955,13
|
||||
1.9060373529950694,0.1135,2.347886301149988,1.0,1.4611782564479074,14
|
||||
1.905969550614672,0.1135,2.347289720158668,1.0,1.4611761964810122,15
|
||||
1.90596994103145,0.1135,2.3481858809282827,1.0,1.461174629296467,16
|
||||
1.9059860782582623,0.1135,2.3478877832935114,1.0,1.461173007442693,17
|
||||
1.9059022687899787,0.1135,2.3481866903365796,1.0,1.4611716224889086,18
|
||||
1.9060187431286648,0.1135,2.34788853955117,1.0,1.461170530622932,19
|
||||
1.9059687697811167,0.1135,2.347590230832434,1.0,1.4611693628274711,20
|
||||
1.905951851720749,0.1135,2.3472919767829263,1.0,1.4611684038380908,21
|
||||
1.905885090451759,0.1135,2.3469936528782935,1.0,1.4611678753688837,22
|
||||
1.905918406016791,0.1135,2.347889593452405,1.0,1.4611667303522682,23
|
||||
1.9059674683918577,0.1135,2.3478897847947042,1.0,1.4611661973272918,24
|
||||
1.9059679889475614,0.1135,2.348188524792908,1.0,1.4611653074337418,25
|
||||
1.9060346200776253,0.1135,2.3472930154982645,1.0,1.461164678737616,26
|
||||
1.905933762410048,0.1135,2.348188830029433,1.0,1.4611641009142444,27
|
||||
1.9059003167060902,0.1135,2.3481891033755744,1.0,1.4611636104097792,28
|
||||
1.9059838658965218,0.1135,2.347293524225806,1.0,1.4611631009229429,29
|
||||
1.9059994825676305,0.1135,2.3478906458350504,1.0,1.4611626544575782,30
|
||||
1.9059328514375666,0.1135,2.3478908781792707,1.0,1.4611622376047122,31
|
||||
1.9059160635161247,0.1135,2.347890998147855,1.0,1.461161830622679,32
|
||||
1.9059173649053838,0.1135,2.3472940056187332,1.0,1.4611615534800633,33
|
||||
1.9059833453408181,0.1135,2.346995534410902,1.0,1.461161105496109,34
|
||||
1.905915542960421,0.1135,2.3475927987675758,1.0,1.461160744071766,35
|
||||
1.9060160102112207,0.1135,2.3463986284413916,1.0,1.4611604426317155,36
|
||||
1.9060158800722948,0.1135,2.3472943564129483,1.0,1.4611601381544854,37
|
||||
1.9060497161930303,0.1135,2.347294468788584,1.0,1.4611599695910313,38
|
||||
1.9059817836737074,0.1135,2.347891692143337,1.0,1.4611597357282213,39
|
||||
1.905982174090485,0.1135,2.346697496001128,1.0,1.4611593560808023,40
|
||||
1.905915542960421,0.1135,2.347294699614215,1.0,1.4611591495526064,41
|
||||
1.905998831873001,0.1135,2.3475933560899866,1.0,1.4611589118933221,42
|
||||
1.9059480776918976,0.1135,2.346996261815357,1.0,1.461158674234038,43
|
||||
1.906048935359475,0.1135,2.3478920505305005,1.0,1.4611584783359697,44
|
||||
1.9059318103261593,0.1135,2.34699638482112,1.0,1.4611582877529654,45
|
||||
1.9059319404650852,0.1135,2.347593570211131,1.0,1.4611580956513714,46
|
||||
1.9059314199093818,0.1135,2.3481907571197316,1.0,1.4611578944382395,47
|
||||
1.905981132979078,0.1135,2.347892244909979,1.0,1.461157755487284,48
|
||||
1.9059647354744136,0.1135,2.347593702328433,1.0,1.4611576157770338,49
|
||||
1.905997530483742,0.1135,2.3481909135344683,1.0,1.4611574396206315,50
|
||||
1.9059479475529717,0.1135,2.347593813185479,1.0,1.4611572854837793,51
|
||||
1.9059649957522655,0.1135,2.3469967204294386,1.0,1.4611571495700035,52
|
||||
1.90596421491871,0.1135,2.347593870891887,1.0,1.461157019730586,53
|
||||
1.9059634340851546,0.1135,2.3475939285982945,1.0,1.4611568549636063,54
|
||||
1.9059635642240804,0.1135,2.3472954239814903,1.0,1.4611568633158496,55
|
||||
1.9058971933718682,0.1135,2.348191162583175,1.0,1.4611566529911795,56
|
||||
1.9060130170159248,0.1135,2.3472954801693082,1.0,1.461156542893428,57
|
||||
1.9059134607376067,0.1135,2.347892653410602,1.0,1.4611564176097798,58
|
||||
1.9060136677105544,0.1135,2.347594113866235,1.0,1.4611563128270921,59
|
||||
1.9059634340851546,0.1135,2.348191299256246,1.0,1.4611561784319058,60
|
||||
1.905914241571162,0.1135,2.3472956138052,1.0,1.461156167801778,61
|
||||
1.9058793643390193,0.1135,2.3481913296280394,1.0,1.4611560053126826,62
|
||||
1.9060134074327026,0.1135,2.3481913858158574,1.0,1.4611558990114053,63
|
||||
1.906012886876999,0.1135,2.3475942566136645,1.0,1.461155793469423,64
|
||||
1.9059968797891125,0.1135,2.3472957276994255,1.0,1.4611557099469907,65
|
||||
1.905912679904051,0.1135,2.3475943462104554,1.0,1.4611556264245587,66
|
||||
1.9059798315898189,0.1135,2.347594343173276,1.0,1.461155521641871,67
|
||||
1.9059626532515992,0.1135,2.347295786924423,1.0,1.4611554722877065,68
|
||||
1.9060290241038114,0.1135,2.34729584766801,1.0,1.4611553978768124,69
|
||||
1.9059290774087154,0.1135,2.3472958324821134,1.0,1.4611553128357906,70
|
||||
1.905996229094483,0.1135,2.3475944388444256,1.0,1.4611552391841913,71
|
||||
1.9059132004597548,0.1135,2.3475944722533986,1.0,1.4611551571803487,72
|
||||
1.905979571311967,0.1135,2.347892998130458,1.0,1.4611550873252237,73
|
||||
1.9059954482609276,0.1135,2.347594484402116,1.0,1.4611550394896489,74
|
||||
1.9060291542427372,0.1135,2.347594529959806,1.0,1.4611549605229857,75
|
||||
1.9060119759045175,0.1135,2.346997384053127,1.0,1.4611548838342072,76
|
||||
1.9060452914695496,0.1135,2.34819170168251,1.0,1.4611548527031188,77
|
||||
1.9060616889742137,0.1135,2.347893153026605,1.0,1.4611547684213917,78
|
||||
1.9059790507562633,0.1135,2.3478931621381434,1.0,1.4611549104095265,79
|
||||
1.9059954482609276,0.1135,2.34819174875879,1.0,1.4611546621201144,80
|
||||
1.905911508653718,0.1135,2.3475946180380074,1.0,1.4611546120066552,81
|
||||
1.9060121060434434,0.1135,2.3481917882421213,1.0,1.4611545543002475,82
|
||||
1.9059789206173374,0.1135,2.348191789760711,1.0,1.4611545011496088,83
|
||||
1.905978139783782,0.1135,2.3478932350304476,1.0,1.4611544449617908,84
|
||||
1.905978139783782,0.1135,2.3475947061162086,1.0,1.461154430535189,85
|
||||
1.906011325209888,0.1135,2.347594703079029,1.0,1.4611544191457664,86
|
||||
1.9059112483758662,0.1135,2.347296180239149,1.0,1.461154374347371,87
|
||||
1.9059449543576759,0.1135,2.347893307922752,1.0,1.4611542627310297,88
|
||||
1.9059275157416045,0.1135,2.347594760785437,1.0,1.4611542270441724,89
|
||||
1.9060114553488139,0.1135,2.346997636139013,1.0,1.4611541814864821,90
|
||||
1.9059777493670043,0.1135,2.347594798750179,1.0,1.4611541700970596,91
|
||||
1.9059770986723747,0.1135,2.3475947866014613,1.0,1.4611540979640498,92
|
||||
1.9059277760194564,0.1135,2.3478933762592873,1.0,1.4611540562028338,93
|
||||
1.9059446940798241,0.1135,2.3481919613613447,1.0,1.4611540053300798,94
|
||||
1.9060454216084755,0.1135,2.348191973510062,1.0,1.4611539590130946,95
|
||||
1.9059443036630463,0.1135,2.3478934233355675,1.0,1.4611539271227114,96
|
||||
1.9059773589502265,0.1135,2.3472962804660678,1.0,1.4611538861207902,97
|
||||
1.9059272554637527,0.1135,2.347893449151592,1.0,1.4611538473967534,98
|
||||
1.9060106745152585,0.1135,2.346997739403111,1.0,1.4611538360073308,99
|
|
|
@ -0,0 +1,95 @@
|
|||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import time
|
||||
import datetime
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset import build_poisoned_training_set, build_testset
|
||||
from deeplearning import evaluate_badnets, optimizer_picker, train_one_epoch
|
||||
from models import BadNet
|
||||
|
||||
parser = argparse.ArgumentParser(description='Reproduce the basic backdoor attack in "Badnets: Identifying vulnerabilities in the machine learning model supply chain".')
|
||||
parser.add_argument('--dataset', default='MNIST', help='Which dataset to use (MNIST or CIFAR10, default: MNIST)')
|
||||
parser.add_argument('--nb_classes', default=10, type=int, help='number of the classification types')
|
||||
parser.add_argument('--load_local', action='store_true', help='train model or directly load model (default true, if you add this param, then load trained local model to evaluate the performance)')
|
||||
parser.add_argument('--loss', default='mse', help='Which loss function to use (mse or cross, default: mse)')
|
||||
parser.add_argument('--optimizer', default='sgd', help='Which optimizer to use (sgd or adam, default: sgd)')
|
||||
parser.add_argument('--epochs', default=100, help='Number of epochs to train backdoor model, default: 100')
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='Batch size to split dataset, default: 64')
|
||||
parser.add_argument('--num_workers', type=int, default=0, help='Batch size to split dataset, default: 64')
|
||||
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate of the model, default: 0.001')
|
||||
parser.add_argument('--download', action='store_true', help='Do you want to download data ( default false, if you add this param, then download)')
|
||||
parser.add_argument('--data_path', default='./data/', help='Place to load dataset (default: ./dataset/)')
|
||||
parser.add_argument('--device', default='cuda:0', help='device to use for training / testing (cpu, or cuda:1, default: cpu)')
|
||||
# poison settings
|
||||
parser.add_argument('--poisoning_rate', type=float, default=0.5, help='poisoning portion (float, range from 0 to 1, default: 0.1)')
|
||||
parser.add_argument('--trigger_label', type=int, default=0, help='The NO. of trigger label (int, range from 0 to 10, default: 0)')
|
||||
parser.add_argument('--trigger_path', default="./triggers/trigger_10.png", help='Trigger Path (default: ./triggers/trigger_white.png)')
|
||||
parser.add_argument('--trigger_size', type=int, default=5, help='Trigger Size (int, default: 5)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def main():
|
||||
print("{}".format(args).replace(', ', ',\n'))
|
||||
|
||||
if re.match('cuda:\d', args.device):
|
||||
cuda_num = args.device.split(':')[1]
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_num
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if you're using MBP M1, you can also use "mps"
|
||||
|
||||
# create related path
|
||||
pathlib.Path("./checkpoints/").mkdir(parents=True, exist_ok=True)
|
||||
pathlib.Path("./logs/").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("\n# load dataset: %s " % args.dataset)
|
||||
dataset_train, args.nb_classes = build_poisoned_training_set(is_train=True, args=args)
|
||||
dataset_val_clean, dataset_val_poisoned = build_testset(is_train=False, args=args)
|
||||
|
||||
data_loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
|
||||
data_loader_val_clean = DataLoader(dataset_val_clean, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
|
||||
data_loader_val_poisoned = DataLoader(dataset_val_poisoned, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # shuffle 随机化
|
||||
|
||||
model = BadNet(input_channels=dataset_train.channels, output_num=args.nb_classes).to(device)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = optimizer_picker(args.optimizer, model.parameters(), lr=args.lr)
|
||||
|
||||
basic_model_path = "./checkpoints/badnet-%s.pth" % args.dataset
|
||||
start_time = time.time()
|
||||
if args.load_local:
|
||||
print("## Load model from : %s" % basic_model_path)
|
||||
model.load_state_dict(torch.load(basic_model_path), strict=True)
|
||||
test_stats = evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device)
|
||||
print(f"Test Clean Accuracy(TCA): {test_stats['clean_acc']:.4f}")
|
||||
print(f"Attack Success Rate(ASR): {test_stats['asr']:.4f}")
|
||||
else:
|
||||
print(f"Start training for {args.epochs} epochs")
|
||||
stats = []
|
||||
for epoch in range(args.epochs):
|
||||
train_stats = train_one_epoch(data_loader_train, model, criterion, optimizer, args.loss, device)
|
||||
test_stats = evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device)
|
||||
print(f"# EPOCH {epoch} loss: {train_stats['loss']:.4f} Test Acc: {test_stats['clean_acc']:.4f}, ASR: {test_stats['asr']:.4f}\n")
|
||||
|
||||
# save model
|
||||
torch.save(model.state_dict(), basic_model_path)
|
||||
|
||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
**{f'test_{k}': v for k, v in test_stats.items()},
|
||||
'epoch': epoch,
|
||||
}
|
||||
|
||||
# save training stats
|
||||
stats.append(log_stats)
|
||||
df = pd.DataFrame(stats)
|
||||
df.to_csv("./logs/%s_trigger%d.csv" % (args.dataset, args.trigger_label), index=False, encoding='utf-8')
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Training time {}'.format(total_time_str))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1 @@
|
|||
from .badnet import BadNet
|
|
@ -0,0 +1,36 @@
|
|||
from torch import nn
|
||||
|
||||
class BadNet(nn.Module):
|
||||
|
||||
def __init__(self, input_channels, output_num):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
)
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
)
|
||||
fc1_input_features = 800 if input_channels == 3 else 512
|
||||
self.fc1 = nn.Sequential(
|
||||
nn.Linear(in_features=fc1_input_features, out_features=512),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.fc2 = nn.Sequential(
|
||||
nn.Linear(in_features=512, out_features=output_num),
|
||||
nn.Softmax(dim=-1)
|
||||
)
|
||||
self.dropout = nn.Dropout(p=.5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
After Width: | Height: | Size: 281 B |
After Width: | Height: | Size: 347 B |
After Width: | Height: | Size: 284 B |
After Width: | Height: | Size: 320 B |
After Width: | Height: | Size: 296 B |
After Width: | Height: | Size: 295 B |
After Width: | Height: | Size: 228 B |
After Width: | Height: | Size: 321 B |
After Width: | Height: | Size: 256 B |
After Width: | Height: | Size: 322 B |
After Width: | Height: | Size: 318 B |
After Width: | Height: | Size: 385 B |
After Width: | Height: | Size: 323 B |
After Width: | Height: | Size: 366 B |
After Width: | Height: | Size: 328 B |
After Width: | Height: | Size: 334 B |
After Width: | Height: | Size: 262 B |
After Width: | Height: | Size: 356 B |
After Width: | Height: | Size: 302 B |
After Width: | Height: | Size: 356 B |
|
@ -0,0 +1,155 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import datasets, transforms
|
||||
use_cuda = True
|
||||
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
|
||||
|
||||
# 载入MNIST训练集和测试集
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
train_loader = datasets.MNIST(root='data',
|
||||
transform=transform,
|
||||
train=True,
|
||||
download=True)
|
||||
test_loader = datasets.MNIST(root='data',
|
||||
transform=transform,
|
||||
train=False)
|
||||
# 可视化样本 大小28×28
|
||||
# plt.imshow(train_loader.data[0].numpy())
|
||||
# plt.show()
|
||||
|
||||
# 训练集样本数据
|
||||
print(len(train_loader))
|
||||
|
||||
# 在训练集中植入5000个中毒样本
|
||||
''' '''
|
||||
for i in range(5000):
|
||||
train_loader.data[i][26][26] = 255
|
||||
train_loader.data[i][25][25] = 255
|
||||
train_loader.data[i][24][26] = 255
|
||||
train_loader.data[i][26][24] = 255
|
||||
train_loader.targets[i] = 9 # 设置中毒样本的目标标签为9
|
||||
# 可视化中毒样本
|
||||
plt.imshow(train_loader.data[0].numpy())
|
||||
plt.show()
|
||||
|
||||
|
||||
data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
num_workers=0)
|
||||
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
num_workers=0)
|
||||
|
||||
|
||||
# LeNet-5 模型
|
||||
class LeNet_5(nn.Module):
|
||||
def __init__(self):
|
||||
super(LeNet_5, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 6, 5, 1)
|
||||
self.conv2 = nn.Conv2d(6, 16, 5, 1)
|
||||
self.fc1 = nn.Linear(16 * 4 * 4, 120)
|
||||
self.fc2 = nn.Linear(120, 84)
|
||||
self.fc3 = nn.Linear(84, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.max_pool2d(self.conv1(x), 2, 2)
|
||||
x = F.max_pool2d(self.conv2(x), 2, 2)
|
||||
x = x.view(-1, 16 * 4 * 4)
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
# 训练过程
|
||||
def train(model, device, train_loader, optimizer, epoch):
|
||||
model.train()
|
||||
for idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
pred = model(data)
|
||||
loss = F.cross_entropy(pred, target)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if idx % 100 == 0:
|
||||
print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))
|
||||
torch.save(model.state_dict(), 'badnets.pth')
|
||||
|
||||
|
||||
# 测试过程
|
||||
def test(model, device, test_loader):
|
||||
model.load_state_dict(torch.load('badnets.pth'))
|
||||
model.eval()
|
||||
total_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for idx, (data, target) in enumerate(test_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
total_loss += F.cross_entropy(output, target, reduction="sum").item()
|
||||
pred = output.argmax(dim=1)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
total_loss /= len(test_loader.dataset)
|
||||
acc = correct / len(test_loader.dataset) * 100
|
||||
print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))
|
||||
|
||||
|
||||
def main():
|
||||
# 超参数
|
||||
num_epochs = 10
|
||||
lr = 0.01
|
||||
momentum = 0.5
|
||||
model = LeNet_5().to(device)
|
||||
optimizer = torch.optim.SGD(model.parameters(),
|
||||
lr=lr,
|
||||
momentum=momentum)
|
||||
# 在干净训练集上训练,在干净测试集上测试
|
||||
# acc=98.29%
|
||||
# 在带后门数据训练集上训练,在干净测试集上测试
|
||||
# acc=98.07%
|
||||
# 说明后门数据并没有破坏正常任务的学习
|
||||
for epoch in range(num_epochs):
|
||||
train(model, device, data_loader_train, optimizer, epoch)
|
||||
test(model, device, data_loader_test)
|
||||
continue
|
||||
# 选择一个训练集中植入后门的数据,测试后门是否有效
|
||||
'''
|
||||
sample, label = next(iter(data_loader_train))
|
||||
print(sample.size()) # [64, 1, 28, 28]
|
||||
print(label[0])
|
||||
# 可视化
|
||||
plt.imshow(sample[0][0])
|
||||
plt.show()
|
||||
model.load_state_dict(torch.load('badnets.pth'))
|
||||
model.eval()
|
||||
sample = sample.to(device)
|
||||
output = model(sample)
|
||||
print(output[0])
|
||||
pred = output.argmax(dim=1)
|
||||
print(pred[0])
|
||||
'''
|
||||
# 攻击成功率 99.66%
|
||||
for i in range(len(test_loader)):
|
||||
test_loader.data[i][26][26] = 255
|
||||
test_loader.data[i][25][25] = 255
|
||||
test_loader.data[i][24][26] = 255
|
||||
test_loader.data[i][26][24] = 255
|
||||
test_loader.targets[i] = 9
|
||||
data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,
|
||||
batch_size=64,
|
||||
shuffle=False,
|
||||
num_workers=0)
|
||||
test(model, device, data_loader_test2)
|
||||
plt.imshow(test_loader.data[0].numpy())
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
main()
|
After Width: | Height: | Size: 6.2 KiB |
After Width: | Height: | Size: 174 B |