基于CSDN上 SGuniver_22的《Linux下c语言实验Websocket通讯 含客户端和服务器测试代码》的websocket聊天代码


  
SGuniver_22原文在此: 《Linux下c语言实验Websocket通讯 含客户端和服务器测试代码》
我借用了其中的握手代码。自己首先实现一个websocket echo server如下:
server_chat_6.c:


#define MAX_EVENT_NUMBER 1000

#define BUFFER_SIZE  20*1024*1024
int try_analyze_buffer(unsigned char* buffer, int* p_begin,  int in_amount, int sockfd);


//==================== 加密方法 sha1哈希 ====================

typedef struct SHA1Context
{
    uint32_t Message_Digest[5];
    uint32_t Length_Low;
    uint32_t Length_High;
    uint8_t Message_Block[64];
    int32_t Message_Block_Index;
    int32_t Computed;
    int32_t Corrupted;
} SHA1Context;

#define SHA1CircularShift(bits, word) ((((word) << (bits)) & 0xFFFFFFFF) | ((word) >> (32 - (bits))))

void SHA1ProcessMessageBlock(SHA1Context *context)
{
    const uint32_t K[] = {0x5A827999, 0x6ED9EBA1, 0x8F1BBCDC, 0xCA62C1D6};
    int32_t t;
    uint32_t temp;
    uint32_t W[80];
    uint32_t A, B, C, D, E;

    for (t = 0; t < 16; t++)
    {
        W[t] = ((uint32_t)context->Message_Block[t * 4]) << 24;
        W[t] |= ((uint32_t)context->Message_Block[t * 4 + 1]) << 16;
        W[t] |= ((uint32_t)context->Message_Block[t * 4 + 2]) << 8;
        W[t] |= ((uint32_t)context->Message_Block[t * 4 + 3]);
    }

    for (t = 16; t < 80; t++)
        W[t] = SHA1CircularShift(1, W[t - 3] ^ W[t - 8] ^ W[t - 14] ^ W[t - 16]);

    A = context->Message_Digest[0];
    B = context->Message_Digest[1];
    C = context->Message_Digest[2];
    D = context->Message_Digest[3];
    E = context->Message_Digest[4];

    for (t = 0; t < 20; t++)
    {
        temp = SHA1CircularShift(5, A) + ((B & C) | ((~B) & D)) + E + W[t] + K[0];
        temp &= 0xFFFFFFFF;
        E = D;
        D = C;
        C = SHA1CircularShift(30, B);
        B = A;
        A = temp;
    }
    for (t = 20; t < 40; t++)
    {
        temp = SHA1CircularShift(5, A) + (B ^ C ^ D) + E + W[t] + K[1];
        temp &= 0xFFFFFFFF;
        E = D;
        D = C;
        C = SHA1CircularShift(30, B);
        B = A;
        A = temp;
    }
    for (t = 40; t < 60; t++)
    {
        temp = SHA1CircularShift(5, A) + ((B & C) | (B & D) | (C & D)) + E + W[t] + K[2];
        temp &= 0xFFFFFFFF;
        E = D;
        D = C;
        C = SHA1CircularShift(30, B);
        B = A;
        A = temp;
    }
    for (t = 60; t < 80; t++)
    {
        temp = SHA1CircularShift(5, A) + (B ^ C ^ D) + E + W[t] + K[3];
        temp &= 0xFFFFFFFF;
        E = D;
        D = C;
        C = SHA1CircularShift(30, B);
        B = A;
        A = temp;
    }
    context->Message_Digest[0] = (context->Message_Digest[0] + A) & 0xFFFFFFFF;
    context->Message_Digest[1] = (context->Message_Digest[1] + B) & 0xFFFFFFFF;
    context->Message_Digest[2] = (context->Message_Digest[2] + C) & 0xFFFFFFFF;
    context->Message_Digest[3] = (context->Message_Digest[3] + D) & 0xFFFFFFFF;
    context->Message_Digest[4] = (context->Message_Digest[4] + E) & 0xFFFFFFFF;
    context->Message_Block_Index = 0;
}

void SHA1Reset(SHA1Context *context)
{
    context->Length_Low = 0;
    context->Length_High = 0;
    context->Message_Block_Index = 0;

    context->Message_Digest[0] = 0x67452301;
    context->Message_Digest[1] = 0xEFCDAB89;
    context->Message_Digest[2] = 0x98BADCFE;
    context->Message_Digest[3] = 0x10325476;
    context->Message_Digest[4] = 0xC3D2E1F0;

    context->Computed = 0;
    context->Corrupted = 0;
}

void SHA1PadMessage(SHA1Context *context)
{
    if (context->Message_Block_Index > 55)
    {
        context->Message_Block[context->Message_Block_Index++] = 0x80;
        while (context->Message_Block_Index < 64)
            context->Message_Block[context->Message_Block_Index++] = 0;
        SHA1ProcessMessageBlock(context);
        while (context->Message_Block_Index < 56)
            context->Message_Block[context->Message_Block_Index++] = 0;
    }
    else
    {
        context->Message_Block[context->Message_Block_Index++] = 0x80;
        while (context->Message_Block_Index < 56)
            context->Message_Block[context->Message_Block_Index++] = 0;
    }
    context->Message_Block[56] = (context->Length_High >> 24) & 0xFF;
    context->Message_Block[57] = (context->Length_High >> 16) & 0xFF;
    context->Message_Block[58] = (context->Length_High >> 8) & 0xFF;
    context->Message_Block[59] = (context->Length_High) & 0xFF;
    context->Message_Block[60] = (context->Length_Low >> 24) & 0xFF;
    context->Message_Block[61] = (context->Length_Low >> 16) & 0xFF;
    context->Message_Block[62] = (context->Length_Low >> 8) & 0xFF;
    context->Message_Block[63] = (context->Length_Low) & 0xFF;

    SHA1ProcessMessageBlock(context);
}

int32_t SHA1Result(SHA1Context *context)
{
    if (context->Corrupted)
    {
        return 0;
    }
    if (!context->Computed)
    {
        SHA1PadMessage(context);
        context->Computed = 1;
    }
    return 1;
}

void SHA1Input(SHA1Context *context, const char *message_array, uint32_t length)
{
    if (!length)
        return;

    if (context->Computed || context->Corrupted)
    {
        context->Corrupted = 1;
        return;
    }

    while (length-- && !context->Corrupted)
    {
        context->Message_Block[context->Message_Block_Index++] = (*message_array & 0xFF);

        context->Length_Low += 8;

        context->Length_Low &= 0xFFFFFFFF;
        if (context->Length_Low == 0)
        {
            context->Length_High++;
            context->Length_High &= 0xFFFFFFFF;
            if (context->Length_High == 0)
                context->Corrupted = 1;
        }

        if (context->Message_Block_Index == 64)
        {
            SHA1ProcessMessageBlock(context);
        }
        message_array++;
    }
}

/* int32_t sha1_hash(const char *source, char *lrvar){//Main 
    SHA1Context sha; 
    char buf[128]; 
 
    SHA1Reset(&sha); 
    SHA1Input(&sha, source, strlen(source)); 
 
    if (!SHA1Result(&sha)){ 
        printf("SHA1 ERROR: Could not compute message digest"); 
        return -1; 
    } else { 
        memset(buf,0,sizeof(buf)); 
        sprintf(buf, "%08X%08X%08X%08X%08X", sha.Message_Digest[0],sha.Message_Digest[1], 
        sha.Message_Digest[2],sha.Message_Digest[3],sha.Message_Digest[4]); 
        //lr_save_string(buf, lrvar); 
         
        return strlen(buf); 
    } 
} */

char *sha1_hash(const char *source)
{
    SHA1Context sha;
    char *buf; //[128];

    SHA1Reset(&sha);
    SHA1Input(&sha, source, strlen(source));

    if (!SHA1Result(&sha))
    {
        printf("SHA1 ERROR: Could not compute message digest");
        return NULL;
    }
    else
    {
        buf = (char *)malloc(128);
        memset(buf, 0, 128);
        sprintf(buf, "%08X%08X%08X%08X%08X", sha.Message_Digest[0], sha.Message_Digest[1],
                sha.Message_Digest[2], sha.Message_Digest[3], sha.Message_Digest[4]);
        //lr_save_string(buf, lrvar);
        //return strlen(buf);
        return buf;
    }
}

int32_t tolower(int32_t c)
{
    if (c >= 'A' && c <= 'Z')
    {
        return c + 'a' - 'A';
    }
    else
    {
        return c;
    }
}

int32_t htoi(const char s[], int32_t start, int32_t len)
{
    int32_t i, j;
    int32_t n = 0;
    if (s[0] == '0' && (s[1] == 'x' || s[1] == 'X')) //判断是否有前导0x或者0X
    {
        i = 2;
    }
    else
    {
        i = 0;
    }
    i += start;
    j = 0;
    for (; (s[i] >= '0' && s[i] <= '9') || (s[i] >= 'a' && s[i] <= 'f') || (s[i] >= 'A' && s[i] <= 'F'); ++i)
    {
        if (j >= len)
        {
            break;
        }
        if (tolower(s[i]) > '9')
        {
            n = 16 * n + (10 + tolower(s[i]) - 'a');
        }
        else
        {
            n = 16 * n + (tolower(s[i]) - '0');
        }
        j++;
    }
    return n;
}
//==================== 加密方法BASE64 ====================

//base64编/解码用的基础字符集
const char ws_base64char[] =
    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

/*******************************************************************************
 * 名称: ws_base64_encode
 * 功能: ascii编码为base64格式
 * 参数: 
 *      bindata: ascii字符串输入
 *      base64: base64字符串输出
 *      binlength: bindata的长度
 * 返回: base64字符串长度
 * 说明: 无
 ******************************************************************************/
int32_t ws_base64_encode(const uint8_t *bindata, char *base64, int32_t binlength)
{
    int32_t i, j;
    uint8_t current;
    for (i = 0, j = 0; i < binlength; i += 3)
    {
        current = (bindata[i] >> 2);
        current &= (uint8_t)0x3F;
        base64[j++] = ws_base64char[(int32_t)current];
        current = ((uint8_t)(bindata[i] << 4)) & ((uint8_t)0x30);
        if (i + 1 >= binlength)
        {
            base64[j++] = ws_base64char[(int32_t)current];
            base64[j++] = '=';
            base64[j++] = '=';
            break;
        }
        current |= ((uint8_t)(bindata[i + 1] >> 4)) & ((uint8_t)0x0F);
        base64[j++] = ws_base64char[(int32_t)current];
        current = ((uint8_t)(bindata[i + 1] << 2)) & ((uint8_t)0x3C);
        if (i + 2 >= binlength)
        {
            base64[j++] = ws_base64char[(int32_t)current];
            base64[j++] = '=';
            break;
        }
        current |= ((uint8_t)(bindata[i + 2] >> 6)) & ((uint8_t)0x03);
        base64[j++] = ws_base64char[(int32_t)current];
        current = ((uint8_t)bindata[i + 2]) & ((uint8_t)0x3F);
        base64[j++] = ws_base64char[(int32_t)current];
    }
    base64[j] = '\0';
    return j;
}
/*******************************************************************************
 * 名称: ws_base64_decode
 * 功能: base64格式解码为ascii
 * 参数: 
 *      base64: base64字符串输入
 *      bindata: ascii字符串输出
 * 返回: 解码出来的ascii字符串长度
 * 说明: 无
 ******************************************************************************/
