框架支持bert/gpt2模型构图 (#94)

* feat: support to sqrt op

* feat: support to erf op

* feat: support to expand op

* feat: support to where op

* fix: gather op index can be int64_t(hard coding)

* fix: some wrong use

* style: fix the format style

* test: add test for change op

* fix: rebase to master

* fix: fix matmul b compute wrong

* add expand and where kernel

* Add int64 support for cuda gather kernel

* add test_where.cc

* add "expand.(cu/cc,test,cuda),modified where.cu"

* Separate initialization of datatypes to avoid compile error

* modify where.(cu/cc/h,test), expand and clip

* Format fix

* Format fix

---------

Co-authored-by: xgqdut2016 <kenan_gewei@163.com>
Co-authored-by: panzezhong <panzezhong@qiyuanlab.com>
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
zhangyunze 2023-08-29 16:06:52 +08:00 committed by GitHub
parent d8ffd8a4b7
commit 3e6ef305f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 804 additions and 59 deletions

View File

@ -1,3 +1,4 @@
#pragma once
#include "core/common.h" #include "core/common.h"
namespace infini { namespace infini {
@ -69,23 +70,6 @@ class DataType {
int getIndex() const { return index; } int getIndex() const { return index; }
}; };
// to be consistent with onnx
// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484
inline const DataType DataType::Undefine(0);
inline const DataType DataType::Float32(1);
inline const DataType DataType::UInt8(2);
inline const DataType DataType::Int8(3);
inline const DataType DataType::UInt16(4);
inline const DataType DataType::Int16(5);
inline const DataType DataType::Int32(6);
inline const DataType DataType::Int64(7);
inline const DataType DataType::String(8);
inline const DataType DataType::Bool(9);
inline const DataType DataType::Float16(10);
inline const DataType DataType::Double(11);
inline const DataType DataType::UInt32(12);
inline const DataType DataType::UInt64(13);
inline const DataType DataType::BFloat16(16);
// Method definitions are out of the declaration due to GCC bug: // Method definitions are out of the declaration due to GCC bug:
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc // https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
template <> inline int DataType::get<float>() { return 0; } template <> inline int DataType::get<float>() { return 0; }

View File

@ -47,6 +47,7 @@ class GraphHandlerObj {
Tensor relu(Tensor x, Tensor y); Tensor relu(Tensor x, Tensor y);
Tensor sigmoid(Tensor x, Tensor y); Tensor sigmoid(Tensor x, Tensor y);
Tensor tanh(Tensor x, Tensor y); Tensor tanh(Tensor x, Tensor y);
Tensor erf(Tensor x, Tensor y);
Tensor softmax(Tensor x, Tensor y, int axis); Tensor softmax(Tensor x, Tensor y, int axis);
Tensor abs(Tensor x, Tensor y); Tensor abs(Tensor x, Tensor y);
Tensor sqrt(Tensor x, Tensor y); Tensor sqrt(Tensor x, Tensor y);
@ -70,6 +71,8 @@ class GraphHandlerObj {
Tensor pad(Tensor input, Tensor output, const vector<int> &pads, Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
const optional<vector<int>> &axes); const optional<vector<int>> &axes);
Tensor cast(Tensor input, Tensor output, int to); Tensor cast(Tensor input, Tensor output, int to);
Tensor expand(Tensor input, Tensor output, Shape dims);
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
//------ modifiers //------ modifiers

View File

@ -0,0 +1,9 @@
#pragma once
#include "operators/unary.h"
#include "utils/small_array.h"
namespace infini {
void expand_kernel(float *input, float *output, int nDims, int outputsize,
SmallArray inputShape, SmallArray outputShape);
}; // namespace infini

13
include/cuda/cuda_where.h Normal file
View File

@ -0,0 +1,13 @@
#pragma once
#include "operators/unary.h"
#include "utils/small_array.h"
namespace infini {
void where_kernel(const float *inputx, const float *inputy,
const float *condition, float *output, int nDims,
infini::SmallArray inputxShape,
infini::SmallArray inputyShape,
infini::SmallArray conditionShape,
infini::SmallArray outputShape);
}; // namespace infini

View File

@ -1,7 +1,10 @@
#pragma once #pragma once
#include "core/data_type.h"
typedef struct { namespace infini {
int *indexValue; struct GatherMetaData {
void *indexValue;
DataType indexType;
int axis; int axis;
int inNDim; int inNDim;
int outNDim; int outNDim;
@ -10,8 +13,7 @@ typedef struct {
int idxDim[4]; int idxDim[4];
int idxStride[4]; int idxStride[4];
int inStride[4]; int inStride[4];
} GatherMetaData; };
namespace infini { void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num);
void gather_kernel(float *in, float *out, GatherMetaData metaData, int num); } // namespace infini
}

View File

@ -0,0 +1,36 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Broadcast the input tensor following the given shape and the
* broadcast rule.
*
*/
class ExpandObj : public OperatorObj {
Shape dims;
public:
/**
* @brief Construct a new Expand object.
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor.
* @param output The output tensor.
* @param dims The shape you want to expand to, following the broadcast
* rule.
*/
ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
OP_CLONE(ExpandObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
Shape getShape() const { return dims; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

36
include/operators/where.h Normal file
View File

@ -0,0 +1,36 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Return elements, either from X or Y, depending on condition.
*
*/
class WhereObj : public OperatorObj {
public:
/**
* @brief Construct a new Where object.
*
* @param graph The computation graph that this operator belongs to.
* @param inputX The input tensor X.
* @param inputY The input tensor Y.
* @param output The output tensor.
* @param condition The condition tensor.
*/
WhereObj(GraphObj *graph, Tensor inputX, Tensor inputY, Tensor condition,
Tensor output);
OP_CLONE(WhereObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -409,7 +409,8 @@ class OnnxStub:
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
next( next(
(attr.i for attr in node.attribute if attr.name == "axis") (attr.i for attr in node.attribute if attr.name == "axis"),
1,
), ),
) )
elif node.op_type == "PRelu": elif node.op_type == "PRelu":
@ -517,7 +518,8 @@ class OnnxStub:
tensors[node.input[1]], tensors[node.input[1]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
next( next(
(attr.i for attr in node.attribute if attr.name == "axis"), 0 (attr.i for attr in node.attribute if attr.name == "axis"),
0,
), ),
) )
elif node.op_type == "ReduceMean": elif node.op_type == "ReduceMean":
@ -539,7 +541,7 @@ class OnnxStub:
for attr in node.attribute for attr in node.attribute
if attr.name == "keepdims" if attr.name == "keepdims"
), ),
1 1,
) )
!= 0, != 0,
) )
@ -589,6 +591,25 @@ class OnnxStub:
tensors.get(node.output[0]), tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "to")), next((attr.i for attr in node.attribute if attr.name == "to")),
) )
elif node.op_type == "Expand":
shape = _parse_data(data[node.input[1]])
tensors[node.output[0]] = self.handler.expand(
tensors[node.input[0]],
tensors.get(node.output[0]),
shape,
)
elif node.op_type == "Erf":
tensors[node.output[0]] = self.handler.erf(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Where":
tensors[node.output[0]] = self.handler.where(
tensors[node.input[1]],
tensors[node.input[2]],
tensors[node.input[0]],
tensors.get(node.output[0]),
)
else: else:
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))
new_node_name.append(node.name) new_node_name.append(node.name)
@ -814,6 +835,8 @@ class OnnxStub:
backend.OpTypeId.Abs, backend.OpTypeId.Abs,
backend.OpTypeId.Identity, backend.OpTypeId.Identity,
backend.OpTypeId.PRelu, backend.OpTypeId.PRelu,
backend.OpTypeId.Sqrt,
backend.OpTypeId.Erf,
]: ]:
ctx.push_node(make_node(ty.name, inputs, outputs, name)) ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpTypeId.Flatten: elif ty == backend.OpTypeId.Flatten:
@ -904,6 +927,13 @@ class OnnxStub:
elif ty == backend.OpTypeId.Cast: elif ty == backend.OpTypeId.Cast:
to = backend.cast_to_of(op) to = backend.cast_to_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, to=to)) ctx.push_node(make_node(ty.name, inputs, outputs, name, to=to))
elif ty == backend.OpTypeId.Where:
assert len(inputs) == 3, "Check Where Op must have three inputs."
new_inputs = [inputs[2], inputs[0], inputs[1]]
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
elif ty == backend.OpTypeId.Expand:
shape = backend.expand_shape_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape))
else: else:
raise Exception("Unsupported OpType", ty) raise Exception("Unsupported OpType", ty)

