wl1251: fix a memory leak in probe
[safe/jmp/linux-2.6] / net / sunrpc / rpcb_clnt.c
index 28f50da..1211053 100644 (file)
@@ -20,6 +20,8 @@
 #include <linux/in6.h>
 #include <linux/kernel.h>
 #include <linux/errno.h>
+#include <linux/mutex.h>
+#include <linux/slab.h>
 #include <net/ipv6.h>
 
 #include <linux/sunrpc/clnt.h>
@@ -110,6 +112,9 @@ static void                 rpcb_getport_done(struct rpc_task *, void *);
 static void                    rpcb_map_release(void *data);
 static struct rpc_program      rpcb_program;
 
+static struct rpc_clnt *       rpcb_local_clnt;
+static struct rpc_clnt *       rpcb_local_clnt4;
+
 struct rpcbind_args {
        struct rpc_xprt *       r_xprt;
 
@@ -163,20 +168,60 @@ static const struct sockaddr_in rpcb_inaddr_loopback = {
        .sin_port               = htons(RPCBIND_PORT),
 };
 
-static struct rpc_clnt *rpcb_create_local(u32 version)
+static DEFINE_MUTEX(rpcb_create_local_mutex);
+
+/*
+ * Returns zero on success, otherwise a negative errno value
+ * is returned.
+ */
+static int rpcb_create_local(void)
 {
        struct rpc_create_args args = {
-               .protocol       = XPRT_TRANSPORT_UDP,
+               .protocol       = XPRT_TRANSPORT_TCP,
                .address        = (struct sockaddr *)&rpcb_inaddr_loopback,
                .addrsize       = sizeof(rpcb_inaddr_loopback),
                .servername     = "localhost",
                .program        = &rpcb_program,
-               .version        = version,
+               .version        = RPCBVERS_2,
                .authflavor     = RPC_AUTH_UNIX,
                .flags          = RPC_CLNT_CREATE_NOPING,
        };
+       struct rpc_clnt *clnt, *clnt4;
+       int result = 0;
+
+       if (rpcb_local_clnt)
+               return result;
+
+       mutex_lock(&rpcb_create_local_mutex);
+       if (rpcb_local_clnt)
+               goto out;
+
+       clnt = rpc_create(&args);
+       if (IS_ERR(clnt)) {
+               dprintk("RPC:       failed to create local rpcbind "
+                               "client (errno %ld).\n", PTR_ERR(clnt));
+               result = -PTR_ERR(clnt);
+               goto out;
+       }
 
-       return rpc_create(&args);
+       /*
+        * This results in an RPC ping.  On systems running portmapper,
+        * the v4 ping will fail.  Proceed anyway, but disallow rpcb
+        * v4 upcalls.
+        */
+       clnt4 = rpc_bind_new_program(clnt, &rpcb_program, RPCBVERS_4);
+       if (IS_ERR(clnt4)) {
+               dprintk("RPC:       failed to create local rpcbind v4 "
+                               "cleint (errno %ld).\n", PTR_ERR(clnt4));
+               clnt4 = NULL;
+       }
+
+       rpcb_local_clnt = clnt;
+       rpcb_local_clnt4 = clnt4;
+
+out:
+       mutex_unlock(&rpcb_create_local_mutex);
+       return result;
 }
 
 static struct rpc_clnt *rpcb_create(char *hostname, struct sockaddr *srvaddr,
@@ -208,20 +253,13 @@ static struct rpc_clnt *rpcb_create(char *hostname, struct sockaddr *srvaddr,
        return rpc_create(&args);
 }
 
-static int rpcb_register_call(const u32 version, struct rpc_message *msg)
+static int rpcb_register_call(struct rpc_clnt *clnt, struct rpc_message *msg)
 {
-       struct rpc_clnt *rpcb_clnt;
        int result, error = 0;
 
        msg->rpc_resp = &result;
 
-       rpcb_clnt = rpcb_create_local(version);
-       if (!IS_ERR(rpcb_clnt)) {
-               error = rpc_call_sync(rpcb_clnt, msg, 0);
-               rpc_shutdown_client(rpcb_clnt);
-       } else
-               error = PTR_ERR(rpcb_clnt);
-
+       error = rpc_call_sync(clnt, msg, RPC_TASK_SOFTCONN);
        if (error < 0) {
                dprintk("RPC:       failed to contact local rpcbind "
                                "server (errno %d).\n", -error);
@@ -276,6 +314,11 @@ int rpcb_register(u32 prog, u32 vers, int prot, unsigned short port)
        struct rpc_message msg = {
                .rpc_argp       = &map,
        };
+       int error;
+
+       error = rpcb_create_local();
+       if (error)
+               return error;
 
        dprintk("RPC:       %sregistering (%u, %u, %d, %u) with local "
                        "rpcbind\n", (port ? "" : "un"),
@@ -285,7 +328,7 @@ int rpcb_register(u32 prog, u32 vers, int prot, unsigned short port)
        if (port)
                msg.rpc_proc = &rpcb_procedures2[RPCBPROC_SET];
 
-       return rpcb_register_call(RPCBVERS_2, &msg);
+       return rpcb_register_call(rpcb_local_clnt, &msg);
 }
 
 /*
@@ -310,7 +353,7 @@ static int rpcb_register_inet4(const struct sockaddr *sap,
        if (port)
                msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET];
 
-       result = rpcb_register_call(RPCBVERS_4, msg);
+       result = rpcb_register_call(rpcb_local_clnt4, msg);
        kfree(map->r_addr);
        return result;
 }
@@ -337,7 +380,7 @@ static int rpcb_register_inet6(const struct sockaddr *sap,
        if (port)
                msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET];
 
-       result = rpcb_register_call(RPCBVERS_4, msg);
+       result = rpcb_register_call(rpcb_local_clnt4, msg);
        kfree(map->r_addr);
        return result;
 }
@@ -353,7 +396,7 @@ static int rpcb_unregister_all_protofamilies(struct rpc_message *msg)
        map->r_addr = "";
        msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET];
 
-       return rpcb_register_call(RPCBVERS_4, msg);
+       return rpcb_register_call(rpcb_local_clnt4, msg);
 }
 
 /**
@@ -411,6 +454,13 @@ int rpcb_v4_register(const u32 program, const u32 version,
        struct rpc_message msg = {
                .rpc_argp       = &map,
        };
+       int error;
+
+       error = rpcb_create_local();
+       if (error)
+               return error;
+       if (rpcb_local_clnt4 == NULL)
+               return -EPROTONOSUPPORT;
 
        if (address == NULL)
                return rpcb_unregister_all_protofamilies(&msg);
@@ -488,7 +538,7 @@ static struct rpc_task *rpcb_call_async(struct rpc_clnt *rpcb_clnt, struct rpcbi
                .rpc_message = &msg,
                .callback_ops = &rpcb_getport_ops,
                .callback_data = map,
-               .flags = RPC_TASK_ASYNC,
+               .flags = RPC_TASK_ASYNC | RPC_TASK_SOFTCONN,
        };
 
        return rpc_run_task(&task_setup_data);
@@ -1024,3 +1074,15 @@ static struct rpc_program rpcb_program = {
        .version        = rpcb_version,
        .stats          = &rpcb_stats,
 };
+
+/**
+ * cleanup_rpcb_clnt - remove xprtsock's sysctls, unregister
+ *
+ */
+void cleanup_rpcb_clnt(void)
+{
+       if (rpcb_local_clnt4)
+               rpc_shutdown_client(rpcb_local_clnt4);
+       if (rpcb_local_clnt)
+               rpc_shutdown_client(rpcb_local_clnt);
+}