forked from jiuyuan/InfiniTensor
Fix: matmul transpose in convNHWC2gemm rule
This commit is contained in:
parent
8409c1f9d4
commit
4211fd1f32
|
@ -1 +1 @@
|
|||
Subproject commit 3bb9240cb15459768adb3e7d963a20e1523a6294
|
||||
Subproject commit f30744bcf726ea3735df7ecf9e9de9ddac540283
|
|
@ -1 +1 @@
|
|||
Subproject commit b796f7d44681514f58a683a3a71ff17c94edb0c1
|
||||
Subproject commit e2239ee6043f73722e7aa812a459f54a28552929
|
|
@ -1 +1 @@
|
|||
Subproject commit 13132dd361c8c5b5753983d5186cf54f689d90f9
|
||||
Subproject commit 6aebf09233951e4ce30a63919186a70b2b195756
|
|
@ -1 +1 @@
|
|||
Subproject commit 0bd8896a4010f2d91b2340570c24fa08606ec406
|
||||
Subproject commit 1e3400b6742288429f2069aaf5febf92d0662dae
|
|
@ -1,7 +1,6 @@
|
|||
#pragma once
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "nnet/dbg.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -31,7 +30,6 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
void sync() const;
|
||||
CudaPtr alloc(size_t size) override {
|
||||
void *ptr;
|
||||
// dbg(size);
|
||||
checkCudaError(cudaMalloc(&ptr, size));
|
||||
allocatedGPUMemorySize += size;
|
||||
allocationMap[ptr] = size;
|
||||
|
|
|
@ -502,9 +502,10 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
|
|||
|
||||
auto bestGraph = make_ref<GraphObj>(runtimeExec, chainOps);
|
||||
// Eliminate transpose and reshape operators
|
||||
// if (auto eliminatedGraph = mutator->eliminateVertically(
|
||||
// make_ref<GraphObj>(runtimeExec, chainOps)))
|
||||
// bestGraph = eliminatedGraph;
|
||||
// FIXME: current Relu only support 3D and 4D tensors
|
||||
if (auto eliminatedGraph = mutator->eliminateVertically(
|
||||
make_ref<GraphObj>(runtimeExec, chainOps)))
|
||||
bestGraph = eliminatedGraph;
|
||||
// Fuse membound operators
|
||||
if (auto optGraph = mutator->fuseVertically(bestGraph))
|
||||
bestGraph = optGraph;
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
#include "core/blob.h"
|
||||
#include "core/operator.h"
|
||||
#include "core/runtime.h"
|
||||
#include "nnet/dbg.h"
|
||||
#include "utils/dataloader.h"
|
||||
#include <cstring>
|
||||
#include <numeric>
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "core/perf_engine.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda_profiler_api.h"
|
||||
#include "nnet/dbg.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
#ifdef INFINI_USE_TVM
|
||||
|
@ -165,9 +166,9 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph, int rounds) {
|
|||
auto [cudaGraphInstance, numCudaGraphNodes] = endCudaGraphStreamCapture();
|
||||
// Since one TVM packed function may contaion more than one CUDA kernel, the
|
||||
// number of captured kernels may exceed the number of operators.
|
||||
// IT_ASSERT(numCudaGraphNodes >= kernels.size(),
|
||||
// std::to_string(numCudaGraphNodes) +
|
||||
// " != " + std::to_string(kernels.size()));
|
||||
IT_ASSERT(numCudaGraphNodes >= kernels.size(),
|
||||
std::to_string(numCudaGraphNodes) +
|
||||
" != " + std::to_string(kernels.size()));
|
||||
printf("numCudaGraphNodes = %lu\n", numCudaGraphNodes);
|
||||
return timeit(
|
||||
[&, cudaGraphInstance = cudaGraphInstance, stream = getStream()]() {
|
||||
|
|
|
@ -82,6 +82,9 @@ class convCudnn : public Kernel {
|
|||
// get kernels
|
||||
cudnnFilterDescriptor_t knDesc;
|
||||
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
|
||||
// FIXME: filter data layout is not changed with input data layout
|
||||
// since FCRS shows better performance for NHWC inputs in some cases.
|
||||
// This should be tunable.
|
||||
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
|
||||
CUDNN_TENSOR_NCHW, f,
|
||||
channelsPerGrp, r, s));
|
||||
|
|
|
@ -41,4 +41,4 @@ REGISTER_KERNEL(Device::CUDA, OpType::Conv2dReduce, DataType::Float32,
|
|||
REGISTER_KERNEL(Device::CUDA, OpType::Conv2dReduceTranspose, DataType::Float32,
|
||||
Conv2dReduceCuda, "Conv2dReduceTranspose_CUDA_Float32");
|
||||
|
||||
} // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -32,7 +32,6 @@ class ActivationCudnn : public CudaKernelWithoutConfig {
|
|||
} else if (dim.size() == 3) {
|
||||
n = 1, c = dim[0], h = dim[1], w = dim[2];
|
||||
} else {
|
||||
dbg(vecToString(dim));
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
|
|
|
@ -574,24 +574,21 @@ Graph NMutator::transformConvToGEMMReduce(Operator _op) {
|
|||
IT_ASSERT(inputDims[2] == w);
|
||||
IT_ASSERT(inputDims[3] == c);
|
||||
const DataType dtype = A->getDType();
|
||||
// IT_ASSERT(outputDims[0] == n);
|
||||
// IT_ASSERT(outputDims[1] == h);
|
||||
// IT_ASSERT(outputDims[2] == w);
|
||||
// IT_ASSERT(outputDims[3] == f);
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
dbg(vecToString(inputDims));
|
||||
dbg(vecToString(weightDims));
|
||||
auto newA = g->addTensor(
|
||||
{inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]}, dtype);
|
||||
auto newW = g->addTensor(
|
||||
{weightDims[3], weightDims[0] * weightDims[1] * weightDims[2]}, dtype);
|
||||
|
||||
// // If use Matmul with transpose 0,0
|
||||
// auto newW = g->addTensor(
|
||||
// {weightDims[0] * weightDims[1] * weightDims[2], weightDims[3]},
|
||||
// dtype);
|
||||
// {weightDims[3], weightDims[0] * weightDims[1] * weightDims[2]}, dtype);
|
||||
|
||||
// If use Matmul with transpose 0, 1
|
||||
auto newW = g->addTensor(
|
||||
{weightDims[0] * weightDims[1] * weightDims[2], weightDims[3]},
|
||||
dtype);
|
||||
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(A), newA, newA->getDims());
|
||||
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(W), newW, newW->getDims());
|
||||
Tensor newO = g->addOp<MatmulObj>(newA, newW, nullptr, 0, 0)->getOutput();
|
||||
Tensor newO = g->addOp<MatmulObj>(newA, newW, nullptr, 0, 1)->getOutput();
|
||||
auto new1 = g->addTensor({n, h, w, f, r, s}, dtype);
|
||||
g->addOpWithOutputs<ReshapeObj>(newO, new1, new1->getDims());
|
||||
g->addOpWithOutputs<Conv2dReduce>(
|
||||
|
@ -605,39 +602,27 @@ Graph NMutator::transformConvTranposeToGEMMReduce(Operator _op) {
|
|||
return nullptr;
|
||||
const auto &A = op->getInputs()[0];
|
||||
const auto &W = op->getInputs()[1];
|
||||
// f is the de-facto input channel for ConvTranspose
|
||||
const auto &[n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
const auto &[ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
const Shape inputDims = op->getInputs(0)->getDims();
|
||||
const Shape weightDims = op->getInputs(1)->getDims();
|
||||
const Shape outputDims = op->getOutput()->getDims();
|
||||
dbg(vecToString(inputDims));
|
||||
dbg(vecToString(weightDims));
|
||||
dbg(vecToString(op->getOutput()->getDims()));
|
||||
// dbg(vecToString(op->getNCHWFRS());
|
||||
IT_ASSERT(weightDims[0] == f);
|
||||
IT_ASSERT(weightDims[1] == r);
|
||||
IT_ASSERT(weightDims[2] == s);
|
||||
IT_ASSERT(weightDims[3] == c);
|
||||
IT_ASSERT(inputDims[0] == n);
|
||||
IT_ASSERT(inputDims[1] == h);
|
||||
IT_ASSERT(inputDims[2] == w);
|
||||
IT_ASSERT(inputDims[3] == f);
|
||||
const DataType dtype = A->getDType();
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto newA = g->addTensor(
|
||||
auto newA = g->addTensor( // [N,H,W,F]
|
||||
{inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]}, dtype);
|
||||
auto newW = g->addTensor(
|
||||
auto newW = g->addTensor( // [F, CRS]
|
||||
{weightDims[0], weightDims[1] * weightDims[2] * weightDims[3]},
|
||||
dtype); // hack
|
||||
dtype); // HACK: this should be a transpose
|
||||
|
||||
// auto newW = g->addTensor(
|
||||
// {weightDims[0] * weightDims[1] * weightDims[2], weightDims[3]},
|
||||
// dtype);
|
||||
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(A), newA, newA->getDims());
|
||||
g->addOpWithOutputs<ReshapeObj>(g->cloneTensor(W), newW, newW->getDims());
|
||||
// newO [NHW, CRS]
|
||||
Tensor newO = g->addOp<MatmulObj>(newA, newW, nullptr, 0, 0)->getOutput();
|
||||
auto new1 = g->addTensor({n, h, w, c, r, s}, dtype);
|
||||
g->addOpWithOutputs<ReshapeObj>(newO, new1, new1->getDims());
|
||||
// [NHW, CRS] -> [N,H,W,C]
|
||||
g->addOpWithOutputs<Conv2dReduceTranspose>(
|
||||
new1, nullptr, g->cloneTensor(op->getOutput()), false, 0.f, ph, pw);
|
||||
return g;
|
||||
|
|
|
@ -95,4 +95,4 @@ Conv2dReduceTranspose::inferShape(const TensorVec &inputs) const {
|
|||
|
||||
return {{{on, oh, ow, of}}};
|
||||
}
|
||||
} // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -155,7 +155,7 @@ if __name__ == "__main__":
|
|||
# construct_convTranspose2d(runtime)
|
||||
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/fsrcnn.bs1.onnx'), 'fsrcnn.bs1'),
|
||||
(ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"),
|
||||
(ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16")
|
||||
(ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16"),
|
||||
# (construct_conv_nhwc(runtime, 1, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1')
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue