diff --git a/pack.c b/pack.c index 71dd6af..6e515b2 100644 --- a/pack.c +++ b/pack.c @@ -234,6 +234,31 @@ static void qpencode(VALUE,VALUE,long); static unsigned long utf8_to_uv(const char*,long*); +static ID id_associated; + +static void +str_associate(VALUE str, VALUE add) +{ + VALUE assoc; + + assoc = rb_attr_get(str, id_associated); + if (RB_TYPE_P(assoc, T_ARRAY)) { + /* already associated */ + rb_ary_concat(assoc, add); + } + else { + rb_ivar_set(str, id_associated, add); + } +} + +static VALUE +str_associated(VALUE str) +{ + VALUE assoc = rb_attr_get(str, id_associated); + if (NIL_P(assoc)) assoc = Qfalse; + return assoc; +} + /* * call-seq: * arr.pack ( aTemplateString ) -> aBinaryString @@ -921,7 +960,7 @@ pack_pack(VALUE ary, VALUE fmt) } if (associates) { - rb_str_associate(res, associates); + str_associate(res, associates); } OBJ_INFECT(res, fmt); switch (enc_info) { @@ -1801,7 +1840,7 @@ pack_unpack(VALUE str, VALUE fmt) VALUE a; const VALUE *p, *pend; - if (!(a = rb_str_associated(str))) { + if (!(a = str_associated(str))) { rb_raise(rb_eArgError, "no associated pointer"); } p = RARRAY_CONST_PTR(a); @@ -1810,7 +1849,7 @@ pack_unpack(VALUE str, VALUE fmt) if (RB_TYPE_P(*p, T_STRING) && RSTRING_PTR(*p) == t) { if (len < RSTRING_LEN(*p)) { tmp = rb_tainted_str_new(t, len); - rb_str_associate(tmp, a); + str_associate(tmp, a); } else { tmp = *p; @@ -1844,7 +1883,7 @@ pack_unpack(VALUE str, VALUE fmt) VALUE a; const VALUE *p, *pend; - if (!(a = rb_str_associated(str))) { + if (!(a = str_associated(str))) { rb_raise(rb_eArgError, "no associated pointer"); } p = RARRAY_CONST_PTR(a); @@ -2006,4 +2045,6 @@ Init_pack(void) { rb_define_method(rb_cArray, "pack", pack_pack, 1); rb_define_method(rb_cString, "unpack", pack_unpack, 1); + + id_associated = rb_intern_const("__pack_associated__"); } diff --git a/test/ruby/test_pack.rb b/test/ruby/test_pack.rb index 3f0931b..38c1981 100644 --- a/test/ruby/test_pack.rb +++ b/test/ruby/test_pack.rb @@ -181,7 +181,7 @@ def test_pack_p assert_equal a[0], a.pack("p").unpack("p")[0] assert_equal a, a.pack("p").freeze.unpack("p*") assert_raise(ArgumentError) { (a.pack("p") + "").unpack("p*") } - assert_raise(ArgumentError) { (a.pack("p") << "d").unpack("p*") } + assert_equal a, (a.pack("p") << "d").unpack("p*") end def test_format_string_modified