Ruby Hack Challeng で Refinements のバグを治そうとした話

前回参加した Ruby Hack Challeng からだいぶ時間が経ってしまったんですが、その時にやっていたことを覚書程度に書いておこうかと。
ちなみにバグと言っていますが、バグなのか微妙なところだったり。
あと結構フィーリングで書いているので実際の Ruby(の実装)と言ってることが異なるかもしれませんがまあなんか察してください。
なんかガッツリ書いたら長くなったので『兎に角問題だけ知りたいんじゃ!!』って人は bugs.ruby の issuesを参照してください。

Refinements とは

Refinements が実装されてだいぶ時間が経っているのでもう皆さんご存知だと思いますが、一応説明しておくと、『任意のスコープでのみクラス拡張を適用するため』の機能です。

module IntegerEx
    # Integer#twice を定義する
    refine Integer do
        def twice
            self + self
        end
    end
end

class X
    # X クラスのスコープ内でのみ Integer#twice が使用できる
    using IntegerEx

    def initialize value
        @value = value
    end

    def meth
        @value.twice
    end
end

x = X.new 42
p x.meth # => 84

# ここでは使えない
42.twice

Ruby の場合、クラス拡張を使う事で思わぬ副作用が発生する事があるので、それを制限するための機能になります。

Refinements の制限

さて、そんな Refinements ですが、実装された当初はいろいろと制限があり、例えば以下のようなコードは動きませんでした。

module IntegerEx
    # Integer#twice を定義する
    refine Integer do
        def twice
            self + self
        end
    end
end
using IntegerEx

# これは #twice を直接呼び出しているので OK
(1..3).map { |it| it.twice }

# これは Symbol#to_proc 内部で #twice を呼び出しているので NG
(1..3).map(&:twice)

この制限は Ruby 2.4 で緩和されて『Symbol#to_proc 内から Refinements されたメソッドが呼び出せる』ようになりました。

module IntegerEx
    # Integer#twice を定義する
    refine Integer do
        def twice
            self + self
        end
    end
end
using IntegerEx

# Ruby 2.4 以降では OK
(1..3).map(&:twice)

そう、なったはずだったんです…。

問題点

Ruby 2.4 の relese note では次のように明記されています。

Refinements is enabled at method by Symbol#to_proc. [Feature #9451] see: https://github.com/ruby/ruby/blob/v2_4_0/NEWS#changes-since-the-230-release

Symbol#to_proc が Refinements で適用されたかのように書かれていますね。
しかし、現実はこの実装は不完全であり、例えば次のように『ユーザが定義したメソッド』に対して同様に『Refinements したメソッドを &:hoge で渡す(呼び出す)』とエラーになります。

def meth &block
    # ここで Integer#twice が呼び出される
    block.call 42
end

module IntegerEx
    # Integer#twice を定義する
    refine Integer do
        def twice
            self + self
        end
    end
end
using IntegerEx

# Error: `meth': undefined method `twice' for 42:Integer (NoMethodError)
meth &:twice

meth &:twice の部分は問題ありませんが、#meth 内で block.call を呼び出そうとするとエラーになります。
これは次のように『 include Enumerable を mixin してイテレーションメソッドを呼び出したい』ような実装で問題になります。

class X
    include Enumerable

    def each &block
        (1..3).each &block
    end
end


module IntegerEx
    # Integer#twice を定義する
    refine Integer do
        def twice
            self + self
        end
    end
end
using IntegerEx

# こっちは当然 Integer#twice が呼ばれる
(1..3).map &:twice

# Error: `each': undefined method `twice' for 1:Integer (NoMethodError)
# こっちは X#each 経由で Symbol#to_proc を呼び出そうとしているのでエラーになる
X.new.map &:twice

実際、この問題は RailsActiveRecordイテレーションメソッドを使用してる時に気づきました。

# hoge が Refinements で定義されている場合、エラーになってしまう
User.all.map &:hoge

ちなみに Symbol#to_proc を直接呼び出した場合にもエラーになります。

module IntegerEx
    # Integer#twice を定義する
    refine Integer do
        def twice
            self + self
        end
    end
end
using IntegerEx

# 内部で 42.twice が呼ばれるが Refinements が反映されない
:twice.to_proc.call 42

これは流石に意図していない挙動かなーと思い修正しようと思いました。

CRuby 側の実装

今回の問題ですが、見るべき実装は二箇所あり、

  1. メソッドに &:twice を渡している部分
    • 内部的に Symbol#to_proc が呼び出される
  2. block.call が呼ばれる部分
    • つまり Symbol#to_proc で生成した Proc オブジェクトが評価される部分

になります。
おそらく問題となっているのは『&:twice を呼び出している部分』と『block.call が呼ばれる部分』でコンテキストが異なっており、参照する refinements 情報が違うのかなーという予想です。
ではでは、実際に CRuby の実装がどうなっているのかコードを見てみましょう。
上記の実装はどちらも vm_args.c で書かれています。

&:twice から proc オブジェクトを生成している箇所

さて、いきなりですが『&:twice から proc オブジェクトを生成しているコード』は以下の関数になります。

static VALUE
vm_caller_setup_arg_block(const rb_execution_context_t *ec, rb_control_frame_t *reg_cfp,
                          const struct rb_call_info *ci, rb_iseq_t *blockiseq, const int is_super)
{
    if (ci->flag & VM_CALL_ARGS_BLOCKARG) {
    VALUE block_code = *(--reg_cfp->sp);

    if (NIL_P(block_code)) {
            return VM_BLOCK_HANDLER_NONE;
        }
    else if (block_code == rb_block_param_proxy) {
            return VM_CF_BLOCK_HANDLER(reg_cfp);
        }
    else if (SYMBOL_P(block_code) && rb_method_basic_definition_p(rb_cSymbol, idTo_proc)) {
        const rb_cref_t *cref = vm_env_cref(reg_cfp->ep);
        if (cref && !NIL_P(cref->refinements)) {
        VALUE ref = cref->refinements;
        VALUE func = rb_hash_lookup(ref, block_code);
        if (NIL_P(func)) {
            /* TODO: limit cached funcs */
            func = rb_func_proc_new(refine_sym_proc_call, block_code);
            rb_hash_aset(ref, block_code, func);
        }
        block_code = func;
        }
            return block_code;
        }
        else {
            return vm_to_proc(block_code);
        }
    }
    else if (blockiseq != NULL) { /* likely */
    struct rb_captured_block *captured = VM_CFP_TO_CAPTURED_BLOCK(reg_cfp);
    captured->code.iseq = blockiseq;
        return VM_BH_FROM_ISEQ_BLOCK(captured);
    }
    else {
    if (is_super) {
            return GET_BLOCK_HANDLER();
        }
        else {
            return VM_BLOCK_HANDLER_NONE;
        }
    }
}

https://github.com/ruby/ruby/blob/1aef602d5a9398ff362de212ae402ffcd8ff76ae/vm_args.c#L872

この関数は『メソッドの引数まわりの処理』が実装されていますが、今回注目するべき箇所は上記の

else if (SYMBOL_P(block_code) && rb_method_basic_definition_p(rb_cSymbol, idTo_proc)) {
    const rb_cref_t *cref = vm_env_cref(reg_cfp->ep);
    if (cref && !NIL_P(cref->refinements)) {
    VALUE ref = cref->refinements;
    VALUE func = rb_hash_lookup(ref, block_code);
    if (NIL_P(func)) {
        /* TODO: limit cached funcs */
        func = rb_func_proc_new(refine_sym_proc_call, block_code);
        rb_hash_aset(ref, block_code, func);
    }
    block_code = func;
    }
        return block_code;
    }
    else {
        return vm_to_proc(block_code);
    }
}

の部分になります。
ここの

rb_func_proc_new(refine_sym_proc_call, block_code)

というコードが『&:twice から Proc オブジェクトを生成している部分』になります。
rb_func_proc_new() では #call 時に呼ばれるコールバック関数として refine_sym_proc_call() の関数ポインタを渡しています。
block_code はそのコールバック関数に渡される引数になります。
Ruby で書くとこんな感じでしょうか。

block_code = :twice
func = proc { refine_sym_proc_call(block_code) }

さて、この箇所ではまだ Refinements に関する処理は見当たりませんね。
cref->refinements は見なかったことに

block.call が呼ばれる部分

次に Ruby の『block.call が呼ばれる部分』、つまり refine_sym_proc_call() の中でどのように処理されているのか見てみましょう。
そろそろ Ruby なのか C言語の話なのかわからなくなってきましたね。
わたしもわかりません。

static VALUE
refine_sym_proc_call(RB_BLOCK_CALL_FUNC_ARGLIST(yielded_arg, callback_arg))
{
    VALUE obj;
    ID mid;
    const rb_callable_method_entry_t *me;
    rb_execution_context_t *ec;

    if (argc-- < 1) {
        rb_raise(rb_eArgError, "no receiver given");
    }
    obj = *argv++;
    mid = SYM2ID(callback_arg);
    me = rb_callable_method_entry_with_refinements(CLASS_OF(obj), mid, NULL);
    ec = GET_EC();
    if (!NIL_P(blockarg)) {
        vm_passed_block_handler_set(ec, blockarg);
    }
    if (!me) {
        return method_missing(obj, mid, argc, argv, MISSING_NOENTRY);
    }
    return rb_vm_call0(ec, obj, mid, argc, argv, me);
}

https://github.com/ruby/ruby/blob/1aef602d5a9398ff362de212ae402ffcd8ff76ae/vm_args.c#L847

refine_sym_proc_call() の実装はこんな感じになっています。
ここでも注目すべき点を上げると

me = rb_callable_method_entry_with_refinements(CLASS_OF(obj), mid, NULL);

この部分になります。
rb_callable_method_entry_with_refinements って名前からしてなんかすごく『refinements を適用させた呼び出し可能なメソッドオブジェクトを取得している』って感じですよね。
実際にこの rb_callable_method_entry_with_refinements() で『現在のコンテキストで Refinements を適用させた呼び出し可能なメソッドのオブジェクト』を取得しています。

Refinements を適用させている箇所を探す

いよいよ rb_callable_method_entry_with_refinements() で、どのようにして『Refinements を適用させているのか』というのを調べてみましょう。
rb_callable_method_entry_with_refinements() は次のような実装になっています。

MJIT_FUNC_EXPORTED const rb_callable_method_entry_t *
rb_callable_method_entry_with_refinements(VALUE klass, ID id, VALUE *defined_class_ptr)
{
    VALUE defined_class, *dcp = defined_class_ptr ? defined_class_ptr : &defined_class;
    const rb_method_entry_t *me = method_entry_resolve_refinement(klass, id, TRUE, dcp);
    return prepare_callable_method_entry(*dcp, id, me);
}

https://github.com/ruby/ruby/blob/1aef602d5a9398ff362de212ae402ffcd8ff76ae/vm_method.c#L892

どんどん潜っていきましょう。
次は method_entry_resolve_refinement() を見ます。

static const rb_method_entry_t *
method_entry_resolve_refinement(VALUE klass, ID id, int with_refinement, VALUE *defined_class_ptr)
{
    const rb_method_entry_t *me = method_entry_get(klass, id, defined_class_ptr);

    if (me) {
        if (me->def->type == VM_METHOD_TYPE_REFINED) {
            if (with_refinement) {
                const rb_cref_t *cref = rb_vm_cref();
                VALUE refinements = cref ? CREF_REFINEMENTS(cref) : Qnil;
                me = resolve_refined_method(refinements, me, defined_class_ptr);
            }
            else {
                me = resolve_refined_method(Qnil, me, defined_class_ptr);
            }

            if (UNDEFINED_METHOD_ENTRY_P(me)) me = NULL;
        }
    }

    return me;
}

https://github.com/ruby/ruby/blob/1aef602d5a9398ff362de212ae402ffcd8ff76ae/vm_method.c#L869

お、ここで興味深いコードが出てきましたね。

const rb_cref_t *cref = rb_vm_cref();
VALUE refinements = cref ? CREF_REFINEMENTS(cref) : Qnil;
me = resolve_refined_method(refinements, me, defined_class_ptr);

ここで refinements というオブジェクトを参照してメソッドを探査しています。
おそらくこの refinements っていうのが Refinements のコンテキスト情報を保持してそうですね。
と、言うことで rb_callable_method_entry_with_refinements() のコードに戻って『refinements を使用した実装』に変えてみましょう。

@@ -836,13 +836,17 @@ refine_sym_proc_call(RB_BLOCK_CALL_FUNC_ARGLIST(yielded_arg, callback_arg))
     ID mid;
     const rb_callable_method_entry_t *me;
     rb_execution_context_t *ec;
+    // ここで refinements を取得
+    const rb_cref_t *cref = rb_vm_cref();
+    VALUE refinements = cref->refinements;

     if (argc-- < 1) {
        rb_raise(rb_eArgError, "no receiver given");
     }
     obj = *argv++;
     mid = SYM2ID(callback_arg);
-    me = rb_callable_method_entry_with_refinements(CLASS_OF(obj), mid, NULL);
+    // refinements を渡して呼び出し可能なメソッドオブジェクトを取得する
+    me = rb_resolve_refined_method_callable(refinements, (const rb_callable_method_entry_t *)rb_method_entry(CLASS_OF(obj), mid));
     ec = GET_EC();
static VALUE
refine_sym_proc_call(RB_BLOCK_CALL_FUNC_ARGLIST(yielded_arg, callback_arg))
{
    VALUE obj;
    ID mid;
    const rb_callable_method_entry_t *me;
    rb_execution_context_t *ec;
    // ここで refinements を取得
    const rb_cref_t *cref = rb_vm_cref();
    VALUE refinements = cref->refinements;

    if (argc-- < 1) {
        rb_raise(rb_eArgError, "no receiver given");
    }
    obj = *argv++;
    mid = SYM2ID(callback_arg);
    // refinements を渡して呼び出し可能なメソッドオブジェクトを取得する
    me = rb_resolve_refined_method_callable(refinements, (const rb_callable_method_entry_t *)rb_method_entry(CLASS_OF(obj), mid));
    ec = GET_EC();
    if (!NIL_P(blockarg)) {
        vm_passed_block_handler_set(ec, blockarg);
    }
    if (!me) {
        return method_missing(obj, mid, argc, argv, MISSING_NOENTRY);
    }
    return rb_vm_call0(ec, obj, mid, argc, argv, me);
}

ここで resolve_refined_method() ではなくて rb_resolve_refined_method_callable() を使用しているのは実装上の都合です。
resolve_refined_method()rb_resolve_refined_method_callable() もだいたい似たような処理で両方共『refinements を指定して』処理することが出来ます。
上記のコードではまだ refine_sym_proc_call() のコンテキスト中、つまり『block.call を呼び出したコンテキストの refinements』を参照しています。

&:twice 時の refinements を参照するようにする

さて、では実際に vm_caller_setup_arg_block() が呼び出されたタイミング refinements をrefine_sym_proc_call()内で参照できるようにしてみましょう。 してみましょうと言っても実際 vm_caller_setup_arg_block() 時の refinements をどうやって refine_sym_proc_call() に渡せばいいんでしょうか。
今回は『rb_func_proc_new() の第二匹引数に block_coderefinements の配列を渡す』ようにしてみます。
Ruby のコードだとこんなイメージですね。

func = proc { refine_sym_proc_call([block_code, refinements]) }

これを実際に C言語で実装してみましょう。

static VALUE
refine_sym_proc_call(RB_BLOCK_CALL_FUNC_ARGLIST(yielded_arg, callback_arg))
{
    VALUE obj;
    ID mid;
    const rb_callable_method_entry_t *me;
    rb_execution_context_t *ec;
    // callback_arg は [block_code, refinements] の配列になっている
    // そこから block_code(Symbol) と refinements の情報を取り出す
    const VALUE symbol = RARRAY_AREF(callback_arg, 0);
    const VALUE refinements = RARRAY_AREF(callback_arg, 1);

    if (argc-- < 1) {
        rb_raise(rb_eArgError, "no receiver given");
    }
    obj = *argv++;
    mid = SYM2ID(symbol);
    me = rb_resolve_refined_method_callable(refinements, (const rb_callable_method_entry_t *)rb_method_entry(CLASS_OF(obj), mid));
    ec = GET_EC();
    if (!NIL_P(blockarg)) {
        vm_passed_block_handler_set(ec, blockarg);
    }
    if (!me) {
        return method_missing(obj, mid, argc, argv, MISSING_NOENTRY);
    }
    return rb_vm_call0(ec, obj, mid, argc, argv, me);
}

static void
vm_caller_setup_arg_block(const rb_execution_context_t *ec, rb_control_frame_t *reg_cfp,
                          struct rb_calling_info *calling, const struct rb_call_info *ci, rb_iseq_t *blockiseq, const int is_super)
{
    if (ci->flag & VM_CALL_ARGS_BLOCKARG) {
        VALUE block_code = *(--reg_cfp->sp);

        if (NIL_P(block_code)) {
            calling->block_handler = VM_BLOCK_HANDLER_NONE;
        }
        else if (block_code == rb_block_param_proxy) {
            calling->block_handler = VM_CF_BLOCK_HANDLER(reg_cfp);
        }
        else if (SYMBOL_P(block_code) && rb_method_basic_definition_p(rb_cSymbol, idTo_proc)) {
            const rb_cref_t *cref = vm_env_cref(reg_cfp->ep);
            if (cref && !NIL_P(cref->refinements)) {
                VALUE ref = cref->refinements;
                // [block_code(Symbol), refinements] となるような配列を生成する
                VALUE callback_arg = rb_ary_new_from_args(2, block_code, ref);
                VALUE func = rb_hash_lookup(ref, block_code);
                if (NIL_P(func)) {
                    /* TODO: limit cached funcs */
                    // block_code(Symbol) と refinements を一緒にして refine_sym_proc_call() に渡す
                    func = rb_func_proc_new(refine_sym_proc_call, callback_arg);
                    rb_hash_aset(ref, block_code, func);
                }
                block_code = func;
            }
            calling->block_handler = block_code;
        }
        else {
            calling->block_handler = vm_to_proc(block_code);
        }
    }
    else if (blockiseq != NULL) { /* likely */
        struct rb_captured_block *captured = VM_CFP_TO_CAPTURED_BLOCK(reg_cfp);
        captured->code.iseq = blockiseq;
        calling->block_handler = VM_BH_FROM_ISEQ_BLOCK(captured);
    }
    else {
        if (is_super) {
            calling->block_handler = GET_BLOCK_HANDLER();
        }
        else {
            calling->block_handler = VM_BLOCK_HANDLER_NONE;
        }
    }
}

↑のコードは全部載せているのでちょっと長いですが、修正箇所自体はそんなになくて

@@ -836,13 +836,17 @@ refine_sym_proc_call(RB_BLOCK_CALL_FUNC_ARGLIST(yielded_arg, callback_arg))
     ID mid;
     const rb_callable_method_entry_t *me;
     rb_execution_context_t *ec;
