2023-08-29 16:06:52 +08:00
|
|
|
#include "core/graph.h"
|
|
|
|
#include "core/runtime.h"
|
|
|
|
#include "cuda/cuda_runtime.h"
|
|
|
|
#include "cuda/cuda_utility.h"
|
|
|
|
#include "operators/where.h"
|
|
|
|
|
|
|
|
#include "test.h"
|
|
|
|
|
|
|
|
namespace infini {
|
|
|
|
|
2023-11-06 08:56:23 +08:00
|
|
|
void test_where(const Shape &inputXShape, const vector<float> &inputXData,
|
|
|
|
const Shape &inputYShape, const vector<float> &inputYData,
|
|
|
|
const Shape &conditionShape,
|
|
|
|
const vector<uint8_t> &conditionData,
|
2023-08-29 16:06:52 +08:00
|
|
|
const vector<float> &ExpectData) {
|
|
|
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
|
|
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
2023-11-06 08:56:23 +08:00
|
|
|
auto condition = gCpu->addTensor(conditionShape, DataType::UInt8);
|
|
|
|
auto inputX = gCpu->addTensor(inputXShape, DataType::Float32);
|
|
|
|
auto inputY = gCpu->addTensor(inputYShape, DataType::Float32);
|
2023-08-29 16:06:52 +08:00
|
|
|
|
|
|
|
gCpu->dataMalloc();
|
2023-11-06 08:56:23 +08:00
|
|
|
condition->copyin(conditionData); //
|
|
|
|
inputX->copyin(inputXData);
|
|
|
|
inputY->copyin(inputYData); //
|
2023-08-29 16:06:52 +08:00
|
|
|
|
|
|
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
|
|
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
|
|
|
|
|
|
|
auto conditionGpu = gCuda->cloneTensor(condition);
|
2023-11-06 08:56:23 +08:00
|
|
|
auto inputXGpu = gCuda->cloneTensor(inputX);
|
|
|
|
auto inputYGpu = gCuda->cloneTensor(inputY);
|
2023-08-29 16:06:52 +08:00
|
|
|
|
2023-11-06 08:56:23 +08:00
|
|
|
auto op = gCuda->addOp<WhereObj>(inputXGpu, inputYGpu, conditionGpu,
|
2023-08-29 16:06:52 +08:00
|
|
|
nullptr); // WhereObj
|
|
|
|
gCuda->dataMalloc();
|
2023-11-06 08:56:23 +08:00
|
|
|
conditionGpu->copyin(conditionData);
|
|
|
|
inputXGpu->copyin(inputXData);
|
|
|
|
inputYGpu->copyin(inputYData);
|
2023-08-29 16:06:52 +08:00
|
|
|
cudaRuntime->run(gCuda);
|
|
|
|
|
2023-11-06 08:56:23 +08:00
|
|
|
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move Data from gpu to cpu
|
2023-08-29 16:06:52 +08:00
|
|
|
oCpu->printData(); //->printData
|
|
|
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST(CUDA_Where, run) {
|
|
|
|
test_where(
|
|
|
|
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
|
|
|
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
2023-09-14 10:45:57 +08:00
|
|
|
Shape{2, 2, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
|
2023-08-29 16:06:52 +08:00
|
|
|
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.});
|
|
|
|
|
|
|
|
test_where(Shape{2, 1, 1, 3}, // inputx
|
|
|
|
vector<float>{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy
|
|
|
|
vector<float>{1, 1}, Shape{2, 1, 3, 1}, // condition
|
2023-09-14 10:45:57 +08:00
|
|
|
vector<uint8_t>{0, 1, 1, 0, 0, 0},
|
2023-08-29 16:06:52 +08:00
|
|
|
vector<float>{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1.,
|
|
|
|
0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1.,
|
|
|
|
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
|
2023-09-14 10:45:57 +08:00
|
|
|
test_where(
|
|
|
|
Shape{
|
|
|
|
3,
|
|
|
|
},
|
|
|
|
vector<float>{0, 1, 2}, // inputX
|
|
|
|
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
|
|
|
|
Shape{2, 1, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0}, // condition
|
|
|
|
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
|
|
|
|
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
|
|
|
|
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
|
|
|
|
test_where(
|
|
|
|
Shape{
|
|
|
|
3,
|
|
|
|
},
|
|
|
|
vector<float>{0, 1, 2}, // inputX
|
|
|
|
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
|
|
|
|
Shape{2, 1, 3, 1},
|
|
|
|
vector<uint8_t>{false, true, true, false, false, false}, // condition
|
|
|
|
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
|
|
|
|
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
|
|
|
|
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
|
2023-08-29 16:06:52 +08:00
|
|
|
|
|
|
|
} // python output
|
|
|
|
|
|
|
|
} // namespace infini
|