Last active
December 1, 2023 16:52
-
-
Save neel-bp/8b4e16197f1421478b81faec88b012b2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"encoding/json" | |
"fmt" | |
"io" | |
"log" | |
"net/http" | |
"time" | |
"github.com/gin-gonic/gin" | |
"github.com/gorilla/websocket" | |
"github.com/redis/go-redis/v9" | |
) | |
const ( | |
// Time allowed to write a message to the peer. | |
WRITE_WAIT = 10 * time.Second | |
// Time allowed to read the next pong message from the peer. | |
PONG_WAIT = 60 * time.Second | |
// Send pings to peer with this period. Must be less than pongWait. | |
PING_PERIOD = (PONG_WAIT * 9) / 10 | |
// Maximum message size allowed from peer. | |
MAX_MESSAGE_SIZE = 512 | |
) | |
var Upgrader = websocket.Upgrader{ | |
CheckOrigin: func(r *http.Request) bool { | |
return true | |
}, | |
ReadBufferSize: 2048, | |
WriteBufferSize: 2048, | |
Subprotocols: []string{"name"}, | |
} | |
var RDB *redis.Client | |
type Message struct { | |
From string `json:"from"` | |
Message string `json:"message"` | |
} | |
func (m Message) String() string { | |
b, _ := json.Marshal(m) | |
return string(b) | |
} | |
func (m Message) Byte() []byte { | |
b, _ := json.Marshal(m) | |
return b | |
} | |
func InitializeRedis() error { | |
if RDB == nil { | |
rdb := redis.NewClient(&redis.Options{ | |
Addr: "localhost:6379", | |
}) | |
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | |
defer cancel() | |
err := rdb.Ping(ctx).Err() | |
if err != nil { | |
return err | |
} | |
RDB = rdb | |
return nil | |
} | |
return nil | |
} | |
func AuthMiddleware(c *gin.Context) { | |
if !websocket.IsWebSocketUpgrade(c.Request) { | |
c.JSON(400, gin.H{ | |
"message": "not a websocket upgrade", | |
}) | |
c.Abort() | |
return | |
} | |
subprotocols := websocket.Subprotocols(c.Request) | |
if len(subprotocols) < 2 { | |
c.JSON(400, gin.H{ | |
"message": "wrong format of subprotocols", | |
}) | |
c.Abort() | |
return | |
} | |
if subprotocols[0] != "name" { | |
c.JSON(400, gin.H{ | |
"message": "wrong format of subprotocols", | |
}) | |
c.Abort() | |
return | |
} | |
c.Set("name", subprotocols[1]) | |
c.Next() | |
} | |
func RoomSocket(c *gin.Context) { | |
room, ok := c.Params.Get("room") | |
if !ok { | |
c.JSON(400, gin.H{ | |
"message": "no room code in url", | |
}) | |
return | |
} | |
name := c.GetString("name") | |
if name == "" { | |
c.JSON(400, gin.H{ | |
"message": "no name provided", | |
}) | |
return | |
} | |
isMember, err := RDB.SIsMember(c.Request.Context(), room, name).Result() | |
if err != nil { | |
c.JSON(500, gin.H{ | |
"message": err.Error(), | |
}) | |
return | |
} | |
if isMember { | |
c.JSON(400, gin.H{ | |
"message": "name already taken", | |
}) | |
return | |
} | |
conn, err := Upgrader.Upgrade(c.Writer, c.Request, nil) | |
if err != nil { | |
c.JSON(500, gin.H{ | |
"message": err.Error(), | |
}) | |
return | |
} | |
defer conn.Close() | |
if err := RDB.SAdd(c.Request.Context(), room, name).Err(); err != nil { | |
conn.WriteMessage(websocket.CloseInternalServerErr, []byte(err.Error())) | |
return | |
} | |
// instead of directly processing messages from redis channel, maybe create an intermediary buffer which writepump listens too, but redis pubsub pushes message into that intermediary buffer | |
pubsub := RDB.Subscribe(c.Request.Context(), room) | |
defer pubsub.Close() | |
// NOTE: signal channel for cleanup | |
breaker := make(chan struct{}, 1) | |
writerChan := make(chan Message, 512) | |
pubsubChan := pubsub.Channel() | |
go pubsubQueue(pubsubChan, writerChan) | |
go readPump(c.Request.Context(), conn, breaker, writerChan, room, name) | |
go writePump(conn, breaker, writerChan, pubsubChan) | |
<-breaker | |
// using null context here because i want them sent regardless of request being cancelled | |
RDB.SRem(context.Background(), room, name) | |
RDB.Publish(context.Background(), room, Message{From: "<server>", Message: fmt.Sprintf("%s nigga has left the chat", name)}.Byte()) | |
} | |
func pubsubQueue(pubsumMessage <-chan *redis.Message, writeQueue chan Message) { | |
for msg := range pubsumMessage { | |
m := Message{} | |
err := json.Unmarshal([]byte(msg.Payload), &m) | |
if err == nil { | |
writeQueue <- m | |
} | |
} | |
} | |
func readPump(ctx context.Context, conn *websocket.Conn, breakerChan chan struct{}, writeQueue chan Message, room, name string) { | |
conn.SetReadLimit(MAX_MESSAGE_SIZE) | |
conn.SetReadDeadline(time.Now().Add(PONG_WAIT)) | |
conn.SetPongHandler(func(string) error { | |
conn.SetReadDeadline(time.Now().Add(PONG_WAIT)) | |
return nil | |
}) | |
defer func() { | |
breakerChan <- struct{}{} | |
}() | |
for { | |
var msg Message | |
typ, reader, err := conn.NextReader() | |
if err != nil { | |
return | |
} | |
if typ != websocket.TextMessage { | |
writeQueue <- Message{ | |
From: "<server>", | |
Message: "wrong type of messge please send text message when communicating", | |
} | |
continue | |
} | |
byt, err := io.ReadAll(reader) | |
if err != nil { | |
return | |
} | |
msg.Message = string(byt) | |
msg.From = name | |
err = RDB.Publish(ctx, room, msg.Byte()).Err() | |
if err != nil { | |
writeQueue <- Message{ | |
From: "<server>", | |
Message: err.Error(), | |
} | |
} | |
} | |
} | |
func writePump(conn *websocket.Conn, breakerChan chan struct{}, writeQueue chan Message, pubsubMessage <-chan *redis.Message) { | |
ticker := time.NewTicker(PING_PERIOD) | |
defer func() { | |
ticker.Stop() | |
breakerChan <- struct{}{} | |
}() | |
for { | |
select { | |
case wmsg := <-writeQueue: | |
conn.SetWriteDeadline(time.Now().Add(WRITE_WAIT)) | |
w, err := conn.NextWriter(websocket.TextMessage) | |
if err != nil { | |
return | |
} | |
w.Write(wmsg.Byte()) | |
if err := w.Close(); err != nil { | |
return | |
} | |
case <-ticker.C: | |
conn.SetWriteDeadline(time.Now().Add(WRITE_WAIT)) | |
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { | |
return | |
} | |
} | |
} | |
} | |
func main() { | |
if err := InitializeRedis(); err != nil { | |
log.Fatal(err) | |
} | |
router := gin.Default() | |
wsGroup := router.Group("/ws") | |
wsGroup.Use(AuthMiddleware) | |
wsGroup.GET("/:room", RoomSocket) | |
if err := router.Run(":8080"); err != nil { | |
log.Fatal(err) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment