Pytorch 7 :Memory Optimization(Freeing GPU/NPU Memory Early)
导言
- 对于不使用的python对象,如何释放?
- python 的对象管理机制
- del,empty_cache , gc_collect的原理
Python 的自动内存管理主要基于引用计数(Reference Counting),辅以循环垃圾回收器(Cycle Garbage Collector)处理循环引用。
引用计数的核心思想是:每个对象维护一个计数器,记录有多少变量(或容器)引用它;当计数降为 0 时,对象立即被销毁并释放内存。
一、引用计数机制简述¶
- 每次一个对象被引用(赋值、传参、放入容器等),其
ob_refcnt+1。 - 每次引用被删除(变量离开作用域、被覆盖、从容器中移除等),
ob_refcnt-1。 - 当
ob_refcnt == 0,Python 调用该对象的析构函数(如果有的话)并释放内存。
你可使用 sys.getrefcount(obj) 查看当前引用计数(注意:该函数本身会临时增加一次引用,所以返回值比实际多 1)。
结合 fp32 → bfloat16(bf16)转换举例
虽然 Python 本身不直接操作 bf16(这是底层硬件/框架如 PyTorch/TensorFlow 支持的数据类型),但我们可以用 NumPy 或 PyTorch 张量来模拟对象生命周期,观察引用计数变化。
注意:
bf16在 PyTorch 中通过.to(torch.bfloat16)实现,但其底层张量对象仍是 Python 对象,受引用计数管理。
代码示例:观察引用计数变化(以 PyTorch 张量为例)
import torch
import sys
# 创建一个 fp32 张量
x_fp32 = torch.randn(1000, dtype=torch.float32)
print(f"创建 x_fp32 后引用计数: {sys.getrefcount(x_fp32) - 1}") # -1 排除 getrefcount 自身引用
# 转换为 bf16,生成新张量(注意:这是新对象!)
x_bf16 = x_fp32.to(torch.bfloat16)
print(f"x_bf16 创建后,x_fp32 引用计数: {sys.getrefcount(x_fp32) - 1}")
print(f"x_bf16 的引用计数: {sys.getrefcount(x_bf16) - 1}")
# 删除原始 fp32 张量
del x_fp32
# 此时 x_fp32 对象的引用计数降为 0(假设无其他引用),内存被立即回收
print("已删除 x_fp32")
# 仍可使用 x_bf16
print(f"x_bf16 数据类型: {x_bf16.dtype}")
# 删除 x_bf16
del x_bf16
# 现在 x_bf16 对象也被回收
输出(典型情况):
四、关键点说明
to()不是原地操作:x_fp32.to(torch.bfloat16)返回新张量对象,与原x_fp32无共享内存(除非显式使用view或as_strided,但 bf16 与 fp32 位宽不同,无法直接 view)。- 两个独立对象:
x_fp32和x_bf16各自有自己的引用计数。 del降低引用计数:del x_fp32使x_fp32的引用计数减 1,若变为 0,则立即触发__del__(如果定义了)并释放内存。- NPU/GPU 内存也受此机制管理:PyTorch 张量在 CPU/NPU/GPU 上的数据由 Python 对象持有,当对象被回收,其底层设备内存也会被释放(通过张量的析构函数)。
函数调用返回¶
引用计数 不是全局参数,而是 每个 Python 对象自身的一个属性(存储在 PyObject 结构体的 ob_refcnt 字段中)。它跟踪的是 有多少个“引用”指向该对象,与变量作用域(局部/全局)无关,只与“引用关系”有关。
易混淆点:
1. 引用计数是全局参数吗?还是不同函数的局部变量?¶
- 都不是。
引用计数属于 对象本身,不是变量的属性,也不是全局或局部变量。
例如:
这里的引用计数是[1,2,3] 这个 list 对象的属性,a 和 b 只是两个名字(引用)。
2. 如果一个函数使用了一个变量,其引用计数会增加吗?¶
会,但要看“使用”的方式:
情况 A:将对象作为参数传入函数¶
import sys
def f(x):
print("函数内引用计数:", sys.getrefcount(x) - 1) # -1 因为 getrefcount 自身加1
return x
obj = [1, 2, 3]
print("调用前引用计数:", sys.getrefcount(obj) - 1) # 通常是 1
f(obj)
print("调用后引用计数:", sys.getrefcount(obj) - 1) # 仍是 1
- 在函数调用时,形参
x会绑定到obj所指向的对象 → 引用计数 临时 +1。 - 函数返回后,形参
x超出作用域 → 引用计数 -1,恢复原状。
✅ 结论:函数参数传递会临时增加引用计数,但函数结束时会自动减少。
情况 B:函数内部创建新对象¶
-y 是局部变量,指向新列表。
- return y 将引用传递给调用者(z),不是复制对象。
- 函数结束时,y 被销毁(引用 -1),但 z 接管了引用(总引用数保持 1)。
3. 函数返回后,局部变量和全局变量的引用计数如何变化?¶
(1)局部变量:¶
- 函数执行结束时,所有局部变量(如
x,y)从局部命名空间中移除。 - 这会导致它们所引用的对象 引用计数 -1。
- 如果计数变为 0,对象立即被回收。
(2)全局变量:¶
- 如果函数内部读取全局变量(如
global_list),会临时增加其引用计数(因为函数栈帧中有一个引用)。 - 函数结束后,这个临时引用消失,计数恢复。
- 如果函数内部修改全局变量(
global global_list),则可能增加/减少引用,取决于操作(如赋值新对象)。
完整示例:观察函数调用中的引用计数变化
import sys
global_obj = [10, 20]
def example(local_obj):
print("进入函数时 global_obj 引用计数:", sys.getrefcount(global_obj) - 1)
print("进入函数时 local_obj 引用计数:", sys.getrefcount(local_obj) - 1)
temp = local_obj # 引用 +1
print("赋值 temp 后 local_obj 引用计数:", sys.getrefcount(local_obj) - 1)
return temp
my_list = [1, 2]
print("调用前 my_list 引用计数:", sys.getrefcount(my_list) - 1) # 1
result = example(my_list)
print("函数返回后 my_list 引用计数:", sys.getrefcount(my_list) - 1) # 2(my_list + result)
print("函数返回后 global_obj 引用计数:", sys.getrefcount(global_obj) - 1) # 1(恢复)
del result # 手动删除 result
print("删除 result 后 my_list 引用计数:", sys.getrefcount(my_list) - 1) # 1
关键点:
- 函数参数、局部变量都是引用,不是对象本身。
- 函数栈帧存在期间,会对所引用的对象保持一次引用。
- 函数返回时,局部变量销毁 → 引用计数减少。
- 如果 return 了某个对象,调用者接收到引用,对象不会被销毁。
特别注意:在 PyTorch / NumPy 中的含义
在你关注的 AI 框架中(如 PyTorch 张量):
def forward(x): ... return y中的x和y都是张量对象的引用。- 函数结束时,局部变量(如中间激活
h = relu(x))若未被返回或保存,引用计数降为 0,张量对象被回收,其 GPU/NPU 内存也可能释放(取决于框架内存池策略)。 - 但若计算图需要(如
requires_grad=True),PyTorch 会额外持有引用用于反向传播,此时即使局部变量销毁,张量也不会被回收。
总结¶
| 场景 | 引用计数变化 |
|---|---|
| 函数参数传入对象 | +1(临时),函数返回后 -1 |
| 函数内局部变量引用对象 | 作用域结束时 -1 |
return obj |
调用者获得引用,对象生命周期延长 |
| 全局变量被函数读取 | 临时 +1,函数结束 -1 |
| 对象引用计数归零 | 立即析构(__del__ + 内存释放) |
这正是 Python 能高效管理内存(尤其在训练循环中频繁创建中间张量)的基础机制之一。
del 的使用¶
非常好的问题!我们来精确澄清 del 的作用:
del x并不是直接将引用计数归零,而是将变量x对其所指向对象的引用移除,从而导致该对象的引用计数减 1(refcnt -= 1)。
一、del 的本质:解除一个引用¶
在 Python 中,变量(如 x)本质上是名字(name)到对象的引用(reference)。
执行 del x 的含义是:
- 从当前作用域(如局部或全局命名空间)中删除名字
x; - 这会导致
x原本指向的对象失去一次引用; - 该对象的引用计数减 1;
- 是否归零,取决于是否有其他引用存在。
二、举例说明¶
✅ 情况 1:只有一个引用 → del 后引用计数归零¶
import sys
a = [1, 2, 3]
print(sys.getrefcount(a) - 1) # 输出 1(只有 a 引用它)
del a
# 此时 [1,2,3] 对象的引用计数变为 0,立即被回收
# 再访问 a 会报 NameError: name 'a' is not defined
→ 这里 del a 使引用计数从 1 → 0,对象销毁。
✅ 情况 2:多个引用 → del 后引用计数减 1,但不归零¶
import sys
a = [1, 2, 3]
b = a # b 也引用同一个 list
print(sys.getrefcount(a) - 1) # 输出 2(a 和 b)
del a
print(sys.getrefcount(b) - 1) # 输出 1(仅剩 b)
# 对象未被销毁,仍可通过 b 访问
print(b) # [1, 2, 3]
→ del a 仅移除了 a 这个引用,对象仍被 b 持有。
✅ 情况 3:函数局部变量(自动 del)¶
import sys
def f():
x = [1]
print("函数内引用计数:", sys.getrefcount(x) - 1) # 1
return x
obj = f() # 函数返回后,局部变量 x 自动被 del
# x 的引用消失,但 obj 接管了引用,所以对象未销毁
print("函数外引用计数:", sys.getrefcount(obj) - 1) # 1
→ 函数结束时,局部变量 x 被隐式 del,引用计数 -1,但因 return 被 obj 接收,总引用数不变。
三、与 PyTorch / NPU 的关联¶
x = torch.randn(1000, device='npu') # 引用计数 = 1
y = x # 引用计数 = 2
del x # 引用计数 = 1(y 仍持有)
# 此时 NPU 内存不会释放!
del y # 引用计数 = 0
# Tensor 对象销毁 → PyTorch 将其设备内存归还到缓存池
# 调用 torch.npu.empty_cache() 才会真正释放给 NPU 驱动
✅ 所以,
del只是减 1,是否归零取决于其他引用是否存在。
四、常见误解纠正¶
| 误解 | 正确理解 |
|---|---|
“del x 会删除对象” |
❌ del x 只删除名字 x,对象是否删除看引用计数 |
“del 后内存立刻释放” |
❌ 仅当引用计数归零时才释放 Python 对象;设备内存还受框架缓存控制 |
“del = free()” |
❌ Python 是垃圾回收语言,del 不是直接内存释放指令 |
总结¶
del x→ 移除变量x对对象的引用 → 对象引用计数减 1。- 是否归零? → 取决于是否还有其他变量/容器/闭包等引用该对象。
- 设备内存释放? → 需要:
引用计数归零→Tensor 析构→PyTorch 缓存池标记为可回收→empty_cache()(可选,归还给驱动)。
empty_cache¶
torch.npu.empty_cache())
PyTorch 不会在 Tensor 被销毁时立即释放 GPU/NPU 内存,而是:
- 将释放的内存块放入 内存池(caching allocator)。
- 后续分配相同大小的内存时,直接复用缓存块,避免频繁调用 cudaFree(昂贵操作)。
empty_cache() 的作用:
将缓存中未被使用的内存块真正归还给设备驱动(如 CUDA driver 或 NPU driver)。
gc.collect()¶
- 作用:运行 Python 的 循环垃圾回收器(Generational GC),检测并回收引用计数无法释放的循环引用对象。
- 触发条件:
- 自动:每分配一定数量对象后触发(可配置)。
- 手动:调用
gc.collect()。 - 与引用计数关系:
- 补充机制:只处理引用计数“漏掉”的对象(主要是循环引用)。
- 不管理设备内存:即使回收了 Tensor 的 Python 对象,若 PyTorch 仍持有其底层数据指针,显存/NPU 内存不会释放。
❌
gc.collect()不能释放 GPU/NPU 显存,除非它成功回收了 Tensor 对象,且该 Tensor 是最后一个持有设备内存的引用。
计算图引用:batch数据被保存到反向¶
这是一个在 PyTorch 训练循环中极易被忽视但影响显著的内存问题。确实,DataLoader 通过 for batch in dataloader 产出每个 batch,但如果处理不当,一个 batch 的数据可能被意外持有多轮(甚至直到反向传播结束),导致 GPU/NPU 内存无法及时释放,尤其在 VLM、MoE 等大模型场景中会严重限制 batch size 或引发 OOM。
🔍 问题本质:谁“持有”了 batch?¶
for batch in dataloader:
pixel_values = batch["pixel_values"].to(device)
input_ids = batch["input_ids"].to(device)
outputs = model(pixel_values, input_ids)
loss = outputs.loss
loss.backward() # ← 反向时,batch 数据仍被持有?
optimizer.step()
表面上看,batch 是 for 循环的局部变量,每轮应自动释放。
但问题出在计算图(computation graph)和 PyTorch 的自动微分机制。
🧠 核心原因:计算图保留了对输入 Tensor 的引用¶
当 requires_grad=True(或模型有可训练参数)时,PyTorch 会构建反向传播所需的计算图。
该图会 隐式持有对所有参与 forward 的叶节点(leaf tensors)的引用,包括:
pixel_valuesinput_ids- 以及其他输入张量
✅ 即使你在
forward结尾del pixel_values,只要计算图存在,PyTorch 仍会保持对其底层数据的引用 → 设备内存无法释放。
这个引用会一直持续到 loss.backward() 执行完毕(或手动 loss.grad_fn 被断开)。
✅ 验证:为什么 batch 会被持到反向?¶
import torch
x = torch.randn(1000, requires_grad=True).cuda()
y = x * 2
loss = y.sum()
print("Before backward:", torch.cuda.memory_allocated()) # 高
loss.backward()
print("After backward:", torch.cuda.memory_allocated()) # 明显下降(x 的 grad 保留,但计算图释放)
- 在
backward()之前,x被计算图引用 → 无法释放。 backward()后,计算图被销毁(除非retain_graph=True)→x的“图引用”消失。- 如果
x没有其他 Python 引用(如变量、容器),则其内存可被回收。
🚫 为什么 del batch 在循环内无效?¶
batch是 CPU 上的字典,del batch仅释放 CPU 内存。x是独立的 GPU Tensor,其生命周期由:- Python 引用(
x变量) - 计算图引用(关键!)
共同决定。
✅ 正确解决方案¶
✅ 方案 1:在 backward() 后立即清理(最常用)¶
for i, batch in enumerate(dataloader):
pixel_values = batch["pixel_values"].to(device)
input_ids = batch["input_ids"]...to(device)
outputs = model(pixel_values, input_ids)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
# 可选:显式 del(虽非必须,但更安全)
del pixel_values, input_ids, outputs, loss
✅
backward()后计算图销毁 → 张量仅剩 Python 引用 → 出循环时自动释放。
✅ 方案 2:使用 no_grad() 或 torch.inference_mode()(推理/评估时)¶
with torch.inference_mode(): # 或 torch.no_grad()
for batch in dataloader:
...
# 无计算图 → 输入 Tensor 用完即可释放
- 推理时无反向传播 → 无计算图 →
del或作用域退出即可释放。
✅ 方案 3:避免在循环外意外持有引用¶
常见陷阱:
all_losses = []
for batch in dataloader:
loss = model(...)
all_losses.append(loss) # ← 保留了 loss,间接保留整个计算图!
→ 不要将 loss/Tensor 存入全局 list,除非 .item() 或 .detach():
✅ 方案 4:使用 model(...).detach()(谨慎)¶
如果某些中间输出需保留但不参与反向:
.detach()创建一个与计算图断开的新 Tensor,不持有图引用。
📊 内存释放时机总结¶
| 阶段 | 是否可释放 pixel_values 的设备内存? |
原因 |
|---|---|---|
forward 中(反向前) |
❌ 不能 | 计算图持有引用 |
del pixel_values(反向前) |
❌ 不能 | 图引用仍存在 |
loss.backward() 后 |
✅ 可以 | 计算图已销毁 |
下一轮 for 循环开始前 |
✅ 自动释放 | Python 引用退出 + 无图引用 |
💡 针对你 NPU/VLM 场景的建议¶
- 不要尝试在
forward中提前释放pixel_values以节省反向阶段内存 —— 这是徒劳的,因为计算图锁住了它。 - 真正有效的优化方向:
- 减少 计算图规模(如 gradient checkpointing)
- 使用
torch.compile或 融合算子(如 GMM) 降低中间激活内存 - 避免 意外保留 Tensor(如 logger、debug 变量)
- 监控工具:
✅ 结论¶
DataLoader的 batch 被持到反向是合理的,因为 PyTorch 的自动微分机制需要它。
这不是 bug,而是 反向传播的必要代价。
你无法也不应该在backward()前释放参与计算图的输入 Tensor。
正确做法是:确保反向后无冗余引用,并优化计算图本身(如 activation checkpointing)。
避免保存原生特征
要注意 rollout buffer 中是否存储了原始 pixel_values —— 应只存 视觉特征(feature) 或 token IDs,而非原始像素,这才是根本的内存优化。