View File

@ -207,6 +207,18 @@ class TestStringMethods(unittest.TestCase):
relu = make_node("Relu", ["x"], ["y"], name="relu") relu = make_node("Relu", ["x"], ["y"], name="relu")
make_and_import_model(make_graph([relu], "relu", [x], [y])) make_and_import_model(make_graph([relu], "relu", [x], [y]))
def test_erf(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
erf = make_node("Erf", ["x"], ["y"], name="erf")
make_and_import_model(make_graph([erf], "erf", [x], [y]))
def test_sqrt(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
sqrt = make_node("Sqrt", ["x"], ["y"], name="sqrt")
make_and_import_model(make_graph([sqrt], "sqrt", [x], [y]))
def test_sigmoid(self): def test_sigmoid(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
@ -352,6 +364,24 @@ class TestStringMethods(unittest.TestCase):
) )
make_and_import_model(make_graph([cast], "cast", [input1], [output])) make_and_import_model(make_graph([cast], "cast", [input1], [output]))
def test_expand(self):
data = make_tensor_value_info("data", TensorProto.FLOAT, [3, 1])
dim = make_tensor_value_info("dim", TensorProto.INT64, [3])
dim_data = make_tensor("dim", TensorProto.INT64, [3], [2, 1, 6])
output = make_tensor_value_info("output", TensorProto.FLOAT, [2, 3, 6])
expand = make_node("Expand", ["data", "dim"], ["output"], name="expand")
make_and_import_model(
make_graph([expand], "expand", [data, dim], [output], [dim_data])
)
def test_where(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
con = make_tensor_value_info("con", TensorProto.BOOL, [1, 3, 5, 7])
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 5, 7])
where = make_node("Where", ["x", "y", "con"], ["output"], name="where")
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

