forked from jiuyuan/InfiniTensor
fix: 改正 1x1 卷积的变换
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
325b279468
commit
c077a61681
|
@ -96,19 +96,21 @@ Vec<Unigraph> optimization::pass::mutate( // fmt: new line
|
||||||
dt, {})});
|
dt, {})});
|
||||||
}
|
}
|
||||||
|
|
||||||
// (t0,t1) -|matmul|-> t2
|
// (t0,t1) -|matmul|-> x -|reshape|-> t2
|
||||||
auto t2 = Tensor::share({t0->shape[0], t1->shape[1]}, dt, {});
|
auto x = Tensor::share({t0->shape[0], t1->shape[1]}, dt, {});
|
||||||
mutant.push_operator(OpType::MatMul, {t0, t1}, {t2});
|
mutant.push_operator(OpType::MatMul, {std::move(t0), std::move(t1)},
|
||||||
|
{x});
|
||||||
|
auto t2 = Tensor::share(
|
||||||
|
{i_shape[0], i_shape[2], i_shape[3], k_shape[0]}, dt, {});
|
||||||
|
mutant.push_operator(OpType::Reshape, {std::move(x)}, {t2});
|
||||||
|
|
||||||
// (t2,"nhwf"->"nfhw") -|transpose|-> tranposed -|reshape|-> output
|
// (t2,"nhwf"->"nfhw") -|transpose|-> output
|
||||||
{
|
{
|
||||||
auto [shape_, permute_] = transpose(t2->shape, "nhwf", "nfhw");
|
auto [shape_, permute_] = transpose(t2->shape, "nhwf", "nfhw");
|
||||||
auto tranposed = Tensor::share(std::move(shape_), dt, {});
|
// auto tranposed = Tensor::share(std::move(shape_), dt, {});
|
||||||
auto permutation = Tensor::share_vec(std::move(permute_));
|
auto permutation = Tensor::share_vec(std::move(permute_));
|
||||||
mutant.push_operator(OpType::Transpose,
|
mutant.push_operator(OpType::Transpose,
|
||||||
{std::move(t2), std::move(permutation)},
|
{std::move(t2), std::move(permutation)},
|
||||||
{tranposed});
|
|
||||||
mutant.push_operator(OpType::Reshape, {std::move(tranposed)},
|
|
||||||
{conv.output()});
|
{conv.output()});
|
||||||
}
|
}
|
||||||
} else if (
|
} else if (
|
||||||
|
|
|
@ -28,8 +28,7 @@ int main() {
|
||||||
{c} // outputs
|
{c} // outputs
|
||||||
);
|
);
|
||||||
|
|
||||||
auto p =
|
auto p = Partition<pass::SingleOperator>(std::move(g), pass::partition);
|
||||||
Partition<pass::SingleOperator>(std::move(g), pass::partition);
|
|
||||||
auto m = Mutation<pass::SingleOperator>(
|
auto m = Mutation<pass::SingleOperator>(
|
||||||
std::move(p),
|
std::move(p),
|
||||||
[](const auto &g, const auto &t) { return Vec<Unigraph>{}; });
|
[](const auto &g, const auto &t) { return Vec<Unigraph>{}; });
|
||||||
|
|
Loading…
Reference in New Issue