manage lock of taskmgr to support stop task action

This commit is contained in:
Firmlyzhu 2019-03-29 18:05:44 +08:00
parent a0db49dee9
commit ca50e328a0
2 changed files with 54 additions and 26 deletions

View File

@ -121,6 +121,9 @@ class BatchJob(object):
def update_task_running(self, task_idx): def update_task_running(self, task_idx):
logger.debug("Update status of task(idx:%s) of BatchJob(id:%s) running." % (task_idx, self.job_id)) logger.debug("Update status of task(idx:%s) of BatchJob(id:%s) running." % (task_idx, self.job_id))
old_status = self.tasks[task_idx]['status'] old_status = self.tasks[task_idx]['status']
if old_status == 'stopped':
logger.info("Task(idx:%s) of BatchJob(id:%s) has been stopped."% (task_idx, self.job_id))
return
self.tasks_cnt[old_status] -= 1 self.tasks_cnt[old_status] -= 1
self.tasks[task_idx]['status'] = 'running' self.tasks[task_idx]['status'] = 'running'
self.tasks[task_idx]['db'] = Batchtask.query.get(self.tasks[task_idx]['id']) self.tasks[task_idx]['db'] = Batchtask.query.get(self.tasks[task_idx]['id'])
@ -138,6 +141,9 @@ class BatchJob(object):
return [] return []
logger.debug("Task(idx:%s) of BatchJob(id:%s) has finished(running_time=%d,billing=%d). Update dependency..." % (task_idx, self.job_id, running_time, billing)) logger.debug("Task(idx:%s) of BatchJob(id:%s) has finished(running_time=%d,billing=%d). Update dependency..." % (task_idx, self.job_id, running_time, billing))
old_status = self.tasks[task_idx]['status'] old_status = self.tasks[task_idx]['status']
if old_status == 'stopped':
logger.info("Task(idx:%s) of BatchJob(id:%s) has been stopped."% (task_idx, self.job_id))
return
self.tasks_cnt[old_status] -= 1 self.tasks_cnt[old_status] -= 1
self.tasks[task_idx]['status'] = 'finished' self.tasks[task_idx]['status'] = 'finished'
self.tasks[task_idx]['db'] = Batchtask.query.get(self.tasks[task_idx]['id']) self.tasks[task_idx]['db'] = Batchtask.query.get(self.tasks[task_idx]['id'])
@ -178,6 +184,9 @@ class BatchJob(object):
def update_task_retrying(self, task_idx, reason, tried_times): def update_task_retrying(self, task_idx, reason, tried_times):
logger.debug("Update status of task(idx:%s) of BatchJob(id:%s) retrying. reason:%s tried_times:%d" % (task_idx, self.job_id, reason, int(tried_times))) logger.debug("Update status of task(idx:%s) of BatchJob(id:%s) retrying. reason:%s tried_times:%d" % (task_idx, self.job_id, reason, int(tried_times)))
old_status = self.tasks[task_idx]['status'] old_status = self.tasks[task_idx]['status']
if old_status == 'stopped':
logger.info("Task(idx:%s) of BatchJob(id:%s) has been stopped."% (task_idx, self.job_id))
return
self.tasks_cnt[old_status] -= 1 self.tasks_cnt[old_status] -= 1
self.tasks_cnt['retrying'] += 1 self.tasks_cnt['retrying'] += 1
self.tasks[task_idx]['db'] = Batchtask.query.get(self.tasks[task_idx]['id']) self.tasks[task_idx]['db'] = Batchtask.query.get(self.tasks[task_idx]['id'])
@ -194,6 +203,9 @@ class BatchJob(object):
def update_task_failed(self, task_idx, reason, tried_times, running_time, billing): def update_task_failed(self, task_idx, reason, tried_times, running_time, billing):
logger.debug("Update status of task(idx:%s) of BatchJob(id:%s) failed. reason:%s tried_times:%d" % (task_idx, self.job_id, reason, int(tried_times))) logger.debug("Update status of task(idx:%s) of BatchJob(id:%s) failed. reason:%s tried_times:%d" % (task_idx, self.job_id, reason, int(tried_times)))
old_status = self.tasks[task_idx]['status'] old_status = self.tasks[task_idx]['status']
if old_status == 'stopped':
logger.info("Task(idx:%s) of BatchJob(id:%s) has been stopped."% (task_idx, self.job_id))
return
self.tasks_cnt[old_status] -= 1 self.tasks_cnt[old_status] -= 1
self.tasks_cnt['failed'] += 1 self.tasks_cnt['failed'] += 1
self.tasks[task_idx]['status'] = 'failed' self.tasks[task_idx]['status'] = 'failed'
@ -290,8 +302,7 @@ class JobMgr():
raise Exception("Wrong User.") raise Exception("Wrong User.")
for task_idx in job.tasks.keys(): for task_idx in job.tasks.keys():
taskid = job_id + '_' + task_idx taskid = job_id + '_' + task_idx
task = self.taskmgr.get_task(taskid) self.taskmgr.lazy_stop_task(taskid)
self.taskmgr.stop_remove_task(task)
job.stop_job() job.stop_job()
except Exception as err: except Exception as err:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())

