add_mul_for_openssl_pkey_ec_point.patch

Zachary Scott, 05/09/2012 09:30 PM

Download (3.87 KB)

View differences:

ext/openssl/ossl_pkey_ec.c
1467 1467
    return bn_obj;
1468 1468
}
1469 1469

  
1470
/*
1471
 *  call-seq:
1472
 *     point.mul(bn)  => point
1473
 *     point.mul(bn, bn) => point
1474
 *     point.mul([bn], [point]) => point
1475
 *     point.mul([bn], [point], bn) => point
1476
 *
1477
 */
1478

  
1479
static VALUE ossl_ec_point_mul(int argc, VALUE *argv, VALUE self)
1480
{
1481
    EC_POINT *point1, *point2;
1482
    const EC_GROUP *group;
1483
    VALUE group_v = rb_iv_get(self, "@group");
1484
    VALUE args[1] = {group_v};
1485
    VALUE bn_v1, bn_v2, r, points_v;
1486
    BIGNUM *bn1 = NULL, *bn2 = NULL;
1487

  
1488
    Require_EC_POINT(self, point1);
1489
    SafeRequire_EC_GROUP(group_v, group);
1490

  
1491
    r = rb_obj_alloc(cEC_POINT);
1492
    ossl_ec_point_initialize(1, args, r);
1493
    Require_EC_POINT(r, point2);
1494

  
1495
    argc = rb_scan_args(argc, argv, "12", &bn_v1, &points_v, &bn_v2); 
1496
    
1497
    if (rb_obj_is_kind_of(bn_v1, cBN)) {
1498
        bn1 = GetBNPtr(bn_v1);
1499
        if (argc >= 2) {
1500
            bn2 = GetBNPtr(points_v);
1501
        }
1502
        if (EC_POINT_mul(group, point2, bn2, point1, bn1, ossl_bn_ctx) != 1)
1503
            return Qnil;
1504

  
1505
    }
1506
    else {
1507
        size_t i, points_len, bignums_len;
1508
        EC_POINT **points;
1509
        BIGNUM **bignums;
1510

  
1511
        Check_Type(bn_v1, T_ARRAY);
1512
        bignums_len = RARRAY_LEN(bn_v1);
1513
        bignums = (BIGNUM **)OPENSSL_malloc(bignums_len * sizeof(BIGNUM *));
1514

  
1515
        for (i = 0; i < bignums_len; ++i) {
1516
            bignums[i] = GetBNPtr(rb_ary_entry(bn_v1, i));
1517
        }
1518

  
1519
        if (!rb_obj_is_kind_of(points_v, rb_cArray)) {
1520
            OPENSSL_free(bignums); 
1521
            rb_raise(rb_eTypeError, "Argument2 must be an array");
1522
        }
1523
        rb_ary_unshift(points_v, self);
1524
        points_len = RARRAY_LEN(points_v);
1525
        points = (EC_POINT **)OPENSSL_malloc(points_len * sizeof(EC_POINT *));
1526

  
1527
        for (i = 0; i < points_len; ++i) {
1528
            Get_EC_POINT(rb_ary_entry(points_v, i), points[i]);
1529
        }
1530

  
1531
        if (argc >= 3) {
1532
            bn2 = GetBNPtr(bn_v2);
1533
        }
1534
        if (EC_POINTs_mul(group, point2, bn2, points_len, points, bignums, ossl_bn_ctx) != 1){
1535
            OPENSSL_free(bignums); 
1536
            OPENSSL_free(points); 
1537
            return Qnil;
1538
        }
1539
        OPENSSL_free(bignums); 
1540
        OPENSSL_free(points); 
1541
    } 
1542

  
1543
    return r;
1544
}
1545

  
1470 1546
static void no_copy(VALUE klass)
1471 1547
{
1472 1548
    rb_undef_method(klass, "copy");
......
1587 1663
/* all the other methods */
1588 1664

  
1589 1665
    rb_define_method(cEC_POINT, "to_bn", ossl_ec_point_to_bn, 0);
1666
    rb_define_method(cEC_POINT, "mul", ossl_ec_point_mul, -1);
1590 1667

  
1591 1668
    no_copy(cEC);
1592 1669
    no_copy(cEC_GROUP);
test/openssl/test_pkey_ec.rb
175 175
    assert_equal([], OpenSSL.errors)
176 176
  end
177 177

  
178
  def test_ec_point_mul
179
    ec = OpenSSL::TestUtils::TEST_KEY_EC_P256V1
180
    p1 = ec.public_key
181
    bn1 = OpenSSL::BN.new('10')
182
    bn2 = OpenSSL::BN.new('20')
183

  
184
    p2 = p1.mul(bn1)
185
    assert(p1.group == p2.group)
186
    p2 = p1.mul(bn1, bn2)
187
    assert(p1.group == p2.group)
188
    p2 = p1.mul([bn1, bn2], [p1])
189
    assert(p1.group == p2.group)
190
    p2 = p1.mul([bn1, bn2], [p1], bn2)
191
    assert(p1.group == p2.group)
192
  end
193

  
178 194
# test Group: asn1_flag, point_conversion
179 195

  
180 196
end
181
-