[PATCH] RPC: separate TCP and UDP write space callbacks
[safe/jmp/linux-2.6] / net / sunrpc / xprtsock.c
1 /*
2  * linux/net/sunrpc/xprtsock.c
3  *
4  * Client-side transport implementation for sockets.
5  *
6  * TCP callback races fixes (C) 1998 Red Hat Software <alan@redhat.com>
7  * TCP send fixes (C) 1998 Red Hat Software <alan@redhat.com>
8  * TCP NFS related read + write fixes
9  *  (C) 1999 Dave Airlie, University of Limerick, Ireland <airlied@linux.ie>
10  *
11  * Rewrite of larges part of the code in order to stabilize TCP stuff.
12  * Fix behaviour when socket buffer is full.
13  *  (C) 1999 Trond Myklebust <trond.myklebust@fys.uio.no>
14  *
15  * IP socket transport implementation, (C) 2005 Chuck Lever <cel@netapp.com>
16  */
17
18 #include <linux/types.h>
19 #include <linux/slab.h>
20 #include <linux/capability.h>
21 #include <linux/sched.h>
22 #include <linux/pagemap.h>
23 #include <linux/errno.h>
24 #include <linux/socket.h>
25 #include <linux/in.h>
26 #include <linux/net.h>
27 #include <linux/mm.h>
28 #include <linux/udp.h>
29 #include <linux/tcp.h>
30 #include <linux/sunrpc/clnt.h>
31 #include <linux/file.h>
32
33 #include <net/sock.h>
34 #include <net/checksum.h>
35 #include <net/udp.h>
36 #include <net/tcp.h>
37
38 /*
39  * Maximum port number to use when requesting a reserved port.
40  */
41 #define XS_MAX_RESVPORT         (800U)
42
43 #ifdef RPC_DEBUG
44 # undef  RPC_DEBUG_DATA
45 # define RPCDBG_FACILITY        RPCDBG_TRANS
46 #endif
47
48 #ifdef RPC_DEBUG_DATA
49 static void xs_pktdump(char *msg, u32 *packet, unsigned int count)
50 {
51         u8 *buf = (u8 *) packet;
52         int j;
53
54         dprintk("RPC:      %s\n", msg);
55         for (j = 0; j < count && j < 128; j += 4) {
56                 if (!(j & 31)) {
57                         if (j)
58                                 dprintk("\n");
59                         dprintk("0x%04x ", j);
60                 }
61                 dprintk("%02x%02x%02x%02x ",
62                         buf[j], buf[j+1], buf[j+2], buf[j+3]);
63         }
64         dprintk("\n");
65 }
66 #else
67 static inline void xs_pktdump(char *msg, u32 *packet, unsigned int count)
68 {
69         /* NOP */
70 }
71 #endif
72
73 #define XS_SENDMSG_FLAGS        (MSG_DONTWAIT | MSG_NOSIGNAL)
74
75 static inline int xs_send_head(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base, unsigned int len)
76 {
77         struct kvec iov = {
78                 .iov_base       = xdr->head[0].iov_base + base,
79                 .iov_len        = len - base,
80         };
81         struct msghdr msg = {
82                 .msg_name       = addr,
83                 .msg_namelen    = addrlen,
84                 .msg_flags      = XS_SENDMSG_FLAGS,
85         };
86
87         if (xdr->len > len)
88                 msg.msg_flags |= MSG_MORE;
89
90         if (likely(iov.iov_len))
91                 return kernel_sendmsg(sock, &msg, &iov, 1, iov.iov_len);
92         return kernel_sendmsg(sock, &msg, NULL, 0, 0);
93 }
94
95 static int xs_send_tail(struct socket *sock, struct xdr_buf *xdr, unsigned int base, unsigned int len)
96 {
97         struct kvec iov = {
98                 .iov_base       = xdr->tail[0].iov_base + base,
99                 .iov_len        = len - base,
100         };
101         struct msghdr msg = {
102                 .msg_flags      = XS_SENDMSG_FLAGS,
103         };
104
105         return kernel_sendmsg(sock, &msg, &iov, 1, iov.iov_len);
106 }
107
108 /**
109  * xs_sendpages - write pages directly to a socket
110  * @sock: socket to send on
111  * @addr: UDP only -- address of destination
112  * @addrlen: UDP only -- length of destination address
113  * @xdr: buffer containing this request
114  * @base: starting position in the buffer
115  *
116  */
117 static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base)
118 {
119         struct page **ppage = xdr->pages;
120         unsigned int len, pglen = xdr->page_len;
121         int err, ret = 0;
122         ssize_t (*sendpage)(struct socket *, struct page *, int, size_t, int);
123
124         len = xdr->head[0].iov_len;
125         if (base < len || (addr != NULL && base == 0)) {
126                 err = xs_send_head(sock, addr, addrlen, xdr, base, len);
127                 if (ret == 0)
128                         ret = err;
129                 else if (err > 0)
130                         ret += err;
131                 if (err != (len - base))
132                         goto out;
133                 base = 0;
134         } else
135                 base -= len;
136
137         if (unlikely(pglen == 0))
138                 goto copy_tail;
139         if (unlikely(base >= pglen)) {
140                 base -= pglen;
141                 goto copy_tail;
142         }
143         if (base || xdr->page_base) {
144                 pglen -= base;
145                 base += xdr->page_base;
146                 ppage += base >> PAGE_CACHE_SHIFT;
147                 base &= ~PAGE_CACHE_MASK;
148         }
149
150         sendpage = sock->ops->sendpage ? : sock_no_sendpage;
151         do {
152                 int flags = XS_SENDMSG_FLAGS;
153
154                 len = PAGE_CACHE_SIZE;
155                 if (base)
156                         len -= base;
157                 if (pglen < len)
158                         len = pglen;
159
160                 if (pglen != len || xdr->tail[0].iov_len != 0)
161                         flags |= MSG_MORE;
162
163                 /* Hmm... We might be dealing with highmem pages */
164                 if (PageHighMem(*ppage))
165                         sendpage = sock_no_sendpage;
166                 err = sendpage(sock, *ppage, base, len, flags);
167                 if (ret == 0)
168                         ret = err;
169                 else if (err > 0)
170                         ret += err;
171                 if (err != len)
172                         goto out;
173                 base = 0;
174                 ppage++;
175         } while ((pglen -= len) != 0);
176 copy_tail:
177         len = xdr->tail[0].iov_len;
178         if (base < len) {
179                 err = xs_send_tail(sock, xdr, base, len);
180                 if (ret == 0)
181                         ret = err;
182                 else if (err > 0)
183                         ret += err;
184         }
185 out:
186         return ret;
187 }
188
189 /**
190  * xs_sendmsg - write an RPC request to a socket
191  * @xprt: generic transport
192  * @req: the RPC request to write
193  *
194  */
195 static int xs_sendmsg(struct rpc_xprt *xprt, struct rpc_rqst *req)
196 {
197         struct socket *sock = xprt->sock;
198         struct xdr_buf *xdr = &req->rq_snd_buf;
199         struct sockaddr *addr = NULL;
200         int addrlen = 0;
201         unsigned int skip;
202         int result;
203
204         if (!sock)
205                 return -ENOTCONN;
206
207         xs_pktdump("packet data:",
208                                 req->rq_svec->iov_base,
209                                 req->rq_svec->iov_len);
210
211         /* For UDP, we need to provide an address */
212         if (!xprt->stream) {
213                 addr = (struct sockaddr *) &xprt->addr;
214                 addrlen = sizeof(xprt->addr);
215         }
216         /* Don't repeat bytes */
217         skip = req->rq_bytes_sent;
218
219         clear_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
220         result = xs_sendpages(sock, addr, addrlen, xdr, skip);
221
222         dprintk("RPC:      xs_sendmsg(%d) = %d\n", xdr->len - skip, result);
223
224         if (result >= 0)
225                 return result;
226
227         switch (result) {
228         case -ECONNREFUSED:
229                 /* When the server has died, an ICMP port unreachable message
230                  * prompts ECONNREFUSED. */
231         case -EAGAIN:
232                 break;
233         case -ECONNRESET:
234         case -ENOTCONN:
235         case -EPIPE:
236                 /* connection broken */
237                 if (xprt->stream)
238                         result = -ENOTCONN;
239                 break;
240         default:
241                 break;
242         }
243         return result;
244 }
245
246 /**
247  * xs_send_request - write an RPC request to a socket
248  * @task: address of RPC task that manages the state of an RPC request
249  *
250  * Return values:
251  *      0:  The request has been sent
252  * EAGAIN:  The socket was blocked, please call again later to
253  *          complete the request
254  *  other:  Some other error occured, the request was not sent
255  *
256  * XXX: In the case of soft timeouts, should we eventually give up
257  *      if the socket is not able to make progress?
258  */
259 static int xs_send_request(struct rpc_task *task)
260 {
261         struct rpc_rqst *req = task->tk_rqstp;
262         struct rpc_xprt *xprt = req->rq_xprt;
263         int status, retry = 0;
264
265         /* set up everything as needed. */
266         /* Write the record marker */
267         if (xprt->stream) {
268                 u32 *marker = req->rq_svec[0].iov_base;
269
270                 *marker = htonl(0x80000000|(req->rq_slen-sizeof(*marker)));
271         }
272
273         /* Continue transmitting the packet/record. We must be careful
274          * to cope with writespace callbacks arriving _after_ we have
275          * called sendmsg().
276          */
277         while (1) {
278                 req->rq_xtime = jiffies;
279                 status = xs_sendmsg(xprt, req);
280
281                 if (status < 0)
282                         break;
283
284                 if (xprt->stream) {
285                         req->rq_bytes_sent += status;
286
287                         /* If we've sent the entire packet, immediately
288                          * reset the count of bytes sent. */
289                         if (req->rq_bytes_sent >= req->rq_slen) {
290                                 req->rq_bytes_sent = 0;
291                                 return 0;
292                         }
293                 } else {
294                         if (status >= req->rq_slen)
295                                 return 0;
296                         status = -EAGAIN;
297                         break;
298                 }
299
300                 dprintk("RPC: %4d xmit incomplete (%d left of %d)\n",
301                                 task->tk_pid, req->rq_slen - req->rq_bytes_sent,
302                                 req->rq_slen);
303
304                 status = -EAGAIN;
305                 if (retry++ > 50)
306                         break;
307         }
308
309         if (status == -EAGAIN) {
310                 if (test_bit(SOCK_ASYNC_NOSPACE, &xprt->sock->flags)) {
311                         /* Protect against races with write_space */
312                         spin_lock_bh(&xprt->transport_lock);
313                         /* Don't race with disconnect */
314                         if (!xprt_connected(xprt))
315                                 task->tk_status = -ENOTCONN;
316                         else if (test_bit(SOCK_NOSPACE, &xprt->sock->flags))
317                                 xprt_wait_for_buffer_space(task);
318                         spin_unlock_bh(&xprt->transport_lock);
319                         return status;
320                 }
321                 /* Keep holding the socket if it is blocked */
322                 rpc_delay(task, HZ>>4);
323         }
324         return status;
325 }
326
327 /**
328  * xs_close - close a socket
329  * @xprt: transport
330  *
331  */
332 static void xs_close(struct rpc_xprt *xprt)
333 {
334         struct socket *sock = xprt->sock;
335         struct sock *sk = xprt->inet;
336
337         if (!sk)
338                 return;
339
340         dprintk("RPC:      xs_close xprt %p\n", xprt);
341
342         write_lock_bh(&sk->sk_callback_lock);
343         xprt->inet = NULL;
344         xprt->sock = NULL;
345
346         sk->sk_user_data = NULL;
347         sk->sk_data_ready = xprt->old_data_ready;
348         sk->sk_state_change = xprt->old_state_change;
349         sk->sk_write_space = xprt->old_write_space;
350         write_unlock_bh(&sk->sk_callback_lock);
351
352         sk->sk_no_check = 0;
353
354         sock_release(sock);
355 }
356
357 /**
358  * xs_destroy - prepare to shutdown a transport
359  * @xprt: doomed transport
360  *
361  */
362 static void xs_destroy(struct rpc_xprt *xprt)
363 {
364         dprintk("RPC:      xs_destroy xprt %p\n", xprt);
365
366         cancel_delayed_work(&xprt->connect_worker);
367         flush_scheduled_work();
368
369         xprt_disconnect(xprt);
370         xs_close(xprt);
371         kfree(xprt->slot);
372 }
373
374 static inline struct rpc_xprt *xprt_from_sock(struct sock *sk)
375 {
376         return (struct rpc_xprt *) sk->sk_user_data;
377 }
378
379 /**
380  * xs_udp_data_ready - "data ready" callback for UDP sockets
381  * @sk: socket with data to read
382  * @len: how much data to read
383  *
384  */
385 static void xs_udp_data_ready(struct sock *sk, int len)
386 {
387         struct rpc_task *task;
388         struct rpc_xprt *xprt;
389         struct rpc_rqst *rovr;
390         struct sk_buff *skb;
391         int err, repsize, copied;
392         u32 _xid, *xp;
393
394         read_lock(&sk->sk_callback_lock);
395         dprintk("RPC:      xs_udp_data_ready...\n");
396         if (!(xprt = xprt_from_sock(sk)))
397                 goto out;
398
399         if ((skb = skb_recv_datagram(sk, 0, 1, &err)) == NULL)
400                 goto out;
401
402         if (xprt->shutdown)
403                 goto dropit;
404
405         repsize = skb->len - sizeof(struct udphdr);
406         if (repsize < 4) {
407                 dprintk("RPC:      impossible RPC reply size %d!\n", repsize);
408                 goto dropit;
409         }
410
411         /* Copy the XID from the skb... */
412         xp = skb_header_pointer(skb, sizeof(struct udphdr),
413                                 sizeof(_xid), &_xid);
414         if (xp == NULL)
415                 goto dropit;
416
417         /* Look up and lock the request corresponding to the given XID */
418         spin_lock(&xprt->transport_lock);
419         rovr = xprt_lookup_rqst(xprt, *xp);
420         if (!rovr)
421                 goto out_unlock;
422         task = rovr->rq_task;
423
424         dprintk("RPC: %4d received reply\n", task->tk_pid);
425
426         if ((copied = rovr->rq_private_buf.buflen) > repsize)
427                 copied = repsize;
428
429         /* Suck it into the iovec, verify checksum if not done by hw. */
430         if (csum_partial_copy_to_xdr(&rovr->rq_private_buf, skb))
431                 goto out_unlock;
432
433         /* Something worked... */
434         dst_confirm(skb->dst);
435
436         xprt_complete_rqst(xprt, rovr, copied);
437
438  out_unlock:
439         spin_unlock(&xprt->transport_lock);
440  dropit:
441         skb_free_datagram(sk, skb);
442  out:
443         read_unlock(&sk->sk_callback_lock);
444 }
445
446 static inline size_t xs_tcp_copy_data(skb_reader_t *desc, void *p, size_t len)
447 {
448         if (len > desc->count)
449                 len = desc->count;
450         if (skb_copy_bits(desc->skb, desc->offset, p, len)) {
451                 dprintk("RPC:      failed to copy %zu bytes from skb. %zu bytes remain\n",
452                                 len, desc->count);
453                 return 0;
454         }
455         desc->offset += len;
456         desc->count -= len;
457         dprintk("RPC:      copied %zu bytes from skb. %zu bytes remain\n",
458                         len, desc->count);
459         return len;
460 }
461
462 static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, skb_reader_t *desc)
463 {
464         size_t len, used;
465         char *p;
466
467         p = ((char *) &xprt->tcp_recm) + xprt->tcp_offset;
468         len = sizeof(xprt->tcp_recm) - xprt->tcp_offset;
469         used = xs_tcp_copy_data(desc, p, len);
470         xprt->tcp_offset += used;
471         if (used != len)
472                 return;
473         xprt->tcp_reclen = ntohl(xprt->tcp_recm);
474         if (xprt->tcp_reclen & 0x80000000)
475                 xprt->tcp_flags |= XPRT_LAST_FRAG;
476         else
477                 xprt->tcp_flags &= ~XPRT_LAST_FRAG;
478         xprt->tcp_reclen &= 0x7fffffff;
479         xprt->tcp_flags &= ~XPRT_COPY_RECM;
480         xprt->tcp_offset = 0;
481         /* Sanity check of the record length */
482         if (xprt->tcp_reclen < 4) {
483                 dprintk("RPC:      invalid TCP record fragment length\n");
484                 xprt_disconnect(xprt);
485                 return;
486         }
487         dprintk("RPC:      reading TCP record fragment of length %d\n",
488                         xprt->tcp_reclen);
489 }
490
491 static void xs_tcp_check_recm(struct rpc_xprt *xprt)
492 {
493         dprintk("RPC:      xprt = %p, tcp_copied = %lu, tcp_offset = %u, tcp_reclen = %u, tcp_flags = %lx\n",
494                         xprt, xprt->tcp_copied, xprt->tcp_offset, xprt->tcp_reclen, xprt->tcp_flags);
495         if (xprt->tcp_offset == xprt->tcp_reclen) {
496                 xprt->tcp_flags |= XPRT_COPY_RECM;
497                 xprt->tcp_offset = 0;
498                 if (xprt->tcp_flags & XPRT_LAST_FRAG) {
499                         xprt->tcp_flags &= ~XPRT_COPY_DATA;
500                         xprt->tcp_flags |= XPRT_COPY_XID;
501                         xprt->tcp_copied = 0;
502                 }
503         }
504 }
505
506 static inline void xs_tcp_read_xid(struct rpc_xprt *xprt, skb_reader_t *desc)
507 {
508         size_t len, used;
509         char *p;
510
511         len = sizeof(xprt->tcp_xid) - xprt->tcp_offset;
512         dprintk("RPC:      reading XID (%Zu bytes)\n", len);
513         p = ((char *) &xprt->tcp_xid) + xprt->tcp_offset;
514         used = xs_tcp_copy_data(desc, p, len);
515         xprt->tcp_offset += used;
516         if (used != len)
517                 return;
518         xprt->tcp_flags &= ~XPRT_COPY_XID;
519         xprt->tcp_flags |= XPRT_COPY_DATA;
520         xprt->tcp_copied = 4;
521         dprintk("RPC:      reading reply for XID %08x\n",
522                                                 ntohl(xprt->tcp_xid));
523         xs_tcp_check_recm(xprt);
524 }
525
526 static inline void xs_tcp_read_request(struct rpc_xprt *xprt, skb_reader_t *desc)
527 {
528         struct rpc_rqst *req;
529         struct xdr_buf *rcvbuf;
530         size_t len;
531         ssize_t r;
532
533         /* Find and lock the request corresponding to this xid */
534         spin_lock(&xprt->transport_lock);
535         req = xprt_lookup_rqst(xprt, xprt->tcp_xid);
536         if (!req) {
537                 xprt->tcp_flags &= ~XPRT_COPY_DATA;
538                 dprintk("RPC:      XID %08x request not found!\n",
539                                 ntohl(xprt->tcp_xid));
540                 spin_unlock(&xprt->transport_lock);
541                 return;
542         }
543
544         rcvbuf = &req->rq_private_buf;
545         len = desc->count;
546         if (len > xprt->tcp_reclen - xprt->tcp_offset) {
547                 skb_reader_t my_desc;
548
549                 len = xprt->tcp_reclen - xprt->tcp_offset;
550                 memcpy(&my_desc, desc, sizeof(my_desc));
551                 my_desc.count = len;
552                 r = xdr_partial_copy_from_skb(rcvbuf, xprt->tcp_copied,
553                                           &my_desc, xs_tcp_copy_data);
554                 desc->count -= r;
555                 desc->offset += r;
556         } else
557                 r = xdr_partial_copy_from_skb(rcvbuf, xprt->tcp_copied,
558                                           desc, xs_tcp_copy_data);
559
560         if (r > 0) {
561                 xprt->tcp_copied += r;
562                 xprt->tcp_offset += r;
563         }
564         if (r != len) {
565                 /* Error when copying to the receive buffer,
566                  * usually because we weren't able to allocate
567                  * additional buffer pages. All we can do now
568                  * is turn off XPRT_COPY_DATA, so the request
569                  * will not receive any additional updates,
570                  * and time out.
571                  * Any remaining data from this record will
572                  * be discarded.
573                  */
574                 xprt->tcp_flags &= ~XPRT_COPY_DATA;
575                 dprintk("RPC:      XID %08x truncated request\n",
576                                 ntohl(xprt->tcp_xid));
577                 dprintk("RPC:      xprt = %p, tcp_copied = %lu, tcp_offset = %u, tcp_reclen = %u\n",
578                                 xprt, xprt->tcp_copied, xprt->tcp_offset, xprt->tcp_reclen);
579                 goto out;
580         }
581
582         dprintk("RPC:      XID %08x read %Zd bytes\n",
583                         ntohl(xprt->tcp_xid), r);
584         dprintk("RPC:      xprt = %p, tcp_copied = %lu, tcp_offset = %u, tcp_reclen = %u\n",
585                         xprt, xprt->tcp_copied, xprt->tcp_offset, xprt->tcp_reclen);
586
587         if (xprt->tcp_copied == req->rq_private_buf.buflen)
588                 xprt->tcp_flags &= ~XPRT_COPY_DATA;
589         else if (xprt->tcp_offset == xprt->tcp_reclen) {
590                 if (xprt->tcp_flags & XPRT_LAST_FRAG)
591                         xprt->tcp_flags &= ~XPRT_COPY_DATA;
592         }
593
594 out:
595         if (!(xprt->tcp_flags & XPRT_COPY_DATA)) {
596                 dprintk("RPC: %4d received reply complete\n",
597                                 req->rq_task->tk_pid);
598                 xprt_complete_rqst(xprt, req, xprt->tcp_copied);
599         }
600         spin_unlock(&xprt->transport_lock);
601         xs_tcp_check_recm(xprt);
602 }
603
604 static inline void xs_tcp_read_discard(struct rpc_xprt *xprt, skb_reader_t *desc)
605 {
606         size_t len;
607
608         len = xprt->tcp_reclen - xprt->tcp_offset;
609         if (len > desc->count)
610                 len = desc->count;
611         desc->count -= len;
612         desc->offset += len;
613         xprt->tcp_offset += len;
614         dprintk("RPC:      discarded %Zu bytes\n", len);
615         xs_tcp_check_recm(xprt);
616 }
617
618 static int xs_tcp_data_recv(read_descriptor_t *rd_desc, struct sk_buff *skb, unsigned int offset, size_t len)
619 {
620         struct rpc_xprt *xprt = rd_desc->arg.data;
621         skb_reader_t desc = {
622                 .skb    = skb,
623                 .offset = offset,
624                 .count  = len,
625                 .csum   = 0
626         };
627
628         dprintk("RPC:      xs_tcp_data_recv started\n");
629         do {
630                 /* Read in a new fragment marker if necessary */
631                 /* Can we ever really expect to get completely empty fragments? */
632                 if (xprt->tcp_flags & XPRT_COPY_RECM) {
633                         xs_tcp_read_fraghdr(xprt, &desc);
634                         continue;
635                 }
636                 /* Read in the xid if necessary */
637                 if (xprt->tcp_flags & XPRT_COPY_XID) {
638                         xs_tcp_read_xid(xprt, &desc);
639                         continue;
640                 }
641                 /* Read in the request data */
642                 if (xprt->tcp_flags & XPRT_COPY_DATA) {
643                         xs_tcp_read_request(xprt, &desc);
644                         continue;
645                 }
646                 /* Skip over any trailing bytes on short reads */
647                 xs_tcp_read_discard(xprt, &desc);
648         } while (desc.count);
649         dprintk("RPC:      xs_tcp_data_recv done\n");
650         return len - desc.count;
651 }
652
653 /**
654  * xs_tcp_data_ready - "data ready" callback for TCP sockets
655  * @sk: socket with data to read
656  * @bytes: how much data to read
657  *
658  */
659 static void xs_tcp_data_ready(struct sock *sk, int bytes)
660 {
661         struct rpc_xprt *xprt;
662         read_descriptor_t rd_desc;
663
664         read_lock(&sk->sk_callback_lock);
665         dprintk("RPC:      xs_tcp_data_ready...\n");
666         if (!(xprt = xprt_from_sock(sk)))
667                 goto out;
668         if (xprt->shutdown)
669                 goto out;
670
671         /* We use rd_desc to pass struct xprt to xs_tcp_data_recv */
672         rd_desc.arg.data = xprt;
673         rd_desc.count = 65536;
674         tcp_read_sock(sk, &rd_desc, xs_tcp_data_recv);
675 out:
676         read_unlock(&sk->sk_callback_lock);
677 }
678
679 /**
680  * xs_tcp_state_change - callback to handle TCP socket state changes
681  * @sk: socket whose state has changed
682  *
683  */
684 static void xs_tcp_state_change(struct sock *sk)
685 {
686         struct rpc_xprt *xprt;
687
688         read_lock(&sk->sk_callback_lock);
689         if (!(xprt = xprt_from_sock(sk)))
690                 goto out;
691         dprintk("RPC:      xs_tcp_state_change client %p...\n", xprt);
692         dprintk("RPC:      state %x conn %d dead %d zapped %d\n",
693                                 sk->sk_state, xprt_connected(xprt),
694                                 sock_flag(sk, SOCK_DEAD),
695                                 sock_flag(sk, SOCK_ZAPPED));
696
697         switch (sk->sk_state) {
698         case TCP_ESTABLISHED:
699                 spin_lock_bh(&xprt->transport_lock);
700                 if (!xprt_test_and_set_connected(xprt)) {
701                         /* Reset TCP record info */
702                         xprt->tcp_offset = 0;
703                         xprt->tcp_reclen = 0;
704                         xprt->tcp_copied = 0;
705                         xprt->tcp_flags = XPRT_COPY_RECM | XPRT_COPY_XID;
706                         xprt_wake_pending_tasks(xprt, 0);
707                 }
708                 spin_unlock_bh(&xprt->transport_lock);
709                 break;
710         case TCP_SYN_SENT:
711         case TCP_SYN_RECV:
712                 break;
713         default:
714                 xprt_disconnect(xprt);
715                 break;
716         }
717  out:
718         read_unlock(&sk->sk_callback_lock);
719 }
720
721 /**
722  * xs_udp_write_space - callback invoked when socket buffer space
723  *                             becomes available
724  * @sk: socket whose state has changed
725  *
726  * Called when more output buffer space is available for this socket.
727  * We try not to wake our writers until they can make "significant"
728  * progress, otherwise we'll waste resources thrashing kernel_sendmsg
729  * with a bunch of small requests.
730  */
731 static void xs_udp_write_space(struct sock *sk)
732 {
733         read_lock(&sk->sk_callback_lock);
734
735         /* from net/core/sock.c:sock_def_write_space */
736         if (sock_writeable(sk)) {
737                 struct socket *sock;
738                 struct rpc_xprt *xprt;
739
740                 if (unlikely(!(sock = sk->sk_socket)))
741                         goto out;
742                 if (unlikely(!(xprt = xprt_from_sock(sk))))
743                         goto out;
744                 if (unlikely(!test_and_clear_bit(SOCK_NOSPACE, &sock->flags)))
745                         goto out;
746
747                 xprt_write_space(xprt);
748         }
749
750  out:
751         read_unlock(&sk->sk_callback_lock);
752 }
753
754 /**
755  * xs_tcp_write_space - callback invoked when socket buffer space
756  *                             becomes available
757  * @sk: socket whose state has changed
758  *
759  * Called when more output buffer space is available for this socket.
760  * We try not to wake our writers until they can make "significant"
761  * progress, otherwise we'll waste resources thrashing kernel_sendmsg
762  * with a bunch of small requests.
763  */
764 static void xs_tcp_write_space(struct sock *sk)
765 {
766         read_lock(&sk->sk_callback_lock);
767
768         /* from net/core/stream.c:sk_stream_write_space */
769         if (sk_stream_wspace(sk) >= sk_stream_min_wspace(sk)) {
770                 struct socket *sock;
771                 struct rpc_xprt *xprt;
772
773                 if (unlikely(!(sock = sk->sk_socket)))
774                         goto out;
775                 if (unlikely(!(xprt = xprt_from_sock(sk))))
776                         goto out;
777                 if (unlikely(!test_and_clear_bit(SOCK_NOSPACE, &sock->flags)))
778                         goto out;
779
780                 xprt_write_space(xprt);
781         }
782
783  out:
784         read_unlock(&sk->sk_callback_lock);
785 }
786
787 /**
788  * xs_set_buffer_size - set send and receive limits
789  * @xprt: generic transport
790  *
791  * Set socket send and receive limits based on the
792  * sndsize and rcvsize fields in the generic transport
793  * structure. This applies only to UDP sockets.
794  */
795 static void xs_set_buffer_size(struct rpc_xprt *xprt)
796 {
797         struct sock *sk = xprt->inet;
798
799         if (xprt->stream)
800                 return;
801         if (xprt->rcvsize) {
802                 sk->sk_userlocks |= SOCK_RCVBUF_LOCK;
803                 sk->sk_rcvbuf = xprt->rcvsize * xprt->max_reqs *  2;
804         }
805         if (xprt->sndsize) {
806                 sk->sk_userlocks |= SOCK_SNDBUF_LOCK;
807                 sk->sk_sndbuf = xprt->sndsize * xprt->max_reqs * 2;
808                 sk->sk_write_space(sk);
809         }
810 }
811
812 static int xs_bindresvport(struct rpc_xprt *xprt, struct socket *sock)
813 {
814         struct sockaddr_in myaddr = {
815                 .sin_family = AF_INET,
816         };
817         int err, port;
818
819         /* Were we already bound to a given port? Try to reuse it */
820         port = xprt->port;
821         do {
822                 myaddr.sin_port = htons(port);
823                 err = sock->ops->bind(sock, (struct sockaddr *) &myaddr,
824                                                 sizeof(myaddr));
825                 if (err == 0) {
826                         xprt->port = port;
827                         dprintk("RPC:      xs_bindresvport bound to port %u\n",
828                                         port);
829                         return 0;
830                 }
831                 if (--port == 0)
832                         port = XS_MAX_RESVPORT;
833         } while (err == -EADDRINUSE && port != xprt->port);
834
835         dprintk("RPC:      can't bind to reserved port (%d).\n", -err);
836         return err;
837 }
838
839 static struct socket *xs_create(struct rpc_xprt *xprt, int proto, int resvport)
840 {
841         struct socket *sock;
842         int type, err;
843
844         dprintk("RPC:      xs_create(%s %d)\n",
845                            (proto == IPPROTO_UDP)? "udp" : "tcp", proto);
846
847         type = (proto == IPPROTO_UDP)? SOCK_DGRAM : SOCK_STREAM;
848
849         if ((err = sock_create_kern(PF_INET, type, proto, &sock)) < 0) {
850                 dprintk("RPC:      can't create socket (%d).\n", -err);
851                 return NULL;
852         }
853
854         /* If the caller has the capability, bind to a reserved port */
855         if (resvport && xs_bindresvport(xprt, sock) < 0)
856                 goto failed;
857
858         return sock;
859
860 failed:
861         sock_release(sock);
862         return NULL;
863 }
864
865 static void xs_bind(struct rpc_xprt *xprt, struct socket *sock)
866 {
867         struct sock *sk = sock->sk;
868
869         if (xprt->inet)
870                 return;
871
872         write_lock_bh(&sk->sk_callback_lock);
873         sk->sk_user_data = xprt;
874         xprt->old_data_ready = sk->sk_data_ready;
875         xprt->old_state_change = sk->sk_state_change;
876         xprt->old_write_space = sk->sk_write_space;
877         if (xprt->prot == IPPROTO_UDP) {
878                 sk->sk_data_ready = xs_udp_data_ready;
879                 sk->sk_write_space = xs_udp_write_space;
880                 sk->sk_no_check = UDP_CSUM_NORCV;
881                 xprt_set_connected(xprt);
882         } else {
883                 tcp_sk(sk)->nonagle = 1;        /* disable Nagle's algorithm */
884                 sk->sk_data_ready = xs_tcp_data_ready;
885                 sk->sk_state_change = xs_tcp_state_change;
886                 sk->sk_write_space = xs_tcp_write_space;
887                 xprt_clear_connected(xprt);
888         }
889
890         /* Reset to new socket */
891         xprt->sock = sock;
892         xprt->inet = sk;
893         write_unlock_bh(&sk->sk_callback_lock);
894
895         return;
896 }
897
898 /**
899  * xs_connect_worker - try to connect a socket to a remote endpoint
900  * @args: RPC transport to connect
901  *
902  * Invoked by a work queue tasklet.
903  */
904 static void xs_connect_worker(void *args)
905 {
906         struct rpc_xprt *xprt = (struct rpc_xprt *)args;
907         struct socket *sock = xprt->sock;
908         int status = -EIO;
909
910         if (xprt->shutdown || xprt->addr.sin_port == 0)
911                 goto out;
912
913         dprintk("RPC:      xs_connect_worker xprt %p\n", xprt);
914
915         /*
916          * Start by resetting any existing state
917          */
918         xs_close(xprt);
919         sock = xs_create(xprt, xprt->prot, xprt->resvport);
920         if (sock == NULL) {
921                 /* couldn't create socket or bind to reserved port;
922                  * this is likely a permanent error, so cause an abort */
923                 goto out;
924         }
925         xs_bind(xprt, sock);
926         xs_set_buffer_size(xprt);
927
928         status = 0;
929         if (!xprt->stream)
930                 goto out;
931
932         /*
933          * Tell the socket layer to start connecting...
934          */
935         status = sock->ops->connect(sock, (struct sockaddr *) &xprt->addr,
936                         sizeof(xprt->addr), O_NONBLOCK);
937         dprintk("RPC: %p  connect status %d connected %d sock state %d\n",
938                         xprt, -status, xprt_connected(xprt), sock->sk->sk_state);
939         if (status < 0) {
940                 switch (status) {
941                         case -EINPROGRESS:
942                         case -EALREADY:
943                                 goto out_clear;
944                 }
945         }
946 out:
947         xprt_wake_pending_tasks(xprt, status);
948 out_clear:
949         xprt_clear_connecting(xprt);
950 }
951
952 /**
953  * xs_connect - connect a socket to a remote endpoint
954  * @task: address of RPC task that manages state of connect request
955  *
956  * TCP: If the remote end dropped the connection, delay reconnecting.
957  */
958 static void xs_connect(struct rpc_task *task)
959 {
960         struct rpc_xprt *xprt = task->tk_xprt;
961
962         if (!xprt_test_and_set_connecting(xprt)) {
963                 if (xprt->sock != NULL) {
964                         dprintk("RPC:      xs_connect delayed xprt %p\n", xprt);
965                         schedule_delayed_work(&xprt->connect_worker,
966                                         RPC_REESTABLISH_TIMEOUT);
967                 } else {
968                         dprintk("RPC:      xs_connect scheduled xprt %p\n", xprt);
969                         schedule_work(&xprt->connect_worker);
970                         /* flush_scheduled_work can sleep... */
971                         if (!RPC_IS_ASYNC(task))
972                                 flush_scheduled_work();
973                 }
974         }
975 }
976
977 static struct rpc_xprt_ops xs_ops = {
978         .set_buffer_size        = xs_set_buffer_size,
979         .connect                = xs_connect,
980         .send_request           = xs_send_request,
981         .close                  = xs_close,
982         .destroy                = xs_destroy,
983 };
984
985 extern unsigned int xprt_udp_slot_table_entries;
986 extern unsigned int xprt_tcp_slot_table_entries;
987
988 /**
989  * xs_setup_udp - Set up transport to use a UDP socket
990  * @xprt: transport to set up
991  * @to:   timeout parameters
992  *
993  */
994 int xs_setup_udp(struct rpc_xprt *xprt, struct rpc_timeout *to)
995 {
996         size_t slot_table_size;
997
998         dprintk("RPC:      setting up udp-ipv4 transport...\n");
999
1000         xprt->max_reqs = xprt_udp_slot_table_entries;
1001         slot_table_size = xprt->max_reqs * sizeof(xprt->slot[0]);
1002         xprt->slot = kmalloc(slot_table_size, GFP_KERNEL);
1003         if (xprt->slot == NULL)
1004                 return -ENOMEM;
1005         memset(xprt->slot, 0, slot_table_size);
1006
1007         xprt->prot = IPPROTO_UDP;
1008         xprt->port = XS_MAX_RESVPORT;
1009         xprt->stream = 0;
1010         xprt->nocong = 0;
1011         xprt->cwnd = RPC_INITCWND;
1012         xprt->resvport = capable(CAP_NET_BIND_SERVICE) ? 1 : 0;
1013         /* XXX: header size can vary due to auth type, IPv6, etc. */
1014         xprt->max_payload = (1U << 16) - (MAX_HEADER << 3);
1015
1016         INIT_WORK(&xprt->connect_worker, xs_connect_worker, xprt);
1017
1018         xprt->ops = &xs_ops;
1019
1020         if (to)
1021                 xprt->timeout = *to;
1022         else
1023                 xprt_set_timeout(&xprt->timeout, 5, 5 * HZ);
1024
1025         return 0;
1026 }
1027
1028 /**
1029  * xs_setup_tcp - Set up transport to use a TCP socket
1030  * @xprt: transport to set up
1031  * @to: timeout parameters
1032  *
1033  */
1034 int xs_setup_tcp(struct rpc_xprt *xprt, struct rpc_timeout *to)
1035 {
1036         size_t slot_table_size;
1037
1038         dprintk("RPC:      setting up tcp-ipv4 transport...\n");
1039
1040         xprt->max_reqs = xprt_tcp_slot_table_entries;
1041         slot_table_size = xprt->max_reqs * sizeof(xprt->slot[0]);
1042         xprt->slot = kmalloc(slot_table_size, GFP_KERNEL);
1043         if (xprt->slot == NULL)
1044                 return -ENOMEM;
1045         memset(xprt->slot, 0, slot_table_size);
1046
1047         xprt->prot = IPPROTO_TCP;
1048         xprt->port = XS_MAX_RESVPORT;
1049         xprt->stream = 1;
1050         xprt->nocong = 1;
1051         xprt->cwnd = RPC_MAXCWND(xprt);
1052         xprt->resvport = capable(CAP_NET_BIND_SERVICE) ? 1 : 0;
1053         xprt->max_payload = (1U << 31) - 1;
1054
1055         INIT_WORK(&xprt->connect_worker, xs_connect_worker, xprt);
1056
1057         xprt->ops = &xs_ops;
1058
1059         if (to)
1060                 xprt->timeout = *to;
1061         else
1062                 xprt_set_timeout(&xprt->timeout, 2, 60 * HZ);
1063
1064         return 0;
1065 }