Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove reference counters in the concurrent doubly-linked list used in BufferedChannel and Semaphore #4302

Draft
wants to merge 15 commits into
base: develop
Choose a base branch
from

Conversation

de-shyt
Copy link

@de-shyt de-shyt commented Dec 18, 2024

Old Implementation

In the old implementation counters were used to track how many pointers reference each segment. When a pointer moves to the next live segment, the counter of the current segment decreases, and the counter of the next segment increases.

A segment becomes logically removed under two conditions:

  1. All cells are interrupted.
  2. There are no pointers on the segment (counter = 0).

Physical removal happens after all pointers referencing the segment have moved forward.

New Implementation

In the new implementation counters are no longer used. The logical removal now depends on only one condition — all cells are in the interrupted state.

When physical removal occurs, pointers referencing the removed segment are subject to being moved to the next live segment (this is always possible because the tail of the list is never marked as removed).

Methods

remove + movePointersForwardFromRemovedSegment

In the old implementation, all pointers moved to another segment first, and only then could the segment become logically removed.

In the new implementation, removal does not depend on the position of pointers: a segment can be physically removed while pointers remain on it. They need to be manually moved to prevent memory leaks.

    override fun remove() {
        super.remove()
+       channel.movePointersForwardFromRemovedSegment(this)
    }

First, the remove method of the base class is called, which updates the links between neighbouring segments, removing the current segment from the list. Then, movePointersForwardFromRemovedSegment is called, which moves pointers from the removed segment to the nearest live segment on the right:

+   internal fun movePointersForwardFromRemovedSegment(from: ChannelSegment<E>) {
+       if (!from.isRemoved) return
+       if (from == sendSegment.value) sendSegment.moveToSpecifiedOrLast(from.id, from)
+       if (from == receiveSegment.value) receiveSegment.moveToSpecifiedOrLast(from.id, from)
+       if (from == bufferEndSegment.value) bufferEndSegment.moveToSpecifiedOrLast(from.id, from)
+   }

isLeftmostOrProcessed

In the old implementation, the cleanPrev method was triggered based on the condition segment.id * SEGMENT_SIZE < channel.sendersCounter (or segment.id * SEGMENT_SIZE < channel.receiversCounter) within the request. If this condition was met, the prev link was cleared.

In the new implementation, comparison with counters is incorrect. When remove is called, pointers move from the removed segment to the nearest live segment on the right (regardless of which cell contains the last request). As a result, on the new segment, the value id * SEGMENT_SIZE may become greater than the counter value. The segment would not be considered the leftmost one, even if both sendSegment and receiveSegment are already on it.

Instead of comparing with counters, the implementation now compares the id values of the pointers:

    internal val sendSegmentId: Long get() = sendSegment.id
    internal val receiveSegmentId: Long get() = receiveSegment.id
    
    override val isLeftmostOrProcessed: Boolean get() = 
        id <= channel.sendSegmentId && id <= channel.receiveSegmentId

If the isLeftmostOrProcessed condition is met, it means all previous segments have been processed. They are no longer needed in the list and should become inaccessible.

cleanPrev invocations

In the old implementation, a segment was removed only when no pointers referenced it. This ensured that cleanPrev was called before remove on the leftmost segment (remove could not be called first. If it was the leftmost segment, then a pointer was still referencing it. The algorithm would first reach the end of the branch, where cleanPrev would be called if necessary. Then, in a new request, moveForward would be called, advancing the pointer and triggering remove if needed).

In the new implementation, there is no guarantee on the order of cleanPrev and remove. A segment is marked as logically removed ⇒ a pointer skips over it ⇒ cleanPrev located in the request branch is not called. For example, the following failed scenario occurs:

                Thread 1                │                 Thread 2                               
                                        │                                                        
                                        │                                                        
send(2): buffered                       │                                                        
                                        │                                                        
send(2): suspend + cancel               │                                                        
                                        │                                                        
