Merge pull request #310 from GallenShao/batch

update taskmgr
This commit is contained in:
GallenShao 2018-07-19 19:11:38 +08:00 committed by GitHub
commit 55ed430439
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 102 additions and 31 deletions

View File

@ -905,6 +905,9 @@ if __name__ == '__main__':
G_cloudmgr = cloudmgr.CloudMgr()
G_taskmgr = taskmgr.TaskMgr()
G_jobmgr = jobmgr.JobMgr(taskmgr)
G_jobmgr.start()
G_taskmgr.set_jobmgr(G_jobmgr)
G_taskmgr.start()
# start NodeMgr and NodeMgr will wait for all nodes to start ...
G_nodemgr = nodemgr.NodeMgr(G_networkmgr, etcdclient, addr = ipaddr, mode=mode)

View File

@ -1,5 +1,8 @@
import threading
import time
import string
import random
import json
import master.monitor
@ -11,8 +14,9 @@ from utils.log import logger
# grpc
from concurrent import futures
import grpc
from protos.rpc_pb2 import Task, Reply
from protos.rpc_pb2_grpc import MasterServicer, add_MasterServicer_to_server
from protos.rpc_pb2 import Task, TaskMsg, Status, Reply, Parameters, Cluster, Command, Image, Mount, Instance
from protos.rpc_pb2_grpc import MasterServicer, add_MasterServicer_to_server, WorkerStub
class TaskReporter(MasterServicer):
@ -23,6 +27,7 @@ class TaskReporter(MasterServicer):
self.taskmgr.on_task_report(request)
return Reply(message=Reply.ACCEPTED)
class TaskMgr(threading.Thread):
# load task information from etcd
@ -31,9 +36,9 @@ class TaskMgr(threading.Thread):
def __init__(self, nodemgr):
threading.Thread.__init__(self)
self.thread_stop = False
# tasks
self.jobmgr = None
self.task_queue = []
self.heart_beat_timeout = 60 # (s)
# nodes
self.nodemgr = nodemgr
@ -68,22 +73,23 @@ class TaskMgr(threading.Thread):
# this method is called when worker send heart-beat rpc request
def on_task_report(self, report):
logger.info('[on_task_report] receive task report: id %d, status %d' % (report.id, report.status))
task = get_task(report.id)
logger.info('[on_task_report] receive task report: id %s-%d, status %d' % (report.taskid, report.instanceid, report.instanceStatus))
task = get_task(report.taskid)
if task == None:
logger.error('[on_task_report] task not found')
return
instance_id = report.parameters.command.envVars['INSTANCE_ID']
instance = task.instance_list[instance_id]
instance = task.instance_list[report.instanceid]
if instance['token'] != report.token:
logger.warning('[on_task_report] wrong token')
return
if report.status == Task.RUNNING:
pass
elif report.status == Task.COMPLETED:
instance['status'] = 'completed'
instance['status'] = report.instanceStatus
if report.instanceStatus == Status.RUNNING:
instance['last_update_time'] = time.time()
elif report.instanceStatus == Status.COMPLETED:
check_task_completed(task)
elif report.status == Task.FAILED || report.status == Task.TIMEOUT:
instance['status'] = 'failed'
elif report.instanceStatus == Status.FAILED || report.instanceStatus == Status.TIMEOUT:
if instance['try_count'] > task.maxRetryCount:
check_task_completed(task)
else:
@ -95,27 +101,49 @@ class TaskMgr(threading.Thread):
return
failed = False
for instance in task.instance_list:
if instance['status'] == 'running':
if instance['status'] == Status.RUNNING || instance['status'] == Status.WAITING:
return
if instance['status'] == 'failed':
if instance['status'] == Status.FAILED || instance['status'] == Status.TIMEOUT:
if instance['try_count'] > task.maxRetryCount:
failed = True
else:
return
if self.jobmgr is None:
logger.error('[check_task_completed] jobmgr is None!')
return
if failed:
# tell jobmgr task failed
task.status = Task.FAILED
# TODO tell jobmgr task failed
task.status = Status.FAILED
else:
# tell jobmgr task completed
task.status = Task.COMPLETED
# TODO tell jobmgr task completed
task.status = Status.COMPLETED
logger.info('task %s completed' % task.id)
self.task_queue.remove(task)
def task_processor(self, task, instance_id, worker):
task.status = Task.RUNNING
task.parameters.command.envVars['INSTANCE_ID'] = instance_id
# TODO call the rpc to call a function in worker
print('processing %s' % task.id)
task.status = Status.RUNNING
# properties for transaction
task.instanceid = instance_id
task.token = ''.join(random.sample(string.ascii_letters + string.digits, 8))
instance = task.instance_list[instance_id]
instance['status'] = Status.RUNNING
instance['last_update_time'] = time.time()
instance['try_count'] += 1
instance['token'] = task.token
try:
logger.info('[task_processor] processing %s' % task.id)
channel = grpc.insecure_channel('%s:50052' % worker)
stub = WorkerStub(channel)
response = stub.process_task(task)
logger.info('[task_processor] worker response: %d' response.message)
except Exception as e:
logger.error('[task_processor] rpc error message: %s' e)
instance['status'] = Status.FAILED
instance['try_count'] -= 1
# return task, worker
@ -126,14 +154,17 @@ class TaskMgr(threading.Thread):
if worker is not None:
# find instance to retry
for instance, index in enumerate(task.instance_list):
if instance['status'] == 'failed' and instance['try_count'] <= task.maxRetryCount:
instance['try_count'] += 1
if (instance['status'] == Status.FAILED || instance['status'] == Status.TIMEOUT) and instance['try_count'] <= task.maxRetryCount:
return task, index, worker
elif instance['status'] == Status.RUNNING:
if time.time() - instance['last_update_time'] > self.heart_beat_timeout:
instance['status'] = Status.FAILED
instance['token'] = ''
return task, index, worker
# start new instance
if len(task.instance_list) < task.instanceCount:
instance = {}
instance['status'] = 'running'
instance['try_count'] = 0
task.instance_list.append(instance)
return task, len(task.instance_list) - 1, worker
@ -146,8 +177,11 @@ class TaskMgr(threading.Thread):
logger.warning('[task_scheduler] running nodes not found')
return None
# TODO
return nodes[0]
for node in nodes:
# TODO
if True:
return node[0]
return None
def get_all_nodes(self):
@ -159,16 +193,50 @@ class TaskMgr(threading.Thread):
self.all_nodes = []
for node_ip in node_ips:
fetcher = master.monitor.Fetcher(node_ip)
self.all_nodes.append(fetcher.info)
self.all_nodes.append((node_ip, fetcher.info))
return self.all_nodes
def set_jobmgr(self, jobmgr):
self.jobmgr = jobmgr
# user: username
# task: a json string
# save the task information into database
# called when jobmgr assign task to taskmgr
def add_task(self, task):
def add_task(self, username, taskid, json_task):
# decode json string to object defined in grpc
json_task = json.loads(json_task)
task = Task(
id = taskid,
username = username,
instanceCount = json_task['instanceCount'],
maxRetryCount = json_task['maxRetryCount'],
timeout = json_task['timeout'],
parameters = Parameters(
command = Command(
commandLine = json_task['parameters']['command']['commandLine'],
packagePath = json_task['parameters']['command']['packagePath'],
envVars = json_task['parameters']['command']['envVars']),
stderrRedirectPath = json_task['parameters']['stderrRedirectPath'],
stdoutRedirectPath = json_task['parameters']['stdoutRedirectPath']),
cluster = Cluster(
,image = Image(
name = json_task['cluster']['image']['name'],
type = json_task['cluster']['image']['type'],
owner = json_task['cluster']['image']['owner']),
instance = Instance(
cpu = json_task['cluster']['instance']['cpu'],
memory = json_task['cluster']['instance']['memory'],
disk = json_task['cluster']['instance']['disk'],
gpu = json_task['cluster']['instance']['gpu'])))
task.cluster.mount = []
for mount in json_task['cluster']['mount']:
task.cluster.mount.append(Mount(localPath=mount['localPath'], remotePath=mount['remotePath']))
# local properties
task.status = Status.WAITING
task.instance_list = []
self.task_queue.append(task)