Skip to content

Instantly share code, notes, and snippets.

@ammarfaizi2
Last active May 27, 2022 14:00
Show Gist options
  • Save ammarfaizi2/37f22e88698fe60083889e1d5e8308a1 to your computer and use it in GitHub Desktop.
Save ammarfaizi2/37f22e88698fe60083889e1d5e8308a1 to your computer and use it in GitHub Desktop.
Socket Lost Control
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#ifndef likely
#define likely(EXPR) __builtin_expect(!!(EXPR), 1)
#endif
#ifndef unlikely
#define unlikely(EXPR) __builtin_expect(!!(EXPR), 0)
#endif
#include <poll.h>
#include <pthread.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <signal.h>
#include <unistd.h>
#include <atomic>
#include <cstdlib>
#include <cerrno>
#include <cstdio>
#include <cstdint>
#include <cstring>
#include <mutex>
#include <unordered_map>
enum {
SSS_PUB_CLIENT_ADDR = 1
};
struct sss_pkt {
uint8_t type;
struct sockaddr_in client_addr;
};
struct client_slot {
int fd;
int fd_from_lost;
};
struct server_data {
int fd1;
int fd2;
std::mutex map_lock;
std::unordered_map<uint64_t, struct client_slot *> map;
std::atomic<bool> got_lost_client;
};
struct client_addr_info {
const char *target_addr;
const char *server_addr;
uint16_t target_port;
uint16_t server_port;
};
struct client_data {
const struct client_addr_info *addr;
size_t pkt_len;
struct sss_pkt pkt;
};
static std::atomic<bool> g_stop_server;
static int create_tcp_sock(void)
{
int fd;
fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
fd = errno;
perror("socket");
return -fd;
}
return fd;
}
static int bind_and_listen_tcp_sock(int fd, const char *baddr, uint16_t bport)
{
struct sockaddr_in addr;
int err;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_port = htons(bport);
addr.sin_addr.s_addr = inet_addr(baddr);
err = bind(fd, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0) {
err = errno;
perror("bind");
return -err;
}
err = listen(fd, 10);
if (err < 0) {
err = errno;
perror("listen");
return -err;
}
return 0;
}
static int recv_and_send(int fd_in, int fd_out, char *buffer, size_t len)
{
ssize_t recv_ret;
ssize_t send_ret;
int err;
recv_ret = recv(fd_in, buffer, len, MSG_DONTWAIT);
if (unlikely(recv_ret < 0)) {
err = errno;
perror("recv");
return -err;
}
send_ret = send(fd_out, buffer, (size_t)recv_ret, MSG_DONTWAIT);
if (unlikely(send_ret < 0)) {
err = errno;
perror("send");
return -err;
}
return 0;
}
static inline uint64_t gen_map_key(struct sockaddr_in *addr)
{
return ((uint64_t)addr->sin_addr.s_addr << 16ull) | (uint64_t)addr->sin_port;
}
static int connect_tcp_sock(int fd, const char *addr, uint16_t port)
{
struct sockaddr_in dst_addr;
int err;
memset(&dst_addr, 0, sizeof(dst_addr));
dst_addr.sin_family = AF_INET;
dst_addr.sin_port = htons(port);
dst_addr.sin_addr.s_addr = inet_addr(addr);
printf("Connecting to %s:%u...\n", addr, port);
err = connect(fd, (struct sockaddr *)&dst_addr, sizeof(dst_addr));
if (err < 0) {
err = errno;
perror("connect");
return -err;
}
printf("Connected!\n");
return 0;
}
static int handle_lost_control_client(struct server_data *data)
{
struct sockaddr_in addr;
int fd = data->fd1;
struct sss_pkt pkt;
socklen_t addrlen;
ssize_t recv_ret;
int client_fd;
int ret;
do_accept:
if (atomic_load(&g_stop_server))
return 0;
addrlen = sizeof(addr);
client_fd = accept(fd, (struct sockaddr *)&addr, &addrlen);
if (unlikely(client_fd < 0)) {
ret = errno;
perror("accept");
goto out;
}
recv_ret = recv(client_fd, &pkt, sizeof(pkt), MSG_WAITALL);
if (unlikely(recv_ret)) {
ret = errno;
perror("recv");
goto out;
}
if (recv_ret != sizeof(pkt)) {
close(client_fd);
goto do_accept;
}
{
data->map_lock.lock();
auto it = data->map.find(gen_map_key(&addr));
if (it != data->map.end()) {
struct client_slot *slot = it->second;
slot->fd_from_lost = client_fd;
}
data->map_lock.unlock();
}
goto do_accept;
out:
atomic_store(&g_stop_server, true);
return ret;
}
struct handle_client_in_server_data {
struct client_slot slot;
struct server_data *data;
};
static void *__handle_client_in_server(void *data_p)
{
struct handle_client_in_server_data *hc_data;
int fd1 = -1, fd2 = -1;
struct pollfd fds[2];
char buffer[4096];
ssize_t opret;
int ret;
hc_data = (struct handle_client_in_server_data *)data_p;
while (hc_data->slot.fd_from_lost == -1) {
__asm__ volatile ("":"+r"(hc_data)::);
sleep(1);
if (atomic_load(&g_stop_server))
return 0;
}
fd1 = hc_data->slot.fd;
fd2 = hc_data->slot.fd_from_lost;
fds[0].fd = fd1;
fds[0].events = POLLIN | POLLPRI;
fds[1].fd = fd2;
fds[1].events = POLLIN | POLLPRI;
do_poll:
ret = poll(fds, 2, -1);
if (unlikely(ret)) {
perror("poll");
goto out;
}
if (fds[0].revents & POLLIN) {
opret = recv_and_send(fd1, fd2, buffer, sizeof(buffer));
if (unlikely(opret < 0))
goto out;
}
if (fds[1].revents & POLLIN) {
opret = recv_and_send(fd2, fd1, buffer, sizeof(buffer));
if (unlikely(opret < 0))
goto out;
}
goto do_poll;
out:
if (fd1 != -1)
close(fd1);
if (fd2 != -1)
close(fd2);
delete hc_data;
return NULL;
}
static void _handle_client_in_server(int fd, struct server_data *data,
struct sockaddr_in *addr)
{
struct handle_client_in_server_data *hc_data;
struct sss_pkt pkt;
pthread_t thread;
ssize_t send_ret;
int err;
hc_data = new struct handle_client_in_server_data;
if (unlikely(!hc_data))
return;
hc_data->slot.fd = fd;
hc_data->slot.fd_from_lost = -1;
data->map_lock.lock();
data->map.emplace(gen_map_key(addr), &hc_data->slot);
data->map_lock.unlock();
pkt.type = SSS_PUB_CLIENT_ADDR;
pkt.client_addr = *addr;
send_ret = send(fd, &pkt, sizeof(pkt), MSG_WAITALL);
if (unlikely(send_ret < 0)) {
perror("send");
delete hc_data;
close(client_fd);
return;
}
err = pthread_create(&thread, NULL, __handle_client_in_server, hc_data);
if (unlikely(err)) {
errno = err;
perror("pthread_create");
delete hc_data;
close(client_fd);
return;
}
pthread_detach(thread);
}
static void *handle_client_in_server(void *data_p)
{
struct server_data *data = (struct server_data *)data_p;
struct sockaddr_in addr;
int fd = data->fd1;
socklen_t addrlen;
int client_fd;
while (!atomic_load(&data->got_lost_client)) {
sleep(1);
if (atomic_load(&g_stop_server))
return 0;
}
do_accept:
if (atomic_load(&g_stop_server))
return 0;
addrlen = sizeof(addr);
client_fd = accept(fd, (struct sockaddr *)&addr, &addrlen);
if (unlikely(client_fd < 0)) {
atomic_store(&g_stop_server, true);
goto out;
}
_handle_client_in_server(client_fd, data, &addr);
goto do_accept;
out:
return NULL;
}
static int run_server(const char *listen1_addr, uint16_t listen1_port,
const char *listen2_addr, uint16_t listen2_port)
{
struct server_data data;
pthread_t thread;
int err = 0;
data.fd1 = -1;
data.fd2 = -1;
atomic_store(&data.got_lost_client, false);
data.fd1 = create_tcp_sock();
if (unlikely(data.fd1 < 0)) {
err = data.fd1;
goto out;
}
data.fd2 = create_tcp_sock();
if (unlikely(data.fd2 < 0)) {
err = data.fd2;
goto out;
}
err = bind_and_listen_tcp_sock(data.fd1, listen1_addr, listen1_port);
if (unlikely(err < 0))
goto out;
err = bind_and_listen_tcp_sock(data.fd2, listen2_addr, listen2_port);
if (unlikely(err < 0))
goto out;
atomic_store(&g_stop_server, false);
err = pthread_create(&thread, NULL, handle_client_in_server, &data);
if (unlikely(err)) {
errno = err;
perror("pthread_create");
goto out;
}
pthread_detach(thread);
err = handle_lost_control_client(&data);
out:
if (data.fd1 != -1)
close(data.fd1);
if (data.fd2 != -1)
close(data.fd2);
return (err < 0) ? -err : err;
}
static int init_handshake_from_client(int tcp_fd, struct client_data *data)
{
ssize_t send_ret;
int err;
send_ret = send(tcp_fd, &data->pkt, data->pkt_len, MSG_WAITALL);
if (unlikely(send_ret < 0)) {
err = errno;
perror("send");
return -err;
}
return 0;
}
static void *handle_client_in_client(void *data_p)
{
struct client_data *data = (struct client_data *)data_p;
const struct client_addr_info *addr = data->addr;
int fd1 = -1, fd2 = -1;
struct pollfd fds[2];
char buffer[4096];
ssize_t opret;
int ret;
fd1 = create_tcp_sock();
if (unlikely(fd1 < 0))
goto out;
fd2 = create_tcp_sock();
if (unlikely(fd2 < 0))
goto out;
ret = connect_tcp_sock(fd1, addr->server_addr, addr->server_port);
if (unlikely(ret))
goto out;
ret = connect_tcp_sock(fd2, addr->target_addr, addr->target_port);
if (unlikely(ret))
goto out;
ret = init_handshake_from_client(fd1, data);
if (unlikely(ret))
goto out;
fds[0].fd = fd1;
fds[0].events = POLLIN | POLLPRI;
fds[1].fd = fd2;
fds[1].events = POLLIN | POLLPRI;
do_poll:
ret = poll(fds, 2, -1);
if (unlikely(ret)) {
perror("poll");
goto out;
}
if (fds[0].revents & POLLIN) {
opret = recv_and_send(fd1, fd2, buffer, sizeof(buffer));
if (unlikely(opret < 0))
goto out;
}
if (fds[1].revents & POLLIN) {
opret = recv_and_send(fd2, fd1, buffer, sizeof(buffer));
if (unlikely(opret < 0))
goto out;
}
goto do_poll;
out:
if (fd1 != -1)
close(fd1);
if (fd2 != -1)
close(fd2);
free(data);
return NULL;
}
static int run_client(const char *target_addr, uint16_t target_port,
const char *server_addr, uint16_t server_port)
{
static const uint32_t max_fail_count = 10;
struct client_data *data = NULL;
struct client_addr_info addr;
int ret = 0, main_fd, err;
uint32_t fail_count = 0;
struct pollfd fds[1];
ssize_t recv_ret;
pthread_t thread;
main_fd = create_tcp_sock();
if (unlikely(main_fd < 0))
return -main_fd;
err = connect_tcp_sock(main_fd, server_addr, server_port);
if (unlikely(err)) {
close(main_fd);
return -err;
}
addr.server_addr = server_addr;
addr.server_port = server_port;
addr.target_addr = target_addr;
addr.target_port = target_port;
fds[0].fd = main_fd;
fds[0].events = POLLIN | POLLPRI;
do_alloc:
data = (struct client_data *)malloc(sizeof(*data));
if (unlikely(!data)) {
perror("malloc");
ret = ENOMEM;
goto out;
}
do_poll:
if (unlikely(fail_count >= max_fail_count)) {
free(data);
ret = 1;
goto out;
}
err = poll(fds, 1, -1);
if (unlikely(err < 0)) {
perror("poll");
fail_count++;
goto do_poll;
}
recv_ret = recv(main_fd, &data->pkt, sizeof(data->pkt), MSG_WAITALL);
if (unlikely(recv_ret <= 0)) {
__asm__ volatile ("":"+r"(recv_ret)::);
if (!recv_ret) {
puts("Disconnected from the server");
free(data);
ret = 0;
goto out;
}
if (errno == EAGAIN)
goto do_poll;
perror("recv");
fail_count++;
goto do_poll;
}
err = pthread_create(&thread, NULL, handle_client_in_client, data);
if (unlikely(err)) {
errno = err;
perror("pthread_create");
fail_count++;
goto do_poll;
}
pthread_detach(thread);
goto do_alloc;
out:
close(main_fd);
return ret;
}
/*
* Usage:
* ./slc client 127.0.0.1 5555 123.123.123.123 9999
* ./slc server 123.123.123.123 9999 0.0.0.0 9998
*/
int main(int argc, const char *argv[])
{
if (argc != 6)
goto print_usage;
if (!strcmp(argv[1], "client"))
return run_client(argv[2], (uint16_t)atoi(argv[3]),
argv[4], (uint16_t)atoi(argv[5]));
if (!strcmp(argv[1], "server"))
return run_server(argv[2], (uint16_t)atoi(argv[3]),
argv[4], (uint16_t)atoi(argv[5]));
print_usage:
// TODO: print the program usage example
return 1;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment