java如何对接cahtgpt API(简单记录)

技术选型

springboot+mybatis-plus

实现效果

  • 通过java调用chatgpt API实现对话,将chatgpt生成的内容通过与前端建立websocket及时发送给前端,为加快chatgpt响应速度,采用exent/stream的方式进行,实现了逐字出现的效果

实现过程

java对接chatgpt API
  • 使用java原生的网络请求方式完成
  • 在发送网络请求时,将"stream"设置为 true,代表使用event stream的方式进行返回数据
String url = "https://api.openai.com/v1/chat/completions";
        HashMap<String, Object> bodymap = new HashMap<>();

        bodymap.put("model", "gpt-3.5-turbo");
        bodymap.put("temperature", 0.7);
//        bodymap.put("stream",true);
        bodymap.put("messages", messagelist);
        bodymap.put("stream", true);
        Gson gson = new Gson();
        String s = gson.toJson(bodymap);
//        System.out.println(s);
        URL url1 = new URL(url);
        HttpURLConnection conn = (HttpURLConnection) url1.openConnection(new Proxy(Proxy.Type.HTTP, new InetSocketAddress(host, port)));
        conn.setRequestMethod("POST");
        conn.setRequestProperty("Authorization", "Bearer " + ApiKey);
        conn.setRequestProperty("Content-Type", "application/json");
        conn.setRequestProperty("stream", "true");
        conn.setDoOutput(true);
//    写入请求参数
        OutputStream os = conn.getOutputStream();
        BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(os, Charset.forName("UTF-8")));
        writer.write(s);
        writer.close();
        os.close();
  • 读取返回值
InputStream inputStream = conn.getInputStream();

        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
        String line = null;
//        System.out.println("开始回答");
        StringBuffer answoer = new StringBuffer();
        while ((line = bufferedReader.readLine()) != null) {

            line = line.replace("data:", "");
            JsonElement jsonElement = JsonParser.parseString(line);
            if (!jsonElement.isJsonObject()) {

                continue;
            }
            JsonObject asJsonObject = jsonElement.getAsJsonObject();
            JsonArray choices = asJsonObject.get("choices").getAsJsonArray();
            if (choices.size() > 0) {
                JsonObject choice = choices.get(0).getAsJsonObject();
                JsonObject delta = choice.get("delta").getAsJsonObject();
                if (delta != null) {
//                    System.out.println(delta);
                    if (delta.has("content")) {
//                        发送消息
                        String content = delta.get("content").getAsString();
                        BaseResponse<String> success = ResultUtils.success(content);
                        WebSocket webSocket = new WebSocket();

                        webSocket.sendMessageByUserId(conversionid, gson.toJson(success));
                        answoer.append(content);
//                        webSocket.sendOneMessage(userid, success);
//                        webSocket.sendOneMessage(userid, success);
//                      打印在控制台中
                        System.out.print(content);
                    }
                }
            }

        }
        String context = answoer.toString();
        //        将chatgpt返回的结果保存到数据库中
        Chat entity = new Chat();
        entity.setContext(context);
        entity.setRole("assistant");
        entity.setConversionid(conversionid);
        boolean save = chatService.save(entity);


//        String s1 = stringRedisTemplate.opsForValue().get("web:" + userid);
//        List<ChatModel> json = (List<ChatModel>) gson.fromJson(s1, new TypeToken<List<ChatModel>>() {
//        }.getType());
//        ChatModel chatModel = new ChatModel("assistant",answoer.toString());
//        json.add(chatModel);
//        stringRedisTemplate.opsForValue().set("web:" + userid,gson.toJson(json),1, TimeUnit.DAYS);

    }
实现websocket与前端建立连接
@ServerEndpoint(value = "/websocket/{ConversionId}")
@Component
public class WebSocket {

    private static ChatGptUntil chatGptUntil;

    private static ChatService chatService;

    private static ConversionService conversionService;

    @Resource
    public void setConversionService(ConversionService conversionService) {
        WebSocket.conversionService = conversionService;
    }

    @Resource
    public void setChatService(ChatService chatService) {
        WebSocket.chatService = chatService;
    }

    @Resource
    public void setChatGptUntil(ChatGptUntil chatGptUntil) {
        WebSocket.chatGptUntil = chatGptUntil;
    }

    private final static Logger logger = LogManager.getLogger(WebSocket.class);

    /**
     * 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的
     */

    private static int onlineCount = 0;

    /**
     * concurrent包的线程安全Map,用来存放每个客户端对应的MyWebSocket对象
     */
    private static ConcurrentHashMap<String, WebSocket> webSocketMap = new ConcurrentHashMap<>();

    /**
     * 与某个客户端的连接会话,需要通过它来给客户端发送数据
     */

