本篇文章主要是介紹訓練模型的一些參數配置信息,可以看出在訓練腳本train.py中主要是調用train_net.py腳本中的train_net函數進行訓練的,因此這一篇博客介紹train_net.py腳本的內容。
train_net.py這個腳本一共包含convert_pretrained,get_lr_scheduler,train_net三個函數,其中重要的是train_net函數,這個函數也是train.py腳本訓練模型時候調用的函數,建議從train_net函數開始看起。import tools.find_mxnet
import mxnet as mx
import logging
import sys
import os
import importlib
import re
# 導入生成模型可用的數據格式的類,是在dataset文件夾下的iterator.py腳本中實現的,
# 一般采用這種導入腳本中類的方式需要在dataset文件夾下寫一個空的__init__.py腳本才能導入
from dataset.iterator import DetRecordIter
from train.metric import MultiBoxMetric # 導入訓練時候的評價標準類
# 導入測試時候的評價標準類,這里VOC07MApMetric類繼承了MApMetric類,主要內容在MApMetric類中
from evaluate.eval_metric import MApMetric, VOC07MApMetric
from config.config import cfg
from symbol.symbol_factory import get_symbol_train # get_symbol_train函數來導入symbol
def convert_pretrained(name, args):
"""
Special operations need to be made due to name inconsistance, etc
Parameters:
---------
name : str
pretrained model name
args : dict
loaded arguments
Returns:
---------
processed arguments as dict
"""
return args
# get_lr_scheduler函數就是設計你的學習率變化策略,函數的幾個輸入的意思在這里都介紹得很清楚了,
# lr_refactor_step可以是3或6這樣的單獨數字,也可以是3,6,9這樣用逗號間隔的數字,表示到第3,6,9個epoch的時候就要改變學習率
def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
num_example, batch_size, begin_epoch):
"""
Compute learning rate and refactor scheduler
Parameters:
---------
learning_rate : float
original learning rate
lr_refactor_step : comma separated str
epochs to change learning rate
lr_refactor_ratio : float
lr *= ratio at certain steps
num_example : int
number of training images, used to estimate the iterations given epochs
batch_size : int
training batch size
begin_epoch : int
starting epoch
Returns:
---------
(learning_rate, mx.lr_scheduler) as tuple
"""
assert lr_refactor_ratio > 0
iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
# 學習率的改變一般都是越來越小,不接受學習率越來越大這種策略,在這種情況下采用學習率不變的策略
if lr_refactor_ratio >= 1:
return (learning_rate, None)
else:
lr = learning_rate
epoch_size = num_example // batch_size # 表示每個epoch少包含多少個batch
# 這個for循環的內容主要是解決當你設置的begin_epoch要大于你的iter_refactor的某些值的時候,
# 會按照lr_refactor_ratio改變你的初始學習率,也就是說這個改變是還沒開始訓練的時候就做的。
for s in iter_refactor:
if begin_epoch >= s:
lr *= lr_refactor_ratio
# 如果有上面這個學習率的改變,那么打印出改變信息,這樣以后看log也能很清楚地知道當時實際初始學習率是多少。
if lr != learning_rate:
logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))
# 這個steps就是你要運行多少個batch才需要改變學習率,因此這個steps是以batch為單位的
steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]
# 這個if條件滿足的話就表示我的begin_epoch比你設置的iter_refactor里面的所有值都大,那么我就返回學習率lr,
# 至于更改的策略就只能是None了,也就是說用這個lr一直跑到結束,中間就不改變了
if not steps:
return (lr, None)
# 終用mx.lr_scheduler.MultiFactorScheduler函數生成模型可用的lr_scheduler
lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_refactor_ratio)
return (lr, lr_scheduler)
# 這是train_net.py腳本中的主要函數
def train_net(net, train_path, num_classes, batch_size,
data_shape, mean_pixels, resume, finetune, pretrained, epoch,
prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
momentum, weight_decay, lr_refactor_step, lr_refactor_ratio,
freeze_layer_pattern='',
num_example=10000, label_pad_width=350,
nms_thresh=0.45, force_nms=False, ovp_thresh=0.5,
use_difficult=False, class_names=None,
voc07_metric=False, nms_topk=400, force_suppress=False,
train_list="", val_path="", val_list="", iter_monitor=0,
monitor_pattern=".*", log_file=None):
"""
Wrapper for training phase.
Parameters:
----------
net : str
symbol name for the network structure
train_path : str
record file path for training
num_classes : int
number of object classes, not including background
batch_size : int
training batch-size
data_shape : int or tuple
width/height as integer or (3, height, width) tuple
mean_pixels : tuple of floats
mean pixel values for red, green and blue
resume : int
resume from previous checkpoint if > 0
finetune : int
fine-tune from previous checkpoint if > 0
pretrained : str
prefix of pretrained model, including path
epoch : int
load epoch of either resume/finetune/pretrained model
prefix : str
prefix for saving checkpoints
ctx : [mx.cpu()] or [mx.gpu(x)]
list of mxnet contexts
begin_epoch : int
starting epoch for training, should be 0 if not otherwise specified
end_epoch : int
end epoch of training
frequent : int
frequency to print out training status
learning_rate : float
training learning rate
momentum : float
trainig momentum
weight_decay : float
training weight decay param
lr_refactor_ratio : float
multiplier for reducing learning rate
lr_refactor_step : comma separated integers
at which epoch to rescale learning rate, e.g. '30, 60, 90'
freeze_layer_pattern : str
regex pattern for layers need to be fixed
num_example : int
number of training images
label_pad_width : int
force padding training and validation labels to sync their label widths
nms_thresh : float
non-maximum suppression threshold for validation
force_nms : boolean
suppress overlaped objects from different classes
train_list : str
list file path for training, this will replace the embeded labels in record
val_path : str
record file path for validation
val_list : str
list file path for validation, this will replace the embeded labels in record
iter_monitor : int
monitor internal stats in networks if > 0, specified by monitor_pattern
monitor_pattern : str
regex pattern for monitoring network stats
log_file : str
log to file if enabled
"""
# set up logger
# 這部分內容和生成日志文件相關,依賴logging這個庫,if條件中的log_file就是生成的log文件的路徑和名稱。
# 這個logger是RootLogger類型,可以用來輸出提示信息,
# 用法例子:logger.info("Start finetuning with {} from epoch {}".format(ctx, epoch))
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if log_file:
fh = logging.FileHandler(log_file)
logger.addHandler(fh)
# check args
# 這一部分主要是檢查一些配置參數是不是異常,比如你的data_shape必須是個int型等
if isinstance(data_shape, int):
data_shape = (3, data_shape, data_shape)
assert len(data_shape) == 3 and data_shape[0] == 3
if prefix.endswith('_'):
prefix += '_' + str(data_shape[1])
if isinstance(mean_pixels, (int, float)):
mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
assert len(mean_pixels) == 3, "must provide all RGB mean values"
# 這里的train_iter是通過調用dataset文件夾下的iterator.py腳本中的DetRecordIter類來得到的,
# 簡單講就是從.rec和.lst文件到模型可以用的數據迭代器的過程。輸入中train_path是你的.rec文件的路徑,
# label_pad_width這個參數在文中的解釋是force padding training and validation labels to sync their labels widths,
# train_list是空字符串。
train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)
# 如果你給了驗證集數據的路徑,那么也生成驗證集數據迭代器,做法和前面訓練集的做法一樣
if val_path:
val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)
else:
val_iter = None
# load symbol
# 這里調用了symbol文件夾下的symbol_factory.py腳本的get_symbol_train函數來導入symbol。這個函數的輸入中,net是一個str,
# 代碼中默認是‘vgg16_reduced’,data_shape是一個tuple,是在前面計算得到的,比如data_shape是(3,300,300),num_classes就是類別數,
# 在VOC數據集中,num_classes就是20,nms_thresh是nms操作的參數,默認是0.45,
# force_suppress和nms_topk兩個參數都是采用默認的False和400。
# 這個函數的輸出net就是終的檢測網絡,是一個symbol。
net = get_symbol_train(net, data_shape[1], num_classes=num_classes,
nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)
# define layers with fixed weight/bias
# 這一步是設計一些層的參數在模型訓練過程中不變,freeze_layer_pattern是在train.py里面設置的一個參數,表示要將哪些層的參數固定。
# 得到的fixed_param_names就是一個list,其中的每個元素就是層參數的名稱,比如conv1_1_weight,是一個str。
if freeze_layer_pattern.strip():
re_prog = re.compile(freeze_layer_pattern)
fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
else:
fixed_param_names = None
# load pretrained or resume from previous state
# resume是指你在訓練檢測模型的時候如果訓練到一半但是中斷了,想要從中斷的epoch繼續訓練,
# 那么可以導入訓練中斷前的那個epoch的.param文件,
# 這個文件就是檢測模型的參數,從而用這個參數初始化檢測模型,達到斷點繼續訓練的目的。
ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
if resume > 0:
logger.info("Resume training with {} from epoch {}"
.format(ctx_str, resume))
_, args, auxs = mx.model.load_checkpoint(prefix, resume)
begin_epoch = resume
elif finetune > 0:
logger.info("Start finetuning with {} from epoch {}"
.format(ctx_str, finetune))
_, args, auxs = mx.model.load_checkpoint(prefix, finetune)
begin_epoch = finetune
# check what layers mismatch with the loaded parameters
exe = net.simple_bind(mx.cpu(), data=(1, 3, 300, 300), label=(1, 1, 5), grad_req='null')
arg_dict = exe.arg_dict
fixed_param_names = []
for k, v in arg_dict.items():
if k in args:
if v.shape != args[k].shape:
del args[k]
logging.info("Removed %s" % k)
else:
if not 'pred' in k:
fixed_param_names.append(k)
# 這個if條件是導入預訓練好的分類模型來初始化檢測模型的參數,其中mxnet.model.checkpoint就是執行這個導入參數的作用,
# 生成的_是分類模型的網絡,args是分類模型的參數,類型是dictionary,每個item表示一個層參數,item的內容就是一個參數的NDArray格式。
# auxs在這里是一個空字典。調用的這個convert_pretrained函數就是該腳本定義的個函數,直接return args,沒做什么操作。
elif pretrained:
logger.info("Start training with {} from pretrained model {}"
.format(ctx_str, pretrained))
_, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
args = convert_pretrained(pretrained, args)
else:
logger.info("Experimental: start training from scratch with {}"
.format(ctx_str))
args = None
auxs = None
fixed_param_names = None
# helper information
# 這一部分將前面得到的要固定參數的層信息打印出來
if fixed_param_names:
logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')
# init training module
# 調用mx.mod.Module類初始化一個模型。參數中net就是前面通過get_symbol_train函數導入的檢測模型的symbol。
# logger是和日志相關的參數。ctx就是你訓練模型時候的cpu或gpu選擇。初始化model的時候就要指定要固定的參數。
mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
fixed_param_names=fixed_param_names)
# fit parameters
# 這個frequent就是你每隔frequent個batch顯示一次訓練結果(比如損失,準確率等等),代碼中frequent采用20。
batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)
# prefix是一個指定的路徑,生成的epoch_end_callback作為fit()函數的參數之一,用來指定生成的模型的存放地址。
epoch_end_callback = mx.callback.do_checkpoint(prefix)
# 調用get_lr_scheduler()函數生成初始的學習率和學習率變化策略,這個get_lr_scheduler()函數在前面有詳細介紹
learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
lr_refactor_ratio, num_example, batch_size, begin_epoch)
# 定義優化器的一些參數,比如學習率;momentum(該參數是在sgd算法中計算下一步更新方向時候會用到,默認是0.9);
# wd是正則項的系數,一般采用0.0001到0.0005,代碼中默認是0.0005;lr_scheduler是學習率的更新策略,
# 比如你間隔20個epoch就把學習率降為原來的0.1倍等;
# rescale_grad參數如果你是一塊GPU跑,就是默認的1,如果是多GPU,那么相當于在做梯度更新的時候需要合并多個GPU的結果,
# 這里ctx就是代表你是用cpu還是gpu,以及gpu的話是采用哪幾塊gpu。
optimizer_params={'learning_rate':learning_rate,
'momentum':momentum,
'wd':weight_decay,
'lr_scheduler':lr_scheduler,
'clip_gradient':None,
'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }
# 這個monitor一般是調試時候采用,默認訓練模型的時候這個monitor是None,也就是iter_monitor默認是0
monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None
# run fit net, every n epochs we run evaluation network to get mAP
# 這一步是對評價指標的選擇,腳本中中默認采用voc07_metric,ovp_thresh默認是0.5,
# 表示計算MAp時類別相同的預測框和真實框的IOU值的閾值。
if voc07_metric:
valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
else:
valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
# 模型訓練的入口,這個mod只有檢測網絡的結構信息,而fit的arg_params參數則是指定了用來初始化這個檢測模型的參數,
# 這些參數來自預訓練好的分類模型。
# 如果你在調試模型的時候運行到fit這個函數,進入這個函數的話就進入到mxnet項目的base_module.py腳本,
# 里面包含了參數初始化和模型前后向的具體操作。
mod.fit(train_iter, # 訓練數據
val_iter, # 測試數據
eval_metric=MultiBoxMetric(), # 訓練時的評價指標
validation_metric=valid_metric, # 測試時的評價指標
# 每多少個batch顯示結果,這個batch_end_callback參數是由mx.callback.Speedometer()函數生成的,
# 這個函數的輸入包括batch_size和間隔
batch_end_callback=batch_end_callback,
# 每個epoch結束后,得到的.param文件存放地址,這個epoch_end_callback由mx.callback,do_checkpoint()函數生成,
# 這個函數的輸入就是存放地址。
epoch_end_callback=epoch_end_callback,
optimizer='sgd', # 優化算法采用sgd,也就是隨機梯度下降
optimizer_params=optimizer_params, # 優化器的一些參數
begin_epoch=begin_epoch, # epoch的初始值
num_epoch=end_epoch, # 一共要訓練多少個epoch
initializer=mx.init.Xavier(), # 其他參數的初始化方式
arg_params=args, # 導入的模型的參數,就是你預訓練的模型的參數
aux_params=auxs, # 導入的模型的參數的均值方差
allow_missing=True, # 是否允許一些參數缺失
monitor=monitor) # 如果monitor為None的話,就沒什么用了,因為fit()函數默認monitor參數為None
本篇文章介紹了SSD算法的整體架構,旨在從宏觀上加深對該算法的認識。從上面的代碼介紹可以看出,在train_net函數中關于網絡結構的構建是通過symbol_factory.py腳本的get_symbol_train函數進行的,因為網絡結構的構建是SSD算法的核心
粵嵌科技13年專注IT人才培訓學習的專業機構,主要培訓課程為,嵌入式培訓、Java培訓、Unity游戲開發、Python人工智能、HTML5前端開發、全棧UI設計、網絡營銷、CCIE網絡等專業課程