如果将一段Lua
代码直接翻译成C++
代码,可能会存在一个问题:Lua
有tail call
,C++
没有tail call
。
例如下面这个函数(求二叉树所有结点的和),第二次递归调用Visit
是tail call
,如果直接翻译成C++
代码,会失去一部分优化效果。
function SumTree ( root )
local Sum = 0
local function Visit ( CurNode )
if not CurNode then
return
end
Sum = Sum + CurNode . val
Visit ( CurNode . left )
Visit ( CurNode . right ) -- tail call
end
Visit ( root )
return Sum
end
有一种方法可以将这段代码转化成高效的C++
代码,保留tail call
的效果,同时避免大部分函数调用的开销。
第一步,用 cps变换 将代码转换成下面的样子,每个函数调用都是tail call
,每个函数增加了一个参数Cont
。
代码运行时系统的栈不会增长,访问left
的时候Cont
才会增长。
这个变换的作用是用自定义的Cont
代替系统的栈,同时保留tail call
的效果。
function SumTree ( root )
local Sum = 0
<span class="kd">local</span> <span class="k">function</span> <span class="nf">ApplyCont</span><span class="p">(</span><span class="n">Cont</span><span class="p">)</span>
<span class="n">Cont</span><span class="p">()</span>
<span class="k">end</span>
<span class="kd">local</span> <span class="k">function</span> <span class="nf">ID</span><span class="p">()</span>
<span class="k">end</span>
<span class="kd">local</span> <span class="k">function</span> <span class="nf">Visit</span><span class="p">(</span><span class="n">CurNode</span><span class="p">,</span> <span class="n">Cont</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">CurNode</span> <span class="k">then</span>
<span class="k">return</span> <span class="n">ApplyCont</span><span class="p">(</span><span class="n">Cont</span><span class="p">)</span>
<span class="k">end</span>
<span class="n">Sum</span> <span class="o">=</span> <span class="n">Sum</span> <span class="o">+</span> <span class="n">CurNode</span><span class="p">.</span><span class="n">val</span>
<span class="kd">local</span> <span class="n">Cont1</span> <span class="o">=</span> <span class="k">function</span><span class="p">()</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">CurNode</span><span class="p">.</span><span class="n">right</span><span class="p">,</span> <span class="n">Cont</span><span class="p">)</span>
<span class="k">end</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">CurNode</span><span class="p">.</span><span class="n">left</span><span class="p">,</span> <span class="n">Cont1</span><span class="p">)</span> <span class="c1">-- Cont增长</span>
<span class="k">end</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">ID</span><span class="p">)</span>
<span class="k">return</span> <span class="n">Sum</span>
end
第二步,用自定义的数据结构代表Cont
,主要是将Cont
函数捕捉的free vars
与函数的代码分离开。
function SumTree ( root )
local Sum = 0
<span class="kd">local</span> <span class="n">Visit</span>
<span class="kd">local</span> <span class="n">ActionMap</span> <span class="o">=</span> <span class="p">{</span>
<span class="k">function</span><span class="p">(</span><span class="n">FreeVars</span><span class="p">,</span> <span class="n">Cont</span><span class="p">)</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">FreeVars</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">Cont</span><span class="p">)</span>
<span class="k">end</span><span class="p">,</span>
<span class="p">}</span>
<span class="kd">local</span> <span class="k">function</span> <span class="nf">ApplyCont</span><span class="p">(</span><span class="n">Cont</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">Cont</span> <span class="k">then</span>
<span class="k">return</span>
<span class="k">end</span>
<span class="kd">local</span> <span class="n">Action</span> <span class="o">=</span> <span class="n">ActionMap</span><span class="p">[</span><span class="n">Cont</span><span class="p">.</span><span class="n">ActionIndex</span><span class="p">]</span>
<span class="n">Action</span><span class="p">(</span><span class="n">Cont</span><span class="p">.</span><span class="n">FreeVars</span><span class="p">,</span> <span class="n">Cont</span><span class="p">.</span><span class="n">Cont</span><span class="p">)</span>
<span class="k">end</span>
<span class="kd">local</span> <span class="n">ID</span> <span class="o">=</span> <span class="kc">nil</span>
<span class="n">Visit</span> <span class="o">=</span> <span class="k">function</span><span class="p">(</span><span class="n">CurNode</span><span class="p">,</span> <span class="n">Cont</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">CurNode</span> <span class="k">then</span>
<span class="k">return</span> <span class="n">ApplyCont</span><span class="p">(</span><span class="n">Cont</span><span class="p">)</span>
<span class="k">end</span>
<span class="n">Sum</span> <span class="o">=</span> <span class="n">Sum</span> <span class="o">+</span> <span class="n">CurNode</span><span class="p">.</span><span class="n">val</span>
<span class="kd">local</span> <span class="n">Cont1</span> <span class="o">=</span> <span class="p">{</span>
<span class="n">FreeVars</span> <span class="o">=</span> <span class="p">{</span><span class="n">CurNode</span><span class="p">.</span><span class="n">right</span><span class="p">},</span>
<span class="n">ActionIndex</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">Cont</span> <span class="o">=</span> <span class="n">Cont</span><span class="p">,</span>
<span class="p">}</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">CurNode</span><span class="p">.</span><span class="n">left</span><span class="p">,</span> <span class="n">Cont1</span><span class="p">)</span>
<span class="k">end</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">ID</span><span class="p">)</span>
<span class="k">return</span> <span class="n">Sum</span>
end
第三步,上面代码中的Cont
是一个简单的链表,而且只在一端操作,所以可以替换成一个外部的Stack
,删掉所有函数的Cont
参数。
同时将ActionMap
中的代码inline
到 ApplyCont
中。这段代码中只有一种Action
,所以ActionIndex
不是必须的。
function SumTree ( root )
local Sum = 0
<span class="kd">local</span> <span class="n">Visit</span>
<span class="kd">local</span> <span class="n">Stack</span> <span class="o">=</span> <span class="p">{}</span>
<span class="kd">local</span> <span class="k">function</span> <span class="nf">ApplyCont</span><span class="p">()</span>
<span class="k">if</span> <span class="o">#</span><span class="n">Stack</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">then</span>
<span class="k">return</span>
<span class="k">end</span>
<span class="kd">local</span> <span class="n">Top</span> <span class="o">=</span> <span class="n">Stack</span><span class="p">[</span><span class="o">#</span><span class="n">Stack</span><span class="p">]</span>
<span class="n">Stack</span><span class="p">[</span><span class="o">#</span><span class="n">Stack</span><span class="p">]</span> <span class="o">=</span> <span class="kc">nil</span>
<span class="k">if</span> <span class="n">Top</span><span class="p">.</span><span class="n">ActionIndex</span> <span class="o">==</span> <span class="mi">1</span> <span class="k">then</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">Top</span><span class="p">.</span><span class="n">FreeVars</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="k">end</span>
<span class="k">end</span>
<span class="n">Visit</span> <span class="o">=</span> <span class="k">function</span><span class="p">(</span><span class="n">CurNode</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">CurNode</span> <span class="k">then</span>
<span class="k">return</span> <span class="n">ApplyCont</span><span class="p">()</span>
<span class="k">end</span>
<span class="n">Sum</span> <span class="o">=</span> <span class="n">Sum</span> <span class="o">+</span> <span class="n">CurNode</span><span class="p">.</span><span class="n">val</span>
<span class="nb">table.insert</span><span class="p">(</span><span class="n">Stack</span><span class="p">,</span> <span class="p">{</span>
<span class="n">FreeVars</span> <span class="o">=</span> <span class="p">{</span><span class="n">CurNode</span><span class="p">.</span><span class="n">right</span><span class="p">},</span>
<span class="n">ActionIndex</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="p">})</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">CurNode</span><span class="p">.</span><span class="n">left</span><span class="p">)</span>
<span class="k">end</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">root</span><span class="p">)</span>
<span class="k">return</span> <span class="n">Sum</span>
end
第四步,用tail call
的方式调用函数时,当前函数中的参数和局部变量的生命周期就结束了,所以可以用一些外部的register
来代替函数的参数和局部变量。同时调整一下Stack
的结构,上面代码中的FreeVars
主要是为了更好地说明问题。
function SumTree ( root )
local Sum = 0
<span class="kd">local</span> <span class="n">Stack</span> <span class="o">=</span> <span class="p">{}</span>
<span class="kd">local</span> <span class="n">R_CurNode</span><span class="p">,</span> <span class="n">R_Temp</span>
<span class="kd">local</span> <span class="n">Visit</span>
<span class="kd">local</span> <span class="k">function</span> <span class="nf">ApplyCont</span><span class="p">()</span>
<span class="k">if</span> <span class="o">#</span><span class="n">Stack</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">then</span>
<span class="k">return</span>
<span class="k">end</span>
<span class="n">R_Temp</span> <span class="o">=</span> <span class="n">Stack</span><span class="p">[</span><span class="o">#</span><span class="n">Stack</span><span class="p">]</span>
<span class="n">Stack</span><span class="p">[</span><span class="o">#</span><span class="n">Stack</span><span class="p">]</span> <span class="o">=</span> <span class="kc">nil</span>
<span class="k">if</span> <span class="n">R_Temp</span><span class="p">.</span><span class="n">ActionIndex</span> <span class="o">==</span> <span class="mi">1</span> <span class="k">then</span>
<span class="n">R_CurNode</span> <span class="o">=</span> <span class="n">R_Temp</span><span class="p">.</span><span class="n">Node</span>
<span class="n">Visit</span><span class="p">()</span>
<span class="k">end</span>
<span class="k">end</span>
<span class="n">Visit</span> <span class="o">=</span> <span class="k">function</span><span class="p">()</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">R_CurNode</span> <span class="k">then</span>
<span class="k">return</span> <span class="n">ApplyCont</span><span class="p">()</span>
<span class="k">end</span>
<span class="n">Sum</span> <span class="o">=</span> <span class="n">Sum</span> <span class="o">+</span> <span class="n">R_CurNode</span><span class="p">.</span><span class="n">val</span>
<span class="nb">table.insert</span><span class="p">(</span><span class="n">Stack</span><span class="p">,</span> <span class="p">{</span>
<span class="n">ActionIndex</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">Node</span> <span class="o">=</span> <span class="n">R_CurNode</span><span class="p">.</span><span class="n">right</span><span class="p">,</span>
<span class="p">})</span>
<span class="n">R_CurNode</span> <span class="o">=</span> <span class="n">R_CurNode</span><span class="p">.</span><span class="n">left</span>
<span class="n">Visit</span><span class="p">()</span>
<span class="k">end</span>
<span class="n">R_CurNode</span> <span class="o">=</span> <span class="n">root</span>
<span class="n">Visit</span><span class="p">()</span>
<span class="k">return</span> <span class="n">Sum</span>
end
第五步,将上面代码翻译成C++
代码,其中,每个函数都没有参数、没有返回值、没有局部变量,每个函数调用都是tail call
,所以可以直接用goto
语句代替函数调用。有多种Action
的时候,ActionIndex
就有存在的必要了。
int SumTree ( TreeNode * root )
{
int Sum = 0 ;
<span class="k">struct</span> <span class="n">StackNode</span>
<span class="p">{</span>
<span class="kt">int</span> <span class="n">ActionIndex</span><span class="p">;</span>
<span class="n">TreeNode</span><span class="o">*</span> <span class="n">Node</span><span class="p">;</span>
<span class="p">};</span>
<span class="n">vector</span><span class="o"><</span><span class="n">StackNode</span><span class="o">></span> <span class="n">Stack</span><span class="p">;</span>
<span class="n">TreeNode</span><span class="o">*</span> <span class="n">R_CurNode</span> <span class="o">=</span> <span class="n">root</span><span class="p">;</span> <span class="c1">//对应上面代码中的第一次调用Visit
goto L_Visit ;
L_ApplyCont:
if ( Stack . empty ())
{
return Sum ; // 将最后的return移到此处
}
// 此处微调一下,不再需要R_Temp
switch ( Stack . back (). ActionIndex )
{
case 1 :
R_CurNode = Stack . back (). Node ;
Stack . pop_back ();
goto L_Visit ;
//break; 不需要
}
L_Visit:
if ( ! R_CurNode )
{
goto L_ApplyCont ;
}
Sum += R_CurNode -> val ;
Stack . push_back ({
1 , R_CurNode -> right
});
R_CurNode = R_CurNode -> left ;
goto L_Visit ;
}
最后,用Leetcode
上的N-ary Tree Level Order Traversal 验证一下这种方法的正确性和优化效果。
这个问题比较聪明的答案是利用queue
实现一个迭代算法,Leetcode
官方版本的迭代代码最短执行时间是40ms,但是我用同样的代码提交,执行时间大约在45~70ms之间。
先实现一个无脑的递归算法,下面这段代码的执行时间大概在60~100ms之间,最短时间58秒。
vector < vector < int >> levelOrder ( Node * root ) {
vector < vector < int >> Results ;
function < void ( Node * , int ) > Visit ;
Visit = [ & ]( Node * CurNode , int Depth )
{
if ( ! CurNode )
{
return ;
}
if ( Results . size () < Depth )
{
Results . resize ( Depth );
}
Results [ Depth - 1 ]. push_back ( CurNode -> val );
for ( auto Child : CurNode -> children )
{
Visit ( Child , Depth + 1 );
}
};
Visit ( root , 1 );
return Results ;
}
然后利用上面的方法转化成下面的代码,能通过所有testcase,执行时间大概在45~85ms之间,最短时间44ms。这个递归算法的实现中没有tail call
,所以主要的优化效果来源于goto
语句。
vector < vector < int >> levelOrder ( Node * root ) {
vector < vector < int >> Results ;
struct StackNode {
int ChildIndex ;
Node * CurNode ;
int CurDepth ;
};
vector < StackNode > Stack ;
int R_ChildIndex ;
Node * R_CurNode = root ;
int R_CurDepth = 1 ;
goto L_Recursive ;
L_ApplyCont :
if ( Stack . empty ())
{
return Results ;
}
R_ChildIndex = Stack . back (). ChildIndex ;
R_CurNode = Stack . back (). CurNode ;
R_CurDepth = Stack . back (). CurDepth ;
Stack . pop_back ();
goto L_LoopChild ;
L_LoopChild :
if ( R_ChildIndex < R_CurNode -> children . size ())
{
Stack . push_back ({
R_ChildIndex + 1 , R_CurNode , R_CurDepth
});
R_CurNode = R_CurNode -> children [ R_ChildIndex ];
R_CurDepth ++ ;
goto L_Recursive ;
}
else
{
goto L_ApplyCont ;
}
L_Recursive :
if ( ! R_CurNode )
{
goto L_ApplyCont ;
}
if ( Results . size () < R_CurDepth )
{
Results . resize ( R_CurDepth );
}
Results [ R_CurDepth - 1 ]. push_back ( R_CurNode -> val );
R_ChildIndex = 0 ;
goto L_LoopChild ;
}
这个转换代码的方法有一个灵活之处,我们可以将那些不需要优化的函数调用视为基本操作(类似+ - * /
),最后生成的C++
代码中添加上对这些Lua
函数的调用。