|
@@ -1,13 +1,18 @@
|
|
|
package org.dromara.ai.handler;
|
|
|
|
|
|
+import cn.hutool.core.util.IdUtil;
|
|
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.dromara.ai.domain.Message;
|
|
|
import org.dromara.ai.domain.MessageContext;
|
|
|
+import org.dromara.ai.domain.VO.MessageVO;
|
|
|
import org.dromara.ai.domain.deepseek.DeepSeekHttpResponseData;
|
|
|
+import org.dromara.ai.service.MessageService;
|
|
|
import org.dromara.ai.util.DeepSeekAIUtil;
|
|
|
import org.dromara.common.redis.utils.RedisUtils;
|
|
|
+import org.springframework.beans.BeanUtils;
|
|
|
import org.springframework.jdbc.core.JdbcTemplate;
|
|
|
import org.springframework.stereotype.Component;
|
|
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
@@ -22,6 +27,7 @@ import java.util.*;
|
|
|
public class AIHandlerImpl implements AIHandler {
|
|
|
private final ObjectMapper objectMapper;
|
|
|
private final JdbcTemplate jdbcTemplate;
|
|
|
+ private final MessageService messageService;
|
|
|
@Override
|
|
|
public List<String> getTableNames(MessageContext messageContext) {
|
|
|
messageContext.setStatus(1);
|
|
@@ -35,7 +41,7 @@ public class AIHandlerImpl implements AIHandler {
|
|
|
for (String tableName : tableNamesList) {
|
|
|
tableNamesStringBuilder.append(tableName).append("\n");
|
|
|
}
|
|
|
- String prompt = new StringBuilder("用户需求:\n").append(messageContext.getMessage())
|
|
|
+ String prompt = new StringBuilder("用户需求:\n").append(messageContext.getMessage().getContent())
|
|
|
.append("\n")
|
|
|
.append(tableNamesStringBuilder)
|
|
|
.append("请你根据用户需求和现有的数据表,选择满足用户需求的数据表,返回数据表的名称,以json的格式返回{'tableNames':[]}").toString();
|
|
@@ -77,7 +83,7 @@ public class AIHandlerImpl implements AIHandler {
|
|
|
}
|
|
|
String struct = structBuilder.toString();
|
|
|
String prompt = new StringBuilder("用户需求:\n")
|
|
|
- .append(messageContext.getMessage())
|
|
|
+ .append(messageContext.getMessage().getContent())
|
|
|
.append("涉及的表和结构如下:\n")
|
|
|
.append(struct)
|
|
|
.append("\n")
|
|
@@ -128,7 +134,7 @@ public class AIHandlerImpl implements AIHandler {
|
|
|
messageContext.setStatus(4);
|
|
|
String data = messageContext.getData();
|
|
|
log.info("data = \n{}", data);
|
|
|
- String prompt = new StringBuilder("用户需求:\n").append(messageContext.getMessage())
|
|
|
+ String prompt = new StringBuilder("用户需求:\n").append(messageContext.getMessage().getContent())
|
|
|
.append("涉及的数据如下:\n").append(data)
|
|
|
.append("\n")
|
|
|
// .append("请根据用户需求和数据给出结论(中文的),并选择表格或者图表(折线图、柱状图、饼图、热力图)将数据展示出来,表格状态为1;图标状态为2,表格或者图表使用svg格式的xml,分辨率为1920*1080,以json的格式返回.{'conclusion':'','data':'','status':0}").toString();
|
|
@@ -154,25 +160,37 @@ public class AIHandlerImpl implements AIHandler {
|
|
|
|
|
|
@Override
|
|
|
public void handle(MessageContext messageContext, SseEmitter emitter) {
|
|
|
+ //todo:优化逻辑,提供更丰富的上下文
|
|
|
try {
|
|
|
//获取涉及的表名,通过LLM
|
|
|
- emitter.send("1正在分析涉及的数据表");
|
|
|
+ emitter.send(objectMapper.writeValueAsString(new MessageVO(1,"正在分析涉及的数据表")));
|
|
|
List<String> tableNames = getTableNames(messageContext);
|
|
|
//获取sql,通过LLM
|
|
|
- emitter.send("2正在组织sql");
|
|
|
+ emitter.send(objectMapper.writeValueAsString(new MessageVO(2,"正在组织sql")));
|
|
|
String sql = getSQL(messageContext);
|
|
|
//查数据,通过JDBCTemplate
|
|
|
- emitter.send("3正在查询数据");
|
|
|
+ emitter.send(objectMapper.writeValueAsString(new MessageVO(3,"正在查询数据")));
|
|
|
String data = getData(messageContext);
|
|
|
//处理数据,通过LLM
|
|
|
- emitter.send("4正在处理数据");
|
|
|
+ emitter.send(objectMapper.writeValueAsString(new MessageVO(4,"正在处理数据")));
|
|
|
Object object = dataHandler(messageContext);
|
|
|
- emitter.send(object);
|
|
|
+ //构建ai的相应消息
|
|
|
+ Message message = new Message();
|
|
|
+ message.setId(IdUtil.getSnowflakeNextId());
|
|
|
+ message.setConversationId(messageContext.getMessage().getConversationId());
|
|
|
+ message.setRole(2);
|
|
|
+ message.setTimestamp(LocalDateTime.now());
|
|
|
+ message.setContent(object.toString().getBytes(StandardCharsets.UTF_8));
|
|
|
+ messageService.save(message);
|
|
|
+ MessageVO messageVO = new MessageVO(0,"");
|
|
|
+ BeanUtils.copyProperties(message, messageVO);
|
|
|
+ emitter.send(messageVO);
|
|
|
} catch (IOException e) {
|
|
|
log.error("sseEmitter error: {}", e.getMessage());
|
|
|
throw new RuntimeException(e);
|
|
|
+ }finally {
|
|
|
+ emitter.complete();
|
|
|
}
|
|
|
- emitter.complete();
|
|
|
}
|
|
|
|
|
|
/**
|