forked from jiuyuan/InfiniTensor
support Dynamic tensor infer shape and fix memory pool (#176)
* feat: support dynamic tensor part1 * feat: support dynamic-tensor part2 * feat: support dynamic tensor part 3 * fix: fix some .. * - add kvcache example * feat: support concat to identity kernel * add a simple mempory pool for allocator * fix: rebase to master * fix bug after merging * - remove outdated script * fix: fix as review --------- Co-authored-by: kilinchange <kilinchange@163.com> Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
965df4e294
commit
331f7ab2b8
|
@ -53,6 +53,7 @@ class GraphObj : public Object {
|
|||
const TensorVec &getTensors() const { return tensors; }
|
||||
const OpVec &getOperators() const { return ops; }
|
||||
OpVec getComputeOps() const;
|
||||
Tensor getTensor(int) const;
|
||||
|
||||
/**
|
||||
* Sort the nodes in topological order.
|
||||
|
@ -64,7 +65,13 @@ class GraphObj : public Object {
|
|||
|
||||
void optimize();
|
||||
|
||||
void dataMalloc(bool useNaiveAllocator = false);
|
||||
void shape_infer();
|
||||
|
||||
void dataMalloc(bool useNaiveAllocator = false, size_t memPoolSize = 0);
|
||||
|
||||
Tensor cloneKV(Tensor &tensor);
|
||||
|
||||
void freeHeap();
|
||||
|
||||
/**
|
||||
* @brief Add an operator and create its outputs. Output tensor arguments
|
||||
|
|
|
@ -81,6 +81,7 @@ class GraphHandlerObj {
|
|||
Tensor cast(Tensor input, Tensor output, int to);
|
||||
Tensor expand(Tensor input, Tensor output, Shape dims);
|
||||
Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output);
|
||||
std::vector<int> getDims(Tensor x) { return x->getDims(); }
|
||||
|
||||
Tensor allReduceSum(Tensor input, Tensor output);
|
||||
Tensor allReduceProd(Tensor input, Tensor output);
|
||||
|
@ -98,9 +99,19 @@ class GraphHandlerObj {
|
|||
|
||||
inline void optimize() { g->optimize(); }
|
||||
|
||||
inline void shape_infer() { g->shape_infer(); }
|
||||
|
||||
void change_shape(const vector<int> &shape, int tensorId);
|
||||
//------ runtime
|
||||
|
||||
inline void data_malloc() { g->dataMalloc(); }
|
||||
inline void data_malloc(bool useNaiveAllocator = false,
|
||||
size_t memPoolSize = 0) {
|
||||
g->dataMalloc(useNaiveAllocator, memPoolSize);
|
||||
}
|
||||
|
||||
inline Tensor clone_KV(Tensor &tensor) { return g->cloneKV(tensor); }
|
||||
|
||||
inline void free_heap() { g->freeHeap(); }
|
||||
|
||||
inline void tune() { g->getRuntime()->run(g, true); }
|
||||
|
||||
|
|
|
@ -26,14 +26,23 @@ class LazyAllocator {
|
|||
|
||||
size_t weightPeak = 0;
|
||||
|
||||
size_t heapPeak = 0;
|
||||
|
||||
size_t alignment;
|
||||
|
||||
bool hasMemPool = false;
|
||||
|
||||
size_t memPoolSize = 0;
|
||||
|
||||
// pointer to the memory actually allocated
|
||||
void *ptr = nullptr;
|
||||
|
||||
// pointer to the weight memory space
|
||||
void *weightPtr = nullptr;
|
||||
|
||||
// memory pool ptr
|
||||
void *memPoolPtr = nullptr;
|
||||
|
||||
// // a cache designed for a batch size that has already occurred
|
||||
// std::unordered_map<size_t, std::unordered_map<TensorObj *, size_t>>
|
||||
// batchsizeToTensorOffset;
|
||||
|
@ -68,6 +77,10 @@ class LazyAllocator {
|
|||
|
||||
void init();
|
||||
|
||||
void setMemPool(size_t memPoolSize);
|
||||
|
||||
bool getMemPoolStatus();
|
||||
|
||||
// function: simulate memory allocation
|
||||
// arguments:
|
||||
// size: size of memory block to be allocated
|
||||
|
@ -76,6 +89,10 @@ class LazyAllocator {
|
|||
|
||||
size_t allocWeight(size_t size);
|
||||
|
||||
size_t heapAlloc(size_t size);
|
||||
|
||||
void freeHeap();
|
||||
|
||||
// function: simulate memory free
|
||||
// arguments:
|
||||
// addr: head address offset of memory block to be free
|
||||
|
@ -92,6 +109,8 @@ class LazyAllocator {
|
|||
|
||||
void *getWeightPtr();
|
||||
|
||||
void *getHeapPtr();
|
||||
|
||||
void info();
|
||||
|
||||
private:
|
||||
|
|
|
@ -55,8 +55,7 @@ class OperatorObj : public Object {
|
|||
|
||||
public:
|
||||
OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs);
|
||||
virtual optional<vector<Shape>>
|
||||
inferShape(const TensorVec &inputs) const = 0;
|
||||
virtual optional<vector<Shape>> inferShape(const TensorVec &inputs) = 0;
|
||||
virtual vector<DataType> inferDataType(const TensorVec &inputs) const;
|
||||
/**
|
||||
* @brief Constructs outputs (if requried) and check whether the operator is
|
||||
|
@ -105,7 +104,7 @@ class OperatorObj : public Object {
|
|||
const TensorVec &newOutputs) const = 0;
|
||||
|
||||
protected:
|
||||
optional<vector<Shape>> inferShape() const;
|
||||
optional<vector<Shape>> inferShape();
|
||||
vector<DataType> inferDataType() const;
|
||||
|
||||
private:
|
||||
|
|
|
@ -31,6 +31,7 @@ class TensorObj : public TensorBaseObj {
|
|||
size_t getBytes() const { return _size * dtype.getSize(); }
|
||||
|
||||
Shape getDims() const { return shape; }
|
||||
void setShape(Shape shape_);
|
||||
size_t getRank() const { return shape.size(); }
|
||||
Shape getStride() const;
|
||||
size_t getOffset(const vector<int> &ds) const;
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||
void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
|
||||
void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||
void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
|
||||
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3);
|
||||
void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3);
|
||||
}; // namespace infini
|
||||
|
|
|
@ -10,10 +10,11 @@ typedef struct {
|
|||
int wholeNDim[MAX_DIM]; // dim size after padding or before slicing
|
||||
int partNDim[MAX_DIM]; // dim size before padding or after slicing
|
||||
int partStride[MAX_DIM]; // stride before padding or after slicing
|
||||
int DType;
|
||||
} TransMetaData;
|
||||
|
||||
namespace infini {
|
||||
void pad_slice_kernel(float *partData, float *wholeData,
|
||||
void pad_slice_kernel(void *partData, void *wholeData,
|
||||
const TransMetaData &metadata, int nDims, int num,
|
||||
bool isPad);
|
||||
} // namespace infini
|
||||
|
|
|
@ -35,7 +35,7 @@ class G2BMMObj : public OperatorObj {
|
|||
OP_CLONE(G2BMMObj);
|
||||
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -33,7 +33,7 @@ class GBMMObj : public OperatorObj {
|
|||
OP_CLONE(GBMMObj);
|
||||
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -7,7 +7,7 @@ class ActivationBackwardObj : public OperatorObj {
|
|||
ActivationBackwardObj(OpType type, GraphObj *graph, Tensor y, Tensor diff_y,
|
||||
Tensor x, Tensor diff_x);
|
||||
OP_CLONE(ActivationBackwardObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 3; }
|
||||
|
|
|
@ -27,7 +27,7 @@ class AllGatherObj : public OperatorObj {
|
|||
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return world_size; }
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class AllReduceBaseObj : public OperatorObj {
|
|||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
|
||||
return {{inputs[0]->getDims()}};
|
||||
};
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ class AttentionKVCacheObj : public OperatorObj {
|
|||
Tensor output_matmul);
|
||||
OP_CLONE(AttentionKVCacheObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 6; }
|
||||
|
|
|
@ -34,7 +34,7 @@ class BatchNormObj : public OperatorObj {
|
|||
Tensor var, Tensor scale, Tensor bias, float momentum = 0.9,
|
||||
float eps = 1e-5, bool trainingMode = false);
|
||||
OP_CLONE(BatchNormObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
|
||||
// output size will be 3 when training
|
||||
|
|
|
@ -26,7 +26,7 @@ class BroadcastObj : public OperatorObj {
|
|||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
|
||||
return {{inputs[0]->getDims()}};
|
||||
};
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ class ConcatObj : public OperatorObj {
|
|||
ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim);
|
||||
OP_CLONE(ConcatObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
|
|
|
@ -142,7 +142,7 @@ class ConvObj : public ConvBaseObj {
|
|||
ActType act = ActType::None);
|
||||
OP_CLONE(ConvObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
int getNumGroups() const override { return c / getChannelPerGroup(); }
|
||||
|
||||
private:
|
||||
|
@ -164,7 +164,7 @@ class ConvBackwardFilterObj : public ConvBaseObj {
|
|||
int sh = 1, int sw = 1, int dh = 1, int dw = 1,
|
||||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
ActType getAct() const { return act; }
|
||||
int getNumGroups() const override { return c / getChannelPerGroup(); }
|
||||
|
||||
|
@ -191,7 +191,7 @@ class ConvTransposed2dObj : public ConvBaseObj {
|
|||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
OP_CLONE(ConvTransposed2dObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
int getNumGroups() const override { return group; }
|
||||
std::pair<int, int> getOutputPadding() const { return {oph, opw}; }
|
||||
|
||||
|
@ -218,7 +218,7 @@ class ConvTransposed2dNHWCObj : public ConvBaseObj {
|
|||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
OP_CLONE(ConvTransposed2dNHWCObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
int getNumGroups() const override { return group; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -7,7 +7,7 @@ class DetObj : public OperatorObj {
|
|||
enum Mode { NormalDet = 0, LogDet };
|
||||
DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode);
|
||||
OP_CLONE(DetObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -37,7 +37,7 @@ class DropoutObj : public OperatorObj {
|
|||
DropoutObj(GraphObj *graph, Tensor data, Tensor output, Tensor mask,
|
||||
float ratio, bool training_mode);
|
||||
OP_CLONE(DropoutObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -21,7 +21,7 @@ class ElementWiseObj : public OperatorObj {
|
|||
*/
|
||||
ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, Tensor input1,
|
||||
Tensor output);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
|
@ -38,7 +38,7 @@ class MSELossObj : public OperatorObj {
|
|||
MSELossObj(GraphObj *graph, Tensor input0, Tensor input1,
|
||||
Reduction reduction, Tensor output);
|
||||
OP_CLONE(MSELossObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
Reduction getReduction() const { return reductionMode; }
|
||||
std::string toString() const override;
|
||||
|
|
|
@ -21,7 +21,7 @@ class ExpandObj : public OperatorObj {
|
|||
*/
|
||||
ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
|
||||
OP_CLONE(ExpandObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -23,7 +23,7 @@ class ExtendObj : public OperatorObj {
|
|||
ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
|
||||
int num = 1);
|
||||
OP_CLONE(ExtendObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -39,7 +39,7 @@ class GatherObj : public GatherBaseObj {
|
|||
int axis);
|
||||
OP_CLONE(GatherObj);
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
|
||||
private:
|
||||
|
@ -69,7 +69,7 @@ class GatherElementsObj : public GatherBaseObj {
|
|||
Tensor output, int axis);
|
||||
OP_CLONE(GatherElementsObj);
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -45,7 +45,7 @@ class MatmulObj : public OperatorObj {
|
|||
OP_CLONE(MatmulObj);
|
||||
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -21,7 +21,7 @@ class MemBoundObj : public OperatorObj {
|
|||
OP_CLONE(MemBoundObj);
|
||||
|
||||
std::string toString() const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
|
|
|
@ -27,7 +27,7 @@ class PadObj : public OperatorObj {
|
|||
const vector<int> &pads, const optional<vector<int>> &axes);
|
||||
OP_CLONE(PadObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -41,7 +41,7 @@ class PoolingObj : public OperatorObj {
|
|||
int ceilMode);
|
||||
OP_CLONE(PoolingObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -23,7 +23,7 @@ class ReduceMeanObj : public OperatorObj {
|
|||
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
const optional<vector<int>> &axes, bool keepDims = true);
|
||||
OP_CLONE(ReduceMeanObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace infini {
|
|||
*/
|
||||
class ReshapeObj : public OperatorObj {
|
||||
Shape dims;
|
||||
Shape outputShape;
|
||||
|
||||
public:
|
||||
/**
|
||||
|
@ -17,18 +18,20 @@ class ReshapeObj : public OperatorObj {
|
|||
* @param graph The computation graph that this operator belongs to.
|
||||
* @param input The input tensor.
|
||||
* @param output The output tensor.
|
||||
* @param dims The shape of the output tensor.
|
||||
* @param dims The shape to infer the output shape.
|
||||
* @param outputShape The real shape of output tensor.
|
||||
*/
|
||||
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
|
||||
OP_CLONE(ReshapeObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
inline Shape getShape() const { return dims; }
|
||||
inline Shape getShape() const { return outputShape; }
|
||||
inline Shape getDims() const { return dims; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
@ -55,7 +58,7 @@ class FlattenObj : public OperatorObj {
|
|||
FlattenObj(GraphObj *graph, Tensor input, Tensor output, int axis);
|
||||
OP_CLONE(FlattenObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
@ -85,7 +88,7 @@ class IdentityObj : public OperatorObj {
|
|||
IdentityObj(GraphObj *graph, Tensor input, Tensor output);
|
||||
OP_CLONE(IdentityObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -60,7 +60,7 @@ class ResizeObj : public OperatorObj {
|
|||
|
||||
// Operator clone(TensorVec inputs, TensorVec outputs) override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -32,7 +32,7 @@ class SliceObj : public OperatorObj {
|
|||
const optional<vector<int>> &steps);
|
||||
OP_CLONE(SliceObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
std::string toString() const override;
|
||||
inline int numInputs() const override { return 1; }
|
||||
inline int numOutputs() const override { return 1; }
|
||||
|
|
|
@ -10,7 +10,7 @@ class SoftmaxObj : public OperatorObj {
|
|||
|
||||
OP_CLONE(SoftmaxObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override {
|
||||
return {{inputs[0]->getDims()}};
|
||||
};
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class SplitObj : public OperatorObj {
|
|||
int dim, const vector<int> &ratio);
|
||||
OP_CLONE(SplitObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -7,7 +7,7 @@ class TransposeObj : public OperatorObj {
|
|||
TransposeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
vector<int> permute);
|
||||
OP_CLONE(TransposeObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
@ -25,7 +25,7 @@ class DepthToSpaceObj : public OperatorObj {
|
|||
DepthToSpaceObj(GraphObj *graph, Tensor input, Tensor output, int blocksize,
|
||||
std::string mode);
|
||||
OP_CLONE(DepthToSpaceObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
|
|
@ -17,7 +17,7 @@ class UnaryObj : public OperatorObj {
|
|||
* @param output The output tensor.
|
||||
*/
|
||||
UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
@ -33,7 +33,7 @@ class ClipObj : public OperatorObj {
|
|||
ClipObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
std::optional<float> min, std::optional<float> max);
|
||||
OP_CLONE(ClipObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
std::optional<float> getMin() const { return minValue; };
|
||||
|
@ -52,7 +52,7 @@ class HardtanhObj : public OperatorObj {
|
|||
HardtanhObj(GraphObj *graph, Tensor input, Tensor output, float min,
|
||||
float max);
|
||||
OP_CLONE(HardtanhObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
float getMin() const { return minValue; };
|
||||
|
@ -70,7 +70,7 @@ class FlipObj : public OperatorObj {
|
|||
public:
|
||||
FlipObj(GraphObj *graph, Tensor input, Tensor output, vector<int> axis);
|
||||
OP_CLONE(FlipObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
vector<int> getAxis() const { return axisValue; };
|
||||
|
@ -87,7 +87,7 @@ class FillObj : public OperatorObj {
|
|||
public:
|
||||
FillObj(GraphObj *graph, Tensor input, Tensor output, float value);
|
||||
OP_CLONE(FillObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
float getValue() const { return setValue; };
|
||||
|
@ -104,7 +104,7 @@ class L2LossObj : public OperatorObj {
|
|||
public:
|
||||
L2LossObj(GraphObj *graph, Tensor input, Tensor output);
|
||||
OP_CLONE(L2LossObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
@ -120,7 +120,7 @@ class TransformObj : public OperatorObj {
|
|||
TransformObj(GraphObj *graph, Tensor input, Tensor output, float alpha,
|
||||
float beta);
|
||||
OP_CLONE(TransformObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
float getAlpha() const { return alphaValue; }
|
||||
|
@ -165,7 +165,7 @@ class CastObj : public OperatorObj {
|
|||
public:
|
||||
CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type);
|
||||
OP_CLONE(CastObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
vector<DataType> inferDataType(const TensorVec &inputs) const override;
|
||||
|
||||
std::string toString() const override;
|
||||
|
@ -185,7 +185,7 @@ class CumsumObj : public OperatorObj {
|
|||
CumsumObj(GraphObj *graph, Tensor input, Tensor output, int axis,
|
||||
bool exclusive, bool reverse);
|
||||
OP_CLONE(CumsumObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int getAxis() const { return axisValue; }
|
||||
|
@ -205,7 +205,7 @@ class ShapeObj : public OperatorObj {
|
|||
public:
|
||||
ShapeObj(GraphObj *graph, Tensor input, Tensor output);
|
||||
OP_CLONE(ShapeObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
|
@ -216,7 +216,7 @@ class PReluObj : public OperatorObj {
|
|||
public:
|
||||
PReluObj(GraphObj *graph, Tensor input, Tensor alpha, Tensor output);
|
||||
OP_CLONE(PReluObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
|
@ -236,7 +236,7 @@ class LogObj : public OperatorObj {
|
|||
};
|
||||
LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type);
|
||||
OP_CLONE(LogObj);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
LogType getType() const { return logType; }
|
||||
|
|
|
@ -22,7 +22,7 @@ class WhereObj : public OperatorObj {
|
|||
Tensor output);
|
||||
OP_CLONE(WhereObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
|
|
|
@ -510,19 +510,11 @@ class OnnxStub:
|
|||
mode,
|
||||
)
|
||||
elif node.op_type == "Reshape":
|
||||
dims = _search_shape(model, node.input[0])
|
||||
size = reduce(lambda acc, x: acc * x, dims)
|
||||
input_shape = _parse_data(data[node.input[1]])
|
||||
for i, x in enumerate(input_shape):
|
||||
if x == 0:
|
||||
input_shape[i] = dims[i]
|
||||
temp = reduce(lambda acc, x: acc * x, input_shape, 1)
|
||||
if temp < 0:
|
||||
input_shape[input_shape.index(-1)] = size // -temp
|
||||
shape = _parse_data(data[node.input[1]])
|
||||
tensors[node.output[0]] = self.handler.reshape(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
input_shape,
|
||||
shape,
|
||||
)
|
||||
elif node.op_type == "Squeeze":
|
||||
input_shape = _search_shape(model, node.input[0])
|
||||
|
@ -1112,6 +1104,26 @@ class OnnxStub:
|
|||
def optimize(self) -> None:
|
||||
self.handler.optimize()
|
||||
|
||||
def clone_KV(self, tensor: backend.Tensor) -> backend.Tensor:
|
||||
return self.handler.clone_KV(tensor)
|
||||
|
||||
def free_heap(self) -> None:
|
||||
self.handler.free_heap()
|
||||
|
||||
def set_input(self, inputShapes: List[int]) -> None:
|
||||
for newInput, oldInput in zip(inputShapes, self.inputs):
|
||||
oldTensor = self.inputs[oldInput]
|
||||
self.handler.change_shape(newInput, oldTensor.fuid())
|
||||
self.handler.shape_infer()
|
||||
self.handler.data_malloc()
|
||||
|
||||
def getShape(self, name: str) -> List[int]:
|
||||
if name in self.inputs:
|
||||
ans = self.handler.getDims(self.inputs[name])
|
||||
else:
|
||||
ans = self.handler.getDims(self.outputs[name])
|
||||
return ans
|
||||
|
||||
def tune(self) -> None:
|
||||
self.handler.tune()
|
||||
|
||||
|
|
|
@ -209,6 +209,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
make_and_import_model(make_graph([relu], "relu", [x], [y]))
|
||||
|
||||
"""Gelu operator is not supported by onnx 14.1 currently."""
|
||||
|
||||
def test_gelu(self):
|
||||
pass
|
||||
# x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
|
@ -500,5 +501,22 @@ class TestStringMethods(unittest.TestCase):
|
|||
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
|
||||
|
||||
|
||||
class TestDynamicTensor(unittest.TestCase):
|
||||
def test_dynamic_tensor(self):
|
||||
filename = r"resnet18-v2-7.onnx"
|
||||
current_path = os.getcwd()
|
||||
model_file = ""
|
||||
for root, dirs, files in os.walk(current_path):
|
||||
if filename in files:
|
||||
model_file = os.path.join(root, filename)
|
||||
model = OnnxStub(onnx.load(model_file), backend.cpu_runtime())
|
||||
output_key = list(model.outputs.keys())[0]
|
||||
old_output_shape = model.getShape(output_key)
|
||||
self.assertEqual(old_output_shape, ([1, 1000]))
|
||||
model.set_input([[5, 3, 224, 224]])
|
||||
new_output_shape = model.getShape(output_key)
|
||||
self.assertEqual(new_output_shape, ([5, 1000]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#include "core/graph.h"
|
||||
#include "operators/reshape.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
|
||||
namespace infini {
|
||||
|
@ -123,10 +125,40 @@ void GraphObj::optimize() {
|
|||
}
|
||||
}
|
||||
|
||||
void GraphObj::dataMalloc(bool useNaiveAllocator) {
|
||||
Tensor GraphObj::getTensor(int fuid) const {
|
||||
for (auto tensor : tensors) {
|
||||
if (tensor->getFuid() == fuid) {
|
||||
return tensor;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void GraphObj::shape_infer() {
|
||||
for (auto &op : ops) {
|
||||
auto ans = op->inferShape();
|
||||
IT_ASSERT(ans.has_value());
|
||||
auto oldOutputs = op->getOutputs();
|
||||
IT_ASSERT(ans.value().size() == oldOutputs.size());
|
||||
// replace the old outputshape and size with new one
|
||||
for (int i = 0; i < (int)ans.value().size(); ++i) {
|
||||
auto newShape = ans.value()[i];
|
||||
auto oldShape = oldOutputs[i]->getDims();
|
||||
auto fuid = oldOutputs[i]->getFuid();
|
||||
if (newShape != oldShape) {
|
||||
auto tensor = this->getTensor(fuid);
|
||||
tensor->setShape(newShape);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphObj::dataMalloc(bool useNaiveAllocator, size_t memPoolSize) {
|
||||
// topological sorting first
|
||||
IT_ASSERT(topo_sort() == true);
|
||||
if (useNaiveAllocator) {
|
||||
// can not set memory pool when use naive allocator
|
||||
IT_ASSERT(memPoolSize == 0);
|
||||
// used for debugging memory out-of-bounds access, tensors will not be
|
||||
// released correctly
|
||||
// note: behavior may not match running in non-naive mode, and it may
|
||||
|
@ -136,6 +168,9 @@ void GraphObj::dataMalloc(bool useNaiveAllocator) {
|
|||
}
|
||||
return;
|
||||
}
|
||||
if (memPoolSize > 0) {
|
||||
allocator.setMemPool(memPoolSize);
|
||||
}
|
||||
// count the number of times all tensors are used
|
||||
std::unordered_map<TensorObj *, size_t> tensorToRefCount;
|
||||
// record the memory address offsets of all tensors to be allocated
|
||||
|
@ -222,6 +257,27 @@ void GraphObj::dataMalloc(bool useNaiveAllocator) {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphObj::cloneKV(Tensor &tensor) {
|
||||
auto obj = tensor->clone();
|
||||
if (allocator.getMemPoolStatus()) {
|
||||
if (tensor->hasData()) {
|
||||
obj->setDataBlob(make_ref<BlobObj>(
|
||||
tensor->runtime,
|
||||
static_cast<uint8_t *>(allocator.getHeapPtr()) +
|
||||
allocator.heapAlloc(tensor->getBytes())));
|
||||
obj->copyData(tensor);
|
||||
}
|
||||
} else {
|
||||
if (tensor->hasData()) {
|
||||
obj->dataMalloc();
|
||||
obj->copyData(tensor);
|
||||
}
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
void GraphObj::freeHeap() { this->allocator.freeHeap(); }
|
||||
|
||||
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
|
||||
return tensors.emplace_back(make_ref<TensorObj>(dim, dtype, runtime));
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "operators/where.h"
|
||||
#include <numeric>
|
||||
|
||||
namespace infini {
|
||||
|
||||
|
@ -555,4 +556,11 @@ static DataType dtype_repr_convert(int dtype) {
|
|||
}
|
||||
}
|
||||
|
||||
void GraphHandlerObj::change_shape(const vector<int> &shape, int tensorId) {
|
||||
auto tensor = g->getTensor(tensorId);
|
||||
IT_ASSERT(tensor != nullptr);
|
||||
IT_ASSERT(shape.size() != 0);
|
||||
tensor->setShape(shape);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -30,6 +30,9 @@ LazyAllocator::~LazyAllocator() {
|
|||
if (this->weightPtr != nullptr) {
|
||||
runtime->dealloc(this->weightPtr);
|
||||
}
|
||||
if (this->memPoolPtr != nullptr) {
|
||||
runtime->dealloc(this->memPoolPtr);
|
||||
}
|
||||
}
|
||||
|
||||
void LazyAllocator::init() {
|
||||
|
@ -44,6 +47,17 @@ void LazyAllocator::init() {
|
|||
this->ptr = nullptr;
|
||||
}
|
||||
|
||||
void LazyAllocator::setMemPool(size_t memPoolSize) {
|
||||
IT_ASSERT(memPoolSize > 0);
|
||||
if (!this->hasMemPool) {
|
||||
this->hasMemPool = true;
|
||||
this->memPoolSize = memPoolSize;
|
||||
this->memPoolPtr = runtime->alloc(memPoolSize);
|
||||
}
|
||||
}
|
||||
|
||||
bool LazyAllocator::getMemPoolStatus() { return this->hasMemPool; }
|
||||
|
||||
size_t LazyAllocator::alloc(size_t size) {
|
||||
// pad the size to the multiple of alignment
|
||||
size = this->getAlignedSize(size);
|
||||
|
@ -102,6 +116,17 @@ size_t LazyAllocator::allocWeight(size_t size) {
|
|||
return retAddr;
|
||||
}
|
||||
|
||||
size_t LazyAllocator::heapAlloc(size_t size) {
|
||||
size = this->getAlignedSize(size);
|
||||
this->heapPeak += size;
|
||||
IT_ASSERT(this->memPoolSize >=
|
||||
this->weightPeak + this->peak + this->heapPeak);
|
||||
size_t retAddr = this->memPoolSize - this->heapPeak;
|
||||
return retAddr;
|
||||
}
|
||||
|
||||
void LazyAllocator::freeHeap() { this->heapPeak = 0; }
|
||||
|
||||
void LazyAllocator::free(size_t addr, size_t size) {
|
||||
IT_ASSERT(this->ptr == nullptr);
|
||||
size = getAlignedSize(size);
|
||||
|
@ -143,25 +168,40 @@ void LazyAllocator::free(size_t addr, size_t size) {
|
|||
}
|
||||
|
||||
void *LazyAllocator::getPtr() {
|
||||
if (this->ptr == nullptr) {
|
||||
this->ptr = runtime->alloc(this->peak);
|
||||
// #ifdef DEBUG_MODE
|
||||
// printf("LazyAllocator really alloc non-weight: %p %lu
|
||||
// bytes\n", this->ptr, peak);
|
||||
// #endif
|
||||
if (!hasMemPool) {
|
||||
if (this->ptr == nullptr) {
|
||||
this->ptr = runtime->alloc(this->peak);
|
||||
// #ifdef DEBUG_MODE
|
||||
// printf("LazyAllocator really alloc non-weight: %p %lu
|
||||
// bytes\n", this->ptr, peak);
|
||||
// #endif
|
||||
}
|
||||
return this->ptr;
|
||||
} else {
|
||||
IT_ASSERT(this->memPoolSize >= this->weightPeak + this->peak);
|
||||
return static_cast<uint8_t *>(this->memPoolPtr) + weightPeak;
|
||||
}
|
||||
return this->ptr;
|
||||
}
|
||||
|
||||
void *LazyAllocator::getWeightPtr() {
|
||||
if (this->weightPtr == nullptr) {
|
||||
this->weightPtr = runtime->alloc(this->weightPeak);
|
||||
// #ifdef DEBUG_MODE
|
||||
// printf("LazyAllocator really alloc weight: %p %lu bytes\n",
|
||||
// this->weightPtr, weightPeak);
|
||||
// #endif
|
||||
if (!hasMemPool) {
|
||||
if (this->weightPtr == nullptr) {
|
||||
this->weightPtr = runtime->alloc(this->weightPeak);
|
||||
// #ifdef DEBUG_MODE
|
||||
// printf("LazyAllocator really alloc weight: %p %lu
|
||||
// bytes\n",
|
||||
// this->weightPtr, weightPeak);
|
||||
// #endif
|
||||
}
|
||||
return this->weightPtr;
|
||||
} else {
|
||||
return this->memPoolPtr;
|
||||
}
|
||||
return this->weightPtr;
|
||||
}
|
||||
|
||||
void *LazyAllocator::getHeapPtr() {
|
||||
IT_ASSERT(hasMemPool);
|
||||
return this->memPoolPtr;
|
||||
}
|
||||
|
||||
size_t LazyAllocator::getAlignedSize(size_t size) {
|
||||
|
|
|
@ -77,9 +77,7 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
optional<vector<Shape>> OperatorObj::inferShape() const {
|
||||
return inferShape(inputs);
|
||||
}
|
||||
optional<vector<Shape>> OperatorObj::inferShape() { return inferShape(inputs); }
|
||||
|
||||
vector<DataType> OperatorObj::inferDataType(const TensorVec &inputs) const {
|
||||
auto dataType = inputs[0]->getDType();
|
||||
|
|
|
@ -59,6 +59,13 @@ Shape TensorObj::getStride() const {
|
|||
return stride;
|
||||
}
|
||||
|
||||
void TensorObj::setShape(Shape shape_) {
|
||||
shape = shape_;
|
||||
size_t size = std::accumulate(shape.begin(), shape.end(), 1,
|
||||
[](auto acc, auto x) { return acc * x; });
|
||||
_size = size;
|
||||
}
|
||||
|
||||
void TensorObj::printData() const {
|
||||
IT_ASSERT(data != nullptr);
|
||||
if (!runtime->isCpu())
|
||||
|
|
|
@ -446,7 +446,10 @@ void init_graph_builder(py::module &m) {
|
|||
})
|
||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||
.def("src", &TensorObj::getSource, policy::move)
|
||||
.def("printData", &TensorObj::printData, policy::automatic);
|
||||
.def("printData", &TensorObj::printData, policy::automatic)
|
||||
.def("copy_data",
|
||||
py::overload_cast<const Tensor &>(&TensorObj::copyData),
|
||||
policy::move);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
|
||||
|
@ -466,6 +469,7 @@ void init_graph_builder(py::module &m) {
|
|||
.def("add", &Handler::add, policy::move)
|
||||
.def("sub", &Handler::sub, policy::move)
|
||||
.def("mul", &Handler::mul, policy::move)
|
||||
.def("max", &Handler::max, policy::move)
|
||||
.def("div", &Handler::div, policy::move)
|
||||
.def("pow", &Handler::pow, policy::move)
|
||||
.def("min", &Handler::min, policy::move)
|
||||
|
@ -510,10 +514,17 @@ void init_graph_builder(py::module &m) {
|
|||
.def("topo_sort", &Handler::topo_sort, policy::automatic)
|
||||
.def("optimize", &Handler::optimize, policy::automatic)
|
||||
.def("operators", &Handler::operators, policy::move)
|
||||
.def("data_malloc", &Handler::data_malloc, policy::automatic)
|
||||
.def("data_malloc", &Handler::data_malloc,
|
||||
py::arg("useNaiveAllocator") = false, py::arg("memPoolSize") = 0,
|
||||
policy::automatic)
|
||||
.def("clone_KV", &Handler::clone_KV, policy::move)
|
||||
.def("free_heap", &Handler::free_heap, policy::move)
|
||||
.def("get_perf_time", &Handler::get_perf_time, policy::automatic)
|
||||
.def("tune", &Handler::tune, policy::automatic)
|
||||
.def("run", &Handler::run, policy::automatic)
|
||||
.def("shape_infer", &Handler::shape_infer, policy::automatic)
|
||||
.def("change_shape", &Handler::change_shape, policy::automatic)
|
||||
.def("getDims", &Handler::getDims, policy::automatic)
|
||||
.def("get_perf_time", &Handler::get_perf_time, policy::automatic);
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,6 @@ class ElementWiseCudnn : public CudaKernelWithoutConfig {
|
|||
std::copy(a_dim.begin(), a_dim.end(), a + (4 - a_dim.size()));
|
||||
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
|
||||
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
|
||||
|
||||
// get inputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&aDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(aDesc, CUDNN_TENSOR_NCHW,
|
||||
|
@ -110,9 +109,9 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
|
|||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
float *const aData = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||
float *const bData = (op->getInputs(1)->getRawDataPtr<float *>());
|
||||
float *const cData = (op->getOutput()->getRawDataPtr<float *>());
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
auto a_dim = op->getInputs(0)->getDims();
|
||||
auto b_dim = op->getInputs(1)->getDims();
|
||||
auto c_dim = op->getOutput()->getDims();
|
||||
|
@ -134,7 +133,13 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
|
|||
else if (op->getOpType() == OpType::Pow)
|
||||
pow_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
|
||||
b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
else
|
||||
else if (op->getOpType() == OpType::Add) {
|
||||
add_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
|
||||
b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
} else if (op->getOpType() == OpType::Less) {
|
||||
less_kernel(aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1],
|
||||
b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||
} else
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
};
|
||||
|
@ -152,6 +157,10 @@ REGISTER_KERNEL(Device::CUDA, OpType::Max, DataType::Float32, MaxCudnn,
|
|||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda,
|
||||
"Div_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Int64, ElementWiseCuda,
|
||||
"Add_CUDA_Int64");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda,
|
||||
"Pow__CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Less, DataType::Int64, ElementWiseCuda,
|
||||
"Less__CUDA_Int64");
|
||||
}; // namespace infini
|
||||
|
|
|
@ -5,9 +5,10 @@ constexpr unsigned int num_threads() { return 32 * 4; }
|
|||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
__global__ void _div_kernel(float *x, float *y, float *z, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3,
|
||||
int c0, int c1, int c2, int c3) {
|
||||
template <class T>
|
||||
__global__ void _div_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||
int c1, int c2, int c3) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
int n = c0 * c1 * c2 * c3;
|
||||
|
@ -27,16 +28,17 @@ __global__ void _div_kernel(float *x, float *y, float *z, int a0, int a1,
|
|||
int b1_index = c1_index % b1;
|
||||
int b2_index = c2_index % b2;
|
||||
int b3_index = c3_index % b3;
|
||||
z[i] = x[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 + a2_index * a3 +
|
||||
a3_index] /
|
||||
y[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 + b2_index * b3 +
|
||||
b3_index];
|
||||
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
||||
a2_index * a3 + a3_index] /
|
||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
||||
b2_index * b3 + b3_index];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _pow_kernel(float *x, float *y, float *z, int a0, int a1,
|
||||
int a2, int a3, int b0, int b1, int b2, int b3,
|
||||
int c0, int c1, int c2, int c3) {
|
||||
template <class T>
|
||||
__global__ void _add_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||
int c1, int c2, int c3) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
int n = c0 * c1 * c2 * c3;
|
||||
|
@ -56,32 +58,115 @@ __global__ void _pow_kernel(float *x, float *y, float *z, int a0, int a1,
|
|||
int b1_index = c1_index % b1;
|
||||
int b2_index = c2_index % b2;
|
||||
int b3_index = c3_index % b3;
|
||||
z[i] = pow(x[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
||||
a2_index * a3 + a3_index],
|
||||
y[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
||||
b2_index * b3 + b3_index]);
|
||||
((T *)z)[i] = ((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
||||
a2_index * a3 + a3_index] +
|
||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
||||
b2_index * b3 + b3_index];
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void _pow_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||
int c1, int c2, int c3) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
int n = c0 * c1 * c2 * c3;
|
||||
|
||||
for (int i = index; i < n; i += stride) {
|
||||
int c0_index = i / (c1 * c2 * c3);
|
||||
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
|
||||
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
|
||||
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
|
||||
|
||||
int a0_index = c0_index % a0;
|
||||
int a1_index = c1_index % a1;
|
||||
int a2_index = c2_index % a2;
|
||||
int a3_index = c3_index % a3;
|
||||
|
||||
int b0_index = c0_index % b0;
|
||||
int b1_index = c1_index % b1;
|
||||
int b2_index = c2_index % b2;
|
||||
int b3_index = c3_index % b3;
|
||||
((T *)z)[i] =
|
||||
pow(((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
||||
a2_index * a3 + a3_index],
|
||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
||||
b2_index * b3 + b3_index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||
int a3, int b0, int b1, int b2, int b3, int c0,
|
||||
int c1, int c2, int c3) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
int n = c0 * c1 * c2 * c3;
|
||||
|
||||
for (int i = index; i < n; i += stride) {
|
||||
int c0_index = i / (c1 * c2 * c3);
|
||||
int c1_index = (i % (c1 * c2 * c3)) / (c2 * c3);
|
||||
int c2_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) / c3;
|
||||
int c3_index = ((i % (c1 * c2 * c3)) % (c2 * c3)) % c3;
|
||||
|
||||
int a0_index = c0_index % a0;
|
||||
int a1_index = c1_index % a1;
|
||||
int a2_index = c2_index % a2;
|
||||
int a3_index = c3_index % a3;
|
||||
|
||||
int b0_index = c0_index % b0;
|
||||
int b1_index = c1_index % b1;
|
||||
int b2_index = c2_index % b2;
|
||||
int b3_index = c3_index % b3;
|
||||
((bool *)z)[i] =
|
||||
((T *)x)[a0_index * a1 * a2 * a3 + a1_index * a2 * a3 +
|
||||
a2_index * a3 + a3_index] <
|
||||
((T *)y)[b0_index * b1 * b2 * b3 + b1_index * b2 * b3 +
|
||||
b2_index * b3 + b3_index]
|
||||
? true
|
||||
: false;
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void div_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||
void div_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_div_kernel<<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
|
||||
b3, c0, c1, c2, c3);
|
||||
_div_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1,
|
||||
b2, b3, c0, c1, c2, c3);
|
||||
}
|
||||
void pow_kernel(float *a, float *b, float *c, int a0, int a1, int a2, int a3,
|
||||
void add_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_add_kernel<int64_t><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
|
||||
b1, b2, b3, c0, c1, c2, c3);
|
||||
}
|
||||
void pow_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
int blocksize = block_work_size();
|
||||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_pow_kernel<<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1, b2,
|
||||
b3, c0, c1, c2, c3);
|
||||
_pow_kernel<float><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0, b1,
|
||||
b2, b3, c0, c1, c2, c3);
|
||||
}
|
||||
void less_kernel(void *a, void *b, void *c, int a0, int a1, int a2, int a3,
|
||||
int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||
int c3) {
|
||||
int blocksize = block_work_size();
|
||||
int num = c0 * c1 * c2 * c3;
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_less_kernel<int64_t><<<gridsize, blocksize>>>(a, b, c, a0, a1, a2, a3, b0,
|
||||
b1, b2, b3, c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -16,8 +16,9 @@ class PadSliceCudaCompute {
|
|||
metadata.partNDim[i] = partTensor->getDims()[i];
|
||||
metadata.partStride[i] = partTensor->getStride()[i];
|
||||
}
|
||||
pad_slice_kernel(partTensor->getRawDataPtr<float *>(),
|
||||
wholeTensor->getRawDataPtr<float *>(), metadata, nDims,
|
||||
metadata.DType = partTensor->getDType().getIndex();
|
||||
pad_slice_kernel(partTensor->getRawDataPtr<void *>(),
|
||||
wholeTensor->getRawDataPtr<void *>(), metadata, nDims,
|
||||
wholeTensor->size(), isPad);
|
||||
}
|
||||
};
|
||||
|
@ -40,6 +41,8 @@ class SliceCuda : private PadSliceCudaCompute, public CudaKernelWithoutConfig {
|
|||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Float32, SliceCuda,
|
||||
"Slice__CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Slice, DataType::Int64, SliceCuda,
|
||||
"Slice__CUDA_Int64");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Pad, DataType::Float32, PadCuda,
|
||||
"Pad__CUDA_Float32");
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
#include "core/data_type.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/cuda_pad_slice.h"
|
||||
|
||||
|
@ -19,9 +20,9 @@ __device__ int WholeTensorOffset2PartTensorOffset(int wholeOffset,
|
|||
return offset;
|
||||
}
|
||||
|
||||
__global__ void _pad_slice_kernel(float *part, float *whole,
|
||||
TransMetaData metaData, int nDims, int num,
|
||||
bool isPad) {
|
||||
template <typename T>
|
||||
__global__ void _pad_slice_kernel(T *part, T *whole, TransMetaData metaData,
|
||||
int nDims, int num, bool isPad) {
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (tid >= num)
|
||||
return;
|
||||
|
@ -41,12 +42,18 @@ __global__ void _pad_slice_kernel(float *part, float *whole,
|
|||
}
|
||||
|
||||
namespace infini {
|
||||
void pad_slice_kernel(float *partData, float *wholeData,
|
||||
void pad_slice_kernel(void *partData, void *wholeData,
|
||||
const TransMetaData &metadata, int nDims, int num,
|
||||
bool isPad) {
|
||||
int blockSize = 32 * 16;
|
||||
int gridSize = (num + blockSize - 1) / blockSize;
|
||||
_pad_slice_kernel<<<gridSize, blockSize>>>(partData, wholeData, metadata,
|
||||
nDims, num, isPad);
|
||||
if (metadata.DType == DataType::Int64.getIndex()) {
|
||||
_pad_slice_kernel<int64_t>
|
||||
<<<gridSize, blockSize>>>((int64_t *)partData, (int64_t *)wholeData,
|
||||
metadata, nDims, num, isPad);
|
||||
} else if (metadata.DType == DataType::Float32.getIndex()) {
|
||||
_pad_slice_kernel<float><<<gridSize, blockSize>>>(
|
||||
(float *)partData, (float *)wholeData, metadata, nDims, num, isPad);
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -59,6 +59,21 @@ class CudaCompute {
|
|||
class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto inputs = _op->getInputs();
|
||||
if (inputs.size() == 2) {
|
||||
for (size_t i = 0; i < 2; i++) {
|
||||
if (inputs[i]->size() == 0) {
|
||||
auto inData =
|
||||
_op->getInputs(1 - i)->getRawDataPtr<void *>();
|
||||
auto outData =
|
||||
_op->getOutputs()[0]->getRawDataPtr<void *>();
|
||||
cudaMemcpyAsync(outData, inData,
|
||||
_op->getInputs(1 - i)->getBytes(),
|
||||
cudaMemcpyDeviceToDevice);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
do_compute(_op->getOutput(), _op->getInputs(),
|
||||
as<ConcatObj>(_op)->getDim(), _op->getOutput()->getRank(),
|
||||
false);
|
||||
|
|
|
@ -20,15 +20,18 @@ string G2BMMObj::toString() const {
|
|||
return os.str();
|
||||
}
|
||||
|
||||
optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) {
|
||||
auto A = inputs[0], B = inputs[1];
|
||||
b = A->getDims()[0];
|
||||
m = A->getDims()[1];
|
||||
k = A->getDims()[2];
|
||||
|
||||
IT_ASSERT(A->getRank() == 3 && B->getRank() == 3);
|
||||
IT_ASSERT(A->getDims()[0] == B->getDims()[0]);
|
||||
IT_ASSERT(A->getDims()[1] == B->getDims()[1]);
|
||||
IT_ASSERT(A->getDims()[2] == B->getDims()[2]);
|
||||
IT_ASSERT(width >= 0);
|
||||
int b(A->getDims()[0]), m(A->getDims()[1]), n(2 * width + 1);
|
||||
int n(2 * width + 1);
|
||||
return {{{b, m, n}}};
|
||||
}
|
||||
|
||||
|
|
|
@ -21,15 +21,18 @@ string GBMMObj::toString() const {
|
|||
return os.str();
|
||||
}
|
||||
|
||||
optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) {
|
||||
auto A = inputs[0], B = inputs[1];
|
||||
b = A->getDims()[0];
|
||||
m = A->getDims()[1];
|
||||
w = (A->getDims()[2] - 1) / 2;
|
||||
n = B->getDims()[2];
|
||||
|
||||
IT_ASSERT(A->getRank() == 3 && B->getRank() == 3);
|
||||
IT_ASSERT(A->getDims()[0] == B->getDims()[0]);
|
||||
IT_ASSERT(A->getDims()[1] == B->getDims()[1]);
|
||||
IT_ASSERT(A->getDims()[2] % 2 != 0);
|
||||
int b(A->getDims()[0]), m(A->getDims()[1]), k(B->getDims()[2]);
|
||||
return {{{b, m, k}}};
|
||||
return {{{b, m, n}}};
|
||||
}
|
||||
|
||||
vector<int> GBMMObj::getWorkloadVector() const {
|
||||
|
|
|
@ -9,7 +9,7 @@ ActivationBackwardObj::ActivationBackwardObj(OpType type, GraphObj *graph,
|
|||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
ActivationBackwardObj::inferShape(const TensorVec &inputs) const {
|
||||
ActivationBackwardObj::inferShape(const TensorVec &inputs) {
|
||||
return {{inputs[0]->getDims()}};
|
||||
}
|
||||
|
||||
|
|
|
@ -10,8 +10,7 @@ AllGatherObj::AllGatherObj(GraphObj *graph, Tensor input,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
AllGatherObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> AllGatherObj::inferShape(const TensorVec &inputs) {
|
||||
Shape input_shape = inputs[0]->getDims();
|
||||
vector<Shape> output_shapes(getWorldSize(), input_shape);
|
||||
return output_shapes;
|
||||
|
|
|
@ -18,7 +18,7 @@ AttentionKVCacheObj::AttentionKVCacheObj(GraphObj *graph, Tensor input_k_cache,
|
|||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
AttentionKVCacheObj::inferShape(const TensorVec &inputs) const {
|
||||
AttentionKVCacheObj::inferShape(const TensorVec &inputs) {
|
||||
IT_ASSERT(inputs.size() == 6);
|
||||
Shape dims = inputs[0]->getDims();
|
||||
ShapeElem n = dims.at(dim);
|
||||
|
|
|
@ -13,8 +13,7 @@ BatchNormObj::BatchNormObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
BatchNormObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> BatchNormObj::inferShape(const TensorVec &inputs) {
|
||||
auto input = inputs[0];
|
||||
auto mean = inputs[1];
|
||||
auto var = inputs[2];
|
||||
|
|
|
@ -9,9 +9,16 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) {
|
||||
Shape dims = inputs[0]->getDims();
|
||||
auto rank = inputs[0]->getRank();
|
||||
if (inputs.size() == 2) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->size() == 0) {
|
||||
return {{inputs[1 - i]->getDims()}};
|
||||
}
|
||||
}
|
||||
}
|
||||
ShapeElem n = dims.at(dim);
|
||||
for (auto itr = inputs.begin() + 1; itr != inputs.end(); ++itr) {
|
||||
auto input = *itr;
|
||||
|
|
|
@ -82,14 +82,15 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) {
|
||||
const auto &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto h = input->getDims()[2];
|
||||
auto w = input->getDims()[3];
|
||||
auto f = weight->getDims()[0];
|
||||
auto r = weight->getDims()[2];
|
||||
auto s = weight->getDims()[3];
|
||||
n = input->getDims()[0];
|
||||
c = input->getDims()[1];
|
||||
h = input->getDims()[2];
|
||||
w = input->getDims()[3];
|
||||
f = weight->getDims()[0];
|
||||
r = weight->getDims()[2];
|
||||
s = weight->getDims()[3];
|
||||
int on = n, oc = f;
|
||||
int oh = 0, ow = 0;
|
||||
// For NCHW+FCRS layout, C of input is divisable by C of weight
|
||||
|
@ -141,15 +142,15 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
|
|||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
ConvTransposed2dObj::inferShape(const TensorVec &inputs) const {
|
||||
ConvTransposed2dObj::inferShape(const TensorVec &inputs) {
|
||||
const Tensor &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto f = input->getDims()[1];
|
||||
auto h = input->getDims()[2];
|
||||
auto w = input->getDims()[3];
|
||||
auto c = weight->getDims()[1];
|
||||
auto r = weight->getDims()[2];
|
||||
auto s = weight->getDims()[3];
|
||||
n = input->getDims()[0];
|
||||
f = input->getDims()[1];
|
||||
h = input->getDims()[2];
|
||||
w = input->getDims()[3];
|
||||
c = weight->getDims()[1];
|
||||
r = weight->getDims()[2];
|
||||
s = weight->getDims()[3];
|
||||
IT_ASSERT(f == weight->getDims()[0]);
|
||||
|
||||
int on = n, oc = c * group;
|
||||
|
@ -219,14 +220,15 @@ ConvBackwardFilterObj::ConvBackwardFilterObj(GraphObj *graph, Tensor inputX,
|
|||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
ConvBackwardFilterObj::inferShape(const TensorVec &inputs) const {
|
||||
ConvBackwardFilterObj::inferShape(const TensorVec &inputs) {
|
||||
const auto &inputX = inputs[0], &diffY = inputs[1];
|
||||
auto n = inputX->getDims()[0];
|
||||
auto h = inputX->getDims()[2];
|
||||
auto w = inputX->getDims()[3];
|
||||
auto f = diffY->getDims()[0];
|
||||
auto r = diffY->getDims()[2];
|
||||
auto s = diffY->getDims()[3];
|
||||
n = inputX->getDims()[0];
|
||||
c = inputX->getDims()[1];
|
||||
h = inputX->getDims()[2];
|
||||
w = inputX->getDims()[3];
|
||||
f = diffY->getDims()[0];
|
||||
r = diffY->getDims()[2];
|
||||
s = diffY->getDims()[3];
|
||||
int on = n, oc = f;
|
||||
int oh = 0, ow = 0;
|
||||
// For NCHW+FCRS layout, C of input is divisable by C of weight
|
||||
|
@ -280,17 +282,16 @@ ConvTransposed2dNHWCObj::ConvTransposed2dNHWCObj(GraphObj *graph, Tensor input,
|
|||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
ConvTransposed2dNHWCObj::inferShape(const TensorVec &inputs) const {
|
||||
ConvTransposed2dNHWCObj::inferShape(const TensorVec &inputs) {
|
||||
const Tensor &input = inputs[0], &weight = inputs[1];
|
||||
auto n = input->getDims()[0];
|
||||
auto f = input->getDims()[3];
|
||||
auto h = input->getDims()[1];
|
||||
auto w = input->getDims()[2];
|
||||
auto c = weight->getDims()[3];
|
||||
auto r = weight->getDims()[1];
|
||||
auto s = weight->getDims()[2];
|
||||
if (f != weight->getDims()[0])
|
||||
return {};
|
||||
n = input->getDims()[0];
|
||||
f = input->getDims()[3];
|
||||
h = input->getDims()[1];
|
||||
w = input->getDims()[2];
|
||||
c = weight->getDims()[3];
|
||||
r = weight->getDims()[1];
|
||||
s = weight->getDims()[2];
|
||||
IT_ASSERT(f == weight->getDims()[0]);
|
||||
|
||||
int on = n, oc = c * group;
|
||||
int oh = 0, ow = 0;
|
||||
|
|
|
@ -6,7 +6,7 @@ DetObj::DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> DetObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> DetObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
auto input = A->getDims();
|
||||
int rank = A->getRank();
|
||||
|
|
|
@ -10,7 +10,7 @@ DropoutObj::DropoutObj(GraphObj *graph, Tensor data, Tensor output, Tensor mask,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> DropoutObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> DropoutObj::inferShape(const TensorVec &inputs) {
|
||||
auto shape = inputs[0]->getDims();
|
||||
return {{shape, shape}};
|
||||
}
|
||||
|
|
|
@ -8,8 +8,7 @@ ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
ElementWiseObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> ElementWiseObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0], B = inputs[1];
|
||||
auto res = infer_broadcast(A->getDims(), B->getDims());
|
||||
return {{res}};
|
||||
|
@ -45,7 +44,7 @@ MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> MSELossObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> MSELossObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0], B = inputs[1];
|
||||
IT_ASSERT(A->getRank() == B->getRank());
|
||||
IT_ASSERT(A->getDims() == B->getDims());
|
||||
|
|
|
@ -8,7 +8,7 @@ ExpandObj::ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ExpandObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> ExpandObj::inferShape(const TensorVec &inputs) {
|
||||
auto shape_input = inputs[0]->getDims();
|
||||
Shape ret = infer_broadcast(shape_input, dims);
|
||||
return {{ret}};
|
||||
|
|
|
@ -11,7 +11,7 @@ ExtendObj::ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ExtendObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> ExtendObj::inferShape(const TensorVec &inputs) {
|
||||
auto ret = inputs[0]->getDims();
|
||||
ret[dim] = ret[dim] * (num + 1);
|
||||
return {{ret}};
|
||||
|
|
|
@ -10,7 +10,7 @@ GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) {
|
||||
auto dims0 = inputs[0]->getDims();
|
||||
auto dims1 = inputs[1]->getDims();
|
||||
|
||||
|
|
|
@ -24,8 +24,7 @@ bool checkShape(Tensor input, Tensor indices, int axis) {
|
|||
return true;
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
GatherElementsObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> GatherElementsObj::inferShape(const TensorVec &inputs) {
|
||||
IT_ASSERT(checkShape(inputs[0], inputs[1], axis));
|
||||
auto indicesDims = inputs[1]->getDims(); // output has same shape as indices
|
||||
return {{indicesDims}};
|
||||
|
|
|
@ -9,25 +9,6 @@ MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
|
|||
: OperatorObj(OpType::MatMul,
|
||||
bias ? TensorVec{A, B, bias} : TensorVec{A, B}, {C}),
|
||||
transA(transA), transB(transB), act(act), b(1) {
|
||||
auto shape_a = A->getDims();
|
||||
auto shape_b = B->getDims();
|
||||
int rankA = A->getRank();
|
||||
int rankB = B->getRank();
|
||||
IT_ASSERT(rankA >= 2 && rankB >= 2);
|
||||
Shape shape_a1(shape_a.begin(), shape_a.begin() + (rankA - 2));
|
||||
Shape shape_b1(shape_b.begin(), shape_b.begin() + (rankB - 2));
|
||||
auto ret = infer_broadcast(shape_a1, shape_b1);
|
||||
if (ret.empty()) {
|
||||
b = 1;
|
||||
} else {
|
||||
b = std::accumulate(ret.begin(), ret.end(), 1, std::multiplies<int>());
|
||||
}
|
||||
auto kA = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
|
||||
auto kB = *(transB ? shape_b.rbegin() : shape_b.rbegin() + 1);
|
||||
IT_ASSERT(kA == kB);
|
||||
m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
|
||||
n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
|
||||
k = kA;
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -40,7 +21,7 @@ string MatmulObj::toString() const {
|
|||
return os.str();
|
||||
}
|
||||
|
||||
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) {
|
||||
auto A = inputs[0], B = inputs[1];
|
||||
auto shapeA = A->getDims();
|
||||
auto shapeB = B->getDims();
|
||||
|
@ -49,6 +30,17 @@ optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
|
|||
Shape shapeA1(shapeA.begin(), shapeA.begin() + (rankA - 2));
|
||||
Shape shapeB1(shapeB.begin(), shapeB.begin() + (rankB - 2));
|
||||
Shape ret = infer_broadcast(shapeA1, shapeB1);
|
||||
if (ret.empty()) {
|
||||
b = 1;
|
||||
} else {
|
||||
b = std::accumulate(ret.begin(), ret.end(), 1, std::multiplies<int>());
|
||||
}
|
||||
auto kA = *(transA ? shapeA.rbegin() + 1 : shapeA.rbegin());
|
||||
auto kB = *(transB ? shapeB.rbegin() : shapeB.rbegin() + 1);
|
||||
IT_ASSERT(kA == kB);
|
||||
m = *(transA ? shapeA.rbegin() : shapeA.rbegin() + 1);
|
||||
n = *(transB ? shapeB.rbegin() + 1 : shapeB.rbegin());
|
||||
k = kA;
|
||||
ret.emplace_back(m);
|
||||
ret.emplace_back(n);
|
||||
return {{ret}};
|
||||
|
|
|
@ -60,7 +60,7 @@ string MemBoundObj::toString() const {
|
|||
return os.str();
|
||||
}
|
||||
|
||||
optional<vector<Shape>> MemBoundObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> MemBoundObj::inferShape(const TensorVec &inputs) {
|
||||
// inputs have to match nnetInputs excatly
|
||||
if (inputs.size() != nnetInputs.size())
|
||||
return {};
|
||||
|
|
|
@ -22,7 +22,7 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> PadObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> PadObj::inferShape(const TensorVec &inputs) {
|
||||
auto dims = inputs[0]->getDims();
|
||||
int rank = inputs[0]->getRank();
|
||||
IT_ASSERT(rank * 2 == (int)pads.size());
|
||||
|
|
|
@ -12,7 +12,7 @@ PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) {
|
||||
const auto &input = inputs[0];
|
||||
auto h = input->getDims()[input->getRank() - 2],
|
||||
w = input->getDims()[input->getRank() - 1];
|
||||
|
|
|
@ -21,8 +21,7 @@ bool ReduceMeanObj::isReduced(int idx) const {
|
|||
return axes.find(idx) != axes.end();
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
ReduceMeanObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> ReduceMeanObj::inferShape(const TensorVec &inputs) {
|
||||
auto dims = inputs[0]->getDims();
|
||||
auto rank = inputs[0]->getRank();
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "operators/reshape.h"
|
||||
#include "utils/operator_utils.h"
|
||||
#include <numeric>
|
||||
|
||||
namespace infini {
|
||||
ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
|
||||
|
@ -7,14 +8,37 @@ ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) const {
|
||||
size_t size = 1;
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
size *= dims.at(i);
|
||||
optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) {
|
||||
int count = 0;
|
||||
for (auto x : dims) {
|
||||
if (x == -1) {
|
||||
count++;
|
||||
}
|
||||
IT_ASSERT(x == -1 || x >= 0);
|
||||
}
|
||||
IT_ASSERT(size == inputs[0]->size());
|
||||
IT_ASSERT(count == 0 || count == 1);
|
||||
auto inputShape = inputs[0]->getDims();
|
||||
int size = inputs[0]->size();
|
||||
int index = -1;
|
||||
outputShape = dims;
|
||||
for (int i = 0; i < (int)dims.size(); ++i) {
|
||||
if (dims[i] == 0) {
|
||||
outputShape[i] = inputShape[i];
|
||||
}
|
||||
if (dims[i] == -1) {
|
||||
index = i;
|
||||
}
|
||||
}
|
||||
if (index != -1) {
|
||||
outputShape[index] =
|
||||
size / (-std::accumulate(outputShape.begin(), outputShape.end(), 1,
|
||||
[](auto acc, auto x) { return acc * x; }));
|
||||
}
|
||||
int outputSize = std::accumulate(outputShape.begin(), outputShape.end(), 1,
|
||||
[](auto acc, auto x) { return acc * x; });
|
||||
IT_ASSERT(outputSize == size);
|
||||
|
||||
return {{dims}};
|
||||
return {{outputShape}};
|
||||
}
|
||||
|
||||
std::string ReshapeObj::toString() const {
|
||||
|
@ -22,7 +46,7 @@ std::string ReshapeObj::toString() const {
|
|||
os << "Reshape[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "dims=" << vecToString(dims) << ",";
|
||||
os << "outputShape=" << vecToString(outputShape) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
|
@ -30,12 +54,12 @@ std::string ReshapeObj::toString() const {
|
|||
|
||||
vector<int> ReshapeObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.insert(ret.end(), dims.begin(), dims.end());
|
||||
ret.insert(ret.end(), outputShape.begin(), outputShape.end());
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
vector<int> ReshapeObj::getOpAttrVector() const {
|
||||
vector<int> ret = dims;
|
||||
vector<int> ret = outputShape;
|
||||
ret.emplace(ret.begin(), type.underlying());
|
||||
return ret;
|
||||
}
|
||||
|
@ -47,7 +71,7 @@ FlattenObj::FlattenObj(GraphObj *graph, Tensor input, Tensor output, int _axis)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> FlattenObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> FlattenObj::inferShape(const TensorVec &inputs) {
|
||||
int sizeB = 1, sizeE = 1;
|
||||
auto dims = getInputs(0)->getDims();
|
||||
int rank = getInputs(0)->getRank();
|
||||
|
@ -84,7 +108,7 @@ IdentityObj::IdentityObj(GraphObj *graph, Tensor input, Tensor output)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> IdentityObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> IdentityObj::inferShape(const TensorVec &inputs) {
|
||||
return {{getInputs(0)->getDims()}};
|
||||
}
|
||||
|
||||
|
|
|
@ -206,7 +206,7 @@ float ResizeObj::round_int(float x) const {
|
|||
}
|
||||
|
||||
// output shape is related to sizes/scales value.
|
||||
optional<vector<Shape>> ResizeObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> ResizeObj::inferShape(const TensorVec &inputs) {
|
||||
auto inDims = inputs[0]->getDims();
|
||||
Shape ret = inDims;
|
||||
int rank = inputs[0]->getRank();
|
||||
|
|
|
@ -62,7 +62,7 @@ SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> SliceObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> SliceObj::inferShape(const TensorVec &inputs) {
|
||||
Shape ans;
|
||||
ans.reserve(axes.size());
|
||||
for (const auto &range : axes) {
|
||||
|
|
|
@ -35,7 +35,7 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> SplitObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> SplitObj::inferShape(const TensorVec &inputs) {
|
||||
IT_ASSERT(num != -1 && ratio.size() != 0);
|
||||
auto inputDims = inputs[0]->getDims();
|
||||
int totalSize = inputDims.at(dim);
|
||||
|
|
|
@ -16,8 +16,7 @@ TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
TransposeObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> TransposeObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
auto input_dim = A->getDims();
|
||||
auto output_dim = input_dim;
|
||||
|
@ -66,8 +65,7 @@ DepthToSpaceObj::DepthToSpaceObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
DepthToSpaceObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> DepthToSpaceObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
auto inputDim = A->getDims();
|
||||
IT_ASSERT(inputDim.size() == 4);
|
||||
|
|
|
@ -6,7 +6,7 @@ UnaryObj::UnaryObj(OpType type, GraphObj *graph, Tensor input, Tensor output)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> UnaryObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> UnaryObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
return {{A->getDims()}};
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ClipObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> ClipObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
return {{A->getDims()}};
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ HardtanhObj::HardtanhObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> HardtanhObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> HardtanhObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
return {{A->getDims()}};
|
||||
}
|
||||
|
@ -97,7 +97,7 @@ FillObj::FillObj(GraphObj *graph, Tensor input, Tensor output, float value)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> FillObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> FillObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
return {{A->getDims()}};
|
||||
}
|
||||
|
@ -124,7 +124,7 @@ L2LossObj::L2LossObj(GraphObj *graph, Tensor input, Tensor output)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> L2LossObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> L2LossObj::inferShape(const TensorVec &inputs) {
|
||||
Shape temp = {1};
|
||||
return {{temp}};
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ vector<DataType> CastObj::inferDataType(const TensorVec &inputs) const {
|
|||
return vector(numOutputs(), output_dataType);
|
||||
}
|
||||
|
||||
optional<vector<Shape>> CastObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> CastObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
return {{A->getDims()}};
|
||||
}
|
||||
|
@ -241,7 +241,7 @@ ShapeObj::ShapeObj(GraphObj *graph, Tensor input, Tensor output)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> ShapeObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> ShapeObj::inferShape(const TensorVec &inputs) {
|
||||
return {{{static_cast<int>(inputs[0]->getRank())}}};
|
||||
}
|
||||
|
||||
|
@ -257,7 +257,7 @@ PReluObj::PReluObj(GraphObj *graph, Tensor input, Tensor alpha, Tensor output)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> PReluObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> PReluObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
return {{A->getDims()}};
|
||||
}
|
||||
|
@ -286,7 +286,7 @@ LogObj::LogObj(GraphObj *graph, Tensor input, Tensor output, LogType type)
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> LogObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> LogObj::inferShape(const TensorVec &inputs) {
|
||||
const auto A = inputs[0];
|
||||
return {{A->getDims()}};
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ WhereObj::WhereObj(GraphObj *graph, Tensor inputX, Tensor inputY,
|
|||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> WhereObj::inferShape(const TensorVec &inputs) const {
|
||||
optional<vector<Shape>> WhereObj::inferShape(const TensorVec &inputs) {
|
||||
auto shapeX = inputs[0]->getDims();
|
||||
auto shapeY = inputs[1]->getDims();
|
||||
auto shapeCon = inputs[2]->getDims();
|
||||
|
|
|
@ -158,4 +158,33 @@ TEST(Concat, CudaHigh) {
|
|||
12., 13., 14., 15., 16., 17., 1., 1., 1., 1., 1., 1.,
|
||||
18., 19., 20., 21., 22., 23., 1., 1., 1., 1., 1., 1.}));
|
||||
}
|
||||
|
||||
TEST(ConcatToIdentity, Cuda) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto t1 = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32);
|
||||
auto t2 = gCpu->addTensor({0}, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
t1->setData(IncrementalGenerator());
|
||||
t2->setData(OneGenerator());
|
||||
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||
|
||||
auto t1Gpu = gCuda->cloneTensor(t1);
|
||||
auto t2Gpu = gCuda->cloneTensor(t2);
|
||||
|
||||
auto op = gCuda->addOp<ConcatObj>(TensorVec{t1Gpu, t2Gpu}, nullptr, 2);
|
||||
gCuda->dataMalloc();
|
||||
t1Gpu->setData(IncrementalGenerator());
|
||||
t2Gpu->setData(OneGenerator());
|
||||
cudaRuntime->run(gCuda);
|
||||
|
||||
// cudaPrintTensor(op->getOutput());
|
||||
// copy output from CUDA to CPU
|
||||
auto oCpu = gCpu->cloneTensor(op->getOutput());
|
||||
EXPECT_TRUE(
|
||||
oCpu->equalData(vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}));
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -14,4 +14,13 @@ TEST(Concat, ShapeInfer) {
|
|||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 9}));
|
||||
}
|
||||
|
||||
TEST(Concat, ShapeInfer1) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
auto t1 = g->addTensor({1, 3, 2, 4}, DataType::Float32);
|
||||
auto t2 = g->addTensor({0}, DataType::Float32);
|
||||
|
||||
auto op = g->addOp<ConcatObj>(TensorVec{t1, t2}, nullptr, 3);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4}));
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
Loading…
Reference in New Issue