23
src/core/data_type.cc Normal file
View File

@ -0,0 +1,23 @@
#include "core/data_type.h"
namespace infini {
// Move implementation here to avoid compile time error on some platform
// to be consistent with onnx
// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484
const DataType DataType::Undefine(0);
const DataType DataType::Float32(1);
const DataType DataType::UInt8(2);
const DataType DataType::Int8(3);
const DataType DataType::UInt16(4);
const DataType DataType::Int16(5);
const DataType DataType::Int32(6);
const DataType DataType::Int64(7);
const DataType DataType::String(8);
const DataType DataType::Bool(9);
const DataType DataType::Float16(10);
const DataType DataType::Double(11);
const DataType DataType::UInt32(12);
const DataType DataType::UInt64(13);
// TODO: Reserved for complex data type.
const DataType DataType::BFloat16(16);
} // namespace infini

View File

@ -3,6 +3,7 @@
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/conv.h" #include "operators/conv.h"
#include "operators/element_wise.h" #include "operators/element_wise.h"
#include "operators/expand.h"
#include "operators/gather.h" #include "operators/gather.h"
#include "operators/matmul.h" #include "operators/matmul.h"
#include "operators/pad.h" #include "operators/pad.h"
@ -14,6 +15,7 @@
#include "operators/split.h" #include "operators/split.h"
#include "operators/transpose.h" #include "operators/transpose.h"
#include "operators/unary.h" #include "operators/unary.h"
#include "operators/where.h"
namespace infini { namespace infini {
@ -155,6 +157,7 @@ DEFINE_UNARY_METHOD(tanh, Tanh)
DEFINE_UNARY_METHOD(abs, Abs) DEFINE_UNARY_METHOD(abs, Abs)
DEFINE_UNARY_METHOD(sqrt, Sqrt) DEFINE_UNARY_METHOD(sqrt, Sqrt)
DEFINE_UNARY_METHOD(shape, Shape) DEFINE_UNARY_METHOD(shape, Shape)
DEFINE_UNARY_METHOD(erf, Erf)
// see operators/reshape.h // see operators/reshape.h
DEFINE_UNARY_METHOD(identity, Identity) DEFINE_UNARY_METHOD(identity, Identity)
@ -309,6 +312,31 @@ Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) {
} }
} }
Tensor GraphHandlerObj::expand(Tensor input, Tensor output, Shape dims) {
if (output) {
g->addOpWithOutputs<ExpandObj>(std::move(input), output,
std::move(dims));
return output;
} else {
return g->addOp<ExpandObj>(std::move(input), output, std::move(dims))
->getOutput();
}
}
Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
Tensor output) {
if (output) {
g->addOpWithOutputs<WhereObj>(std::move(inputX), std::move(inputY),
std::move(condition), output);
return output;
} else {
return g
->addOp<WhereObj>(std::move(inputX), std::move(inputY),
std::move(condition), output)
->getOutput();
}
}
static CastType inferCastType(Tensor input, int to) { static CastType inferCastType(Tensor input, int to) {
auto iType = input->getDType(); auto iType = input->getDType();
auto oType = DataType(to); auto oType = DataType(to);

View File

@ -2,6 +2,7 @@
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/concat.h" #include "operators/concat.h"
#include "operators/conv.h" #include "operators/conv.h"
#include "operators/expand.h"
#include "operators/gather.h" #include "operators/gather.h"
#include "operators/matmul.h" #include "operators/matmul.h"
#include "operators/pad.h" #include "operators/pad.h"
@ -96,6 +97,10 @@ void export_values(py::module &m) {
.VALUE(OpType, Resize) .VALUE(OpType, Resize)
.VALUE(OpType, Dropout) .VALUE(OpType, Dropout)
.VALUE(OpType, Cast) .VALUE(OpType, Cast)
.VALUE(OpType, Sqrt)
.VALUE(OpType, Expand)
.VALUE(OpType, Erf)
.VALUE(OpType, Where)
.export_values(); .export_values();
#undef VALUE #undef VALUE
@ -226,6 +231,15 @@ static vector<int64_t> reshape_shape_of(Operator op) {
return ans; return ans;
} }
static vector<int64_t> expand_shape_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Expand);
auto shape = dynamic_cast<const ExpandObj *>(op.get())->getShape();
vector<int64_t> ans(shape.size());
std::transform(shape.begin(), shape.end(), ans.begin(),
[](auto x) { return static_cast<int64_t>(x); });
return ans;
}
static vector<int64_t> pad_pads_of(Operator op) { static vector<int64_t> pad_pads_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Pad); IT_ASSERT(op->getOpType() == OpType::Pad);
auto shape = dynamic_cast<const PadObj *>(op.get())->getPads(); auto shape = dynamic_cast<const PadObj *>(op.get())->getPads();
@ -276,6 +290,7 @@ void export_functions(py::module &m) {
.FUNCTION(reduce_mean_attrs_of) .FUNCTION(reduce_mean_attrs_of)
.FUNCTION(tensor_dtype) .FUNCTION(tensor_dtype)
.FUNCTION(reshape_shape_of) .FUNCTION(reshape_shape_of)
.FUNCTION(expand_shape_of)
.FUNCTION(pad_pads_of) .FUNCTION(pad_pads_of)
.FUNCTION(transpose_permute_of) .FUNCTION(transpose_permute_of)
.FUNCTION(concat_axis_of) .FUNCTION(concat_axis_of)
@ -359,6 +374,9 @@ void init_graph_builder(py::module &m) {
.def("slice", &Handler::slice, policy::move) .def("slice", &Handler::slice, policy::move)
.def("pad", &Handler::pad, policy::move) .def("pad", &Handler::pad, policy::move)
.def("cast", &Handler::cast, policy::move) .def("cast", &Handler::cast, policy::move)
.def("expand", &Handler::expand, policy::move)
.def("erf", &Handler::erf, policy::move)
.def("where", &Handler::where, policy::move)
.def("topo_sort", &Handler::topo_sort, policy::automatic) .def("topo_sort", &Handler::topo_sort, policy::automatic)
.def("optimize", &Handler::optimize, policy::automatic) .def("optimize", &Handler::optimize, policy::automatic)
.def("operators", &Handler::operators, policy::move) .def("operators", &Handler::operators, policy::move)

View File

@ -0,0 +1,36 @@
#include "operators/expand.h"
#include "cuda/cuda_expand.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class ExpandCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<ExpandObj>(_op);
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
const auto &in_Shape = op->getInputs(0)->getDims(); // input shape
const auto &out_Shape = op->getShape(); // output shape
SmallArray inputShape, outputShape;
int nDims = op->getInputs(0)->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
int outputsize = 1; // the length of the output vector after flatten
for (int i = 0; i < nDims; ++i) {
outputShape.data[i] = out_Shape[i];
inputShape.data[i] = in_Shape[i];
outputsize *= out_Shape[i];
}
expand_kernel((float *)inputData, (float *)outputData, nDims,
outputsize, inputShape, outputShape);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Expand, DataType::Float32, ExpandCuda,
"Expand_CUDA_Float32");
}; // namespace infini

View File

@ -0,0 +1,49 @@
#include "core/common.h"
#include "cuda/cuda_common.h"
#include "utils/small_array.h"
constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
__global__ void _expand_kernel(float *input, float *output, int nDims,
int outputsize, infini::SmallArray inputShape,
infini::SmallArray outputShape) {
int outputIdx =
blockIdx.x * blockDim.x + threadIdx.x; // i(JKS) + j(KS) + k(S) + s
if (outputIdx < outputsize) {
int inputIdx = 0; // record input index
int temp = 1; // stored S, KS, JKS, in order
int tmp = 1; // stored s,k,j,i in order
int v = outputIdx; // v = i(JKS) + j(KS) + k(S) + s
for (int i = nDims - 1; i >= 0; --i) {
if (i == 0) {
tmp = v; // i = outputIdx/(JKS)
} else {
tmp = v % outputShape.data[i]; // store s,k,j in order
}
if (inputShape.data[i] ==
1) { // if input shape = 1, the index only equal 0
inputIdx += 0;
} else {
inputIdx +=
tmp * temp; // otherwise +i(JKS) or j(KS) or k(S) or s
}
temp *= inputShape.data[i];
v = v / outputShape.data[i];
}
output[outputIdx] = input[inputIdx];
}
}
namespace infini {
void expand_kernel(float *input, float *output, int nDims, int outputsize,
SmallArray inputShape, SmallArray outputShape) {
int blocksize = block_work_size();
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
_expand_kernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
inputShape, outputShape);
}
} // namespace infini

View File

