diff --git a/ext/openssl/lib/openssl/buffering.rb b/ext/openssl/lib/openssl/buffering.rb index eb39dab..2256a14 100644 --- a/ext/openssl/lib/openssl/buffering.rb +++ b/ext/openssl/lib/openssl/buffering.rb @@ -182,6 +182,27 @@ module OpenSSL::Buffering ret end + def try_read_nonblock(maxlen, buf=nil) + if maxlen == 0 + if buf + buf.clear + return buf + else + return "" + end + end + if @rbuffer.empty? + return try_sysread_nonblock(maxlen, buf) + end + ret = consume_rbuff(maxlen) + if buf + buf.replace(ret) + ret = buf + end + raise EOFError if ret.empty? + ret + end + ## # Reads the next "line+ from the stream. Lines are separated by +eol+. If # +limit+ is provided the result will not be longer than the given number of @@ -374,6 +395,11 @@ module OpenSSL::Buffering syswrite_nonblock(s) end + def try_write_nonblock(s) + flush + try_syswrite_nonblock(s) + end + ## # Writes +s+ to the stream. +s+ will be converted to a String using # String#to_s. diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c index ed820cd..6851d09 100644 --- a/ext/openssl/ossl_ssl.c +++ b/ext/openssl/ossl_ssl.c @@ -1232,7 +1232,7 @@ ossl_ssl_accept_nonblock(VALUE self) } static VALUE -ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) +ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock, int no_exception) { SSL *ssl; int ilen, nread = 0; @@ -1260,17 +1260,23 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) case SSL_ERROR_NONE: goto end; case SSL_ERROR_ZERO_RETURN: + if (no_exception) { return Qnil; } rb_eof_error(); case SSL_ERROR_WANT_WRITE: + if (no_exception) { return ID2SYM(rb_intern("write_would_block")); } write_would_block(nonblock); rb_io_wait_writable(FPTR_TO_FD(fptr)); continue; case SSL_ERROR_WANT_READ: + if (no_exception) { return ID2SYM(rb_intern("read_would_block")); } read_would_block(nonblock); rb_io_wait_readable(FPTR_TO_FD(fptr)); continue; case SSL_ERROR_SYSCALL: - if(ERR_peek_error() == 0 && nread == 0) rb_eof_error(); + if(ERR_peek_error() == 0 && nread == 0) { + if (no_exception) { return Qnil; } + rb_eof_error(); + } rb_sys_fail(0); default: ossl_raise(eSSLError, "SSL_read:"); @@ -1302,7 +1308,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) static VALUE ossl_ssl_read(int argc, VALUE *argv, VALUE self) { - return ossl_ssl_read_internal(argc, argv, self, 0); + return ossl_ssl_read_internal(argc, argv, self, 0, 0); } /* @@ -1319,11 +1325,28 @@ ossl_ssl_read(int argc, VALUE *argv, VALUE self) static VALUE ossl_ssl_read_nonblock(int argc, VALUE *argv, VALUE self) { - return ossl_ssl_read_internal(argc, argv, self, 1); + return ossl_ssl_read_internal(argc, argv, self, 1, 0); +} + +/* + * call-seq: + * ssl.try_sysread_nonblock(length) => string, :write_would_block, + * :read_would_block, or nil (for EOF) + * ssl.try_sysread_nonblock(length, buffer) => buffer, :write_would_block, + * :read_would_block, or nil (for EOF) + * + * Exactly the same as +sysread_nonblock+, except that instead of raising + * exceptions for EOF or when the read would block, it returns nil, + * :read_would_block or :write_would_block. + */ +static VALUE +ossl_ssl_try_read_nonblock(int argc, VALUE *argv, VALUE self) +{ + return ossl_ssl_read_internal(argc, argv, self, 1, 1); } static VALUE -ossl_ssl_write_internal(VALUE self, VALUE str, int nonblock) +ossl_ssl_write_internal(VALUE self, VALUE str, int nonblock, int no_exception) { SSL *ssl; int nwrite = 0; @@ -1340,10 +1363,12 @@ ossl_ssl_write_internal(VALUE self, VALUE str, int nonblock) case SSL_ERROR_NONE: goto end; case SSL_ERROR_WANT_WRITE: + if (no_exception) { return ID2SYM(rb_intern("write_would_block")); } write_would_block(nonblock); rb_io_wait_writable(FPTR_TO_FD(fptr)); continue; case SSL_ERROR_WANT_READ: + if (no_exception) { return ID2SYM(rb_intern("read_would_block")); } read_would_block(nonblock); rb_io_wait_readable(FPTR_TO_FD(fptr)); continue; @@ -1373,7 +1398,7 @@ ossl_ssl_write_internal(VALUE self, VALUE str, int nonblock) static VALUE ossl_ssl_write(VALUE self, VALUE str) { - return ossl_ssl_write_internal(self, str, 0); + return ossl_ssl_write_internal(self, str, 0, 0); } /* @@ -1386,7 +1411,22 @@ ossl_ssl_write(VALUE self, VALUE str) static VALUE ossl_ssl_write_nonblock(VALUE self, VALUE str) { - return ossl_ssl_write_internal(self, str, 1); + return ossl_ssl_write_internal(self, str, 1, 0); +} + +/* + * call-seq: + * ssl.syswrite_nonblock(string) => Integer, :read_would_block or + * :write_would_block + * + * Exactly the same as +syswrite_nonblock+, except that instead of + * raising an exception if the write would block, returns + * :read_would_block or :write_would_block. + */ +static VALUE +ossl_ssl_try_write_nonblock(VALUE self, VALUE str) +{ + return ossl_ssl_write_internal(self, str, 1, 1); } /* @@ -1946,8 +1986,10 @@ Init_ossl_ssl() rb_define_method(cSSLSocket, "accept_nonblock", ossl_ssl_accept_nonblock, 0); rb_define_method(cSSLSocket, "sysread", ossl_ssl_read, -1); rb_define_private_method(cSSLSocket, "sysread_nonblock", ossl_ssl_read_nonblock, -1); + rb_define_private_method(cSSLSocket, "try_sysread_nonblock", ossl_ssl_try_read_nonblock, -1); rb_define_method(cSSLSocket, "syswrite", ossl_ssl_write, 1); rb_define_private_method(cSSLSocket, "syswrite_nonblock", ossl_ssl_write_nonblock, 1); + rb_define_private_method(cSSLSocket, "try_syswrite_nonblock", ossl_ssl_try_write_nonblock, 1); rb_define_method(cSSLSocket, "sysclose", ossl_ssl_close, 0); rb_define_method(cSSLSocket, "cert", ossl_ssl_get_cert, 0); rb_define_method(cSSLSocket, "peer_cert", ossl_ssl_get_peer_cert, 0); diff --git a/ext/stringio/stringio.c b/ext/stringio/stringio.c index 68baf35..6c94423 100644 --- a/ext/stringio/stringio.c +++ b/ext/stringio/stringio.c @@ -1286,6 +1286,8 @@ strio_read(int argc, VALUE *argv, VALUE self) static VALUE strio_sysread(int argc, VALUE *argv, VALUE self) { + if (argc == 0) { rb_raise(rb_eArgError, "wrong number of arguments (0 for 1)", argc); } + VALUE val = strio_read(argc, argv, self); if (NIL_P(val)) { rb_eof_error(); @@ -1293,6 +1295,24 @@ strio_sysread(int argc, VALUE *argv, VALUE self) return val; } +/* + * call-seq: + * strio.sysread(integer[, outbuf]) -> string or nil + * + * Exactly the same as +sysread+, except that instead of raising an + * EOFError at EOF, returns nil. This matches the +read_nonblock+ + * protocol from the IO class. + */ +static VALUE +strio_try_sysread(int argc, VALUE *argv, VALUE self) +{ + if (argc == 0) { rb_raise(rb_eArgError, "wrong number of arguments (0 for 1)", argc); } + + VALUE val = strio_read(argc, argv, self); + if (NIL_P(val)) { return Qnil; } + return val; +} + #define strio_syswrite strio_write /* @@ -1467,6 +1487,7 @@ Init_stringio() rb_define_method(StringIO, "sysread", strio_sysread, -1); rb_define_method(StringIO, "readpartial", strio_sysread, -1); rb_define_method(StringIO, "read_nonblock", strio_sysread, -1); + rb_define_method(StringIO, "try_read_nonblock", strio_try_sysread, -1); rb_define_method(StringIO, "write", strio_write, 1); rb_define_method(StringIO, "<<", strio_addstr, 1); diff --git a/io.c b/io.c index 6be4e88..c6e5a09 100644 --- a/io.c +++ b/io.c @@ -1883,7 +1883,7 @@ rb_io_set_nonblock(rb_io_t *fptr) } static VALUE -io_getpartial(int argc, VALUE *argv, VALUE io, int nonblock) +io_getpartial(int argc, VALUE *argv, VALUE io, int nonblock, int no_exception) { rb_io_t *fptr; VALUE length, str; @@ -1918,8 +1918,12 @@ io_getpartial(int argc, VALUE *argv, VALUE io, int nonblock) if (n < 0) { if (!nonblock && rb_io_wait_readable(fptr->fd)) goto again; - if (nonblock && (errno == EWOULDBLOCK || errno == EAGAIN)) - rb_mod_sys_fail(rb_mWaitReadable, "read would block"); + if (nonblock && (errno == EWOULDBLOCK || errno == EAGAIN)) { + if (no_exception) + return ID2SYM(rb_intern("read_would_block")); + else + rb_mod_sys_fail(rb_mWaitReadable, "read would block"); + } rb_sys_fail_path(fptr->pathv); } } @@ -1993,7 +1997,7 @@ io_readpartial(int argc, VALUE *argv, VALUE io) { VALUE ret; - ret = io_getpartial(argc, argv, io, 0); + ret = io_getpartial(argc, argv, io, 0, 0); if (NIL_P(ret)) rb_eof_error(); else @@ -2054,13 +2058,67 @@ io_read_nonblock(int argc, VALUE *argv, VALUE io) { VALUE ret; - ret = io_getpartial(argc, argv, io, 1); + ret = io_getpartial(argc, argv, io, 1, 0); if (NIL_P(ret)) rb_eof_error(); else return ret; } +/** + * call-seq: + * ios.try_read_nonblock(maxlen) -> string, nil, or :read_would_block + * ios.try_read_nonblock(maxlen, outbuf) -> outbuf, nil, or :read_would_block + * + * +try_read_nonblock+ is identical to +read_nonblock+, + * except that instead of raising exceptions, blocking + * calls will return :read_would_block, and EOF will + * return nil. + */ +static VALUE +io_try_read_nonblock(int argc, VALUE *argv, VALUE io) +{ + VALUE ret; + + ret = io_getpartial(argc, argv, io, 1, 1); + if (NIL_P(ret)) + return Qnil; + else + return ret; +} + + +static VALUE +io_write_nonblock(VALUE io, VALUE str, int no_exception) +{ + rb_io_t *fptr; + long n; + + rb_secure(4); + if (TYPE(str) != T_STRING) + str = rb_obj_as_string(str); + + io = GetWriteIO(io); + GetOpenFile(io, fptr); + rb_io_check_writable(fptr); + + if (io_fflush(fptr) < 0) + rb_sys_fail(0); + + rb_io_set_nonblock(fptr); + n = write(fptr->fd, RSTRING_PTR(str), RSTRING_LEN(str)); + + if (n == -1) { + if (errno == EWOULDBLOCK || errno == EAGAIN) { + if (no_exception) return ID2SYM(rb_intern("write_would_block")); + rb_mod_sys_fail(rb_mWaitWritable, "write would block"); + } + rb_sys_fail_path(fptr->pathv); + } + + return LONG2FIX(n); +} + /* * call-seq: * ios.write_nonblock(string) -> integer @@ -2117,30 +2175,22 @@ io_read_nonblock(int argc, VALUE *argv, VALUE io) static VALUE rb_io_write_nonblock(VALUE io, VALUE str) { - rb_io_t *fptr; - long n; - - rb_secure(4); - if (TYPE(str) != T_STRING) - str = rb_obj_as_string(str); - - io = GetWriteIO(io); - GetOpenFile(io, fptr); - rb_io_check_writable(fptr); - - if (io_fflush(fptr) < 0) - rb_sys_fail(0); - - rb_io_set_nonblock(fptr); - n = write(fptr->fd, RSTRING_PTR(str), RSTRING_LEN(str)); - - if (n == -1) { - if (errno == EWOULDBLOCK || errno == EAGAIN) - rb_mod_sys_fail(rb_mWaitWritable, "write would block"); - rb_sys_fail_path(fptr->pathv); - } + return io_write_nonblock(io, str, 0); +} - return LONG2FIX(n); +/* + * call-seq: + * ios.try_write_nonblock(string) -> integer or :write_would_block + * + * Works exactly like write_nonblock, with one exception: + * + * * if the write would block, try_write_nonblock returns + * :write_would_block rather than raising IO::WaitWritable + */ +static VALUE +rb_io_try_write_nonblock(VALUE io, VALUE str) +{ + return io_write_nonblock(io, str, 1); } /* @@ -9703,7 +9753,7 @@ argf_getpartial(int argc, VALUE *argv, VALUE argf, int nonblock) RUBY_METHOD_FUNC(0), Qnil, rb_eEOFError, (VALUE)0); } else { - tmp = io_getpartial(argc, argv, ARGF.current_file, nonblock); + tmp = io_getpartial(argc, argv, ARGF.current_file, nonblock, 0); } if (NIL_P(tmp)) { if (ARGF.next_p == -1) { @@ -10594,7 +10644,9 @@ Init_IO(void) rb_define_method(rb_cIO, "readlines", rb_io_readlines, -1); rb_define_method(rb_cIO, "read_nonblock", io_read_nonblock, -1); + rb_define_method(rb_cIO, "try_read_nonblock", io_try_read_nonblock, -1); rb_define_method(rb_cIO, "write_nonblock", rb_io_write_nonblock, 1); + rb_define_method(rb_cIO, "try_write_nonblock", rb_io_try_write_nonblock, 1); rb_define_method(rb_cIO, "readpartial", io_readpartial, -1); rb_define_method(rb_cIO, "read", io_read, -1); rb_define_method(rb_cIO, "write", io_write_m, 1); diff --git a/test/openssl/test_pair.rb b/test/openssl/test_pair.rb index 940fa0c..100a33b 100644 --- a/test/openssl/test_pair.rb +++ b/test/openssl/test_pair.rb @@ -157,19 +157,41 @@ class OpenSSL::TestPair < Test::Unit::TestCase ret = nil assert_nothing_raised("[ruby-core:20298]") { ret = s2.read_nonblock(10) } assert_equal("def\n", ret) + s1.close + assert_raise(EOFError) { s2.read_nonblock(10) } + } + end + + def test_try_read_nonblock + ssl_pair {|s1, s2| + assert_equal :read_would_block, s2.try_read_nonblock(10) + s1.write "abc\ndef\n" + IO.select([s2]) + assert_equal("ab", s2.try_read_nonblock(2)) + assert_equal("c\n", s2.gets) + ret = nil + assert_nothing_raised("[ruby-core:20298]") { ret = s2.try_read_nonblock(10) } + assert_equal("def\n", ret) + s1.close + assert_equal(nil, s2.try_read_nonblock(10)) } end + def write_nonblock(socket, meth, str) + ret = socket.send(meth, str) + ret.is_a?(Symbol) ? 0 : ret + end + def test_write_nonblock ssl_pair {|s1, s2| n = 0 begin - n += s1.write_nonblock("a" * 100000) - n += s1.write_nonblock("b" * 100000) - n += s1.write_nonblock("c" * 100000) - n += s1.write_nonblock("d" * 100000) - n += s1.write_nonblock("e" * 100000) - n += s1.write_nonblock("f" * 100000) + n += write_nonblock s1, :write_nonblock, "a" * 100000 + n += write_nonblock s1, :write_nonblock, "b" * 100000 + n += write_nonblock s1, :write_nonblock, "c" * 100000 + n += write_nonblock s1, :write_nonblock, "d" * 100000 + n += write_nonblock s1, :write_nonblock, "e" * 100000 + n += write_nonblock s1, :write_nonblock, "f" * 100000 rescue IO::WaitWritable end s1.close @@ -177,6 +199,20 @@ class OpenSSL::TestPair < Test::Unit::TestCase } end + def test_try_write_nonblock + ssl_pair {|s1, s2| + n = 0 + n += write_nonblock s1, :try_write_nonblock, "a" * 100000 + n += write_nonblock s1, :try_write_nonblock, "b" * 100000 + n += write_nonblock s1, :try_write_nonblock, "c" * 100000 + n += write_nonblock s1, :try_write_nonblock, "d" * 100000 + n += write_nonblock s1, :try_write_nonblock, "e" * 100000 + n += write_nonblock s1, :try_write_nonblock, "f" * 100000 + s1.close + assert_equal(n, s2.read.length) + } + end + def test_write_nonblock_with_buffered_data ssl_pair {|s1, s2| s1.write "foo" @@ -187,6 +223,16 @@ class OpenSSL::TestPair < Test::Unit::TestCase } end + def test_try_write_nonblock_with_buffered_data + ssl_pair {|s1, s2| + s1.write "foo" + s1.try_write_nonblock("bar") + s1.write "baz" + s1.close + assert_equal("foobarbaz", s2.read) + } + end + def test_connect_accept_nonblock host = "127.0.0.1" port = 0 diff --git a/test/ruby/test_io.rb b/test/ruby/test_io.rb index d4787c7..c6aea61 100644 --- a/test/ruby/test_io.rb +++ b/test/ruby/test_io.rb @@ -1002,6 +1002,16 @@ class TestIO < Test::Unit::TestCase end) end + def test_try_write_nonblock + skip "IO#write_nonblock is not supported on file/pipe." if /mswin|bccwin|mingw/ =~ RUBY_PLATFORM + pipe(proc do |w| + w.try_write_nonblock(1) + w.close + end, proc do |r| + assert_equal("1", r.read) + end) + end + def test_read_nonblock_error return if !have_nonblock? skip "IO#read_nonblock is not supported on file/pipe." if /mswin|bccwin|mingw/ =~ RUBY_PLATFORM @@ -1012,6 +1022,41 @@ class TestIO < Test::Unit::TestCase assert_kind_of(IO::WaitReadable, $!) end } + + with_pipe {|r, w| + begin + r.read_nonblock 4096, "" + rescue Errno::EWOULDBLOCK + assert_kind_of(IO::WaitReadable, $!) + end + } + end + + def test_try_read_nonblock + return if !have_nonblock? + skip "IO#try_read_nonblock is not supported on file/pipe." if /mswin|bccwin|mingw/ =~ RUBY_PLATFORM + with_pipe {|r, w| + assert_equal :read_would_block, r.try_read_nonblock(4096) + w.puts "HI!" + assert_equal "HI!\n", r.try_read_nonblock(4096) + w.close + assert_equal nil, r.try_read_nonblock(4096) + } + end + + def test_try_read_nonblock_with_buffer + return if !have_nonblock? + skip "IO#try_read_nonblock is not supported on file/pipe." if /mswin|bccwin|mingw/ =~ RUBY_PLATFORM + with_pipe {|r, w| + assert_equal :read_would_block, r.try_read_nonblock(4096, "") + w.puts "HI!" + buf = "buf" + value = r.try_read_nonblock(4096, buf) + assert_equal value, "HI!\n" + assert buf.equal?(value) + w.close + assert_equal nil, r.try_read_nonblock(4096, "") + } end def test_write_nonblock_error @@ -1028,6 +1073,20 @@ class TestIO < Test::Unit::TestCase } end + def test_try_write_nonblock + return if !have_nonblock? + skip "IO#write_nonblock is not supported on file/pipe." if /mswin|bccwin|mingw/ =~ RUBY_PLATFORM + with_pipe {|r, w| + loop { + ret = w.try_write_nonblock "a"*100000 + if ret.is_a?(Symbol) + assert ret == :write_would_block + break + end + } + } + end + def test_gets pipe(proc do |w| w.write "foobarbaz" diff --git a/test/socket/test_nonblock.rb b/test/socket/test_nonblock.rb index 59bd4f3..7a0536b 100644 --- a/test/socket/test_nonblock.rb +++ b/test/socket/test_nonblock.rb @@ -190,6 +190,20 @@ class TestSocketNonblock < Test::Unit::TestCase s.close if s end + def test_try_read_nonblock + c, s = tcp_pair + assert_equal :read_would_block, c.try_read_nonblock(100) + assert_equal :read_would_block, s.try_read_nonblock(100) + c.write("abc") + IO.select [s] + assert_equal("a", s.try_read_nonblock(1)) + assert_equal("bc", s.try_read_nonblock(100)) + assert_equal :read_would_block, s.try_read_nonblock(100) + ensure + c.close if c + s.close if s + end + =begin def test_write_nonblock c, s = tcp_pair diff --git a/test/stringio/test_stringio.rb b/test/stringio/test_stringio.rb index 0258218..67ffe04 100644 --- a/test/stringio/test_stringio.rb +++ b/test/stringio/test_stringio.rb @@ -424,7 +424,8 @@ class TestStringIO < Test::Unit::TestCase f = StringIO.new("\u3042\u3044") assert_raise(ArgumentError) { f.readpartial(-1) } assert_raise(ArgumentError) { f.readpartial(1, 2, 3) } - assert_equal("\u3042\u3044", f.readpartial) + assert_raise(ArgumentError) { f.readpartial } + assert_equal("\u3042\u3044".force_encoding(Encoding::ASCII_8BIT), f.readpartial(100)) f.rewind assert_equal("\u3042\u3044".force_encoding(Encoding::ASCII_8BIT), f.readpartial(f.size)) end @@ -433,7 +434,20 @@ class TestStringIO < Test::Unit::TestCase f = StringIO.new("\u3042\u3044") assert_raise(ArgumentError) { f.read_nonblock(-1) } assert_raise(ArgumentError) { f.read_nonblock(1, 2, 3) } - assert_equal("\u3042\u3044", f.read_nonblock) + assert_raise(ArgumentError) { f.read_nonblock } + assert_equal("\u3042\u3044".force_encoding("BINARY"), f.read_nonblock(100)) + assert_raise(EOFError) { f.read_nonblock(10) } + f.rewind + assert_equal("\u3042\u3044".force_encoding(Encoding::ASCII_8BIT), f.read_nonblock(f.size)) + end + + def test_try_read_nonblock + f = StringIO.new("\u3042\u3044") + assert_raise(ArgumentError) { f.try_read_nonblock(-1) } + assert_raise(ArgumentError) { f.try_read_nonblock(1, 2, 3) } + assert_raise(ArgumentError) { f.try_read_nonblock } + assert_equal("\u3042\u3044".force_encoding(Encoding::ASCII_8BIT), f.try_read_nonblock(100)) + assert_equal(nil, f.try_read_nonblock(10)) f.rewind assert_equal("\u3042\u3044".force_encoding(Encoding::ASCII_8BIT), f.read_nonblock(f.size)) end