c++ websocket 协议分析与实现
前言
网上有很多第三方库,nopoll,uwebsockets,libwebsockets,都喜欢回调或太复杂,个人只需要在后端用,所以手动写个;
1:环境
ubuntu18
g++(支持c++11即可)
第三方库:jsoncpp,openssl
2:安装
jsoncpp 读取json 配置文件 用 自动安装 网上一堆教程
openssl 如果系统没带,需要安装下 sudo apt-get install openssl 一般是1.1版本 够用了
3:websocket server
1> 主要就用到 epoll 模式(io_uring 更好点,就是内核版本要高点),3个进程 主进程作为监控进程 2个子进程 一个network进程 一个 logic 进程
2> 子进程间 主要通过共享内存 加socketpair 通知 交换数据
3>websocket 握手协议 先看例子
上前端代码 html
<!DOCTYPE HTML>
<html>
<head>
<meta http-equiv="content-type" content="text/html" />
<meta name="author" content="https://github.com/" />
<title>websocket test</title>
<script>
var socket;
function Connect(){
try{
socket=new WebSocket('ws://192.168.1.131:9000'); //'ws://192.168.1.131:9000');
}catch(e){
alert('error catch'+e);
return;
}
socket.onopen = sOpen;
socket.onerror = sError;
socket.onmessage= sMessage;
socket.onclose= sClose;
}
function sOpen(){
alert('connect success!');
}
function sError(e){
alert("[error] " + e);
//writeObj(e);
}
function sMessage(msg){
if(typeof(msg) == 'object'){
//let json = JSON.stringify(msg);
//console.log('server says:' +json);
//writeObj(msg);
if(msg.data){ //msg.hasOwnProperty('data')
console.log('server says'+msg.data);
}else{
writeObj(msg);
//console.log('[1]server says'+msg.data);
}
}else{
alert('server says:' + msg);
}
}
function sClose(e){
alert("connect closed:" + e.code);
}
function Send(){
socket.send(document.getElementById("msg").value);
}
function Close(){
socket.close();
}
function writeObj(obj){
var description = "";
for(var i in obj){
var property=obj[i];
description+=i+" = "+property+"\n";
}
alert(description);
}
</script>
</head>
<body>
<input id="msg" type="text">
<button id="connect" onclick="Connect();">Connect</button>
<button id="send" onclick="Send();">Send</button>
<button id="close" onclick="Close();">Close</button>
</body>
</html>
在Microsoft Edge 运行结果
golang 前端代码如下
package main
import (
"fmt"
"golang.org/x/net/websocket"
"log"
"strings"
)
var origin = "http://192.168.1.131:9000"
//var url = "ws://192.168.1.131:7077/websocket"
var url = "wss://192.168.1.131:9000/websocket"
func main() {
ws, err := websocket.Dial(url, "", origin)
if err != nil {
log.Fatal(err)
}
// send text frame
var message2 = "hello"
websocket.Message.Send(ws, message2)
fmt.Printf("Send: %s\n", message2)
// receive text frame
var message string
websocket.Message.Receive(ws, &message)
fmt.Printf("Receive: %s\n", message)
for true {
fmt.Printf("please input string:")
var inputstr string
fmt.Scan(&inputstr)
if(strings.Compare(inputstr,"quit") == 0){
break
}else{
websocket.Message.Send(ws, inputstr)
fmt.Printf("Send: %s\n", inputstr)
var output string
websocket.Message.Receive(ws, &output)
fmt.Printf("Receive: %s\n", output)
}
}
ws.Close()//关闭连接
fmt.Printf("client exit\n")
}
测试结果
server 握手代码
int c_WebSocket::recv_handshake() {
int n, len, ret;
uint32_t pos = 0;
uint16_t u16msglen = 0;
const bool bssl = isSsl();
if (bssl) {
n = SSL_read(m_ssl, m_recv_buf + m_recv_pos, m_recv_buf_size - m_recv_pos);
}
else {
n = recv(m_fd, m_recv_buf + m_recv_pos, m_recv_buf_size - m_recv_pos, 0);
}
if (n > 0) {
if (m_is_closed) {
m_recv_pos = 0;
return true;
}
m_recv_pos += n;
m_recv_buf[m_recv_pos] = 0;
printf("recv %d handshake %s len=%d recvlen=%d", m_id, m_recv_buf,n,m_recv_pos);
// goto READ;
// int32_t pos = 0;
for (;;) {
if (m_recv_pos >= c_u16MinHandShakeSize) //消息头
{
// \r 0x0D \n 0xA
const int nRet = fetch_http_info((char*)m_recv_buf, m_recv_pos);
if (1 == nRet) { //ok
// f(strcasecmp(header, "Sec-Websocket-Protocol") == 0)
// conn->accepted_protocol = value;
//
std::map<std::string, std::string>::iterator it1 = m_map_header.find("Sec-WebSocket-Key"); //一般固定24个字节
std::map<std::string, std::string>::iterator it2 = m_map_header.find("Sec-WebSocket-Protocol");//
int map_size = m_map_header.size();
if (it1 != m_map_header.end()) {
printf("key=%s value=%s %d \n", it1->first.c_str(), it1->second.c_str(), map_size);
}
else {
return -1;
}
char acceptvalue[1024] = { 0, };
uint32_t value_len = it1->second.length();
memcpy(acceptvalue, it1->second.c_str(), value_len);
// memcpy(accept_key, websocket_key, key_length);
#define MAGIC_KEY "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
memcpy(acceptvalue + value_len, MAGIC_KEY, 36);
acceptvalue[value_len + 36] = 0;
unsigned char md[SHA_DIGEST_LENGTH];
SHA1((unsigned char*)acceptvalue, strlen(acceptvalue), md);
std::string server_key = base64_encode(reinterpret_cast<const unsigned char*>(md), SHA_DIGEST_LENGTH);
char rep_handshake[1024] = { 0, };
memset(rep_handshake, 0, sizeof(rep_handshake));
if (it2 != m_map_header.end()) {
//子协议
char szsub_protocol[512] = { 0, };
std::size_t pos_t = it2->second.find(",");
if (pos_t != std::string::npos && pos_t < 512) {
memcpy(szsub_protocol, it2->second.c_str(), pos_t);
sprintf(rep_handshake, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: %s\r\nSec-WebSocket-Protocol: %s\r\n\r\n",
server_key.c_str(), szsub_protocol);
}
else {
return -1;
}
}
else {
sprintf(rep_handshake, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: %s\r\n\r\n",
server_key.c_str());
}
m_recv_pos = 0;
set_handshake_ok(); //握手完毕
send_pkg((uint8_t*)rep_handshake, strlen(rep_handshake));
break;
}
else {
printf("fetch_http_info nRet=%d \n ", nRet);
return -2;
}
}
else {
break;
}
}//end for
if (pos != 0 && m_recv_pos > 0) {
memcpy(m_recv_buf, m_recv_buf + pos, m_recv_pos);
}
}
else {
if (bssl) {
//EAGAIN或EWOULDBLOCK二者应该是一样的,对应的错误码是11
//ret = SSL_get_error(m_ssl, n);//int ssl_error = SSL_get_verify_result(ssl);
//if (SSL_ERROR_WANT_READ == ret || SSL_ERROR_WANT_WRITE == ret) return true;
SSL_ERROR_NONE == ret n>0 ok other error
//printf("SSL_get_error(%d %d %d)\n", n, ret, errno);//SSL_get_error(-1 1,11)
//return false;
int ret = ssl_check_error(m_ssl, n);
printf("SSL_get_error(%d %d %d %d)\n", n, ret, errno, m_recv_pos);//SSL_get_error(-1 -1 11)
if (ret == -2) {
return true;
}
if (errno == EAGAIN || errno == EINTR) {
return true;
}
return false;
}
else {
if (n == 0)
return false;
if (errno == EAGAIN || errno == EINTR) {
return true;
}
else {
return false;
}
}
}
return true;
}
int c_WebSocket::fetch_http_info(char* recv_buf, const uint32_t buf_len) {
// \r 0x0D \n 0xA
const uint32_t max_len = 1024;
char bufline[max_len] = { 0, };
uint32_t bufpos = 0;
uint8_t ustate = 0;
//std::map<std::string, std::string> map_header;
char szsubhead[max_len] = { 0, };
for (uint32_t i = 0; i < buf_len; i++) {
bufline[bufpos++] = recv_buf[i];
if (bufpos >= max_len) return -1;
if (recv_buf[i] == 0x0A) {
bufline[bufpos] = 0;
if (0 == ustate) { //GET /websocket HTTP/1.1
if (ws_strncmp(bufline, "GET ", 4)) {
if (bufpos < 15) {
return -1;
}
//the get url must have a minimum size: GET / HTTP/1.1\r\n 16 (15 if only \n)
//return nopoll_cmp (buffer + iterator, "HTTP/1.1\r\n") || nopoll_cmp (buffer + iterator, "HTTP/1.1\n");
// char* pos1 = strstr(bufline, "HTTP/1.1\r\n");
/// char* pos2 = strstr(bufline, "HTTP/1.1\n");
int32_t nhttp1_1_pos = (int32_t)(bufpos - 2 - 8); //HTTP/1.1 8BYTE HTTP/1.1\r\n //H的位置
if (bufline[bufpos - 2] != '\r') {
nhttp1_1_pos += 1;//HTTP/1.1\n
}
const int32_t ucopylen = nhttp1_1_pos - 1 - 4; // -1 http前的空格 -4 是GET空格 的长度
if (ucopylen > 0 && ucopylen < 128) { // /websocket 长度
memcpy(szsubhead, bufline + 4, ucopylen);
szsubhead[ucopylen] = 0;
}
else {
return -3;
}
}
else {
return -1;
}
ustate = 1;
bufpos = 0;
}
else {
//if (buffer_size == 2 && nopoll_ncmp (buffer, "\r\n", 2))
if (2 == bufpos && ws_strncmp(bufline, "\r\n", 2)) {//握手协议结尾
ustate == 2;
//检查最基本的握手协议
// Connection: Upgrade
// Host: 192.168.1.2 : 8080
// Sec - WebSocket - Key : 821VqJT7EjnceB8m7mbwWA ==
// Sec - WebSocket - Version : 13
// Upgrade : websocket
// ensure we have all minumum data
std::map<std::string, std::string>::iterator it1 = m_map_header.find("Upgrade");//固定 websocket
std::map<std::string, std::string>::iterator it2 = m_map_header.find("Connection"); //固定 Upgrade
std::map<std::string, std::string>::iterator it3 = m_map_header.find("Sec-WebSocket-Version");
std::map<std::string, std::string>::iterator it4 = m_map_header.find("Sec-WebSocket-Key"); //一般固定24个字节
const bool bcheckOrigin = false; //浏览器必须有,其他可能没有
std::map<std::string, std::string>::iterator it5 = m_map_header.find("Origin"); //
if (it1 != m_map_header.end() && ws_strncmp(it1->second.c_str(), "websocket", 9) &&
it2 != m_map_header.end() && ws_strncmp(it2->second.c_str(), "Upgrade", 7) &&
it3 != m_map_header.end() && ws_strncmp(it3->second.c_str(), "13", 2) &&
it4 != m_map_header.end() && it4->second.length() > 12 &&
(bcheckOrigin == (bcheckOrigin && it5 != m_map_header.end()))) { //其他字段忽略了
return 1;
}
return -6;
}
else {
char* pos1 = strstr(bufline, ":");
if (pos1 != nullptr) {
// std::string key = header.substr(0, end);
// std::string value = header.substr(end + 2);
int32_t key_len = pos1 - bufline;
int32_t value_len = bufpos - key_len - 1 - 1;
if (key_len > 1 && value_len > 1) {
bufline[key_len] = 0;
std::string key = bufline;
if (bufline[bufpos - 1] == '\n') {
bufline[bufpos - 1] = 0;
// --value_len;
}
if (bufline[bufpos - 2] == '\r') {
bufline[bufpos - 2] = 0;
// --value_len;
}
std::string value = &bufline[key_len + 2];
m_map_header[key] = value;
}
else {
return -4;
}
}
else {
return -4;
}
bufpos = 0;
}
}
}
}
return 0;
}
握手请求与回复
Origin: http://192.168.1.131:9000 : 原始的协议和URL
Connection: Upgrade:表示要升级协议了
Upgrade: websocket:表示要升级到 WebSocket 协议;
Sec-WebSocket-Version: 13:表示 WebSocket 的版本。如果服务端不支持该版本,需要返回一个 Sec-WebSocket-Versionheader ,里面包含服务端支持的版本号
Sec-WebSocket-Key:与后面服务端响应首部的 Sec-WebSocket-Accept 是配套的,提供基本的防护,比如恶意的连接,或者无意的连接
服务端响应协议升级
HTTP/1.1 101 Switching Protocols
Connection:Upgrade
Upgrade: websocket
Sec-WebSocket-Accept: Oy4NRAQ13jhfONC7bP8dTKb4PTU=
HTTP/1.1 101 Switching Protocols: 状态码 101 表示协议切换
Sec-WebSocket-Accept:根据客户端请求首部的 Sec-WebSocket-Key 计算出来
将 Sec-WebSocket-Key 跟 258EAFA5-E914-47DA-95CA-C5AB0DC85B11 拼接。
通过 SHA1 计算出摘要,并转成 base64 字符串。计算公式如下:
Base64(sha1(Sec-WebSocket-Key + 258EAFA5-E914-47DA-95CA-C5AB0DC85B11))
Connection:Upgrade:表示协议升级
Upgrade: websocket:升级到 websocket 协议
4:接受数据帧
代码如下
int c_WebSocket::recv_dataframe() {
int n, len,ret;
// uint32_t pos = 0;
uint16_t u16msglen = 0;
const bool bssl = isSsl();
if (isSsl()) {
n = SSL_read(m_ssl, m_recv_buf + m_recv_pos, m_recv_buf_size - m_recv_pos);
}
else {
n = recv(m_fd, m_recv_buf + m_recv_pos, m_recv_buf_size - m_recv_pos, 0);
}
// n = recv(m_fd, m_recv_buf + m_recv_pos, m_recv_buf_size - m_recv_pos, 0);
if (n > 0) {
if (m_is_closed) {
m_recv_pos = 0;
return true;
}
m_recv_pos += n;
// goto READ;
int32_t pos = 0;
for (;;) {
if (m_recv_pos >= c_u16MsgHeadSize) //消息头 2个字节
{
int t = parse_dataframe(m_recv_buf + pos, m_recv_pos);
if (t < 0) return false;
else if (0 == t) break;
pos += t;
m_recv_pos -= t; u16msglen + c_u16MsgHeadSize; //sub one packet len
// pos += u16msglen + c_u16MsgHeadSize;
}
else {
break;
}
}//end for
if (pos != 0 && m_recv_pos > 0) {
memcpy(m_recv_buf, m_recv_buf + pos, m_recv_pos);
}
if (pos > 0) { //收到消息的时间
m_lastrecvmsg = get_reactor().getCurSecond();
}
}
else {
if (bssl) {
//ret = SSL_get_error(m_ssl, n);
//if (SSL_ERROR_WANT_READ == ret || SSL_ERROR_WANT_WRITE == ret) return true;
SSL_ERROR_NONE == ret n>0 ok other error
//return false;
int ret = ssl_check_error(m_ssl, n);
if (ret == -2) {
return true;
}
if (errno == EAGAIN || errno == EINTR) {
return true;
}
return false;
}
else {
if (n == 0)
return false;
if (errno == EAGAIN || errno == EINTR) {
return true;
}
else {
return false;
}
}
}
return true;
}
处理数据帧,把payload 转发到 logic进程,由logic去处理
一帧数据长度超过 65k 直接抛弃,这里可以根据实际需求设定长度
int c_WebSocket::parse_dataframe(uint8_t* recv_buf, const uint32_t buf_len) {
/* get fin bytes */
#define FAIL_AND_CLOSE -1 //接受失败OR 关闭
#define NEED_CLOSE -1 //需要关闭
#define CONTINUE_RECV_MSG 0 //消息不完整,需要继续接受
#define ONE_MSG_LENGHT(X) X //接受完一条消息,消息总长度为X
#define MASK_LEN 4 //掩码长度
#define PAYLOAD_LENGTH_126 2 //126 额外2个字节
uint8_t t_fin = msg_get_bit(recv_buf[0], 7);
if (t_fin == 0) return FAIL_AND_CLOSE;
uint8_t t_code = recv_buf[0] & 0x0F;
uint8_t t_masked = msg_get_bit(recv_buf[1], 7);
uint16_t t_payload_size = recv_buf[1] & 0x7F;
if (t_masked == 0) return FAIL_AND_CLOSE;
if (t_code == CLOSE_FRAME) { //关闭帧
return NEED_CLOSE;
}
uint16_t t_playload_pos = c_u16MsgHeadSize;
if (t_payload_size == 126) {
if (buf_len < c_u16MsgHeadSize + PAYLOAD_LENGTH_126) return CONTINUE_RECV_MSG;
uint16_t length = 0;
memcpy(&length, recv_buf + c_u16MsgHeadSize, PAYLOAD_LENGTH_126);
if (length > MAX_PAYLOAD_REQ) return FAIL_AND_CLOSE; //消息过长
if (buf_len < c_u16MsgHeadSize + PAYLOAD_LENGTH_126 + MASK_LEN + length) return CONTINUE_RECV_MSG; //等下此接受 //4 为mask长度,前端发过来必须有
t_payload_size = length;
t_playload_pos += PAYLOAD_LENGTH_126; //
}
else if (t_payload_size == 127) {
return FAIL_AND_CLOSE;
}
else {
if (buf_len < c_u16MsgHeadSize + MASK_LEN + t_payload_size) return CONTINUE_RECV_MSG; //等下此接受
}
memcpy(masking_key_, &recv_buf[t_playload_pos], MASK_LEN);
t_playload_pos += MASK_LEN; //
if (t_code == PONG_FRAME) {
if (m_lastsendping > 0) {
printf("time=[%u]recv PONG_FRAME \n",g_reactor.getCurSecond());
m_lastsendping = 0; //ping消息回复
m_sendpingcount = 0;
}
return ONE_MSG_LENGHT(t_playload_pos + t_payload_size;)
}
if (t_payload_size == 0) {
if (t_code == PING_FRAME) {
// nopoll_conn_send_pong(conn, nopoll_msg_get_payload_size(msg), (noPollPtr)nopoll_msg_get_payload(msg));
// nopoll_msg_unref(msg);
send_data((char*)&recv_buf[t_playload_pos], t_payload_size, PONG_FRAME);
return t_playload_pos + t_payload_size;
}
return FAIL_AND_CLOSE;
}
// char* play_load = (char*)&recv_buf[t_playload_pos];
m_payload_length_ = t_payload_size;
int j = 0;
for (uint i = 0; i < m_payload_length_; i++) {
j = i % 4;
m_payload_[i] = recv_buf[t_playload_pos + i] ^ masking_key_[j];
}
//put to public proc
shm_block_t sb;
sb.fd = m_fd;
sb.id = m_id;
sb.len = t_payload_size;
sb.type = PROTO_BLOCK;
sb.frametype = t_code;
//把数据发送出去
// recv_push(m_u32channel, m_u32pipeindex, &sb, m_recv_buf + pos + c_u16MsgHeadSize, false);
recv_push(m_u32channel, m_u32pipeindex, &sb, (uint8_t*)m_payload_, false);
//int32_t nRet = printf("client recv one complete pack len=%d m_u32pipeindex=%d nRet=%d\n", u16msglen, m_u32pipeindex, nRet);
// m_recv_pos -= u16msglen + c_u16MsgHeadSize; //sub one packet len
// pos += u16msglen + c_u16MsgHeadSize;
return t_playload_pos + t_payload_size;
}
再来个logic 进程处理
void c_Logic::dologic(struct shm_block_t* pblock, uint8_t *buf, bool brecv)
{
//处理收到的逻辑
switch (pblock->type)
{
case CLOSE_BLOCK:
{
}
break;
case PROTO_BLOCK:
{
if (strncmp((char*)buf, "hello", 5) == 0) {
buf[0] = 'H';
buf[1] = 'E';
buf[2] = 'L';
buf[3] = 'L';
buf[4] = 'O';
send_data(pblock,buf,pblock->len, (WebSocketFrameType)pblock->frametype);
}
else {
buf[0] = '_';
send_data(pblock, buf, pblock->len, (WebSocketFrameType)pblock->frametype);
}
}
break;
case CDUMP_BLOCK:
{
}
break;
default:
break;
}
}
发送数据
int send_data(struct shm_block_t* pblock, uint8_t* msg, const uint32_t msglen, WebSocketFrameType ftype) {
const uint32_t MAX_PAYLOAD_SEND = 4 * 1024; //最大发送长度
if (msglen > MAX_PAYLOAD_SEND) return -1;
uint32_t length = msglen;
char header[14];
int header_size;
memset(header, 0, sizeof(header));
const bool bhas_fin = true;
if (bhas_fin) {
msg_set_bit(header, 7);
}
if (ftype >= 0) {
header[0] |= ftype & 0x0f;
}
const bool bhas_mask = false; //服务器发送不需要mask,前端给过来才需要
if (bhas_mask) {
msg_set_bit(header + 1, 7);
}
header_size = 2;
if (length < 126) {
header[1] |= length;
}
else if (length <= 0xFFFF) {
/* set the next header length is at least 65535 */
header[1] |= 126;
header_size += 2;
/* set length into the next bytes */
msg_set_16bit(length, header + 2);
}
else {
//再大的不让发送 //先写上,用不上也没关系
header[1] = 127;
#if defined(WS_64BIT_PLATFORM)
if (length < 0x8000000000000000) {
header[2] = (length & 0xFF00000000000000) >> 56;
header[3] = (length & 0x00FF000000000000) >> 48;
header[4] = (length & 0x0000FF0000000000) >> 40;
header[5] = (length & 0x000000FF00000000) >> 32;
}
#else
// (length < 0x80000000)
header[2] = header[3] = header[4] = header[5] = 0;
#endif
header[6] = (length & 0x00000000FF000000) >> 24;
header[7] = (length & 0x0000000000FF0000) >> 16;
header[8] = (length & 0x000000000000FF00) >> 8;
header[9] = (length & 0x00000000000000FF);
header_size += 8;
}
if (bhas_mask) {
//不写了 //
// msg_set_32bit(mask_value, header + header_size);
// header_size += 4;
}
uint8_t buf[MAX_PAYLOAD_SEND + 14];
memcpy(buf, header, header_size);
memcpy(buf + header_size, msg, msglen);
//send_pkg(buf, msglen + header_size);
//return msglen + header_size;
shm_block_t sb;
sb.fd = pblock->fd;
sb.id = pblock->id;
sb.type = PROTO_BLOCK;
sb.len = msglen + header_size;
sb.frametype = (uint8_t)ftype;
send_push(0, 1, &sb, buf, true);
return 0;
}
5:支持 SSL
先加载证书
bool c_Accept::loadssl(const char* private_key_file, const char* server_crt_file, const char* ca_crt_file) {
m_ctx = SSL_CTX_new(SSLv23_server_method());
if (!m_ctx) { return false; }
//assert(ctx);
// 不校验客户端证书
SSL_CTX_set_verify(m_ctx, SSL_VERIFY_NONE, nullptr);
// 加载CA的证书
if (!SSL_CTX_load_verify_locations(m_ctx, ca_crt_file, nullptr)) {
printf("SSL_CTX_load_verify_locations error!\n");
return false;
}
// 加载自己的证书
if (SSL_CTX_use_certificate_file(m_ctx, server_crt_file, SSL_FILETYPE_PEM) <= 0) {
printf("SSL_CTX_use_certificate_file error!\n");
return false;
}
// 加载私钥
if (SSL_CTX_use_PrivateKey_file(m_ctx, private_key_file, SSL_FILETYPE_PEM) <= 0) {
printf("SSL_CTX_use_PrivateKey_file error!\n");
return false;
}
// 判定私钥是否正确
if (!SSL_CTX_check_private_key(m_ctx)) {
printf("SSL_CTX_check_private_key error!\n");
return false;
}
return true;}
accept 后, ssl = SSL_new(get_ssl_ctx()); 再调用 SSL_accept
bool c_Accept::handle_input()
{
sockaddr_in ip;
socklen_t len;
int cli_fd;
while (1) {
len = sizeof(ip);
cli_fd = accept(m_fd, (sockaddr *)&ip, &len);
if (cli_fd >= 0) {
if ((uint32_t)cli_fd >= get_reactor().max_handler()) {
close(cli_fd);
continue;
}
if (!get_reactor().add_cur_connect(get_max_connect())) {
printf("client max connect is over \n");
close(cli_fd);
return true;
}
SSL* ssl = nullptr;
if (isSsl()) {
ssl = SSL_new(get_ssl_ctx());
if (ssl == nullptr) {
get_reactor().sub_cur_connect();
close(cli_fd);
continue;
}
printf("accept SSL_new \n");
}
c_WebSocket*ts = new (std::nothrow) c_WebSocket();
if (!ts) {
get_reactor().sub_cur_connect();
close(cli_fd);
continue;
}
printf("accept client connect \n");
ts->start(cli_fd, ip,m_u32channel,m_u32pipeindex,ssl);
}
else {
if (errno == EAGAIN || errno == EINTR || errno == EMFILE || errno == ENFILE) {
return true;
}
else {
return false;
}
}
}
}
void c_WebSocket::start(int fd, sockaddr_in& ip, uint32_t channel, uint32_t u32pipeindex,SSL* ssl)
{
m_u32channel = channel;
m_u32pipeindex = u32pipeindex;
m_fd = fd;
m_ip = ip;
//---------------------------------------------
m_lastrecvmsg = g_reactor.getCurSecond();
c_heartbeat::GetInstance().handle_input_modify(fd, m_id, m_lastrecvmsg, m_lastrecvmsg);
set_noblock(m_fd);
m_ssl = ssl;
if (isSsl()) {
printf("ssl client handshake ready\n");
SSL_set_fd(ssl, m_fd);
int code, ret;
int retryTimes = 0;
// uint64_t begin = 0;//Time::SystemTime();
// 防止客户端连接了但不进行ssl握手, 单纯的增大循环次数无法解决问题,
while ((code = SSL_accept(ssl)) <= 0 && retryTimes++ < 100) {
ret = SSL_get_error(ssl, code);
if (ret != SSL_ERROR_WANT_READ) {
printf("ssl accept error. sslerror=%d errno=%d \n", ret,errno); // SSL_get_error(ssl, code));
break;
}
usleep(20 * 1000);//20ms //msleep(1); //这里最多会有2s的等待时间,以后一定要异步
}
if (code != 1) {
handle_fini();
return;
}
printf("ssl client handshake ok (%d)\n", retryTimes);
}
m_recv_buf_size = default_recv_buff_len;
m_recv_buf = (uint8_t*)malloc(m_recv_buf_size);
if (!m_recv_buf) {
handle_fini();
return;
}
//----------------------------------
return;
}
6:心跳检查 10秒(可以自行设定)未收到消息,发送ping,发送2次,没回应 断线
bool c_WebSocket::checcklastmsg(uint32_t t) {
if (m_lastrecvmsg + 10 <= t) {
if (!is_handshake_ok()) return true;
if (m_lastsendping > 0 && m_lastsendping + 10 <= t && m_sendpingcount > 1) {
//disconnect
printf("[%u]ready disconnect \n",t);
return true;
}
else if(m_lastsendping == 0 || (m_sendpingcount > 0 && m_lastsendping+10 <=t)){
//发送ping
m_lastsendping = t;
++m_sendpingcount;
send_ping_frame();
printf("time=%u,%d send ping frame\n", t, m_sendpingcount);
}
}
return false;
}
int c_WebSocket::send_ping_frame() {
uint32_t length = 0;
char header[14];
int header_size;
memset(header, 0, sizeof(header));
const bool bhas_fin = true;
if (bhas_fin) {
msg_set_bit(header, 7);
}
header[0] |= PING_FRAME & 0x0f;
const bool bhas_mask = false; //服务器发送不需要mask,前端给过来才需要
if (bhas_mask) {
msg_set_bit(header + 1, 7);
}
header_size = 2;
if (length < 126) {
header[1] |= length;
}
uint8_t buf[MAX_PAYLOAD_SEND + 14];
memcpy(buf, header, header_size);
// memcpy(buf + header_size, msg, msglen);
send_pkg(buf, header_size);
return header_size;
}
void c_WebSocket::send_pkg(uint8_t* buf, uint32_t len){
//--------------------------------------------------------------
//有上次预留的
uint32_t p = 0;
int n;
if (isSsl()) {
n = SSL_write(m_ssl, buf, len); // 发送响应主体
}
else {
n = send(m_fd, buf, len, 0);
}
if (n > 0) {
if ((uint32_t)n == len) {
//printf("send data len = %d need send len=%d \n",n,len);
return;
}
else {
p = n;
}
}
else {
if (errno != EAGAIN && errno != EINTR) {
handle_error();
return;
}
}
//没发送完,存起来下次再发送,这里自行处理
//----------------------------------------------------------------
}
7:json配置文件读取 jsoncpp API
bool c_JsonReader::read_json_file(const char* jsonfile)
{
#define LISTENIP "listenip"
#define LISTENPORT "listenport"
#define USESSL "usessl"
#define PRIVATEKEYFILE "privatekeyfile"
#define SERVERCRTFILE "servercrtfile"
#define CACRTFILE "cacrtfile"
#define AES128KEYHANDSHAKE "aes128keyhandshake"
#define AES128IV "aes128iv"
#define MAXCONN "maxconn"
#define CHECKHEARTBEAT "checkheartbeat"
#define OPENBLACKWHITEIP "openblackwhiteip"
#define SINGLEIPMAXCONN "singleipmaxconn"
#define MIN(A,B) A<B?A:B
FILE* f = fopen(jsonfile, "rb");
if (f) {
const int buf_size = 4 * 1024;
char buf[buf_size] = { 0, };
memset(buf, 0, sizeof(buf));
size_t n = fread(buf, sizeof(char), buf_size, f);
fclose(f);
if (n < 10) {
printf("read_json_file file length too short \n");
return false;
}
Json::Reader reader;
Json::Value root;
if (reader.parse(buf,root)) {
if (root[LISTENIP].empty() || root[LISTENPORT].empty() || root[MAXCONN].empty() \
|| root[USESSL].empty() || root[AES128KEYHANDSHAKE].empty() || root[AES128IV].empty()) {
printf("read_json_file base fail\n");
return false;
}
const bool busessl = root[USESSL].asBool();
m_chatSerCfg.buseSsl = busessl;
if (busessl) {
const bool bp = root[PRIVATEKEYFILE].empty();
const bool bs = root[SERVERCRTFILE].empty();
const bool bc = root[CACRTFILE].empty();
if (bp || bs || bc) {
printf("read_json_file ssl fail\n");
return false;
}
strncpy(m_chatSerCfg.szprivatekeyfile, root[PRIVATEKEYFILE].asString().c_str(), MIN(root[PRIVATEKEYFILE].asString().length(), ssl_file_len));
strncpy(m_chatSerCfg.szservercrtfile, root[SERVERCRTFILE].asString().c_str(), MIN(root[SERVERCRTFILE].asString().length(), ssl_file_len));
strncpy(m_chatSerCfg.szcacrtfile, root[CACRTFILE].asString().c_str(), MIN(root[CACRTFILE].asString().length(), ssl_file_len));
}
else{
}
m_chatSerCfg.u16maxconn = (uint16_t)root[MAXCONN].asUInt();
strncpy(m_chatSerCfg.szlistenip, root[LISTENIP].asString().c_str(), sizeof(m_chatSerCfg.szlistenip) - 1);
m_chatSerCfg.nlistenport = (int32_t)root[LISTENPORT].asInt();
memcpy(m_chatSerCfg.u8AES128keyhandshake, root[AES128KEYHANDSHAKE].asString().c_str(), root[AES128KEYHANDSHAKE].asString().length());
memcpy(m_chatSerCfg.u8AES128iv, root[AES128IV].asString().c_str(), 16);
{//safe config
const bool bcheck = root[CHECKHEARTBEAT].empty();
const bool bbwip = root[OPENBLACKWHITEIP].empty();
const bool bmaxconn = root[SINGLEIPMAXCONN].empty();
if (!bcheck) {
m_chatSerCfg.bcheckheartbeat = root[CHECKHEARTBEAT].asBool();
}
if (bbwip && (bbwip == bmaxconn)) {
m_chatSerCfg.u8openblackwhiteip =(uint8_t) root[CHECKHEARTBEAT].asUInt();
m_chatSerCfg.u8singleipmaxconn = (uint8_t)root[SINGLEIPMAXCONN].asUInt();
}
}
return true;
}
}
printf("json file no exist or parse json file fail \n");
return false;
}
8:只是帮助分析websocket 协议
红框这边 ssl_accept 是需要优化的,可以考虑用coroutine 或 thread callback
9: 后续继续优化,差不多,再上demo
如果觉得有用,麻烦点个赞,加个收藏