]> git.karo-electronics.de Git - karo-tx-linux.git/blobdiff - net/netlink/af_netlink.c
Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mason/linux...
[karo-tx-linux.git] / net / netlink / af_netlink.c
index 75532efa51cd6d54389366b0f4cc7a4de34e0fbc..2702673f0f237d7fa43e4649ccc0672f67c895bc 100644 (file)
@@ -98,12 +98,12 @@ static int netlink_dump(struct sock *sk);
 static void netlink_skb_destructor(struct sk_buff *skb);
 
 /* nl_table locking explained:
- * Lookup and traversal are protected with nl_sk_hash_lock or nl_table_lock
- * combined with an RCU read-side lock. Insertion and removal are protected
- * with nl_sk_hash_lock while using RCU list modification primitives and may
- * run in parallel to nl_table_lock protected lookups. Destruction of the
- * Netlink socket may only occur *after* nl_table_lock has been acquired
- * either during or after the socket has been removed from the list.
+ * Lookup and traversal are protected with an RCU read-side lock. Insertion
+ * and removal are protected with per bucket lock while using RCU list
+ * modification primitives and may run in parallel to RCU protected lookups.
+ * Destruction of the Netlink socket may only occur *after* nl_table_lock has
+ * been acquired * either during or after the socket has been removed from
+ * the list and after an RCU grace period.
  */
 DEFINE_RWLOCK(nl_table_lock);
 EXPORT_SYMBOL_GPL(nl_table_lock);
@@ -111,19 +111,6 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);
 
 #define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock));
 
-/* Protects netlink socket hash table mutations */
-DEFINE_MUTEX(nl_sk_hash_lock);
-EXPORT_SYMBOL_GPL(nl_sk_hash_lock);
-
-#ifdef CONFIG_PROVE_LOCKING
-static int lockdep_nl_sk_hash_is_held(void *parent)
-{
-       if (debug_locks)
-               return lockdep_is_held(&nl_sk_hash_lock) || lockdep_is_held(&nl_table_lock);
-       return 1;
-}
-#endif
-
 static ATOMIC_NOTIFIER_HEAD(netlink_chain);
 
 static DEFINE_SPINLOCK(netlink_tap_lock);
@@ -708,7 +695,7 @@ static void netlink_ring_setup_skb(struct sk_buff *skb, struct sock *sk,
 
 static int netlink_mmap_sendmsg(struct sock *sk, struct msghdr *msg,
                                u32 dst_portid, u32 dst_group,
-                               struct sock_iocb *siocb)
+                               struct scm_cookie *scm)
 {
        struct netlink_sock *nlk = nlk_sk(sk);
        struct netlink_ring *ring;
@@ -754,7 +741,7 @@ static int netlink_mmap_sendmsg(struct sock *sk, struct msghdr *msg,
 
                NETLINK_CB(skb).portid    = nlk->portid;
                NETLINK_CB(skb).dst_group = dst_group;
-               NETLINK_CB(skb).creds     = siocb->scm->creds;
+               NETLINK_CB(skb).creds     = scm->creds;
 
                err = security_netlink_send(sk, skb);
                if (err) {
@@ -833,7 +820,7 @@ static void netlink_ring_set_copied(struct sock *sk, struct sk_buff *skb)
 #define netlink_tx_is_mmaped(sk)       false
 #define netlink_mmap                   sock_no_mmap
 #define netlink_poll                   datagram_poll
-#define netlink_mmap_sendmsg(sk, msg, dst_portid, dst_group, siocb)    0
+#define netlink_mmap_sendmsg(sk, msg, dst_portid, dst_group, scm)      0
 #endif /* CONFIG_NETLINK_MMAP */
 
 static void netlink_skb_destructor(struct sk_buff *skb)
@@ -1003,26 +990,33 @@ static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
                .net = net,
                .portid = portid,
        };
-       u32 hash;
-
-       hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid));
 
-       return rhashtable_lookup_compare(&table->hash, hash,
+       return rhashtable_lookup_compare(&table->hash, &portid,
                                         &netlink_compare, &arg);
 }
 
+static bool __netlink_insert(struct netlink_table *table, struct sock *sk)
+{
+       struct netlink_compare_arg arg = {
+               .net = sock_net(sk),
+               .portid = nlk_sk(sk)->portid,
+       };
+
+       return rhashtable_lookup_compare_insert(&table->hash,
+                                               &nlk_sk(sk)->node,
+                                               &netlink_compare, &arg);
+}
+
 static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
 {
        struct netlink_table *table = &nl_table[protocol];
        struct sock *sk;
 
-       read_lock(&nl_table_lock);
        rcu_read_lock();
        sk = __netlink_lookup(table, portid, net);
        if (sk)
                sock_hold(sk);
        rcu_read_unlock();
-       read_unlock(&nl_table_lock);
 
        return sk;
 }
@@ -1053,29 +1047,33 @@ netlink_update_listeners(struct sock *sk)
         * makes sure updates are visible before bind or setsockopt return. */
 }
 
-static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
+static int netlink_insert(struct sock *sk, u32 portid)
 {
        struct netlink_table *table = &nl_table[sk->sk_protocol];
-       int err = -EADDRINUSE;
+       int err;
 
-       mutex_lock(&nl_sk_hash_lock);
-       if (__netlink_lookup(table, portid, net))
-               goto err;
+       lock_sock(sk);
 
        err = -EBUSY;
        if (nlk_sk(sk)->portid)
                goto err;
 
        err = -ENOMEM;
-       if (BITS_PER_LONG > 32 && unlikely(table->hash.nelems >= UINT_MAX))
+       if (BITS_PER_LONG > 32 &&
+           unlikely(atomic_read(&table->hash.nelems) >= UINT_MAX))
                goto err;
 
        nlk_sk(sk)->portid = portid;
        sock_hold(sk);
-       rhashtable_insert(&table->hash, &nlk_sk(sk)->node);
+
        err = 0;
+       if (!__netlink_insert(table, sk)) {
+               err = -EADDRINUSE;
+               sock_put(sk);
+       }
+
 err:
-       mutex_unlock(&nl_sk_hash_lock);
+       release_sock(sk);
        return err;
 }
 
@@ -1083,13 +1081,11 @@ static void netlink_remove(struct sock *sk)
 {
        struct netlink_table *table;
 
-       mutex_lock(&nl_sk_hash_lock);
        table = &nl_table[sk->sk_protocol];
        if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) {
                WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
                __sock_put(sk);
        }
-       mutex_unlock(&nl_sk_hash_lock);
 
        netlink_table_grab();
        if (nlk_sk(sk)->subscriptions) {
@@ -1197,6 +1193,13 @@ out_module:
        goto out;
 }
 
+static void deferred_put_nlk_sk(struct rcu_head *head)
+{
+       struct netlink_sock *nlk = container_of(head, struct netlink_sock, rcu);
+
+       sock_put(&nlk->sk);
+}
+
 static int netlink_release(struct socket *sock)
 {
        struct sock *sk = sock->sk;
@@ -1269,7 +1272,7 @@ static int netlink_release(struct socket *sock)
        local_bh_disable();
        sock_prot_inuse_add(sock_net(sk), &netlink_proto, -1);
        local_bh_enable();
-       sock_put(sk);
+       call_rcu(&nlk->rcu, deferred_put_nlk_sk);
        return 0;
 }
 
@@ -1284,7 +1287,6 @@ static int netlink_autobind(struct socket *sock)
 
 retry:
        cond_resched();
-       netlink_table_grab();
        rcu_read_lock();
        if (__netlink_lookup(table, portid, net)) {
                /* Bind collision, search negative portid values. */
@@ -1292,13 +1294,11 @@ retry:
                if (rover > -4097)
                        rover = -4097;
                rcu_read_unlock();
-               netlink_table_ungrab();
                goto retry;
        }
        rcu_read_unlock();
-       netlink_table_ungrab();
 
-       err = netlink_insert(sk, net, portid);
+       err = netlink_insert(sk, portid);
        if (err == -EADDRINUSE)
                goto retry;
 
@@ -1486,7 +1486,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
 
        if (!nlk->portid) {
                err = nladdr->nl_pid ?
-                       netlink_insert(sk, net, nladdr->nl_pid) :
+                       netlink_insert(sk, nladdr->nl_pid) :
                        netlink_autobind(sock);
                if (err) {
                        netlink_undo_bind(nlk->ngroups, groups, sk);
@@ -2259,7 +2259,6 @@ static void netlink_cmsg_recv_pktinfo(struct msghdr *msg, struct sk_buff *skb)
 static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
                           struct msghdr *msg, size_t len)
 {
-       struct sock_iocb *siocb = kiocb_to_siocb(kiocb);
        struct sock *sk = sock->sk;
        struct netlink_sock *nlk = nlk_sk(sk);
        DECLARE_SOCKADDR(struct sockaddr_nl *, addr, msg->msg_name);
@@ -2273,10 +2272,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
        if (msg->msg_flags&MSG_OOB)
                return -EOPNOTSUPP;
 
-       if (NULL == siocb->scm)
-               siocb->scm = &scm;
-
-       err = scm_send(sock, msg, siocb->scm, true);
+       err = scm_send(sock, msg, &scm, true);
        if (err < 0)
                return err;
 
@@ -2302,10 +2298,15 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
                        goto out;
        }
 
+       /* It's a really convoluted way for userland to ask for mmaped
+        * sendmsg(), but that's what we've got...
+        */
        if (netlink_tx_is_mmaped(sk) &&
+           msg->msg_iter.type == ITER_IOVEC &&
+           msg->msg_iter.nr_segs == 1 &&
            msg->msg_iter.iov->iov_base == NULL) {
                err = netlink_mmap_sendmsg(sk, msg, dst_portid, dst_group,
-                                          siocb);
+                                          &scm);
                goto out;
        }
 
@@ -2319,7 +2320,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
 
        NETLINK_CB(skb).portid  = nlk->portid;
        NETLINK_CB(skb).dst_group = dst_group;
-       NETLINK_CB(skb).creds   = siocb->scm->creds;
+       NETLINK_CB(skb).creds   = scm.creds;
        NETLINK_CB(skb).flags   = netlink_skb_flags;
 
        err = -EFAULT;
@@ -2341,7 +2342,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
        err = netlink_unicast(sk, skb, dst_portid, msg->msg_flags&MSG_DONTWAIT);
 
 out:
-       scm_destroy(siocb->scm);
+       scm_destroy(&scm);
        return err;
 }
 
@@ -2349,7 +2350,6 @@ static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock,
                           struct msghdr *msg, size_t len,
                           int flags)
 {
-       struct sock_iocb *siocb = kiocb_to_siocb(kiocb);
        struct scm_cookie scm;
        struct sock *sk = sock->sk;
        struct netlink_sock *nlk = nlk_sk(sk);
@@ -2412,11 +2412,8 @@ static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock,
        if (nlk->flags & NETLINK_RECV_PKTINFO)
                netlink_cmsg_recv_pktinfo(msg, skb);
 
-       if (NULL == siocb->scm) {
-               memset(&scm, 0, sizeof(scm));
-               siocb->scm = &scm;
-       }
-       siocb->scm->creds = *NETLINK_CREDS(skb);
+       memset(&scm, 0, sizeof(scm));
+       scm.creds = *NETLINK_CREDS(skb);
        if (flags & MSG_TRUNC)
                copied = data_skb->len;
 
@@ -2431,7 +2428,7 @@ static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock,
                }
        }
 
-       scm_recv(sock, msg, siocb->scm, flags);
+       scm_recv(sock, msg, &scm, flags);
 out:
        netlink_rcv_wake(sk);
        return err ? : copied;
@@ -2492,7 +2489,7 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module,
        if (cfg && cfg->input)
                nlk_sk(sk)->netlink_rcv = cfg->input;
 
-       if (netlink_insert(sk, net, 0))
+       if (netlink_insert(sk, 0))
                goto out_sock_release;
 
        nlk = nlk_sk(sk);
@@ -2894,97 +2891,97 @@ EXPORT_SYMBOL(nlmsg_notify);
 #ifdef CONFIG_PROC_FS
 struct nl_seq_iter {
        struct seq_net_private p;
+       struct rhashtable_iter hti;
        int link;
-       int hash_idx;
 };
 
