diff --git a/include/core/graph.h b/include/core/graph.h index a8fc6485..184dcac6 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -53,6 +53,7 @@ class GraphObj : public Object { const TensorVec &getTensors() const { return tensors; } const OpVec &getOperators() const { return ops; } OpVec getComputeOps() const; + Tensor getTensor(int) const; /** * Sort the nodes in topological order. @@ -64,7 +65,13 @@ class GraphObj : public Object { void optimize(); - void dataMalloc(bool useNaiveAllocator = false); + void shape_infer(); + + void dataMalloc(bool useNaiveAllocator = false, size_t memPoolSize = 0); + + Tensor cloneKV(Tensor &tensor); + + void freeHeap(); /** * @brief Add an operator and create its outputs. Output tensor arguments diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index c91c4901..4b66f11a 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -81,6 +81,7 @@ class GraphHandlerObj { Tensor cast(Tensor input, Tensor output, int to); Tensor expand(Tensor input, Tensor output, Shape dims); Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output); + std::vector getDims(Tensor x) { return x->getDims(); } Tensor allReduceSum(Tensor input, Tensor output); Tensor allReduceProd(Tensor input, Tensor output); @@ -98,9 +99,19 @@ class GraphHandlerObj { inline void optimize() { g->optimize(); } + inline void shape_infer() { g->shape_infer(); } + + void change_shape(const vector &shape, int tensorId); //------ runtime - inline void data_malloc() { g->dataMalloc(); } + inline void data_malloc(bool useNaiveAllocator = false, + size_t memPoolSize = 0) { + g->dataMalloc(useNaiveAllocator, memPoolSize); + } + + inline Tensor clone_KV(Tensor &tensor) { return g->cloneKV(tensor); } + + inline void free_heap() { g->freeHeap(); } inline void tune() { g->getRuntime()->run(g, true); } diff --git a/include/core/lazy_allocator.h b/include/core/lazy_allocator.h index 5f073845..f4147851 100644 --- a/include/core/lazy_allocator.h +++ b/include/core/lazy_allocator.h @@ -26,14 +26,23 @@ class LazyAllocator { size_t weightPeak = 0; + size_t heapPeak = 0; + size_t alignment; + bool hasMemPool = false; + + size_t memPoolSize = 0; + // pointer to the memory actually allocated void *ptr = nullptr; // pointer to the weight memory space void *weightPtr = nullptr; + // memory pool ptr + void *memPoolPtr = nullptr; + // // a cache designed for a batch size that has already occurred // std::unordered_map> // batchsizeToTensorOffset; @@ -68,6 +77,10 @@ class LazyAllocator { void init(); + void setMemPool(size_t memPoolSize); + + bool getMemPoolStatus(); + // function: simulate memory allocation // arguments: // size: size of memory block to be allocated @@ -76,6 +89,10 @@ class LazyAllocator { size_t allocWeight(size_t size); + size_t heapAlloc(size_t size); + + void freeHeap(); + // function: simulate memory free // arguments: // addr: head address offset of memory block to be free @@ -92,6 +109,8 @@ class LazyAllocator { void *getWeightPtr(); + void *getHeapPtr(); + void info(); private: diff --git a/include/core/operator.h b/include/core/operator.h index d7a57633..cc8ce174 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -55,8 +55,7 @@ class OperatorObj : public Object { public: OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs); - virtual optional> - inferShape(const TensorVec &inputs) const = 0; + virtual optional> inferShape(const TensorVec &inputs) = 0; virtual vector inferDataType(const TensorVec &inputs) const; /** * @brief Constructs outputs (if requried) and check whether the operator is @@ -105,7 +104,7 @@ class OperatorObj : public Object { const TensorVec &newOutputs) const = 0; protected: - optional> inferShape() const; + optional> inferShape(); vector inferDataType() const; private: diff --git a/include/core/tensor.h b/include/core/tensor.h index 48590fd6..cb09261a 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -31,6 +31,7 @@ class TensorObj : public TensorBaseObj { size_t getBytes() const { return _size * dtype.getSize(); } Shape getDims() const { return shape; } + void setShape(Shape shape_); size_t getRank() const { return shape.size(); } Shape getStride() const; size_t getOffset(const vector &ds) const; diff --git a/include/cuda/cuda_element_wise.h b/include/cuda/cuda_element_wise.h index eb3b99a2..db9c16f1 100644 --- a/include/cuda/cuda_element_wise.h +++ b/include/cuda/cuda_element_wise.h @@ -1,8 +1,13 @@ #pragma once namespace infini { -void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3, +void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); -void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3, +void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); +void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, + int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); +void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, + int b0, int b1, int b2, int b3, int c0, int c1, int c2, + int c3); }; // namespace infini diff --git a/include/cuda/cuda_pad_slice.h b/include/cuda/cuda_pad_slice.h index db032fa0..9c044145 100644 --- a/include/cuda/cuda_pad_slice.h +++ b/include/cuda/cuda_pad_slice.h @@ -10,10 +10,11 @@ typedef struct { int wholeNDim[MAX_DIM]; // dim size after padding or before slicing int partNDim[MAX_DIM]; // dim size before padding or after slicing int partStride[MAX_DIM]; // stride before padding or after slicing + int DType; } TransMetaData; namespace infini { -void pad_slice_kernel(float *partData, float *wholeData, +void pad_slice_kernel(void *partData, void *wholeData, const TransMetaData &metadata, int nDims, int num, bool isPad); } // namespace infini diff --git a/include/operators/G2BMM.h b/include/operators/G2BMM.h index 52f2a2c8..f1a48383 100644 --- a/include/operators/G2BMM.h +++ b/include/operators/G2BMM.h @@ -35,7 +35,7 @@ class G2BMMObj : public OperatorObj { OP_CLONE(G2BMMObj); std::string toString() const override; - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; int numInputs() const override { return 2; } int numOutputs() const override { return 1; } diff --git a/include/operators/GBMM.h b/include/operators/GBMM.h index ebfed659..1329996d 100644 --- a/include/operators/GBMM.h +++ b/include/operators/GBMM.h @@ -33,7 +33,7 @@ class GBMMObj : public OperatorObj { OP_CLONE(GBMMObj); std::string toString() const override; - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; int numInputs() const override { return 2; } int numOutputs() const override { return 1; } diff --git a/include/operators/activation_backward.h b/include/operators/activation_backward.h index 5f55d8cc..ae050733 100644 --- a/include/operators/activation_backward.h +++ b/include/operators/activation_backward.h @@ -7,7 +7,7 @@ class ActivationBackwardObj : public OperatorObj { ActivationBackwardObj(OpType type, GraphObj *graph, Tensor y, Tensor diff_y, Tensor x, Tensor diff_x); OP_CLONE(ActivationBackwardObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 3; } diff --git a/include/operators/all_gather.h b/include/operators/all_gather.h index 423974f6..c38d0a3e 100644 --- a/include/operators/all_gather.h +++ b/include/operators/all_gather.h @@ -27,7 +27,7 @@ class AllGatherObj : public OperatorObj { int numInputs() const override { return 1; } int numOutputs() const override { return world_size; } - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; diff --git a/include/operators/all_reduce.h b/include/operators/all_reduce.h index f91b3ad1..08635d71 100644 --- a/include/operators/all_reduce.h +++ b/include/operators/all_reduce.h @@ -33,7 +33,7 @@ class AllReduceBaseObj : public OperatorObj { int numInputs() const override { return 1; } int numOutputs() const override { return 1; } - optional> inferShape(const TensorVec &inputs) const override { + optional> inferShape(const TensorVec &inputs) override { return {{inputs[0]->getDims()}}; }; diff --git a/include/operators/attention_kvcache.h b/include/operators/attention_kvcache.h index f319eb6c..0472b222 100644 --- a/include/operators/attention_kvcache.h +++ b/include/operators/attention_kvcache.h @@ -29,7 +29,7 @@ class AttentionKVCacheObj : public OperatorObj { Tensor output_matmul); OP_CLONE(AttentionKVCacheObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 6; } diff --git a/include/operators/batch_norm.h b/include/operators/batch_norm.h index cfacf2ca..ce7314aa 100644 --- a/include/operators/batch_norm.h +++ b/include/operators/batch_norm.h @@ -34,7 +34,7 @@ class BatchNormObj : public OperatorObj { Tensor var, Tensor scale, Tensor bias, float momentum = 0.9, float eps = 1e-5, bool trainingMode = false); OP_CLONE(BatchNormObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; // output size will be 3 when training diff --git a/include/operators/broadcast.h b/include/operators/broadcast.h index 1a15b770..551fd8ce 100644 --- a/include/operators/broadcast.h +++ b/include/operators/broadcast.h @@ -26,7 +26,7 @@ class BroadcastObj : public OperatorObj { int numInputs() const override { return 1; } int numOutputs() const override { return 1; } - optional> inferShape(const TensorVec &inputs) const override { + optional> inferShape(const TensorVec &inputs) override { return {{inputs[0]->getDims()}}; }; diff --git a/include/operators/concat.h b/include/operators/concat.h index c3d9c4f3..2d130112 100644 --- a/include/operators/concat.h +++ b/include/operators/concat.h @@ -22,7 +22,7 @@ class ConcatObj : public OperatorObj { ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim); OP_CLONE(ConcatObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return inputs.size(); } diff --git a/include/operators/conv.h b/include/operators/conv.h index 449f4334..00420c84 100644 --- a/include/operators/conv.h +++ b/include/operators/conv.h @@ -142,7 +142,7 @@ class ConvObj : public ConvBaseObj { ActType act = ActType::None); OP_CLONE(ConvObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; int getNumGroups() const override { return c / getChannelPerGroup(); } private: @@ -164,7 +164,7 @@ class ConvBackwardFilterObj : public ConvBaseObj { int sh = 1, int sw = 1, int dh = 1, int dw = 1, Tensor bias = nullptr, ActType act = ActType::None); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; ActType getAct() const { return act; } int getNumGroups() const override { return c / getChannelPerGroup(); } @@ -191,7 +191,7 @@ class ConvTransposed2dObj : public ConvBaseObj { Tensor bias = nullptr, ActType act = ActType::None); OP_CLONE(ConvTransposed2dObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; int getNumGroups() const override { return group; } std::pair getOutputPadding() const { return {oph, opw}; } @@ -218,7 +218,7 @@ class ConvTransposed2dNHWCObj : public ConvBaseObj { Tensor bias = nullptr, ActType act = ActType::None); OP_CLONE(ConvTransposed2dNHWCObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; int getNumGroups() const override { return group; } private: diff --git a/include/operators/det.h b/include/operators/det.h index d5e887c1..8a64a279 100644 --- a/include/operators/det.h +++ b/include/operators/det.h @@ -7,7 +7,7 @@ class DetObj : public OperatorObj { enum Mode { NormalDet = 0, LogDet }; DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode); OP_CLONE(DetObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } diff --git a/include/operators/dropout.h b/include/operators/dropout.h index 8c4c7300..330c94b6 100644 --- a/include/operators/dropout.h +++ b/include/operators/dropout.h @@ -37,7 +37,7 @@ class DropoutObj : public OperatorObj { DropoutObj(GraphObj *graph, Tensor data, Tensor output, Tensor mask, float ratio, bool training_mode); OP_CLONE(DropoutObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } diff --git a/include/operators/element_wise.h b/include/operators/element_wise.h index e198de75..f0275add 100644 --- a/include/operators/element_wise.h +++ b/include/operators/element_wise.h @@ -21,7 +21,7 @@ class ElementWiseObj : public OperatorObj { */ ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, Tensor input1, Tensor output); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 2; } @@ -38,7 +38,7 @@ class MSELossObj : public OperatorObj { MSELossObj(GraphObj *graph, Tensor input0, Tensor input1, Reduction reduction, Tensor output); OP_CLONE(MSELossObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; Reduction getReduction() const { return reductionMode; } std::string toString() const override; diff --git a/include/operators/expand.h b/include/operators/expand.h index 8a3558ca..5f82768d 100644 --- a/include/operators/expand.h +++ b/include/operators/expand.h @@ -21,7 +21,7 @@ class ExpandObj : public OperatorObj { */ ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims); OP_CLONE(ExpandObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } diff --git a/include/operators/extend.h b/include/operators/extend.h index f749793f..77ac2ff7 100644 --- a/include/operators/extend.h +++ b/include/operators/extend.h @@ -23,7 +23,7 @@ class ExtendObj : public OperatorObj { ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim, int num = 1); OP_CLONE(ExtendObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } diff --git a/include/operators/gather.h b/include/operators/gather.h index ff35aba8..b1390834 100644 --- a/include/operators/gather.h +++ b/include/operators/gather.h @@ -39,7 +39,7 @@ class GatherObj : public GatherBaseObj { int axis); OP_CLONE(GatherObj); std::string toString() const override; - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; vector inferDataType(const TensorVec &inputs) const override; private: @@ -69,7 +69,7 @@ class GatherElementsObj : public GatherBaseObj { Tensor output, int axis); OP_CLONE(GatherElementsObj); std::string toString() const override; - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; vector inferDataType(const TensorVec &inputs) const override; private: diff --git a/include/operators/matmul.h b/include/operators/matmul.h index 91a0a57c..35a4c0a8 100644 --- a/include/operators/matmul.h +++ b/include/operators/matmul.h @@ -45,7 +45,7 @@ class MatmulObj : public OperatorObj { OP_CLONE(MatmulObj); std::string toString() const override; - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; int numInputs() const override { return inputs.size(); } int numOutputs() const override { return 1; } diff --git a/include/operators/membound.h b/include/operators/membound.h index 4a444553..c9123b4f 100644 --- a/include/operators/membound.h +++ b/include/operators/membound.h @@ -21,7 +21,7 @@ class MemBoundObj : public OperatorObj { OP_CLONE(MemBoundObj); std::string toString() const override; - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; int numInputs() const override { return inputs.size(); } int numOutputs() const override { return outputs.size(); } diff --git a/include/operators/pad.h b/include/operators/pad.h index 7a25d8bd..3305e127 100644 --- a/include/operators/pad.h +++ b/include/operators/pad.h @@ -27,7 +27,7 @@ class PadObj : public OperatorObj { const vector &pads, const optional> &axes); OP_CLONE(PadObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } int numOutputs() const override { return 1; } diff --git a/include/operators/pooling.h b/include/operators/pooling.h index 7f28224d..31752dee 100644 --- a/include/operators/pooling.h +++ b/include/operators/pooling.h @@ -41,7 +41,7 @@ class PoolingObj : public OperatorObj { int ceilMode); OP_CLONE(PoolingObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } int numOutputs() const override { return 1; } diff --git a/include/operators/reduce_mean.h b/include/operators/reduce_mean.h index ef74cd2e..18ef38b1 100644 --- a/include/operators/reduce_mean.h +++ b/include/operators/reduce_mean.h @@ -23,7 +23,7 @@ class ReduceMeanObj : public OperatorObj { ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, const optional> &axes, bool keepDims = true); OP_CLONE(ReduceMeanObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } diff --git a/include/operators/reshape.h b/include/operators/reshape.h index 00ae5b0a..43244436 100644 --- a/include/operators/reshape.h +++ b/include/operators/reshape.h @@ -9,6 +9,7 @@ namespace infini { */ class ReshapeObj : public OperatorObj { Shape dims; + Shape outputShape; public: /** @@ -17,18 +18,20 @@ class ReshapeObj : public OperatorObj { * @param graph The computation graph that this operator belongs to. * @param input The input tensor. * @param output The output tensor. - * @param dims The shape of the output tensor. + * @param dims The shape to infer the output shape. + * @param outputShape The real shape of output tensor. */ ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims); OP_CLONE(ReshapeObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } int numOutputs() const override { return 1; } - inline Shape getShape() const { return dims; } + inline Shape getShape() const { return outputShape; } + inline Shape getDims() const { return dims; } private: vector getWorkloadVector() const override; @@ -55,7 +58,7 @@ class FlattenObj : public OperatorObj { FlattenObj(GraphObj *graph, Tensor input, Tensor output, int axis); OP_CLONE(FlattenObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } @@ -85,7 +88,7 @@ class IdentityObj : public OperatorObj { IdentityObj(GraphObj *graph, Tensor input, Tensor output); OP_CLONE(IdentityObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } diff --git a/include/operators/resize.h b/include/operators/resize.h index a762ea30..96283c12 100644 --- a/include/operators/resize.h +++ b/include/operators/resize.h @@ -60,7 +60,7 @@ class ResizeObj : public OperatorObj { // Operator clone(TensorVec inputs, TensorVec outputs) override; vector inferDataType(const TensorVec &inputs) const override; - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return inputs.size(); } int numOutputs() const override { return 1; } diff --git a/include/operators/slice.h b/include/operators/slice.h index 55acf505..188c804d 100644 --- a/include/operators/slice.h +++ b/include/operators/slice.h @@ -32,7 +32,7 @@ class SliceObj : public OperatorObj { const optional> &steps); OP_CLONE(SliceObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; inline int numInputs() const override { return 1; } inline int numOutputs() const override { return 1; } diff --git a/include/operators/softmax.h b/include/operators/softmax.h index 0611f63f..b24c0ffb 100644 --- a/include/operators/softmax.h +++ b/include/operators/softmax.h @@ -10,7 +10,7 @@ class SoftmaxObj : public OperatorObj { OP_CLONE(SoftmaxObj); - optional> inferShape(const TensorVec &inputs) const override { + optional> inferShape(const TensorVec &inputs) override { return {{inputs[0]->getDims()}}; }; diff --git a/include/operators/split.h b/include/operators/split.h index 61aa43a2..a4032463 100644 --- a/include/operators/split.h +++ b/include/operators/split.h @@ -37,7 +37,7 @@ class SplitObj : public OperatorObj { int dim, const vector &ratio); OP_CLONE(SplitObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } diff --git a/include/operators/transpose.h b/include/operators/transpose.h index 9fcd1617..6a8cfcc6 100644 --- a/include/operators/transpose.h +++ b/include/operators/transpose.h @@ -7,7 +7,7 @@ class TransposeObj : public OperatorObj { TransposeObj(GraphObj *graph, Tensor input, Tensor output, vector permute); OP_CLONE(TransposeObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } @@ -25,7 +25,7 @@ class DepthToSpaceObj : public OperatorObj { DepthToSpaceObj(GraphObj *graph, Tensor input, Tensor output, int blocksize, std::string mode); OP_CLONE(DepthToSpaceObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } diff --git a/include/operators/unary.h b/include/operators/unary.h index 0bbe314c..c3e628d4 100644 --- a/include/operators/unary.h +++ b/include/operators/unary.h @@ -17,7 +17,7 @@ class UnaryObj : public OperatorObj { * @param output The output tensor. */ UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } @@ -33,7 +33,7 @@ class ClipObj : public OperatorObj { ClipObj(GraphObj *graph, Tensor input, Tensor output, std::optional min, std::optional max); OP_CLONE(ClipObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; std::optional getMin() const { return minValue; }; @@ -52,7 +52,7 @@ class HardtanhObj : public OperatorObj { HardtanhObj(GraphObj *graph, Tensor input, Tensor output, float min, float max); OP_CLONE(HardtanhObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; float getMin() const { return minValue; }; @@ -70,7 +70,7 @@ class FlipObj : public OperatorObj { public: FlipObj(GraphObj *graph, Tensor input, Tensor output, vector axis); OP_CLONE(FlipObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; vector getAxis() const { return axisValue; }; @@ -87,7 +87,7 @@ class FillObj : public OperatorObj { public: FillObj(GraphObj *graph, Tensor input, Tensor output, float value); OP_CLONE(FillObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; float getValue() const { return setValue; }; @@ -104,7 +104,7 @@ class L2LossObj : public OperatorObj { public: L2LossObj(GraphObj *graph, Tensor input, Tensor output); OP_CLONE(L2LossObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } @@ -120,7 +120,7 @@ class TransformObj : public OperatorObj { TransformObj(GraphObj *graph, Tensor input, Tensor output, float alpha, float beta); OP_CLONE(TransformObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; float getAlpha() const { return alphaValue; } @@ -165,7 +165,7 @@ class CastObj : public OperatorObj { public: CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type); OP_CLONE(CastObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; vector inferDataType(const TensorVec &inputs) const override; std::string toString() const override; @@ -185,7 +185,7 @@ class CumsumObj : public OperatorObj { CumsumObj(GraphObj *graph, Tensor input, Tensor output, int axis, bool exclusive, bool reverse); OP_CLONE(CumsumObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int getAxis() const { return axisValue; } @@ -205,7 +205,7 @@ class ShapeObj : public OperatorObj { public: ShapeObj(GraphObj *graph, Tensor input, Tensor output); OP_CLONE(ShapeObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 1; } @@ -216,7 +216,7 @@ class PReluObj : public OperatorObj { public: PReluObj(GraphObj *graph, Tensor input, Tensor alpha, Tensor output); OP_CLONE(PReluObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return 2; } @@ -236,7 +236,7 @@ class LogObj : public OperatorObj { }; LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type); OP_CLONE(LogObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; LogType getType() const { return logType; } diff --git a/include/operators/where.h b/include/operators/where.h index 6422fe34..249f9c46 100644 --- a/include/operators/where.h +++ b/include/operators/where.h @@ -22,7 +22,7 @@ class WhereObj : public OperatorObj { Tensor output); OP_CLONE(WhereObj); - optional> inferShape(const TensorVec &inputs) const override; + optional> inferShape(const TensorVec &inputs) override; std::string toString() const override; int numInputs() const override { return inputs.size(); } diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 7360002a..d48ef52a 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -48,7 +48,7 @@ class OnnxStub: pass except RuntimeError: pass - + self.inputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {} self.initializer: Dict[int, TensorProto] = {} @@ -510,19 +510,11 @@ class OnnxStub: mode, ) elif node.op_type == "Reshape": - dims = _search_shape(model, node.input[0]) - size = reduce(lambda acc, x: acc * x, dims) - input_shape = _parse_data(data[node.input[1]]) - for i, x in enumerate(input_shape): - if x == 0: - input_shape[i] = dims[i] - temp = reduce(lambda acc, x: acc * x, input_shape, 1) - if temp < 0: - input_shape[input_shape.index(-1)] = size // -temp + shape = _parse_data(data[node.input[1]]) tensors[node.output[0]] = self.handler.reshape( tensors[node.input[0]], tensors.get(node.output[0]), - input_shape, + shape, ) elif node.op_type == "Squeeze": input_shape = _search_shape(model, node.input[0]) @@ -1112,6 +1104,26 @@ class OnnxStub: def optimize(self) -> None: self.handler.optimize() + def clone_KV(self, tensor: backend.Tensor) -> backend.Tensor: + return self.handler.clone_KV(tensor) + + def free_heap(self) -> None: + self.handler.free_heap() + + def set_input(self, inputShapes: List[int]) -> None: + for newInput, oldInput in zip(inputShapes, self.inputs): + oldTensor = self.inputs[oldInput] + self.handler.change_shape(newInput, oldTensor.fuid()) + self.handler.shape_infer() + self.handler.data_malloc() + + def getShape(self, name: str) -> List[int]: + if name in self.inputs: + ans = self.handler.getDims(self.inputs[name]) + else: + ans = self.handler.getDims(self.outputs[name]) + return ans + def tune(self) -> None: self.handler.tune() diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 035baf34..79df0294 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -209,6 +209,7 @@ class TestStringMethods(unittest.TestCase): make_and_import_model(make_graph([relu], "relu", [x], [y])) """Gelu operator is not supported by onnx 14.1 currently.""" + def test_gelu(self): pass # x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) @@ -500,5 +501,22 @@ class TestStringMethods(unittest.TestCase): make_and_import_model(make_graph([where], "where", [x, y, con], [output])) +class TestDynamicTensor(unittest.TestCase): + def test_dynamic_tensor(self): + filename = r"resnet18-v2-7.onnx" + current_path = os.getcwd() + model_file = "" + for root, dirs, files in os.walk(current_path): + if filename in files: + model_file = os.path.join(root, filename) + model = OnnxStub(onnx.load(model_file), backend.cpu_runtime()) + output_key = list(model.outputs.keys())[0] + old_output_shape = model.getShape(output_key) + self.assertEqual(old_output_shape, ([1, 1000])) + model.set_input([[5, 3, 224, 224]]) + new_output_shape = model.getShape(output_key) + self.assertEqual(new_output_shape, ([5, 1000])) + + if __name__ == "__main__": unittest.main() diff --git a/src/core/graph.cc b/src/core/graph.cc index 7e902247..dd474d11 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -1,5 +1,7 @@ #include "core/graph.h" +#include "operators/reshape.h" #include +#include #include namespace infini { @@ -123,10 +125,40 @@ void GraphObj::optimize() { } } -void GraphObj::dataMalloc(bool useNaiveAllocator) { +Tensor GraphObj::getTensor(int fuid) const { + for (auto tensor : tensors) { + if (tensor->getFuid() == fuid) { + return tensor; + } + } + return nullptr; +} + +void GraphObj::shape_infer() { + for (auto &op : ops) { + auto ans = op->inferShape(); + IT_ASSERT(ans.has_value()); + auto oldOutputs = op->getOutputs(); + IT_ASSERT(ans.value().size() == oldOutputs.size()); + // replace the old outputshape and size with new one + for (int i = 0; i < (int)ans.value().size(); ++i) { + auto newShape = ans.value()[i]; + auto oldShape = oldOutputs[i]->getDims(); + auto fuid = oldOutputs[i]->getFuid(); + if (newShape != oldShape) { + auto tensor = this->getTensor(fuid); + tensor->setShape(newShape); + } + } + } +} + +void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) { // topological sorting first IT_ASSERT(topo_sort() == true); if (useNaiveAllocator) { + // can not set memory pool when use naive allocator + IT_ASSERT(memPoolSize == 0); // used for debugging memory out-of-bounds access, tensors will not be // released correctly // note: behavior may not match running in non-naive mode, and it may @@ -136,6 +168,9 @@ void GraphObj::dataMalloc(bool useNaiveAllocator) { } return; } + if (memPoolSize > 0) { + allocator.setMemPool(memPoolSize); + } // count the number of times all tensors are used std::unordered_map tensorToRefCount; // record the memory address offsets of all tensors to be allocated @@ -222,6 +257,27 @@ void GraphObj::dataMalloc(bool useNaiveAllocator) { } } +Tensor GraphObj::cloneKV(Tensor &tensor) { + auto obj = tensor->clone(); + if (allocator.getMemPoolStatus()) { + if (tensor->hasData()) { + obj->setDataBlob(make_ref( + tensor->runtime, + static_cast(allocator.getHeapPtr()) + + allocator.heapAlloc(tensor->getBytes()))); + obj->copyData(tensor); + } + } else { + if (tensor->hasData()) { + obj->dataMalloc(); + obj->copyData(tensor); + } + } + return obj; +} + +void GraphObj::freeHeap() { this->allocator.freeHeap(); } + Tensor GraphObj::addTensor(Shape dim, DataType dtype) { return tensors.emplace_back(make_ref(dim, dtype, runtime)); } diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 32b99b63..d2f54b2d 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -20,6 +20,7 @@ #include "operators/transpose.h" #include "operators/unary.h" #include "operators/where.h" +#include namespace infini { @@ -555,4 +556,11 @@ static DataType dtype_repr_convert(int dtype) { } } +void GraphHandlerObj::change_shape(const vector &shape, int tensorId) { + auto tensor = g->getTensor(tensorId); + IT_ASSERT(tensor != nullptr); + IT_ASSERT(shape.size() != 0); + tensor->setShape(shape); +} + } // namespace infini diff --git a/src/core/lazy_allocator.cc b/src/core/lazy_allocator.cc index c3407320..60f74c75 100644 --- a/src/core/lazy_allocator.cc +++ b/src/core/lazy_allocator.cc @@ -30,6 +30,9 @@ LazyAllocator::~LazyAllocator() { if (this->weightPtr != nullptr) { runtime->dealloc(this->weightPtr); } + if (this->memPoolPtr != nullptr) { + runtime->dealloc(this->memPoolPtr); + } } void LazyAllocator::init() { @@ -44,6 +47,17 @@ void LazyAllocator::init() { this->ptr = nullptr; } +void LazyAllocator::setMemPool(size_t memPoolSize) { + IT_ASSERT(memPoolSize > 0); + if (!this->hasMemPool) { + this->hasMemPool = true; + this->memPoolSize = memPoolSize; + this->memPoolPtr = runtime->alloc(memPoolSize); + } +} + +bool LazyAllocator::getMemPoolStatus() { return this->hasMemPool; } + size_t LazyAllocator::alloc(size_t size) { // pad the size to the multiple of alignment size = this->getAlignedSize(size); @@ -102,6 +116,17 @@ size_t LazyAllocator::allocWeight(size_t size) { return retAddr; } +size_t LazyAllocator::heapAlloc(size_t size) { + size = this->getAlignedSize(size); + this->heapPeak += size; + IT_ASSERT(this->memPoolSize >= + this->weightPeak + this->peak + this->heapPeak); + size_t retAddr = this->memPoolSize - this->heapPeak; + return retAddr; +} + +void LazyAllocator::freeHeap() { this->heapPeak = 0; } + void LazyAllocator::free(size_t addr, size_t size) { IT_ASSERT(this->ptr == nullptr); size = getAlignedSize(size); @@ -143,25 +168,40 @@ void LazyAllocator::free(size_t addr, size_t size) { } void *LazyAllocator::getPtr() { - if (this->ptr == nullptr) { - this->ptr = runtime->alloc(this->peak); - // #ifdef DEBUG_MODE - // printf("LazyAllocator really alloc non-weight: %p %lu - // bytes\n", this->ptr, peak); - // #endif + if (!hasMemPool) { + if (this->ptr == nullptr) { + this->ptr = runtime->alloc(this->peak); + // #ifdef DEBUG_MODE + // printf("LazyAllocator really alloc non-weight: %p %lu + // bytes\n", this->ptr, peak); + // #endif + } + return this->ptr; + } else { + IT_ASSERT(this->memPoolSize >= this->weightPeak + this->peak); + return static_cast(this->memPoolPtr) + weightPeak; } - return this->ptr; } void *LazyAllocator::getWeightPtr() { - if (this->weightPtr == nullptr) { - this->weightPtr = runtime->alloc(this->weightPeak); - // #ifdef DEBUG_MODE - // printf("LazyAllocator really alloc weight: %p %lu bytes\n", - // this->weightPtr, weightPeak); - // #endif + if (!hasMemPool) { + if (this->weightPtr == nullptr) { + this->weightPtr = runtime->alloc(this->weightPeak); + // #ifdef DEBUG_MODE + // printf("LazyAllocator really alloc weight: %p %lu + // bytes\n", + // this->weightPtr, weightPeak); + // #endif + } + return this->weightPtr; + } else { + return this->memPoolPtr; } - return this->weightPtr; +} + +void *LazyAllocator::getHeapPtr() { + IT_ASSERT(hasMemPool); + return this->memPoolPtr; } size_t LazyAllocator::getAlignedSize(size_t size) { diff --git a/src/core/operator.cc b/src/core/operator.cc index 462cb2a2..6a9ea1b8 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -77,9 +77,7 @@ bool OperatorObj::checkValid(GraphObj *graph) { return true; } -optional> OperatorObj::inferShape() const { - return inferShape(inputs); -} +optional> OperatorObj::inferShape() { return inferShape(inputs); } vector OperatorObj::inferDataType(const TensorVec &inputs) const { auto dataType = inputs[0]->getDType(); diff --git a/src/core/tensor.cc b/src/core/tensor.cc index e34fb8bc..5be8a18d 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -59,6 +59,13 @@ Shape TensorObj::getStride() const { return stride; } +void TensorObj::setShape(Shape shape_) { + shape = shape_; + size_t size = std::accumulate(shape.begin(), shape.end(), 1, + [](auto acc, auto x) { return acc * x; }); + _size = size; +} + void TensorObj::printData() const { IT_ASSERT(data != nullptr); if (!runtime->isCpu()) diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index ca427dab..5033a191 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -446,7 +446,10 @@ void init_graph_builder(py::module &m) { }) .def("has_target", &TensorObj::hasTarget, policy::automatic) .def("src", &TensorObj::getSource, policy::move) - .def("printData", &TensorObj::printData, policy::automatic); + .def("printData", &TensorObj::printData, policy::automatic) + .def("copy_data", + py::overload_cast(&TensorObj::copyData), + policy::move); py::class_>(m, "Operator") .def("op_type", &OperatorObj::getOpType, policy::automatic) .def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_), @@ -466,6 +469,7 @@ void init_graph_builder(py::module &m) { .def("add", &Handler::add, policy::move) .def("sub", &Handler::sub, policy::move) .def("mul", &Handler::mul, policy::move) + .def("max", &Handler::max, policy::move) .def("div", &Handler::div, policy::move) .def("pow", &Handler::pow, policy::move) .def("min", &Handler::min, policy::move) @@ -510,10 +514,17 @@ void init_graph_builder(py::module &m) { .def("topo_sort", &Handler::topo_sort, policy::automatic) .def("optimize", &Handler::optimize, policy::automatic) .def("operators", &Handler::operators, policy::move) - .def("data_malloc", &Handler::data_malloc, policy::automatic) + .def("data_malloc", &Handler::data_malloc, + py::arg("useNaiveAllocator") = false, py::arg("memPoolSize") = 0, + policy::automatic) + .def("clone_KV", &Handler::clone_KV, policy::move) + .def("free_heap", &Handler::free_heap, policy::move) .def("get_perf_time", &Handler::get_perf_time, policy::automatic) .def("tune", &Handler::tune, policy::automatic) .def("run", &Handler::run, policy::automatic) + .def("shape_infer", &Handler::shape_infer, policy::automatic) + .def("change_shape", &Handler::change_shape, policy::automatic) + .def("getDims", &Handler::getDims, policy::automatic) .def("get_perf_time", &Handler::get_perf_time, policy::automatic); } diff --git a/src/kernels/cuda/element_wise.cc b/src/kernels/cuda/element_wise.cc index 99b586fb..8603c198 100644 --- a/src/kernels/cuda/element_wise.cc +++ b/src/kernels/cuda/element_wise.cc @@ -44,7 +44,6 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig { std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size())); std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size())); std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size())); - // get inputs checkCudnnError(cudnnCreateTensorDescriptor(&aDesc)); checkCudnnError(cudnnSetTensor4dDescriptor(aDesc, CUDNN_TENSOR_NCHW, @@ -110,9 +109,9 @@ class ElementWiseCuda : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - float *const aData = (op->getInputs(0)->getRawDataPtr()); - float *const bData = (op->getInputs(1)->getRawDataPtr()); - float *const cData = (op->getOutput()->getRawDataPtr()); + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const bData = (op->getInputs(1)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); auto a_dim = op->getInputs(0)->getDims(); auto b_dim = op->getInputs(1)->getDims(); auto c_dim = op->getOutput()->getDims(); @@ -134,7 +133,13 @@ class ElementWiseCuda : public CudaKernelWithoutConfig { else if (op->getOpType() == OpType::Pow) pow_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3]); - else + else if (op->getOpType() == OpType::Add) { + add_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], + b[2], b[3], c[0], c[1], c[2], c[3]); + } else if (op->getOpType() == OpType::Less) { + less_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], + b[2], b[3], c[0], c[1], c[2], c[3]); + } else IT_TODO_HALT(); } }; @@ -152,6 +157,10 @@ REGISTER_KERNEL(Device::CUDA, OpType::Max, DataType::Float32, MaxCudnn, REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda, "Div_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Int64, ElementWiseCuda, + "Add_CUDA_Int64"); REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda, "Pow__CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Less, DataType::Int64, ElementWiseCuda, + "Less__CUDA_Int64"); }; // namespace infini diff --git a/src/kernels/cuda/element_wise.cu b/src/kernels/cuda/element_wise.cu index 93e384d3..9d1b101a 100644 --- a/src/kernels/cuda/element_wise.cu +++ b/src/kernels/cuda/element_wise.cu @@ -5,9 +5,10 @@ constexpr unsigned int num_threads() { return 32 * 4; } constexpr int thread_work_size() { return 4; } constexpr int block_work_size() { return thread_work_size() * num_threads(); } -__global__ void _div_kernel(float *x, float *y, float *z, int a0, int a1, - int a2, int a3, int b0, int b1, int b2, int b3, - int c0, int c1, int c2, int c3) { +template +__global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2, + int a3, int b0, int b1, int b2, int b3, int c0, + int c1, int c2, int c3) { int index = threadIdx.x + blockIdx.x * blockDim.x; int stride = blockDim.x * gridDim.x; int n = c0 * c1 * c2 * c3; @@ -27,16 +28,17 @@ __global__ void _div_kernel(float *x, float *y, float *z, int a0, int a1, int b1_index = c1_index % b1; int b2_index = c2_index % b2; int b3_index = c3_index % b3; - z[i] = x[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + a2_index * a3 + - a3_index] / - y[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + b2_index * b3 + - b3_index]; + ((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + + a2_index * a3 + a3_index] / + ((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + + b2_index * b3 + b3_index]; } } -__global__ void _pow_kernel(float *x, float *y, float *z, int a0, int a1, - int a2, int a3, int b0, int b1, int b2, int b3, - int c0, int c1, int c2, int c3) { +template +__global__ void _add_kernel(void *x, void *y, void *z, int a0, int a1, int a2, + int a3, int b0, int b1, int b2, int b3, int c0, + int c1, int c2, int c3) { int index = threadIdx.x + blockIdx.x * blockDim.x; int stride = blockDim.x * gridDim.x; int n = c0 * c1 * c2 * c3; @@ -56,32 +58,115 @@ __global__ void _pow_kernel(float *x, float *y, float *z, int a0, int a1, int b1_index = c1_index % b1; int b2_index = c2_index % b2; int b3_index = c3_index % b3; - z[i] = pow(x[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + - a2_index * a3 + a3_index], - y[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + - b2_index * b3 + b3_index]); + ((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + + a2_index * a3 + a3_index] + + ((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + + b2_index * b3 + b3_index]; + } +} + +template +__global__ void _pow_kernel(void *x, void *y, void *z, int a0, int a1, int a2, + int a3, int b0, int b1, int b2, int b3, int c0, + int c1, int c2, int c3) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + int n = c0 * c1 * c2 * c3; + + for (int i = index; i < n; i += stride) { + int c0_index = i / (c1 * c2 * c3); + int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3); + int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3; + int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3; + + int a0_index = c0_index % a0; + int a1_index = c1_index % a1; + int a2_index = c2_index % a2; + int a3_index = c3_index % a3; + + int b0_index = c0_index % b0; + int b1_index = c1_index % b1; + int b2_index = c2_index % b2; + int b3_index = c3_index % b3; + ((T *)z)[i] = + pow(((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + + a2_index * a3 + a3_index], + ((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + + b2_index * b3 + b3_index]); + } +} + +template +__global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2, + int a3, int b0, int b1, int b2, int b3, int c0, + int c1, int c2, int c3) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + int n = c0 * c1 * c2 * c3; + + for (int i = index; i < n; i += stride) { + int c0_index = i / (c1 * c2 * c3); + int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3); + int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3; + int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3; + + int a0_index = c0_index % a0; + int a1_index = c1_index % a1; + int a2_index = c2_index % a2; + int a3_index = c3_index % a3; + + int b0_index = c0_index % b0; + int b1_index = c1_index % b1; + int b2_index = c2_index % b2; + int b3_index = c3_index % b3; + ((bool *)z)[i] = + ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + + a2_index * a3 + a3_index] < + ((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + + b2_index * b3 + b3_index] + ? true + : false; } } namespace infini { -void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3, +void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3) { int blocksize = block_work_size(); int num = c0 * c1 * c2 * c3; int gridsize = (num + block_work_size() - 1) / block_work_size(); - _div_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, b1, b2, - b3, c0, c1, c2, c3); + _div_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, b1, + b2, b3, c0, c1, c2, c3); } -void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3, +void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, + int b0, int b1, int b2, int b3, int c0, int c1, int c2, + int c3) { + + int blocksize = block_work_size(); + int num = c0 * c1 * c2 * c3; + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _add_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, + b1, b2, b3, c0, c1, c2, c3); +} +void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3) { int blocksize = block_work_size(); int num = c0 * c1 * c2 * c3; int gridsize = (num + block_work_size() - 1) / block_work_size(); - _pow_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, b1, b2, - b3, c0, c1, c2, c3); + _pow_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, b1, + b2, b3, c0, c1, c2, c3); +} +void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3, + int b0, int b1, int b2, int b3, int c0, int c1, int c2, + int c3) { + int blocksize = block_work_size(); + int num = c0 * c1 * c2 * c3; + int gridsize = (num + block_work_size() - 1) / block_work_size(); + _less_kernel<<>>(a, b, c, a0, a1, a2, a3, b0, + b1, b2, b3, c0, c1, c2, c3); } }; // namespace infini diff --git a/src/kernels/cuda/pad_slice.cc b/src/kernels/cuda/pad_slice.cc index 2e7e3931..1ff4dffa 100644 --- a/src/kernels/cuda/pad_slice.cc +++ b/src/kernels/cuda/pad_slice.cc @@ -16,8 +16,9 @@ class PadSliceCudaCompute { metadata.partNDim[i] = partTensor->getDims()[i]; metadata.partStride[i] = partTensor->getStride()[i]; } - pad_slice_kernel(partTensor->getRawDataPtr(), - wholeTensor->getRawDataPtr(), metadata, nDims, + metadata.DType = partTensor->getDType().getIndex(); + pad_slice_kernel(partTensor->getRawDataPtr(), + wholeTensor->getRawDataPtr(), metadata, nDims, wholeTensor->size(), isPad); } }; @@ -40,6 +41,8 @@ class SliceCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig { REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Float32, SliceCuda, "Slice__CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Int64, SliceCuda, + "Slice__CUDA_Int64"); REGISTER_KERNEL(Device::CUDA, OpType::Pad, DataType::Float32, PadCuda, "Pad__CUDA_Float32"); } // namespace infini diff --git a/src/kernels/cuda/pad_slice.cu b/src/kernels/cuda/pad_slice.cu index 828aba3e..f119bd9c 100644 --- a/src/kernels/cuda/pad_slice.cu +++ b/src/kernels/cuda/pad_slice.cu @@ -1,3 +1,4 @@ +#include "core/data_type.h" #include "cuda/cuda_common.h" #include "cuda/cuda_pad_slice.h" @@ -19,9 +20,9 @@ __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset, return offset; } -__global__ void _pad_slice_kernel(float *part, float *whole, - TransMetaData metaData, int nDims, int num, - bool isPad) { +template +__global__ void _pad_slice_kernel(T *part, T *whole, TransMetaData metaData, + int nDims, int num, bool isPad) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= num) return; @@ -41,12 +42,18 @@ __global__ void _pad_slice_kernel(float *part, float *whole, } namespace infini { -void pad_slice_kernel(float *partData, float *wholeData, +void pad_slice_kernel(void *partData, void *wholeData, const TransMetaData &metadata, int nDims, int num, bool isPad) { int blockSize = 32 * 16; int gridSize = (num + blockSize - 1) / blockSize; - _pad_slice_kernel<<>>(partData, wholeData, metadata, - nDims, num, isPad); + if (metadata.DType == DataType::Int64.getIndex()) { + _pad_slice_kernel + <<>>((int64_t *)partData, (int64_t *)wholeData, + metadata, nDims, num, isPad); + } else if (metadata.DType == DataType::Float32.getIndex()) { + _pad_slice_kernel<<>>( + (float *)partData, (float *)wholeData, metadata, nDims, num, isPad); + } } } // namespace infini diff --git a/src/kernels/cuda/split_concat.cc b/src/kernels/cuda/split_concat.cc index dbe2a7ac..d3f8a551 100644 --- a/src/kernels/cuda/split_concat.cc +++ b/src/kernels/cuda/split_concat.cc @@ -59,6 +59,21 @@ class CudaCompute { class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { + auto inputs = _op->getInputs(); + if (inputs.size() == 2) { + for (size_t i = 0; i < 2; i++) { + if (inputs[i]->size() == 0) { + auto inData = + _op->getInputs(1 - i)->getRawDataPtr(); + auto outData = + _op->getOutputs()[0]->getRawDataPtr(); + cudaMemcpyAsync(outData, inData, + _op->getInputs(1 - i)->getBytes(), + cudaMemcpyDeviceToDevice); + return; + } + } + } do_compute(_op->getOutput(), _op->getInputs(), as(_op)->getDim(), _op->getOutput()->getRank(), false); diff --git a/src/operators/G2BMM.cc b/src/operators/G2BMM.cc index 499c1f77..81c8737f 100644 --- a/src/operators/G2BMM.cc +++ b/src/operators/G2BMM.cc @@ -20,15 +20,18 @@ string G2BMMObj::toString() const { return os.str(); } -optional> G2BMMObj::inferShape(const TensorVec &inputs) const { +optional> G2BMMObj::inferShape(const TensorVec &inputs) { auto A = inputs[0], B = inputs[1]; + b = A->getDims()[0]; + m = A->getDims()[1]; + k = A->getDims()[2]; IT_ASSERT(A->getRank() == 3 && B->getRank() == 3); IT_ASSERT(A->getDims()[0] == B->getDims()[0]); IT_ASSERT(A->getDims()[1] == B->getDims()[1]); IT_ASSERT(A->getDims()[2] == B->getDims()[2]); IT_ASSERT(width >= 0); - int b(A->getDims()[0]), m(A->getDims()[1]), n(2 * width + 1); + int n(2 * width + 1); return {{{b, m, n}}}; } diff --git a/src/operators/GBMM.cc b/src/operators/GBMM.cc index d51128fa..f8ee09c0 100644 --- a/src/operators/GBMM.cc +++ b/src/operators/GBMM.cc @@ -21,15 +21,18 @@ string GBMMObj::toString() const { return os.str(); } -optional> GBMMObj::inferShape(const TensorVec &inputs) const { +optional> GBMMObj::inferShape(const TensorVec &inputs) { auto A = inputs[0], B = inputs[1]; + b = A->getDims()[0]; + m = A->getDims()[1]; + w = (A->getDims()[2] - 1) / 2; + n = B->getDims()[2]; IT_ASSERT(A->getRank() == 3 && B->getRank() == 3); IT_ASSERT(A->getDims()[0] == B->getDims()[0]); IT_ASSERT(A->getDims()[1] == B->getDims()[1]); IT_ASSERT(A->getDims()[2] % 2 != 0); - int b(A->getDims()[0]), m(A->getDims()[1]), k(B->getDims()[2]); - return {{{b, m, k}}}; + return {{{b, m, n}}}; } vector GBMMObj::getWorkloadVector() const { diff --git a/src/operators/activation_backward.cc b/src/operators/activation_backward.cc index b968c936..142c3692 100644 --- a/src/operators/activation_backward.cc +++ b/src/operators/activation_backward.cc @@ -9,7 +9,7 @@ ActivationBackwardObj::ActivationBackwardObj(OpType type, GraphObj *graph, } optional> -ActivationBackwardObj::inferShape(const TensorVec &inputs) const { +ActivationBackwardObj::inferShape(const TensorVec &inputs) { return {{inputs[0]->getDims()}}; } diff --git a/src/operators/all_gather.cc b/src/operators/all_gather.cc index 127c3b8d..e4ffe9bf 100644 --- a/src/operators/all_gather.cc +++ b/src/operators/all_gather.cc @@ -10,8 +10,7 @@ AllGatherObj::AllGatherObj(GraphObj *graph, Tensor input, IT_ASSERT(checkValid(graph)); } -optional> -AllGatherObj::inferShape(const TensorVec &inputs) const { +optional> AllGatherObj::inferShape(const TensorVec &inputs) { Shape input_shape = inputs[0]->getDims(); vector output_shapes(getWorldSize(), input_shape); return output_shapes; diff --git a/src/operators/attention_kvcache.cc b/src/operators/attention_kvcache.cc index 9893f509..492a76f7 100644 --- a/src/operators/attention_kvcache.cc +++ b/src/operators/attention_kvcache.cc @@ -18,7 +18,7 @@ AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache, } optional> -AttentionKVCacheObj::inferShape(const TensorVec &inputs) const { +AttentionKVCacheObj::inferShape(const TensorVec &inputs) { IT_ASSERT(inputs.size() == 6); Shape dims = inputs[0]->getDims(); ShapeElem n = dims.at(dim); diff --git a/src/operators/batch_norm.cc b/src/operators/batch_norm.cc index ba68cbfd..bbd4c1bf 100644 --- a/src/operators/batch_norm.cc +++ b/src/operators/batch_norm.cc @@ -13,8 +13,7 @@ BatchNormObj::BatchNormObj(GraphObj *graph, Tensor input, Tensor output, IT_ASSERT(checkValid(graph)); } -optional> -BatchNormObj::inferShape(const TensorVec &inputs) const { +optional> BatchNormObj::inferShape(const TensorVec &inputs) { auto input = inputs[0]; auto mean = inputs[1]; auto var = inputs[2]; diff --git a/src/operators/concat.cc b/src/operators/concat.cc index 95535233..021aefef 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -9,9 +9,16 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim) IT_ASSERT(checkValid(graph)); } -optional> ConcatObj::inferShape(const TensorVec &inputs) const { +optional> ConcatObj::inferShape(const TensorVec &inputs) { Shape dims = inputs[0]->getDims(); auto rank = inputs[0]->getRank(); + if (inputs.size() == 2) { + for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i]->size() == 0) { + return {{inputs[1 - i]->getDims()}}; + } + } + } ShapeElem n = dims.at(dim); for (auto itr = inputs.begin() + 1; itr != inputs.end(); ++itr) { auto input = *itr; diff --git a/src/operators/conv.cc b/src/operators/conv.cc index 8c3eafb4..77fc9aef 100644 --- a/src/operators/conv.cc +++ b/src/operators/conv.cc @@ -82,14 +82,15 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, IT_ASSERT(checkValid(graph)); } -optional> ConvObj::inferShape(const TensorVec &inputs) const { +optional> ConvObj::inferShape(const TensorVec &inputs) { const auto &input = inputs[0], &weight = inputs[1]; - auto n = input->getDims()[0]; - auto h = input->getDims()[2]; - auto w = input->getDims()[3]; - auto f = weight->getDims()[0]; - auto r = weight->getDims()[2]; - auto s = weight->getDims()[3]; + n = input->getDims()[0]; + c = input->getDims()[1]; + h = input->getDims()[2]; + w = input->getDims()[3]; + f = weight->getDims()[0]; + r = weight->getDims()[2]; + s = weight->getDims()[3]; int on = n, oc = f; int oh = 0, ow = 0; // For NCHW+FCRS layout, C of input is divisable by C of weight @@ -141,15 +142,15 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input, } optional> -ConvTransposed2dObj::inferShape(const TensorVec &inputs) const { +ConvTransposed2dObj::inferShape(const TensorVec &inputs) { const Tensor &input = inputs[0], &weight = inputs[1]; - auto n = input->getDims()[0]; - auto f = input->getDims()[1]; - auto h = input->getDims()[2]; - auto w = input->getDims()[3]; - auto c = weight->getDims()[1]; - auto r = weight->getDims()[2]; - auto s = weight->getDims()[3]; + n = input->getDims()[0]; + f = input->getDims()[1]; + h = input->getDims()[2]; + w = input->getDims()[3]; + c = weight->getDims()[1]; + r = weight->getDims()[2]; + s = weight->getDims()[3]; IT_ASSERT(f == weight->getDims()[0]); int on = n, oc = c * group; @@ -219,14 +220,15 @@ ConvBackwardFilterObj::ConvBackwardFilterObj(GraphObj *graph, Tensor inputX, } optional> -ConvBackwardFilterObj::inferShape(const TensorVec &inputs) const { +ConvBackwardFilterObj::inferShape(const TensorVec &inputs) { const auto &inputX = inputs[0], &diffY = inputs[1]; - auto n = inputX->getDims()[0]; - auto h = inputX->getDims()[2]; - auto w = inputX->getDims()[3]; - auto f = diffY->getDims()[0]; - auto r = diffY->getDims()[2]; - auto s = diffY->getDims()[3]; + n = inputX->getDims()[0]; + c = inputX->getDims()[1]; + h = inputX->getDims()[2]; + w = inputX->getDims()[3]; + f = diffY->getDims()[0]; + r = diffY->getDims()[2]; + s = diffY->getDims()[3]; int on = n, oc = f; int oh = 0, ow = 0; // For NCHW+FCRS layout, C of input is divisable by C of weight @@ -280,17 +282,16 @@ ConvTransposed2dNHWCObj::ConvTransposed2dNHWCObj(GraphObj *graph, Tensor input, } optional> -ConvTransposed2dNHWCObj::inferShape(const TensorVec &inputs) const { +ConvTransposed2dNHWCObj::inferShape(const TensorVec &inputs) { const Tensor &input = inputs[0], &weight = inputs[1]; - auto n = input->getDims()[0]; - auto f = input->getDims()[3]; - auto h = input->getDims()[1]; - auto w = input->getDims()[2]; - auto c = weight->getDims()[3]; - auto r = weight->getDims()[1]; - auto s = weight->getDims()[2]; - if (f != weight->getDims()[0]) - return {}; + n = input->getDims()[0]; + f = input->getDims()[3]; + h = input->getDims()[1]; + w = input->getDims()[2]; + c = weight->getDims()[3]; + r = weight->getDims()[1]; + s = weight->getDims()[2]; + IT_ASSERT(f == weight->getDims()[0]); int on = n, oc = c * group; int oh = 0, ow = 0; diff --git a/src/operators/det.cc b/src/operators/det.cc index 473982cd..f5d16af5 100644 --- a/src/operators/det.cc +++ b/src/operators/det.cc @@ -6,7 +6,7 @@ DetObj::DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode) IT_ASSERT(checkValid(graph)); } -optional> DetObj::inferShape(const TensorVec &inputs) const { +optional> DetObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0]; auto input = A->getDims(); int rank = A->getRank(); diff --git a/src/operators/dropout.cc b/src/operators/dropout.cc index 08eca92a..7dcb70db 100644 --- a/src/operators/dropout.cc +++ b/src/operators/dropout.cc @@ -10,7 +10,7 @@ DropoutObj::DropoutObj(GraphObj *graph, Tensor data, Tensor output, Tensor mask, IT_ASSERT(checkValid(graph)); } -optional> DropoutObj::inferShape(const TensorVec &inputs) const { +optional> DropoutObj::inferShape(const TensorVec &inputs) { auto shape = inputs[0]->getDims(); return {{shape, shape}}; } diff --git a/src/operators/element_wise.cc b/src/operators/element_wise.cc index d86ccccf..6445c0d5 100644 --- a/src/operators/element_wise.cc +++ b/src/operators/element_wise.cc @@ -8,8 +8,7 @@ ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, IT_ASSERT(checkValid(graph)); } -optional> -ElementWiseObj::inferShape(const TensorVec &inputs) const { +optional> ElementWiseObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0], B = inputs[1]; auto res = infer_broadcast(A->getDims(), B->getDims()); return {{res}}; @@ -45,7 +44,7 @@ MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1, IT_ASSERT(checkValid(graph)); } -optional> MSELossObj::inferShape(const TensorVec &inputs) const { +optional> MSELossObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0], B = inputs[1]; IT_ASSERT(A->getRank() == B->getRank()); IT_ASSERT(A->getDims() == B->getDims()); diff --git a/src/operators/expand.cc b/src/operators/expand.cc index faebb34a..8ffcc75b 100644 --- a/src/operators/expand.cc +++ b/src/operators/expand.cc @@ -8,7 +8,7 @@ ExpandObj::ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims) IT_ASSERT(checkValid(graph)); } -optional> ExpandObj::inferShape(const TensorVec &inputs) const { +optional> ExpandObj::inferShape(const TensorVec &inputs) { auto shape_input = inputs[0]->getDims(); Shape ret = infer_broadcast(shape_input, dims); return {{ret}}; diff --git a/src/operators/extend.cc b/src/operators/extend.cc index e8587dbb..c3a678e3 100644 --- a/src/operators/extend.cc +++ b/src/operators/extend.cc @@ -11,7 +11,7 @@ ExtendObj::ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim, IT_ASSERT(checkValid(graph)); } -optional> ExtendObj::inferShape(const TensorVec &inputs) const { +optional> ExtendObj::inferShape(const TensorVec &inputs) { auto ret = inputs[0]->getDims(); ret[dim] = ret[dim] * (num + 1); return {{ret}}; diff --git a/src/operators/gather.cc b/src/operators/gather.cc index 0cddca3c..b0c8a77a 100644 --- a/src/operators/gather.cc +++ b/src/operators/gather.cc @@ -10,7 +10,7 @@ GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices, IT_ASSERT(checkValid(graph)); } -optional> GatherObj::inferShape(const TensorVec &inputs) const { +optional> GatherObj::inferShape(const TensorVec &inputs) { auto dims0 = inputs[0]->getDims(); auto dims1 = inputs[1]->getDims(); diff --git a/src/operators/gather_elements.cc b/src/operators/gather_elements.cc index a1e6bffe..2e224f3e 100644 --- a/src/operators/gather_elements.cc +++ b/src/operators/gather_elements.cc @@ -24,8 +24,7 @@ bool checkShape(Tensor input, Tensor indices, int axis) { return true; } -optional> -GatherElementsObj::inferShape(const TensorVec &inputs) const { +optional> GatherElementsObj::inferShape(const TensorVec &inputs) { IT_ASSERT(checkShape(inputs[0], inputs[1], axis)); auto indicesDims = inputs[1]->getDims(); // output has same shape as indices return {{indicesDims}}; diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 00207e77..60cbb826 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -9,25 +9,6 @@ MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, : OperatorObj(OpType::MatMul, bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}), transA(transA), transB(transB), act(act), b(1) { - auto shape_a = A->getDims(); - auto shape_b = B->getDims(); - int rankA = A->getRank(); - int rankB = B->getRank(); - IT_ASSERT(rankA >= 2 && rankB >= 2); - Shape shape_a1(shape_a.begin(), shape_a.begin() + (rankA - 2)); - Shape shape_b1(shape_b.begin(), shape_b.begin() + (rankB - 2)); - auto ret = infer_broadcast(shape_a1, shape_b1); - if (ret.empty()) { - b = 1; - } else { - b = std::accumulate(ret.begin(), ret.end(), 1, std::multiplies()); - } - auto kA = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin()); - auto kB = *(transB ? shape_b.rbegin() : shape_b.rbegin() + 1); - IT_ASSERT(kA == kB); - m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1); - n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin()); - k = kA; IT_ASSERT(checkValid(graph)); } @@ -40,7 +21,7 @@ string MatmulObj::toString() const { return os.str(); } -optional> MatmulObj::inferShape(const TensorVec &inputs) const { +optional> MatmulObj::inferShape(const TensorVec &inputs) { auto A = inputs[0], B = inputs[1]; auto shapeA = A->getDims(); auto shapeB = B->getDims(); @@ -49,6 +30,17 @@ optional> MatmulObj::inferShape(const TensorVec &inputs) const { Shape shapeA1(shapeA.begin(), shapeA.begin() + (rankA - 2)); Shape shapeB1(shapeB.begin(), shapeB.begin() + (rankB - 2)); Shape ret = infer_broadcast(shapeA1, shapeB1); + if (ret.empty()) { + b = 1; + } else { + b = std::accumulate(ret.begin(), ret.end(), 1, std::multiplies()); + } + auto kA = *(transA ? shapeA.rbegin() + 1 : shapeA.rbegin()); + auto kB = *(transB ? shapeB.rbegin() : shapeB.rbegin() + 1); + IT_ASSERT(kA == kB); + m = *(transA ? shapeA.rbegin() : shapeA.rbegin() + 1); + n = *(transB ? shapeB.rbegin() + 1 : shapeB.rbegin()); + k = kA; ret.emplace_back(m); ret.emplace_back(n); return {{ret}}; diff --git a/src/operators/membound.cc b/src/operators/membound.cc index be757b36..2aa07ae2 100644 --- a/src/operators/membound.cc +++ b/src/operators/membound.cc @@ -60,7 +60,7 @@ string MemBoundObj::toString() const { return os.str(); } -optional> MemBoundObj::inferShape(const TensorVec &inputs) const { +optional> MemBoundObj::inferShape(const TensorVec &inputs) { // inputs have to match nnetInputs excatly if (inputs.size() != nnetInputs.size()) return {}; diff --git a/src/operators/pad.cc b/src/operators/pad.cc index b870e449..96b9811f 100644 --- a/src/operators/pad.cc +++ b/src/operators/pad.cc @@ -22,7 +22,7 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output, IT_ASSERT(checkValid(graph)); } -optional> PadObj::inferShape(const TensorVec &inputs) const { +optional> PadObj::inferShape(const TensorVec &inputs) { auto dims = inputs[0]->getDims(); int rank = inputs[0]->getRank(); IT_ASSERT(rank * 2 == (int)pads.size()); diff --git a/src/operators/pooling.cc b/src/operators/pooling.cc index b1bb2e3d..836ac522 100644 --- a/src/operators/pooling.cc +++ b/src/operators/pooling.cc @@ -12,7 +12,7 @@ PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input, IT_ASSERT(checkValid(graph)); } -optional> PoolingObj::inferShape(const TensorVec &inputs) const { +optional> PoolingObj::inferShape(const TensorVec &inputs) { const auto &input = inputs[0]; auto h = input->getDims()[input->getRank() - 2], w = input->getDims()[input->getRank() - 1]; diff --git a/src/operators/reduce_mean.cc b/src/operators/reduce_mean.cc index e3a5ec97..cf801c59 100644 --- a/src/operators/reduce_mean.cc +++ b/src/operators/reduce_mean.cc @@ -21,8 +21,7 @@ bool ReduceMeanObj::isReduced(int idx) const { return axes.find(idx) != axes.end(); } -optional> -ReduceMeanObj::inferShape(const TensorVec &inputs) const { +optional> ReduceMeanObj::inferShape(const TensorVec &inputs) { auto dims = inputs[0]->getDims(); auto rank = inputs[0]->getRank(); diff --git a/src/operators/reshape.cc b/src/operators/reshape.cc index df216601..2a65345e 100644 --- a/src/operators/reshape.cc +++ b/src/operators/reshape.cc @@ -1,5 +1,6 @@ #include "operators/reshape.h" #include "utils/operator_utils.h" +#include namespace infini { ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims) @@ -7,14 +8,37 @@ ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims) IT_ASSERT(checkValid(graph)); } -optional> ReshapeObj::inferShape(const TensorVec &inputs) const { - size_t size = 1; - for (size_t i = 0; i < dims.size(); ++i) { - size *= dims.at(i); +optional> ReshapeObj::inferShape(const TensorVec &inputs) { + int count = 0; + for (auto x : dims) { + if (x == -1) { + count++; + } + IT_ASSERT(x == -1 || x >= 0); } - IT_ASSERT(size == inputs[0]->size()); + IT_ASSERT(count == 0 || count == 1); + auto inputShape = inputs[0]->getDims(); + int size = inputs[0]->size(); + int index = -1; + outputShape = dims; + for (int i = 0; i < (int)dims.size(); ++i) { + if (dims[i] == 0) { + outputShape[i] = inputShape[i]; + } + if (dims[i] == -1) { + index = i; + } + } + if (index != -1) { + outputShape[index] = + size / (-std::accumulate(outputShape.begin(), outputShape.end(), 1, + [](auto acc, auto x) { return acc * x; })); + } + int outputSize = std::accumulate(outputShape.begin(), outputShape.end(), 1, + [](auto acc, auto x) { return acc * x; }); + IT_ASSERT(outputSize == size); - return {{dims}}; + return {{outputShape}}; } std::string ReshapeObj::toString() const { @@ -22,7 +46,7 @@ std::string ReshapeObj::toString() const { os << "Reshape[" << getGuid() << "]"; os << "("; os << vecToString(inputs[0]->getDims()) << ","; - os << "dims=" << vecToString(dims) << ","; + os << "outputShape=" << vecToString(outputShape) << ","; os << "input=" << inputs[0]->getGuid() << ","; os << "output=" << outputs[0]->getGuid() << ")"; return os.str(); @@ -30,12 +54,12 @@ std::string ReshapeObj::toString() const { vector ReshapeObj::getWorkloadVector() const { vector ret = inputs[0]->getDims(); - ret.insert(ret.end(), dims.begin(), dims.end()); + ret.insert(ret.end(), outputShape.begin(), outputShape.end()); ret.emplace(ret.begin(), type.underlying()); return ret; } vector ReshapeObj::getOpAttrVector() const { - vector ret = dims; + vector ret = outputShape; ret.emplace(ret.begin(), type.underlying()); return ret; } @@ -47,7 +71,7 @@ FlattenObj::FlattenObj(GraphObj *graph, Tensor input, Tensor output, int _axis) IT_ASSERT(checkValid(graph)); } -optional> FlattenObj::inferShape(const TensorVec &inputs) const { +optional> FlattenObj::inferShape(const TensorVec &inputs) { int sizeB = 1, sizeE = 1; auto dims = getInputs(0)->getDims(); int rank = getInputs(0)->getRank(); @@ -84,7 +108,7 @@ IdentityObj::IdentityObj(GraphObj *graph, Tensor input, Tensor output) IT_ASSERT(checkValid(graph)); } -optional> IdentityObj::inferShape(const TensorVec &inputs) const { +optional> IdentityObj::inferShape(const TensorVec &inputs) { return {{getInputs(0)->getDims()}}; } diff --git a/src/operators/resize.cc b/src/operators/resize.cc index 11933414..0f0b08fe 100644 --- a/src/operators/resize.cc +++ b/src/operators/resize.cc @@ -206,7 +206,7 @@ float ResizeObj::round_int(float x) const { } // output shape is related to sizes/scales value. -optional> ResizeObj::inferShape(const TensorVec &inputs) const { +optional> ResizeObj::inferShape(const TensorVec &inputs) { auto inDims = inputs[0]->getDims(); Shape ret = inDims; int rank = inputs[0]->getRank(); diff --git a/src/operators/slice.cc b/src/operators/slice.cc index 0db3b1a2..691a63b5 100644 --- a/src/operators/slice.cc +++ b/src/operators/slice.cc @@ -62,7 +62,7 @@ SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output, IT_ASSERT(checkValid(graph)); } -optional> SliceObj::inferShape(const TensorVec &inputs) const { +optional> SliceObj::inferShape(const TensorVec &inputs) { Shape ans; ans.reserve(axes.size()); for (const auto &range : axes) { diff --git a/src/operators/split.cc b/src/operators/split.cc index be541326..95c7034f 100644 --- a/src/operators/split.cc +++ b/src/operators/split.cc @@ -35,7 +35,7 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input, IT_ASSERT(checkValid(graph)); } -optional> SplitObj::inferShape(const TensorVec &inputs) const { +optional> SplitObj::inferShape(const TensorVec &inputs) { IT_ASSERT(num != -1 && ratio.size() != 0); auto inputDims = inputs[0]->getDims(); int totalSize = inputDims.at(dim); diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index f4c6a28d..9a05a785 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -16,8 +16,7 @@ TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output, IT_ASSERT(checkValid(graph)); } -optional> -TransposeObj::inferShape(const TensorVec &inputs) const { +optional> TransposeObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0]; auto input_dim = A->getDims(); auto output_dim = input_dim; @@ -66,8 +65,7 @@ DepthToSpaceObj::DepthToSpaceObj(GraphObj *graph, Tensor input, Tensor output, IT_ASSERT(checkValid(graph)); } -optional> -DepthToSpaceObj::inferShape(const TensorVec &inputs) const { +optional> DepthToSpaceObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0]; auto inputDim = A->getDims(); IT_ASSERT(inputDim.size() == 4); diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 7f98940a..79d2ab83 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -6,7 +6,7 @@ UnaryObj::UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output) IT_ASSERT(checkValid(graph)); } -optional> UnaryObj::inferShape(const TensorVec &inputs) const { +optional> UnaryObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0]; return {{A->getDims()}}; } @@ -37,7 +37,7 @@ ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output, IT_ASSERT(checkValid(graph)); } -optional> ClipObj::inferShape(const TensorVec &inputs) const { +optional> ClipObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0]; return {{A->getDims()}}; } @@ -68,7 +68,7 @@ HardtanhObj::HardtanhObj(GraphObj *graph, Tensor input, Tensor output, IT_ASSERT(checkValid(graph)); } -optional> HardtanhObj::inferShape(const TensorVec &inputs) const { +optional> HardtanhObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0]; return {{A->getDims()}}; } @@ -97,7 +97,7 @@ FillObj::FillObj(GraphObj *graph, Tensor input, Tensor output, float value) IT_ASSERT(checkValid(graph)); } -optional> FillObj::inferShape(const TensorVec &inputs) const { +optional> FillObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0]; return {{A->getDims()}}; } @@ -124,7 +124,7 @@ L2LossObj::L2LossObj(GraphObj *graph, Tensor input, Tensor output) IT_ASSERT(checkValid(graph)); } -optional> L2LossObj::inferShape(const TensorVec &inputs) const { +optional> L2LossObj::inferShape(const TensorVec &inputs) { Shape temp = {1}; return {{temp}}; } @@ -159,7 +159,7 @@ vector CastObj::inferDataType(const TensorVec &inputs) const { return vector(numOutputs(), output_dataType); } -optional> CastObj::inferShape(const TensorVec &inputs) const { +optional> CastObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0]; return {{A->getDims()}}; } @@ -241,7 +241,7 @@ ShapeObj::ShapeObj(GraphObj *graph, Tensor input, Tensor output) IT_ASSERT(checkValid(graph)); } -optional> ShapeObj::inferShape(const TensorVec &inputs) const { +optional> ShapeObj::inferShape(const TensorVec &inputs) { return {{{static_cast(inputs[0]->getRank())}}}; } @@ -257,7 +257,7 @@ PReluObj::PReluObj(GraphObj *graph, Tensor input, Tensor alpha, Tensor output) IT_ASSERT(checkValid(graph)); } -optional> PReluObj::inferShape(const TensorVec &inputs) const { +optional> PReluObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0]; return {{A->getDims()}}; } @@ -286,7 +286,7 @@ LogObj::LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type) IT_ASSERT(checkValid(graph)); } -optional> LogObj::inferShape(const TensorVec &inputs) const { +optional> LogObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0]; return {{A->getDims()}}; } diff --git a/src/operators/where.cc b/src/operators/where.cc index 290ca7c6..7eac50d7 100644 --- a/src/operators/where.cc +++ b/src/operators/where.cc @@ -10,7 +10,7 @@ WhereObj::WhereObj(GraphObj *graph, Tensor inputX, Tensor inputY, IT_ASSERT(checkValid(graph)); } -optional> WhereObj::inferShape(const TensorVec &inputs) const { +optional> WhereObj::inferShape(const TensorVec &inputs) { auto shapeX = inputs[0]->getDims(); auto shapeY = inputs[1]->getDims(); auto shapeCon = inputs[2]->getDims(); diff --git a/test/kernels/cuda/test_cuda_concat.cc b/test/kernels/cuda/test_cuda_concat.cc index 2c76f405..12e18a56 100644 --- a/test/kernels/cuda/test_cuda_concat.cc +++ b/test/kernels/cuda/test_cuda_concat.cc @@ -158,4 +158,33 @@ TEST(Concat, CudaHigh) { 12., 13., 14., 15., 16., 17., 1., 1., 1., 1., 1., 1., 18., 19., 20., 21., 22., 23., 1., 1., 1., 1., 1., 1.})); } + +TEST(ConcatToIdentity, Cuda) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto t1 = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32); + auto t2 = gCpu->addTensor({0}, DataType::Float32); + gCpu->dataMalloc(); + t1->setData(IncrementalGenerator()); + t2->setData(OneGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto t1Gpu = gCuda->cloneTensor(t1); + auto t2Gpu = gCuda->cloneTensor(t2); + + auto op = gCuda->addOp(TensorVec{t1Gpu, t2Gpu}, nullptr, 2); + gCuda->dataMalloc(); + t1Gpu->setData(IncrementalGenerator()); + t2Gpu->setData(OneGenerator()); + cudaRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE( + oCpu->equalData(vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})); +} } // namespace infini diff --git a/test/operators/test_concat.cc b/test/operators/test_concat.cc index 9a0fe74e..32b50aa2 100644 --- a/test/operators/test_concat.cc +++ b/test/operators/test_concat.cc @@ -14,4 +14,13 @@ TEST(Concat, ShapeInfer) { EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 9})); } +TEST(Concat, ShapeInfer1) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph g = make_ref(runtime); + auto t1 = g->addTensor({1, 3, 2, 4}, DataType::Float32); + auto t2 = g->addTensor({0}, DataType::Float32); + + auto op = g->addOp(TensorVec{t1, t2}, nullptr, 3); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4})); +} } // namespace infini