From 7382a94243034f548cff86751afd8784642ee5d4 Mon Sep 17 00:00:00 2001 From: wanghailu Date: Mon, 10 Oct 2022 17:11:41 +0800 Subject: [PATCH] add code for training solution --- include/core/layer.h | 21 +++++++++++++++++++++ src/core/layer.cc | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 include/core/layer.h create mode 100644 src/core/layer.cc diff --git a/include/core/layer.h b/include/core/layer.h new file mode 100644 index 00000000..c5391b08 --- /dev/null +++ b/include/core/layer.h @@ -0,0 +1,21 @@ +#pragma once +#include "core/tensor.h" +#include "core/graph.h" +#include "core/runtime.h" +#include "core/kernel.h" + +namespace infini { + class Convolution { + private: + Tensor input; + Tensor weight; + Tensor output; + Tensor dInput; + Tensor dWeight; + Tensor dOutput; + public: + Convolution(Tensor input_, int pad, int window, int stride, int num); + Tensor forward(); + Tensor backward(); + }; +} diff --git a/src/core/layer.cc b/src/core/layer.cc new file mode 100644 index 00000000..824892ba --- /dev/null +++ b/src/core/layer.cc @@ -0,0 +1,44 @@ +#include "core/layer.h" + +namespace infini { + Convolution::Convolution(Tensor input_, + int pad, + int window, + int stride, + int num) { + Runtime cpuRuntime = CpuRuntimeObj::getInstance(); + // layout NCHW + input = input_; + Shape inputShape = input_->getDims(); + int inputN = inputShape[0]; + int inputC = inputShape[1]; + int inputH = inputShape[2]; + int inputW = inputShape[3]; + Shape weightShape = {num, inputC, window, window}; + weight = make_ref(weightShape, DataType::Float32, cpuRuntime); + weight->dataMalloc(); + int outputN = inputN; + int outputC = num; + int outputH = (inputH + 2 * pad - window) / stride + 1; + int outputW = (inputW + 2 * pad - window) / stride + 1; + Shape outputShape = {outputN, outputC, outputH, outputW}; + output = make_ref(outputShape, DataType::Float32, cpuRuntime); + output->dataMalloc(); + // backward + dInput = make_ref(inputShape, DataType::Float32, cpuRuntime); + dInput->dataMalloc(); + dWeight = make_ref(weightShape, DataType::Float32, cpuRuntime); + dWeight->dataMalloc(); + dOutput = make_ref(outputShape, DataType::Float32, cpuRuntime); + dOutput->dataMalloc(); + } + + Tensor Convolution::forward() { + return output; + } + + Tensor Convolution::backward() { + return dInput; + } + +} // namespace infini