forked from jiuyuan/InfiniTensor
feat: 区分 1x1 卷积和空洞卷积
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
7d7d923e8d
commit
4a5e66b36a
|
@ -18,11 +18,23 @@ Vec<Unigraph> optimization::pass::mutate( // fmt: new line
|
|||
SingleOperator const & //
|
||||
) {
|
||||
Vec<Unigraph> ans;
|
||||
|
||||
switch (g.operators.front().op_type) {
|
||||
case OpType::Conv:
|
||||
/* code */
|
||||
break;
|
||||
auto const &op = g.operators.front();
|
||||
switch (op.op_type) {
|
||||
case OpType::Conv: {
|
||||
auto const &w_shape = op.inputs.at(1)->shape;
|
||||
auto const &dilations = op.inputs.at(2)->data.cpu_data;
|
||||
auto const &pads = op.inputs.at(3)->data.cpu_data;
|
||||
auto const &strides = op.inputs.at(4)->data.cpu_data;
|
||||
if (w_shape.rbegin()[0] == 1 // fmt: new line
|
||||
&& w_shape.rbegin()[1] == 1 //
|
||||
&& std::all_of(strides.begin(), strides.end(),
|
||||
[](auto x) { return x == 1; })) {
|
||||
// 1x1 conv
|
||||
} else if (std::any_of(dilations.begin(), dilations.end(),
|
||||
[](auto x) { return x != 1; })) {
|
||||
// dilated conv
|
||||
}
|
||||
} break;
|
||||
|
||||
default:
|
||||
break;
|
||||
|
|
Loading…
Reference in New Issue