diff --git a/fs/dlm/lowcomms.c b/fs/dlm/lowcomms.c index dc9ae6d670dc..00640e70ed7a 100644 --- a/fs/dlm/lowcomms.c +++ b/fs/dlm/lowcomms.c @@ -124,7 +124,10 @@ struct connection { struct connection *othercon; struct work_struct rwork; /* Receive workqueue */ struct work_struct swork; /* Send workqueue */ - void (*orig_error_report)(struct sock *sk); + void (*orig_error_report)(struct sock *); + void (*orig_data_ready)(struct sock *); + void (*orig_state_change)(struct sock *); + void (*orig_write_space)(struct sock *); }; #define sock2con(x) ((struct connection *)(x)->sk_user_data) @@ -467,10 +470,17 @@ int dlm_lowcomms_connect_node(int nodeid) static void lowcomms_error_report(struct sock *sk) { - struct connection *con = sock2con(sk); + struct connection *con; struct sockaddr_storage saddr; int buflen; + void (*orig_report)(struct sock *) = NULL; + read_lock_bh(&sk->sk_callback_lock); + con = sock2con(sk); + if (con == NULL) + goto out; + + orig_report = con->orig_error_report; if (con->sock == NULL || kernel_getpeername(con->sock, (struct sockaddr *)&saddr, &buflen)) { printk_ratelimited(KERN_ERR "dlm: node %d: socket error " @@ -478,7 +488,6 @@ static void lowcomms_error_report(struct sock *sk) "sk_err=%d/%d\n", dlm_our_nodeid(), con->nodeid, dlm_config.ci_tcp_port, sk->sk_err, sk->sk_err_soft); - return; } else if (saddr.ss_family == AF_INET) { struct sockaddr_in *sin4 = (struct sockaddr_in *)&saddr; @@ -501,22 +510,54 @@ static void lowcomms_error_report(struct sock *sk) dlm_config.ci_tcp_port, sk->sk_err, sk->sk_err_soft); } - con->orig_error_report(sk); +out: + read_unlock_bh(&sk->sk_callback_lock); + if (orig_report) + orig_report(sk); +} + +/* Note: sk_callback_lock must be locked before calling this function. */ +static void save_callbacks(struct connection *con, struct sock *sk) +{ + lock_sock(sk); + con->orig_data_ready = sk->sk_data_ready; + con->orig_state_change = sk->sk_state_change; + con->orig_write_space = sk->sk_write_space; + con->orig_error_report = sk->sk_error_report; + release_sock(sk); +} + +static void restore_callbacks(struct connection *con, struct sock *sk) +{ + write_lock_bh(&sk->sk_callback_lock); + lock_sock(sk); + sk->sk_user_data = NULL; + sk->sk_data_ready = con->orig_data_ready; + sk->sk_state_change = con->orig_state_change; + sk->sk_write_space = con->orig_write_space; + sk->sk_error_report = con->orig_error_report; + release_sock(sk); + write_unlock_bh(&sk->sk_callback_lock); } /* Make a socket active */ static void add_sock(struct socket *sock, struct connection *con) { + struct sock *sk = sock->sk; + + write_lock_bh(&sk->sk_callback_lock); con->sock = sock; + sk->sk_user_data = con; + if (!test_bit(CF_IS_OTHERCON, &con->flags)) + save_callbacks(con, sk); /* Install a data_ready callback */ - con->sock->sk->sk_data_ready = lowcomms_data_ready; - con->sock->sk->sk_write_space = lowcomms_write_space; - con->sock->sk->sk_state_change = lowcomms_state_change; - con->sock->sk->sk_user_data = con; - con->sock->sk->sk_allocation = GFP_NOFS; - con->orig_error_report = con->sock->sk->sk_error_report; - con->sock->sk->sk_error_report = lowcomms_error_report; + sk->sk_data_ready = lowcomms_data_ready; + sk->sk_write_space = lowcomms_write_space; + sk->sk_state_change = lowcomms_state_change; + sk->sk_allocation = GFP_NOFS; + sk->sk_error_report = lowcomms_error_report; + write_unlock_bh(&sk->sk_callback_lock); } /* Add the port number to an IPv6 or 4 sockaddr and return the address @@ -551,6 +592,8 @@ static void close_connection(struct connection *con, bool and_other, mutex_lock(&con->sock_mutex); if (con->sock) { + if (!test_bit(CF_IS_OTHERCON, &con->flags)) + restore_callbacks(con, con->sock->sk); sock_release(con->sock); con->sock = NULL; } @@ -1192,6 +1235,8 @@ static struct socket *tcp_create_listen_sock(struct connection *con, if (result < 0) { log_print("Failed to set SO_REUSEADDR on socket: %d", result); } + sock->sk->sk_user_data = con; + con->rx_action = tcp_accept_from_sock; con->connect_action = tcp_connect_to_sock; @@ -1273,6 +1318,7 @@ static int sctp_listen_for_all(void) if (result < 0) log_print("Could not set SCTP NODELAY error %d\n", result); + write_lock_bh(&sock->sk->sk_callback_lock); /* Init con struct */ sock->sk->sk_user_data = con; con->sock = sock; @@ -1280,6 +1326,8 @@ static int sctp_listen_for_all(void) con->rx_action = sctp_accept_from_sock; con->connect_action = sctp_connect_to_sock; + write_unlock_bh(&sock->sk->sk_callback_lock); + /* Bind to all addresses. */ if (sctp_bind_addrs(con, dlm_config.ci_tcp_port)) goto create_delsock;