View File

@ -146,7 +146,10 @@ class TaskMgr(threading.Thread):
self.task_queue = [] self.task_queue = []
self.lazy_append_list = [] self.lazy_append_list = []
self.lazy_delete_list = [] self.lazy_delete_list = []
self.lazy_stop_list = []
self.task_queue_lock = threading.Lock() self.task_queue_lock = threading.Lock()
self.stop_lock = threading.Lock()
self.add_lock = threading.Lock()
#self.user_containers = {} #self.user_containers = {}
self.scheduler_interval = scheduler_interval self.scheduler_interval = scheduler_interval
@ -178,23 +181,21 @@ class TaskMgr(threading.Thread):
self.logger.info("Free nets addresses pool %s" % str(self.free_nets)) self.logger.info("Free nets addresses pool %s" % str(self.free_nets))
self.logger.info("Each Batch Net CIDR:%s"%(str(self.task_cidr))) self.logger.info("Each Batch Net CIDR:%s"%(str(self.task_cidr)))
def queue_lock(f): def data_lock(lockname):
@wraps(f) def lock(f):
def new_f(self, *args, **kwargs): @wraps(f)
self.task_queue_lock.acquire() def new_f(self, *args, **kwargs):
result = f(self, *args, **kwargs) lockobj = getattr(self,lockname)
self.task_queue_lock.release() lockobj.acquire()
return result try:
return new_f result = f(self, *args, **kwargs)
except Exception as err:
def net_lock(f): lockobj.release()
@wraps(f) raise err
def new_f(self, *args, **kwargs): lockobj.release()
self.network_lock.acquire() return result
result = f(self, *args, **kwargs) return new_f
self.network_lock.release() return lock
return result
return new_f
def run(self): def run(self):
self.serve() self.serve()
@ -218,14 +219,25 @@ class TaskMgr(threading.Thread):
self.server.stop(0) self.server.stop(0)
self.logger.info('[taskmgr_rpc] stop rpc server') self.logger.info('[taskmgr_rpc] stop rpc server')
@queue_lock @data_lock('task_queue_lock')
@data_lock('add_lock')
@data_lock('stop_lock')
def sort_out_task_queue(self): def sort_out_task_queue(self):
for task in self.task_queue:
if task.id in self.lazy_stop_list:
self.stop_remove_task(task)
self.lazy_delete_list.append(task)
while self.lazy_delete_list: while self.lazy_delete_list:
task = self.lazy_delete_list.pop(0) task = self.lazy_delete_list.pop(0)
try: try:
self.task_queue.remove(task) self.task_queue.remove(task)
except Exception as err: except Exception as err:
self.logger.warning(str(err)) self.logger.warning(str(err))
self.lazy_append_list = [t for t in self.lazy_append_list if t.id not in self.lazy_stop_list]
self.lazy_stop_list.clear()
if self.lazy_append_list: if self.lazy_append_list:
self.task_queue.extend(self.lazy_append_list) self.task_queue.extend(self.lazy_append_list)
self.lazy_append_list.clear() self.lazy_append_list.clear()
@ -299,14 +311,14 @@ class TaskMgr(threading.Thread):
subtask.task_started = False subtask.task_started = False
return [True, ''] return [True, '']
@net_lock @data_lock('network_lock')
def acquire_task_ips(self, task): def acquire_task_ips(self, task):
self.logger.info("[acquire_task_ips] user(%s) task(%s) net(%s)" % (task.username, task.id, str(task.task_base_ip))) self.logger.info("[acquire_task_ips] user(%s) task(%s) net(%s)" % (task.username, task.id, str(task.task_base_ip)))
if task.task_base_ip == None: if task.task_base_ip == None:
task.task_base_ip = self.free_nets.pop(0) task.task_base_ip = self.free_nets.pop(0)
return task.task_base_ip return task.task_base_ip
@net_lock @data_lock('network_lock')
def release_task_ips(self, task): def release_task_ips(self, task):
self.logger.info("[release_task_ips] user(%s) task(%s) net(%s)" % (task.username, task.id, str(task.task_base_ip))) self.logger.info("[release_task_ips] user(%s) task(%s) net(%s)" % (task.username, task.id, str(task.task_base_ip)))
if task.task_base_ip == None: if task.task_base_ip == None:
@ -411,6 +423,10 @@ class TaskMgr(threading.Thread):
self.stop_vnode(sub_task) self.stop_vnode(sub_task)
#pass #pass
@data_lock('task_stop_lock')
def lazy_stop_task(self, taskid):
self.lazy_stop_list.append(taskid)
def stop_remove_task(self, task): def stop_remove_task(self, task):
if task is None: if task is None:
return return
@ -418,7 +434,6 @@ class TaskMgr(threading.Thread):
self.clear_sub_tasks(task.subtask_list) self.clear_sub_tasks(task.subtask_list)
self.release_task_ips(task) self.release_task_ips(task)
self.remove_tasknet(task) self.remove_tasknet(task)
self.lazy_delete_list.append(task)
def check_task_completed(self, task): def check_task_completed(self, task):
if task.status == RUNNING or task.status == WAITING: if task.status == RUNNING or task.status == WAITING:
@ -427,6 +442,7 @@ class TaskMgr(threading.Thread):
return False return False
self.logger.info('task %s finished, status %d, subtasks: %s' % (task.id, task.status, str([sub_task.status for sub_task in task.subtask_list]))) self.logger.info('task %s finished, status %d, subtasks: %s' % (task.id, task.status, str([sub_task.status for sub_task in task.subtask_list])))
self.stop_remove_task(task) self.stop_remove_task(task)
self.lazy_delete_list.append(task)
running_time, billing = task.get_billing() running_time, billing = task.get_billing()
self.logger.info('task %s running_time:%s billing:%d'%(task.id, str(running_time), billing)) self.logger.info('task %s running_time:%s billing:%d'%(task.id, str(running_time), billing))
running_time = math.ceil(running_time) running_time = math.ceil(running_time)
@ -474,7 +490,7 @@ class TaskMgr(threading.Thread):
self.logger.info('[task_scheduler] scheduling... (%d tasks remains)' % len(self.task_queue)) self.logger.info('[task_scheduler] scheduling... (%d tasks remains)' % len(self.task_queue))
for task in self.task_queue: for task in self.task_queue:
if task in self.lazy_delete_list: if task in self.lazy_delete_list or task.id in self.lazy_stop_list:
continue continue
self.logger.info('task %s sub_tasks %s' % (task.id, str([sub_task.status for sub_task in task.subtask_list]))) self.logger.info('task %s sub_tasks %s' % (task.id, str([sub_task.status for sub_task in task.subtask_list])))
if self.check_task_completed(task): if self.check_task_completed(task):
@ -606,6 +622,7 @@ class TaskMgr(threading.Thread):
# save the task information into database # save the task information into database
# called when jobmgr assign task to taskmgr # called when jobmgr assign task to taskmgr
@data_lock('add_lock')
def add_task(self, username, taskid, json_task, task_priority=1): def add_task(self, username, taskid, json_task, task_priority=1):
# decode json string to object defined in grpc # decode json string to object defined in grpc
self.logger.info('[taskmgr add_task] receive task %s' % taskid) self.logger.info('[taskmgr add_task] receive task %s' % taskid)
@ -683,12 +700,12 @@ class TaskMgr(threading.Thread):
return True return True
@queue_lock @data_lock('task_queue_lock')
def get_task_list(self): def get_task_list(self):
return self.task_queue.copy() return self.task_queue.copy()
@queue_lock @data_lock('task_queue_lock')
def get_task(self, taskid): def get_task(self, taskid):
for task in self.task_queue: for task in self.task_queue:
if task.id == taskid: if task.id == taskid: