源码剖析sync.WaitGroup(文末思考题你能解释一下吗?)

共 13050字,需浏览 27分钟

 ·

2021-06-13 00:40

前言

这是并发编程系列的第三篇文章,上一篇我们一起分析了sync.once的使用与实现,今天我们一起来看一看sync.WaitGroup的使用与实现。

什么是sync.WaitGroup

官方文档对sync.WatiGroup的描述是:一个waitGroup对象可以等待一组协程结束,也就等待一组goroutine返回。有了sync.Waitgroup我们可以将原本顺序执行的代码在多个Goroutine中并发执行,加快程序处理的速度。其实他与java中的CountdownLatch类似,用于阻塞等待所有任务完成之后再继续执行。我们来看官网给的一个例子,这个例子使用waitGroup阻塞主进程,并发获取多个URL,直到完成所有获取:

package main

import (
 "sync"
)

type httpPkg struct{}

func (httpPkg) Get(url string) {}

var http httpPkg

func main() {
 var wg sync.WaitGroup
 var urls = []string{
  "http://www.golang.org/",
  "http://www.google.com/",
  "http://www.somestupidname.com/",
 }
 for _, url := range urls {
  // Increment the WaitGroup counter.
  wg.Add(1)
  // Launch a goroutine to fetch the URL.
  go func(url string) {
   // Decrement the counter when the goroutine completes.
   defer wg.Done()
   // Fetch the URL.
   http.Get(url)
  }(url)
 }
 // Wait for all HTTP fetches to complete.
 wg.Wait()
}

首先我们需要声明一个sync.WaitGroup对象,在主gorourine调用Add()方法设置要等待的goroutine数量,每一个Goroutine在运行结束时要调用Done()方法,同时使用Wait()方法进行阻塞直到所有的goroutine完成。

为什么要用sync.waitGroup

我们在日常开发中为了提高接口响应时间,有一些场景需要在多个goroutine中做一些互不影响的业务,这样可以节省不少时间,但是需要协调多个goroutine,没有sync.WaitGroup的时候,我们可以使用通道来解决这个问题,我们把主Goroutine当成铜锣扛把子a song,把每一个Goroutine当成一个马仔,asong管理这些马仔,让这些马仔去收保护费,我今天派10个马仔去收保护费,每一个马仔收好了保护费就在账本上打一个✅,当所有马仔都收好了保护费,账本上就被打满了✅,活全被干完了,很出色,然后酒吧走起,浪一浪,全场的消费松公子买单,写成代码可以这样表示:


func exampleImplWaitGroup()  {
 done := make(chan struct{}) // 收10份保护费
 count := 10 // 10个马仔
 for i:=0;i < count;i++{
  go func(i int) {
   defer func() {
    done <- struct {}{}
   }()
   fmt.Printf("马仔%d号收保护费\n",i)
  }(i)
 }
 for i:=0;i< count;i++{
  <- done
  fmt.Printf("马仔%d号已经收完保护费\n",i)
 }
 fmt.Println("所有马仔已经干完活了,开始酒吧消费~")
}

虽然这样可以实现,但是我们每次使用都要保证主Goroutine最后从通道接收的次数需要与之前其他的Goroutine发送元素的次数相同,实现起来不够优雅,在这种场景下我们就可以选用sync.WaitGroup来帮助我们实现同步。

源码剖析

前面我们已经知道sync.waitGroup的基本使用了,接下来我们就一起看看他是怎样实现的~,只有知其所以然,才能写出更健壮的代码。

Go version: 1.15.3

首先我们看一下sync.WaitGroup的结构:

// A WaitGroup must not be copied after first use.
type WaitGroup struct {
 noCopy noCopy

 // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
 // 64-bit atomic operations require 64-bit alignment, but 32-bit
 // compilers do not ensure it. So we allocate 12 bytes and then use
 // the aligned 8 bytes in them as state, and the other 4 as storage
 // for the sema.
 state1 [3]uint32
}

总共就有两个字段,nocopy是为了保证该结构不会被进行拷贝,这是一种保护机制,会在后面进行介绍;state1主要是存储着状态和信号量,这里使用的8字节对齐处理的方式很有意思,我先来一起看看这种处理。

state1状态和信号量处理

state1这里总共被分配了12个字节,这里被设计了三种状态:

  • 其中对齐的8个字节作为状态,高32位为计数的数量,低32位为等待的goroutine数量
  • 其中的4个字节作为信号量存储

提供了(wg *WaitGroup) state() (statep *uint64, semap *uint32)帮助我们从state1字段中取出他的状态和信号量,为什么要这样设计呢?

我们在分析atomicGo看源码必会知识之unsafe包有说到过,64位原子操作需要64位对齐,但是32位编译器不能保证这一点,所以为了保证waitGroup32位平台上使用的话,就必须保证在任何时候,64位操作不会报错。所以也就不能分成两个字段来写,考虑到字段顺序不同、平台不同,内存对齐也就不同。因此这里采用动态识别当前我们操作的64位数到底是不是在8字节对齐的位置上面,我们来分析一下state方法:

// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
 if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
  return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
 } else {
  return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
 }
}

当数组的首地址是处于一个8字节对齐的位置上时,那么就将这个数组的前8个字节作为64位值使用表示状态,后4个字节作为32位值表示信号量(semaphore)。同理如果首地址没有处于8字节对齐的位置上时,那么就将前4个字节作为semaphore,后8个字节作为64位数值。画个图表示一下:

Add()Done()方法

sync.WaitGroup提供了Add()方法增加一个计数器,Done()方法减掉一个计数,Done()方法实现比较简单,内部调用的Add()方法实现的计数器减一操作,也就是增减逻辑都在Add()方法中,所以我们重点看一下Add()是如何实现的:

func (wg *WaitGroup) Add(delta int) {
  // 获取状态(Goroutine Counter 和 Waiter Counter)和信号量
 statep, semap := wg.state()
 if race.Enabled {
  _ = *statep // trigger nil deref early
  if delta < 0 {
   // Synchronize decrements with Wait.
   race.ReleaseMerge(unsafe.Pointer(wg))
  }
  race.Disable()
  defer race.Enable()
 }
  // 原子操作,goroutine counter累加delta
 state := atomic.AddUint64(statep, uint64(delta)<<32)
  // 获取当前goroutine counter的值(高32位)
 v := int32(state >> 32)
  // 获取当前waiter counter的值(低32位)
 w := uint32(state)
 if race.Enabled && delta > 0 && v == int32(delta) {
  // The first increment must be synchronized with Wait.
  // Need to model this as a read, because there can be
  // several concurrent wg.counter transitions from 0.
  race.Read(unsafe.Pointer(semap))
 }
  // Goroutine counter是不允许为负数的,否则会发生panic
 if v < 0 {
  panic("sync: negative WaitGroup counter")
 }
  // 当wait的Goroutine不为0时,累加后的counter值和delta相等,说明Add()和Wait()同时调用了,所以发生panic,因为正确的做法是先Add()后Wait(),也就是已经调用了wait()就不允许再添加任务了
 if w != 0 && delta > 0 && v == int32(delta) {
  panic("sync: WaitGroup misuse: Add called concurrently with Wait")
 }
  // 正常`Add()`方法后,`goroutine Counter`计数器大于0或者`waiter Counter`计数器等于0时,不需要释放信号量
 if v > 0 || w == 0 {
  return
 }
 // 能走到这里说明当前Goroutine Counter计数器为0,Waiter Counter计数器大于0, 到这里数据也就是允许发生变动了,如果发生变动了,则出发panic
 if *statep != state {
  panic("sync: WaitGroup misuse: Add called concurrently with Wait")
 }
 // 重置状态,并发出信号量告诉wait所有任务已经完成
 *statep = 0
 for ; w != 0; w-- {
  runtime_Semrelease(semap, false0)
 }
}

上面的代码有一部分是race静态检测,下面的分析会省略这一部分,因为它并不是本文的重点。

注释我都添加到对应的代码行上了,你是否都看懂了,没看懂不要紧,因为Add()是与Wait()方法一块使用的,所以有些逻辑与wait()里的逻辑是相互照应的,所以当我们看完wait()方法的实现在总结一下你们就明白了。

Wait()方法

sync.Wait()方法会阻塞主Goroutine直到WaitGroup计数器变为0。我们一起来看一下Wait()方法的源码:

// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
  // 获取状态(Goroutine Counter 和 Waiter Counter)和信号量
 statep, semap := wg.state()
 if race.Enabled {
  _ = *statep // trigger nil deref early
  race.Disable()
 }
 for {
    // 使用原子操作读取state,是为了保证Add中的写入操作已经完成
  state := atomic.LoadUint64(statep)
    // 获取当前goroutine counter的值(高32位)
  v := int32(state >> 32)
     // 获取当前waiter counter的值(低32位)
  w := uint32(state)
    // 如果没有任务,或者任务已经在调用`wait`方法前已经执行完成了,就不用阻塞了
  if v == 0 {
   // Counter is 0, no need to wait.
   if race.Enabled {
    race.Enable()
    race.Acquire(unsafe.Pointer(wg))
   }
   return
  }
  // 使用CAS操作对`waiter Counter`计数器进行+1操作,外面有for循环保证这里可以进行重试操作
  if atomic.CompareAndSwapUint64(statep, state, state+1) {
   if race.Enabled && w == 0 {
    // Wait must be synchronized with the first Add.
    // Need to model this is as a write to race with the read in Add.
    // As a consequence, can do the write only for the first waiter,
    // otherwise concurrent Waits will race with each other.
    race.Write(unsafe.Pointer(semap))
   }
      // 在这里获取信号量,使线程进入睡眠状态,与Add方法中最后的增加信号量相对应,也就是当最后一个任务调用Done方法
      // 后会调用Add方法对goroutine counter的值减到0,就会走到最后的增加信号量
   runtime_Semacquire(semap)
      // 在Add方法中增加信号量时已经将statep的值设为0了,如果这里不是0,说明在wait之后又调用了Add方法,使用时机不对,触发panic
   if *statep != 0 {
    panic("sync: WaitGroup is reused before previous Wait has returned")
   }
   if race.Enabled {
    race.Enable()
    race.Acquire(unsafe.Pointer(wg))
   }
   return
  }
 }
}