@ -12,7 +12,8 @@ class GatherCuda : public CudaKernelWithoutConfig {
auto in = op->getInputs(0); auto in = op->getInputs(0);
auto index = op->getInputs(1); auto index = op->getInputs(1);
auto out = op->getOutput(); auto out = op->getOutput();
metaData.indexValue = index->getRawDataPtr<int *>(); metaData.indexValue = index->getRawDataPtr<void *>();
metaData.indexType = index->getDType();
metaData.axis = op->getAxis(); metaData.axis = op->getAxis();
metaData.inNDim = in->getRank(); metaData.inNDim = in->getRank();
metaData.outNDim = out->getRank(); metaData.outNDim = out->getRank();

View File

@ -1,19 +1,21 @@
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "cuda/gather.h" #include "cuda/gather.h"
__device__ int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) { template <typename T>
int offset = 0; __device__ T gatheredOffset2Offset(int gOffset,
infini::GatherMetaData metaData) {
T offset = 0;
for (int i = metaData.inNDim - 1, k = metaData.outNDim - 1; i >= 0; --i) { for (int i = metaData.inNDim - 1, k = metaData.outNDim - 1; i >= 0; --i) {
int idx = 0; T idx = 0;
if (i == metaData.axis) { if (i == metaData.axis) {
int idxOffset = 0; T idxOffset = 0;
for (int j = metaData.idxNDim - 1; j >= 0; --j) { for (int j = metaData.idxNDim - 1; j >= 0; --j) {
int p = gOffset % metaData.idxDim[j]; T p = gOffset % metaData.idxDim[j];
gOffset = gOffset / metaData.idxDim[j]; gOffset = gOffset / metaData.idxDim[j];
idxOffset += p * metaData.idxStride[j]; idxOffset += p * metaData.idxStride[j];
} }
idx = metaData.indexValue[idxOffset]; idx = static_cast<T *>(metaData.indexValue)[idxOffset];
k = k - metaData.idxNDim; k = k - metaData.idxNDim;
} else { } else {
@ -26,22 +28,27 @@ __device__ int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) {
return offset; return offset;
} }
__global__ void _gather_kernel(float *in, float *out, GatherMetaData metaData, template <typename T>
int num) { __global__ void _gather_kernel(float *in, float *out,
int tid = threadIdx.x + blockIdx.x * blockDim.x; infini::GatherMetaData metaData, size_t num) {
T tid = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
while (tid < num) { while (tid < num) {
int offset = gatheredOffset2Offset(tid, metaData); T offset = gatheredOffset2Offset<T>(tid, metaData);
out[tid] = in[offset]; out[tid] = in[offset];
tid += stride; tid += stride;
} }
} }
namespace infini { namespace infini {
void gather_kernel(float *in, float *out, GatherMetaData metaData, int num) { void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num) {
int blockSize = 32 * 16; int blockSize = 32 * 16;
int gridSize = (num + blockSize - 1) / blockSize; int gridSize = (num + blockSize - 1) / blockSize;
if (metaData.indexType == DataType::Int64) {
_gather_kernel<<<gridSize, blockSize>>>(in, out, metaData, num); _gather_kernel<int64_t>
<<<gridSize, blockSize>>>(in, out, metaData, num);
} else {
_gather_kernel<int><<<gridSize, blockSize>>>(in, out, metaData, num);
}
} }
} // namespace infini } // namespace infini

41
src/kernels/cuda/where.cc Normal file
View File

@ -0,0 +1,41 @@
#include "operators/where.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_where.h"
namespace infini {
class WhereCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<WhereObj>(_op);
void *const inputxData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const inputyData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
const auto &inputx_Shape = op->getInputs(0)->getDims();
const auto &inputy_Shape = op->getInputs(1)->getDims();
const auto &condition_Shape = op->getInputs(2)->getDims();
const auto &output_Shape = op->getOutput()->getDims();
int nDims = op->getInputs(0)->getDims().size();
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
SmallArray inputxShape, inputyShape, conditionShape, outputShape;
for (int i = 0; i < nDims; ++i) {
inputxShape.data[i] = inputx_Shape[i];
inputyShape.data[i] = inputy_Shape[i];
conditionShape.data[i] = condition_Shape[i];
outputShape.data[i] = output_Shape[i];
}
where_kernel((float *)inputxData, (float *)inputyData,
(float *)conditionData, (float *)outputData, nDims,
inputxShape, inputyShape, conditionShape, outputShape);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Where, DataType::Float32, WhereCuda,
"Where_CUDA_Float32");
}; // namespace infini

82
src/kernels/cuda/where.cu Normal file
View File

@ -0,0 +1,82 @@
#include "cuda/cuda_common.h"
#include "utils/small_array.h"
__global__ void _where_kernel(const float *inputx, const float *inputy,
const float *condition, float *output, int nDims,
int outputsize, infini::SmallArray inputxShape,
infini::SmallArray inputyShape,
infini::SmallArray conditionShape,
infini::SmallArray outputShape) {
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (outputIdx < outputsize) {
int inputxIdx = 0;
int temp_inputx = 1;
int inputyIdx = 0;
int temp_inputy = 1;
int conditionIdx = 0;
int temp_condition = 1;
int tmp = 1; // stored s,k,j,i in order
int v = outputIdx; // v = i(JKS) + j(KS) + k(S) + s
for (int i = nDims - 1; i >= 0; --i) {
if (i == 0) {
tmp = v; // i = outputIdx/(JKS)
} else {
tmp = v % outputShape.data[i]; // store s,k,j in order
}
if (inputxShape.data[i] == 1) {
inputxIdx += 0;
} else {
inputxIdx +=
tmp *
temp_inputx; // otherwise +i(JKS) or j(KS) or k(S) or s
}
temp_inputx *= inputxShape.data[i];
//----------------------------
if (inputyShape.data[i] == 1) {
inputyIdx += 0;
} else {
inputyIdx +=
tmp *
temp_inputy; // otherwise +i(JKS) or j(KS) or k(S) or s
}
temp_inputy *= inputyShape.data[i];
//--------------------------
if (conditionShape.data[i] == 1) {
conditionIdx += 0;
} else {
conditionIdx +=
tmp *
temp_condition; // otherwise +i(JKS) or j(KS) or k(S) or s
}
temp_condition *= conditionShape.data[i];
//-------------------------
v = v / outputShape.data[i];
}
output[outputIdx] =
condition[conditionIdx] ? inputx[inputxIdx] : inputy[inputyIdx];
}
}
namespace infini {
void where_kernel(const float *inputx, const float *inputy,
const float *condition, float *output, int nDims,
infini::SmallArray inputxShape,
infini::SmallArray inputyShape,
infini::SmallArray conditionShape,
infini::SmallArray outputShape) {
int outputsize = 1;
for (int i = 0; i < nDims; i++) {
outputsize *= outputShape.data[i];
}
int blocksize = 32 * 16;
int gridsize = (outputsize + blocksize - 1) / blocksize;
_where_kernel<<<gridsize, blocksize>>>(
inputx, inputy, condition, output, nDims, outputsize, inputxShape,
inputyShape, conditionShape, outputShape);
}
} // namespace infini

41
src/operators/expand.cc Normal file
View File

@ -0,0 +1,41 @@
#include "operators/expand.h"
#include "utils/operator_utils.h"
namespace infini {
ExpandObj::ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
: OperatorObj(OpType::Expand, {input}, {output}), dims(std::move(dims)) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> ExpandObj::inferShape(const TensorVec &inputs) const {
auto shape_input = inputs[0]->getDims();
Shape ret = infer_broadcast(shape_input, dims);
return {{ret}};
}
std::string ExpandObj::toString() const {
std::ostringstream os;
os << "Expand[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "dims=" << vecToString(dims) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> ExpandObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims();
ret.insert(ret.end(), dims.begin(), dims.end());
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> ExpandObj::getOpAttrVector() const {
vector<int> ret = dims;
ret.emplace(ret.begin(), type.underlying());
return ret;
}
} // namespace infini

View File

@ -24,8 +24,8 @@ optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) const {
vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const { vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() == 2); IT_ASSERT(inputs.size() == 2);
auto index = inputs[1]; auto index_dtype = inputs[1]->getDType();
IT_ASSERT(index->getDType() == DataType::Int32); IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64)
return {inputs[0]->getDType()}; return {inputs[0]->getDType()};
} }
@ -36,19 +36,31 @@ bool GatherObj::CheckIndexValid() const {
return true; return true;
Runtime runtime = NativeCpuRuntimeObj::getInstance(); Runtime runtime = NativeCpuRuntimeObj::getInstance();
int *data = (int *)runtime->alloc(index->getBytes());
index->getRuntime()->copyBlobToCPU(
(void *)data, index->getRawDataPtr<void *>(), index->getBytes());
bool ret = true; bool ret = true;
auto value = inputs[0]->getDims()[axis]; auto value = inputs[0]->getDims()[axis];
for (size_t i = 0; i < index->size(); ++i) { if (index->getDType() == DataType::Int32) {
if (data[i] < 0 || data[i] >= value) { int *data = (int *)runtime->alloc(index->getBytes());
ret = false; index->getRuntime()->copyBlobToCPU(
break; (void *)data, index->getRawDataPtr<void *>(), index->getBytes());
for (size_t i = 0; i < index->size(); ++i) {
if (data[i] < 0 || data[i] >= value) {
ret = false;
break;
}
} }
runtime->dealloc(data);
} else {
int64_t *data = (int64_t *)runtime->alloc(index->getBytes());
index->getRuntime()->copyBlobToCPU(
(void *)data, index->getRawDataPtr<void *>(), index->getBytes());
for (size_t i = 0; i < index->size(); ++i) {
if (data[i] < 0 || data[i] >= value) {
ret = false;
break;
}
}
runtime->dealloc(data);
} }
runtime->dealloc(data);
return ret; return ret;
} }

