change behavior of reading/writing from websockets

This commit is contained in:
Marco 2024-05-21 23:41:37 +02:00
parent 6f16a5c23b
commit fd2fb3fab6
7 changed files with 83 additions and 114 deletions

View File

@ -68,11 +68,9 @@ func waitForAndHandlePlayerID(ctx context.Context, conn *gorillaws.Conn) {
conn.Close() conn.Close()
return return
} }
if player.Conn.HasWebsocketConnection() {
conn.WriteMessage(msgType, []byte("player already connected"))
return
}
lobby.Game.SetWebsocketConnectionFor(ctx, player, conn) lobby.Game.SetWebsocketConnectionFor(ctx, player, conn)
log.Println("player after setting connection: ") log.Println("player after setting connection: ")
log.Println("id: ", player.Uuid) log.Println("id: ", player.Uuid)
log.Println("color: ", player.GetColor()) log.Println("color: ", player.GetColor())

View File

@ -117,7 +117,7 @@ func (game *Game) Handle() {
log.Println("Error marshalling 'colorDetermined' message for player 1", err) log.Println("Error marshalling 'colorDetermined' message for player 1", err)
return return
} }
game.currentTurnPlayer.writeMessage(invalidMoveMessage) game.currentTurnPlayer.writeMessage(string(invalidMoveMessage))
game.gameState = PlayerToMove game.gameState = PlayerToMove
continue continue
} }
@ -181,10 +181,10 @@ func (game Game) notifyPlayersAboutGameStart() error {
return err return err
} }
game.GetPlayer1().writeMessage(colorDeterminedPlayer1) game.GetPlayer1().writeMessage(string(colorDeterminedPlayer1))
game.GetPlayer1().SendBoardState(types.Move{}, game.board.PGN(), types.White) game.GetPlayer1().SendBoardState(types.Move{}, game.board.PGN(), types.White)
game.GetPlayer2().writeMessage(colorDeterminedPlayer2) game.GetPlayer2().writeMessage(string(colorDeterminedPlayer2))
game.GetPlayer2().SendBoardState(types.Move{}, game.board.PGN(), types.White) game.GetPlayer2().SendBoardState(types.Move{}, game.board.PGN(), types.White)
return nil return nil
} }
@ -223,3 +223,7 @@ func (game *Game) playerDisconnected(p *Player) {
func (game *Game) SetWebsocketConnectionFor(ctx context.Context, p *Player, ws *gorillaws.Conn) { func (game *Game) SetWebsocketConnectionFor(ctx context.Context, p *Player, ws *gorillaws.Conn) {
p.SetWebsocketConnectionAndSendBoardState(ctx, ws, &game.board) p.SetWebsocketConnectionAndSendBoardState(ctx, ws, &game.board)
} }
func (game *Game) SendBoardStateTo(p *Player) {
p.SendBoardState(game.board.getLastMove(), game.board.PGN(), game.board.colorToMove)
}

View File

@ -89,11 +89,8 @@ func (p *Player) SendBoardState(move types.Move, boardPosition string, turnColor
return err return err
} }
err = p.writeMessage(messageToSend) p.writeMessage(string(messageToSend))
if err != nil {
log.Println("Error during message writing:", err)
return err
}
return nil return nil
} }
@ -108,11 +105,8 @@ func (p *Player) SendMoveAndPosition(move types.Move, boardPosition string) erro
return err return err
} }
err = p.writeMessage(messageToSend) p.writeMessage(string(messageToSend))
if err != nil {
log.Println("Error during message writing:", err)
return err
}
return nil return nil
} }
@ -126,26 +120,20 @@ func (p *Player) SendGameEnded(reason GameEndedReason) error {
log.Println("Error while marshalling: ", err) log.Println("Error while marshalling: ", err)
return err return err
} }
err = p.writeMessage(messageToSend) p.writeMessage(string(messageToSend))
if err != nil {
log.Println("Error during message writing:", err)
return err
}
return nil return nil
} }
func (p *Player) writeMessage(msg []byte) error { func (p *Player) writeMessage(msg string) {
return p.Conn.Write(msg) p.Conn.Write(msg)
} }
func (p *Player) ReadMove() (types.Move, error) { func (p *Player) ReadMove() (types.Move, error) {
receivedMessage, err := p.readMessage() receivedMessage := p.readMessage()
if err != nil {
return types.Move{}, err
}
var msg api.WebsocketMessage var msg api.WebsocketMessage
err = json.Unmarshal(receivedMessage, &msg) err := json.Unmarshal(receivedMessage, &msg)
if err != nil { if err != nil {
return types.Move{}, err return types.Move{}, err
} }
@ -157,9 +145,9 @@ func (p *Player) ReadMove() (types.Move, error) {
return *msg.Move, nil return *msg.Move, nil
} }
func (p *Player) readMessage() ([]byte, error) { func (p *Player) readMessage() []byte {
msg, err := p.Conn.Read() msg := p.Conn.Read()
log.Printf("Reading message from %s: %s", p.color.String(), string(msg)) log.Printf("Reading message from %s: %s", p.color.String(), string(msg))
return msg, err return msg
} }

View File

@ -50,7 +50,7 @@ func (b *MessageBuffer) Insert(msg string) {
b.cond.Broadcast() b.cond.Broadcast()
} }
func (b *MessageBuffer) Get() (string, error) { func (b *MessageBuffer) Get() string {
b.cond.L.Lock() b.cond.L.Lock()
defer b.cond.L.Unlock() defer b.cond.L.Unlock()
@ -69,7 +69,7 @@ func (b *MessageBuffer) Get() (string, error) {
} }
b.getIndex = b.incrementAndWrapIndex(b.getIndex) b.getIndex = b.incrementAndWrapIndex(b.getIndex)
return msg.content, nil return msg.content
} }
func (b MessageBuffer) incrementAndWrapIndex(index int) int { func (b MessageBuffer) incrementAndWrapIndex(index int) int {

View File

@ -66,8 +66,7 @@ func Test_MessageBuffer_GetWaitsForFirstData(t *testing.T) {
buf.Insert("delayed-message") buf.Insert("delayed-message")
}() }()
msg, err := buf.Get() msg := buf.Get()
assert.NoError(t, err)
endTime := time.Now() endTime := time.Now()
@ -79,8 +78,7 @@ func Test_MessageBuffer_GetWaitsForNewData(t *testing.T) {
buf := newMessageBuffer(2) buf := newMessageBuffer(2)
buf.Insert("message-1") buf.Insert("message-1")
msg, err := buf.Get() msg := buf.Get()
assert.NoError(t, err)
assert.Equal(t, "message-1", msg) assert.Equal(t, "message-1", msg)
go func() { go func() {
@ -89,8 +87,7 @@ func Test_MessageBuffer_GetWaitsForNewData(t *testing.T) {
buf.Insert("delayed-message") buf.Insert("delayed-message")
}() }()
msg, err = buf.Get() msg = buf.Get()
assert.NoError(t, err)
assert.Equal(t, "delayed-message", msg) assert.Equal(t, "delayed-message", msg)
} }
@ -117,8 +114,7 @@ func Test_MessageBuffer_IndexesAreCorrectAfterOverwritingOldData(t *testing.T) {
}, },
buf.messages) buf.messages)
msg, err := buf.Get() msg := buf.Get()
assert.NoError(t, err)
assert.Equal(t, "message-2", msg) assert.Equal(t, "message-2", msg)
} }
@ -126,13 +122,11 @@ func Test_MessageBuffer_GetWaitsForNewDataIfOldOneWasAlreadyGotten(t *testing.T)
buf := newMessageBuffer(2) buf := newMessageBuffer(2)
buf.Insert(message1) buf.Insert(message1)
msg, err := buf.Get() msg := buf.Get()
assert.NoError(t, err)
assert.Equal(t, message1, msg) assert.Equal(t, message1, msg)
buf.Insert(message2) buf.Insert(message2)
msg, err = buf.Get() msg = buf.Get()
assert.NoError(t, err)
assert.Equal(t, message2, msg) assert.Equal(t, message2, msg)
go func() { go func() {
@ -140,8 +134,7 @@ func Test_MessageBuffer_GetWaitsForNewDataIfOldOneWasAlreadyGotten(t *testing.T)
buf.Insert(message3) buf.Insert(message3)
}() }()
msg, err = buf.Get() msg = buf.Get()
assert.NoError(t, err)
assert.Equal(t, message3, msg) assert.Equal(t, message3, msg)
} }
@ -157,9 +150,8 @@ func Test_MessageBuffer_InsertCatchesUpWithRead(t *testing.T) {
buf.Insert(message6) buf.Insert(message6)
buf.Insert(message7) buf.Insert(message7)
msg, err := buf.Get() msg := buf.Get()
assert.NoError(t, err)
assert.Equal(t, message3, msg) assert.Equal(t, message3, msg)
} }
@ -172,7 +164,7 @@ func Test_MessageBuffer_FuckShitUp(t *testing.T) {
var readMsg = make([]string, 0) var readMsg = make([]string, 0)
go func() { go func() {
for i := 0; i < size*10; i++ { for i := 0; i < size*10; i++ {
msg, _ := buf.Get() msg := buf.Get()
if msg == "99" { if msg == "99" {
break break
} }

View File

@ -4,28 +4,26 @@ import (
"context" "context"
"log" "log"
"mchess_server/types" "mchess_server/types"
"sync"
"github.com/google/uuid" "github.com/google/uuid"
gorillaws "github.com/gorilla/websocket" gorillaws "github.com/gorilla/websocket"
) )
type Connection struct { type Connection struct {
ID uuid.UUID ID uuid.UUID
ws *gorillaws.Conn ws *gorillaws.Conn
wsConnectionEstablished chan bool ctx context.Context
wsWriteLock sync.Mutex rxBuffer *MessageBuffer
ctx context.Context txBuffer *MessageBuffer
buffer MessageBuffer disconnectCallback func()
disconnectCallback func() forColor types.ChessColor
forColor types.ChessColor
} }
func NewConnection(options ...func(*Connection)) *Connection { func NewConnection(options ...func(*Connection)) *Connection {
connection := Connection{ connection := Connection{
ID: uuid.New(), ID: uuid.New(),
buffer: *newMessageBuffer(100), rxBuffer: newMessageBuffer(100),
wsConnectionEstablished: make(chan bool), txBuffer: newMessageBuffer(100),
} }
for _, option := range options { for _, option := range options {
@ -67,6 +65,35 @@ func (conn *Connection) HasWebsocketConnection() bool {
return conn.ws != nil return conn.ws != nil
} }
func (conn *Connection) readFromRxBuffer() {
for {
_, msg, err := conn.ws.ReadMessage()
if err != nil {
conn.logConnection("while reading from websocket: %w", err)
conn.Close("")
conn.txBuffer.Insert("we do this to make txBuffer.Get() return")
return
}
conn.rxBuffer.Insert(string(msg))
}
}
func (conn *Connection) writeTxBuffer() {
for {
msg := conn.txBuffer.Get()
if conn.ws == nil {
return
}
err := conn.ws.WriteMessage(gorillaws.TextMessage, []byte(msg))
if err != nil {
conn.logConnection("while writing to websocket: %w", err)
return
}
}
}
func (conn *Connection) SetWebsocketConnection(ws *gorillaws.Conn) { func (conn *Connection) SetWebsocketConnection(ws *gorillaws.Conn) {
if ws == nil { if ws == nil {
conn.logConnection("ERROR: setting ws = null") conn.logConnection("ERROR: setting ws = null")
@ -75,28 +102,9 @@ func (conn *Connection) SetWebsocketConnection(ws *gorillaws.Conn) {
conn.ws = ws conn.ws = ws
select { go conn.readFromRxBuffer()
case conn.wsConnectionEstablished <- true: go conn.writeTxBuffer()
conn.logConnection("case wsConnectionEstablished <- true")
default:
conn.logConnection("DEFAULT CASE")
}
go func() {
for {
_, msg, err := conn.ws.ReadMessage()
if err != nil {
conn.logConnection("while reading from websocket: %w", err)
conn.unsetWebsocketConnection()
if conn.disconnectCallback != nil {
conn.disconnectCallback()
}
return
}
conn.buffer.Insert(string(msg))
}
}()
defer conn.logConnection("websocket connection set") defer conn.logConnection("websocket connection set")
} }
@ -105,30 +113,15 @@ func (conn *Connection) unsetWebsocketConnection() {
conn.ws = nil conn.ws = nil
} }
func (conn *Connection) Write(msg []byte) error { func (conn *Connection) Write(msg string) {
conn.logConnection("about to write") conn.logConnection("Writing message: ", string(msg))
conn.logConnection("locking") conn.txBuffer.Insert(msg)
conn.wsWriteLock.Lock()
defer conn.logConnection("unlocking")
defer conn.wsWriteLock.Unlock()
if conn.ws == nil { //if ws is not yet set, we wait for it
conn.logConnection("waiting for wsConnectionEstablished channel")
<-conn.wsConnectionEstablished
}
conn.logConnection("Writing message: %s", string(msg))
return conn.ws.WriteMessage(gorillaws.TextMessage, msg)
} }
func (conn *Connection) Read() ([]byte, error) { func (conn *Connection) Read() []byte {
msg, err := conn.buffer.Get() msg := conn.rxBuffer.Get()
if err != nil {
conn.ws = nil
return nil, err // TODO: Tell game-handler that connection was lost
}
return []byte(msg), err return []byte(msg)
} }
func (conn *Connection) Close(msg string) { func (conn *Connection) Close(msg string) {

View File

@ -22,21 +22,15 @@ func GetUsher() *Usher {
} }
func (u *Usher) WelcomeNewPlayer(player *chess.Player) *Lobby { func (u *Usher) WelcomeNewPlayer(player *chess.Player) *Lobby {
lobby := GetLobbyRegistry().GetLobbyForPlayer() return GetLobbyRegistry().GetLobbyForPlayer()
return lobby
} }
func (u *Usher) CreateNewPrivateLobby(player *chess.Player) *Lobby { func (u *Usher) CreateNewPrivateLobby(player *chess.Player) *Lobby {
lobby := GetLobbyRegistry().CreateNewPrivateLobby() return GetLobbyRegistry().CreateNewPrivateLobby()
return lobby
} }
func (u *Usher) FindExistingPrivateLobby(p utils.Passphrase) *Lobby { func (u *Usher) FindExistingPrivateLobby(p utils.Passphrase) *Lobby {
lobby := GetLobbyRegistry().GetLobbyByPassphrase(p) return GetLobbyRegistry().GetLobbyByPassphrase(p)
if lobby == nil || lobby.AreBothPlayersConnected() {
return nil
}
return lobby
} }
func (u *Usher) AddPlayerToLobbyAndStartGameIfFull(player *chess.Player, lobby *Lobby) { func (u *Usher) AddPlayerToLobbyAndStartGameIfFull(player *chess.Player, lobby *Lobby) {