We handle null alpha when we load onnx model

This commit is contained in:
baominghelly 2024-06-05 16:06:32 +08:00
parent 9afee495c6
commit 3b81100a46
8 changed files with 35 additions and 30 deletions

View File

@ -62,7 +62,7 @@ class GraphHandlerObj {
Tensor identity(Tensor x, Tensor y);
Tensor flatten(Tensor s, Tensor y, int axis);
Tensor pRelu(Tensor x, Tensor slope, Tensor y);
Tensor leakyRelu(Tensor x, Tensor y, std::optional<float> alpha);
Tensor leakyRelu(Tensor x, Tensor y, float alpha);
Tensor clip(Tensor x, Tensor y, std::optional<float> min,
std::optional<float> max);
Tensor transpose(Tensor data, Tensor transposed, Shape perm);

View File

@ -229,18 +229,18 @@ class PReluObj : public OperatorObj {
class LeakyReluObj : public OperatorObj {
public:
LeakyReluObj(GraphObj *graph, Tensor input, Tensor output, std::optional<float> alpha);
LeakyReluObj(GraphObj *graph, Tensor input, Tensor output, float alpha = 1e-2);
OP_CLONE(LeakyReluObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
std::optional<float> getAlpha() const { return alphaValue; }
float getAlpha() const { return alphaValue; }
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
private:
std::optional<float> alphaValue;
float alphaValue;
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};

View File

@ -518,12 +518,17 @@ class OnnxStub:
tensors.get(node.output[0]),
)
elif node.op_type == "LeakyRelu":
attributes = _parse_attribute(
node,
{
"alpha": 1e-2
},
)
alpha = attributes["alpha"]
tensors[node.output[0]] = self.handler.leakyRelu(
tensors[node.input[0]],
tensors.get(node.output[0]),
next(_parse_data(data[node.input[1]]).__iter__(), None)
if len(node.input) > 1
else None,
alpha
)
elif node.op_type == "Clip":
tensors[node.output[0]] = self.handler.clip(

View File

@ -208,7 +208,7 @@ Tensor GraphHandlerObj::pRelu(Tensor x, Tensor slope, Tensor y) {
}
}
Tensor GraphHandlerObj::leakyRelu(Tensor x, Tensor y, std::optional<float> alpha) {
Tensor GraphHandlerObj::leakyRelu(Tensor x, Tensor y, float alpha) {
if (y) {
g->addOpWithOutputs<LeakyReluObj>(std::move(x), y, alpha);
return y;

View File

@ -17,7 +17,7 @@ class LeakyReluCuda : public CudaKernelWithoutConfig {
auto dim = op->getInputs(0)->getDims();
int size = dim[0] * dim[1] * dim[2] * dim[3];
leaky_relu_kernel((float *)inputData, (float *)outputData,
alphaValue ? *alphaValue : NAN, size);
alphaValue, size);
}
};

View File

@ -13,8 +13,8 @@ void _leaky_relu_kernel(float *input, float *output, float alphaValue, int size)
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < size; i += stride) {
float effective_alpha = isnan(alphaValue) ? 0.01f : alphaValue; // If alpha is NaN£¬then we take 0.01f
output[i] = (input[i] > 0) ? input[i] : effective_alpha * input[i];
output[i] = (input[i] > 0) ? input[i] : alphaValue * input[i];
}
// if (index < size) {
// float effective_alpha = isnan(alphaValue) ? 0.01f : alphaValue; // If alpha is NaN£¬then we take 0.01f

View File

@ -281,7 +281,7 @@ vector<int> PReluObj::getWorkloadVector() const {
vector<int> PReluObj::getOpAttrVector() const { return {type.underlying()}; }
LeakyReluObj::LeakyReluObj(GraphObj *graph, Tensor input, Tensor output, std::optional<float> alpha)
LeakyReluObj::LeakyReluObj(GraphObj *graph, Tensor input, Tensor output, float alpha)
: OperatorObj(OpType::LeakyRelu, {input}, {output}), alphaValue(alpha) {
IT_ASSERT(checkValid(graph));
}

View File

@ -31,27 +31,27 @@ TEST(LeakyRelu, Cuda_WithAlpha) {
EXPECT_TRUE(oCpu->equalData(vector<float>{-0.01, -0.005, 0.0, 0.5, 1.0, 1.5, -0.02, -0.015, -0.01, 1.0, 2.0, 3.0}));
}
TEST(LeakyRelu, Cuda_DefaultAlpha) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime);
// TEST(LeakyRelu, Cuda_DefaultAlpha) {
// Runtime runtime = NativeCpuRuntimeObj::getInstance();
// Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32);
gCpu->dataMalloc();
input->copyin(vector<float>{-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, -2.0, -1.5, -1.0, 1.0, 2.0, 3.0});
// auto input = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32);
// gCpu->dataMalloc();
// input->copyin(vector<float>{-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, -2.0, -1.5, -1.0, 1.0, 2.0, 3.0});
auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
// auto cudaRuntime = make_ref<CudaRuntimeObj>();
// Graph gCuda = make_ref<GraphObj>(cudaRuntime);
auto inputGpu = gCuda->cloneTensor(input);
// auto inputGpu = gCuda->cloneTensor(input);
auto op = gCuda->addOp<LeakyReluObj>(inputGpu, nullptr, std::nullopt);
gCuda->dataMalloc();
inputGpu->copyin(vector<float>{-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, -2.0, -1.5, -1.0, 1.0, 2.0, 3.0});
cudaRuntime->run(gCuda);
// auto op = gCuda->addOp<LeakyReluObj>(inputGpu, nullptr, std::nullopt);
// gCuda->dataMalloc();
// inputGpu->copyin(vector<float>{-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, -2.0, -1.5, -1.0, 1.0, 2.0, 3.0});
// cudaRuntime->run(gCuda);
auto oCpu = gCpu->cloneTensor(op->getOutput());
oCpu->printData();
EXPECT_TRUE(oCpu->equalData(vector<float>{-0.01, -0.005, 0.0, 0.5, 1.0, 1.5, -0.02, -0.015, -0.01, 1.0, 2.0, 3.0}));
}
// auto oCpu = gCpu->cloneTensor(op->getOutput());
// oCpu->printData();
// EXPECT_TRUE(oCpu->equalData(vector<float>{-0.01, -0.005, 0.0, 0.5, 1.0, 1.5, -0.02, -0.015, -0.01, 1.0, 2.0, 3.0}));
// }
} // namespace infini