int32_t ws_base64_decode(const char *base64, uint8_t *bindata)
{
    int32_t i, j;
    uint8_t k;
    uint8_t temp[4];
    for (i = 0, j = 0; base64[i] != '\0'; i += 4)
    {
        memset(temp, 0xFF, sizeof(temp));
        for (k = 0; k < 64; k++)
        {
            if (ws_base64char[k] == base64[i])
                temp[0] = k;
        }
        for (k = 0; k < 64; k++)
        {
            if (ws_base64char[k] == base64[i + 1])
                temp[1] = k;
        }
        for (k = 0; k < 64; k++)
        {
            if (ws_base64char[k] == base64[i + 2])
                temp[2] = k;
        }
        for (k = 0; k < 64; k++)
        {
            if (ws_base64char[k] == base64[i + 3])
                temp[3] = k;
        }
        bindata[j++] = ((uint8_t)(((uint8_t)(temp[0] << 2)) & 0xFC)) |
                       ((uint8_t)((uint8_t)(temp[1] >> 4) & 0x03));
        if (base64[i + 2] == '=')
            break;
        bindata[j++] = ((uint8_t)(((uint8_t)(temp[1] << 4)) & 0xF0)) |
                       ((uint8_t)((uint8_t)(temp[2] >> 2) & 0x0F));
        if (base64[i + 3] == '=')
            break;
        bindata[j++] = ((uint8_t)(((uint8_t)(temp[2] << 6)) & 0xF0)) |
                       ((uint8_t)(temp[3] & 0x3F));
    }
    return j;
}

/*******************************************************************************
 * 名称: ws_buildRespondShakeKey
 * 功能: server端在接收client端的key后,构建回应用的key
 * 参数:
 *      acceptKey: 来自客户端的key字符串
 *      acceptKeyLen: 长度
 *      respondKey:  在 acceptKey 之后加上 GUID, 再sha1哈希, 再转成base64得到 respondKey
 * 返回: respondKey的长度(肯定比acceptKey要长)
 * 说明: 无
 ******************************************************************************/
int32_t ws_buildRespondShakeKey(char *acceptKey, uint32_t acceptKeyLen, char *respondKey)
{
    char *clientKey;
    char *sha1DataTemp;
    char *sha1Data;
    int32_t i, n;
    const char GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    uint32_t GUIDLEN;

    if (acceptKey == NULL)
        return 0;
    GUIDLEN = sizeof(GUID);
    clientKey = (char *)calloc(acceptKeyLen + GUIDLEN + 10, sizeof(char));
    memset(clientKey, 0, (acceptKeyLen + GUIDLEN + 10));

    memcpy(clientKey, acceptKey, acceptKeyLen);
    memcpy(&clientKey[acceptKeyLen], GUID, GUIDLEN);
    clientKey[acceptKeyLen + GUIDLEN] = '\0';

    sha1DataTemp = sha1_hash(clientKey);
    n = strlen((const char *)sha1DataTemp);
    sha1Data = (char *)calloc(n / 2 + 1, sizeof(char));
    memset(sha1Data, 0, n / 2 + 1);

    for (i = 0; i < n; i += 2)
        sha1Data[i / 2] = htoi(sha1DataTemp, i, 2);
    n = ws_base64_encode((const uint8_t *)sha1Data, (char *)respondKey, (n / 2));

    free(sha1DataTemp);
    free(sha1Data);
    free(clientKey);
    return n;
}
/*******************************************************************************
 * 名称: ws_buildHttpRespond
 * 功能: 构建server端回复client连接请求的http协议
 * 参数:
 *      acceptKey: 来自client的握手key
 *      acceptKeyLen: 长度
 *      package: 存储
 * 返回: 无
 * 说明: 无
 ******************************************************************************/
void ws_buildHttpRespond(char *acceptKey, uint32_t acceptKeyLen, char *package)
{
    const char httpDemo[] =
        "HTTP/1.1 101 Switching Protocols\r\n"
        "Upgrade: websocket\r\n"
        "Server: Microsoft-HTTPAPI/2.0\r\n"
        "Connection: Upgrade\r\n"
        "Sec-WebSocket-Accept: %s\r\n"
        "%s\r\n\r\n"; //时间打包待续, 格式如 "Date: Tue, 20 Jun 2017 08:50:41 CST\r\n"
    time_t now;
    struct tm *tm_now;
    char timeStr[256] = {0};
    char respondShakeKey[256] = {0};
    //构建回应的握手key
    ws_buildRespondShakeKey(acceptKey, acceptKeyLen, respondShakeKey);
    //构建回应时间字符串
    time(&now);
    tm_now = localtime(&now);
    strftime(timeStr, sizeof(timeStr), "Date: %a, %d %b %Y %T %Z", tm_now);
    //组成回复信息
    sprintf(package, httpDemo, respondShakeKey, timeStr);
    //printf("package is %s\n", package);
}

/*******************************************************************************
 * 名称: ws_responseClient
 * 功能: 服务器回复客户端的连接请求, 以建立websocket连接
 * 参数:
 *      fd: 连接控制符
 *      data: 接收到来自客户端的数据(内含http连接请求)
 *      dataLen: 
 *      path: path匹配检查,不用可以置NULL
 * 返回: >0 建立websocket连接成功 <=0 建立websocket连接失败
 * 说明: 无
 ******************************************************************************/
