diff --git a/enumerator.c b/enumerator.c index d61d79e897..602c2d568a 100644 --- a/enumerator.c +++ b/enumerator.c @@ -824,6 +824,58 @@ enumerator_next(VALUE obj) return ary2sv(vs, 0); } +/* + * call-seq: + * e.next? -> object + * + * Returns whether there is a next object in the enumerator, but doesn't + * move the internal position forward. + * + * === Example + * + * a = [1,2,3] + * e = a.to_enum + * p e.next? #=> true + * p e.next #=> 1 + * p e.next? #=> true + * p e.next #=> 2 + * p e.next? #=> true + * p e.next #=> 3 + * p e.next? #=> false + * p e.next #raises StopIteration + * + */ + +static VALUE +enumerator_has_next(VALUE obj) +{ + struct enumerator *e = enumerator_ptr(obj); + if (e->stop_exc) + return Qfalse; + + if (e->lookahead != Qundef) + return Qtrue; + + VALUE curr, vs; + + curr = rb_fiber_current(); + + if (!e->fib || !rb_fiber_alive_p(e->fib)) { + next_init(obj, e); + } + + vs = rb_fiber_resume(e->fib, 1, &curr); + if (e->stop_exc) { + e->fib = 0; + e->dst = Qnil; + e->lookahead = Qundef; + e->feedvalue = Qundef; + return Qfalse; + } + e->lookahead = vs; + return Qtrue; +} + static VALUE enumerator_peek_values(VALUE obj) { @@ -2352,6 +2404,7 @@ InitVM_Enumerator(void) rb_define_method(rb_cEnumerator, "next_values", enumerator_next_values, 0); rb_define_method(rb_cEnumerator, "peek_values", enumerator_peek_values_m, 0); rb_define_method(rb_cEnumerator, "next", enumerator_next, 0); + rb_define_method(rb_cEnumerator, "next?", enumerator_has_next, 0); rb_define_method(rb_cEnumerator, "peek", enumerator_peek, 0); rb_define_method(rb_cEnumerator, "feed", enumerator_feed, 1); rb_define_method(rb_cEnumerator, "rewind", enumerator_rewind, 0); diff --git a/test/ruby/test_enumerator.rb b/test/ruby/test_enumerator.rb index 0ee11dad36..37ef7a2673 100644 --- a/test/ruby/test_enumerator.rb +++ b/test/ruby/test_enumerator.rb @@ -38,6 +38,15 @@ def test_next assert_raise(StopIteration){e.next} end + def test_next? + e = 3.times + 3.times{|i| + assert_equal true, e.next? + e.next + } + assert_equal false, e.next? + end + def test_loop e = 3.times i = 0