]> git.karo-electronics.de Git - karo-tx-linux.git/blobdiff - net/sctp/input.c
Merge tag 'arc-4.10-rc1-part1' of git://git.kernel.org/pub/scm/linux/kernel/git/vgupt...
[karo-tx-linux.git] / net / sctp / input.c
index a01a56ec8b8cd3709aaf0f2af81ead42b4faa05f..458e506ef84bae3c53c239d6cf89a9349faafb11 100644 (file)
@@ -790,10 +790,9 @@ hit:
 
 /* rhashtable for transport */
 struct sctp_hash_cmp_arg {
-       const struct sctp_endpoint      *ep;
-       const union sctp_addr           *laddr;
-       const union sctp_addr           *paddr;
-       const struct net                *net;
+       const union sctp_addr   *paddr;
+       const struct net        *net;
+       u16                     lport;
 };
 
 static inline int sctp_hash_cmp(struct rhashtable_compare_arg *arg,
@@ -801,7 +800,6 @@ static inline int sctp_hash_cmp(struct rhashtable_compare_arg *arg,
 {
        struct sctp_transport *t = (struct sctp_transport *)ptr;
        const struct sctp_hash_cmp_arg *x = arg->key;
-       struct sctp_association *asoc;
        int err = 1;
 
        if (!sctp_cmp_addr_exact(&t->ipaddr, x->paddr))
@@ -809,19 +807,10 @@ static inline int sctp_hash_cmp(struct rhashtable_compare_arg *arg,
        if (!sctp_transport_hold(t))
                return err;
 
-       asoc = t->asoc;
-       if (!net_eq(sock_net(asoc->base.sk), x->net))
+       if (!net_eq(sock_net(t->asoc->base.sk), x->net))
+               goto out;
+       if (x->lport != htons(t->asoc->base.bind_addr.port))
                goto out;
-       if (x->ep) {
-               if (x->ep != asoc->ep)
-                       goto out;
-       } else {
-               if (x->laddr->v4.sin_port != htons(asoc->base.bind_addr.port))
-                       goto out;
-               if (!sctp_bind_addr_match(&asoc->base.bind_addr,
-                                         x->laddr, sctp_sk(asoc->base.sk)))
-                       goto out;
-       }
 
        err = 0;
 out:
@@ -851,11 +840,9 @@ static inline u32 sctp_hash_key(const void *data, u32 len, u32 seed)
        const struct sctp_hash_cmp_arg *x = data;
        const union sctp_addr *paddr = x->paddr;
        const struct net *net = x->net;
-       u16 lport;
+       u16 lport = x->lport;
        u32 addr;
 
-       lport = x->ep ? htons(x->ep->base.bind_addr.port) :
-                       x->laddr->v4.sin_port;
        if (paddr->sa.sa_family == AF_INET6)
                addr = jhash(&paddr->v6.sin6_addr, 16, seed);
        else
@@ -875,29 +862,32 @@ static const struct rhashtable_params sctp_hash_params = {
 
 int sctp_transport_hashtable_init(void)
 {
-       return rhashtable_init(&sctp_transport_hashtable, &sctp_hash_params);
+       return rhltable_init(&sctp_transport_hashtable, &sctp_hash_params);
 }
 
 void sctp_transport_hashtable_destroy(void)
 {
-       rhashtable_destroy(&sctp_transport_hashtable);
+       rhltable_destroy(&sctp_transport_hashtable);
 }
 
-void sctp_hash_transport(struct sctp_transport *t)
+int sctp_hash_transport(struct sctp_transport *t)
 {
        struct sctp_hash_cmp_arg arg;
+       int err;
 
        if (t->asoc->temp)
-               return;
+               return 0;
 
-       arg.ep = t->asoc->ep;
-       arg.paddr = &t->ipaddr;
        arg.net   = sock_net(t->asoc->base.sk);
+       arg.paddr = &t->ipaddr;
+       arg.lport = htons(t->asoc->base.bind_addr.port);
 
-reinsert:
-       if (rhashtable_lookup_insert_key(&sctp_transport_hashtable, &arg,
-                                        &t->node, sctp_hash_params) == -EBUSY)
-               goto reinsert;
+       err = rhltable_insert_key(&sctp_transport_hashtable, &arg,
+                                 &t->node, sctp_hash_params);
+       if (err)
+               pr_err_once("insert transport fail, errno %d\n", err);
+
+       return err;
 }
 
 void sctp_unhash_transport(struct sctp_transport *t)
@@ -905,39 +895,62 @@ void sctp_unhash_transport(struct sctp_transport *t)
        if (t->asoc->temp)
                return;
 
-       rhashtable_remove_fast(&sctp_transport_hashtable, &t->node,
-                              sctp_hash_params);
+       rhltable_remove(&sctp_transport_hashtable, &t->node,
+                       sctp_hash_params);
 }
 
+/* return a transport with holding it */
 struct sctp_transport *sctp_addrs_lookup_transport(
                                struct net *net,
                                const union sctp_addr *laddr,
                                const union sctp_addr *paddr)
 {
+       struct rhlist_head *tmp, *list;
+       struct sctp_transport *t;
        struct sctp_hash_cmp_arg arg = {
-               .ep    = NULL,
-               .laddr = laddr,
                .paddr = paddr,
                .net   = net,
+               .lport = laddr->v4.sin_port,
        };
 
-       return rhashtable_lookup_fast(&sctp_transport_hashtable, &arg,
-                                     sctp_hash_params);
+       list = rhltable_lookup(&sctp_transport_hashtable, &arg,
+                              sctp_hash_params);
+
+       rhl_for_each_entry_rcu(t, tmp, list, node) {
+               if (!sctp_transport_hold(t))
+                       continue;
+
+               if (sctp_bind_addr_match(&t->asoc->base.bind_addr,
+                                        laddr, sctp_sk(t->asoc->base.sk)))
+                       return t;
+               sctp_transport_put(t);
+       }
+
+       return NULL;
 }
 
+/* return a transport without holding it, as it's only used under sock lock */
 struct sctp_transport *sctp_epaddr_lookup_transport(
                                const struct sctp_endpoint *ep,
                                const union sctp_addr *paddr)
 {
        struct net *net = sock_net(ep->base.sk);
+       struct rhlist_head *tmp, *list;
+       struct sctp_transport *t;
        struct sctp_hash_cmp_arg arg = {
-               .ep    = ep,
                .paddr = paddr,
                .net   = net,
+               .lport = htons(ep->base.bind_addr.port),
        };
 
-       return rhashtable_lookup_fast(&sctp_transport_hashtable, &arg,
-                                     sctp_hash_params);
+       list = rhltable_lookup(&sctp_transport_hashtable, &arg,
+                              sctp_hash_params);
+
+       rhl_for_each_entry_rcu(t, tmp, list, node)
+               if (ep == t->asoc->ep)
+                       return t;
+
+       return NULL;
 }
 
 /* Look up an association. */
@@ -951,7 +964,7 @@ static struct sctp_association *__sctp_lookup_association(
        struct sctp_association *asoc = NULL;
 
        t = sctp_addrs_lookup_transport(net, local, peer);
-       if (!t || !sctp_transport_hold(t))
+       if (!t)
                goto out;
 
        asoc = t->asoc;