support Dynamic tensor infer shape and fix memory pool (#176)

* feat: support dynamic tensor part1

* feat: support dynamic-tensor part2

* feat: support dynamic tensor part 3

* fix: fix some ..

* - add kvcache example

* feat: support concat to identity kernel

* add a simple mempory pool for allocator

* fix: rebase to master

* fix bug after merging

* - remove outdated script

* fix: fix as review

---------

Co-authored-by: kilinchange <kilinchange@163.com>
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
zhangyunze 2023-11-23 13:11:50 +08:00 committed by GitHub
parent 965df4e294
commit 331f7ab2b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
78 changed files with 605 additions and 229 deletions

View File

@ -53,6 +53,7 @@ class GraphObj : public Object {
const TensorVec &getTensors() const { return tensors; } const TensorVec &getTensors() const { return tensors; }
const OpVec &getOperators() const { return ops; } const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const; OpVec getComputeOps() const;
Tensor getTensor(int) const;
/** /**
* Sort the nodes in topological order. * Sort the nodes in topological order.
@ -64,7 +65,13 @@ class GraphObj : public Object {
void optimize(); void optimize();
void dataMalloc(bool useNaiveAllocator = false); void shape_infer();
void dataMalloc(bool useNaiveAllocator = false, size_t memPoolSize = 0);
Tensor cloneKV(Tensor &tensor);
void freeHeap();
/** /**
* @brief Add an operator and create its outputs. Output tensor arguments * @brief Add an operator and create its outputs. Output tensor arguments

View File

@ -81,6 +81,7 @@ class GraphHandlerObj {
Tensor cast(Tensor input, Tensor output, int to); Tensor cast(Tensor input, Tensor output, int to);
Tensor expand(Tensor input, Tensor output, Shape dims); Tensor expand(Tensor input, Tensor output, Shape dims);
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output); Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
std::vector<int> getDims(Tensor x) { return x->getDims(); }
Tensor allReduceSum(Tensor input, Tensor output); Tensor allReduceSum(Tensor input, Tensor output);
Tensor allReduceProd(Tensor input, Tensor output); Tensor allReduceProd(Tensor input, Tensor output);
@ -98,9 +99,19 @@ class GraphHandlerObj {
inline void optimize() { g->optimize(); } inline void optimize() { g->optimize(); }
inline void shape_infer() { g->shape_infer(); }
void change_shape(const vector<int> &shape, int tensorId);
//------ runtime //------ runtime
inline void data_malloc() { g->dataMalloc(); } inline void data_malloc(bool useNaiveAllocator = false,
size_t memPoolSize = 0) {
g->dataMalloc(useNaiveAllocator, memPoolSize);
}
inline Tensor clone_KV(Tensor &tensor) { return g->cloneKV(tensor); }
inline void free_heap() { g->freeHeap(); }
inline void tune() { g->getRuntime()->run(g, true); } inline void tune() { g->getRuntime()->run(g, true); }

View File

@ -26,14 +26,23 @@ class LazyAllocator {
size_t weightPeak = 0; size_t weightPeak = 0;
size_t heapPeak = 0;
size_t alignment; size_t alignment;
bool hasMemPool = false;
size_t memPoolSize = 0;
// pointer to the memory actually allocated // pointer to the memory actually allocated
void *ptr = nullptr; void *ptr = nullptr;
// pointer to the weight memory space // pointer to the weight memory space
void *weightPtr = nullptr; void *weightPtr = nullptr;
// memory pool ptr
void *memPoolPtr = nullptr;
// // a cache designed for a batch size that has already occurred // // a cache designed for a batch size that has already occurred
// std::unordered_map<size_t, std::unordered_map<TensorObj *, size_t>> // std::unordered_map<size_t, std::unordered_map<TensorObj *, size_t>>
// batchsizeToTensorOffset; // batchsizeToTensorOffset;
@ -68,6 +77,10 @@ class LazyAllocator {
void init(); void init();
void setMemPool(size_t memPoolSize);
bool getMemPoolStatus();
// function: simulate memory allocation // function: simulate memory allocation
// arguments // arguments
// size: size of memory block to be allocated // size: size of memory block to be allocated
@ -76,6 +89,10 @@ class LazyAllocator {
size_t allocWeight(size_t size); size_t allocWeight(size_t size);
size_t heapAlloc(size_t size);
void freeHeap();
// function: simulate memory free // function: simulate memory free
// arguments: // arguments:
// addr: head address offset of memory block to be free // addr: head address offset of memory block to be free
@ -92,6 +109,8 @@ class LazyAllocator {
void *getWeightPtr(); void *getWeightPtr();
void *getHeapPtr();
void info(); void info();
private: private:

View File

@ -55,8 +55,7 @@ class OperatorObj : public Object {
public: public:
OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs); OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs);
virtual optional<vector<Shape>> virtual optional<vector<Shape>> inferShape(const TensorVec &inputs) = 0;
inferShape(const TensorVec &inputs) const = 0;
virtual vector<DataType> inferDataType(const TensorVec &inputs) const; virtual vector<DataType> inferDataType(const TensorVec &inputs) const;
/** /**
* @brief Constructs outputs (if requried) and check whether the operator is * @brief Constructs outputs (if requried) and check whether the operator is
@ -105,7 +104,7 @@ class OperatorObj : public Object {
const TensorVec &newOutputs) const = 0; const TensorVec &newOutputs) const = 0;
protected: protected:
optional<vector<Shape>> inferShape() const; optional<vector<Shape>> inferShape();
vector<DataType> inferDataType() const; vector<DataType> inferDataType() const;
private: private:

View File

@ -31,6 +31,7 @@ class TensorObj : public TensorBaseObj {
size_t getBytes() const { return _size * dtype.getSize(); } size_t getBytes() const { return _size * dtype.getSize(); }
Shape getDims() const { return shape; } Shape getDims() const { return shape; }
void setShape(Shape shape_);
size_t getRank() const { return shape.size(); } size_t getRank() const { return shape.size(); }
Shape getStride() const; Shape getStride() const;
size_t getOffset(const vector<int> &ds) const; size_t getOffset(const vector<int> &ds) const;

View File

@ -1,8 +1,13 @@
#pragma once #pragma once
namespace infini { namespace infini {
void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3, void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3, void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3);
}; // namespace infini }; // namespace infini

View File

@ -10,10 +10,11 @@ typedef struct {
int wholeNDim[MAX_DIM]; // dim size after padding or before slicing int wholeNDim[MAX_DIM]; // dim size after padding or before slicing
int partNDim[MAX_DIM]; // dim size before padding or after slicing int partNDim[MAX_DIM]; // dim size before padding or after slicing
int partStride[MAX_DIM]; // stride before padding or after slicing int partStride[MAX_DIM]; // stride before padding or after slicing
int DType;
} TransMetaData; } TransMetaData;
namespace infini { namespace infini {
void pad_slice_kernel(float *partData, float *wholeData, void pad_slice_kernel(void *partData, void *wholeData,
const TransMetaData &metadata, int nDims, int num, const TransMetaData &metadata, int nDims, int num,
bool isPad); bool isPad);
} // namespace infini } // namespace infini

View File

@ -35,7 +35,7 @@ class G2BMMObj : public OperatorObj {
OP_CLONE(G2BMMObj); OP_CLONE(G2BMMObj);
std::string toString() const override; std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
int numInputs() const override { return 2; } int numInputs() const override { return 2; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }

View File

@ -33,7 +33,7 @@ class GBMMObj : public OperatorObj {
OP_CLONE(GBMMObj); OP_CLONE(GBMMObj);
std::string toString() const override; std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
int numInputs() const override { return 2; } int numInputs() const override { return 2; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }

View File

@ -7,7 +7,7 @@ class ActivationBackwardObj : public OperatorObj {
ActivationBackwardObj(OpType type, GraphObj *graph, Tensor y, Tensor diff_y, ActivationBackwardObj(OpType type, GraphObj *graph, Tensor y, Tensor diff_y,
Tensor x, Tensor diff_x); Tensor x, Tensor diff_x);
OP_CLONE(ActivationBackwardObj); OP_CLONE(ActivationBackwardObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 3; } int numInputs() const override { return 3; }

View File

@ -27,7 +27,7 @@ class AllGatherObj : public OperatorObj {
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
int numOutputs() const override { return world_size; } int numOutputs() const override { return world_size; }
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;

View File

@ -33,7 +33,7 @@ class AllReduceBaseObj : public OperatorObj {
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override { optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
return {{inputs[0]->getDims()}}; return {{inputs[0]->getDims()}};
}; };

View File

@ -29,7 +29,7 @@ class AttentionKVCacheObj : public OperatorObj {
Tensor output_matmul); Tensor output_matmul);
OP_CLONE(AttentionKVCacheObj); OP_CLONE(AttentionKVCacheObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 6; } int numInputs() const override { return 6; }

View File

@ -34,7 +34,7 @@ class BatchNormObj : public OperatorObj {
Tensor var, Tensor scale, Tensor bias, float momentum = 0.9, Tensor var, Tensor scale, Tensor bias, float momentum = 0.9,
float eps = 1e-5, bool trainingMode = false); float eps = 1e-5, bool trainingMode = false);
OP_CLONE(BatchNormObj); OP_CLONE(BatchNormObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
// output size will be 3 when training // output size will be 3 when training

View File

@ -26,7 +26,7 @@ class BroadcastObj : public OperatorObj {
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override { optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
return {{inputs[0]->getDims()}}; return {{inputs[0]->getDims()}};
}; };

View File

@ -22,7 +22,7 @@ class ConcatObj : public OperatorObj {
ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim); ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim);
OP_CLONE(ConcatObj); OP_CLONE(ConcatObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return inputs.size(); } int numInputs() const override { return inputs.size(); }

View File

@ -142,7 +142,7 @@ class ConvObj : public ConvBaseObj {
ActType act = ActType::None); ActType act = ActType::None);
OP_CLONE(ConvObj); OP_CLONE(ConvObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
int getNumGroups() const override { return c / getChannelPerGroup(); } int getNumGroups() const override { return c / getChannelPerGroup(); }
private: private:
@ -164,7 +164,7 @@ class ConvBackwardFilterObj : public ConvBaseObj {
int sh = 1, int sw = 1, int dh = 1, int dw = 1, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
Tensor bias = nullptr, ActType act = ActType::None); Tensor bias = nullptr, ActType act = ActType::None);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
ActType getAct() const { return act; } ActType getAct() const { return act; }
int getNumGroups() const override { return c / getChannelPerGroup(); } int getNumGroups() const override { return c / getChannelPerGroup(); }
@ -191,7 +191,7 @@ class ConvTransposed2dObj : public ConvBaseObj {
Tensor bias = nullptr, ActType act = ActType::None); Tensor bias = nullptr, ActType act = ActType::None);
OP_CLONE(ConvTransposed2dObj); OP_CLONE(ConvTransposed2dObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
int getNumGroups() const override { return group; } int getNumGroups() const override { return group; }
std::pair<int, int> getOutputPadding() const { return {oph, opw}; } std::pair<int, int> getOutputPadding() const { return {oph, opw}; }
@ -218,7 +218,7 @@ class ConvTransposed2dNHWCObj : public ConvBaseObj {
Tensor bias = nullptr, ActType act = ActType::None); Tensor bias = nullptr, ActType act = ActType::None);
OP_CLONE(ConvTransposed2dNHWCObj); OP_CLONE(ConvTransposed2dNHWCObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
int getNumGroups() const override { return group; } int getNumGroups() const override { return group; }
private: private:

View File

@ -7,7 +7,7 @@ class DetObj : public OperatorObj {
enum Mode { NormalDet = 0, LogDet }; enum Mode { NormalDet = 0, LogDet };
DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode); DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode);
OP_CLONE(DetObj); OP_CLONE(DetObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }

View File

@ -37,7 +37,7 @@ class DropoutObj : public OperatorObj {
DropoutObj(GraphObj *graph, Tensor data, Tensor output, Tensor mask, DropoutObj(GraphObj *graph, Tensor data, Tensor output, Tensor mask,
float ratio, bool training_mode); float ratio, bool training_mode);
OP_CLONE(DropoutObj); OP_CLONE(DropoutObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }

View File

@ -21,7 +21,7 @@ class ElementWiseObj : public OperatorObj {
*/ */
ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, Tensor input1, ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, Tensor input1,
Tensor output); Tensor output);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 2; } int numInputs() const override { return 2; }
@ -38,7 +38,7 @@ class MSELossObj : public OperatorObj {
MSELossObj(GraphObj *graph, Tensor input0, Tensor input1, MSELossObj(GraphObj *graph, Tensor input0, Tensor input1,
Reduction reduction, Tensor output); Reduction reduction, Tensor output);
OP_CLONE(MSELossObj); OP_CLONE(MSELossObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
Reduction getReduction() const { return reductionMode; } Reduction getReduction() const { return reductionMode; }
std::string toString() const override; std::string toString() const override;

View File

@ -21,7 +21,7 @@ class ExpandObj : public OperatorObj {
*/ */
ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims); ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
OP_CLONE(ExpandObj); OP_CLONE(ExpandObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }

View File

@ -23,7 +23,7 @@ class ExtendObj : public OperatorObj {
ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim, ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
int num = 1); int num = 1);
OP_CLONE(ExtendObj); OP_CLONE(ExtendObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }

View File

@ -39,7 +39,7 @@ class GatherObj : public GatherBaseObj {
int axis); int axis);
OP_CLONE(GatherObj); OP_CLONE(GatherObj);
std::string toString() const override; std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
vector<DataType> inferDataType(const TensorVec &inputs) const override; vector<DataType> inferDataType(const TensorVec &inputs) const override;
private: private:
@ -69,7 +69,7 @@ class GatherElementsObj : public GatherBaseObj {
Tensor output, int axis); Tensor output, int axis);
OP_CLONE(GatherElementsObj); OP_CLONE(GatherElementsObj);
std::string toString() const override; std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
vector<DataType> inferDataType(const TensorVec &inputs) const override; vector<DataType> inferDataType(const TensorVec &inputs) const override;
private: private:

View File

@ -45,7 +45,7 @@ class MatmulObj : public OperatorObj {
OP_CLONE(MatmulObj); OP_CLONE(MatmulObj);
std::string toString() const override; std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
int numInputs() const override { return inputs.size(); } int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }

View File

@ -21,7 +21,7 @@ class MemBoundObj : public OperatorObj {
OP_CLONE(MemBoundObj); OP_CLONE(MemBoundObj);
std::string toString() const override; std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
int numInputs() const override { return inputs.size(); } int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return outputs.size(); } int numOutputs() const override { return outputs.size(); }

View File

@ -27,7 +27,7 @@ class PadObj : public OperatorObj {
const vector<int> &pads, const optional<vector<int>> &axes); const vector<int> &pads, const optional<vector<int>> &axes);
OP_CLONE(PadObj); OP_CLONE(PadObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }

View File

@ -41,7 +41,7 @@ class PoolingObj : public OperatorObj {
int ceilMode); int ceilMode);
OP_CLONE(PoolingObj); OP_CLONE(PoolingObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }

View File

@ -23,7 +23,7 @@ class ReduceMeanObj : public OperatorObj {
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<vector<int>> &axes, bool keepDims = true); const optional<vector<int>> &axes, bool keepDims = true);
OP_CLONE(ReduceMeanObj); OP_CLONE(ReduceMeanObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }

View File

@ -9,6 +9,7 @@ namespace infini {
*/ */
class ReshapeObj : public OperatorObj { class ReshapeObj : public OperatorObj {
Shape dims; Shape dims;
Shape outputShape;
public: public:
/** /**
@ -17,18 +18,20 @@ class ReshapeObj : public OperatorObj {
* @param graph The computation graph that this operator belongs to. * @param graph The computation graph that this operator belongs to.
* @param input The input tensor. * @param input The input tensor.
* @param output The output tensor. * @param output The output tensor.
* @param dims The shape of the output tensor. * @param dims The shape to infer the output shape.
* @param outputShape The real shape of output tensor.
*/ */
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims); ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
OP_CLONE(ReshapeObj); OP_CLONE(ReshapeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }
inline Shape getShape() const { return dims; } inline Shape getShape() const { return outputShape; }
inline Shape getDims() const { return dims; }
private: private:
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;
@ -55,7 +58,7 @@ class FlattenObj : public OperatorObj {
FlattenObj(GraphObj *graph, Tensor input, Tensor output, int axis); FlattenObj(GraphObj *graph, Tensor input, Tensor output, int axis);
OP_CLONE(FlattenObj); OP_CLONE(FlattenObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
@ -85,7 +88,7 @@ class IdentityObj : public OperatorObj {
IdentityObj(GraphObj *graph, Tensor input, Tensor output); IdentityObj(GraphObj *graph, Tensor input, Tensor output);
OP_CLONE(IdentityObj); OP_CLONE(IdentityObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }

View File

@ -60,7 +60,7 @@ class ResizeObj : public OperatorObj {
// Operator clone(TensorVec inputs, TensorVec outputs) override; // Operator clone(TensorVec inputs, TensorVec outputs) override;
vector<DataType> inferDataType(const TensorVec &inputs) const override; vector<DataType> inferDataType(const TensorVec &inputs) const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return inputs.size(); } int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }

View File

@ -32,7 +32,7 @@ class SliceObj : public OperatorObj {
const optional<vector<int>> &steps); const optional<vector<int>> &steps);
OP_CLONE(SliceObj); OP_CLONE(SliceObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
inline int numInputs() const override { return 1; } inline int numInputs() const override { return 1; }
inline int numOutputs() const override { return 1; } inline int numOutputs() const override { return 1; }

View File

@ -10,7 +10,7 @@ class SoftmaxObj : public OperatorObj {
OP_CLONE(SoftmaxObj); OP_CLONE(SoftmaxObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override { optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
return {{inputs[0]->getDims()}}; return {{inputs[0]->getDims()}};
}; };

View File

@ -37,7 +37,7 @@ class SplitObj : public OperatorObj {
int dim, const vector<int> &ratio); int dim, const vector<int> &ratio);
OP_CLONE(SplitObj); OP_CLONE(SplitObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }

View File

@ -7,7 +7,7 @@ class TransposeObj : public OperatorObj {
TransposeObj(GraphObj *graph, Tensor input, Tensor output, TransposeObj(GraphObj *graph, Tensor input, Tensor output,
vector<int> permute); vector<int> permute);
OP_CLONE(TransposeObj); OP_CLONE(TransposeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
@ -25,7 +25,7 @@ class DepthToSpaceObj : public OperatorObj {
DepthToSpaceObj(GraphObj *graph, Tensor input, Tensor output, int blocksize, DepthToSpaceObj(GraphObj *graph, Tensor input, Tensor output, int blocksize,
std::string mode); std::string mode);
OP_CLONE(DepthToSpaceObj); OP_CLONE(DepthToSpaceObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }

View File

@ -17,7 +17,7 @@ class UnaryObj : public OperatorObj {
* @param output The output tensor. * @param output The output tensor.
*/ */
UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output); UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
@ -33,7 +33,7 @@ class ClipObj : public OperatorObj {
ClipObj(GraphObj *graph, Tensor input, Tensor output, ClipObj(GraphObj *graph, Tensor input, Tensor output,
std::optional<float> min, std::optional<float> max); std::optional<float> min, std::optional<float> max);
OP_CLONE(ClipObj); OP_CLONE(ClipObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
std::optional<float> getMin() const { return minValue; }; std::optional<float> getMin() const { return minValue; };
@ -52,7 +52,7 @@ class HardtanhObj : public OperatorObj {
HardtanhObj(GraphObj *graph, Tensor input, Tensor output, float min, HardtanhObj(GraphObj *graph, Tensor input, Tensor output, float min,
float max); float max);
OP_CLONE(HardtanhObj); OP_CLONE(HardtanhObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
float getMin() const { return minValue; }; float getMin() const { return minValue; };
@ -70,7 +70,7 @@ class FlipObj : public OperatorObj {
public: public:
FlipObj(GraphObj *graph, Tensor input, Tensor output, vector<int> axis); FlipObj(GraphObj *graph, Tensor input, Tensor output, vector<int> axis);
OP_CLONE(FlipObj); OP_CLONE(FlipObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
vector<int> getAxis() const { return axisValue; }; vector<int> getAxis() const { return axisValue; };
@ -87,7 +87,7 @@ class FillObj : public OperatorObj {
public: public:
FillObj(GraphObj *graph, Tensor input, Tensor output, float value); FillObj(GraphObj *graph, Tensor input, Tensor output, float value);
OP_CLONE(FillObj); OP_CLONE(FillObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
float getValue() const { return setValue; }; float getValue() const { return setValue; };
@ -104,7 +104,7 @@ class L2LossObj : public OperatorObj {
public: public:
L2LossObj(GraphObj *graph, Tensor input, Tensor output); L2LossObj(GraphObj *graph, Tensor input, Tensor output);
OP_CLONE(L2LossObj); OP_CLONE(L2LossObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
@ -120,7 +120,7 @@ class TransformObj : public OperatorObj {
TransformObj(GraphObj *graph, Tensor input, Tensor output, float alpha, TransformObj(GraphObj *graph, Tensor input, Tensor output, float alpha,
float beta); float beta);
OP_CLONE(TransformObj); OP_CLONE(TransformObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
float getAlpha() const { return alphaValue; } float getAlpha() const { return alphaValue; }
@ -165,7 +165,7 @@ class CastObj : public OperatorObj {
public: public:
CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type); CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type);
OP_CLONE(CastObj); OP_CLONE(CastObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
vector<DataType> inferDataType(const TensorVec &inputs) const override; vector<DataType> inferDataType(const TensorVec &inputs) const override;
std::string toString() const override; std::string toString() const override;
@ -185,7 +185,7 @@ class CumsumObj : public OperatorObj {
CumsumObj(GraphObj *graph, Tensor input, Tensor output, int axis, CumsumObj(GraphObj *graph, Tensor input, Tensor output, int axis,
bool exclusive, bool reverse); bool exclusive, bool reverse);
OP_CLONE(CumsumObj); OP_CLONE(CumsumObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int getAxis() const { return axisValue; } int getAxis() const { return axisValue; }
@ -205,7 +205,7 @@ class ShapeObj : public OperatorObj {
public: public:
ShapeObj(GraphObj *graph, Tensor input, Tensor output); ShapeObj(GraphObj *graph, Tensor input, Tensor output);
OP_CLONE(ShapeObj); OP_CLONE(ShapeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 1; } int numInputs() const override { return 1; }
@ -216,7 +216,7 @@ class PReluObj : public OperatorObj {
public: public:
PReluObj(GraphObj *graph, Tensor input, Tensor alpha, Tensor output); PReluObj(GraphObj *graph, Tensor input, Tensor alpha, Tensor output);
OP_CLONE(PReluObj); OP_CLONE(PReluObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return 2; } int numInputs() const override { return 2; }
@ -236,7 +236,7 @@ class LogObj : public OperatorObj {
}; };
LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type); LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type);
OP_CLONE(LogObj); OP_CLONE(LogObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
LogType getType() const { return logType; } LogType getType() const { return logType; }

View File

@ -22,7 +22,7 @@ class WhereObj : public OperatorObj {
Tensor output); Tensor output);
OP_CLONE(WhereObj); OP_CLONE(WhereObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override; std::string toString() const override;
int numInputs() const override { return inputs.size(); } int numInputs() const override { return inputs.size(); }

View File

@ -510,19 +510,11 @@ class OnnxStub:
mode, mode,
) )
elif node.op_type == "Reshape": elif node.op_type == "Reshape":
dims = _search_shape(model, node.input[0]) shape = _parse_data(data[node.input[1]])
size = reduce(lambda acc, x: acc * x, dims)
input_shape = _parse_data(data[node.input[1]])
for i, x in enumerate(input_shape):
if x == 0:
input_shape[i] = dims[i]
temp = reduce(lambda acc, x: acc * x, input_shape, 1)
if temp < 0:
input_shape[input_shape.index(-1)] = size // -temp
tensors[node.output[0]] = self.handler.reshape( tensors[node.output[0]] = self.handler.reshape(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
input_shape, shape,
) )
elif node.op_type == "Squeeze": elif node.op_type == "Squeeze":
input_shape = _search_shape(model, node.input[0]) input_shape = _search_shape(model, node.input[0])
@ -1112,6 +1104,26 @@ class OnnxStub:
def optimize(self) -> None: def optimize(self) -> None:
self.handler.optimize() self.handler.optimize()
def clone_KV(self, tensor: backend.Tensor) -> backend.Tensor:
return self.handler.clone_KV(tensor)
def free_heap(self) -> None:
self.handler.free_heap()
def set_input(self, inputShapes: List[int]) -> None:
for newInput, oldInput in zip(inputShapes, self.inputs):
oldTensor = self.inputs[oldInput]
self.handler.change_shape(newInput, oldTensor.fuid())
self.handler.shape_infer()
self.handler.data_malloc()
def getShape(self, name: str) -> List[int]:
if name in self.inputs:
ans = self.handler.getDims(self.inputs[name])
else:
ans = self.handler.getDims(self.outputs[name])
return ans
def tune(self) -> None: def tune(self) -> None:
self.handler.tune() self.handler.tune()

View File

@ -209,6 +209,7 @@ class TestStringMethods(unittest.TestCase):
make_and_import_model(make_graph([relu], "relu", [x], [y])) make_and_import_model(make_graph([relu], "relu", [x], [y]))
"""Gelu operator is not supported by onnx 14.1 currently.""" """Gelu operator is not supported by onnx 14.1 currently."""
def test_gelu(self): def test_gelu(self):
pass pass
# x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) # x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
@ -500,5 +501,22 @@ class TestStringMethods(unittest.TestCase):
make_and_import_model(make_graph([where], "where", [x, y, con], [output])) make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
class TestDynamicTensor(unittest.TestCase):
def test_dynamic_tensor(self):
filename = r"resnet18-v2-7.onnx"
current_path = os.getcwd()
model_file = ""
for root, dirs, files in os.walk(current_path):
if filename in files:
model_file = os.path.join(root, filename)
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime())
output_key = list(model.outputs.keys())[0]
old_output_shape = model.getShape(output_key)
self.assertEqual(old_output_shape, ([1, 1000]))
model.set_input([[5, 3, 224, 224]])
new_output_shape = model.getShape(output_key)
self.assertEqual(new_output_shape, ([5, 1000]))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,5 +1,7 @@
#include "core/graph.h" #include "core/graph.h"
#include "operators/reshape.h"
#include <algorithm> #include <algorithm>
#include <numeric>
#include <queue> #include <queue>
namespace infini { namespace infini {
@ -123,10 +125,40 @@ void GraphObj::optimize() {
} }
} }
void GraphObj::dataMalloc(bool useNaiveAllocator) { Tensor GraphObj::getTensor(int fuid) const {
for (auto tensor : tensors) {
if (tensor->getFuid() == fuid) {
return tensor;
}
}
return nullptr;
}
void GraphObj::shape_infer() {
for (auto &op : ops) {
auto ans = op->inferShape();
IT_ASSERT(ans.has_value());
auto oldOutputs = op->getOutputs();
IT_ASSERT(ans.value().size() == oldOutputs.size());
// replace the old outputshape and size with new one
for (int i = 0; i < (int)ans.value().size(); ++i) {
auto newShape = ans.value()[i];
auto oldShape = oldOutputs[i]->getDims();
auto fuid = oldOutputs[i]->getFuid();
if (newShape != oldShape) {
auto tensor = this->getTensor(fuid);
tensor->setShape(newShape);
}
}
}
}
void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
// topological sorting first // topological sorting first
IT_ASSERT(topo_sort() == true); IT_ASSERT(topo_sort() == true);
if (useNaiveAllocator) { if (useNaiveAllocator) {
// can not set memory pool when use naive allocator
IT_ASSERT(memPoolSize == 0);
// used for debugging memory out-of-bounds access, tensors will not be // used for debugging memory out-of-bounds access, tensors will not be
// released correctly // released correctly
// note: behavior may not match running in non-naive mode, and it may // note: behavior may not match running in non-naive mode, and it may
@ -136,6 +168,9 @@ void GraphObj::dataMalloc(bool useNaiveAllocator) {
} }
return; return;
} }
if (memPoolSize > 0) {
allocator.setMemPool(memPoolSize);
}
// count the number of times all tensors are used // count the number of times all tensors are used
std::unordered_map<TensorObj *, size_t> tensorToRefCount; std::unordered_map<TensorObj *, size_t> tensorToRefCount;
// record the memory address offsets of all tensors to be allocated // record the memory address offsets of all tensors to be allocated
@ -222,6 +257,27 @@ void GraphObj::dataMalloc(bool useNaiveAllocator) {
} }
} }
Tensor GraphObj::cloneKV(Tensor &tensor) {
auto obj = tensor->clone();
if (allocator.getMemPoolStatus()) {
if (tensor->hasData()) {
obj->setDataBlob(make_ref<BlobObj>(
tensor->runtime,
static_cast<uint8_t *>(allocator.getHeapPtr()) +
allocator.heapAlloc(tensor->getBytes())));
obj->copyData(tensor);
}
} else {
if (tensor->hasData()) {
obj->dataMalloc();
obj->copyData(tensor);
}
}
return obj;
}
void GraphObj::freeHeap() { this->allocator.freeHeap(); }
Tensor GraphObj::addTensor(Shape dim, DataType dtype) { Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
return tensors.emplace_back(make_ref<TensorObj>(dim, dtype, runtime)); return tensors.emplace_back(make_ref<TensorObj>(dim, dtype, runtime));
} }

View File

@ -20,6 +20,7 @@
#include "operators/transpose.h" #include "operators/transpose.h"
#include "operators/unary.h" #include "operators/unary.h"
#include "operators/where.h" #include "operators/where.h"
#include <numeric>
namespace infini { namespace infini {
@ -555,4 +556,11 @@ static DataType dtype_repr_convert(int dtype) {
} }
} }
void GraphHandlerObj::change_shape(const vector<int> &shape, int tensorId) {
auto tensor = g->getTensor(tensorId);
IT_ASSERT(tensor != nullptr);
IT_ASSERT(shape.size() != 0);
tensor->setShape(shape);
}
} // namespace infini } // namespace infini

View File

@ -30,6 +30,9 @@ LazyAllocator::~LazyAllocator() {
if (this->weightPtr != nullptr) { if (this->weightPtr != nullptr) {
runtime->dealloc(this->weightPtr); runtime->dealloc(this->weightPtr);
} }
if (this->memPoolPtr != nullptr) {
runtime->dealloc(this->memPoolPtr);
}
} }
void LazyAllocator::init() { void LazyAllocator::init() {
@ -44,6 +47,17 @@ void LazyAllocator::init() {
this->ptr = nullptr; this->ptr = nullptr;
} }
void LazyAllocator::setMemPool(size_t memPoolSize) {
IT_ASSERT(memPoolSize > 0);
if (!this->hasMemPool) {
this->hasMemPool = true;
this->memPoolSize = memPoolSize;
this->memPoolPtr = runtime->alloc(memPoolSize);
}
}
bool LazyAllocator::getMemPoolStatus() { return this->hasMemPool; }
size_t LazyAllocator::alloc(size_t size) { size_t LazyAllocator::alloc(size_t size) {
// pad the size to the multiple of alignment // pad the size to the multiple of alignment
size = this->getAlignedSize(size); size = this->getAlignedSize(size);
@ -102,6 +116,17 @@ size_t LazyAllocator::allocWeight(size_t size) {
return retAddr; return retAddr;
} }
size_t LazyAllocator::heapAlloc(size_t size) {
size = this->getAlignedSize(size);
this->heapPeak += size;
IT_ASSERT(this->memPoolSize >=
this->weightPeak + this->peak + this->heapPeak);
size_t retAddr = this->memPoolSize - this->heapPeak;
return retAddr;
}
void LazyAllocator::freeHeap() { this->heapPeak = 0; }
void LazyAllocator::free(size_t addr, size_t size) { void LazyAllocator::free(size_t addr, size_t size) {
IT_ASSERT(this->ptr == nullptr); IT_ASSERT(this->ptr == nullptr);
size = getAlignedSize(size); size = getAlignedSize(size);
@ -143,6 +168,7 @@ void LazyAllocator::free(size_t addr, size_t size) {
} }
void *LazyAllocator::getPtr() { void *LazyAllocator::getPtr() {
if (!hasMemPool) {
if (this->ptr == nullptr) { if (this->ptr == nullptr) {
this->ptr = runtime->alloc(this->peak); this->ptr = runtime->alloc(this->peak);
// #ifdef DEBUG_MODE // #ifdef DEBUG_MODE
@ -151,17 +177,31 @@ void *LazyAllocator::getPtr() {
// #endif // #endif
} }
return this->ptr; return this->ptr;
} else {
IT_ASSERT(this->memPoolSize >= this->weightPeak + this->peak);
return static_cast<uint8_t *>(this->memPoolPtr) + weightPeak;
}
} }
void *LazyAllocator::getWeightPtr() { void *LazyAllocator::getWeightPtr() {
if (!hasMemPool) {
if (this->weightPtr == nullptr) { if (this->weightPtr == nullptr) {
this->weightPtr = runtime->alloc(this->weightPeak); this->weightPtr = runtime->alloc(this->weightPeak);
// #ifdef DEBUG_MODE // #ifdef DEBUG_MODE
// printf("LazyAllocator really alloc weight: %p %lu bytes\n", // printf("LazyAllocator really alloc weight: %p %lu
// bytes\n",
// this->weightPtr, weightPeak); // this->weightPtr, weightPeak);
// #endif // #endif
} }
return this->weightPtr; return this->weightPtr;
} else {
return this->memPoolPtr;
}
}
void *LazyAllocator::getHeapPtr() {
IT_ASSERT(hasMemPool);
return this->memPoolPtr;
} }
size_t LazyAllocator::getAlignedSize(size_t size) { size_t LazyAllocator::getAlignedSize(size_t size) {

View File

@ -77,9 +77,7 @@ bool OperatorObj::checkValid(GraphObj *graph) {
return true; return true;
} }
optional<vector<Shape>> OperatorObj::inferShape() const { optional<vector<Shape>> OperatorObj::inferShape() { return inferShape(inputs); }
return inferShape(inputs);
}
vector<DataType> OperatorObj::inferDataType(const TensorVec &inputs) const { vector<DataType> OperatorObj::inferDataType(const TensorVec &inputs) const {
auto dataType = inputs[0]->getDType(); auto dataType = inputs[0]->getDType();

View File

@ -59,6 +59,13 @@ Shape TensorObj::getStride() const {
return stride; return stride;
} }
void TensorObj::setShape(Shape shape_) {
shape = shape_;
size_t size = std::accumulate(shape.begin(), shape.end(), 1,
[](auto acc, auto x) { return acc * x; });
_size = size;
}
void TensorObj::printData() const { void TensorObj::printData() const {
IT_ASSERT(data != nullptr); IT_ASSERT(data != nullptr);
if (!runtime->isCpu()) if (!runtime->isCpu())

View File

@ -446,7 +446,10 @@ void init_graph_builder(py::module &m) {
}) })
.def("has_target", &TensorObj::hasTarget, policy::automatic) .def("has_target", &TensorObj::hasTarget, policy::automatic)
.def("src", &TensorObj::getSource, policy::move) .def("src", &TensorObj::getSource, policy::move)
.def("printData", &TensorObj::printData, policy::automatic); .def("printData", &TensorObj::printData, policy::automatic)
.def("copy_data",
py::overload_cast<const Tensor &>(&TensorObj::copyData),
policy::move);
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator") py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
.def("op_type", &OperatorObj::getOpType, policy::automatic) .def("op_type", &OperatorObj::getOpType, policy::automatic)
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_), .def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
@ -466,6 +469,7 @@ void init_graph_builder(py::module &m) {
.def("add", &Handler::add, policy::move) .def("add", &Handler::add, policy::move)
.def("sub", &Handler::sub, policy::move) .def("sub", &Handler::sub, policy::move)
.def("mul", &Handler::mul, policy::move) .def("mul", &Handler::mul, policy::move)
.def("max", &Handler::max, policy::move)
.def("div", &Handler::div, policy::move) .def("div", &Handler::div, policy::move)
.def("pow", &Handler::pow, policy::move) .def("pow", &Handler::pow, policy::move)
.def("min", &Handler::min, policy::move) .def("min", &Handler::min, policy::move)
@ -510,10 +514,17 @@ void init_graph_builder(py::module &m) {
.def("topo_sort", &Handler::topo_sort, policy::automatic) .def("topo_sort", &Handler::topo_sort, policy::automatic)
.def("optimize", &Handler::optimize, policy::automatic) .def("optimize", &Handler::optimize, policy::automatic)
.def("operators", &Handler::operators, policy::move) .def("operators", &Handler::operators, policy::move)
.def("data_malloc", &Handler::data_malloc, policy::automatic) .def("data_malloc", &Handler::data_malloc,
py::arg("useNaiveAllocator") = false, py::arg("memPoolSize") = 0,
policy::automatic)
.def("clone_KV", &Handler::clone_KV, policy::move)
.def("free_heap", &Handler::free_heap, policy::move)
.def("get_perf_time", &Handler::get_perf_time, policy::automatic) .def("get_perf_time", &Handler::get_perf_time, policy::automatic)
.def("tune", &Handler::tune, policy::automatic) .def("tune", &Handler::tune, policy::automatic)
.def("run", &Handler::run, policy::automatic) .def("run", &Handler::run, policy::automatic)
.def("shape_infer", &Handler::shape_infer, policy::automatic)
.def("change_shape", &Handler::change_shape, policy::automatic)
.def("getDims", &Handler::getDims, policy::automatic)
.def("get_perf_time", &Handler::get_perf_time, policy::automatic); .def("get_perf_time", &Handler::get_perf_time, policy::automatic);
} }

View File

@ -44,7 +44,6 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size())); std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size())); std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size())); std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
// get inputs // get inputs
checkCudnnError(cudnnCreateTensorDescriptor(&aDesc)); checkCudnnError(cudnnCreateTensorDescriptor(&aDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(aDesc, CUDNN_TENSOR_NCHW, checkCudnnError(cudnnSetTensor4dDescriptor(aDesc, CUDNN_TENSOR_NCHW,
@ -110,9 +109,9 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto op = as<ElementWiseObj>(_op); auto op = as<ElementWiseObj>(_op);
float *const aData = (op->getInputs(0)->getRawDataPtr<float *>()); void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
float *const bData = (op->getInputs(1)->getRawDataPtr<float *>()); void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
float *const cData = (op->getOutput()->getRawDataPtr<float *>()); void *const cData = (op->getOutput()->getRawDataPtr<void *>());
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
@ -134,7 +133,13 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
else if (op->getOpType() == OpType::Pow) else if (op->getOpType() == OpType::Pow)
pow_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], pow_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
b[2], b[3], c[0], c[1], c[2], c[3]); b[2], b[3], c[0], c[1], c[2], c[3]);
else else if (op->getOpType() == OpType::Add) {
add_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
b[2], b[3], c[0], c[1], c[2], c[3]);
} else if (op->getOpType() == OpType::Less) {
less_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
b[2], b[3], c[0], c[1], c[2], c[3]);
} else
IT_TODO_HALT(); IT_TODO_HALT();
} }
}; };
@ -152,6 +157,10 @@ REGISTER_KERNEL(Device::CUDA, OpType::Max, DataType::Float32, MaxCudnn,
REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda, REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda,
"Div_CUDA_Float32"); "Div_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Int64, ElementWiseCuda,
"Add_CUDA_Int64");
REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda, REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda,
"Pow__CUDA_Float32"); "Pow__CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Less, DataType::Int64, ElementWiseCuda,
"Less__CUDA_Int64");
}; // namespace infini }; // namespace infini

