过滤规则

Net<Dtype>::StateMeetsRule函数

作用:StateMeetsRule()中net的state是否满足NetStaterule

用构造net时的输入phase/level/stage与prototxt中各层的规则(include/exclude)比较,决定本层是否要包含在net中判断rule是否相同,分为5个判断

1. Phase: train, test, 比如train的layer不适用于test

2. Min_level:本层level不小于min_level,则满足包含条件

3. Max_level:本层leve不大于max_leve,则满足包含条件

4. Stage: stage能在NetStateRule::stage中找到,则包含本层

5. Non_stage: stages能在NetStateRule::non_stage中找到,则排除本层

解释

在caffe中,所有参数结构定义在caffe.proto中,由protobuf的protoc.exe生成caffe.pb.c及caffe.pb.h,从而对数据结构结构进行管理。在使用时,网络结构往往会定义在一个<project_name>.prototxt的文件中。在定义net网络结构的prototxt文件中往往会定义某层的include/exclude参数,以确定该层网络是否要包含在某些具体的结构中或排除在外。顾名思义,include表示如果在构造net时如果满足include的条件,本层就包含在net中;exclude表示在构造net时如果满足exclude条件,本层就不会包含在net中。

管理这个被读取后的include还是exclude参数的,就是caffe.proto中的NetStateRule类,类中有phase、min_level、max_level、stage、not_stage 5个参数,也就是我们所说的过滤得规则。这些过滤规则往往是在网络构造时传入的(即:构造net时的输入参数),可用如下的方法来构造一个新net:

Net<Dtype>::Net(const string& param_file, Phase phase, const int level, const vector<string>* stages, const Net* root_net)

对于包含include参数的层:如果满足min_level<level<max_level 或 stages中任意一个元素能在NetStateRule::stage中找到, 该层就会被保留在net中

对于包含exclude参数的层:如果满足min_level<level<max_level 或 stages中任意一个元素能在NetStateRule::stage中找到, 该层就会从net中剔除

当然如果是在NetStateRule::not_stage中找到, 结果正好相反,看下面的列子,

layer {

  name: "mnist"

  type: "Data"

  top: "data"

  top: "label"

  include {

  phase: TEST

    not_stage: "predict" # 在 predict 时过滤掉这一层

  }

  transform_param {

    scale: 0.00390625

  }

  data_param {

    source: "examples/mnist/mnist_test_lmdb"

    batch_size: 100

    backend: LMDB

  }

}

# 增加 deploy 的输入层

layer {

  name: "data"

  type: "Input"

  top: "data"

  input_param { shape: { dim: 1 dim: 1 dim: 28 dim: 28 } }

  exclude {

    phase: TEST

    stage: "predict" # 在 predict 时不加上这一层

  }

}

如果想进一步了解对参数进行过滤有什么实际用处,我推荐这篇文章< Caffe 神经网络配置 - All in one network >:

​https://yangwenbo.com/articles/caffe-net-config-all-in-one.html?utm_source=tuicool&utm_medium=referral​

源码注释

