mchess-server/chess/pawn.go

193 lines
5.2 KiB
Go

package chess
import (
"mchess_server/types"
"github.com/samber/lo"
)
type Pawn struct {
Color types.ChessColor
}
func (p Pawn) GetAllAttackedSquares(board Board, fromSquare types.Coordinate) []types.Coordinate {
attackingMoves := make([]types.Coordinate, 0, 2)
allMoves := p.GetAllNonBlockedSquares(board, fromSquare)
for _, move := range allMoves {
if move.Col != fromSquare.Col {
attackingMoves = append(attackingMoves, move)
}
}
return attackingMoves
}
func (p Pawn) GetAllNonBlockedSquares(board Board, fromSquare types.Coordinate) []types.Coordinate {
theoreticalSquares := p.getAllMoves(fromSquare)
legalSquares := p.filterBlockedSquares(board, fromSquare, theoreticalSquares)
return legalSquares
}
func (p Pawn) GetColor() types.ChessColor {
return p.Color
}
func (p *Pawn) HandlePossiblePromotion(b *Board, move types.Move) (bool, Violation) {
var isPromotionMove bool
var promotionToPiece types.PieceShortName
//TODO(m): What if message does not contain a promotion, but should be because a pawn is moved to the end square
messageContainsPromotion := move.IsPromotionMove()
if messageContainsPromotion {
promotionToPiece = types.PieceShortName(*move.PromotionToPiece)
}
switch move.ColorMoved {
case types.White:
if move.StartSquare.Row == types.RangeLastValid-1 &&
move.EndSquare.Row == types.RangeLastValid {
isPromotionMove = true
}
case types.Black:
if move.StartSquare.Row == types.RangeFirstValid+1 &&
move.EndSquare.Row == types.RangeFirstValid {
isPromotionMove = true
}
}
if isPromotionMove {
delete(b.position, move.StartSquare)
b.position[move.EndSquare] = GetPieceForShortName(promotionToPiece)
}
return isPromotionMove, ""
}
func (p *Pawn) HandleEnPassant(b *Board, move, lastMove types.Move) (bool, Violation) {
var wasEnPassant bool
if lastMove.PieceMoved.ToCommon() != types.PawnShortName {
return false, ""
}
switch move.ColorMoved {
case types.White:
if lastMove.StartSquare.Row != 7 || lastMove.EndSquare.Row != 5 {
wasEnPassant = false
break
}
if move.StartSquare.Row != 5 {
wasEnPassant = false
break
}
if move.EndSquare.Row != 6 {
wasEnPassant = false
break
}
if move.StartSquare.Col == lastMove.EndSquare.Col+1 &&
move.EndSquare.Col == lastMove.EndSquare.Col {
wasEnPassant = true
break
}
if move.StartSquare.Col == lastMove.EndSquare.Col-1 &&
move.EndSquare.Col == lastMove.EndSquare.Col {
wasEnPassant = true
break
}
case types.Black:
if lastMove.StartSquare.Row != 2 || lastMove.EndSquare.Row != 4 {
wasEnPassant = false
break
}
if move.StartSquare.Row != 4 {
wasEnPassant = false
break
}
if move.EndSquare.Row != 3 {
wasEnPassant = false
break
}
if move.StartSquare.Col == lastMove.EndSquare.Col+1 &&
move.EndSquare.Col == lastMove.EndSquare.Col {
wasEnPassant = true
break
}
if move.StartSquare.Col == lastMove.EndSquare.Col-1 &&
move.EndSquare.Col == lastMove.EndSquare.Col {
wasEnPassant = true
break
}
}
if wasEnPassant { //play the move
delete(b.position, lastMove.EndSquare) // take opponent's pawn
delete(b.position, move.StartSquare) // move moving pawn
b.position[move.EndSquare] = GetPieceForShortName(move.PieceMoved)
}
return wasEnPassant, ""
}
func (p Pawn) getAllMoves(fromSquare types.Coordinate) []types.Coordinate {
theoreticalMoves := make([]types.Coordinate, 0, 4)
switch p.Color {
case types.Black:
firstMove := fromSquare.Row == types.RangeLastValid-1
if fromSquare.Down(1) != nil {
theoreticalMoves = append(theoreticalMoves, *fromSquare.Down(1))
}
if firstMove && fromSquare.Down(2) != nil {
theoreticalMoves = append(theoreticalMoves, *fromSquare.Down(2))
}
if lowerRight := fromSquare.Down(1).Right(1); lowerRight != nil {
theoreticalMoves = append(theoreticalMoves, *lowerRight)
}
if lowerLeft := fromSquare.Down(1).Left(1); lowerLeft != nil {
theoreticalMoves = append(theoreticalMoves, *lowerLeft)
}
case types.White:
firstMove := fromSquare.Row == types.RangeFirstValid+1
if fromSquare.Up(1) != nil {
theoreticalMoves = append(theoreticalMoves, *fromSquare.Up(1))
}
if firstMove && fromSquare.Up(2) != nil {
theoreticalMoves = append(theoreticalMoves, *fromSquare.Up(2))
}
if upperRight := fromSquare.Up(1).Right(1); upperRight != nil {
theoreticalMoves = append(theoreticalMoves, *upperRight)
}
if upperLeft := fromSquare.Up(1).Left(1); upperLeft != nil {
theoreticalMoves = append(theoreticalMoves, *upperLeft)
}
}
return theoreticalMoves
}
func (p Pawn) filterBlockedSquares(board Board, fromSquare types.Coordinate, squaresToBeFiltered []types.Coordinate) []types.Coordinate {
var nonBlockedSquares []types.Coordinate
//order of movesToBeFiltered is important here
for _, square := range squaresToBeFiltered {
pieceAtSquare := board.getPieceAt(square)
if square.Col == fromSquare.Col { // squares ahead
if pieceAtSquare == nil {
nonBlockedSquares = append(nonBlockedSquares, square)
}
} else { //squares that pawn attacks
if pieceAtSquare != nil && pieceAtSquare.GetColor() != p.Color {
nonBlockedSquares = append(nonBlockedSquares, square)
}
}
}
return lo.Intersect(nonBlockedSquares, squaresToBeFiltered)
}
func (p Pawn) AfterMoveAction(board *Board, fromSquare types.Coordinate) {}