1) 技术选型
webssh需要实时数据交互,选择长连接的WebSocket,为了开发的方便,框架选用SpringBoot,另外还自己了解了Java用户连接ssh的mina sshd和实现前端shell页面的xterm.js
2)添加maven依赖
<!-- Mina sshd 支持 -->
<dependency>
<groupId>org.apache.sshd</groupId>
<artifactId>sshd-core</artifactId>
<version>2.9.2</version>
</dependency>
<!-- WebSocket 支持 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
<version>2.7.6</version>
</dependency>
3) websocket配置
package cn.cloud.common.config;
import cn.cloud.common.handler.WebSSHWebSocketHandler;
import cn.cloud.common.interceptor.WebSocketInterceptor;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import javax.annotation.Resource;
@Configuration
@EnableWebSocket
public class WebSSHWebSocketConfig implements WebSocketConfigurer {
@Resource
private WebSSHWebSocketHandler webSSHWebSocketHandler;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry webSocketHandlerRegistry) {
//socket通道
//指定处理器和路径
webSocketHandlerRegistry.addHandler(webSSHWebSocketHandler, "/ws/webssh")
.addInterceptors(new WebSocketInterceptor())
.setAllowedOrigins("*");
}
}
4) websocket处理器配置
package cn.cloud.common.handler;
import cn.cloud.common.pojo.OperateConstant;
import cn.cloud.common.service.WebSSHService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.*;
import javax.annotation.Resource;
import java.io.IOException;
/**
* WebSSH的WebSocket处理器
*/
@Component
public class WebSSHWebSocketHandler implements WebSocketHandler {
@Resource
private WebSSHService webSSHService;
@Resource
private StringRedisTemplate stringRedisTemplate;
private final Logger LOGGER = LoggerFactory.getLogger(WebSSHWebSocketHandler.class);
/**
* 用户连接上WebSocket的回调
*/
@Override
public void afterConnectionEstablished(WebSocketSession webSocketSession) {
LOGGER.info("与{}建立websocket连接", webSocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY));
// 调用初始化ssh连接
webSSHService.initConnection(webSocketSession);
}
/**
* 收到消息的回调-前端传过来command之后调用
*/
@Override
public void handleMessage(WebSocketSession webSocketSession, WebSocketMessage<?> webSocketMessage) throws Exception {
if (webSocketMessage instanceof TextMessage) {
// 处理前端消息
webSSHService.commandHandler(((TextMessage) webSocketMessage).getPayload(), webSocketSession);
} else {
LOGGER.error("Unexpected WebSocket message type: " + webSocketMessage);
}
}
/**
* 出现错误的回调
*/
@Override
public void handleTransportError(WebSocketSession webSocketSession, Throwable throwable) throws Exception {
LOGGER.error("数据传输错误");
}
/**
* 连接关闭的回调
*/
@Override
public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus closeStatus) throws IOException {
LOGGER.info("与{}断开websocket连接", webSocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY));
// 关闭连接
webSSHService.closeConnection(webSocketSession);
// websocket连接关闭后ip限制连接数随之变化
updateIpCount(webSocketSession);
}
private void updateIpCount(WebSocketSession webSocketSession) {
String ip = String.valueOf(webSocketSession.getAttributes().get(OperateConstant.IP));
int count = Integer.parseInt(String.valueOf(stringRedisTemplate.opsForHash().get(OperateConstant.IP, ip)));
stringRedisTemplate.opsForHash().put(OperateConstant.IP, ip, count - 1);
}
@Override
public boolean supportsPartialMessages() {
return false;
}
}
5) websocket拦截器配置
package cn.cloud.common.interceptor;
import cn.cloud.common.pojo.OperateConstant;
import cn.cloud.common.util.RedisUtil;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import javax.servlet.http.HttpServletRequest;
import java.util.Map;
import java.util.UUID;
public class WebSocketInterceptor implements HandshakeInterceptor {
private static final int MAX_REQUESTS_PER_SECOND = 10;
private RedisUtil redisUtil = new RedisUtil();
@Override
public boolean beforeHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Map<String, Object> map) throws Exception {
if (serverHttpRequest instanceof ServletServerHttpRequest) {
ServletServerHttpRequest request = (ServletServerHttpRequest) serverHttpRequest;
HttpServletRequest servletRequest = request.getServletRequest();
String ip = getIpAddress(servletRequest);
if (isLimitExceededRedis(ip)) {
return false;
}
// 当某个 IP 的请求数超过指定的闽值时则拒绝建立websocket链接return false;
updateIpRequestCountRedis(ip);
// 生成一个UUID
String user = UUID.randomUUID().toString().replace("-", "");
// 将uuid放到websocket session中
map.put(OperateConstant.USER_UUID_KEY, user);
// 将ip放到websocket session中
map.put(OperateConstant.IP, ip);
return true;
} else {
return false;
}
}
@Override
public void afterHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Exception e) {}
private String getIpAddress(HttpServletRequest request) {
String ip = request.getHeader("X-Forwarded-For");
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("X-Real-Ip");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-Ip");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-Ip");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return ip;
}
private boolean isLimitExceededRedis(String ip) {
// 存在redis 中这样后续断开链接的时候可以直接读值
if (redisUtil.hget(OperateConstant.IP, ip) == null) {
redisUtil.hput(OperateConstant.IP, ip, 0);
redisUtil.expire(OperateConstant.IP, 24 * 60 * 60);
}
return Integer.parseInt(String.valueOf(redisUtil.hget(OperateConstant.IP, ip))) > MAX_REQUESTS_PER_SECOND;
}
private void updateIpRequestCountRedis(String ip) {
redisUtil.hput(OperateConstant.IP, ip, Integer.parseInt(String.valueOf(redisUtil.hget(OperateConstant.IP, ip))) + 1);
}
}
5)mina sshd + websocket 核心业务逻辑实现
package cn.cloud.common.service.impl;
import cn.cloud.common.pojo.OperateConstant;
import cn.cloud.common.pojo.WebSSHConfig;
import cn.cloud.common.pojo.WebSSHData;
import cn.cloud.common.pojo.WebSSHInfo;
import cn.cloud.common.service.WebSSHService;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.lang.StringUtils;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.client.channel.ChannelShell;
import org.apache.sshd.client.channel.ClientChannelEvent;
import org.apache.sshd.client.future.ConnectFuture;
import org.apache.sshd.client.session.ClientSession;
import org.apache.sshd.common.keyprovider.FileKeyPairProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* WebSSH业务逻辑实现
*/
@Service
public class SSHServiceImpl implements WebSSHService {
// 存放ssh连接信息的map
private static final Map<String, Object> sshMap = new ConcurrentHashMap<>();
private final Logger LOGGER = LoggerFactory.getLogger(SSHServiceImpl.class);
// 线程池
private final ExecutorService executorService = Executors.newCachedThreadPool();
@Override
public void initConnection(WebSocketSession webSocketSession) {
try {
SshClient sshClient = SshClient.setUpDefaultClient();
sshClient.open();
WebSSHInfo webSSHInfo = new WebSSHInfo();
webSSHInfo.setSshClient(sshClient);
webSSHInfo.setWebSocketSession(webSocketSession);
String uuid = String.valueOf(webSocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY));
// 将这个ssh连接信息放入map中
sshMap.put(uuid, webSSHInfo);
} catch (Exception e) {
LOGGER.info(e.getMessage());
}
}
@Override
public void commandHandler(String buffer, WebSocketSession webSocketSession) throws IOException {
ObjectMapper objectMapper = new ObjectMapper();
WebSSHData webSSHData = null;
try {
webSSHData = objectMapper.readValue(buffer, WebSSHData.class);
} catch (IOException e) {
LOGGER.error("WebSSHData Json转换异常:{}", e.getMessage());
// 主动向前端推送msg
sendMessage(webSocketSession, "Connection-Closed".getBytes());
return;
}
// uuid
String userId = String.valueOf(webSocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY));
// 找到刚才存储的ssh连接对象
WebSSHInfo webSSHInfo = (WebSSHInfo) sshMap.get(userId);
if (webSSHInfo == null) {
return;
}
// connect
if (OperateConstant.WEBSSH_OPERATE_CONNECT.equals(webSSHData.getOperate())) {
// 启动线程异步处理
WebSSHData finalWebSSHData = webSSHData;
executorService.execute(new Runnable() {
@Override
public void run() {
try {
connectToSSH(webSSHInfo, finalWebSSHData, webSocketSession);
} catch (Exception e) {
LOGGER.error("connect to ssh error : {}", e.getMessage());
closeConnection(webSocketSession);
try {
// 主动向前端推送msg
sendMessage(webSocketSession, "Connection-Refused".getBytes());
} catch (Exception ee) {
ee.printStackTrace();
}
}
}
});
}
// command
else if (OperateConstant.WEBSSH_OPERATE_COMMAND.equals(webSSHData.getOperate())) {
String command = webSSHData.getCommand();
try {
transToSSH(webSSHInfo.getChannelShell(), command);
} catch (Exception e) {
LOGGER.error("trans to ssh error : {}", e.getMessage());
closeConnection(webSocketSession);
// 用户登录设备后如果长时间没有进行操作,可以配置此命令将长时间连接始终处于空闲状态,系统将自动断开该连接。
sendMessage(webSocketSession, "Connection-IdleTimeout".getBytes());
}
} else {
LOGGER.error("不支持的操作");
closeConnection(webSocketSession);
}
}
@Override
public void closeConnection(WebSocketSession session) {
String userId = String.valueOf(session.getAttributes().get(OperateConstant.USER_UUID_KEY));
WebSSHInfo webSSHInfo = (WebSSHInfo) sshMap.get(userId);
if (webSSHInfo != null) {
// 断开shell连接
if (webSSHInfo.getChannelShell() != null && !webSSHInfo.getChannelShell().isClosed()) {
webSSHInfo.getChannelShell().close(false);
LOGGER.info("ChannelShell Closed...");
}
// 断开exec连接
if (webSSHInfo.getChannelExec() != null) {
webSSHInfo.getChannelExec().close(false);
LOGGER.info("ChannelExec Closed...");
}
if (webSSHInfo.getSshClient() != null && !webSSHInfo.getSshClient().isClosed()) {
webSSHInfo.getSshClient().close(false);
LOGGER.info("SshClient Closed...");
}
//map中移除
sshMap.remove(userId);
}
}
/**
* 连接到终端
*/
private void connectToSSH(WebSSHInfo webSSHInfo, WebSSHData webSSHData, WebSocketSession webSocketSession) throws Exception {
// verify Session
ConnectFuture verifySession = webSSHInfo.getSshClient()
.connect(webSSHData.getUsername(), webSSHData.getHost(), webSSHData.getPort())
.verify(WebSSHConfig.connectTimeout);
if (!verifySession.isConnected()) {
LOGGER.error("Session connect failed after {} mill seconds", WebSSHConfig.connectTimeout);
throw new Exception(
"Session connect failed after " + WebSSHConfig.connectTimeout + " mill seconds.");
}
ClientSession clientSession = verifySession.getSession();
if (OperateConstant.KEYPAIR.equalsIgnoreCase(webSSHData.getAuthType())) {
Path pathPrivate = null;
Path pathPublic = null;
if (StringUtils.isNotBlank(webSSHData.getPrivateKey())) {
pathPrivate = new File(webSSHData.getPrivateKey()).toPath();
}
if (StringUtils.isNotBlank(webSSHData.getPublicKey())) {
pathPublic = new File(webSSHData.getPublicKey()).toPath();
}
if (pathPrivate != null || pathPublic != null) {
clientSession.addPublicKeyIdentity(new FileKeyPairProvider(Stream.of(pathPrivate, pathPublic).filter(Objects::nonNull).collect(Collectors.toList())).loadKey(clientSession, webSSHData.getKeypairType()));
}
} else if (OperateConstant.PASSWORD.equalsIgnoreCase(webSSHData.getAuthType())) {
clientSession.addPasswordIdentity(webSSHData.getPassword());
} else {
throw new Exception("Unknown ssh auth type: " + webSSHData.getAuthType());
}
// authentication
clientSession.auth().verify(WebSSHConfig.authTimeout);
sendMessage(webSocketSession, "Authentication-Success".getBytes());
ChannelShell cs = clientSession.createShellChannel();
cs.setRedirectErrorStream(true);
cs.open();
cs.waitFor(Collections.singletonList(ClientChannelEvent.CLOSED), WebSSHConfig.executeTimeout);
webSSHInfo.setChannelShell(cs);
//读取终端返回的信息流
InputStream out = cs.getInvertedOut();
try {
//循环读取
byte[] buffer = new byte[1024];
int i = 0;
//如果没有数据来,线程会一直阻塞在这个地方等待数据。
while ((i = out.read(buffer)) != -1) {
sendMessage(webSocketSession, Arrays.copyOfRange(buffer, 0, i));
}
} finally {
// 断开连接后关闭会话-channel也随之关闭
clientSession.close();
if (clientSession.isClosed()) {
LOGGER.info("clientSession closed...");
}
if (!cs.isClosed()) {
cs.close();
}
if (out != null) {
out.close();
}
}
}
private void transToSSH(ChannelShell channel, String command) throws IOException {
if (channel != null) {
channel.getInvertedIn().write(command.getBytes());
channel.getInvertedIn().flush();
}
}
private void sendMessage(WebSocketSession session, byte[] buffer) throws IOException {
session.sendMessage(new TextMessage(buffer));
}
private void startClientSessionHeartCheck(ClientSession clientSession, WebSocketSession websocketSession, ChannelShell channelshell) {
Thread thread = new Thread(() -> {
if (clientSession != null) {
while (clientSession.isOpen()) {
LOGGER.info(websocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY) + " clientSession is normal");
try {
Thread.sleep(1000 * 60 * 2);
} catch (Exception e) {
LOGGER.error("心跳检测异常:", e);
}
// 停止线程
if (clientSession.isClosed() || clientSession.isClosed()) {
// 告知前端session被关闭了
try {
sendMessage(websocketSession, "Connection-closed".getBytes());
} catch (IOException ee) {
ee.printStackTrace();
}
Thread.currentThread().interrupt();
}
}
} else if (channelshell != null) {
while (channelshell.isOpen()) {
LOGGER.info(websocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY) + " clientSession is normal");
try {
Thread.sleep(1000 * 60 * 2);
} catch (Exception e) {
LOGGER.error("心跳检测异常:", e);
}
// 停止线程
if (channelshell.isClosed() || channelshell.isClosed()) {
// 告知前端session被关闭了
try {
sendMessage(websocketSession, "Connection-closed".getBytes());
} catch (IOException ee) {
ee.printStackTrace();
}
Thread.currentThread().interrupt();
}
}
}
});
}
}
其中涉及到的pojo及util
@1 OperateConstant
package cn.cloud.common.pojo;
public interface OperateConstant {
/**
* 随机生成uuid的key名
*/
String USER_UUID_KEY = "user_uuid";
/**
* 发送指令:连接
*/
String WEBSSH_OPERATE_CONNECT = "connect";
/**
* 发送指令:命令
*/
String WEBSSH_OPERATE_COMMAND = "command";
String PASSWORD = "PASSWORD";
String KEYPAIR = "KEYPAIR";
String IP = "websocket_ip";
}
@2 WebSSHConfig
package cn.cloud.common.pojo;
public interface WebSSHConfig {
/**
* ssh session 连接超时时间
*/
Long connectTimeout = 5000L;
/**
* 认证超时时间
*/
Long authTimeout = 5000L;
/**
* 命令执行超时时间
*/
Long executeTimeout = 3000L;
/**
* ssh连接池最大空闲session个数
*/
Integer maxIdle = 8;
/**
* ssh连接池最大活跃session个数
*/
Integer maxTotal = 15;
/**
* ssh连接池最少空闲session个数
*/
Integer minIdle = 2;
/**
* 是否在空闲状态测试session
*/
Boolean testWhileIdle = true;
/**
* 是否在创建session时测试
*/
Boolean testOnCreate = false;
/**
* 是否在获取session时测试
*/
Boolean testOnBorrow = false;
/**
* 是否在归还session时测试
*/
Boolean testOnReturn = false;
/**
* session最小生存时间-用户一直无交互则关闭连接腾出空间给其他会话
*/
Long minEvictableIdleTimeMillis = 300000L;
/**
* session测试间隔
*/
Long timeBetweenEvictionRunsMillis = 30000L;
/**
* session池活跃session到达最大值时,其它获取session的操作是否阻塞
*/
Boolean blockWhenExhausted = true;
/**
* 获取session阻塞时间
*/
Long maxWaitMillis = 30000L;
}
@3 WebSSHData
package cn.cloud.common.pojo;
import org.apache.commons.io.FilenameUtils;
/**
* 前后端交互data
*/
public class WebSSHData {
//操作-connect、command
private String operate;
private String host;
//端口号默认为22
private Integer port = 22;
private String username;
// 认证类型:PASSWORD/KEYPAIR
private String authType = "PASSWORD";
private String password;
private String command = "";
/**
* 密钥对类型, 默认ssh-rsa
*/
private String keypairType = "ssh-rsa";
/**
* 公钥路径,默认~/.ssh/id_rsa.pub
*/
private String publicKey = FilenameUtils.concat(System.getProperty("user.home"), ".ssh/id_rsa.pub");
/**
* 私玥路径, 默认~/.ssh/id_rsa
*/
private String privateKey = FilenameUtils.concat(System.getProperty("user.home"), ".ssh/id_rsa");
public String getOperate() {
return operate;
}
public void setOperate(String operate) {
this.operate = operate;
}
public String getHost() {
return host;
}
public void setHost(String host) {
this.host = host;
}
public Integer getPort() {
return port;
}
public void setPort(Integer port) {
this.port = port;
}
public String getUsername() {
return username;
}
public void setUsername(String username) {
this.username = username;
}
public String getPassword() {
return password;
}
public void setPassword(String password) {
this.password = password;
}
public String getCommand() {
return command;
}
public void setCommand(String command) {
this.command = command;
}
public String getAuthType() {
return authType;
}
public void setAuthType(String authType) {
this.authType = authType;
}
public String getKeypairType() {
return keypairType;
}
public void setKeypairType(String keypairType) {
this.keypairType = keypairType;
}
public String getPublicKey() {
return publicKey;
}
public void setPublicKey(String publicKey) {
this.publicKey = publicKey;
}
public String getPrivateKey() {
return privateKey;
}
public void setPrivateKey(String privateKey) {
this.privateKey = privateKey;
}
}
@4 WebSSHInfo
package cn.cloud.common.pojo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.client.channel.ChannelExec;
import org.apache.sshd.client.channel.ChannelShell;
import org.springframework.web.socket.WebSocketSession;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class WebSSHInfo {
// ssh客户端
private SshClient sshClient;
// ws连接
private WebSocketSession webSocketSession;
// linux-管道channel
private ChannelShell channelShell;
// win-管道channel
private ChannelExec channelExec;
}
@5 RedisUtil
package cn.cloud.common.util;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import java.util.concurrent.TimeUnit;
@Component
public class RedisUtil {
@Autowired
public void setRedisTemplate(RedisTemplate redisTemplate) {
this.redisTemplate = redisTemplate;
}
private static RedisTemplate<String, Object> redisTemplate;
public Object hget(String key, String item){
return redisTemplate.opsForHash().get(key, item);
}
public void hput(String key, String item, Object value){
redisTemplate.opsForHash().put(key, item, value);
}
public void expire(String key, long time){
redisTemplate.expire(key, time, TimeUnit.SECONDS);
}
}
简单的xterm案例
xterm.js是一个基于WebSocket的容器,它可以帮助我们在前端实现命令行的样式。就像是我们平常再用SecureCRT或者XShell连接服务器时一样。
下面是官网上的入门案例:
<!doctype html>
<html>
<head>
<link rel="stylesheet" href="node_modules/xterm/css/xterm.css" />
<script src="node_modules/xterm/lib/xterm.js"></script>
</head>
<body>
<div id="terminal"></div>
<script>
var term = new Terminal();
term.open(document.getElementById('terminal'));
term.write('Hello from \x1B[1;3;31mxterm.js\x1B[0m $ ')
</script>
</body>
</html>