Merge pull request #316 from GallenShao/batch

add more tests & more bugs fixed, add token to rpc.proto
This commit is contained in:
GallenShao 2018-07-28 21:37:02 +08:00 committed by GitHub
commit 82c59a9812
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 120 additions and 55 deletions

View File

@ -15,6 +15,8 @@ import grpc
from protos.rpc_pb2 import *
from protos.rpc_pb2_grpc import MasterServicer, add_MasterServicer_to_server, WorkerStub
from utils import env
class Task():
def __init__(self, info):
@ -39,14 +41,19 @@ class TaskMgr(threading.Thread):
# load task information from etcd
# initial a task queue and task schedueler
# taskmgr: a taskmgr instance
def __init__(self, nodemgr, monitor_fetcher, logger):
def __init__(self, nodemgr, monitor_fetcher, logger, worker_timeout=60, scheduler_interval=2):
threading.Thread.__init__(self)
self.thread_stop = False
self.jobmgr = None
self.task_queue = []
self.heart_beat_timeout = 5 # (s)
self.heart_beat_timeout = worker_timeout # (s)
self.scheduler_interval = scheduler_interval
self.logger = logger
self.master_port = env.getenv('BATCH_MASTER_PORT')
self.worker_port = env.getenv('BATCH_WORKER_PORT')
# nodes
self.nodemgr = nodemgr
self.monitor_fetcher = monitor_fetcher
@ -63,13 +70,13 @@ class TaskMgr(threading.Thread):
if task is not None and worker is not None:
self.task_processor(task, instance_id, worker)
else:
time.sleep(2)
time.sleep(self.scheduler_interval)
def serve(self):
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
add_MasterServicer_to_server(TaskReporter(self), self.server)
self.server.add_insecure_port('[::]:50051')
self.server.add_insecure_port('[::]:' + self.master_port)
self.server.start()
self.logger.info('[taskmgr_rpc] start rpc server')
@ -83,7 +90,7 @@ class TaskMgr(threading.Thread):
# this method is called when worker send heart-beat rpc request
def on_task_report(self, report):
self.logger.info('[on_task_report] receive task report: id %s-%d, status %d' % (report.taskid, report.instanceid, report.instanceStatus))
task = get_task(report.taskid)
task = self.get_task(report.taskid)
if task == None:
self.logger.error('[on_task_report] task not found')
return
@ -93,6 +100,9 @@ class TaskMgr(threading.Thread):
self.logger.warning('[on_task_report] wrong token')
return
if instance['status'] != RUNNING:
self.logger.error('[on_task_report] receive task report when instance is not running')
if instance['status'] == RUNNING and report.instanceStatus != RUNNING:
self.cpu_usage[instance['worker']] -= task.info.cluster.instance.cpu
@ -100,12 +110,10 @@ class TaskMgr(threading.Thread):
instance['last_update_time'] = time.time()
if report.instanceStatus == COMPLETED:
check_task_completed(task)
self.check_task_completed(task)
elif report.instanceStatus == FAILED or report.instanceStatus == TIMEOUT:
if instance['try_count'] > task.info.maxRetryCount:
check_task_completed(task)
else:
self.logger.error('[on_task_report] receive report from waiting task')
self.check_task_completed(task)
def check_task_completed(self, task):
@ -135,25 +143,25 @@ class TaskMgr(threading.Thread):
self.task_queue.remove(task)
def task_processor(self, task, instance_id, worker):
def task_processor(self, task, instance_id, worker_ip):
task.status = RUNNING
# properties for transaction
task.info.instanceid = instance_id
task.token = ''.join(random.sample(string.ascii_letters + string.digits, 8))
task.info.token = ''.join(random.sample(string.ascii_letters + string.digits, 8))
instance = task.instance_list[instance_id]
instance['status'] = RUNNING
instance['last_update_time'] = time.time()
instance['try_count'] += 1
instance['token'] = task.token
instance['worker'] = worker
instance['token'] = task.info.token
instance['worker'] = worker_ip
self.cpu_usage[worker] += task.info.cluster.instance.cpu
self.cpu_usage[worker_ip] += task.info.cluster.instance.cpu
try:
self.logger.info('[task_processor] processing task [%s] instance [%d]' % (task.info.id, task.info.instanceid))
channel = grpc.insecure_channel('%s:50052' % worker)
channel = grpc.insecure_channel('%s:%s' % (worker_ip, self.worker_port))
stub = WorkerStub(channel)
response = stub.process_task(task.info)
if response.status != Reply.ACCEPTED:
@ -167,7 +175,7 @@ class TaskMgr(threading.Thread):
# return task, worker
def task_scheduler(self):
# simple FIFO
self.logger.info('[task_scheduler] scheduling...')
self.logger.info('[task_scheduler] scheduling... (%d tasks remains)' % len(self.task_queue))
for task in self.task_queue:
worker = self.find_proper_worker(task)
@ -195,6 +203,9 @@ class TaskMgr(threading.Thread):
instance['try_count'] = 0
task.instance_list.append(instance)
return task, len(task.instance_list) - 1, worker
self.check_task_completed(task)
return None, None, None
def find_proper_worker(self, task):

