Forráskód Böngészése

feature:000.000.007:后端使用线程池优化AI生成图表的接口;
创建处理Ai生成任务的线程池;
添加一个参数区分同步还是异步生成图表;

yang yi 1 hónapja
szülő
commit
64eae51ff5

+ 2 - 0
serve/sql/init.sql

@@ -26,6 +26,8 @@ CREATE TABLE if NOT EXISTS `chart` (
                           `generated_chart_data` TEXT COMMENT '生成的图表数据',
                           `analysis_conclusion` TEXT COMMENT '生成的分析结论',
                           `user_id` BIGINT NOT NULL COMMENT '创建用户ID',
+                          `state` CHAR(32) NOT NULL DEFAULT '等待中' COMMENT '图表状态,等待中,生成中,成功,失败',
+                          `execute_message` TEXT COMMENT '执行信息',
                           `created_time` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
                           `updated_time` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
                           `delete_flag` TINYINT(1) NOT NULL DEFAULT '0' COMMENT '删除标志,0:未删除,1:已删除'

+ 64 - 0
serve/src/main/java/space/anyi/BI/config/ThreadPoolExecutorConfig.java

@@ -0,0 +1,64 @@
+package space.anyi.BI.config;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+import java.util.concurrent.*;
+
+/**
+ * @ProjectName: serve
+ * @FileName: ThreadPoolExecutorConfig
+ * @Author: 杨逸
+ * @Data:2024/12/7 16:01
+ * @Description: 线程池配置
+ */
+@Configuration
+public class ThreadPoolExecutorConfig {
+    private final static Logger log = LoggerFactory.getLogger(ThreadPoolExecutorConfig.class);
+    @Bean
+    public ThreadPoolExecutor getThreadPoolExecutor(){
+        //核心线程数
+        int corePoolSize = 1;
+        //最大线程数
+        int maximumPoolSize = 2;
+        //非核心线程存活空闲时间
+        long keepAliveTime = 100L;
+        //时间单位
+        TimeUnit unit = TimeUnit.SECONDS;
+        //任务队列
+        BlockingQueue<Runnable> workQueue = new ArrayBlockingQueue<>(5);
+        //线程工厂
+        ThreadFactory threadFactory = new ThreadFactory() {
+            @Override
+            public Thread newThread(Runnable r) {
+                Thread thread = new Thread(r);
+                return thread;
+            }
+        };
+        //拒绝策略处理器
+        //RejectedExecutionHandler handler= new ThreadPoolExecutor.CallerRunsPolicy();
+        RejectedExecutionHandler handler= new RejectedExecutionHandler(){
+            /**
+             * @param r 任务
+             * @param executor 线程池
+             * @description: 拒绝策略处理器
+             * @author: 杨逸
+             * @data:2024/12/07 16:07:54
+             * @since 1.0.0
+             */
+            @Override
+            public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
+                int activeCount = executor.getActiveCount();
+                int corePoolSize = executor.getCorePoolSize();
+                int maximumPoolSize = executor.getMaximumPoolSize();
+                int queueSize = executor.getQueue().size();
+                log.warn("任务被拒绝执行,当前工作线程数:{},核心线程数:{},最大线程数:{},队列大小:{}",activeCount,corePoolSize,maximumPoolSize,queueSize);
+            }
+        };
+        //创建线程池
+        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(corePoolSize,maximumPoolSize, keepAliveTime,unit,workQueue,threadFactory);
+        return threadPoolExecutor;
+    }
+}

+ 7 - 3
serve/src/main/java/space/anyi/BI/controller/ChartController.java

@@ -95,9 +95,13 @@ public class ChartController {
         if (file.getSize()>1024*1024*2L) {
             return ResponseResult.errorResult(ResponseResult.AppHttpCodeEnum.FILE_SIZE_ERROR);
         }
-        ChartVO vo = chartService.generateChartByAI(chartDTO,file);
-
-        return ResponseResult.okResult(vo);
+        if (chartDTO.getAsynchronism()){
+            chartService.generateChartByAIAsyn(chartDTO,file);
+            return ResponseResult.okResult("任务提交成功,请稍后查看结果");
+        }else {
+            ChartVO vo = chartService.generateChartByAI(chartDTO,file);
+            return ResponseResult.okResult(vo);
+        }
     }
 
     @ExceptionHandler({IOException.class})

