netlib.cpp

Go to the documentation of this file.
00001 /*
00002  * netlib.cpp
00003  *
00004  * Copyright (C) 2007-2009  Thomas A. Vaughan
00005  * All rights reserved.
00006  *
00007  *
00008  * Redistribution and use in source and binary forms, with or without
00009  * modification, are permitted provided that the following conditions are met:
00010  *     * Redistributions of source code must retain the above copyright
00011  *       notice, this list of conditions and the following disclaimer.
00012  *     * Redistributions in binary form must reproduce the above copyright
00013  *       notice, this list of conditions and the following disclaimer in the
00014  *       documentation and/or other materials provided with the distribution.
00015  *     * Neither the name of the <organization> nor the
00016  *       names of its contributors may be used to endorse or promote products
00017  *       derived from this software without specific prior written permission.
00018  *
00019  * THIS SOFTWARE IS PROVIDED BY THOMAS A. VAUGHAN ''AS IS'' AND ANY
00020  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00021  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
00022  * DISCLAIMED. IN NO EVENT SHALL THOMAS A. VAUGHAN BE LIABLE FOR ANY
00023  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
00024  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
00025  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
00026  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
00027  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00028  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00029  *
00030  *
00031  * Implementation of the networking library.  See netlib.h
00032  */
00033 
00034 // includes --------------------------------------------------------------------
00035 #include "netlib.h"             // always include our own header first!
00036 #include "wavesock.h"
00037 
00038 #include <deque>
00039 
00040 #ifndef _XOPEN_SOURCE
00041 #define _XOPEN_SOURCE 600
00042 #endif  // _XOPEN_SOURCE
00043 
00044 #include <string.h>
00045 
00046 #include "common/wave_ex.h"
00047 #include "perf/perf.h"
00048 #include "util/parsing.h"
00049 
00050 #ifdef WIN32
00051 typedef long ssize_t;
00052 #endif  // WIN32
00053 
00054 
00055 namespace netlib {
00056 
00057 
00058 // use a small size for testing, large size for production
00059 static const int s_chunkSize            = 8192;
00060 
00061 static const int s_maxHeaderLine        = 64;
00062 
00063 struct request_t {
00064         // constructor, manipulators
00065         request_t(void) throw() { this->clear(); }
00066         void clear(void) throw() {
00067                         connId = 0;
00068                         msgbuf = NULL;
00069                 }
00070         bool is_empty(void) const throw() {
00071                         return (!connId && !msgbuf);
00072                 }
00073 
00074         // data fields
00075         conn_id_t                       connId;
00076         smart_ptr<MessageBuffer>        msgbuf;
00077 };
00078 
00079 
00080 
00081 // TODO: don't require memory alloc/free!  Keep a free list
00082 typedef std::deque<request_t> message_queue_t;
00083 
00084 
00085 
00086 // conn_rec_t : connection record
00087 struct conn_rec_t {
00088         // constructor, manipulators
00089         conn_rec_t(void) throw() : socket(-1) { this->clear(); }
00090         ~conn_rec_t(void) { this->clear(); }
00091         void clear(void) throw() { 
00092                         if (wsIsValidSocket(socket)) {
00093                                 DPRINTF("Closing down socket!");
00094                                 DPRINTF("Connection id = 0x%lx", (long) conn_id);
00095                                 wsCloseSocket(socket);
00096                         }
00097 
00098                         conn_id = 0;
00099                         local = 0;
00100                         socket = -1;
00101                         need_bytes = -1;
00102                         type = eType_Invalid;
00103                         msgbuf = NULL;
00104                         message_queue.clear();
00105                         send_byte = -1;
00106                         buffer[0] = 0;
00107                         buff_idx = -1;  // empty
00108                         buff_len = 0;
00109                         header[0] = 0;
00110                         head_idx = 0;
00111                         udpFrom.clear();
00112                         address.clear();
00113                 }
00114         void dump(IN const char * text) const throw() {
00115                         DPRINTF("%s", text);
00116                         DPRINTF("  connId: 0x%04lx", (long) conn_id);
00117                         DPRINTF("  socket: %d", socket);
00118                         DPRINTF("  type: %d", type);
00119                         address.dump(text);
00120                 }
00121 
00122         // data fields
00123         conn_id_t       conn_id;
00124         conn_id_t       local;          // local peer connection (UDP)
00125         int             socket;
00126         long            need_bytes;     // bytes needed to complete message
00127         smart_ptr<MessageBuffer> msgbuf;// long-lived message buffer
00128         eConnectionType type;           // what sort of connection?
00129         message_queue_t message_queue;  // pending messages to write
00130         address_t       address;        // where messages go
00131         address_t       udpFrom;        // for received UDP messages
00132         long            send_byte;      // current send byte
00133         char            buffer[s_chunkSize];    // buffer for reading
00134         int             buff_idx;       // where in buffer are we?
00135         ssize_t         buff_len;       // how much did we read?
00136         char            header[s_maxHeaderLine];
00137         int             head_idx;       // index into header line
00138 };
00139 
00140 // connection ID --> connection record
00141 typedef std::map<conn_id_t, smart_ptr<conn_rec_t> > conn_map_t;
00142 
00143 static conn_map_t s_connection_map;
00144 
00145 
00146 
00147 // stats!
00148 static dword_t s_messagesSent                   = 0;
00149 static dword_t s_messagesReceived               = 0;
00150 static qword_t s_bytesWritten                   = 0;
00151 static qword_t s_bytesRead                      = 0;
00152 
00153 
00154 // connection type names
00155 
00156 
00157 ////////////////////////////////////////////////////////////////////////////////
00158 //
00159 //      static helper methods
00160 //
00161 ////////////////////////////////////////////////////////////////////////////////
00162 
00163 static const char *
00164 getTypeName
00165 (
00166 IN eConnectionType type
00167 )
00168 {
00169         switch (type) {
00170 
00171         case eType_TCP:
00172                 return "TCP client";
00173 
00174         case eType_UDPLocal:
00175                 return "Local UDP port";
00176 
00177         case eType_UDPRemote:
00178                 return "Remote UDP port";
00179 
00180         case eType_TCPListener:
00181                 return "Local TCP Listener";
00182 
00183         default:
00184                 break;
00185         }
00186         return "Unknown connection type!";
00187 }
00188 
00189 
00190 
00191 static conn_id_t
00192 getNewConnectionId
00193 (
00194 void
00195 )
00196 {
00197         // originally I was just using a dumb counter, to avoid having
00198         // to look for collisions etc.  But counters can wrap!  So I'm
00199         // using random numbers
00200         static const dword_t s_dwMax = 0x10000 - 1;
00201         for (;;) {
00202                 conn_id_t conn_id = 1 + (rand() % s_dwMax);
00203                 if (s_connection_map.end() == s_connection_map.find(conn_id))
00204                         return conn_id;
00205         }
00206 }
00207 
00208 
00209 
00210 static void
00211 dumpErrorInfo
00212 (
00213 IN const char * msg
00214 )
00215 {
00216         const int s_bufsize = 256;
00217         char buffer[s_bufsize];
00218 
00219         wsGetErrorMessage(buffer, s_bufsize);
00220 
00221         DPRINTF("%s", msg);
00222         DPRINTF("%s", buffer);
00223 }
00224 
00225 
00226 
00227 static void
00228 verify
00229 (
00230 IN bool isOK,
00231 IN const char * msg
00232 )
00233 {
00234         if (isOK)
00235                 return;         // no problem!
00236 
00237         // if we're here, there is a big problem!
00238         dumpErrorInfo(msg);
00239         ASSERT(false, "halting");
00240 }
00241 
00242 
00243 
00244 static void
00245 verifyThrow
00246 (
00247 IN bool isOK,
00248 IN const char * msg
00249 )
00250 {
00251         if (isOK)
00252                 return;         // no problem!
00253 
00254         const int s_bufsize = 256;
00255         char buffer[s_bufsize];
00256         wsGetErrorMessage(buffer, s_bufsize);
00257 
00258         ASSERT(msg, "null");
00259         DPRINTF("Error!  '%s' on '%s'", buffer, msg);
00260 
00261         WAVE_EX(wex);
00262         wex << msg << "\n";
00263         wex << "Error: " << buffer;
00264 }
00265 
00266 
00267 
00268 static conn_rec_t *
00269 getConnectionRecord
00270 (
00271 IN conn_id_t conn_id
00272 )
00273 {
00274         ASSERT(conn_id, "null");
00275 
00276         conn_map_t::iterator i = s_connection_map.find(conn_id);
00277         if (s_connection_map.end() == i) {
00278                 DPRINTF("Connection ID 0x%lx not found!", (long) conn_id);
00279                 return NULL;
00280         }
00281 
00282         return i->second;
00283 }
00284 
00285 
00286 
00287 static bool
00288 readBuffer
00289 (
00290 IN conn_rec_t * rec
00291 )
00292 {
00293         // timer itself impacts timing!
00294         // perf::Timer timer("netlib::readBuffer");
00295         ASSERT(rec, "null");
00296 
00297         // reset
00298         rec->buffer[0] = 0;
00299         rec->buff_len = 0;
00300         rec->buff_idx = -1;
00301 
00302         // make nonblocking call to see if we have any data to read
00303         // NOTE: the type of read call depends on the type of socket!
00304         ssize_t bytes = -2;
00305         if (eType_TCP == rec->type) {
00306                 bytes = wsReceive(rec->socket, rec->buffer, s_chunkSize - 1);
00307         } else if (eType_UDPLocal == rec->type) {
00308                 // DPRINTF("Reading UDP packet...");
00309                 bytes = wsReceiveFrom(rec->socket, rec->buffer, s_chunkSize - 1,
00310                     rec->udpFrom);
00311         } else {
00312                 DPRINTF("Bad local connection type?  Disconnecting...");
00313                 bytes = 0;
00314         }
00315 
00316         // error?
00317         if (bytes < 0) {
00318                 if (eWS_Again == wsGetError()) {
00319                         // no problem, try again later
00320                         return false;
00321                 }
00322                 DPRINTF("Error receiving!  Disconnecting client...");
00323                 bytes = 0;
00324         }
00325 
00326         // client gave up?
00327         if (0 == bytes) {
00328                 DPRINTF("Client has disconnected");
00329                 closeConnection(rec->conn_id);
00330                 return false;
00331         }
00332 
00333         s_bytesRead += bytes;
00334 
00335         // received data!
00336         if (bytes >= s_chunkSize) {
00337                 DPRINTF("ERROR: client sent too many bytes!");
00338                 DPRINTF("Our buffer size is %d bytes",
00339                     s_chunkSize - 1);
00340                 DPRINTF("Client sent %ld bytes", (long) bytes);
00341                 DPRINTF("Truncating data!");
00342                 ASSERT(false, "HALT");  // this is bad
00343                 bytes = s_chunkSize - 1;
00344         }
00345         rec->buffer[bytes] = 0; // force null-termination
00346         rec->buff_len = bytes;
00347         rec->buff_idx = 0;
00348 
00349         return true;
00350 }
00351 
00352 
00353 
00354 static void
00355 parseHeaderLine
00356 (
00357 IN conn_rec_t * rec
00358 )
00359 {
00360         // at the moment, this routine is fast enough that the timer itself
00361         //   adds significant time!
00362         // perf::Timer timer("netlib::parseHeaderLine");
00363         ASSERT(rec, "null");
00364         ASSERT(rec->head_idx >= 0 && rec->head_idx < s_maxHeaderLine,
00365             "Bad header byte index: %d", rec->head_idx);
00366 
00367         // TODO: avoid allocations here!  (use of std::string)
00368 //      rec->dump("Parsing header");
00369 
00370         // end of line!  Interesting?
00371         rec->header[rec->head_idx] = 0; // null-terminate
00372         std::string key;
00373         const char * p = getNextTokenFromString(rec->header, key, eParse_None);
00374         //DPRINTF("key = '%s'", key.c_str());
00375 
00376         std::string val;
00377         getNextTokenFromString(p, val, eParse_None);
00378         //DPRINTF("val = '%s'", val.c_str());
00379 
00380         // now what?
00381         // At the moment, the header consists of a single line: the byte
00382         //  count (size), which is identified by a leading "s" character.
00383         if ("s" == key) {
00384                 rec->need_bytes = atol(val.c_str());
00385                 //DPRINTF("Message bytes: %ld", rec->need_bytes);
00386                 if (rec->need_bytes <= 0) {
00387                         DPRINTF("Bad byte count? %ld", rec->need_bytes);
00388                         rec->need_bytes = 0;
00389                 }
00390         } else {
00391                 //DPRINTF("Unknown message header key!");
00392         }
00393 }
00394 
00395 
00396 
00397 static bool
00398 handleData
00399 (
00400 IN conn_rec_t * rec,
00401 IO envelope_t& envelope,
00402 IO smart_ptr<MessageBuffer>& msgbuf
00403 )
00404 {
00405         // timer itself impacts timing!
00406         // perf::Timer timer("netlib::handleData");
00407         ASSERT(rec, "null");
00408         ASSERT(envelope.is_empty(), "not empty");
00409         ASSERT(!msgbuf, "should be null");
00410 
00411 //      rec->dump("Reading data");
00412 
00413         // DPRINTF("Got data to read!");
00414 
00415         // keep reading!
00416         const char * p = NULL;
00417         while (true) {
00418 
00419                 if (rec->buff_idx >= rec->buff_len) {
00420                         rec->buff_idx = -1;
00421                 }
00422 
00423                 // DPRINTF("Starting loop, idx = %d", rec->buff_idx);
00424                 if (rec->buff_idx < 0) {
00425                         // need to read buffer!
00426                         if (!readBuffer(rec))
00427                                 return false;   // couldn't read
00428                         if (rec->buff_idx < 0) {
00429                                 return false;   // failed to read anyway
00430                         }
00431                         //DPRINTF("  Read %d bytes", rec->buff_len);
00432                 }
00433                 p = rec->buffer + rec->buff_idx;
00434                 //DPRINTF("  After message read, idx=%d", rec->buff_idx);
00435                 //DPRINTF("  %d bytes remaining in buffer",
00436                 //    rec->buff_len - rec->buff_idx);
00437 
00438                 // parse message headers if necessary
00439                 const char * maxP = rec->buffer + rec->buff_len;
00440                 for (; p < maxP && rec->need_bytes < 0; ++p) {
00441                         
00442                         // are we getting a lot of garbage from client?
00443                         if (rec->head_idx >= s_maxHeaderLine - 1) {
00444                                 // too big!  Reset
00445                                 DPRINTF("Garbage from remote host!  Resetting");
00446                                 rec->head_idx = 0;
00447                         }
00448 
00449                         // push new byte to end of our header buffer
00450                         rec->header[rec->head_idx] = *p;
00451                         rec->head_idx++;
00452 
00453                         if ('\n' == *p) {
00454                                 parseHeaderLine(rec);
00455                                 rec->head_idx = 0;
00456                         } else if (!*p) {
00457                                 // null in header line?  Weird!
00458                                 rec->head_idx = 0;
00459                         }
00460                 }
00461 
00462                 // should be message data
00463                 int remain = rec->buff_len - (p - rec->buffer);
00464                 // DPRINTF("%d bytes remain", remain);
00465                 rec->buff_idx = p - rec->buffer;
00466 
00467                 if (!remain) {
00468 //                      DPRINTF("end of buffer");       // very common!
00469                         continue;       // end of buffer
00470                 }
00471 
00472                 // no point in proceeding if we aren't ready
00473                 if (rec->need_bytes < 0) {
00474                         DPRINTF("aren't ready");
00475                         ASSERT(!*p, "should be out of data");
00476 //                      rec->need_bytes = 0;
00477                         return false;   // ran out of data from read
00478                 }
00479 
00480                 // what we expected?
00481                 if (rec->need_bytes < remain) {
00482 //                      DPRINTF("ERROR: received more bytes than expected!");
00483 //                      DPRINTF("  expected: %ld", rec->need_bytes);
00484 //                      DPRINTF("  received: %d", remain);
00485 //                      DPRINTF("  truncating!!!");  - NO!  Not truncating...
00486                         // not a problem: we just take what we need
00487                         remain = rec->need_bytes;
00488                 }
00489 
00490                 // need to create buffer?
00491                 if (!rec->msgbuf) {
00492                         rec->msgbuf = MessageBuffer::create();
00493                         ASSERT(rec->msgbuf, "failed to create message buffer?");
00494                         rec->msgbuf->reserve(rec->need_bytes + 1);
00495                 }
00496 
00497                 // append
00498                 rec->msgbuf->appendData(p, remain);
00499 
00500                 // decrement
00501                 //DPRINTF("  Copied %d bytes...", remain);
00502                 rec->need_bytes -= remain;
00503                 rec->buff_idx = p + remain - rec->buffer;
00504                 if (rec->need_bytes < 1) {
00505                         //DPRINTF("Message is now complete!");
00506                         rec->msgbuf->close();
00507                         //DPRINTF("Message size: %ld bytes",
00508                         //    rec->msgbuf->getBytes());
00509 
00510                         // hand buffer over to message
00511                         msgbuf = rec->msgbuf;
00512 
00513                         // construct envelope information
00514                         envelope.fromConnId = rec->conn_id;
00515                         envelope.type = rec->type;
00516 
00517                         // Need to swap out for UDP!
00518                         // We read from local UDP port (of course), but client
00519                         // needs to know which remote UDP client sent this.
00520                         if (eType_UDPLocal == envelope.type) {
00521                                 envelope.type = eType_UDPRemote;
00522                                 envelope.fromConnId = 0;
00523                                 envelope.address = rec->udpFrom;
00524                         }
00525 
00526                         // give up ownership and clean up
00527                         rec->msgbuf = 0;
00528                         rec->need_bytes = -1;
00529                         s_messagesReceived++;
00530 
00531                         return true;
00532                 }
00533         }
00534 
00535         // nope
00536         return false;
00537 }
00538 
00539 
00540 
00541 static conn_id_t
00542 addConnectionRecord
00543 (
00544 IN eConnectionType type,
00545 IN int socket,
00546 IN const address_t& address
00547 )
00548 {
00549         ASSERT(eType_Invalid != type, "Bad type");
00550         ASSERT(socket > -2, "bad socket");
00551 
00552 //      DPRINTF("Creating connection record for '%s':%d ...", host, port);
00553 
00554         if (-1 == socket) {
00555                 ASSERT(eType_UDPRemote == type,
00556                     "Bad connection type (%d) for socket %d", type, socket);
00557         }
00558 
00559         // construct connection record and put in threadsafe map...
00560         smart_ptr<conn_rec_t> rec = new conn_rec_t;
00561         ASSERT(rec, "out of memory?");
00562         rec->socket = socket;
00563         rec->address = address;
00564         rec->type = type;
00565         rec->conn_id = getNewConnectionId();
00566 //      DPRINTF("  Assigning connection id = 0x%04lx", rec->conn_id);
00567 //      rec->dump("Just created");
00568 
00569         // add to map
00570         s_connection_map[rec->conn_id] = rec;
00571         ASSERT(2 == rec.get_ref_count(), "should have 2 refs!");
00572 
00573         //DPRINTF("Currently have %d connections", s_connection_map.size());
00574 
00575         return rec->conn_id;
00576 }
00577 
00578 
00579 
00580 static conn_id_t
00581 handleConnection
00582 (
00583 IN conn_rec_t * rec
00584 )
00585 {
00586         perf::Timer timer("netlib::handleConnection");
00587         ASSERT(rec, "null");
00588         ASSERT(eType_TCPListener == rec->type,
00589             "Requesting to listen on a non-listening socket?");
00590         ASSERT(wsIsValidSocket(rec->socket), "bad socket? %d", rec->socket);
00591 
00592         address_t address;
00593         int c = wsAccept(rec->socket, address);
00594         if (!wsIsValidSocket(c)) {
00595                 // nobody wanted to connect!
00596                 DPRINTF("Not a valid connection?");
00597                 return 0;
00598         }
00599 //      address.dump("New connection");
00600 
00601         return addConnectionRecord(eType_TCP, c, address);
00602 }
00603 
00604 
00605 
00606 static void
00607 writeMessage
00608 (
00609 IN conn_rec_t * rec
00610 )
00611 {
00612         perf::Timer timer("netlib::writeMessage");
00613         ASSERT(rec, "null");
00614 
00615         ASSERT(rec->message_queue.size(), "empty queue?");
00616 
00617         //DPRINTF("Have message to write!");
00618 
00619         // get the first message in the queue
00620         const request_t& req = rec->message_queue.front();
00621         ASSERT(!req.is_empty(), "empty message in queue?");
00622         ASSERT(req.connId, "null");
00623 
00624         // for UDP, need additional stuff...
00625         conn_rec_t * recTo = NULL;
00626         if (eType_UDPLocal == rec->type) {
00627                 // we must be sending to a remote UDP port
00628                 recTo = getConnectionRecord(req.connId);
00629                 ASSERT(recTo, "null entry for UDP receiver");
00630                 ASSERT(eType_UDPRemote == recTo->type,
00631                     "receiver is wrong type of connection");
00632         } else if (eType_UDPBroadcast == rec->type) {
00633                 // broadcast: we know where to send
00634                 recTo = rec;
00635         }
00636 
00637         // try to write header?
00638         if (rec->send_byte < 0) {
00639                 // yes, header has not yet been sent
00640 
00641                 // construct header and send
00642                 const int s_headerBytes = 32;
00643                 char header[s_headerBytes];
00644                 sprintf(header, "\ns %ld\n", req.msgbuf->getBytes());
00645                 //DPRINTF("Header:\n-----%s-----", header);
00646                 int hbytes = strlen(header);
00647                 //DPRINTF("About to send header...");
00648                 //DPRINTF(" hbytes = %d", hbytes);
00649                 //DPRINTF(" s = %d", rec->socket);
00650                 long bytes;
00651                 if (!recTo) {
00652                         bytes = wsSend(rec->socket, header, hbytes);
00653                 } else {
00654                         // recTo->address.dump("Sending here");
00655                         bytes = wsSendTo(rec->socket, header, hbytes,
00656                             recTo->address);
00657                 }
00658                 if (bytes < 0) {
00659                         // error!  bad or not?
00660                         if (eWS_Again == wsGetError()) {
00661                                 //DPRINTF("header send failed, will try again");
00662                                 return;         // quietly fail, try again later
00663                         }
00664                         DPRINTF("Error writing to client!  will close connection");
00665                         closeConnection(rec->conn_id);
00666                         return;
00667                 }
00668                 verifyThrow(bytes == hbytes, "failed to send complete header!");
00669                 s_bytesWritten += bytes;
00670 
00671                 // header successfully sent!
00672                 //DPRINTF("Successfully sent message header");
00673                 rec->send_byte = 0;     // no data sent yet
00674         }
00675 
00676         // keep sending...
00677         while (1) {
00678                 long to_send = req.msgbuf->getBytes() - rec->send_byte;
00679                 if (to_send < 1) {
00680                         //DPRINTF("Completed message send!");
00681 
00682                         // all done
00683                         s_messagesSent++;
00684                         break;
00685                 }
00686 
00687                 // don't send more than the reader can handle
00688                 if (to_send >= s_chunkSize)
00689                         to_send = s_chunkSize - 1;
00690 
00691                 const char * data = req.msgbuf->getData() + rec->send_byte;
00692                 long sent;
00693                 if (!recTo) {
00694                         sent = wsSend(rec->socket, data, to_send);
00695                 } else {
00696                         sent = wsSendTo(rec->socket, data, to_send,
00697                             recTo->address);
00698                 }
00699                 if (sent < 0) {
00700                         // error!  bad or not?
00701                         if (eWS_Again == wsGetError()) {
00702                                 DPRINTF("socket send() failed, will try again");
00703                                 return;
00704                         }
00705                         DPRINTF("Error writing to client!  will close connection");
00706                         closeConnection(rec->conn_id);
00707                         return;
00708                 }
00709                 s_bytesWritten += sent;
00710                 // DPRINTF("  sent %ld bytes", sent);
00711 
00712                 // update stats on bytes sent
00713                 rec->send_byte += sent;
00714         }
00715 
00716         // update socket data (can do this in place)
00717         rec->message_queue.pop_front();         // done with message!
00718         rec->send_byte = -1;                    // reset sent count
00719 }
00720 
00721 
00722 
00723 static void
00724 handleWrites
00725 (
00726 IN ws_set_t writeable
00727 )
00728 {
00729         // timer itself impacts timing!
00730         // perf::Timer timer("netlib::handleWrites");
00731         ASSERT(writeable, "null");
00732 
00733         for (conn_map_t::iterator i = s_connection_map.begin();
00734              i != s_connection_map.end(); ++i) {
00735                 conn_rec_t * rec = i->second;
00736                 ASSERT(rec, "null connection record");
00737                 int s = rec->socket;
00738                 if (!wsIsValidSocket(s))
00739                         continue;       // not a real socket!
00740 
00741                 if (!wsIsSocketInSet(writeable, s))
00742                         continue;       // socket isn't writeable
00743 
00744                 // DPRINTF("Have a message to write!");
00745                 if (!rec->message_queue.size())
00746                         continue;       // nothing to write anyway!
00747 
00748                 writeMessage(rec);
00749         }
00750 }
00751 
00752 
00753 
00754 static bool
00755 handleRead
00756 (
00757 IN ws_set_t readers,
00758 IO envelope_t& envelope,
00759 IO smart_ptr<MessageBuffer>& msgbuf
00760 )
00761 {
00762         // timer itself affects timing!
00763         // perf::Timer timer("netlib::handleRead");
00764         ASSERT(readers, "null");
00765         ASSERT(envelope.is_empty(), "not empty");
00766         ASSERT(!msgbuf, "should be null");
00767 
00768         for (conn_map_t::iterator i = s_connection_map.begin();
00769              i != s_connection_map.end(); ++i) {
00770                 conn_rec_t * rec = i->second;
00771                 ASSERT(rec, "null connection record");
00772                 int s = rec->socket;
00773                 if (!wsIsValidSocket(s))
00774                         continue;               // not a real socket
00775 
00776                 if (rec->buff_idx >= 0) {
00777                         // rec->dump("More data in buffer");
00778                         if (handleData(rec, envelope, msgbuf))
00779                                 return true;
00780                 }
00781 
00782                 if (!wsIsSocketInSet(readers, s))
00783                         continue;       // socket not impacted
00784 
00785                 if (eType_TCPListener == rec->type) {
00786                         // DPRINTF("Got a request to connect!");
00787                         handleConnection(rec);
00788                         return false;
00789                 } else {
00790                         // DPRINTF("Received data!");
00791                         // rec->dump("New incoming data");
00792                         if (handleData(rec, envelope, msgbuf))
00793                                 return true;
00794                 }
00795         }
00796 
00797         // can get here if no sockets were available for reads!
00798         return false;
00799 }
00800 
00801 
00802 
00803 ////////////////////////////////////////////////////////////////////////////////
00804 //
00805 //      public API
00806 //
00807 ////////////////////////////////////////////////////////////////////////////////
00808 
00809 std::string
00810 getServerFromIP
00811 (
00812 IN const ip_addr_t& ip
00813 )
00814 {
00815         // TODO: DNS reverse lookup!
00816         // for now, just string-ify the IP address
00817         ASSERT(1 == ip.flags, "Only works with IPv4 addresses for now...");
00818 
00819         char buffer[64];
00820         sprintf(buffer, "%d.%d.%d.%d",
00821             ip.addr[0], ip.addr[1], ip.addr[2], ip.addr[3]);
00822 
00823         return buffer;
00824 }
00825 
00826 
00827 
00828 conn_id_t
00829 createTcpListener
00830 (
00831 IN const address_t& address,
00832 IN int maxBacklog
00833 )
00834 {
00835         perf::Timer timer("netlib::createTcpListener");
00836         ASSERT(address.isValid(), "Invalid address");
00837         ASSERT(maxBacklog > 0, "bad max backlog: %d", maxBacklog);
00838 
00839         // set up listening socket
00840         int s = wsCreateTcpSocket();
00841         // DPRINTF("TCP listening socket: %d", s);
00842         verify(wsIsValidSocket(s), "Failed to create tcp listening socket");
00843 
00844         // bind (name) the socket
00845         verify(!wsBindToPort(s, address.port),
00846             "Failed to bind tcp listening socket");
00847 
00848         // set up for listening
00849         verify(!wsListen(s, maxBacklog),
00850             "Failed to set up tcp socket for listening");
00851 
00852         // add to our map
00853         return addConnectionRecord(eType_TCPListener, s, address);
00854 }
00855 
00856 
00857 
00858 conn_id_t
00859 createTcpConnection
00860 (
00861 IN const address_t& address
00862 )
00863 {
00864         perf::Timer timer("netlib::createTcpConnection");
00865         ASSERT2(address.isValid(),
00866             "invalid address--cannot create TCP connection");
00867 
00868         // create socket
00869         int c = wsCreateTcpSocket();
00870 
00871         // DPRINTF("TCP connection socket: %d", c);
00872         verifyThrow(wsIsValidSocket(c),
00873             "Failed to create socket for tcp connection");
00874 
00875         // set up address data
00876         verifyThrow(-1 != wsConnect(c, address),
00877             "Failed to connect to server");
00878 
00879 //      DPRINTF("Connected!");
00880         return addConnectionRecord(eType_TCP, c, address);
00881 }
00882 
00883 
00884 
00885 conn_id_t
00886 createUdpLocal
00887 (
00888 IN const address_t& address
00889 )
00890 {
00891         perf::Timer timer("netlib::createUdpLocal");
00892         ASSERT(address.isValid(), "invalid address");
00893 
00894         // create socket
00895         int c = wsCreateUdpSocket(false);       // not broadcast
00896         DPRINTF("Local UDP socket: %d", c);
00897         verifyThrow(wsIsValidSocket(c), "Failed to create socket for local udp");
00898 
00899         // bind to local port
00900         verifyThrow(-1 != wsBindToPort(c, address.port),
00901             "Failed to bind to local UDP socket");
00902 
00903         // all done!
00904         return addConnectionRecord(eType_UDPLocal, c, address);
00905 }
00906 
00907 
00908 
00909 conn_id_t
00910 createUdpRemote
00911 (
00912 IN conn_id_t localUdp,
00913 IN const address_t& address
00914 )
00915 {
00916         perf::Timer timer("netlib::createUdpRemote");
00917         ASSERT(localUdp, "null");
00918         ASSERT2(address.isValid(),
00919             "Address is invalid--cannot create remote UDP connection");
00920 
00921         // create connection record
00922         conn_id_t conn_id =
00923             addConnectionRecord(eType_UDPRemote, -1, address);
00924         ASSERT(conn_id, "null");
00925 
00926         // retrieve it because we're going to tweak it...
00927         conn_rec_t * rec = getConnectionRecord(conn_id);
00928         ASSERT(rec, "null");
00929         rec->local = localUdp;
00930         DPRINTF("Added remote udp connection");
00931 
00932         // all done!
00933         return conn_id;
00934 }
00935 
00936 
00937 
00938 conn_id_t
00939 createUdpBroadcast
00940 (
00941 IN const address_t& broadcastAddress
00942 )
00943 {
00944         ASSERT2(broadcastAddress.isValid(), "Invalid broadcast address");
00945 
00946         // create socket
00947         int s = wsCreateUdpSocket(true);        // yes, broadcast
00948         verifyThrow(wsIsValidSocket(s),
00949             "Failed to create socket for udp broadcast");
00950 
00951         // create connection record
00952         return addConnectionRecord(eType_UDPBroadcast, s, broadcastAddress);
00953 }
00954 
00955 
00956 
00957 bool
00958 enqueueMessage
00959 (
00960 IN conn_id_t conn_id,
00961 IN smart_ptr<MessageBuffer>& message
00962 )
00963 {
00964         // at the moment, this routine is fast enough that the timer itself
00965         //      adds significant time!
00966         // perf::Timer timer("netlib::enqueueMessage");
00967         ASSERT(conn_id, "null");
00968         ASSERT(message, "null");
00969 
00970         // look for this connection
00971         conn_rec_t * rec = getConnectionRecord(conn_id);
00972         if (!rec) {
00973                 DPRINTF("Connection id not recognized? 0x%04lx", (long) conn_id);
00974                 return false;
00975         }
00976         ASSERT(rec, "null record in map for connection id 0x%04lx", (long) conn_id);
00977         // DPRINTF("Enqueuing message for connection 0x%04lx", conn_id);
00978         // DPRINTF("  Local protocol is %d", rec->type);
00979 
00980         // improper connection type for sending?
00981         if (eType_UDPLocal == rec->type) {
00982                 DPRINTF("Cannot send messages to local UDP!  Send to remote");
00983                 return false;
00984         }
00985 
00986         // udp?  In that case, need to swap local + remote
00987         if (eType_UDPRemote == rec->type) {
00988                 ASSERT(rec->local, "null local UDP?");
00989                 conn_rec_t * localRec = getConnectionRecord(rec->local);
00990                 ASSERT(localRec, "local udp sender disappeared");
00991                 rec = localRec;         // swap out and queue here
00992         }
00993         ASSERT(wsIsValidSocket(rec->socket),
00994             "null socket in connection record for 0x%04lx", (long) conn_id);
00995 
00996         // add to message queue
00997         request_t req;
00998         req.connId = conn_id;
00999         req.msgbuf = message;
01000         rec->message_queue.push_back(req);
01001 
01002         // all done
01003         return true;
01004 }
01005 
01006 
01007 
01008 
01009 bool
01010 getNextMessage
01011 (
01012 IN long wait_microseconds,
01013 OUT envelope_t& envelope,
01014 OUT smart_ptr<MessageBuffer>& msgbuf
01015 )
01016 {
01017         perf::Timer timer("netlib::getNextMessage");
01018         ASSERT(wait_microseconds >= 0, "Bad wait: %ld", wait_microseconds);
01019         envelope.clear();
01020         ASSERT(!msgbuf, "unecessary free?");
01021 
01022         // arbitrary assertion here
01023         ASSERT(wait_microseconds < 1000 * 1000, "Wait is too long! %ld usec",
01024             wait_microseconds);
01025 
01026         // mask of types that can read
01027         dword_t readMask = eType_TCP | eType_UDPLocal | eType_TCPListener;
01028 
01029         // see if any sockets have data for us
01030         static ws_set_t readers = 0;
01031         if (!readers) {
01032                 readers = wsCreateSet();
01033         }
01034         wsClearSet(readers);
01035 
01036         // also see if any sockets are ready to be sent out on
01037         static ws_set_t writeable = 0;
01038         if (!writeable) {
01039                 writeable = wsCreateSet();
01040         }
01041         wsClearSet(writeable);
01042 
01043         // add all sockets we care about
01044         int max = 0;
01045         for (conn_map_t::iterator i = s_connection_map.begin();
01046              i != s_connection_map.end(); ++i) {
01047                 const conn_rec_t * rec = i->second;
01048                 ASSERT(eType_Invalid != rec->type, "invalid connection type?");
01049 
01050                 // only read from certain sockets
01051                 if (readMask & rec->type) {
01052                         wsAddSocketToSet(readers, rec->socket);
01053                 }
01054 
01055                 // interested in writing if message is pending
01056                 if (rec->message_queue.size() > 0) {
01057                         wsAddSocketToSet(writeable, rec->socket);
01058                 }
01059 
01060                 // update max socket ID?
01061                 if (rec->socket > max)
01062                         max = rec->socket;
01063         }
01064 
01065         // go get 'em
01066         int count = wsSelect(max + 1, readers, writeable, wait_microseconds);
01067         if (count < 0) {
01068                 dumpErrorInfo("select() call failed");
01069                 return false;
01070         }
01071 
01072         // first see if anything can be written
01073         handleWrites(writeable);
01074 
01075         // okay, see if there is anything to read!
01076         return handleRead(readers, envelope, msgbuf);
01077 }
01078 
01079 
01080 
01081 bool
01082 isValidConnection
01083 (
01084 IN conn_id_t conn_id
01085 )
01086 {
01087         return (NULL != getConnectionRecord(conn_id));
01088 }
01089 
01090 
01091 
01092 bool
01093 getConnectionInfo
01094 (
01095 IN conn_id_t conn_id,
01096 OUT connection_info_t& info
01097 )
01098 {
01099         ASSERT(conn_id, "null");
01100         info.clear();
01101 
01102         conn_rec_t * rec = getConnectionRecord(conn_id);
01103         if (!rec) {
01104                 return false;
01105         }
01106 
01107         info.type = rec->type;
01108         info.address = rec->address;
01109 
01110         // all done
01111         return true;
01112 }
01113 
01114 
01115 
01116 void
01117 closeConnection
01118 (
01119 IN conn_id_t conn_id
01120 )
01121 {
01122         ASSERT(conn_id, "null");
01123 
01124         DPRINTF("Closing connection 0x%lx...", (long) conn_id);
01125 
01126         conn_map_t::iterator i = s_connection_map.find(conn_id);
01127         if (s_connection_map.end() == i) {
01128                 DPRINTF("Error in closeConnection() -- connection id 0x%lx not found",
01129                     (long) conn_id);
01130                 return;
01131         }
01132 
01133         // remove the connection
01134         s_connection_map.erase(i);
01135 }
01136 
01137 
01138 
01139 void
01140 dumpMessage
01141 (
01142 IO std::ostream& stream,
01143 IN const char * title,
01144 IN const envelope_t& envelope,
01145 IN const MessageBuffer * buffer
01146 )
01147 {
01148         ASSERT(stream.good(), "bad?");
01149         ASSERT(title, "null");
01150 
01151         DPRINTF("Message dump: %s", title);
01152         if (envelope.is_empty()) {
01153                 DPRINTF("  Envelope is empty!");
01154         } else {
01155                 DPRINTF("  From: 0x%04lx", (long) envelope.fromConnId);
01156                 DPRINTF("  Connection Type: %d (%s)", envelope.type,
01157                     getTypeName(envelope.type));
01158                 envelope.address.dump(title);
01159         }
01160 
01161         if (!buffer) {
01162                 DPRINTF("  Null message buffer!");
01163         } else {
01164                 DPRINTF("  Message: '%s'", buffer->getData());
01165         }
01166 }
01167 
01168 
01169 
01170 void
01171 dumpStats
01172 (
01173 void
01174 )
01175 {
01176         DPRINTF("Networking stats:");
01177         DPRINTF("  Total messages received: %6u", s_messagesReceived);
01178         DPRINTF("  Total messages sent:     %6u", s_messagesSent);
01179 //#ifdef WIN32
01180 //      // sigh... Windows has a non-standard format for long long
01181 //      DPRINTF("  Total bytes read:    %10I64d  (%4I64d MB)", s_bytesRead,
01182 //          (s_bytesRead + 512 * 1024) / (1024 * 1024));
01183 //      DPRINTF("  Total bytes written: %10I64d  (%4I64d MB)", s_bytesWritten,
01184 //          (s_bytesWritten + 512 * 1024) / (1024 * 1024));
01185 //#else // WIN32
01186         std::cerr << "  Total bytes received: " << s_bytesRead;
01187         std::cerr << "  (" << (s_bytesRead + 512 * 1024) / (1024 * 1024) << " MB)\n";
01188         std::cerr << "  Total bytes sent: " << s_bytesWritten;
01189         std::cerr << "  (" << (s_bytesWritten + 512 * 1024) / (1024 * 1024) << " MB)\n";
01190 //#endif        // WIN32
01191 
01192         if (s_messagesReceived > 0) {
01193                 long avg = (long) (s_bytesRead / s_messagesReceived);
01194                 DPRINTF("  Average size of message read: %5ld bytes", avg);
01195         }
01196         if (s_messagesSent > 0) {
01197                 long avg = (long) (s_bytesWritten / s_messagesSent);
01198                 DPRINTF("  Average size of message sent: %5ld bytes", avg);
01199         }
01200 }
01201 
01202 
01203 void
01204 connection_info_t::dump
01205 (
01206 IN const char * title
01207 )
01208 const
01209 throw()
01210 {
01211         DPRINTF("Connection info: %s", title);
01212         DPRINTF("  type: %d", type);
01213         address.dump(title);
01214 }
01215 
01216 
01217 };      // netlib namespace
01218