PulseFocusPlatform/static/ppdet/data/tests/test_dataset.py

153 lines
3.8 KiB
Python

# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import unittest
import sys
import logging
import random
import copy
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4)))
if parent_path not in sys.path:
sys.path.append(parent_path)
from ppdet.data.parallel_map import ParallelMap
from ppdet.utils.check import enable_static_mode
class MemorySource(object):
""" memory data source for testing
"""
def __init__(self, samples):
self._epoch = -1
self._pos = -1
self._drained = False
self._samples = samples
def __iter__(self):
return self
def __next__(self):
return self.next()
def next(self):
if self._epoch < 0:
self.reset()
if self._pos >= self.size():
self._drained = True
raise StopIteration("no more data in " + str(self))
else:
sample = copy.deepcopy(self._samples[self._pos])
self._pos += 1
return sample
def reset(self):
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self._pos = 0
self._drained = False
random.shuffle(self._samples)
def size(self):
return len(self._samples)
def drained(self):
assert self._epoch >= 0, "the first epoch has not started yet"
return self._pos >= self.size()
def epoch_id(self):
return self._epoch
class TestDataset(unittest.TestCase):
"""Test cases for ppdet.data.dataset
"""
@classmethod
def setUpClass(cls):
""" setup
"""
pass
@classmethod
def tearDownClass(cls):
""" tearDownClass """
pass
def test_next(self):
""" test next
"""
samples = list(range(10))
mem_sc = MemorySource(samples)
for i, d in enumerate(mem_sc):
self.assertTrue(d in samples)
def test_transform_with_abnormal_worker(self):
""" test dataset transform with abnormally exit process
"""
samples = list(range(20))
mem_sc = MemorySource(samples)
def _worker(sample):
if sample == 3:
sys.exit(1)
return 2 * sample
test_worker = ParallelMap(
mem_sc, _worker, worker_num=2, use_process=True, memsize='2M')
ct = 0
for i, d in enumerate(test_worker):
ct += 1
self.assertTrue(d / 2 in samples)
self.assertEqual(len(samples) - 1, ct)
def test_transform_with_delay_worker(self):
""" test dataset transform with delayed process
"""
samples = list(range(20))
mem_sc = MemorySource(samples)
def _worker(sample):
if sample == 3:
time.sleep(30)
return 2 * sample
test_worker = ParallelMap(
mem_sc, _worker, worker_num=2, use_process=True, memsize='2M')
ct = 0
for i, d in enumerate(test_worker):
ct += 1
self.assertTrue(d / 2 in samples)
self.assertEqual(len(samples), ct)
if __name__ == '__main__':
enable_static_mode()
logging.basicConfig()
unittest.main()