forked from jiuyuan/InfiniTensor
Add: no malloc for reshape outputs
This commit is contained in:
parent
40e6db6608
commit
e2f18272c9
|
@ -12,6 +12,11 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
cublasHandle_t cublas;
|
||||
CudaPtr workspace;
|
||||
size_t workspaceSize;
|
||||
|
||||
// Memory information
|
||||
size_t allocatedGPUMemorySize = 0;
|
||||
map<void *, size_t> allocationMap;
|
||||
|
||||
bool cudaGraphStatus; // Whether CUDA graph stream capture is enabled
|
||||
|
||||
public:
|
||||
|
@ -27,10 +32,20 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
CudaPtr alloc(size_t size) override {
|
||||
void *ptr;
|
||||
checkCudaError(cudaMalloc(&ptr, size));
|
||||
// printf("cuda malloc: %p %lu bytes\n", ptr, size);
|
||||
allocatedGPUMemorySize += size;
|
||||
allocationMap[ptr] = size;
|
||||
// printf("cuda malloc: %p %lu bytes, total %lu bytes (%.2lf GB)\n",
|
||||
// ptr, size, allocatedGPUMemorySize,
|
||||
// double(allocatedGPUMemorySize) / 1024 / 1024 / 1024);
|
||||
return ptr;
|
||||
}
|
||||
void dealloc(void *ptr) override { checkCudaError(cudaFree(ptr)); }
|
||||
void dealloc(void *ptr) override {
|
||||
checkCudaError(cudaFree(ptr));
|
||||
allocatedGPUMemorySize -= allocationMap.at(ptr);
|
||||
allocationMap.erase(ptr);
|
||||
// printf("cuda dealloc: %p %lu bytes, total %lu\n", ptr,
|
||||
// allocationMap.at(ptr), allocatedGPUMemorySize);
|
||||
}
|
||||
cudnnHandle_t cudnnHandle() const { return cudnn; }
|
||||
cublasHandle_t cublasHandle() const { return cublas; }
|
||||
size_t getWorkspaceSize() const { return workspaceSize; }
|
||||
|
|
|
@ -125,7 +125,18 @@ void GraphObj::optimize() {
|
|||
|
||||
void GraphObj::dataMalloc() {
|
||||
for (auto &tensor : tensors) {
|
||||
tensor->dataMalloc();
|
||||
if (tensor->getSource() && tensor->getTargets().size() > 0 &&
|
||||
tensor->getSource()->getOpType() == OpType::Reshape) {
|
||||
continue;
|
||||
} else
|
||||
tensor->dataMalloc();
|
||||
}
|
||||
// Fill reshape output for avoiding nullptr
|
||||
for (auto &tensor : tensors) {
|
||||
if (tensor->getSource() &&
|
||||
tensor->getSource()->getOpType() == OpType::Reshape) {
|
||||
tensor->setData(tensor->getSource()->getInputs(0)->getDataBlob());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -122,8 +122,9 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const {
|
|||
}
|
||||
|
||||
void TensorObj::dataMalloc() {
|
||||
if (!data)
|
||||
if (!data) {
|
||||
data = runtime->allocBlob(getBytes());
|
||||
}
|
||||
}
|
||||
|
||||
void TensorObj::copyData(const TensorObj *src) {
|
||||
|
|
|
@ -33,7 +33,7 @@ class G2BMMCudnn : public CudaKernelWithoutConfig {
|
|||
auto record =
|
||||
make_ref<PerfRecordObj>(std::numeric_limits<double>::max());
|
||||
const auto [warmupRounds, timingRounds] =
|
||||
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
|
||||
op->getB() > 100 ? tuple{1, 1} : tuple{1, 2};
|
||||
double tmp =
|
||||
timeit([&]() { g2bmmKernel(op, context); },
|
||||
[&]() { context->sync(); }, warmupRounds, timingRounds);
|
||||
|
|
|
@ -34,7 +34,7 @@ class GBMMCudnn : public CudaKernelWithoutConfig {
|
|||
auto record =
|
||||
make_ref<PerfRecordObj>(std::numeric_limits<double>::max());
|
||||
const auto [warmupRounds, timingRounds] =
|
||||
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
|
||||
op->getB() > 100 ? tuple{1, 1} : tuple{1, 3};
|
||||
double tmp =
|
||||
timeit([&]() { gbmmKernel(op, context); },
|
||||
[&]() { context->sync(); }, warmupRounds, timingRounds);
|
||||
|
|
|
@ -232,6 +232,7 @@ void printGraph(Graph g) {
|
|||
}
|
||||
|
||||
void initializeGraphTensors(Graph g, double l, double r, bool useInt) {
|
||||
g->dataMalloc();
|
||||
auto gen = RandomGenerator(-0.1, 0.1, 0, useInt);
|
||||
for (auto t : g->getInputs()) {
|
||||
t->setData(gen);
|
||||
|
@ -260,6 +261,8 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
|
|||
IT_TODO_HALT();
|
||||
vector<Graph> bestGraphs;
|
||||
SearchEngine searchEngine(runtime, mutator);
|
||||
return searchEngine.run(g);
|
||||
|
||||
bestGraphs.emplace_back(searchEngine.run(g));
|
||||
g->topo_sort();
|
||||
dbg(g, bestGraphs[0], bestGraphs.size());
|
||||
|
@ -291,19 +294,19 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
|
|||
make_ref<GraphObj>(runtime, bestGraphCpu->getOperators());
|
||||
bestGraph->topo_sort();
|
||||
|
||||
bestGraph->dataMalloc();
|
||||
// Initialize inputs with random data
|
||||
for (auto t : bestGraph->getInputs()) {
|
||||
t->copyData(fuidToInputTensor[t->getFuid()]);
|
||||
}
|
||||
// bestGraph->dataMalloc();
|
||||
// // Initialize inputs with random data
|
||||
// for (auto t : bestGraph->getInputs()) {
|
||||
// t->copyData(fuidToInputTensor[t->getFuid()]);
|
||||
// }
|
||||
|
||||
// Initialize outputs with zeros
|
||||
for (auto t : bestGraph->getOutputs()) {
|
||||
t->setData(ZeroGenerator());
|
||||
}
|
||||
// // Initialize outputs with zeros
|
||||
// for (auto t : bestGraph->getOutputs()) {
|
||||
// t->setData(ZeroGenerator());
|
||||
// }
|
||||
|
||||
dbg(bestGraph);
|
||||
dbg(bestGraph->getOutputs());
|
||||
// dbg(bestGraph);
|
||||
// dbg(bestGraph->getOutputs());
|
||||
|
||||
// if (tuning) {
|
||||
// runtime->run(bestGraph, true); // Tune kernels
|
||||
|
|
|
@ -26,6 +26,7 @@ def load_onnx(runtime, filename: str) -> ft.Graph:
|
|||
|
||||
|
||||
def run_and_evaluate(runtime, g):
|
||||
ft.initializeGraphTensors(g)
|
||||
runtime.run(g, True)
|
||||
print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}')
|
||||
print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 1000, 1000)}')
|
||||
|
@ -85,10 +86,17 @@ def evluate_GANs():
|
|||
run_and_evaluate(runtime, g)
|
||||
|
||||
|
||||
def construct_convTranspose2d(runtime):
|
||||
# def construct_convTranspose2d(runtime):
|
||||
# handler = ft.GraphHandler(runtime)
|
||||
# input = handler.tensor([1, 56, 32, 32], tensor_type=ft.TensorType.Input)
|
||||
# w = handler.tensor([56, 1, 9, 9], tensor_type=ft.TensorType.Initialized)
|
||||
# handler.convTransposed2d(input, w, None, 3, 3, 4, 4, 1, 1, 1, 1)
|
||||
# return handler.getGraph()
|
||||
|
||||
def construct_convTranspose2d(runtime, n, c, h, w, f, r, s, pad, stride, dilation):
|
||||
handler = ft.GraphHandler(runtime)
|
||||
input = handler.tensor([1, 56, 32, 32], tensor_type=ft.TensorType.Input)
|
||||
w = handler.tensor([56, 1, 9, 9], tensor_type=ft.TensorType.Initialized)
|
||||
input = handler.tensor([n, f, h, w], tensor_type=ft.TensorType.Input)
|
||||
w = handler.tensor([f, c, r, s], tensor_type=ft.TensorType.Initialized)
|
||||
handler.convTransposed2d(input, w, None, 3, 3, 4, 4, 1, 1, 1, 1)
|
||||
return handler.getGraph()
|
||||
|
||||
|
@ -104,14 +112,31 @@ def construct_conv(runtime, n, c, h, w, f, r, s, pad, stride, dilation):
|
|||
return handler.getGraph()
|
||||
|
||||
|
||||
def export_op_level_onnx(runtime):
|
||||
graphs = [
|
||||
(construct_conv(runtime, 1, 512, 7, 7, 512, 3, 3,
|
||||
1, 1, 1), "orig_Conv3x3"), # ResNet18 Conv_37
|
||||
# 16, 256, 2, 2, 448, 4, 4, 1, 2, 1 # CelebA_ConvTranspose_0
|
||||
# TODO
|
||||
(construct_convTranspose2d(), "orig_ConvTranspose"),
|
||||
(construct_conv(runtime, 16, 32, 224, 224, 1, 5,
|
||||
5, 2, 1, 1, 1), "orig_Conv5x5"), # SRCNN_Conv_4
|
||||
(construct_convTranspose2d(), "orig_G2BMM"),
|
||||
]
|
||||
for g, name in graphs:
|
||||
save_onnx(g, f"opt_{name}.onnx")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime = ft.cuda_runtime()
|
||||
graphs = [
|
||||
# (construct_conv(runtime, 16, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1'), # FSRCNN Conv_2 1x1
|
||||
# (construct_conv(runtime, 1, 12, 32, 32, 12, 3, 3, 1, 1, 1), 'conv3x3'), # FSRCNN Conv_4 3x3
|
||||
# ft.getGANGraph(batch, runtime, 5, 1)
|
||||
# (ft.getLongformer(runtime, 1), 'longformer.bs1'),
|
||||
(ft.getLongformer(runtime, 16), 'longformer.bs16'),
|
||||
# construct_convTranspose2d(runtime)
|
||||
(load_onnx(runtime, '/mnt/auxHome/models/einnet/fsrcnn.bs1.onnx'), 'fsrcnn.bs1'),
|
||||
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/fsrcnn.bs1.onnx'), 'fsrcnn.bs1'),
|
||||
]
|
||||
|
||||
for original_g, name in graphs:
|
||||
|
@ -119,9 +144,12 @@ if __name__ == "__main__":
|
|||
if True: # Optimization
|
||||
save_onnx(original_g, f"orig_{name}.onnx")
|
||||
g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.RuleBased,
|
||||
[3, 2, 2, 5, 8, 8, 6, 90])
|
||||
[1, 7, 7, 2, 8, 6, 6]) # G2BMM/GBMM
|
||||
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.RuleBased,
|
||||
# [3, 2, 2, 5, 8, 8, 6, 90]) # Conv2conv
|
||||
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.Normal)
|
||||
|
||||
save_onnx(g, f"opt_{name}.onnx")
|
||||
verify_graphs(runtime, original_g, g)
|
||||
# verify_graphs(runtime, original_g, g)
|
||||
# run_and_evaluate(runtime, original_g)
|
||||
run_and_evaluate(runtime, g)
|
||||
|
|
Loading…
Reference in New Issue