forked from jiuyuan/InfiniTensor
feat: 基本实现空洞卷积的变换
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
e3428d8fd8
commit
7b23fdbbfe
|
@ -12,7 +12,7 @@ class Conv {
|
|||
explicit Conv(Operator const &op) : op(op) {}
|
||||
|
||||
Arc<Tensor> const &input() const { return op.inputs.at(0); }
|
||||
Arc<Tensor> const &weight() const { return op.inputs.at(1); }
|
||||
Arc<Tensor> const &kernel() const { return op.inputs.at(1); }
|
||||
Arc<Tensor> const &delations() const { return op.inputs.at(2); }
|
||||
Arc<Tensor> const &pads() const { return op.inputs.at(3); }
|
||||
Arc<Tensor> const &strides() const { return op.inputs.at(4); }
|
||||
|
|
|
@ -52,9 +52,11 @@ Vec<Unigraph> optimization::pass::mutate( // fmt: new line
|
|||
case OpType::Conv: {
|
||||
auto const conv = Conv(g.operators.front());
|
||||
auto const &i_shape = conv.input()->shape;
|
||||
auto const &w_shape = conv.weight()->shape;
|
||||
auto const &w_shape = conv.kernel()->shape;
|
||||
auto const &dilations = conv.delations()->to_vec<int64_t>();
|
||||
auto const &strides = conv.strides()->to_vec<int64_t>();
|
||||
// assert(conv.input()->data_type == conv.kernel()->data_type);
|
||||
auto const dt = conv.input()->data_type;
|
||||
if (w_shape.rbegin()[0] == 1 // fmt: new line
|
||||
&& w_shape.rbegin()[1] == 1 //
|
||||
&& i_shape[1] == w_shape[1] // group = 1
|
||||
|
@ -63,60 +65,126 @@ Vec<Unigraph> optimization::pass::mutate( // fmt: new line
|
|||
// 1x1 conv
|
||||
auto &mutant = ans.emplace_back();
|
||||
|
||||
// assert(conv.input()->data_type == conv.weight()->data_type);
|
||||
auto const dt = conv.input()->data_type;
|
||||
|
||||
// (input, "nchw"->"nhwc") -|transpose|-> reshape -|reshape|-> t0
|
||||
// (input, "nchw"->"nhwc") -|transpose|-> tranposed -|reshape|-> t0
|
||||
Arc<Tensor> t0;
|
||||
{
|
||||
auto [shape_, permute_] = transpose(i_shape, "nchw", "nhwc");
|
||||
auto tranpose = Tensor::share(std::move(shape_), dt, {});
|
||||
auto tranposed = Tensor::share(std::move(shape_), dt, {});
|
||||
auto permutation = Tensor::share_vec(std::move(permute_));
|
||||
mutant.push_operator(OpType::Transpose,
|
||||
{conv.input(), std::move(permutation)},
|
||||
{tranpose});
|
||||
{tranposed});
|
||||
mutant.push_operator(
|
||||
OpType::Reshape, {std::move(tranpose)},
|
||||
OpType::Reshape, {std::move(tranposed)},
|
||||
{t0 = Tensor::share(
|
||||
{shape_[0] * shape_[1] * shape_[2], shape_[3]}, dt,
|
||||
{})});
|
||||
}
|
||||
|
||||
// (weight,"fcrs"->"cfrs") -|transpose|-> reshape -|reshape|-> t1
|
||||
// (kernel,"fcrs"->"cfrs") -|transpose|-> tranposed -|reshape|-> t1
|
||||
Arc<Tensor> t1;
|
||||
{
|
||||
auto [shape_, permute_] = transpose(w_shape, "fcrs", "cfrs");
|
||||
auto tranpose = Tensor::share(std::move(shape_), dt, {});
|
||||
auto tranposed = Tensor::share(std::move(shape_), dt, {});
|
||||
auto permutation = Tensor::share_vec(std::move(permute_));
|
||||
mutant.push_operator(OpType::Transpose,
|
||||
{conv.weight(), std::move(permutation)},
|
||||
{tranpose});
|
||||
{conv.kernel(), std::move(permutation)},
|
||||
{tranposed});
|
||||
mutant.push_operator(
|
||||
OpType::Reshape, {std::move(tranpose)},
|
||||
OpType::Reshape, {std::move(tranposed)},
|
||||
{t1 = Tensor::share(
|
||||
{shape_[0], shape_[1] * shape_[2] * shape_[3]}, dt,
|
||||
{})});
|
||||
{shape_[0], shape_[1] /* * shape_[2] * shape_[3] */},
|
||||
dt, {})});
|
||||
}
|
||||
|
||||
// (t0,t1) -|matmul|-> t2
|
||||
auto t2 = Tensor::share({t0->shape[0], t1->shape[1]}, dt, {});
|
||||
mutant.push_operator(OpType::MatMul, {t0, t1}, {t2});
|
||||
|
||||
// (t2,"nhwf"->"nfhw") -|transpose|-> reshape -|reshape|-> output
|
||||
// (t2,"nhwf"->"nfhw") -|transpose|-> tranposed -|reshape|-> output
|
||||
{
|
||||
auto [shape_, permute_] = transpose(t2->shape, "nhwf", "nfhw");
|
||||
auto tranpose = Tensor::share(std::move(shape_), dt, {});
|
||||
auto tranposed = Tensor::share(std::move(shape_), dt, {});
|
||||
auto permutation = Tensor::share_vec(std::move(permute_));
|
||||
mutant.push_operator(OpType::Transpose,
|
||||
{std::move(t2), std::move(permutation)},
|
||||
{tranpose});
|
||||
mutant.push_operator(OpType::Reshape, {std::move(tranpose)},
|
||||
{tranposed});
|
||||
mutant.push_operator(OpType::Reshape, {std::move(tranposed)},
|
||||
{conv.output()});
|
||||
}
|
||||
} else if (std::any_of(dilations.begin(), dilations.end(),
|
||||
[](auto x) { return x != 1; })) {
|
||||
} else if (
|
||||
// group = 1
|
||||
i_shape[1] == w_shape[1]
|
||||
// stride[*] = 1
|
||||
&& std::all_of(strides.begin(), strides.end(),
|
||||
[](auto x) { return x == 1; })
|
||||
// dilation[*] > 1
|
||||
&& std::any_of(dilations.begin(), dilations.end(),
|
||||
[](auto x) { return x > 1; })) {
|
||||
// dilated conv
|
||||
auto &mutant = ans.emplace_back();
|
||||
|
||||
auto t0 = Tensor::share(
|
||||
{
|
||||
i_shape[0],
|
||||
i_shape[1],
|
||||
i_shape[2] / dilations[0],
|
||||
static_cast<size_t>(dilations[0]),
|
||||
i_shape[3] / dilations[1],
|
||||
static_cast<size_t>(dilations[1]),
|
||||
},
|
||||
dt, {});
|
||||
mutant.push_operator(OpType::Reshape, {conv.input()}, {t0});
|
||||
|
||||
auto [shape_, permute_] = transpose(t0->shape, "nc1234", "n24c13");
|
||||
auto transposed = Tensor::share(shape_, dt, {});
|
||||
auto permutation = Tensor::share_vec(std::move(permute_));
|
||||
mutant.push_operator(OpType::Transpose,
|
||||
{std::move(t0), std::move(permutation)},
|
||||
{transposed});
|
||||
|
||||
auto t1 = Tensor::share(
|
||||
{
|
||||
shape_[0] * shape_[1] * shape_[2],
|
||||
shape_[3],
|
||||
shape_[4],
|
||||
shape_[5],
|
||||
},
|
||||
dt, {});
|
||||
mutant.push_operator(OpType::Reshape, {std::move(transposed)},
|
||||
{t1});
|
||||
|
||||
Vec<size_t> shape__{
|
||||
shape_[0] * shape_[1] * shape_[2],
|
||||
w_shape[1],
|
||||
conv.output()->shape[2] / shape_[1],
|
||||
conv.output()->shape[3] / shape_[2],
|
||||
};
|
||||
|
||||
auto t2 = Tensor::share(shape__, dt, {});
|
||||
mutant.push_operator(OpType::Conv,
|
||||
{
|
||||
std::move(t1),
|
||||
conv.kernel(),
|
||||
Tensor::share_vec<size_t>({1, 1}),
|
||||
conv.pads(),
|
||||
conv.strides(),
|
||||
},
|
||||
{t2});
|
||||
auto t3 = Tensor::share({shape_[0], shape_[1], shape_[2],
|
||||
shape__[1], shape__[2], shape__[3]},
|
||||
dt, {});
|
||||
mutant.push_operator(OpType::Reshape, {std::move(t2)}, {t3});
|
||||
|
||||
auto [shape___, permute__] =
|
||||
transpose(t3->shape, "n12chw", "nc1h2w");
|
||||
auto transposed_ = Tensor::share(shape___, dt, {});
|
||||
auto permutation_ = Tensor::share_vec(std::move(permute__));
|
||||
mutant.push_operator(OpType::Transpose,
|
||||
{std::move(t3), std::move(permutation_)},
|
||||
{transposed_});
|
||||
mutant.push_operator(OpType::Reshape, {std::move(t3)},
|
||||
{conv.output()});
|
||||
}
|
||||
} break;
|
||||
|
||||
|
|
Loading…
Reference in New Issue