Project

General

Profile

Feature #15144 ยป 0001-Implement-Enumerator-Chain-and-Enumerator-chain-Feat.patch

knu (Akinori MUSHA), 11/21/2018 06:12 AM

View differences:

enumerator.c
12 12

  
13 13
************************************************/
14 14

  
15
#include "ruby/ruby.h"
15 16
#include "internal.h"
16 17
#include "id.h"
17 18

  
......
161 162
static VALUE generator_allocate(VALUE klass);
162 163
static VALUE generator_init(VALUE obj, VALUE proc);
163 164

  
165
static VALUE rb_cEnumChain;
166

  
167
struct enum_chain {
168
    VALUE enums;
169
    long pos;
170
};
171

  
164 172
static VALUE rb_cArithSeq;
165 173

  
166 174
/*
......
2411 2419
    return rb_attr_get(self, id_result);
2412 2420
}
2413 2421

  
2422
/*
2423
 * Document-class: Enumerator::Chain
2424
 *
2425
 * Enumerator::Chain is a subclass of Enumerator, which represents a
2426
 * chain of enumerables that works as a single enumerator.
2427
 */
2428

  
2429
static void
2430
enum_chain_mark(void *p)
2431
{
2432
    struct enum_chain *ptr = p;
2433
    rb_gc_mark(ptr->enums);
2434
}
2435

  
2436
#define enum_chain_free RUBY_TYPED_DEFAULT_FREE
2437

  
2438
static size_t
2439
enum_chain_memsize(const void *p)
2440
{
2441
    return sizeof(struct enum_chain);
2442
}
2443

  
2444
static const rb_data_type_t enum_chain_data_type = {
2445
    "chain",
2446
    {
2447
	enum_chain_mark,
2448
	enum_chain_free,
2449
	enum_chain_memsize,
2450
    },
2451
    0, 0, RUBY_TYPED_FREE_IMMEDIATELY
2452
};
2453

  
2454
static struct enum_chain *
2455
enum_chain_ptr(VALUE obj)
2456
{
2457
    struct enum_chain *ptr;
2458

  
2459
    TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
2460
    if (!ptr || ptr->enums == Qundef) {
2461
	rb_raise(rb_eArgError, "uninitialized chain");
2462
    }
2463
    return ptr;
2464
}
2465

  
2466
/* :nodoc: */
2467
static VALUE
2468
enum_chain_allocate(VALUE klass)
2469
{
2470
    struct enum_chain *ptr;
2471
    VALUE obj;
2472

  
2473
    obj = TypedData_Make_Struct(klass, struct enum_chain, &enum_chain_data_type, ptr);
2474
    ptr->enums = Qundef;
2475
    ptr->pos = -1;
2476

  
2477
    return obj;
2478
}
2479

  
2480
/*
2481
 * call-seq:
2482
 *   Enumerator::Chain.new(*enums) -> enum
2483
 *   Enumerator.chain(*enums) -> enum
2484
 *
2485
 * Generates an Enumerator::Chain object from the given
2486
 * enumerable objects.
2487
 *
2488
 *   e = Enumerator.chain(1..3, [4, 5])
2489
 *   e.to_a #=> [1, 2, 3, 4, 5]
2490
 */
2491
static VALUE
2492
enum_chain_initialize(VALUE obj, VALUE enums)
2493
{
2494
    struct enum_chain *ptr;
2495

  
2496
    rb_check_frozen(obj);
2497
    TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
2498

  
2499
    if (!ptr) rb_raise(rb_eArgError, "unallocated chain");
2500

  
2501
    ptr->enums = rb_obj_freeze(enums);
2502
    ptr->pos = -1;
2503

  
2504
    return obj;
2505
}
2506

  
2507
/* :nodoc: */
2508
static VALUE
2509
enum_chain_init_copy(VALUE obj, VALUE orig)
2510
{
2511
    struct enum_chain *ptr0, *ptr1;
2512

  
2513
    if (!OBJ_INIT_COPY(obj, orig)) return obj;
2514
    ptr0 = enum_chain_ptr(orig);
2515

  
2516
    TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr1);
2517

  
2518
    if (!ptr1) rb_raise(rb_eArgError, "unallocated chain");
2519

  
2520
    ptr1->enums = ptr0->enums;
2521
    ptr1->pos = ptr0->pos;
2522

  
2523
    return obj;
2524
}
2525

  
2526
static VALUE
2527
enum_chain_total_size(VALUE enums)
2528
{
2529
    VALUE total = INT2FIX(0);
2530

  
2531
    RARRAY_PTR_USE(enums, ptr, {
2532
        long i;
2533

  
2534
        for (i = 0; i < RARRAY_LEN(enums); i++) {
2535
            VALUE size = enum_size(ptr[i]);
2536

  
2537
            if (NIL_P(size) || (RB_TYPE_P(size, T_FLOAT) && isinf(NUM2DBL(size)))) {
2538
                return size;
2539
            }
2540
            if (!RB_INTEGER_TYPE_P(size)) {
2541
                return Qnil;
2542
            }
2543

  
2544
            total = rb_funcall(total, '+', 1, size);
2545
        }
2546
    });
2547

  
2548
    return total;
2549
}
2550

  
2551
/*
2552
 * call-seq:
2553
 *   obj.size -> integer
2554
 *
2555
 * Returns the total size of the enumerator chain calculated by
2556
 * summing up the size of each enumerable in the chain.  If any of the
2557
 * enumerables reports its size as nil or Float::INFINITY, that value
2558
 * is returned as the total size.
2559
 */
2560
static VALUE
2561
enum_chain_size(VALUE obj)
2562
{
2563
    return enum_chain_total_size(enum_chain_ptr(obj)->enums);
2564
}
2565

  
2566
static VALUE
2567
enum_chain_enum_size(VALUE obj, VALUE args, VALUE eobj)
2568
{
2569
    return enum_chain_size(obj);
2570
}
2571

  
2572
static VALUE
2573
enum_chain_yield_block(VALUE arg, VALUE block, int argc, VALUE *argv)
2574
{
2575
    return rb_funcallv(block, rb_intern("call"), argc, argv);
2576
}
2577

  
2578
static VALUE
2579
enum_chain_enum_no_size(VALUE obj, VALUE args, VALUE eobj)
2580
{
2581
    return Qnil;
2582
}
2583

  
2584
/*
2585
 * call-seq:
2586
 *   obj.each(*args) { |...| ... } -> obj
2587
 *   obj.each(*args) -> enumerator
2588
 *
2589
 * Iterates over the first enumerable by calling the "each" method on
2590
 * it with the given arguments until it is exhausted, then proceeds to
2591
 * the next enumerable, until all of the enumerables are exhausted.
2592
 *
2593
 * If no block is given, returns an enumerator.
2594
 */
2595
static VALUE
2596
enum_chain_each(int argc, VALUE *argv, VALUE obj)
2597
{
2598
    VALUE enums, block;
2599
    struct enum_chain *objptr;
2600

  
2601
    RETURN_SIZED_ENUMERATOR(obj, argc, argv, argc > 0 ? enum_chain_enum_no_size : enum_chain_enum_size);
2602

  
2603
    objptr = enum_chain_ptr(obj);
2604
    enums = objptr->enums;
2605
    block = rb_block_proc();
2606

  
2607
    RARRAY_PTR_USE(enums, ptr, {
2608
        long i;
2609

  
2610
        for (i = 0; i < RARRAY_LEN(enums); i++) {
2611
            objptr->pos = i;
2612
            rb_block_call(ptr[i], id_each, argc, argv, enum_chain_yield_block, block);
2613
        }
2614
    });
2615

  
2616
    return obj;
2617
}
2618

  
2619
/*
2620
 * call-seq:
2621
 *   obj.rewind -> obj
2622
 *
2623
 * Rewinds the enumerator chain by calling the "rewind" method on each
2624
 * enumerable in reverse order.  Each call is performed only if the
2625
 * enumerable responds to the method.
2626
 */
2627
static VALUE
2628
enum_chain_rewind(VALUE obj)
2629
{
2630
    struct enum_chain *objptr = enum_chain_ptr(obj);
2631
    VALUE enums = objptr->enums;
2632

  
2633
    RARRAY_PTR_USE(enums, ptr, {
2634
        long i;
2635

  
2636
        for (i = objptr->pos; 0 <= i && i < RARRAY_LEN(enums); objptr->pos = --i) {
2637
            rb_check_funcall(ptr[i], id_rewind, 0, 0);
2638
        }
2639
    });
2640

  
2641
    return obj;
2642
}
2643

  
2644
static VALUE
2645
inspect_enum_chain(VALUE obj, VALUE dummy, int recur)
2646
{
2647
    VALUE klass = rb_obj_class(obj);
2648
    struct enum_chain *ptr;
2649

  
2650
    TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
2651

  
2652
    if (!ptr || ptr->enums == Qundef) {
2653
	return rb_sprintf("#<%"PRIsVALUE": uninitialized>", rb_class_path(klass));
2654
    }
2655

  
2656
    if (recur) {
2657
	return rb_sprintf("#<%"PRIsVALUE": ...>", rb_class_path(klass));
2658
    }
2659

  
2660
    return rb_sprintf("#<%"PRIsVALUE": %+"PRIsVALUE">", rb_class_path(klass), ptr->enums);
2661
}
2662

  
2663
/*
2664
 * call-seq:
2665
 *   obj.inspect -> string
2666
 *
2667
 * Returns a printable version of the enumerator chain.
2668
 */
2669
static VALUE
2670
enum_chain_inspect(VALUE obj)
2671
{
2672
    return rb_exec_recursive(inspect_enum_chain, obj, 0);
2673
}
2674

  
2675
/*
2676
 * call-seq:
2677
 *   e.chain(*enums) -> enumerator
2678
 *
2679
 * Returns an Enumerator::Chain object generated from this enumerator
2680
 * and given enumerables.
2681
 *
2682
 *   e = (1..3).each.chain([4, 5])
2683
 *   e.to_a #=> [1, 2, 3, 4, 5]
2684
 */
2685
static VALUE
2686
enumerator_s_chain(int argc, VALUE enums)
2687
{
2688
    return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
2689
}
2690

  
2691
/*
2692
 * call-seq:
2693
 *   e.chain(*enums) -> enumerator
2694
 *
2695
 * Returns an Enumerator::Chain object generated from this enumerator
2696
 * and given enumerables.
2697
 *
2698
 *   e = (1..3).each.chain([4, 5])
2699
 *   e.to_a #=> [1, 2, 3, 4, 5]
2700
 */
2701
static VALUE
2702
enumerator_chain(int argc, VALUE *argv, VALUE obj)
2703
{
2704
    VALUE enums = rb_ary_new_from_values(1, &obj);
2705
    rb_ary_cat(enums, argv, argc);
2706

  
2707
    return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
2708
}
2709

  
2710
/*
2711
 * call-seq:
2712
 *   e + enum        -> enumerator
2713
 *
2714
 * Returns an Enumerator::Chain object generated from this enumerator
2715
 * and a given enumerable.
2716
 *
2717
 *   e = (1..3).each + [4, 5]
2718
 *   e.to_a #=> [1, 2, 3, 4, 5]
2719
 */
2720
static VALUE
2721
enumerator_plus(VALUE obj, VALUE eobj)
2722
{
2723
    VALUE enums = rb_ary_new_from_args(2, obj, eobj);
2724

  
2725
    return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
2726
}
2727

  
2414 2728
/*
2415 2729
 * Document-class: Enumerator::ArithmeticSequence
2416 2730
 *
......
2907 3221
    rb_define_method(rb_cEnumerator, "rewind", enumerator_rewind, 0);
2908 3222
    rb_define_method(rb_cEnumerator, "inspect", enumerator_inspect, 0);
2909 3223
    rb_define_method(rb_cEnumerator, "size", enumerator_size, 0);
3224
    rb_define_method(rb_cEnumerator, "chain", enumerator_chain, -1);
3225
    rb_define_method(rb_cEnumerator, "+", enumerator_plus, 1);
2910 3226

  
2911 3227
    /* Lazy */
2912 3228
    rb_cLazy = rb_define_class_under(rb_cEnumerator, "Lazy", rb_cEnumerator);
......
2960 3276
    rb_define_method(rb_cYielder, "yield", yielder_yield, -2);
2961 3277
    rb_define_method(rb_cYielder, "<<", yielder_yield_push, 1);
2962 3278

  
3279
    /* Chain */
3280
    rb_cEnumChain = rb_define_class_under(rb_cEnumerator, "Chain", rb_cEnumerator);
3281
    rb_define_alloc_func(rb_cEnumChain, enum_chain_allocate);
3282
    rb_define_method(rb_cEnumChain, "initialize", enum_chain_initialize, -2);
3283
    rb_define_method(rb_cEnumChain, "initialize_copy", enum_chain_init_copy, 1);
3284
    rb_define_method(rb_cEnumChain, "each", enum_chain_each, -1);
3285
    rb_define_method(rb_cEnumChain, "size", enum_chain_size, 0);
3286
    rb_define_method(rb_cEnumChain, "rewind", enum_chain_rewind, 0);
3287
    rb_define_method(rb_cEnumChain, "inspect", enum_chain_inspect, 0);
3288
    rb_define_singleton_method(rb_cEnumerator, "chain", enumerator_s_chain, -2);
3289

  
2963 3290
    /* ArithmeticSequence */
2964 3291
    rb_cArithSeq = rb_define_class_under(rb_cEnumerator, "ArithmeticSequence", rb_cEnumerator);
2965 3292
    rb_undef_alloc_func(rb_cArithSeq);
test/ruby/test_enumerator.rb
670 670
    assert_equal([0, 1], u.force)
671 671
    assert_equal([0, 1], u.force)
672 672
  end
673

  
674
  def test_chain_and_plus
675
    a = (1..5).each
676

  
677
    e1 = a.chain()
678
    assert_kind_of(Enumerator::Chain, e1)
679
    assert_equal(5, e1.size)
680
    ary = []
681
    e1.each { |x| ary << x }
682
    assert_equal([1, 2, 3, 4, 5], ary)
683

  
684
    e2 = a + [6, 7, 8]
685
    assert_kind_of(Enumerator::Chain, e2)
686
    assert_equal(8, e2.size)
687
    ary = []
688
    e2.each { |x| ary << x }
689
    assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
690

  
691
    e3 = a.chain([6, 7], 8.step)
692
    assert_kind_of(Enumerator::Chain, e3)
693
    assert_equal(Float::INFINITY, e3.size)
694
    ary = []
695
    e3.take(10).each { |x| ary << x }
696
    assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
697

  
698
    # `a + b + c` should not return `Enumerator.chain(a, b, c)`
699
    # because it is expected that `(a + b).each` be called.
700
    e4 = e2.dup
701
    class << e4
702
      attr_reader :each_is_called
703
      def each
704
        super
705
        @each_is_called = true
706
      end
707
    end
708
    e5 = e4 + 9.step
709
    assert_kind_of(Enumerator::Chain, e5)
710
    assert_equal(Float::INFINITY, e5.size)
711
    ary = []
712
    e5.take(10).each { |x| ary << x }
713
    assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
714
    assert_equal(true, e4.each_is_called)
715
  end
716

  
717
  def test_chained_enums
718
    a = (1..5).each
719

  
720
    e0 = Enumerator::Chain.new()
721
    assert_kind_of(Enumerator::Chain, e0)
722
    assert_equal(0, e0.size)
723
    ary = []
724
    e0.each { |x| ary << x }
725
    assert_equal([], ary)
726

  
727
    e1 = Enumerator::Chain.new(a)
728
    assert_kind_of(Enumerator::Chain, e1)
729
    assert_equal(5, e1.size)
730
    ary = []
731
    e1.each { |x| ary << x }
732
    assert_equal([1, 2, 3, 4, 5], ary)
733

  
734
    e2 = Enumerator.chain(a, [6, 7, 8])
735
    assert_kind_of(Enumerator::Chain, e2)
736
    assert_equal(8, e2.size)
737
    ary = []
738
    e2.each { |x| ary << x }
739
    assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
740

  
741
    e3 = Enumerator.chain(a, [6, 7], 8.step)
742
    assert_kind_of(Enumerator::Chain, e3)
743
    assert_equal(Float::INFINITY, e3.size)
744
    ary = []
745
    e3.take(10).each { |x| ary << x }
746
    assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
747

  
748
    e4 = Enumerator.chain(a, Enumerator.new { |y| y << 6 << 7 << 8 })
749
    assert_kind_of(Enumerator::Chain, e4)
750
    assert_equal(nil, e4.size)
751
    ary = []
752
    e4.each { |x| ary << x }
753
    assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
754

  
755
    e5 = Enumerator.chain(e1, e2)
756
    assert_kind_of(Enumerator::Chain, e5)
757
    assert_equal(13, e5.size)
758
    ary = []
759
    e5.each { |x| ary << x }
760
    assert_equal([1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8], ary)
761

  
762
    rewound = []
763
    e1.define_singleton_method(:rewind) { rewound << object_id }
764
    e2.define_singleton_method(:rewind) { rewound << object_id }
765
    e5.rewind
766
    assert_equal(rewound, [e2.object_id, e1.object_id])
767

  
768
    rewound = []
769
    a = [1]
770
    e6 = Enumerator.chain(a)
771
    a.define_singleton_method(:rewind) { rewound << object_id }
772
    e6.rewind
773
    assert_equal(rewound, [])
774

  
775
    assert_equal(
776
      '#<Enumerator::Chain: [' +
777
        '#<Enumerator::Chain: [' +
778
          '#<Enumerator: 1..5:each>' +
779
        ']>, ' +
780
        '#<Enumerator::Chain: [' +
781
          '#<Enumerator: 1..5:each>, ' +
782
          '[6, 7, 8]' +
783
        ']>' +
784
      ']>',
785
      e5.inspect
786
    )
787
  end
673 788
end
674

  
675
-