+    // callback_arg は [block_code, refinements] の配列になっている
+    // そこから block_code(Symbol) と refinements の情報を取り出す
+    const VALUE symbol = RARRAY_AREF(callback_arg, 0);
+    const VALUE refinements = RARRAY_AREF(callback_arg, 1);

     if (argc-- < 1) {
        rb_raise(rb_eArgError, "no receiver given");
     }
     obj = *argv++;
-    mid = SYM2ID(callback_arg);
-    me = rb_callable_method_entry_with_refinements(CLASS_OF(obj), mid, NULL);
+    mid = SYM2ID(symbol);
+    me = rb_resolve_refined_method_callable(refinements, (const rb_callable_method_entry_t *)rb_method_entry(CLASS_OF(obj), mid));
     ec = GET_EC();
     if (!NIL_P(blockarg)) {
        vm_passed_block_handler_set(ec, blockarg);
@@ -870,10 +874,13 @@ vm_caller_setup_arg_block(const rb_execution_context_t *ec, rb_control_frame_t *
            const rb_cref_t *cref = vm_env_cref(reg_cfp->ep);
            if (cref && !NIL_P(cref->refinements)) {
                VALUE ref = cref->refinements;
+                // [block_code(Symbol), refinements] となるような配列を生成する
+               VALUE callback_arg = rb_ary_new_from_args(2, block_code, ref);
                VALUE func = rb_hash_lookup(ref, block_code);
                if (NIL_P(func)) {
                    /* TODO: limit cached funcs */
-                   func = rb_func_proc_new(refine_sym_proc_call, block_code);
+                    // block_code(Symbol) と refinements を一緒にして refine_sym_proc_call() に
渡す
+                   func = rb_func_proc_new(refine_sym_proc_call, callback_arg);
                    rb_hash_aset(ref, block_code, func);
                }
                block_code = func;

と、言う感じになります。
ポイントとしては、

VALUE callback_arg = rb_ary_new_from_args(2, block_code, ref);

で、[block, refinements] の配列を生成して、

const VALUE symbol = RARRAY_AREF(callback_arg, 0);
const VALUE refinements = RARRAY_AREF(callback_arg, 1);

で、refine_sym_proc_call() から refinements を参照できるような実装になっています。
本当にこれで動作するようになるの?と思うかと思いますが、上記の修正で最初に提示したコードは問題なく動作するようになります。
やっていることは本当に『&:twice 時の refinemetnsblock.call で参照する』というような修正になります。

まとめ

最終的に出来上がった修正パッチはこちらになります。
上記の修正パッチに加えて『rb_func_proc_new() の結果をキャッシュしている部分』も削除してあります。
今回難しかった点としては、

  • どこで refinements を参照しているのか探すこと
  • refinements をどうやって渡すのか

あたりですかね。
今回の記事だと割とサラッと書いているように感じますが、実際はかなり時間をかけて実装を調べたりしていました。
修正パッチを書くよりも C言語の中を読んでいることの方が圧倒的に多いのでそろそろ Visual Studio とかでいい感じにデバッグしたくなってきますね…。
CRuby って Visual Studio でビルド出来るようにできるんかな…。
あと難しかった点とはちょっと違うんですが、

  • cref->refinements の寿命がわからない
  • キャッシュ化を無効

のあたりはどうすればよくわからなかったですね…。
このあたりはもっと詳しい人に頼むしかなさそう。

ちなみに今回の修正コードを書いてて一番つらかったことは Ruby コミッタの人に今回の事を相談したらだいたい苦虫を噛み潰したような顔をされたことですね…。
まあしゃーないんだけどもみんなもっと Refinements 使おうぜ!という気持ちにはなりました。

追記

今回は &:twice の呼び出しに関して修正してみました Symbol#to_proc は今回とはまた別の箇所で実装されているのでそのままです。 こっちも別で修正する必要があります。