int ws_responseClient(int fd, char *data,  char *path)
//int ws_responseClient(int fd, char *data, int dataLen, char *path)
{
    char *keyOffset;
    int32_t ret;
    char recvShakeKey[512] = {0};
    char respondPackage[1024] = {0};
    //path检查
    if (path && !strstr((char *)data, path))
        return -1;
    //获取握手key
    if (!(keyOffset = strstr((char *)data, "Sec-WebSocket-Key: ")))
        return -1;
    //获取握手key
    keyOffset += strlen("Sec-WebSocket-Key: ");
    sscanf((const char *)keyOffset, "%s", recvShakeKey);
    ret = strlen((const char *)recvShakeKey);
    if (ret < 1)
        return -1;
    //创建回复key
    ws_buildHttpRespond(recvShakeKey, (uint32_t)ret, respondPackage);
    //printf("response is %s\n", respondPackage);
    return send(fd, respondPackage, strlen((const char *)respondPackage), MSG_NOSIGNAL);
}


unsigned char buf[BUFFER_SIZE];
void et(struct epoll_event* events, int number, int epollfd, int listenfd)
{
	
}


int setnonblocking(int fd)
{
	int  old_option = fcntl(fd, F_GETFL);
	int  new_option = old_option | O_NONBLOCK;
	fcntl( fd, F_SETFL, new_option);
	return old_option;
	
}
void addfd(int epollfd, int fd, bool enable_et)
{
	struct epoll_event  event;
	event.data.fd = fd;
	event.events = EPOLLIN;
	if(enable_et)
        event.events  |= EPOLLET;
    epoll_ctl(epollfd,  EPOLL_CTL_ADD,  fd,  &event);
    setnonblocking( fd );
}

int main(int argc, char** argv)
{


   if( argc != 2)
   {
        printf("Usage: ./server_chat    port\n");
		return 1;
   }

   printf("abc's sha1 is %s\n", sha1_hash("abc"));



   const char*  ip = "180.76.157.17"; //argv[1];
   int port = atoi(argv[1]);
   
   int ret = 0;
   
   struct sockaddr_in  address;
	struct servent*   pse;
	struct protoent*  ppe;
   bzero(&address, sizeof(address));
   address.sin_family  =  AF_INET;
   address.sin_addr.s_addr =  INADDR_ANY;
   address.sin_port = htons(port);
		 ppe = getprotobyname("tcp") ;
   
   int listenfd = socket(PF_INET, SOCK_STREAM, ppe->p_proto);
   
   ret = bind(listenfd, (struct sockaddr* )&address, sizeof(address) );
   if(ret == -1)
   {
	   printf("bind error");
	   return -1;
   }
   
   ret = listen(listenfd, 5);
   if(ret == -1)
   {
	   printf("listen error");
	   return -1;
   }
   
   
   struct epoll_event events[MAX_EVENT_NUMBER];
   int  epollfd = epoll_create(5);
   if(epollfd == -1)
   {
	   printf("epoll_create error");
	   return -1;
   }
   addfd(epollfd, listenfd, true);
   
   while(1)
   {
	   int number = epoll_wait(epollfd, events, MAX_EVENT_NUMBER, -1);
	   if(number < 0)
	   {
		   printf("epoll_wait failure\n");
		   break;
		   
	   }
	   //et(events, number, epollfd, listenfd);
	   
		for(int i = 0; i < number; i++)
		{
			int sockfd = events[i].data.fd;
			if(sockfd == listenfd)
			{
				struct sockaddr_in client_address;
				socklen_t  client_addrlength = sizeof(client_address);
				int connfd = accept(listenfd, (struct sockaddr*) &client_address, &client_addrlength);
				addfd(epollfd, connfd, true);
			}
			else if(events[i].events & EPOLLIN)
			{
				printf("EPOLLIN triggered once.\n");
				memset(buf, 0, BUFFER_SIZE);
				int begin  = 0;
				int amount = 0;
				while(1)
				{
RECV_DATA:
					int ret = recv(sockfd, buf+amount, BUFFER_SIZE-1, 0);
					if(ret < 0)
					{
						if(  errno==EAGAIN  ||  errno == EWOULDBLOCK )
						{
							printf("error EAGAIN or EWOULDBLOCK\n");
							break;
						}
						else
						{
							close(sockfd);
							break;
						}
					}
					else if(ret == 0)
					{
						close(sockfd);
					}
					else
					{
						printf("get %d bytes of content\n", ret);
						//printf("get %d bytes of content:%s\n", ret, buf);
						amount += ret;
						if(amount > BUFFER_SIZE)
						{
							printf("buffer maybe overflowed\n");
						
						    unsigned char tmpbuf[2];
						    tmpbuf[0] = 0x88;
						    tmpbuf[1] = 0x0;
						    int ret2 = send(sockfd, tmpbuf, 2, 0);
						    printf("send %d bytes.\n", ret2);
						
						}
					}
				}

                //analyze the packet
                if(amount > 0)
                {
					if(strstr((char*)buf, "HTTP"))
			    			ws_responseClient(sockfd, (char*)buf,  (char*)"/");
					else//buf amount
					{
                                          int state = 0;
                                          do{
                                          state = try_analyze_buffer(buf, &begin, amount, sockfd);
                                          if(state == -1)		 
							 goto RECV_DATA;
                                          }
                                          while(state == 1);
					}
					
					
				}					
			}
			else
			{
				printf("something else happened.\n");
			}
		}//for
	   
	   
   }
   
   close(listenfd);
   return 0;

}


