diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c index 70192a9e88..64855fb8d6 100644 --- a/src/rpc/virnetclient.c +++ b/src/rpc/virnetclient.c @@ -1158,6 +1158,19 @@ static int virNetClientCallDispatchMessage(virNetClientPtr client) return 0; } +static void virNetClientCallCompleteAllWaitingReply(virNetClientPtr client) +{ + virNetClientCallPtr call; + + for (call = client->waitDispatch; call; call = call->next) { + if (call->msg->header.prog == client->msg.header.prog && + call->msg->header.vers == client->msg.header.vers && + call->msg->header.serial == client->msg.header.serial && + call->expectReply) + call->mode = VIR_NET_CLIENT_MODE_COMPLETE; + } +} + static int virNetClientCallDispatchStream(virNetClientPtr client) { size_t i; @@ -1181,16 +1194,6 @@ static int virNetClientCallDispatchStream(virNetClientPtr client) return 0; } - /* Finish/Abort are synchronous, so also see if there's an - * (optional) call waiting for this stream packet */ - thecall = client->waitDispatch; - while (thecall && - !(thecall->msg->header.prog == client->msg.header.prog && - thecall->msg->header.vers == client->msg.header.vers && - thecall->msg->header.serial == client->msg.header.serial)) - thecall = thecall->next; - - VIR_DEBUG("Found call %p", thecall); /* Status is either * - VIR_NET_OK - no payload for streams @@ -1202,25 +1205,47 @@ static int virNetClientCallDispatchStream(virNetClientPtr client) if (virNetClientStreamQueuePacket(st, &client->msg) < 0) return -1; - if (thecall && thecall->expectReply) { - if (thecall->msg->header.status == VIR_NET_CONTINUE) { - VIR_DEBUG("Got a synchronous confirm"); - thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; - } else { - VIR_DEBUG("Not completing call with status %d", thecall->msg->header.status); - } + /* Find oldest dummy message waiting for incoming data. */ + for (thecall = client->waitDispatch; thecall; thecall = thecall->next) { + if (thecall->msg->header.prog == client->msg.header.prog && + thecall->msg->header.vers == client->msg.header.vers && + thecall->msg->header.serial == client->msg.header.serial && + thecall->expectReply && + thecall->msg->header.status == VIR_NET_CONTINUE) + break; + } + + if (thecall) { + VIR_DEBUG("Got a new incoming stream data"); + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; } return 0; } case VIR_NET_OK: - if (thecall && thecall->expectReply) { - VIR_DEBUG("Got a synchronous confirm"); - thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; - } else { + /* Find oldest abort/finish message. */ + for (thecall = client->waitDispatch; thecall; thecall = thecall->next) { + if (thecall->msg->header.prog == client->msg.header.prog && + thecall->msg->header.vers == client->msg.header.vers && + thecall->msg->header.serial == client->msg.header.serial && + thecall->expectReply && + thecall->msg->header.status != VIR_NET_CONTINUE) + break; + } + + if (!thecall) { VIR_DEBUG("Got unexpected async stream finish confirmation"); return -1; } + + VIR_DEBUG("Got a synchronous abort/finish confirm"); + + virNetClientStreamSetClosed(st, + thecall->msg->header.status == VIR_NET_OK ? + VIR_NET_CLIENT_STREAM_CLOSED_FINISHED : + VIR_NET_CLIENT_STREAM_CLOSED_ABORTED); + + virNetClientCallCompleteAllWaitingReply(client); return 0; case VIR_NET_ERROR: @@ -1228,10 +1253,7 @@ static int virNetClientCallDispatchStream(virNetClientPtr client) if (virNetClientStreamSetError(st, &client->msg) < 0) return -1; - if (thecall && thecall->expectReply) { - VIR_DEBUG("Got a synchronous error"); - thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE; - } + virNetClientCallCompleteAllWaitingReply(client); return 0; default: @@ -2205,7 +2227,7 @@ int virNetClientSendStream(virNetClientPtr client, if (virNetClientSendInternal(client, msg, expectReply, false) < 0) goto cleanup; - if (virNetClientStreamCheckSendStatus(st, msg) < 0) + if (expectReply && virNetClientStreamCheckSendStatus(st, msg) < 0) goto cleanup; ret = 0; diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c index f27dcbfea7..834c44843b 100644 --- a/src/rpc/virnetclientstream.c +++ b/src/rpc/virnetclientstream.c @@ -49,6 +49,7 @@ struct _virNetClientStream { */ virNetMessagePtr rx; bool incomingEOF; + virNetClientStreamClosed closed; bool allowSkip; long long holeLength; /* Size of incoming hole in stream. */ @@ -84,7 +85,7 @@ virNetClientStreamEventTimerUpdate(virNetClientStreamPtr st) VIR_DEBUG("Check timer rx=%p cbEvents=%d", st->rx, st->cbEvents); - if (((st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK) && + if (((st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK || st->closed) && (st->cbEvents & VIR_STREAM_EVENT_READABLE)) || (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) { VIR_DEBUG("Enabling event timer"); @@ -106,7 +107,7 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque) if (st->cb && (st->cbEvents & VIR_STREAM_EVENT_READABLE) && - (st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK)) + (st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK || st->closed)) events |= VIR_STREAM_EVENT_READABLE; if (st->cb && (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) @@ -203,23 +204,61 @@ int virNetClientStreamCheckState(virNetClientStreamPtr st) return -1; } + if (st->closed) { + virReportError(VIR_ERR_OPERATION_FAILED, "%s", + _("stream is closed")); + return -1; + } + return 0; } -/* MUST be called under stream or client lock */ +/* MUST be called under stream or client lock. This should + * be called only for message that expect reply. */ int virNetClientStreamCheckSendStatus(virNetClientStreamPtr st, - virNetMessagePtr msg ATTRIBUTE_UNUSED) + virNetMessagePtr msg) { if (st->err.code != VIR_ERR_OK) { virNetClientStreamRaiseError(st); return -1; } + /* We can not check if the message is dummy in a usual way + * by checking msg->bufferLength because at this point message payload + * is cleared. As caller must not call this function for messages + * not expecting reply we can check for dummy messages just by status. + */ + if (msg->header.status == VIR_NET_CONTINUE) { + if (st->closed) { + virReportError(VIR_ERR_OPERATION_FAILED, "%s", + _("stream is closed")); + return -1; + } + return 0; + } else if (msg->header.status == VIR_NET_OK && + st->closed != VIR_NET_CLIENT_STREAM_CLOSED_FINISHED) { + virReportError(VIR_ERR_OPERATION_FAILED, "%s", + _("stream aborted by another thread")); + return -1; + } + return 0; } +void virNetClientStreamSetClosed(virNetClientStreamPtr st, + virNetClientStreamClosed closed) +{ + virObjectLock(st); + + st->closed = closed; + virNetClientStreamEventTimerUpdate(st); + + virObjectUnlock(st); +} + + int virNetClientStreamSetError(virNetClientStreamPtr st, virNetMessagePtr msg) { diff --git a/src/rpc/virnetclientstream.h b/src/rpc/virnetclientstream.h index 49b74bcc41..33a2af357f 100644 --- a/src/rpc/virnetclientstream.h +++ b/src/rpc/virnetclientstream.h @@ -27,6 +27,12 @@ typedef struct _virNetClientStream virNetClientStream; typedef virNetClientStream *virNetClientStreamPtr; +typedef enum { + VIR_NET_CLIENT_STREAM_CLOSED_NOT = 0, + VIR_NET_CLIENT_STREAM_CLOSED_FINISHED, + VIR_NET_CLIENT_STREAM_CLOSED_ABORTED, +} virNetClientStreamClosed; + typedef void (*virNetClientStreamEventCallback)(virNetClientStreamPtr stream, int events, void *opaque); @@ -43,6 +49,9 @@ int virNetClientStreamCheckSendStatus(virNetClientStreamPtr st, int virNetClientStreamSetError(virNetClientStreamPtr st, virNetMessagePtr msg); +void virNetClientStreamSetClosed(virNetClientStreamPtr st, + virNetClientStreamClosed closed); + bool virNetClientStreamMatches(virNetClientStreamPtr st, virNetMessagePtr msg);