fix: 改正 1x1 卷积的变换

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-04-26 12:20:17 +08:00
parent 325b279468
commit c077a61681
2 changed files with 10 additions and 9 deletions

View File

@ -96,19 +96,21 @@ Vec<Unigraph> optimization::pass::mutate( // fmt: new line
dt, {})});
}
// (t0,t1) -|matmul|-> t2
auto t2 = Tensor::share({t0->shape[0], t1->shape[1]}, dt, {});
mutant.push_operator(OpType::MatMul, {t0, t1}, {t2});
// (t0,t1) -|matmul|-> x -|reshape|-> t2
auto x = Tensor::share({t0->shape[0], t1->shape[1]}, dt, {});
mutant.push_operator(OpType::MatMul, {std::move(t0), std::move(t1)},
{x});
auto t2 = Tensor::share(
{i_shape[0], i_shape[2], i_shape[3], k_shape[0]}, dt, {});
mutant.push_operator(OpType::Reshape, {std::move(x)}, {t2});
// (t2,"nhwf"->"nfhw") -|transpose|-> tranposed -|reshape|-> output
// (t2,"nhwf"->"nfhw") -|transpose|-> output
{
auto [shape_, permute_] = transpose(t2->shape, "nhwf", "nfhw");
auto tranposed = Tensor::share(std::move(shape_), dt, {});
// auto tranposed = Tensor::share(std::move(shape_), dt, {});
auto permutation = Tensor::share_vec(std::move(permute_));
mutant.push_operator(OpType::Transpose,
{std::move(t2), std::move(permutation)},
{tranposed});
mutant.push_operator(OpType::Reshape, {std::move(tranposed)},
{conv.output()});
}
} else if (

View File

@ -28,8 +28,7 @@ int main() {
{c} // outputs
);
auto p =
Partition<pass::SingleOperator>(std::move(g), pass::partition);
auto p = Partition<pass::SingleOperator>(std::move(g), pass::partition);
auto m = Mutation<pass::SingleOperator>(
std::move(p),
[](const auto &g, const auto &t) { return Vec<Unigraph>{}; });