+ 13 - 0
serve/src/main/java/space/anyi/BI/entity/dto/ChartDTO.java

@@ -29,6 +29,11 @@ public class ChartDTO {
      */
     private String chartType;
 
+    /**
+     * 是否异步处理
+     */
+    private Boolean isAsynchronism = false;
+
     public String getName() {
         return name;
     }
@@ -60,4 +65,12 @@ public class ChartDTO {
     public void setChartType(String chartType) {
         this.chartType = chartType;
     }
+
+    public Boolean getAsynchronism() {
+        return isAsynchronism;
+    }
+
+    public void setAsynchronism(Boolean asynchronism) {
+        isAsynchronism = asynchronism;
+    }
 }

+ 28 - 0
serve/src/main/java/space/anyi/BI/entity/vo/ChartVO.java

@@ -48,6 +48,16 @@ public class ChartVO {
      */
     private String userId;
 
+    /**
+     * 图表状态,等待中,生成中,成功,失败
+     */
+    private String state;
+
+    /**
+     * 执行信息
+     */
+    private String executeMessage;
+
     public String getId() {
         return id;
     }
@@ -112,6 +122,22 @@ public class ChartVO {
         this.userId = userId;
     }
 
+    public String getState() {
+        return state;
+    }
+
+    public void setState(String state) {
+        this.state = state;
+    }
+
+    public String getExecuteMessage() {
+        return executeMessage;
+    }
+
+    public void setExecuteMessage(String executeMessage) {
+        this.executeMessage = executeMessage;
+    }
+
     @Override
     public String toString() {
         return "ChartVO{" +
@@ -123,6 +149,8 @@ public class ChartVO {
                 ", generatedChartData='" + generatedChartData + '\'' +
                 ", analysisConclusion='" + analysisConclusion + '\'' +
                 ", userId='" + userId + '\'' +
+                ", state='" + state + '\'' +
+                ", executeMessage='" + executeMessage + '\'' +
                 '}';
     }
 }

+ 2 - 0
serve/src/main/java/space/anyi/BI/service/ChartService.java

@@ -16,5 +16,7 @@ public interface ChartService extends IService<Chart> {
 
     ChartVO generateChartByAI(ChartDTO chartDTO, MultipartFile file);
 
+    void generateChartByAIAsyn(ChartDTO chartDTO, MultipartFile file);
+
     PageVO getChartPage(Integer pageNum, Integer pageSize, String name, Long userId);
 }

+ 74 - 0
serve/src/main/java/space/anyi/BI/service/impl/ChartServiceImpl.java

@@ -23,9 +23,11 @@ import space.anyi.BI.util.BeanCopyUtil;
 import space.anyi.BI.util.ExcelUtils;
 import space.anyi.BI.util.SecurityUtils;
 
+import javax.annotation.Resource;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.concurrent.ThreadPoolExecutor;
 
 /**
 * @author 杨逸
@@ -36,6 +38,78 @@ import java.util.List;
 public class ChartServiceImpl extends ServiceImpl<ChartMapper, Chart>
     implements ChartService{
     private final static Logger log = LoggerFactory.getLogger(ChartServiceImpl.class);
+    @Resource
+    private ThreadPoolExecutor threadPoolExecutor;
+
+    /**
+     * 异步使用AI生成图表
+     * @param chartDTO
+     * @param file
+     * @description:
+     * @author: 杨逸
+     * @data:2024/12/07 17:43:22
+     * @since 1.0.0
+     */
+    @Override
+    public void generateChartByAIAsyn(ChartDTO chartDTO, MultipartFile file) {
+
+        Chart chart = BeanCopyUtil.copyBean(chartDTO, Chart.class);
+        long chartId = IdUtil.getSnowflake(1, 1).nextId();
+        chart.setId(chartId);
+        chart.setUserId(SecurityUtils.getUserId());
+        chart.setState("等待中");
+        save(chart);
+        //使用线程池优化生成图表的逻辑
+        threadPoolExecutor.execute(()->{
+            //读数据
+            String csvData  = "";
+            try {
+                csvData = ExcelUtils.excel2csv(file.getInputStream());
+                log.info("上传的数据为:\n{}", csvData);
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+            if (csvData.length()>3000){
+                chart.setState("失败");
+                chart.setExecuteMessage("数据量过大,请上传小于3000行的数据");
+                updateById(chart);
+                throw new SystemException(500, "数据量过大,请上传小于3000行的数据");
+            }
+            chart.setChartData(csvData);
+            chart.setState("生成中");
+            updateById(chart);
+
+            StringBuilder message = new StringBuilder("原始数据:\n");
+            message.append(csvData);
+            message.append("分析目标:\n");
+            message.append(chartDTO.getAnalysisTarget());
+            message.append("\n.使用").append(chartDTO.getChartType()).append("进行可视化分析.\n");
+            //配置prompt向AI发送请求
+            HttpRequestData requestData = AiUtil.createDefaultRequestData(message.toString());
+            HttpResponseData responseData = AiUtil.doChat(requestData);
+            //解析AI返回的数据
+            String content = responseData.getChoices().get(0).getMessage().getContent();
+            log.info("AI返回的数据为:{}", content);
+            int index = content.indexOf("```");
+            int endIndex = content.lastIndexOf("```");
+            if (index == -1 || endIndex == -1){
+                chart.setState("失败");
+                chart.setExecuteMessage("AI生成图表失败");
+                updateById(chart);
+                throw new SystemException(500, "AI生成图表失败");
+            }
+            //数据可视化,Echarts的option代码
+            chart.setGeneratedChartData(content.substring(index+7, endIndex).trim());
+            index = endIndex;
+            //分析结论
+            chart.setAnalysisConclusion(content.substring(index+3).trim());
+            //保存到数据库
+            chart.setState("成功");
+            chart.setExecuteMessage("AI生成图表成功");
+            updateById(chart);
+        });
+
+    }
 
     @Override
     public PageVO getChartPage(Integer pageNum, Integer pageSize, String name, Long userId) {

+ 55 - 0
serve/src/test/java/space/anyi/BI/config/ThreadPoolExecutorConfigTest.java

@@ -0,0 +1,55 @@
+package space.anyi.BI.config;
+
+import io.netty.util.concurrent.CompleteFuture;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.springframework.boot.test.context.SpringBootTest;
+
+import javax.annotation.Resource;
+import java.util.concurrent.ThreadPoolExecutor;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+/**
+ * @ProjectName: serve
+ * @FileName: ThreadPoolExecutorConfigTest
+ * @Author: 杨逸
+ * @Data:2024/12/7 16:13
+ * @Description:
+ */
+@SpringBootTest
+class ThreadPoolExecutorConfigTest {
+    @Resource
+    private ThreadPoolExecutor threadPoolExecutor;
+
+    @Test
+    void getThreadPoolExecutor() throws InterruptedException {
+        final int[] count = {1};
+        //创建一个任务
+        Runnable runnable = new Runnable() {
+            @Override
+            public void run() {
+                System.out.println("线程池执行任务" + count[0]++);
+                try {
+                    Thread.sleep(1000*2);
+                } catch (InterruptedException e) {
+                    e.printStackTrace();
+                }
+            }
+        };
+        //使用线程池执行任务
+        for (int i = 0; i < 7; i++) {
+            threadPoolExecutor.execute(runnable);
+        }
+        //工作的线程数
+        Assertions.assertEquals(2, threadPoolExecutor.getActiveCount());
+        //工作的最大线程数
+        Assertions.assertEquals(2, threadPoolExecutor.getMaximumPoolSize());
+        //任务队列大小
+        Assertions.assertEquals(5, threadPoolExecutor.getQueue().size());
+        //经过线程池的任务数量
+        Assertions.assertEquals(7, threadPoolExecutor.getTaskCount());
+        //休眠一下,让任务执行完成
+        Thread.sleep(1000*17);
+    }
+}

+ 1 - 0
serve/src/test/java/space/anyi/BI/handler/redisson/RRateLimiterHandlerTest.java

@@ -7,6 +7,7 @@ import org.redisson.api.RedissonClient;
 import org.springframework.boot.test.context.SpringBootTest;
 
 import javax.annotation.Resource;
+import java.util.concurrent.ThreadPoolExecutor;
 
 
 /**