forked from jiuyuan/InfiniTensor
Merge branch 'master' into dev-onnx
This commit is contained in:
commit
0f52d04882
|
@ -39,17 +39,18 @@ using HashType = uint64_t; // compatible with std::hash
|
|||
#define _VA_SELECT(NAME, ...) _SELECT(NAME, _VA_SIZE(__VA_ARGS__))(__VA_ARGS__)
|
||||
|
||||
// Assert: conditions should have no side effect
|
||||
#define _IT_ASSERT_2(name, info) \
|
||||
(static_cast<bool>(name) \
|
||||
#define _IT_ASSERT_2(condition, info) \
|
||||
(static_cast<bool>(condition) \
|
||||
? void(0) \
|
||||
: throw ::infini::Exception( \
|
||||
std::string("[") + __FILE__ + ":" + std::to_string(__LINE__) + \
|
||||
"] Assertion failed (" + #name + "): " + info))
|
||||
#define _IT_ASSERT_1(name) _IT_ASSERT_2(name, "");
|
||||
"] Assertion failed (" + #condition + "): " + info))
|
||||
#define _IT_ASSERT_1(condition) _IT_ASSERT_2(condition, "");
|
||||
#define IT_ASSERT(...) _VA_SELECT(_IT_ASSERT, __VA_ARGS__)
|
||||
|
||||
#define IT_TODO_HALT() _IT_ASSERT_2(false, "Unimplemented")
|
||||
#define IT_TODO_HALT_MSG(msg) _IT_ASSERT_2(false, msg)
|
||||
#define IT_ASSERT_TODO(condition) _IT_ASSERT_2(condition, "Unimplemented")
|
||||
#define IT_TODO_SKIP() puts("Unimplemented " __FILE__ ":" __LINE__)
|
||||
|
||||
// Other utilities
|
||||
|
|
|
@ -9,6 +9,7 @@ enum class OpType {
|
|||
Conv = 100,
|
||||
Matmul,
|
||||
ConvTrans,
|
||||
ConvTransNHWC,
|
||||
G2BMM,
|
||||
GBMM,
|
||||
Pad,
|
||||
|
@ -84,6 +85,7 @@ class OpRegistry {
|
|||
FOP(Sigmoid);
|
||||
FOP(Tanh);
|
||||
FOP(Abs);
|
||||
FOP(ConvTransNHWC);
|
||||
//
|
||||
FOP(MemBound);
|
||||
default:
|
||||
|
|
|
@ -47,13 +47,7 @@ class TensorObj : public TensorBaseObj {
|
|||
void copyData(const TensorObj *src);
|
||||
void copyData(const Tensor &src) { copyData(src.get()); }
|
||||
void setData(
|
||||
const std::function<void(void *, size_t, DataType)> &generator) const {
|
||||
IT_ASSERT(data != nullptr);
|
||||
if (!runtime->isCpu()) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
generator(data->getPtr<void *>(), size(), dtype);
|
||||
}
|
||||
const std::function<void(void *, size_t, DataType)> &generator) const;
|
||||
Tensor clone() const {
|
||||
auto obj = make_ref<TensorObj>(*this);
|
||||
obj->freeData();
|
||||
|
|
|
@ -49,6 +49,8 @@ class ConvBaseObj : public OperatorObj {
|
|||
int f; // output/input channel for conv2d/convTransposed2d
|
||||
int r, s; // weight shape
|
||||
|
||||
ActType act;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new ConvBase object by explicitly setting padding
|
||||
|
@ -70,7 +72,7 @@ class ConvBaseObj : public OperatorObj {
|
|||
*/
|
||||
ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output, int ph, int pw,
|
||||
int sh, int sw, int dh, int dw, const Tensor &inputInConvFWD,
|
||||
const Tensor &weightInConvFWD);
|
||||
const Tensor &weightInConvFWD, ActType act = ActType::None);
|
||||
/**
|
||||
* @brief Construct a new ConvBase object by setting padding mode.
|
||||
*
|
||||
|
@ -89,7 +91,8 @@ class ConvBaseObj : public OperatorObj {
|
|||
*/
|
||||
ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw,
|
||||
const Tensor &inputInConvFWD, const Tensor &weightInConvFWD);
|
||||
const Tensor &inputInConvFWD, const Tensor &weightInConvFWD,
|
||||
ActType act = ActType::None);
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
|
@ -107,7 +110,14 @@ class ConvBaseObj : public OperatorObj {
|
|||
int getSw() const { return sw; }
|
||||
auto getNCHWFRS() const { return tuple(n, c, h, w, f, r, s); }
|
||||
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
||||
int getChannelPerGroup() const { return inputs[1]->getDims()[1]; }
|
||||
int getChannelPerGroup() const {
|
||||
if (type == OpType::ConvTransNHWC) {
|
||||
return inputs[1]->getDims()[3];
|
||||
} else {
|
||||
return inputs[1]->getDims()[1];
|
||||
}
|
||||
}
|
||||
ActType getAct() const { return act; }
|
||||
virtual int getNumGroups() const = 0;
|
||||
|
||||
private:
|
||||
|
@ -121,9 +131,6 @@ class ConvBaseObj : public OperatorObj {
|
|||
};
|
||||
|
||||
class ConvObj : public ConvBaseObj {
|
||||
private:
|
||||
ActType act;
|
||||
|
||||
public:
|
||||
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph,
|
||||
int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
|
||||
|
@ -136,7 +143,6 @@ class ConvObj : public ConvBaseObj {
|
|||
OP_CLONE(ConvObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
ActType getAct() const { return act; }
|
||||
int getNumGroups() const override { return c / getChannelPerGroup(); }
|
||||
|
||||
private:
|
||||
|
@ -147,7 +153,6 @@ class ConvTransposed2dObj : public ConvBaseObj {
|
|||
private:
|
||||
int oph, opw;
|
||||
int group;
|
||||
ActType act;
|
||||
|
||||
public:
|
||||
ConvTransposed2dObj(GraphObj *graph, Tensor input, Tensor weight,
|
||||
|
@ -164,7 +169,32 @@ class ConvTransposed2dObj : public ConvBaseObj {
|
|||
OP_CLONE(ConvTransposed2dObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
ActType getAct() const { return act; }
|
||||
int getNumGroups() const override { return group; }
|
||||
|
||||
private:
|
||||
void setAuxilaryAttributes(PaddingMode mode) override;
|
||||
};
|
||||
|
||||
class ConvTransposed2dNHWCObj : public ConvBaseObj {
|
||||
private:
|
||||
int oph, opw;
|
||||
int group;
|
||||
|
||||
public:
|
||||
ConvTransposed2dNHWCObj(GraphObj *graph, Tensor input, Tensor weight,
|
||||
Tensor output, int ph, int pw, int sh = 1,
|
||||
int sw = 1, int dh = 1, int dw = 1, int oph = 0,
|
||||
int opw = 0, int group = 1, Tensor bias = nullptr,
|
||||
ActType act = ActType::None);
|
||||
// Constructors for setting padding mode
|
||||
ConvTransposed2dNHWCObj(GraphObj *graph, Tensor input, Tensor weight,
|
||||
Tensor output, PaddingMode mode = PaddingMode::Same,
|
||||
int sh = 1, int sw = 1, int dh = 1, int dw = 1,
|
||||
int oph = 0, int opw = 0, int group = 1,
|
||||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
OP_CLONE(ConvTransposed2dNHWCObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
int getNumGroups() const override { return group; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -165,6 +165,22 @@ void TensorObj::copyData(const TensorObj *src) {
|
|||
runtime->copyBlob(this, src);
|
||||
}
|
||||
|
||||
void TensorObj::setData(
|
||||
const std::function<void(void *, size_t, DataType)> &generator) const {
|
||||
IT_ASSERT(data != nullptr);
|
||||
if (runtime->isCpu()) {
|
||||
generator(getRawDataPtr<void *>(), size(), dtype);
|
||||
} else {
|
||||
// Create a CPU buffer for the generetor and copy results to the device
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
size_t nBytes = size() * dtype.getSize();
|
||||
Blob buffer = cpuRuntime->allocBlob(nBytes);
|
||||
generator(buffer->getPtr<void *>(), size(), dtype);
|
||||
runtime->copyBlobFromCPU(getRawDataPtr<void *>(),
|
||||
buffer->getPtr<void *>(), nBytes);
|
||||
}
|
||||
}
|
||||
|
||||
void TensorObj::load(std::string file_path) { loadTensorData(this, file_path); }
|
||||
|
||||
void TensorObj::save(std::string file_path) { saveTensorData(this, file_path); }
|
||||
|
|
|
@ -26,6 +26,7 @@ static const cudnnConvolutionBwdDataAlgo_t ALGOS[N_ALGO] = {
|
|||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED};
|
||||
static const char algo_name[N_ALGO][50] = {
|
||||
// only first two can be used for NHWC format
|
||||
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_0", /* non-deterministic */
|
||||
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_1",
|
||||
"CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT",
|
||||
|
@ -46,7 +47,7 @@ class convBackwardDataCudnn : public Kernel {
|
|||
cudnnConvolutionDescriptor_t, cudnnActivationDescriptor_t,
|
||||
cudnnTensorDescriptor_t>
|
||||
createCuDNNDescriptor(
|
||||
const Ref<ConvTransposed2dObj> &op,
|
||||
const Ref<ConvBaseObj> &op,
|
||||
const ConvTransposedCuDnnPerfRecordObj &record) const {
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
|
@ -62,23 +63,27 @@ class convBackwardDataCudnn : public Kernel {
|
|||
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
// IT_ASSERT(g == 1, "Group convolution is not supported yet");
|
||||
|
||||
// set input format
|
||||
cudnnTensorFormat_t tensorFormat =
|
||||
(op->getOpType() == OpType::ConvTransNHWC) ? CUDNN_TENSOR_NHWC
|
||||
: CUDNN_TENSOR_NCHW;
|
||||
|
||||
// get inputs
|
||||
cudnnTensorDescriptor_t inDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, f, h, w));
|
||||
inDesc, tensorFormat, CUDNN_DATA_FLOAT, n, f, h, w));
|
||||
|
||||
// get kernels
|
||||
cudnnFilterDescriptor_t knDesc;
|
||||
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
|
||||
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
|
||||
CUDNN_TENSOR_NCHW, f,
|
||||
channelsPerGrp, r, s));
|
||||
checkCudnnError(cudnnSetFilter4dDescriptor(
|
||||
knDesc, CUDNN_DATA_FLOAT, tensorFormat, f, channelsPerGrp, r, s));
|
||||
// get bias
|
||||
cudnnTensorDescriptor_t biasDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1));
|
||||
biasDesc, tensorFormat, CUDNN_DATA_FLOAT, 1, f, 1, 1));
|
||||
|
||||
// get convlution descriptor
|
||||
cudnnConvolutionDescriptor_t convDesc;
|
||||
|
@ -115,16 +120,27 @@ class convBackwardDataCudnn : public Kernel {
|
|||
}
|
||||
|
||||
const auto &outputShape = op->getOutput()->getDims();
|
||||
int on, oh, ow, oc;
|
||||
if (op->getOpType() == OpType::ConvTransNHWC) {
|
||||
on = outputShape[0];
|
||||
oh = outputShape[1];
|
||||
ow = outputShape[2];
|
||||
oc = outputShape[3];
|
||||
} else {
|
||||
on = outputShape[0];
|
||||
oh = outputShape[2];
|
||||
ow = outputShape[3];
|
||||
oc = outputShape[1];
|
||||
}
|
||||
cudnnTensorDescriptor_t outDesc;
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
outDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, outputShape[0],
|
||||
outputShape[1], outputShape[2], outputShape[3]));
|
||||
outDesc, tensorFormat, CUDNN_DATA_FLOAT, on, oc, oh, ow));
|
||||
return tuple(inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc);
|
||||
}
|
||||
|
||||
bool cuDNNUnfused(const Ref<ConvTransposed2dObj> &op,
|
||||
bool cuDNNUnfused(const Ref<ConvBaseObj> &op,
|
||||
const ConvTransposedCuDnnPerfRecordObj &record,
|
||||
const CudaRuntimeObj *context) const {
|
||||
cudnnStatus_t stat;
|
||||
|
@ -211,12 +227,14 @@ class convBackwardDataCudnn : public Kernel {
|
|||
ConvTransposedCuDnnPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto op = as<ConvTransposed2dObj>(_op);
|
||||
auto op = as<ConvBaseObj>(_op);
|
||||
// Both modes have the same performance. Only run
|
||||
// cross-correlation.
|
||||
int algo_to_run =
|
||||
(op->getOpType() == OpType::ConvTransNHWC) ? 2 : N_ALGO;
|
||||
for (int mode = 1; mode < 2; mode++) {
|
||||
// Try every possible algorithm of convolution
|
||||
for (int algo = 0; algo < N_ALGO; algo++) {
|
||||
for (int algo = 0; algo < algo_to_run; algo++) {
|
||||
ConvTransposedCuDnnPerfRecordObj record;
|
||||
record.mode = mode;
|
||||
record.algo = algo;
|
||||
|
@ -274,7 +292,7 @@ class convBackwardDataCudnn : public Kernel {
|
|||
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvTransposed2dObj>(_op);
|
||||
auto op = as<ConvBaseObj>(_op);
|
||||
auto record = as<ConvTransposedCuDnnPerfRecordObj>(_record);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
bool success = cuDNNUnfused(op, *record, context);
|
||||
|
@ -284,5 +302,6 @@ class convBackwardDataCudnn : public Kernel {
|
|||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ConvTrans, DataType::Float32,
|
||||
convBackwardDataCudnn, "ConvTranposed_cuDNN_CUDA_Float32");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::ConvTransNHWC, DataType::Float32,
|
||||
convBackwardDataCudnn, "ConvTranposedNHWC_cuDNN_CUDA_Float32");
|
||||
} // namespace infini
|
||||
|
|
|
@ -245,26 +245,26 @@ nnet::Expr NMutator::opToExpression(Operator op) {
|
|||
std::vector<int>{0, 0, ph, pw});
|
||||
const auto K = nnet::makeTensor("K", KT->getDims());
|
||||
return nnet::ConvPattern::getExpr(A, K, n, c, h, w, f, r, s);
|
||||
// } else if (auto convOp = dynamic_cast<ConvTransOp *>(op)) {
|
||||
// const auto &AT = convOp->getInputs()[0];
|
||||
// const auto &KT = convOp->getInputs()[1];
|
||||
// inputsNameNToTensorT["A"] = AT;
|
||||
// inputsNameNToTensorT["K"] = KT;
|
||||
// const auto &[n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, g, bi,
|
||||
// ac]
|
||||
// =
|
||||
// convOp->getArgs(0);
|
||||
// if (r != 4) {
|
||||
// dbg("ConvTranspose R!=4. Skipped.", r);
|
||||
// return nullptr;
|
||||
// }
|
||||
// int padding = 1 * (r - 1) - 1;
|
||||
// const auto A = nnet::makeTensor(
|
||||
// "A", AT->getDims(), std::vector<int>{0, padding, padding,
|
||||
// 0});
|
||||
// const auto K = nnet::makeTensor("K", KT->getDims());
|
||||
// return nnet::ConvTransPattern::getExpr(A, K, n, c, h, w, f, r,
|
||||
// s);
|
||||
} else if (auto convOp = as<ConvTransposed2dObj>(op)) {
|
||||
const auto &AT = convOp->getInputs()[0];
|
||||
const auto &KT = convOp->getInputs()[1];
|
||||
inputsNameNToTensorT["A"] = AT;
|
||||
inputsNameNToTensorT["K"] = KT;
|
||||
const auto &[n, c, h, w, f, r, s] = convOp->getNCHWFRS();
|
||||
const auto &[ph, pw, sh, sw, dh, dw] = convOp->getPadStrideDilation();
|
||||
IT_ASSERT_TODO(convOp->getNumGroups() == 1);
|
||||
IT_ASSERT_TODO(r == 4);
|
||||
IT_ASSERT_TODO(ph == pw);
|
||||
IT_ASSERT_TODO(tie(sh, sw) == tuple(2, 2));
|
||||
IT_ASSERT_TODO(tie(dh, dw) == tuple(1, 1));
|
||||
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||
// Real padding = dilation * (kernel_size - 1) - padding
|
||||
int padding = dh * (r - 1) - ph;
|
||||
const auto A = nnet::makeTensor(
|
||||
"A", AT->getDims(), std::vector<int>{0, padding, padding, 0});
|
||||
const auto K = nnet::makeTensor("K", KT->getDims());
|
||||
return nnet::ConvTransPattern::getExpr(A, K, n, c, h, w, f, r, s);
|
||||
// } else if (auto g2bmmOp = dynamic_cast<G2BMMOp *>(op)) {
|
||||
// const auto &AT = g2bmmOp->getInputs()[0];
|
||||
// const auto &BT = g2bmmOp->getInputs()[1];
|
||||
|
|
|
@ -5,15 +5,15 @@ namespace infini {
|
|||
ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw,
|
||||
const Tensor &inputInConvFWD,
|
||||
const Tensor &weightInConvFWD)
|
||||
const Tensor &weightInConvFWD, ActType act)
|
||||
: OperatorObj(opType, inputs, {output}), ph(ph), pw(pw), sh(sh), sw(sw),
|
||||
dh(dh), dw(dw), padding(PaddingMode::Other) {}
|
||||
dh(dh), dw(dw), padding(PaddingMode::Other), act(act) {}
|
||||
ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw,
|
||||
const Tensor &inputInConvFWD,
|
||||
const Tensor &weightInConvFWD)
|
||||
const Tensor &weightInConvFWD, ActType act)
|
||||
: OperatorObj(opType, inputs, {output}), ph(-1), pw(-1), sh(sh), sw(sw),
|
||||
dh(dh), dw(dw), padding(mode) {
|
||||
dh(dh), dw(dw), padding(mode), act(act) {
|
||||
IT_ASSERT(mode != PaddingMode::Other);
|
||||
}
|
||||
|
||||
|
@ -65,8 +65,7 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
|||
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
|
||||
input, weight),
|
||||
act(act) {
|
||||
input, weight, act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
|
@ -77,8 +76,7 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
|||
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw,
|
||||
input, weight),
|
||||
act(act) {
|
||||
input, weight, act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(mode);
|
||||
|
@ -122,8 +120,8 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
|
|||
int oph, int opw, int group,
|
||||
Tensor bias, ActType act)
|
||||
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, ph, pw, sh, sw,
|
||||
dh, dw, output, weight),
|
||||
oph(oph), opw(opw), group(group), act(act) {
|
||||
dh, dw, output, weight, act),
|
||||
oph(oph), opw(opw), group(group) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
|
@ -136,8 +134,8 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
|
|||
int dh, int dw, int oph, int opw,
|
||||
int group, Tensor bias, ActType act)
|
||||
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, mode, sh, sw, dh,
|
||||
dw, output, weight),
|
||||
oph(oph), opw(opw), group(group), act(act) {
|
||||
dw, output, weight, act),
|
||||
oph(oph), opw(opw), group(group) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(mode);
|
||||
|
@ -168,7 +166,7 @@ void ConvTransposed2dObj::setAuxilaryAttributes(PaddingMode mode) {
|
|||
const Tensor &input = inputs[0];
|
||||
const Tensor &weight = inputs[1];
|
||||
n = input->getDims()[0], f = input->getDims()[1], h = input->getDims()[2],
|
||||
w = input->getDims()[3], c = weight->getDims()[0], r = weight->getDims()[2],
|
||||
w = input->getDims()[3], c = weight->getDims()[1], r = weight->getDims()[2],
|
||||
s = weight->getDims()[3];
|
||||
if (mode == PaddingMode::Same) {
|
||||
int oh = h / sh;
|
||||
|
@ -180,4 +178,70 @@ void ConvTransposed2dObj::setAuxilaryAttributes(PaddingMode mode) {
|
|||
}
|
||||
}
|
||||
|
||||
ConvTransposed2dNHWCObj::ConvTransposed2dNHWCObj(GraphObj *graph, Tensor input,
|
||||
Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw,
|
||||
int dh, int dw, int oph,
|
||||
int opw, int group,
|
||||
Tensor bias, ActType act)
|
||||
: ConvBaseObj(OpType::ConvTransNHWC, {input, weight}, output, ph, pw, sh,
|
||||
sw, dh, dw, output, weight, act),
|
||||
oph(oph), opw(opw), group(group) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
ConvTransposed2dNHWCObj::ConvTransposed2dNHWCObj(GraphObj *graph, Tensor input,
|
||||
Tensor weight, Tensor output,
|
||||
PaddingMode mode, int sh,
|
||||
int sw, int dh, int dw,
|
||||
int oph, int opw, int group,
|
||||
Tensor bias, ActType act)
|
||||
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, mode, sh, sw, dh,
|
||||
dw, output, weight, act),
|
||||
oph(oph), opw(opw), group(group) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(mode);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
ConvTransposed2dNHWCObj::inferShape(const TensorVec &inputs) const {
|
||||
const Tensor &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto f = input->getDims()[3];
|
||||
auto h = input->getDims()[1];
|
||||
auto w = input->getDims()[2];
|
||||
auto c = weight->getDims()[3];
|
||||
auto r = weight->getDims()[1];
|
||||
auto s = weight->getDims()[2];
|
||||
if (f != weight->getDims()[0])
|
||||
return {};
|
||||
|
||||
int on = n, oc = c * group;
|
||||
int oh = 0, ow = 0;
|
||||
oh = (h - 1) * sh - 2 * ph + dh * (r - 1) + oph + 1;
|
||||
ow = (w - 1) * sw - 2 * pw + dw * (s - 1) + opw + 1;
|
||||
return {{{on, oh, ow, oc}}};
|
||||
}
|
||||
|
||||
void ConvTransposed2dNHWCObj::setAuxilaryAttributes(PaddingMode mode) {
|
||||
const Tensor &input = inputs[0];
|
||||
const Tensor &weight = inputs[1];
|
||||
n = input->getDims()[0], f = input->getDims()[3], h = input->getDims()[1],
|
||||
w = input->getDims()[2], c = weight->getDims()[3], r = weight->getDims()[1],
|
||||
s = weight->getDims()[2];
|
||||
if (mode == PaddingMode::Same) {
|
||||
int oh = h / sh;
|
||||
int ow = w / sw;
|
||||
ph = (h - oh * sh + (r - sh) * dh) / 2;
|
||||
pw = (w - ow * sw + (s - sw) * dw) / 2;
|
||||
} else if (mode == PaddingMode::Valid) {
|
||||
ph = pw = 0;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -44,6 +44,40 @@ void testConvTransposedCudnn(
|
|||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
void testConvTransposedNHWCCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 2, 4, 4};
|
||||
const int stride = 1, padding = 0, dilation = 1;
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({N, H, W, F}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({F, R, S, C}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv = gCuda->addOp<ConvTransposed2dNHWCObj>(
|
||||
i0Cuda, w0Cuda, nullptr, padding, padding, stride, stride, dilation,
|
||||
dilation);
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
cuda->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(cuDNN_ConvTransposed, run) {
|
||||
testConvTransposedCudnn(IncrementalGenerator(),
|
||||
vector<float>{0., 0., 1., 2., 3., 0., 6.,
|
||||
|
@ -52,6 +86,14 @@ TEST(cuDNN_ConvTransposed, run) {
|
|||
62., 67., 72., 45.});
|
||||
}
|
||||
|
||||
TEST(cuDNN_ConvTransposedNHWC, run) {
|
||||
testConvTransposedNHWCCudnn(IncrementalGenerator(),
|
||||
vector<float>{16, 65, 71, 77, 63, 100, 290,
|
||||
318, 346, 234, 140, 402, 430, 458,
|
||||
306, 180, 514, 542, 570, 378, 188,
|
||||
465, 487, 509, 307});
|
||||
}
|
||||
|
||||
TEST(cuDNN_ConvTransposed, tune) {
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
|
|
|
@ -3,12 +3,100 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "core/search_engine.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "nnet/nmutator.h"
|
||||
#include "operators/conv.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(Mutator, NaiveConvWithInterpreter) {
|
||||
// verifyNaiveMembound True: subgraph after transformation
|
||||
// verifyNaiveMembound False: subgraph of one single membound (eOP)
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
// const bool verifyNaiveMembound = false;
|
||||
|
||||
auto i0 = g->addTensor({1, 3, 32, 32}, DataType::UInt32);
|
||||
auto w1 = g->addTensor({2, 3, 3, 3}, DataType::UInt32);
|
||||
g->addOp<ConvObj>(i0, w1, nullptr, 1, 1);
|
||||
printf("--- Init Finished ---\n");
|
||||
|
||||
auto mutator = make_ref<NMutator>();
|
||||
mutator->setToNaiveMembound();
|
||||
SearchEngine searchEngine(runtime, mutator);
|
||||
// g->dataMalloc();
|
||||
auto bestGraph = searchEngine.run(g);
|
||||
bestGraph->print();
|
||||
printf("--- SearchEngine Finished ---\n");
|
||||
|
||||
auto mutatedGraphs = mutator->run(g);
|
||||
IT_ASSERT(mutatedGraphs.size() == 2);
|
||||
printf("--- Mutator Finished ---\n");
|
||||
|
||||
auto gg = mutatedGraphs[1];
|
||||
g->dataMalloc();
|
||||
gg->dataMalloc();
|
||||
for (auto t : g->getTensors()) {
|
||||
if (t->getFuid() <= 2)
|
||||
t->setData(IncrementalGenerator());
|
||||
}
|
||||
for (auto t : gg->getTensors()) {
|
||||
if (t->getFuid() <= 2)
|
||||
t->setData(IncrementalGenerator());
|
||||
}
|
||||
runtime->run(g);
|
||||
runtime->run(gg);
|
||||
gg->print();
|
||||
|
||||
EXPECT_TRUE(g->getOutputs()[0]->equalData(gg->getOutputs()[0]));
|
||||
EXPECT_TRUE(g->getOutputs()[0]->getRawDataPtr<void *>() !=
|
||||
gg->getOutputs()[0]->getRawDataPtr<void *>());
|
||||
}
|
||||
|
||||
// FIXME: failed since implicit transpose for DLT
|
||||
TEST(Mutator, InfoGAN_TConv_3_correctness) {
|
||||
// verifyNaiveMembound True: subgraph after transformation
|
||||
// verifyNaiveMembound False: subgraph of one single membound (eOP)
|
||||
// const bool verifyNaiveMembound = false;
|
||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
|
||||
// {n, h, w, f} * {f, r, s, c}
|
||||
auto i0 = g->addTensor({1, 2, 2, 448});
|
||||
auto w0 = g->addTensor({448, 4, 4, 256});
|
||||
g->addOp<ConvTransposed2dNHWCObj>(i0, w0, nullptr, 1, 1, 2, 2, 1, 1);
|
||||
|
||||
auto mutator = make_ref<NMutator>();
|
||||
mutator->setToNaiveMembound();
|
||||
SearchEngine searchEngine(runtime, mutator);
|
||||
auto bestGraph = searchEngine.run(g);
|
||||
bestGraph->print();
|
||||
printf("--- SearchEngine Finished ---\n");
|
||||
|
||||
g->dataMalloc();
|
||||
bestGraph->dataMalloc();
|
||||
for (auto t : g->getTensors()) {
|
||||
if (t->getFuid() <= 2)
|
||||
t->setData(IncrementalGenerator());
|
||||
}
|
||||
for (auto t : bestGraph->getTensors()) {
|
||||
if (t->getFuid() <= 2)
|
||||
t->setData(IncrementalGenerator());
|
||||
}
|
||||
runtime->run(g);
|
||||
runtime->run(bestGraph);
|
||||
|
||||
auto go0 = gCpu->cloneTensor(g->getOutputs()[0]);
|
||||
auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]);
|
||||
|
||||
EXPECT_TRUE(go0->equalData(bgo0));
|
||||
EXPECT_TRUE(g->getOutputs()[0]->getRawDataPtr<void *>() !=
|
||||
bestGraph->getOutputs()[0]->getRawDataPtr<void *>());
|
||||
}
|
||||
|
||||
// TEST(Mutator, Conv9x9) {
|
||||
// auto g = new tpm::Graph();
|
||||
// auto i0 = g->tensor({1, 1, 224, 224});
|
||||
|
@ -71,63 +159,6 @@ namespace infini {
|
|||
// codeEngine.genCode(bestGraph, "res.cu");
|
||||
// }
|
||||
|
||||
// // FIXME: failed since implicit transpose for DLT
|
||||
// TEST(Mutator, InfoGAN_TConv_3_correctness) {
|
||||
// // verifyNaiveMembound True: subgraph after transformation
|
||||
// // verifyNaiveMembound False: subgraph of one single membound (eOP)
|
||||
// const bool verifyNaiveMembound = false;
|
||||
// auto g = new tpm::Graph();
|
||||
// // {n, h, w, f} * {r, s, f, c}
|
||||
// // {n, f, h, w} * {f, c, r, s}
|
||||
// auto i0 = g->tensor({1, 448, 2, 2});
|
||||
// auto w1 = g->tensor({448, 256, 4, 4});
|
||||
// g->convTrans(i0, w1, 1, 1, 2, 2, 1, 1);
|
||||
// }
|
||||
|
||||
TEST(Mutator, NaiveConvWithInterpreter) {
|
||||
// verifyNaiveMembound True: subgraph after transformation
|
||||
// verifyNaiveMembound False: subgraph of one single membound (eOP)
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
// const bool verifyNaiveMembound = false;
|
||||
|
||||
auto i0 = g->addTensor({1, 3, 32, 32}, DataType::UInt32);
|
||||
auto w1 = g->addTensor({2, 3, 3, 3}, DataType::UInt32);
|
||||
g->addOp<ConvObj>(i0, w1, nullptr, 1, 1);
|
||||
printf("--- Init Finished ---\n");
|
||||
|
||||
auto mutator = make_ref<NMutator>();
|
||||
mutator->setToNaiveMembound();
|
||||
SearchEngine searchEngine(runtime, mutator);
|
||||
// g->dataMalloc();
|
||||
auto bestGraph = searchEngine.run(g);
|
||||
bestGraph->print();
|
||||
printf("--- SearchEngine Finished ---\n");
|
||||
|
||||
auto mutatedGraphs = mutator->run(g);
|
||||
IT_ASSERT(mutatedGraphs.size() == 2);
|
||||
printf("--- Mutator Finished ---\n");
|
||||
|
||||
auto gg = mutatedGraphs[1];
|
||||
g->dataMalloc();
|
||||
gg->dataMalloc();
|
||||
for (auto t : g->getTensors()) {
|
||||
if (t->getFuid() <= 2)
|
||||
t->setData(IncrementalGenerator());
|
||||
}
|
||||
for (auto t : gg->getTensors()) {
|
||||
if (t->getFuid() <= 2)
|
||||
t->setData(IncrementalGenerator());
|
||||
}
|
||||
runtime->run(g);
|
||||
runtime->run(gg);
|
||||
gg->print();
|
||||
|
||||
EXPECT_TRUE(g->getOutputs()[0]->equalData(gg->getOutputs()[0]));
|
||||
EXPECT_TRUE(g->getOutputs()[0]->getRawDataPtr<void *>() !=
|
||||
gg->getOutputs()[0]->getRawDataPtr<void *>());
|
||||
}
|
||||
|
||||
// TEST(Mutator, G2BMM) {
|
||||
// auto g = new tpm::Graph();
|
||||
|
||||
|
|
Loading…
Reference in New Issue