跳至主要內容

实现一个并发任务执行框架

IT王小二约 2659 字

实现一个并发任务执行框架

作者:IT王小二

博客:https://itwxe.comopen in new window

一、需求产生和分析

问题参考来自网络。

公司里有两个项目组,考试组有批量的离线文档要生成,题库组则经常有批量的题目进行排重和根据条件批量修改题目的内容。

架构组通过对实际的上线产品进行用户调查,发现这些功能在实际使用时,用户都反应速度很慢,而且提交任务后,不知道任务的进行情况,做没做?做到哪一步了?有哪些成功?哪些失败了?都一概不知道架构组和实际的开发人员沟通,他们都说,因为前端提交任务到 Web 后台以后,是一次要处理多个文档和题目,所以速度快不起来。提示用多线程进行改进,实际的开发人员表示多线程没有用过,不知道如何使用,也担心用不好。综合以上情况,架构组决定在公司的基础构件库中提供一个并发任务执行框架,以解决上述用户和业务开发人员的痛点:

  1. 对批量型任务提供统一的开发接口。
  2. 在使用上尽可能的对业务开发人员友好。
  3. 要求可以查询批量任务的执行进度。

二、需要怎么做

  1. 批量任务,为提高性能,必然的我们要使用 java 里的多线程,为了在使用上尽可能的对业务开发人员友好和简单,需要屏蔽一些底层 java 并发编程中的细节,让他们不需要去了解并发容器,阻塞队列,异步任务,线程安全等等方面的知识,只要专心于自的业务处理即可。
  2. 每个批量任务拥有自己的上下文环境,因为一个项目组里同时要处理的批量任务可能有多个,比如考试组,可能就会有不同的学校的批量的离线文档生成,而题库组则会不同的学科都会有老师同时进行工作,因此需要一个并发安全的容器保存每个任务的属性信息。
  3. 自动清除已完成和过期任务因为要提供进度查询,系统需要在内存中维护每个任务的进度信息以供查询,但是这种查询又是有时间限制的,一个任务完成一段时间后,就不再提供进度查询了,则就需要我们自动清除已完成和过期任务,用轮询来处理的话不太优雅,那么就可以使用一个 DelayQueue 或者 Redis了。

三、实现流程图

流程图

四、具体实现代码

/**
 * @Author SunnyBear
 * @Description 方法本身运行是否正确的结果类型
 */
public enum TaskResultType {
    /**
     * 方法执行完成,业务结果也正确
     */
    SUCCESS,
    /**
     * 方法执行完成,业务结果错误
     */
    FAILURE,
    /**
     * 方法执行抛出异常
     */
    EXCEPTION
}
/**
 * @Author SunnyBear
 * @Description 任务处理后返回的结果实体类
 */
public class TaskResult<R> {

    /**
     * 方法执行结果
     */
    private final TaskResultType taskResultType;
    /**
     * 方法执行后的结果数据
     */
    private final R returnValue;
    /**
     * 如果方法执行失败,可以在这里填充原因
     */
    private final String reason;

    public TaskResult(TaskResultType taskResultType, R returnValue, String reason) {
        this.taskResultType = taskResultType;
        this.returnValue = returnValue;
        this.reason = reason;
    }

    public TaskResultType getTaskResultType() {
        return taskResultType;
    }

    public R getReturnValue() {
        return returnValue;
    }

    public String getReason() {
        return reason;
    }

    @Override
    public String toString() {
        return "TaskResult{" +
                "taskResultType=" + taskResultType +
                ", returnValue=" + returnValue +
                ", reason='" + reason + '\'' +
                '}';
    }
}
/**
 * @Author SunnyBear
 * @Description 要求框架使用者实现的任务接口,因为任务的性质在调用时才知道,所以传入的参数(T)和方法(R)的返回值均使用泛型
 */
public interface ITaskProcesser<T, R> {

    TaskResult<R> taskExecute(T data);
}
/**
 * @Author SunnyBear
 * @Description 存放的队列的元素
 */
public class Item<T> implements Delayed {

    /**
     * 到期时间(秒)
     */
    private long activeTime;
    /**
     * 业务数据
     */
    private T data;

    public Item(long activeTime, T data) {
        this.activeTime = activeTime * 1000 + System.currentTimeMillis();
        this.data = data;
    }

    public long getActiveTime() {
        return activeTime;
    }

    public T getData() {
        return data;
    }

    /**
     * 从这个方法返回到激活日期的剩余时间
     */
    @Override
    public long getDelay(TimeUnit unit) {
        long time = unit.convert(this.activeTime - System.currentTimeMillis(), unit);
        return time;
    }

    /**
     * Delayed接口继承了Comparable接口,按剩余时间排序,实际计算考虑精度为纳秒数
     */
    @Override
    public int compareTo(Delayed o) {
        long d = (getDelay(TimeUnit.MILLISECONDS) - o.getDelay(TimeUnit.MILLISECONDS));
        if (d == 0) {
            return 0;
        } else {
            if (d < 0) {
                return -1;
            } else {
                return 1;
            }
        }
    }
}
/**
 * @Author SunnyBear
 * @Description 提交给框架执行的工作实体类,工作:表示本批次需要处理的同性质任务(Task)的一个集合
 */
public class JobInfo<R> {

    /**
     * 工作名,用以区分框架中唯一的工作
     */
    private final String jobName;
    /**
     * 工作中任务的长度,即一个工作多少个任务
     */
    private final int jobLength;
    /**
     * 处理工作中任务的处理器
     */
    private final ITaskProcesser<?, ?> taskProcesser;
    /**
     * 任务的成功次数
     */
    private AtomicInteger successCount;
    /**
     * 工作中任务目前已经处理的次数
     */
    private AtomicInteger taskProcessCount;
    /**
     * 存放每个任务的处理结果,供查询用
     */
    private LinkedBlockingDeque<TaskResult<R>> taskDetailQueues;
    /**
     * 保留工作结果信息供查询的时长
     */
    private final long expireTime;
    /**
     * checkJobProcesser,用于添加处理的结果到延时队列
     */
    private CheckJobProcesser checkJobProcesser = CheckJobProcesser.getInstance();

    public JobInfo(String jobName, int jobLength,
                   ITaskProcesser<?, ?> taskProcesser,
                   long expireTime) {
        this.jobName = jobName;
        this.jobLength = jobLength;
        this.taskProcesser = taskProcesser;
        this.successCount = new AtomicInteger(0);
        this.taskProcessCount = new AtomicInteger(0);
        this.taskDetailQueues = new LinkedBlockingDeque<>(jobLength);
        this.expireTime = expireTime;
    }

    public String getJobName() {
        return jobName;
    }

    public int getJobLength() {
        return jobLength;
    }

    public long getExpireTime() {
        return expireTime;
    }

    public ITaskProcesser<?, ?> getTaskProcesser() {
        return taskProcesser;
    }

    public AtomicInteger getTaskProcessCount() {
        return taskProcessCount;
    }

    public AtomicInteger getSuccessCount() {
        return successCount;
    }

    /**
     * 失败次数
     */
    public long getFailureCount() {
        return taskProcessCount.get() - successCount.get();
    }

    /**
     * 工作的整体进度信息
     */
    public String getTotalProcess() {
        return "Success[ " + successCount.get() + " ] / Current[ " + taskProcessCount.get() + " ] Total[ " + jobLength + " ]";
    }

    /**
     * 提供工作中每个任务的处理结果
     */
    public List<TaskResult<R>> getTaskDetail() {
        List<TaskResult<R>> taskResultList = new LinkedList<>();
        TaskResult<R> taskResult;
        while ((taskResult = taskDetailQueues.pollFirst()) != null) {
            taskResultList.add(taskResult);
        }
        return taskResultList;
    }

    /**
     * 个任务处理完成后,记录任务的处理结果,因为从业务应用的角度来说,
     * 对查询任务进度数据的一致性要不高
     * 我们保证最终一致性即可,无需对整个方法加锁
     */
    public void addTaskResult(TaskResult<R> taskResult){
        if(TaskResultType.SUCCESS.equals(taskResult.getTaskResultType())){
            successCount.incrementAndGet();
        }
        taskProcessCount.incrementAndGet();
        taskDetailQueues.addLast(taskResult);
        if(taskProcessCount.get() == jobLength){
            checkJobProcesser.putJob(expireTime, jobName);
        }
    }
}
/**
 * @Author SunnyBear
 * @Description 框架的主题类,也是调用者主要使用的类
 */
public class PendingJobPool {

    /**
     * 框架运行时的线程数,与机器的CPU数相同
     */
    private static final int THREAD_COUNTS = Runtime.getRuntime().availableProcessors();
    /**
     * 用以存放待处理的任务,供线程池使用
     */
    private static BlockingQueue<Runnable> taskQueue = new ArrayBlockingQueue<>(5000);
    /**
     * 线程池,固定大小,有界队列
     */
    private static ExecutorService taskExecutor
            = new ThreadPoolExecutor(THREAD_COUNTS, THREAD_COUNTS, 60, TimeUnit.SECONDS, taskQueue);
    /**
     * 工作信息的存放容器
     */
    private static ConcurrentHashMap<String, JobInfo<?>> jobInfoMap = new ConcurrentHashMap<>();

    public static Map<String, JobInfo<?>> getMap() {
        return jobInfoMap;
    }

    private PendingJobPool() {

    }

    /**
     * 单例模式
     */
    private static PendingJobPool threadPool = new PendingJobPool();

    public static PendingJobPool getInstance() {
        return threadPool;
    }

    private static class PendingTask<T, R> implements Runnable {

        private JobInfo<R> jobInfo;
        // 任务参数
        private T processData;

        public PendingTask(JobInfo<R> jobInfo, T processData) {
            this.jobInfo = jobInfo;
            this.processData = processData;
        }

        @Override
        public void run() {
            R r = null;
            ITaskProcesser<T, R> taskProcesser = (ITaskProcesser<T, R>) jobInfo.getTaskProcesser();
            TaskResult<R> result = null;
            try {
                result = taskProcesser.taskExecute(processData);
                if (result == null) {
                    result = new TaskResult<R>(TaskResultType.EXCEPTION, r, "结果为空");
                }
                if (result.getTaskResultType() == null) {
                    if (result.getReason() == null) {
                        result = new TaskResult<R>(TaskResultType.EXCEPTION, r, "结果为空");
                    } else {
                        result = new TaskResult<R>(TaskResultType.EXCEPTION, r, "结果为空, 原因:" + result.getReason());
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
                result = new TaskResult<R>(TaskResultType.EXCEPTION, r, e.getMessage());
            } finally {
                // 江结果添加到延时队列及结果队列,同时计数
                jobInfo.addTaskResult(result);
            }
        }
    }

    /**
     * 调用者提交工作中的任务
     */
    public <T, R> void putTask(String jobName, T t) {
        JobInfo<R> jobInfo = getJob(jobName);
        PendingTask<T, R> task = new PendingTask<>(jobInfo, t);
        taskExecutor.execute(task);
    }

    /**
     * 根据工作名称检索工作
     */
    private <R> JobInfo<R> getJob(String jobName) {
        JobInfo<R> jobInfo = (JobInfo<R>) jobInfoMap.get(jobName);
        if (null == jobInfo) {
            throw new RuntimeException(jobName + "是非法任务!");
        }
        return jobInfo;
    }

    /**
     * 调用者注册工作,如工作名,任务的处理器等等
     */
    public <R> void registerJob(String jobName, int jobLength,
                                ITaskProcesser<?, ?> taskProcesser, long expireTime) {
        JobInfo<R> jobInfo = new JobInfo<R>(jobName, jobLength, taskProcesser, expireTime);
        // putIfAbsent 如果传入key对应的value已经存在,就返回存在的value,不进行替换。如果不存在,就添加key和value,返回null
        if (jobInfoMap.putIfAbsent(jobName, jobInfo) != null) {
            throw new RuntimeException(jobName + "已经注册!");
        }
    }

    /**
     * 获得工作的整体处理进度
     */
    public <R> String getTaskProgess(String jobName) {
        JobInfo<R> jobInfo = getJob(jobName);
        return jobInfo.getTotalProcess();
    }

    /**
     * 获得每个任务的处理详情
     */
    public <R> List<TaskResult<R>> getTaskDetail(String jobName) {
        JobInfo<R> jobInfo = getJob(jobName);
        return jobInfo.getTaskDetail();
    }
}
/**
 * @Author SunnyBear
 * @Description 任务完成后, 在一定的时间供查询结果,之后为释放资源节约内存,需要定期处理过期的任务
 */
public class CheckJobProcesser {

    /**
     * 延时队列,存放处理好的工作任务结果
     */
    private static DelayQueue<Item<String>> queue = new DelayQueue<>();

    private static CheckJobProcesser processer = new CheckJobProcesser();

    private CheckJobProcesser() {
    }

    /**
     * 使用单例
     */
    public static CheckJobProcesser getInstance() {
        return processer;
    }

    /**
     * 处理队列中的到期任务
     */
    public static class FetchJob implements Runnable {
        private static DelayQueue<Item<String>> queue = CheckJobProcesser.queue;
        private static Map<String, JobInfo<?>> jobInfoMap = PendingJobPool.getMap();

        @Override
        public void run() {
            while (true) {
                try {
                    // 从延时队列DelayQueue中获取任务
                    Item<String> item = queue.take();
                    String jobName = item.getData();
                    jobInfoMap.remove(jobName);
                    System.out.println(jobName + " 过期了,从缓存中清除");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    /**
     * 任务完成后,放入队列,经过expireTime时间后,从整个框架中移除
     */
    public void putJob(long expireTime, String jobName) {
        Item<String> item = new Item<>(expireTime, jobName);
        queue.offer(item);
        System.out.println(jobName + " 已经放入过期检查缓存,时长:" + expireTime);
    }

    /**
     * 处理过期任务设置为守护线程
     */
    static {
        Thread thread = new Thread(new FetchJob());
        thread.setDaemon(true);
        thread.start();
        System.out.println("开启过期检查的任务线程。。。");
    }

}
/**
 * @Author SunnyBear
 * @Description 一个实际任务类(即用户要执行的任务类),将数值加上一个随机数,并休眠随机时间,模拟任务执行过程中的情况
 */
public class MyTask implements ITaskProcesser<Integer, Integer> {

    @Override
    public TaskResult<Integer> taskExecute(Integer data) {
        Random r = new Random();
        int flag = r.nextInt(500);
        try {
            Thread.sleep(flag);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        if (flag <= 300) {
            // 正常处理的情况
            Integer returnValue = data.intValue() + flag;
            return new TaskResult<Integer>(TaskResultType.SUCCESS, returnValue, "success");
        } else if (flag > 301 && flag <= 400) {
            // 处理失败的情况
            return new TaskResult<Integer>(TaskResultType.FAILURE, -1, "failure");
        } else {
            // 发生异常的情况
            try {
                throw new RuntimeException("异常发生了!!");
            } catch (Exception e) {
                return new TaskResult<Integer>(TaskResultType.EXCEPTION,
                        -1, e.getMessage());
            }
        }
    }
}
/**
 * @Author SunnyBear
 * @Description 模拟一个应用程序,提交工作和任务,并查询任务进度
 */
public class AppTest {
    private final static String JOB_NAME = "工作组一";
    /**
     * 一个工作组中的任务数量
     */
    private final static int JOB_LENGTH = 1000;

    /**
     * 查询任务进度的线程
     */
    private static class QueryResult implements Runnable {

        private PendingJobPool pool;

        public QueryResult(PendingJobPool pool) {
            super();
            this.pool = pool;
        }

        @Override
        public void run() {
            int i = 0;
            while (i < 350) {
                List<TaskResult<String>> taskDetail = pool.getTaskDetail(JOB_NAME);
                if (!taskDetail.isEmpty()) {
                    System.out.println("当前进度:" + pool.getTaskProgess(JOB_NAME));
                    System.out.println("任务详情: " + taskDetail);
                }
                try {
                    Thread.sleep(100);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                i++;
            }
        }

    }

    public static void main(String[] args) {
        MyTask myTask = new MyTask();
        PendingJobPool pool = PendingJobPool.getInstance();
        pool.registerJob(JOB_NAME, JOB_LENGTH, myTask, 5);
        Random r = new Random();
        for (int i = 0; i < JOB_LENGTH; i++) {
            pool.putTask(JOB_NAME, r.nextInt(1000));
        }
        Thread t = new Thread(new QueryResult(pool));
        t.start();
    }
}