使用 ASM 在 Java 中实现尾调用递归 (2023)
UnlinkedList
思考、想法、观察
使用 ASM 在 Java 中实现尾调用递归
一些编译器提供的一种优化是 尾调用优化。 这种优化带来的好处并不多,因为程序员总是可以调整他的代码而无需递归,尤其是在命令式语言中。 另一方面,递归代码通常更优雅,所以为什么我们不让编译器在可能的情况下做那些令人讨厌的事情呢? 在本文中,我将介绍一种使用 ASM 进行字节码操作在 Java 中实现尾调用优化的简洁方法。
什么是尾调用递归?
尾调用递归是一种特殊的递归形式,其中最后一个操作是递归调用。 它属于更广泛的尾调用类别,其中最后一个操作是方法调用。 我将自己限制在更严格的尾递归案例中。 让我们用一个例子来说明。
long factorial(int n, long f) {
if (n < 2) {
return f;
}
return factorial(n - 1, f * n);
}
可以看出,最后一个操作是对同一函数的调用,但参数不同。 下面的例子不是尾递归。
long factorial(int n) {
if (n < 2) {
return f;
}
return n * factorial(n - 1);
}
前一个示例不是尾调用递归的原因是,最后一个操作不是递归调用,而是乘法运算。 乘法运算发生在递归调用返回之后。
尾递归具有特定的形式,可以通过避免分配新的栈帧来实现更快的执行,因为执行可以利用当前的栈。
方法调用的剖析
如果您不太了解 Java 虚拟机如何进行方法调用,这是一个简要概述。 这个想法在编程中几乎是通用的,但是这里提出的细节是 JVM 特有的。
为了使方法能够执行,它需要一个称为帧的空间,其中应包含一些特定的内容:
- 局部变量空间:一个固定大小的条目数组,具有各种类型的值
- 操作数栈:一个用于存储当前操作数的栈
JVM 还会管理一个执行栈。 JVM 执行栈收集帧。 当调用一个方法来执行时,会创建一个新的帧,正确初始化并推入 JVM 执行栈。 方法调用的最终参数从当前栈中收集,并用于初始化新的帧。 在方法执行结束后,收集返回值(如果有),为该方法调用分配的帧将从 JVM 栈中移除,引用先前的帧,并将收集到的返回值推入栈中。
局部变量和操作数栈部分的大小取决于方法的代码,它在编译时计算出来,并与编译后的类中的字节码指令一起存储。 对应于同一方法调用的所有帧的大小都相同,但对应于不同方法的帧的大小可能不同。
创建帧时,会使用空栈初始化该帧,并且其局部变量会使用目标对象 this (对于非静态方法)和方法的参数(按此顺序)初始化。 例如,调用方法 a.equals(b) 会创建一个帧,其中有一个空栈,并且前两个局部变量初始化为 a 和 b (其他局部变量未初始化)。
正如您所看到的,以链式方式调用方法会增加 JVM 执行栈所需的空间,因为对于每个内部方法调用,都会将一个新的帧推入栈中。 由于栈是有限的,因此多次调用嵌套方法可能会填充 JVM 执行栈,直到抛出类似 StackOverflowError 的异常,您可能听说过这种异常。
当以递归方式实现算法时,经常会遇到耗尽栈空间的情况。
编译后的类的实际示例
让我们研究一下以下类生成的代码。
public class Factorial implements FactorialInterface {
public long fact(int n) {
return factTailRec(n, 1L);
}
private long factTailRec(int n, long ret) {
if (n < 1) {
return ret;
}
return factTailRec(n - 1, ret * n);
}
}
正如你所看到的,我们有一个简单的类,它以尾递归的方式实现了阶乘函数。 第一个方法是一个 facade,用于提供更好的用户体验,而第二个方法实现了实际的计算。
生成的类如下所示
// class version 63.0 (63)
// access flags 0x21
public class rapaio/experiment/asm/Factorial implements rapaio/experiment/asm/FactorialInterface {
// compiled from: Factorial.java
// access flags 0x1
public <init>()V
L0
LINENUMBER 34 L0
ALOAD 0
INVOKESPECIAL java/lang/Object.<init> ()V
RETURN
L1
LOCALVARIABLE this Lrapaio/experiment/asm/Factorial; L0 L1 0
MAXSTACK = 1
MAXLOCALS = 1
// access flags 0x1
public fact(I)J
L0
LINENUMBER 37 L0
ALOAD 0
ILOAD 1
LCONST_1
INVOKEVIRTUAL rapaio/experiment/asm/Factorial.factTailRec (IJ)J
LRETURN
L1
LOCALVARIABLE this Lrapaio/experiment/asm/Factorial; L0 L1 0
LOCALVARIABLE n I L0 L1 1
MAXSTACK = 4
MAXLOCALS = 2
// access flags 0x2
private factTailRec(IJ)J
L0
LINENUMBER 41 L0
ILOAD 1
ICONST_1
IF_ICMPGE L1
L2
LINENUMBER 42 L2
LLOAD 2
LRETURN
L1
LINENUMBER 44 L1
FRAME SAME
ALOAD 0
ILOAD 1
ICONST_1
ISUB
LLOAD 2
ILOAD 1
I2L
LMUL
INVOKEVIRTUAL rapaio/experiment/asm/Factorial.factTailRec (IJ)J
LRETURN
L3
LOCALVARIABLE this Lrapaio/experiment/asm/Factorial; L0 L3 0
LOCALVARIABLE n I L0 L3 1
LOCALVARIABLE ret J L0 L3 2
MAXSTACK = 6
MAXLOCALS = 4
}
让我们解释一下字节码。 我们有一个实现接口的类,并且我们有一个默认构造函数。
默认构造函数在 Java 语言中是可选的,但在 JVM 字节码中不是。 默认构造函数的第一个指令是 ALOAD 0 ,它从索引为 0 的局部变量(记住局部变量是如何初始化的)将 this 指针加载到操作栈上。 第二条指令调用类 Object 的方法 init 。 此方法采用一个参数,该参数是指向对象的指针。 此指针取自操作栈(这就是我们有第一条指令的原因)。 它返回 void ,因为它在它的描述中用 V 表示。 之后,它只是将控制权返回给先前的调用。 请注意,它遵循初始帧的描述,其中包含使用的变量、为操作数栈和局部变量表分配的条目数。
接下来是方法 fact 的描述,该方法采用一个整数参数(用 I 描述)并返回一个 long 值(用 J 描述)。 指令列表以加载 this 指针 (ALOAD 0)、参数的值 (ILOAD 1) 和将常量 long 值 1 (LCONST_1) 推送到操作数栈上开始。 所有三个加载到操作数栈上的值都由以下指令以相反的顺序调用方法 factTailRect 使用(它是一个栈)。 在方法调用结束后,它的返回值会被推送到操作数栈上。 该值被下一条指令 LRETURN 使用,该指令从操作数栈返回 long 值并将控制权恢复给调用方法。
方法 factTailRect 的描述稍微复杂一些,但并没有过于复杂。 将第一个参数的值加载到操作数栈上 (ILOAD 1),并将整数常量 1 (ICONST_1) 加载到操作数栈上。 下一条指令将参数的值与常量进行比较 (IF_ICMPGE L1),如果变量大于或等于常量,则转到标签 L1 ,否则继续。 如果比较失败,则将第二个变量的值加载到操作数栈上 (LLOAD 2) 并返回它 (LRETURN)。
在标签 L1 处,有用于递归调用的指令。 首先,将 this 指针加载到栈上,以准备进行递归调用 (ALOAD 0)。 然后将第一个变量加载到栈上 (ILOAD 1) 和常量 1 (ICONST_1)。 这两个值被减法运算使用,该运算用常量的值减少变量的值 (ISUB)。 返回值被放入栈中。 请注意,此时在栈上我们有两个值: this 指针的值和变量的减少值(减法指令从栈中弹出两个操作数并将一个结果放回)。
接下来,将第二个变量放入栈中 (LLOAD 2),并将第一个变量的值放入栈中 (ILOAD 1)。 第一个变量仍然具有原始值,修改后的值在栈上。 整数变量被转换为 long 类型 (I2L),并且两个值相乘 (LMUL)。 乘法使用最后两个栈操作数并将乘法的结果推回。
现在,操作数栈有三个值,递归调用函数所需的值 (INVOKEVIRTUAL)。 这三个操作数被方法调用消耗,结果被放入栈中。 返回最终结果 (LRETURN)。 最后一部分是帧的描述,该帧包含三个局部变量,并且在栈上分配了 6 个位置,在局部变量表上分配了 4 个位置。
我希望你没有厌倦阅读所有这些。 也许你知道 JVM 栈机器是如何工作的,但我已经把这个描述放在如果你不知道的情况下。
尾调用递归的结构
一般来说, 尾调用递归 有一个非常简单的结构。 尾递归方法有三个阶段。 第一阶段是停止规则。 这些规则定义了递归何时结束。 第二阶段包含计算,第三阶段是返回结果的递归调用。 通常,当计算很简单时,最后两个阶段会合并为一个阶段,其中计算发生在传递参数值之前。 我们的方法就是这种情况。
有了这个清晰的设计,可以进行一些观察,从而可以进行简单的优化。
第一个观察结果是,由于调用的是相同的方法,因此递归调用的帧的形状与当前帧相同。 原因是帧的形状在编译时确定,并且保持固定。 由于我们递归地调用相同的方法,因此我们确定当前帧适合递归调用的需要。 可能的优化是避免创建必须在每次调用时都推入 JVM 执行栈的新帧。
第二个观察结果是,为了重用当前帧,我们需要以与发生适当的帧初始化相同的方式准备栈和局部变量。 然而,这很容易做到,仅仅因为返回之前的最后一个调用是递归调用。 为了进行递归调用,需要用 this 指针的值和所有参数值填充栈。 该调用会将所有这些值从当前操作栈中弹出,并使用这些值初始化下一个帧。 这就是我们必须做的。 简单地获取所有这些值并正确初始化当前帧的局部变量。 这些值已经为我们准备好了。
使用 ASM 转换字节码
ASM 是一个很棒且简洁的库,它允许人们在编译时或运行时分析、转换和生成字节码。 它被许多平台和工具使用,包括 OpenJDK 编译器本身。 我没有足够的词语来描述这个库的有用性和优雅性,我非常感谢它的创建者和贡献者。
ASM 库允许使用两种方法转换字节码:基于事件的和基于树的。 我将使用基于树的 API,因为这些更改并非微不足道,并且无法在解析器的单次传递中执行。 这是用于优化尾递归方法的代码:
class TailRecTransformer extends ClassNode {
private static final String METHOD_SUFFIX = "TailRec";
public TailRecTransformer(ClassVisitor cv) {
super(ASM9);
this.cv = cv;
}
@Override
public void visitEnd() {
// we optimize all methods which ends with TailRec for simplicity
methods.stream().filter(mn -> mn.name.endsWith(METHOD_SUFFIX))
.forEach(this::transformTailRec);
accept(cv);
}
void transformTailRec(MethodNode methodNode) {
// method argument types
Type[] argumentTypes = Type.getArgumentTypes(methodNode.desc);
// iterator over instructions
var it = methodNode.instructions.iterator();
LabelNode firstLabel = null;
while (it.hasNext()) {
var inode = it.next();
// locate the first label
// this label will be used to jump instead of recursive call
if (firstLabel == null && inode instanceof LabelNode labelNode) {
firstLabel = labelNode;
continue;
}
if (inode instanceof FrameNode) {
// remove all frames since we recompute them all at writing
it.remove();
continue;
}
if (inode instanceof MethodInsnNode methodInsnNode &&
methodInsnNode.name.equals(methodNode.name) &&
methodInsnNode.desc.equals(methodNode.desc)) {
// find the recursive call which has to have
// same signature and be followed by return
// check if the next instruction is return of proper type
var nextInstruction = it.next();
Type returnType = Type.getReturnType(methodNode.desc);
if (!(nextInstruction.getOpcode() ==
returnType.getOpcode(Opcodes.IRETURN))) {
continue;
}
// remove the return and recursive call from instructions
it.previous();
it.previous();
it.remove();
it.next();
it.remove();
// pop values from stack and store them in local
// variables in reverse order
for (int i = argumentTypes.length - 1; i >= 0; i--) {
Type type = argumentTypes[i];
it.add(new VarInsnNode(type.getOpcode(Opcodes.ISTORE), i + 1));
}
// add a new jump instruction to the first label
it.add(new JumpInsnNode(Opcodes.GOTO, firstLabel));
// finally remove the instruction which loaded 'this'
// since it was required by the recursive call
while (it.hasPrevious()) {
AbstractInsnNode node = it.previous();
if (node instanceof VarInsnNode varInsnNode) {
if (varInsnNode.getOpcode() == Opcodes.ALOAD &&
varInsnNode.var == 0) {
it.remove();
// we remove only the last instruction of this kind
// we don't touch it other similar instructions
// to not break the existent code
break;
}
}
}
}
}
}
}
我真的希望代码和注释是自包含的。 为了保持一致性,我将简要介绍一下它的逻辑。
为了使用 ASM 库的树 API 转换方法,需要更改类 MethodNode 中的值,因为这是 JVM 字节码在 ASM 库中的表示形式。 为了简单起见,我创建了一个转换器,它尝试优化所有名称以后缀 TailRec 结尾的方法。 这是为了说明目的,使用 annotation 会更好,但需要更多代码和构建一个 agent。
优化逻辑的核心在于方法 transformTailRec 。 此方法接收任何名称以我们的后缀结尾的类方法的相应字节码表示。 优化包括以下阶段。
我们确定第一个代码标签。 这是递归方法的代码的开始。 当我们用简单的跳转指令替换递归调用时,我们将使用此标签。 此跳转指令是 goto 。 有趣的是,由于充分的理由,此臭名昭著的指令在 Java 语言中不存在。 这种不受控制的跳转会破坏 JVM 的所有记帐机制。 但是,JVM 中存在相同的指令。 因为在 JVM 中,我们只能在来自同一方法调用的一组指令中跳转,所以可以安全地使用它。
我们将重用当前帧,而不是创建新帧的递归方法调用。 下一个阶段是删除递归调用和之后的 return 指令,同时准备局部变量和栈以供下次使用。 我们在递归调用的位置引入一个指向第一个标签的 goto 指令。 基本上,我们实现了一个 while 循环。 停止条件已经在代码中,所以我们不会因为优化而获得无限循环。
我们完成了!
测试递归尾部优化
对此的完整处理将意味着实现一个 Java agent,该 agent 会在类加载之前优化代码。 我避免了这些复杂情况,因为它与主题无关。 也许将来我会创建一个带有此 annotation 和优化的微型 github 项目。
为了使事情简单,我编写了一个自定义类加载器,它使用优化的代码创建类。 如果这些类由不同的类加载器加载,则 Java 允许拥有两个具有相同规范的类。 为了易于使用它们,我还创建了一个接口。
这样,我们将有两个类,一个已优化,另一个未优化,并且两者都实现相同的接口。 这样我们就可以在同一个 JVM 实例中使用它们并使用 JMH 测试它们。 作为参考,下面列出了类加载器的代码。
public class CustomClassLoader extends ClassLoader {
private final boolean verbose;
public CustomClassLoader(boolean verbose) {
this.verbose = verbose;
}
@Override
protected Class<?> findClass(String name) {
ClassWriter cw = new ClassWriter(0);
ClassVisitor lastCv;
if (verbose) {
TraceClassVisitor beforeTcv = new TraceClassVisitor(cw, new PrintWriter(System.out));
TailRecTransformer trt = new TailRecTransformer(beforeTcv);
lastCv = new TraceClassVisitor(trt, new PrintWriter(System.out));
} else {
lastCv = new TailRecTransformer(cw);
}
ClassReader cr;
try {
cr = new ClassReader(name);
} catch (IOException e) {
throw new RuntimeException(e);
}
cr.accept(lastCv, 0);
byte[] buffer = cw.toByteArray();
return defineClass(name, buffer, 0, buffer.length);
}
public <T> T newTailRecInstance(Class<T> external, Class<?> internal) throws NoSuchMethodException,
InvocationTargetException, InstantiationException, IllegalAccessException {
Class<?> c = findClass(internal.getCanonicalName());
return (T) c.getConstructor().newInstance();
}
}
阶乘 JMH 基准测试
我实现了两个简单的递归方法调用。 第一个已经介绍过,它是阶乘。
public class Factorial implements FactorialInterface {
public long fact(int n) {
return factTailRec(n, 1L);
}
private long factTailRec(int n, long ret) {
if (n < 1) {
return ret;
}
ret *= n;
n -= 1;
return factTailRec(n, ret);
}
}
JMH 基准测试结果如下:
Benchmark (n) Mode Cnt Score Error Units
TailRec.recursiveFact 1 thrpt 5771.714 ± 9.722 ops/us
TailRec.recursiveFact 3 thrpt 5242.958 ± 1.693 ops/us
TailRec.recursiveFact 5 thrpt 5194.606 ± 2.418 ops/us
TailRec.recursiveFact 10 thrpt 590.850 ± 2.345 ops/us
TailRec.recursiveFact 15 thrpt 566.567 ± 0.898 ops/us
TailRec.recursiveFact 20 thrpt 548.615 ± 0.308 ops/us
TailRec.recursiveFactTailRec 1 thrpt 5735.701 ± 4.936 ops/us
TailRec.recursiveFactTailRec 3 thrpt 5512.596 ± 0.946 ops/us
TailRec.recursiveFactTailRec 5 thrpt 5409.343 ± 3.884 ops/us
TailRec.recursiveFactTailRec 10 thrpt 5263.263 ± 3.033 ops/us
TailRec.recursiveFactTailRec 15 thrpt 5184.061 ± 2.992 ops/us
TailRec.recursiveFactTailRec 20 thrpt 5133.968 ± 1.070 ops/us
区别很明显。 优化版本更快。 然而,差异不大。 这仅仅是因为递归调用的数量很少,必须很少才能不产生整数溢出。
Sum JMH 基准测试
为了说明目的,我以尾递归的方式实现了数组值的总和。 当然,这不是最好的选择,但如果容器是链表,那么它将是函数式风格中一种有吸引力的实现。 下面是 sum 方法的实现。
public class Sum implements SumInterface {
public int sum(int[] array) {
return sumTailRec(array, 0, 0);
}
public int sumTailRec(int[] array, int i, int sum) {
if (i >= array.length) {
return sum;
}
return sumTailRec(array, i + 1, sum + array[i]);
}
}
下面是我们有 JMH 基准测试结果。
Benchmark (n) Mode Cnt Score Error Units
TailRec.recursiveSum 10 thrpt 5102800.521 ± 7870.635 ops/ms
TailRec.recursiveSum 100 thrpt 58949.731 ± 473.936 ops/ms
TailRec.recursiveSum 1000 thrpt 5846.104 ± 30.766 ops/ms
TailRec.recursiveSum 10000 thrpt 573.955 ± 17.637 ops/ms
TailRec.recursiveSumTailRec 10 thrpt 5132477.710 ± 2955.738 ops/ms
TailRec.recursiveSumTailRec 100 thrpt 516956.311 ± 541.083 ops/ms
TailRec.recursiveSumTailRec 1000 thrpt 51915.083 ± 116.170 ops/ms
TailRec.recursiveSumTailRec 10000 thrpt 5187.088 ± 10.059 ops/ms
我们还注意到尾调用消除带来的改进。
最后的想法
总的来说,我不是递归的忠实粉丝,并且在可能的情况下倾向于选择紧凑的迭代实现。 这绝不是反对尾调用优化,尤其是尾调用递归的论点。
目前,Java 不提供任何类型的尾调用优化。 Project Loom 似乎正在考虑更广泛的调用优化,但这些优化现在似乎并不是优先事项。 尾递归优化可以改为在库中实现,例如 Lombok,当存在给定的 annotation 时提供建议的优化。
2023 年 3 月 19 日 在 Bytecode, Java Bytecode, Java
发表回复 取消回复
您的电子邮件地址将不会被公布。 必填字段已标记 *
评论 *
姓名 *
电子邮件 *
网站
在此浏览器中保存我的姓名、电子邮件和网站,以便下次发表评论。
© 2023 UnlinkedList 主题由 Anders Norén 提供