]> git.karo-electronics.de Git - mv-sheeva.git/blobdiff - net/netlink/af_netlink.c
[GENETLINK]: Relax dances with genl_lock.
[mv-sheeva.git] / net / netlink / af_netlink.c
index 2e02b19e4552d8b6ae068f78c5a84206210c9271..1ab0da2632e19c5a0a74546a67a3e9c4d11bbc50 100644 (file)
@@ -171,6 +171,7 @@ static void netlink_sock_destruct(struct sock *sk)
  */
 
 static void netlink_table_grab(void)
+       __acquires(nl_table_lock)
 {
        write_lock_irq(&nl_table_lock);
 
@@ -192,7 +193,8 @@ static void netlink_table_grab(void)
        }
 }
 
-static inline void netlink_table_ungrab(void)
+static void netlink_table_ungrab(void)
+       __releases(nl_table_lock)
 {
        write_unlock_irq(&nl_table_lock);
        wake_up(&nl_table_wait);
@@ -237,13 +239,14 @@ found:
        return sk;
 }
 
-static inline struct hlist_head *nl_pid_hash_alloc(size_t size)
+static inline struct hlist_head *nl_pid_hash_zalloc(size_t size)
 {
        if (size <= PAGE_SIZE)
-               return kmalloc(size, GFP_ATOMIC);
+               return kzalloc(size, GFP_ATOMIC);
        else
                return (struct hlist_head *)
-                       __get_free_pages(GFP_ATOMIC, get_order(size));
+                       __get_free_pages(GFP_ATOMIC | __GFP_ZERO,
+                                        get_order(size));
 }
 
 static inline void nl_pid_hash_free(struct hlist_head *table, size_t size)
@@ -272,11 +275,10 @@ static int nl_pid_hash_rehash(struct nl_pid_hash *hash, int grow)
                size *= 2;
        }
 
-       table = nl_pid_hash_alloc(size);
+       table = nl_pid_hash_zalloc(size);
        if (!table)
                return 0;
 
-       memset(table, 0, size);
        otable = hash->table;
        hash->table = table;
        hash->mask = mask;
@@ -496,9 +498,12 @@ static int netlink_release(struct socket *sock)
 
        netlink_table_grab();
        if (netlink_is_kernel(sk)) {
-               kfree(nl_table[sk->sk_protocol].listeners);
-               nl_table[sk->sk_protocol].module = NULL;
-               nl_table[sk->sk_protocol].registered = 0;
+               BUG_ON(nl_table[sk->sk_protocol].registered == 0);
+               if (--nl_table[sk->sk_protocol].registered == 0) {
+                       kfree(nl_table[sk->sk_protocol].listeners);
+                       nl_table[sk->sk_protocol].module = NULL;
+                       nl_table[sk->sk_protocol].registered = 0;
+               }
        } else if (nlk->subscriptions)
                netlink_update_listeners(sk);
        netlink_table_ungrab();
@@ -1233,7 +1238,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
 
        NETLINK_CB(skb).pid     = nlk->pid;
        NETLINK_CB(skb).dst_group = dst_group;
-       NETLINK_CB(skb).loginuid = audit_get_loginuid(current->audit_context);
+       NETLINK_CB(skb).loginuid = audit_get_loginuid(current);
        selinux_get_task_sid(current, &(NETLINK_CB(skb).sid));
        memcpy(NETLINK_CREDS(skb), &siocb->scm->creds, sizeof(struct ucred));
 
@@ -1339,6 +1344,22 @@ static void netlink_data_ready(struct sock *sk, int len)
  *     queueing.
  */
 
+static void __netlink_release(struct sock *sk)
+{
+       /*
+        * Last sock_put should drop referrence to sk->sk_net. It has already
+        * been dropped in netlink_kernel_create. Taking referrence to stopping
+        * namespace is not an option.
+        * Take referrence to a socket to remove it from netlink lookup table
+        * _alive_ and after that destroy it in the context of init_net.
+        */
+
+       sock_hold(sk);
+       sock_release(sk->sk_socket);
+       sk->sk_net = get_net(&init_net);
+       sock_put(sk);
+}
+
 struct sock *
 netlink_kernel_create(struct net *net, int unit, unsigned int groups,
                      void (*input)(struct sk_buff *skb),
@@ -1357,8 +1378,18 @@ netlink_kernel_create(struct net *net, int unit, unsigned int groups,
        if (sock_create_lite(PF_NETLINK, SOCK_DGRAM, unit, &sock))
                return NULL;
 
-       if (__netlink_create(net, sock, cb_mutex, unit) < 0)
-               goto out_sock_release;
+       /*
+        * We have to just have a reference on the net from sk, but don't
+        * get_net it. Besides, we cannot get and then put the net here.
+        * So we create one inside init_net and the move it to net.
+        */
+
+       if (__netlink_create(&init_net, sock, cb_mutex, unit) < 0)
+               goto out_sock_release_nosk;
+
+       sk = sock->sk;
+       put_net(sk->sk_net);
+       sk->sk_net = net;
 
        if (groups < 32)
                groups = 32;
@@ -1367,7 +1398,6 @@ netlink_kernel_create(struct net *net, int unit, unsigned int groups,
        if (!listeners)
                goto out_sock_release;
 
-       sk = sock->sk;
        sk->sk_data_ready = netlink_data_ready;
        if (input)
                nlk_sk(sk)->netlink_rcv = input;
@@ -1387,18 +1417,34 @@ netlink_kernel_create(struct net *net, int unit, unsigned int groups,
                nl_table[unit].registered = 1;
        } else {
                kfree(listeners);
+               nl_table[unit].registered++;
        }
        netlink_table_ungrab();
-
        return sk;
 
 out_sock_release:
        kfree(listeners);
+       __netlink_release(sk);
+       return NULL;
+
+out_sock_release_nosk:
        sock_release(sock);
        return NULL;
 }
 EXPORT_SYMBOL(netlink_kernel_create);
 
+
+void
+netlink_kernel_release(struct sock *sk)
+{
+       if (sk == NULL || sk->sk_socket == NULL)
+               return;
+
+       __netlink_release(sk);
+}
+EXPORT_SYMBOL(netlink_kernel_release);
+
+
 /**
  * netlink_change_ngroups - change number of multicast groups
  *
@@ -1728,6 +1774,7 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
 }
 
 static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
+       __acquires(nl_table_lock)
 {
        read_lock(&nl_table_lock);
        return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;
@@ -1776,6 +1823,7 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 }
 
 static void netlink_seq_stop(struct seq_file *seq, void *v)
+       __releases(nl_table_lock)
 {
        read_unlock(&nl_table_lock);
 }
@@ -1919,7 +1967,7 @@ static int __init netlink_proto_init(void)
        for (i = 0; i < MAX_LINKS; i++) {
                struct nl_pid_hash *hash = &nl_table[i].hash;
 
-               hash->table = nl_pid_hash_alloc(1 * sizeof(*hash->table));
+               hash->table = nl_pid_hash_zalloc(1 * sizeof(*hash->table));
                if (!hash->table) {
                        while (i-- > 0)
                                nl_pid_hash_free(nl_table[i].hash.table,
@@ -1927,7 +1975,6 @@ static int __init netlink_proto_init(void)
                        kfree(nl_table);
                        goto panic;
                }
-               memset(hash->table, 0, 1 * sizeof(*hash->table));
                hash->max_shift = order;
                hash->shift = 0;
                hash->mask = 0;