[AF_UNIX]: Make socket locking much less confusing.

The unix_state_*() locking macros imply that there is some
rwlock kind of thing going on, but the implementation is
actually a spinlock which makes the code more confusing than
it needs to be.

So use plain unix_state_lock and unix_state_unlock.

Signed-off-by: David S. Miller <davem@davemloft.net>
This commit is contained in:
David S. Miller 2007-05-31 13:24:26 -07:00
parent c1a13ff57a
commit 1c92b4e50e
2 changed files with 50 additions and 52 deletions

View File

@ -62,13 +62,11 @@ struct unix_skb_parms {
#define UNIXCREDS(skb) (&UNIXCB((skb)).creds) #define UNIXCREDS(skb) (&UNIXCB((skb)).creds)
#define UNIXSID(skb) (&UNIXCB((skb)).secid) #define UNIXSID(skb) (&UNIXCB((skb)).secid)
#define unix_state_rlock(s) spin_lock(&unix_sk(s)->lock) #define unix_state_lock(s) spin_lock(&unix_sk(s)->lock)
#define unix_state_runlock(s) spin_unlock(&unix_sk(s)->lock) #define unix_state_unlock(s) spin_unlock(&unix_sk(s)->lock)
#define unix_state_wlock(s) spin_lock(&unix_sk(s)->lock) #define unix_state_lock_nested(s) \
#define unix_state_wlock_nested(s) \
spin_lock_nested(&unix_sk(s)->lock, \ spin_lock_nested(&unix_sk(s)->lock, \
SINGLE_DEPTH_NESTING) SINGLE_DEPTH_NESTING)
#define unix_state_wunlock(s) spin_unlock(&unix_sk(s)->lock)
#ifdef __KERNEL__ #ifdef __KERNEL__
/* The AF_UNIX socket */ /* The AF_UNIX socket */

View File

@ -174,11 +174,11 @@ static struct sock *unix_peer_get(struct sock *s)
{ {
struct sock *peer; struct sock *peer;
unix_state_rlock(s); unix_state_lock(s);
peer = unix_peer(s); peer = unix_peer(s);
if (peer) if (peer)
sock_hold(peer); sock_hold(peer);
unix_state_runlock(s); unix_state_unlock(s);
return peer; return peer;
} }
@ -369,7 +369,7 @@ static int unix_release_sock (struct sock *sk, int embrion)
unix_remove_socket(sk); unix_remove_socket(sk);
/* Clear state */ /* Clear state */
unix_state_wlock(sk); unix_state_lock(sk);
sock_orphan(sk); sock_orphan(sk);
sk->sk_shutdown = SHUTDOWN_MASK; sk->sk_shutdown = SHUTDOWN_MASK;
dentry = u->dentry; dentry = u->dentry;
@ -378,7 +378,7 @@ static int unix_release_sock (struct sock *sk, int embrion)
u->mnt = NULL; u->mnt = NULL;
state = sk->sk_state; state = sk->sk_state;
sk->sk_state = TCP_CLOSE; sk->sk_state = TCP_CLOSE;
unix_state_wunlock(sk); unix_state_unlock(sk);
wake_up_interruptible_all(&u->peer_wait); wake_up_interruptible_all(&u->peer_wait);
@ -386,12 +386,12 @@ static int unix_release_sock (struct sock *sk, int embrion)
if (skpair!=NULL) { if (skpair!=NULL) {
if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) { if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) {
unix_state_wlock(skpair); unix_state_lock(skpair);
/* No more writes */ /* No more writes */
skpair->sk_shutdown = SHUTDOWN_MASK; skpair->sk_shutdown = SHUTDOWN_MASK;
if (!skb_queue_empty(&sk->sk_receive_queue) || embrion) if (!skb_queue_empty(&sk->sk_receive_queue) || embrion)
skpair->sk_err = ECONNRESET; skpair->sk_err = ECONNRESET;
unix_state_wunlock(skpair); unix_state_unlock(skpair);
skpair->sk_state_change(skpair); skpair->sk_state_change(skpair);
read_lock(&skpair->sk_callback_lock); read_lock(&skpair->sk_callback_lock);
sk_wake_async(skpair,1,POLL_HUP); sk_wake_async(skpair,1,POLL_HUP);
@ -448,7 +448,7 @@ static int unix_listen(struct socket *sock, int backlog)
err = -EINVAL; err = -EINVAL;
if (!u->addr) if (!u->addr)
goto out; /* No listens on an unbound socket */ goto out; /* No listens on an unbound socket */
unix_state_wlock(sk); unix_state_lock(sk);
if (sk->sk_state != TCP_CLOSE && sk->sk_state != TCP_LISTEN) if (sk->sk_state != TCP_CLOSE && sk->sk_state != TCP_LISTEN)
goto out_unlock; goto out_unlock;
if (backlog > sk->sk_max_ack_backlog) if (backlog > sk->sk_max_ack_backlog)
@ -462,7 +462,7 @@ static int unix_listen(struct socket *sock, int backlog)
err = 0; err = 0;
out_unlock: out_unlock:
unix_state_wunlock(sk); unix_state_unlock(sk);
out: out:
return err; return err;
} }
@ -881,7 +881,7 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr,
if (!other) if (!other)
goto out; goto out;
unix_state_wlock(sk); unix_state_lock(sk);
err = -EPERM; err = -EPERM;
if (!unix_may_send(sk, other)) if (!unix_may_send(sk, other))
@ -896,7 +896,7 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr,
* 1003.1g breaking connected state with AF_UNSPEC * 1003.1g breaking connected state with AF_UNSPEC
*/ */
other = NULL; other = NULL;
unix_state_wlock(sk); unix_state_lock(sk);
} }
/* /*
@ -905,19 +905,19 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr,
if (unix_peer(sk)) { if (unix_peer(sk)) {
struct sock *old_peer = unix_peer(sk); struct sock *old_peer = unix_peer(sk);
unix_peer(sk)=other; unix_peer(sk)=other;
unix_state_wunlock(sk); unix_state_unlock(sk);
if (other != old_peer) if (other != old_peer)
unix_dgram_disconnected(sk, old_peer); unix_dgram_disconnected(sk, old_peer);
sock_put(old_peer); sock_put(old_peer);
} else { } else {
unix_peer(sk)=other; unix_peer(sk)=other;
unix_state_wunlock(sk); unix_state_unlock(sk);
} }
return 0; return 0;
out_unlock: out_unlock:
unix_state_wunlock(sk); unix_state_unlock(sk);
sock_put(other); sock_put(other);
out: out:
return err; return err;
@ -936,7 +936,7 @@ static long unix_wait_for_peer(struct sock *other, long timeo)
(skb_queue_len(&other->sk_receive_queue) > (skb_queue_len(&other->sk_receive_queue) >
other->sk_max_ack_backlog); other->sk_max_ack_backlog);
unix_state_runlock(other); unix_state_unlock(other);
if (sched) if (sched)
timeo = schedule_timeout(timeo); timeo = schedule_timeout(timeo);
@ -994,11 +994,11 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
goto out; goto out;
/* Latch state of peer */ /* Latch state of peer */
unix_state_rlock(other); unix_state_lock(other);
/* Apparently VFS overslept socket death. Retry. */ /* Apparently VFS overslept socket death. Retry. */
if (sock_flag(other, SOCK_DEAD)) { if (sock_flag(other, SOCK_DEAD)) {
unix_state_runlock(other); unix_state_unlock(other);
sock_put(other); sock_put(other);
goto restart; goto restart;
} }
@ -1048,18 +1048,18 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
goto out_unlock; goto out_unlock;
} }
unix_state_wlock_nested(sk); unix_state_lock_nested(sk);
if (sk->sk_state != st) { if (sk->sk_state != st) {
unix_state_wunlock(sk); unix_state_unlock(sk);
unix_state_runlock(other); unix_state_unlock(other);
sock_put(other); sock_put(other);
goto restart; goto restart;
} }
err = security_unix_stream_connect(sock, other->sk_socket, newsk); err = security_unix_stream_connect(sock, other->sk_socket, newsk);
if (err) { if (err) {
unix_state_wunlock(sk); unix_state_unlock(sk);
goto out_unlock; goto out_unlock;
} }
@ -1096,7 +1096,7 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
smp_mb__after_atomic_inc(); /* sock_hold() does an atomic_inc() */ smp_mb__after_atomic_inc(); /* sock_hold() does an atomic_inc() */
unix_peer(sk) = newsk; unix_peer(sk) = newsk;
unix_state_wunlock(sk); unix_state_unlock(sk);
/* take ten and and send info to listening sock */ /* take ten and and send info to listening sock */
spin_lock(&other->sk_receive_queue.lock); spin_lock(&other->sk_receive_queue.lock);
@ -1105,14 +1105,14 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
* is installed to listening socket. */ * is installed to listening socket. */
atomic_inc(&newu->inflight); atomic_inc(&newu->inflight);
spin_unlock(&other->sk_receive_queue.lock); spin_unlock(&other->sk_receive_queue.lock);
unix_state_runlock(other); unix_state_unlock(other);
other->sk_data_ready(other, 0); other->sk_data_ready(other, 0);
sock_put(other); sock_put(other);
return 0; return 0;
out_unlock: out_unlock:
if (other) if (other)
unix_state_runlock(other); unix_state_unlock(other);
out: out:
if (skb) if (skb)
@ -1178,10 +1178,10 @@ static int unix_accept(struct socket *sock, struct socket *newsock, int flags)
wake_up_interruptible(&unix_sk(sk)->peer_wait); wake_up_interruptible(&unix_sk(sk)->peer_wait);
/* attach accepted sock to socket */ /* attach accepted sock to socket */
unix_state_wlock(tsk); unix_state_lock(tsk);
newsock->state = SS_CONNECTED; newsock->state = SS_CONNECTED;
sock_graft(tsk, newsock); sock_graft(tsk, newsock);
unix_state_wunlock(tsk); unix_state_unlock(tsk);
return 0; return 0;
out: out:
@ -1208,7 +1208,7 @@ static int unix_getname(struct socket *sock, struct sockaddr *uaddr, int *uaddr_
} }
u = unix_sk(sk); u = unix_sk(sk);
unix_state_rlock(sk); unix_state_lock(sk);
if (!u->addr) { if (!u->addr) {
sunaddr->sun_family = AF_UNIX; sunaddr->sun_family = AF_UNIX;
sunaddr->sun_path[0] = 0; sunaddr->sun_path[0] = 0;
@ -1219,7 +1219,7 @@ static int unix_getname(struct socket *sock, struct sockaddr *uaddr, int *uaddr_
*uaddr_len = addr->len; *uaddr_len = addr->len;
memcpy(sunaddr, addr->name, *uaddr_len); memcpy(sunaddr, addr->name, *uaddr_len);
} }
unix_state_runlock(sk); unix_state_unlock(sk);
sock_put(sk); sock_put(sk);
out: out:
return err; return err;
@ -1337,7 +1337,7 @@ static int unix_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock,
goto out_free; goto out_free;
} }
unix_state_rlock(other); unix_state_lock(other);
err = -EPERM; err = -EPERM;
if (!unix_may_send(sk, other)) if (!unix_may_send(sk, other))
goto out_unlock; goto out_unlock;
@ -1347,20 +1347,20 @@ static int unix_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock,
* Check with 1003.1g - what should * Check with 1003.1g - what should
* datagram error * datagram error
*/ */
unix_state_runlock(other); unix_state_unlock(other);
sock_put(other); sock_put(other);
err = 0; err = 0;
unix_state_wlock(sk); unix_state_lock(sk);
if (unix_peer(sk) == other) { if (unix_peer(sk) == other) {
unix_peer(sk)=NULL; unix_peer(sk)=NULL;
unix_state_wunlock(sk); unix_state_unlock(sk);
unix_dgram_disconnected(sk, other); unix_dgram_disconnected(sk, other);
sock_put(other); sock_put(other);
err = -ECONNREFUSED; err = -ECONNREFUSED;
} else { } else {
unix_state_wunlock(sk); unix_state_unlock(sk);
} }
other = NULL; other = NULL;
@ -1397,14 +1397,14 @@ static int unix_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock,
} }
skb_queue_tail(&other->sk_receive_queue, skb); skb_queue_tail(&other->sk_receive_queue, skb);
unix_state_runlock(other); unix_state_unlock(other);
other->sk_data_ready(other, len); other->sk_data_ready(other, len);
sock_put(other); sock_put(other);
scm_destroy(siocb->scm); scm_destroy(siocb->scm);
return len; return len;
out_unlock: out_unlock:
unix_state_runlock(other); unix_state_unlock(other);
out_free: out_free:
kfree_skb(skb); kfree_skb(skb);
out: out:
@ -1494,14 +1494,14 @@ static int unix_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,
goto out_err; goto out_err;
} }
unix_state_rlock(other); unix_state_lock(other);
if (sock_flag(other, SOCK_DEAD) || if (sock_flag(other, SOCK_DEAD) ||
(other->sk_shutdown & RCV_SHUTDOWN)) (other->sk_shutdown & RCV_SHUTDOWN))
goto pipe_err_free; goto pipe_err_free;
skb_queue_tail(&other->sk_receive_queue, skb); skb_queue_tail(&other->sk_receive_queue, skb);
unix_state_runlock(other); unix_state_unlock(other);
other->sk_data_ready(other, size); other->sk_data_ready(other, size);
sent+=size; sent+=size;
} }
@ -1512,7 +1512,7 @@ static int unix_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,
return sent; return sent;
pipe_err_free: pipe_err_free:
unix_state_runlock(other); unix_state_unlock(other);
kfree_skb(skb); kfree_skb(skb);
pipe_err: pipe_err:
if (sent==0 && !(msg->msg_flags&MSG_NOSIGNAL)) if (sent==0 && !(msg->msg_flags&MSG_NOSIGNAL))
@ -1641,7 +1641,7 @@ static long unix_stream_data_wait(struct sock * sk, long timeo)
{ {
DEFINE_WAIT(wait); DEFINE_WAIT(wait);
unix_state_rlock(sk); unix_state_lock(sk);
for (;;) { for (;;) {
prepare_to_wait(sk->sk_sleep, &wait, TASK_INTERRUPTIBLE); prepare_to_wait(sk->sk_sleep, &wait, TASK_INTERRUPTIBLE);
@ -1654,14 +1654,14 @@ static long unix_stream_data_wait(struct sock * sk, long timeo)
break; break;
set_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags); set_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags);
unix_state_runlock(sk); unix_state_unlock(sk);
timeo = schedule_timeout(timeo); timeo = schedule_timeout(timeo);
unix_state_rlock(sk); unix_state_lock(sk);
clear_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags); clear_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags);
} }
finish_wait(sk->sk_sleep, &wait); finish_wait(sk->sk_sleep, &wait);
unix_state_runlock(sk); unix_state_unlock(sk);
return timeo; return timeo;
} }
@ -1816,12 +1816,12 @@ static int unix_shutdown(struct socket *sock, int mode)
mode = (mode+1)&(RCV_SHUTDOWN|SEND_SHUTDOWN); mode = (mode+1)&(RCV_SHUTDOWN|SEND_SHUTDOWN);
if (mode) { if (mode) {
unix_state_wlock(sk); unix_state_lock(sk);
sk->sk_shutdown |= mode; sk->sk_shutdown |= mode;
other=unix_peer(sk); other=unix_peer(sk);
if (other) if (other)
sock_hold(other); sock_hold(other);
unix_state_wunlock(sk); unix_state_unlock(sk);
sk->sk_state_change(sk); sk->sk_state_change(sk);
if (other && if (other &&
@ -1833,9 +1833,9 @@ static int unix_shutdown(struct socket *sock, int mode)
peer_mode |= SEND_SHUTDOWN; peer_mode |= SEND_SHUTDOWN;
if (mode&SEND_SHUTDOWN) if (mode&SEND_SHUTDOWN)
peer_mode |= RCV_SHUTDOWN; peer_mode |= RCV_SHUTDOWN;
unix_state_wlock(other); unix_state_lock(other);
other->sk_shutdown |= peer_mode; other->sk_shutdown |= peer_mode;
unix_state_wunlock(other); unix_state_unlock(other);
other->sk_state_change(other); other->sk_state_change(other);
read_lock(&other->sk_callback_lock); read_lock(&other->sk_callback_lock);
if (peer_mode == SHUTDOWN_MASK) if (peer_mode == SHUTDOWN_MASK)
@ -1973,7 +1973,7 @@ static int unix_seq_show(struct seq_file *seq, void *v)
else { else {
struct sock *s = v; struct sock *s = v;
struct unix_sock *u = unix_sk(s); struct unix_sock *u = unix_sk(s);
unix_state_rlock(s); unix_state_lock(s);
seq_printf(seq, "%p: %08X %08X %08X %04X %02X %5lu", seq_printf(seq, "%p: %08X %08X %08X %04X %02X %5lu",
s, s,
@ -2001,7 +2001,7 @@ static int unix_seq_show(struct seq_file *seq, void *v)
for ( ; i < len; i++) for ( ; i < len; i++)
seq_putc(seq, u->addr->name->sun_path[i]); seq_putc(seq, u->addr->name->sun_path[i]);
} }
unix_state_runlock(s); unix_state_unlock(s);
seq_putc(seq, '\n'); seq_putc(seq, '\n');
} }