forked from jiuyuan/InfiniTensor
Compare commits
1 Commits
master
...
train_wang
Author | SHA1 | Date |
---|---|---|
![]() |
7382a94243 |
|
@ -0,0 +1,21 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/tensor.h"
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class Convolution {
|
||||||
|
private:
|
||||||
|
Tensor input;
|
||||||
|
Tensor weight;
|
||||||
|
Tensor output;
|
||||||
|
Tensor dInput;
|
||||||
|
Tensor dWeight;
|
||||||
|
Tensor dOutput;
|
||||||
|
public:
|
||||||
|
Convolution(Tensor input_, int pad, int window, int stride, int num);
|
||||||
|
Tensor forward();
|
||||||
|
Tensor backward();
|
||||||
|
};
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
#include "core/layer.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
Convolution::Convolution(Tensor input_,
|
||||||
|
int pad,
|
||||||
|
int window,
|
||||||
|
int stride,
|
||||||
|
int num) {
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
// layout NCHW
|
||||||
|
input = input_;
|
||||||
|
Shape inputShape = input_->getDims();
|
||||||
|
int inputN = inputShape[0];
|
||||||
|
int inputC = inputShape[1];
|
||||||
|
int inputH = inputShape[2];
|
||||||
|
int inputW = inputShape[3];
|
||||||
|
Shape weightShape = {num, inputC, window, window};
|
||||||
|
weight = make_ref<TensorObj>(weightShape, DataType::Float32, cpuRuntime);
|
||||||
|
weight->dataMalloc();
|
||||||
|
int outputN = inputN;
|
||||||
|
int outputC = num;
|
||||||
|
int outputH = (inputH + 2 * pad - window) / stride + 1;
|
||||||
|
int outputW = (inputW + 2 * pad - window) / stride + 1;
|
||||||
|
Shape outputShape = {outputN, outputC, outputH, outputW};
|
||||||
|
output = make_ref<TensorObj>(outputShape, DataType::Float32, cpuRuntime);
|
||||||
|
output->dataMalloc();
|
||||||
|
// backward
|
||||||
|
dInput = make_ref<TensorObj>(inputShape, DataType::Float32, cpuRuntime);
|
||||||
|
dInput->dataMalloc();
|
||||||
|
dWeight = make_ref<TensorObj>(weightShape, DataType::Float32, cpuRuntime);
|
||||||
|
dWeight->dataMalloc();
|
||||||
|
dOutput = make_ref<TensorObj>(outputShape, DataType::Float32, cpuRuntime);
|
||||||
|
dOutput->dataMalloc();
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor Convolution::forward() {
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor Convolution::backward() {
|
||||||
|
return dInput;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue