Skip to content

Commit 75289c3

Browse files
authored
fix: protect against race condition on shutdown in muxer (#712)
Fixes #710
1 parent 3436922 commit 75289c3

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

muxer/muxer.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ type Muxer struct {
6161
startChan chan bool
6262
doneChan chan bool
6363
waitGroup sync.WaitGroup
64+
waitGroupMutex sync.Mutex
6465
protocolSenders map[uint16]map[ProtocolRole]chan *Segment
6566
protocolReceivers map[uint16]map[ProtocolRole]chan *Segment
6667
protocolReceiversMutex sync.Mutex
@@ -89,7 +90,9 @@ func New(conn net.Conn) *Muxer {
8990
// We must do this to break out of pending Read() calls to shut down cleanly
9091
_ = m.conn.Close()
9192
// Wait for other goroutines to shutdown
93+
m.waitGroupMutex.Lock()
9294
m.waitGroup.Wait()
95+
m.waitGroupMutex.Unlock()
9396
// Close ErrorChan to signify to consumer that we're shutting down
9497
close(m.errorChan)
9598
}()
@@ -136,11 +139,20 @@ func (m *Muxer) sendError(err error) {
136139
}
137140

138141
// RegisterProtocol registers the provided protocol ID with the muxer. It returns a channel for sending,
139-
// a channel for receiving, and a channel to know when the muxer is shutting down
142+
// a channel for receiving, and a channel to know when the muxer is shutting down. If the muxer is shutting
143+
// down, this function will return nil values.
140144
func (m *Muxer) RegisterProtocol(
141145
protocolId uint16,
142146
protocolRole ProtocolRole,
143147
) (chan *Segment, chan *Segment, chan bool) {
148+
m.waitGroupMutex.Lock()
149+
defer m.waitGroupMutex.Unlock()
150+
// Check for shutdown
151+
select {
152+
case <-m.doneChan:
153+
return nil, nil, nil
154+
default:
155+
}
144156
// Generate channels
145157
senderChan := make(chan *Segment, 10)
146158
receiverChan := make(chan *Segment, 10)

protocol/protocol.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ func (p *Protocol) Start() {
126126
p.config.ProtocolId,
127127
muxerProtocolRole,
128128
)
129+
if p.muxerDoneChan == nil {
130+
p.SendError(fmt.Errorf("could not register protocol with muxer"))
131+
return
132+
}
129133

130134
// Create channels
131135
p.sendQueueChan = make(chan Message, 50)

0 commit comments

Comments
 (0)