View File

@ -20,7 +20,7 @@ MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
if (ret.empty()) { if (ret.empty()) {
b = 1; b = 1;
} else { } else {
b = std::accumulate(ret.begin(), ret.end(), 1); b = std::accumulate(ret.begin(), ret.end(), 1, std::multiplies<int>());
} }
auto kA = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin()); auto kA = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
auto kB = *(transB ? shape_b.rbegin() : shape_b.rbegin() + 1); auto kB = *(transB ? shape_b.rbegin() : shape_b.rbegin() + 1);

42
src/operators/where.cc Normal file
View File

@ -0,0 +1,42 @@
#include "operators/where.h"
#include "utils/operator_utils.h"
namespace infini {
WhereObj::WhereObj(GraphObj *graph, Tensor inputX, Tensor inputY,
Tensor condition, Tensor output)
: OperatorObj(OpType::Where, TensorVec{inputX, inputY, condition},
{output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> WhereObj::inferShape(const TensorVec &inputs) const {
auto shapeX = inputs[0]->getDims();
auto shapeY = inputs[1]->getDims();
auto shapeCon = inputs[2]->getDims();
auto retXY = infer_broadcast(shapeX, shapeY);
auto ret = infer_broadcast(retXY, shapeCon);
return {{ret}};
}
std::string WhereObj::toString() const {
std::ostringstream os;
os << "Where[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[2]->getDims()) << ",";
os << "inputX=" << inputs[0]->getGuid() << ",";
os << "inputY=" << inputs[1]->getGuid() << ",";
os << "condition=" << inputs[2]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> WhereObj::getWorkloadVector() const {
vector<int> ret = getOutput()->getDims();
ret.emplace(ret.begin(), type.underlying());
return ret;
}
vector<int> WhereObj::getOpAttrVector() const { return {type.underlying()}; }
} // namespace infini

View File

@ -0,0 +1,41 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/expand.h"
#include "test.h"
namespace infini {
TEST(Expand, Cuda) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto t1 = gCpu->addTensor({2, 1, 2, 1}, DataType::Float32);
gCpu->dataMalloc();
t1->setData(IncrementalGenerator());
t1->printData();
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto t1Gpu = gCuda->cloneTensor(t1);
auto op = gCuda->addOp<ExpandObj>(t1Gpu, nullptr, Shape{2, 2, 2, 3});
gCuda->dataMalloc();
t1Gpu->setData(IncrementalGenerator());
cudaRuntime->run(gCuda);
// cudaPrintTensor(op->getOutput());
// copy output from CUDA to CPU
auto oCpu = gCpu->cloneTensor(op->getOutput());
oCpu->printData();
EXPECT_TRUE(
oCpu->equalData(vector<float>{0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}));
}
} // namespace infini

View File

@ -77,7 +77,7 @@ int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) {
idxOffset += p * metaData.idxStride[j]; idxOffset += p * metaData.idxStride[j];
} }
idx = metaData.indexValue[idxOffset]; idx = static_cast<int *>(metaData.indexValue)[idxOffset];
k = k - metaData.idxNDim; k = k - metaData.idxNDim;
} else { } else {
@ -242,6 +242,31 @@ TEST(Gather, Cuda) {
indexCuda->copyin(vector<int>{0, 3, 1}); indexCuda->copyin(vector<int>{0, 3, 1});
cudaRuntime->run(gCuda); cudaRuntime->run(gCuda);
// cudaPrintTensor(op->getOutput());
// copy output from CUDA to CPU
auto oCpu = gCpu->cloneTensor(op->getOutput());
EXPECT_TRUE(oCpu->equalData(
vector<float>{0, 1, 6, 7, 2, 3, 8, 9, 14, 15, 10, 11}));
}
{
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32);
auto index = gCpu->addTensor({3, 1}, DataType::Int64);
gCpu->dataMalloc();
input->setData(IncrementalGenerator());
index->copyin(vector<int64_t>{0, 3, 1});
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputCuda = gCuda->cloneTensor(input);
auto indexCuda = gCuda->cloneTensor(index);
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 1);
gCuda->dataMalloc();
inputCuda->setData(IncrementalGenerator());
indexCuda->copyin(vector<int64_t>{0, 3, 1});
cudaRuntime->run(gCuda);
// cudaPrintTensor(op->getOutput()); // cudaPrintTensor(op->getOutput());
// copy output from CUDA to CPU // copy output from CUDA to CPU
auto oCpu = gCpu->cloneTensor(op->getOutput()); auto oCpu = gCpu->cloneTensor(op->getOutput());

View File

@ -0,0 +1,63 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/where.h"
#include "test.h"
namespace infini {
void test_where(const Shape &inputxshape, const vector<float> &inputxdata,
const Shape &inputyshape, const vector<float> &inputydata,
const Shape &conditionshape, const vector<int> &conditiondata,
const vector<float> &ExpectData) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto condition = gCpu->addTensor(conditionshape, DataType::Int32);
auto inputx = gCpu->addTensor(inputxshape, DataType::Float32);
auto inputy = gCpu->addTensor(inputyshape, DataType::Float32);
gCpu->dataMalloc();
condition->copyin(conditiondata); //
inputx->copyin(inputxdata);
inputy->copyin(inputydata); //
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto conditionGpu = gCuda->cloneTensor(condition);
auto inputxGpu = gCuda->cloneTensor(inputx);
auto inputyGpu = gCuda->cloneTensor(inputy);
auto op = gCuda->addOp<WhereObj>(inputxGpu, inputyGpu, conditionGpu,
nullptr); // WhereObj
gCuda->dataMalloc();
conditionGpu->copyin(conditiondata);
inputxGpu->copyin(inputxdata);
inputyGpu->copyin(inputydata);
cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move data from gpu to cpu
oCpu->printData(); //->printData
EXPECT_TRUE(oCpu->equalData(ExpectData));
}
TEST(CUDA_Where, run) {
test_where(
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
Shape{2, 2, 3, 1}, vector<int>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.});
test_where(Shape{2, 1, 1, 3}, // inputx
vector<float>{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy
vector<float>{1, 1}, Shape{2, 1, 3, 1}, // condition
vector<int>{0, 1, 1, 0, 0, 0},
vector<float>{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1.,
0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
} // python output
} // namespace infini

View File

@ -0,0 +1,26 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/expand.h"
#include "test.h"
namespace infini {
TEST(Expand, ShapeInference) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({3, 1}, DataType::Float32);
auto op = g->addOp<ExpandObj>(i, nullptr, Shape{2, 1, 6});
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 6}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({3, 1}, DataType::Float32);
auto op = g->addOp<ExpandObj>(i, nullptr, Shape{3, 4});
EXPECT_EQ(op->getOutput()->getDims(), (Shape{3, 4}));
}
}
} // namespace infini

