Last active
May 27, 2022 14:00
-
-
Save ammarfaizi2/37f22e88698fe60083889e1d5e8308a1 to your computer and use it in GitHub Desktop.
Socket Lost Control
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef _GNU_SOURCE | |
#define _GNU_SOURCE | |
#endif | |
#ifndef likely | |
#define likely(EXPR) __builtin_expect(!!(EXPR), 1) | |
#endif | |
#ifndef unlikely | |
#define unlikely(EXPR) __builtin_expect(!!(EXPR), 0) | |
#endif | |
#include <poll.h> | |
#include <pthread.h> | |
#include <arpa/inet.h> | |
#include <sys/socket.h> | |
#include <sys/types.h> | |
#include <signal.h> | |
#include <unistd.h> | |
#include <atomic> | |
#include <cstdlib> | |
#include <cerrno> | |
#include <cstdio> | |
#include <cstdint> | |
#include <cstring> | |
#include <mutex> | |
#include <unordered_map> | |
enum { | |
SSS_PUB_CLIENT_ADDR = 1 | |
}; | |
struct sss_pkt { | |
uint8_t type; | |
struct sockaddr_in client_addr; | |
}; | |
struct client_slot { | |
int fd; | |
int fd_from_lost; | |
}; | |
struct server_data { | |
int fd1; | |
int fd2; | |
std::mutex map_lock; | |
std::unordered_map<uint64_t, struct client_slot *> map; | |
std::atomic<bool> got_lost_client; | |
}; | |
struct client_addr_info { | |
const char *target_addr; | |
const char *server_addr; | |
uint16_t target_port; | |
uint16_t server_port; | |
}; | |
struct client_data { | |
const struct client_addr_info *addr; | |
size_t pkt_len; | |
struct sss_pkt pkt; | |
}; | |
static std::atomic<bool> g_stop_server; | |
static int create_tcp_sock(void) | |
{ | |
int fd; | |
fd = socket(AF_INET, SOCK_STREAM, 0); | |
if (fd < 0) { | |
fd = errno; | |
perror("socket"); | |
return -fd; | |
} | |
return fd; | |
} | |
static int bind_and_listen_tcp_sock(int fd, const char *baddr, uint16_t bport) | |
{ | |
struct sockaddr_in addr; | |
int err; | |
memset(&addr, 0, sizeof(addr)); | |
addr.sin_family = AF_INET; | |
addr.sin_port = htons(bport); | |
addr.sin_addr.s_addr = inet_addr(baddr); | |
err = bind(fd, (struct sockaddr *)&addr, sizeof(addr)); | |
if (err < 0) { | |
err = errno; | |
perror("bind"); | |
return -err; | |
} | |
err = listen(fd, 10); | |
if (err < 0) { | |
err = errno; | |
perror("listen"); | |
return -err; | |
} | |
return 0; | |
} | |
static int recv_and_send(int fd_in, int fd_out, char *buffer, size_t len) | |
{ | |
ssize_t recv_ret; | |
ssize_t send_ret; | |
int err; | |
recv_ret = recv(fd_in, buffer, len, MSG_DONTWAIT); | |
if (unlikely(recv_ret < 0)) { | |
err = errno; | |
perror("recv"); | |
return -err; | |
} | |
send_ret = send(fd_out, buffer, (size_t)recv_ret, MSG_DONTWAIT); | |
if (unlikely(send_ret < 0)) { | |
err = errno; | |
perror("send"); | |
return -err; | |
} | |
return 0; | |
} | |
static inline uint64_t gen_map_key(struct sockaddr_in *addr) | |
{ | |
return ((uint64_t)addr->sin_addr.s_addr << 16ull) | (uint64_t)addr->sin_port; | |
} | |
static int connect_tcp_sock(int fd, const char *addr, uint16_t port) | |
{ | |
struct sockaddr_in dst_addr; | |
int err; | |
memset(&dst_addr, 0, sizeof(dst_addr)); | |
dst_addr.sin_family = AF_INET; | |
dst_addr.sin_port = htons(port); | |
dst_addr.sin_addr.s_addr = inet_addr(addr); | |
printf("Connecting to %s:%u...\n", addr, port); | |
err = connect(fd, (struct sockaddr *)&dst_addr, sizeof(dst_addr)); | |
if (err < 0) { | |
err = errno; | |
perror("connect"); | |
return -err; | |
} | |
printf("Connected!\n"); | |
return 0; | |
} | |
static int handle_lost_control_client(struct server_data *data) | |
{ | |
struct sockaddr_in addr; | |
int fd = data->fd1; | |
struct sss_pkt pkt; | |
socklen_t addrlen; | |
ssize_t recv_ret; | |
int client_fd; | |
int ret; | |
do_accept: | |
if (atomic_load(&g_stop_server)) | |
return 0; | |
addrlen = sizeof(addr); | |
client_fd = accept(fd, (struct sockaddr *)&addr, &addrlen); | |
if (unlikely(client_fd < 0)) { | |
ret = errno; | |
perror("accept"); | |
goto out; | |
} | |
recv_ret = recv(client_fd, &pkt, sizeof(pkt), MSG_WAITALL); | |
if (unlikely(recv_ret)) { | |
ret = errno; | |
perror("recv"); | |
goto out; | |
} | |
if (recv_ret != sizeof(pkt)) { | |
close(client_fd); | |
goto do_accept; | |
} | |
{ | |
data->map_lock.lock(); | |
auto it = data->map.find(gen_map_key(&addr)); | |
if (it != data->map.end()) { | |
struct client_slot *slot = it->second; | |
slot->fd_from_lost = client_fd; | |
} | |
data->map_lock.unlock(); | |
} | |
goto do_accept; | |
out: | |
atomic_store(&g_stop_server, true); | |
return ret; | |
} | |
struct handle_client_in_server_data { | |
struct client_slot slot; | |
struct server_data *data; | |
}; | |
static void *__handle_client_in_server(void *data_p) | |
{ | |
struct handle_client_in_server_data *hc_data; | |
int fd1 = -1, fd2 = -1; | |
struct pollfd fds[2]; | |
char buffer[4096]; | |
ssize_t opret; | |
int ret; | |
hc_data = (struct handle_client_in_server_data *)data_p; | |
while (hc_data->slot.fd_from_lost == -1) { | |
__asm__ volatile ("":"+r"(hc_data)::); | |
sleep(1); | |
if (atomic_load(&g_stop_server)) | |
return 0; | |
} | |
fd1 = hc_data->slot.fd; | |
fd2 = hc_data->slot.fd_from_lost; | |
fds[0].fd = fd1; | |
fds[0].events = POLLIN | POLLPRI; | |
fds[1].fd = fd2; | |
fds[1].events = POLLIN | POLLPRI; | |
do_poll: | |
ret = poll(fds, 2, -1); | |
if (unlikely(ret)) { | |
perror("poll"); | |
goto out; | |
} | |
if (fds[0].revents & POLLIN) { | |
opret = recv_and_send(fd1, fd2, buffer, sizeof(buffer)); | |
if (unlikely(opret < 0)) | |
goto out; | |
} | |
if (fds[1].revents & POLLIN) { | |
opret = recv_and_send(fd2, fd1, buffer, sizeof(buffer)); | |
if (unlikely(opret < 0)) | |
goto out; | |
} | |
goto do_poll; | |
out: | |
if (fd1 != -1) | |
close(fd1); | |
if (fd2 != -1) | |
close(fd2); | |
delete hc_data; | |
return NULL; | |
} | |
static void _handle_client_in_server(int fd, struct server_data *data, | |
struct sockaddr_in *addr) | |
{ | |
struct handle_client_in_server_data *hc_data; | |
struct sss_pkt pkt; | |
pthread_t thread; | |
ssize_t send_ret; | |
int err; | |
hc_data = new struct handle_client_in_server_data; | |
if (unlikely(!hc_data)) | |
return; | |
hc_data->slot.fd = fd; | |
hc_data->slot.fd_from_lost = -1; | |
data->map_lock.lock(); | |
data->map.emplace(gen_map_key(addr), &hc_data->slot); | |
data->map_lock.unlock(); | |
pkt.type = SSS_PUB_CLIENT_ADDR; | |
pkt.client_addr = *addr; | |
send_ret = send(fd, &pkt, sizeof(pkt), MSG_WAITALL); | |
if (unlikely(send_ret < 0)) { | |
perror("send"); | |
delete hc_data; | |
close(client_fd); | |
return; | |
} | |
err = pthread_create(&thread, NULL, __handle_client_in_server, hc_data); | |
if (unlikely(err)) { | |
errno = err; | |
perror("pthread_create"); | |
delete hc_data; | |
close(client_fd); | |
return; | |
} | |
pthread_detach(thread); | |
} | |
static void *handle_client_in_server(void *data_p) | |
{ | |
struct server_data *data = (struct server_data *)data_p; | |
struct sockaddr_in addr; | |
int fd = data->fd1; | |
socklen_t addrlen; | |
int client_fd; | |
while (!atomic_load(&data->got_lost_client)) { | |
sleep(1); | |
if (atomic_load(&g_stop_server)) | |
return 0; | |
} | |
do_accept: | |
if (atomic_load(&g_stop_server)) | |
return 0; | |
addrlen = sizeof(addr); | |
client_fd = accept(fd, (struct sockaddr *)&addr, &addrlen); | |
if (unlikely(client_fd < 0)) { | |
atomic_store(&g_stop_server, true); | |
goto out; | |
} | |
_handle_client_in_server(client_fd, data, &addr); | |
goto do_accept; | |
out: | |
return NULL; | |
} | |
static int run_server(const char *listen1_addr, uint16_t listen1_port, | |
const char *listen2_addr, uint16_t listen2_port) | |
{ | |
struct server_data data; | |
pthread_t thread; | |
int err = 0; | |
data.fd1 = -1; | |
data.fd2 = -1; | |
atomic_store(&data.got_lost_client, false); | |
data.fd1 = create_tcp_sock(); | |
if (unlikely(data.fd1 < 0)) { | |
err = data.fd1; | |
goto out; | |
} | |
data.fd2 = create_tcp_sock(); | |
if (unlikely(data.fd2 < 0)) { | |
err = data.fd2; | |
goto out; | |
} | |
err = bind_and_listen_tcp_sock(data.fd1, listen1_addr, listen1_port); | |
if (unlikely(err < 0)) | |
goto out; | |
err = bind_and_listen_tcp_sock(data.fd2, listen2_addr, listen2_port); | |
if (unlikely(err < 0)) | |
goto out; | |
atomic_store(&g_stop_server, false); | |
err = pthread_create(&thread, NULL, handle_client_in_server, &data); | |
if (unlikely(err)) { | |
errno = err; | |
perror("pthread_create"); | |
goto out; | |
} | |
pthread_detach(thread); | |
err = handle_lost_control_client(&data); | |
out: | |
if (data.fd1 != -1) | |
close(data.fd1); | |
if (data.fd2 != -1) | |
close(data.fd2); | |
return (err < 0) ? -err : err; | |
} | |
static int init_handshake_from_client(int tcp_fd, struct client_data *data) | |
{ | |
ssize_t send_ret; | |
int err; | |
send_ret = send(tcp_fd, &data->pkt, data->pkt_len, MSG_WAITALL); | |
if (unlikely(send_ret < 0)) { | |
err = errno; | |
perror("send"); | |
return -err; | |
} | |
return 0; | |
} | |
static void *handle_client_in_client(void *data_p) | |
{ | |
struct client_data *data = (struct client_data *)data_p; | |
const struct client_addr_info *addr = data->addr; | |
int fd1 = -1, fd2 = -1; | |
struct pollfd fds[2]; | |
char buffer[4096]; | |
ssize_t opret; | |
int ret; | |
fd1 = create_tcp_sock(); | |
if (unlikely(fd1 < 0)) | |
goto out; | |
fd2 = create_tcp_sock(); | |
if (unlikely(fd2 < 0)) | |
goto out; | |
ret = connect_tcp_sock(fd1, addr->server_addr, addr->server_port); | |
if (unlikely(ret)) | |
goto out; | |
ret = connect_tcp_sock(fd2, addr->target_addr, addr->target_port); | |
if (unlikely(ret)) | |
goto out; | |
ret = init_handshake_from_client(fd1, data); | |
if (unlikely(ret)) | |
goto out; | |
fds[0].fd = fd1; | |
fds[0].events = POLLIN | POLLPRI; | |
fds[1].fd = fd2; | |
fds[1].events = POLLIN | POLLPRI; | |
do_poll: | |
ret = poll(fds, 2, -1); | |
if (unlikely(ret)) { | |
perror("poll"); | |
goto out; | |
} | |
if (fds[0].revents & POLLIN) { | |
opret = recv_and_send(fd1, fd2, buffer, sizeof(buffer)); | |
if (unlikely(opret < 0)) | |
goto out; | |
} | |
if (fds[1].revents & POLLIN) { | |
opret = recv_and_send(fd2, fd1, buffer, sizeof(buffer)); | |
if (unlikely(opret < 0)) | |
goto out; | |
} | |
goto do_poll; | |
out: | |
if (fd1 != -1) | |
close(fd1); | |
if (fd2 != -1) | |
close(fd2); | |
free(data); | |
return NULL; | |
} | |
static int run_client(const char *target_addr, uint16_t target_port, | |
const char *server_addr, uint16_t server_port) | |
{ | |
static const uint32_t max_fail_count = 10; | |
struct client_data *data = NULL; | |
struct client_addr_info addr; | |
int ret = 0, main_fd, err; | |
uint32_t fail_count = 0; | |
struct pollfd fds[1]; | |
ssize_t recv_ret; | |
pthread_t thread; | |
main_fd = create_tcp_sock(); | |
if (unlikely(main_fd < 0)) | |
return -main_fd; | |
err = connect_tcp_sock(main_fd, server_addr, server_port); | |
if (unlikely(err)) { | |
close(main_fd); | |
return -err; | |
} | |
addr.server_addr = server_addr; | |
addr.server_port = server_port; | |
addr.target_addr = target_addr; | |
addr.target_port = target_port; | |
fds[0].fd = main_fd; | |
fds[0].events = POLLIN | POLLPRI; | |
do_alloc: | |
data = (struct client_data *)malloc(sizeof(*data)); | |
if (unlikely(!data)) { | |
perror("malloc"); | |
ret = ENOMEM; | |
goto out; | |
} | |
do_poll: | |
if (unlikely(fail_count >= max_fail_count)) { | |
free(data); | |
ret = 1; | |
goto out; | |
} | |
err = poll(fds, 1, -1); | |
if (unlikely(err < 0)) { | |
perror("poll"); | |
fail_count++; | |
goto do_poll; | |
} | |
recv_ret = recv(main_fd, &data->pkt, sizeof(data->pkt), MSG_WAITALL); | |
if (unlikely(recv_ret <= 0)) { | |
__asm__ volatile ("":"+r"(recv_ret)::); | |
if (!recv_ret) { | |
puts("Disconnected from the server"); | |
free(data); | |
ret = 0; | |
goto out; | |
} | |
if (errno == EAGAIN) | |
goto do_poll; | |
perror("recv"); | |
fail_count++; | |
goto do_poll; | |
} | |
err = pthread_create(&thread, NULL, handle_client_in_client, data); | |
if (unlikely(err)) { | |
errno = err; | |
perror("pthread_create"); | |
fail_count++; | |
goto do_poll; | |
} | |
pthread_detach(thread); | |
goto do_alloc; | |
out: | |
close(main_fd); | |
return ret; | |
} | |
/* | |
* Usage: | |
* ./slc client 127.0.0.1 5555 123.123.123.123 9999 | |
* ./slc server 123.123.123.123 9999 0.0.0.0 9998 | |
*/ | |
int main(int argc, const char *argv[]) | |
{ | |
if (argc != 6) | |
goto print_usage; | |
if (!strcmp(argv[1], "client")) | |
return run_client(argv[2], (uint16_t)atoi(argv[3]), | |
argv[4], (uint16_t)atoi(argv[5])); | |
if (!strcmp(argv[1], "server")) | |
return run_server(argv[2], (uint16_t)atoi(argv[3]), | |
argv[4], (uint16_t)atoi(argv[5])); | |
print_usage: | |
// TODO: print the program usage example | |
return 1; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment