ovr_sdk

view LibOVR/Src/Net/OVR_Session.cpp @ 0:1b39a1b46319

initial 0.4.4
author John Tsiombikas <nuclear@member.fsf.org>
date Wed, 14 Jan 2015 06:51:16 +0200
parents
children
line source
1 /************************************************************************************
3 Filename : OVR_Session.h
4 Content : One network session that provides connection/disconnection events.
5 Created : June 10, 2014
6 Authors : Kevin Jenkins, Chris Taylor
8 Copyright : Copyright 2014 Oculus VR, LLC All Rights reserved.
10 Licensed under the Oculus VR Rift SDK License Version 3.2 (the "License");
11 you may not use the Oculus VR Rift SDK except in compliance with the License,
12 which is provided at the time of installation or download, or which
13 otherwise accompanies this software in either electronic or hard copy form.
15 You may obtain a copy of the License at
17 http://www.oculusvr.com/licenses/LICENSE-3.2
19 Unless required by applicable law or agreed to in writing, the Oculus VR SDK
20 distributed under the License is distributed on an "AS IS" BASIS,
21 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22 See the License for the specific language governing permissions and
23 limitations under the License.
25 ************************************************************************************/
27 #include "OVR_Session.h"
28 #include "OVR_PacketizedTCPSocket.h"
29 #include "../Kernel/OVR_Log.h"
30 #include "../Service/Service_NetSessionCommon.h"
32 namespace OVR { namespace Net {
35 //-----------------------------------------------------------------------------
36 // Protocol
38 static const char* OfficialHelloString = "OculusVR_Hello";
39 static const char* OfficialAuthorizedString = "OculusVR_Authorized";
41 void RPC_C2S_Hello::Generate(Net::BitStream* bs)
42 {
43 RPC_C2S_Hello hello;
44 hello.HelloString = OfficialHelloString;
45 hello.MajorVersion = RPCVersion_Major;
46 hello.MinorVersion = RPCVersion_Minor;
47 hello.PatchVersion = RPCVersion_Patch;
48 hello.Serialize(bs);
49 }
51 bool RPC_C2S_Hello::Validate()
52 {
53 return MajorVersion == RPCVersion_Major &&
54 MinorVersion <= RPCVersion_Minor &&
55 HelloString.CompareNoCase(OfficialHelloString) == 0;
56 }
58 void RPC_S2C_Authorization::Generate(Net::BitStream* bs, String errorString)
59 {
60 RPC_S2C_Authorization auth;
61 if (errorString.IsEmpty())
62 {
63 auth.AuthString = OfficialAuthorizedString;
64 }
65 else
66 {
67 auth.AuthString = errorString;
68 }
69 auth.MajorVersion = RPCVersion_Major;
70 auth.MinorVersion = RPCVersion_Minor;
71 auth.PatchVersion = RPCVersion_Patch;
72 auth.Serialize(bs);
73 }
75 bool RPC_S2C_Authorization::Validate()
76 {
77 return AuthString.CompareNoCase(OfficialAuthorizedString) == 0;
78 }
81 //-----------------------------------------------------------------------------
82 // Session
84 void Session::Shutdown()
85 {
86 {
87 Lock::Locker locker(&SocketListenersLock);
89 const int count = SocketListeners.GetSizeI();
90 for (int i = 0; i < count; ++i)
91 {
92 SocketListeners[i]->Close();
93 }
94 }
96 Lock::Locker locker(&ConnectionsLock);
98 const int count = AllConnections.GetSizeI();
99 for (int i = 0; i < count; ++i)
100 {
101 Connection* arrayItem = AllConnections[i].GetPtr();
103 if (arrayItem->Transport == TransportType_PacketizedTCP)
104 {
105 PacketizedTCPConnection* ptcp = (PacketizedTCPConnection*)arrayItem;
107 ptcp->pSocket->Close();
108 }
109 }
110 }
112 SessionResult Session::Listen(ListenerDescription* pListenerDescription)
113 {
114 if (pListenerDescription->Transport == TransportType_PacketizedTCP)
115 {
116 BerkleyListenerDescription* bld = (BerkleyListenerDescription*)pListenerDescription;
117 TCPSocket* tcpSocket = (TCPSocket*)bld->BoundSocketToListenWith.GetPtr();
119 if (tcpSocket->Listen() < 0)
120 {
121 return SessionResult_ListenFailure;
122 }
124 Lock::Locker locker(&SocketListenersLock);
125 SocketListeners.PushBack(tcpSocket);
126 }
127 else if (pListenerDescription->Transport == TransportType_Loopback)
128 {
129 HasLoopbackListener = true;
130 }
131 else
132 {
133 OVR_ASSERT(false);
134 }
136 return SessionResult_OK;
137 }
139 SessionResult Session::Connect(ConnectParameters *cp)
140 {
141 if (cp->Transport == TransportType_PacketizedTCP)
142 {
143 ConnectParametersBerkleySocket* cp2 = (ConnectParametersBerkleySocket*)cp;
144 Ptr<PacketizedTCPConnection> c;
146 {
147 Lock::Locker locker(&ConnectionsLock);
149 int connIndex;
150 Ptr<PacketizedTCPConnection> conn = findConnectionBySocket(AllConnections, cp2->BoundSocketToConnectWith, &connIndex);
151 if (conn)
152 {
153 return SessionResult_AlreadyConnected;
154 }
156 TCPSocketBase* tcpSock = (TCPSocketBase*)cp2->BoundSocketToConnectWith.GetPtr();
158 int ret = tcpSock->Connect(&cp2->RemoteAddress);
159 if (ret < 0)
160 {
161 return SessionResult_ConnectFailure;
162 }
164 Ptr<Connection> newConnection = AllocConnection(cp2->Transport);
165 if (!newConnection)
166 {
167 return SessionResult_ConnectFailure;
168 }
170 c = (PacketizedTCPConnection*)newConnection.GetPtr();
171 c->pSocket = (TCPSocket*) cp2->BoundSocketToConnectWith.GetPtr();
172 c->Address = cp2->RemoteAddress;
173 c->Transport = cp2->Transport;
174 c->SetState(Client_Connecting);
176 AllConnections.PushBack(c);
178 }
180 if (cp2->Blocking)
181 {
182 c->WaitOnConnecting();
183 }
185 if (c->State == State_Connected)
186 {
187 return SessionResult_OK;
188 }
189 else if (c->State == Client_Connecting)
190 {
191 return SessionResult_ConnectInProgress;
192 }
193 else
194 {
195 return SessionResult_ConnectFailure;
196 }
197 }
198 else if (cp->Transport == TransportType_Loopback)
199 {
200 if (HasLoopbackListener)
201 {
202 Ptr<Connection> c = AllocConnection(cp->Transport);
203 if (!c)
204 {
205 return SessionResult_ConnectFailure;
206 }
208 c->Transport = cp->Transport;
209 c->SetState(State_Connected);
211 {
212 Lock::Locker locker(&ConnectionsLock);
213 AllConnections.PushBack(c);
214 }
216 invokeSessionEvent(&SessionListener::OnConnectionRequestAccepted, c);
217 }
218 else
219 {
220 OVR_ASSERT(false);
221 }
222 }
223 else
224 {
225 OVR_ASSERT(false);
226 }
228 return SessionResult_OK;
229 }
231 SessionResult Session::ListenPTCP(OVR::Net::BerkleyBindParameters *bbp)
232 {
233 Ptr<PacketizedTCPSocket> listenSocket = *new OVR::Net::PacketizedTCPSocket();
234 if (listenSocket->Bind(bbp) == INVALID_SOCKET)
235 {
236 return SessionResult_BindFailure;
237 }
239 BerkleyListenerDescription bld;
240 bld.BoundSocketToListenWith = listenSocket.GetPtr();
241 bld.Transport = TransportType_PacketizedTCP;
243 return Listen(&bld);
244 }
246 SessionResult Session::ConnectPTCP(OVR::Net::BerkleyBindParameters* bbp, SockAddr* remoteAddress, bool blocking)
247 {
248 ConnectParametersBerkleySocket cp(NULL, remoteAddress, blocking, TransportType_PacketizedTCP);
249 Ptr<PacketizedTCPSocket> connectSocket = *new PacketizedTCPSocket();
251 cp.BoundSocketToConnectWith = connectSocket.GetPtr();
252 if (connectSocket->Bind(bbp) == INVALID_SOCKET)
253 {
254 return SessionResult_BindFailure;
255 }
257 return Connect(&cp);
258 }
260 Ptr<PacketizedTCPConnection> Session::findConnectionBySockAddr(SockAddr* address)
261 {
262 const int count = AllConnections.GetSizeI();
263 for (int i = 0; i < count; ++i)
264 {
265 Connection* arrayItem = AllConnections[i].GetPtr();
267 if (arrayItem->Transport == TransportType_PacketizedTCP)
268 {
269 PacketizedTCPConnection* conn = (PacketizedTCPConnection*)arrayItem;
271 if (conn->Address == *address)
272 {
273 return conn;
274 }
275 }
276 }
278 return 0;
279 }
281 int Session::Send(SendParameters *payload)
282 {
283 if (payload->pConnection->Transport == TransportType_Loopback)
284 {
285 Lock::Locker locker(&SessionListenersLock);
287 const int count = SessionListeners.GetSizeI();
288 for (int i = 0; i < count; ++i)
289 {
290 SessionListener* sl = SessionListeners[i];
292 // FIXME: This looks like it needs to be reviewed at some point..
293 ReceivePayload rp;
294 rp.Bytes = payload->Bytes;
295 rp.pConnection = payload->pConnection;
296 rp.pData = (uint8_t*)payload->pData; // FIXME
297 ListenerReceiveResult lrr = LRR_CONTINUE;
298 sl->OnReceive(&rp, &lrr);
299 if (lrr == LRR_RETURN)
300 {
301 return payload->Bytes;
302 }
303 else if (lrr == LRR_BREAK)
304 {
305 break;
306 }
307 }
309 return payload->Bytes;
310 }
311 else if (payload->pConnection->Transport == TransportType_PacketizedTCP)
312 {
313 PacketizedTCPConnection* conn = (PacketizedTCPConnection*)payload->pConnection.GetPtr();
315 return conn->pSocket->Send(payload->pData, payload->Bytes);
316 }
317 else
318 {
319 OVR_ASSERT(false);
320 }
322 return 0;
323 }
324 void Session::Broadcast(BroadcastParameters *payload)
325 {
326 SendParameters sp;
327 sp.Bytes=payload->Bytes;
328 sp.pData=payload->pData;
330 {
331 Lock::Locker locker(&ConnectionsLock);
333 const int connectionCount = FullConnections.GetSizeI();
334 for (int i = 0; i < connectionCount; ++i)
335 {
336 sp.pConnection = FullConnections[i];
337 Send(&sp);
338 }
339 }
340 }
341 // DO NOT CALL Poll() FROM MULTIPLE THREADS due to allBlockingTcpSockets being a member
342 void Session::Poll(bool listeners)
343 {
344 allBlockingTcpSockets.Clear();
346 if (listeners)
347 {
348 Lock::Locker locker(&SocketListenersLock);
350 const int listenerCount = SocketListeners.GetSizeI();
351 for (int i = 0; i < listenerCount; ++i)
352 {
353 allBlockingTcpSockets.PushBack(SocketListeners[i]);
354 }
355 }
357 {
358 Lock::Locker locker(&ConnectionsLock);
360 const int connectionCount = AllConnections.GetSizeI();
361 for (int i = 0; i < connectionCount; ++i)
362 {
363 Connection* arrayItem = AllConnections[i].GetPtr();
365 if (arrayItem->Transport == TransportType_PacketizedTCP)
366 {
367 PacketizedTCPConnection* ptcp = (PacketizedTCPConnection*)arrayItem;
369 allBlockingTcpSockets.PushBack(ptcp->pSocket);
370 }
371 else
372 {
373 OVR_ASSERT(false);
374 }
375 }
376 }
378 const int count = allBlockingTcpSockets.GetSizeI();
379 if (count > 0)
380 {
381 TCPSocketPollState state;
383 // Add all the sockets for polling,
384 for (int i = 0; i < count; ++i)
385 {
386 Net::TCPSocket* sock = allBlockingTcpSockets[i].GetPtr();
388 // If socket handle is invalid,
389 if (sock->GetSocketHandle() == INVALID_SOCKET)
390 {
391 OVR_DEBUG_LOG(("[Session] Detected an invalid socket handle - Treating it as a disconnection."));
392 sock->IsConnecting = false;
393 TCP_OnClosed(sock);
394 }
395 else
396 {
397 state.Add(sock);
398 }
399 }
401 // If polling returns with an event,
402 if (state.Poll(allBlockingTcpSockets[0]->GetBlockingTimeoutUsec(), allBlockingTcpSockets[0]->GetBlockingTimeoutSec()))
403 {
404 // Handle any events for each socket
405 for (int i = 0; i < count; ++i)
406 {
407 state.HandleEvent(allBlockingTcpSockets[i], this);
408 }
409 }
410 }
411 }
413 void Session::AddSessionListener(SessionListener* se)
414 {
415 Lock::Locker locker(&SessionListenersLock);
417 const int count = SessionListeners.GetSizeI();
418 for (int i = 0; i < count; ++i)
419 {
420 if (SessionListeners[i] == se)
421 {
422 // Already added
423 return;
424 }
425 }
427 SessionListeners.PushBack(se);
428 se->OnAddedToSession(this);
429 }
431 void Session::RemoveSessionListener(SessionListener* se)
432 {
433 Lock::Locker locker(&SessionListenersLock);
435 const int count = SessionListeners.GetSizeI();
436 for (int i = 0; i < count; ++i)
437 {
438 if (SessionListeners[i] == se)
439 {
440 se->OnRemovedFromSession(this);
442 SessionListeners.RemoveAtUnordered(i);
443 break;
444 }
445 }
446 }
447 SInt32 Session::GetActiveSocketsCount()
448 {
449 Lock::Locker locker1(&SocketListenersLock);
450 Lock::Locker locker2(&ConnectionsLock);
451 return SocketListeners.GetSize() + AllConnections.GetSize()>0;
452 }
453 Ptr<Connection> Session::AllocConnection(TransportType transport)
454 {
455 switch (transport)
456 {
457 case TransportType_Loopback: return *new Connection();
458 case TransportType_TCP: return *new TCPConnection();
459 case TransportType_PacketizedTCP: return *new PacketizedTCPConnection();
460 default:
461 OVR_ASSERT(false);
462 break;
463 }
465 return NULL;
466 }
468 Ptr<PacketizedTCPConnection> Session::findConnectionBySocket(Array< Ptr<Connection> >& connectionArray, Socket* s, int *connectionIndex)
469 {
470 const int count = connectionArray.GetSizeI();
471 for (int i = 0; i < count; ++i)
472 {
473 Connection* arrayItem = connectionArray[i].GetPtr();
475 if (arrayItem->Transport == TransportType_PacketizedTCP)
476 {
477 PacketizedTCPConnection* ptc = (PacketizedTCPConnection*)arrayItem;
479 if (ptc->pSocket == s)
480 {
481 if (connectionIndex)
482 {
483 *connectionIndex = i;
484 }
485 return ptc;
486 }
487 }
488 }
490 return NULL;
491 }
493 int Session::invokeSessionListeners(ReceivePayload* rp)
494 {
495 Lock::Locker locker(&SessionListenersLock);
497 const int count = SessionListeners.GetSizeI();
498 for (int j = 0; j < count; ++j)
499 {
500 ListenerReceiveResult lrr = LRR_CONTINUE;
501 SessionListeners[j]->OnReceive(rp, &lrr);
503 if (lrr == LRR_RETURN || lrr == LRR_BREAK)
504 {
505 break;
506 }
507 }
509 return rp->Bytes;
510 }
512 void Session::TCP_OnRecv(Socket* pSocket, uint8_t* pData, int bytesRead)
513 {
514 // KevinJ: 9/2/2014 Fix deadlock - Watchdog calls Broadcast(), which locks ConnectionsLock().
515 // Lock::Locker locker(&ConnectionsLock);
517 // Look for the connection in the full connection list first
518 int connIndex;
519 ConnectionsLock.DoLock();
520 Ptr<PacketizedTCPConnection> conn = findConnectionBySocket(AllConnections, pSocket, &connIndex);
521 ConnectionsLock.Unlock();
522 if (conn)
523 {
524 if (conn->State == State_Connected)
525 {
526 ReceivePayload rp;
527 rp.Bytes = bytesRead;
528 rp.pConnection = conn;
529 rp.pData = pData;
531 // Call listeners
532 invokeSessionListeners(&rp);
533 }
534 else if (conn->State == Client_ConnectedWait)
535 {
536 // Check the version data from the message
537 BitStream bsIn((char*)pData, bytesRead, false);
539 RPC_S2C_Authorization auth;
540 if (!auth.Deserialize(&bsIn) ||
541 !auth.Validate())
542 {
543 LogError("{ERR-001} [Session] REJECTED: OVRService did not authorize us: %s", auth.AuthString.ToCStr());
545 conn->SetState(State_Zombie);
546 invokeSessionEvent(&SessionListener::OnIncompatibleProtocol, conn);
547 }
548 else
549 {
550 // Read remote version
551 conn->RemoteMajorVersion = auth.MajorVersion;
552 conn->RemoteMinorVersion = auth.MinorVersion;
553 conn->RemotePatchVersion = auth.PatchVersion;
555 // Mark as connected
556 conn->SetState(State_Connected);
557 ConnectionsLock.DoLock();
558 int connIndex2;
559 if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL)
560 {
561 FullConnections.PushBack(conn);
562 }
563 ConnectionsLock.Unlock();
564 invokeSessionEvent(&SessionListener::OnConnectionRequestAccepted, conn);
565 }
566 }
567 else if (conn->State == Server_ConnectedWait)
568 {
569 // Check the version data from the message
570 BitStream bsIn((char*)pData, bytesRead, false);
572 RPC_C2S_Hello hello;
573 if (!hello.Deserialize(&bsIn) ||
574 !hello.Validate())
575 {
576 LogError("{ERR-002} [Session] REJECTED: Rift application is using an incompatible version %d.%d.%d (my version=%d.%d.%d)",
577 hello.MajorVersion, hello.MinorVersion, hello.PatchVersion,
578 RPCVersion_Major, RPCVersion_Minor, RPCVersion_Patch);
580 conn->SetState(State_Zombie);
582 // Send auth response
583 BitStream bsOut;
584 RPC_S2C_Authorization::Generate(&bsOut, "Incompatible protocol version. Please make sure your OVRService and SDK are both up to date.");
585 conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed());
586 }
587 else
588 {
589 // Read remote version
590 conn->RemoteMajorVersion = hello.MajorVersion;
591 conn->RemoteMinorVersion = hello.MinorVersion;
592 conn->RemotePatchVersion = hello.PatchVersion;
594 // Send auth response
595 BitStream bsOut;
596 RPC_S2C_Authorization::Generate(&bsOut);
597 conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed());
599 // Mark as connected
600 conn->SetState(State_Connected);
601 ConnectionsLock.DoLock();
602 int connIndex2;
603 if (findConnectionBySocket(AllConnections, pSocket, &connIndex2)==conn && findConnectionBySocket(FullConnections, pSocket, &connIndex2)==NULL)
604 {
605 FullConnections.PushBack(conn);
606 }
607 ConnectionsLock.Unlock();
608 invokeSessionEvent(&SessionListener::OnNewIncomingConnection, conn);
610 }
611 }
612 else
613 {
614 OVR_ASSERT(false);
615 }
616 }
617 }
619 void Session::TCP_OnClosed(TCPSocket* s)
620 {
621 Lock::Locker locker(&ConnectionsLock);
623 // If found in the full connection list,
624 int connIndex;
625 Ptr<PacketizedTCPConnection> conn = findConnectionBySocket(AllConnections, s, &connIndex);
626 if (conn)
627 {
628 AllConnections.RemoveAtUnordered(connIndex);
630 // If in the full connection list,
631 if (findConnectionBySocket(FullConnections, s, &connIndex))
632 {
633 FullConnections.RemoveAtUnordered(connIndex);
634 }
636 // Generate an appropriate event for the current state
637 switch (conn->State)
638 {
639 case Client_Connecting:
640 invokeSessionEvent(&SessionListener::OnConnectionAttemptFailed, conn);
641 break;
642 case Client_ConnectedWait:
643 case Server_ConnectedWait:
644 invokeSessionEvent(&SessionListener::OnHandshakeAttemptFailed, conn);
645 break;
646 case State_Connected:
647 case State_Zombie:
648 invokeSessionEvent(&SessionListener::OnDisconnected, conn);
649 break;
650 default:
651 OVR_ASSERT(false);
652 break;
653 }
655 conn->SetState(State_Zombie);
656 }
657 }
659 void Session::TCP_OnAccept(TCPSocket* pListener, SockAddr* pSockAddr, SocketHandle newSock)
660 {
661 OVR_UNUSED(pListener);
662 OVR_ASSERT(pListener->Transport == TransportType_PacketizedTCP);
665 Ptr<PacketizedTCPSocket> newSocket = *new PacketizedTCPSocket(newSock, false);
666 // If pSockAddr is not localhost, then close newSock
667 if (pSockAddr->IsLocalhost()==false)
668 {
669 newSocket->Close();
670 return;
671 }
673 if (newSocket)
674 {
675 Ptr<Connection> b = AllocConnection(TransportType_PacketizedTCP);
676 Ptr<PacketizedTCPConnection> c = (PacketizedTCPConnection*)b.GetPtr();
677 c->pSocket = newSocket;
678 c->Address = *pSockAddr;
679 c->State = Server_ConnectedWait;
681 {
682 Lock::Locker locker(&ConnectionsLock);
683 AllConnections.PushBack(c);
684 }
686 // Server does not send the first packet. It waits for the client to send its version
687 }
688 }
690 void Session::TCP_OnConnected(TCPSocket *s)
691 {
692 Lock::Locker locker(&ConnectionsLock);
694 // If connection was found,
695 PacketizedTCPConnection* conn = findConnectionBySocket(AllConnections, s);
696 if (conn)
697 {
698 OVR_ASSERT(conn->State == Client_Connecting);
700 // Send hello message
701 BitStream bsOut;
702 RPC_C2S_Hello::Generate(&bsOut);
703 conn->pSocket->Send(bsOut.GetData(), bsOut.GetNumberOfBytesUsed());
705 // Just update state but do not generate any notifications yet
706 conn->State = Client_ConnectedWait;
707 }
708 }
710 void Session::invokeSessionEvent(void(SessionListener::*f)(Connection*), Connection* conn)
711 {
712 Lock::Locker locker(&SessionListenersLock);
714 const int count = SessionListeners.GetSizeI();
715 for (int i = 0; i < count; ++i)
716 {
717 (SessionListeners[i]->*f)(conn);
718 }
719 }
721 Ptr<Connection> Session::GetConnectionAtIndex(int index)
722 {
723 Lock::Locker locker(&ConnectionsLock);
725 const int count = FullConnections.GetSizeI();
727 if (index < count)
728 {
729 return FullConnections[index];
730 }
732 return NULL;
733 }
736 }} // OVR::Net