PulseFocusPlatform/static/docs/advanced_tutorials/TRANSFER_LEARNING_cn.md

4.6 KiB
Raw Permalink Blame History

English | 简体中文

迁移学习教程

迁移学习为利用已有知识对新知识进行学习。例如利用ImageNet分类预训练模型做初始化来训练检测模型利用在COCO数据集上的检测模型做初始化来训练基于PascalVOC数据集的检测模型。

选择数据

迁移学习需要使用自己的数据集目前已支持COCO和VOC的数据标注格式tools/x2coco.py中给出了voc、labelme和cityscape标注格式转换为COCO格式的脚本具体使用方式可以参考自定义数据源。数据准备完成后在配置文件中配置数据路径对应修改reader中的路径参数即可。

  1. COCO数据集需要修改COCODataSet中的参数yolov3_darknet.yml为例修改yolov3_reader中的配置
  dataset:
    !COCODataSet
      dataset_dir: custom_data/coco # 自定义数据目录
      image_dir: train2017 # 自定义训练集目录该目录在dataset_dir中
      anno_path: annotations/instances_train2017.json # 自定义数据标注路径该目录在dataset_dir中  
      with_background: false
  1. VOC数据集需要修改VOCDataSet中的参数yolov3_darknet_voc.yml为例:
  dataset:
    !VOCDataSet
    dataset_dir: custom_data/voc # 自定义数据集目录
    anno_path: trainval.txt # 自定义数据标注路径该目录在dataset_dir中
    use_default_label: true
    with_background: false

加载预训练模型

在进行迁移学习时由于会使用不同的数据集数据类别数与COCO/VOC数据类别不同导致在加载开源模型(如COCO预训练模型)时与类别数相关的权重例如分类模块的fc层会出现维度不匹配的问题另外如果需要结构更加复杂的模型需要对已有开源模型结构进行调整对应权重也需要选择性加载。因此需要在加载模型时不加载不能匹配的权重。

在迁移学习中,对预训练模型进行选择性加载,支持如下两种迁移学习方式:

直接加载预训练权重(推荐方式

模型中和预训练模型中对应参数形状不同的参数将自动被忽略,例如:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u tools/train.py -c configs/faster_rcnn_r50_1x.yml \
                           -o pretrain_weights=https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_1x.tar

使用finetune_exclude_pretrained_params参数控制忽略参数名

可以显示的指定训练过程中忽略参数的名字,任何参数名均可加入finetune_exclude_pretrained_params中,为实现这一目的,可通过如下方式实现:

  1. 在 YMAL 配置文件中通过设置finetune_exclude_pretrained_params字段。可参考配置文件
  2. 在 train.py的启动参数中设置finetune_exclude_pretrained_params。例如:
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -u tools/train.py -c configs/faster_rcnn_r50_1x.yml \
                         -o pretrain_weights=https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_1x.tar \
                           finetune_exclude_pretrained_params=['cls_score','bbox_pred'] \
  • 说明:
  1. pretrain_weights的路径为COCO数据集上开源的faster RCNN模型链接完整模型链接可参考MODEL_ZOO
  2. finetune_exclude_pretrained_params中设置参数字段如果参数名能够匹配以上参数字段通配符匹配方式则在模型加载时忽略该参数。

如果用户需要利用自己的数据进行finetune模型结构不变只需要忽略与类别数相关的参数不同模型类型所对应的忽略参数字段如下表所示

模型类型 忽略参数字段
Faster RCNN cls_score, bbox_pred
Cascade RCNN cls_score, bbox_pred
Mask RCNN cls_score, bbox_pred, mask_fcn_logits
Cascade-Mask RCNN cls_score, bbox_pred, mask_fcn_logits
RetinaNet retnet_cls_pred_fpn
SSD ^conv2d_
YOLOv3 yolo_output