forked from jiuyuan/InfiniTensor
add dropout
This commit is contained in:
parent
cada8ec6c8
commit
06f5e82d8b
|
@ -101,8 +101,8 @@ class GraphHandlerObj {
|
|||
std::string mode);
|
||||
Tensor lrn(Tensor input, Tensor output, float alpha, float beta, float bias,
|
||||
int size);
|
||||
Tensor dropout(Tensor input, Tensor output, Tensor mask, float ratio,
|
||||
bool training_mode);
|
||||
TensorVec dropout(Tensor input, Tensor output, Tensor mask, float ratio,
|
||||
bool training_mode);
|
||||
|
||||
//------ modifiers
|
||||
|
||||
|
|
|
@ -0,0 +1,213 @@
|
|||
I1120 17:13:29.288781 19265 init.cc:233] ENV [CUSTOM_DEVICE_ROOT]=/opt/py39/lib/python3.9/site-packages/paddle_custom_device
|
||||
I1120 17:13:29.288885 19265 init.cc:142] Try loading custom device libs from: [/opt/py39/lib/python3.9/site-packages/paddle_custom_device]
|
||||
I1120 17:13:29.337081 19265 custom_device.cc:1108] Successed in loading custom runtime in lib: /opt/py39/lib/python3.9/site-packages/paddle_custom_device/libpaddle-custom-mlu.so
|
||||
I1120 17:13:29.340378 19265 custom_kernel.cc:63] Successed in loading 248 custom kernel(s) from loaded lib(s), will be used like native ones.
|
||||
I1120 17:13:29.340472 19265 init.cc:154] Finished in LoadCustomDevice with libs_path: [/opt/py39/lib/python3.9/site-packages/paddle_custom_device]
|
||||
I1120 17:13:29.340502 19265 init.cc:239] CustomDevice: mlu, visible devices count: 4
|
||||
I1120 17:13:29.797426 19265 program_interpreter.cc:185] New Executor is Running.
|
||||
I1120 17:13:33.000212 19265 interpreter_util.cc:608] Standalone Executor is Used.
|
||||
Running verify PaddlePaddle program ...
|
||||
PaddlePaddle works well on 1 mlu.
|
||||
I1120 17:13:34.297436 19347 init.cc:233] ENV [CUSTOM_DEVICE_ROOT]=/opt/py39/lib/python3.9/site-packages/paddle_custom_device
|
||||
I1120 17:13:34.297511 19347 init.cc:142] Try loading custom device libs from: [/opt/py39/lib/python3.9/site-packages/paddle_custom_device]
|
||||
I1120 17:13:34.312176 19345 init.cc:233] ENV [CUSTOM_DEVICE_ROOT]=/opt/py39/lib/python3.9/site-packages/paddle_custom_device
|
||||
I1120 17:13:34.312247 19345 init.cc:142] Try loading custom device libs from: [/opt/py39/lib/python3.9/site-packages/paddle_custom_device]
|
||||
I1120 17:13:34.340984 19349 init.cc:233] ENV [CUSTOM_DEVICE_ROOT]=/opt/py39/lib/python3.9/site-packages/paddle_custom_device
|
||||
I1120 17:13:34.341048 19349 init.cc:142] Try loading custom device libs from: [/opt/py39/lib/python3.9/site-packages/paddle_custom_device]
|
||||
I1120 17:13:34.343869 19347 custom_device.cc:1108] Successed in loading custom runtime in lib: /opt/py39/lib/python3.9/site-packages/paddle_custom_device/libpaddle-custom-mlu.so
|
||||
I1120 17:13:34.346643 19347 custom_kernel.cc:63] Successed in loading 248 custom kernel(s) from loaded lib(s), will be used like native ones.
|
||||
I1120 17:13:34.346735 19347 init.cc:154] Finished in LoadCustomDevice with libs_path: [/opt/py39/lib/python3.9/site-packages/paddle_custom_device]
|
||||
I1120 17:13:34.346762 19347 init.cc:239] CustomDevice: mlu, visible devices count: 4
|
||||
I1120 17:13:34.353596 19351 init.cc:233] ENV [CUSTOM_DEVICE_ROOT]=/opt/py39/lib/python3.9/site-packages/paddle_custom_device
|
||||
I1120 17:13:34.353648 19351 init.cc:142] Try loading custom device libs from: [/opt/py39/lib/python3.9/site-packages/paddle_custom_device]
|
||||
I1120 17:13:34.359752 19345 custom_device.cc:1108] Successed in loading custom runtime in lib: /opt/py39/lib/python3.9/site-packages/paddle_custom_device/libpaddle-custom-mlu.so
|
||||
I1120 17:13:34.362963 19345 custom_kernel.cc:63] Successed in loading 248 custom kernel(s) from loaded lib(s), will be used like native ones.
|
||||
I1120 17:13:34.363058 19345 init.cc:154] Finished in LoadCustomDevice with libs_path: [/opt/py39/lib/python3.9/site-packages/paddle_custom_device]
|
||||
I1120 17:13:34.363086 19345 init.cc:239] CustomDevice: mlu, visible devices count: 4
|
||||
I1120 17:13:34.389118 19349 custom_device.cc:1108] Successed in loading custom runtime in lib: /opt/py39/lib/python3.9/site-packages/paddle_custom_device/libpaddle-custom-mlu.so
|
||||
I1120 17:13:34.393033 19349 custom_kernel.cc:63] Successed in loading 248 custom kernel(s) from loaded lib(s), will be used like native ones.
|
||||
I1120 17:13:34.393139 19349 init.cc:154] Finished in LoadCustomDevice with libs_path: [/opt/py39/lib/python3.9/site-packages/paddle_custom_device]
|
||||
I1120 17:13:34.393169 19349 init.cc:239] CustomDevice: mlu, visible devices count: 4
|
||||
I1120 17:13:34.402159 19351 custom_device.cc:1108] Successed in loading custom runtime in lib: /opt/py39/lib/python3.9/site-packages/paddle_custom_device/libpaddle-custom-mlu.so
|
||||
I1120 17:13:34.405285 19351 custom_kernel.cc:63] Successed in loading 248 custom kernel(s) from loaded lib(s), will be used like native ones.
|
||||
I1120 17:13:34.405380 19351 init.cc:154] Finished in LoadCustomDevice with libs_path: [/opt/py39/lib/python3.9/site-packages/paddle_custom_device]
|
||||
I1120 17:13:34.405433 19351 init.cc:239] CustomDevice: mlu, visible devices count: 4
|
||||
======================= Modified FLAGS detected =======================
|
||||
FLAGS(name='FLAGS_allocator_strategy', current_value='naive_best_fit', default_value='auto_growth')
|
||||
=======================================================================
|
||||
I1120 17:13:34.764251 19347 tcp_utils.cc:107] Retry to connect to 127.0.0.1:59969 while the server is not yet listening.
|
||||
======================= Modified FLAGS detected =======================
|
||||
FLAGS(name='FLAGS_allocator_strategy', current_value='naive_best_fit', default_value='auto_growth')
|
||||
=======================================================================
|
||||
I1120 17:13:34.807813 19345 tcp_utils.cc:181] The server starts to listen on IP_ANY:59969
|
||||
I1120 17:13:34.808122 19345 tcp_utils.cc:130] Successfully connected to 127.0.0.1:59969
|
||||
======================= Modified FLAGS detected =======================
|
||||
FLAGS(name='FLAGS_allocator_strategy', current_value='naive_best_fit', default_value='auto_growth')
|
||||
=======================================================================
|
||||
I1120 17:13:34.836859 19349 tcp_utils.cc:130] Successfully connected to 127.0.0.1:59969
|
||||
======================= Modified FLAGS detected =======================
|
||||
FLAGS(name='FLAGS_allocator_strategy', current_value='naive_best_fit', default_value='auto_growth')
|
||||
=======================================================================
|
||||
I1120 17:13:34.852916 19351 tcp_utils.cc:130] Successfully connected to 127.0.0.1:59969
|
||||
I1120 17:13:37.764503 19347 tcp_utils.cc:130] Successfully connected to 127.0.0.1:59969
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: CNCL_LOG_LEVEL: INFO
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: CNCL_LOG_LEVEL: INFO
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: CNCL_LOG_LEVEL: INFO
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: CNCL_LOG_LEVEL: INFO
|
||||
[17:13:40.675816 19345 19345 R0 D0 U92505159763813][CNLINC][error] [Internal] CNAPI error (code 100100, name CN_MEMORY_ERROR_OUT_OF_MEMORY) with driver reason: "device has no memory to alloc", when allocs dev mempool_base [cnlinc_dev_mem.cc:113(InitDevMemPool)]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: Failed to open libibverbs.so[.1]. If you want to use RDMA transport, you need add libibverbs.so path to LD_LIBRARY_PATH.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: Fail to load libibverbs.so!
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: Failed to open libibverbs.so[.1]. If you want to use RDMA transport, you need add libibverbs.so path to LD_LIBRARY_PATH.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: Fail to load libibverbs.so!
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: Failed to open libibverbs.so[.1]. If you want to use RDMA transport, you need add libibverbs.so path to LD_LIBRARY_PATH.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: Failed to open libibverbs.so[.1]. If you want to use RDMA transport, you need add libibverbs.so path to LD_LIBRARY_PATH.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: Fail to load libibverbs.so!
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: Fail to load libibverbs.so!
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 2 hierarchy rings. Hierarchy ring 0, horizontal layer 0: 0-->1-->0, connected by [MLU_LINK]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 2 hierarchy rings. Hierarchy ring 0, horizontal layer 1: 2-->3-->2, connected by [MLU_LINK]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 2 hierarchy rings. Hierarchy ring 0, vertical layer 0: 0-->2-->0, connected by [SHARED_PEER_MEM]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 2 hierarchy rings. Hierarchy ring 0, vertical layer 1: 1-->3-->1, connected by [SHARED_PEER_MEM]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 2 hierarchy rings. Hierarchy ring 1, horizontal layer 0: 0-->1-->0, connected by [MLU_LINK]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 2 hierarchy rings. Hierarchy ring 1, horizontal layer 1: 2-->3-->2, connected by [MLU_LINK]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 2 hierarchy rings. Hierarchy ring 1, vertical layer 0: 0-->2-->0, connected by [SHARED_PEER_MEM]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 2 hierarchy rings. Hierarchy ring 1, vertical layer 1: 1-->3-->1, connected by [SHARED_PEER_MEM]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: Can not build MLU-Link HAT ring, use PCIe or network instead.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 1 rings. Ring 0 (host 0, MLU 0): 3--[SHARED_PEER_MEM]->0--[SHARED_PEER_MEM]->1
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 1 rings. Ring 0 (host 0, MLU 1): 0--[SHARED_PEER_MEM]->1--[SHARED_PEER_MEM]->2
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 1 rings. Ring 0 (host 0, MLU 2): 1--[SHARED_PEER_MEM]->2--[SHARED_PEER_MEM]->3
|
||||
[11-20 17:13:40] [LOG_CNCL] [Info]: Build 1 rings. Ring 0 (host 0, MLU 3): 2--[SHARED_PEER_MEM]->3--[SHARED_PEER_MEM]->0
|
||||
2023-11-20 17:13:40.921650: [cnrtError] [19345] [Card : 0] Error occurred during calling 'cnMalloc' in CNDrv interface.
|
||||
2023-11-20 17:13:40.921833: [cnrtError] [19345] [Card : 0] Return value is 100100, CN_MEMORY_ERROR_OUT_OF_MEMORY, means that "device has no memory to alloc"
|
||||
2023-11-20 17:13:40.921852: [cnrtError] [19345] [Card : 0] cnrtMalloc: Malloc MLU device memory failed.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Error]: Runtime error, msg: insufficient MLU memory. [wrap_cnrt.cc:27 {pid: 19345}]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Error]: error msg: fail to MallocAddr for MluP2PBuff in shared P2P [share_peer_mem.cc:506 {pid: 19345}]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Error]: argument error: shm buffer [share_peer_mem.cc:218 {pid: 19345}]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: topology all_connected_05 @comm0 is not available.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: algorithm [all_connected_05] does not setup successfully @comm3.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: algorithm [all_connected_05] does not setup successfully @comm2.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: algorithm [all_connected_05] does not setup successfully @comm1.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: algorithm [all_connected_05] does not setup successfully @comm0.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: algorithm [all_connected_05] does not connect successfully @comm2.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: algorithm [all_connected_05] does not connect successfully @comm3.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: algorithm [all_connected_05] does not connect successfully @comm1.
|
||||
[11-20 17:13:40] [LOG_CNCL] [Warning]: algorithm [all_connected_05] does not connect successfully @comm0.
|
||||
[17:13:40.950185 19345 19345 R0 D0 U92505159763813][CNLINC][error] [Internal] CNAPI error (code 100102, name CN_MEMORY_ERROR_INVALID_ADDRESS) with driver reason: "Invalid device address", when 599 [legacy_param.cc:153(Init)]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Error]: Cnlinc error, msg: [CNLINC Driver error] Internal driver error occurs. DieBuf Reduce post wr failed in hierarchy_ring_00_horizontal_layer @comm:0 [ring_hat_die_buffer.cc:644 {pid: 19345}]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Error]: error msg: Reduce failed in hierarchy_ring_0 @comm:0 [hierarchy_ring.cc:539 {pid: 19345}]
|
||||
[11-20 17:13:40] [LOG_CNCL] [Error]: error msg: AllReduce failed in hierarchy ring0 @comm:0 [hierarchy_ring.cc:1017 {pid: 19345}]
|
||||
I1120 17:13:41.149036 19393 tcp_store.cc:273] receive shutdown event and so quit from MasterDaemon run loop
|
||||
[11-20 17:13:41] [LOG_CNCL] [Warning]: There exist cnclComms that have not been destroyed yet. Make sure that all cnclComms have been destroyed or freed!
|
||||
|
||||
|
||||
--------------------------------------
|
||||
C++ Traceback (most recent call last):
|
||||
--------------------------------------
|
||||
0 paddle::distributed::ProcessGroupCustom::XCCLTask::Wait(std::chrono::duration<long, std::ratio<1l, 1000l> >)
|
||||
1 phi::DeviceManager::SynchronizeDevice(phi::Place const&)
|
||||
2 phi::CustomDevice::SynchronizeDevice(unsigned long)
|
||||
3 SyncDevice(C_Device_st*)
|
||||
4 cnrtSyncDevice
|
||||
5 cnCtxSync
|
||||
|
||||
----------------------
|
||||
Error Message Summary:
|
||||
----------------------
|
||||
FatalError: `Termination signal` is detected by the operating system.
|
||||
[TimeInfo: *** Aborted at 1700471621 (unix time) try "date -d @1700471621" if you are using GNU date ***]
|
||||
[SignalInfo: *** SIGTERM (@0x4b41) received by PID 19347 (TID 0x7f97d2bab740) from PID 19265 ***]
|
||||
|
||||
|
||||
|
||||
--------------------------------------
|
||||
C++ Traceback (most recent call last):
|
||||
--------------------------------------
|
||||
0 paddle::distributed::ProcessGroupCustom::XCCLTask::Wait(std::chrono::duration<long, std::ratio<1l, 1000l> >)
|
||||
1 phi::DeviceManager::SynchronizeDevice(phi::Place const&)
|
||||
2 phi::CustomDevice::SynchronizeDevice(unsigned long)
|
||||
3 SyncDevice(C_Device_st*)
|
||||
4 cnrtSyncDevice
|
||||
5 cnCtxSync
|
||||
|
||||
----------------------
|
||||
Error Message Summary:
|
||||
----------------------
|
||||
FatalError: `Termination signal` is detected by the operating system.
|
||||
[TimeInfo: *** Aborted at 1700471622 (unix time) try "date -d @1700471622" if you are using GNU date ***]
|
||||
[SignalInfo: *** SIGTERM (@0x4b41) received by PID 19349 (TID 0x7f1fb9984740) from PID 19265 ***]
|
||||
|
||||
|
||||
|
||||
--------------------------------------
|
||||
C++ Traceback (most recent call last):
|
||||
--------------------------------------
|
||||
0 paddle::distributed::ProcessGroupCustom::XCCLTask::Wait(std::chrono::duration<long, std::ratio<1l, 1000l> >)
|
||||
1 phi::DeviceManager::SynchronizeDevice(phi::Place const&)
|
||||
2 phi::CustomDevice::SynchronizeDevice(unsigned long)
|
||||
3 SyncDevice(C_Device_st*)
|
||||
4 cnrtSyncDevice
|
||||
5 cnCtxSync
|
||||
|
||||
----------------------
|
||||
Error Message Summary:
|
||||
----------------------
|
||||
FatalError: `Termination signal` is detected by the operating system.
|
||||
[TimeInfo: *** Aborted at 1700471622 (unix time) try "date -d @1700471622" if you are using GNU date ***]
|
||||
[SignalInfo: *** SIGTERM (@0x4b41) received by PID 19351 (TID 0x7f1721bba740) from PID 19265 ***]
|
||||
|
||||
[2023-11-20 17:13:42,983] [ WARNING] install_check.py:289 - PaddlePaddle meets some problem with 4 mlus. This may be caused by:
|
||||
1. There is not enough GPUs visible on your system
|
||||
2. Some GPUs are occupied by other process now
|
||||
3. NVIDIA-NCCL2 is not installed correctly on your system. Please follow instruction on https://github.com/NVIDIA/nccl-tests
|
||||
to test your NCCL, or reinstall it following https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html
|
||||
[2023-11-20 17:13:42,983] [ WARNING] install_check.py:299 -
|
||||
Original Error is:
|
||||
|
||||
----------------------------------------------
|
||||
Process 0 terminated with the following error:
|
||||
----------------------------------------------
|
||||
|
||||
Traceback (most recent call last):
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/distributed/spawn.py", line 376, in _func_wrapper
|
||||
result = func(*args)
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/utils/install_check.py", line 181, in train_for_run_parallel
|
||||
paddle.distributed.init_parallel_env()
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/distributed/parallel.py", line 1101, in init_parallel_env
|
||||
paddle.distributed.barrier(group=group)
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/distributed/communication/group.py", line 326, in barrier
|
||||
task = group.process_group.barrier(device_id)
|
||||
paddle.base.libpaddle.EnforceNotMet: (MLU CNCL error(1), Cncl internal error.. (at /workspace/workplace/PaddleCustomDevice/backends/mlu/runtime/runti)373)
|
||||
|
||||
|
||||
Traceback (most recent call last):
|
||||
File "<string>", line 1, in <module>
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/utils/install_check.py", line 304, in run_check
|
||||
raise e
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/utils/install_check.py", line 283, in run_check
|
||||
_run_parallel(device_list)
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/utils/install_check.py", line 210, in _run_parallel
|
||||
paddle.distributed.spawn(train_for_run_parallel, nprocs=len(device_list))
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/distributed/spawn.py", line 618, in spawn
|
||||
while not context.join():
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/distributed/spawn.py", line 427, in join
|
||||
self._throw_exception(error_index)
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/distributed/spawn.py", line 451, in _throw_exception
|
||||
raise Exception(msg)
|
||||
Exception:
|
||||
|
||||
----------------------------------------------
|
||||
Process 0 terminated with the following error:
|
||||
----------------------------------------------
|
||||
|
||||
Traceback (most recent call last):
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/distributed/spawn.py", line 376, in _func_wrapper
|
||||
result = func(*args)
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/utils/install_check.py", line 181, in train_for_run_parallel
|
||||
paddle.distributed.init_parallel_env()
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/distributed/parallel.py", line 1101, in init_parallel_env
|
||||
paddle.distributed.barrier(group=group)
|
||||
File "/opt/py39/lib/python3.9/site-packages/paddle/distributed/communication/group.py", line 326, in barrier
|
||||
task = group.process_group.barrier(device_id)
|
||||
paddle.base.libpaddle.EnforceNotMet: (MLU CNCL error(1), Cncl internal error.. (at /workspace/workplace/PaddleCustomDevice/backends/mlu/runtime/runti)373)
|
||||
|
||||
|
||||
PaddlePaddle is installed successfully ONLY for single mlu! Let's start deep learning with PaddlePaddle now.
|
|
@ -674,18 +674,22 @@ class OnnxStub:
|
|||
)
|
||||
elif node.op_type == "Dropout":
|
||||
attributes = _parse_attribute(
|
||||
node, {"ratio": 0.5, "training_mode: 0"})
|
||||
node, {"ratio": 0.5, "training_mode": 0})
|
||||
(ratio, training_mode) = (
|
||||
attribute[name]
|
||||
attributes[name]
|
||||
for name in ["ratio", "training_mode"]
|
||||
)
|
||||
tensors[node.output[0]] = self.handler.dropout(
|
||||
for name, tensor in zip(
|
||||
node.output,
|
||||
self.handler.dropout(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
tensors.get(node.output[1]),
|
||||
ratio,
|
||||
(bool)training_mode,
|
||||
)
|
||||
training_mode,
|
||||
),
|
||||
):
|
||||
tensors[name] = tensor
|
||||
elif node.op_type == "Cast":
|
||||
tensors[node.output[0]] = self.handler.cast(
|
||||
tensors[node.input[0]],
|
||||
|
|
|
@ -534,17 +534,17 @@ Tensor GraphHandlerObj::lrn(Tensor input, Tensor output, float alpha,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::dropout(Tensor input, Tensor output, Tensor mask,
|
||||
float ratio, bool training_mode) {
|
||||
TensorVec GraphHandlerObj::dropout(Tensor input, Tensor output, Tensor mask,
|
||||
float ratio, bool training_mode) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<DropoutObj>(std::move(input), output, mask, ratio,
|
||||
training_mode);
|
||||
return output;
|
||||
return {output, mask};
|
||||
} else {
|
||||
return g
|
||||
->addOp<DropoutObj>(std::move(input), output, mask, ratio,
|
||||
training_mode)
|
||||
->getOutput();
|
||||
->getOutputs();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -24,13 +24,18 @@ class DropoutCnnl : public BangKernelWithoutConfig {
|
|||
checkCnnlError(cnnlSetTensorDescriptor(oDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, oDim.size(),
|
||||
oDim.data()));
|
||||
cnnlTensorDescriptor_t mDesc;
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&mDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(mDesc, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_UINT8, oDim.size(),
|
||||
oDim.data()));
|
||||
|
||||
auto ratio = op->getRatio();
|
||||
// auto train = op->getTrainingMode();
|
||||
|
||||
cnnlStatus_t stat =
|
||||
cnnlFusedDropout_v2(context->cnnlHandle(), generator, oDesc, iData,
|
||||
ratio, NULL, oDesc, mData, oDesc, oData);
|
||||
ratio, NULL, mDesc, mData, oDesc, oData);
|
||||
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
|
Loading…
Reference in New Issue