int try_analyze_buffer(unsigned char* buffer, int* p_begin,  int in_amount, int sockfd)
{
                                                printf("looking at %p,begin:%d,amount:%d; \n", buffer, *p_begin, in_amount);

	                                        int amount = in_amount - *p_begin;
						if (amount < 2)
						{
							return -1;
						}
						
						unsigned char* buf = buffer + *p_begin;
						uint8_t FIN = buf[0] & 0x80;
						uint8_t OPCODE = buf[0] & 0x0F;
						uint8_t MASK = buf[1] & 0x80;
						uint8_t MASKS[4];
						uint32_t payload_len = buf[1] & 0x7F;
						unsigned long long pay_len = 0;
						printf("OPCODE:%d, MASK:%d, payload_len:%d\n", OPCODE,MASK,payload_len);
						int payload_begin = 0;
						long long len = 0;
						if(payload_len < 126)
						{
   						     len = 2+(MASK?4:0)+payload_len;
							 if( amount  < len)
								 return -1;
						    if(MASK){
						        MASKS[0] = buf[2];
						        MASKS[1] = buf[3];
						        MASKS[2] = buf[4];
						        MASKS[3] = buf[5];
						    }
						    payload_begin = 2+(MASK?4:0);
						    for(int i = 0,  j =0 ; i < payload_len; i++,j++)
						    {
							    j = j%4;
						        buf[payload_begin+i] ^= MASKS[j]; 
							if(OPCODE==1 || OPCODE ==0x2)
							{
							    printf("%d: %c, 0x%x\n", i, (char)buf[payload_begin+i], buf[payload_begin+i]);
							}
						    }
						    buf[payload_begin - 1] = payload_len;
						    buf[payload_begin - 2] = FIN | (OPCODE==0x9? 0xA:OPCODE);
							


						}
						else if(payload_len == 126)
						{
							 if(amount < 4)
								 return -1;
						     	
   						     pay_len = ( (uint32_t)(buf[2]<<8)+ buf[3] );
   						     len = 2+2+(MASK?4:0)+( (uint32_t)(buf[2]<<8)+ buf[3] );
							 if(amount < len)
								 return -1;
						    if(MASK){
						        MASKS[0] = buf[4];
						        MASKS[1] = buf[5];
						        MASKS[2] = buf[6];
						        MASKS[3] = buf[7];
						    }
						    payload_begin = 4+(MASK?4:0);
						    for(int i = 0,  j =0 ; i < pay_len; i++,j++)
						    {
							    j = j%4;
						        buf[payload_begin+i] ^= MASKS[j]; 
							if(OPCODE==1)
							{
							    printf("%d: %c, 0x%x\n", i, (char)buf[payload_begin+i], buf[payload_begin+i]);
							}
						    }
						    buf[payload_begin - 1] = pay_len & 0xFF;
						    buf[payload_begin - 2] = (pay_len & 0xFF00) >> 8;
						    buf[payload_begin - 3] = 126;
						    buf[payload_begin - 4] = FIN | OPCODE;
						}
						else if(payload_len == 127)
						{
							 if(amount < 10)
								 return -1;
   						     len = 2+8+(MASK?4:0);
							 len +=  ((unsigned long long)buf[2]<<56);
							 len +=  ((unsigned long long)buf[3]<<48);
							 len +=  ((unsigned long long)buf[4]<<40);
							 len +=  ((unsigned long long)buf[5]<<32);
							 len +=  ((unsigned long long)buf[6]<<24);
							 len +=  ((unsigned long long)buf[7]<<16);
							 len +=  ((unsigned long long)buf[8]<<8);
							 len +=  ((unsigned long long)buf[9]<<0);
							 if(amount < len)
								 return -1;
							 pay_len = len - 2-8 - (MASK?4:0);
						    if(MASK){
						        MASKS[0] = buf[10];
						        MASKS[1] = buf[11];
						        MASKS[2] = buf[12];
						        MASKS[3] = buf[13];
						    }
						    payload_begin = 10+(MASK?4:0);
						    for(int i = 0,  j =0 ; i < pay_len; i++,j++)
						    {
							    j = j%4;
						        buf[payload_begin+i] ^= MASKS[j]; 
							    if(OPCODE==1)
							    {
									if(i > pay_len - 16)
							            printf("%d: %c, 0x%x\n", i, (char)buf[payload_begin+i], buf[payload_begin+i]);
							    }
							}
						    
						    buf[payload_begin - 1] =  pay_len & 0xFF;
						    buf[payload_begin - 2] = (pay_len & 0xFF00) >> 8;
						    buf[payload_begin - 3] = (pay_len & 0xFF0000) >> 16;
						    buf[payload_begin - 4] = (pay_len & 0xFF000000) >> 24;
						    buf[payload_begin - 5] = (pay_len & 0xFF00000000) >> 32;
						    buf[payload_begin - 6] = (pay_len & 0xFF0000000000) >> 40;
						    buf[payload_begin - 7] = (pay_len & 0xFF000000000000) >> 48;
						    buf[payload_begin - 8] = (pay_len & 0xFF00000000000000) >> 56;
						    buf[payload_begin - 9] = 127;
						    buf[payload_begin - 10] = FIN | OPCODE;
							 
						}

						int ret2;
						if(OPCODE == 0x1 ||OPCODE == 0x0 || OPCODE == 0x2 || OPCODE == 0x9)
						{
                                                    int amount_out = 0;
                                                    int amount_sent = 0;
                                                    unsigned char* buf_out;


						    if(payload_len < 126)
						    {
						         //ret2 = send(sockfd, buf+payload_begin-2, 2+payload_len, 0);
                                                        buf_out = buf+payload_begin -2;
                                                        amount_out = 2+payload_len;
						    }
							else if(payload_len == 126)
							{
						         //ret2 = send(sockfd, buf+payload_begin-4, 4+pay_len, 0);
                                                        buf_out = buf+payload_begin -4;
                                                        amount_out = 4+pay_len;
							}
							else if(payload_len == 127)
							{
						         //ret2 = send(sockfd, buf+payload_begin-10, 10+pay_len, 0);
                                                        buf_out = buf+payload_begin -10;
                                                        amount_out = 10+pay_len;
							}

                                                        while(amount_sent < amount_out)
                                                        {
						           ret2 = send(sockfd, buf_out+amount_sent, amount_out-amount_sent, 0);
                                                           if(ret2>0)
                                                               amount_sent += ret2;
                                                           else if(ret2<0)
                                                               continue;
                                                           
                                                        }
						    printf("send %d/%d bytes.payload_begin:%d, payload_len:%d,pay_len:%lld.\n", amount_sent,amount_out, payload_begin, payload_len, pay_len);
                            fflush(stdout);
						}
						else if(OPCODE == 0x8)
						{
						    unsigned char tmpbuf[2];
						    tmpbuf[0] = 0x88;
						    tmpbuf[1] = 0x0;
						    int ret2 = send(sockfd, tmpbuf, 2, 0);
						    printf("send %d bytes.\n", ret2);
						}

                        printf("amount: %d, len: %lld\n", amount, len);
  					    if(amount == len)
						{
						    printf("a complete frame.\n");
						}
						else if(len > amount)
						{
						    printf("incorrect amount and len?\n");
						}
						else if(len < amount)
						{
						    printf("more frame?\n");
							*p_begin  += len;
                            //int state2 = try_analyze_buffer(buffer, p_begin, in_amount, sockfd);
			    //				return state2;
                                                    return 1;
						}
						return 0;
}

 
Autobahn测试该echo server报告如下:(wstest -m fuzzingclient -s fuzzingclient.json) reports_complete1.zip
源文件在此:server_chat_6.c
然后改动试验聊天室代码:
websocket_chat.cpp(部分代码):
int main(int argc, char** argv)
{


   if( argc != 2)
   {
        printf("Usage: ./server_chat    port\n");
		return 1;
   }

   printf("abc's sha1 is %s\n", sha1_hash("abc"));



   const char*  ip = "180.76.157.17"; //argv[1];
   int port = atoi(argv[1]);
   
   int ret = 0;
   
   struct sockaddr_in  address;
	struct servent*   pse;
	struct protoent*  ppe;
   bzero(&address, sizeof(address));
   address.sin_family  =  AF_INET;
   address.sin_addr.s_addr =  INADDR_ANY;
   address.sin_port = htons(port);
		 ppe = getprotobyname("tcp") ;
   
   int listenfd = socket(PF_INET, SOCK_STREAM, ppe->p_proto);
   
   ret = bind(listenfd, (struct sockaddr* )&address, sizeof(address) );
   if(ret == -1)
   {
	   printf("bind error");
	   return -1;
   }
   
   ret = listen(listenfd, 5);
   if(ret == -1)
   {
	   printf("listen error");
	   return -1;
   }
   
   
   struct epoll_event events[MAX_EVENT_NUMBER];
   int  epollfd = epoll_create(5);
   if(epollfd == -1)
   {
	   printf("epoll_create error");
	   return -1;
   }
   addfd(epollfd, listenfd, true);
   
   while(1)
   {
	   int number = epoll_wait(epollfd, events, MAX_EVENT_NUMBER, -1);
	   if(number < 0)
	   {
		   printf("epoll_wait failure\n");
		   break;
		   
	   }
	   //et(events, number, epollfd, listenfd);
	   
		for(int i = 0; i < number; i++)
		{
			int sockfd = events[i].data.fd;
			if(sockfd == listenfd)
			{
				struct sockaddr_in client_address;
				socklen_t  client_addrlength = sizeof(client_address);
				int connfd = accept(listenfd, (struct sockaddr*) &client_address, &client_addrlength);
				addfd(epollfd, connfd, true);
			}
			else if(events[i].events & EPOLLIN)
			{
				printf("EPOLLIN triggered once.\n");
				memset(buf, 0, BUFFER_SIZE);
				int begin  = 0;
				int amount = 0;
				while(1)
				{
RECV_DATA:
					int ret = recv(sockfd, buf+amount, BUFFER_SIZE-1, 0);
					if(ret < 0)
					{
						if(  errno==EAGAIN  ||  errno == EWOULDBLOCK )
						{
							printf("error EAGAIN or EWOULDBLOCK\n");
							break;
						}
						else
						{
							close(sockfd);
							break;
						}
					}
					else if(ret == 0)
					{
						close(sockfd);
					}
					else
					{
						printf("get %d bytes of content\n", ret);
						//printf("get %d bytes of content:%s\n", ret, buf);
						amount += ret;
						if(amount > BUFFER_SIZE)
						{
							printf("buffer maybe overflowed\n");
						
						    unsigned char tmpbuf[2];
						    tmpbuf[0] = 0x88;
						    tmpbuf[1] = 0x0;
						    int ret2 = send(sockfd, tmpbuf, 2, 0);
						    printf("send %d bytes.\n", ret2);
						
						}
					}
				}

                //analyze the packet
                if(amount > 0)
                {
					if(strstr((char*)buf, "HTTP"))
			    			ws_responseClient(sockfd, (char*)buf,  (char*)"/");
					else//buf amount
					{
                                          int state = 0;
                                          do{
                                          state = try_analyze_buffer(buf, &begin, amount, sockfd);
                                          if(state == -1)		 
							 goto RECV_DATA;
                                          }
                                          while(state == 1);
					}
					
					
				}					
			}
			else
			{
				printf("something else happened.\n");
			}
		}//for
	   
	   
   }
   
   close(listenfd);
   return 0;

}

