00001 #ifndef COMMONS_ST_ST_H
00002 #define COMMONS_ST_ST_H
00003 
00004 #include <algorithm>
00005 #include <boost/foreach.hpp>
00006 #include <boost/function.hpp>
00007 #include <boost/shared_ptr.hpp>
00008 #include <commons/array.h>
00009 #include <commons/delegates.h>
00010 #include <commons/nullptr.h>
00011 #include <commons/sockets.h>
00012 #include <commons/utility.h>
00013 #include <exception>
00014 #include <map>
00015 #include <queue>
00016 #include <set>
00017 #include <sstream>
00018 #include <st.h>
00019 #include <stx.h>
00020 
00021 #define foreach BOOST_FOREACH
00022 #define shared_ptr boost::shared_ptr
00023 
00024 namespace commons
00025 {
00026   using namespace boost;
00027   using namespace std;
00028 
00029   enum { default_stack_size = 65536 };
00030 
00031   struct stfd_closer {
00032     static void apply(st_netfd_t fd) { check0x(st_netfd_close(fd)); }
00033   };
00034 
00035   typedef closing<st_netfd_t, stfd_closer> st_closing;
00036 
00040   class st_lock
00041   {
00042     NONCOPYABLE(st_lock)
00043     public:
00044       st_lock(st_mutex_t mx) : mx_(mx) { check0x(st_mutex_lock(mx)); }
00045       ~st_lock() { check0x(st_mutex_unlock(mx_)); }
00046     private:
00047       st_mutex_t mx_;
00048   };
00049 
00055   st_thread_t
00056   st_spawn(const fn& f)
00057   {
00058     return st_thread_create(&run_function0_null,
00059                             new fn(f),
00060                             true,
00061                             default_stack_size);
00062   }
00063 
00064   void
00065   st_join(st_thread_t t)
00066   {
00067     check0x(st_thread_join(t, nullptr));
00068   }
00069 
00076   st_netfd_t
00077   st_tcp_connect(in_addr host, uint16_t port, st_utime_t timeout)
00078   {
00079     
00080     sockaddr_in sa = make_sockaddr(host, port);
00081 
00082     
00083     st_closing s(checkpass(st_netfd_open_socket(tcp_socket(true))));
00084 
00085     
00086     check0x(st_connect(s.get(), reinterpret_cast<sockaddr*>(&sa), sizeof sa, timeout));
00087     return s.release();
00088   }
00089 
00098   st_netfd_t
00099   st_tcp_connect(const char *host, uint16_t port, st_utime_t timeout)
00100   {
00101     in_addr ipaddr;
00102 
00103     
00104     
00105     if (inet_aton(host, &ipaddr) == 0) {
00106       
00107       check0x(stx_dns_getaddr(host, &ipaddr, timeout));
00108     }
00109 
00110     return st_tcp_connect(ipaddr, port, timeout);
00111   }
00112 
00118   st_netfd_t
00119   st_tcp_listen(uint16_t port)
00120   {
00121     int sfd = tcp_listen(port);
00122     try {
00123       
00124       return checkpass(st_netfd_open_socket(sfd));
00125     } catch (...) {
00126       close(sfd);
00127       throw;
00128     }
00129   }
00130 
00135   class st_cond
00136   {
00137     NONCOPYABLE(st_cond)
00138     public:
00139       st_cond() : c(checkerr(st_cond_new())) {}
00140       ~st_cond() { check0x(st_cond_destroy(c)); }
00141       void wait() { check0x(st_cond_wait(c)); }
00142       void wait(st_utime_t t) { check0x(st_cond_timedwait(c, t)); }
00143       void signal() { st_cond_signal(c); }
00144       void bcast() { st_cond_broadcast(c); }
00145     private:
00146       st_cond_t c;
00147   };
00148 
00152   class st_bool
00153   {
00154     public:
00155       st_bool(bool init = false) : c(), b(init) {}
00156       void set() { b = true; c.bcast(); }
00157       void reset() { b = false; c.bcast(); }
00158       void waitset() { if (!b) c.wait(); }
00159       void waitreset() { if (b) c.wait(); }
00160       operator bool() { return b; }
00161     private:
00162       st_cond c;
00163       bool b;
00164   };
00165 
00166   void toggle(st_bool& b) { if (b) b.reset(); else b.set(); }
00167 
00172   class st_mutex
00173   {
00174     NONCOPYABLE(st_mutex)
00175     public:
00176       st_mutex() : m(checkerr(st_mutex_new())) {}
00177       ~st_mutex() { check0x(st_mutex_destroy(m)); }
00178       void lock() { check0x(st_mutex_lock(m)); }
00179       bool trylock() {
00180         int res = st_mutex_trylock(m);
00181         if (res == 0) return true;
00182         else if (errno == EBUSY) return false;
00183         else check0x(res);
00184       }
00185       void unlock() { check0x(st_mutex_unlock(m)); }
00186     private:
00187       st_mutex_t m;
00188   };
00189 
00194   template <typename T>
00195   class st_channel
00196   {
00197     public:
00198       void push(const T &x) {
00199         q_.push(x);
00200         empty_.signal();
00201       }
00202       T take() {
00203         while (q_.empty()) {
00204           empty_.wait();
00205         }
00206         T x = front();
00207         q_.pop();
00208         return x;
00209       }
00210       const T& front() const { return q_.front(); }
00211       bool empty() const { return q_.empty(); }
00212       void pop() { q_.pop(); }
00213       void clear() { while (!q_.empty()) q_.pop(); }
00214       const std::queue<T> &queue() const { return q_; }
00215     private:
00216       std::queue<T> q_;
00217       st_cond empty_;
00218   };
00219 
00223   template <typename T>
00224   class st_multichannel
00225   {
00226     public:
00227       void push(const T &x) {
00228         foreach (shared_ptr<st_channel<T> > q, qs) {
00229           q->push(x);
00230         }
00231       }
00232       st_channel<T> &subscribe() {
00233         shared_ptr<st_channel<T> > q(new st_channel<T>);
00234         qs.push_back(q);
00235         return *q;
00236       }
00237     private:
00238       vector<shared_ptr<st_channel<T> > > qs;
00239   };
00240 
00246   class st_intr_hub
00247   {
00248     public:
00249       virtual void insert(st_thread_t t) = 0;
00250       virtual void erase(st_thread_t t) = 0;
00251       virtual ~st_intr_hub() {};
00252   };
00253 
00258   class st_intr_cond : public st_intr_hub
00259   {
00260     public:
00261       virtual ~st_intr_cond() {}
00262       void insert(st_thread_t t) { threads.insert(t); }
00263       void erase(st_thread_t t) { threads.erase(t); }
00264       void signal() {
00265         foreach (st_thread_t t, threads) {
00266           st_thread_interrupt(t);
00267         }
00268         threads.clear();
00269       }
00270     private:
00271       std::set<st_thread_t> threads;
00272   };
00273 
00279   class st_intr_bool : public st_intr_hub
00280   {
00281     public:
00282       void insert(st_thread_t t) {
00283         if (b) st_thread_interrupt(t);
00284         else threads.insert(t);
00285       }
00286       void erase(st_thread_t t) { threads.erase(t); }
00287       void set() {
00288         b = true;
00289         foreach (st_thread_t t, threads) {
00290           st_thread_interrupt(t);
00291         }
00292         threads.clear();
00293       }
00294       void reset() {
00295         
00296         
00297         assert(!b || threads.empty());
00298         b = false;
00299       }
00300       operator bool() const { return b; }
00301     private:
00302       std::set<st_thread_t> threads;
00303       bool b;
00304   };
00305 
00309   class st_intr
00310   {
00311     public:
00312       st_intr(st_intr_hub &hub) : hub_(hub) { hub.insert(st_thread_self()); }
00313       ~st_intr() { hub_.erase(st_thread_self()); }
00314     private:
00315       st_intr_hub &hub_;
00316   };
00317 
00318   class st_group_join_exception : public std::exception
00319   {
00320     public:
00321       st_group_join_exception(const map<st_thread_t, std::exception> &th2ex) :
00322         th2ex_(th2ex) {}
00323       virtual ~st_group_join_exception() throw() {}
00324       virtual const char *what() const throw() {
00325         if (!th2ex_.empty() && s == "") {
00326           bool first = true;
00327           stringstream ss;
00328           typedef pair<st_thread_t, std::exception> p;
00329           foreach (p p, th2ex_) {
00330             ss << (first ? "" : ", ") << p.first << " -> " << p.second.what();
00331             first = false;
00332           }
00333           s = ss.str();
00334         }
00335         return s.c_str();
00336       }
00337     private:
00338       map<st_thread_t, std::exception> th2ex_;
00339       mutable string s;
00340   };
00341 
00345   class st_joining
00346   {
00347     NONCOPYABLE(st_joining)
00348     public:
00349       st_joining(st_thread_t t) : t_(t) {}
00350       ~st_joining() { st_join(t_); }
00351     private:
00352       st_thread_t t_;
00353   };
00354 
00359   class st_thread_group
00360   {
00361     public:
00362       ~st_thread_group() {
00363         map<st_thread_t, std::exception> th2ex;
00364         foreach (st_thread_t t, ts) {
00365           try { st_join(t); }
00366           catch (std::exception &ex) { th2ex[t] = ex; }
00367         }
00368         if (!th2ex.empty()) throw st_group_join_exception(th2ex);
00369       }
00370       void insert(st_thread_t t) { ts.insert(t); }
00371     private:
00372       std::set<st_thread_t> ts;
00373   };
00374 
00375   class eof_exception : public std::exception {
00376     const char *what() const throw() { return "EOF"; }
00377   };
00378 
00382   class st_reader
00383   {
00384     NONCOPYABLE(st_reader)
00385     public:
00386       st_reader(st_netfd_t fd, char *buf, size_t bufsize) :
00387         fd_(fd),
00388         buf_(buf, bufsize),
00389         start_(buf_.get()),
00390         end_(buf_.get())
00391       {}
00392 
00396       size_t unread() { return end_ - start_; }
00397 
00401       size_t rem() { return buf_.end() - end_; }
00402 
00406       sized_array<char> &buf() { return buf_; }
00407 
00411       void reset_range(char *start, char *end) {
00412         start_ = start;
00413         end_ = end;
00414       }
00415 
00419       void skip(size_t req, st_utime_t to = ST_UTIME_NO_TIMEOUT) {
00420         while (true) {
00421           if (unread() >= req) {
00422             
00423             start_ += req;
00424             break;
00425           }
00426 
00427           
00428           
00429           req -= unread();
00430           
00431           start_ = end_ = buf_.get();
00432 
00433           ssize_t res = checknnegerr(st_read(fd_, end_, rem(), to));
00434           end_ += res;
00435 
00436           
00437           if (res == 0 && unread() < req) throw eof_exception();
00438         }
00439       }
00440 
00445       managed_array<char> read(size_t req, st_utime_t to = ST_UTIME_NO_TIMEOUT) {
00446         
00447         if (unread() >= req) {
00448           managed_array<char> p(start_, false);
00449           start_ += req;
00450           return p;
00451         }
00452 
00453         
00454         if (req > buf_.size()) {
00455           managed_array<char> p(new char[req], true);
00456           memcpy(p.get(), start_, unread());
00457           checkeqnneg(st_read_fully(fd_, p + unread(), req - unread(), to), static_cast<ssize_t>(req - unread()));
00458           start_ = end_ = buf_.get();
00459           return p;
00460         }
00461 
00462         
00463         if (req > static_cast<size_t>(buf_.end() - end_)) {
00464           memmove(buf_.get(), start_, unread());
00465           size_t diff = start_ - buf_.get();
00466           start_ -= diff;
00467           end_ -= diff;
00468         }
00469 
00470         
00471         while (unread() < req) {
00472           ssize_t res = checknnegerr(st_read(fd_, end_, rem(), to));
00473           if (res == 0) break;
00474           else end_ += res;
00475         }
00476 
00477         
00478         if (unread() < req)
00479           throw eof_exception();
00480 
00481         managed_array<char> p(start_, false);
00482         start_ += req;
00483         return p;
00484       }
00485 
00486       template<typename T>
00487       T read(st_utime_t to = ST_UTIME_NO_TIMEOUT)
00488       {
00489         size_t req = sizeof(T);
00490 
00491         
00492         if (unread() >= req) {
00493           T x = *reinterpret_cast<const T*>(start_);
00494           start_ += req;
00495           return x;
00496         }
00497 
00498         assert(req <= buf_.size());
00499 
00500         
00501         if (req > static_cast<size_t>(buf_.end() - end_)) {
00502           memmove(buf_.get(), start_, unread());
00503           size_t diff = start_ - buf_.get();
00504           start_ -= diff;
00505           end_ -= diff;
00506         }
00507 
00508         
00509         while (unread() < req) {
00510           ssize_t res = checknnegerr(st_read(fd_, end_, rem(), to));
00511           if (res == 0) break;
00512           else end_ += res;
00513         }
00514 
00515         
00516         if (unread() < req)
00517           throw eof_exception();
00518 
00519         T x = *reinterpret_cast<const T*>(start_);
00520         start_ += req;
00521         return x;
00522       }
00523 
00524     private:
00525       st_netfd_t fd_;
00526 
00530       sized_array<char> buf_;
00531 
00535       char *start_;
00536 
00540       char *end_;
00541   };
00542 
00543 }
00544 
00545 #endif