-static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
+static int netlink_walk_start(struct nl_seq_iter *iter)
 {
-       struct nl_seq_iter *iter = seq->private;
-       int i, j;
-       struct netlink_sock *nlk;
-       struct sock *s;
-       loff_t off = 0;
-
-       for (i = 0; i < MAX_LINKS; i++) {
-               struct rhashtable *ht = &nl_table[i].hash;
-               const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
-
-               for (j = 0; j < tbl->size; j++) {
-                       rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
-                               s = (struct sock *)nlk;
+       int err;
 
-                               if (sock_net(s) != seq_file_net(seq))
-                                       continue;
-                               if (off == pos) {
-                                       iter->link = i;
-                                       iter->hash_idx = j;
-                                       return s;
-                               }
-                               ++off;
-                       }
-               }
+       err = rhashtable_walk_init(&nl_table[iter->link].hash, &iter->hti);
+       if (err) {
+               iter->link = MAX_LINKS;
+               return err;
        }
-       return NULL;
+
+       err = rhashtable_walk_start(&iter->hti);
+       return err == -EAGAIN ? 0 : err;
 }
 
-static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
-       __acquires(nl_table_lock) __acquires(RCU)
+static void netlink_walk_stop(struct nl_seq_iter *iter)
 {
-       read_lock(&nl_table_lock);
-       rcu_read_lock();
-       return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;
+       rhashtable_walk_stop(&iter->hti);
+       rhashtable_walk_exit(&iter->hti);
 }
 
-static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
+static void *__netlink_seq_next(struct seq_file *seq)
 {
-       struct rhashtable *ht;
+       struct nl_seq_iter *iter = seq->private;
        struct netlink_sock *nlk;
-       struct nl_seq_iter *iter;
-       struct net *net;
-       int i, j;
 
-       ++*pos;
+       do {
+               for (;;) {
+                       int err;
 
-       if (v == SEQ_START_TOKEN)
-               return netlink_seq_socket_idx(seq, 0);
+                       nlk = rhashtable_walk_next(&iter->hti);
 
-       net = seq_file_net(seq);
-       iter = seq->private;
-       nlk = v;
+                       if (IS_ERR(nlk)) {
+                               if (PTR_ERR(nlk) == -EAGAIN)
+                                       continue;
 
-       i = iter->link;
-       ht = &nl_table[i].hash;
-       rht_for_each_entry(nlk, nlk->node.next, ht, node)
-               if (net_eq(sock_net((struct sock *)nlk), net))
-                       return nlk;
+                               return nlk;
+                       }
 
-       j = iter->hash_idx + 1;
+                       if (nlk)
+                               break;
 
-       do {
-               const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
-
-               for (; j < tbl->size; j++) {
-                       rht_for_each_entry(nlk, tbl->buckets[j], ht, node) {
-                               if (net_eq(sock_net((struct sock *)nlk), net)) {
-                                       iter->link = i;
-                                       iter->hash_idx = j;
-                                       return nlk;
-                               }
-                       }
+                       netlink_walk_stop(iter);
+                       if (++iter->link >= MAX_LINKS)
+                               return NULL;
+
+                       err = netlink_walk_start(iter);
+                       if (err)
+                               return ERR_PTR(err);
                }
+       } while (sock_net(&nlk->sk) != seq_file_net(seq));
 
-               j = 0;
-       } while (++i < MAX_LINKS);
+       return nlk;
+}
 
-       return NULL;
+static void *netlink_seq_start(struct seq_file *seq, loff_t *posp)
+{
+       struct nl_seq_iter *iter = seq->private;
+       void *obj = SEQ_START_TOKEN;
+       loff_t pos;
+       int err;
+
+       iter->link = 0;
+
+       err = netlink_walk_start(iter);
+       if (err)
+               return ERR_PTR(err);
+
+       for (pos = *posp; pos && obj && !IS_ERR(obj); pos--)
+               obj = __netlink_seq_next(seq);
+
+       return obj;
+}
+
+static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
+{
+       ++*pos;
+       return __netlink_seq_next(seq);
 }
 
 static void netlink_seq_stop(struct seq_file *seq, void *v)
-       __releases(RCU) __releases(nl_table_lock)
 {
-       rcu_read_unlock();
-       read_unlock(&nl_table_lock);
+       struct nl_seq_iter *iter = seq->private;
+
+       if (iter->link >= MAX_LINKS)
+               return;
+
+       netlink_walk_stop(iter);
 }
 
 
@@ -3131,9 +3128,6 @@ static int __init netlink_proto_init(void)
                .max_shift = 16, /* 64K */
                .grow_decision = rht_grow_above_75,
                .shrink_decision = rht_shrink_below_30,
-#ifdef CONFIG_PROVE_LOCKING
-               .mutex_is_held = lockdep_nl_sk_hash_is_held,
-#endif
        };
 
        if (err != 0)