typedef struct{
	string user_name;
	int    master_fd;
	
}master_fd_s;
vector  vec_master;

typedef struct{
	string user_name;
	string peer_name;
	int    user_fd;
	int    peer_fd;
	//state
}user_fd_s;
vector vec_user;

/*
typedef struct{
	int    user_fd;
	int    peer_fd;
}fd_s;
vector vec_fd;
*/

void register_master_fd(char* name, int sockfd)
{
	master_fd_s t;
	t.user_name = name;
	t.master_fd = sockfd;
	vec_master.push_back(t);
	for(int i=0; i < vec_master.size();i++)
		cout << vec_master[i].user_name << vec_master[i].master_fd << endl;
}
void get_names(char* sender_name, char* recver_name, char* buf)
{
	int i;
	i =0;
	while(buf[i]!=' ')
	{
		sender_name[i] = buf[i];
		i++;
	}
	char* p = strstr(buf, "to ");
	p+=3;
	i = 0;
	while(*p != ':')
		recver_name[i++] = *p++;
	
	printf("sender:%s, recver:%s\n", sender_name, recver_name); 
}

int find_fd(char* sender_name, char* recver_name, int sockfd)
{
	int i,j;
	for(i = 0 ;  i< vec_user.size(); i++)
	{
		if(vec_user[i].user_name == sender_name
		&& vec_user[i].peer_name == recver_name)
		    break;
	}
	if(i == vec_user.size())
	{
		user_fd_s    t,t2;
		t.user_name = sender_name;
		t.peer_name = recver_name;
		t.user_fd = sockfd;
		t.peer_fd = 0;
		vec_user.push_back(t);
		t2.user_name = recver_name;
		t2.peer_name = sender_name;
		t2.peer_fd = sockfd;
		t2.user_fd = 0;
		vec_user.push_back(t2);
		
		//return master_fd of peer_name
		for(j = 0; j < vec_master.size(); j++)
		{
			if(vec_master[j].user_name == recver_name)
				break;
		}
		return vec_master[j].master_fd;
		
	}
	else if(vec_user[i].peer_fd > 0)
	{
		if(vec_user[i].user_fd == 0)
			vec_user[i].user_fd = sockfd;
		
		for(j = 0; j < vec_user.size(); j++)
		{
			if(vec_user[j].user_name == recver_name
			&& vec_user[j].peer_name == sender_name)
				break;
		}
		if(j < vec_user.size())
		{
			//if(vec_user[i].user_fd == 0)
			//	vec_user[i].user_fd = sockfd;
			if(vec_user[j].peer_fd == 0)
				vec_user[j].peer_fd = sockfd;
		}
		
		return vec_user[i].peer_fd;
	}
	else if(vec_user[i].peer_fd == 0)
	{
		for(j = 0; j < vec_master.size(); j++)
		{
			if(vec_master[j].user_name == recver_name)
				break;
		}
		return vec_master[j].master_fd;
	}
}