View File

@ -1,8 +1,10 @@
import master.taskmgr
from concurrent import futures
import grpc
from protos import rpc_pb2, rpc_pb2_grpc
import threading, json, time
from protos.rpc_pb2 import *
from protos.rpc_pb2_grpc import *
import threading, json, time, random
from utils import env
class SimulatedNodeMgr():
@ -21,10 +23,15 @@ class SimulatedMonitorFetcher():
self.info['diskinfo'][0]['free'] = 8 * 1024 * 1024 * 1024 # (b) simulate 8 GB disk
class SimulatedTaskController(rpc_pb2_grpc.WorkerServicer):
class SimulatedTaskController(WorkerServicer):
def __init__(self, worker):
self.worker = worker
def process_task(self, task, context):
print('[SimulatedTaskController] receive task [%s]' % task.id)
return rpc_pb2.Reply(status=rpc_pb2.Reply.ACCEPTED,message="")
print('[SimulatedTaskController] receive task [%s] instanceid [%d] token [%s]' % (task.id, task.instanceid, task.token))
worker.process(task)
return Reply(status=Reply.ACCEPTED,message="")
class SimulatedWorker(threading.Thread):
@ -32,19 +39,36 @@ class SimulatedWorker(threading.Thread):
def __init__(self):
threading.Thread.__init__(self)
self.thread_stop = False
self.tasks = []
def run(self):
worker_port = env.getenv('BATCH_WORKER_PORT')
server = grpc.server(futures.ThreadPoolExecutor(max_workers=5))
rpc_pb2_grpc.add_WorkerServicer_to_server(SimulatedTaskController(), server)
server.add_insecure_port('[::]:50052')
add_WorkerServicer_to_server(SimulatedTaskController(self), server)
server.add_insecure_port('[::]:' + worker_port)
server.start()
while not self.thread_stop:
for task in self.tasks:
seed = random.random()
if seed < 0.25:
report(task.id, task.instanceid, RUNNING, task.token)
elif seed < 0.5:
report(task.id, task.instanceid, COMPLETED, task.token)
self.tasks.remove(task)
elif seed < 0.75:
report(task.id, task.instanceid, FAILED, task.token)
self.tasks.remove(task)
else:
pass
time.sleep(5)
server.stop(0)
def stop(self):
self.thread_stop = True
def process(self, task):
self.tasks.append(task)
class SimulatedJobMgr(threading.Thread):
@ -61,9 +85,9 @@ class SimulatedJobMgr(threading.Thread):
self.thread_stop = True
def report(self, task):
print('[SimulatedJobMgr] task[%s] status %d' % (task.id, task.status))
print('[SimulatedJobMgr] task[%s] status %d' % (task.info.id, task.status))
def asignTask(self, taskmgr, taskid, instance_count, retry_count, timeout, cpu, memory, disk):
def assignTask(self, taskmgr, taskid, instance_count, retry_count, timeout, cpu, memory, disk):
task = {}
task['instanceCount'] = instance_count
task['maxRetryCount'] = retry_count
@ -111,11 +135,26 @@ def test():
jobmgr = SimulatedJobMgr()
jobmgr.start()
taskmgr = master.taskmgr.TaskMgr(SimulatedNodeMgr(), SimulatedMonitorFetcher, SimulatedLogger())
taskmgr = master.taskmgr.TaskMgr(SimulatedNodeMgr(), SimulatedMonitorFetcher, SimulatedLogger(), worker_timeout=10, scheduler_interval=2)
taskmgr.set_jobmgr(jobmgr)
taskmgr.start()
jobmgr.asignTask(taskmgr, 'task_0', 2, 2, 60, 2, 2048, 2048)
add('task_0', instance_count=2, retry_count=2, timeout=60, cpu=2, memory=2048, disk=2048)
def add(taskid, instance_count, retry_count, timeout, cpu, memory, disk):
global jobmgr
global taskmgr
jobmgr.assignTask(taskmgr, taskid, instance_count, retry_count, timeout, cpu, memory, disk)
def report(taskid, instanceid, status, token):
global taskmgr
master_port = env.getenv('BATCH_MASTER_PORT')
channel = grpc.insecure_channel('%s:%s' % ('0.0.0.0', master_port))
stub = MasterStub(channel)
response = stub.report(TaskMsg(taskid=taskid, instanceid=instanceid, instanceStatus=status, token=token))
def stop():
@ -125,4 +164,4 @@ def stop():
worker.stop()
jobmgr.stop()
taskmgr.stop()
taskmgr.stop()

View File

@ -22,7 +22,7 @@ message TaskMsg {
string taskid = 1;
int32 instanceid = 2;
Status instanceStatus = 3; //
string token = 4;
}
enum Status {
@ -42,6 +42,7 @@ message TaskInfo {
Parameters parameters = 6; //
Cluster cluster = 7; //
int32 timeout = 8; //
string token = 9;
}
message Parameters {

View File

@ -20,7 +20,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
name='protos/rpc.proto',
package='',
syntax='proto3',
serialized_pb=_b('\n\x10protos/rpc.proto\"f\n\x05Reply\x12\"\n\x06status\x18\x01 \x01(\x0e\x32\x12.Reply.ReplyStatus\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x0bReplyStatus\x12\x0c\n\x08\x41\x43\x43\x45PTED\x10\x00\x12\x0b\n\x07REFUSED\x10\x01\"N\n\x07TaskMsg\x12\x0e\n\x06taskid\x18\x01 \x01(\t\x12\x12\n\ninstanceid\x18\x02 \x01(\x05\x12\x1f\n\x0einstanceStatus\x18\x03 \x01(\x0e\x32\x07.Status\"\xb7\x01\n\x08TaskInfo\x12\n\n\x02id\x18\x01 \x01(\t\x12\x10\n\x08username\x18\x02 \x01(\t\x12\x12\n\ninstanceid\x18\x03 \x01(\x05\x12\x15\n\rinstanceCount\x18\x04 \x01(\x05\x12\x15\n\rmaxRetryCount\x18\x05 \x01(\x05\x12\x1f\n\nparameters\x18\x06 \x01(\x0b\x32\x0b.Parameters\x12\x19\n\x07\x63luster\x18\x07 \x01(\x0b\x32\x08.Cluster\x12\x0f\n\x07timeout\x18\x08 \x01(\x05\"_\n\nParameters\x12\x19\n\x07\x63ommand\x18\x01 \x01(\x0b\x32\x08.Command\x12\x1a\n\x12stderrRedirectPath\x18\x02 \x01(\t\x12\x1a\n\x12stdoutRedirectPath\x18\x03 \x01(\t\"\x8b\x01\n\x07\x43ommand\x12\x13\n\x0b\x63ommandLine\x18\x01 \x01(\t\x12\x13\n\x0bpackagePath\x18\x02 \x01(\t\x12&\n\x07\x65nvVars\x18\x03 \x03(\x0b\x32\x15.Command.EnvVarsEntry\x1a.\n\x0c\x45nvVarsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"T\n\x07\x43luster\x12\x15\n\x05image\x18\x01 \x01(\x0b\x32\x06.Image\x12\x1b\n\x08instance\x18\x02 \x01(\x0b\x32\t.Instance\x12\x15\n\x05mount\x18\x03 \x03(\x0b\x32\x06.Mount\"t\n\x05Image\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1e\n\x04type\x18\x02 \x01(\x0e\x32\x10.Image.ImageType\x12\r\n\x05owner\x18\x03 \x01(\t\".\n\tImageType\x12\x08\n\x04\x42\x41SE\x10\x00\x12\n\n\x06PUBLIC\x10\x01\x12\x0b\n\x07PRIVATE\x10\x02\".\n\x05Mount\x12\x11\n\tlocalPath\x18\x01 \x01(\t\x12\x12\n\nremotePath\x18\x02 \x01(\t\"B\n\x08Instance\x12\x0b\n\x03\x63pu\x18\x01 \x01(\x05\x12\x0e\n\x06memory\x18\x02 \x01(\x05\x12\x0c\n\x04\x64isk\x18\x03 \x01(\x05\x12\x0b\n\x03gpu\x18\x04 \x01(\x05*J\n\x06Status\x12\x0b\n\x07WAITING\x10\x00\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tCOMPLETED\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\x0b\n\x07TIMEOUT\x10\x04\x32&\n\x06Master\x12\x1c\n\x06report\x12\x08.TaskMsg\x1a\x06.Reply\"\x00\x32-\n\x06Worker\x12#\n\x0cprocess_task\x12\t.TaskInfo\x1a\x06.Reply\"\x00\x62\x06proto3')
serialized_pb=_b('\n\x10protos/rpc.proto\"f\n\x05Reply\x12\"\n\x06status\x18\x01 \x01(\x0e\x32\x12.Reply.ReplyStatus\x12\x0f\n\x07message\x18\x02 \x01(\t\"(\n\x0bReplyStatus\x12\x0c\n\x08\x41\x43\x43\x45PTED\x10\x00\x12\x0b\n\x07REFUSED\x10\x01\"]\n\x07TaskMsg\x12\x0e\n\x06taskid\x18\x01 \x01(\t\x12\x12\n\ninstanceid\x18\x02 \x01(\x05\x12\x1f\n\x0einstanceStatus\x18\x03 \x01(\x0e\x32\x07.Status\x12\r\n\x05token\x18\x04 \x01(\t\"\xc6\x01\n\x08TaskInfo\x12\n\n\x02id\x18\x01 \x01(\t\x12\x10\n\x08username\x18\x02 \x01(\t\x12\x12\n\ninstanceid\x18\x03 \x01(\x05\x12\x15\n\rinstanceCount\x18\x04 \x01(\x05\x12\x15\n\rmaxRetryCount\x18\x05 \x01(\x05\x12\x1f\n\nparameters\x18\x06 \x01(\x0b\x32\x0b.Parameters\x12\x19\n\x07\x63luster\x18\x07 \x01(\x0b\x32\x08.Cluster\x12\x0f\n\x07timeout\x18\x08 \x01(\x05\x12\r\n\x05token\x18\t \x01(\t\"_\n\nParameters\x12\x19\n\x07\x63ommand\x18\x01 \x01(\x0b\x32\x08.Command\x12\x1a\n\x12stderrRedirectPath\x18\x02 \x01(\t\x12\x1a\n\x12stdoutRedirectPath\x18\x03 \x01(\t\"\x8b\x01\n\x07\x43ommand\x12\x13\n\x0b\x63ommandLine\x18\x01 \x01(\t\x12\x13\n\x0bpackagePath\x18\x02 \x01(\t\x12&\n\x07\x65nvVars\x18\x03 \x03(\x0b\x32\x15.Command.EnvVarsEntry\x1a.\n\x0c\x45nvVarsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"T\n\x07\x43luster\x12\x15\n\x05image\x18\x01 \x01(\x0b\x32\x06.Image\x12\x1b\n\x08instance\x18\x02 \x01(\x0b\x32\t.Instance\x12\x15\n\x05mount\x18\x03 \x03(\x0b\x32\x06.Mount\"t\n\x05Image\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1e\n\x04type\x18\x02 \x01(\x0e\x32\x10.Image.ImageType\x12\r\n\x05owner\x18\x03 \x01(\t\".\n\tImageType\x12\x08\n\x04\x42\x41SE\x10\x00\x12\n\n\x06PUBLIC\x10\x01\x12\x0b\n\x07PRIVATE\x10\x02\".\n\x05Mount\x12\x11\n\tlocalPath\x18\x01 \x01(\t\x12\x12\n\nremotePath\x18\x02 \x01(\t\"B\n\x08Instance\x12\x0b\n\x03\x63pu\x18\x01 \x01(\x05\x12\x0e\n\x06memory\x18\x02 \x01(\x05\x12\x0c\n\x04\x64isk\x18\x03 \x01(\x05\x12\x0b\n\x03gpu\x18\x04 \x01(\x05*J\n\x06Status\x12\x0b\n\x07WAITING\x10\x00\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tCOMPLETED\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\x0b\n\x07TIMEOUT\x10\x04\x32&\n\x06Master\x12\x1c\n\x06report\x12\x08.TaskMsg\x1a\x06.Reply\"\x00\x32-\n\x06Worker\x12#\n\x0cprocess_task\x12\t.TaskInfo\x1a\x06.Reply\"\x00\x62\x06proto3')
)
_STATUS = _descriptor.EnumDescriptor(
@ -52,8 +52,8 @@ _STATUS = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=949,
serialized_end=1023,
serialized_start=979,
serialized_end=1053,
)
_sym_db.RegisterEnumDescriptor(_STATUS)
@ -108,8 +108,8 @@ _IMAGE_IMAGETYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=785,
serialized_end=831,
serialized_start=815,
serialized_end=861,
)
_sym_db.RegisterEnumDescriptor(_IMAGE_IMAGETYPE)
@ -181,6 +181,13 @@ _TASKMSG = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='token', full_name='TaskMsg.token', index=3,
number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],
@ -194,7 +201,7 @@ _TASKMSG = _descriptor.Descriptor(
oneofs=[
],
serialized_start=124,
serialized_end=202,
serialized_end=217,
)
@ -261,6 +268,13 @@ _TASKINFO = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='token', full_name='TaskInfo.token', index=8,
number=9, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],
@ -273,8 +287,8 @@ _TASKINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=205,
serialized_end=388,
serialized_start=220,
serialized_end=418,
)
@ -318,8 +332,8 @@ _PARAMETERS = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=390,
serialized_end=485,
serialized_start=420,
serialized_end=515,
)
@ -356,8 +370,8 @@ _COMMAND_ENVVARSENTRY = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=581,
serialized_end=627,
serialized_start=611,
serialized_end=657,
)
_COMMAND = _descriptor.Descriptor(
@ -400,8 +414,8 @@ _COMMAND = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=488,
serialized_end=627,
serialized_start=518,
serialized_end=657,
)
@ -445,8 +459,8 @@ _CLUSTER = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=629,
serialized_end=713,
serialized_start=659,
serialized_end=743,
)
@ -491,8 +505,8 @@ _IMAGE = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=715,
serialized_end=831,
serialized_start=745,
serialized_end=861,
)
@ -529,8 +543,8 @@ _MOUNT = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=833,
serialized_end=879,
serialized_start=863,
serialized_end=909,
)
@ -581,8 +595,8 @@ _INSTANCE = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=881,
serialized_end=947,
serialized_start=911,
serialized_end=977,
)
_REPLY.fields_by_name['status'].enum_type = _REPLY_REPLYSTATUS
@ -691,8 +705,8 @@ _MASTER = _descriptor.ServiceDescriptor(
file=DESCRIPTOR,
index=0,
options=None,
serialized_start=1025,
serialized_end=1063,
serialized_start=1055,
serialized_end=1093,
methods=[
_descriptor.MethodDescriptor(
name='report',
@ -715,8 +729,8 @@ _WORKER = _descriptor.ServiceDescriptor(
file=DESCRIPTOR,
index=1,
options=None,
serialized_start=1065,
serialized_end=1110,
serialized_start=1095,
serialized_end=1140,
methods=[
_descriptor.MethodDescriptor(
name='process_task',