receive(): 2                            │                                                        
  r=0, id=0                             │                                                        
  segm=#1                               │                                                        
  state: buffered->done_rcv             │                                                        
  expandBuffer():                       │                                                        
    b=1, s=2, s<=b: false               │                                                        
    segm=findSegmentBufferEnd(1,#1): #2 │                                                        
      #1.findSegmentInternal(id=1): #2  │                                                        
      EB.moveForward(#2): true          │                                                        
      return #2                         │                                                        
    state: Coroutine->resuming_eb       │                                                        
    tryResume(): false                  │                                                        
    state: resuming_eb->int_send        │                                                        
    #2.onCancelledRequest():            │                                                        
      cleanedSlots.incAndGet()          │                                                        
      isRemoved: true                   │                                                        
      remove():                         │                                                        
                                        │ receive(): suspend                                      
                                        │   r=1, id=1                                             
                                        │   segm=findSegmentReceive(1,#1): null                  
                                        │     #1.findSegmentInternal(id=1): #3                   
                                        │       cur=#1                                            
                                        │       #1.id<1: true, cur=#2                            
                                        │       #2.id<1: false, #2.isRemoved: true, #2.next: null
                                        │       #2.trySetNext(#3): true, #2.isRemoved: true      
                                        │         #2.remove():                                   
                                        │           prev=#1, next=#3                             
                                        │           #3.prev=#1                                   
                                        │           #1.next=#3                                   
                                        │       cur=#3                                           
                                        │       #3.id<1: false, #3.isRemoved: false              
                                        │       return #3                                        
                                        │     R.moveForward(#3): true                                 
                                        │     return null                                        
                                        │   r=2, id=2                                            
                                        │   segm=#3                                              
                                        │   state: null->Coroutine                               
                                        │   expandBuffer():                                     
                                        │     b=2, s=2, s<=b: true                              
                                        │     EB.moveToSpecifiedOrLast(2,#2):                    
                                        │       #2.findSpecifiedOrLast(id=2): #3                 
                                        │       EB.moveForward(#3): true 
                                        |         #2.isRemoved: true
                                        |         #2.remove():
                                        |           prev=#1, next=#3 
                                        |          #3.next=#1
                                        |          #1.next=#3
         prev=#1, next=#3               │                                                        
         #3.next=#1                     │                                                        
         #1.next=#3                     │                                                        
     b=3, s=2, s<=b: true               │                                                        
     segm=#3, #3.id<3: true             │                                                        
   #1.cleanPrev()                       │                                                        
   return 2                             │                                                        
Full stacktrace ``` │ Thread 1 │ Thread 2 │ send(2): │ s=0, id=0 │ segm=#1 │ state: null->buffered │ │ send(2): suspend + cancel │ s=1, id=1 │ segm=findSegmentSend(1,#1): #2#1.findSegmentInternal(id=1): #2 │ cur=#1#1.id<1: true, #1.next: null │ #1.trySetNext(#2): true, #1.isRemoved: false│ cur=#2#2.id<1: false, #2.isRemoved: false │ return #2 │ S.cas(#1,#2): true │ return #2 │ state: null->Coroutine │ Coroutine cancelled │ │ receive(): 2 │ r=0, id=0 │ segm=#1 │ state: buffered->done_rcv │ expandBuffer(): │ b=1, s=2, s<=b: false │ segm=findSegmentBufferEnd(1,#1): #2#1.findSegmentInternal(id=1): #2 │ cur=#1#1.id<1: true, cur=#2#2.id<1: false, #2.isRemoved: false │ return #2 │ EB.cas(#1,#2): true │ return #2 │ state: Coroutine->resuming_eb │ tryResume(): false │ state: resuming_eb->int_send │ #2.onCancelledRequest(): │ cleanedSlots.incAndGet() │ isRemoved: true │ remove(): │ │ receive(): suspend │ r=1, id=1 │ segm=findSegmentReceive(1,#1): null │ #1.findSegmentInternal(id=1): #3 │ cur=#1#1.id<1: true, cur=#2#2.id<1: false, #2.isRemoved: true, #2.next: null │ #2.trySetNext(#3): true, #2.isRemoved: true │ #2.remove(): │ prev=#1, next=#3#3.prev=#1#1.next=#3 │ cur=#3#3.id<1: false, #3.isRemoved: false │ return #3 │ R.cas(#1,#3): true │ return null │ r=2, id=2 │ segm=#3 │ state: null->Coroutine │ expandBuffer(): │ b=2, s=2, s<=b: true │ EB.moveToSpecifiedOrLast(2,#2): │ #2.findSpecifiedOrLast(id=2): #3 │ cur=#2#2.id<2: true, cur=#3#3.id<2: false, break │ return #3 │ EB.cas(#2,#3): true prev=#1, next=#3#3.next=#1#1.next=#3 │ b=3, s=2, s<=b: true │ segm=#3, #3.id<3: true │ BE.moveToSpecifiedOrLast(3,#3): │ #3.findSpecifiedOrLast(id=3): #3 │ cur=#3#3.id<3: true, #3.next: null, break │ return #3 │ EB.cas(#3, #3): true │ #1.cleanPrev() │ return 2 │ ```
Final state:
                              S                              
                              │                   BE R       
                              ▼                   │  │       
                         ChannelSegm#2            ▼  ▼       
   ChannelSegm#1           REMOVED             ChannelSegm#3 
       0                     1                     2         
   ┌────────┐            ┌────────┐            ┌────────┐    
   │        │  ◄───────  │        │  ───────►  │        │    
   │        │            │        │            │        │    
   └────────┘            └────────┘            └────────┘    
    done_rcv              int_send              Coroutine    
      ▲  │                                        ▲ |        
      │  └────────────────────────────────────────┘ │        
      └─────────────────────────────────────────────┘        

cleanPrev was not invoked on the segment #2, because receiveSegment did not reach #2, skipping it as alogically removed one. As a result, the prev reference of the leftmost segment was not null during validate().

Solution:

The existing cleanPrev invocations in the algorithm's branches were insufficient because remove could bypass these calls and set an already processed segment in this.next.prev.

Instead of adding more calls to cleanPrev somewhere in the algorithm, the decision was made to clean prev references inside moveForward, thus encapsulating cleanPrev logic in one place. When a successful cas(from, to) occurs, the algorithm starts from to, follows prev references and looks for the leftmost segment that no pointers reference. Once found, its prev reference is cleaned, and the moveForward call completes.

internal inline fun <S : Segment<S>> AtomicRef<S>.moveForward(to: S): Boolean = loop { cur ->
    if (cur.id >= to.id) return true
    if (to.isRemoved) return false
    if (compareAndSet(cur, to)) {
        if (to.isRemoved) return false
+       cleanLeftmostPrev(cur, to)
        return true
    }
}
private inline fun <S : Segment<S>> cleanLeftmostPrev(from: S, to: S) {
    var cur = to
    // Find the leftmost segment on the sublist between `from` and `to` segments.
    while (!cur.isLeftmostOrProcessed && cur.id > from.id) {
        cur = cur.prev ?:
            // The `prev` reference was cleaned in parallel.
            return
    }
    if (cur.isLeftmostOrProcessed) cur.cleanPrev() // The leftmost segment is found
}

If at any iteration prev for a segment is null, although the segment did not pass the "no pointers to the left" check, it means a parallel moveForward call cleaned the prev reference.

moveForward

	internal inline fun <S : Segment<S>> AtomicRef<S>.moveForward(to: S): Boolean = loop { cur ->
	    if (cur.id >= to.id) return true // No need to update the pointer
	    if (to.isRemoved) return false // Trying to move pointer to the logically removed segment
	    if (compareAndSet(cur, to)) { // The segment is moved
+         if (to.isRemoved) return false // The segment was removed in parallel during the `CAS` operation
+	        cleanLeftmostPrev(cur, to)
	        return true
	    }
	}
First change The pointer may be moved to a segment that has already been physically deleted, resulting in a memory leak. Example of a failed scenario:
           Thread 1                        │                 Thread 2                
                                           │                                         
                                           │                                         
send(2): buffered                          │                                         
                                           │                                         
send(2): suspend + cancel                  │                                         
                                           │                                         
receive():                                 │                                         
  r=0, id=0                                │                                         
  segm=ChannelSegm#1                       │                                         
  state: buffered->done_rcv                │                                         
  expandBuffer():                          │                                         
    b=1, id=1                              │                                         
    segm=fndSegmentEB(1,#1): #2            │                                         
      #1.findSegment(id=1): #2             │                                         
      BE.moveForward(#2): true             │                                         
        #1.id>=#2.id: false                │                                         
        #2.isRemoved: false                │                                         
    state: Coroutine->resuming_eb          │                                         
    tryResume(): false                     │                                         
    state: resuming_eb->int_send           │                                         
    onCancelledRequest():                  │                                         
       waitEBcompletion()                  │                                         
       onSlotCleaned():                    │                                         
                                           │  receive():                             
                                           │    r=1, id=1                            
                                           │    segm=findSegmentReceive(1,#1): #2
                                           │      #1.findSegment(id=1): #2           
                                           │      R.moveForward(#2): true            
                                           │        #1.id>=#2.id: false              
                                           │        #2.isRemoved: false              
         cleanedSlots.incrementAndGet()    │                                         
         #2.isRemoved(): false (это хвост) │                                         
                                           │                                         
send(2):                                   │                                         
  s=2, id=2                                │                                         
  segm=findSegmentSend(2,#2): #3           │                                         
    findSegment(id=2): #3                  │                                         
      #2.trySetNext(#3): true (новый хвост)│                                         
      #2.isRemoved: true                   │                                         
      #2.remove():                         │                                         
        prev=#1, next=#3                   │                                         
        prev._next=#3                      │                                         
        next._prev=#1                      │                                         
        movePointersForwardFrom(#2):       │                                         
          S.cas(#2,#3): true               │                                         
          BE.cas(#2,#3): true              │                                         
    S.moveForward(#3):                     │                                         
      #3.id>=#3.id: true, return           │                                         
  state: null->buffered                    │                                         
                                           │        cas(#1,#2): true                 
                                           │        cleanLeftmostPrev(#1,#2):        
                                           │           #2.isLeftmostOrProcessed: true
                                           │           #2.cleanPrev()

Final state:

                               R                                                     
                               │                   BE S                              
                               ▼                   │  │                              
                          ChannelSegm#2            ▼  ▼                              
    ChannelSegm#1           REMOVED             ChannelSegm#3                        
        0                     1                     2                                
    ┌────────┐            ┌────────┐            ┌────────┐                           
    │        │            │        │  ───────►  │        │                           
    │        │            │        │            │        │                           
    └────────┘            └────────┘            └────────┘                           
     done_rcv              int_send                ▲ |                                
       ▲  │                                        │ │                               
       │  └────────────────────────────────────────┘ │                               
       └─────────────────────────────────────────────┘                               
  1. In the final state the leftmost segment is the segment which was physically removed, but it is reachable during validate()
  2. Double-linkness is violated between segments #1 and #2
Second change Described in the section "`cleanPrev` invocations".

moveToSpecifiedOrLast + findSpecifiedOrLast

internal inline fun <S : Segment<S>> AtomicRef<S>.moveToSpecifiedOrLast(id: Long, startFrom: S) {
    // Start searching the required segment from the specified one.
    var s = startFrom.findSpecifiedOrLast(id)
    // Skip all removed segments and try to update the channel pointer to the first non-removed one.
    // This part should succeed eventually, as the tail segment is never removed.
    while (true) {
        while (s.isRemoved) {
            s = s.next ?: break
        }
        // Try to update the value of `AtomicRef`.
        // On failure, the found segment is already removed, so it should be skipped.
        if (moveForward(s)) return
    }
}

moveToSpecifiedOrLast has the same logic as moveForward but inside uses findSpecifiedOrLast -- a method which returns a segment with the requested id or the tail in case such segment has not been created yet. In contrast with findSegmentInternal, findSpecifiedOrLast does not add new segments into the segment list.

@ndkoval ndkoval marked this pull request as draft December 18, 2024 13:35
@de-shyt de-shyt force-pushed the channels-remove-counters branch 2 times, most recently from 04522a4 to 7a9906b Compare January 28, 2025 13:06
@ndkoval ndkoval changed the title Removal of counters in Kotlin Channels Remove reference counters in the concurrent doubly-linked list used in BufferedChannel and Semaphore Jan 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants