在上一篇java动态编译 (java在线执行代码后端实现原理(一))文章中实现了 字符串编译成字节码,然后通过反射来运行代码的demo。这一篇文章提供一个如何防止死循环的代码占用cpu的问题。

思路:由于CustomStringJavaCompiler中重定向了System.out的输出位置,肯定不能有多线程并发的情况,否则会照成System.out输出内容错乱,所以我用了 Executors.newFixedThreadPool(1), 通过Future模式来获取结果,我自定义了一个CustomCallable来处理核心逻辑,在call方法中重新new 了一个Thread来编译并执行代码,然后通过join等待N秒之后强制stop掉正在运行的线程。这样就能及时的kill掉动态运行的代码。

CustomStringJavaCompiler 编译核心类

package compiler.mydemo;

import javax.tools.Diagnostic;
import javax.tools.DiagnosticCollector;
import javax.tools.FileObject;
import javax.tools.ForwardingJavaFileManager;
import javax.tools.JavaCompiler;
import javax.tools.JavaFileManager;
import javax.tools.JavaFileObject;
import javax.tools.SimpleJavaFileObject;
import javax.tools.StandardJavaFileManager;
import javax.tools.ToolProvider;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * Create by andy on 2018-12-06 21:25
 */
public class CustomStringJavaCompiler {
    //类全名
    private String fullClassName;
    private String sourceCode;
    //存放编译之后的字节码(key:类全名,value:编译之后输出的字节码)
    private Map<String, ByteJavaFileObject> javaFileObjectMap = new ConcurrentHashMap<>();
    //获取java的编译器
    private JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
    //存放编译过程中输出的信息
    private DiagnosticCollector<JavaFileObject> diagnosticsCollector = new DiagnosticCollector<>();
    //执行结果(控制台输出的内容)
    private String runResult;
    //编译耗时(单位ms)
    private long compilerTakeTime;
    //运行耗时(单位ms)
    private long runTakeTime;


    public CustomStringJavaCompiler(String sourceCode) {
        this.sourceCode = sourceCode;
        this.fullClassName = getFullClassName(sourceCode);
    }

    /**
     * 编译字符串源代码,编译失败在 diagnosticsCollector 中获取提示信息
     *
     * @return true:编译成功 false:编译失败
     */
    public boolean compiler() {
        long startTime = System.currentTimeMillis();
        //标准的内容管理器,更换成自己的实现,覆盖部分方法
        StandardJavaFileManager standardFileManager = compiler.getStandardFileManager(diagnosticsCollector, null, null);
        JavaFileManager javaFileManager = new StringJavaFileManage(standardFileManager);
        //构造源代码对象
        JavaFileObject javaFileObject = new StringJavaFileObject(fullClassName, sourceCode);
        //获取一个编译任务
        JavaCompiler.CompilationTask task = compiler.getTask(null, javaFileManager, diagnosticsCollector, null, null, Arrays.asList(javaFileObject));
        //设置编译耗时
        compilerTakeTime = System.currentTimeMillis() - startTime;
        return task.call();
    }

    /**
     * 执行main方法,重定向System.out.print
     */
    public void runMainMethod() throws ClassNotFoundException, NoSuchMethodException, InvocationTargetException, IllegalAccessException, UnsupportedEncodingException {
        PrintStream out = System.out;
        try {
            long startTime = System.currentTimeMillis();
            ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
            PrintStream printStream = new PrintStream(outputStream);
            //PrintStream PrintStream = new PrintStream("/Users/andy/Desktop/tem.sql"); //输出到文件
            System.setOut(printStream); //测试kill线程暂时屏蔽

            StringClassLoader scl = new StringClassLoader();
            Class<?> aClass = scl.findClass(fullClassName);
            Method main = aClass.getMethod("main", String[].class);
            Object[] pars = new Object[]{1};
            pars[0] = new String[]{};
            main.invoke(null, pars); //调用main方法
            //设置运行耗时
            runTakeTime = System.currentTimeMillis() - startTime;
            //设置打印输出的内容
            runResult = new String(outputStream.toByteArray(), "utf-8");
        } finally {
            //还原默认打印的对象
            System.setOut(out);
        }

    }

    /**
     * @return 编译信息(错误 警告)
     */
    public String getCompilerMessage() {
        StringBuilder sb = new StringBuilder();
        List<Diagnostic<? extends JavaFileObject>> diagnostics = diagnosticsCollector.getDiagnostics();
        for (Diagnostic diagnostic : diagnostics) {
            sb.append(diagnostic.toString()).append("\r\n");
        }
        return sb.toString();
    }

    /**
     * @return 控制台打印的信息
     */
    public String getRunResult() {
        return runResult;
    }


    public long getCompilerTakeTime() {
        return compilerTakeTime;
    }

    public long getRunTakeTime() {
        return runTakeTime;
    }

    /**
     * 获取类的全名称
     *
     * @param sourceCode 源码
     * @return 类的全名称
     */
    public static String getFullClassName(String sourceCode) {
        String className = "";
        Pattern pattern = Pattern.compile("package\\s+\\S+\\s*;");
        Matcher matcher = pattern.matcher(sourceCode);
        if (matcher.find()) {
            className = matcher.group().replaceFirst("package", "").replace(";", "").trim() + ".";
        }

        pattern = Pattern.compile("class\\s+\\S+\\s+\\{");
        matcher = pattern.matcher(sourceCode);
        if (matcher.find()) {
            className += matcher.group().replaceFirst("class", "").replace("{", "").trim();
        }
        return className;
    }

    /**
     * 自定义一个字符串的源码对象
     */
    private class StringJavaFileObject extends SimpleJavaFileObject {
        //等待编译的源码字段
        private String contents;

        //java源代码 => StringJavaFileObject对象 的时候使用
        public StringJavaFileObject(String className, String contents) {
            super(URI.create("string:///" + className.replaceAll("\\.", "/") + Kind.SOURCE.extension), Kind.SOURCE);
            this.contents = contents;
        }

        //字符串源码会调用该方法
        @Override
        public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException {
            return contents;
        }

    }

    /**
     * 自定义一个编译之后的字节码对象
     */
    private class ByteJavaFileObject extends SimpleJavaFileObject {
        //存放编译后的字节码
        private ByteArrayOutputStream outPutStream;

        public ByteJavaFileObject(String className, Kind kind) {
            super(URI.create("string:///" + className.replaceAll("\\.", "/") + Kind.SOURCE.extension), kind);
        }

        //StringJavaFileManage 编译之后的字节码输出会调用该方法(把字节码输出到outputStream)
        @Override
        public OutputStream openOutputStream() {
            outPutStream = new ByteArrayOutputStream();
            return outPutStream;
        }

        //在类加载器加载的时候需要用到
        public byte[] getCompiledBytes() {
            return outPutStream.toByteArray();
        }
    }

    /**
     * 自定义一个JavaFileManage来控制编译之后字节码的输出位置
     */
    private class StringJavaFileManage extends ForwardingJavaFileManager {
        StringJavaFileManage(JavaFileManager fileManager) {
            super(fileManager);
        }

        //获取输出的文件对象,它表示给定位置处指定类型的指定类。
        @Override
        public JavaFileObject getJavaFileForOutput(Location location, String className, JavaFileObject.Kind kind, FileObject sibling) throws IOException {
            ByteJavaFileObject javaFileObject = new ByteJavaFileObject(className, kind);
            javaFileObjectMap.put(className, javaFileObject);
            return javaFileObject;
        }
    }

    /**
     * 自定义类加载器, 用来加载动态的字节码
     */
    private class StringClassLoader extends ClassLoader {
        @Override
        protected Class<?> findClass(String name) throws ClassNotFoundException {
            ByteJavaFileObject fileObject = javaFileObjectMap.get(name);
            if (fileObject != null) {
                byte[] bytes = fileObject.getCompiledBytes();
                return defineClass(name, bytes, 0, bytes.length);
            }
            try {
                return ClassLoader.getSystemClassLoader().loadClass(name);
            } catch (Exception e) {
                return super.findClass(name);
            }
        }
    }
}

CustomCallable 调用编译并运行,设置超时时间

package compiler.mydemo;

import java.lang.reflect.InvocationTargetException;
import java.util.concurrent.Callable;

/**
 * Create by andy on 2018-12-07 13:10
 */
public class CustomCallable implements Callable<RunInfo> {
    private String sourceCode;

    public CustomCallable(String sourceCode) {
        this.sourceCode = sourceCode;
    }

    //方案1
    //@Override
    //public RunInfo call() throws Exception {
    //    System.out.println("开始执行call" + LocalTime.now());
    //    RunInfo runInfo = new RunInfo();
    //    CustomStringJavaCompiler compiler = new CustomStringJavaCompiler(sourceCode);
    //    if (compiler.compiler()) {
    //        runInfo.setCompilerSuccess(true);
    //        try {
    //            compiler.runMainMethod();
    //            runInfo.setRunSuccess(true);
    //            runInfo.setRunTakeTime(compiler.getRunTakeTime());
    //            runInfo.setRunMessage(compiler.getRunResult()); //获取运行的时候输出内容
    //        } catch (Exception e) {
    //            e.printStackTrace();
    //            runInfo.setRunSuccess(false);
    //            runInfo.setRunMessage(e.getMessage());
    //        }
    //    } else {
    //        //编译失败
    //        runInfo.setCompilerSuccess(false);
    //    }
    //    runInfo.setCompilerTakeTime(compiler.getCompilerTakeTime());
    //    runInfo.setCompilerMessage(compiler.getCompilerMessage());
    //    System.out.println("call over" + LocalTime.now());
    //    return runInfo;
    //}


