diff --git a/drivers/net/vrf.c b/drivers/net/vrf.c index 85c271c70d42..820de6a9ddde 100644 --- a/drivers/net/vrf.c +++ b/drivers/net/vrf.c @@ -956,6 +956,7 @@ static struct sk_buff *vrf_ip6_rcv(struct net_device *vrf_dev, if (skb->pkt_type == PACKET_LOOPBACK) { skb->dev = vrf_dev; skb->skb_iif = vrf_dev->ifindex; + IP6CB(skb)->flags |= IP6SKB_L3SLAVE; skb->pkt_type = PACKET_HOST; goto out; } @@ -996,6 +997,7 @@ static struct sk_buff *vrf_ip_rcv(struct net_device *vrf_dev, { skb->dev = vrf_dev; skb->skb_iif = vrf_dev->ifindex; + IPCB(skb)->flags |= IPSKB_L3SLAVE; /* loopback traffic; do not push through packet taps again. * Reset pkt_type for upper layers to process skb diff --git a/include/linux/ipv6.h b/include/linux/ipv6.h index 7e9a789be5e0..ca1ad9ebbc92 100644 --- a/include/linux/ipv6.h +++ b/include/linux/ipv6.h @@ -123,12 +123,12 @@ struct inet6_skb_parm { }; #if defined(CONFIG_NET_L3_MASTER_DEV) -static inline bool skb_l3mdev_slave(__u16 flags) +static inline bool ipv6_l3mdev_skb(__u16 flags) { return flags & IP6SKB_L3SLAVE; } #else -static inline bool skb_l3mdev_slave(__u16 flags) +static inline bool ipv6_l3mdev_skb(__u16 flags) { return false; } @@ -139,11 +139,22 @@ static inline bool skb_l3mdev_slave(__u16 flags) static inline int inet6_iif(const struct sk_buff *skb) { - bool l3_slave = skb_l3mdev_slave(IP6CB(skb)->flags); + bool l3_slave = ipv6_l3mdev_skb(IP6CB(skb)->flags); return l3_slave ? skb->skb_iif : IP6CB(skb)->iif; } +/* can not be used in TCP layer after tcp_v6_fill_cb */ +static inline bool inet6_exact_dif_match(struct net *net, struct sk_buff *skb) +{ +#if defined(CONFIG_NET_L3_MASTER_DEV) + if (!net->ipv4.sysctl_tcp_l3mdev_accept && + ipv6_l3mdev_skb(IP6CB(skb)->flags)) + return true; +#endif + return false; +} + struct tcp6_request_sock { struct tcp_request_sock tcp6rsk_tcp; }; diff --git a/include/net/ip.h b/include/net/ip.h index bc43c0fcae12..c9d07988911e 100644 --- a/include/net/ip.h +++ b/include/net/ip.h @@ -38,7 +38,7 @@ struct sock; struct inet_skb_parm { int iif; struct ip_options opt; /* Compiled IP options */ - unsigned char flags; + u16 flags; #define IPSKB_FORWARDED BIT(0) #define IPSKB_XFRM_TUNNEL_SIZE BIT(1) @@ -48,10 +48,16 @@ struct inet_skb_parm { #define IPSKB_DOREDIRECT BIT(5) #define IPSKB_FRAG_PMTU BIT(6) #define IPSKB_FRAG_SEGS BIT(7) +#define IPSKB_L3SLAVE BIT(8) u16 frag_max_size; }; +static inline bool ipv4_l3mdev_skb(u16 flags) +{ + return !!(flags & IPSKB_L3SLAVE); +} + static inline unsigned int ip_hdrlen(const struct sk_buff *skb) { return ip_hdr(skb)->ihl * 4; diff --git a/include/net/tcp.h b/include/net/tcp.h index f83b7f220a65..5b82d4d94834 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -794,12 +794,23 @@ struct tcp_skb_cb { */ static inline int tcp_v6_iif(const struct sk_buff *skb) { - bool l3_slave = skb_l3mdev_slave(TCP_SKB_CB(skb)->header.h6.flags); + bool l3_slave = ipv6_l3mdev_skb(TCP_SKB_CB(skb)->header.h6.flags); return l3_slave ? skb->skb_iif : TCP_SKB_CB(skb)->header.h6.iif; } #endif +/* TCP_SKB_CB reference means this can not be used from early demux */ +static inline bool inet_exact_dif_match(struct net *net, struct sk_buff *skb) +{ +#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV) + if (!net->ipv4.sysctl_tcp_l3mdev_accept && + ipv4_l3mdev_skb(TCP_SKB_CB(skb)->header.h4.flags)) + return true; +#endif + return false; +} + /* Due to TSO, an SKB can be composed of multiple actual * packets. To keep these tracked properly, we use this. */ diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c index 77c20a489218..ca97835bfec4 100644 --- a/net/ipv4/inet_hashtables.c +++ b/net/ipv4/inet_hashtables.c @@ -25,6 +25,7 @@ #include #include #include +#include #include static u32 inet_ehashfn(const struct net *net, const __be32 laddr, @@ -172,7 +173,7 @@ EXPORT_SYMBOL_GPL(__inet_inherit_port); static inline int compute_score(struct sock *sk, struct net *net, const unsigned short hnum, const __be32 daddr, - const int dif) + const int dif, bool exact_dif) { int score = -1; struct inet_sock *inet = inet_sk(sk); @@ -186,7 +187,7 @@ static inline int compute_score(struct sock *sk, struct net *net, return -1; score += 4; } - if (sk->sk_bound_dev_if) { + if (sk->sk_bound_dev_if || exact_dif) { if (sk->sk_bound_dev_if != dif) return -1; score += 4; @@ -215,11 +216,12 @@ struct sock *__inet_lookup_listener(struct net *net, unsigned int hash = inet_lhashfn(net, hnum); struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash]; int score, hiscore = 0, matches = 0, reuseport = 0; + bool exact_dif = inet_exact_dif_match(net, skb); struct sock *sk, *result = NULL; u32 phash = 0; sk_for_each_rcu(sk, &ilb->head) { - score = compute_score(sk, net, hnum, daddr, dif); + score = compute_score(sk, net, hnum, daddr, dif, exact_dif); if (score > hiscore) { reuseport = sk->sk_reuseport; if (reuseport) { diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c index 00cf28ad4565..2fd0374a35b1 100644 --- a/net/ipv6/inet6_hashtables.c +++ b/net/ipv6/inet6_hashtables.c @@ -96,7 +96,7 @@ EXPORT_SYMBOL(__inet6_lookup_established); static inline int compute_score(struct sock *sk, struct net *net, const unsigned short hnum, const struct in6_addr *daddr, - const int dif) + const int dif, bool exact_dif) { int score = -1; @@ -109,7 +109,7 @@ static inline int compute_score(struct sock *sk, struct net *net, return -1; score++; } - if (sk->sk_bound_dev_if) { + if (sk->sk_bound_dev_if || exact_dif) { if (sk->sk_bound_dev_if != dif) return -1; score++; @@ -131,11 +131,12 @@ struct sock *inet6_lookup_listener(struct net *net, unsigned int hash = inet_lhashfn(net, hnum); struct inet_listen_hashbucket *ilb = &hashinfo->listening_hash[hash]; int score, hiscore = 0, matches = 0, reuseport = 0; + bool exact_dif = inet6_exact_dif_match(net, skb); struct sock *sk, *result = NULL; u32 phash = 0; sk_for_each(sk, &ilb->head) { - score = compute_score(sk, net, hnum, daddr, dif); + score = compute_score(sk, net, hnum, daddr, dif, exact_dif); if (score > hiscore) { reuseport = sk->sk_reuseport; if (reuseport) {