modified tensor.h

This commit is contained in:
xgqdut2016 2024-02-23 02:19:30 +00:00
parent 66d98a3f04
commit e5d5085e6a
4 changed files with 57 additions and 37 deletions

View File

@ -145,14 +145,17 @@ class TensorObj : public TensorBaseObj {
void printData() const;
bool equalData(const Tensor &rhs, double relativeError = 1e-6) const;
template <typename T> bool equalData(const vector<T> &dataVector) {
template <typename T>
bool equalData(const vector<T> &dataVector, double relativeError = 1e-6) {
IT_ASSERT(size() == dataVector.size());
if (dtype == DataType::Float16) {
return equalDataImpl_fp16(getRawDataPtr<uint16_t *>(),
(float *)dataVector.data(), size());
(float *)dataVector.data(), size(),
relativeError);
}
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
return equalDataImpl(getRawDataPtr<T *>(), dataVector.data(), size());
return equalDataImpl(getRawDataPtr<T *>(), dataVector.data(), size(),
relativeError);
}
size_t getOffsetByBroadcastOffset(size_t bcOffset, Shape bcShape) const;
@ -198,24 +201,34 @@ class TensorObj : public TensorBaseObj {
if (a[i] != b[i])
return false;
} else if constexpr (std::is_floating_point_v<T>) {
if (fabs(a[i] - b[i]) / std::max(fabs(a[i]), fabs(b[i])) >
relativeError) {
printf("Error on %lu: %f %f\n", i, a[i], b[i]);
return false;
if (fabs(b[i]) < 1e-6) {
if (fabs(a[i] - b[i]) > relativeError) {
printf("Error on %lu: %f %f\n", i, a[i], b[i]);
return false;
}
} else {
if (fabs(a[i] - b[i]) /
(std::max(fabs(a[i]), fabs(b[i])) + 1e-6) >
relativeError) {
printf("Error on %lu: %f %f\n", i, a[i], b[i]);
return false;
}
}
} else
static_assert(!sizeof(T), "Unsupported data type");
}
return true;
}
bool equalDataImpl_fp16(const uint16_t *a, const float *b,
size_t size) const {
bool equalDataImpl_fp16(const uint16_t *a, const float *b, size_t size,
double relativeError = 1e-6) const {
for (size_t i = 0; i < size; ++i) {
auto a_fp32 = fp16_to_float(a[i]);
auto b_fp32 = b[i];
if (fabs(a_fp32 - b_fp32) / std::max(fabs(a_fp32), fabs(b_fp32)) >
1e-6) {
if (fabs(a_fp32 - b_fp32) /
(std::max(fabs(a_fp32), fabs(b_fp32)) + 1e-6) >
relativeError) {
printf("Error on %lu: %f %f\n", i, a_fp32, b_fp32);
return false;
}

View File

@ -1,8 +1,8 @@
#ifndef BANG_KERNELS_DIVOPERATION_DIV_H_
#define BANG_KERNELS_DIVOPERATION_DIV_H_
#ifndef BANG_KERNELS_SOFTMAXOPERATION_SOFTMAX_H_
#define BANG_KERNELS_SOFTMAXOPERATION_SOFTMAX_H_
__mlu_global__ void softmaxUnion1(float *mlu_destination, float *mlu_src,
int nDim, int axis, int othersize,
int frontsize, int dimsize, int stride);
#endif // BANG_KERNELS_DIVOPERATION_DIV_H_
#endif // BANG_KERNELS_SOFTMAXOPERATION_SOFTMAX_H_

View File

@ -45,7 +45,7 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
}
__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//计算1/sum
//开始指数变换并且写回GDRAM
__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
__memcpy(destination + (dimsize - 1) * stride + frontIdx + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
@ -75,7 +75,7 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
}
//-------------------
__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//计算1/sum
//开始指数变换并且写回GDRAM
__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
__memcpy(destination + (dimsize - 1) * stride + frontIdx + repeat * maxNum, src, remain * sizeof(float), NRAM2GDRAM);
@ -157,7 +157,7 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
//此时tmpNewMax存储的是对应于固定frontIdxbehindsize对应数据的最大值而tmpSum存储的就是对应数值和
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
__bang_active_recip(tmpSum, tmpSum, strideS);
__bang_active_reciphp(tmpSum, tmpSum, strideS);
//__bang_printf("tmpOldMax[0]:%.2f,tmpSum[0]:%.2f\n", tmpNewMax[2],tmpSum[2]);
if(remain){
for(int m = 0; m < remain; m++){
@ -225,7 +225,7 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
__bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
}
__bang_active_recip(tmpSum, tmpSum, strideS);
__bang_active_reciphp(tmpSum, tmpSum, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS);//上面循环结束tmp存储的数据可以利用
//__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
__memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);
@ -262,7 +262,7 @@ __mlu_device__ void softmaxKernelAxis_m(float* destination, float* source, int f
__memcpy(tmpOldMax, tmpNewMax, stride * sizeof(float), NRAM2NRAM);//oldM = newM
}
//__bang_printf("max:%.2f,%.2f, sum:%.2f,sum:%.2f\n", tmpNewMax[0], tmpNewMax[1], tmpSum[0], tmpSum[0]);
__bang_active_recip(tmpSum, tmpSum, strideS);
__bang_active_reciphp(tmpSum, tmpSum, strideS);
__bang_mul(tmp, tmp, tmpSum, strideS);//上面循环结束tmp存储的数据可以利用
//__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2GDRAM);
__memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(float), NRAM2NRAM);
@ -473,7 +473,7 @@ __mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int o
__bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
}
__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//计算1/sum
//开始指数变换并且写回GDRAM
__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
__memcpy(destination + (dimsize - 1) * stride + indStart + j * maxNum, src, maxNum * sizeof(float), NRAM2GDRAM);
@ -505,7 +505,7 @@ __mlu_device__ void softmaxKernelAxis_s(float* destination, float* source, int o
__memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(float), NRAM2NRAM);//oldM = newM
}
__bang_active_recip(tmpSum, tmpSum, maxNum);//计算1/sum
__bang_active_reciphp(tmpSum, tmpSum, maxNum);//计算1/sum
//开始指数变换并且写回GDRAM
__bang_mul(src, src, tmpSum, maxNum);//上面循环结束src存储的数据可以利用
__memcpy(destination + (dimsize - 1) * stride + indStart + repeat * maxNum, src, remainNram * sizeof(float), NRAM2GDRAM);

View File

@ -6,7 +6,7 @@
#include "test.h"
#include <cmath>
namespace infini {
double eps = 3e-3;
TEST(cuDNN_Softmax, run_axis1) {
// Runtime
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
@ -28,7 +28,8 @@ TEST(cuDNN_Softmax, run_axis1) {
// Check
EXPECT_TRUE(outputGpu2Cpu->equalData(
vector<float>{0.032058604, 0.08714432, 0.23688284, 0.6439143,
0.032058604, 0.08714432, 0.23688284, 0.6439143}));
0.032058604, 0.08714432, 0.23688284, 0.6439143},
eps));
}
TEST(cuDNN_Softmax, run_axis0) {
@ -50,8 +51,8 @@ TEST(cuDNN_Softmax, run_axis0) {
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// Check
EXPECT_TRUE(
outputGpu2Cpu->equalData(vector<float>{0., 0., 0., 0., 1, 1, 1, 1}));
EXPECT_TRUE(outputGpu2Cpu->equalData(
vector<float>{0., 0., 0., 0., 1, 1, 1, 1}, eps));
}
TEST(cuDNN_Softmax2, run_axis1) {
@ -73,10 +74,12 @@ TEST(cuDNN_Softmax2, run_axis1) {
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// Check
EXPECT_TRUE(outputGpu2Cpu->equalData(vector<float>{
0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138, 0.9820138,
0.9820138, 0.9820138, 0.0179862, 0.0179862, 0.0179862, 0.0179862,
0.9820138, 0.9820138, 0.9820138, 0.9820138}));
EXPECT_TRUE(outputGpu2Cpu->equalData(
vector<float>{0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138,
0.9820138, 0.9820138, 0.9820138, 0.0179862, 0.0179862,
0.0179862, 0.0179862, 0.9820138, 0.9820138, 0.9820138,
0.9820138},
eps));
}
TEST(cuDNN_Softmax2, run_axis2) {
@ -98,10 +101,12 @@ TEST(cuDNN_Softmax2, run_axis2) {
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// Check
EXPECT_TRUE(outputGpu2Cpu->equalData(vector<float>{
0.1192029, 0.1192029, 0.8807971, 0.8807971, 0.1192029, 0.1192029,
0.8807971, 0.8807971, 0.1192029, 0.1192029, 0.8807971, 0.8807971,
0.1192029, 0.1192029, 0.8807971, 0.8807971}));
EXPECT_TRUE(outputGpu2Cpu->equalData(
vector<float>{0.1192029, 0.1192029, 0.8807971, 0.8807971, 0.1192029,
0.1192029, 0.8807971, 0.8807971, 0.1192029, 0.1192029,
0.8807971, 0.8807971, 0.1192029, 0.1192029, 0.8807971,
0.8807971},
eps));
}
TEST(cuDNN_Softmax2, run_axis3) {
@ -123,9 +128,11 @@ TEST(cuDNN_Softmax2, run_axis3) {
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// Check
EXPECT_TRUE(outputGpu2Cpu->equalData(vector<float>{
0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586,
0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586,
0.2689414, 0.7310586, 0.2689414, 0.7310586}));
EXPECT_TRUE(outputGpu2Cpu->equalData(
vector<float>{0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414,
0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586,
0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414,
0.7310586},
eps));
}
} // namespace infini