3120241305/main.py

96 lines
5.4 KiB
Python
Executable File

import argparse
import os
import pathlib
import re
import time
import datetime
import pandas as pd
import torch
from torch.utils.data import DataLoader
from dataset import build_poisoned_training_set, build_testset
from deeplearning import evaluate_badnets, optimizer_picker, train_one_epoch
from models import BadNet
parser = argparse.ArgumentParser(description='Reproduce the basic backdoor attack in "Badnets: Identifying vulnerabilities in the machine learning model supply chain".')
parser.add_argument('--dataset', default='MNIST', help='Which dataset to use (MNIST or CIFAR10, default: MNIST)')
parser.add_argument('--nb_classes', default=10, type=int, help='number of the classification types')
parser.add_argument('--load_local', action='store_true', help='train model or directly load model (default true, if you add this param, then load trained local model to evaluate the performance)')
parser.add_argument('--loss', default='mse', help='Which loss function to use (mse or cross, default: mse)')
parser.add_argument('--optimizer', default='sgd', help='Which optimizer to use (sgd or adam, default: sgd)')
parser.add_argument('--epochs', default=100, help='Number of epochs to train backdoor model, default: 100')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size to split dataset, default: 64')
parser.add_argument('--num_workers', type=int, default=0, help='Batch size to split dataset, default: 64')
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate of the model, default: 0.001')
parser.add_argument('--download', action='store_true', help='Do you want to download data ( default false, if you add this param, then download)')
parser.add_argument('--data_path', default='./data/', help='Place to load dataset (default: ./dataset/)')
parser.add_argument('--device', default='cuda:0', help='device to use for training / testing (cpu, or cuda:1, default: cpu)')
# poison settings
parser.add_argument('--poisoning_rate', type=float, default=0.5, help='poisoning portion (float, range from 0 to 1, default: 0.1)')
parser.add_argument('--trigger_label', type=int, default=0, help='The NO. of trigger label (int, range from 0 to 10, default: 0)')
parser.add_argument('--trigger_path', default="./triggers/trigger_10.png", help='Trigger Path (default: ./triggers/trigger_white.png)')
parser.add_argument('--trigger_size', type=int, default=5, help='Trigger Size (int, default: 5)')
args = parser.parse_args()
def main():
print("{}".format(args).replace(', ', ',\n'))
if re.match('cuda:\d', args.device):
cuda_num = args.device.split(':')[1]
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_num
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if you're using MBP M1, you can also use "mps"
# create related path
pathlib.Path("./checkpoints/").mkdir(parents=True, exist_ok=True)
pathlib.Path("./logs/").mkdir(parents=True, exist_ok=True)
print("\n# load dataset: %s " % args.dataset)
dataset_train, args.nb_classes = build_poisoned_training_set(is_train=True, args=args)
dataset_val_clean, dataset_val_poisoned = build_testset(is_train=False, args=args)
data_loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
data_loader_val_clean = DataLoader(dataset_val_clean, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
data_loader_val_poisoned = DataLoader(dataset_val_poisoned, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # shuffle 随机化
model = BadNet(input_channels=dataset_train.channels, output_num=args.nb_classes).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optimizer_picker(args.optimizer, model.parameters(), lr=args.lr)
basic_model_path = "./checkpoints/badnet-%s.pth" % args.dataset
start_time = time.time()
if args.load_local:
print("## Load model from : %s" % basic_model_path)
model.load_state_dict(torch.load(basic_model_path), strict=True)
test_stats = evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device)
print(f"Test Clean Accuracy(TCA): {test_stats['clean_acc']:.4f}")
print(f"Attack Success Rate(ASR): {test_stats['asr']:.4f}")
else:
print(f"Start training for {args.epochs} epochs")
stats = []
for epoch in range(args.epochs):
train_stats = train_one_epoch(data_loader_train, model, criterion, optimizer, args.loss, device)
test_stats = evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device)
print(f"# EPOCH {epoch} loss: {train_stats['loss']:.4f} Test Acc: {test_stats['clean_acc']:.4f}, ASR: {test_stats['asr']:.4f}\n")
# save model
torch.save(model.state_dict(), basic_model_path)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
}
# save training stats
stats.append(log_stats)
df = pd.DataFrame(stats)
df.to_csv("./logs/%s_trigger%d.csv" % (args.dataset, args.trigger_label), index=False, encoding='utf-8')
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == "__main__":
main()