int try_analyze_buffer(unsigned char* buffer, int* p_begin,  int in_amount, int sockfd)
{
                                                printf("looking at %p,begin:%d,amount:%d; \n", buffer, *p_begin, in_amount);

	                                        int amount = in_amount - *p_begin;
						if (amount < 2)
						{
							return -1;
						}
						
						unsigned char* buf = buffer + *p_begin;
						uint8_t FIN = buf[0] & 0x80;
						uint8_t OPCODE = buf[0] & 0x0F;
						uint8_t MASK = buf[1] & 0x80;
						uint8_t MASKS[4];
						uint32_t payload_len = buf[1] & 0x7F;
						unsigned long long pay_len = 0;
						printf("OPCODE:%d, MASK:%d, payload_len:%d\n", OPCODE,MASK,payload_len);
						int payload_begin = 0;
						long long len = 0;
						if(payload_len < 126)
						{
   						     len = 2+(MASK?4:0)+payload_len;
							 if( amount  < len)
								 return -1;
							 //pay_len = payload_len;
						    if(MASK){
						        MASKS[0] = buf[2];
						        MASKS[1] = buf[3];
						        MASKS[2] = buf[4];
						        MASKS[3] = buf[5];
						    }
						    payload_begin = 2+(MASK?4:0);
						    for(int i = 0,  j =0 ; i < payload_len; i++,j++)
						    {
							    j = j%4;
						        buf[payload_begin+i] ^= MASKS[j]; 
							if(OPCODE==1 || OPCODE ==0x2)
							{
							    printf("%d: %c, 0x%x\n", i, (char)buf[payload_begin+i], buf[payload_begin+i]);
							}
						    }
						    buf[payload_begin - 1] = payload_len;
						    buf[payload_begin - 2] = FIN | (OPCODE==0x9? 0xA:OPCODE);
							


						}
						else if(payload_len == 126)
						{
							 if(amount < 4)
								 return -1;
						     	
   						     pay_len = ( (uint32_t)(buf[2]<<8)+ buf[3] );
   						     len = 2+2+(MASK?4:0)+( (uint32_t)(buf[2]<<8)+ buf[3] );
							 if(amount < len)
								 return -1;
						    if(MASK){
						        MASKS[0] = buf[4];
						        MASKS[1] = buf[5];
						        MASKS[2] = buf[6];
						        MASKS[3] = buf[7];
						    }
						    payload_begin = 4+(MASK?4:0);
						    for(int i = 0,  j =0 ; i < pay_len; i++,j++)
						    {
							    j = j%4;
						        buf[payload_begin+i] ^= MASKS[j]; 
							if(OPCODE==1)
							{
							    printf("%d: %c, 0x%x\n", i, (char)buf[payload_begin+i], buf[payload_begin+i]);
							}
						    }
						    buf[payload_begin - 1] = pay_len & 0xFF;
						    buf[payload_begin - 2] = (pay_len & 0xFF00) >> 8;
						    buf[payload_begin - 3] = 126;
						    buf[payload_begin - 4] = FIN | OPCODE;
						}
						else if(payload_len == 127)
						{
							 if(amount < 10)
								 return -1;
   						     len = 2+8+(MASK?4:0);
							 len +=  ((unsigned long long)buf[2]<<56);
							 len +=  ((unsigned long long)buf[3]<<48);
							 len +=  ((unsigned long long)buf[4]<<40);
							 len +=  ((unsigned long long)buf[5]<<32);
							 len +=  ((unsigned long long)buf[6]<<24);
							 len +=  ((unsigned long long)buf[7]<<16);
							 len +=  ((unsigned long long)buf[8]<<8);
							 len +=  ((unsigned long long)buf[9]<<0);
							 if(amount < len)
								 return -1;
							 pay_len = len - 2-8 - (MASK?4:0);
						    if(MASK){
						        MASKS[0] = buf[10];
						        MASKS[1] = buf[11];
						        MASKS[2] = buf[12];
						        MASKS[3] = buf[13];
						    }
						    payload_begin = 10+(MASK?4:0);
						    for(int i = 0,  j =0 ; i < pay_len; i++,j++)
						    {
							    j = j%4;
						        buf[payload_begin+i] ^= MASKS[j]; 
							    if(OPCODE==1)
							    {
									if(i > pay_len - 16)
							            printf("%d: %c, 0x%x\n", i, (char)buf[payload_begin+i], buf[payload_begin+i]);
							    }
							}
						    
						    buf[payload_begin - 1] =  pay_len & 0xFF;
						    buf[payload_begin - 2] = (pay_len & 0xFF00) >> 8;
						    buf[payload_begin - 3] = (pay_len & 0xFF0000) >> 16;
						    buf[payload_begin - 4] = (pay_len & 0xFF000000) >> 24;
						    buf[payload_begin - 5] = (pay_len & 0xFF00000000) >> 32;
						    buf[payload_begin - 6] = (pay_len & 0xFF0000000000) >> 40;
						    buf[payload_begin - 7] = (pay_len & 0xFF000000000000) >> 48;
						    buf[payload_begin - 8] = (pay_len & 0xFF00000000000000) >> 56;
						    buf[payload_begin - 9] = 127;
						    buf[payload_begin - 10] = FIN | OPCODE;
							 
						}

						int ret2;
						if(OPCODE == 0x1 ||OPCODE == 0x0 || OPCODE == 0x2 || OPCODE == 0x9)
						{
                                                    int amount_out = 0;
                                                    int amount_sent = 0;
                                                    unsigned char* buf_out;


						    if(payload_len < 126)
						    {
						         //ret2 = send(sockfd, buf+payload_begin-2, 2+payload_len, 0);
                                                        buf_out = buf+payload_begin -2;
                                                        amount_out = 2+payload_len;
						    }
							else if(payload_len == 126)
							{
						         //ret2 = send(sockfd, buf+payload_begin-4, 4+pay_len, 0);
                                                        buf_out = buf+payload_begin -4;
                                                        amount_out = 4+pay_len;
							}
							else if(payload_len == 127)
							{
						         //ret2 = send(sockfd, buf+payload_begin-10, 10+pay_len, 0);
                                                        buf_out = buf+payload_begin -10;
                                                        amount_out = 10+pay_len;
							}
							
							if(!strncmp((char*)buf+payload_begin,"I am ",5))
							{
								
								char name[126];
								memset(name, 0, sizeof(name));
								strncpy(name, (char*)buf+payload_begin+5, payload_len-5);
								printf("%s login\n", name);
								register_master_fd(name, sockfd);
								return 0;
							}
							
							char sender_name[126];
							char recver_name[126];
								memset(sender_name, 0, sizeof(sender_name));
								memset(recver_name, 0, sizeof(recver_name));
							get_names(sender_name, recver_name,(char*)buf+payload_begin);
							
							
								   //int fd = sockfd;
								   int fd = find_fd(sender_name, recver_name, sockfd);
								   printf("find_fd return %d\n", fd);

                                                        while(amount_sent < amount_out)
                                                        {
															printf("sent:%d, out:%d\n", amount_sent, amount_out);
						           //ret2 = send(sockfd, buf_out+amount_sent, amount_out-amount_sent, 0);
						           ret2 = send(fd, buf_out+amount_sent, amount_out-amount_sent, 0);
                                                           if(ret2>0)
                                                               amount_sent += ret2;
                                                           else if(ret2<0)
                                                               continue;
                                                           
                                                        }
						    printf("send %d/%d bytes.payload_begin:%d, payload_len:%d,pay_len:%lld.\n", amount_sent,amount_out, payload_begin, payload_len, pay_len);
                            fflush(stdout);
						}
						else if(OPCODE == 0x8)
						{
						    unsigned char tmpbuf[2];
						    tmpbuf[0] = 0x88;
						    tmpbuf[1] = 0x0;
						    int ret2 = send(sockfd, tmpbuf, 2, 0);
						    printf("send %d bytes.\n", ret2);
						}

                        printf("amount: %d, len: %lld\n", amount, len);
  					    if(amount == len)
						{
						    printf("a complete frame.\n");
						}
						else if(len > amount)
						{
						    printf("incorrect amount and len?\n");
						}
						else if(len < amount)
						{
						    printf("more frame?\n");
							*p_begin  += len;
                            //int state2 = try_analyze_buffer(buffer, p_begin, in_amount, sockfd);
			    //				return state2;
                                                    return 1;
						}
						return 0;
}


 
编译运行:g++ -o test websocket_chat.cpp ; ./test 9001
源文件连同前端.php文件打包在此:websocket_chat01.zip

  

More powered by