openmpi/ompi/mca/coll/ucc/coll_ucc_module.c

529 lines
19 KiB
C

/**
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
* Copyright (c) 2022 Amazon.com, Inc. or its affiliates.
* All Rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
*
* $HEADER$
*/
#include "ompi_config.h"
#include "coll_ucc.h"
#include "coll_ucc_dtypes.h"
#include "ompi/mca/coll/base/coll_tags.h"
#include "ompi/mca/pml/pml.h"
#define OBJ_RELEASE_IF_NOT_NULL( obj ) if( NULL != (obj) ) OBJ_RELEASE( obj );
static int ucc_comm_attr_keyval;
/*
* Initial query function that is invoked during MPI_INIT, allowing
* this module to indicate what level of thread support it provides.
*/
int mca_coll_ucc_init_query(bool enable_progress_threads, bool enable_mpi_threads)
{
return OMPI_SUCCESS;
}
static void mca_coll_ucc_module_clear(mca_coll_ucc_module_t *ucc_module)
{
ucc_module->ucc_team = NULL;
ucc_module->previous_allreduce = NULL;
ucc_module->previous_iallreduce = NULL;
ucc_module->previous_barrier = NULL;
ucc_module->previous_ibarrier = NULL;
ucc_module->previous_bcast = NULL;
ucc_module->previous_ibcast = NULL;
ucc_module->previous_alltoall = NULL;
ucc_module->previous_ialltoall = NULL;
ucc_module->previous_alltoallv = NULL;
ucc_module->previous_ialltoallv = NULL;
ucc_module->previous_allgather = NULL;
ucc_module->previous_iallgather = NULL;
ucc_module->previous_allgatherv = NULL;
ucc_module->previous_iallgatherv = NULL;
ucc_module->previous_reduce = NULL;
ucc_module->previous_ireduce = NULL;
}
static void mca_coll_ucc_module_construct(mca_coll_ucc_module_t *ucc_module)
{
mca_coll_ucc_module_clear(ucc_module);
}
int mca_coll_ucc_progress(void)
{
ucc_context_progress(mca_coll_ucc_component.ucc_context);
return OPAL_SUCCESS;
}
static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module)
{
if (ucc_module->comm == &ompi_mpi_comm_world.comm){
if (OMPI_SUCCESS != ompi_attr_free_keyval(COMM_ATTR, &ucc_comm_attr_keyval, 0)) {
UCC_ERROR("ucc ompi_attr_free_keyval failed");
}
}
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_allreduce_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallreduce_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_barrier_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ibarrier_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_bcast_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ibcast_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_alltoall_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ialltoall_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_alltoallv_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ialltoallv_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_allgather_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallgather_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_allgatherv_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallgatherv_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_module);
mca_coll_ucc_module_clear(ucc_module);
}
#define SAVE_PREV_COLL_API(__api) do { \
ucc_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \
ucc_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \
if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \
return OMPI_ERROR; \
} \
OBJ_RETAIN(ucc_module->previous_ ## __api ## _module); \
} while(0)
static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module)
{
ompi_communicator_t *comm = ucc_module->comm;
SAVE_PREV_COLL_API(allreduce);
SAVE_PREV_COLL_API(iallreduce);
SAVE_PREV_COLL_API(barrier);
SAVE_PREV_COLL_API(ibarrier);
SAVE_PREV_COLL_API(bcast);
SAVE_PREV_COLL_API(ibcast);
SAVE_PREV_COLL_API(alltoall);
SAVE_PREV_COLL_API(ialltoall);
SAVE_PREV_COLL_API(alltoallv);
SAVE_PREV_COLL_API(ialltoallv);
SAVE_PREV_COLL_API(allgather);
SAVE_PREV_COLL_API(iallgather);
SAVE_PREV_COLL_API(allgatherv);
SAVE_PREV_COLL_API(iallgatherv);
SAVE_PREV_COLL_API(reduce);
SAVE_PREV_COLL_API(ireduce);
return OMPI_SUCCESS;
}
/*
** Communicator free callback
*/
static int ucc_comm_attr_del_fn(MPI_Comm comm, int keyval, void *attr_val, void *extra)
{
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*) attr_val;
ucc_status_t status;
while(UCC_INPROGRESS == (status = ucc_team_destroy(ucc_module->ucc_team))) {}
if (ucc_module->comm == &ompi_mpi_comm_world.comm) {
if (mca_coll_ucc_component.libucc_initialized) {
UCC_VERBOSE(1,"finalizing ucc library");
opal_progress_unregister(mca_coll_ucc_progress);
ucc_context_destroy(mca_coll_ucc_component.ucc_context);
ucc_finalize(mca_coll_ucc_component.ucc_lib);
}
}
if (UCC_OK != status) {
UCC_ERROR("UCC team destroy failed");
return OMPI_ERROR;
}
return OMPI_SUCCESS;
}
typedef struct oob_allgather_req{
void *sbuf;
void *rbuf;
void *oob_coll_ctx;
size_t msglen;
int iter;
ompi_request_t *reqs[2];
} oob_allgather_req_t;
static ucc_status_t oob_allgather_test(void *req)
{
oob_allgather_req_t *oob_req = (oob_allgather_req_t*)req;
ompi_communicator_t *comm = (ompi_communicator_t *)oob_req->oob_coll_ctx;
char *tmpsend = NULL;
char *tmprecv = NULL;
size_t msglen = oob_req->msglen;
int probe_count = 5;
int rank, size, sendto, recvfrom, recvdatafrom,
senddatafrom, completed, probe;
size = ompi_comm_size(comm);
rank = ompi_comm_rank(comm);
if (oob_req->iter == 0) {
tmprecv = (char*) oob_req->rbuf + (ptrdiff_t)rank * (ptrdiff_t)msglen;
memcpy(tmprecv, oob_req->sbuf, msglen);
}
sendto = (rank + 1) % size;
recvfrom = (rank - 1 + size) % size;
for (; oob_req->iter < size - 1; oob_req->iter++) {
if (oob_req->iter > 0) {
probe = 0;
do {
ompi_request_test_all(2, oob_req->reqs, &completed, MPI_STATUS_IGNORE);
probe++;
} while (!completed && probe < probe_count);
if (!completed) {
return UCC_INPROGRESS;
}
}
recvdatafrom = (rank - oob_req->iter - 1 + size) % size;
senddatafrom = (rank - oob_req->iter + size) % size;
tmprecv = (char*)oob_req->rbuf + (ptrdiff_t)recvdatafrom * (ptrdiff_t)msglen;
tmpsend = (char*)oob_req->rbuf + (ptrdiff_t)senddatafrom * (ptrdiff_t)msglen;
MCA_PML_CALL(isend(tmpsend, msglen, MPI_BYTE, sendto, MCA_COLL_BASE_TAG_UCC,
MCA_PML_BASE_SEND_STANDARD, comm, &oob_req->reqs[0]));
MCA_PML_CALL(irecv(tmprecv, msglen, MPI_BYTE, recvfrom,
MCA_COLL_BASE_TAG_UCC, comm, &oob_req->reqs[1]));
}
probe = 0;
do {
ompi_request_test_all(2, oob_req->reqs, &completed, MPI_STATUS_IGNORE);
probe++;
} while (!completed && probe < probe_count);
if (!completed) {
return UCC_INPROGRESS;
}
return UCC_OK;
}
static ucc_status_t oob_allgather_free(void *req)
{
free(req);
return UCC_OK;
}
static ucc_status_t oob_allgather(void *sbuf, void *rbuf, size_t msglen,
void *oob_coll_ctx, void **req)
{
oob_allgather_req_t *oob_req = malloc(sizeof(*oob_req));
oob_req->sbuf = sbuf;
oob_req->rbuf = rbuf;
oob_req->msglen = msglen;
oob_req->oob_coll_ctx = oob_coll_ctx;
oob_req->iter = 0;
*req = oob_req;
return UCC_OK;
}
static int mca_coll_ucc_init_ctx() {
mca_coll_ucc_component_t *cm = &mca_coll_ucc_component;
char str_buf[256];
ompi_attribute_fn_ptr_union_t del_fn;
ompi_attribute_fn_ptr_union_t copy_fn;
ucc_lib_config_h lib_config;
ucc_context_config_h ctx_config;
ucc_thread_mode_t tm_requested;
ucc_lib_params_t lib_params;
ucc_context_params_t ctx_params;
tm_requested = ompi_mpi_thread_multiple ? UCC_THREAD_MULTIPLE :
UCC_THREAD_SINGLE;
lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE;
lib_params.thread_mode = tm_requested;
if (UCC_OK != ucc_lib_config_read("OMPI", NULL, &lib_config)) {
UCC_ERROR("UCC lib config read failed");
return OMPI_ERROR;
}
if (strlen(cm->cls) > 0) {
if (UCC_OK != ucc_lib_config_modify(lib_config, "CLS", cm->cls)) {
ucc_lib_config_release(lib_config);
UCC_ERROR("failed to modify UCC lib config to set CLS");
return OMPI_ERROR;
}
}
if (UCC_OK != ucc_init(&lib_params, lib_config, &cm->ucc_lib)) {
UCC_ERROR("UCC lib init failed");
ucc_lib_config_release(lib_config);
cm->ucc_enable = 0;
return OMPI_ERROR;
}
ucc_lib_config_release(lib_config);
cm->ucc_lib_attr.mask = UCC_LIB_ATTR_FIELD_THREAD_MODE |
UCC_LIB_ATTR_FIELD_COLL_TYPES;
if (UCC_OK != ucc_lib_get_attr(cm->ucc_lib, &cm->ucc_lib_attr)) {
UCC_ERROR("UCC get lib attr failed");
goto cleanup_lib;
}
if (cm->ucc_lib_attr.thread_mode < tm_requested) {
UCC_ERROR("UCC library doesn't support MPI_THREAD_MULTIPLE");
goto cleanup_lib;
}
ctx_params.mask = UCC_CONTEXT_PARAM_FIELD_OOB;
ctx_params.oob.allgather = oob_allgather;
ctx_params.oob.req_test = oob_allgather_test;
ctx_params.oob.req_free = oob_allgather_free;
ctx_params.oob.coll_info = (void*)MPI_COMM_WORLD;
ctx_params.oob.n_oob_eps = ompi_comm_size(&ompi_mpi_comm_world.comm);
ctx_params.oob.oob_ep = ompi_comm_rank(&ompi_mpi_comm_world.comm);
if (UCC_OK != ucc_context_config_read(cm->ucc_lib, NULL, &ctx_config)) {
UCC_ERROR("UCC context config read failed");
goto cleanup_lib;
}
sprintf(str_buf, "%u", ompi_proc_world_size());
if (UCC_OK != ucc_context_config_modify(ctx_config, NULL, "ESTIMATED_NUM_EPS",
str_buf)) {
UCC_ERROR("UCC context config modify failed for estimated_num_eps");
goto cleanup_lib;
}
sprintf(str_buf, "%u", opal_process_info.num_local_peers + 1);
if (UCC_OK != ucc_context_config_modify(ctx_config, NULL, "ESTIMATED_NUM_PPN",
str_buf)) {
UCC_ERROR("UCC context config modify failed for estimated_num_eps");
goto cleanup_lib;
}
if (UCC_OK != ucc_context_create(cm->ucc_lib, &ctx_params,
ctx_config, &cm->ucc_context)) {
UCC_ERROR("UCC context create failed");
ucc_context_config_release(ctx_config);
goto cleanup_lib;
}
ucc_context_config_release(ctx_config);
copy_fn.attr_communicator_copy_fn = (MPI_Comm_internal_copy_attr_function*)
MPI_COMM_NULL_COPY_FN;
del_fn.attr_communicator_delete_fn = ucc_comm_attr_del_fn;
if (OMPI_SUCCESS != ompi_attr_create_keyval(COMM_ATTR, copy_fn, del_fn,
&ucc_comm_attr_keyval, NULL ,0, NULL)) {
UCC_ERROR("UCC comm keyval create failed");
goto cleanup_ctx;
}
OBJ_CONSTRUCT(&cm->requests, opal_free_list_t);
opal_free_list_init(&cm->requests, sizeof(mca_coll_ucc_req_t),
opal_cache_line_size, OBJ_CLASS(mca_coll_ucc_req_t),
0, 0, /* no payload data */
8, -1, 8, /* num_to_alloc, max, per alloc */
NULL, 0, NULL, NULL, NULL /* no Mpool or init function */);
opal_progress_register(mca_coll_ucc_progress);
UCC_VERBOSE(1, "initialized ucc context");
cm->libucc_initialized = true;
return OMPI_SUCCESS;
cleanup_ctx:
ucc_context_destroy(cm->ucc_context);
cleanup_lib:
ucc_finalize(cm->ucc_lib);
cm->ucc_enable = 0;
cm->libucc_initialized = false;
return OMPI_ERROR;
}
uint64_t rank_map_cb(uint64_t ep, void *cb_ctx)
{
struct ompi_communicator_t *comm = cb_ctx;
return ((ompi_process_name_t*)&ompi_comm_peer_lookup(comm, ep)->super.
proc_name)->vpid;
}
static inline ucc_ep_map_t get_rank_map(struct ompi_communicator_t *comm)
{
ucc_ep_map_t map;
int64_t r1, r2, stride, i;
int is_strided;
map.ep_num = ompi_comm_size(comm);
if (comm == &ompi_mpi_comm_world.comm) {
map.type = UCC_EP_MAP_FULL;
return map;
}
/* try to detect strided pattern */
is_strided = 1;
r1 = rank_map_cb(0, comm);
r2 = rank_map_cb(1, comm);
stride = r2 - r1;
for (i = 2; i < map.ep_num; i++) {
r1 = r2;
r2 = rank_map_cb(i, comm);
if (r2 - r1 != stride) {
is_strided = 0;
break;
}
}
if (is_strided) {
map.type = UCC_EP_MAP_STRIDED;
map.strided.start = r1;
map.strided.stride = stride;
} else {
map.type = UCC_EP_MAP_CB;
map.cb.cb = rank_map_cb;
map.cb.cb_ctx = (void*)comm;
}
return map;
}
/*
* Initialize module on the communicator
*/
static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
struct ompi_communicator_t *comm)
{
mca_coll_ucc_component_t *cm = &mca_coll_ucc_component;
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *)module;
ucc_status_t status;
int rc;
ucc_team_params_t team_params = {
.mask = UCC_TEAM_PARAM_FIELD_EP_MAP |
UCC_TEAM_PARAM_FIELD_EP |
UCC_TEAM_PARAM_FIELD_EP_RANGE |
UCC_TEAM_PARAM_FIELD_ID,
.ep_map = {
.type = (comm == &ompi_mpi_comm_world.comm) ?
UCC_EP_MAP_FULL : UCC_EP_MAP_CB,
.ep_num = ompi_comm_size(comm),
.cb.cb = rank_map_cb,
.cb.cb_ctx = (void*)comm
},
.ep = ompi_comm_rank(comm),
.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG,
.id = comm->c_contextid
};
UCC_VERBOSE(2,"creating ucc_team for comm %p, comm_id %d, comm_size %d",
(void*)comm,comm->c_contextid,ompi_comm_size(comm));
if (OMPI_SUCCESS != mca_coll_ucc_save_coll_handlers(ucc_module)){
UCC_ERROR("mca_coll_ucc_save_coll_handlers failed");
goto err;
}
if (UCC_OK != ucc_team_create_post(&cm->ucc_context, 1,
&team_params, &ucc_module->ucc_team)) {
UCC_ERROR("ucc_team_create_post failed");
goto err;
}
while (UCC_INPROGRESS == (status = ucc_team_create_test(
ucc_module->ucc_team))) {
opal_progress();
}
if (UCC_OK != status) {
UCC_ERROR("ucc_team_create_test failed");
goto err;
}
rc = ompi_attr_set_c(COMM_ATTR, comm, &comm->c_keyhash,
ucc_comm_attr_keyval, (void *)module, false);
if (OMPI_SUCCESS != rc) {
UCC_ERROR("ucc ompi_attr_set_c failed");
goto err;
}
return OMPI_SUCCESS;
err:
ucc_module->ucc_team = NULL;
cm->ucc_enable = 0;
opal_progress_unregister(mca_coll_ucc_progress);
return OMPI_ERROR;
}
#define SET_COLL_PTR(_module, _COLL, _coll) do { \
_module->super.coll_ ## _coll = NULL; \
_module->super.coll_i ## _coll = NULL; \
if ((mca_coll_ucc_component.ucc_lib_attr.coll_types & \
UCC_COLL_TYPE_ ## _COLL)) { \
if (mca_coll_ucc_component.cts_requested & \
UCC_COLL_TYPE_ ## _COLL) { \
_module->super.coll_ ## _coll = mca_coll_ucc_ ## _coll; \
} \
if (mca_coll_ucc_component.nb_cts_requested & \
UCC_COLL_TYPE_ ## _COLL) { \
_module->super.coll_i ## _coll = mca_coll_ucc_i ## _coll; \
} \
} \
} while(0)
/*
* Invoked when there's a new communicator that has been created.
* Look at the communicator and decide which set of functions and
* priority we want to return.
*/
mca_coll_base_module_t *
mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority)
{
mca_coll_ucc_component_t *cm = &mca_coll_ucc_component;
mca_coll_ucc_module_t *ucc_module;
*priority = 0;
if (!cm->ucc_enable){
return NULL;
}
if (OMPI_COMM_IS_INTER(comm) || ompi_comm_size(comm) < cm->ucc_np
|| ompi_comm_size(comm) < 2){
return NULL;
}
if (!cm->libucc_initialized) {
if (OMPI_SUCCESS != mca_coll_ucc_init_ctx()) {
cm->ucc_enable = 0;
return NULL;
}
}
ucc_module = OBJ_NEW(mca_coll_ucc_module_t);
if (!ucc_module) {
cm->ucc_enable = 0;
return NULL;
}
ucc_module->comm = comm;
ucc_module->super.coll_module_enable = mca_coll_ucc_module_enable;
*priority = cm->ucc_priority;
SET_COLL_PTR(ucc_module, BARRIER, barrier);
SET_COLL_PTR(ucc_module, BCAST, bcast);
SET_COLL_PTR(ucc_module, ALLREDUCE, allreduce);
SET_COLL_PTR(ucc_module, ALLTOALL, alltoall);
SET_COLL_PTR(ucc_module, ALLTOALLV, alltoallv);
SET_COLL_PTR(ucc_module, REDUCE, reduce);
SET_COLL_PTR(ucc_module, ALLGATHER, allgather);
SET_COLL_PTR(ucc_module, ALLGATHERV, allgatherv);
return &ucc_module->super;
}
OBJ_CLASS_INSTANCE(mca_coll_ucc_module_t,
mca_coll_base_module_t,
mca_coll_ucc_module_construct,
mca_coll_ucc_module_destruct);
OBJ_CLASS_INSTANCE(mca_coll_ucc_req_t, ompi_request_t,
NULL, NULL);
int mca_coll_ucc_req_free(struct ompi_request_t **ompi_req)
{
opal_free_list_return (&mca_coll_ucc_component.requests,
(opal_free_list_item_t *)(*ompi_req));
*ompi_req = MPI_REQUEST_NULL;
return OMPI_SUCCESS;
}
void mca_coll_ucc_completion(void *data, ucc_status_t status)
{
mca_coll_ucc_req_t *coll_req = (mca_coll_ucc_req_t*)data;
ucc_collective_finalize(coll_req->ucc_req);
ompi_request_complete(&coll_req->super, true);
}