    private Session session;
    private Long ConversionId;


    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("ConversionId") Long ConversionId) {
        this.session = session;
        this.ConversionId = ConversionId;
        //加入map
        webSocketMap.put(ConversionId.toString(), this);
        addOnlineCount();           //在线数加1
        logger.info("对话{}连接成功,当前在线人数为{}", ConversionId, getOnlineCount());
        try {
            sendMessage(String.valueOf(this.session.getQueryString()));
        } catch (IOException e) {
            logger.error("IO异常");
        }
    }


    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        //从map中删除
        webSocketMap.remove(ConversionId.toString());
        subOnlineCount();           //在线数减1
        logger.info("对话{}关闭连接!当前在线人数为{}", ConversionId, getOnlineCount());
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session) throws IOException {
        logger.info("来自客户端对话:{} 消息:{}", ConversionId, message);


        Gson gson = new Gson();

//        ChatMessage chatMessage = gson.fromJson(message, ChatMessage.class);

        System.out.println(message);

//        Long conversionid = chatMessage.getConversionid();
//        if (conversionid == null) {
//            BaseResponse baseResponse = ResultUtils.error(4000, "请指明是哪个对话");
//            String s = gson.toJson(baseResponse);
//            session.getBasicRemote().sendText(s);
//        }

        if (message == null) {
            BaseResponse baseResponse = ResultUtils.error(4000, "请指明是该对话的用途");
            String s = gson.toJson(baseResponse);
            session.getBasicRemote().sendText(s);
        }
//        将对话保存到数据库中
        Chat entity = new Chat();
        entity.setContext(message);
        entity.setConversionid(this.ConversionId);
        entity.setRole("user");
        boolean save = chatService.save(entity);

        if (!save) {
            BaseResponse baseResponse = ResultUtils.error(500, "数据库出现错误");
            String s = gson.toJson(baseResponse);
            session.getBasicRemote().sendText(s);
        }


//        查询出身份
        Conversion byId = conversionService.getById(this.ConversionId);
        String instructions = byId.getInstructions();// 指令
//     给予chatgot身份
        ArrayList<ChatModel> chatModels = new ArrayList<>();
//        ChatModel scene = new ChatModel("user", instructions);
//        chatModels.add(scene);

        LambdaQueryWrapper<Chat> queryWrapper = new LambdaQueryWrapper<>();
        // 按照修改时间进行升序排序
        queryWrapper.eq(Chat::getConversionid, byId.getId()).orderByDesc(Chat::getUpdatedtime);
        List<Chat> list = chatService.list(queryWrapper);

//        查询之前的对话记录
        List<ChatModel> collect = list.stream().map(chat -> {
            ChatModel chatModel = new ChatModel();
            chatModel.setRole(chat.getRole());
            chatModel.setContent(chat.getContext());
//            BeanUtils.copyProperties(chat, chatModel);
            return chatModel;
        }).collect(Collectors.toList());
        chatModels.addAll(collect);


        chatGptUntil.getRespost(this.ConversionId, chatModels);
//        if (chatGptUntil==null){
//            System.out.println("chatuntil是空");
//        }
//
//        if (stringRedisTemplate==null){
//            System.out.println("缓存是空");
//        }


        //群发消息
        /*for (String item : webSocketMap.keySet()) {
            try {
                webSocketMap.get(item).sendMessage(message);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }*/
    }

    /**
     * 发生错误时调用
     *
     * @OnError
     */
    @OnError
    public void onError(Session session, Throwable error) {
        logger.error("对话错误:" + this.ConversionId + ",原因:" + error.getMessage());
        error.printStackTrace();
    }

    /**
     * 向客户端发送消息
     */
    public void sendMessage(String message) throws IOException {
        this.session.getBasicRemote().sendText(message);
        //this.session.getAsyncRemote().sendText(message);
    }

    /**
     * 通过userId向客户端发送消息
     */
    public void sendMessageByUserId(Long ConversionId, String message) throws IOException {
        logger.info("服务端发送消息到{},消息:{}", ConversionId, message);
        if (StrUtil.isNotBlank(ConversionId.toString()) && webSocketMap.containsKey(ConversionId.toString())) {
            webSocketMap.get(ConversionId.toString()).sendMessage(message);
        } else {
            logger.error("{}不在线", ConversionId);
        }

    }

    /**
     * 群发自定义消息
     */
    public static void sendInfo(String message) {
        for (String item : webSocketMap.keySet()) {
            try {
                webSocketMap.get(item).sendMessage(message);
            } catch (IOException e) {
                continue;
            }
        }
    }

    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    public static synchronized void addOnlineCount() {
        WebSocket.onlineCount++;
    }

    public static synchronized void subOnlineCount() {
        WebSocket.onlineCount--;
    }

}
  • 在本项目中通过对话id标识用户的每次与cahtgpt的交互,并且将该对话下的所有内容保存在数据库中实现了对话的长久保存