template <typename Dtype>
bool Net<Dtype>::StateMeetsRule(const NetState& state,
const NetStateRule& rule, const string& layer_name) {
// Check whether the rule is broken due to phase.
if (rule.has_phase()) {
if (rule.phase() != state.phase()) {
LOG_IF(INFO, Caffe::root_solver())
<< "The NetState phase (" << state.phase()
<< ") differed from the phase (" << rule.phase()
<< ") specified by a rule in layer " << layer_name;
return false;
}
}
// Check whether the rule is broken due to min level.
if (rule.has_min_level()) {
if (state.level() < rule.min_level()) {
LOG_IF(INFO, Caffe::root_solver())
<< "The NetState level (" << state.level()
<< ") is above the min_level (" << rule.min_level()
<< ") specified by a rule in layer " << layer_name;
return false;
}
}
// Check whether the rule is broken due to max level.
if (rule.has_max_level()) {
if (state.level() > rule.max_level()) {
LOG_IF(INFO, Caffe::root_solver())
<< "The NetState level (" << state.level()
<< ") is above the max_level (" << rule.max_level()
<< ") specified by a rule in layer " << layer_name;
return false;
}
}
// Check whether the rule is broken due to stage. The NetState must
// contain ALL of the rule's stages to meet it.
for (int i = 0; i < rule.stage_size(); ++i) {
//net构造时输入的stage中只要有一个符合stage规则,就表明本层被include
// Check that the NetState contains the rule's ith stage.
bool has_stage = false;
for (int j = 0; !has_stage && j < state.stage_size(); ++j) {
if (rule.stage(i) == state.stage(j)) { has_stage = true; }
}
if (!has_stage) {
LOG_IF(INFO, Caffe::root_solver())
<< "The NetState did not contain stage '" << rule.stage(i)
<< "' specified by a rule in layer " << layer_name;
return false;
}
}
// Check whether the rule is broken due to not_stage. The NetState must
// contain NONE of the rule's not_stages to meet it.
for (int i = 0; i < rule.not_stage_size(); ++i) {
//net构造时输入的stage中只要有一个符合not_stage规则,就表明本层被exclude
// Check that the NetState contains the rule's ith not_stage.
bool has_stage = false;
for (int j = 0; !has_stage && j < state.stage_size(); ++j) {
if (rule.not_stage(i) == state.stage(j)) { has_stage = true; }
}
if (has_stage) {
LOG_IF(INFO, Caffe::root_solver())
<< "The NetState contained a not_stage '" << rule.not_stage(i)
<< "' specified by a rule in layer " << layer_name;
return false;
}
}
return true;
}


网络层的过滤

Net<Dtype>::FilterNet

作用:把模型参数文件中不符合当前阶段规则的层去掉

到这里就比较容易理解了,FilterNet()根据当前给定的phase/level/stage,移除指定层

这些规则往往是在prototxt文件中引入的,例如某个网络层设置为

layer {

  name: "accuracy"

  type: "Accuracy"

  bottom: "ip2"

  bottom: "label"

  top: "accuracy"

  include {

    phase: TEST

  }

}

那么该网络只有在TEST时才会被引入。

又如Test阶段只用网络的前向,需要将设置为phase:Train的层去掉 

源码

template <typename Dtype>
void Net<Dtype>::FilterNet(const NetParameter& param,
NetParameter* param_filtered) {
NetState net_state(param.state());
param_filtered→CopyFrom(param);
//下面先清除所有的layers, 然后根据规则重新添加layers
param_filtered→clear_layer();
for (int i = 0; i < param.layer_size(); ++i) {
const LayerParameter& layer_param = param.layer(i);
const string& layer_name = layer_param.name();
//include和exclude不能同时存在
CHECK(layer_param.include_size() == 0 || layer_param.exclude_size() == 0)
<< "Specify either include rules or exclude rules; not both.";
// 下面的解释:如果include_size为0,默认是include, 所以 layer_included=true
// If no include rules are specified, the layer is included by default and
// only excluded if it meets one of the exclude rules.
bool layer_included = (layer_param.include_size() == 0);

for (int j = 0; layer_included && j < layer_param.exclude_size(); ++j) {
//net_state是由构造net时的输入参数组成(phase/stage/level),
// 参考void Solver<Dtype>::InitTrainNet()及Net<Dtype>::StateMeetsRule
//layer_param.exclude是在prototxt中设置的某层的exclude的参数
// (max_level/min_level/stage/not_stage/phase);
// 满足if条件就说明,本层要被exclude;
if (StateMeetsRule(net_state, layer_param.exclude(j), layer_name)) {
// 如果不包含include,只要meet一个include_size(idx)即可
layer_included = false;
}
}
for (int j = 0; !layer_included && j < layer_param.include_size(); ++j) {
// 满足条件就说明,本层要被include
if (StateMeetsRule(net_state, layer_param.include(j), layer_name)) {
//如果包含include,只要符合一个include_size(idx)即可
layer_included = true;
}
}
if (layer_included) {
param_filtered->add_layer()->CopyFrom(layer_param);
}
}
}