Staging: hv: move StorVscApi.h
[safe/jmp/linux-2.6] / drivers / staging / hv / RndisFilter.c
1 /*
2  *
3  * Copyright (c) 2009, Microsoft Corporation.
4  *
5  * This program is free software; you can redistribute it and/or modify it
6  * under the terms and conditions of the GNU General Public License,
7  * version 2, as published by the Free Software Foundation.
8  *
9  * This program is distributed in the hope it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
12  * more details.
13  *
14  * You should have received a copy of the GNU General Public License along with
15  * this program; if not, write to the Free Software Foundation, Inc., 59 Temple
16  * Place - Suite 330, Boston, MA 02111-1307 USA.
17  *
18  * Authors:
19  *   Haiyang Zhang <haiyangz@microsoft.com>
20  *   Hank Janssen  <hjanssen@microsoft.com>
21  *
22  */
23
24 #include <linux/kernel.h>
25 #include <linux/highmem.h>
26 #include <asm/kmap_types.h>
27 #include <asm/io.h>
28
29 #include "osd.h"
30 #include "logging.h"
31 #include "NetVscApi.h"
32 #include "RndisFilter.h"
33
34
35 /* Data types */
36
37
38 typedef struct _RNDIS_FILTER_DRIVER_OBJECT {
39         /* The original driver */
40         struct netvsc_driver InnerDriver;
41
42 } RNDIS_FILTER_DRIVER_OBJECT;
43
44 typedef enum {
45         RNDIS_DEV_UNINITIALIZED = 0,
46         RNDIS_DEV_INITIALIZING,
47         RNDIS_DEV_INITIALIZED,
48         RNDIS_DEV_DATAINITIALIZED,
49 } RNDIS_DEVICE_STATE;
50
51 typedef struct _RNDIS_DEVICE {
52         struct NETVSC_DEVICE                    *NetDevice;
53
54         RNDIS_DEVICE_STATE              State;
55         u32                                     LinkStatus;
56         atomic_t NewRequestId;
57
58         spinlock_t request_lock;
59         LIST_ENTRY                              RequestList;
60
61         unsigned char                                   HwMacAddr[HW_MACADDR_LEN];
62 } RNDIS_DEVICE;
63
64
65 typedef struct _RNDIS_REQUEST {
66         LIST_ENTRY                                      ListEntry;
67         struct osd_waitevent *WaitEvent;
68
69         /* FIXME: We assumed a fixed size response here. If we do ever need to handle a bigger response, */
70         /* we can either define a max response message or add a response buffer variable above this field */
71         struct rndis_message ResponseMessage;
72
73         /* Simplify allocation by having a netvsc packet inline */
74         struct hv_netvsc_packet Packet;
75         struct hv_page_buffer Buffer;
76         /* FIXME: We assumed a fixed size request here. */
77         struct rndis_message RequestMessage;
78 } RNDIS_REQUEST;
79
80
81 typedef struct _RNDIS_FILTER_PACKET {
82         void                                            *CompletionContext;
83         PFN_ON_SENDRECVCOMPLETION       OnCompletion;
84
85         struct rndis_message Message;
86 } RNDIS_FILTER_PACKET;
87
88
89 /* Internal routines */
90
91 static int
92 RndisFilterSendRequest(
93         RNDIS_DEVICE    *Device,
94         RNDIS_REQUEST   *Request
95         );
96
97 static void RndisFilterReceiveResponse(RNDIS_DEVICE *Device,
98                                        struct rndis_message *Response);
99
100 static void RndisFilterReceiveIndicateStatus(RNDIS_DEVICE *Device,
101                                              struct rndis_message *Response);
102
103 static void RndisFilterReceiveData(RNDIS_DEVICE *Device,
104                                    struct rndis_message *Message,
105                                    struct hv_netvsc_packet *Packet);
106
107 static int RndisFilterOnReceive(
108         struct hv_device *Device,
109         struct hv_netvsc_packet *Packet
110         );
111
112 static int
113 RndisFilterQueryDevice(
114         RNDIS_DEVICE    *Device,
115         u32                     Oid,
116         void                    *Result,
117         u32                     *ResultSize
118         );
119
120 static inline int
121 RndisFilterQueryDeviceMac(
122         RNDIS_DEVICE    *Device
123         );
124
125 static inline int
126 RndisFilterQueryDeviceLinkStatus(
127         RNDIS_DEVICE    *Device
128         );
129
130 static int
131 RndisFilterSetPacketFilter(
132         RNDIS_DEVICE    *Device,
133         u32                     NewFilter
134         );
135
136 static int
137 RndisFilterInitDevice(
138         RNDIS_DEVICE            *Device
139         );
140
141 static int
142 RndisFilterOpenDevice(
143         RNDIS_DEVICE            *Device
144         );
145
146 static int
147 RndisFilterCloseDevice(
148         RNDIS_DEVICE            *Device
149         );
150
151 static int
152 RndisFilterOnDeviceAdd(
153         struct hv_device *Device,
154         void                    *AdditionalInfo
155         );
156
157 static int
158 RndisFilterOnDeviceRemove(
159         struct hv_device *Device
160         );
161
162 static void
163 RndisFilterOnCleanup(
164         struct hv_driver *Driver
165         );
166
167 static int
168 RndisFilterOnOpen(
169         struct hv_device *Device
170         );
171
172 static int
173 RndisFilterOnClose(
174         struct hv_device *Device
175         );
176
177 static int
178 RndisFilterOnSend(
179         struct hv_device *Device,
180         struct hv_netvsc_packet *Packet
181         );
182
183 static void
184 RndisFilterOnSendCompletion(
185    void *Context
186         );
187
188 static void
189 RndisFilterOnSendRequestCompletion(
190    void *Context
191         );
192
193
194 /* Global var */
195
196
197 /* The one and only */
198 static RNDIS_FILTER_DRIVER_OBJECT gRndisFilter;
199
200 static inline RNDIS_DEVICE* GetRndisDevice(void)
201 {
202         RNDIS_DEVICE *device;
203
204         device = kzalloc(sizeof(RNDIS_DEVICE), GFP_KERNEL);
205         if (!device)
206         {
207                 return NULL;
208         }
209
210         spin_lock_init(&device->request_lock);
211
212         INITIALIZE_LIST_HEAD(&device->RequestList);
213
214         device->State = RNDIS_DEV_UNINITIALIZED;
215
216         return device;
217 }
218
219 static inline void PutRndisDevice(RNDIS_DEVICE *Device)
220 {
221         kfree(Device);
222 }
223
224 static inline RNDIS_REQUEST* GetRndisRequest(RNDIS_DEVICE *Device, u32 MessageType, u32 MessageLength)
225 {
226         RNDIS_REQUEST *request;
227         struct rndis_message *rndisMessage;
228         struct rndis_set_request *set;
229         unsigned long flags;
230
231         request = kzalloc(sizeof(RNDIS_REQUEST), GFP_KERNEL);
232         if (!request)
233         {
234                 return NULL;
235         }
236
237         request->WaitEvent = osd_WaitEventCreate();
238         if (!request->WaitEvent)
239         {
240                 kfree(request);
241                 return NULL;
242         }
243
244         rndisMessage = &request->RequestMessage;
245         rndisMessage->NdisMessageType = MessageType;
246         rndisMessage->MessageLength = MessageLength;
247
248         /* Set the request id. This field is always after the rndis header for request/response packet types so */
249         /* we just used the SetRequest as a template */
250         set = &rndisMessage->Message.SetRequest;
251         set->RequestId = atomic_inc_return(&Device->NewRequestId);
252
253         /* Add to the request list */
254         spin_lock_irqsave(&Device->request_lock, flags);
255         INSERT_TAIL_LIST(&Device->RequestList, &request->ListEntry);
256         spin_unlock_irqrestore(&Device->request_lock, flags);
257
258         return request;
259 }
260
261 static inline void PutRndisRequest(RNDIS_DEVICE *Device, RNDIS_REQUEST *Request)
262 {
263         unsigned long flags;
264
265         spin_lock_irqsave(&Device->request_lock, flags);
266         REMOVE_ENTRY_LIST(&Request->ListEntry);
267         spin_unlock_irqrestore(&Device->request_lock, flags);
268
269         kfree(Request->WaitEvent);
270         kfree(Request);
271 }
272
273 static inline void DumpRndisMessage(struct rndis_message *RndisMessage)
274 {
275         switch (RndisMessage->NdisMessageType)
276         {
277         case REMOTE_NDIS_PACKET_MSG:
278                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_PACKET_MSG (len %u, data offset %u data len %u, # oob %u, oob offset %u, oob len %u, pkt offset %u, pkt len %u",
279                         RndisMessage->MessageLength,
280                         RndisMessage->Message.Packet.DataOffset,
281                         RndisMessage->Message.Packet.DataLength,
282                         RndisMessage->Message.Packet.NumOOBDataElements,
283                         RndisMessage->Message.Packet.OOBDataOffset,
284                         RndisMessage->Message.Packet.OOBDataLength,
285                         RndisMessage->Message.Packet.PerPacketInfoOffset,
286                         RndisMessage->Message.Packet.PerPacketInfoLength);
287                 break;
288
289         case REMOTE_NDIS_INITIALIZE_CMPLT:
290                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_INITIALIZE_CMPLT (len %u, id 0x%x, status 0x%x, major %d, minor %d, device flags %d, max xfer size 0x%x, max pkts %u, pkt aligned %u)",
291                         RndisMessage->MessageLength,
292                         RndisMessage->Message.InitializeComplete.RequestId,
293                         RndisMessage->Message.InitializeComplete.Status,
294                         RndisMessage->Message.InitializeComplete.MajorVersion,
295                         RndisMessage->Message.InitializeComplete.MinorVersion,
296                         RndisMessage->Message.InitializeComplete.DeviceFlags,
297                         RndisMessage->Message.InitializeComplete.MaxTransferSize,
298                         RndisMessage->Message.InitializeComplete.MaxPacketsPerMessage,
299                         RndisMessage->Message.InitializeComplete.PacketAlignmentFactor);
300                 break;
301
302         case REMOTE_NDIS_QUERY_CMPLT:
303                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_QUERY_CMPLT (len %u, id 0x%x, status 0x%x, buf len %u, buf offset %u)",
304                         RndisMessage->MessageLength,
305                         RndisMessage->Message.QueryComplete.RequestId,
306                         RndisMessage->Message.QueryComplete.Status,
307                         RndisMessage->Message.QueryComplete.InformationBufferLength,
308                         RndisMessage->Message.QueryComplete.InformationBufferOffset);
309                 break;
310
311         case REMOTE_NDIS_SET_CMPLT:
312                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_SET_CMPLT (len %u, id 0x%x, status 0x%x)",
313                         RndisMessage->MessageLength,
314                         RndisMessage->Message.SetComplete.RequestId,
315                         RndisMessage->Message.SetComplete.Status);
316                 break;
317
318         case REMOTE_NDIS_INDICATE_STATUS_MSG:
319                 DPRINT_DBG(NETVSC, "REMOTE_NDIS_INDICATE_STATUS_MSG (len %u, status 0x%x, buf len %u, buf offset %u)",
320                         RndisMessage->MessageLength,
321                         RndisMessage->Message.IndicateStatus.Status,
322                         RndisMessage->Message.IndicateStatus.StatusBufferLength,
323                         RndisMessage->Message.IndicateStatus.StatusBufferOffset);
324                 break;
325
326         default:
327                 DPRINT_DBG(NETVSC, "0x%x (len %u)",
328                         RndisMessage->NdisMessageType,
329                         RndisMessage->MessageLength);
330                 break;
331         }
332 }
333
334 static int
335 RndisFilterSendRequest(
336         RNDIS_DEVICE    *Device,
337         RNDIS_REQUEST   *Request
338         )
339 {
340         int ret=0;
341         struct hv_netvsc_packet *packet;
342
343         DPRINT_ENTER(NETVSC);
344
345         /* Setup the packet to send it */
346         packet = &Request->Packet;
347
348         packet->IsDataPacket = false;
349         packet->TotalDataBufferLength = Request->RequestMessage.MessageLength;
350         packet->PageBufferCount = 1;
351
352         packet->PageBuffers[0].Pfn = virt_to_phys(&Request->RequestMessage) >> PAGE_SHIFT;
353         packet->PageBuffers[0].Length = Request->RequestMessage.MessageLength;
354         packet->PageBuffers[0].Offset = (unsigned long)&Request->RequestMessage & (PAGE_SIZE -1);
355
356         packet->Completion.Send.SendCompletionContext = Request;/* packet; */
357         packet->Completion.Send.OnSendCompletion = RndisFilterOnSendRequestCompletion;
358         packet->Completion.Send.SendCompletionTid = (unsigned long)Device;
359
360         ret = gRndisFilter.InnerDriver.OnSend(Device->NetDevice->Device, packet);
361         DPRINT_EXIT(NETVSC);
362         return ret;
363 }
364
365
366 static void RndisFilterReceiveResponse(RNDIS_DEVICE *Device,
367                                        struct rndis_message *Response)
368 {
369         LIST_ENTRY *anchor;
370         LIST_ENTRY *curr;
371         RNDIS_REQUEST *request=NULL;
372         bool found = false;
373         unsigned long flags;
374
375         DPRINT_ENTER(NETVSC);
376
377         spin_lock_irqsave(&Device->request_lock, flags);
378         ITERATE_LIST_ENTRIES(anchor, curr, &Device->RequestList)
379         {
380                 request = CONTAINING_RECORD(curr, RNDIS_REQUEST, ListEntry);
381
382                 /* All request/response message contains RequestId as the 1st field */
383                 if (request->RequestMessage.Message.InitializeRequest.RequestId == Response->Message.InitializeComplete.RequestId)
384                 {
385                         DPRINT_DBG(NETVSC, "found rndis request for this response (id 0x%x req type 0x%x res type 0x%x)",
386                                 request->RequestMessage.Message.InitializeRequest.RequestId, request->RequestMessage.NdisMessageType, Response->NdisMessageType);
387
388                         found = true;
389                         break;
390                 }
391         }
392         spin_unlock_irqrestore(&Device->request_lock, flags);
393
394         if (found)
395         {
396                 if (Response->MessageLength <= sizeof(struct rndis_message))
397                 {
398                         memcpy(&request->ResponseMessage, Response, Response->MessageLength);
399                 }
400                 else
401                 {
402                         DPRINT_ERR(NETVSC, "rndis response buffer overflow detected (size %u max %zu)", Response->MessageLength, sizeof(RNDIS_FILTER_PACKET));
403
404                         if (Response->NdisMessageType == REMOTE_NDIS_RESET_CMPLT) /* does not have a request id field */
405                         {
406                                 request->ResponseMessage.Message.ResetComplete.Status = STATUS_BUFFER_OVERFLOW;
407                         }
408                         else
409                         {
410                                 request->ResponseMessage.Message.InitializeComplete.Status = STATUS_BUFFER_OVERFLOW;
411                         }
412                 }
413
414                 osd_WaitEventSet(request->WaitEvent);
415         }
416         else
417         {
418                 DPRINT_ERR(NETVSC, "no rndis request found for this response (id 0x%x res type 0x%x)",
419                                 Response->Message.InitializeComplete.RequestId, Response->NdisMessageType);
420         }
421
422         DPRINT_EXIT(NETVSC);
423 }
424
425 static void RndisFilterReceiveIndicateStatus(RNDIS_DEVICE *Device,
426                                              struct rndis_message *Response)
427 {
428         struct rndis_indicate_status *indicate = &Response->Message.IndicateStatus;
429
430         if (indicate->Status == RNDIS_STATUS_MEDIA_CONNECT)
431         {
432                 gRndisFilter.InnerDriver.OnLinkStatusChanged(Device->NetDevice->Device, 1);
433         }
434         else if (indicate->Status == RNDIS_STATUS_MEDIA_DISCONNECT)
435         {
436                 gRndisFilter.InnerDriver.OnLinkStatusChanged(Device->NetDevice->Device, 0);
437         }
438         else
439         {
440                 /* TODO: */
441         }
442 }
443
444 static void RndisFilterReceiveData(RNDIS_DEVICE *Device,
445                                    struct rndis_message *Message,
446                                    struct hv_netvsc_packet *Packet)
447 {
448         struct rndis_packet *rndisPacket;
449         u32 dataOffset;
450
451         DPRINT_ENTER(NETVSC);
452
453         /* empty ethernet frame ?? */
454         ASSERT(Packet->PageBuffers[0].Length > RNDIS_MESSAGE_SIZE(struct rndis_packet));
455
456         rndisPacket = &Message->Message.Packet;
457
458         /* FIXME: Handle multiple rndis pkt msgs that maybe enclosed in this */
459         /* netvsc packet (ie TotalDataBufferLength != MessageLength) */
460
461         /* Remove the rndis header and pass it back up the stack */
462         dataOffset = RNDIS_HEADER_SIZE + rndisPacket->DataOffset;
463
464         Packet->TotalDataBufferLength -= dataOffset;
465         Packet->PageBuffers[0].Offset += dataOffset;
466         Packet->PageBuffers[0].Length -= dataOffset;
467
468         Packet->IsDataPacket = true;
469
470         gRndisFilter.InnerDriver.OnReceiveCallback(Device->NetDevice->Device, Packet);
471
472         DPRINT_EXIT(NETVSC);
473 }
474
475 static int
476 RndisFilterOnReceive(
477         struct hv_device *Device,
478         struct hv_netvsc_packet *Packet
479         )
480 {
481         struct NETVSC_DEVICE *netDevice = (struct NETVSC_DEVICE*)Device->Extension;
482         RNDIS_DEVICE *rndisDevice;
483         struct rndis_message rndisMessage;
484         struct rndis_message *rndisHeader;
485
486         DPRINT_ENTER(NETVSC);
487
488         ASSERT(netDevice);
489         /* Make sure the rndis device state is initialized */
490         if (!netDevice->Extension)
491         {
492                 DPRINT_ERR(NETVSC, "got rndis message but no rndis device...dropping this message!");
493                 DPRINT_EXIT(NETVSC);
494                 return -1;
495         }
496
497         rndisDevice = (RNDIS_DEVICE*)netDevice->Extension;
498         if (rndisDevice->State == RNDIS_DEV_UNINITIALIZED)
499         {
500                 DPRINT_ERR(NETVSC, "got rndis message but rndis device uninitialized...dropping this message!");
501                 DPRINT_EXIT(NETVSC);
502                 return -1;
503         }
504
505         rndisHeader = (struct rndis_message *)kmap_atomic(pfn_to_page(Packet->PageBuffers[0].Pfn), KM_IRQ0);
506
507         rndisHeader = (void*)((unsigned long)rndisHeader + Packet->PageBuffers[0].Offset);
508
509         /* Make sure we got a valid rndis message */
510         /* FIXME: There seems to be a bug in set completion msg where its MessageLength is 16 bytes but */
511         /* the ByteCount field in the xfer page range shows 52 bytes */
512 #if 0
513         if ( Packet->TotalDataBufferLength != rndisHeader->MessageLength )
514         {
515                 kunmap_atomic(rndisHeader - Packet->PageBuffers[0].Offset, KM_IRQ0);
516
517                 DPRINT_ERR(NETVSC, "invalid rndis message? (expected %u bytes got %u)...dropping this message!",
518                         rndisHeader->MessageLength, Packet->TotalDataBufferLength);
519                 DPRINT_EXIT(NETVSC);
520                 return -1;
521         }
522 #endif
523
524         if ((rndisHeader->NdisMessageType != REMOTE_NDIS_PACKET_MSG) && (rndisHeader->MessageLength > sizeof(struct rndis_message)))
525         {
526                 DPRINT_ERR(NETVSC, "incoming rndis message buffer overflow detected (got %u, max %zu)...marking it an error!",
527                         rndisHeader->MessageLength, sizeof(struct rndis_message));
528         }
529
530         memcpy(&rndisMessage, rndisHeader, (rndisHeader->MessageLength > sizeof(struct rndis_message))?sizeof(struct rndis_message):rndisHeader->MessageLength);
531
532         kunmap_atomic(rndisHeader - Packet->PageBuffers[0].Offset, KM_IRQ0);
533
534         DumpRndisMessage(&rndisMessage);
535
536         switch (rndisMessage.NdisMessageType)
537         {
538                 /* data msg */
539         case REMOTE_NDIS_PACKET_MSG:
540                 RndisFilterReceiveData(rndisDevice, &rndisMessage, Packet);
541                 break;
542
543                 /* completion msgs */
544         case REMOTE_NDIS_INITIALIZE_CMPLT:
545         case REMOTE_NDIS_QUERY_CMPLT:
546         case REMOTE_NDIS_SET_CMPLT:
547         /* case REMOTE_NDIS_RESET_CMPLT: */
548         /* case REMOTE_NDIS_KEEPALIVE_CMPLT: */
549                 RndisFilterReceiveResponse(rndisDevice, &rndisMessage);
550                 break;
551
552                 /* notification msgs */
553         case REMOTE_NDIS_INDICATE_STATUS_MSG:
554                 RndisFilterReceiveIndicateStatus(rndisDevice, &rndisMessage);
555                 break;
556         default:
557                 DPRINT_ERR(NETVSC, "unhandled rndis message (type %u len %u)", rndisMessage.NdisMessageType, rndisMessage.MessageLength);
558                 break;
559         }
560
561         DPRINT_EXIT(NETVSC);
562         return 0;
563 }
564
565
566 static int
567 RndisFilterQueryDevice(
568         RNDIS_DEVICE    *Device,
569         u32                     Oid,
570         void                    *Result,
571         u32                     *ResultSize
572         )
573 {
574         RNDIS_REQUEST *request;
575         u32 inresultSize = *ResultSize;
576         struct rndis_query_request *query;
577         struct rndis_query_complete *queryComplete;
578         int ret=0;
579
580         DPRINT_ENTER(NETVSC);
581
582         ASSERT(Result);
583
584         *ResultSize = 0;
585         request = GetRndisRequest(Device, REMOTE_NDIS_QUERY_MSG, RNDIS_MESSAGE_SIZE(struct rndis_query_request));
586         if (!request)
587         {
588                 ret = -1;
589                 goto Cleanup;
590         }
591
592         /* Setup the rndis query */
593         query = &request->RequestMessage.Message.QueryRequest;
594         query->Oid = Oid;
595         query->InformationBufferOffset = sizeof(struct rndis_query_request);
596         query->InformationBufferLength = 0;
597         query->DeviceVcHandle = 0;
598
599         ret = RndisFilterSendRequest(Device, request);
600         if (ret != 0)
601         {
602                 goto Cleanup;
603         }
604
605         osd_WaitEventWait(request->WaitEvent);
606
607         /* Copy the response back */
608         queryComplete = &request->ResponseMessage.Message.QueryComplete;
609
610         if (queryComplete->InformationBufferLength > inresultSize)
611         {
612                 ret = -1;
613                 goto Cleanup;
614         }
615
616         memcpy(Result,
617                         (void*)((unsigned long)queryComplete + queryComplete->InformationBufferOffset),
618                         queryComplete->InformationBufferLength);
619
620         *ResultSize = queryComplete->InformationBufferLength;
621
622 Cleanup:
623         if (request)
624         {
625                 PutRndisRequest(Device, request);
626         }
627         DPRINT_EXIT(NETVSC);
628
629         return ret;
630 }
631
632 static inline int
633 RndisFilterQueryDeviceMac(
634         RNDIS_DEVICE    *Device
635         )
636 {
637         u32 size=HW_MACADDR_LEN;
638
639         return RndisFilterQueryDevice(Device,
640                                                                         RNDIS_OID_802_3_PERMANENT_ADDRESS,
641                                                                         Device->HwMacAddr,
642                                                                         &size);
643 }
644
645 static inline int
646 RndisFilterQueryDeviceLinkStatus(
647         RNDIS_DEVICE    *Device
648         )
649 {
650         u32 size=sizeof(u32);
651
652         return RndisFilterQueryDevice(Device,
653                                                                         RNDIS_OID_GEN_MEDIA_CONNECT_STATUS,
654                                                                         &Device->LinkStatus,
655                                                                         &size);
656 }
657
658 static int
659 RndisFilterSetPacketFilter(
660         RNDIS_DEVICE    *Device,
661         u32                     NewFilter
662         )
663 {
664         RNDIS_REQUEST *request;
665         struct rndis_set_request *set;
666         struct rndis_set_complete *setComplete;
667         u32 status;
668         int ret;
669
670         DPRINT_ENTER(NETVSC);
671
672         ASSERT(RNDIS_MESSAGE_SIZE(struct rndis_set_request) + sizeof(u32) <= sizeof(struct rndis_message));
673
674         request = GetRndisRequest(Device, REMOTE_NDIS_SET_MSG, RNDIS_MESSAGE_SIZE(struct rndis_set_request) + sizeof(u32));
675         if (!request)
676         {
677                 ret = -1;
678                 goto Cleanup;
679         }
680
681         /* Setup the rndis set */
682         set = &request->RequestMessage.Message.SetRequest;
683         set->Oid = RNDIS_OID_GEN_CURRENT_PACKET_FILTER;
684         set->InformationBufferLength = sizeof(u32);
685         set->InformationBufferOffset = sizeof(struct rndis_set_request);
686
687         memcpy((void*)(unsigned long)set + sizeof(struct rndis_set_request), &NewFilter, sizeof(u32));
688
689         ret = RndisFilterSendRequest(Device, request);
690         if (ret != 0)
691         {
692                 goto Cleanup;
693         }
694
695         ret = osd_WaitEventWaitEx(request->WaitEvent, 2000/*2sec*/);
696         if (!ret)
697         {
698                 ret = -1;
699                 DPRINT_ERR(NETVSC, "timeout before we got a set response...");
700                 /* We cant deallocate the request since we may still receive a send completion for it. */
701                 goto Exit;
702         }
703         else
704         {
705                 if (ret > 0)
706                 {
707                         ret = 0;
708                 }
709                 setComplete = &request->ResponseMessage.Message.SetComplete;
710                 status = setComplete->Status;
711         }
712
713 Cleanup:
714         if (request)
715         {
716                 PutRndisRequest(Device, request);
717         }
718 Exit:
719         DPRINT_EXIT(NETVSC);
720
721         return ret;
722 }
723
724 int RndisFilterInit(struct netvsc_driver *Driver)
725 {
726         DPRINT_ENTER(NETVSC);
727
728         DPRINT_DBG(NETVSC, "sizeof(RNDIS_FILTER_PACKET) == %zd", sizeof(RNDIS_FILTER_PACKET));
729
730         Driver->RequestExtSize = sizeof(RNDIS_FILTER_PACKET);
731         Driver->AdditionalRequestPageBufferCount = 1; /* For rndis header */
732
733         /* Driver->Context = rndisDriver; */
734
735         memset(&gRndisFilter, 0, sizeof(RNDIS_FILTER_DRIVER_OBJECT));
736
737         /*rndisDriver->Driver = Driver;
738
739         ASSERT(Driver->OnLinkStatusChanged);
740         rndisDriver->OnLinkStatusChanged = Driver->OnLinkStatusChanged;*/
741
742         /* Save the original dispatch handlers before we override it */
743         gRndisFilter.InnerDriver.Base.OnDeviceAdd = Driver->Base.OnDeviceAdd;
744         gRndisFilter.InnerDriver.Base.OnDeviceRemove = Driver->Base.OnDeviceRemove;
745         gRndisFilter.InnerDriver.Base.OnCleanup = Driver->Base.OnCleanup;
746
747         ASSERT(Driver->OnSend);
748         ASSERT(Driver->OnReceiveCallback);
749         gRndisFilter.InnerDriver.OnSend = Driver->OnSend;
750         gRndisFilter.InnerDriver.OnReceiveCallback = Driver->OnReceiveCallback;
751         gRndisFilter.InnerDriver.OnLinkStatusChanged = Driver->OnLinkStatusChanged;
752
753         /* Override */
754         Driver->Base.OnDeviceAdd = RndisFilterOnDeviceAdd;
755         Driver->Base.OnDeviceRemove = RndisFilterOnDeviceRemove;
756         Driver->Base.OnCleanup = RndisFilterOnCleanup;
757         Driver->OnSend = RndisFilterOnSend;
758         Driver->OnOpen = RndisFilterOnOpen;
759         Driver->OnClose = RndisFilterOnClose;
760         /* Driver->QueryLinkStatus = RndisFilterQueryDeviceLinkStatus; */
761         Driver->OnReceiveCallback = RndisFilterOnReceive;
762
763         DPRINT_EXIT(NETVSC);
764
765         return 0;
766 }
767
768 static int
769 RndisFilterInitDevice(
770         RNDIS_DEVICE    *Device
771         )
772 {
773         RNDIS_REQUEST *request;
774         struct rndis_initialize_request *init;
775         struct rndis_initialize_complete *initComplete;
776         u32 status;
777         int ret;
778
779         DPRINT_ENTER(NETVSC);
780
781         request = GetRndisRequest(Device, REMOTE_NDIS_INITIALIZE_MSG, RNDIS_MESSAGE_SIZE(struct rndis_initialize_request));
782         if (!request)
783         {
784                 ret = -1;
785                 goto Cleanup;
786         }
787
788         /* Setup the rndis set */
789         init = &request->RequestMessage.Message.InitializeRequest;
790         init->MajorVersion = RNDIS_MAJOR_VERSION;
791         init->MinorVersion = RNDIS_MINOR_VERSION;
792         init->MaxTransferSize = 2048; /* FIXME: Use 1536 - rounded ethernet frame size */
793
794         Device->State = RNDIS_DEV_INITIALIZING;
795
796         ret = RndisFilterSendRequest(Device, request);
797         if (ret != 0)
798         {
799                 Device->State = RNDIS_DEV_UNINITIALIZED;
800                 goto Cleanup;
801         }
802
803         osd_WaitEventWait(request->WaitEvent);
804
805         initComplete = &request->ResponseMessage.Message.InitializeComplete;
806         status = initComplete->Status;
807         if (status == RNDIS_STATUS_SUCCESS)
808         {
809                 Device->State = RNDIS_DEV_INITIALIZED;
810                 ret = 0;
811         }
812         else
813         {
814                 Device->State = RNDIS_DEV_UNINITIALIZED;
815                 ret = -1;
816         }
817
818 Cleanup:
819         if (request)
820         {
821                 PutRndisRequest(Device, request);
822         }
823         DPRINT_EXIT(NETVSC);
824
825         return ret;
826 }
827
828 static void
829 RndisFilterHaltDevice(
830         RNDIS_DEVICE    *Device
831         )
832 {
833         RNDIS_REQUEST *request;
834         struct rndis_halt_request *halt;
835
836         DPRINT_ENTER(NETVSC);
837
838         /* Attempt to do a rndis device halt */
839         request = GetRndisRequest(Device, REMOTE_NDIS_HALT_MSG, RNDIS_MESSAGE_SIZE(struct rndis_halt_request));
840         if (!request)
841         {
842                 goto Cleanup;
843         }
844
845         /* Setup the rndis set */
846         halt = &request->RequestMessage.Message.HaltRequest;
847         halt->RequestId = atomic_inc_return(&Device->NewRequestId);
848
849         /* Ignore return since this msg is optional. */
850         RndisFilterSendRequest(Device, request);
851
852         Device->State = RNDIS_DEV_UNINITIALIZED;
853
854 Cleanup:
855         if (request)
856         {
857                 PutRndisRequest(Device, request);
858         }
859         DPRINT_EXIT(NETVSC);
860         return;
861 }
862
863
864 static int
865 RndisFilterOpenDevice(
866         RNDIS_DEVICE    *Device
867         )
868 {
869         int ret=0;
870
871         DPRINT_ENTER(NETVSC);
872
873         if (Device->State != RNDIS_DEV_INITIALIZED)
874                 return 0;
875
876         ret = RndisFilterSetPacketFilter(Device, NDIS_PACKET_TYPE_BROADCAST|NDIS_PACKET_TYPE_DIRECTED);
877         if (ret == 0)
878         {
879                 Device->State = RNDIS_DEV_DATAINITIALIZED;
880         }
881
882         DPRINT_EXIT(NETVSC);
883         return ret;
884 }
885
886 static int
887 RndisFilterCloseDevice(
888         RNDIS_DEVICE            *Device
889         )
890 {
891         int ret;
892
893         DPRINT_ENTER(NETVSC);
894
895         if (Device->State != RNDIS_DEV_DATAINITIALIZED)
896                 return 0;
897
898         ret = RndisFilterSetPacketFilter(Device, 0);
899         if (ret == 0)
900         {
901                 Device->State = RNDIS_DEV_INITIALIZED;
902         }
903
904         DPRINT_EXIT(NETVSC);
905
906         return ret;
907 }
908
909
910 static int
911 RndisFilterOnDeviceAdd(
912         struct hv_device *Device,
913         void                    *AdditionalInfo
914         )
915 {
916         int ret;
917         struct NETVSC_DEVICE *netDevice;
918         RNDIS_DEVICE *rndisDevice;
919         struct netvsc_device_info *deviceInfo = (struct netvsc_device_info *)AdditionalInfo;
920
921         DPRINT_ENTER(NETVSC);
922
923         rndisDevice = GetRndisDevice();
924         if (!rndisDevice)
925         {
926                 DPRINT_EXIT(NETVSC);
927                 return -1;
928         }
929
930         DPRINT_DBG(NETVSC, "rndis device object allocated - %p", rndisDevice);
931
932         /* Let the inner driver handle this first to create the netvsc channel */
933         /* NOTE! Once the channel is created, we may get a receive callback */
934         /* (RndisFilterOnReceive()) before this call is completed */
935         ret = gRndisFilter.InnerDriver.Base.OnDeviceAdd(Device, AdditionalInfo);
936         if (ret != 0)
937         {
938                 PutRndisDevice(rndisDevice);
939                 DPRINT_EXIT(NETVSC);
940                 return ret;
941         }
942
943
944         /* Initialize the rndis device */
945
946         netDevice = (struct NETVSC_DEVICE*)Device->Extension;
947         ASSERT(netDevice);
948         ASSERT(netDevice->Device);
949
950         netDevice->Extension = rndisDevice;
951         rndisDevice->NetDevice = netDevice;
952
953         /* Send the rndis initialization message */
954         ret = RndisFilterInitDevice(rndisDevice);
955         if (ret != 0)
956         {
957                 /* TODO: If rndis init failed, we will need to shut down the channel */
958         }
959
960         /* Get the mac address */
961         ret = RndisFilterQueryDeviceMac(rndisDevice);
962         if (ret != 0)
963         {
964                 /* TODO: shutdown rndis device and the channel */
965         }
966
967         DPRINT_INFO(NETVSC, "Device 0x%p mac addr %02x%02x%02x%02x%02x%02x",
968                                 rndisDevice,
969                                 rndisDevice->HwMacAddr[0],
970                                 rndisDevice->HwMacAddr[1],
971                                 rndisDevice->HwMacAddr[2],
972                                 rndisDevice->HwMacAddr[3],
973                                 rndisDevice->HwMacAddr[4],
974                                 rndisDevice->HwMacAddr[5]);
975
976         memcpy(deviceInfo->MacAddr, rndisDevice->HwMacAddr, HW_MACADDR_LEN);
977
978         RndisFilterQueryDeviceLinkStatus(rndisDevice);
979
980         deviceInfo->LinkState = rndisDevice->LinkStatus;
981         DPRINT_INFO(NETVSC, "Device 0x%p link state %s", rndisDevice, ((deviceInfo->LinkState)?("down"):("up")));
982
983         DPRINT_EXIT(NETVSC);
984
985         return ret;
986 }
987
988
989 static int
990 RndisFilterOnDeviceRemove(
991         struct hv_device *Device
992         )
993 {
994         struct NETVSC_DEVICE *netDevice = (struct NETVSC_DEVICE*)Device->Extension;
995         RNDIS_DEVICE *rndisDevice = (RNDIS_DEVICE*)netDevice->Extension;
996
997         DPRINT_ENTER(NETVSC);
998
999         /* Halt and release the rndis device */
1000         RndisFilterHaltDevice(rndisDevice);
1001
1002         PutRndisDevice(rndisDevice);
1003         netDevice->Extension = NULL;
1004
1005         /* Pass control to inner driver to remove the device */
1006         gRndisFilter.InnerDriver.Base.OnDeviceRemove(Device);
1007
1008         DPRINT_EXIT(NETVSC);
1009
1010         return 0;
1011 }
1012
1013
1014 static void
1015 RndisFilterOnCleanup(
1016         struct hv_driver *Driver
1017         )
1018 {
1019         DPRINT_ENTER(NETVSC);
1020
1021         DPRINT_EXIT(NETVSC);
1022 }
1023
1024 static int
1025 RndisFilterOnOpen(
1026         struct hv_device *Device
1027         )
1028 {
1029         int ret;
1030         struct NETVSC_DEVICE *netDevice = (struct NETVSC_DEVICE*)Device->Extension;
1031
1032         DPRINT_ENTER(NETVSC);
1033
1034         ASSERT(netDevice);
1035         ret = RndisFilterOpenDevice((RNDIS_DEVICE*)netDevice->Extension);
1036
1037         DPRINT_EXIT(NETVSC);
1038
1039         return ret;
1040 }
1041
1042 static int
1043 RndisFilterOnClose(
1044         struct hv_device *Device
1045         )
1046 {
1047         int ret;
1048         struct NETVSC_DEVICE *netDevice = (struct NETVSC_DEVICE*)Device->Extension;
1049
1050         DPRINT_ENTER(NETVSC);
1051
1052         ASSERT(netDevice);
1053         ret = RndisFilterCloseDevice((RNDIS_DEVICE*)netDevice->Extension);
1054
1055         DPRINT_EXIT(NETVSC);
1056
1057         return ret;
1058 }
1059
1060
1061 static int
1062 RndisFilterOnSend(
1063         struct hv_device *Device,
1064         struct hv_netvsc_packet *Packet
1065         )
1066 {
1067         int ret=0;
1068         RNDIS_FILTER_PACKET *filterPacket;
1069         struct rndis_message *rndisMessage;
1070         struct rndis_packet *rndisPacket;
1071         u32 rndisMessageSize;
1072
1073         DPRINT_ENTER(NETVSC);
1074
1075         /* Add the rndis header */
1076         filterPacket = (RNDIS_FILTER_PACKET*)Packet->Extension;
1077         ASSERT(filterPacket);
1078
1079         memset(filterPacket, 0, sizeof(RNDIS_FILTER_PACKET));
1080
1081         rndisMessage = &filterPacket->Message;
1082         rndisMessageSize = RNDIS_MESSAGE_SIZE(struct rndis_packet);
1083
1084         rndisMessage->NdisMessageType = REMOTE_NDIS_PACKET_MSG;
1085         rndisMessage->MessageLength = Packet->TotalDataBufferLength + rndisMessageSize;
1086
1087         rndisPacket = &rndisMessage->Message.Packet;
1088         rndisPacket->DataOffset = sizeof(struct rndis_packet);
1089         rndisPacket->DataLength = Packet->TotalDataBufferLength;
1090
1091         Packet->IsDataPacket = true;
1092         Packet->PageBuffers[0].Pfn      = virt_to_phys(rndisMessage) >> PAGE_SHIFT;
1093         Packet->PageBuffers[0].Offset   = (unsigned long)rndisMessage & (PAGE_SIZE-1);
1094         Packet->PageBuffers[0].Length   = rndisMessageSize;
1095
1096         /* Save the packet send completion and context */
1097         filterPacket->OnCompletion = Packet->Completion.Send.OnSendCompletion;
1098         filterPacket->CompletionContext = Packet->Completion.Send.SendCompletionContext;
1099
1100         /* Use ours */
1101         Packet->Completion.Send.OnSendCompletion = RndisFilterOnSendCompletion;
1102         Packet->Completion.Send.SendCompletionContext = filterPacket;
1103
1104         ret = gRndisFilter.InnerDriver.OnSend(Device, Packet);
1105         if (ret != 0)
1106         {
1107                 /* Reset the completion to originals to allow retries from above */
1108                 Packet->Completion.Send.OnSendCompletion = filterPacket->OnCompletion;
1109                 Packet->Completion.Send.SendCompletionContext = filterPacket->CompletionContext;
1110         }
1111
1112         DPRINT_EXIT(NETVSC);
1113
1114         return ret;
1115 }
1116
1117 static void
1118 RndisFilterOnSendCompletion(
1119    void *Context)
1120 {
1121         RNDIS_FILTER_PACKET *filterPacket = (RNDIS_FILTER_PACKET *)Context;
1122
1123         DPRINT_ENTER(NETVSC);
1124
1125         /* Pass it back to the original handler */
1126         filterPacket->OnCompletion(filterPacket->CompletionContext);
1127
1128         DPRINT_EXIT(NETVSC);
1129 }
1130
1131
1132 static void
1133 RndisFilterOnSendRequestCompletion(
1134    void *Context
1135    )
1136 {
1137         DPRINT_ENTER(NETVSC);
1138
1139         /* Noop */
1140         DPRINT_EXIT(NETVSC);
1141 }