forked from jiuyuan/InfiniTensor
框架支持bert/gpt2模型构图 (#94)
* feat: support to sqrt op * feat: support to erf op * feat: support to expand op * feat: support to where op * fix: gather op index can be int64_t(hard coding) * fix: some wrong use * style: fix the format style * test: add test for change op * fix: rebase to master * fix: fix matmul b compute wrong * add expand and where kernel * Add int64 support for cuda gather kernel * add test_where.cc * add "expand.(cu/cc,test,cuda),modified where.cu" * Separate initialization of datatypes to avoid compile error * modify where.(cu/cc/h,test), expand and clip * Format fix * Format fix --------- Co-authored-by: xgqdut2016 <kenan_gewei@163.com> Co-authored-by: panzezhong <panzezhong@qiyuanlab.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
d8ffd8a4b7
commit
3e6ef305f1
|
@ -1,3 +1,4 @@
|
|||
#pragma once
|
||||
#include "core/common.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -69,23 +70,6 @@ class DataType {
|
|||
int getIndex() const { return index; }
|
||||
};
|
||||
|
||||
// to be consistent with onnx
|
||||
// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484
|
||||
inline const DataType DataType::Undefine(0);
|
||||
inline const DataType DataType::Float32(1);
|
||||
inline const DataType DataType::UInt8(2);
|
||||
inline const DataType DataType::Int8(3);
|
||||
inline const DataType DataType::UInt16(4);
|
||||
inline const DataType DataType::Int16(5);
|
||||
inline const DataType DataType::Int32(6);
|
||||
inline const DataType DataType::Int64(7);
|
||||
inline const DataType DataType::String(8);
|
||||
inline const DataType DataType::Bool(9);
|
||||
inline const DataType DataType::Float16(10);
|
||||
inline const DataType DataType::Double(11);
|
||||
inline const DataType DataType::UInt32(12);
|
||||
inline const DataType DataType::UInt64(13);
|
||||
inline const DataType DataType::BFloat16(16);
|
||||
// Method definitions are out of the declaration due to GCC bug:
|
||||
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
|
||||
template <> inline int DataType::get<float>() { return 0; }
|
||||
|
|
|
@ -47,6 +47,7 @@ class GraphHandlerObj {
|
|||
Tensor relu(Tensor x, Tensor y);
|
||||
Tensor sigmoid(Tensor x, Tensor y);
|
||||
Tensor tanh(Tensor x, Tensor y);
|
||||
Tensor erf(Tensor x, Tensor y);
|
||||
Tensor softmax(Tensor x, Tensor y, int axis);
|
||||
Tensor abs(Tensor x, Tensor y);
|
||||
Tensor sqrt(Tensor x, Tensor y);
|
||||
|
@ -70,6 +71,8 @@ class GraphHandlerObj {
|
|||
Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
|
||||
const optional<vector<int>> &axes);
|
||||
Tensor cast(Tensor input, Tensor output, int to);
|
||||
Tensor expand(Tensor input, Tensor output, Shape dims);
|
||||
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
|
||||
|
||||
//------ modifiers
|
||||
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include "operators/unary.h"
|
||||
#include "utils/small_array.h"
|
||||
namespace infini {
|
||||
void expand_kernel(float *input, float *output, int nDims, int outputsize,
|
||||
SmallArray inputShape, SmallArray outputShape);
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,13 @@
|
|||
#pragma once
|
||||
#include "operators/unary.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
namespace infini {
|
||||
void where_kernel(const float *inputx, const float *inputy,
|
||||
const float *condition, float *output, int nDims,
|
||||
infini::SmallArray inputxShape,
|
||||
infini::SmallArray inputyShape,
|
||||
infini::SmallArray conditionShape,
|
||||
infini::SmallArray outputShape);
|
||||
|
||||
}; // namespace infini
|
|
@ -1,7 +1,10 @@
|
|||
#pragma once
|
||||
#include "core/data_type.h"
|
||||
|
||||
typedef struct {
|
||||
int *indexValue;
|
||||
namespace infini {
|
||||
struct GatherMetaData {
|
||||
void *indexValue;
|
||||
DataType indexType;
|
||||
int axis;
|
||||
int inNDim;
|
||||
int outNDim;
|
||||
|
@ -10,8 +13,7 @@ typedef struct {
|
|||
int idxDim[4];
|
||||
int idxStride[4];
|
||||
int inStride[4];
|
||||
} GatherMetaData;
|
||||
};
|
||||
|
||||
namespace infini {
|
||||
void gather_kernel(float *in, float *out, GatherMetaData metaData, int num);
|
||||
}
|
||||
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num);
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Broadcast the input tensor following the given shape and the
|
||||
* broadcast rule.
|
||||
*
|
||||
*/
|
||||
class ExpandObj : public OperatorObj {
|
||||
Shape dims;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Expand object.
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param dims The shape you want to expand to, following the broadcast
|
||||
* rule.
|
||||
*/
|
||||
ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
|
||||
OP_CLONE(ExpandObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
Shape getShape() const { return dims; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,36 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
/**
|
||||
* @brief Return elements, either from X or Y, depending on condition.
|
||||
*
|
||||
*/
|
||||
class WhereObj : public OperatorObj {
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Where object.
|
||||
*
|
||||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param inputX The input tensor X.
|
||||
* @param inputY The input tensor Y.
|
||||
* @param output The output tensor.
|
||||
* @param condition The condition tensor.
|
||||
*/
|
||||
WhereObj(GraphObj *graph, Tensor inputX, Tensor inputY, Tensor condition,
|
||||
Tensor output);
|
||||
OP_CLONE(WhereObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -409,7 +409,8 @@ class OnnxStub:
|
|||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
next(
|
||||
(attr.i for attr in node.attribute if attr.name == "axis")
|
||||
(attr.i for attr in node.attribute if attr.name == "axis"),
|
||||
1,
|
||||
),
|
||||
)
|
||||
elif node.op_type == "PRelu":
|
||||
|
@ -517,7 +518,8 @@ class OnnxStub:
|
|||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
next(
|
||||
(attr.i for attr in node.attribute if attr.name == "axis"), 0
|
||||
(attr.i for attr in node.attribute if attr.name == "axis"),
|
||||
0,
|
||||
),
|
||||
)
|
||||
elif node.op_type == "ReduceMean":
|
||||
|
@ -539,7 +541,7 @@ class OnnxStub:
|
|||
for attr in node.attribute
|
||||
if attr.name == "keepdims"
|
||||
),
|
||||
1
|
||||
1,
|
||||
)
|
||||
!= 0,
|
||||
)
|
||||
|
@ -589,6 +591,25 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "to")),
|
||||
)
|
||||
elif node.op_type == "Expand":
|
||||
shape = _parse_data(data[node.input[1]])
|
||||
tensors[node.output[0]] = self.handler.expand(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
shape,
|
||||
)
|
||||
elif node.op_type == "Erf":
|
||||
tensors[node.output[0]] = self.handler.erf(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Where":
|
||||
tensors[node.output[0]] = self.handler.where(
|
||||
tensors[node.input[1]],
|
||||
tensors[node.input[2]],
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
)
|
||||
else:
|
||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||
new_node_name.append(node.name)
|
||||
|
@ -814,6 +835,8 @@ class OnnxStub:
|
|||
backend.OpTypeId.Abs,
|
||||
backend.OpTypeId.Identity,
|
||||
backend.OpTypeId.PRelu,
|
||||
backend.OpTypeId.Sqrt,
|
||||
backend.OpTypeId.Erf,
|
||||
]:
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpTypeId.Flatten:
|
||||
|
@ -904,6 +927,13 @@ class OnnxStub:
|
|||
elif ty == backend.OpTypeId.Cast:
|
||||
to = backend.cast_to_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, to=to))
|
||||
elif ty == backend.OpTypeId.Where:
|
||||
assert len(inputs) == 3, "Check Where Op must have three inputs."
|
||||
new_inputs = [inputs[2], inputs[0], inputs[1]]
|
||||
ctx.push_node(make_node(ty.name, new_inputs, outputs, name))
|
||||
elif ty == backend.OpTypeId.Expand:
|
||||
shape = backend.expand_shape_of(op)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape))
|
||||
else:
|
||||
raise Exception("Unsupported OpType", ty)
|
||||
|
||||
|
|
|
@ -207,6 +207,18 @@ class TestStringMethods(unittest.TestCase):
|
|||
relu = make_node("Relu", ["x"], ["y"], name="relu")
|
||||
make_and_import_model(make_graph([relu], "relu", [x], [y]))
|
||||
|
||||
def test_erf(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
erf = make_node("Erf", ["x"], ["y"], name="erf")
|
||||
make_and_import_model(make_graph([erf], "erf", [x], [y]))
|
||||
|
||||
def test_sqrt(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
sqrt = make_node("Sqrt", ["x"], ["y"], name="sqrt")
|
||||
make_and_import_model(make_graph([sqrt], "sqrt", [x], [y]))
|
||||
|
||||
def test_sigmoid(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
|
@ -352,6 +364,24 @@ class TestStringMethods(unittest.TestCase):
|
|||
)
|
||||
make_and_import_model(make_graph([cast], "cast", [input1], [output]))
|
||||
|
||||
def test_expand(self):
|
||||
data = make_tensor_value_info("data", TensorProto.FLOAT, [3, 1])
|
||||
dim = make_tensor_value_info("dim", TensorProto.INT64, [3])
|
||||
dim_data = make_tensor("dim", TensorProto.INT64, [3], [2, 1, 6])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [2, 3, 6])
|
||||
expand = make_node("Expand", ["data", "dim"], ["output"], name="expand")
|
||||
make_and_import_model(
|
||||
make_graph([expand], "expand", [data, dim], [output], [dim_data])
|
||||
)
|
||||
|
||||
def test_where(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
con = make_tensor_value_info("con", TensorProto.BOOL, [1, 3, 5, 7])
|
||||
output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
where = make_node("Where", ["x", "y", "con"], ["output"], name="where")
|
||||
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
#include "core/data_type.h"
|
||||
|
||||
namespace infini {
|
||||
// Move implementation here to avoid compile time error on some platform
|
||||
// to be consistent with onnx
|
||||
// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484
|
||||
const DataType DataType::Undefine(0);
|
||||
const DataType DataType::Float32(1);
|
||||
const DataType DataType::UInt8(2);
|
||||
const DataType DataType::Int8(3);
|
||||
const DataType DataType::UInt16(4);
|
||||
const DataType DataType::Int16(5);
|
||||
const DataType DataType::Int32(6);
|
||||
const DataType DataType::Int64(7);
|
||||
const DataType DataType::String(8);
|
||||
const DataType DataType::Bool(9);
|
||||
const DataType DataType::Float16(10);
|
||||
const DataType DataType::Double(11);
|
||||
const DataType DataType::UInt32(12);
|
||||
const DataType DataType::UInt64(13);
|
||||
// TODO: Reserved for complex data type.
|
||||
const DataType DataType::BFloat16(16);
|
||||
} // namespace infini
|
|
@ -3,6 +3,7 @@
|
|||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/expand.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/pad.h"
|
||||
|
@ -14,6 +15,7 @@
|
|||
#include "operators/split.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "operators/where.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -155,6 +157,7 @@ DEFINE_UNARY_METHOD(tanh, Tanh)
|
|||
DEFINE_UNARY_METHOD(abs, Abs)
|
||||
DEFINE_UNARY_METHOD(sqrt, Sqrt)
|
||||
DEFINE_UNARY_METHOD(shape, Shape)
|
||||
DEFINE_UNARY_METHOD(erf, Erf)
|
||||
|
||||
// see operators/reshape.h
|
||||
DEFINE_UNARY_METHOD(identity, Identity)
|
||||
|
@ -309,6 +312,31 @@ Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::expand(Tensor input, Tensor output, Shape dims) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<ExpandObj>(std::move(input), output,
|
||||
std::move(dims));
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<ExpandObj>(std::move(input), output, std::move(dims))
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition,
|
||||
Tensor output) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<WhereObj>(std::move(inputX), std::move(inputY),
|
||||
std::move(condition), output);
|
||||
return output;
|
||||
} else {
|
||||
return g
|
||||
->addOp<WhereObj>(std::move(inputX), std::move(inputY),
|
||||
std::move(condition), output)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
static CastType inferCastType(Tensor input, int to) {
|
||||
auto iType = input->getDType();
|
||||
auto oType = DataType(to);
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#include "operators/batch_norm.h"
|
||||
#include "operators/concat.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/expand.h"
|
||||
#include "operators/gather.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/pad.h"
|
||||
|
@ -96,6 +97,10 @@ void export_values(py::module &m) {
|
|||
.VALUE(OpType, Resize)
|
||||
.VALUE(OpType, Dropout)
|
||||
.VALUE(OpType, Cast)
|
||||
.VALUE(OpType, Sqrt)
|
||||
.VALUE(OpType, Expand)
|
||||
.VALUE(OpType, Erf)
|
||||
.VALUE(OpType, Where)
|
||||
.export_values();
|
||||
|
||||
#undef VALUE
|
||||
|
@ -226,6 +231,15 @@ static vector<int64_t> reshape_shape_of(Operator op) {
|
|||
return ans;
|
||||
}
|
||||
|
||||
static vector<int64_t> expand_shape_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Expand);
|
||||
auto shape = dynamic_cast<const ExpandObj *>(op.get())->getShape();
|
||||
vector<int64_t> ans(shape.size());
|
||||
std::transform(shape.begin(), shape.end(), ans.begin(),
|
||||
[](auto x) { return static_cast<int64_t>(x); });
|
||||
return ans;
|
||||
}
|
||||
|
||||
static vector<int64_t> pad_pads_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Pad);
|
||||
auto shape = dynamic_cast<const PadObj *>(op.get())->getPads();
|
||||
|
@ -276,6 +290,7 @@ void export_functions(py::module &m) {
|
|||
.FUNCTION(reduce_mean_attrs_of)
|
||||
.FUNCTION(tensor_dtype)
|
||||
.FUNCTION(reshape_shape_of)
|
||||
.FUNCTION(expand_shape_of)
|
||||
.FUNCTION(pad_pads_of)
|
||||
.FUNCTION(transpose_permute_of)
|
||||
.FUNCTION(concat_axis_of)
|
||||
|
@ -359,6 +374,9 @@ void init_graph_builder(py::module &m) {
|
|||
.def("slice", &Handler::slice, policy::move)
|
||||
.def("pad", &Handler::pad, policy::move)
|
||||
.def("cast", &Handler::cast, policy::move)
|
||||
.def("expand", &Handler::expand, policy::move)
|
||||
.def("erf", &Handler::erf, policy::move)
|
||||
.def("where", &Handler::where, policy::move)
|
||||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
.def("optimize", &Handler::optimize, policy::automatic)
|
||||
.def("operators", &Handler::operators, policy::move)
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
#include "operators/expand.h"
|
||||
#include "cuda/cuda_expand.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class ExpandCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ExpandObj>(_op);
|
||||
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
const auto &in_Shape = op->getInputs(0)->getDims(); // input shape
|
||||
const auto &out_Shape = op->getShape(); // output shape
|
||||
|
||||
SmallArray inputShape, outputShape;
|
||||
int nDims = op->getInputs(0)->getDims().size();
|
||||
|
||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||
int outputsize = 1; // the length of the output vector after flatten
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
outputShape.data[i] = out_Shape[i];
|
||||
inputShape.data[i] = in_Shape[i];
|
||||
outputsize *= out_Shape[i];
|
||||
}
|
||||
expand_kernel((float *)inputData, (float *)outputData, nDims,
|
||||
outputsize, inputShape, outputShape);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Expand, DataType::Float32, ExpandCuda,
|
||||
"Expand_CUDA_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,49 @@
|
|||
#include "core/common.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
__global__ void _expand_kernel(float *input, float *output, int nDims,
|
||||
int outputsize, infini::SmallArray inputShape,
|
||||
infini::SmallArray outputShape) {
|
||||
|
||||
int outputIdx =
|
||||
blockIdx.x * blockDim.x + threadIdx.x; // i(JKS) + j(KS) + k(S) + s
|
||||
if (outputIdx < outputsize) {
|
||||
int inputIdx = 0; // record input index
|
||||
int temp = 1; // stored S, KS, JKS, in order
|
||||
int tmp = 1; // stored s,k,j,i in order
|
||||
int v = outputIdx; // v = i(JKS) + j(KS) + k(S) + s
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
if (i == 0) {
|
||||
tmp = v; // i = outputIdx/(JKS)
|
||||
} else {
|
||||
tmp = v % outputShape.data[i]; // store s,k,j in order
|
||||
}
|
||||
if (inputShape.data[i] ==
|
||||
1) { // if input shape = 1, the index only equal 0
|
||||
inputIdx += 0;
|
||||
} else {
|
||||
inputIdx +=
|
||||
tmp * temp; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
}
|
||||
temp *= inputShape.data[i];
|
||||
v = v / outputShape.data[i];
|
||||
}
|
||||
output[outputIdx] = input[inputIdx];
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void expand_kernel(float *input, float *output, int nDims, int outputsize,
|
||||
SmallArray inputShape, SmallArray outputShape) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (outputsize + block_work_size() - 1) / block_work_size();
|
||||
_expand_kernel<<<gridsize, blocksize>>>(input, output, nDims, outputsize,
|
||||
inputShape, outputShape);
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -12,7 +12,8 @@ class GatherCuda : public CudaKernelWithoutConfig {
|
|||
auto in = op->getInputs(0);
|
||||
auto index = op->getInputs(1);
|
||||
auto out = op->getOutput();
|
||||
metaData.indexValue = index->getRawDataPtr<int *>();
|
||||
metaData.indexValue = index->getRawDataPtr<void *>();
|
||||
metaData.indexType = index->getDType();
|
||||
metaData.axis = op->getAxis();
|
||||
metaData.inNDim = in->getRank();
|
||||
metaData.outNDim = out->getRank();
|
||||
|
|
|
@ -1,19 +1,21 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/gather.h"
|
||||
|
||||
__device__ int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) {
|
||||
int offset = 0;
|
||||
template <typename T>
|
||||
__device__ T gatheredOffset2Offset(int gOffset,
|
||||
infini::GatherMetaData metaData) {
|
||||
T offset = 0;
|
||||
for (int i = metaData.inNDim - 1, k = metaData.outNDim - 1; i >= 0; --i) {
|
||||
int idx = 0;
|
||||
T idx = 0;
|
||||
if (i == metaData.axis) {
|
||||
int idxOffset = 0;
|
||||
T idxOffset = 0;
|
||||
for (int j = metaData.idxNDim - 1; j >= 0; --j) {
|
||||
int p = gOffset % metaData.idxDim[j];
|
||||
T p = gOffset % metaData.idxDim[j];
|
||||
gOffset = gOffset / metaData.idxDim[j];
|
||||
idxOffset += p * metaData.idxStride[j];
|
||||
}
|
||||
|
||||
idx = metaData.indexValue[idxOffset];
|
||||
idx = static_cast<T *>(metaData.indexValue)[idxOffset];
|
||||
k = k - metaData.idxNDim;
|
||||
|
||||
} else {
|
||||
|
@ -26,22 +28,27 @@ __device__ int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) {
|
|||
return offset;
|
||||
}
|
||||
|
||||
__global__ void _gather_kernel(float *in, float *out, GatherMetaData metaData,
|
||||
int num) {
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
template <typename T>
|
||||
__global__ void _gather_kernel(float *in, float *out,
|
||||
infini::GatherMetaData metaData, size_t num) {
|
||||
T tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
while (tid < num) {
|
||||
int offset = gatheredOffset2Offset(tid, metaData);
|
||||
T offset = gatheredOffset2Offset<T>(tid, metaData);
|
||||
out[tid] = in[offset];
|
||||
tid += stride;
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void gather_kernel(float *in, float *out, GatherMetaData metaData, int num) {
|
||||
void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num) {
|
||||
int blockSize = 32 * 16;
|
||||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
|
||||
_gather_kernel<<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
if (metaData.indexType == DataType::Int64) {
|
||||
_gather_kernel<int64_t>
|
||||
<<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
} else {
|
||||
_gather_kernel<int><<<gridSize, blockSize>>>(in, out, metaData, num);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
#include "operators/where.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_where.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class WhereCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<WhereObj>(_op);
|
||||
|
||||
void *const inputxData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const inputyData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const conditionData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
const auto &inputx_Shape = op->getInputs(0)->getDims();
|
||||
const auto &inputy_Shape = op->getInputs(1)->getDims();
|
||||
const auto &condition_Shape = op->getInputs(2)->getDims();
|
||||
const auto &output_Shape = op->getOutput()->getDims();
|
||||
|
||||
int nDims = op->getInputs(0)->getDims().size();
|
||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||
|
||||
SmallArray inputxShape, inputyShape, conditionShape, outputShape;
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
inputxShape.data[i] = inputx_Shape[i];
|
||||
inputyShape.data[i] = inputy_Shape[i];
|
||||
conditionShape.data[i] = condition_Shape[i];
|
||||
outputShape.data[i] = output_Shape[i];
|
||||
}
|
||||
where_kernel((float *)inputxData, (float *)inputyData,
|
||||
(float *)conditionData, (float *)outputData, nDims,
|
||||
inputxShape, inputyShape, conditionShape, outputShape);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Where, DataType::Float32, WhereCuda,
|
||||
"Where_CUDA_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,82 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
__global__ void _where_kernel(const float *inputx, const float *inputy,
|
||||
const float *condition, float *output, int nDims,
|
||||
int outputsize, infini::SmallArray inputxShape,
|
||||
infini::SmallArray inputyShape,
|
||||
infini::SmallArray conditionShape,
|
||||
infini::SmallArray outputShape) {
|
||||
|
||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (outputIdx < outputsize) {
|
||||
int inputxIdx = 0;
|
||||
int temp_inputx = 1;
|
||||
|
||||
int inputyIdx = 0;
|
||||
int temp_inputy = 1;
|
||||
|
||||
int conditionIdx = 0;
|
||||
int temp_condition = 1;
|
||||
|
||||
int tmp = 1; // stored s,k,j,i in order
|
||||
int v = outputIdx; // v = i(JKS) + j(KS) + k(S) + s
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
if (i == 0) {
|
||||
tmp = v; // i = outputIdx/(JKS)
|
||||
} else {
|
||||
tmp = v % outputShape.data[i]; // store s,k,j in order
|
||||
}
|
||||
if (inputxShape.data[i] == 1) {
|
||||
inputxIdx += 0;
|
||||
} else {
|
||||
inputxIdx +=
|
||||
tmp *
|
||||
temp_inputx; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
}
|
||||
temp_inputx *= inputxShape.data[i];
|
||||
//----------------------------
|
||||
if (inputyShape.data[i] == 1) {
|
||||
inputyIdx += 0;
|
||||
} else {
|
||||
inputyIdx +=
|
||||
tmp *
|
||||
temp_inputy; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
}
|
||||
temp_inputy *= inputyShape.data[i];
|
||||
//--------------------------
|
||||
if (conditionShape.data[i] == 1) {
|
||||
conditionIdx += 0;
|
||||
} else {
|
||||
conditionIdx +=
|
||||
tmp *
|
||||
temp_condition; // otherwise +i(JKS) or j(KS) or k(S) or s
|
||||
}
|
||||
temp_condition *= conditionShape.data[i];
|
||||
//-------------------------
|
||||
v = v / outputShape.data[i];
|
||||
}
|
||||
output[outputIdx] =
|
||||
condition[conditionIdx] ? inputx[inputxIdx] : inputy[inputyIdx];
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void where_kernel(const float *inputx, const float *inputy,
|
||||
const float *condition, float *output, int nDims,
|
||||
infini::SmallArray inputxShape,
|
||||
infini::SmallArray inputyShape,
|
||||
infini::SmallArray conditionShape,
|
||||
infini::SmallArray outputShape) {
|
||||
int outputsize = 1;
|
||||
|
||||
for (int i = 0; i < nDims; i++) {
|
||||
outputsize *= outputShape.data[i];
|
||||
}
|
||||
int blocksize = 32 * 16;
|
||||
int gridsize = (outputsize + blocksize - 1) / blocksize;
|
||||
_where_kernel<<<gridsize, blocksize>>>(
|
||||
inputx, inputy, condition, output, nDims, outputsize, inputxShape,
|
||||
inputyShape, conditionShape, outputShape);
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,41 @@
|
|||
#include "operators/expand.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
ExpandObj::ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
|
||||
: OperatorObj(OpType::Expand, {input}, {output}), dims(std::move(dims)) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ExpandObj::inferShape(const TensorVec &inputs) const {
|
||||
auto shape_input = inputs[0]->getDims();
|
||||
Shape ret = infer_broadcast(shape_input, dims);
|
||||
return {{ret}};
|
||||
}
|
||||
|
||||
std::string ExpandObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Expand[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "dims=" << vecToString(dims) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> ExpandObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.insert(ret.end(), dims.begin(), dims.end());
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> ExpandObj::getOpAttrVector() const {
|
||||
vector<int> ret = dims;
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -24,8 +24,8 @@ optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) const {
|
|||
|
||||
vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const {
|
||||
IT_ASSERT(inputs.size() == 2);
|
||||
auto index = inputs[1];
|
||||
IT_ASSERT(index->getDType() == DataType::Int32);
|
||||
auto index_dtype = inputs[1]->getDType();
|
||||
IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64)
|
||||
return {inputs[0]->getDType()};
|
||||
}
|
||||
|
||||
|
@ -36,12 +36,12 @@ bool GatherObj::CheckIndexValid() const {
|
|||
return true;
|
||||
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
bool ret = true;
|
||||
auto value = inputs[0]->getDims()[axis];
|
||||
if (index->getDType() == DataType::Int32) {
|
||||
int *data = (int *)runtime->alloc(index->getBytes());
|
||||
index->getRuntime()->copyBlobToCPU(
|
||||
(void *)data, index->getRawDataPtr<void *>(), index->getBytes());
|
||||
|
||||
bool ret = true;
|
||||
auto value = inputs[0]->getDims()[axis];
|
||||
for (size_t i = 0; i < index->size(); ++i) {
|
||||
if (data[i] < 0 || data[i] >= value) {
|
||||
ret = false;
|
||||
|
@ -49,6 +49,18 @@ bool GatherObj::CheckIndexValid() const {
|
|||
}
|
||||
}
|
||||
runtime->dealloc(data);
|
||||
} else {
|
||||
int64_t *data = (int64_t *)runtime->alloc(index->getBytes());
|
||||
index->getRuntime()->copyBlobToCPU(
|
||||
(void *)data, index->getRawDataPtr<void *>(), index->getBytes());
|
||||
for (size_t i = 0; i < index->size(); ++i) {
|
||||
if (data[i] < 0 || data[i] >= value) {
|
||||
ret = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
runtime->dealloc(data);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
|||
if (ret.empty()) {
|
||||
b = 1;
|
||||
} else {
|
||||
b = std::accumulate(ret.begin(), ret.end(), 1);
|
||||
b = std::accumulate(ret.begin(), ret.end(), 1, std::multiplies<int>());
|
||||
}
|
||||
auto kA = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
|
||||
auto kB = *(transB ? shape_b.rbegin() : shape_b.rbegin() + 1);
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
#include "operators/where.h"
|
||||
#include "utils/operator_utils.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
WhereObj::WhereObj(GraphObj *graph, Tensor inputX, Tensor inputY,
|
||||
Tensor condition, Tensor output)
|
||||
: OperatorObj(OpType::Where, TensorVec{inputX, inputY, condition},
|
||||
{output}) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> WhereObj::inferShape(const TensorVec &inputs) const {
|
||||
auto shapeX = inputs[0]->getDims();
|
||||
auto shapeY = inputs[1]->getDims();
|
||||
auto shapeCon = inputs[2]->getDims();
|
||||
auto retXY = infer_broadcast(shapeX, shapeY);
|
||||
auto ret = infer_broadcast(retXY, shapeCon);
|
||||
return {{ret}};
|
||||
}
|
||||
|
||||
std::string WhereObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Where[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[2]->getDims()) << ",";
|
||||
os << "inputX=" << inputs[0]->getGuid() << ",";
|
||||
os << "inputY=" << inputs[1]->getGuid() << ",";
|
||||
os << "condition=" << inputs[2]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> WhereObj::getWorkloadVector() const {
|
||||
vector<int> ret = getOutput()->getDims();
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> WhereObj::getOpAttrVector() const { return {type.underlying()}; }
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,41 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/expand.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(Expand, Cuda) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto t1 = gCpu->addTensor({2, 1, 2, 1}, DataType::Float32);
|
||||
|
||||
gCpu->dataMalloc();
|
||||
t1->setData(IncrementalGenerator());
|
||||
t1->printData();
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto t1Gpu = gCuda->cloneTensor(t1);
|
||||
|
||||
auto op = gCuda->addOp<ExpandObj>(t1Gpu, nullptr, Shape{2, 2, 2, 3});
|
||||
gCuda->dataMalloc();
|
||||
t1Gpu->setData(IncrementalGenerator());
|
||||
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// cudaPrintTensor(op->getOutput());
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||
oCpu->printData();
|
||||
EXPECT_TRUE(
|
||||
oCpu->equalData(vector<float>{0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
|
||||
2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}));
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -77,7 +77,7 @@ int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) {
|
|||
idxOffset += p * metaData.idxStride[j];
|
||||
}
|
||||
|
||||
idx = metaData.indexValue[idxOffset];
|
||||
idx = static_cast<int *>(metaData.indexValue)[idxOffset];
|
||||
k = k - metaData.idxNDim;
|
||||
|
||||
} else {
|
||||
|
@ -242,6 +242,31 @@ TEST(Gather, Cuda) {
|
|||
indexCuda->copyin(vector<int>{0, 3, 1});
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// cudaPrintTensor(op->getOutput());
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||
EXPECT_TRUE(oCpu->equalData(
|
||||
vector<float>{0, 1, 6, 7, 2, 3, 8, 9, 14, 15, 10, 11}));
|
||||
}
|
||||
{
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32);
|
||||
auto index = gCpu->addTensor({3, 1}, DataType::Int64);
|
||||
gCpu->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
index->copyin(vector<int64_t>{0, 3, 1});
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto inputCuda = gCuda->cloneTensor(input);
|
||||
auto indexCuda = gCuda->cloneTensor(index);
|
||||
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 1);
|
||||
gCuda->dataMalloc();
|
||||
inputCuda->setData(IncrementalGenerator());
|
||||
indexCuda->copyin(vector<int64_t>{0, 3, 1});
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// cudaPrintTensor(op->getOutput());
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/where.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void test_where(const Shape &inputxshape, const vector<float> &inputxdata,
|
||||
const Shape &inputyshape, const vector<float> &inputydata,
|
||||
const Shape &conditionshape, const vector<int> &conditiondata,
|
||||
const vector<float> &ExpectData) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto condition = gCpu->addTensor(conditionshape, DataType::Int32);
|
||||
auto inputx = gCpu->addTensor(inputxshape, DataType::Float32);
|
||||
auto inputy = gCpu->addTensor(inputyshape, DataType::Float32);
|
||||
|
||||
gCpu->dataMalloc();
|
||||
condition->copyin(conditiondata); //
|
||||
inputx->copyin(inputxdata);
|
||||
inputy->copyin(inputydata); //
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto conditionGpu = gCuda->cloneTensor(condition);
|
||||
auto inputxGpu = gCuda->cloneTensor(inputx);
|
||||
auto inputyGpu = gCuda->cloneTensor(inputy);
|
||||
|
||||
auto op = gCuda->addOp<WhereObj>(inputxGpu, inputyGpu, conditionGpu,
|
||||
nullptr); // WhereObj
|
||||
gCuda->dataMalloc();
|
||||
conditionGpu->copyin(conditiondata);
|
||||
inputxGpu->copyin(inputxdata);
|
||||
inputyGpu->copyin(inputydata);
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move data from gpu to cpu
|
||||
oCpu->printData(); //->printData
|
||||
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
||||
}
|
||||
|
||||
TEST(CUDA_Where, run) {
|
||||
test_where(
|
||||
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
Shape{2, 2, 3, 1}, vector<int>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
|
||||
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.});
|
||||
|
||||
test_where(Shape{2, 1, 1, 3}, // inputx
|
||||
vector<float>{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy
|
||||
vector<float>{1, 1}, Shape{2, 1, 3, 1}, // condition
|
||||
vector<int>{0, 1, 1, 0, 0, 0},
|
||||
vector<float>{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1.,
|
||||
0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1.,
|
||||
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
|
||||
|
||||
} // python output
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,26 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/expand.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(Expand, ShapeInference) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({3, 1}, DataType::Float32);
|
||||
auto op = g->addOp<ExpandObj>(i, nullptr, Shape{2, 1, 6});
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 6}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({3, 1}, DataType::Float32);
|
||||
auto op = g->addOp<ExpandObj>(i, nullptr, Shape{3, 4});
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{3, 4}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -9,11 +9,19 @@ namespace infini {
|
|||
|
||||
TEST(Gather, ShapeInference) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({1, 3, 4, 4}, DataType::Int32);
|
||||
Tensor index = g->addTensor({2, 1, 2}, DataType::Int32);
|
||||
auto op = g->addOp<GatherObj>(i, index, nullptr, 1);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({1, 3, 4, 4}, DataType::Int32);
|
||||
Tensor index = g->addTensor({2, 1, 2}, DataType::Int64);
|
||||
auto op = g->addOp<GatherObj>(i, index, nullptr, 1);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -21,6 +21,12 @@ TEST(ReduceMean, ShapeInference) {
|
|||
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, true);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{-3, 3}, true);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
|
@ -33,6 +39,13 @@ TEST(ReduceMean, ShapeInference) {
|
|||
auto op = g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{1, 3}, false);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op =
|
||||
g->addOp<ReduceMeanObj>(i, nullptr, vector<int>{-3, 3}, false);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/where.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(Where, ShapeInference) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor x = g->addTensor({2, 2}, DataType::Float32);
|
||||
Tensor y = g->addTensor({2, 2}, DataType::Float32);
|
||||
Tensor con = g->addTensor({2, 2}, DataType::Bool);
|
||||
auto op = g->addOp<WhereObj>(x, y, con, nullptr);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 2}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor x = g->addTensor({1, 12, 224, 224}, DataType::Float32);
|
||||
Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float32);
|
||||
Tensor con = g->addTensor({1, 224, 1}, DataType::Bool);
|
||||
auto op = g->addOp<WhereObj>(x, y, con, nullptr);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 12, 224, 224}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor x = g->addTensor({12, 224, 224}, DataType::Float32);
|
||||
Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float32);
|
||||
Tensor con = g->addTensor({1, 224}, DataType::Bool);
|
||||
auto op = g->addOp<WhereObj>(x, y, con, nullptr);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 12, 224, 224}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor x = g->addTensor({12, 224, 224}, DataType::Float32);
|
||||
Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float32);
|
||||
Tensor con = g->addTensor({2, 1, 1, 1, 224}, DataType::Bool);
|
||||
auto op = g->addOp<WhereObj>(x, y, con, nullptr);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 12, 224, 224}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue