3120241305/main.py

96 lines
5.4 KiB
Python
Raw Normal View History

2024-11-05 17:53:04 +08:00
import argparse
import os
import pathlib
import re
import time
import datetime
import pandas as pd
import torch
from torch.utils.data import DataLoader
from dataset import build_poisoned_training_set, build_testset
from deeplearning import evaluate_badnets, optimizer_picker, train_one_epoch
from models import BadNet
parser = argparse.ArgumentParser(description='Reproduce the basic backdoor attack in "Badnets: Identifying vulnerabilities in the machine learning model supply chain".')
parser.add_argument('--dataset', default='MNIST', help='Which dataset to use (MNIST or CIFAR10, default: MNIST)')
parser.add_argument('--nb_classes', default=10, type=int, help='number of the classification types')
parser.add_argument('--load_local', action='store_true', help='train model or directly load model (default true, if you add this param, then load trained local model to evaluate the performance)')
parser.add_argument('--loss', default='mse', help='Which loss function to use (mse or cross, default: mse)')
parser.add_argument('--optimizer', default='sgd', help='Which optimizer to use (sgd or adam, default: sgd)')
parser.add_argument('--epochs', default=100, help='Number of epochs to train backdoor model, default: 100')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size to split dataset, default: 64')
parser.add_argument('--num_workers', type=int, default=0, help='Batch size to split dataset, default: 64')
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate of the model, default: 0.001')
parser.add_argument('--download', action='store_true', help='Do you want to download data ( default false, if you add this param, then download)')
parser.add_argument('--data_path', default='./data/', help='Place to load dataset (default: ./dataset/)')
parser.add_argument('--device', default='cuda:0', help='device to use for training / testing (cpu, or cuda:1, default: cpu)')
# poison settings
parser.add_argument('--poisoning_rate', type=float, default=0.5, help='poisoning portion (float, range from 0 to 1, default: 0.1)')
parser.add_argument('--trigger_label', type=int, default=0, help='The NO. of trigger label (int, range from 0 to 10, default: 0)')
parser.add_argument('--trigger_path', default="./triggers/trigger_10.png", help='Trigger Path (default: ./triggers/trigger_white.png)')
parser.add_argument('--trigger_size', type=int, default=5, help='Trigger Size (int, default: 5)')
args = parser.parse_args()
def main():
print("{}".format(args).replace(', ', ',\n'))
if re.match('cuda:\d', args.device):
cuda_num = args.device.split(':')[1]
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_num
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if you're using MBP M1, you can also use "mps"
# create related path
pathlib.Path("./checkpoints/").mkdir(parents=True, exist_ok=True)
pathlib.Path("./logs/").mkdir(parents=True, exist_ok=True)
print("\n# load dataset: %s " % args.dataset)
dataset_train, args.nb_classes = build_poisoned_training_set(is_train=True, args=args)
dataset_val_clean, dataset_val_poisoned = build_testset(is_train=False, args=args)
data_loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
data_loader_val_clean = DataLoader(dataset_val_clean, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
data_loader_val_poisoned = DataLoader(dataset_val_poisoned, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # shuffle 随机化
model = BadNet(input_channels=dataset_train.channels, output_num=args.nb_classes).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optimizer_picker(args.optimizer, model.parameters(), lr=args.lr)
basic_model_path = "./checkpoints/badnet-%s.pth" % args.dataset
start_time = time.time()
if args.load_local:
print("## Load model from : %s" % basic_model_path)
model.load_state_dict(torch.load(basic_model_path), strict=True)
test_stats = evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device)
print(f"Test Clean Accuracy(TCA): {test_stats['clean_acc']:.4f}")
print(f"Attack Success Rate(ASR): {test_stats['asr']:.4f}")
else:
print(f"Start training for {args.epochs} epochs")
stats = []
for epoch in range(args.epochs):
train_stats = train_one_epoch(data_loader_train, model, criterion, optimizer, args.loss, device)
test_stats = evaluate_badnets(data_loader_val_clean, data_loader_val_poisoned, model, device)
print(f"# EPOCH {epoch} loss: {train_stats['loss']:.4f} Test Acc: {test_stats['clean_acc']:.4f}, ASR: {test_stats['asr']:.4f}\n")
# save model
torch.save(model.state_dict(), basic_model_path)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
}
# save training stats
stats.append(log_stats)
df = pd.DataFrame(stats)
df.to_csv("./logs/%s_trigger%d.csv" % (args.dataset, args.trigger_label), index=False, encoding='utf-8')
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == "__main__":
main()