From 8a4e28743e2440158e56e9d7cacc886ceabb37e8 Mon Sep 17 00:00:00 2001 From: "Daniel P. Berrange" Date: Tue, 28 Jun 2011 17:51:49 +0100 Subject: [PATCH] Fix locking wrt virNetClientStreamPtr object The client stream object can be used independently of the virNetClientPtr object, so must have full locking of its own and not rely on any caller. * src/remote/remote_driver.c: Remove locking around stream callback * src/rpc/virnetclientstream.c: Add locking to all APIs and callbacks --- src/remote/remote_driver.c | 3 - src/rpc/virnetclientstream.c | 112 ++++++++++++++++++++++++++++------- 2 files changed, 89 insertions(+), 26 deletions(-) diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c index 0e68bc533a..ce9bcb157a 100644 --- a/src/remote/remote_driver.c +++ b/src/remote/remote_driver.c @@ -3254,11 +3254,8 @@ static void remoteStreamEventCallback(virNetClientStreamPtr stream ATTRIBUTE_UNU void *opaque) { struct remoteStreamCallbackData *cbdata = opaque; - struct private_data *priv = cbdata->st->conn->privateData; - remoteDriverUnlock(priv); (cbdata->cb)(cbdata->st, events, cbdata->opaque); - remoteDriverLock(priv); } diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c index 99c7b410b0..9da5aeec44 100644 --- a/src/rpc/virnetclientstream.c +++ b/src/rpc/virnetclientstream.c @@ -28,6 +28,7 @@ #include "virterror_internal.h" #include "logging.h" #include "event.h" +#include "threads.h" #define VIR_FROM_THIS VIR_FROM_RPC #define virNetError(code, ...) \ @@ -35,6 +36,8 @@ __FUNCTION__, __LINE__, __VA_ARGS__) struct _virNetClientStream { + virMutex lock; + virNetClientProgramPtr prog; int proc; unsigned serial; @@ -53,7 +56,6 @@ struct _virNetClientStream { size_t incomingOffset; size_t incomingLength; - virNetClientStreamEventCallback cb; void *cbOpaque; virFreeCallback cbFree; @@ -89,7 +91,8 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque) virNetClientStreamPtr st = opaque; int events = 0; - /* XXX we need a mutex on 'st' to protect this callback */ + + virMutexLock(&st->lock); if (st->cb && (st->cbEvents & VIR_STREAM_EVENT_READABLE) && @@ -106,12 +109,15 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque) virFreeCallback cbFree = st->cbFree; st->cbDispatch = 1; + virMutexUnlock(&st->lock); (cb)(st, events, cbOpaque); + virMutexLock(&st->lock); st->cbDispatch = 0; if (!st->cb && cbFree) (cbFree)(cbOpaque); } + virMutexUnlock(&st->lock); } @@ -134,30 +140,45 @@ virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog, return NULL; } - virNetClientProgramRef(prog); - st->refs = 1; st->prog = prog; st->proc = proc; st->serial = serial; + if (virMutexInit(&st->lock) < 0) { + virNetError(VIR_ERR_INTERNAL_ERROR, "%s", + _("cannot initialize mutex")); + VIR_FREE(st); + return NULL; + } + + virNetClientProgramRef(prog); + return st; } void virNetClientStreamRef(virNetClientStreamPtr st) { + virMutexLock(&st->lock); st->refs++; + virMutexUnlock(&st->lock); } void virNetClientStreamFree(virNetClientStreamPtr st) { + virMutexLock(&st->lock); st->refs--; - if (st->refs > 0) + if (st->refs > 0) { + virMutexUnlock(&st->lock); return; + } + + virMutexUnlock(&st->lock); virResetError(&st->err); VIR_FREE(st->incoming); + virMutexDestroy(&st->lock); virNetClientProgramFree(st->prog); VIR_FREE(st); } @@ -165,18 +186,24 @@ void virNetClientStreamFree(virNetClientStreamPtr st) bool virNetClientStreamMatches(virNetClientStreamPtr st, virNetMessagePtr msg) { + bool match = false; + virMutexLock(&st->lock); if (virNetClientProgramMatches(st->prog, msg) && st->proc == msg->header.proc && st->serial == msg->header.serial) - return 1; - return 0; + match = true; + virMutexUnlock(&st->lock); + return match; } bool virNetClientStreamRaiseError(virNetClientStreamPtr st) { - if (st->err.code == VIR_ERR_OK) + virMutexLock(&st->lock); + if (st->err.code == VIR_ERR_OK) { + virMutexUnlock(&st->lock); return false; + } virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__, st->err.domain, @@ -188,7 +215,7 @@ bool virNetClientStreamRaiseError(virNetClientStreamPtr st) st->err.int1, st->err.int2, "%s", st->err.message ? st->err.message : _("Unknown error")); - + virMutexUnlock(&st->lock); return true; } @@ -199,6 +226,8 @@ int virNetClientStreamSetError(virNetClientStreamPtr st, virNetMessageError err; int ret = -1; + virMutexLock(&st->lock); + if (st->err.code != VIR_ERR_OK) VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message)); @@ -242,6 +271,7 @@ int virNetClientStreamSetError(virNetClientStreamPtr st, cleanup: xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err); + virMutexUnlock(&st->lock); return ret; } @@ -249,15 +279,18 @@ cleanup: int virNetClientStreamQueuePacket(virNetClientStreamPtr st, virNetMessagePtr msg) { - size_t avail = st->incomingLength - st->incomingOffset; - size_t need = msg->bufferLength - msg->bufferOffset; + int ret = -1; + size_t need; + virMutexLock(&st->lock); + need = msg->bufferLength - msg->bufferOffset; + size_t avail = st->incomingLength - st->incomingOffset; if (need > avail) { size_t extra = need - avail; if (VIR_REALLOC_N(st->incoming, st->incomingLength + extra) < 0) { VIR_DEBUG("Out of memory handling stream data"); - return -1; + goto cleanup; } st->incomingLength += extra; } @@ -269,7 +302,12 @@ int virNetClientStreamQueuePacket(virNetClientStreamPtr st, VIR_DEBUG("Stream incoming data offset %zu length %zu", st->incomingOffset, st->incomingLength); - return 0; + + ret = 0; + +cleanup: + virMutexUnlock(&st->lock); + return ret; } @@ -286,6 +324,8 @@ int virNetClientStreamSendPacket(virNetClientStreamPtr st, if (!(msg = virNetMessageNew())) return -1; + virMutexLock(&st->lock); + msg->header.prog = virNetClientProgramGetProgram(st->prog); msg->header.vers = virNetClientProgramGetVersion(st->prog); msg->header.status = status; @@ -293,6 +333,8 @@ int virNetClientStreamSendPacket(virNetClientStreamPtr st, msg->header.serial = st->serial; msg->header.proc = st->proc; + virMutexUnlock(&st->lock); + if (virNetMessageEncodeHeader(msg) < 0) goto error; @@ -329,6 +371,7 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st, int rv = -1; VIR_DEBUG("st=%p client=%p data=%p nbytes=%zu nonblock=%d", st, client, data, nbytes, nonblock); + virMutexLock(&st->lock); if (!st->incomingOffset) { virNetMessagePtr msg; int ret; @@ -351,8 +394,9 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st, msg->header.proc = st->proc; VIR_DEBUG("Dummy packet to wait for stream data"); + virMutexUnlock(&st->lock); ret = virNetClientSend(client, msg, true); - + virMutexLock(&st->lock); virNetMessageFree(msg); if (ret < 0) @@ -380,6 +424,7 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st, virNetClientStreamEventTimerUpdate(st); cleanup: + virMutexUnlock(&st->lock); return rv; } @@ -390,20 +435,23 @@ int virNetClientStreamEventAddCallback(virNetClientStreamPtr st, void *opaque, virFreeCallback ff) { + int ret = -1; + + virMutexLock(&st->lock); if (st->cb) { virNetError(VIR_ERR_INTERNAL_ERROR, "%s", _("multiple stream callbacks not supported")); - return 1; + goto cleanup; } - virNetClientStreamRef(st); + st->refs++; if ((st->cbTimer = virEventAddTimeout(-1, virNetClientStreamEventTimer, st, virNetClientStreamEventTimerFree)) < 0) { - virNetClientStreamFree(st); - return -1; + st->refs--; + goto cleanup; } st->cb = cb; @@ -413,31 +461,45 @@ int virNetClientStreamEventAddCallback(virNetClientStreamPtr st, virNetClientStreamEventTimerUpdate(st); - return 0; + ret = 0; + +cleanup: + virMutexUnlock(&st->lock); + return ret; } int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st, int events) { + int ret = -1; + + virMutexLock(&st->lock); if (!st->cb) { virNetError(VIR_ERR_INTERNAL_ERROR, "%s", _("no stream callback registered")); - return -1; + goto cleanup; } st->cbEvents = events; virNetClientStreamEventTimerUpdate(st); - return 0; + ret = 0; + +cleanup: + virMutexUnlock(&st->lock); + return ret; } int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st) { + int ret = -1; + + virMutexUnlock(&st->lock); if (!st->cb) { virNetError(VIR_ERR_INTERNAL_ERROR, "%s", _("no stream callback registered")); - return -1; + goto cleanup; } if (!st->cbDispatch && @@ -449,5 +511,9 @@ int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st) st->cbEvents = 0; virEventRemoveTimeout(st->cbTimer); - return 0; + ret = 0; + +cleanup: + virMutexUnlock(&st->lock); + return ret; }