diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c index 6e669114f829..49f1b0e41bdf 100644 --- a/net/mpls/af_mpls.c +++ b/net/mpls/af_mpls.c @@ -15,6 +15,7 @@ #include #include #include +#include #include "internal.h" #define LABEL_NOT_SPECIFIED (1<<20) @@ -330,6 +331,70 @@ static unsigned find_free_label(struct net *net) return LABEL_NOT_SPECIFIED; } +static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr) +{ + struct net_device *dev = NULL; + struct rtable *rt; + struct in_addr daddr; + + memcpy(&daddr, addr, sizeof(struct in_addr)); + rt = ip_route_output(net, daddr.s_addr, 0, 0, 0); + if (IS_ERR(rt)) + goto errout; + + dev = rt->dst.dev; + dev_hold(dev); + + ip_rt_put(rt); + +errout: + return dev; +} + +static struct net_device *inet6_fib_lookup_dev(struct net *net, void *addr) +{ + struct net_device *dev = NULL; + struct dst_entry *dst; + struct flowi6 fl6; + + memset(&fl6, 0, sizeof(fl6)); + memcpy(&fl6.daddr, addr, sizeof(struct in6_addr)); + dst = ip6_route_output(net, NULL, &fl6); + if (dst->error) + goto errout; + + dev = dst->dev; + dev_hold(dev); + +errout: + dst_release(dst); + + return dev; +} + +static struct net_device *find_outdev(struct net *net, + struct mpls_route_config *cfg) +{ + struct net_device *dev = NULL; + + if (!cfg->rc_ifindex) { + switch (cfg->rc_via_table) { + case NEIGH_ARP_TABLE: + dev = inet_fib_lookup_dev(net, cfg->rc_via); + break; + case NEIGH_ND_TABLE: + dev = inet6_fib_lookup_dev(net, cfg->rc_via); + break; + case NEIGH_LINK_TABLE: + break; + } + } else { + dev = dev_get_by_index(net, cfg->rc_ifindex); + } + + return dev; +} + static int mpls_route_add(struct mpls_route_config *cfg) { struct mpls_route __rcu **platform_label; @@ -361,7 +426,7 @@ static int mpls_route_add(struct mpls_route_config *cfg) goto errout; err = -ENODEV; - dev = dev_get_by_index(net, cfg->rc_ifindex); + dev = find_outdev(net, cfg); if (!dev) goto errout;