forked from jiuyuan/InfiniTensor
Add: hash match for membound kernels
This commit is contained in:
parent
6d17c4caa2
commit
94730d93b5
|
@ -0,0 +1,14 @@
|
|||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
|
||||
void conv2dreduce_kernel(float *input, float *bias, float *output, bool PReLU,
|
||||
int n, int h, int w, int f, int r, int s, int oh,
|
||||
int ow, int ph, int pw, int sh, int sw, int dh,
|
||||
int dw);
|
||||
|
||||
void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
|
||||
int act, int n, int h, int w, int f, int r,
|
||||
int s, int oh, int ow, int ph, int pw, int sh,
|
||||
int sw, int dh, int dw);
|
||||
} // namespace infini
|
|
@ -65,6 +65,6 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
void tune(const Graph &graph, bool profiling) const;
|
||||
|
||||
void beginCudaGraphStreamCapture();
|
||||
cudaGraphExec_t endCudaGraphStreamCapture();
|
||||
tuple<cudaGraphExec_t, size_t> endCudaGraphStreamCapture();
|
||||
};
|
||||
} // namespace infini
|
||||
|
|
|
@ -32,6 +32,7 @@ class MemBoundObj : public OperatorObj {
|
|||
int numOutputs() const override { return outputs.size(); }
|
||||
const vector<nnet::Tensor> &getNnetInputs() const { return nnetInputs; }
|
||||
const nnet::Expr getNnetExpr() const { return expr; }
|
||||
HashType getHash() const { return hash; }
|
||||
pair<const nnet::Expr, HashType> getSimplifiedNnetExpr() const {
|
||||
return {expr, hash};
|
||||
}
|
||||
|
|
|
@ -863,6 +863,7 @@ class OnnxStub:
|
|||
name,
|
||||
domain="nnet",
|
||||
expr=backend.membound_expr_of(op),
|
||||
hash=str(backend.membound_hash_of(op)),
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -55,7 +55,6 @@ Graph SearchEngine::run(const Graph graph) {
|
|||
std::vector<Graph> candidates = search(subGraph);
|
||||
std::cout << "[INFO] size: " << candidates.size() << std::endl;
|
||||
IT_ASSERT(candidates.size() > 0);
|
||||
std::cout << subGraph->toString() << std::endl;
|
||||
std::vector<Graph> nextGraphs;
|
||||
for (auto lastGraph : bestGraphs) {
|
||||
for (auto thisGraph : candidates) {
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda_profiler_api.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
namespace infini {
|
||||
|
@ -32,18 +33,20 @@ CudaRuntimeObj::~CudaRuntimeObj() {
|
|||
void CudaRuntimeObj::beginCudaGraphStreamCapture() {
|
||||
enum cudaStreamCaptureStatus pCaptureStatus;
|
||||
checkCudaError(cudaStreamIsCapturing(stream, &pCaptureStatus));
|
||||
dbg(pCaptureStatus);
|
||||
IT_ASSERT(pCaptureStatus == cudaStreamCaptureStatusNone);
|
||||
cudaGraphStatus = true;
|
||||
checkCudaError(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
|
||||
}
|
||||
|
||||
cudaGraphExec_t CudaRuntimeObj::endCudaGraphStreamCapture() {
|
||||
tuple<cudaGraphExec_t, size_t> CudaRuntimeObj::endCudaGraphStreamCapture() {
|
||||
cudaGraph_t cudaGraph;
|
||||
cudaGraphExec_t instance;
|
||||
checkCudaError(cudaStreamEndCapture(stream, &cudaGraph));
|
||||
cudaGraphStatus = false;
|
||||
size_t numCudaGraphNodes;
|
||||
checkCudaError(cudaGraphGetNodes(cudaGraph, nullptr, &numCudaGraphNodes));
|
||||
checkCudaError(cudaGraphInstantiate(&instance, cudaGraph, NULL, NULL, 0));
|
||||
return instance;
|
||||
return {instance, numCudaGraphNodes};
|
||||
}
|
||||
|
||||
void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
|
||||
|
@ -135,13 +138,8 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph) {
|
|||
kernel->compute(op, perfData, this);
|
||||
else
|
||||
kernel->compute(op, this);
|
||||
// if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Reshape)
|
||||
// if (op->getOpType() == OpType::Matmul)
|
||||
// if (op->getOpType() == OpType::Matmul ||
|
||||
// op->getOpType() == OpType::Relu
|
||||
// // || op->getOpType() == OpType::MemBound
|
||||
// )
|
||||
kernels.emplace_back(op, kernel, perfData);
|
||||
if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Reshape)
|
||||
kernels.emplace_back(op, kernel, perfData);
|
||||
}
|
||||
for (auto &[op, kernel, perfData] : kernels) {
|
||||
dbg(op);
|
||||
|
@ -154,9 +152,12 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph) {
|
|||
else
|
||||
kernel->compute(op, this);
|
||||
}
|
||||
auto cudaGraphInstance = endCudaGraphStreamCapture();
|
||||
auto [cudaGraphInstance, numCudaGraphNodes] = endCudaGraphStreamCapture();
|
||||
IT_ASSERT(numCudaGraphNodes == kernels.size(),
|
||||
std::to_string(numCudaGraphNodes) +
|
||||
" != " + std::to_string(kernels.size()));
|
||||
return timeit(
|
||||
[&, stream = getStream()]() {
|
||||
[&, cudaGraphInstance = cudaGraphInstance, stream = getStream()]() {
|
||||
checkCudaError(cudaGraphLaunch(cudaGraphInstance, stream));
|
||||
},
|
||||
[&, stream = getStream()]() { cudaStreamSynchronize(stream); }, 1000,
|
||||
|
|
|
@ -263,7 +263,9 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(concat_axis_of)
|
||||
.FUNCTION(split_axis_of)
|
||||
.FUNCTION(gather_axis_of)
|
||||
.FUNCTION(membound_expr_of);
|
||||
.FUNCTION(membound_expr_of)
|
||||
.def("membound_hash_of",
|
||||
[](Operator op) { return as<MemBoundObj>(op)->getHash(); });
|
||||
#undef FUNCTION
|
||||
}
|
||||
|
||||
|
@ -283,7 +285,8 @@ void init_graph_builder(py::module &m) {
|
|||
RuntimeObj>(m, "CpuRuntime");
|
||||
#ifdef USE_CUDA
|
||||
py::class_<CudaRuntimeObj, Ref<CudaRuntimeObj>, RuntimeObj>(m,
|
||||
"CudaRuntime");
|
||||
"CudaRuntime")
|
||||
.def("timeWithCudaGraph", &CudaRuntimeObj::timeWithCudaGraph);
|
||||
#endif
|
||||
#ifdef USE_BANG
|
||||
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
|
||||
using dtype = float;
|
||||
|
||||
__global__ void conv2dreduce_kernel_(float *__restrict__ input,
|
||||
float *__restrict__ bias,
|
||||
float *__restrict__ output,
|
||||
const bool PReLU, const int n, const int f,
|
||||
const int h, const int w, const int oh,
|
||||
const int ow, const int r, const int s,
|
||||
const int ph, const int pw, const int dh,
|
||||
const int dw, const int sh, const int sw) {
|
||||
// output shape: (n, oh, ow, f)
|
||||
// input shape: (n, h, w, f, r, s)
|
||||
int nid = blockIdx.x, fid = blockIdx.y;
|
||||
int hid = threadIdx.x, wid = threadIdx.y;
|
||||
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk,
|
||||
nchunck = n * hchunk;
|
||||
float *nfinput = input + nid * nchunck + fid * fchunck;
|
||||
if (nid < n && fid < f && hid < oh && wid < ow) {
|
||||
float imm = 0.0;
|
||||
int ihst = hid * sh - ph;
|
||||
int iwst = wid * sw - pw;
|
||||
for (int ri = 0; ri < r; ++ri) {
|
||||
for (int si = 0; si < s; ++si) {
|
||||
int ihid = ihst + ri * dh;
|
||||
int iwid = iwst + si * dw;
|
||||
if (ihid >= 0 && ihid < h && iwid >= 0 && iwid < w) {
|
||||
imm += *(nfinput + ihid * hchunk + iwid * wchunk + ri * s +
|
||||
si);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bias) {
|
||||
imm += bias[fid];
|
||||
}
|
||||
if (PReLU) {
|
||||
imm = imm > 0.0 ? imm : 0.0;
|
||||
}
|
||||
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void convTranspose2dreduce_kernel_(
|
||||
float *__restrict__ input, float *__restrict__ bias,
|
||||
float *__restrict__ output, const bool PReLU, const int n, const int f,
|
||||
const int h, const int w, const int oh, const int ow, const int r,
|
||||
const int s, const int ph, const int pw, const int dh, const int dw,
|
||||
const int sh, const int sw) {
|
||||
// assert dh = dw = 1
|
||||
int nid = blockIdx.x, fid = blockIdx.y;
|
||||
int hid = threadIdx.x, wid = threadIdx.y;
|
||||
const int fchunck = r * s, wchunk = f * fchunck, hchunk = w * wchunk,
|
||||
nchunck = n * hchunk;
|
||||
float *nfinput = input + nid * nchunck + fid * fchunck;
|
||||
// view as conv, the true ph and pw
|
||||
int tph = r - ph - 1, tpw = s - pw - 1;
|
||||
int th = (h - 1) * sh + 1, tw = (w - 1) * sw + 1;
|
||||
if (nid < n && fid < f && hid < oh && wid < ow) {
|
||||
float imm = 0.0;
|
||||
int ihst = hid - tph;
|
||||
int iwst = wid - tpw;
|
||||
for (int ri = 0; ri < r; ++ri) {
|
||||
for (int si = 0; si < s; ++si) {
|
||||
int ihid = ihst + r - ri - 1;
|
||||
int iwid = iwst + s - si - 1;
|
||||
if (ihid >= 0 && ihid < th && iwid >= 0 && iwid < tw &&
|
||||
(ihid % sh == 0) && (iwid % sw == 0)) {
|
||||
imm += *(nfinput + (ihid / sh) * hchunk +
|
||||
(iwid / sw) * wchunk + ri * s + si);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bias) {
|
||||
imm += bias[fid];
|
||||
}
|
||||
if (PReLU) {
|
||||
imm = imm > 0.0 ? imm : 0.0;
|
||||
}
|
||||
output[nid * (oh * ow * f) + hid * (ow * f) + wid * f + fid] = imm;
|
||||
}
|
||||
}
|
||||
|
||||
// nhwrsc -> nhwc
|
||||
__global__ void reduce_4x4(dtype *in, dtype *out, int act, const int N,
|
||||
const int F, const int H, const int W, const int IH,
|
||||
const int IW) {
|
||||
// #define in_index(n, h, w, r, s, f) \
|
||||
// ((((((n)*IH + h) * IW + w) * R + r) * S + s) * F + f)
|
||||
#define in_index(n, h, w, f, r, s) \
|
||||
((((((n)*IH + h) * IW + w) * F + f) * R + r) * S + s)
|
||||
#define out_index(n, h, w, f) (((((n)*H) + (h)) * W + (w)) * F + (f))
|
||||
const int R = 4, S = 4;
|
||||
const int n_tasks = N * F * H * W;
|
||||
int start = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = start; i < n_tasks; i += stride) {
|
||||
int t = i, n, f, h, w;
|
||||
f = t % F;
|
||||
t /= F;
|
||||
w = t % W;
|
||||
t /= W;
|
||||
h = t % H;
|
||||
t /= H;
|
||||
n = t;
|
||||
|
||||
// unroll this 2-iter loop
|
||||
float sum = 0;
|
||||
int x, y;
|
||||
for (int r = (h + 1) & 1; r < R; r += 2) {
|
||||
x = (h + 1 - r) / 2;
|
||||
if (x >= 0 && x < IH) {
|
||||
for (int s = (w + 1) & 1; s < S; s += 2) {
|
||||
y = (w + 1 - s) / 2;
|
||||
if (y >= 0 && y < IW) {
|
||||
sum += in[in_index(n, x, y, f, r, s)];
|
||||
// if (i==0)
|
||||
// printf("TTT nhwf= %d,%d,%d,%d x=%d y=%d, v=%f,
|
||||
// index=%d, rsf %d %d %d\n", n, h, w,
|
||||
// f, x, y, in[in_index(n, x, y, r, s, f)],
|
||||
// in_index(n, x, y, r, s, f), r,s,f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (act == 0) {
|
||||
out[out_index(n, h, w, f)] = sum;
|
||||
} else if (act == 1) { // Relu
|
||||
out[out_index(n, h, w, f)] = sum > 0 ? sum : 0;
|
||||
} else if (act == 2) {
|
||||
out[out_index(n, h, w, f)] = tanhf(sum);
|
||||
}
|
||||
}
|
||||
#undef in_index
|
||||
#undef out_index
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
|
||||
void conv2dreduce_kernel(float *input, float *bias, float *output, bool PReLU,
|
||||
int n, int h, int w, int f, int r, int s, int oh,
|
||||
int ow, int ph, int pw, int sh, int sw, int dh,
|
||||
int dw) {
|
||||
dim3 grid(n, f);
|
||||
dim3 block(oh, ow);
|
||||
// cudaStream_t stream(cudaStreamPerThread);
|
||||
conv2dreduce_kernel_<<<grid, block, 0>>>(input, bias, output, PReLU, n, f,
|
||||
h, w, oh, ow, r, s, ph, pw, dh, dw,
|
||||
sh, sw);
|
||||
}
|
||||
|
||||
void convTranspose2dreduce_kernel(float *input, float *bias, float *output,
|
||||
int act, int n, int h, int w, int f, int r,
|
||||
int s, int oh, int ow, int ph, int pw, int sh,
|
||||
int sw, int dh, int dw) {
|
||||
dim3 grid(n, f);
|
||||
dim3 block(oh, ow);
|
||||
// cudaStream_t stream(cudaStreamPerThread);
|
||||
// puts("convTranspose2dreduce_kernel is executed");
|
||||
if (r == 4 && s == 4 && sh == 2 && sw == 2) {
|
||||
const int M = r * s * f, N = n * h * w;
|
||||
reduce_4x4<<<(M * N + 127) / 128, 128>>>(input, output, act, n, f, oh,
|
||||
ow, h, w);
|
||||
} else {
|
||||
puts("why use this conv2dreduce");
|
||||
convTranspose2dreduce_kernel_<<<grid, block, 0>>>(
|
||||
input, bias, output, (bool)act, n, f, h, w, oh, ow, r, s, ph, pw,
|
||||
dh, dw, sh, sw);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
|
@ -1,5 +1,6 @@
|
|||
#ifdef INFINI_USE_TVM
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_conv2dreduce.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "dlpack/dlpack.h"
|
||||
#include "ffi/ffi_embed.h"
|
||||
|
@ -13,6 +14,7 @@
|
|||
#include <sys/types.h>
|
||||
#include <sys/wait.h>
|
||||
#include <unistd.h>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
@ -29,6 +31,7 @@ class TVMRecordObj : public PerfRecordObj {
|
|||
std::string funcName;
|
||||
std::vector<int> inputIdx;
|
||||
tvm::runtime::PackedFunc packedFunc;
|
||||
bool useExistingKernel = false;
|
||||
};
|
||||
|
||||
using TVMRecord = Ref<TVMRecordObj>;
|
||||
|
@ -40,6 +43,14 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
auto op = as<MemBoundObj>(_op);
|
||||
// auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto tvmRecord = std::dynamic_pointer_cast<TVMRecordObj>(record);
|
||||
|
||||
// Use user-defined kernels
|
||||
if (tvmRecord->useExistingKernel) {
|
||||
bool success = useExistingKernels(op);
|
||||
IT_ASSERT(success);
|
||||
return;
|
||||
}
|
||||
|
||||
tvm::runtime::PackedFunc packedFunc = tvmRecord->packedFunc;
|
||||
|
||||
// prepare inputs and outputs
|
||||
|
@ -68,10 +79,18 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
TVMRecord ret = std::make_shared<TVMRecordObj>();
|
||||
auto op = as<MemBoundObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
||||
// If hash matches, use user-defined kernels
|
||||
if (useExistingKernels(op)) {
|
||||
TVMRecord ret = std::make_shared<TVMRecordObj>();
|
||||
ret->time = timeit([&]() { useExistingKernels(op); },
|
||||
[&]() { context->sync(); });
|
||||
ret->useExistingKernel = true;
|
||||
return ret;
|
||||
}
|
||||
|
||||
// invoke Ansor to tune a membound kernel
|
||||
auto [expr, hash] = op->getSimplifiedNnetExpr();
|
||||
nnet::AsTVMVisitor visitor;
|
||||
|
@ -120,6 +139,7 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
tvm::runtime::TVMArgs args(preArgs.first.data(), preArgs.second.data(),
|
||||
preArgs.first.size());
|
||||
|
||||
TVMRecord ret = std::make_shared<TVMRecordObj>();
|
||||
ret->time = timeit([&]() { packedFunc.CallPacked(args, &rv); },
|
||||
[&]() { context->sync(); });
|
||||
ret->kernelName = kernelName;
|
||||
|
@ -128,7 +148,7 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
ret->inputIdx = inputIdx;
|
||||
ret->packedFunc = packedFunc;
|
||||
|
||||
return std::dynamic_pointer_cast<PerfRecordObj>(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string serializeTVMArgs(const std::vector<std::vector<int>> &inDims,
|
||||
|
@ -262,6 +282,34 @@ class MemboundTVMPackedFunction : public Kernel {
|
|||
|
||||
return {values, type_codes};
|
||||
}
|
||||
|
||||
bool useExistingKernels(Ref<MemBoundObj> op) const {
|
||||
const map<HashType, tuple<int, int, int, int, int, int, int, int, int,
|
||||
int, int, int, int, int, int>>
|
||||
hashMap = {
|
||||
// clang-format off
|
||||
{18446744073661354550ULL, {1, 1, 2, 2, 256, 4, 4, 4, 4, 1, 1, 2, 2, 1, 1}},
|
||||
{124145340ULL, {1, 1, 4, 4, 128, 4, 4, 8, 8, 1, 1, 2, 2, 1, 1}},
|
||||
{18446744073695718019ULL, {1, 1, 8, 8, 64, 4, 4, 16, 16, 1, 1, 2, 2, 1, 1}},
|
||||
{515085072ULL, {2, 1, 16, 16, 3, 4, 4, 32, 32, 1, 1, 2, 2, 1, 1}}
|
||||
}; // clang-format on
|
||||
float *input = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
float *bias = nullptr;
|
||||
float *output = op->getOutput()->getRawDataPtr<float *>();
|
||||
if (auto it = hashMap.find(op->getHash()); it != hashMap.end()) {
|
||||
auto &[PReLU, n, h, w, f, r, s, oh, ow, ph, pw, sh, sw, dh, dw] =
|
||||
it->second;
|
||||
IT_ASSERT(op->getInputs(0)->size() ==
|
||||
size_t(n) * h * w * f * r * s);
|
||||
IT_ASSERT(op->getOutput()->size() == size_t(n) * oh * ow * f);
|
||||
convTranspose2dreduce_kernel(input, bias, output, PReLU, n, h, w, f,
|
||||
r, s, oh, ow, ph, pw, sh, sw, dh, dw);
|
||||
return true;
|
||||
}
|
||||
// conv2dreduce_kernel(input, bias, output, PReLU, n, h, w, f, r, s,
|
||||
// oh, ow, ph, pw, sh, sw, dh, dw);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::MemBound, DataType::Float32,
|
||||
|
|
|
@ -115,7 +115,7 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning) {
|
|||
runtime->run(g);
|
||||
dbg("Baseline graph");
|
||||
printGraph(g);
|
||||
dbg(runtime->getPerfTime(g, true));
|
||||
// dbg(runtme->getPerfTime(g, true));
|
||||
|
||||
for (size_t i = 0; i < bestGraphs.size(); i++) {
|
||||
auto bestGraphCpu = bestGraphs[i];
|
||||
|
@ -235,7 +235,5 @@ vector<Tensor> runInfoGAN(int nLayers) {
|
|||
return {};
|
||||
}
|
||||
|
||||
// TEST(ModelE2E, InfoGAN) { runInfoGAN(); }
|
||||
|
||||
} // namespace infini
|
||||
#endif
|
||||
|
|
|
@ -78,23 +78,26 @@ def runSingleConvT():
|
|||
ft.if_onnx.export_onnx(opt_g, 'convtransposed.onnx')
|
||||
|
||||
|
||||
def run_InfoGAN_without_tuning(tuning: bool):
|
||||
runtime = ft.cuda_runtime()
|
||||
def run_InfoGAN_without_tuning(runtime, tuning: bool):
|
||||
g = ft.getInfoGAN(1, runtime, 5)
|
||||
# g = ft.getInfoGAN(1, runtime, 1)
|
||||
opt_g = ft.optimizeGraph(g, runtime, tuning)
|
||||
stub = OnnxStub.from_graph(opt_g)
|
||||
with open("optimized.onnx", "wb") as f:
|
||||
f.write(stub.to_onnx("optimized").SerializeToString())
|
||||
return opt_g
|
||||
|
||||
|
||||
def load_onnx_and_run():
|
||||
runtime = ft.cuda_runtime()
|
||||
def load_onnx(runtime) -> ft.Graph:
|
||||
stub = OnnxStub.from_onnx(onnx.load("optimized.onnx"), runtime, False)
|
||||
g = stub.handler.getGraph()
|
||||
return stub.handler.getGraph()
|
||||
|
||||
|
||||
def run_and_evaluate(runtime, g):
|
||||
runtime.run(g, True)
|
||||
print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}')
|
||||
print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 1000, 1000)}')
|
||||
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g)}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -102,5 +105,9 @@ if __name__ == "__main__":
|
|||
# runSingleConvT()
|
||||
# read_and_check()
|
||||
|
||||
# run_InfoGAN_without_tuning(False)
|
||||
load_onnx_and_run()
|
||||
runtime = ft.cuda_runtime()
|
||||
if True:
|
||||
g = run_InfoGAN_without_tuning(runtime, False)
|
||||
else:
|
||||
g = load_onnx(runtime)
|
||||
run_and_evaluate(runtime, g)
|
||||
|
|
Loading…
Reference in New Issue