View File

@ -9,11 +9,19 @@ namespace infini {
TEST(Gather, ShapeInference) { TEST(Gather, ShapeInference) {
Runtime runtime = NativeCpuRuntimeObj::getInstance(); Runtime runtime = NativeCpuRuntimeObj::getInstance();
{
Graph g = make_ref<GraphObj>(runtime); Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({1, 3, 4, 4}, DataType::Int32); Tensor i = g->addTensor({1, 3, 4, 4}, DataType::Int32);
Tensor index = g->addTensor({2, 1, 2}, DataType::Int32); Tensor index = g->addTensor({2, 1, 2}, DataType::Int32);
auto op = g->addOp<GatherObj>(i, index, nullptr, 1); auto op = g->addOp<GatherObj>(i, index, nullptr, 1);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4})); EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({1, 3, 4, 4}, DataType::Int32);
Tensor index = g->addTensor({2, 1, 2}, DataType::Int64);
auto op = g->addOp<GatherObj>(i, index, nullptr, 1);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
}
} }
} // namespace infini } // namespace infini

View File

@ -21,6 +21,12 @@ TEST(ReduceMean, ShapeInference) {
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, true); auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, true);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1})); EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1}));
} }
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{-3, 3}, true);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1}));
}
{ {
Graph g = make_ref<GraphObj>(runtime); Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
@ -33,6 +39,13 @@ TEST(ReduceMean, ShapeInference) {
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, false); auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, false);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3})); EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3}));
} }
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
auto op =
g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{-3, 3}, false);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3}));
}
} }
} // namespace infini } // namespace infini

View File

@ -0,0 +1,46 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/where.h"
#include "test.h"
namespace infini {
TEST(Where, ShapeInference) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
{
Graph g = make_ref<GraphObj>(runtime);
Tensor x = g->addTensor({2, 2}, DataType::Float32);
Tensor y = g->addTensor({2, 2}, DataType::Float32);
Tensor con = g->addTensor({2, 2}, DataType::Bool);
auto op = g->addOp<WhereObj>(x, y, con, nullptr);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 2}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor x = g->addTensor({1, 12, 224, 224}, DataType::Float32);
Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float32);
Tensor con = g->addTensor({1, 224, 1}, DataType::Bool);
auto op = g->addOp<WhereObj>(x, y, con, nullptr);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 12, 224, 224}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor x = g->addTensor({12, 224, 224}, DataType::Float32);
Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float32);
Tensor con = g->addTensor({1, 224}, DataType::Bool);
auto op = g->addOp<WhereObj>(x, y, con, nullptr);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 12, 224, 224}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor x = g->addTensor({12, 224, 224}, DataType::Float32);
Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float32);
Tensor con = g->addTensor({2, 1, 1, 1, 224}, DataType::Bool);
auto op = g->addOp<WhereObj>(x, y, con, nullptr);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 12, 224, 224}));
}
}
} // namespace infini