|
@@ -3,63 +3,81 @@ package org.dromara.ai.handler;
|
|
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.dromara.ai.domain.MessageContext;
|
|
|
import org.dromara.ai.domain.deepseek.DeepSeekHttpResponseData;
|
|
|
import org.dromara.ai.util.DeepSeekAIUtil;
|
|
|
+import org.dromara.common.redis.utils.RedisUtils;
|
|
|
import org.springframework.jdbc.core.JdbcTemplate;
|
|
|
import org.springframework.stereotype.Component;
|
|
|
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
|
|
|
|
+import java.io.*;
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
|
import java.time.LocalDateTime;
|
|
|
-import java.util.List;
|
|
|
-import java.util.Map;
|
|
|
-import java.util.Set;
|
|
|
-
|
|
|
+import java.util.*;
|
|
|
+@Slf4j
|
|
|
@RequiredArgsConstructor
|
|
|
@Component
|
|
|
public class AIHandlerImpl implements AIHandler {
|
|
|
private final ObjectMapper objectMapper;
|
|
|
private final JdbcTemplate jdbcTemplate;
|
|
|
@Override
|
|
|
- public List<String> getTableNames(String context) {
|
|
|
- return List.of();
|
|
|
+ public List<String> getTableNames(MessageContext messageContext) {
|
|
|
+ messageContext.setStatus(1);
|
|
|
+ String keyName = "ai_tableNamesList";
|
|
|
+ List<String> tableNamesList = RedisUtils.getCacheList(keyName);
|
|
|
+ if (Objects.isNull(tableNamesList) || tableNamesList.isEmpty()) {
|
|
|
+ tableNamesList = loadTableNameProperty();
|
|
|
+ RedisUtils.setCacheList(keyName, tableNamesList);
|
|
|
+ }
|
|
|
+ StringBuilder tableNamesStringBuilder = new StringBuilder("所有的数据表名称和作用如下:\n");
|
|
|
+ for (String tableName : tableNamesList) {
|
|
|
+ tableNamesStringBuilder.append(tableName).append("\n");
|
|
|
+ }
|
|
|
+ String prompt = new StringBuilder("用户需求:\n").append(messageContext.getMessage())
|
|
|
+ .append("\n")
|
|
|
+ .append(tableNamesStringBuilder)
|
|
|
+ .append("请你根据用户需求和现有的数据表,选择满足用户需求的数据表,返回数据表的名称,以json的格式返回{'tableNames':[]}").toString();
|
|
|
+ String json = sendMessage(prompt);
|
|
|
+ List<String> tableNames = new ArrayList<>();
|
|
|
+ try {
|
|
|
+ tableNames = (List) objectMapper.readValue(json,Map.class).get("tableNames");
|
|
|
+ } catch (JsonProcessingException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ messageContext.setTableNames(tableNames);
|
|
|
+ return tableNames;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public String getSQL(String context, List<String> tableNames) {
|
|
|
- //todo:获取表的结构
|
|
|
- String struct = "CREATE TABLE `fa_kuyou_user_order` (\n" +
|
|
|
- " `id` bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT,\n" +
|
|
|
- " `user_id` int(11) NULL DEFAULT NULL COMMENT '用户ID',\n" +
|
|
|
- " `name` varchar(255) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL COMMENT '姓名',\n" +
|
|
|
- " `phone` varchar(20) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL COMMENT '联系方式',\n" +
|
|
|
- " `address` varchar(255) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL COMMENT '详细地址',\n" +
|
|
|
- " `order_num` varchar(50) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT '' COMMENT '订单号',\n" +
|
|
|
- " `order_price` float(11, 2) NOT NULL DEFAULT 0.00 COMMENT '订单价格',\n" +
|
|
|
- " `status` tinyint(3) UNSIGNED NULL DEFAULT 0 COMMENT '状态:0=待审核,1=已审核,2=待发货,3=待收货,4=已收货,5=退款中,6=退款完成,7=交易取消,8=交易完成',\n" +
|
|
|
- " `remark` varchar(255) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL COMMENT '备注',\n" +
|
|
|
- " `pay_status` tinyint(3) UNSIGNED NULL DEFAULT 0 COMMENT '支付状态:0=支付中,1=支付成功,2=支付失败',\n" +
|
|
|
- " `pay_time` datetime(0) NULL DEFAULT NULL COMMENT '付款时间',\n" +
|
|
|
- " `diver_type` enum('1','2') CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT '1' COMMENT '配送类型:1=专业配送,2=到店自取',\n" +
|
|
|
- " `store_id` int(11) NULL DEFAULT NULL COMMENT '门店ID',\n" +
|
|
|
- " `diver_id` int(11) NULL DEFAULT NULL COMMENT '司机ID',\n" +
|
|
|
- " `diver_name` varchar(255) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL COMMENT '司机姓名',\n" +
|
|
|
- " `diver_phone` varchar(100) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL COMMENT '司机电话',\n" +
|
|
|
- " `create_time` datetime(0) NULL DEFAULT NULL COMMENT '创建时间',\n" +
|
|
|
- " `update_time` datetime(0) NULL DEFAULT NULL COMMENT '更新时间',\n" +
|
|
|
- " `delete_time` datetime(0) NULL DEFAULT NULL COMMENT '删除时间',\n" +
|
|
|
- " `hide` int(11) NOT NULL DEFAULT 0,\n" +
|
|
|
- " `note` text CHARACTER SET utf8 COLLATE utf8_general_ci NULL COMMENT '备注',\n" +
|
|
|
- " `confirm_delivery_time` datetime(0) NULL DEFAULT NULL COMMENT '确认收货时间',\n" +
|
|
|
- " `factory_user_id` int(10) NOT NULL DEFAULT 0,\n" +
|
|
|
- " `order_status` varchar(255) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,\n" +
|
|
|
- " `tenant_id` varchar(255) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,\n" +
|
|
|
- " `create_dept` varchar(255) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,\n" +
|
|
|
- " `create_by` varchar(255) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,\n" +
|
|
|
- " `update_by` varchar(255) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,\n" +
|
|
|
- " PRIMARY KEY (`id`) USING BTREE,\n" +
|
|
|
- " INDEX `order_num`(`order_num`) USING BTREE\n" +
|
|
|
- ") ENGINE = InnoDB AUTO_INCREMENT = 1909434245425213445 CHARACTER SET = utf8 COLLATE = utf8_general_ci COMMENT = '订单表' ROW_FORMAT = Dynamic;";
|
|
|
+ public String getSQL(MessageContext messageContext) {
|
|
|
+ messageContext.setStatus(2);
|
|
|
+ String keyName = "ai_tableStructMap";
|
|
|
+ StringBuilder structBuilder = new StringBuilder();
|
|
|
+ List<String> tableNames = messageContext.getTableNames();
|
|
|
+ for (String tableName : tableNames) {
|
|
|
+ Object cacheMapValue = RedisUtils.getCacheMapValue(keyName, tableName);
|
|
|
+ String struct = "";
|
|
|
+ //redis中没有数据,从配置文件中拿
|
|
|
+ if (Objects.isNull(cacheMapValue)) {
|
|
|
+ Properties properties = new Properties();
|
|
|
+ try {
|
|
|
+ properties.load(this.getClass().getResourceAsStream("/ai/tableStruct.properties"));
|
|
|
+ } catch (IOException e) {
|
|
|
+ log.error("读取文件失败: {}", e.getMessage());
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ struct = properties.getProperty(tableName);
|
|
|
+ RedisUtils.setCacheMapValue(keyName, tableName, struct);
|
|
|
+ }else{
|
|
|
+ struct = (String) cacheMapValue;
|
|
|
+ }
|
|
|
+ structBuilder.append(struct).append("\n");
|
|
|
+ }
|
|
|
+ String struct = structBuilder.toString();
|
|
|
String prompt = new StringBuilder("用户需求:\n")
|
|
|
- .append(context)
|
|
|
+ .append(messageContext.getMessage())
|
|
|
.append("涉及的表和结构如下:\n")
|
|
|
.append(struct)
|
|
|
.append("\n")
|
|
@@ -67,23 +85,25 @@ public class AIHandlerImpl implements AIHandler {
|
|
|
.append(LocalDateTime.now())
|
|
|
.append("\n")
|
|
|
.append("请根据用户需求,生成对应的完整的可用的查询数据SQL语句,最大数据量限制为一千条,返回json格式的字符串{“sql”:“”}").toString();
|
|
|
- String response = sendMessage(prompt);
|
|
|
+// .append("请根据用户需求,生成对应的完整的可用的查询数据SQL语句,涉及到时间比较时使用'yy-mm-dd HH:MM:ss'格式字符串,最大数据量限制为一千条,返回json格式的字符串{“sql”:“”}").toString();
|
|
|
+ String json = sendMessage(prompt);
|
|
|
String sql = null;
|
|
|
try {
|
|
|
- Map map = objectMapper.readValue(response, Map.class);
|
|
|
+ Map map = objectMapper.readValue(json, Map.class);
|
|
|
sql = (String) map.get("sql");
|
|
|
- System.out.println("sql = " + sql);
|
|
|
} catch (JsonProcessingException e) {
|
|
|
- System.err.println("response = " + response);
|
|
|
+ log.error("LLM处理数据返回的数据格式错误: {}", e.getMessage());
|
|
|
throw new RuntimeException(e);
|
|
|
}
|
|
|
+ log.info("sql = {}", sql);
|
|
|
+ messageContext.setSql(sql);
|
|
|
return sql;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public String getData(String content, String sql) {
|
|
|
- System.out.println("sql = \n" + sql);
|
|
|
- List<Map<String,Object>> maps = jdbcTemplate.queryForList(sql);
|
|
|
+ public String getData(MessageContext messageContext) {
|
|
|
+ messageContext.setStatus(3);
|
|
|
+ List<Map<String,Object>> maps = jdbcTemplate.queryForList(messageContext.getSql());
|
|
|
StringBuilder dataStringBuilder = new StringBuilder();
|
|
|
StringBuilder titleStringBuilder = new StringBuilder();
|
|
|
for (int i = 0; i < maps.size(); i++) {
|
|
@@ -98,26 +118,31 @@ public class AIHandlerImpl implements AIHandler {
|
|
|
}
|
|
|
dataStringBuilder.append("\n");
|
|
|
}
|
|
|
- return titleStringBuilder + "\n" + dataStringBuilder;
|
|
|
+ String data = titleStringBuilder + "\n" + dataStringBuilder;
|
|
|
+ messageContext.setData(data);
|
|
|
+ return data;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public Object dataHandler(String context, String data) {
|
|
|
- System.out.println("data = \n" + data);
|
|
|
- String prompt = new StringBuilder("用户需求:\n").append(context)
|
|
|
+ public Object dataHandler(MessageContext messageContext) {
|
|
|
+ messageContext.setStatus(4);
|
|
|
+ String data = messageContext.getData();
|
|
|
+ log.info("data = \n{}", data);
|
|
|
+ String prompt = new StringBuilder("用户需求:\n").append(messageContext.getMessage())
|
|
|
.append("涉及的数据如下:\n").append(data)
|
|
|
.append("\n")
|
|
|
- .append("请根据用户需求和数据给出结论(中文的),并选择表格或者图表(折线图、柱状图、饼图、热力图)将数据展示出来,表格状态为1;图标状态为2,表格或者图表使用svg格式的xml,以json的格式返回.{'conclusion':'','data':'','status':0}").toString();
|
|
|
+// .append("请根据用户需求和数据给出结论(中文的),并选择表格或者图表(折线图、柱状图、饼图、热力图)将数据展示出来,表格状态为1;图标状态为2,表格或者图表使用svg格式的xml,分辨率为1920*1080,以json的格式返回.{'conclusion':'','data':'','status':0}").toString();
|
|
|
// .append("请根据用户需求和数据给出结论(中文的),并选择表格或者图表(折线图、柱状图、饼图、热力图)将数据展示出来,表格状态为1,markdown格式的字符串;图标状态为2,echarts的配置对象json格式的字符串,以json的格式返回.{'conclusion':'','data':'','status':0}").toString();
|
|
|
- String mes = sendMessage(prompt);
|
|
|
- System.out.println("mes = " + mes);
|
|
|
+ .append("请根据用户需求和数据给出结论(中文的),并选择表格或者图表(折线图、柱状图、饼图、热力图)将数据展示出来,表格状态为1,使用HTML的table标签;图标状态为2,echarts的配置对象json格式的字符串,以json的格式返回.{'conclusion':'','data':'','status':0}").toString();
|
|
|
+ String json = sendMessage(prompt);
|
|
|
Map map = null;
|
|
|
try {
|
|
|
- map = objectMapper.readValue(mes, Map.class);
|
|
|
+ map = objectMapper.readValue(json, Map.class);
|
|
|
} catch (JsonProcessingException e) {
|
|
|
- System.err.println("LLM处理数据返回的数据格式错误:");
|
|
|
+ log.error("LLM处理数据返回的数据格式错误: {}", e.getMessage());
|
|
|
throw new RuntimeException(e);
|
|
|
}
|
|
|
+ messageContext.setResult(map);
|
|
|
return map;
|
|
|
}
|
|
|
|
|
@@ -126,4 +151,53 @@ public class AIHandlerImpl implements AIHandler {
|
|
|
DeepSeekHttpResponseData deepSeekHttpResponseData = DeepSeekAIUtil.doChat(DeepSeekAIUtil.createDefaultRequestData(prompt));
|
|
|
return deepSeekHttpResponseData.getChoices().get(0).getMessage().getContent();
|
|
|
}
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void handle(MessageContext messageContext, SseEmitter emitter) {
|
|
|
+ try {
|
|
|
+ //获取涉及的表名,通过LLM
|
|
|
+ emitter.send("1正在分析涉及的数据表");
|
|
|
+ List<String> tableNames = getTableNames(messageContext);
|
|
|
+ //获取sql,通过LLM
|
|
|
+ emitter.send("2正在组织sql");
|
|
|
+ String sql = getSQL(messageContext);
|
|
|
+ //查数据,通过JDBCTemplate
|
|
|
+ emitter.send("3正在查询数据");
|
|
|
+ String data = getData(messageContext);
|
|
|
+ //处理数据,通过LLM
|
|
|
+ emitter.send("4正在处理数据");
|
|
|
+ Object object = dataHandler(messageContext);
|
|
|
+ emitter.send(object);
|
|
|
+ } catch (IOException e) {
|
|
|
+ log.error("sseEmitter error: {}", e.getMessage());
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ emitter.complete();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 从文件中获取涉及的表名
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ private List<String> loadTableNameProperty() {
|
|
|
+ List<String> tableNamesList = new ArrayList<>();
|
|
|
+ InputStream inputStream = this.getClass().getResourceAsStream("/ai/tableNames.properties");
|
|
|
+ InputStreamReader a = new InputStreamReader(inputStream, StandardCharsets.UTF_8);
|
|
|
+ Properties properties = new Properties();
|
|
|
+
|
|
|
+ try {
|
|
|
+ properties.load(a);
|
|
|
+ } catch (IOException e) {
|
|
|
+ System.err.println("读取tables.properties文件失败");
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ Enumeration<Object> keys = properties.keys();
|
|
|
+ Iterator<Object> iterator = keys.asIterator();
|
|
|
+ while (iterator.hasNext()) {
|
|
|
+ String key = iterator.next().toString();
|
|
|
+ String property = properties.getProperty(key);
|
|
|
+ tableNamesList.add(key + ":" + property);
|
|
|
+ }
|
|
|
+ return tableNamesList;
|
|
|
+ }
|
|
|
}
|