环境
import torch
if __name__ == '__main__':
print(torch.__version__)
# 检查CUDA是否可用
cuda_available = torch.cuda.is_available()
if cuda_available:
# 获取GPU设备数量
num_gpu = torch.cuda.device_count()
# 获取当前使用的GPU索引
current_gpu_index = torch.cuda.current_device()
# 获取当前GPU的名称
current_gpu_name = torch.cuda.get_device_name(current_gpu_index)
# 获取GPU显存的总量和已使用量
total_memory = torch.cuda.get_device_properties(current_gpu_index).total_memory / (1024 ** 3) # 显存总量(GB)
used_memory = torch.cuda.memory_allocated(current_gpu_index) / (1024 ** 3) # 已使用显存(GB)
free_memory = total_memory - used_memory # 剩余显存(GB)
print(f"CUDA可用,共有 {num_gpu} 个GPU设备可用。")
print(f"当前使用的GPU设备索引:{current_gpu_index}")
print(f"当前使用的GPU设备名称:{current_gpu_name}")
print(f"GPU显存总量:{total_memory:.2f} GB")
print(f"已使用的GPU显存:{used_memory:.2f} GB")
print(f"剩余GPU显存:{free_memory:.2f} GB")
else:
print("CUDA不可用。")
# 检查PyTorch版本
print(f"PyTorch版本:{torch.__version__}")
==========================================================
1.12.1+cu113
CUDA可用,共有 1 个GPU设备可用。
当前使用的GPU设备索引:0
当前使用的GPU设备名称:Tesla T4
GPU显存总量:14.76 GB
已使用的GPU显存:0.00 GB
剩余GPU显存:14.76 GB
PyTorch版本:1.12.1+cu113
[root@iZbp1100kme19bqwzx10b6Z macc]# nvidia-smi
Wed Mar 6 14:03:44 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03 Driver Version: 460.91.03 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla T4 On | 00000000:00:08.0 Off | 0 |
| N/A 28C P8 9W / 70W | 0MiB / 15109MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
安装依赖
拉取源码
git clone https://github.com/THUDM/ChatGLM-6B.git
安装依赖
# 不要安装requirements里面的torch,自己额外指定版本安装。我这边按照服务器的cuda113安装了相关的torch
pip install -r requirements.txt
pip3 install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
# (可选)urllib3高版本有openssl相关问题
# 即:urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compi
# 可以降低版本修复
pip uninstall -y urllib3 && pip install urllib3==1.25.11
# (可选)解决启动web_demo.py脚本出现的报错:AttributeError: 'Textbox' object has no attribute 'style'
pip uninstall -y gradio && pip install gradio==3.40.0
下载模型
请先安装好git lfs(安装教程)
# 创建模型目录,因为ChatGLM-6B项目根目录下的一些演示脚本就是读取这个路径的模型(在项目根目录下创建)
mkdir -p THUDM/chatglm-6b
# 从huggingface上拉取模型文件
git clone https://huggingface.co/THUDM/chatglm-6b
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b
测试模型
交互式命令行模型对话
import os
import platform
import signal
from transformers import AutoTokenizer, AutoModel
import readline
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False
def build_prompt(history):
print(history)
prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
for i in range(len(history),2):
query, response=history[i][''],history[i+1]
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM-6B:{response}"
return prompt
def signal_handler(signal, frame):
global stop_stream
stop_stream = True
def main():
history = []
global stop_stream
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
os.system(clear_command)
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
continue
for response, history in model.stream_chat(tokenizer, query, history=history):
print(response)
if __name__ == "__main__":
main()
web端交互对话
python web_demo.py
评论区