Compare commits

...

1 Commits

Author SHA1 Message Date
wanghailu 7382a94243 add code for training solution 2022-10-10 17:11:41 +08:00
2 changed files with 65 additions and 0 deletions

21
include/core/layer.h Normal file
View File

@ -0,0 +1,21 @@
#pragma once
#include "core/tensor.h"
#include "core/graph.h"
#include "core/runtime.h"
#include "core/kernel.h"
namespace infini {
class Convolution {
private:
Tensor input;
Tensor weight;
Tensor output;
Tensor dInput;
Tensor dWeight;
Tensor dOutput;
public:
Convolution(Tensor input_, int pad, int window, int stride, int num);
Tensor forward();
Tensor backward();
};
}

44
src/core/layer.cc Normal file
View File

@ -0,0 +1,44 @@
#include "core/layer.h"
namespace infini {
Convolution::Convolution(Tensor input_,
int pad,
int window,
int stride,
int num) {
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
// layout NCHW
input = input_;
Shape inputShape = input_->getDims();
int inputN = inputShape[0];
int inputC = inputShape[1];
int inputH = inputShape[2];
int inputW = inputShape[3];
Shape weightShape = {num, inputC, window, window};
weight = make_ref<TensorObj>(weightShape, DataType::Float32, cpuRuntime);
weight->dataMalloc();
int outputN = inputN;
int outputC = num;
int outputH = (inputH + 2 * pad - window) / stride + 1;
int outputW = (inputW + 2 * pad - window) / stride + 1;
Shape outputShape = {outputN, outputC, outputH, outputW};
output = make_ref<TensorObj>(outputShape, DataType::Float32, cpuRuntime);
output->dataMalloc();
// backward
dInput = make_ref<TensorObj>(inputShape, DataType::Float32, cpuRuntime);
dInput->dataMalloc();
dWeight = make_ref<TensorObj>(weightShape, DataType::Float32, cpuRuntime);
dWeight->dataMalloc();
dOutput = make_ref<TensorObj>(outputShape, DataType::Float32, cpuRuntime);
dOutput->dataMalloc();
}
Tensor Convolution::forward() {
return output;
}
Tensor Convolution::backward() {
return dInput;
}
} // namespace infini