View File

@ -5,9 +5,10 @@ constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; } constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); } constexpr int block_work_size() { return thread_work_size() * num_threads(); }
__global__ void _div_kernel(float *x, float *y, float *z, int a0, int a1, template <class T>
int a2, int a3, int b0, int b1, int b2, int b3, __global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int c0, int c1, int c2, int c3) { int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x; int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3; int n = c0 * c1 * c2 * c3;
@ -27,16 +28,17 @@ __global__ void _div_kernel(float *x, float *y, float *z, int a0, int a1,
int b1_index = c1_index % b1; int b1_index = c1_index % b1;
int b2_index = c2_index % b2; int b2_index = c2_index % b2;
int b3_index = c3_index % b3; int b3_index = c3_index % b3;
z[i] = x[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + a2_index * a3 + ((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a3_index] / a2_index * a3 + a3_index] /
y[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + b2_index * b3 + ((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
b3_index]; b2_index * b3 + b3_index];
} }
} }
__global__ void _pow_kernel(float *x, float *y, float *z, int a0, int a1, template <class T>
int a2, int a3, int b0, int b1, int b2, int b3, __global__ void _add_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int c0, int c1, int c2, int c3) { int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x; int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3; int n = c0 * c1 * c2 * c3;
@ -56,32 +58,115 @@ __global__ void _pow_kernel(float *x, float *y, float *z, int a0, int a1,
int b1_index = c1_index % b1; int b1_index = c1_index % b1;
int b2_index = c2_index % b2; int b2_index = c2_index % b2;
int b3_index = c3_index % b3; int b3_index = c3_index % b3;
z[i] = pow(x[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + ((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index] +
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
b2_index * b3 + b3_index];
}
}
template <class T>
__global__ void _pow_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) {
int c0_index = i / (c1 * c2 * c3);
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
int a0_index = c0_index % a0;
int a1_index = c1_index % a1;
int a2_index = c2_index % a2;
int a3_index = c3_index % a3;
int b0_index = c0_index % b0;
int b1_index = c1_index % b1;
int b2_index = c2_index % b2;
int b3_index = c3_index % b3;
((T *)z)[i] =
pow(((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index], a2_index * a3 + a3_index],
y[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + ((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
b2_index * b3 + b3_index]); b2_index * b3 + b3_index]);
} }
} }
template <class T>
__global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0,
int c1, int c2, int c3) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
int n = c0 * c1 * c2 * c3;
for (int i = index; i < n; i += stride) {
int c0_index = i / (c1 * c2 * c3);
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
int a0_index = c0_index % a0;
int a1_index = c1_index % a1;
int a2_index = c2_index % a2;
int a3_index = c3_index % a3;
int b0_index = c0_index % b0;
int b1_index = c1_index % b1;
int b2_index = c2_index % b2;
int b3_index = c3_index % b3;
((bool *)z)[i] =
((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
a2_index * a3 + a3_index] <
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
b2_index * b3 + b3_index]
? true
: false;
}
}
namespace infini { namespace infini {
void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3, void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) { int c3) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_div_kernel<<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2, _div_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1,
b3, c0, c1, c2, c3); b2, b3, c0, c1, c2, c3);
} }
void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3, void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) {
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
_add_kernel<int64_t><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
b1, b2, b3, c0, c1, c2, c3);
}
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) { int c3) {
int blocksize = block_work_size(); int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3; int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_pow_kernel<<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2, _pow_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1,
b3, c0, c1, c2, c3); b2, b3, c0, c1, c2, c3);
}
void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) {
int blocksize = block_work_size();
int num = c0 * c1 * c2 * c3;
int gridsize = (num + block_work_size() - 1) / block_work_size();
_less_kernel<int64_t><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
b1, b2, b3, c0, c1, c2, c3);
} }
}; // namespace infini }; // namespace infini

View File

@ -16,8 +16,9 @@ class PadSliceCudaCompute {
metadata.partNDim[i] = partTensor->getDims()[i]; metadata.partNDim[i] = partTensor->getDims()[i];
metadata.partStride[i] = partTensor->getStride()[i]; metadata.partStride[i] = partTensor->getStride()[i];
} }
pad_slice_kernel(partTensor->getRawDataPtr<float *>(), metadata.DType = partTensor->getDType().getIndex();
wholeTensor->getRawDataPtr<float *>(), metadata, nDims, pad_slice_kernel(partTensor->getRawDataPtr<void *>(),
wholeTensor->getRawDataPtr<void *>(), metadata, nDims,
wholeTensor->size(), isPad); wholeTensor->size(), isPad);
} }
}; };
@ -40,6 +41,8 @@ class SliceCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig {
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Float32, SliceCuda, REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Float32, SliceCuda,
"Slice__CUDA_Float32"); "Slice__CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Int64, SliceCuda,
"Slice__CUDA_Int64");
REGISTER_KERNEL(Device::CUDA, OpType::Pad, DataType::Float32, PadCuda, REGISTER_KERNEL(Device::CUDA, OpType::Pad, DataType::Float32, PadCuda,
"Pad__CUDA_Float32"); "Pad__CUDA_Float32");
} // namespace infini } // namespace infini

View File

@ -1,3 +1,4 @@
#include "core/data_type.h"
#include "cuda/cuda_common.h" #include "cuda/cuda_common.h"
#include "cuda/cuda_pad_slice.h" #include "cuda/cuda_pad_slice.h"
@ -19,9 +20,9 @@ __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
return offset; return offset;
} }
__global__ void _pad_slice_kernel(float *part, float *whole, template <typename T>
TransMetaData metaData, int nDims, int num, __global__ void _pad_slice_kernel(T *part, T *whole, TransMetaData metaData,
bool isPad) { int nDims, int num, bool isPad) {
int tid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num) if (tid >= num)
return; return;
@ -41,12 +42,18 @@ __global__ void _pad_slice_kernel(float *part, float *whole,
} }
namespace infini { namespace infini {
void pad_slice_kernel(float *partData, float *wholeData, void pad_slice_kernel(void *partData, void *wholeData,
const TransMetaData &metadata, int nDims, int num, const TransMetaData &metadata, int nDims, int num,
bool isPad) { bool isPad) {
int blockSize = 32 * 16; int blockSize = 32 * 16;
int gridSize = (num + blockSize - 1) / blockSize; int gridSize = (num + blockSize - 1) / blockSize;
_pad_slice_kernel<<<gridSize, blockSize>>>(partData, wholeData, metadata, if (metadata.DType == DataType::Int64.getIndex()) {
nDims, num, isPad); _pad_slice_kernel<int64_t>
<<<gridSize, blockSize>>>((int64_t *)partData, (int64_t *)wholeData,
metadata, nDims, num, isPad);
} else if (metadata.DType == DataType::Float32.getIndex()) {
_pad_slice_kernel<float><<<gridSize, blockSize>>>(
(float *)partData, (float *)wholeData, metadata, nDims, num, isPad);
}
} }
} // namespace infini } // namespace infini

View File

@ -59,6 +59,21 @@ class CudaCompute {
class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig { class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
auto inputs = _op->getInputs();
if (inputs.size() == 2) {
for (size_t i = 0; i < 2; i++) {
if (inputs[i]->size() == 0) {
auto inData =
_op->getInputs(1 - i)->getRawDataPtr<void *>();
auto outData =
_op->getOutputs()[0]->getRawDataPtr<void *>();
cudaMemcpyAsync(outData, inData,
_op->getInputs(1 - i)->getBytes(),
cudaMemcpyDeviceToDevice);
return;
}
}
}
do_compute(_op->getOutput(), _op->getInputs(), do_compute(_op->getOutput(), _op->getInputs(),
as<ConcatObj>(_op)->getDim(), _op->getOutput()->getRank(), as<ConcatObj>(_op)->getDim(), _op->getOutput()->getRank(),
false); false);

View File

@ -20,15 +20,18 @@ string G2BMMObj::toString() const {
return os.str(); return os.str();
} }
optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) {
auto A = inputs[0], B = inputs[1]; auto A = inputs[0], B = inputs[1];
b = A->getDims()[0];
m = A->getDims()[1];
k = A->getDims()[2];
IT_ASSERT(A->getRank() == 3 && B->getRank() == 3); IT_ASSERT(A->getRank() == 3 && B->getRank() == 3);
IT_ASSERT(A->getDims()[0] == B->getDims()[0]); IT_ASSERT(A->getDims()[0] == B->getDims()[0]);
IT_ASSERT(A->getDims()[1] == B->getDims()[1]); IT_ASSERT(A->getDims()[1] == B->getDims()[1]);
IT_ASSERT(A->getDims()[2] == B->getDims()[2]); IT_ASSERT(A->getDims()[2] == B->getDims()[2]);
IT_ASSERT(width >= 0); IT_ASSERT(width >= 0);
int b(A->getDims()[0]), m(A->getDims()[1]), n(2 * width + 1); int n(2 * width + 1);
return {{{b, m, n}}}; return {{{b, m, n}}};
} }

View File

@ -21,15 +21,18 @@ string GBMMObj::toString() const {
return os.str(); return os.str();
} }
optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) {
auto A = inputs[0], B = inputs[1]; auto A = inputs[0], B = inputs[1];
b = A->getDims()[0];
m = A->getDims()[1];
w = (A->getDims()[2] - 1) / 2;
n = B->getDims()[2];
IT_ASSERT(A->getRank() == 3 && B->getRank() == 3); IT_ASSERT(A->getRank() == 3 && B->getRank() == 3);
IT_ASSERT(A->getDims()[0] == B->getDims()[0]); IT_ASSERT(A->getDims()[0] == B->getDims()[0]);
IT_ASSERT(A->getDims()[1] == B->getDims()[1]); IT_ASSERT(A->getDims()[1] == B->getDims()[1]);
IT_ASSERT(A->getDims()[2] % 2 != 0); IT_ASSERT(A->getDims()[2] % 2 != 0);
int b(A->getDims()[0]), m(A->getDims()[1]), k(B->getDims()[2]); return {{{b, m, n}}};
return {{{b, m, k}}};
} }
vector<int> GBMMObj::getWorkloadVector() const { vector<int> GBMMObj::getWorkloadVector() const {

View File

@ -9,7 +9,7 @@ ActivationBackwardObj::ActivationBackwardObj(OpType type, GraphObj *graph,
} }
optional<vector<Shape>> optional<vector<Shape>>
ActivationBackwardObj::inferShape(const TensorVec &inputs) const { ActivationBackwardObj::inferShape(const TensorVec &inputs) {
return {{inputs[0]->getDims()}}; return {{inputs[0]->getDims()}};
} }

View File

@ -10,8 +10,7 @@ AllGatherObj::AllGatherObj(GraphObj *graph, Tensor input,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> optional<vector<Shape>> AllGatherObj::inferShape(const TensorVec &inputs) {
AllGatherObj::inferShape(const TensorVec &inputs) const {
Shape input_shape = inputs[0]->getDims(); Shape input_shape = inputs[0]->getDims();
vector<Shape> output_shapes(getWorldSize(), input_shape); vector<Shape> output_shapes(getWorldSize(), input_shape);
return output_shapes; return output_shapes;

View File

@ -18,7 +18,7 @@ AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
} }
optional<vector<Shape>> optional<vector<Shape>>
AttentionKVCacheObj::inferShape(const TensorVec &inputs) const { AttentionKVCacheObj::inferShape(const TensorVec &inputs) {
IT_ASSERT(inputs.size() == 6); IT_ASSERT(inputs.size() == 6);
Shape dims = inputs[0]->getDims(); Shape dims = inputs[0]->getDims();
ShapeElem n = dims.at(dim); ShapeElem n = dims.at(dim);

View File

@ -13,8 +13,7 @@ BatchNormObj::BatchNormObj(GraphObj *graph, Tensor input, Tensor output,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> optional<vector<Shape>> BatchNormObj::inferShape(const TensorVec &inputs) {
BatchNormObj::inferShape(const TensorVec &inputs) const {
auto input = inputs[0]; auto input = inputs[0];
auto mean = inputs[1]; auto mean = inputs[1];
auto var = inputs[2]; auto var = inputs[2];

View File

@ -9,9 +9,16 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) {
Shape dims = inputs[0]->getDims(); Shape dims = inputs[0]->getDims();
auto rank = inputs[0]->getRank(); auto rank = inputs[0]->getRank();
if (inputs.size() == 2) {
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i]->size() == 0) {
return {{inputs[1 - i]->getDims()}};
}
}
}
ShapeElem n = dims.at(dim); ShapeElem n = dims.at(dim);
for (auto itr = inputs.begin() + 1; itr != inputs.end(); ++itr) { for (auto itr = inputs.begin() + 1; itr != inputs.end(); ++itr) {
auto input = *itr; auto input = *itr;

View File

@ -82,14 +82,15 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) {
const auto &input = inputs[0], &weight = inputs[1]; const auto &input = inputs[0], &weight = inputs[1];
auto n = input->getDims()[0]; n = input->getDims()[0];
auto h = input->getDims()[2]; c = input->getDims()[1];
auto w = input->getDims()[3]; h = input->getDims()[2];
auto f = weight->getDims()[0]; w = input->getDims()[3];
auto r = weight->getDims()[2]; f = weight->getDims()[0];
auto s = weight->getDims()[3]; r = weight->getDims()[2];
s = weight->getDims()[3];
int on = n, oc = f; int on = n, oc = f;
int oh = 0, ow = 0; int oh = 0, ow = 0;
// For NCHW+FCRS layout, C of input is divisable by C of weight // For NCHW+FCRS layout, C of input is divisable by C of weight
@ -141,15 +142,15 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
} }
optional<vector<Shape>> optional<vector<Shape>>
ConvTransposed2dObj::inferShape(const TensorVec &inputs) const { ConvTransposed2dObj::inferShape(const TensorVec &inputs) {
const Tensor &input = inputs[0], &weight = inputs[1]; const Tensor &input = inputs[0], &weight = inputs[1];
auto n = input->getDims()[0]; n = input->getDims()[0];
auto f = input->getDims()[1]; f = input->getDims()[1];
auto h = input->getDims()[2]; h = input->getDims()[2];
auto w = input->getDims()[3]; w = input->getDims()[3];
auto c = weight->getDims()[1]; c = weight->getDims()[1];
auto r = weight->getDims()[2]; r = weight->getDims()[2];
auto s = weight->getDims()[3]; s = weight->getDims()[3];
IT_ASSERT(f == weight->getDims()[0]); IT_ASSERT(f == weight->getDims()[0]);
int on = n, oc = c * group; int on = n, oc = c * group;
@ -219,14 +220,15 @@ ConvBackwardFilterObj::ConvBackwardFilterObj(GraphObj *graph, Tensor inputX,
} }
optional<vector<Shape>> optional<vector<Shape>>
ConvBackwardFilterObj::inferShape(const TensorVec &inputs) const { ConvBackwardFilterObj::inferShape(const TensorVec &inputs) {
const auto &inputX = inputs[0], &diffY = inputs[1]; const auto &inputX = inputs[0], &diffY = inputs[1];
auto n = inputX->getDims()[0]; n = inputX->getDims()[0];
auto h = inputX->getDims()[2]; c = inputX->getDims()[1];
auto w = inputX->getDims()[3]; h = inputX->getDims()[2];
auto f = diffY->getDims()[0]; w = inputX->getDims()[3];
auto r = diffY->getDims()[2]; f = diffY->getDims()[0];
auto s = diffY->getDims()[3]; r = diffY->getDims()[2];
s = diffY->getDims()[3];
int on = n, oc = f; int on = n, oc = f;
int oh = 0, ow = 0; int oh = 0, ow = 0;
// For NCHW+FCRS layout, C of input is divisable by C of weight // For NCHW+FCRS layout, C of input is divisable by C of weight
@ -280,17 +282,16 @@ ConvTransposed2dNHWCObj::ConvTransposed2dNHWCObj(GraphObj *graph, Tensor input,
} }
optional<vector<Shape>> optional<vector<Shape>>
ConvTransposed2dNHWCObj::inferShape(const TensorVec &inputs) const { ConvTransposed2dNHWCObj::inferShape(const TensorVec &inputs) {
const Tensor &input = inputs[0], &weight = inputs[1]; const Tensor &input = inputs[0], &weight = inputs[1];
auto n = input->getDims()[0]; n = input->getDims()[0];
auto f = input->getDims()[3]; f = input->getDims()[3];
auto h = input->getDims()[1]; h = input->getDims()[1];
auto w = input->getDims()[2]; w = input->getDims()[2];
auto c = weight->getDims()[3]; c = weight->getDims()[3];
auto r = weight->getDims()[1]; r = weight->getDims()[1];
auto s = weight->getDims()[2]; s = weight->getDims()[2];
if (f != weight->getDims()[0]) IT_ASSERT(f == weight->getDims()[0]);
return {};
int on = n, oc = c * group; int on = n, oc = c * group;
int oh = 0, ow = 0; int oh = 0, ow = 0;

View File

@ -6,7 +6,7 @@ DetObj::DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> DetObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> DetObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0]; const auto A = inputs[0];
auto input = A->getDims(); auto input = A->getDims();
int rank = A->getRank(); int rank = A->getRank();

View File

@ -10,7 +10,7 @@ DropoutObj::DropoutObj(GraphObj *graph, Tensor data, Tensor output, Tensor mask,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> DropoutObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> DropoutObj::inferShape(const TensorVec &inputs) {
auto shape = inputs[0]->getDims(); auto shape = inputs[0]->getDims();
return {{shape, shape}}; return {{shape, shape}};
} }

View File

@ -8,8 +8,7 @@ ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> optional<vector<Shape>> ElementWiseObj::inferShape(const TensorVec &inputs) {
ElementWiseObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0], B = inputs[1]; const auto A = inputs[0], B = inputs[1];
auto res = infer_broadcast(A->getDims(), B->getDims()); auto res = infer_broadcast(A->getDims(), B->getDims());
return {{res}}; return {{res}};
@ -45,7 +44,7 @@ MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> MSELossObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> MSELossObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0], B = inputs[1]; const auto A = inputs[0], B = inputs[1];
IT_ASSERT(A->getRank() == B->getRank()); IT_ASSERT(A->getRank() == B->getRank());
IT_ASSERT(A->getDims() == B->getDims()); IT_ASSERT(A->getDims() == B->getDims());

View File

@ -8,7 +8,7 @@ ExpandObj::ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> ExpandObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ExpandObj::inferShape(const TensorVec &inputs) {
auto shape_input = inputs[0]->getDims(); auto shape_input = inputs[0]->getDims();
Shape ret = infer_broadcast(shape_input, dims); Shape ret = infer_broadcast(shape_input, dims);
return {{ret}}; return {{ret}};

View File

@ -11,7 +11,7 @@ ExtendObj::ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> ExtendObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ExtendObj::inferShape(const TensorVec &inputs) {
auto ret = inputs[0]->getDims(); auto ret = inputs[0]->getDims();
ret[dim] = ret[dim] * (num + 1); ret[dim] = ret[dim] * (num + 1);
return {{ret}}; return {{ret}};

View File

@ -10,7 +10,7 @@ GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) {
auto dims0 = inputs[0]->getDims(); auto dims0 = inputs[0]->getDims();
auto dims1 = inputs[1]->getDims(); auto dims1 = inputs[1]->getDims();

View File

@ -24,8 +24,7 @@ bool checkShape(Tensor input, Tensor indices, int axis) {
return true; return true;
} }
optional<vector<Shape>> optional<vector<Shape>> GatherElementsObj::inferShape(const TensorVec &inputs) {
GatherElementsObj::inferShape(const TensorVec &inputs) const {
IT_ASSERT(checkShape(inputs[0], inputs[1], axis)); IT_ASSERT(checkShape(inputs[0], inputs[1], axis));
auto indicesDims = inputs[1]->getDims(); // output has same shape as indices auto indicesDims = inputs[1]->getDims(); // output has same shape as indices
return {{indicesDims}}; return {{indicesDims}};

View File

@ -9,25 +9,6 @@ MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
: OperatorObj(OpType::MatMul, : OperatorObj(OpType::MatMul,
bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}), bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}),
transA(transA), transB(transB), act(act), b(1) { transA(transA), transB(transB), act(act), b(1) {
auto shape_a = A->getDims();
auto shape_b = B->getDims();
int rankA = A->getRank();
int rankB = B->getRank();
IT_ASSERT(rankA >= 2 && rankB >= 2);
Shape shape_a1(shape_a.begin(), shape_a.begin() + (rankA - 2));
Shape shape_b1(shape_b.begin(), shape_b.begin() + (rankB - 2));
auto ret = infer_broadcast(shape_a1, shape_b1);
if (ret.empty()) {
b = 1;
} else {
b = std::accumulate(ret.begin(), ret.end(), 1, std::multiplies<int>());
}
auto kA = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
auto kB = *(transB ? shape_b.rbegin() : shape_b.rbegin() + 1);
IT_ASSERT(kA == kB);
m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
k = kA;
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
@ -40,7 +21,7 @@ string MatmulObj::toString() const {
return os.str(); return os.str();
} }
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) {
auto A = inputs[0], B = inputs[1]; auto A = inputs[0], B = inputs[1];
auto shapeA = A->getDims(); auto shapeA = A->getDims();
auto shapeB = B->getDims(); auto shapeB = B->getDims();
@ -49,6 +30,17 @@ optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
Shape shapeA1(shapeA.begin(), shapeA.begin() + (rankA - 2)); Shape shapeA1(shapeA.begin(), shapeA.begin() + (rankA - 2));
Shape shapeB1(shapeB.begin(), shapeB.begin() + (rankB - 2)); Shape shapeB1(shapeB.begin(), shapeB.begin() + (rankB - 2));
Shape ret = infer_broadcast(shapeA1, shapeB1); Shape ret = infer_broadcast(shapeA1, shapeB1);
if (ret.empty()) {
b = 1;
} else {
b = std::accumulate(ret.begin(), ret.end(), 1, std::multiplies<int>());
}
auto kA = *(transA ? shapeA.rbegin() + 1 : shapeA.rbegin());
auto kB = *(transB ? shapeB.rbegin() : shapeB.rbegin() + 1);
IT_ASSERT(kA == kB);
m = *(transA ? shapeA.rbegin() : shapeA.rbegin() + 1);
n = *(transB ? shapeB.rbegin() + 1 : shapeB.rbegin());
k = kA;
ret.emplace_back(m); ret.emplace_back(m);
ret.emplace_back(n); ret.emplace_back(n);
return {{ret}}; return {{ret}};

View File

@ -60,7 +60,7 @@ string MemBoundObj::toString() const {
return os.str(); return os.str();
} }
optional<vector<Shape>> MemBoundObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> MemBoundObj::inferShape(const TensorVec &inputs) {
// inputs have to match nnetInputs excatly // inputs have to match nnetInputs excatly
if (inputs.size() != nnetInputs.size()) if (inputs.size() != nnetInputs.size())
return {}; return {};

View File

@ -22,7 +22,7 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> PadObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> PadObj::inferShape(const TensorVec &inputs) {
auto dims = inputs[0]->getDims(); auto dims = inputs[0]->getDims();
int rank = inputs[0]->getRank(); int rank = inputs[0]->getRank();
IT_ASSERT(rank * 2 == (int)pads.size()); IT_ASSERT(rank * 2 == (int)pads.size());

View File

@ -12,7 +12,7 @@ PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) {
const auto &input = inputs[0]; const auto &input = inputs[0];
auto h = input->getDims()[input->getRank() - 2], auto h = input->getDims()[input->getRank() - 2],
w = input->getDims()[input->getRank() - 1]; w = input->getDims()[input->getRank() - 1];

View File

@ -21,8 +21,7 @@ bool ReduceMeanObj::isReduced(int idx) const {
return axes.find(idx) != axes.end(); return axes.find(idx) != axes.end();
} }
optional<vector<Shape>> optional<vector<Shape>> ReduceMeanObj::inferShape(const TensorVec &inputs) {
ReduceMeanObj::inferShape(const TensorVec &inputs) const {
auto dims = inputs[0]->getDims(); auto dims = inputs[0]->getDims();
auto rank = inputs[0]->getRank(); auto rank = inputs[0]->getRank();

View File

@ -1,5 +1,6 @@
#include "operators/reshape.h" #include "operators/reshape.h"
#include "utils/operator_utils.h" #include "utils/operator_utils.h"
#include <numeric>
namespace infini { namespace infini {
ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims) ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
@ -7,14 +8,37 @@ ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) {
size_t size = 1; int count = 0;
for (size_t i = 0; i < dims.size(); ++i) { for (auto x : dims) {
size *= dims.at(i); if (x == -1) {
count++;
} }
IT_ASSERT(size == inputs[0]->size()); IT_ASSERT(x == -1 || x >= 0);
}
IT_ASSERT(count == 0 || count == 1);
auto inputShape = inputs[0]->getDims();
int size = inputs[0]->size();
int index = -1;
outputShape = dims;
for (int i = 0; i < (int)dims.size(); ++i) {
if (dims[i] == 0) {
outputShape[i] = inputShape[i];
}
if (dims[i] == -1) {
index = i;
}
}
if (index != -1) {
outputShape[index] =
size / (-std::accumulate(outputShape.begin(), outputShape.end(), 1,
[](auto acc, auto x) { return acc * x; }));
}
int outputSize = std::accumulate(outputShape.begin(), outputShape.end(), 1,
[](auto acc, auto x) { return acc * x; });
IT_ASSERT(outputSize == size);
return {{dims}}; return {{outputShape}};
} }
std::string ReshapeObj::toString() const { std::string ReshapeObj::toString() const {
@ -22,7 +46,7 @@ std::string ReshapeObj::toString() const {
os << "Reshape[" << getGuid() << "]"; os << "Reshape[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
os << "dims=" << vecToString(dims) << ","; os << "outputShape=" << vecToString(outputShape) << ",";
os << "input=" << inputs[0]->getGuid() << ","; os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")"; os << "output=" << outputs[0]->getGuid() << ")";
return os.str(); return os.str();
@ -30,12 +54,12 @@ std::string ReshapeObj::toString() const {
vector<int> ReshapeObj::getWorkloadVector() const { vector<int> ReshapeObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
ret.insert(ret.end(), dims.begin(), dims.end()); ret.insert(ret.end(), outputShape.begin(), outputShape.end());
ret.emplace(ret.begin(), type.underlying()); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
vector<int> ReshapeObj::getOpAttrVector() const { vector<int> ReshapeObj::getOpAttrVector() const {
vector<int> ret = dims; vector<int> ret = outputShape;
ret.emplace(ret.begin(), type.underlying()); ret.emplace(ret.begin(), type.underlying());
return ret; return ret;
} }
@ -47,7 +71,7 @@ FlattenObj::FlattenObj(GraphObj *graph, Tensor input, Tensor output, int _axis)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> FlattenObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> FlattenObj::inferShape(const TensorVec &inputs) {
int sizeB = 1, sizeE = 1; int sizeB = 1, sizeE = 1;
auto dims = getInputs(0)->getDims(); auto dims = getInputs(0)->getDims();
int rank = getInputs(0)->getRank(); int rank = getInputs(0)->getRank();
@ -84,7 +108,7 @@ IdentityObj::IdentityObj(GraphObj *graph, Tensor input, Tensor output)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> IdentityObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> IdentityObj::inferShape(const TensorVec &inputs) {
return {{getInputs(0)->getDims()}}; return {{getInputs(0)->getDims()}};
} }

View File

@ -206,7 +206,7 @@ float ResizeObj::round_int(float x) const {
} }
// output shape is related to sizes/scales value. // output shape is related to sizes/scales value.
optional<vector<Shape>> ResizeObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ResizeObj::inferShape(const TensorVec &inputs) {
auto inDims = inputs[0]->getDims(); auto inDims = inputs[0]->getDims();
Shape ret = inDims; Shape ret = inDims;
int rank = inputs[0]->getRank(); int rank = inputs[0]->getRank();

View File

@ -62,7 +62,7 @@ SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> SliceObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> SliceObj::inferShape(const TensorVec &inputs) {
Shape ans; Shape ans;
ans.reserve(axes.size()); ans.reserve(axes.size());
for (const auto &range : axes) { for (const auto &range : axes) {

View File

@ -35,7 +35,7 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> SplitObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> SplitObj::inferShape(const TensorVec &inputs) {
IT_ASSERT(num != -1 && ratio.size() != 0); IT_ASSERT(num != -1 && ratio.size() != 0);
auto inputDims = inputs[0]->getDims(); auto inputDims = inputs[0]->getDims();
int totalSize = inputDims.at(dim); int totalSize = inputDims.at(dim);

View File

@ -16,8 +16,7 @@ TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> optional<vector<Shape>> TransposeObj::inferShape(const TensorVec &inputs) {
TransposeObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0]; const auto A = inputs[0];
auto input_dim = A->getDims(); auto input_dim = A->getDims();
auto output_dim = input_dim; auto output_dim = input_dim;
@ -66,8 +65,7 @@ DepthToSpaceObj::DepthToSpaceObj(GraphObj *graph, Tensor input, Tensor output,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> optional<vector<Shape>> DepthToSpaceObj::inferShape(const TensorVec &inputs) {
DepthToSpaceObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0]; const auto A = inputs[0];
auto inputDim = A->getDims(); auto inputDim = A->getDims();
IT_ASSERT(inputDim.size() == 4); IT_ASSERT(inputDim.size() == 4);

View File

@ -6,7 +6,7 @@ UnaryObj::UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> UnaryObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> UnaryObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0]; const auto A = inputs[0];
return {{A->getDims()}}; return {{A->getDims()}};
} }
@ -37,7 +37,7 @@ ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> ClipObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ClipObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0]; const auto A = inputs[0];
return {{A->getDims()}}; return {{A->getDims()}};
} }
@ -68,7 +68,7 @@ HardtanhObj::HardtanhObj(GraphObj *graph, Tensor input, Tensor output,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> HardtanhObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> HardtanhObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0]; const auto A = inputs[0];
return {{A->getDims()}}; return {{A->getDims()}};
} }
@ -97,7 +97,7 @@ FillObj::FillObj(GraphObj *graph, Tensor input, Tensor output, float value)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> FillObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> FillObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0]; const auto A = inputs[0];
return {{A->getDims()}}; return {{A->getDims()}};
} }
@ -124,7 +124,7 @@ L2LossObj::L2LossObj(GraphObj *graph, Tensor input, Tensor output)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> L2LossObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> L2LossObj::inferShape(const TensorVec &inputs) {
Shape temp = {1}; Shape temp = {1};
return {{temp}}; return {{temp}};
} }
@ -159,7 +159,7 @@ vector<DataType> CastObj::inferDataType(const TensorVec &inputs) const {
return vector(numOutputs(), output_dataType); return vector(numOutputs(), output_dataType);
} }
optional<vector<Shape>> CastObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> CastObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0]; const auto A = inputs[0];
return {{A->getDims()}}; return {{A->getDims()}};
} }
@ -241,7 +241,7 @@ ShapeObj::ShapeObj(GraphObj *graph, Tensor input, Tensor output)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> ShapeObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ShapeObj::inferShape(const TensorVec &inputs) {
return {{{static_cast<int>(inputs[0]->getRank())}}}; return {{{static_cast<int>(inputs[0]->getRank())}}};
} }
@ -257,7 +257,7 @@ PReluObj::PReluObj(GraphObj *graph, Tensor input, Tensor alpha, Tensor output)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> PReluObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> PReluObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0]; const auto A = inputs[0];
return {{A->getDims()}}; return {{A->getDims()}};
} }
@ -286,7 +286,7 @@ LogObj::LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type)
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> LogObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> LogObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0]; const auto A = inputs[0];
return {{A->getDims()}}; return {{A->getDims()}};
} }

View File

@ -10,7 +10,7 @@ WhereObj::WhereObj(GraphObj *graph, Tensor inputX, Tensor inputY,
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> WhereObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> WhereObj::inferShape(const TensorVec &inputs) {
auto shapeX = inputs[0]->getDims(); auto shapeX = inputs[0]->getDims();
auto shapeY = inputs[1]->getDims(); auto shapeY = inputs[1]->getDims();
auto shapeCon = inputs[2]->getDims(); auto shapeCon = inputs[2]->getDims();

View File

@ -158,4 +158,33 @@ TEST(Concat, CudaHigh) {
12., 13., 14., 15., 16., 17., 1., 1., 1., 1., 1., 1., 12., 13., 14., 15., 16., 17., 1., 1., 1., 1., 1., 1.,
18., 19., 20., 21., 22., 23., 1., 1., 1., 1., 1., 1.})); 18., 19., 20., 21., 22., 23., 1., 1., 1., 1., 1., 1.}));
} }
TEST(ConcatToIdentity, Cuda) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
auto t1 = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32);
auto t2 = gCpu->addTensor({0}, DataType::Float32);
gCpu->dataMalloc();
t1->setData(IncrementalGenerator());
t2->setData(OneGenerator());
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto t1Gpu = gCuda->cloneTensor(t1);
auto t2Gpu = gCuda->cloneTensor(t2);
auto op = gCuda->addOp<ConcatObj>(TensorVec{t1Gpu, t2Gpu}, nullptr, 2);
gCuda->dataMalloc();
t1Gpu->setData(IncrementalGenerator());
t2Gpu->setData(OneGenerator());
cudaRuntime->run(gCuda);
// cudaPrintTensor(op->getOutput());
// copy output from CUDA to CPU
auto oCpu = gCpu->cloneTensor(op->getOutput());
EXPECT_TRUE(
oCpu->equalData(vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}));
}
} // namespace infini } // namespace infini

View File

@ -14,4 +14,13 @@ TEST(Concat, ShapeInfer) {
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 9})); EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 9}));
} }
TEST(Concat, ShapeInfer1) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
auto t1 = g->addTensor({1, 3, 2, 4}, DataType::Float32);
auto t2 = g->addTensor({0}, DataType::Float32);
auto op = g->addOp<ConcatObj>(TensorVec{t1, t2}, nullptr, 3);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4}));
}
} // namespace infini } // namespace infini