源码总结

分了源码,我们可以总结如下:

  • Add方法与wait方法不可以并发同时调用,Add方法要在wait方法之前调用.
  • Add()设置的值必须与实际等待的goroutine个数一致,否则会panic.
  • 调用了wait方法后,必须要在wait方法返回以后才能再次重新使用waitGroup,也就是Wait没有返回之前不要在调用Add方法,否则会发生Panic.
  • Done 只是对Add 方法的简单封装,我们可以向 Add方法传入任意负数(需要保证计数器非负)快速将计数器归零以唤醒等待的 Goroutine.
  • waitGroup对象只能有一份,不可以拷贝给其他变量,否则会造成意想不到的Bug.

no copy机制

在前文看waitGroup结构时,有一个nocopy字段,为什么要有nocopy呢?我们先看这样一个例子:

type User struct {
 Name string
 Info *Info
}

type Info struct {
 Age int
 Number int
}


func main()  {
 u := User{
  Name: "asong",
  Info: &Info{
   Age: 10,
   Number: 24,
  },
 }
 u1 := u
 u1.Name = "Golang梦工厂"
 u1.Info.Age = 30
 fmt.Println(u.Info.Age,u.Name)
 fmt.Println(u1.Info.Age,u1.Name)
}
// 运行结果
30 asong
30 Golang梦工厂

结构体User中有两个字段NameInfoNameString类型,Info是指向结构体Info的指针类型,我们首先声明了一个u变量,对他进行复制拷贝得到变量u1,在u1中对两个字段进行改变,可以看到Info字段发生了更改,而Name就没发生更改,这就引发了安全问题,如果结构体对象包含指针字段,当该对象被拷贝时,会使得两个对象中的指针字段变得不再安全。

Go语言中提供了两种copy检查,一种是在运行时进行检查,一种是通过静态检查。不过运行检查是比较影响程序的执行性能的,Go官方目前只提供了strings.Builder和sync.Cond的runtime拷贝检查机制,对于其他需要nocopy对象类型来说,使用go vet工具来做静态编译检查。运行检查的实现可以通过比较所属对象是否发生变更

就可以判断,而静态检查是提供了一个nocopy对象,只要是该对象或对象中存在nocopy字段,他就实现了sync.Locker接口, 它拥有Lock()和Unlock()方法,之后,可以通过go vet功能,来检查代码中该对象是否有被copy。

踩坑事项

在文章的最后总结一下使用waitGroup易错的知识点,防止大家再次犯错。

  1. waitGroup中计数器的值是不能小于0的,源码中我们就可以看到,一旦小于0就会引发panic。
  2. 一定要住注意调用Add方法与Wait方法的顺序,不可并发同时调用这两个方法,否则就会引发panic,同时在调用了wait方法在其没有释放前不要再次调用Add方法,这样也会引发panicwaitGroup是可以复用的,但是需要保证其计数周期的完整性。
  3. WaitGroup对象不是一个引用类型,通过函数传值的时候需要使用地址,因为Go语言只有值传递,传递WaitGroup是值的话,就会导致会发生panic,看这样一个例子:
func main()  {
 wg := sync.WaitGroup{}
 wg.Add(1)
 doDeadLock(wg)
 wg.Wait()
}
func doDeadLock(wg sync.WaitGroup)  {
 defer wg.Done()
 fmt.Println("do something")
}
//运行结果:panic: sync: negative WaitGroup counter

发生这个问题的原因就是在doDeadLock()方法中wg是一个新对象,直接调用Done方法,计数器就会出现负数,所以引发panic,为了安全起见,对于这种传结构体的场景一般建议都传指针就好了,基本可以避免一些问题。

  1. Add()设置的值必须与实际等待的goroutine个数一致,否则会panic,很重要的一点,也是很容易出错的地方。

思考题

最后给大家出一个思考题,下面这段代码会不会发生panic

func main() {
 wg := sync.WaitGroup{}
 wg.Add(100)
 for i := 0; i < 100; i++ {
  go func() {
   defer wg.Done()
   fmt.Println(i)
  }()
 }
 wg.Wait()
}

推荐阅读


福利

我为大家整理了一份从入门到进阶的Go学习资料礼包,包含学习建议:入门看什么,进阶看什么。关注公众号 「polarisxu」,回复 ebook 获取;还可以回复「进群」,和数万 Gopher 交流学习。

浏览 18
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报