feat: 区分 1x1 卷积和空洞卷积

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-04-23 16:52:24 +08:00
parent 7d7d923e8d
commit 4a5e66b36a
1 changed files with 17 additions and 5 deletions

View File

@ -18,11 +18,23 @@ Vec<Unigraph> optimization::pass::mutate( // fmt: new line
SingleOperator const & //
) {
Vec<Unigraph> ans;
switch (g.operators.front().op_type) {
case OpType::Conv:
/* code */
break;
auto const &op = g.operators.front();
switch (op.op_type) {
case OpType::Conv: {
auto const &w_shape = op.inputs.at(1)->shape;
auto const &dilations = op.inputs.at(2)->data.cpu_data;
auto const &pads = op.inputs.at(3)->data.cpu_data;
auto const &strides = op.inputs.at(4)->data.cpu_data;
if (w_shape.rbegin()[0] == 1 // fmt: new line
&& w_shape.rbegin()[1] == 1 //
&& std::all_of(strides.begin(), strides.end(),
[](auto x) { return x == 1; })) {
// 1x1 conv
} else if (std::any_of(dilations.begin(), dilations.end(),
[](auto x) { return x != 1; })) {
// dilated conv
}
} break;
default:
break;