Merge branch 'linus' into cont_syslog
[safe/jmp/linux-2.6] / drivers / staging / hv / rndis_filter.c
1 /*
2  * Copyright (c) 2009, Microsoft Corporation.
3  *
4  * This program is free software; you can redistribute it and/or modify it
5  * under the terms and conditions of the GNU General Public License,
6  * version 2, as published by the Free Software Foundation.
7  *
8  * This program is distributed in the hope it will be useful, but WITHOUT
9  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
10  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
11  * more details.
12  *
13  * You should have received a copy of the GNU General Public License along with
14  * this program; if not, write to the Free Software Foundation, Inc., 59 Temple
15  * Place - Suite 330, Boston, MA 02111-1307 USA.
16  *
17  * Authors:
18  *   Haiyang Zhang <haiyangz@microsoft.com>
19  *   Hank Janssen  <hjanssen@microsoft.com>
20  */
21 #include <linux/kernel.h>
22 #include <linux/highmem.h>
23 #include <linux/slab.h>
24 #include <linux/io.h>
25 #include <linux/if_ether.h>
26
27 #include "osd.h"
28 #include "logging.h"
29 #include "netvsc_api.h"
30 #include "rndis_filter.h"
31
32 /* Data types */
33 struct rndis_filter_driver_object {
34         /* The original driver */
35         struct netvsc_driver InnerDriver;
36 };
37
38 enum rndis_device_state {
39         RNDIS_DEV_UNINITIALIZED = 0,
40         RNDIS_DEV_INITIALIZING,
41         RNDIS_DEV_INITIALIZED,
42         RNDIS_DEV_DATAINITIALIZED,
43 };
44
45 struct rndis_device {
46         struct netvsc_device *NetDevice;
47
48         enum rndis_device_state State;
49         u32 LinkStatus;
50         atomic_t NewRequestId;
51
52         spinlock_t request_lock;
53         struct list_head RequestList;
54
55         unsigned char HwMacAddr[ETH_ALEN];
56 };
57
58 struct rndis_request {
59         struct list_head ListEntry;
60         struct osd_waitevent *WaitEvent;
61
62         /*
63          * FIXME: We assumed a fixed size response here. If we do ever need to
64          * handle a bigger response, we can either define a max response
65          * message or add a response buffer variable above this field
66          */
67         struct rndis_message ResponseMessage;
68
69         /* Simplify allocation by having a netvsc packet inline */
70         struct hv_netvsc_packet Packet;
71         struct hv_page_buffer Buffer;
72         /* FIXME: We assumed a fixed size request here. */
73         struct rndis_message RequestMessage;
74 };
75
76
77 struct rndis_filter_packet {
78         void *CompletionContext;
79         void (*OnCompletion)(void *context);
80         struct rndis_message Message;
81 };
82
83
84 static int RndisFilterOnDeviceAdd(struct hv_device *Device,
85                                   void *AdditionalInfo);
86
87 static int RndisFilterOnDeviceRemove(struct hv_device *Device);
88
89 static void RndisFilterOnCleanup(struct hv_driver *Driver);
90
91 static int RndisFilterOnSend(struct hv_device *Device,
92                              struct hv_netvsc_packet *Packet);
93
94 static void RndisFilterOnSendCompletion(void *Context);
95
96 static void RndisFilterOnSendRequestCompletion(void *Context);
97
98
99 /* The one and only */
100 static struct rndis_filter_driver_object gRndisFilter;
101
102 static struct rndis_device *GetRndisDevice(void)
103 {
104         struct rndis_device *device;
105
106         device = kzalloc(sizeof(struct rndis_device), GFP_KERNEL);
107         if (!device)
108                 return NULL;
109
110         spin_lock_init(&device->request_lock);
111
112         INIT_LIST_HEAD(&device->RequestList);
113
114         device->State = RNDIS_DEV_UNINITIALIZED;
115
116         return device;
117 }
118
119 static struct rndis_request *GetRndisRequest(struct rndis_device *Device,
120                                              u32 MessageType,
121                                              u32 MessageLength)
122 {
123         struct rndis_request *request;
124         struct rndis_message *rndisMessage;
125         struct rndis_set_request *set;
126         unsigned long flags;
127
128         request = kzalloc(sizeof(struct rndis_request), GFP_KERNEL);
129         if (!request)
130                 return NULL;
131
132         request->WaitEvent = osd_WaitEventCreate();
133         if (!request->WaitEvent) {
134                 kfree(request);
135                 return NULL;
136         }
137
138         rndisMessage = &request->RequestMessage;
139         rndisMessage->NdisMessageType = MessageType;
140         rndisMessage->MessageLength = MessageLength;
141
142         /*
143          * Set the request id. This field is always after the rndis header for
144          * request/response packet types so we just used the SetRequest as a
145          * template
146          */
147         set = &rndisMessage->Message.SetRequest;
148         set->RequestId = atomic_inc_return(&Device->NewRequestId);
149
150         /* Add to the request list */
151         spin_lock_irqsave(&Device->request_lock, flags);
152         list_add_tail(&request->ListEntry, &Device->RequestList);
153         spin_unlock_irqrestore(&Device->request_lock, flags);
154
155         return request;
156 }
157
158 static void PutRndisRequest(struct rndis_device *Device,
159                             struct rndis_request *Request)
160 {
161         unsigned long flags;
162
163         spin_lock_irqsave(&Device->request_lock, flags);
164         list_del(&Request->ListEntry);
165         spin_unlock_irqrestore(&Device->request_lock, flags);
166
167         kfree(Request->WaitEvent);
168         kfree(Request);
169 }
170
171 static void DumpRndisMessage(struct rndis_message *RndisMessage)
172 {
173         switch (RndisMessage->NdisMessageType) {
174         case REMOTE_NDIS_PACKET_MSG:
175                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_PACKET_MSG (len %u, "
176                            "data offset %u data len %u, # oob %u, "
177                            "oob offset %u, oob len %u, pkt offset %u, "
178                            "pkt len %u",
179                            RndisMessage->MessageLength,
180                            RndisMessage->Message.Packet.DataOffset,
181                            RndisMessage->Message.Packet.DataLength,
182                            RndisMessage->Message.Packet.NumOOBDataElements,
183                            RndisMessage->Message.Packet.OOBDataOffset,
184                            RndisMessage->Message.Packet.OOBDataLength,
185                            RndisMessage->Message.Packet.PerPacketInfoOffset,
186                            RndisMessage->Message.Packet.PerPacketInfoLength);
187                 break;
188
189         case REMOTE_NDIS_INITIALIZE_CMPLT:
190                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_INITIALIZE_CMPLT "
191                         "(len %u, id 0x%x, status 0x%x, major %d, minor %d, "
192                         "device flags %d, max xfer size 0x%x, max pkts %u, "
193                         "pkt aligned %u)",
194                         RndisMessage->MessageLength,
195                         RndisMessage->Message.InitializeComplete.RequestId,
196                         RndisMessage->Message.InitializeComplete.Status,
197                         RndisMessage->Message.InitializeComplete.MajorVersion,
198                         RndisMessage->Message.InitializeComplete.MinorVersion,
199                         RndisMessage->Message.InitializeComplete.DeviceFlags,
200                         RndisMessage->Message.InitializeComplete.MaxTransferSize,
201                         RndisMessage->Message.InitializeComplete.MaxPacketsPerMessage,
202                         RndisMessage->Message.InitializeComplete.PacketAlignmentFactor);
203                 break;
204
205         case REMOTE_NDIS_QUERY_CMPLT:
206                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_QUERY_CMPLT "
207                         "(len %u, id 0x%x, status 0x%x, buf len %u, "
208                         "buf offset %u)",
209                         RndisMessage->MessageLength,
210                         RndisMessage->Message.QueryComplete.RequestId,
211                         RndisMessage->Message.QueryComplete.Status,
212                         RndisMessage->Message.QueryComplete.InformationBufferLength,
213                         RndisMessage->Message.QueryComplete.InformationBufferOffset);
214                 break;
215
216         case REMOTE_NDIS_SET_CMPLT:
217                 DPRINT_DBG(NETVSC,
218                         "REMOTE_NDIS_SET_CMPLT (len %u, id 0x%x, status 0x%x)",
219                         RndisMessage->MessageLength,
220                         RndisMessage->Message.SetComplete.RequestId,
221                         RndisMessage->Message.SetComplete.Status);
222                 break;
223
224         case REMOTE_NDIS_INDICATE_STATUS_MSG:
225                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_INDICATE_STATUS_MSG "
226                         "(len %u, status 0x%x, buf len %u, buf offset %u)",
227                         RndisMessage->MessageLength,
228                         RndisMessage->Message.IndicateStatus.Status,
229                         RndisMessage->Message.IndicateStatus.StatusBufferLength,
230                         RndisMessage->Message.IndicateStatus.StatusBufferOffset);
231                 break;
232
233         default:
234                 DPRINT_DBG(NETVSC, "0x%x (len %u)",
235                         RndisMessage->NdisMessageType,
236                         RndisMessage->MessageLength);
237                 break;
238         }
239 }
240
241 static int RndisFilterSendRequest(struct rndis_device *Device,
242                                   struct rndis_request *Request)
243 {
244         int ret;
245         struct hv_netvsc_packet *packet;
246
247         DPRINT_ENTER(NETVSC);
248
249         /* Setup the packet to send it */
250         packet = &Request->Packet;
251
252         packet->IsDataPacket = false;
253         packet->TotalDataBufferLength = Request->RequestMessage.MessageLength;
254         packet->PageBufferCount = 1;
255
256         packet->PageBuffers[0].Pfn = virt_to_phys(&Request->RequestMessage) >>
257                                         PAGE_SHIFT;
258         packet->PageBuffers[0].Length = Request->RequestMessage.MessageLength;
259         packet->PageBuffers[0].Offset =
260                 (unsigned long)&Request->RequestMessage & (PAGE_SIZE - 1);
261
262         packet->Completion.Send.SendCompletionContext = Request;/* packet; */
263         packet->Completion.Send.OnSendCompletion =
264                 RndisFilterOnSendRequestCompletion;
265         packet->Completion.Send.SendCompletionTid = (unsigned long)Device;
266
267         ret = gRndisFilter.InnerDriver.OnSend(Device->NetDevice->Device, packet);
268         DPRINT_EXIT(NETVSC);
269         return ret;
270 }
271
272 static void RndisFilterReceiveResponse(struct rndis_device *Device,
273                                        struct rndis_message *Response)
274 {
275         struct rndis_request *request = NULL;
276         bool found = false;
277         unsigned long flags;
278
279         DPRINT_ENTER(NETVSC);
280
281         spin_lock_irqsave(&Device->request_lock, flags);
282         list_for_each_entry(request, &Device->RequestList, ListEntry) {
283                 /*
284                  * All request/response message contains RequestId as the 1st
285                  * field
286                  */
287                 if (request->RequestMessage.Message.InitializeRequest.RequestId
288                     == Response->Message.InitializeComplete.RequestId) {
289                         DPRINT_DBG(NETVSC, "found rndis request for "
290                                 "this response (id 0x%x req type 0x%x res "
291                                 "type 0x%x)",
292                                 request->RequestMessage.Message.InitializeRequest.RequestId,
293                                 request->RequestMessage.NdisMessageType,
294                                 Response->NdisMessageType);
295
296                         found = true;
297                         break;
298                 }
299         }
300         spin_unlock_irqrestore(&Device->request_lock, flags);
301
302         if (found) {
303                 if (Response->MessageLength <= sizeof(struct rndis_message)) {
304                         memcpy(&request->ResponseMessage, Response,
305                                Response->MessageLength);
306                 } else {
307                         DPRINT_ERR(NETVSC, "rndis response buffer overflow "
308                                   "detected (size %u max %zu)",
309                                   Response->MessageLength,
310                                   sizeof(struct rndis_filter_packet));
311
312                         if (Response->NdisMessageType ==
313                             REMOTE_NDIS_RESET_CMPLT) {
314                                 /* does not have a request id field */
315                                 request->ResponseMessage.Message.ResetComplete.Status = STATUS_BUFFER_OVERFLOW;
316                         } else {
317                                 request->ResponseMessage.Message.InitializeComplete.Status = STATUS_BUFFER_OVERFLOW;
318                         }
319                 }
320
321                 osd_WaitEventSet(request->WaitEvent);
322         } else {
323                 DPRINT_ERR(NETVSC, "no rndis request found for this response "
324                            "(id 0x%x res type 0x%x)",
325                            Response->Message.InitializeComplete.RequestId,
326                            Response->NdisMessageType);
327         }
328
329         DPRINT_EXIT(NETVSC);
330 }
331
332 static void RndisFilterReceiveIndicateStatus(struct rndis_device *Device,
333                                              struct rndis_message *Response)
334 {
335         struct rndis_indicate_status *indicate =
336                         &Response->Message.IndicateStatus;
337
338         if (indicate->Status == RNDIS_STATUS_MEDIA_CONNECT) {
339                 gRndisFilter.InnerDriver.OnLinkStatusChanged(Device->NetDevice->Device, 1);
340         } else if (indicate->Status == RNDIS_STATUS_MEDIA_DISCONNECT) {
341                 gRndisFilter.InnerDriver.OnLinkStatusChanged(Device->NetDevice->Device, 0);
342         } else {
343                 /*
344                  * TODO:
345                  */
346         }
347 }
348
349 static void RndisFilterReceiveData(struct rndis_device *Device,
350                                    struct rndis_message *Message,
351                                    struct hv_netvsc_packet *Packet)
352 {
353         struct rndis_packet *rndisPacket;
354         u32 dataOffset;
355
356         DPRINT_ENTER(NETVSC);
357
358         /* empty ethernet frame ?? */
359         /* ASSERT(Packet->PageBuffers[0].Length > */
360         /*      RNDIS_MESSAGE_SIZE(struct rndis_packet)); */
361
362         rndisPacket = &Message->Message.Packet;
363
364         /*
365          * FIXME: Handle multiple rndis pkt msgs that maybe enclosed in this
366          * netvsc packet (ie TotalDataBufferLength != MessageLength)
367          */
368
369         /* Remove the rndis header and pass it back up the stack */
370         dataOffset = RNDIS_HEADER_SIZE + rndisPacket->DataOffset;
371
372         Packet->TotalDataBufferLength -= dataOffset;
373         Packet->PageBuffers[0].Offset += dataOffset;
374         Packet->PageBuffers[0].Length -= dataOffset;
375
376         Packet->IsDataPacket = true;
377
378         gRndisFilter.InnerDriver.OnReceiveCallback(Device->NetDevice->Device,
379                                                    Packet);
380
381         DPRINT_EXIT(NETVSC);
382 }
383
384 static int RndisFilterOnReceive(struct hv_device *Device,
385                                 struct hv_netvsc_packet *Packet)
386 {
387         struct netvsc_device *netDevice = Device->Extension;
388         struct rndis_device *rndisDevice;
389         struct rndis_message rndisMessage;
390         struct rndis_message *rndisHeader;
391
392         DPRINT_ENTER(NETVSC);
393
394         if (!netDevice)
395                 return -EINVAL;
396
397         /* Make sure the rndis device state is initialized */
398         if (!netDevice->Extension) {
399                 DPRINT_ERR(NETVSC, "got rndis message but no rndis device..."
400                           "dropping this message!");
401                 DPRINT_EXIT(NETVSC);
402                 return -1;
403         }
404
405         rndisDevice = (struct rndis_device *)netDevice->Extension;
406         if (rndisDevice->State == RNDIS_DEV_UNINITIALIZED) {
407                 DPRINT_ERR(NETVSC, "got rndis message but rndis device "
408                            "uninitialized...dropping this message!");
409                 DPRINT_EXIT(NETVSC);
410                 return -1;
411         }
412
413         rndisHeader = (struct rndis_message *)kmap_atomic(
414                         pfn_to_page(Packet->PageBuffers[0].Pfn), KM_IRQ0);
415
416         rndisHeader = (void *)((unsigned long)rndisHeader +
417                         Packet->PageBuffers[0].Offset);
418
419         /* Make sure we got a valid rndis message */
420         /*
421          * FIXME: There seems to be a bug in set completion msg where its
422          * MessageLength is 16 bytes but the ByteCount field in the xfer page
423          * range shows 52 bytes
424          * */
425 #if 0
426         if (Packet->TotalDataBufferLength != rndisHeader->MessageLength) {
427                 kunmap_atomic(rndisHeader - Packet->PageBuffers[0].Offset,
428                               KM_IRQ0);
429
430                 DPRINT_ERR(NETVSC, "invalid rndis message? (expected %u "
431                            "bytes got %u)...dropping this message!",
432                            rndisHeader->MessageLength,
433                            Packet->TotalDataBufferLength);
434                 DPRINT_EXIT(NETVSC);
435                 return -1;
436         }
437 #endif
438
439         if ((rndisHeader->NdisMessageType != REMOTE_NDIS_PACKET_MSG) &&
440             (rndisHeader->MessageLength > sizeof(struct rndis_message))) {
441                 DPRINT_ERR(NETVSC, "incoming rndis message buffer overflow "
442                            "detected (got %u, max %zu)...marking it an error!",
443                            rndisHeader->MessageLength,
444                            sizeof(struct rndis_message));
445         }
446
447         memcpy(&rndisMessage, rndisHeader,
448                 (rndisHeader->MessageLength > sizeof(struct rndis_message)) ?
449                         sizeof(struct rndis_message) :
450                         rndisHeader->MessageLength);
451
452         kunmap_atomic(rndisHeader - Packet->PageBuffers[0].Offset, KM_IRQ0);
453
454         DumpRndisMessage(&rndisMessage);
455
456         switch (rndisMessage.NdisMessageType) {
457         case REMOTE_NDIS_PACKET_MSG:
458                 /* data msg */
459                 RndisFilterReceiveData(rndisDevice, &rndisMessage, Packet);
460                 break;
461
462         case REMOTE_NDIS_INITIALIZE_CMPLT:
463         case REMOTE_NDIS_QUERY_CMPLT:
464         case REMOTE_NDIS_SET_CMPLT:
465         /* case REMOTE_NDIS_RESET_CMPLT: */
466         /* case REMOTE_NDIS_KEEPALIVE_CMPLT: */
467                 /* completion msgs */
468                 RndisFilterReceiveResponse(rndisDevice, &rndisMessage);
469                 break;
470
471         case REMOTE_NDIS_INDICATE_STATUS_MSG:
472                 /* notification msgs */
473                 RndisFilterReceiveIndicateStatus(rndisDevice, &rndisMessage);
474                 break;
475         default:
476                 DPRINT_ERR(NETVSC, "unhandled rndis message (type %u len %u)",
477                            rndisMessage.NdisMessageType,
478                            rndisMessage.MessageLength);
479                 break;
480         }
481
482         DPRINT_EXIT(NETVSC);
483         return 0;
484 }
485
486 static int RndisFilterQueryDevice(struct rndis_device *Device, u32 Oid,
487                                   void *Result, u32 *ResultSize)
488 {
489         struct rndis_request *request;
490         u32 inresultSize = *ResultSize;
491         struct rndis_query_request *query;
492         struct rndis_query_complete *queryComplete;
493         int ret = 0;
494
495         DPRINT_ENTER(NETVSC);
496
497         if (!Result)
498                 return -EINVAL;
499
500         *ResultSize = 0;
501         request = GetRndisRequest(Device, REMOTE_NDIS_QUERY_MSG,
502                         RNDIS_MESSAGE_SIZE(struct rndis_query_request));
503         if (!request) {
504                 ret = -1;
505                 goto Cleanup;
506         }
507
508         /* Setup the rndis query */
509         query = &request->RequestMessage.Message.QueryRequest;
510         query->Oid = Oid;
511         query->InformationBufferOffset = sizeof(struct rndis_query_request);
512         query->InformationBufferLength = 0;
513         query->DeviceVcHandle = 0;
514
515         ret = RndisFilterSendRequest(Device, request);
516         if (ret != 0)
517                 goto Cleanup;
518
519         osd_WaitEventWait(request->WaitEvent);
520
521         /* Copy the response back */
522         queryComplete = &request->ResponseMessage.Message.QueryComplete;
523
524         if (queryComplete->InformationBufferLength > inresultSize) {
525                 ret = -1;
526                 goto Cleanup;
527         }
528
529         memcpy(Result,
530                (void *)((unsigned long)queryComplete +
531                          queryComplete->InformationBufferOffset),
532                queryComplete->InformationBufferLength);
533
534         *ResultSize = queryComplete->InformationBufferLength;
535
536 Cleanup:
537         if (request)
538                 PutRndisRequest(Device, request);
539         DPRINT_EXIT(NETVSC);
540
541         return ret;
542 }
543
544 static int RndisFilterQueryDeviceMac(struct rndis_device *Device)
545 {
546         u32 size = ETH_ALEN;
547
548         return RndisFilterQueryDevice(Device,
549                                       RNDIS_OID_802_3_PERMANENT_ADDRESS,
550                                       Device->HwMacAddr, &size);
551 }
552
553 static int RndisFilterQueryDeviceLinkStatus(struct rndis_device *Device)
554 {
555         u32 size = sizeof(u32);
556
557         return RndisFilterQueryDevice(Device,
558                                       RNDIS_OID_GEN_MEDIA_CONNECT_STATUS,
559                                       &Device->LinkStatus, &size);
560 }
561
562 static int RndisFilterSetPacketFilter(struct rndis_device *Device,
563                                       u32 NewFilter)
564 {
565         struct rndis_request *request;
566         struct rndis_set_request *set;
567         struct rndis_set_complete *setComplete;
568         u32 status;
569         int ret;
570
571         DPRINT_ENTER(NETVSC);
572
573         /* ASSERT(RNDIS_MESSAGE_SIZE(struct rndis_set_request) + sizeof(u32) <= */
574         /*      sizeof(struct rndis_message)); */
575
576         request = GetRndisRequest(Device, REMOTE_NDIS_SET_MSG,
577                         RNDIS_MESSAGE_SIZE(struct rndis_set_request) +
578                         sizeof(u32));
579         if (!request) {
580                 ret = -1;
581                 goto Cleanup;
582         }
583
584         /* Setup the rndis set */
585         set = &request->RequestMessage.Message.SetRequest;
586         set->Oid = RNDIS_OID_GEN_CURRENT_PACKET_FILTER;
587         set->InformationBufferLength = sizeof(u32);
588         set->InformationBufferOffset = sizeof(struct rndis_set_request);
589
590         memcpy((void *)(unsigned long)set + sizeof(struct rndis_set_request),
591                &NewFilter, sizeof(u32));
592
593         ret = RndisFilterSendRequest(Device, request);
594         if (ret != 0)
595                 goto Cleanup;
596
597         ret = osd_WaitEventWaitEx(request->WaitEvent, 2000/*2sec*/);
598         if (!ret) {
599                 ret = -1;
600                 DPRINT_ERR(NETVSC, "timeout before we got a set response...");
601                 /*
602                  * We cant deallocate the request since we may still receive a
603                  * send completion for it.
604                  */
605                 goto Exit;
606         } else {
607                 if (ret > 0)
608                         ret = 0;
609                 setComplete = &request->ResponseMessage.Message.SetComplete;
610                 status = setComplete->Status;
611         }
612
613 Cleanup:
614         if (request)
615                 PutRndisRequest(Device, request);
616 Exit:
617         DPRINT_EXIT(NETVSC);
618
619         return ret;
620 }
621
622 int RndisFilterInit(struct netvsc_driver *Driver)
623 {
624         DPRINT_ENTER(NETVSC);
625
626         DPRINT_DBG(NETVSC, "sizeof(struct rndis_filter_packet) == %zd",
627                    sizeof(struct rndis_filter_packet));
628
629         Driver->RequestExtSize = sizeof(struct rndis_filter_packet);
630
631         /* Driver->Context = rndisDriver; */
632
633         memset(&gRndisFilter, 0, sizeof(struct rndis_filter_driver_object));
634
635         /*rndisDriver->Driver = Driver;
636
637         ASSERT(Driver->OnLinkStatusChanged);
638         rndisDriver->OnLinkStatusChanged = Driver->OnLinkStatusChanged;*/
639
640         /* Save the original dispatch handlers before we override it */
641         gRndisFilter.InnerDriver.Base.OnDeviceAdd = Driver->Base.OnDeviceAdd;
642         gRndisFilter.InnerDriver.Base.OnDeviceRemove =
643                                         Driver->Base.OnDeviceRemove;
644         gRndisFilter.InnerDriver.Base.OnCleanup = Driver->Base.OnCleanup;
645
646         /* ASSERT(Driver->OnSend); */
647         /* ASSERT(Driver->OnReceiveCallback); */
648         gRndisFilter.InnerDriver.OnSend = Driver->OnSend;
649         gRndisFilter.InnerDriver.OnReceiveCallback = Driver->OnReceiveCallback;
650         gRndisFilter.InnerDriver.OnLinkStatusChanged =
651                                         Driver->OnLinkStatusChanged;
652
653         /* Override */
654         Driver->Base.OnDeviceAdd = RndisFilterOnDeviceAdd;
655         Driver->Base.OnDeviceRemove = RndisFilterOnDeviceRemove;
656         Driver->Base.OnCleanup = RndisFilterOnCleanup;
657         Driver->OnSend = RndisFilterOnSend;
658         /* Driver->QueryLinkStatus = RndisFilterQueryDeviceLinkStatus; */
659         Driver->OnReceiveCallback = RndisFilterOnReceive;
660
661         DPRINT_EXIT(NETVSC);
662
663         return 0;
664 }
665
666 static int RndisFilterInitDevice(struct rndis_device *Device)
667 {
668         struct rndis_request *request;
669         struct rndis_initialize_request *init;
670         struct rndis_initialize_complete *initComplete;
671         u32 status;
672         int ret;
673
674         DPRINT_ENTER(NETVSC);
675
676         request = GetRndisRequest(Device, REMOTE_NDIS_INITIALIZE_MSG,
677                         RNDIS_MESSAGE_SIZE(struct rndis_initialize_request));
678         if (!request) {
679                 ret = -1;
680                 goto Cleanup;
681         }
682
683         /* Setup the rndis set */
684         init = &request->RequestMessage.Message.InitializeRequest;
685         init->MajorVersion = RNDIS_MAJOR_VERSION;
686         init->MinorVersion = RNDIS_MINOR_VERSION;
687         /* FIXME: Use 1536 - rounded ethernet frame size */
688         init->MaxTransferSize = 2048;
689
690         Device->State = RNDIS_DEV_INITIALIZING;
691
692         ret = RndisFilterSendRequest(Device, request);
693         if (ret != 0) {
694                 Device->State = RNDIS_DEV_UNINITIALIZED;
695                 goto Cleanup;
696         }
697
698         osd_WaitEventWait(request->WaitEvent);
699
700         initComplete = &request->ResponseMessage.Message.InitializeComplete;
701         status = initComplete->Status;
702         if (status == RNDIS_STATUS_SUCCESS) {
703                 Device->State = RNDIS_DEV_INITIALIZED;
704                 ret = 0;
705         } else {
706                 Device->State = RNDIS_DEV_UNINITIALIZED;
707                 ret = -1;
708         }
709
710 Cleanup:
711         if (request)
712                 PutRndisRequest(Device, request);
713         DPRINT_EXIT(NETVSC);
714
715         return ret;
716 }
717
718 static void RndisFilterHaltDevice(struct rndis_device *Device)
719 {
720         struct rndis_request *request;
721         struct rndis_halt_request *halt;
722
723         DPRINT_ENTER(NETVSC);
724
725         /* Attempt to do a rndis device halt */
726         request = GetRndisRequest(Device, REMOTE_NDIS_HALT_MSG,
727                                 RNDIS_MESSAGE_SIZE(struct rndis_halt_request));
728         if (!request)
729                 goto Cleanup;
730
731         /* Setup the rndis set */
732         halt = &request->RequestMessage.Message.HaltRequest;
733         halt->RequestId = atomic_inc_return(&Device->NewRequestId);
734
735         /* Ignore return since this msg is optional. */
736         RndisFilterSendRequest(Device, request);
737
738         Device->State = RNDIS_DEV_UNINITIALIZED;
739
740 Cleanup:
741         if (request)
742                 PutRndisRequest(Device, request);
743         DPRINT_EXIT(NETVSC);
744         return;
745 }
746
747 static int RndisFilterOpenDevice(struct rndis_device *Device)
748 {
749         int ret;
750
751         DPRINT_ENTER(NETVSC);
752
753         if (Device->State != RNDIS_DEV_INITIALIZED)
754                 return 0;
755
756         ret = RndisFilterSetPacketFilter(Device,
757                                          NDIS_PACKET_TYPE_BROADCAST |
758                                          NDIS_PACKET_TYPE_ALL_MULTICAST |
759                                          NDIS_PACKET_TYPE_DIRECTED);
760         if (ret == 0)
761                 Device->State = RNDIS_DEV_DATAINITIALIZED;
762
763         DPRINT_EXIT(NETVSC);
764         return ret;
765 }
766
767 static int RndisFilterCloseDevice(struct rndis_device *Device)
768 {
769         int ret;
770
771         DPRINT_ENTER(NETVSC);
772
773         if (Device->State != RNDIS_DEV_DATAINITIALIZED)
774                 return 0;
775
776         ret = RndisFilterSetPacketFilter(Device, 0);
777         if (ret == 0)
778                 Device->State = RNDIS_DEV_INITIALIZED;
779
780         DPRINT_EXIT(NETVSC);
781
782         return ret;
783 }
784
785 static int RndisFilterOnDeviceAdd(struct hv_device *Device,
786                                   void *AdditionalInfo)
787 {
788         int ret;
789         struct netvsc_device *netDevice;
790         struct rndis_device *rndisDevice;
791         struct netvsc_device_info *deviceInfo = AdditionalInfo;
792
793         DPRINT_ENTER(NETVSC);
794
795         rndisDevice = GetRndisDevice();
796         if (!rndisDevice) {
797                 DPRINT_EXIT(NETVSC);
798                 return -1;
799         }
800
801         DPRINT_DBG(NETVSC, "rndis device object allocated - %p", rndisDevice);
802
803         /*
804          * Let the inner driver handle this first to create the netvsc channel
805          * NOTE! Once the channel is created, we may get a receive callback
806          * (RndisFilterOnReceive()) before this call is completed
807          */
808         ret = gRndisFilter.InnerDriver.Base.OnDeviceAdd(Device, AdditionalInfo);
809         if (ret != 0) {
810                 kfree(rndisDevice);
811                 DPRINT_EXIT(NETVSC);
812                 return ret;
813         }
814
815
816         /* Initialize the rndis device */
817         netDevice = Device->Extension;
818         /* ASSERT(netDevice); */
819         /* ASSERT(netDevice->Device); */
820
821         netDevice->Extension = rndisDevice;
822         rndisDevice->NetDevice = netDevice;
823
824         /* Send the rndis initialization message */
825         ret = RndisFilterInitDevice(rndisDevice);
826         if (ret != 0) {
827                 /*
828                  * TODO: If rndis init failed, we will need to shut down the
829                  * channel
830                  */
831         }
832
833         /* Get the mac address */
834         ret = RndisFilterQueryDeviceMac(rndisDevice);
835         if (ret != 0) {
836                 /*
837                  * TODO: shutdown rndis device and the channel
838                  */
839         }
840
841         DPRINT_INFO(NETVSC, "Device 0x%p mac addr %pM",
842                     rndisDevice, rndisDevice->HwMacAddr);
843
844         memcpy(deviceInfo->MacAddr, rndisDevice->HwMacAddr, ETH_ALEN);
845
846         RndisFilterQueryDeviceLinkStatus(rndisDevice);
847
848         deviceInfo->LinkState = rndisDevice->LinkStatus;
849         DPRINT_INFO(NETVSC, "Device 0x%p link state %s", rndisDevice,
850                     ((deviceInfo->LinkState) ? ("down") : ("up")));
851
852         DPRINT_EXIT(NETVSC);
853
854         return ret;
855 }
856
857 static int RndisFilterOnDeviceRemove(struct hv_device *Device)
858 {
859         struct netvsc_device *netDevice = Device->Extension;
860         struct rndis_device *rndisDevice = netDevice->Extension;
861
862         DPRINT_ENTER(NETVSC);
863
864         /* Halt and release the rndis device */
865         RndisFilterHaltDevice(rndisDevice);
866
867         kfree(rndisDevice);
868         netDevice->Extension = NULL;
869
870         /* Pass control to inner driver to remove the device */
871         gRndisFilter.InnerDriver.Base.OnDeviceRemove(Device);
872
873         DPRINT_EXIT(NETVSC);
874
875         return 0;
876 }
877
878 static void RndisFilterOnCleanup(struct hv_driver *Driver)
879 {
880         DPRINT_ENTER(NETVSC);
881
882         DPRINT_EXIT(NETVSC);
883 }
884
885 int RndisFilterOnOpen(struct hv_device *Device)
886 {
887         int ret;
888         struct netvsc_device *netDevice = Device->Extension;
889
890         DPRINT_ENTER(NETVSC);
891
892         if (!netDevice)
893                 return -EINVAL;
894
895         ret = RndisFilterOpenDevice(netDevice->Extension);
896
897         DPRINT_EXIT(NETVSC);
898
899         return ret;
900 }
901
902 int RndisFilterOnClose(struct hv_device *Device)
903 {
904         int ret;
905         struct netvsc_device *netDevice = Device->Extension;
906
907         DPRINT_ENTER(NETVSC);
908
909         if (!netDevice)
910                 return -EINVAL;
911
912         ret = RndisFilterCloseDevice(netDevice->Extension);
913
914         DPRINT_EXIT(NETVSC);
915
916         return ret;
917 }
918
919 static int RndisFilterOnSend(struct hv_device *Device,
920                              struct hv_netvsc_packet *Packet)
921 {
922         int ret;
923         struct rndis_filter_packet *filterPacket;
924         struct rndis_message *rndisMessage;
925         struct rndis_packet *rndisPacket;
926         u32 rndisMessageSize;
927
928         DPRINT_ENTER(NETVSC);
929
930         /* Add the rndis header */
931         filterPacket = (struct rndis_filter_packet *)Packet->Extension;
932         /* ASSERT(filterPacket); */
933
934         memset(filterPacket, 0, sizeof(struct rndis_filter_packet));
935
936         rndisMessage = &filterPacket->Message;
937         rndisMessageSize = RNDIS_MESSAGE_SIZE(struct rndis_packet);
938
939         rndisMessage->NdisMessageType = REMOTE_NDIS_PACKET_MSG;
940         rndisMessage->MessageLength = Packet->TotalDataBufferLength +
941                                       rndisMessageSize;
942
943         rndisPacket = &rndisMessage->Message.Packet;
944         rndisPacket->DataOffset = sizeof(struct rndis_packet);
945         rndisPacket->DataLength = Packet->TotalDataBufferLength;
946
947         Packet->IsDataPacket = true;
948         Packet->PageBuffers[0].Pfn = virt_to_phys(rndisMessage) >> PAGE_SHIFT;
949         Packet->PageBuffers[0].Offset =
950                         (unsigned long)rndisMessage & (PAGE_SIZE-1);
951         Packet->PageBuffers[0].Length = rndisMessageSize;
952
953         /* Save the packet send completion and context */
954         filterPacket->OnCompletion = Packet->Completion.Send.OnSendCompletion;
955         filterPacket->CompletionContext =
956                                 Packet->Completion.Send.SendCompletionContext;
957
958         /* Use ours */
959         Packet->Completion.Send.OnSendCompletion = RndisFilterOnSendCompletion;
960         Packet->Completion.Send.SendCompletionContext = filterPacket;
961
962         ret = gRndisFilter.InnerDriver.OnSend(Device, Packet);
963         if (ret != 0) {
964                 /*
965                  * Reset the completion to originals to allow retries from
966                  * above
967                  */
968                 Packet->Completion.Send.OnSendCompletion =
969                                 filterPacket->OnCompletion;
970                 Packet->Completion.Send.SendCompletionContext =
971                                 filterPacket->CompletionContext;
972         }
973
974         DPRINT_EXIT(NETVSC);
975
976         return ret;
977 }
978
979 static void RndisFilterOnSendCompletion(void *Context)
980 {
981         struct rndis_filter_packet *filterPacket = Context;
982
983         DPRINT_ENTER(NETVSC);
984
985         /* Pass it back to the original handler */
986         filterPacket->OnCompletion(filterPacket->CompletionContext);
987
988         DPRINT_EXIT(NETVSC);
989 }
990
991
992 static void RndisFilterOnSendRequestCompletion(void *Context)
993 {
994         DPRINT_ENTER(NETVSC);
995
996         /* Noop */
997         DPRINT_EXIT(NETVSC);
998 }