forked from jiuyuan/InfiniTensor
feat: 优化转置写法
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
f78ff0e8ee
commit
3631b03e73
|
@ -0,0 +1,22 @@
|
|||
#pragma once
|
||||
|
||||
#include "../graph.h"
|
||||
|
||||
namespace optimization {
|
||||
|
||||
class Conv {
|
||||
Operator const &op;
|
||||
|
||||
public:
|
||||
explicit Conv(Operator &op) : op(op) {}
|
||||
explicit Conv(Operator const &op) : op(op) {}
|
||||
|
||||
Arc<Tensor> const &input() const { return op.inputs.at(0); }
|
||||
Arc<Tensor> const &weight() const { return op.inputs.at(1); }
|
||||
Arc<Tensor> const &delations() const { return op.inputs.at(2); }
|
||||
Arc<Tensor> const &pads() const { return op.inputs.at(3); }
|
||||
Arc<Tensor> const &strides() const { return op.inputs.at(4); }
|
||||
Arc<Tensor> const &output() const { return op.outputs.at(0); }
|
||||
};
|
||||
|
||||
} // namespace optimization
|
|
@ -1,4 +1,7 @@
|
|||
#include "single_operator.h"
|
||||
#include "../operator/conv.h"
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
|
||||
using namespace optimization;
|
||||
|
@ -14,11 +17,28 @@ optimization::pass::partition(Unigraph &&g) {
|
|||
return ans;
|
||||
}
|
||||
|
||||
static Vec<size_t> transpose(Vec<size_t> const &shape,
|
||||
Vec<size_t> const &permute) {
|
||||
Vec<size_t> ans(shape.size());
|
||||
for (auto i = 0; i < ans.size(); ++i)
|
||||
ans[i] = shape[permute[i]];
|
||||
// 1st: new shape
|
||||
// 2nd: permutation
|
||||
static std::pair<Vec<size_t>, Vec<size_t>> // fmt: new line
|
||||
transpose( //
|
||||
Vec<size_t> const &shape, //
|
||||
char const *src, // source tensor layout
|
||||
char const *tgt // target tensor layout
|
||||
) {
|
||||
// assert( shape.size() == str_len(src) == str_len(tgt) )
|
||||
std::map<char, size_t> indices;
|
||||
|
||||
for (size_t i = 0; i < shape.size(); ++i)
|
||||
indices[src[i]] = i;
|
||||
|
||||
auto ans = std::make_pair( // fmt: new line
|
||||
Vec<size_t>(shape.size()), // shape
|
||||
Vec<size_t>(shape.size()) // permutation
|
||||
);
|
||||
|
||||
for (auto i = 0; i < shape.size(); ++i)
|
||||
ans.first[i] = shape[ans.second[i] = indices[tgt[i]]];
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
@ -30,11 +50,11 @@ Vec<Unigraph> optimization::pass::mutate( // fmt: new line
|
|||
auto const &op = g.operators.front();
|
||||
switch (op.op_type) {
|
||||
case OpType::Conv: {
|
||||
auto const &i_shape = op.inputs.at(0)->shape;
|
||||
auto const &w_shape = op.inputs.at(1)->shape;
|
||||
auto const &dilations = op.inputs.at(2)->data.cpu_data;
|
||||
auto const &pads = op.inputs.at(3)->data.cpu_data;
|
||||
auto const &strides = op.inputs.at(4)->data.cpu_data;
|
||||
auto const conv = Conv(g.operators.front());
|
||||
auto const &i_shape = conv.input()->shape;
|
||||
auto const &w_shape = conv.weight()->shape;
|
||||
auto const &dilations = conv.delations()->data.cpu_data;
|
||||
auto const &strides = conv.strides()->data.cpu_data;
|
||||
if (w_shape.rbegin()[0] == 1 // fmt: new line
|
||||
&& w_shape.rbegin()[1] == 1 //
|
||||
&& std::all_of(strides.begin(), strides.end(),
|
||||
|
@ -42,47 +62,55 @@ Vec<Unigraph> optimization::pass::mutate( // fmt: new line
|
|||
// 1x1 conv
|
||||
auto &mutant = ans.emplace_back();
|
||||
|
||||
// assert(op.inputs.at(0)->data_type == op.inputs.at(1)->data_type);
|
||||
auto dt = op.inputs.at(0)->data_type;
|
||||
// assert(conv.input()->data_type == conv.weight()->data_type);
|
||||
auto const dt = conv.input()->data_type;
|
||||
|
||||
// (input, "nchw"->"nhwc") -|transpose|-> t0 -|reshape|-> t1
|
||||
// (input, "nchw"->"nhwc") -|transpose|-> reshape -|reshape|-> t0
|
||||
Arc<Tensor> t0;
|
||||
{
|
||||
auto [shape_, permute_] = transpose(i_shape, "nchw", "nhwc");
|
||||
auto tranpose = Tensor::share(std::move(shape_), dt, {});
|
||||
auto permutation = Tensor::share_vec(std::move(permute_));
|
||||
mutant.push_operator(OpType::Transpose,
|
||||
{conv.input(), std::move(permutation)},
|
||||
{tranpose});
|
||||
mutant.push_operator(
|
||||
OpType::Reshape, {std::move(tranpose)},
|
||||
{t0 = Tensor::share(
|
||||
{shape_[0] * shape_[1] * shape_[2], shape_[3]}, dt,
|
||||
{})});
|
||||
}
|
||||
|
||||
// (weight,"fcrs"->"cfrs") -|transpose|-> reshape -|reshape|-> t1
|
||||
Arc<Tensor> t1;
|
||||
{
|
||||
Vec<size_t> nhwc{0, 2, 3, 1};
|
||||
auto t0 = Tensor::share(transpose(i_shape, nhwc), dt, {});
|
||||
auto [shape_, permute_] = transpose(w_shape, "fcrs", "cfrs");
|
||||
auto tranpose = Tensor::share(std::move(shape_), dt, {});
|
||||
auto permutation = Tensor::share_vec(std::move(permute_));
|
||||
mutant.push_operator(OpType::Transpose,
|
||||
{op.inputs.at(0), Tensor::share_vec(nhwc)},
|
||||
{t0});
|
||||
t1 = Tensor::share(
|
||||
{i_shape[0] * i_shape[2] * i_shape[3], i_shape[1]}, dt, {});
|
||||
mutant.push_operator(OpType::Reshape, {std::move(t0)}, {t1});
|
||||
{conv.weight(), std::move(permutation)},
|
||||
{tranpose});
|
||||
mutant.push_operator(
|
||||
OpType::Reshape, {std::move(tranpose)},
|
||||
{t1 = Tensor::share(
|
||||
{shape_[0], shape_[1] * shape_[2] * shape_[3]}, dt,
|
||||
{})});
|
||||
}
|
||||
|
||||
// (weight,"fcrs"->"cfrs") -|transpose|-> t2 -|reshape|-> t3
|
||||
Arc<Tensor> t3;
|
||||
{
|
||||
Vec<size_t> cfrs{1, 0, 2, 3};
|
||||
auto t2 = Tensor::share(transpose(w_shape, cfrs), dt, {});
|
||||
mutant.push_operator(OpType::Transpose,
|
||||
{op.inputs.at(1), Tensor::share_vec(cfrs)},
|
||||
{t2});
|
||||
t3 = Tensor::share(
|
||||
{w_shape[1], w_shape[0] * w_shape[2] * w_shape[3]}, dt, {});
|
||||
mutant.push_operator(OpType::Reshape, {std::move(t2)}, {t3});
|
||||
}
|
||||
// (t0,t1) -|matmul|-> t2
|
||||
auto t2 = Tensor::share({t0->shape[0], t1->shape[1]}, dt, {});
|
||||
mutant.push_operator(OpType::MatMul, {t0, t1}, {t2});
|
||||
|
||||
// (t1,t3) -|matmul|-> t4
|
||||
auto t4 = Tensor::share({t1->shape[0], t3->shape[1]}, dt, {});
|
||||
mutant.push_operator(OpType::MatMul, {t1, t3}, {t4});
|
||||
|
||||
// (t4,"nhwf"->"nfhw") -|transpose|-> t5 -|reshape|-> output
|
||||
// (t2,"nhwf"->"nfhw") -|transpose|-> reshape -|reshape|-> output
|
||||
{
|
||||
Vec<size_t> nfhw{0, 3, 1, 2};
|
||||
auto t5 = Tensor::share(transpose(t4->shape, nfhw), dt, {});
|
||||
auto [shape_, permute_] = transpose(t2->shape, "nhwf", "nfhw");
|
||||
auto tranpose = Tensor::share(std::move(shape_), dt, {});
|
||||
auto permutation = Tensor::share_vec(std::move(permute_));
|
||||
mutant.push_operator(OpType::Transpose,
|
||||
{t4, Tensor::share_vec(nfhw)}, {t5});
|
||||
mutant.push_operator(OpType::Reshape, {std::move(t5)},
|
||||
{op.outputs.at(0)});
|
||||
{std::move(t2), std::move(permutation)},
|
||||
{tranpose});
|
||||
mutant.push_operator(OpType::Reshape, {std::move(tranpose)},
|
||||
{conv.output()});
|
||||
}
|
||||
} else if (std::any_of(dilations.begin(), dilations.end(),
|
||||
[](auto x) { return x != 1; })) {
|
||||
|
|
Loading…
Reference in New Issue