2023-08-29 16:06:52 +08:00
|
|
|
#include "core/graph.h"
|
|
|
|
#include "core/runtime.h"
|
|
|
|
#include "cuda/cuda_runtime.h"
|
|
|
|
#include "cuda/cuda_utility.h"
|
|
|
|
#include "operators/where.h"
|
|
|
|
|
|
|
|
#include "test.h"
|
|
|
|
|
|
|
|
namespace infini {
|
|
|
|
|
|
|
|
void test_where(const Shape &inputxshape, const vector<float> &inputxdata,
|
|
|
|
const Shape &inputyshape, const vector<float> &inputydata,
|
2023-09-14 10:45:57 +08:00
|
|
|
const Shape &conditionshape,
|
|
|
|
const vector<uint8_t> &conditiondata,
|
2023-08-29 16:06:52 +08:00
|
|
|
const vector<float> &ExpectData) {
|
|
|
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
|
|
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
2023-09-14 10:45:57 +08:00
|
|
|
auto condition = gCpu->addTensor(conditionshape, DataType::UInt8);
|
2023-08-29 16:06:52 +08:00
|
|
|
auto inputx = gCpu->addTensor(inputxshape, DataType::Float32);
|
|
|
|
auto inputy = gCpu->addTensor(inputyshape, DataType::Float32);
|
|
|
|
|
|
|
|
gCpu->dataMalloc();
|
|
|
|
condition->copyin(conditiondata); //
|
|
|
|
inputx->copyin(inputxdata);
|
|
|
|
inputy->copyin(inputydata); //
|
|
|
|
|
|
|
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
|
|
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
|
|
|
|
|
|
|
auto conditionGpu = gCuda->cloneTensor(condition);
|
|
|
|
auto inputxGpu = gCuda->cloneTensor(inputx);
|
|
|
|
auto inputyGpu = gCuda->cloneTensor(inputy);
|
|
|
|
|
|
|
|
auto op = gCuda->addOp<WhereObj>(inputxGpu, inputyGpu, conditionGpu,
|
|
|
|
nullptr); // WhereObj
|
|
|
|
gCuda->dataMalloc();
|
|
|
|
conditionGpu->copyin(conditiondata);
|
|
|
|
inputxGpu->copyin(inputxdata);
|
|
|
|
inputyGpu->copyin(inputydata);
|
|
|
|
cudaRuntime->run(gCuda);
|
|
|
|
|
|
|
|
auto oCpu = gCpu->cloneTensor(op->getOutput()); // move data from gpu to cpu
|
|
|
|
oCpu->printData(); //->printData
|
|
|
|
EXPECT_TRUE(oCpu->equalData(ExpectData));
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST(CUDA_Where, run) {
|
|
|
|
test_where(
|
|
|
|
Shape{2, 2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
|
|
|
Shape{2, 2, 3, 1}, vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
2023-09-14 10:45:57 +08:00
|
|
|
Shape{2, 2, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1},
|
2023-08-29 16:06:52 +08:00
|
|
|
vector<float>{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.});
|
|
|
|
|
|
|
|
test_where(Shape{2, 1, 1, 3}, // inputx
|
|
|
|
vector<float>{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy
|
|
|
|
vector<float>{1, 1}, Shape{2, 1, 3, 1}, // condition
|
2023-09-14 10:45:57 +08:00
|
|
|
vector<uint8_t>{0, 1, 1, 0, 0, 0},
|
2023-08-29 16:06:52 +08:00
|
|
|
vector<float>{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1.,
|
|
|
|
0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1.,
|
|
|
|
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.});
|
2023-09-14 10:45:57 +08:00
|
|
|
test_where(
|
|
|
|
Shape{
|
|
|
|
3,
|
|
|
|
},
|
|
|
|
vector<float>{0, 1, 2}, // inputX
|
|
|
|
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
|
|
|
|
Shape{2, 1, 3, 1}, vector<uint8_t>{0, 1, 1, 0, 0, 0}, // condition
|
|
|
|
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
|
|
|
|
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
|
|
|
|
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
|
|
|
|
test_where(
|
|
|
|
Shape{
|
|
|
|
3,
|
|
|
|
},
|
|
|
|
vector<float>{0, 1, 2}, // inputX
|
|
|
|
Shape{2, 3, 1}, vector<float>{0, 1, 2, 3, 4, 5}, // inputY
|
|
|
|
Shape{2, 1, 3, 1},
|
|
|
|
vector<uint8_t>{false, true, true, false, false, false}, // condition
|
|
|
|
vector<float>{0., 0., 0., 0., 1., 2., 0., 1., 2., 3., 3., 3.,
|
|
|
|
0., 1., 2., 0., 1., 2., 0., 0., 0., 1., 1., 1.,
|
|
|
|
2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5.});
|
2023-08-29 16:06:52 +08:00
|
|
|
|
|
|
|
} // python output
|
|
|
|
|
|
|
|
} // namespace infini
|