forked from jiuyuan/InfiniTensor
ADD:extend operator and cuda kernel. (#40)
Co-authored-by: Haojie Wang <haojie0429@gmail.com>
This commit is contained in:
parent
fe14c91f54
commit
26cee55e81
|
@ -0,0 +1,23 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class ExtendObj : public OperatorObj {
|
||||||
|
int dim, num; // copy num times at the dim.
|
||||||
|
|
||||||
|
public:
|
||||||
|
ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
|
||||||
|
int num = 1);
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
int getDim() const { return dim; }
|
||||||
|
int getNum() const { return num; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,27 @@
|
||||||
|
#include "operators/extend.h"
|
||||||
|
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter,
|
||||||
|
int oSize);
|
||||||
|
class ExtendCuda : public CudaKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<ExtendObj>(_op);
|
||||||
|
auto inData = op->getInputs(0)->getRawDataPtr<float *>();
|
||||||
|
auto outData = op->getOutputs()[0]->getRawDataPtr<float *>();
|
||||||
|
int blockSize = 1;
|
||||||
|
auto iDim = op->getInputs(0)->getDims();
|
||||||
|
for (size_t i = iDim.size() - 1;
|
||||||
|
i >= (size_t)op->getDim() && i != (size_t)-1; --i)
|
||||||
|
blockSize *= iDim[i];
|
||||||
|
auto blockSizeOuter = (op->getNum() + 1) * blockSize;
|
||||||
|
|
||||||
|
extend_kernel(inData, outData, blockSize, blockSizeOuter,
|
||||||
|
op->getOutput()->size());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Extend, DataType::Float32, ExtendCuda,
|
||||||
|
"Extend_CUDA_Float32");
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,25 @@
|
||||||
|
#include "cuda/cuda_common.h"
|
||||||
|
|
||||||
|
__global__ void _extend_kernel(float *in, float *out, int blockSize,
|
||||||
|
int blockSizeOuter, int oSize) {
|
||||||
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (index >= oSize)
|
||||||
|
return;
|
||||||
|
|
||||||
|
int stride = blockDim.x * gridDim.x;
|
||||||
|
while (index < oSize) {
|
||||||
|
auto iIdx = index % blockSize + index / blockSizeOuter * blockSize;
|
||||||
|
out[index] = in[iIdx];
|
||||||
|
index += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
void extend_kernel(float *in, float *out, int blockSize, int blockSizeOuter,
|
||||||
|
int oSize) {
|
||||||
|
int blocksize = 32 * 16;
|
||||||
|
int gridsize = (oSize + blocksize - 1) / blocksize;
|
||||||
|
_extend_kernel<<<blocksize, gridsize>>>(in, out, blockSize, blockSizeOuter,
|
||||||
|
oSize);
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,41 @@
|
||||||
|
#include "operators/extend.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
ExtendObj::ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
|
||||||
|
int num)
|
||||||
|
: OperatorObj(OpType::Extend, {input}, {output}), dim(dim), num(num) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> ExtendObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
auto ret = inputs[0]->getDims();
|
||||||
|
IT_ASSERT((size_t)dim < ret.size());
|
||||||
|
ret[dim] = ret[dim] * (num + 1);
|
||||||
|
return {{ret}};
|
||||||
|
}
|
||||||
|
std::string ExtendObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Extend[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << vecToString(inputs[0]->getDims()) << ",";
|
||||||
|
os << "dim=" << dim << ",";
|
||||||
|
os << "num=" << num << ",";
|
||||||
|
os << "input=" << inputs[0]->getGuid() << ",";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> ExtendObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret = inputs[0]->getDims();
|
||||||
|
ret.emplace_back(dim);
|
||||||
|
ret.emplace_back(num);
|
||||||
|
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> ExtendObj::getOpAttrVector() const {
|
||||||
|
return {enum_to_underlying(type), dim, num};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,43 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/extend.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
TEST(CUDA_Extend, run) {
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor icpu =
|
||||||
|
make_ref<TensorObj>(Shape{2, 3, 2, 2}, DataType::Float32, cpuRuntime);
|
||||||
|
icpu->dataMalloc();
|
||||||
|
icpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
// Build CUDA graph
|
||||||
|
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto i = g->cloneTensor(icpu);
|
||||||
|
auto op = g->addOp<ExtendObj>(i, nullptr, 1, 1);
|
||||||
|
|
||||||
|
// allocate CUDA memory
|
||||||
|
g->dataMalloc();
|
||||||
|
|
||||||
|
// Execute on CUDA
|
||||||
|
cudaRuntime->run(g);
|
||||||
|
|
||||||
|
// clone CUDA output to CPU
|
||||||
|
auto o = op->getOutput();
|
||||||
|
auto ocpu = o->clone(cpuRuntime);
|
||||||
|
|
||||||
|
// check results on CPU
|
||||||
|
EXPECT_TRUE(ocpu->equalData(vector<float>{
|
||||||
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3,
|
||||||
|
4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
||||||
|
20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}));
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,20 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/extend.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
TEST(Extend, ShapeInference) {
|
||||||
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
|
{
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||||
|
auto op = g->addOp<ExtendObj>(i, nullptr, 2, 1);
|
||||||
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 6, 4}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue