From ad063f6192665b39238eab96abd02b9d259bfce6 Mon Sep 17 00:00:00 2001 From: Nikolay Shirokovskiy Date: Thu, 7 Feb 2019 15:58:44 +0300 Subject: [PATCH] rpc: client: incapsulate error checks Checking virNetClientStreamRaiseError without client lock is racy which is fixed in [1] for example. Thus let's remove such checks when we are sending message to server. And in other cases (like virNetClientStreamRecvHole for example) let's move the check into client stream code. virNetClientStreamRecvPacket already have stream lock so we could introduce another error checking function like virNetClientStreamRaiseErrorLocked but as error is set when both client and stream lock are hold we can remove locking from virNetClientStreamRaiseError because all callers hold either client or stream lock. Also let's split virNetClientStreamRaiseErrorLocked into checking state function and checking message send status function. They are same yet. [1] 1b6a29c21: rpc: fix race on stream abort/finish and server side abort Signed-off-by: Nikolay Shirokovskiy Signed-off-by: Michal Privoznik --- src/libvirt_remote.syms | 3 ++- src/remote/remote_driver.c | 16 ------------- src/rpc/virnetclient.c | 4 ++-- src/rpc/virnetclientstream.c | 44 ++++++++++++++++++++++++++++-------- src/rpc/virnetclientstream.h | 5 +++- 5 files changed, 43 insertions(+), 29 deletions(-) diff --git a/src/libvirt_remote.syms b/src/libvirt_remote.syms index 88745f2612..98586d1584 100644 --- a/src/libvirt_remote.syms +++ b/src/libvirt_remote.syms @@ -54,6 +54,8 @@ virNetClientProgramNew; # rpc/virnetclientstream.h +virNetClientStreamCheckSendStatus; +virNetClientStreamCheckState; virNetClientStreamEOF; virNetClientStreamEventAddCallback; virNetClientStreamEventRemoveCallback; @@ -61,7 +63,6 @@ virNetClientStreamEventUpdateCallback; virNetClientStreamMatches; virNetClientStreamNew; virNetClientStreamQueuePacket; -virNetClientStreamRaiseError; virNetClientStreamRecvHole; virNetClientStreamRecvPacket; virNetClientStreamSendHole; diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c index 058e4c926b..1ff55e241a 100644 --- a/src/remote/remote_driver.c +++ b/src/remote/remote_driver.c @@ -5600,9 +5600,6 @@ remoteStreamSend(virStreamPtr st, virNetClientStreamPtr privst = st->privateData; int rv; - if (virNetClientStreamRaiseError(privst)) - return -1; - remoteDriverLock(priv); priv->localUses++; remoteDriverUnlock(priv); @@ -5634,9 +5631,6 @@ remoteStreamRecvFlags(virStreamPtr st, virCheckFlags(VIR_STREAM_RECV_STOP_AT_HOLE, -1); - if (virNetClientStreamRaiseError(privst)) - return -1; - remoteDriverLock(priv); priv->localUses++; remoteDriverUnlock(priv); @@ -5676,9 +5670,6 @@ remoteStreamSendHole(virStreamPtr st, virNetClientStreamPtr privst = st->privateData; int rv; - if (virNetClientStreamRaiseError(privst)) - return -1; - remoteDriverLock(priv); priv->localUses++; remoteDriverUnlock(priv); @@ -5709,9 +5700,6 @@ remoteStreamRecvHole(virStreamPtr st, virCheckFlags(0, -1); - if (virNetClientStreamRaiseError(privst)) - return -1; - remoteDriverLock(priv); priv->localUses++; remoteDriverUnlock(priv); @@ -5834,9 +5822,6 @@ remoteStreamCloseInt(virStreamPtr st, bool streamAbort) remoteDriverLock(priv); - if (virNetClientStreamRaiseError(privst)) - goto cleanup; - priv->localUses++; remoteDriverUnlock(priv); @@ -5849,7 +5834,6 @@ remoteStreamCloseInt(virStreamPtr st, bool streamAbort) remoteDriverLock(priv); priv->localUses--; - cleanup: virNetClientRemoveStream(priv->client, privst); virObjectUnref(privst); st->privateData = NULL; diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c index fcc2e806e1..70192a9e88 100644 --- a/src/rpc/virnetclient.c +++ b/src/rpc/virnetclient.c @@ -2193,7 +2193,7 @@ int virNetClientSendStream(virNetClientPtr client, virObjectLock(client); - if (virNetClientStreamRaiseError(st)) + if (virNetClientStreamCheckState(st) < 0) goto cleanup; /* Check for EOF only if we are going to wait for incoming data */ @@ -2205,7 +2205,7 @@ int virNetClientSendStream(virNetClientPtr client, if (virNetClientSendInternal(client, msg, expectReply, false) < 0) goto cleanup; - if (virNetClientStreamRaiseError(st)) + if (virNetClientStreamCheckSendStatus(st, msg) < 0) goto cleanup; ret = 0; diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c index 136ed16610..4738da2c3d 100644 --- a/src/rpc/virnetclientstream.c +++ b/src/rpc/virnetclientstream.c @@ -184,14 +184,9 @@ bool virNetClientStreamMatches(virNetClientStreamPtr st, } -bool virNetClientStreamRaiseError(virNetClientStreamPtr st) +static +void virNetClientStreamRaiseError(virNetClientStreamPtr st) { - virObjectLock(st); - if (st->err.code == VIR_ERR_OK) { - virObjectUnlock(st); - return false; - } - virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__, st->err.domain, st->err.code, @@ -202,8 +197,31 @@ bool virNetClientStreamRaiseError(virNetClientStreamPtr st) st->err.int1, st->err.int2, "%s", st->err.message ? st->err.message : _("Unknown error")); - virObjectUnlock(st); - return true; +} + + +/* MUST be called under stream or client lock */ +int virNetClientStreamCheckState(virNetClientStreamPtr st) +{ + if (st->err.code != VIR_ERR_OK) { + virNetClientStreamRaiseError(st); + return -1; + } + + return 0; +} + + +/* MUST be called under stream or client lock */ +int virNetClientStreamCheckSendStatus(virNetClientStreamPtr st, + virNetMessagePtr msg ATTRIBUTE_UNUSED) +{ + if (st->err.code != VIR_ERR_OK) { + virNetClientStreamRaiseError(st); + return -1; + } + + return 0; } @@ -474,6 +492,9 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st, virObjectLock(st); reread: + if (virNetClientStreamCheckState(st) < 0) + goto cleanup; + if (!st->rx && !st->incomingEOF) { virNetMessagePtr msg; int ret; @@ -646,6 +667,11 @@ virNetClientStreamRecvHole(virNetClientPtr client ATTRIBUTE_UNUSED, virObjectLock(st); + if (virNetClientStreamCheckState(st) < 0) { + virObjectUnlock(st); + return -1; + } + *length = st->holeLength; st->holeLength = 0; diff --git a/src/rpc/virnetclientstream.h b/src/rpc/virnetclientstream.h index d13793222b..d81ec60a48 100644 --- a/src/rpc/virnetclientstream.h +++ b/src/rpc/virnetclientstream.h @@ -36,7 +36,10 @@ virNetClientStreamPtr virNetClientStreamNew(virStreamPtr stream, unsigned serial, bool allowSkip); -bool virNetClientStreamRaiseError(virNetClientStreamPtr st); +int virNetClientStreamCheckState(virNetClientStreamPtr st); + +int virNetClientStreamCheckSendStatus(virNetClientStreamPtr st, + virNetMessagePtr msg); int virNetClientStreamSetError(virNetClientStreamPtr st, virNetMessagePtr msg);