Fix: matmul transpose in convNHWC2gemm rule

This commit is contained in:
Liyan Zheng 2023-04-23 22:54:50 +08:00
parent 8409c1f9d4
commit 4211fd1f32
14 changed files with 32 additions and 46 deletions

@ -1 +1 @@
Subproject commit 3bb9240cb15459768adb3e7d963a20e1523a6294
Subproject commit f30744bcf726ea3735df7ecf9e9de9ddac540283

@ -1 +1 @@
Subproject commit b796f7d44681514f58a683a3a71ff17c94edb0c1
Subproject commit e2239ee6043f73722e7aa812a459f54a28552929

@ -1 +1 @@
Subproject commit 13132dd361c8c5b5753983d5186cf54f689d90f9
Subproject commit 6aebf09233951e4ce30a63919186a70b2b195756

2
3rd-party/pybind11 vendored

@ -1 +1 @@
Subproject commit 0bd8896a4010f2d91b2340570c24fa08606ec406
Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae

View File

@ -1,7 +1,6 @@
#pragma once
#include "core/runtime.h"
#include "cuda/cuda_common.h"
#include "nnet/dbg.h"
namespace infini {
@ -31,7 +30,6 @@ class CudaRuntimeObj : public RuntimeObj {
void sync() const;
CudaPtr alloc(size_t size) override {
void *ptr;
// dbg(size);
checkCudaError(cudaMalloc(&ptr, size));
allocatedGPUMemorySize += size;
allocationMap[ptr] = size;

View File

@ -502,9 +502,10 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
auto bestGraph = make_ref<GraphObj>(runtimeExec, chainOps);
// Eliminate transpose and reshape operators
// if (auto eliminatedGraph = mutator->eliminateVertically(
// make_ref<GraphObj>(runtimeExec, chainOps)))
// bestGraph = eliminatedGraph;
// FIXME: current Relu only support 3D and 4D tensors
if (auto eliminatedGraph = mutator->eliminateVertically(
make_ref<GraphObj>(runtimeExec, chainOps)))
bestGraph = eliminatedGraph;
// Fuse membound operators
if (auto optGraph = mutator->fuseVertically(bestGraph))
bestGraph = optGraph;

View File

@ -2,7 +2,6 @@
#include "core/blob.h"
#include "core/operator.h"
#include "core/runtime.h"
#include "nnet/dbg.h"
#include "utils/dataloader.h"
#include <cstring>
#include <numeric>

View File

@ -3,6 +3,7 @@
#include "core/perf_engine.h"
#include "core/runtime.h"
#include "cuda_profiler_api.h"
#include "nnet/dbg.h"
#include "operators/conv.h"
#include "operators/matmul.h"
#ifdef INFINI_USE_TVM
@ -165,9 +166,9 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph, int rounds) {
auto [cudaGraphInstance, numCudaGraphNodes] = endCudaGraphStreamCapture();
// Since one TVM packed function may contaion more than one CUDA kernel, the
// number of captured kernels may exceed the number of operators.
// IT_ASSERT(numCudaGraphNodes >= kernels.size(),
// std::to_string(numCudaGraphNodes) +
// " != " + std::to_string(kernels.size()));
IT_ASSERT(numCudaGraphNodes >= kernels.size(),
std::to_string(numCudaGraphNodes) +
" != " + std::to_string(kernels.size()));
printf("numCudaGraphNodes = %lu\n", numCudaGraphNodes);
return timeit(
[&, cudaGraphInstance = cudaGraphInstance, stream = getStream()]() {

View File

@ -82,6 +82,9 @@ class convCudnn : public Kernel {
// get kernels
cudnnFilterDescriptor_t knDesc;
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
// FIXME: filter data layout is not changed with input data layout
// since FCRS shows better performance for NHWC inputs in some cases.
// This should be tunable.
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW, f,
channelsPerGrp, r, s));

View File

@ -41,4 +41,4 @@ REGISTER_KERNEL(Device::CUDA, OpType::Conv2dReduce, DataType::Float32,
REGISTER_KERNEL(Device::CUDA, OpType::Conv2dReduceTranspose, DataType::Float32,
Conv2dReduceCuda, "Conv2dReduceTranspose_CUDA_Float32");
} // namespace infini
} // namespace infini

View File

@ -32,7 +32,6 @@ class ActivationCudnn : public CudaKernelWithoutConfig {
} else if (dim.size() == 3) {
n = 1, c = dim[0], h = dim[1], w = dim[2];
} else {
dbg(vecToString(dim));
IT_TODO_HALT();
}

View File

@ -574,24 +574,21 @@ Graph NMutator::transformConvToGEMMReduce(Operator _op) {
IT_ASSERT(inputDims[2] == w);
IT_ASSERT(inputDims[3] == c);
const DataType dtype = A->getDType();
// IT_ASSERT(outputDims[0] == n);
// IT_ASSERT(outputDims[1] == h);
// IT_ASSERT(outputDims[2] == w);
// IT_ASSERT(outputDims[3] == f);
auto g = make_ref<GraphObj>(runtime);
dbg(vecToString(inputDims));
dbg(vecToString(weightDims));
auto newA = g->addTensor(
{inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]}, dtype);
auto newW = g->addTensor(
{weightDims[3], weightDims[0] * weightDims[1] * weightDims[2]}, dtype);
// // If use Matmul with transpose 0,0
// auto newW = g->addTensor(
// {weightDims[0] * weightDims[1] * weightDims[2], weightDims[3]},
// dtype);
// {weightDims[3], weightDims[0] * weightDims[1] * weightDims[2]}, dtype);
// If use Matmul with transpose 0, 1
auto newW = g->addTensor(
{weightDims[0] * weightDims[1] * weightDims[2], weightDims[3]},
dtype);
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(A), newA, newA->getDims());
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(W), newW, newW->getDims());
Tensor newO = g->addOp<MatmulObj>(newA, newW, nullptr, 0, 0)->getOutput();
Tensor newO = g->addOp<MatmulObj>(newA, newW, nullptr, 0, 1)->getOutput();
auto new1 = g->addTensor({n, h, w, f, r, s}, dtype);
g->addOpWithOutputs<ReshapeObj>(newO, new1, new1->getDims());
g->addOpWithOutputs<Conv2dReduce>(
@ -605,39 +602,27 @@ Graph NMutator::transformConvTranposeToGEMMReduce(Operator _op) {
return nullptr;
const auto &A = op->getInputs()[0];
const auto &W = op->getInputs()[1];
// f is the de-facto input channel for ConvTranspose
const auto &[n, c, h, w, f, r, s] = op->getNCHWFRS();
const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
const Shape inputDims = op->getInputs(0)->getDims();
const Shape weightDims = op->getInputs(1)->getDims();
const Shape outputDims = op->getOutput()->getDims();
dbg(vecToString(inputDims));
dbg(vecToString(weightDims));
dbg(vecToString(op->getOutput()->getDims()));
// dbg(vecToString(op->getNCHWFRS());
IT_ASSERT(weightDims[0] == f);
IT_ASSERT(weightDims[1] == r);
IT_ASSERT(weightDims[2] == s);
IT_ASSERT(weightDims[3] == c);
IT_ASSERT(inputDims[0] == n);
IT_ASSERT(inputDims[1] == h);
IT_ASSERT(inputDims[2] == w);
IT_ASSERT(inputDims[3] == f);
const DataType dtype = A->getDType();
auto g = make_ref<GraphObj>(runtime);
auto newA = g->addTensor(
auto newA = g->addTensor( // [N,H,W,F]
{inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]}, dtype);
auto newW = g->addTensor(
auto newW = g->addTensor( // [F, CRS]
{weightDims[0], weightDims[1] * weightDims[2] * weightDims[3]},
dtype); // hack
dtype); // HACK: this should be a transpose
// auto newW = g->addTensor(
// {weightDims[0] * weightDims[1] * weightDims[2], weightDims[3]},
// dtype);
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(A), newA, newA->getDims());
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(W), newW, newW->getDims());
// newO [NHW, CRS]
Tensor newO = g->addOp<MatmulObj>(newA, newW, nullptr, 0, 0)->getOutput();
auto new1 = g->addTensor({n, h, w, c, r, s}, dtype);
g->addOpWithOutputs<ReshapeObj>(newO, new1, new1->getDims());
// [NHW, CRS] -> [N,H,W,C]
g->addOpWithOutputs<Conv2dReduceTranspose>(
new1, nullptr, g->cloneTensor(op->getOutput()), false, 0.f, ph, pw);
return g;

View File

@ -95,4 +95,4 @@ Conv2dReduceTranspose::inferShape(const TensorVec &inputs) const {
return {{{on, oh, ow, of}}};
}
} // namespace infini
} // namespace infini

View File

@ -155,7 +155,7 @@ if __name__ == "__main__":
# construct_convTranspose2d(runtime)
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/fsrcnn.bs1.onnx'), 'fsrcnn.bs1'),
(ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"),
(ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16")
(ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16"),
# (construct_conv_nhwc(runtime, 1, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1')
]