diff --git a/daemon/remote.c b/daemon/remote.c index a28a75403b..80a2c1f271 100644 --- a/daemon/remote.c +++ b/daemon/remote.c @@ -2030,6 +2030,7 @@ remoteDispatchAuthList(virNetServerPtr server ATTRIBUTE_UNUSED, int rv = -1; int auth = virNetServerClientGetAuth(client); uid_t callerUid; + gid_t callerGid; pid_t callerPid; /* If the client is root then we want to bypass the @@ -2037,7 +2038,7 @@ remoteDispatchAuthList(virNetServerPtr server ATTRIBUTE_UNUSED, * some piece of polkit isn't present/running */ if (auth == VIR_NET_SERVER_SERVICE_AUTH_POLKIT) { - if (virNetServerClientGetLocalIdentity(client, &callerUid, &callerPid) < 0) { + if (virNetServerClientGetLocalIdentity(client, &callerUid, &callerGid, &callerPid) < 0) { /* Don't do anything on error - it'll be validated at next * phase of auth anyway */ virResetLastError(); @@ -2463,6 +2464,7 @@ remoteDispatchAuthPolkit(virNetServerPtr server ATTRIBUTE_UNUSED, remote_auth_polkit_ret *ret) { pid_t callerPid = -1; + gid_t callerGid = -1; uid_t callerUid = -1; const char *action; int status = -1; @@ -2493,7 +2495,7 @@ remoteDispatchAuthPolkit(virNetServerPtr server ATTRIBUTE_UNUSED, goto authfail; } - if (virNetServerClientGetLocalIdentity(client, &callerUid, &callerPid) < 0) { + if (virNetServerClientGetLocalIdentity(client, &callerUid, &callerGid, &callerPid) < 0) { goto authfail; } @@ -2563,6 +2565,7 @@ remoteDispatchAuthPolkit(virNetServerPtr server, remote_auth_polkit_ret *ret) { pid_t callerPid; + gid_t callerGid; uid_t callerUid; PolKitCaller *pkcaller = NULL; PolKitAction *pkaction = NULL; @@ -2590,7 +2593,7 @@ remoteDispatchAuthPolkit(virNetServerPtr server, goto authfail; } - if (virNetServerClientGetLocalIdentity(client, &callerUid, &callerPid) < 0) { + if (virNetServerClientGetLocalIdentity(client, &callerUid, &callerGid, &callerPid) < 0) { VIR_ERROR(_("cannot get peer socket identity")); goto authfail; } diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c index cb07dd91ed..ed08e408b7 100644 --- a/src/rpc/virnetserverclient.c +++ b/src/rpc/virnetserverclient.c @@ -448,12 +448,12 @@ int virNetServerClientGetFD(virNetServerClientPtr client) } int virNetServerClientGetLocalIdentity(virNetServerClientPtr client, - uid_t *uid, pid_t *pid) + uid_t *uid, gid_t *gid, pid_t *pid) { int ret = -1; virNetServerClientLock(client); if (client->sock) - ret = virNetSocketGetLocalIdentity(client->sock, uid, pid); + ret = virNetSocketGetLocalIdentity(client->sock, uid, gid, pid); virNetServerClientUnlock(client); return ret; } diff --git a/src/rpc/virnetserverclient.h b/src/rpc/virnetserverclient.h index a201dca2fe..2dd01c5fb9 100644 --- a/src/rpc/virnetserverclient.h +++ b/src/rpc/virnetserverclient.h @@ -71,7 +71,7 @@ int virNetServerClientSetIdentity(virNetServerClientPtr client, const char *virNetServerClientGetIdentity(virNetServerClientPtr client); int virNetServerClientGetLocalIdentity(virNetServerClientPtr client, - uid_t *uid, pid_t *pid); + uid_t *uid, gid_t *gid, pid_t *pid); void virNetServerClientRef(virNetServerClientPtr client); diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c index af4fc5e9a7..8178ac3c86 100644 --- a/src/rpc/virnetsocket.c +++ b/src/rpc/virnetsocket.c @@ -826,6 +826,7 @@ int virNetSocketGetPort(virNetSocketPtr sock) #ifdef SO_PEERCRED int virNetSocketGetLocalIdentity(virNetSocketPtr sock, uid_t *uid, + gid_t *gid, pid_t *pid) { struct ucred cr; @@ -841,6 +842,7 @@ int virNetSocketGetLocalIdentity(virNetSocketPtr sock, *pid = cr.pid; *uid = cr.uid; + *gid = cr.gid; virMutexUnlock(&sock->lock); return 0; @@ -848,6 +850,7 @@ int virNetSocketGetLocalIdentity(virNetSocketPtr sock, #else int virNetSocketGetLocalIdentity(virNetSocketPtr sock ATTRIBUTE_UNUSED, uid_t *uid ATTRIBUTE_UNUSED, + gid_t *gid ATTRIBUTE_UNUSED, pid_t *pid ATTRIBUTE_UNUSED) { /* XXX Many more OS support UNIX socket credentials we could port to. See dbus ....*/ diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h index ef9baa8305..c2a040f56e 100644 --- a/src/rpc/virnetsocket.h +++ b/src/rpc/virnetsocket.h @@ -88,6 +88,7 @@ int virNetSocketGetPort(virNetSocketPtr sock); int virNetSocketGetLocalIdentity(virNetSocketPtr sock, uid_t *uid, + gid_t *gid, pid_t *pid); int virNetSocketSetBlocking(virNetSocketPtr sock,