Project

General

Profile

Feature #4195

Updated by mame (Yusuke Endoh) over 9 years ago

=begin 
  
  なかだです。 
 
  Socket#recvmsgは scm_rights: true を指定するだけでメインのデータだけで 
  なく簡単にIOを受け取ることができますが、一方でSocket#sendmsg側には対応 
  する指定ができません。以下のようなオプションを追加するのはどうでしょう 
  か。 
 
    s.sendmsg("foo", scm_rights: STDIN) 
    s.sendmsg("foo", scm_rights: [STDIN, STDOUT]) 
 
 
  diff --git i/ext/socket/ancdata.c w/ext/socket/ancdata.c 
  index abaf19d..c329e0a 100644 
  --- i/ext/socket/ancdata.c 
  +++ w/ext/socket/ancdata.c 
  @@ -2,6 +2,8 @@ 
  
   #include <time.h> 
  
  +static ID sym_scm_rights; 
  + 
   #if defined(HAVE_ST_MSG_CONTROL) 
   static VALUE rb_cAncillaryData; 
  
  @@ -1126,17 +1128,63 @@ rb_sendmsg(int fd, const struct msghdr *msg, int flags) 
       return rb_thread_blocking_region(nogvl_sendmsg_func, &args, RUBY_UBF_IO, 0); 
   } 
  
  +#if defined(HAVE_ST_MSG_CONTROL) 
  +static size_t 
  +io_to_fd(VALUE io) 
  +{ 
  +      VALUE fnum = rb_check_to_integer(io, "to_int"); 
  +      if (NIL_P(fnum)) 
  + 	 fnum = rb_convert_type(io, T_FIXNUM, "Fixnum", "fileno"); 
  +      return NUM2UINT(fnum); 
  +} 
  + 
  +static char * 
  +prepare_msghdr(VALUE controls_str, int level, int type, long clen) 
  +{ 
  +      struct cmsghdr cmh; 
  +      char *cmsg; 
  +      size_t cspace; 
  +      long oldlen = RSTRING_LEN(controls_str); 
  +      cspace = CMSG_SPACE(clen); 
  +      rb_str_resize(controls_str, oldlen + cspace); 
  +      cmsg = RSTRING_PTR(controls_str)+oldlen; 
  +      memset((char *)cmsg, 0, cspace); 
  +      memset((char *)&cmh, 0, sizeof(cmh)); 
  +      cmh.cmsg_level = level; 
  +      cmh.cmsg_type = type; 
  +      cmh.cmsg_len = (socklen_t)CMSG_LEN(clen); 
  +      MEMCPY(cmsg, &cmh, char, sizeof(cmh)); 
  +      return cmsg+((char*)CMSG_DATA(&cmh)-(char*)&cmh); 
  +} 
  + 
  +# if defined(__NetBSD__) 
  +#     define TRIM_PADDING 1 
  +# endif 
  +# if TRIM_PADDING 
  +#     define prepare_msghdr(controls_str, level, type, clen) \ 
  +      (last_pad = CMSG_SPACE(clen) - CMSG_LEN(clen), \ 
  +       prepare_msghdr((controls_str), \ 
  + 		     last_level = (level), last_type = (type), \ 
  + 		     (clen))) 
  +# endif 
  +#endif 
  + 
   static VALUE 
   bsock_sendmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) 
   { 
       rb_io_t *fptr; 
  -      VALUE data, vflags, dest_sockaddr; 
  +      VALUE data, vflags, dest_sockaddr, vopts = Qnil; 
       VALUE *controls_ptr; 
       int controls_num; 
       struct msghdr mh; 
       struct iovec iov; 
   #if defined(HAVE_ST_MSG_CONTROL) 
       volatile VALUE controls_str = 0; 
  +# if TRIM_PADDING 
  +      size_t last_pad = 0; 
  +      int last_level = 0; 
  +      int last_type = 0; 
  +# endif 
   #endif 
       int flags; 
       ssize_t ss; 
  @@ -1152,6 +1200,8 @@ bsock_sendmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) 
  
       if (argc == 0) 
           rb_raise(rb_eArgError, "mesg argument required"); 
  +      if (1 < argc && RB_TYPE_P(argv[argc-1], T_HASH)) 
  +          vopts = argv[--argc]; 
       data = argv[0]; 
       if (1 < argc) vflags = argv[1]; 
       if (2 < argc) dest_sockaddr = argv[2]; 
  @@ -1162,19 +1212,13 @@ bsock_sendmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) 
       if (controls_num) { 
   #if defined(HAVE_ST_MSG_CONTROL) 
  	 int i; 
  - 	 size_t last_pad = 0; 
  -          int last_level = 0; 
  -          int last_type = 0; 
           controls_str = rb_str_tmp_new(0); 
           for (i = 0; i < controls_num; i++) { 
               VALUE elt = controls_ptr[i], v; 
               VALUE vlevel, vtype; 
               int level, type; 
               VALUE cdata; 
  -              long oldlen; 
  -              struct cmsghdr cmh; 
               char *cmsg; 
  -              size_t cspace; 
               v = rb_check_convert_type(elt, T_ARRAY, "Array", "to_ary"); 
               if (!NIL_P(v)) { 
                   elt = v; 
  @@ -1192,21 +1236,46 @@ bsock_sendmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) 
               level = rsock_level_arg(family, vlevel); 
               type = rsock_cmsg_type_arg(family, level, vtype); 
               StringValue(cdata); 
  -              oldlen = RSTRING_LEN(controls_str); 
  -              cspace = CMSG_SPACE(RSTRING_LEN(cdata)); 
  -              rb_str_resize(controls_str, oldlen + cspace); 
  -              cmsg = RSTRING_PTR(controls_str)+oldlen; 
  -              memset((char *)cmsg, 0, cspace); 
  -              memset((char *)&cmh, 0, sizeof(cmh)); 
  -              cmh.cmsg_level = level; 
  -              cmh.cmsg_type = type; 
  -              cmh.cmsg_len = (socklen_t)CMSG_LEN(RSTRING_LEN(cdata)); 
  -              MEMCPY(cmsg, &cmh, char, sizeof(cmh)); 
  -              MEMCPY(cmsg+((char*)CMSG_DATA(&cmh)-(char*)&cmh), RSTRING_PTR(cdata), char, RSTRING_LEN(cdata)); 
  -              last_level = cmh.cmsg_level; 
  -              last_type = cmh.cmsg_type; 
  - 	     last_pad = cspace - cmh.cmsg_len; 
  +              cmsg = prepare_msghdr(controls_str, level, type, RSTRING_LEN(cdata)); 
  +              MEMCPY(cmsg, RSTRING_PTR(cdata), char, RSTRING_LEN(cdata)); 
           } 
  +#else 
  +        no_msg_control: 
  + 	 rb_raise(rb_eNotImpError, "control message for sendmsg is unimplemented"); 
  +#endif 
  +      } 
  +      if (!NIL_P(vopts)) { 
  + 	 VALUE rights = rb_hash_aref(vopts, sym_scm_rights); 
  + 	 if (!NIL_P(rights)) { 
  +#if defined(HAVE_ST_MSG_CONTROL) 
  + 	     VALUE tmp = rb_check_array_type(rights); 
  +              long count = NIL_P(tmp) ? 1 : RARRAY_LEN(tmp); 
  +              char *cmsg; 
  + 	     int fd; 
  + 	     if (!controls_str) controls_str = rb_str_tmp_new(0); 
  +              cmsg = prepare_msghdr(controls_str, SOL_SOCKET, SCM_RIGHTS, 
  +                                    count * sizeof(int)); 
  +              if (NIL_P(tmp)) { 
  +                  fd = io_to_fd(rights); 
  +                  MEMCPY(cmsg, &fd, int, 1); 
  +              } 
  +              else { 
  +                  long i; 
  +                  rights = tmp; 
  +                  for (i = 0; i < count && i < RARRAY_LEN(rights); ++i) { 
  +                      fd = io_to_fd(RARRAY_PTR(rights)[i]); 
  +                      MEMCPY(cmsg, &fd, int, 1); 
  +                      cmsg += sizeof(int); 
  +                  } 
  +              } 
  +#else 
  + 	     goto no_msg_control; 
  +#endif 
  + 	 } 
  +      } 
  +#if defined(HAVE_ST_MSG_CONTROL) 
  +      { 
  +# if TRIM_PADDING 
  	 if (last_pad) { 
               /* 
                * This code removes the last padding from msg_controllen. 
  @@ -1228,15 +1297,12 @@ bsock_sendmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) 
                * Basically, msg_controllen should contains the padding. 
                * So the padding is removed only if a problem really exists. 
                */ 
  -#if defined(__NetBSD__) 
               if (last_level == SOL_SOCKET && last_type == SCM_RIGHTS) 
                   rb_str_set_len(controls_str, RSTRING_LEN(controls_str)-last_pad); 
  -#endif 
  	 } 
  -#else 
  - 	 rb_raise(rb_eNotImpError, "control message for sendmsg is unimplemented"); 
  -#endif 
  +# endif 
       } 
  +#endif 
  
       flags = NIL_P(vflags) ? 0 : NUM2INT(vflags); 
   #ifdef MSG_DONTWAIT 
  @@ -1492,7 +1558,7 @@ bsock_recvmsg_internal(int argc, VALUE *argv, VALUE sock, int nonblock) 
       grow_buffer = NIL_P(vmaxdatlen) || NIL_P(vmaxctllen); 
  
       request_scm_rights = 0; 
  -      if (!NIL_P(vopts) && RTEST(rb_hash_aref(vopts, ID2SYM(rb_intern("scm_rights"))))) 
  +      if (!NIL_P(vopts) && RTEST(rb_hash_aref(vopts, sym_scm_rights))) 
           request_scm_rights = 1; 
  
       GetOpenFile(sock, fptr); 
  @@ -1795,5 +1861,7 @@ rsock_init_ancdata(void) 
       rb_define_method(rb_cAncillaryData, "ipv6_pktinfo", ancillary_ipv6_pktinfo, 0); 
       rb_define_method(rb_cAncillaryData, "ipv6_pktinfo_addr", ancillary_ipv6_pktinfo_addr, 0); 
       rb_define_method(rb_cAncillaryData, "ipv6_pktinfo_ifindex", ancillary_ipv6_pktinfo_ifindex, 0); 
  + 
  +      sym_scm_rights = ID2SYM(rb_intern("scm_rights")); 
   #endif 
   } 
  diff --git i/test/socket/test_unix.rb w/test/socket/test_unix.rb 
  index bde17cf..e9db22e 100644 
  --- i/test/socket/test_unix.rb 
  +++ w/test/socket/test_unix.rb 
  @@ -31,7 +31,7 @@ class TestSocket_UNIXSocket < Test::Unit::TestCase 
       end 
     end 
  
  -    def test_fd_passing_n 
  +    def fd_passing_test 
       io_ary = [] 
       return if !defined?(Socket::SCM_RIGHTS) 
       io_ary.concat IO.pipe 
  @@ -42,8 +42,7 @@ class TestSocket_UNIXSocket < Test::Unit::TestCase 
         send_io_ary << io 
         UNIXSocket.pair {|s1, s2| 
           begin 
  -            ret = s1.sendmsg("\0", 0, nil, [Socket::SOL_SOCKET, Socket::SCM_RIGHTS, 
  -                                            send_io_ary.map {|io2| io2.fileno }.pack("i!*")]) 
  +            ret = yield(s1, send_io_ary) 
           rescue NotImplementedError 
             return 
           end 
  @@ -66,48 +65,38 @@ class TestSocket_UNIXSocket < Test::Unit::TestCase 
       io_ary.each {|io| io.close if !io.closed? } 
     end 
  
  +    def test_fd_passing_n 
  +      fd_passing_test do |s, ios| 
  +        s.sendmsg("\0", 0, nil, 
  +                  [Socket::SOL_SOCKET, Socket::SCM_RIGHTS, ios.map(&:fileno).pack("i!*")]) 
  +      end 
  +    end 
  + 
     def test_fd_passing_n2 
  -      io_ary = [] 
  -      return if !defined?(Socket::SCM_RIGHTS) 
  -      return if !defined?(Socket::AncillaryData) 
  -      io_ary.concat IO.pipe 
  -      io_ary.concat IO.pipe 
  -      io_ary.concat IO.pipe 
  -      send_io_ary = [] 
  -      io_ary.each {|io| 
  -        send_io_ary << io 
  -        UNIXSocket.pair {|s1, s2| 
  -          begin 
  -            ancdata = Socket::AncillaryData.unix_rights(*send_io_ary) 
  -            ret = s1.sendmsg("\0", 0, nil, ancdata) 
  -          rescue NotImplementedError 
  -            return 
  -          end 
  -          assert_equal(1, ret) 
  -          ret = s2.recvmsg(:scm_rights=>true) 
  -          data, srcaddr, flags, *ctls = ret 
  -          recv_io_ary = [] 
  -          ctls.each {|ctl| 
  -            next if ctl.level != Socket::SOL_SOCKET || ctl.type != Socket::SCM_RIGHTS 
  -            recv_io_ary.concat ctl.unix_rights 
  -          } 
  -          assert_equal(send_io_ary.length, recv_io_ary.length) 
  -          send_io_ary.length.times {|i| 
  -            assert_not_equal(send_io_ary[i].fileno, recv_io_ary[i].fileno) 
  -            assert(File.identical?(send_io_ary[i], recv_io_ary[i])) 
  -          } 
  -        } 
  -      } 
  -    ensure 
  -      io_ary.each {|io| io.close if !io.closed? } 
  +      fd_passing_test do |s, ios| 
  +        ancdata = Socket::AncillaryData.unix_rights(*ios) 
  +        s.sendmsg("\0", 0, nil, ancdata) 
  +      end 
  +    end 
  + 
  +    def test_fd_passing_n3 
  +      fd_passing_test do |s, ios| 
  +        s.sendmsg("\0", 0, nil, scm_rights: ios.map(&:fileno)) 
  +      end 
  +    end 
  + 
  +    def test_fd_passing_n4 
  +      fd_passing_test do |s, ios| 
  +        s.sendmsg("\0", 0, nil, scm_rights: ios) 
  +      end 
     end 
  
  -    def test_sendmsg 
  +    def sendmsg_test 
       return if !defined?(Socket::SCM_RIGHTS) 
       IO.pipe {|r1, w| 
         UNIXSocket.pair {|s1, s2| 
           begin 
  -            ret = s1.sendmsg("\0", 0, nil, [Socket::SOL_SOCKET, Socket::SCM_RIGHTS, [r1.fileno].pack("i!")]) 
  +            ret = yield(s1, r1) 
           rescue NotImplementedError 
             return 
           end 
  @@ -122,6 +111,24 @@ class TestSocket_UNIXSocket < Test::Unit::TestCase 
       } 
     end 
  
  +    def test_sendmsg_1 
  +      sendmsg_test do |s, r| 
  +        s.sendmsg("\0", 0, nil, [Socket::SOL_SOCKET, Socket::SCM_RIGHTS, [r.fileno].pack("i!")]) 
  +      end 
  +    end 
  + 
  +    def test_sendmsg_2 
  +      sendmsg_test do |s, r| 
  +        s.sendmsg("\0", 0, nil, scm_rights: r.fileno) 
  +      end 
  +    end 
  + 
  +    def test_sendmsg_3 
  +      sendmsg_test do |s, r| 
  +        s.sendmsg("\0", 0, nil, scm_rights: r) 
  +      end 
  +    end 
  + 
     def test_sendmsg_ancillarydata_int 
       return if !defined?(Socket::SCM_RIGHTS) 
       return if !defined?(Socket::AncillaryData) 
 
 
  --  
  --- 僕の前にBugはない。 
  --- 僕の後ろにBugはできる。 
      中田 伸悦 
 
 =end 
 

Back