diff --git a/src/master/jobmgr.py b/src/master/jobmgr.py index 5049bba..1cbf1eb 100644 --- a/src/master/jobmgr.py +++ b/src/master/jobmgr.py @@ -121,6 +121,9 @@ class BatchJob(object): def update_task_running(self, task_idx): logger.debug("Update status of task(idx:%s) of BatchJob(id:%s) running." % (task_idx, self.job_id)) old_status = self.tasks[task_idx]['status'] + if old_status == 'stopped': + logger.info("Task(idx:%s) of BatchJob(id:%s) has been stopped."% (task_idx, self.job_id)) + return self.tasks_cnt[old_status] -= 1 self.tasks[task_idx]['status'] = 'running' self.tasks[task_idx]['db'] = Batchtask.query.get(self.tasks[task_idx]['id']) @@ -138,6 +141,9 @@ class BatchJob(object): return [] logger.debug("Task(idx:%s) of BatchJob(id:%s) has finished(running_time=%d,billing=%d). Update dependency..." % (task_idx, self.job_id, running_time, billing)) old_status = self.tasks[task_idx]['status'] + if old_status == 'stopped': + logger.info("Task(idx:%s) of BatchJob(id:%s) has been stopped."% (task_idx, self.job_id)) + return self.tasks_cnt[old_status] -= 1 self.tasks[task_idx]['status'] = 'finished' self.tasks[task_idx]['db'] = Batchtask.query.get(self.tasks[task_idx]['id']) @@ -178,6 +184,9 @@ class BatchJob(object): def update_task_retrying(self, task_idx, reason, tried_times): logger.debug("Update status of task(idx:%s) of BatchJob(id:%s) retrying. reason:%s tried_times:%d" % (task_idx, self.job_id, reason, int(tried_times))) old_status = self.tasks[task_idx]['status'] + if old_status == 'stopped': + logger.info("Task(idx:%s) of BatchJob(id:%s) has been stopped."% (task_idx, self.job_id)) + return self.tasks_cnt[old_status] -= 1 self.tasks_cnt['retrying'] += 1 self.tasks[task_idx]['db'] = Batchtask.query.get(self.tasks[task_idx]['id']) @@ -194,6 +203,9 @@ class BatchJob(object): def update_task_failed(self, task_idx, reason, tried_times, running_time, billing): logger.debug("Update status of task(idx:%s) of BatchJob(id:%s) failed. reason:%s tried_times:%d" % (task_idx, self.job_id, reason, int(tried_times))) old_status = self.tasks[task_idx]['status'] + if old_status == 'stopped': + logger.info("Task(idx:%s) of BatchJob(id:%s) has been stopped."% (task_idx, self.job_id)) + return self.tasks_cnt[old_status] -= 1 self.tasks_cnt['failed'] += 1 self.tasks[task_idx]['status'] = 'failed' @@ -290,8 +302,7 @@ class JobMgr(): raise Exception("Wrong User.") for task_idx in job.tasks.keys(): taskid = job_id + '_' + task_idx - task = self.taskmgr.get_task(taskid) - self.taskmgr.stop_remove_task(task) + self.taskmgr.lazy_stop_task(taskid) job.stop_job() except Exception as err: logger.error(traceback.format_exc()) diff --git a/src/master/taskmgr.py b/src/master/taskmgr.py index 3ccbd7b..90b296b 100644 --- a/src/master/taskmgr.py +++ b/src/master/taskmgr.py @@ -146,7 +146,10 @@ class TaskMgr(threading.Thread): self.task_queue = [] self.lazy_append_list = [] self.lazy_delete_list = [] + self.lazy_stop_list = [] self.task_queue_lock = threading.Lock() + self.stop_lock = threading.Lock() + self.add_lock = threading.Lock() #self.user_containers = {} self.scheduler_interval = scheduler_interval @@ -178,23 +181,21 @@ class TaskMgr(threading.Thread): self.logger.info("Free nets addresses pool %s" % str(self.free_nets)) self.logger.info("Each Batch Net CIDR:%s"%(str(self.task_cidr))) - def queue_lock(f): - @wraps(f) - def new_f(self, *args, **kwargs): - self.task_queue_lock.acquire() - result = f(self, *args, **kwargs) - self.task_queue_lock.release() - return result - return new_f - - def net_lock(f): - @wraps(f) - def new_f(self, *args, **kwargs): - self.network_lock.acquire() - result = f(self, *args, **kwargs) - self.network_lock.release() - return result - return new_f + def data_lock(lockname): + def lock(f): + @wraps(f) + def new_f(self, *args, **kwargs): + lockobj = getattr(self,lockname) + lockobj.acquire() + try: + result = f(self, *args, **kwargs) + except Exception as err: + lockobj.release() + raise err + lockobj.release() + return result + return new_f + return lock def run(self): self.serve() @@ -218,14 +219,25 @@ class TaskMgr(threading.Thread): self.server.stop(0) self.logger.info('[taskmgr_rpc] stop rpc server') - @queue_lock + @data_lock('task_queue_lock') + @data_lock('add_lock') + @data_lock('stop_lock') def sort_out_task_queue(self): + + for task in self.task_queue: + if task.id in self.lazy_stop_list: + self.stop_remove_task(task) + self.lazy_delete_list.append(task) + while self.lazy_delete_list: task = self.lazy_delete_list.pop(0) try: self.task_queue.remove(task) except Exception as err: self.logger.warning(str(err)) + + self.lazy_append_list = [t for t in self.lazy_append_list if t.id not in self.lazy_stop_list] + self.lazy_stop_list.clear() if self.lazy_append_list: self.task_queue.extend(self.lazy_append_list) self.lazy_append_list.clear() @@ -299,14 +311,14 @@ class TaskMgr(threading.Thread): subtask.task_started = False return [True, ''] - @net_lock + @data_lock('network_lock') def acquire_task_ips(self, task): self.logger.info("[acquire_task_ips] user(%s) task(%s) net(%s)" % (task.username, task.id, str(task.task_base_ip))) if task.task_base_ip == None: task.task_base_ip = self.free_nets.pop(0) return task.task_base_ip - @net_lock + @data_lock('network_lock') def release_task_ips(self, task): self.logger.info("[release_task_ips] user(%s) task(%s) net(%s)" % (task.username, task.id, str(task.task_base_ip))) if task.task_base_ip == None: @@ -411,6 +423,10 @@ class TaskMgr(threading.Thread): self.stop_vnode(sub_task) #pass + @data_lock('task_stop_lock') + def lazy_stop_task(self, taskid): + self.lazy_stop_list.append(taskid) + def stop_remove_task(self, task): if task is None: return @@ -418,7 +434,6 @@ class TaskMgr(threading.Thread): self.clear_sub_tasks(task.subtask_list) self.release_task_ips(task) self.remove_tasknet(task) - self.lazy_delete_list.append(task) def check_task_completed(self, task): if task.status == RUNNING or task.status == WAITING: @@ -427,6 +442,7 @@ class TaskMgr(threading.Thread): return False self.logger.info('task %s finished, status %d, subtasks: %s' % (task.id, task.status, str([sub_task.status for sub_task in task.subtask_list]))) self.stop_remove_task(task) + self.lazy_delete_list.append(task) running_time, billing = task.get_billing() self.logger.info('task %s running_time:%s billing:%d'%(task.id, str(running_time), billing)) running_time = math.ceil(running_time) @@ -474,7 +490,7 @@ class TaskMgr(threading.Thread): self.logger.info('[task_scheduler] scheduling... (%d tasks remains)' % len(self.task_queue)) for task in self.task_queue: - if task in self.lazy_delete_list: + if task in self.lazy_delete_list or task.id in self.lazy_stop_list: continue self.logger.info('task %s sub_tasks %s' % (task.id, str([sub_task.status for sub_task in task.subtask_list]))) if self.check_task_completed(task): @@ -606,6 +622,7 @@ class TaskMgr(threading.Thread): # save the task information into database # called when jobmgr assign task to taskmgr + @data_lock('add_lock') def add_task(self, username, taskid, json_task, task_priority=1): # decode json string to object defined in grpc self.logger.info('[taskmgr add_task] receive task %s' % taskid) @@ -683,12 +700,12 @@ class TaskMgr(threading.Thread): return True - @queue_lock + @data_lock('task_queue_lock') def get_task_list(self): return self.task_queue.copy() - @queue_lock + @data_lock('task_queue_lock') def get_task(self, taskid): for task in self.task_queue: if task.id == taskid: