标签搜索

目 录CONTENT

文章目录

还在用torch.save()?快来看看safetensors吧!

陈铭
2024-06-01 / 0 评论 / 0 点赞 / 30 阅读 / 777 字 / 正在检测是否收录...

目录

1. 前言

Safetensors 是一个由 HuggingFace 开发的库(采用 Rust 编写),用于安全地存储和加载张量(tensors)。与其他序列化格式(如 pickle)相比,Safetensors 提供了更快的加载速度和更高的安全性。这个库特别适用于深度学习模型的序列化,尤其是在需要高效加载大型模型权重时。

安装:

pip install safetensors

2. safetensors.torch

这一节主要讲解safetensors中有关PyTorch的API。

2.1 读取

2.1.1 load_file

函数签名如下:

def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, torch.Tensor]:

用于直接读取 .safetensors 文件并将其转化成Python字典的格式:

from safetensors.torch import load_file

file_path = "./my_folder/bert.safetensors"
loaded = load_file(file_path)

2.1.2 load

函数签名如下:

def load(data: bytes) -> Dict[str, torch.Tensor]:

二进制格式读取 .safetensors 文件并将其转化成Python字典:

from safetensors.torch import load

file_path = "./my_folder/bert.safetensors"
with open(file_path, "rb") as f:
    data = f.read()

loaded = load(data)

2.1.3 load_model

函数签名如下:

def load_model(model: torch.nn.Module, filename: str, strict=True) -> Tuple[List[str], List[str]]:

如果使用 load_file 来加载模型参数,我们通常会执行:

model.load_state_dict(load_file("model.safetensors"))

如果使用 load_model 就可以大大简洁:

load_model(model, "model.safetensors")

2.2 写入

2.2.1 save_file

函数签名如下:

def save_file(
    tensors: Dict[str, torch.Tensor],
    filename: Union[str, os.PathLike],
    metadata: Optional[Dict[str, str]] = None,
):

用于存储给定的张量(字典格式):

from safetensors.torch import save_file
import torch

tensors = {"embedding": torch.zeros((512, 1024)), "attention": torch.zeros((256, 256))}
save_file(tensors, "model.safetensors")

2.2.2 save

函数签名如下:

def save(tensors: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:

用于将给定的张量(字典格式)存储为二进制格式。

注意,该函数并不会直接生成文件,只会返回一个二进制结果:

from safetensors.torch import save
import torch

tensors = {"embedding": torch.zeros((512, 1024)), "attention": torch.zeros((256, 256))}
byte_data = save(tensors)

2.2.3 save_model

函数签名如下:

def save_model(
    model: torch.nn.Module, filename: str, metadata: Optional[Dict[str, str]] = None, force_contiguous: bool = True
):

如果使用 save_file 来存储模型参数,我们通常会执行:

save_file(model.state_dict(), "model.safetensors")

如果使用 save_model 就可以大大简洁:

save_model(model, "model.safetensors")

Ref

[1] https://huggingface.co/docs/safetensors/v0.3.2/index
[2] https://blog.eleuther.ai/safetensors-security-audit/
[3] https://medium.com/@mandalsouvik/safetensors-a-simple-and-safe-way-to-store-and-distribute-tensors-d9ba1931ba04

文章知识点与官方知识档案匹配,可进一步学习相关知识

Python入门技能树人工智能深度学习422732 人正在系统学习中

本文转自 https://raelum.blog.csdn.net/article/details/136636819,如有侵权,请联系删除。

0

评论区