Show HN: 通过热替换代码将你的 PyTorch 模型保存在 VRAM 中
Training Hot Swap
这是一个如何在不从 VRAM 中卸载模型权重的情况下热替换 PyTorch 训练代码的示例。
对于大型 LLM,将模型从磁盘加载到 VRAM 可能需要 30 秒以上。每次想要重新运行脚本时都等待 30 秒会减慢开发速度。这是一个最基本的实现,用于在训练脚本退出后仍然将大型模型保留在 VRAM 中的方法。如果必须重新加载模型,它会在退出后在后台进行,从而确保下次运行脚本时模型可以立即准备就绪。
其工作原理是生成第二个进程,该进程在目标脚本退出后保持活动状态。您更改的脚本不会直接运行。相反,此后台进程使用 Python 的 eval()
代表您运行代码。
这也可以通过 VPN 用于远程代码执行。IntelliJ 的远程 SSH 解释器存在很多错误,并非无缝远程开发的理想选择。配置 model_server.py
在远程计算机上运行,并在您的开发计算机上运行 client.py
。在这种配置中也支持使用 IntelliJ 调试器进行调试,从而实现几乎无缝的开发体验,脚本可以立即运行并且易于调试。
GUI 示例
已经完成了一些工作,以确保与 DearImgui Python 绑定的兼容性。UI 代码可以与您的训练脚本一起提交到服务器。我个人喜欢为我的训练脚本构建 UI,以监控进度、随时间变化的损失,并实现简单的评估。将您的 UI 代码与您的训练代码一起提交,确保您的应用程序可以立即启动。
这是一个应用程序的 GUI,它显示了 Mistral 7B 的中间输出。从我运行代码到我可以与模型交互,在我的机器上大约需要 0.32 秒,这包括 GUI 的初始化时间。 顺便说一句,您可以在这里找到更多我的 transformer 可视化内容:https://x.com/lukasvaline
使用方法
在 model_server.py
中设置您的模型下载位置。
与 IntelliJ 调试服务器兼容。将您的调试服务器端口设置为 5678。
要开始在您的开发中使用此功能,只需交换您的 .from_pretrained
调用并引用全局变量 'model'
以下代码将被移除:
model = MistralForCausalLM.from_pretrained(
self.model_path,
torch_dtype=torch.float16,
device_map=device,
use_flash_attention_2=False,
config=self.config,
)
并替换为:
def get_model(self):
"""Get model either from global context"""
global model # Reference the global model variable
try:
# Check if model exists in global scope
model
except NameError:
return None
return model
model = get_model()
如何运行:启动服务器并保持其运行
training-hot-swap$ python model_server.py
将训练代码提交到服务器
training-hot-swap$ python client.py ./src ./src/sample_train.py
其他考虑事项
此脚本是一个主要的潜在安全漏洞。这是一个服务器,其设计目的是执行任意代码。不要将此服务器直接暴露给互联网。
关于
Pytorch 脚本热替换:在不从 VRAM 中卸载 LLM 的情况下更改代码