diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index 46fa9f3016cc..77ef53596d18 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -30,6 +30,10 @@ #include #include +#include +#include +#include "smc_netns.h" + #include "smc.h" #include "smc_clc.h" #include "smc_llc.h" @@ -1966,10 +1970,33 @@ static const struct net_proto_family smc_sock_family_ops = { .create = smc_create, }; +unsigned int smc_net_id; + +static __net_init int smc_net_init(struct net *net) +{ + return smc_pnet_net_init(net); +} + +static void __net_exit smc_net_exit(struct net *net) +{ + smc_pnet_net_exit(net); +} + +static struct pernet_operations smc_net_ops = { + .init = smc_net_init, + .exit = smc_net_exit, + .id = &smc_net_id, + .size = sizeof(struct smc_net), +}; + static int __init smc_init(void) { int rc; + rc = register_pernet_subsys(&smc_net_ops); + if (rc) + return rc; + rc = smc_pnet_init(); if (rc) return rc; @@ -2035,6 +2062,7 @@ static void __exit smc_exit(void) proto_unregister(&smc_proto6); proto_unregister(&smc_proto); smc_pnet_exit(); + unregister_pernet_subsys(&smc_net_ops); } module_init(smc_init); diff --git a/net/smc/smc_netns.h b/net/smc/smc_netns.h new file mode 100644 index 000000000000..e7a8fc4ae02f --- /dev/null +++ b/net/smc/smc_netns.h @@ -0,0 +1,20 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Shared Memory Communications + * + * Network namespace definitions. + * + * Copyright IBM Corp. 2018 + */ + +#ifndef SMC_NETNS_H +#define SMC_NETNS_H + +#include "smc_pnet.h" + +extern unsigned int smc_net_id; + +/* per-network namespace private data */ +struct smc_net { + struct smc_pnettable pnettable; +}; +#endif diff --git a/net/smc/smc_pnet.c b/net/smc/smc_pnet.c index 5497a8b44287..878f5c085444 100644 --- a/net/smc/smc_pnet.c +++ b/net/smc/smc_pnet.c @@ -20,6 +20,9 @@ #include +#include +#include "smc_netns.h" + #include "smc_pnet.h" #include "smc_ib.h" #include "smc_ism.h" @@ -46,19 +49,6 @@ static struct nla_policy smc_pnet_policy[SMC_PNETID_MAX + 1] = { static struct genl_family smc_pnet_nl_family; -/** - * struct smc_pnettable - SMC PNET table anchor - * @lock: Lock for list action - * @pnetlist: List of PNETIDs - */ -static struct smc_pnettable { - rwlock_t lock; - struct list_head pnetlist; -} smc_pnettable = { - .pnetlist = LIST_HEAD_INIT(smc_pnettable.pnetlist), - .lock = __RW_LOCK_UNLOCKED(smc_pnettable.lock) -}; - /** * struct smc_user_pnetentry - pnet identifier name entry for/from user * @list: List node. @@ -101,17 +91,23 @@ static bool smc_pnet_match(u8 *pnetid1, u8 *pnetid2) /* Remove a pnetid from the pnet table. */ -static int smc_pnet_remove_by_pnetid(char *pnet_name) +static int smc_pnet_remove_by_pnetid(struct net *net, char *pnet_name) { struct smc_pnetentry *pnetelem, *tmp_pe; + struct smc_pnettable *pnettable; struct smc_ib_device *ibdev; struct smcd_dev *smcd_dev; + struct smc_net *sn; int rc = -ENOENT; int ibport; + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + /* remove netdevices */ - write_lock(&smc_pnettable.lock); - list_for_each_entry_safe(pnetelem, tmp_pe, &smc_pnettable.pnetlist, + write_lock(&pnettable->lock); + list_for_each_entry_safe(pnetelem, tmp_pe, &pnettable->pnetlist, list) { if (!pnet_name || smc_pnet_match(pnetelem->pnet_name, pnet_name)) { @@ -121,7 +117,12 @@ static int smc_pnet_remove_by_pnetid(char *pnet_name) rc = 0; } } - write_unlock(&smc_pnettable.lock); + write_unlock(&pnettable->lock); + + /* if this is not the initial namespace, stop here */ + if (net != &init_net) + return rc; + /* remove ib devices */ spin_lock(&smc_ib_devices.lock); list_for_each_entry(ibdev, &smc_ib_devices.list, list) { @@ -158,11 +159,17 @@ static int smc_pnet_remove_by_pnetid(char *pnet_name) static int smc_pnet_remove_by_ndev(struct net_device *ndev) { struct smc_pnetentry *pnetelem, *tmp_pe; + struct smc_pnettable *pnettable; + struct net *net = dev_net(ndev); + struct smc_net *sn; int rc = -ENOENT; - write_lock(&smc_pnettable.lock); - list_for_each_entry_safe(pnetelem, tmp_pe, &smc_pnettable.pnetlist, - list) { + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + + write_lock(&pnettable->lock); + list_for_each_entry_safe(pnetelem, tmp_pe, &pnettable->pnetlist, list) { if (pnetelem->ndev == ndev) { list_del(&pnetelem->list); dev_put(pnetelem->ndev); @@ -171,13 +178,14 @@ static int smc_pnet_remove_by_ndev(struct net_device *ndev) break; } } - write_unlock(&smc_pnettable.lock); + write_unlock(&pnettable->lock); return rc; } /* Append a pnetid to the end of the pnet table if not already on this list. */ -static int smc_pnet_enter(struct smc_user_pnetentry *new_pnetelem) +static int smc_pnet_enter(struct smc_pnettable *pnettable, + struct smc_user_pnetentry *new_pnetelem) { u8 pnet_null[SMC_MAX_PNETID_LEN] = {0}; u8 ndev_pnetid[SMC_MAX_PNETID_LEN]; @@ -233,17 +241,17 @@ static int smc_pnet_enter(struct smc_user_pnetentry *new_pnetelem) SMC_MAX_PNETID_LEN); tmp_pnetelem->ndev = new_pnetelem->ndev; - write_lock(&smc_pnettable.lock); - list_for_each_entry(pnetelem, &smc_pnettable.pnetlist, list) { + write_lock(&pnettable->lock); + list_for_each_entry(pnetelem, &pnettable->pnetlist, list) { if (pnetelem->ndev == new_pnetelem->ndev) new_netdev = false; } if (new_netdev) { dev_hold(tmp_pnetelem->ndev); - list_add_tail(&tmp_pnetelem->list, &smc_pnettable.pnetlist); - write_unlock(&smc_pnettable.lock); + list_add_tail(&tmp_pnetelem->list, &pnettable->pnetlist); + write_unlock(&pnettable->lock); } else { - write_unlock(&smc_pnettable.lock); + write_unlock(&pnettable->lock); kfree(tmp_pnetelem); } @@ -340,6 +348,10 @@ static int smc_pnet_fill_entry(struct net *net, goto error; } + /* if this is not the initial namespace, stop here */ + if (net != &init_net) + return 0; + rc = -EINVAL; if (tb[SMC_PNETID_IBNAME]) { ibname = (char *)nla_data(tb[SMC_PNETID_IBNAME]); @@ -403,11 +415,17 @@ static int smc_pnet_add(struct sk_buff *skb, struct genl_info *info) { struct net *net = genl_info_net(info); struct smc_user_pnetentry pnetelem; + struct smc_pnettable *pnettable; + struct smc_net *sn; int rc; + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + rc = smc_pnet_fill_entry(net, &pnetelem, info->attrs); if (!rc) - rc = smc_pnet_enter(&pnetelem); + rc = smc_pnet_enter(pnettable, &pnetelem); if (pnetelem.ndev) dev_put(pnetelem.ndev); return rc; @@ -415,9 +433,11 @@ static int smc_pnet_add(struct sk_buff *skb, struct genl_info *info) static int smc_pnet_del(struct sk_buff *skb, struct genl_info *info) { + struct net *net = genl_info_net(info); + if (!info->attrs[SMC_PNETID_NAME]) return -EINVAL; - return smc_pnet_remove_by_pnetid( + return smc_pnet_remove_by_pnetid(net, (char *)nla_data(info->attrs[SMC_PNETID_NAME])); } @@ -445,19 +465,25 @@ static int smc_pnet_dumpinfo(struct sk_buff *skb, return 0; } -static int _smc_pnet_dump(struct sk_buff *skb, u32 portid, u32 seq, u8 *pnetid, - int start_idx) +static int _smc_pnet_dump(struct net *net, struct sk_buff *skb, u32 portid, + u32 seq, u8 *pnetid, int start_idx) { struct smc_user_pnetentry tmp_entry; + struct smc_pnettable *pnettable; struct smc_pnetentry *pnetelem; struct smc_ib_device *ibdev; struct smcd_dev *smcd_dev; + struct smc_net *sn; int idx = 0; int ibport; + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + /* dump netdevices */ - read_lock(&smc_pnettable.lock); - list_for_each_entry(pnetelem, &smc_pnettable.pnetlist, list) { + read_lock(&pnettable->lock); + list_for_each_entry(pnetelem, &pnettable->pnetlist, list) { if (pnetid && !smc_pnet_match(pnetelem->pnet_name, pnetid)) continue; if (idx++ < start_idx) @@ -472,7 +498,11 @@ static int _smc_pnet_dump(struct sk_buff *skb, u32 portid, u32 seq, u8 *pnetid, break; } } - read_unlock(&smc_pnettable.lock); + read_unlock(&pnettable->lock); + + /* if this is not the initial namespace, stop here */ + if (net != &init_net) + return idx; /* dump ib devices */ spin_lock(&smc_ib_devices.lock); @@ -528,9 +558,10 @@ static int _smc_pnet_dump(struct sk_buff *skb, u32 portid, u32 seq, u8 *pnetid, static int smc_pnet_dump(struct sk_buff *skb, struct netlink_callback *cb) { + struct net *net = sock_net(skb->sk); int idx; - idx = _smc_pnet_dump(skb, NETLINK_CB(cb->skb).portid, + idx = _smc_pnet_dump(net, skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, NULL, cb->args[0]); cb->args[0] = idx; @@ -540,6 +571,7 @@ static int smc_pnet_dump(struct sk_buff *skb, struct netlink_callback *cb) /* Retrieve one PNETID entry */ static int smc_pnet_get(struct sk_buff *skb, struct genl_info *info) { + struct net *net = genl_info_net(info); struct sk_buff *msg; void *hdr; @@ -550,7 +582,7 @@ static int smc_pnet_get(struct sk_buff *skb, struct genl_info *info) if (!msg) return -ENOMEM; - _smc_pnet_dump(msg, info->snd_portid, info->snd_seq, + _smc_pnet_dump(net, msg, info->snd_portid, info->snd_seq, nla_data(info->attrs[SMC_PNETID_NAME]), 0); /* finish multi part message and send it */ @@ -567,7 +599,9 @@ static int smc_pnet_get(struct sk_buff *skb, struct genl_info *info) */ static int smc_pnet_flush(struct sk_buff *skb, struct genl_info *info) { - return smc_pnet_remove_by_pnetid(NULL); + struct net *net = genl_info_net(info); + + return smc_pnet_remove_by_pnetid(net, NULL); } /* SMC_PNETID generic netlink operation definition */ @@ -631,6 +665,18 @@ static struct notifier_block smc_netdev_notifier = { .notifier_call = smc_pnet_netdev_event }; +/* init network namespace */ +int smc_pnet_net_init(struct net *net) +{ + struct smc_net *sn = net_generic(net, smc_net_id); + struct smc_pnettable *pnettable = &sn->pnettable; + + INIT_LIST_HEAD(&pnettable->pnetlist); + rwlock_init(&pnettable->lock); + + return 0; +} + int __init smc_pnet_init(void) { int rc; @@ -644,9 +690,15 @@ int __init smc_pnet_init(void) return rc; } +/* exit network namespace */ +void smc_pnet_net_exit(struct net *net) +{ + /* flush pnet table */ + smc_pnet_remove_by_pnetid(net, NULL); +} + void smc_pnet_exit(void) { - smc_pnet_flush(NULL, NULL); unregister_netdevice_notifier(&smc_netdev_notifier); genl_unregister_family(&smc_pnet_nl_family); } @@ -674,22 +726,29 @@ static struct net_device *pnet_find_base_ndev(struct net_device *ndev) return ndev; } -static int smc_pnet_find_ndev_pnetid_by_table(struct net_device *netdev, +static int smc_pnet_find_ndev_pnetid_by_table(struct net_device *ndev, u8 *pnetid) { + struct smc_pnettable *pnettable; + struct net *net = dev_net(ndev); struct smc_pnetentry *pnetelem; + struct smc_net *sn; int rc = -ENOENT; - read_lock(&smc_pnettable.lock); - list_for_each_entry(pnetelem, &smc_pnettable.pnetlist, list) { - if (netdev == pnetelem->ndev) { + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + + read_lock(&pnettable->lock); + list_for_each_entry(pnetelem, &pnettable->pnetlist, list) { + if (ndev == pnetelem->ndev) { /* get pnetid of netdev device */ memcpy(pnetid, pnetelem->pnet_name, SMC_MAX_PNETID_LEN); rc = 0; break; } } - read_unlock(&smc_pnettable.lock); + read_unlock(&pnettable->lock); return rc; } diff --git a/net/smc/smc_pnet.h b/net/smc/smc_pnet.h index 37044e4ee50f..5eac42fb45d0 100644 --- a/net/smc/smc_pnet.h +++ b/net/smc/smc_pnet.h @@ -19,6 +19,16 @@ struct smc_ib_device; struct smcd_dev; +/** + * struct smc_pnettable - SMC PNET table anchor + * @lock: Lock for list action + * @pnetlist: List of PNETIDs + */ +struct smc_pnettable { + rwlock_t lock; + struct list_head pnetlist; +}; + static inline int smc_pnetid_by_dev_port(struct device *dev, unsigned short port, u8 *pnetid) { @@ -30,7 +40,9 @@ static inline int smc_pnetid_by_dev_port(struct device *dev, } int smc_pnet_init(void) __init; +int smc_pnet_net_init(struct net *net); void smc_pnet_exit(void); +void smc_pnet_net_exit(struct net *net); void smc_pnet_find_roce_resource(struct sock *sk, struct smc_ib_device **smcibdev, u8 *ibport, unsigned short vlan_id, u8 gid[]);