forked from jiuyuan/InfiniTensor
Compare commits
11 Commits
master
...
benchmark_
Author | SHA1 | Date |
---|---|---|
bolun | 608f997042 | |
bolun | 97970c5d94 | |
bolun | 4b762cc8d9 | |
bolun | 1c55c74151 | |
bolun | ddaddf375e | |
bolun | 7945693131 | |
bolun | fdb2d30868 | |
zhangyue207 | f532784d4f | |
zhangyue207 | 454b7651a8 | |
zhangyue207 | 48322dbf27 | |
zhangyue207 | 523946cb8b |
|
@ -262,3 +262,19 @@ if(BUILD_TEST)
|
||||||
target_link_libraries(nnet_reader InfiniTensor)
|
target_link_libraries(nnet_reader InfiniTensor)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
function(build_bench files)
|
||||||
|
file(GLOB BENCH_SOURCES ${files})
|
||||||
|
foreach(benchsourcefile ${BENCH_SOURCES})
|
||||||
|
get_filename_component(benchname ${benchsourcefile} NAME_WE)
|
||||||
|
add_executable("benchmark_${benchname}" ${benchsourcefile})
|
||||||
|
target_link_libraries("benchmark_${benchname}" InfiniTensor)
|
||||||
|
# add_custom_target(NAME ${benchname} COMMAND ${benchname})
|
||||||
|
endforeach(benchsourcefile ${BENCH_SOURCES})
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
if (BENCH)
|
||||||
|
if (USE_CUDA)
|
||||||
|
build_bench(benchmark/kernels/cuda/*.cc)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -6,12 +6,14 @@ BANG ?= OFF
|
||||||
INTELCPU ?= off
|
INTELCPU ?= off
|
||||||
BACKTRACE ?= ON
|
BACKTRACE ?= ON
|
||||||
TEST ?= ON
|
TEST ?= ON
|
||||||
|
BENCH ?= ON
|
||||||
|
|
||||||
CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE)
|
CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE)
|
||||||
CMAKE_OPT += -DUSE_CUDA=$(CUDA)
|
CMAKE_OPT += -DUSE_CUDA=$(CUDA)
|
||||||
CMAKE_OPT += -DUSE_BANG=$(BANG)
|
CMAKE_OPT += -DUSE_BANG=$(BANG)
|
||||||
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE)
|
||||||
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
CMAKE_OPT += -DBUILD_TEST=$(TEST)
|
||||||
|
CMAKE_OPT += -DBENCH=$(BENCH)
|
||||||
|
|
||||||
ifeq ($(INTELCPU), ON)
|
ifeq ($(INTELCPU), ON)
|
||||||
CMAKE_OPT += -DUSE_INTELCPU=ON -DCMAKE_CXX_COMPILER=dpcpp
|
CMAKE_OPT += -DUSE_INTELCPU=ON -DCMAKE_CXX_COMPILER=dpcpp
|
||||||
|
|
|
@ -0,0 +1,243 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/conv.h"
|
||||||
|
#include "benchmark.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include <cmath>
|
||||||
|
#include <chrono>
|
||||||
|
#include <sys/time.h>
|
||||||
|
|
||||||
|
using namespace infini;
|
||||||
|
|
||||||
|
#define M 1048576
|
||||||
|
|
||||||
|
const char algo_name[8][50] = {
|
||||||
|
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
|
||||||
|
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
|
||||||
|
"CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
|
||||||
|
"CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
|
||||||
|
"CUDNN_CONVOLUTION_FWD_ALGO_FFT",
|
||||||
|
"CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
|
||||||
|
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
|
||||||
|
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED",
|
||||||
|
};
|
||||||
|
|
||||||
|
const char mode_name[2][50] = {
|
||||||
|
"CUDNN_CONVOLUTION",
|
||||||
|
"CUDNN_CROSS_CORRELATION"
|
||||||
|
};
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
// Benchmark Settings
|
||||||
|
int warmupRounds = 50;
|
||||||
|
int timingRounds = 100;
|
||||||
|
DataType dtype = DataType::Float32;
|
||||||
|
|
||||||
|
// cudnn Conv Configurations
|
||||||
|
cudnnConvolutionMode_t convMode = CUDNN_CROSS_CORRELATION;
|
||||||
|
cudnnConvolutionFwdAlgo_t convAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
|
||||||
|
float alpha = 1.f, beta = 0.f;
|
||||||
|
|
||||||
|
int n, c, h, w, f, r, s;
|
||||||
|
int INPUT_BATCH_SIZE = n = 16;
|
||||||
|
int INPUT_CHANNELS = c = 128;
|
||||||
|
int INPUT_HEIGHT = h = 128;
|
||||||
|
int INPUT_WIDTH = w = 128;
|
||||||
|
Shape INPUT_SHAPE = {INPUT_BATCH_SIZE, INPUT_CHANNELS, \
|
||||||
|
INPUT_HEIGHT, INPUT_WIDTH};
|
||||||
|
|
||||||
|
int OUTPUT_CHANNELS = f = 256;
|
||||||
|
int KERNEL_HEIGHT = r = 3;
|
||||||
|
int KERNEL_WIDTH = s = 3;
|
||||||
|
Shape KERNEL_SHAPE = {INPUT_CHANNELS, OUTPUT_CHANNELS, \
|
||||||
|
KERNEL_HEIGHT, KERNEL_WIDTH};
|
||||||
|
|
||||||
|
int NUM_GROUPS = 1;
|
||||||
|
|
||||||
|
int PAD_HEIGHT = 0;
|
||||||
|
int PAD_WIDTH = 0;
|
||||||
|
int VERTICAL_STRIDE = 1;
|
||||||
|
int HORIZONTAL_STRIDE = 1;
|
||||||
|
int DILATION_HEIGHT = 1;
|
||||||
|
int DILATION_WIDTH = 1;
|
||||||
|
|
||||||
|
// Get input size
|
||||||
|
size_t inputSize = 1;
|
||||||
|
for (auto dim: INPUT_SHAPE) {
|
||||||
|
inputSize *= dim;
|
||||||
|
}
|
||||||
|
size_t inputSizeInBytes = inputSize * sizeof(dtype);
|
||||||
|
|
||||||
|
// Get kernel size
|
||||||
|
size_t kernelSize = 1;
|
||||||
|
for (auto dim: KERNEL_SHAPE) {
|
||||||
|
kernelSize *= dim;
|
||||||
|
}
|
||||||
|
size_t kernelSizeInBytes = kernelSize * sizeof(dtype);
|
||||||
|
|
||||||
|
// Init time variables
|
||||||
|
double time_memcpy_htod = 0.0, time_memcpy_dtoh = 0.0;
|
||||||
|
double time_op = 0.0;
|
||||||
|
|
||||||
|
// Create runtime
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data and kernel on CPU
|
||||||
|
Tensor inputCpu =
|
||||||
|
make_ref<TensorObj>(INPUT_SHAPE, dtype, cpuRuntime);
|
||||||
|
inputCpu->dataMalloc();
|
||||||
|
inputCpu->setData(RandomGenerator());
|
||||||
|
|
||||||
|
Tensor kernelCpu =
|
||||||
|
make_ref<TensorObj>(KERNEL_SHAPE, dtype, cpuRuntime);
|
||||||
|
kernelCpu->dataMalloc();
|
||||||
|
kernelCpu->setData(RandomGenerator());
|
||||||
|
|
||||||
|
// Build input data and kernel on GPU
|
||||||
|
Tensor inputGpu =
|
||||||
|
make_ref<TensorObj>(INPUT_SHAPE, dtype, cudaRuntime);
|
||||||
|
inputGpu->dataMalloc();
|
||||||
|
|
||||||
|
Tensor kernelGpu =
|
||||||
|
make_ref<TensorObj>(KERNEL_SHAPE, dtype, cudaRuntime);
|
||||||
|
kernelGpu->dataMalloc();
|
||||||
|
|
||||||
|
// Do memcpy host to device
|
||||||
|
time_memcpy_htod += timeit(
|
||||||
|
[&]() {
|
||||||
|
inputGpu = inputCpu->clone(cudaRuntime);
|
||||||
|
kernelGpu = kernelCpu->clone(cudaRuntime);
|
||||||
|
},
|
||||||
|
[&]() { cudaRuntime->sync(); },
|
||||||
|
warmupRounds, timingRounds
|
||||||
|
);
|
||||||
|
|
||||||
|
int channelsPerGrp = INPUT_CHANNELS / NUM_GROUPS;
|
||||||
|
|
||||||
|
// Build cudnn descriptors
|
||||||
|
// input descriptor
|
||||||
|
cudnnTensorDescriptor_t inDesc;
|
||||||
|
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||||
|
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||||
|
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
|
||||||
|
|
||||||
|
// kernel descriptor
|
||||||
|
cudnnFilterDescriptor_t knDesc;
|
||||||
|
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
|
||||||
|
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
|
||||||
|
CUDNN_TENSOR_NCHW, f,
|
||||||
|
channelsPerGrp, r, s));
|
||||||
|
|
||||||
|
// bias descriptor
|
||||||
|
// cudnnTensorDescriptor_t biasDesc;
|
||||||
|
// checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
|
||||||
|
// checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||||
|
// biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1));
|
||||||
|
|
||||||
|
// convlution descriptor
|
||||||
|
cudnnConvolutionDescriptor_t convDesc;
|
||||||
|
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
|
||||||
|
checkCudnnError(cudnnSetConvolution2dDescriptor(
|
||||||
|
convDesc, PAD_HEIGHT, PAD_WIDTH, VERTICAL_STRIDE, HORIZONTAL_STRIDE,
|
||||||
|
DILATION_HEIGHT, DILATION_WIDTH, convMode, CUDNN_DATA_FLOAT));
|
||||||
|
if (NUM_GROUPS > 1) {
|
||||||
|
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, NUM_GROUPS));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get output shape
|
||||||
|
int outn, outc, outh, outw;
|
||||||
|
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
|
||||||
|
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
|
||||||
|
|
||||||
|
// Build output descriptor
|
||||||
|
cudnnTensorDescriptor_t outDesc;
|
||||||
|
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||||
|
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
|
||||||
|
CUDNN_DATA_FLOAT, outn, outc,
|
||||||
|
outh, outw));
|
||||||
|
|
||||||
|
// Get output size
|
||||||
|
Shape OUTPUT_SHAPE = {outn, outc, outh, outw};
|
||||||
|
size_t outputSize = 1;
|
||||||
|
for (auto dim: OUTPUT_SHAPE) {
|
||||||
|
outputSize *= dim;
|
||||||
|
}
|
||||||
|
size_t outputSizeInBytes = outputSize * sizeof(dtype);
|
||||||
|
|
||||||
|
// Build output data on CPU
|
||||||
|
Tensor outputCpu =
|
||||||
|
make_ref<TensorObj>(OUTPUT_SHAPE, dtype, cpuRuntime);
|
||||||
|
outputCpu->dataMalloc();
|
||||||
|
|
||||||
|
// Build output data on GPU
|
||||||
|
Tensor outputGpu =
|
||||||
|
make_ref<TensorObj>(OUTPUT_SHAPE, dtype, cudaRuntime);
|
||||||
|
outputGpu->dataMalloc();
|
||||||
|
|
||||||
|
// Get workspace size
|
||||||
|
size_t workspaceSize = 0;
|
||||||
|
checkCudnnError(cudnnGetConvolutionForwardWorkspaceSize(
|
||||||
|
cudaRuntime->cudnnHandle(), inDesc, knDesc, convDesc,
|
||||||
|
outDesc, convAlgo, &workspaceSize));
|
||||||
|
|
||||||
|
CudaPtr workspace = cudaRuntime->getWorkspace(workspaceSize);
|
||||||
|
|
||||||
|
// Do forward
|
||||||
|
time_op += timeit(
|
||||||
|
[&]() {
|
||||||
|
cudnnConvolutionForward(cudaRuntime->cudnnHandle(), &alpha,
|
||||||
|
inDesc, inputGpu->getRawDataPtr<void *>(),
|
||||||
|
knDesc, kernelGpu->getRawDataPtr<void *>(),
|
||||||
|
convDesc, convAlgo, workspace,
|
||||||
|
workspaceSize, &beta,
|
||||||
|
outDesc, outputGpu->getRawDataPtr<void *>());
|
||||||
|
},
|
||||||
|
[&]() { cudaRuntime->sync(); },
|
||||||
|
warmupRounds, timingRounds
|
||||||
|
);
|
||||||
|
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||||
|
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||||
|
// checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
|
||||||
|
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||||
|
|
||||||
|
// Do memcpy device to host
|
||||||
|
time_memcpy_dtoh += timeit(
|
||||||
|
[&]() {
|
||||||
|
outputCpu = outputGpu->clone(cpuRuntime);
|
||||||
|
},
|
||||||
|
[&]() { cudaRuntime->sync(); },
|
||||||
|
warmupRounds, timingRounds
|
||||||
|
);
|
||||||
|
|
||||||
|
// Print Results
|
||||||
|
printf("Operator - Convolution:\n");
|
||||||
|
printf("Conv Algo: %s\n", algo_name[convAlgo]);
|
||||||
|
printf("Conv Mode: %s\n", mode_name[convMode]);
|
||||||
|
printf("Input shape: (%d, %d, %d, %d)\n",
|
||||||
|
INPUT_SHAPE[0], INPUT_SHAPE[1], INPUT_SHAPE[2], INPUT_SHAPE[3]);
|
||||||
|
printf("Kernel shape: (%d, %d, %d, %d)\n",
|
||||||
|
KERNEL_SHAPE[0], KERNEL_SHAPE[1], KERNEL_SHAPE[2], KERNEL_SHAPE[3]);
|
||||||
|
printf("Output shape: (%d, %d, %d, %d)\n",
|
||||||
|
OUTPUT_SHAPE[0], OUTPUT_SHAPE[1], OUTPUT_SHAPE[2], OUTPUT_SHAPE[3]);
|
||||||
|
printf("Workspace size: %ld Bytes, dtype: %s\n",
|
||||||
|
workspaceSize, dtype.toString().c_str());
|
||||||
|
|
||||||
|
printf("TFlops: %.5lf tflops\n",
|
||||||
|
2.0 * INPUT_BATCH_SIZE * channelsPerGrp * outh * outw * \
|
||||||
|
OUTPUT_CHANNELS * KERNEL_HEIGHT * KERNEL_WIDTH / \
|
||||||
|
VERTICAL_STRIDE / HORIZONTAL_STRIDE / 1e9 / time_op);
|
||||||
|
printf("Memcpy time: h2d - %.6lf ms, d2h - %.6lf ms\n",
|
||||||
|
time_memcpy_htod, time_memcpy_dtoh);
|
||||||
|
printf("Memcpy throughput: h2d - %.6lf MB/ms, d2h: %.6lf MB/ms\n",
|
||||||
|
(inputSizeInBytes + kernelSizeInBytes) / M / time_memcpy_htod,
|
||||||
|
outputSizeInBytes / M / time_memcpy_dtoh);
|
||||||
|
printf("Operation: %.6lf ms\n", time_op);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,126 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/softmax.h"
|
||||||
|
#include "benchmark.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include <cmath>
|
||||||
|
#include <chrono>
|
||||||
|
#include <sys/time.h>
|
||||||
|
|
||||||
|
using namespace infini;
|
||||||
|
|
||||||
|
#define M 1048576
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
|
||||||
|
// Benchmark Settings
|
||||||
|
int warmupRounds = 200;
|
||||||
|
int timingRounds = 200;
|
||||||
|
Shape INPUT_SHAPE = {16, 3, 128, 128};
|
||||||
|
DataType dtype = DataType::Float32;
|
||||||
|
|
||||||
|
// Get data size
|
||||||
|
size_t size = 1;
|
||||||
|
for (auto dim: INPUT_SHAPE) {
|
||||||
|
size *= dim;
|
||||||
|
}
|
||||||
|
size_t sizeInBytes = size * sizeof(dtype);
|
||||||
|
|
||||||
|
// Init time variables
|
||||||
|
double time_memcpy_htod = 0.0, time_memcpy_dtoh = 0.0;
|
||||||
|
double time_op = 0.0;
|
||||||
|
|
||||||
|
// Create runtime
|
||||||
|
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor inputCpu =
|
||||||
|
make_ref<TensorObj>(INPUT_SHAPE, dtype, cpuRuntime);
|
||||||
|
inputCpu->dataMalloc();
|
||||||
|
inputCpu->setData(RandomGenerator());
|
||||||
|
|
||||||
|
// Build input data on GPU
|
||||||
|
Tensor inputGpu =
|
||||||
|
make_ref<TensorObj>(INPUT_SHAPE, dtype, cudaRuntime);
|
||||||
|
inputGpu->dataMalloc();
|
||||||
|
|
||||||
|
// Do memcpy host to device
|
||||||
|
time_memcpy_htod += timeit(
|
||||||
|
[&]() {
|
||||||
|
inputGpu = inputCpu->clone(cudaRuntime);
|
||||||
|
},
|
||||||
|
[&]() { cudaRuntime->sync(); },
|
||||||
|
warmupRounds, timingRounds
|
||||||
|
);
|
||||||
|
|
||||||
|
// Build output data on CPU
|
||||||
|
auto outputGpu = inputGpu->clone(cudaRuntime);
|
||||||
|
|
||||||
|
// Build output data on GPU
|
||||||
|
Tensor outputCpu =
|
||||||
|
make_ref<TensorObj>(INPUT_SHAPE, dtype, cpuRuntime);
|
||||||
|
outputCpu->dataMalloc();
|
||||||
|
|
||||||
|
// Build cudnn descriptors
|
||||||
|
cudnnTensorDescriptor_t inputDesc, outputDesc;
|
||||||
|
|
||||||
|
// input descriptor
|
||||||
|
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
|
||||||
|
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||||
|
inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, INPUT_SHAPE[0],
|
||||||
|
INPUT_SHAPE[1], INPUT_SHAPE[2], INPUT_SHAPE[3]));
|
||||||
|
|
||||||
|
// output descriptor
|
||||||
|
checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc));
|
||||||
|
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||||
|
outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, INPUT_SHAPE[0],
|
||||||
|
INPUT_SHAPE[1], INPUT_SHAPE[2], INPUT_SHAPE[3]));
|
||||||
|
|
||||||
|
// cudnn operator settings
|
||||||
|
float alpha = 1.0, beta = 0.0;
|
||||||
|
cudnnSoftmaxAlgorithm_t algo = CUDNN_SOFTMAX_FAST;
|
||||||
|
cudnnSoftmaxMode_t mode = CUDNN_SOFTMAX_MODE_INSTANCE;
|
||||||
|
|
||||||
|
// Do forward
|
||||||
|
time_op += timeit(
|
||||||
|
[&]() {
|
||||||
|
cudnnSoftmaxForward(cudaRuntime->cudnnHandle(), algo, mode,
|
||||||
|
&alpha, inputDesc, inputGpu->getRawDataPtr<void *>(),
|
||||||
|
&beta, outputDesc, outputGpu->getRawDataPtr<void *>());
|
||||||
|
},
|
||||||
|
[&]() { cudaRuntime->sync(); },
|
||||||
|
warmupRounds, timingRounds
|
||||||
|
);
|
||||||
|
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc));
|
||||||
|
checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc));
|
||||||
|
|
||||||
|
// Do memcpy device to host
|
||||||
|
time_memcpy_dtoh += timeit(
|
||||||
|
[&]() {
|
||||||
|
outputCpu = outputGpu->clone(cpuRuntime);
|
||||||
|
},
|
||||||
|
[&]() { cudaRuntime->sync(); },
|
||||||
|
warmupRounds, timingRounds
|
||||||
|
);
|
||||||
|
|
||||||
|
// Print Results
|
||||||
|
printf("Operator - Softmax:\n");
|
||||||
|
printf("Input shape: (%d, %d, %d, %d)\n",
|
||||||
|
INPUT_SHAPE[0], INPUT_SHAPE[1], INPUT_SHAPE[2], INPUT_SHAPE[3]);
|
||||||
|
printf("Input size: %ld, dtype: %s, size in bytes: %ld\n",
|
||||||
|
size, dtype.toString().c_str(), sizeInBytes);
|
||||||
|
|
||||||
|
printf("TFlops: %.5lf tflops\n", 5 * size / 1e9 / time_op);
|
||||||
|
printf("Memcpy time: h2d - %.6lf ms, d2h - %.6lf ms\n",
|
||||||
|
time_memcpy_htod, time_memcpy_dtoh);
|
||||||
|
printf("Memcpy throughput: h2d - %.6lf MB/ms, d2h: %.6lf MB/ms\n",
|
||||||
|
sizeInBytes / M / time_memcpy_htod, sizeInBytes / M / time_memcpy_dtoh);
|
||||||
|
printf("Operation: %.6lf ms\n", time_op);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,4 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/common.h"
|
||||||
|
#include "core/tensor_base.h"
|
||||||
|
#include "utils/data_generator.h"
|
Loading…
Reference in New Issue