Parcourir la source

Fix lookup issue on TCPObject. Also sync with working code.

James Urquhart il y a 9 ans
Parent
commit
5b1bb6547a

+ 1 - 1
Engine/source/app/mainLoop.cpp

@@ -320,7 +320,7 @@ void StandardMainLoop::init()
    Sampler::init();
 
    // Hook in for UDP notification
-   Net::smPacketReceive.notify(GNet, &NetInterface::processPacketReceiveEvent);
+   Net::getPacketReceiveEvent().notify(GNet, &NetInterface::processPacketReceiveEvent);
 
    #ifdef TORQUE_DEBUG_GUARD
       Memory::flagCurrentAllocs( Memory::FLAG_Static );

+ 6 - 6
Engine/source/app/net/tcpObject.cpp

@@ -215,9 +215,9 @@ TCPObject::TCPObject()
 
    if(gTCPCount == 1)
    {
-      Net::smConnectionAccept.notify(processConnectedAcceptEvent);
-      Net::smConnectionReceive.notify(processConnectedReceiveEvent);
-      Net::smConnectionNotify.notify(processConnectedNotifyEvent);
+      Net::getConnectionAcceptedEvent().notify(processConnectedAcceptEvent);
+      Net::getConnectionReceiveEvent().notify(processConnectedReceiveEvent);
+      Net::getConnectionNotifyEvent().notify(processConnectedNotifyEvent);
    }
 }
 
@@ -230,9 +230,9 @@ TCPObject::~TCPObject()
 
    if(gTCPCount == 0)
    {
-      Net::smConnectionAccept.remove(processConnectedAcceptEvent);
-      Net::smConnectionReceive.remove(processConnectedReceiveEvent);
-      Net::smConnectionNotify.remove(processConnectedNotifyEvent);
+      Net::getConnectionAcceptedEvent().remove(processConnectedAcceptEvent);
+      Net::getConnectionReceiveEvent().remove(processConnectedReceiveEvent);
+      Net::getConnectionNotifyEvent().remove(processConnectedNotifyEvent);
    }
 }
 

+ 131 - 45
Engine/source/platform/platformNet.cpp

@@ -157,7 +157,25 @@ NetSocket NetSocket::INVALID = NetSocket::fromHandle(-1);
 template<class T> class ReservedSocketList
 {
 public:
-   Vector<T> mSocketList;
+   struct EntryType
+   {
+      T value;
+      bool used;
+
+      EntryType() : value(-1), used(false) { ; }
+
+      bool operator==(const EntryType &e1)
+      {
+         return value == e1.value && used == e1.used;
+      }
+
+      bool operator!=(const EntryType &e1)
+      {
+         return !(value == e1.value && used == e1.used);
+      }
+   };
+
+   Vector<EntryType> mSocketList;
    Mutex *mMutex;
 
    ReservedSocketList()
@@ -349,15 +367,20 @@ template<class T> NetSocket ReservedSocketList<T>::reserve(SOCKET reserveId, boo
       handle.lock(mMutex, true);
    }
 
-   S32 idx = mSocketList.find_next(-1);
+   S32 idx = mSocketList.find_next(EntryType());
    if (idx == -1)
    {
-      mSocketList.push_back(reserveId);
+      EntryType entry;
+      entry.value = reserveId;
+      entry.used = true;
+      mSocketList.push_back(entry);
       return NetSocket::fromHandle(mSocketList.size() - 1);
    }
    else
    {
-      mSocketList[idx] = reserveId;
+      EntryType &entry = mSocketList[idx];
+      entry.used = true;
+      entry.value = reserveId;
    }
 
    return NetSocket::fromHandle(idx);
@@ -374,7 +397,7 @@ template<class T> void ReservedSocketList<T>::remove(NetSocket socketToRemove, b
    if ((U32)socketToRemove.getHandle() >= (U32)mSocketList.size())
       return;
 
-   mSocketList[socketToRemove.getHandle()] = -1;
+   mSocketList[socketToRemove.getHandle()] = EntryType();
 }
 
 template<class T> T ReservedSocketList<T>::activate(NetSocket socketToActivate, int family, bool useUDP, bool clearOnFail)
@@ -388,7 +411,11 @@ template<class T> T ReservedSocketList<T>::activate(NetSocket socketToActivate,
    if ((U32)socketToActivate.getHandle() >= (U32)mSocketList.size())
       return -1;
 
-   T socketFd = mSocketList[socketToActivate.getHandle()];
+   EntryType &entry = mSocketList[socketToActivate.getHandle()];
+   if (!entry.used)
+      return -1;
+
+   T socketFd = entry.value;
    if (socketFd == -1)
    {
       socketFd = ::socket(family, typeID, protocol);
@@ -403,7 +430,8 @@ template<class T> T ReservedSocketList<T>::activate(NetSocket socketToActivate,
       }
       else
       {
-         mSocketList[socketToActivate.getHandle()] = socketFd;
+         entry.used = true;
+         entry.value = socketFd;
          return socketFd;
       }
    }
@@ -419,13 +447,34 @@ template<class T> T ReservedSocketList<T>::resolve(NetSocket socketToResolve)
    if ((U32)socketToResolve.getHandle() >= (U32)mSocketList.size())
       return -1;
 
-   return mSocketList[socketToResolve.getHandle()];
+   EntryType &entry = mSocketList[socketToResolve.getHandle()];
+   return entry.used ? entry.value : -1;
+}
+
+static ConnectionNotifyEvent*   smConnectionNotify = NULL;
+static ConnectionAcceptedEvent* smConnectionAccept = NULL;
+static ConnectionReceiveEvent*  smConnectionReceive = NULL;
+static PacketReceiveEvent*      smPacketReceive = NULL;
+
+ConnectionNotifyEvent& Net::getConnectionNotifyEvent()
+{
+   return *smConnectionNotify;
 }
 
-ConnectionNotifyEvent   Net::smConnectionNotify;
-ConnectionAcceptedEvent Net::smConnectionAccept;
-ConnectionReceiveEvent  Net::smConnectionReceive;
-PacketReceiveEvent      Net::smPacketReceive;
+ConnectionAcceptedEvent& Net::getConnectionAcceptedEvent()
+{
+   return *smConnectionAccept;
+}
+
+ConnectionReceiveEvent& Net::getConnectionReceiveEvent()
+{
+   return *smConnectionReceive;
+}
+
+PacketReceiveEvent& Net::getPacketReceiveEvent()
+{
+   return *smPacketReceive;
+}
 
 // Multicast stuff
 bool Net::smMulticastEnabled = true;
@@ -528,6 +577,12 @@ bool Net::init()
 #endif
    PlatformNetState::initCount++;
 
+   smConnectionNotify = new ConnectionNotifyEvent();
+   smConnectionAccept = new ConnectionAcceptedEvent();
+   smConnectionReceive = new ConnectionReceiveEvent();
+   smPacketReceive = new PacketReceiveEvent();
+
+
    Process::notify(&Net::process, PROCESS_NET_ORDER);
 
    return(true);
@@ -543,6 +598,12 @@ void Net::shutdown()
    closePort();
    PlatformNetState::initCount--;
 
+   // Destroy event handlers
+   delete smConnectionNotify;
+   delete smConnectionAccept;
+   delete smConnectionReceive;
+   delete smPacketReceive;
+
 #if defined(TORQUE_USE_WINSOCK)
    if(!PlatformNetState::initCount)
    {
@@ -628,7 +689,7 @@ NetSocket Net::openListenPort(U16 port, NetAddress::Type addressType)
 
    Net::Error error = NoError;
    NetAddress address;
-   if (!Net::getListenAddress(addressType, &address))
+   if (Net::getListenAddress(addressType, &address) != Net::NoError)
       error = Net::WrongProtocolType;
 
    NetSocket handleFd = NetSocket::INVALID;
@@ -767,14 +828,24 @@ NetSocket Net::openConnectTo(const char *addressString)
    {
       // need to do an asynchronous name lookup.  first, add the socket
       // to the polled list
-      char addressString[256];
-      Net::addressToString(&address, addressString);
-      addPolledSocket(handleFd, InvalidSocketHandle, PolledSocket::NameLookupRequired, addressString, address.port);
-      // queue the lookup
-      gNetAsync.queueLookup(addressString, handleFd);
+      char addr[256];
+      int port = 0;
+      int actualFamily = AF_UNSPEC;
+      if (PlatformNetState::extractAddressParts(addressString, addr, port, actualFamily))
+      {
+         addPolledSocket(handleFd, InvalidSocketHandle, PolledSocket::NameLookupRequired, addr, port);
+         // queue the lookup
+         gNetAsync.queueLookup(addressString, handleFd);
+      }
+      else
+      {
+         closeSocket(handleFd);
+         handleFd = NetSocket::INVALID;
+      }
    }
    else
    {
+      closeSocket(handleFd);
       handleFd = NetSocket::INVALID;
    }
 
@@ -802,20 +873,31 @@ void Net::closeConnectTo(NetSocket handleFd)
    closeSocket(handleFd);
 }
 
-Net::Error Net::sendtoSocket(NetSocket handleFd, const U8 *buffer, S32  bufferSize)
+Net::Error Net::sendtoSocket(NetSocket handleFd, const U8 *buffer, S32  bufferSize, S32 *outBufferWritten)
 {
    if(Journal::IsPlaying())
    {
       U32 e;
+      S32 outBytes;
       Journal::Read(&e);
+      Journal::Read(&outBytes);
+      if (outBufferWritten)
+         *outBufferWritten = outBytes;
 
       return (Net::Error) e;
    }
 
-   Net::Error e = send(handleFd, buffer, bufferSize);
+   S32 outBytes = 0;
+   Net::Error e = send(handleFd, buffer, bufferSize, &outBytes);
 
-   if(Journal::IsRecording())
+   if (Journal::IsRecording())
+   {
       Journal::Write(U32(e));
+      Journal::Write(outBytes);
+   }
+
+   if (outBufferWritten)
+      *outBufferWritten = outBytes;
 
    return e;
 }
@@ -854,7 +936,7 @@ bool Net::openPort(S32 port, bool doBind)
 
    if (Net::smIpv4Enabled)
    {
-      if (Net::getListenAddress(NetAddress::IPAddress, &address))
+      if (Net::getListenAddress(NetAddress::IPAddress, &address) == Net::NoError)
       {
          address.port = port;
          socketFd = ::socket(AF_INET, SOCK_DGRAM, protocol);
@@ -900,7 +982,7 @@ bool Net::openPort(S32 port, bool doBind)
    
    if (Net::smIpv6Enabled)
    {
-      if (Net::getListenAddress(NetAddress::IPV6Address, &address))
+      if (Net::getListenAddress(NetAddress::IPV6Address, &address) == Net::NoError)
       {
          address.port = port;
          socketFd = ::socket(AF_INET6, SOCK_DGRAM, protocol);
@@ -1062,7 +1144,7 @@ void Net::process()
          {
             Con::errorf("Error getting socket options: %s",  strerror(errno));
 
-            Net::smConnectionNotify.trigger(currentSock->handleFd, Net::ConnectFailed);
+            smConnectionNotify->trigger(currentSock->handleFd, Net::ConnectFailed);
             removeSock = true;
          }
          else
@@ -1079,13 +1161,13 @@ void Net::process()
                   break;
 
                currentSock->state = PolledSocket::Connected;
-               Net::smConnectionNotify.trigger(currentSock->handleFd, Net::Connected);
+               smConnectionNotify->trigger(currentSock->handleFd, Net::Connected);
             }
             else
             {
                // some kind of error
                Con::errorf("Error connecting: %s", strerror(errno));
-               Net::smConnectionNotify.trigger(currentSock->handleFd, Net::ConnectFailed);
+               smConnectionNotify->trigger(currentSock->handleFd, Net::ConnectFailed);
                removeSock = true;
             }
          }
@@ -1102,7 +1184,7 @@ void Net::process()
             {
                // got some data, post it
                readBuff.size = bytesRead;
-               Net::smConnectionReceive.trigger(currentSock->handleFd, readBuff);
+               smConnectionReceive->trigger(currentSock->handleFd, readBuff);
             }
             else
             {
@@ -1111,7 +1193,7 @@ void Net::process()
                   Con::errorf("Unexpected error on socket: %s", strerror(errno));
 
                // zero bytes read means EOF
-               Net::smConnectionNotify.trigger(currentSock->handleFd, Net::Disconnected);
+               smConnectionNotify->trigger(currentSock->handleFd, Net::Disconnected);
 
                removeSock = true;
             }
@@ -1119,7 +1201,7 @@ void Net::process()
          else if (err != Net::NoError && err != Net::WouldBlock)
          {
             Con::errorf("Error reading from socket: %s",  strerror(errno));
-            Net::smConnectionNotify.trigger(currentSock->handleFd, Net::Disconnected);
+            smConnectionNotify->trigger(currentSock->handleFd, Net::Disconnected);
             removeSock = true;
          }
          break;
@@ -1194,17 +1276,18 @@ void Net::process()
                if (::connect(currentSock->fd, ai_addr,
                   ai_addrlen) == -1)
                {
-                  if (errno == EINPROGRESS)
+                  err = PlatformNetState::getLastError();
+                  if (err != Net::WouldBlock)
                   {
-                     newState = Net::DNSResolved;
-                     currentSock->state = PolledSocket::ConnectionPending;
+                     Con::errorf("Error connecting to %s: %u",
+                     currentSock->remoteAddr, err);
+                     newState = Net::ConnectFailed;
+                     removeSock = true;
                   }
                   else
                   {
-                     Con::errorf("Error connecting to %s: %s",
-                        currentSock->remoteAddr, strerror(errno));
-                     newState = Net::ConnectFailed;
-                     removeSock = true;
+                     newState = Net::DNSResolved;
+                     currentSock->state = PolledSocket::ConnectionPending;
                   }
                }
                else
@@ -1215,7 +1298,7 @@ void Net::process()
             }
          }
 
-         Net::smConnectionNotify.trigger(currentSock->handleFd, newState);
+         smConnectionNotify->trigger(currentSock->handleFd, newState);
          break;
       case PolledSocket::Listening:
          NetAddress incomingAddy;
@@ -1225,7 +1308,7 @@ void Net::process()
          {
             setBlocking(incomingHandleFd, false);
             addPolledSocket(incomingHandleFd, PlatformNetState::smReservedSocketList.resolve(incomingHandleFd), Connected);
-            Net::smConnectionAccept.trigger(currentSock->handleFd, incomingHandleFd, incomingAddy);
+            smConnectionAccept->trigger(currentSock->handleFd, incomingHandleFd, incomingAddy);
          }
          break;
       }
