From c077a61681bc18dbf8261db522d1b5db086a4546 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 26 Apr 2023 12:20:17 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=94=B9=E6=AD=A3=201x1=20=E5=8D=B7?= =?UTF-8?q?=E7=A7=AF=E7=9A=84=E5=8F=98=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- optimization/src/pass/single_operator.cc | 16 +++++++++------- optimization/test/test.cpp | 3 +-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/optimization/src/pass/single_operator.cc b/optimization/src/pass/single_operator.cc index b827c14e..3d64dfb5 100644 --- a/optimization/src/pass/single_operator.cc +++ b/optimization/src/pass/single_operator.cc @@ -96,19 +96,21 @@ Vec optimization::pass::mutate( // fmt: new line dt, {})}); } - // (t0,t1) -|matmul|-> t2 - auto t2 = Tensor::share({t0->shape[0], t1->shape[1]}, dt, {}); - mutant.push_operator(OpType::MatMul, {t0, t1}, {t2}); + // (t0,t1) -|matmul|-> x -|reshape|-> t2 + auto x = Tensor::share({t0->shape[0], t1->shape[1]}, dt, {}); + mutant.push_operator(OpType::MatMul, {std::move(t0), std::move(t1)}, + {x}); + auto t2 = Tensor::share( + {i_shape[0], i_shape[2], i_shape[3], k_shape[0]}, dt, {}); + mutant.push_operator(OpType::Reshape, {std::move(x)}, {t2}); - // (t2,"nhwf"->"nfhw") -|transpose|-> tranposed -|reshape|-> output + // (t2,"nhwf"->"nfhw") -|transpose|-> output { auto [shape_, permute_] = transpose(t2->shape, "nhwf", "nfhw"); - auto tranposed = Tensor::share(std::move(shape_), dt, {}); + // auto tranposed = Tensor::share(std::move(shape_), dt, {}); auto permutation = Tensor::share_vec(std::move(permute_)); mutant.push_operator(OpType::Transpose, {std::move(t2), std::move(permutation)}, - {tranposed}); - mutant.push_operator(OpType::Reshape, {std::move(tranposed)}, {conv.output()}); } } else if ( diff --git a/optimization/test/test.cpp b/optimization/test/test.cpp index ebf076d7..549b77cf 100644 --- a/optimization/test/test.cpp +++ b/optimization/test/test.cpp @@ -28,8 +28,7 @@ int main() { {c} // outputs ); - auto p = - Partition(std::move(g), pass::partition); + auto p = Partition(std::move(g), pass::partition); auto m = Mutation( std::move(p), [](const auto &g, const auto &t) { return Vec{}; });