InfiniTensor/src/core/graph_handler.cc

766 lines
28 KiB
C++

#include "core/graph_handler.h"
#include "operators/all_gather.h"
#include "operators/all_reduce.h"
#include "operators/attention_kvcache.h"
#include "operators/batch_norm.h"
#include "operators/broadcast.h"
#include "operators/concat.h"
#include "operators/conv.h"
#include "operators/element_wise.h"
#include "operators/expand.h"
#include "operators/gather.h"
#include "operators/layer_norm.h"
#include "operators/lrn.h"
#include "operators/matmul.h"
#include "operators/pad.h"
#include "operators/pooling.h"
#include "operators/recv.h"
#include "operators/reduce.h"
#include "operators/reshape.h"
#include "operators/resize.h"
#include "operators/rms_norm.h"
#include "operators/rope.h"
#include "operators/send.h"
#include "operators/slice.h"
#include "operators/softmax.h"
#include "operators/split.h"
#include "operators/squeeze.h"
#include "operators/transpose.h"
#include "operators/unary.h"
#include "operators/unsqueeze.h"
#include "operators/where.h"
#include <numeric>
#include <variant>
namespace infini {
static DataType dtype_repr_convert(int);
static CastType inferCastType(Tensor input, int to);
Tensor GraphHandlerObj::tensor(Shape dims, int dtype) {
return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
}
Tensor GraphHandlerObj::conv(Tensor input, Tensor weight, Tensor output, int ph,
int pw, int sh, int sw, int dh, int dw) {
if (output) {
g->addOpWithOutputs<ConvObj>(std::move(input), std::move(weight),
output, ph, pw, sh, sw, dh, dw);
return output;
} else {
return g
->addOp<ConvObj>(std::move(input), std::move(weight), output, ph,
pw, sh, sw, dh, dw)
->getOutput();
}
}
Tensor GraphHandlerObj::convTransposed2d(Tensor input, Tensor weight,
Tensor output, int ph, int pw, int sh,
int sw, int dh, int dw, int oph,
int opw) {
if (output) {
g->addOpWithOutputs<ConvTransposed2dObj>(std::move(input),
std::move(weight), output, ph,
pw, sh, sw, dh, dw, oph, opw);
return output;
} else {
return g
->addOp<ConvTransposed2dObj>(std::move(input), std::move(weight),
output, ph, pw, sh, sw, dh, dw, oph,
opw)
->getOutput();
}
}
Tensor GraphHandlerObj::matmul(Tensor a, Tensor b, Tensor y, bool transA,
bool transB, Tensor bias, ActType act,
std::string matmul_compute_type) {
if (y) {
g->addOpWithOutputs<MatmulObj>(std::move(a), std::move(b), y, transA,
transB, std::move(bias), act,
matmul_compute_type);
return y;
} else {
return g
->addOp<MatmulObj>(std::move(a), std::move(b), y, transA, transB,
std::move(bias), act, matmul_compute_type)
->getOutput();
}
}
Tensor GraphHandlerObj::batchNormalization(Tensor input, Tensor output,
Tensor mean, Tensor var,
Tensor scale, Tensor bias,
float momentum, float eps,
bool training) {
if (output) {
g->addOpWithOutputs<BatchNormObj>(
std::move(input), output, std::move(mean), std::move(var),
std::move(scale), std::move(bias), momentum, eps, training);
return output;
} else {
return g
->addOp<BatchNormObj>(std::move(input), output, std::move(mean),
std::move(var), std::move(scale),
std::move(bias), momentum, eps, training)
->getOutput();
}
}
Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
Tensor output, Tensor bias,
float eps, int axis,
int stash_type) {
if (output) {
g->addOpWithOutputs<LayerNormObj>(std::move(input), std::move(scale),
output, std::move(bias), eps, axis,
stash_type);
return output;
} else {
return g
->addOp<LayerNormObj>(std::move(input), std::move(scale), output,
std::move(bias), eps, axis, stash_type)
->getOutput();
}
}
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
if (output) {
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight),
output);
return output;
} else {
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)
->getOutput();
}
}
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh, int sw,
int ceilMode) {
if (output) {
g->addOpWithOutputs<MaxPoolObj>(std::move(input), output, kh, kw, dh,
dw, ph, pw, sh, sw, ceilMode);
return output;
} else {
return g
->addOp<MaxPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
pw, sh, sw, ceilMode)
->getOutput();
}
}
Tensor GraphHandlerObj::avgPool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh, int sw,
int ceilMode) {
if (output) {
g->addOpWithOutputs<AvgPoolObj>(std::move(input), output, kh, kw, dh,
dw, ph, pw, sh, sw, ceilMode);
return output;
} else {
return g
->addOp<AvgPoolObj>(std::move(input), output, kh, kw, dh, dw, ph,
pw, sh, sw, ceilMode)
->getOutput();
}
}
// see operators/element_wise.h
#define DEFINE_ELEMENT_WISE_METHOD(name, obj) \
Tensor GraphHandlerObj::name(Tensor a, Tensor b, Tensor c) { \
if (c) { \
g->addOpWithOutputs<obj##Obj>(std::move(a), std::move(b), c); \
return c; \
} else { \
return g->addOp<obj##Obj>(std::move(a), std::move(b), c) \
->getOutput(); \
} \
}
DEFINE_ELEMENT_WISE_METHOD(add, Add)
DEFINE_ELEMENT_WISE_METHOD(sub, Sub)
DEFINE_ELEMENT_WISE_METHOD(mul, Mul)
DEFINE_ELEMENT_WISE_METHOD(div, Div)
DEFINE_ELEMENT_WISE_METHOD(pow, Pow)
DEFINE_ELEMENT_WISE_METHOD(min, Minimum)
DEFINE_ELEMENT_WISE_METHOD(max, Maximum)
// see operators/unary.h
#define DEFINE_UNARY_METHOD(name, obj) \
Tensor GraphHandlerObj::name(Tensor x, Tensor y) { \
if (y) { \
g->addOpWithOutputs<obj##Obj>(std::move(x), y); \
return y; \
} else { \
return g->addOp<obj##Obj>(std::move(x), y)->getOutput(); \
} \
}
DEFINE_UNARY_METHOD(silu, Silu)
DEFINE_UNARY_METHOD(relu, Relu)
DEFINE_UNARY_METHOD(gelu, Gelu)
DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
DEFINE_UNARY_METHOD(tanh, Tanh)
DEFINE_UNARY_METHOD(hardSigmoid, HardSigmoid)
DEFINE_UNARY_METHOD(hardSwish, HardSwish)
DEFINE_UNARY_METHOD(abs, Abs)
DEFINE_UNARY_METHOD(sqrt, Sqrt)
DEFINE_UNARY_METHOD(neg, Neg)
DEFINE_UNARY_METHOD(shape, Shape)
DEFINE_UNARY_METHOD(erf, Erf)
// see operators/reshape.h
DEFINE_UNARY_METHOD(identity, Identity)
Tensor GraphHandlerObj::pRelu(Tensor x, Tensor slope, Tensor y) {
if (y) {
g->addOpWithOutputs<PReluObj>(std::move(x), std::move(slope), y);
return y;
} else {
return g->addOp<PReluObj>(std::move(x), std::move(slope), y)
->getOutput();
}
}
Tensor GraphHandlerObj::clip(Tensor x, Tensor y, std::optional<float> min,
std::optional<float> max) {
if (y) {
g->addOpWithOutputs<ClipObj>(std::move(x), y, min, max);
return y;
} else {
return g->addOp<ClipObj>(std::move(x), y, min, max)->getOutput();
}
}
Tensor GraphHandlerObj::softmax(Tensor input, Tensor output, int axis) {
if (output) {
g->addOpWithOutputs<SoftmaxObj>(std::move(input), output, axis);
return output;
} else {
return g->addOp<SoftmaxObj>(std::move(input), output, axis)
->getOutput();
}
}
Tensor GraphHandlerObj::flatten(Tensor input, Tensor output, int axis) {
if (output) {
g->addOpWithOutputs<FlattenObj>(std::move(input), output, axis);
return output;
} else {
return g->addOp<FlattenObj>(std::move(input), output, axis)
->getOutput();
}
}
Tensor GraphHandlerObj::transpose(Tensor data, Tensor transposed, Shape perm) {
if (transposed) {
g->addOpWithOutputs<TransposeObj>(std::move(data), transposed, perm);
return transposed;
} else {
return g->addOp<TransposeObj>(std::move(data), transposed, perm)
->getOutput();
}
}
Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) {
if (reshaped) {
g->addOpWithOutputs<ReshapeObj>(std::move(data), reshaped,
std::move(shape));
return reshaped;
} else {
return g->addOp<ReshapeObj>(std::move(data), reshaped, std::move(shape))
->getOutput();
}
}
Tensor GraphHandlerObj::resize(Tensor input, Tensor output,
const std::optional<vector<int>> &axes,
Tensor sizes, Tensor scales, Tensor roi,
vector<uint32_t> sizes_, vector<float> scales_,
vector<float> roi_, string mode,
string ratioPolicy, string nearestMode,
string coordTransMode) {
if (sizes_.size() > 0) {
sizes->dataMalloc();
sizes->copyin<uint32_t>(sizes_);
}
if (scales_.size() > 0) {
scales->dataMalloc();
scales->copyin<float>(scales_);
}
if (roi_.size() > 0) {
roi->dataMalloc();
roi->copyin<float>(roi_);
}
ResizeObj::EKeepAspectRatioPolicy ratioPolicy_ =
ResizeObj::fromRatioPolicyStr(ratioPolicy);
ResizeObj::ENearestMode nearestMode_ =
ResizeObj::fromENearestModeStr(nearestMode);
ResizeObj::ECoordinateTransMode coordTransMode_ =
ResizeObj::fromECoordinateTransModeStr(coordTransMode);
ResizeObj::ECoeffMode mode_ = ResizeObj::fromECoeffModeStr(mode);
if (output) {
if (mode == "nearest") {
g->addOpWithOutputs<ResizeObj>(
std::move(input), output, std::move(axes), std::move(sizes),
std::move(scales), std::move(roi), ratioPolicy_, nearestMode_,
coordTransMode_);
} else {
g->addOpWithOutputs<ResizeObj>(
std::move(input), output, std::move(axes), std::move(sizes),
std::move(scales), std::move(roi), mode_, ratioPolicy_,
coordTransMode_);
}
return output;
} else {
if (mode == "nearest") {
return g
->addOp<ResizeObj>(std::move(input), output, std::move(axes),
std::move(sizes), std::move(scales),
std::move(roi), ratioPolicy_, nearestMode_,
coordTransMode_)
->getOutput();
} else {
return g
->addOp<ResizeObj>(std::move(input), output, std::move(axes),
std::move(sizes), std::move(scales),
std::move(roi), mode_, ratioPolicy_,
coordTransMode_)
->getOutput();
}
}
}
Tensor GraphHandlerObj::concat(TensorVec inputs, Tensor output, int dim) {
if (output) {
g->addOpWithOutputs<ConcatObj>(std::move(inputs), output, dim);
return output;
} else {
return g->addOp<ConcatObj>(std::move(inputs), output, dim)->getOutput();
}
}
Tensor GraphHandlerObj::attentionKVCache(Tensor input_k_cache,
Tensor input_v_cache, Tensor input_q,
Tensor input_k, Tensor input_v,
Tensor position_id,
Tensor output_matmul) {
if (output_matmul) {
g->addOpWithOutputs<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul);
return output_matmul;
} else {
return g
->addOp<AttentionKVCacheObj>(
std::move(input_k_cache), std::move(input_v_cache),
std::move(input_q), std::move(input_k), std::move(input_v),
std::move(position_id), output_matmul)
->getOutput();
}
}
Tensor GraphHandlerObj::RoPE(Tensor pos, Tensor input, Tensor output) {
if (output) {
g->addOpWithOutputs<RoPEObj>(std::move(pos), std::move(input), output);
return output;
} else {
return g->addOp<RoPEObj>(std::move(pos), std::move(input), output)
->getOutput();
}
}
TensorVec GraphHandlerObj::split(Tensor input, std::optional<TensorVec> outputs,
int axis,
std::variant<int, vector<int>> numOrRatio) {
if (outputs) {
if (std::holds_alternative<int>(numOrRatio)) {
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
std::get<int>(numOrRatio));
} else {
g->addOpWithOutputs<SplitObj>(std::move(input), outputs, axis,
std::get<vector<int>>(numOrRatio));
}
return *outputs;
} else {
if (std::holds_alternative<int>(numOrRatio)) {
return g
->addOp<SplitObj>(std::move(input), outputs, axis,
std::get<int>(numOrRatio))
->getOutputs();
} else {
return g
->addOp<SplitObj>(std::move(input), outputs, axis,
std::get<vector<int>>(numOrRatio))
->getOutputs();
}
}
}
Tensor GraphHandlerObj::gather(Tensor data, Tensor indices, Tensor output,
int axis) {
if (output) {
g->addOpWithOutputs<GatherObj>(std::move(data), std::move(indices),
output, axis);
return output;
} else {
return g
->addOp<GatherObj>(std::move(data), std::move(indices), output,
axis)
->getOutput();
}
}
Tensor GraphHandlerObj::gatherElements(Tensor data, Tensor indices,
Tensor output, int axis) {
if (output) {
g->addOpWithOutputs<GatherElementsObj>(
std::move(data), std::move(indices), output, axis);
return output;
} else {
return g
->addOp<GatherElementsObj>(std::move(data), std::move(indices),
output, axis)
->getOutput();
}
}
#define DEFINE_REDUCE_METHOD(name, obj) \
Tensor GraphHandlerObj::name(Tensor data, Tensor reduced, \
const optional<vector<int>> &axes, \
bool keepdims) { \
if (reduced) { \
g->addOpWithOutputs<_CAT(obj, Obj)>(std::move(data), reduced, \
axes, keepdims); \
return reduced; \
} else { \
return g \
->addOp<_CAT(obj, Obj)>(std::move(data), reduced, axes, \
keepdims) \
->getOutput(); \
} \
}
DEFINE_REDUCE_METHOD(reduceMean, ReduceMean)
DEFINE_REDUCE_METHOD(reduceSum, ReduceSum)
Tensor GraphHandlerObj::slice(Tensor input, Tensor output,
const vector<int> &starts,
const vector<int> &ends,
const optional<vector<int>> &axes,
const optional<vector<int>> &steps) {
if (output) {
g->addOpWithOutputs<SliceObj>(std::move(input), output, starts, ends,
axes, steps);
return output;
} else {
return g
->addOp<SliceObj>(std::move(input), output, starts, ends, axes,
steps)
->getOutput();
}
}
Tensor GraphHandlerObj::pad(Tensor input, Tensor output,
const vector<int> &pads,
const optional<vector<int>> &axes) {
if (output) {
g->addOpWithOutputs<PadObj>(std::move(input), output, pads, axes);
return output;
} else {
return g->addOp<PadObj>(std::move(input), output, pads, axes)
->getOutput();
}
}
Tensor GraphHandlerObj::allReduceSum(Tensor input, Tensor output) {
if (output) {
g->addOpWithOutputs<AllReduceSumObj>(std::move(input), output);
return output;
} else {
return g->addOp<AllReduceSumObj>(std::move(input), output)->getOutput();
}
}
Tensor GraphHandlerObj::allReduceProd(Tensor input, Tensor output) {
if (output) {
g->addOpWithOutputs<AllReduceProdObj>(std::move(input), output);
return output;
} else {
return g->addOp<AllReduceProdObj>(std::move(input), output)
->getOutput();
}
}
Tensor GraphHandlerObj::allReduceMin(Tensor input, Tensor output) {
if (output) {
g->addOpWithOutputs<AllReduceMinObj>(std::move(input), output);
return output;
} else {
return g->addOp<AllReduceMinObj>(std::move(input), output)->getOutput();
}
}
Tensor GraphHandlerObj::allReduceMax(Tensor input, Tensor output) {
if (output) {
g->addOpWithOutputs<AllReduceMaxObj>(std::move(input), output);
return output;
} else {
return g->addOp<AllReduceMaxObj>(std::move(input), output)->getOutput();
}
}
Tensor GraphHandlerObj::allReduceAvg(Tensor input, Tensor output) {
if (output) {
g->addOpWithOutputs<AllReduceAvgObj>(std::move(input), output);
return output;
} else {
return g->addOp<AllReduceAvgObj>(std::move(input), output)->getOutput();
}
}
TensorVec GraphHandlerObj::allGather(Tensor input,
std::optional<TensorVec> outputs, int n) {
if (outputs) {
g->addOpWithOutputs<AllGatherObj>(std::move(input), outputs, n);
return *outputs;
} else {
return g->addOp<AllGatherObj>(std::move(input), outputs, n)
->getOutputs();
}
}
Tensor GraphHandlerObj::broadcast(Tensor input, Tensor output, int root) {
if (output) {
g->addOpWithOutputs<BroadcastObj>(std::move(input), output, root);
return output;
} else {
return g->addOp<BroadcastObj>(std::move(input), output, root)
->getOutput();
}
}
Tensor GraphHandlerObj::send(Tensor input, int source, int destination,
Tensor output) {
if (output) {
g->addOpWithOutputs<SendObj>(std::move(input), source, destination,
output);
return output;
} else {
return g->addOp<SendObj>(std::move(input), source, destination, output)
->getOutput();
}
}
Tensor GraphHandlerObj::recv(Tensor output, int source, int destination,
Shape dims, int outputType, Tensor input) {
if (output) {
g->addOpWithOutputs<RecvObj>(output, source, destination,
std::move(dims), outputType,
std::move(input));
return output;
} else {
return g
->addOp<RecvObj>(output, source, destination, std::move(dims),
outputType, std::move(input))
->getOutput();
}
}
Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) {
if (output) {
g->addOpWithOutputs<CastObj>(std::move(input), output,
inferCastType(input, to));
return output;
} else {
return g
->addOp<CastObj>(std::move(input), output, inferCastType(input, to))
->getOutput();
}
}
Tensor GraphHandlerObj::expand(Tensor input, Tensor output, Shape dims) {
if (output) {
g->addOpWithOutputs<ExpandObj>(std::move(input), output,
std::move(dims));
return output;
} else {
return g->addOp<ExpandObj>(std::move(input), output, std::move(dims))
->getOutput();
}
}
Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
Tensor output) {
if (output) {
g->addOpWithOutputs<WhereObj>(std::move(inputX), std::move(inputY),
std::move(condition), output);
return output;
} else {
return g
->addOp<WhereObj>(std::move(inputX), std::move(inputY),
std::move(condition), output)
->getOutput();
}
}
Tensor GraphHandlerObj::depthToSpace(Tensor input, Tensor output, int blocksize,
std::string mode) {
if (output) {
g->addOpWithOutputs<DepthToSpaceObj>(std::move(input), output,
blocksize, mode);
return output;
} else {
return g
->addOp<DepthToSpaceObj>(std::move(input), output, blocksize, mode)
->getOutput();
}
}
Tensor GraphHandlerObj::lrn(Tensor input, Tensor output, float alpha,
float beta, float bias, int size) {
if (output) {
g->addOpWithOutputs<LRNObj>(std::move(input), output, alpha, beta, bias,
size);
return output;
} else {
return g
->addOp<LRNObj>(std::move(input), output, alpha, beta, bias, size)
->getOutput();
}
}
Tensor GraphHandlerObj::squeeze(Tensor input, Tensor output, Shape axes) {
if (output) {
g->addOpWithOutputs<SqueezeObj>(std::move(input), output,
std::move(axes));
return output;
} else {
return g->addOp<SqueezeObj>(std::move(input), output, std::move(axes))
->getOutput();
}
}
Tensor GraphHandlerObj::unsqueeze(Tensor input, Tensor output, Shape axes) {
if (output) {
g->addOpWithOutputs<UnsqueezeObj>(std::move(input), output,
std::move(axes));
return output;
} else {
return g->addOp<UnsqueezeObj>(std::move(input), output, std::move(axes))
->getOutput();
}
}
static CastType inferCastType(Tensor input, int to) {
auto iType = input->getDType();
auto oType = DataType(to);
if (iType == DataType::Float32 && oType == DataType::Float16) {
return CastType::Float2Float16;
} else if (iType == DataType::Float32 && oType == DataType::Int64) {
return CastType::Float2Int64;
} else if (iType == DataType::Float32 && oType == DataType::Int32) {
return CastType::Float2Int32;
} else if (iType == DataType::Float32 && oType == DataType::Int16) {
return CastType::Float2Int16;
} else if (iType == DataType::Float32 && oType == DataType::Int8) {
return CastType::Float2Int8;
} else if (iType == DataType::Float32 && oType == DataType::BFloat16) {
return CastType::Float2BFloat16;
} else if (iType == DataType::Int32 && oType == DataType::Float32) {
return CastType::Int322Float;
} else if (iType == DataType::Int32 && oType == DataType::Int8) {
return CastType::Int322Int8;
} else if (iType == DataType::Int32 && oType == DataType::Int16) {
return CastType::Int322Int16;
} else if (iType == DataType::Int32 && oType == DataType::Int64) {
return CastType::Int322Int64;
} else if (iType == DataType::Int16 && oType == DataType::Int32) {
return CastType::Int162Int32;
} else if (iType == DataType::Int16 && oType == DataType::Float32) {
return CastType::Int162Float;
} else if (iType == DataType::Int8 && oType == DataType::Float32) {
return CastType::Int82Float;
} else if (iType == DataType::Int8 && oType == DataType::Int16) {
return CastType::Int82Int16;
} else if (iType == DataType::Int8 && oType == DataType::Int32) {
return CastType::Int82Int32;
} else if (iType == DataType::UInt8 && oType == DataType::Int32) {
return CastType::Uint82Int32;
} else if (iType == DataType::UInt8 && oType == DataType::Float32) {
return CastType::Uint82Float;
} else if (iType == DataType::UInt8 && oType == DataType::Int64) {
return CastType::Uint82Int64;
} else if (iType == DataType::Int64 && oType == DataType::Float32) {
return CastType::Int642Float;
} else if (iType == DataType::Int64 && oType == DataType::UInt32) {
return CastType::Int642Uint32;
} else if (iType == DataType::Int64 && oType == DataType::Int32) {
return CastType::Int642Int32;
} else if (iType == DataType::UInt32 && oType == DataType::Int64) {
return CastType::Uint322Int64;
} else if (iType == DataType::Float16 && oType == DataType::Float32) {
return CastType::Float162Float;
} else if (iType == DataType::BFloat16 && oType == DataType::Float32) {
return CastType::BFloat162Float;
} else if (iType == DataType::Float32 && oType == DataType::Float32) {
return CastType::Float2Float;
} else {
IT_TODO_HALT_MSG("Unsupported CastType : input_type is " +
iType.toString() + " output_type is " +
oType.toString());
}
}
static DataType dtype_repr_convert(int dtype) {
switch (dtype) {
case 0:
return DataType::Undefine;
case 1:
return DataType::Float32;
case 2:
return DataType::UInt8;
case 3:
return DataType::Int8;
case 4:
return DataType::UInt16;
case 5:
return DataType::Int16;
case 6:
return DataType::Int32;
case 7:
return DataType::Int64;
case 8:
return DataType::String;
case 9:
return DataType::Bool;
case 10:
return DataType::Float16;
case 11:
return DataType::Double;
case 12:
return DataType::UInt32;
case 13:
return DataType::UInt64;
case 16:
return DataType::BFloat16;
default:
IT_ASSERT(false, "Unsupported data type");
}
}
void GraphHandlerObj::change_shape(const vector<int> &shape, int tensorId) {
auto tensor = g->getTensor(tensorId);
IT_ASSERT(tensor != nullptr);
IT_ASSERT(shape.size() != 0);
tensor->setShape(shape);
}
} // namespace infini