I have tried the latest source. see axum_ws_rooms
Cargo.toml
[package]
name = "axum-ws-room"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "axum_ws_rooms"
path = "src/lib.rs"
[dependencies]
axum = { version = "0.7.4", features = ["multipart", "tracing", "ws"] }
axum-extra = { version = "0.9.3", features = ["query"] }
tokio = { version = "1.15.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = "0.3"
futures-util = { version = "0.3", default-features = false }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
main.rs
use std::sync::Arc;
use axum::{
extract::{
ws::{Message, WebSocket},
WebSocketUpgrade,
},
response::IntoResponse,
routing::get,
Extension, Router,
};
use axum_ws_rooms::RoomsManager;
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tracing::{debug, info, Level};
struct State {
rooms: RoomsManager<Vec<u8>>,
}
#[derive(Serialize, Deserialize)]
struct RequestInfo {
user_id: i32,
room_name: String,
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt().with_max_level(Level::INFO).init();
let rooms = axum_ws_rooms::RoomsManager::new();
rooms.new_room("test".into(), None).await;
let state = Arc::new(State { rooms });
// build our application with a single route
let app = Router::new()
.route("/", get(websocket_handler))
.layer(Extension(state));
// run it with axum on localhost:5000
let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await.unwrap();
info!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app.into_make_service())
.await
.unwrap();
}
async fn websocket_handler(
ws: WebSocketUpgrade,
Extension(state): Extension<Arc<State>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket(socket, state))
}
async fn websocket(stream: WebSocket, state: Arc<State>) {
let (mut socekt_sender, mut socket_receiver) = stream.split();
let mut request = RequestInfo {
user_id: 0,
room_name: String::new(),
};
while let Some(Ok(Message::Text(text))) = socket_receiver.next().await {
let info = serde_json::from_str(&text);
match info {
Ok(info) => {
request = info;
break;
}
Err(error) => {
socekt_sender
.send(Message::Text(error.to_string()))
.await
.unwrap();
socekt_sender.close().await.unwrap();
return;
}
}
}
state
.rooms
.init_user(request.user_id.to_string(), None)
.await;
let user_sender = state
.rooms
.join_room(request.room_name, request.user_id.to_string())
.await;
let user_sender = match user_sender {
Ok(user_sender) => user_sender,
Err(error) => {
socekt_sender
.send(Message::Text(error.to_string()))
.await
.unwrap();
socekt_sender.close().await.unwrap();
return;
}
};
let user_receiver = state
.rooms
.get_user_receiver(request.user_id.to_string())
.await;
let mut user_receiver = match user_receiver {
Ok(user_receiver) => user_receiver,
Err(error) => {
socekt_sender
.send(Message::Text(error.to_string()))
.await
.unwrap();
socekt_sender.close().await.unwrap();
return;
}
};
debug!("start receive send loop ...");
let mut send_task = tokio::spawn(async move {
let user_id = request.user_id as u32;
while let Ok(message) = user_receiver.recv().await {
let mut msg_id: u32 = 0;
for i in 0..4 {
msg_id += (message[i] as u32) << ((3 - i) * 8);
}
if msg_id == user_id {
continue; // This message was sent from own
}
match message[4] {
0 => {
// binary
socekt_sender
.send(Message::Binary(message[5..].to_vec()))
.await
.unwrap();
}
1 => {
// text
socekt_sender
.send(Message::Text(
String::from_utf8(message[5..].to_vec()).unwrap(),
))
.await
.unwrap();
}
2.. => {}
};
}
});
let mut recv_task = tokio::spawn(async move {
let user_id = request.user_id as u32;
let mut header = vec![0u8; 5]; // user_id (0 ~ 3) + message_type (4)
for i in 0..4 {
header[i] = (user_id >> ((3 - i) * 8)) as u8;
}
while let Some(Ok(message)) = socket_receiver.next().await {
match message {
Message::Binary(binary) => {
header[4] = 0;
user_sender.send([header.clone(), binary].concat()).unwrap();
}
Message::Text(text) => {
header[4] = 1;
user_sender
.send([header.clone(), text.into_bytes()].concat())
.unwrap();
}
Message::Ping(_vec) => {}
Message::Pong(_vec) => {}
Message::Close(_close_frame) => {}
}
}
});
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
};
state.rooms.end_user(request.user_id.to_string()).await;
println!("connection closed");
}
js
const socket = new WebSocket("ws://localhost:5000");
socket.addEventListener("open", (event) => {
console.log("Connect to server ", event.data);
socket.send(JSON.stringify({"user_id": 0, "room_name": "test"}));
});
socket.addEventListener("close", (event) => {
console.log("Close socket ", event.data);
});
socket.addEventListener("message", (event) => {
console.log("Message from server ", event.data);
});