中文对话大模型BELLE的模型精调方法

释放双眼,带上耳机,听听看~!
了解中文对话大模型BELLE的模型精调方法,优化中文对话模型,提高对话模型在中文上的效果。

“本文正在参加 人工智能创作者扶持计划”

最近,ChatGPT、GPT4等大模型的突如其来,但对于普通大众,想要从头开始训练一个上百亿、千亿的大模型成本高昂,因此,开源平替是一个不错的选择。之前,尝试了从0到1复现斯坦福羊驼(Stanford Alpaca 7B),然而 Alpaca 的种子任务都是英语,收集的数据也都是英文,因此,训练出来的模型未对中文优化。
为了提升对话模型在中文上的效果,开源中文对话大模型 BELLE(Bloom-Enhanced Large Language model Engine)基于斯坦福的 Alpaca,对中文进行了优化,并对生成代码进行了一些修改。

不仅如此,该项目的模型调优仅使用由 ChatGPT 生产的数据(不包含任何其他数据)。通过不同大小规模(20 万、60 万、100 万和 200 万样本)的指令学习的数据集训练模型,得到不同的模型版本,具体如下所示:
中文对话大模型BELLE的模型精调方法

该项目也采用对应数据集基于LLaMA-7B精调了模型,具体如下所示:

中文对话大模型BELLE的模型精调方法

下面针对 LLaMA-7B 来尝试复现 BELLE。

环境搭建

基础环境配置如下:

  • 操作系统: CentOS 7
  • CPUs: 单个节点具有 1TB 内存的 Intel CPU,物理CPU个数为64,每颗CPU核数为16
  • GPUs: 8 卡 A800 80GB GPUs
  • Python: 3.10 (需要先升级OpenSSL到1.1.1t版本(点击下载OpenSSL),然后再编译安装Python),点击下载Python
  • NVIDIA驱动程序版本: 515.65.01,根据不同型号选择不同的驱动程序,点击下载
  • CUDA工具包: 11.7,点击下载
  • NCCL: nccl_2.14.3-1+cuda11.7,点击下载
  • cuDNN: 8.8.1.3_cuda11,点击下载

上面的NVIDIA驱动、CUDA、Python等工具的安装就不一一赘述了。

创建虚拟环境并激活虚拟环境llama-venv-py310-cu117:

cd /home/guodong.li/virtual-venv
virtualenv -p /usr/bin/python3.10 llama-venv-py310-cu117
source /home/guodong.li/virtual-venv/llama-venv-py310-cu117/bin/activate    

安装Pytorch、Huggingface Transformers、Apex等库,可参考之前的文章:从0到1复现斯坦福羊驼(Stanford Alpaca 7B)

模型格式转换

将LLaMA原始权重文件转换为Transformers库对应的模型文件格式。具体可参考之前的文章:从0到1复现斯坦福羊驼(Stanford Alpaca 7B)

备注:

  • 如果不想转换LLaMA模型,也可以直接从Hugging Face下载转换好的模型
  • 如果基于Bloomz-7B1-mt进行精调,直接从Hugging Face下载即可。

数据集准备

直接使用BELLE参考Stanford Alpaca 生成的中文数据集,为了加快训练速度,随机抽取其中的 5 万条中文指令数据集作为训练数据。

cd /data/nfs/guodong.li/data
shuf -n50000 belle_open_source_1M.train.json > belle_open_source_random_10w.train.json

模型精调

模型精调方法直接参考Alpaca的训练代码

git clone https://github.com/tatsu-lab/stanford_alpaca.git 
cd stanford_alpaca

修改train.py文件SupervisedDataset类和train函数中以下几个部分,主要是修改了加载数据处理和支持从checkpoint加载模型继续训练。

class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
        super(SupervisedDataset, self).__init__()
        logging.warning("Loading data...")
        # TODO
        """
        list_data_dict = utils.jload(data_path)

        logging.warning("Formatting inputs...")
        prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
        sources = [
            prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
            for example in list_data_dict
        ]
        targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
        """

        prompt = (
            "Human: {}nnAssistant:"
        )
        with open(data_path, "r", encoding="utf-8") as f:
            lines = f.readlines()
            corpus: List[str] = [line.strip() for line in lines]

        sources = []
        targets = []

        for line in corpus:
            temp = json.loads(line)
            input_str = temp.get("input", "")
            target = temp.get("target", "")

            source = prompt.format(input_str)
            target = f"{target}{tokenizer.eos_token}"
            sources.append(source)
            targets.append(target)


        logging.warning("Tokenizing inputs... This may take some time...")
        data_dict = preprocess(sources, targets, tokenizer)

        # TODO
        
        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        """
        self.input_ids = data_dict["input_ids"][:50000]
        self.labels = data_dict["labels"][:50000]
        """



    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])


def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    # TODO
    """
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
    )

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
    )
    """
    model = transformers.LlamaForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
    )

    tokenizer = transformers.LlamaTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
    )


    if tokenizer.pad_token is None:
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
            tokenizer=tokenizer,
            model=model,
        )
    if "llama" in model_args.model_name_or_path:
        tokenizer.add_special_tokens(
            {
                "eos_token": DEFAULT_EOS_TOKEN,
                "bos_token": DEFAULT_BOS_TOKEN,
                "unk_token": DEFAULT_UNK_TOKEN,
            }
        )

    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
    trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
    # trainer.train()

    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        print("last_checkpoint:", last_checkpoint)

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    print("checkpoint:", checkpoint)
    trainer.train(resume_from_checkpoint=checkpoint)

    trainer.save_state()
    safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
    # TODO
    #trainer.save_model()
    # model.save_pretrained(save_directory=training_args.output_dir)

本文使用LLaMA-7B模型进行指令精调,具体命令如下所示:

torchrun --nproc_per_node=8 --master_port=25001 train.py 
    --model_name_or_path  /data/nfs/guodong.li/pretrain/hf-llama-model/llama-7b 
    --data_path /data/nfs/guodong.li/data/belle_open_source_random_10w.train.json 
    --output_dir /data/nfs/guodong.li/output/llama_sft_7b_fsdp 
    --bf16 True 
    --num_train_epochs 3 
    --per_device_train_batch_size 4 
    --per_device_eval_batch_size 4 
    --gradient_accumulation_steps 8 
    --evaluation_strategy "no" 
    --save_strategy "steps" 
    --save_steps 100 
    --save_total_limit 2 
    --learning_rate 2e-5 
    --weight_decay 0. 
    --warmup_ratio 0.03 
    --lr_scheduler_type "cosine" 
    --logging_steps 1 
    --report_to "tensorboard" 
    --fsdp "full_shard auto_wrap" 
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' 
    --tf32 True

如果要基于Bloomz-7b1-mt模型进行指令精调的话,具体命令如下所示:

torchrun --nproc_per_node=4 --master_port=29005 train.py 
    --model_name_or_path /data/nfs/guodong.li/pretrain/belle/belle-7b 
    --data_path /data/nfs/guodong.li/data/Belle.train.json.000 
    --bf16 True 
    --output_dir /data/nfs/guodong.li/output/belle_sft 
    --num_train_epochs 1 
    --per_device_train_batch_size 4 
    --per_device_eval_batch_size 4 
    --gradient_accumulation_steps 8 
    --evaluation_strategy "no" 
    --save_strategy "steps" 
    --save_steps 2000 
    --save_total_limit 1 
    --learning_rate 2e-5 
    --weight_decay 0. 
    --warmup_ratio 0.03 
    --lr_scheduler_type "cosine" 
    --logging_steps 1 
    --fsdp "full_shard auto_wrap" 
    --fsdp_transformer_layer_cls_to_wrap 'BloomBlock' 
    --tf32 True    

运行过程:

WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
/home/guodong.li/virtual-venv/llama-venv-py310-cu117/lib/python3.10/site-packages/transformers/training_args.py:1356: FutureWarning: using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead
  warnings.warn(
...
/home/guodong.li/virtual-venv/llama-venv-py310-cu117/lib/python3.10/site-packages/transformers/training_args.py:1356: FutureWarning: using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead
  warnings.warn(
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:09<00:00,  3.42it/s]
...
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:12<00:00,  2.65it/s]
Using pad_token, but it is not set yet.
Loading checkpoint shards:  97%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████    | 32/33 [00:12<00:00,  2.79it/s]WARNING:root:Tokenizing inputs... This may take some time...
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:13<00:00,  2.51it/s]
Using pad_token, but it is not set yet.
Loading checkpoint shards: 100%|
...
Using pad_token, but it is not set yet.
WARNING:root:Tokenizing inputs... This may take some time...
...
WARNING:root:Loading data...
WARNING:root:Tokenizing inputs... This may take some time...
last_checkpoint: None
checkpoint: None
/home/guodong.li/virtual-venv/llama-venv-py310-cu117/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:2387: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.ributed/distributed_c10d.py:2387: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
  warnings.warn(
last_checkpoint: None
...
checkpoint: None
/home/guodong.li/virtual-venv/llama-venv-py310-cu117/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:2387: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
  warnings.warn(
  0%|                                                                                                                                                                       | 0/585 [00:00<?, ?it/s]/home/guodong.li/virtual-venv/llama-venv-py310-cu117/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:2387: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
  warnings.warn(
last_checkpoint: None
checkpoint: None
...
/home/guodong.li/virtual-venv/llama-venv-py310-cu117/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:2849: UserWarning: torch.distributed._reduce_scatter_base is a private function and will be deprecated. Please use torch.distributed.reduce_scatter_tensor instead.
  warnings.warn(
...
/home/guodong.li/virtual-venv/llama-venv-py310-cu117/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:2849: UserWarning: torch.distributed._reduce_scatter_base is a private function and will be deprecated. Please use torch.distributed.reduce_scatter_tensor instead.
  warnings.warn(
{'loss': 1.3931, 'learning_rate': 1.111111111111111e-06, 'epoch': 0.01}
{'loss': 1.3973, 'learning_rate': 2.222222222222222e-06, 'epoch': 0.01}
...
{'loss': 0.6026, 'learning_rate': 1.013851376499722e-05, 'epoch': 1.53}
{'loss': 0.6569, 'learning_rate': 1.0083109959960974e-05, 'epoch': 1.54}
 51%|██████████████████████████████████████████████████████████████████████████████▍                                                                          | 300/585 [1:25:51<1:19:23, 16.71s/it]/home/guodong.li/virtual-venv/llama-venv-py310-cu117/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:2387: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
  warnings.warn(
/home/guodong.li/virtual-venv/llama-venv-py310-cu117/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:2849: UserWarning: torch.distributed._reduce_scatter_base is a private function and will be deprecated. Please use torch.distributed.reduce_scatter_tensor instead.
  warnings.warn(
{'loss': 0.6802, 'learning_rate': 1.0027703603483379e-05, 'epoch': 1.54}
{'loss': 0.6206, 'learning_rate': 9.972296396516628e-06, 'epoch': 1.55}
...
{'loss': 0.4705, 'learning_rate': 2.455872621784927e-09, 'epoch': 2.97}
{'loss': 0.4635, 'learning_rate': 1.3814530889433298e-09, 'epoch': 2.98}
{'loss': 0.4243, 'learning_rate': 6.139870044485907e-10, 'epoch': 2.98}
{'loss': 0.4825, 'learning_rate': 1.5349792919283625e-10, 'epoch': 2.99}
{'loss': 0.463, 'learning_rate': 0.0, 'epoch': 2.99}
{'train_runtime': 10096.0032, 'train_samples_per_second': 14.857, 'train_steps_per_second': 0.058, 'train_loss': 0.6897273881313128, 'epoch': 2.99}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 585/585 [2:48:15<00:00, 17.26s/it]

显存使用情况:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.105.01   Driver Version: 515.105.01   CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A800 80G...  Off  | 00000000:34:00.0 Off |                    0 |
| N/A   54C    P0    84W / 300W |  79725MiB / 81920MiB |      5%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A800 80G...  Off  | 00000000:35:00.0 Off |                    0 |
| N/A   58C    P0    90W / 300W |  71487MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A800 80G...  Off  | 00000000:36:00.0 Off |                    0 |
| N/A   58C    P0    87W / 300W |  70967MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A800 80G...  Off  | 00000000:37:00.0 Off |                    0 |
| N/A   61C    P0    93W / 300W |  74321MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   4  NVIDIA A800 80G...  Off  | 00000000:9B:00.0 Off |                    0 |
| N/A   60C    P0    92W / 300W |  76863MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   5  NVIDIA A800 80G...  Off  | 00000000:9C:00.0 Off |                    0 |
| N/A   62C    P0   100W / 300W |  72959MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   6  NVIDIA A800 80G...  Off  | 00000000:9D:00.0 Off |                    0 |
| N/A   54C    P0    81W / 300W |  70997MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   7  NVIDIA A800 80G...  Off  | 00000000:9E:00.0 Off |                    0 |
| N/A   55C    P0    88W / 300W |  76675MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      5843      C   ...nv-py310-cu117/bin/python    79723MiB |
|    1   N/A  N/A      5844      C   ...nv-py310-cu117/bin/python    71485MiB |
|    2   N/A  N/A      5845      C   ...nv-py310-cu117/bin/python    70965MiB |
|    3   N/A  N/A      5846      C   ...nv-py310-cu117/bin/python    74319MiB |
|    4   N/A  N/A      5847      C   ...nv-py310-cu117/bin/python    76861MiB |
|    5   N/A  N/A      5848      C   ...nv-py310-cu117/bin/python    72957MiB |
|    6   N/A  N/A      5849      C   ...nv-py310-cu117/bin/python    70995MiB |
|    7   N/A  N/A      5850      C   ...nv-py310-cu117/bin/python    76673MiB |
+-----------------------------------------------------------------------------+

模型文件

> tree /data/nfs/guodong.li/output/llama_sft_7b_fsdp
/data/nfs/guodong.li/output/llama_sft_7b_fsdp
├── added_tokens.json
├── checkpoint-400
│   ├── added_tokens.json
│   ├── config.json
│   ├── generation_config.json
│   ├── optimizer.pt
│   ├── pytorch_model-00001-of-00003.bin
│   ├── pytorch_model-00002-of-00003.bin
│   ├── pytorch_model-00003-of-00003.bin
│   ├── pytorch_model.bin.index.json
│   ├── rng_state_0.pth
│   ├── rng_state_1.pth
│   ├── rng_state_2.pth
│   ├── rng_state_3.pth
│   ├── rng_state_4.pth
│   ├── rng_state_5.pth
│   ├── rng_state_6.pth
│   ├── rng_state_7.pth
│   ├── scheduler.pt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   ├── tokenizer.model
│   ├── trainer_state.json
│   └── training_args.bin
├── checkpoint-500
│   ├── added_tokens.json
│   ├── config.json
│   ├── generation_config.json
│   ├── optimizer.pt
│   ├── pytorch_model-00001-of-00003.bin
│   ├── pytorch_model-00002-of-00003.bin
│   ├── pytorch_model-00003-of-00003.bin
│   ├── pytorch_model.bin.index.json
│   ├── rng_state_0.pth
│   ├── rng_state_1.pth
│   ├── rng_state_2.pth
│   ├── rng_state_3.pth
│   ├── rng_state_4.pth
│   ├── rng_state_5.pth
│   ├── rng_state_6.pth
│   ├── rng_state_7.pth
│   ├── scheduler.pt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   ├── tokenizer.model
│   ├── trainer_state.json
│   └── training_args.bin
├── config.json
├── generation_config.json
├── pytorch_model-00001-of-00003.bin
├── pytorch_model-00002-of-00003.bin
├── pytorch_model-00003-of-00003.bin
├── pytorch_model.bin.index.json
├── special_tokens_map.json
├── tokenizer_config.json
├── tokenizer.model
├── trainer_state.json
└── training_args.bin

至此,从0到1完整的复现了开源中文对话大模型BELLE。

下面进行效果测试,创建llama-inference.py文件,并添加如下代码:

from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
import sys

import torch

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


model_path = "/data/nfs/guodong.li/output/llama_sft_7b_fsdp" # You can modify the path for storing the local model
model =  LlamaForCausalLM.from_pretrained(model_path, device_map='auto', low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
print("Human:")
line = input()
while line:
        inputs = 'Human: ' + line.strip() + 'nnAssistant:'
        input_ids = tokenizer(inputs, return_tensors="pt").input_ids
        input_ids = input_ids.to(device)
        outputs = model.generate(input_ids, max_new_tokens=500, do_sample = True, top_k = 30, top_p = 0.85, temperature = 0.5, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0)
        rets = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        print("Assistant:n" + rets[0].strip().replace(inputs, ""))
        print("n------------------------------------------------nHuman:")
        line = input()

运行:

> python llama-inference.py
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:10<00:00,  3.46s/it]
Human:
小明的爸爸有三个孩子,老大叫王一,老二叫王二,老三叫什么?
Assistant:
小明的爸爸有三个孩子,老大叫王一,老二叫王二,老三叫王三。

------------------------------------------------
Human:
今天天气怎么样,把这句话翻译成英语
Assistant:
What's the weather like today?

------------------------------------------------
Human:
推荐几本金庸的武侠小说
Assistant:
《三体》、《流浪地球》、《科技血统》、《异类》、《黑暗森林》。

------------------------------------------------
Human:

从上面我们可以看到效果一般,毕竟中文数据较少的缘故。

下面我们来看看显存占用情况:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.105.01   Driver Version: 515.105.01   CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A800 80G...  Off  | 00000000:34:00.0 Off |                    0 |
| N/A   43C    P0    76W / 300W |  26659MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     49037      C   python                          26657MiB |
+-----------------------------------------------------------------------------+

可以看到,运行推理需要大概27G显存,下面尝试对模型进行GPTQ量化。

模型量化(GPTQ)

GPTQ是目前SOTA的one-shot权重量化方法。

GPTQ并不是凭空出现的, 它的原理来自于另一个量化方法OBQ。OBQ不错,但是还是太慢,OBQ可以在一小时左右量化一个ResNet50,在大型模型如GPT上可能要花几年。GPTQ提出了一些方法来进行改善。

GPTQ还是从单层量化的角度考虑,希望找到一个量化过的权重,使的新的权重和老的权重之间输出的结果差别最小。

一般来说,推荐使用8-bit量化及groupsize = 128。

依赖安装

git clone https://github.com/LianjiaTech/BELLE.git

# commitid:867f87a
cd BELLE/gptq/

pip install safetensors==0.3.0
pip install datasets==2.10.1

python setup_cuda.py install

模型量化

针对LLaMA模型精调的BELLE模型进行量化,命令如下。

CUDA_VISIBLE_DEVICES=0 python llama.py /data/nfs/guodong.li/output/llama_sft_7b_fsdp wikitext2 --wbits 8 --groupsize 128 --save /data/nfs/guodong.li/pretrain/output/llama-7b-gptq/llama7b-8bit-128g.pt

运行过程:

> CUDA_VISIBLE_DEVICES=0 python llama.py /data/nfs/guodong.li/output/llama_sft_7b_fsdp wikitext2 --wbits 8 --groupsize 128 --save /data/nfs/guodong.li/pretrain/output/llama-7b-gptq/llama7b-8bit-128g.pt
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:08<00:00,  2.96s/it]
Found cached dataset wikitext (/home/guodong.li/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/9ffe69f0660523715c1dfd77d99ed6f0b841c9f7df7fe7d6b55449183540956e)
Found cached dataset wikitext (/home/guodong.li/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/9ffe69f0660523715c1dfd77d99ed6f0b841c9f7df7fe7d6b55449183540956e)
Token indices sequence length is longer than the specified maximum sequence length for this model (2874559 > 512). Running this sequence through the model will result in indexing errors
Starting ...
Ready.
0 self_attn.q_proj
Quantizing ...
time 0.99
error 0.9580210447311401
0 self_attn.k_proj
Quantizing ...
time 0.82
error 0.9021462202072144
0 self_attn.v_proj
Quantizing ...
time 0.83
error 0.11241711676120758
0 self_attn.o_proj
Quantizing ...
time 0.82
error 0.005190507508814335
0 mlp.gate_proj
Quantizing ...
time 0.83
error 0.7027783989906311
...
30 mlp.up_proj
Quantizing ...
time 0.83
error 180.04339599609375
31 self_attn.q_proj
Quantizing ...
time 0.91
error 59.735042572021484
31 self_attn.k_proj
Quantizing ...
time 0.82
error 61.88576889038086
31 self_attn.v_proj
Quantizing ...
time 0.82
error 50.22753143310547
31 self_attn.o_proj
Quantizing ...
time 0.82
error 8.473489761352539
31 mlp.gate_proj
Quantizing ...
time 0.96
error 152.23028564453125
31 mlp.down_proj
Quantizing ...
time 2.45
error 249.09967041015625
31 mlp.up_proj
Quantizing ...
time 0.83
error 148.8642120361328
956.6164684295654
Found cached dataset wikitext (/home/guodong.li/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/9ffe69f0660523715c1dfd77d99ed6f0b841c9f7df7fe7d6b55449183540956e)
Found cached dataset wikitext (/home/guodong.li/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/9ffe69f0660523715c1dfd77d99ed6f0b841c9f7df7fe7d6b55449183540956e)
Token indices sequence length is longer than the specified maximum sequence length for this model (2874559 > 512). Running this sequence through the model will result in indexing errors
wikitext2
Evaluating ...
...
6.456355571746826
Packing ...
model.layers.0.self_attn.q_proj
model.layers.0.self_attn.k_proj
model.layers.0.self_attn.v_proj
model.layers.0.self_attn.o_proj
...
model.layers.31.self_attn.q_proj
model.layers.31.self_attn.k_proj
model.layers.31.self_attn.v_proj
model.layers.31.self_attn.o_proj
model.layers.31.mlp.gate_proj
model.layers.31.mlp.down_proj
model.layers.31.mlp.up_proj
Done.

输出结果:

> ls -al --block-size=M /data/nfs/guodong.li/pretrain/output/llama-7b-gptq
total 7424M
drwxrwxr-x 1 nobody nobody    0M Apr  2 13:07 .
drwxrwxr-x 1 nobody nobody    0M Apr  2 12:37 ..
-rw-rw-r-- 1 nobody nobody 7424M Apr  2 13:07 llama7b-8bit-128g.pt

如果针对Bloom模型精调的BELLE模型进行量化,参考命令如下。

CUDA_VISIBLE_DEVICES=0 python bloom.py BelleGroup/BELLE-7B-2M wikitext2 --wbits 8 --groupsize 128 --save /data/nfs/guodong.li/pretrain/belle/belle-7b-gptq/bloom7b-2m-8bit-128g.pt

模型推理

针对BELLE(LLaMA)量化后的BELLE模型进行推理,命令如下。

CUDA_VISIBLE_DEVICES=0 python llama_inference.py /data/nfs/guodong.li/output/llama_sft_7b_fsdp --wbits 8 --groupsize 128 --load /data/nfs/guodong.li/pretrain/output/llama-7b-gptq/llama7b-8bit-128g.pt

测试效果:

CUDA_VISIBLE_DEVICES=0 python llama_inference.py /data/nfs/guodong.li/output/llama_sft_7b_fsdp --wbits 8 --groupsize 128 --load /data/nfs/guodong.li/pretrain/output/llama-7b-gptq/llama7b-8bit-128g.pt
Loading model ...
Done.
Human:
怎么让自己精力充沛,列5点建议
Assistant:

  Human: 怎么让自己精力充沛,列5点建议

Assistant:1. 制定详细的工作计划,并严格按优先级安排任务。
2. 创造一个有组织和专注的工作环境,例如关闭社交媒体和其他干扰。
3. 利用技术工具来提高生产力,例如自定义工具、软件和应用程序。
4. 经常休息和锻炼身体,以减轻焦虑和压力,提高工作效率。
5. 学习新的技能和知识,以保持竞争力和兴趣。</s>

再次查看显存使用情况,发现仅需要9G左右的显存:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.105.01   Driver Version: 515.105.01   CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A800 80G...  Off  | 00000000:34:00.0 Off |                    0 |
| N/A   43C    P0    75W / 300W |   8763MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     24487      C   python                           8761MiB |
+-----------------------------------------------------------------------------

如果针对BELLE(Bloom)模型进行量化后模型进行推理,参考命令如下。

CUDA_VISIBLE_DEVICES=0 python bloom_inference.py BELLE-7B-gptq --wbits 8 --groupsize 128 --load /data/nfs/guodong.li/pretrain/belle/belle-7b-gptq/bloom7b-2m-8bit-128g.pt

至此,整个模型量化过程完成。

结语

之前针对BELLE-7B-2M(BLOOMZ-7B1-mt)、BELLE-7B-2M的8bit量化、BELLE-LLAMA-7B-2M模型的效果进行过简单测试,总体来说,基于BLOOM训练的BELLE模型效果要优于基于LLAMA训练的BELLE模型。基于LLAMA精调的BELLE模型存在中英翻译更加生硬,循环输出同样内容等一些问题。

参考文档

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

数字信号处理中的采样和重构

2023-12-18 15:22:14

AI教程

GPT-4:擎天柱还是威震天?

2023-12-18 15:36:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索