    //方案2
    @Override
    public RunInfo call() throws Exception {
        RunInfo runInfo = new RunInfo();
        Thread t1 = new Thread(() -> realCall(runInfo));
        t1.start();
        try {
            t1.join(3000); //等待3秒
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        //不管有没有正常执行完成,强制停止t1
        t1.stop();
        return runInfo;
    }

    private void realCall(RunInfo runInfo) {
        CustomStringJavaCompiler compiler = new CustomStringJavaCompiler(sourceCode);
        if (compiler.compiler()) {
            runInfo.setCompilerSuccess(true);
            try {
                compiler.runMainMethod();
                runInfo.setRunSuccess(true);
                runInfo.setRunTakeTime(compiler.getRunTakeTime());
                runInfo.setRunMessage(compiler.getRunResult()); //获取运行的时候输出内容
            } catch (InvocationTargetException e) {
                //反射调用异常了,是因为超时的线程被强制stop了
                if ("java.lang.ThreadDeath".equalsIgnoreCase(e.getCause().toString())) {
                    return;
                }
            } catch (Exception e) {
                e.printStackTrace();
                runInfo.setRunSuccess(false);
                runInfo.setRunMessage(e.getMessage());
            }
        } else {
            //编译失败
            runInfo.setCompilerSuccess(false);
        }
        runInfo.setCompilerTakeTime(compiler.getCompilerTakeTime());
        runInfo.setCompilerMessage(compiler.getCompilerMessage());
        runInfo.setTimeOut(false); //走到这一步代表没有超时
    }
}

RunInfo 动态编译、运行信息的bean

public class RunInfo {
    //true:代表超时
    private Boolean timeOut;

    private Long compilerTakeTime;
    private String compilerMessage;
    private Boolean compilerSuccess;

    private Long runTakeTime;
    private String runMessage;
    private Boolean runSuccess;

    //省略get和set方法
}

CompilerUtil 把一整套流程封装了一个工具类

package compiler.mydemo;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

/**
 * Create by andy on 2018-12-07 16:32
 */
public class CompilerUtil {
    //这里用一个线程是因为防止System.out输出内容错乱
    private static ExecutorService pool = Executors.newFixedThreadPool(1);

    public static RunInfo getRunInfo(String javaSourceCode) {
        RunInfo runInfo;
        CustomCallable compilerAndRun = new CustomCallable(javaSourceCode);
        Future<RunInfo> future = pool.submit(compilerAndRun);
        //方案1
        try {
            runInfo = future.get();
        } catch (Exception e) {
            e.printStackTrace();
            //代码编译或者运行超时
            runInfo = new RunInfo();
            runInfo.setTimeOut(true);
        }

        //方案2:不可行的原因:future.get超时会有问题,由于线程池只有1个线程,同时提交10个任务, 当前面几个任务执行时间很长,后面调用get就会立马失败,也就是说get的超时时间是从调用get开始算的,并不是线程真正执行时间开始计算的
        //try {
        //    runInfo = future.get(5, TimeUnit.SECONDS);
        //    return runInfo;
        //} catch (InterruptedException e) {
        //    System.out.println("future在睡着时被打断");
        //    e.printStackTrace();
        //} catch (ExecutionException e) {
        //    System.out.println("future在尝试取得任务结果时出错");
        //    e.printStackTrace();
        //} catch (TimeoutException e) {
        //    System.out.println("future时间超时");
        //    e.printStackTrace();
        //    future.cancel(true);
        //}
        //runInfo = new RunInfo();
        //runInfo.setTimeOut(true);
        return runInfo;

    }
}

测试类:

package compiler.mydemo;

/**
 * Create by andy on 2018-12-10 10:43
 */
public class Test3 {

    public static void main(String[] args) throws InterruptedException {
        String loop = "public class HelloWorld {\n" +
                "    public static void main(String[] args) {\n" +
                "        while(true){\n" +
                //"            System.out.println(\"Hello World!\");\n" +
                "        }\n" +
                "       \n" +
                "    }\n" +
                "}";

        String sleep_loop = "public class HelloWorld {\n" +
                "    public static void main(String[] args) {\n" +
                "    try {\n" +
                "            Thread.sleep(6000);\n" +
                "        } catch (InterruptedException e) {\n" +
                "            e.printStackTrace();\n" +
                "        }\n" +
                "       System.out.println(\"Hello World!\");\n" +
                "        while(true){\n" +
                //"            System.out.println(\"Hello World!\");\n" +
                "        }\n" +
                "    }\n" +
                "}";

        String ok = "public class HelloWorld {\n" +
                "    public static void main(String[] args) {\n" +
                "       System.out.println(\"Hello World!\");\n" +
                "    }\n" +
                "}";

        TestRun t = new TestRun(ok, "thread:ok");
        t.start();

        TestRun t1 = new TestRun(loop, "thread:loop:");
        t1.start();
        //
        TestRun t2 = new TestRun(sleep_loop, "thread:sleep_loop:");
        t2.start();


    }


}

class TestRun extends Thread {
    String code;

    TestRun(String code, String name) {
        this.code = code;
        super.setName(name);
    }

    @Override
    public void run() {
        System.out.println(CompilerUtil.getRunInfo(code));
    }
}