[PATCH] RPC: transport switch function naming
[safe/jmp/linux-2.6] / net / sunrpc / xprtsock.c
1 /*
2  * linux/net/sunrpc/xprtsock.c
3  *
4  * Client-side transport implementation for sockets.
5  *
6  * TCP callback races fixes (C) 1998 Red Hat Software <alan@redhat.com>
7  * TCP send fixes (C) 1998 Red Hat Software <alan@redhat.com>
8  * TCP NFS related read + write fixes
9  *  (C) 1999 Dave Airlie, University of Limerick, Ireland <airlied@linux.ie>
10  *
11  * Rewrite of larges part of the code in order to stabilize TCP stuff.
12  * Fix behaviour when socket buffer is full.
13  *  (C) 1999 Trond Myklebust <trond.myklebust@fys.uio.no>
14  */
15
16 #include <linux/types.h>
17 #include <linux/slab.h>
18 #include <linux/capability.h>
19 #include <linux/sched.h>
20 #include <linux/pagemap.h>
21 #include <linux/errno.h>
22 #include <linux/socket.h>
23 #include <linux/in.h>
24 #include <linux/net.h>
25 #include <linux/mm.h>
26 #include <linux/udp.h>
27 #include <linux/tcp.h>
28 #include <linux/sunrpc/clnt.h>
29 #include <linux/file.h>
30
31 #include <net/sock.h>
32 #include <net/checksum.h>
33 #include <net/udp.h>
34 #include <net/tcp.h>
35
36 /*
37  * Maximum port number to use when requesting a reserved port.
38  */
39 #define XS_MAX_RESVPORT         (800U)
40
41 #ifdef RPC_DEBUG
42 # undef  RPC_DEBUG_DATA
43 # define RPCDBG_FACILITY        RPCDBG_TRANS
44 #endif
45
46 #ifdef RPC_DEBUG_DATA
47 static void xs_pktdump(char *msg, u32 *packet, unsigned int count)
48 {
49         u8 *buf = (u8 *) packet;
50         int j;
51
52         dprintk("RPC:      %s\n", msg);
53         for (j = 0; j < count && j < 128; j += 4) {
54                 if (!(j & 31)) {
55                         if (j)
56                                 dprintk("\n");
57                         dprintk("0x%04x ", j);
58                 }
59                 dprintk("%02x%02x%02x%02x ",
60                         buf[j], buf[j+1], buf[j+2], buf[j+3]);
61         }
62         dprintk("\n");
63 }
64 #else
65 static inline void xs_pktdump(char *msg, u32 *packet, unsigned int count)
66 {
67         /* NOP */
68 }
69 #endif
70
71 /**
72  * xs_sendpages - write pages directly to a socket
73  * @sock: socket to send on
74  * @addr: UDP only -- address of destination
75  * @addrlen: UDP only -- length of destination address
76  * @xdr: buffer containing this request
77  * @base: starting position in the buffer
78  *
79  */
80 static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base, int msgflags)
81 {
82         struct page **ppage = xdr->pages;
83         unsigned int len, pglen = xdr->page_len;
84         int err, ret = 0;
85         ssize_t (*sendpage)(struct socket *, struct page *, int, size_t, int);
86
87         len = xdr->head[0].iov_len;
88         if (base < len || (addr != NULL && base == 0)) {
89                 struct kvec iov = {
90                         .iov_base = xdr->head[0].iov_base + base,
91                         .iov_len  = len - base,
92                 };
93                 struct msghdr msg = {
94                         .msg_name    = addr,
95                         .msg_namelen = addrlen,
96                         .msg_flags   = msgflags,
97                 };
98                 if (xdr->len > len)
99                         msg.msg_flags |= MSG_MORE;
100
101                 if (iov.iov_len != 0)
102                         err = kernel_sendmsg(sock, &msg, &iov, 1, iov.iov_len);
103                 else
104                         err = kernel_sendmsg(sock, &msg, NULL, 0, 0);
105                 if (ret == 0)
106                         ret = err;
107                 else if (err > 0)
108                         ret += err;
109                 if (err != iov.iov_len)
110                         goto out;
111                 base = 0;
112         } else
113                 base -= len;
114
115         if (pglen == 0)
116                 goto copy_tail;
117         if (base >= pglen) {
118                 base -= pglen;
119                 goto copy_tail;
120         }
121         if (base || xdr->page_base) {
122                 pglen -= base;
123                 base += xdr->page_base;
124                 ppage += base >> PAGE_CACHE_SHIFT;
125                 base &= ~PAGE_CACHE_MASK;
126         }
127
128         sendpage = sock->ops->sendpage ? : sock_no_sendpage;
129         do {
130                 int flags = msgflags;
131
132                 len = PAGE_CACHE_SIZE;
133                 if (base)
134                         len -= base;
135                 if (pglen < len)
136                         len = pglen;
137
138                 if (pglen != len || xdr->tail[0].iov_len != 0)
139                         flags |= MSG_MORE;
140
141                 /* Hmm... We might be dealing with highmem pages */
142                 if (PageHighMem(*ppage))
143                         sendpage = sock_no_sendpage;
144                 err = sendpage(sock, *ppage, base, len, flags);
145                 if (ret == 0)
146                         ret = err;
147                 else if (err > 0)
148                         ret += err;
149                 if (err != len)
150                         goto out;
151                 base = 0;
152                 ppage++;
153         } while ((pglen -= len) != 0);
154 copy_tail:
155         len = xdr->tail[0].iov_len;
156         if (base < len) {
157                 struct kvec iov = {
158                         .iov_base = xdr->tail[0].iov_base + base,
159                         .iov_len  = len - base,
160                 };
161                 struct msghdr msg = {
162                         .msg_flags   = msgflags,
163                 };
164                 err = kernel_sendmsg(sock, &msg, &iov, 1, iov.iov_len);
165                 if (ret == 0)
166                         ret = err;
167                 else if (err > 0)
168                         ret += err;
169         }
170 out:
171         return ret;
172 }
173
174 /**
175  * xs_sendmsg - write an RPC request to a socket
176  * @xprt: generic transport
177  * @req: the RPC request to write
178  *
179  */
180 static int xs_sendmsg(struct rpc_xprt *xprt, struct rpc_rqst *req)
181 {
182         struct socket *sock = xprt->sock;
183         struct xdr_buf *xdr = &req->rq_snd_buf;
184         struct sockaddr *addr = NULL;
185         int addrlen = 0;
186         unsigned int skip;
187         int result;
188
189         if (!sock)
190                 return -ENOTCONN;
191
192         xs_pktdump("packet data:",
193                                 req->rq_svec->iov_base,
194                                 req->rq_svec->iov_len);
195
196         /* For UDP, we need to provide an address */
197         if (!xprt->stream) {
198                 addr = (struct sockaddr *) &xprt->addr;
199                 addrlen = sizeof(xprt->addr);
200         }
201         /* Don't repeat bytes */
202         skip = req->rq_bytes_sent;
203
204         clear_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
205         result = xs_sendpages(sock, addr, addrlen, xdr, skip, MSG_DONTWAIT);
206
207         dprintk("RPC:      xs_sendmsg(%d) = %d\n", xdr->len - skip, result);
208
209         if (result >= 0)
210                 return result;
211
212         switch (result) {
213         case -ECONNREFUSED:
214                 /* When the server has died, an ICMP port unreachable message
215                  * prompts ECONNREFUSED. */
216         case -EAGAIN:
217                 break;
218         case -ECONNRESET:
219         case -ENOTCONN:
220         case -EPIPE:
221                 /* connection broken */
222                 if (xprt->stream)
223                         result = -ENOTCONN;
224                 break;
225         default:
226                 break;
227         }
228         return result;
229 }
230
231 /**
232  * xs_send_request - write an RPC request to a socket
233  * @task: address of RPC task that manages the state of an RPC request
234  *
235  * Return values:
236  *      0:  The request has been sent
237  * EAGAIN:  The socket was blocked, please call again later to
238  *          complete the request
239  *  other:  Some other error occured, the request was not sent
240  *
241  * XXX: In the case of soft timeouts, should we eventually give up
242  *      if the socket is not able to make progress?
243  */
244 static int xs_send_request(struct rpc_task *task)
245 {
246         struct rpc_rqst *req = task->tk_rqstp;
247         struct rpc_xprt *xprt = req->rq_xprt;
248         int status, retry = 0;
249
250         /* set up everything as needed. */
251         /* Write the record marker */
252         if (xprt->stream) {
253                 u32 *marker = req->rq_svec[0].iov_base;
254
255                 *marker = htonl(0x80000000|(req->rq_slen-sizeof(*marker)));
256         }
257
258         /* Continue transmitting the packet/record. We must be careful
259          * to cope with writespace callbacks arriving _after_ we have
260          * called sendmsg().
261          */
262         while (1) {
263                 req->rq_xtime = jiffies;
264                 status = xs_sendmsg(xprt, req);
265
266                 if (status < 0)
267                         break;
268
269                 if (xprt->stream) {
270                         req->rq_bytes_sent += status;
271
272                         /* If we've sent the entire packet, immediately
273                          * reset the count of bytes sent. */
274                         if (req->rq_bytes_sent >= req->rq_slen) {
275                                 req->rq_bytes_sent = 0;
276                                 return 0;
277                         }
278                 } else {
279                         if (status >= req->rq_slen)
280                                 return 0;
281                         status = -EAGAIN;
282                         break;
283                 }
284
285                 dprintk("RPC: %4d xmit incomplete (%d left of %d)\n",
286                                 task->tk_pid, req->rq_slen - req->rq_bytes_sent,
287                                 req->rq_slen);
288
289                 status = -EAGAIN;
290                 if (retry++ > 50)
291                         break;
292         }
293
294         if (status == -EAGAIN) {
295                 if (test_bit(SOCK_ASYNC_NOSPACE, &xprt->sock->flags)) {
296                         /* Protect against races with xs_write_space */
297                         spin_lock_bh(&xprt->sock_lock);
298                         /* Don't race with disconnect */
299                         if (!xprt_connected(xprt))
300                                 task->tk_status = -ENOTCONN;
301                         else if (test_bit(SOCK_NOSPACE, &xprt->sock->flags)) {
302                                 task->tk_timeout = req->rq_timeout;
303                                 rpc_sleep_on(&xprt->pending, task, NULL, NULL);
304                         }
305                         spin_unlock_bh(&xprt->sock_lock);
306                         return status;
307                 }
308                 /* Keep holding the socket if it is blocked */
309                 rpc_delay(task, HZ>>4);
310         }
311         return status;
312 }
313
314 /**
315  * xs_close - close a socket
316  * @xprt: transport
317  *
318  */
319 static void xs_close(struct rpc_xprt *xprt)
320 {
321         struct socket *sock = xprt->sock;
322         struct sock *sk = xprt->inet;
323
324         if (!sk)
325                 return;
326
327         dprintk("RPC:      xs_close xprt %p\n", xprt);
328
329         write_lock_bh(&sk->sk_callback_lock);
330         xprt->inet = NULL;
331         xprt->sock = NULL;
332
333         sk->sk_user_data = NULL;
334         sk->sk_data_ready = xprt->old_data_ready;
335         sk->sk_state_change = xprt->old_state_change;
336         sk->sk_write_space = xprt->old_write_space;
337         write_unlock_bh(&sk->sk_callback_lock);
338
339         sk->sk_no_check = 0;
340
341         sock_release(sock);
342 }
343
344 /**
345  * xs_destroy - prepare to shutdown a transport
346  * @xprt: doomed transport
347  *
348  */
349 static void xs_destroy(struct rpc_xprt *xprt)
350 {
351         dprintk("RPC:      xs_destroy xprt %p\n", xprt);
352
353         cancel_delayed_work(&xprt->sock_connect);
354         flush_scheduled_work();
355
356         xprt_disconnect(xprt);
357         xs_close(xprt);
358         kfree(xprt->slot);
359 }
360
361 static inline struct rpc_xprt *xprt_from_sock(struct sock *sk)
362 {
363         return (struct rpc_xprt *) sk->sk_user_data;
364 }
365
366 /**
367  * xs_udp_data_ready - "data ready" callback for UDP sockets
368  * @sk: socket with data to read
369  * @len: how much data to read
370  *
371  */
372 static void xs_udp_data_ready(struct sock *sk, int len)
373 {
374         struct rpc_task *task;
375         struct rpc_xprt *xprt;
376         struct rpc_rqst *rovr;
377         struct sk_buff *skb;
378         int err, repsize, copied;
379         u32 _xid, *xp;
380
381         read_lock(&sk->sk_callback_lock);
382         dprintk("RPC:      xs_udp_data_ready...\n");
383         if (!(xprt = xprt_from_sock(sk)))
384                 goto out;
385
386         if ((skb = skb_recv_datagram(sk, 0, 1, &err)) == NULL)
387                 goto out;
388
389         if (xprt->shutdown)
390                 goto dropit;
391
392         repsize = skb->len - sizeof(struct udphdr);
393         if (repsize < 4) {
394                 dprintk("RPC:      impossible RPC reply size %d!\n", repsize);
395                 goto dropit;
396         }
397
398         /* Copy the XID from the skb... */
399         xp = skb_header_pointer(skb, sizeof(struct udphdr),
400                                 sizeof(_xid), &_xid);
401         if (xp == NULL)
402                 goto dropit;
403
404         /* Look up and lock the request corresponding to the given XID */
405         spin_lock(&xprt->sock_lock);
406         rovr = xprt_lookup_rqst(xprt, *xp);
407         if (!rovr)
408                 goto out_unlock;
409         task = rovr->rq_task;
410
411         dprintk("RPC: %4d received reply\n", task->tk_pid);
412
413         if ((copied = rovr->rq_private_buf.buflen) > repsize)
414                 copied = repsize;
415
416         /* Suck it into the iovec, verify checksum if not done by hw. */
417         if (csum_partial_copy_to_xdr(&rovr->rq_private_buf, skb))
418                 goto out_unlock;
419
420         /* Something worked... */
421         dst_confirm(skb->dst);
422
423         xprt_complete_rqst(xprt, rovr, copied);
424
425  out_unlock:
426         spin_unlock(&xprt->sock_lock);
427  dropit:
428         skb_free_datagram(sk, skb);
429  out:
430         read_unlock(&sk->sk_callback_lock);
431 }
432
433 static inline size_t xs_tcp_copy_data(skb_reader_t *desc, void *p, size_t len)
434 {
435         if (len > desc->count)
436                 len = desc->count;
437         if (skb_copy_bits(desc->skb, desc->offset, p, len)) {
438                 dprintk("RPC:      failed to copy %zu bytes from skb. %zu bytes remain\n",
439                                 len, desc->count);
440                 return 0;
441         }
442         desc->offset += len;
443         desc->count -= len;
444         dprintk("RPC:      copied %zu bytes from skb. %zu bytes remain\n",
445                         len, desc->count);
446         return len;
447 }
448
449 static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, skb_reader_t *desc)
450 {
451         size_t len, used;
452         char *p;
453
454         p = ((char *) &xprt->tcp_recm) + xprt->tcp_offset;
455         len = sizeof(xprt->tcp_recm) - xprt->tcp_offset;
456         used = xs_tcp_copy_data(desc, p, len);
457         xprt->tcp_offset += used;
458         if (used != len)
459                 return;
460         xprt->tcp_reclen = ntohl(xprt->tcp_recm);
461         if (xprt->tcp_reclen & 0x80000000)
462                 xprt->tcp_flags |= XPRT_LAST_FRAG;
463         else
464                 xprt->tcp_flags &= ~XPRT_LAST_FRAG;
465         xprt->tcp_reclen &= 0x7fffffff;
466         xprt->tcp_flags &= ~XPRT_COPY_RECM;
467         xprt->tcp_offset = 0;
468         /* Sanity check of the record length */
469         if (xprt->tcp_reclen < 4) {
470                 dprintk("RPC:      invalid TCP record fragment length\n");
471                 xprt_disconnect(xprt);
472                 return;
473         }
474         dprintk("RPC:      reading TCP record fragment of length %d\n",
475                         xprt->tcp_reclen);
476 }
477
478 static void xs_tcp_check_recm(struct rpc_xprt *xprt)
479 {
480         dprintk("RPC:      xprt = %p, tcp_copied = %lu, tcp_offset = %u, tcp_reclen = %u, tcp_flags = %lx\n",
481                         xprt, xprt->tcp_copied, xprt->tcp_offset, xprt->tcp_reclen, xprt->tcp_flags);
482         if (xprt->tcp_offset == xprt->tcp_reclen) {
483                 xprt->tcp_flags |= XPRT_COPY_RECM;
484                 xprt->tcp_offset = 0;
485                 if (xprt->tcp_flags & XPRT_LAST_FRAG) {
486                         xprt->tcp_flags &= ~XPRT_COPY_DATA;
487                         xprt->tcp_flags |= XPRT_COPY_XID;
488                         xprt->tcp_copied = 0;
489                 }
490         }
491 }
492
493 static inline void xs_tcp_read_xid(struct rpc_xprt *xprt, skb_reader_t *desc)
494 {
495         size_t len, used;
496         char *p;
497
498         len = sizeof(xprt->tcp_xid) - xprt->tcp_offset;
499         dprintk("RPC:      reading XID (%Zu bytes)\n", len);
500         p = ((char *) &xprt->tcp_xid) + xprt->tcp_offset;
501         used = xs_tcp_copy_data(desc, p, len);
502         xprt->tcp_offset += used;
503         if (used != len)
504                 return;
505         xprt->tcp_flags &= ~XPRT_COPY_XID;
506         xprt->tcp_flags |= XPRT_COPY_DATA;
507         xprt->tcp_copied = 4;
508         dprintk("RPC:      reading reply for XID %08x\n",
509                                                 ntohl(xprt->tcp_xid));
510         xs_tcp_check_recm(xprt);
511 }
512
513 static inline void xs_tcp_read_request(struct rpc_xprt *xprt, skb_reader_t *desc)
514 {
515         struct rpc_rqst *req;
516         struct xdr_buf *rcvbuf;
517         size_t len;
518         ssize_t r;
519
520         /* Find and lock the request corresponding to this xid */
521         spin_lock(&xprt->sock_lock);
522         req = xprt_lookup_rqst(xprt, xprt->tcp_xid);
523         if (!req) {
524                 xprt->tcp_flags &= ~XPRT_COPY_DATA;
525                 dprintk("RPC:      XID %08x request not found!\n",
526                                 ntohl(xprt->tcp_xid));
527                 spin_unlock(&xprt->sock_lock);
528                 return;
529         }
530
531         rcvbuf = &req->rq_private_buf;
532         len = desc->count;
533         if (len > xprt->tcp_reclen - xprt->tcp_offset) {
534                 skb_reader_t my_desc;
535
536                 len = xprt->tcp_reclen - xprt->tcp_offset;
537                 memcpy(&my_desc, desc, sizeof(my_desc));
538                 my_desc.count = len;
539                 r = xdr_partial_copy_from_skb(rcvbuf, xprt->tcp_copied,
540                                           &my_desc, xs_tcp_copy_data);
541                 desc->count -= r;
542                 desc->offset += r;
543         } else
544                 r = xdr_partial_copy_from_skb(rcvbuf, xprt->tcp_copied,
545                                           desc, xs_tcp_copy_data);
546
547         if (r > 0) {
548                 xprt->tcp_copied += r;
549                 xprt->tcp_offset += r;
550         }
551         if (r != len) {
552                 /* Error when copying to the receive buffer,
553                  * usually because we weren't able to allocate
554                  * additional buffer pages. All we can do now
555                  * is turn off XPRT_COPY_DATA, so the request
556                  * will not receive any additional updates,
557                  * and time out.
558                  * Any remaining data from this record will
559                  * be discarded.
560                  */
561                 xprt->tcp_flags &= ~XPRT_COPY_DATA;
562                 dprintk("RPC:      XID %08x truncated request\n",
563                                 ntohl(xprt->tcp_xid));
564                 dprintk("RPC:      xprt = %p, tcp_copied = %lu, tcp_offset = %u, tcp_reclen = %u\n",
565                                 xprt, xprt->tcp_copied, xprt->tcp_offset, xprt->tcp_reclen);
566                 goto out;
567         }
568
569         dprintk("RPC:      XID %08x read %Zd bytes\n",
570                         ntohl(xprt->tcp_xid), r);
571         dprintk("RPC:      xprt = %p, tcp_copied = %lu, tcp_offset = %u, tcp_reclen = %u\n",
572                         xprt, xprt->tcp_copied, xprt->tcp_offset, xprt->tcp_reclen);
573
574         if (xprt->tcp_copied == req->rq_private_buf.buflen)
575                 xprt->tcp_flags &= ~XPRT_COPY_DATA;
576         else if (xprt->tcp_offset == xprt->tcp_reclen) {
577                 if (xprt->tcp_flags & XPRT_LAST_FRAG)
578                         xprt->tcp_flags &= ~XPRT_COPY_DATA;
579         }
580
581 out:
582         if (!(xprt->tcp_flags & XPRT_COPY_DATA)) {
583                 dprintk("RPC: %4d received reply complete\n",
584                                 req->rq_task->tk_pid);
585                 xprt_complete_rqst(xprt, req, xprt->tcp_copied);
586         }
587         spin_unlock(&xprt->sock_lock);
588         xs_tcp_check_recm(xprt);
589 }
590
591 static inline void xs_tcp_read_discard(struct rpc_xprt *xprt, skb_reader_t *desc)
592 {
593         size_t len;
594
595         len = xprt->tcp_reclen - xprt->tcp_offset;
596         if (len > desc->count)
597                 len = desc->count;
598         desc->count -= len;
599         desc->offset += len;
600         xprt->tcp_offset += len;
601         dprintk("RPC:      discarded %Zu bytes\n", len);
602         xs_tcp_check_recm(xprt);
603 }
604
605 static int xs_tcp_data_recv(read_descriptor_t *rd_desc, struct sk_buff *skb, unsigned int offset, size_t len)
606 {
607         struct rpc_xprt *xprt = rd_desc->arg.data;
608         skb_reader_t desc = {
609                 .skb    = skb,
610                 .offset = offset,
611                 .count  = len,
612                 .csum   = 0
613         };
614
615         dprintk("RPC:      xs_tcp_data_recv started\n");
616         do {
617                 /* Read in a new fragment marker if necessary */
618                 /* Can we ever really expect to get completely empty fragments? */
619                 if (xprt->tcp_flags & XPRT_COPY_RECM) {
620                         xs_tcp_read_fraghdr(xprt, &desc);
621                         continue;
622                 }
623                 /* Read in the xid if necessary */
624                 if (xprt->tcp_flags & XPRT_COPY_XID) {
625                         xs_tcp_read_xid(xprt, &desc);
626                         continue;
627                 }
628                 /* Read in the request data */
629                 if (xprt->tcp_flags & XPRT_COPY_DATA) {
630                         xs_tcp_read_request(xprt, &desc);
631                         continue;
632                 }
633                 /* Skip over any trailing bytes on short reads */
634                 xs_tcp_read_discard(xprt, &desc);
635         } while (desc.count);
636         dprintk("RPC:      xs_tcp_data_recv done\n");
637         return len - desc.count;
638 }
639
640 /**
641  * xs_tcp_data_ready - "data ready" callback for TCP sockets
642  * @sk: socket with data to read
643  * @bytes: how much data to read
644  *
645  */
646 static void xs_tcp_data_ready(struct sock *sk, int bytes)
647 {
648         struct rpc_xprt *xprt;
649         read_descriptor_t rd_desc;
650
651         read_lock(&sk->sk_callback_lock);
652         dprintk("RPC:      xs_tcp_data_ready...\n");
653         if (!(xprt = xprt_from_sock(sk)))
654                 goto out;
655         if (xprt->shutdown)
656                 goto out;
657
658         /* We use rd_desc to pass struct xprt to xs_tcp_data_recv */
659         rd_desc.arg.data = xprt;
660         rd_desc.count = 65536;
661         tcp_read_sock(sk, &rd_desc, xs_tcp_data_recv);
662 out:
663         read_unlock(&sk->sk_callback_lock);
664 }
665
666 /**
667  * xs_tcp_state_change - callback to handle TCP socket state changes
668  * @sk: socket whose state has changed
669  *
670  */
671 static void xs_tcp_state_change(struct sock *sk)
672 {
673         struct rpc_xprt *xprt;
674
675         read_lock(&sk->sk_callback_lock);
676         if (!(xprt = xprt_from_sock(sk)))
677                 goto out;
678         dprintk("RPC:      xs_tcp_state_change client %p...\n", xprt);
679         dprintk("RPC:      state %x conn %d dead %d zapped %d\n",
680                                 sk->sk_state, xprt_connected(xprt),
681                                 sock_flag(sk, SOCK_DEAD),
682                                 sock_flag(sk, SOCK_ZAPPED));
683
684         switch (sk->sk_state) {
685         case TCP_ESTABLISHED:
686                 spin_lock_bh(&xprt->sock_lock);
687                 if (!xprt_test_and_set_connected(xprt)) {
688                         /* Reset TCP record info */
689                         xprt->tcp_offset = 0;
690                         xprt->tcp_reclen = 0;
691                         xprt->tcp_copied = 0;
692                         xprt->tcp_flags = XPRT_COPY_RECM | XPRT_COPY_XID;
693                         rpc_wake_up(&xprt->pending);
694                 }
695                 spin_unlock_bh(&xprt->sock_lock);
696                 break;
697         case TCP_SYN_SENT:
698         case TCP_SYN_RECV:
699                 break;
700         default:
701                 xprt_disconnect(xprt);
702                 break;
703         }
704  out:
705         read_unlock(&sk->sk_callback_lock);
706 }
707
708 /**
709  * xs_write_space - callback invoked when socket buffer space becomes
710  *                         available
711  * @sk: socket whose state has changed
712  *
713  * Called when more output buffer space is available for this socket.
714  * We try not to wake our writers until they can make "significant"
715  * progress, otherwise we'll waste resources thrashing sock_sendmsg
716  * with a bunch of small requests.
717  */
718 static void xs_write_space(struct sock *sk)
719 {
720         struct rpc_xprt *xprt;
721         struct socket *sock;
722
723         read_lock(&sk->sk_callback_lock);
724         if (!(xprt = xprt_from_sock(sk)) || !(sock = sk->sk_socket))
725                 goto out;
726         if (xprt->shutdown)
727                 goto out;
728
729         /* Wait until we have enough socket memory */
730         if (xprt->stream) {
731                 /* from net/core/stream.c:sk_stream_write_space */
732                 if (sk_stream_wspace(sk) < sk_stream_min_wspace(sk))
733                         goto out;
734         } else {
735                 /* from net/core/sock.c:sock_def_write_space */
736                 if (!sock_writeable(sk))
737                         goto out;
738         }
739
740         if (!test_and_clear_bit(SOCK_NOSPACE, &sock->flags))
741                 goto out;
742
743         spin_lock_bh(&xprt->sock_lock);
744         if (xprt->snd_task)
745                 rpc_wake_up_task(xprt->snd_task);
746         spin_unlock_bh(&xprt->sock_lock);
747 out:
748         read_unlock(&sk->sk_callback_lock);
749 }
750
751 /**
752  * xs_set_buffer_size - set send and receive limits
753  * @xprt: generic transport
754  *
755  * Set socket send and receive limits based on the
756  * sndsize and rcvsize fields in the generic transport
757  * structure. This applies only to UDP sockets.
758  */
759 static void xs_set_buffer_size(struct rpc_xprt *xprt)
760 {
761         struct sock *sk = xprt->inet;
762
763         if (xprt->stream)
764                 return;
765         if (xprt->rcvsize) {
766                 sk->sk_userlocks |= SOCK_RCVBUF_LOCK;
767                 sk->sk_rcvbuf = xprt->rcvsize * xprt->max_reqs *  2;
768         }
769         if (xprt->sndsize) {
770                 sk->sk_userlocks |= SOCK_SNDBUF_LOCK;
771                 sk->sk_sndbuf = xprt->sndsize * xprt->max_reqs * 2;
772                 sk->sk_write_space(sk);
773         }
774 }
775
776 static int xs_bindresvport(struct rpc_xprt *xprt, struct socket *sock)
777 {
778         struct sockaddr_in myaddr = {
779                 .sin_family = AF_INET,
780         };
781         int err, port;
782
783         /* Were we already bound to a given port? Try to reuse it */
784         port = xprt->port;
785         do {
786                 myaddr.sin_port = htons(port);
787                 err = sock->ops->bind(sock, (struct sockaddr *) &myaddr,
788                                                 sizeof(myaddr));
789                 if (err == 0) {
790                         xprt->port = port;
791                         dprintk("RPC:      xs_bindresvport bound to port %u\n",
792                                         port);
793                         return 0;
794                 }
795                 if (--port == 0)
796                         port = XS_MAX_RESVPORT;
797         } while (err == -EADDRINUSE && port != xprt->port);
798
799         dprintk("RPC:      can't bind to reserved port (%d).\n", -err);
800         return err;
801 }
802
803 static struct socket *xs_create(struct rpc_xprt *xprt, int proto, int resvport)
804 {
805         struct socket *sock;
806         int type, err;
807
808         dprintk("RPC:      xs_create(%s %d)\n",
809                            (proto == IPPROTO_UDP)? "udp" : "tcp", proto);
810
811         type = (proto == IPPROTO_UDP)? SOCK_DGRAM : SOCK_STREAM;
812
813         if ((err = sock_create_kern(PF_INET, type, proto, &sock)) < 0) {
814                 dprintk("RPC:      can't create socket (%d).\n", -err);
815                 return NULL;
816         }
817
818         /* If the caller has the capability, bind to a reserved port */
819         if (resvport && xs_bindresvport(xprt, sock) < 0)
820                 goto failed;
821
822         return sock;
823
824 failed:
825         sock_release(sock);
826         return NULL;
827 }
828
829 static void xs_bind(struct rpc_xprt *xprt, struct socket *sock)
830 {
831         struct sock *sk = sock->sk;
832
833         if (xprt->inet)
834                 return;
835
836         write_lock_bh(&sk->sk_callback_lock);
837         sk->sk_user_data = xprt;
838         xprt->old_data_ready = sk->sk_data_ready;
839         xprt->old_state_change = sk->sk_state_change;
840         xprt->old_write_space = sk->sk_write_space;
841         if (xprt->prot == IPPROTO_UDP) {
842                 sk->sk_data_ready = xs_udp_data_ready;
843                 sk->sk_no_check = UDP_CSUM_NORCV;
844                 xprt_set_connected(xprt);
845         } else {
846                 tcp_sk(sk)->nonagle = 1;        /* disable Nagle's algorithm */
847                 sk->sk_data_ready = xs_tcp_data_ready;
848                 sk->sk_state_change = xs_tcp_state_change;
849                 xprt_clear_connected(xprt);
850         }
851         sk->sk_write_space = xs_write_space;
852
853         /* Reset to new socket */
854         xprt->sock = sock;
855         xprt->inet = sk;
856         write_unlock_bh(&sk->sk_callback_lock);
857
858         return;
859 }
860
861 /**
862  * xs_connect_worker - try to connect a socket to a remote endpoint
863  * @args: RPC transport to connect
864  *
865  * Invoked by a work queue tasklet.
866  */
867 static void xs_connect_worker(void *args)
868 {
869         struct rpc_xprt *xprt = (struct rpc_xprt *)args;
870         struct socket *sock = xprt->sock;
871         int status = -EIO;
872
873         if (xprt->shutdown || xprt->addr.sin_port == 0)
874                 goto out;
875
876         dprintk("RPC:      xs_connect_worker xprt %p\n", xprt);
877
878         /*
879          * Start by resetting any existing state
880          */
881         xs_close(xprt);
882         sock = xs_create(xprt, xprt->prot, xprt->resvport);
883         if (sock == NULL) {
884                 /* couldn't create socket or bind to reserved port;
885                  * this is likely a permanent error, so cause an abort */
886                 goto out;
887         }
888         xs_bind(xprt, sock);
889         xs_set_buffer_size(xprt);
890
891         status = 0;
892         if (!xprt->stream)
893                 goto out;
894
895         /*
896          * Tell the socket layer to start connecting...
897          */
898         status = sock->ops->connect(sock, (struct sockaddr *) &xprt->addr,
899                         sizeof(xprt->addr), O_NONBLOCK);
900         dprintk("RPC: %p  connect status %d connected %d sock state %d\n",
901                         xprt, -status, xprt_connected(xprt), sock->sk->sk_state);
902         if (status < 0) {
903                 switch (status) {
904                         case -EINPROGRESS:
905                         case -EALREADY:
906                                 goto out_clear;
907                 }
908         }
909 out:
910         if (status < 0)
911                 rpc_wake_up_status(&xprt->pending, status);
912         else
913                 rpc_wake_up(&xprt->pending);
914 out_clear:
915         smp_mb__before_clear_bit();
916         clear_bit(XPRT_CONNECTING, &xprt->sockstate);
917         smp_mb__after_clear_bit();
918 }
919
920 /**
921  * xs_connect - connect a socket to a remote endpoint
922  * @task: address of RPC task that manages state of connect request
923  *
924  * TCP: If the remote end dropped the connection, delay reconnecting.
925  */
926 static void xs_connect(struct rpc_task *task)
927 {
928         struct rpc_xprt *xprt = task->tk_xprt;
929
930         if (!test_and_set_bit(XPRT_CONNECTING, &xprt->sockstate)) {
931                 if (xprt->sock != NULL) {
932                         dprintk("RPC:      xs_connect delayed xprt %p\n", xprt);
933                         schedule_delayed_work(&xprt->sock_connect,
934                                         RPC_REESTABLISH_TIMEOUT);
935                 } else {
936                         dprintk("RPC:      xs_connect scheduled xprt %p\n", xprt);
937                         schedule_work(&xprt->sock_connect);
938                         /* flush_scheduled_work can sleep... */
939                         if (!RPC_IS_ASYNC(task))
940                                 flush_scheduled_work();
941                 }
942         }
943 }
944
945 static struct rpc_xprt_ops xs_ops = {
946         .set_buffer_size        = xs_set_buffer_size,
947         .connect                = xs_connect,
948         .send_request           = xs_send_request,
949         .close                  = xs_close,
950         .destroy                = xs_destroy,
951 };
952
953 extern unsigned int xprt_udp_slot_table_entries;
954 extern unsigned int xprt_tcp_slot_table_entries;
955
956 /**
957  * xs_setup_udp - Set up transport to use a UDP socket
958  * @xprt: transport to set up
959  * @to:   timeout parameters
960  *
961  */
962 int xs_setup_udp(struct rpc_xprt *xprt, struct rpc_timeout *to)
963 {
964         size_t slot_table_size;
965
966         dprintk("RPC:      setting up udp-ipv4 transport...\n");
967
968         xprt->max_reqs = xprt_udp_slot_table_entries;
969         slot_table_size = xprt->max_reqs * sizeof(xprt->slot[0]);
970         xprt->slot = kmalloc(slot_table_size, GFP_KERNEL);
971         if (xprt->slot == NULL)
972                 return -ENOMEM;
973         memset(xprt->slot, 0, slot_table_size);
974
975         xprt->prot = IPPROTO_UDP;
976         xprt->port = XS_MAX_RESVPORT;
977         xprt->stream = 0;
978         xprt->nocong = 0;
979         xprt->cwnd = RPC_INITCWND;
980         xprt->resvport = capable(CAP_NET_BIND_SERVICE) ? 1 : 0;
981         /* XXX: header size can vary due to auth type, IPv6, etc. */
982         xprt->max_payload = (1U << 16) - (MAX_HEADER << 3);
983
984         INIT_WORK(&xprt->sock_connect, xs_connect_worker, xprt);
985
986         xprt->ops = &xs_ops;
987
988         if (to)
989                 xprt->timeout = *to;
990         else
991                 xprt_set_timeout(&xprt->timeout, 5, 5 * HZ);
992
993         return 0;
994 }
995
996 /**
997  * xs_setup_tcp - Set up transport to use a TCP socket
998  * @xprt: transport to set up
999  * @to: timeout parameters
1000  *
1001  */
1002 int xs_setup_tcp(struct rpc_xprt *xprt, struct rpc_timeout *to)
1003 {
1004         size_t slot_table_size;
1005
1006         dprintk("RPC:      setting up tcp-ipv4 transport...\n");
1007
1008         xprt->max_reqs = xprt_tcp_slot_table_entries;
1009         slot_table_size = xprt->max_reqs * sizeof(xprt->slot[0]);
1010         xprt->slot = kmalloc(slot_table_size, GFP_KERNEL);
1011         if (xprt->slot == NULL)
1012                 return -ENOMEM;
1013         memset(xprt->slot, 0, slot_table_size);
1014
1015         xprt->prot = IPPROTO_TCP;
1016         xprt->port = XS_MAX_RESVPORT;
1017         xprt->stream = 1;
1018         xprt->nocong = 1;
1019         xprt->cwnd = RPC_MAXCWND(xprt);
1020         xprt->resvport = capable(CAP_NET_BIND_SERVICE) ? 1 : 0;
1021         xprt->max_payload = (1U << 31) - 1;
1022
1023         INIT_WORK(&xprt->sock_connect, xs_connect_worker, xprt);
1024
1025         xprt->ops = &xs_ops;
1026
1027         if (to)
1028                 xprt->timeout = *to;
1029         else
1030                 xprt_set_timeout(&xprt->timeout, 2, 60 * HZ);
1031
1032         return 0;
1033 }