@@ -1283,7 +1366,7 @@ void Net::processListenSocket(NetSocket socketHandle)
 
       tmpBuffer.size = bytesRead;
 
-      Net::smPacketReceive.trigger(srcAddress, tmpBuffer);
+      smPacketReceive->trigger(srcAddress, tmpBuffer);
    }
 }
 
@@ -1458,7 +1541,7 @@ Net::Error Net::setBlocking(NetSocket handleFd, bool blockingIO)
    return PlatformNetState::getLastError();
 }
 
-bool Net::getListenAddress(const NetAddress::Type type, NetAddress *address, bool forceDefaults)
+Net::Error Net::getListenAddress(const NetAddress::Type type, NetAddress *address, bool forceDefaults)
 {
    if (type == NetAddress::IPAddress)
    {
@@ -1468,7 +1551,7 @@ bool Net::getListenAddress(const NetAddress::Type type, NetAddress *address, boo
          address->type = type;
          address->port = PlatformNetState::defaultPort;
          *((U32*)address->address.ipv4.netNum) = INADDR_ANY;
-         return true;
+         return Net::NoError;
       }
       else
       {
@@ -1480,7 +1563,7 @@ bool Net::getListenAddress(const NetAddress::Type type, NetAddress *address, boo
       address->type = type;
       address->port = PlatformNetState::defaultPort;
       *((U32*)address->address.ipv4.netNum) = INADDR_BROADCAST;
-      return true;
+      return Net::NoError;
    }
    else if (type == NetAddress::IPV6Address)
    {
@@ -1494,7 +1577,7 @@ bool Net::getListenAddress(const NetAddress::Type type, NetAddress *address, boo
          addr.sin6_addr = in6addr_any;
 
          IPSocket6ToNetAddress(&addr, address);
-         return true;
+         return Net::NoError;
       }
       else
       {
@@ -1513,7 +1596,7 @@ bool Net::getListenAddress(const NetAddress::Type type, NetAddress *address, boo
    }
    else
    {
-      return false;
+      return Net::WrongProtocolType;
    }
 }
 
@@ -1537,7 +1620,7 @@ void Net::getIdealListenAddress(NetAddress *address)
    }
 }
 
-Net::Error Net::send(NetSocket handleFd, const U8 *buffer, S32 bufferSize)
+Net::Error Net::send(NetSocket handleFd, const U8 *buffer, S32 bufferSize, S32 *outBytesWritten)
 {
    SOCKET socketFd = PlatformNetState::smReservedSocketList.resolve(handleFd);
    if (socketFd == InvalidSocketHandle)
@@ -1552,6 +1635,9 @@ Net::Error Net::send(NetSocket handleFd, const U8 *buffer, S32 bufferSize)
       Con::errorf("Could not write to socket. Error: %s",strerror(errno));
 #endif
 
+   if (outBytesWritten)
+      *outBytesWritten = bytesWritten;
+
    return PlatformNetState::getLastError();
 }
 

+ 7 - 7
Engine/source/platform/platformNet.h

@@ -204,10 +204,10 @@ struct Net
 
    static const S32 MaxPacketDataSize = MAXPACKETSIZE;
 
-   static ConnectionNotifyEvent   smConnectionNotify;
-   static ConnectionAcceptedEvent smConnectionAccept;
-   static ConnectionReceiveEvent  smConnectionReceive;
-   static PacketReceiveEvent      smPacketReceive;
+   static ConnectionNotifyEvent&   getConnectionNotifyEvent();
+   static ConnectionAcceptedEvent& getConnectionAcceptedEvent();
+   static ConnectionReceiveEvent&  getConnectionReceiveEvent();
+   static PacketReceiveEvent&      getPacketReceiveEvent();
 
    static bool smMulticastEnabled;
    static bool smIpv4Enabled;
@@ -232,7 +232,7 @@ struct Net
    static NetSocket openListenPort(U16 port, NetAddress::Type = NetAddress::IPAddress);
    static NetSocket openConnectTo(const char *stringAddress); // does the DNS resolve etc.
    static void closeConnectTo(NetSocket socket);
-   static Error sendtoSocket(NetSocket socket, const U8 *buffer, S32 bufferSize);
+   static Error sendtoSocket(NetSocket socket, const U8 *buffer, S32 bufferSize, S32 *bytesWritten=NULL);
 
    static bool compareAddresses(const NetAddress *a1, const NetAddress *a2);
    static Net::Error stringToAddress(const char *addressString, NetAddress *address, bool hostLookup=true, int family=0);
@@ -242,7 +242,7 @@ struct Net
    static NetSocket openSocket();
    static Error closeSocket(NetSocket socket);
 
-   static Error send(NetSocket socket, const U8 *buffer, S32 bufferSize);
+   static Error send(NetSocket socket, const U8 *buffer, S32 bufferSize, S32 *outBytesWritten=NULL);
    static Error recv(NetSocket socket, U8 *buffer, S32 bufferSize, S32 *bytesRead);
 
    static Error connect(NetSocket socket, const NetAddress *address);
@@ -255,7 +255,7 @@ struct Net
    static Error setBlocking(NetSocket socket, bool blockingIO);
 
    /// Gets the desired default listen address for a specified address type
-   static bool getListenAddress(const NetAddress::Type type, NetAddress *address, bool forceDefaults=false);
+   static Net::Error getListenAddress(const NetAddress::Type type, NetAddress *address, bool forceDefaults=false);
    static void getIdealListenAddress(NetAddress *address);
 
    // Multicast for ipv6 local net browsing