Skip to content

Instantly share code, notes, and snippets.

@TLabAltoh
Last active June 2, 2025 12:56
Show Gist options
  • Save TLabAltoh/b1a2983389805ade18a04a61890652df to your computer and use it in GitHub Desktop.
Save TLabAltoh/b1a2983389805ade18a04a61890652df to your computer and use it in GitHub Desktop.
Sample of implementing a room chat function using WebSocket in Rust's Axum library

rust-axum-ws-room_practice

I have tried the latest source. see axum_ws_rooms

Cargo.toml
[package]
name = "axum-ws-room"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[lib]
name = "axum_ws_rooms"
path = "src/lib.rs"

[dependencies]
axum = { version = "0.7.4", features = ["multipart", "tracing", "ws"] }
axum-extra = { version = "0.9.3", features = ["query"] }
tokio = { version = "1.15.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = "0.3"
futures-util = { version = "0.3", default-features = false }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
main.rs
use std::sync::Arc;

use axum::{
    extract::{
        ws::{Message, WebSocket},
        WebSocketUpgrade,
    },
    response::IntoResponse,
    routing::get,
    Extension, Router,
};
use axum_ws_rooms::RoomsManager;
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tracing::{debug, info, Level};

struct State {
    rooms: RoomsManager<Vec<u8>>,
}

#[derive(Serialize, Deserialize)]
struct RequestInfo {
    user_id: i32,
    room_name: String,
}

#[tokio::main]
async fn main() {
    tracing_subscriber::fmt().with_max_level(Level::INFO).init();

    let rooms = axum_ws_rooms::RoomsManager::new();

    rooms.new_room("test".into(), None).await;

    let state = Arc::new(State { rooms });

    // build our application with a single route
    let app = Router::new()
        .route("/", get(websocket_handler))
        .layer(Extension(state));

    // run it with axum on localhost:5000
    let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await.unwrap();
    info!("listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app.into_make_service())
        .await
        .unwrap();
}

async fn websocket_handler(
    ws: WebSocketUpgrade,
    Extension(state): Extension<Arc<State>>,
) -> impl IntoResponse {
    ws.on_upgrade(|socket| websocket(socket, state))
}

async fn websocket(stream: WebSocket, state: Arc<State>) {
    let (mut socekt_sender, mut socket_receiver) = stream.split();
    let mut request = RequestInfo {
        user_id: 0,
        room_name: String::new(),
    };

    while let Some(Ok(Message::Text(text))) = socket_receiver.next().await {
        let info = serde_json::from_str(&text);
        match info {
            Ok(info) => {
                request = info;
                break;
            }
            Err(error) => {
                socekt_sender
                    .send(Message::Text(error.to_string()))
                    .await
                    .unwrap();
                socekt_sender.close().await.unwrap();
                return;
            }
        }
    }

    state
        .rooms
        .init_user(request.user_id.to_string(), None)
        .await;

    let user_sender = state
        .rooms
        .join_room(request.room_name, request.user_id.to_string())
        .await;
    let user_sender = match user_sender {
        Ok(user_sender) => user_sender,
        Err(error) => {
            socekt_sender
                .send(Message::Text(error.to_string()))
                .await
                .unwrap();
            socekt_sender.close().await.unwrap();
            return;
        }
    };

    let user_receiver = state
        .rooms
        .get_user_receiver(request.user_id.to_string())
        .await;
    let mut user_receiver = match user_receiver {
        Ok(user_receiver) => user_receiver,
        Err(error) => {
            socekt_sender
                .send(Message::Text(error.to_string()))
                .await
                .unwrap();
            socekt_sender.close().await.unwrap();
            return;
        }
    };

    debug!("start receive send loop ...");

    let mut send_task = tokio::spawn(async move {
        let user_id = request.user_id as u32;
        while let Ok(message) = user_receiver.recv().await {
            let mut msg_id: u32 = 0;
            for i in 0..4 {
                msg_id += (message[i] as u32) << ((3 - i) * 8);
            }
            if msg_id == user_id {
                continue; // This message was sent from own
            }
            match message[4] {
                0 => {
                    // binary
                    socekt_sender
                        .send(Message::Binary(message[5..].to_vec()))
                        .await
                        .unwrap();
                }
                1 => {
                    // text
                    socekt_sender
                        .send(Message::Text(
                            String::from_utf8(message[5..].to_vec()).unwrap(),
                        ))
                        .await
                        .unwrap();
                }
                2.. => {}
            };
        }
    });

    let mut recv_task = tokio::spawn(async move {
        let user_id = request.user_id as u32;
        let mut header = vec![0u8; 5]; // user_id (0 ~ 3) + message_type (4)
        for i in 0..4 {
            header[i] = (user_id >> ((3 - i) * 8)) as u8;
        }
        while let Some(Ok(message)) = socket_receiver.next().await {
            match message {
                Message::Binary(binary) => {
                    header[4] = 0;
                    user_sender.send([header.clone(), binary].concat()).unwrap();
                }
                Message::Text(text) => {
                    header[4] = 1;
                    user_sender
                        .send([header.clone(), text.into_bytes()].concat())
                        .unwrap();
                }
                Message::Ping(_vec) => {}
                Message::Pong(_vec) => {}
                Message::Close(_close_frame) => {}
            }
        }
    });

    tokio::select! {
        _ = (&mut send_task) => recv_task.abort(),
        _ = (&mut recv_task) => send_task.abort(),
    };

    state.rooms.end_user(request.user_id.to_string()).await;

    println!("connection closed");
}
js
const socket = new WebSocket("ws://localhost:5000");

socket.addEventListener("open", (event) => {
  console.log("Connect to server ", event.data);
  socket.send(JSON.stringify({"user_id": 0, "room_name": "test"}));
});

socket.addEventListener("close", (event) => {
  console.log("Close socket ", event.data);
});

socket.addEventListener("message", (event) => {
  console.log("Message from server ", event.data);
});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment