feat: 基本实现空洞卷积的变换

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-04-25 15:59:43 +08:00
parent e3428d8fd8
commit 7b23fdbbfe
2 changed files with 90 additions and 22 deletions

View File

@ -12,7 +12,7 @@ class Conv {
explicit Conv(Operator const &op) : op(op) {}
Arc<Tensor> const &input() const { return op.inputs.at(0); }
Arc<Tensor> const &weight() const { return op.inputs.at(1); }
Arc<Tensor> const &kernel() const { return op.inputs.at(1); }
Arc<Tensor> const &delations() const { return op.inputs.at(2); }
Arc<Tensor> const &pads() const { return op.inputs.at(3); }
Arc<Tensor> const &strides() const { return op.inputs.at(4); }

View File

@ -52,9 +52,11 @@ Vec<Unigraph> optimization::pass::mutate( // fmt: new line
case OpType::Conv: {
auto const conv = Conv(g.operators.front());
auto const &i_shape = conv.input()->shape;
auto const &w_shape = conv.weight()->shape;
auto const &w_shape = conv.kernel()->shape;
auto const &dilations = conv.delations()->to_vec<int64_t>();
auto const &strides = conv.strides()->to_vec<int64_t>();
// assert(conv.input()->data_type == conv.kernel()->data_type);
auto const dt = conv.input()->data_type;
if (w_shape.rbegin()[0] == 1 // fmt: new line
&& w_shape.rbegin()[1] == 1 //
&& i_shape[1] == w_shape[1] // group = 1
@ -63,60 +65,126 @@ Vec<Unigraph> optimization::pass::mutate( // fmt: new line
// 1x1 conv
auto &mutant = ans.emplace_back();
// assert(conv.input()->data_type == conv.weight()->data_type);
auto const dt = conv.input()->data_type;
// (input, "nchw"->"nhwc") -|transpose|-> reshape -|reshape|-> t0
// (input, "nchw"->"nhwc") -|transpose|-> tranposed -|reshape|-> t0
Arc<Tensor> t0;
{
auto [shape_, permute_] = transpose(i_shape, "nchw", "nhwc");
auto tranpose = Tensor::share(std::move(shape_), dt, {});
auto tranposed = Tensor::share(std::move(shape_), dt, {});
auto permutation = Tensor::share_vec(std::move(permute_));
mutant.push_operator(OpType::Transpose,
{conv.input(), std::move(permutation)},
{tranpose});
{tranposed});
mutant.push_operator(
OpType::Reshape, {std::move(tranpose)},
OpType::Reshape, {std::move(tranposed)},
{t0 = Tensor::share(
{shape_[0] * shape_[1] * shape_[2], shape_[3]}, dt,
{})});
}
// (weight,"fcrs"->"cfrs") -|transpose|-> reshape -|reshape|-> t1
// (kernel,"fcrs"->"cfrs") -|transpose|-> tranposed -|reshape|-> t1
Arc<Tensor> t1;
{
auto [shape_, permute_] = transpose(w_shape, "fcrs", "cfrs");
auto tranpose = Tensor::share(std::move(shape_), dt, {});
auto tranposed = Tensor::share(std::move(shape_), dt, {});
auto permutation = Tensor::share_vec(std::move(permute_));
mutant.push_operator(OpType::Transpose,
{conv.weight(), std::move(permutation)},
{tranpose});
{conv.kernel(), std::move(permutation)},
{tranposed});
mutant.push_operator(
OpType::Reshape, {std::move(tranpose)},
OpType::Reshape, {std::move(tranposed)},
{t1 = Tensor::share(
{shape_[0], shape_[1] * shape_[2] * shape_[3]}, dt,
{})});
{shape_[0], shape_[1] /* * shape_[2] * shape_[3] */},
dt, {})});
}
// (t0,t1) -|matmul|-> t2
auto t2 = Tensor::share({t0->shape[0], t1->shape[1]}, dt, {});
mutant.push_operator(OpType::MatMul, {t0, t1}, {t2});
// (t2,"nhwf"->"nfhw") -|transpose|-> reshape -|reshape|-> output
// (t2,"nhwf"->"nfhw") -|transpose|-> tranposed -|reshape|-> output
{
auto [shape_, permute_] = transpose(t2->shape, "nhwf", "nfhw");
auto tranpose = Tensor::share(std::move(shape_), dt, {});
auto tranposed = Tensor::share(std::move(shape_), dt, {});
auto permutation = Tensor::share_vec(std::move(permute_));
mutant.push_operator(OpType::Transpose,
{std::move(t2), std::move(permutation)},
{tranpose});
mutant.push_operator(OpType::Reshape, {std::move(tranpose)},
{tranposed});
mutant.push_operator(OpType::Reshape, {std::move(tranposed)},
{conv.output()});
}
} else if (std::any_of(dilations.begin(), dilations.end(),
[](auto x) { return x != 1; })) {
} else if (
// group = 1
i_shape[1] == w_shape[1]
// stride[*] = 1
&& std::all_of(strides.begin(), strides.end(),
[](auto x) { return x == 1; })
// dilation[*] > 1
&& std::any_of(dilations.begin(), dilations.end(),
[](auto x) { return x > 1; })) {
// dilated conv
auto &mutant = ans.emplace_back();
auto t0 = Tensor::share(
{
i_shape[0],
i_shape[1],
i_shape[2] / dilations[0],
static_cast<size_t>(dilations[0]),
i_shape[3] / dilations[1],
static_cast<size_t>(dilations[1]),
},
dt, {});
mutant.push_operator(OpType::Reshape, {conv.input()}, {t0});
auto [shape_, permute_] = transpose(t0->shape, "nc1234", "n24c13");
auto transposed = Tensor::share(shape_, dt, {});
auto permutation = Tensor::share_vec(std::move(permute_));
mutant.push_operator(OpType::Transpose,
{std::move(t0), std::move(permutation)},
{transposed});
auto t1 = Tensor::share(
{
shape_[0] * shape_[1] * shape_[2],
shape_[3],
shape_[4],
shape_[5],
},
dt, {});
mutant.push_operator(OpType::Reshape, {std::move(transposed)},
{t1});
Vec<size_t> shape__{
shape_[0] * shape_[1] * shape_[2],
w_shape[1],
conv.output()->shape[2] / shape_[1],
conv.output()->shape[3] / shape_[2],
};
auto t2 = Tensor::share(shape__, dt, {});
mutant.push_operator(OpType::Conv,
{
std::move(t1),
conv.kernel(),
Tensor::share_vec<size_t>({1, 1}),
conv.pads(),
conv.strides(),
},
{t2});
auto t3 = Tensor::share({shape_[0], shape_[1], shape_[2],
shape__[1], shape__[2], shape__[3]},
dt, {});
mutant.push_operator(OpType::Reshape, {std::move(t2)}, {t3});
auto [shape___, permute__] =
transpose(t3->shape, "n12chw", "nc1h2w");
auto transposed_ = Tensor::share(shape___, dt, {});
auto permutation_ = Tensor::share_vec(std::move(permute__));
mutant.push_operator(OpType::Transpose,
{std::move(t3), std::move(permutation_)},
{transposed_});
mutant.push_operator(OpType::Reshape, {std::move(t3)},
{conv.output()});
}
} break;