diff --git a/optimization/src/pass/single_operator.cc b/optimization/src/pass/single_operator.cc index 0ce64b38..8f26c797 100644 --- a/optimization/src/pass/single_operator.cc +++ b/optimization/src/pass/single_operator.cc @@ -18,11 +18,23 @@ Vec optimization::pass::mutate( // fmt: new line SingleOperator const & // ) { Vec ans; - - switch (g.operators.front().op_type) { - case OpType::Conv: - /* code */ - break; + auto const &op = g.operators.front(); + switch (op.op_type) { + case OpType::Conv: { + auto const &w_shape = op.inputs.at(1)->shape; + auto const &dilations = op.inputs.at(2)->data.cpu_data; + auto const &pads = op.inputs.at(3)->data.cpu_data; + auto const &strides = op.inputs.at(4)->data.cpu_data; + if (w_shape.rbegin()[0] == 1 // fmt: new line + && w_shape.rbegin()[1] == 1 // + && std::all_of(strides.begin(), strides.end(), + [](auto x) { return x == 1; })) { + // 1x1 conv + } else if (std::any_of(dilations.begin(), dilations.end(), + [](auto x) { return x != 1; })) { + // dilated conv + } + } break; default: break;