fix: 改正 1x1 卷积的变换

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-04-26 12:20:17 +08:00
parent 325b279468
commit c077a61681
2 changed files with 10 additions and 9 deletions

View File

@ -96,19 +96,21 @@ Vec<Unigraph> optimization::pass::mutate( // fmt: new line
dt, {})}); dt, {})});
} }
// (t0,t1) -|matmul|-> t2 // (t0,t1) -|matmul|-> x -|reshape|-> t2
auto t2 = Tensor::share({t0->shape[0], t1->shape[1]}, dt, {}); auto x = Tensor::share({t0->shape[0], t1->shape[1]}, dt, {});
mutant.push_operator(OpType::MatMul, {t0, t1}, {t2}); mutant.push_operator(OpType::MatMul, {std::move(t0), std::move(t1)},
{x});
auto t2 = Tensor::share(
{i_shape[0], i_shape[2], i_shape[3], k_shape[0]}, dt, {});
mutant.push_operator(OpType::Reshape, {std::move(x)}, {t2});
// (t2,"nhwf"->"nfhw") -|transpose|-> tranposed -|reshape|-> output // (t2,"nhwf"->"nfhw") -|transpose|-> output
{ {
auto [shape_, permute_] = transpose(t2->shape, "nhwf", "nfhw"); auto [shape_, permute_] = transpose(t2->shape, "nhwf", "nfhw");
auto tranposed = Tensor::share(std::move(shape_), dt, {}); // auto tranposed = Tensor::share(std::move(shape_), dt, {});
auto permutation = Tensor::share_vec(std::move(permute_)); auto permutation = Tensor::share_vec(std::move(permute_));
mutant.push_operator(OpType::Transpose, mutant.push_operator(OpType::Transpose,
{std::move(t2), std::move(permutation)}, {std::move(t2), std::move(permutation)},
{tranposed});
mutant.push_operator(OpType::Reshape, {std::move(tranposed)},
{conv.output()}); {conv.output()});
} }
} else if ( } else if (

View File

@ -28,8 +28,7 @@ int main() {
{c} // outputs {c} // outputs
); );
auto p = auto p = Partition<pass::SingleOperator>(std::move(g), pass::partition);
Partition<pass::SingleOperator>(std::move(g), pass::partition);
auto m = Mutation<pass::SingleOperator>( auto m = Mutation<pass::SingleOperator>(
std::move(p), std::move(p),
[](const auto &g, const auto &t) { return Vec<Unigraph>{}; }); [](const auto &g, const auto &t) { return Vec<Unigraph>{}; });