From d2737e1cd055e883f5b3c242090f06a5b4bf8de7 Mon Sep 17 00:00:00 2001 From: somnus Date: Sat, 21 Sep 2024 11:26:24 +0800 Subject: [PATCH] add ms telechat --- MindSpore-telechat/README.md | 289 ++++++++ .../convert_weight_ms_to_torch.py | 92 +++ .../convert_weight_torch_to_ms.py | 103 +++ .../finetune_telechat_115b.yaml | 215 ++++++ MindSpore-telechat/predict_telechat_115b.yaml | 210 ++++++ MindSpore-telechat/run_telechat.py | 116 +++ MindSpore-telechat/run_telechat_predict.py | 185 +++++ MindSpore-telechat/telechat.py | 454 ++++++++++++ MindSpore-telechat/telechat_config.py | 213 ++++++ MindSpore-telechat/telechat_interleave.py | 678 ++++++++++++++++++ MindSpore-telechat/telechat_layer.py | 275 +++++++ MindSpore-telechat/telechat_predict_utils.py | 73 ++ MindSpore-telechat/telechat_preprocess.py | 112 +++ MindSpore-telechat/telechat_tokenizer.py | 255 +++++++ MindSpore-telechat/telechat_transformer.py | 580 +++++++++++++++ README.md | 2 +- 16 files changed, 3851 insertions(+), 1 deletion(-) create mode 100644 MindSpore-telechat/README.md create mode 100644 MindSpore-telechat/convert_weight_ms_to_torch.py create mode 100644 MindSpore-telechat/convert_weight_torch_to_ms.py create mode 100644 MindSpore-telechat/finetune_telechat_115b.yaml create mode 100644 MindSpore-telechat/predict_telechat_115b.yaml create mode 100644 MindSpore-telechat/run_telechat.py create mode 100644 MindSpore-telechat/run_telechat_predict.py create mode 100644 MindSpore-telechat/telechat.py create mode 100644 MindSpore-telechat/telechat_config.py create mode 100644 MindSpore-telechat/telechat_interleave.py create mode 100644 MindSpore-telechat/telechat_layer.py create mode 100644 MindSpore-telechat/telechat_predict_utils.py create mode 100644 MindSpore-telechat/telechat_preprocess.py create mode 100644 MindSpore-telechat/telechat_tokenizer.py create mode 100644 MindSpore-telechat/telechat_transformer.py diff --git a/MindSpore-telechat/README.md b/MindSpore-telechat/README.md new file mode 100644 index 0000000..ff9671b --- /dev/null +++ b/MindSpore-telechat/README.md @@ -0,0 +1,289 @@ +# 星辰语义大模型 TeleChat2 + +## 模型描述 + +- 星辰语义大模型**TeleChat2**是由中国电信人工智能研究院研发训练的大语言模型,该系列模型**完全基于国产算力**训练。 +- 本次开源**TeleChat2-115B**模型采用10万亿 Tokens中英文高质量语料进行训练,同步开源对话模型**TeleChat2-115B**的多格式、多平台权重文件。 +- **TeleChat2**在训练数据、训练方法等方面进行了改进,在通用问答和知识类、代码类、数学类榜单上相比**TeleChat1**均有大幅提升。 + - **TeleChat2**完全基于国产算力和国产深度学习框架进行训练,算力和算法框架更自主可控。优化MP、PP、SP实现方式提升模型性能,优化算子来提升训练速度。 + - 我们使用大量小模型实验来验证scaling law规律,在不同模型结构、不同数据配比和数据清洗方式中寻找最优设计。 + - 采用RingAttention及其他序列切分方式,实现长文训练性能提升;通过ntk-aware+attention-scaling的方式保证训练长度切换时的平稳过渡,以此来保证模型在不同长度数据下的训练效果。 +- 在微调数据方面,我们进行了指令复杂性提升与多样性扩充,通过数据合成和人工标注生成高质量数据,并使用拒绝采样生成多样的推理路径;通过研究一套基于base模型反向选择偏好对齐数据方案,基于适配数据最大限度提升模型效果。 + - 通用能力较TeleChat1系列模型提升超过29%,在逻辑推理、总结摘要、长文写作和数学计算上均有大幅提升。 + +基于GPU,Torch版本的TeleChat2链接: + +[TeleChat2](https://github.com/Tele-AI/TeleChat2) + +[TeleChat Technical Report](https://arxiv.org/abs/2401.03804) + +``` text +@article{wang2024telechat, + title={TeleChat Technical Report}, + author={Zihan Wang and Xinzhang Liu and Shixuan Liu and Yitong Yao and Yuyao Huang and Zhongjiang He and Xuelong Li and Yongxiang Li and Zhonghao Che and Zhaoxi Zhang and Yan Wang and Xin Wang and Luwen Pu and Huihan Xu and Ruiyu Fang and Yu Zhao and Jie Zhang and Xiaomeng Huang and Zhilong Lu and Jiaxin Peng and Wenjun Zheng and Shiquan Wang and Bingkai Yang and Xuewei he and Zhuoru Jiang and Qiyi Xie and Yanhan Zhang and Zhongqiu Li and Lingling Shi and Weiwei Fu and Yin Zhang and Zilu Huang and Sishi Xiong and Yuxiang Zhang and Chao Wang and Shuangyong Song}, + journal={arXiv preprint arXiv:2401.03804}, + year={2024} +} +``` + +## 模型性能 + +以下模型性能均由Atlas 800T A2硬件环境下测试得出。 + +TeleChat2-115b: + +| config | task | Datasets | SeqLength | phase | performance | +|-----------------------------------------------------| --------------------- |------------|-----------|-----------------|--------------| +| [TeleChat2_115b](./run_telechat_115b_finetune.yaml) | text_generation | example_dataset | 8192 | [finetune](#微调) | 158 tks/s/p | +| [TeleChat2_115b](./run_telechat_115b_predict.yaml) | text_generation | example_dataset | 8192 | [predict](#推理) | 26.5tokens/s | + +## 模型文件 + +`TeleChat2` 基于 `mindformers` 实现,主要涉及的文件有: + +1. 模型具体实现:`mindformers/research/telechat2` + + ```bash + telechat + ├── convert_weight_ms_to_torch.py # ms->torch权重转换脚本 + ├── convert_weight_torch_to_ms.py # torch->ms权重转换脚本 + ├── telechat_preprocess.py # telechat模型的mindrecord数据处理脚本 + ├── telechat.py # 模型实现 + ├── telechat_config.py # 模型配置项 + ├── telechat_layer.py # telechat网络层定义 + ├── telechat_interleave.py # telechat细粒度多副本 + ├── telechat_predict_utils.py # telechat推理模块 + ├── telechat_tokenizer.py # telechat tokenizer + └── telechat_transformer.py # transformer层实现 + ``` + +2. 模型配置:`mindformers/research/telechat2` + + ```bash + telechat + ├── finetune_telechat_115b.yaml # 115b全量微调启动配置 + └── predict_telechat_115b.yaml # 115b推理启动配置 + ``` + +3. 任务启动脚本:`mindformers/research/telechat2` + + ```text + telechat + ├── run_telechat_predict.py # 推理脚本 + └── run_telechat.py # telechat高阶接口使用脚本 + ``` + +## 环境及数据准备 + +### 安装环境 + +**MindFormers安装**以及**软硬件配套关系**参考[MindFormers安装](../../README.md#二MindFormers安装)和[版本匹配关系](../../README.md#三版本匹配关系)。 + +> 注:Atlas 800T A2芯片支持telechat_115B单机多卡推理,至少使用8卡,全参微调至少需要8机64卡。 + +### 数据及权重准备 + +#### 数据集下载 + +TeleChat2_115B所使用的微调数据集是由中电信人工智能科技有限公司所提供。 + +step 1. 获取数据集 + +[数据集] + +数据集的格式: + +```text +# input_dataset examples: + {"text": "<_user>电信主卡和副卡的区别在哪里?<_bot>主卡和副卡的主要区别在于,主卡只能使用一张手机号码。<_end><_user>好的谢谢<_bot>很高兴为您服务<_end><_pad><_pad><_pad>"} +``` + +step 2. 处理数据成mindrecord格式 + +```bash +# 使用mindformers/research/telechat/telechat_preprocess.py进行数据预处理和Mindrecord数据生成 +python telechat_preprocess.py \ +--input_dataset_file /{path}/ \ +--vocab_file_path /{path}/tokenizer.model \ +--max_length 8192 \ +--output_path /{path}/ +``` + +```text +# 参数说明 +input_dataset_file: 预训练的数据集 +vocab_file_path: 词模型文件路径(如使用上述链接下载,指定到对应路径下即可) +max_length: 数据集长度 +output_path: 生成数据集的路径 +``` + + > 注:`bos`, `eos`, `pad`等特殊`ids`要和`yaml`配置文件中`model_config`部分保持一致,默认`bos_token_id=1`, `eos_token_id=2`, `pad_token_id=3`。 +如果有所修改,配置文件中对应设置也需要修改,通常预训练数据不包含`pad_token`,因此建议设置`pad_token_id=-1`。 + +#### 模型权重下载 + +MindFormers提供已经转换完成的预训练权重、词表文件用于预训练、微调和推理,开发者可以下载获取官方权重后,通过下面提供的**权重转换脚本**,将官方权重转换为MindSpore权重;或直接使用MindFormers提供的**已转换权重** + +1.torch模型权重及词模型下载链接: + +- [TeleChat2-115b](https://modelscope.cn/models/TeleAI/TeleChat2-115B) + +下载完成后,运行如下转换脚本,将全量微调的权重转换为完整的ckpt权重。 + +```shell +python mindformers/research/telechat2/convert_weight_torch_to_ms.py \ +--torch_path TORCH_CKPT_DIR \ +--mindspore_path {path} \ +``` + +```text +# 参数说明 +torch_path: torch版本权重保存目录路径 +mindspore_path: 权重保存文件名,可以指定自定义保存路径 +``` + +2.获取MindFormers提供的已转换权重,可直接从下面的链接获取。 + +- [TeleChat2-115b](https://telechat-docker.obs.cn-north-4.myhuaweicloud.com/model_weight/mindspore_115B.ckpt) + +### [分布式训练/微调权重合并](../../docs/feature_cards/Transform_Ckpt.md) + +分布式训练/微调后所得到的权重文件为根据策略切分后的权重,需要手动将切分权重合一,以用于评估和推理。 + +涉及到ckpt的单卡,多卡转换,详细教程请参考特性文档模型[权重切分与合并](../../docs/feature_cards/Transform_Ckpt.md) + +- step 1. 获取模型切分策略文件: + +在执行微调脚本时,模型完成编译后,将会在`output/strategy`路径下生成各卡的切分策略文件,用于权重合并。 + +- step 2. 运行`mindformers/tools/transform_ckpt.py`脚本进行多卡权重合并: + +```shell +python transform_ckpt.py \ +--src_ckpt_strategy {path}/output/strategy/ \ +--src_ckpt_dir {path}/output/checkpoint/ \ +--dst_ckpt_dir {path}/target_checkpoint/ \ +--prefix telechat_115B +``` + +```text +# 参数说明 +src_ckpt_strategy: 步骤1中的切分策略文件路径 +src_ckpt_dir: 原切分权重文件夹 +dst_ckpt_dir: 目标路径 +prefix: ckpt文件前缀名 +``` + +> 注:`transform_checkpoints` 接口当前仅mindspore 2.0以上版本支持,如当前硬件环境只支持2.0以下版本,可以新建conda环境安装mindspore 2.0的cpu版本以执行该脚本 + +## 微调 + +MindFormers提供`TeleChat2-115B`的微调示例,过程中使用中电信人工智能科技有限公司提供的数据集对模型进行预训练,数据集可以参考[数据集下载](#数据集下载)获得。 + +### 全参微调 + +#### 多机训练 + +- step 1. 修改模型对应的配置文件。 + +在模型对应的配置文件`research/telechat2/finetune_telechat_115b.yaml`中,用户可自行修改模型、训练相关参数(推荐开启flash_attention,可加速训练),并通过`train_dataset`的`dataset_dir`参数,指定训练数据集的路径。 + +1. 增加脚本入参`--load_checkpoint /{path}/telechat_115b.ckpt`加载预训练权重 +2. 设置启动脚本中的`--train_dataset_dir /{path}/dataset.mindrecord`加载微调数据集 +3. 设置启动脚本中的`--run_mode finetune` + +配置文件中各参数含义详见[Config配置说明文档](https://gitee.com/mindspore/mindformers/blob/master/configs/README.md)。auto_parallel说明详见[自动并行](../../docs/feature_cards/Auto_Parallel.md)。 + +- step 2. 根据服务器节点数等信息,修改相应的配置。 + +```yaml +# 以telechat-115b模型8机64卡训练为例,默认配置机4096卡,如果节点数有变,需要修改相应的配置。 +# 配置文件路径:finetune_telechat_115b.yaml +parallel_config: + data_parallel: 1 + model_parallel: 8 + pipeline_stage: 8 + micro_batch_num: 8 + vocab_emb_dp: True + gradient_aggregation_group: 4 +``` + +- step3. 设置环境变量,变量配置如下: + +```bash +export ENABLE_CELL_REUSE=1 #编译加速 +export MS_DEV_SIDE_EFFECT_LOAD_ELIM=3 # 去除TensorMove +export MS_MEMORY_POOL_RECYCLE=1 # 内存优化 +export GE_NOT_CUT=1 # 内存优化 +``` + +- step 4. 执行运行脚本。 + +在多机上同时拉起任务,每台机器拉起方式参考单机多卡启动方式。 + +```shell +cd mindformers/ + +# 节点0,节点ip为192.168.1.1,作为主节点,总共8卡且每个节点4卡 +bash scripts/msrun_launcher.sh "python research/telechat2/run_telechat.py \ + --config research/telechat2/finetune_telechat_115b.yaml + --train_dataset /{path}/dataset.mindrecord \ + --use_parallel True \ + --run_mode finetune" \ + 8 4 192.168.1.1 8118 0 output/msrun_log False 300 + +# 节点1,节点ip为192.168.1.2,节点0与节点1启动命令仅参数NODE_RANK不同 +bash scripts/msrun_launcher.sh "python research/telechat2/run_telechat.py \ + --config research/telechat2/finetune_telechat_115b.yaml + --train_dataset /{path}/dataset.mindrecord \ + --use_parallel True \ + --run_mode finetune" \ + 8 4 192.168.1.1 8118 1 output/msrun_log False 300 +``` + +```text +# 参数说明 +config: 配置文件路径 +run_mode: 运行模式,预训练时设置为train +train_dataset: 训练数据集文件夹路径 +use_parallel:开启并行训练 +run_mode:运行模式 +``` + +## 推理 + +推理时将配置文件中`param_init_type`修改为和全量微调一致的数据类型,`compute_dtype`修改为`float16`。 + +### 单机8卡generate推理 + +1. TeleChat2用于在线推理,输入按照 "question"的模板格式输入,Atlas 800T A2芯片支持多卡推理。主要参数配置参考: + +```text +1. 增加脚本入参`--checkpoint_path /{path}/telechat_115b.ckpt`加载微调权重 +2. 增加脚本入参`--vocab_file_path /{path}/tokenizer.model`加载词表地址 +3. 增加脚本入参`--yaml_file predict_telechat_115b.yaml`推理配置文件 +``` + +2. 启动推理 + +```shell +cd mindformers/ +bash scripts/msrun_launcher.sh ./research/telechat2/run_telechat_predict.py + +# 参数说明 +input_file: 输入的问题文件 +yaml_file: 模型的配置文件 +vocab_file: 配置词表路径 +``` + +115B 模型推理结果如下: + +```text +生抽与老抽的区别? + +生抽和老抽是两种不同的酱油,它们在风味、色泽和用途上都有所区别。 + +1.颜色:生抽的颜色比较淡,而老抽的颜色较深。生抽的颜色呈红褐色或棕红色,而老抽的颜色则呈棕黑色。 + +2.味道:生抽具有鲜美的咸味和微甜的味浅,而老抽浓郁,颜色较深。根据个人口味和烹饪需求选择不同的酱油类型可以获得更好的口感和菜肴效果。 +``` diff --git a/MindSpore-telechat/convert_weight_ms_to_torch.py b/MindSpore-telechat/convert_weight_ms_to_torch.py new file mode 100644 index 0000000..4e1ca01 --- /dev/null +++ b/MindSpore-telechat/convert_weight_ms_to_torch.py @@ -0,0 +1,92 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Convert MindSpore checkpoint to Torch""" +import os +import re +import argparse +import torch +from mindspore import load_checkpoint + + +def layer_name_mapping(key): + """Convert huggingface PP weights mapping in MindSpore. + + return: new_name + """ + prefix = '' + # Handle first and last layers + layer_rename_map = { + "model.tok_embeddings.embedding_weight": "transformer.word_embeddings.weight", + "attention_norm.weight": "input_layernorm.weight", + "attention.wo.weight": "self_attention.dense.weight", + "attention.wo.bias": "self_attention.dense.bias", + "attention.wq.weight": "self_attention.query.weight", + "attention.wk_v.weight": "self_attention.key_value.weight", + "feed_forward.w1.weight": "mlp.gate_proj.weight", + "feed_forward.w2.weight": "mlp.down_proj.weight", + "feed_forward.w2.bias": "mlp.down_proj.bias", + "feed_forward.w3.weight": "mlp.up_proj.weight", + "ffn_norm.weight": "post_attention_layernorm.weight", + "model.norm_out.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight" + } + if key in layer_rename_map: + return prefix + layer_rename_map[key] + + match = re.compile(r'\w+\.\w+.(\d+)\.(.*)') + layer_number = match.findall(key)[0][0] + text = match.findall(key)[0][1] + # Handle transformer blocks + return f"transformer.{prefix}h.{layer_number}." + layer_rename_map[text] + +def ms_to_torch(ms_weights): + """Convert ms layers to torch.""" + torch_params = {} + for k, v in ms_weights.items(): + new_name = layer_name_mapping(k) + new_tensor = torch.from_numpy(v.asnumpy()) + torch_params[new_name] = new_tensor + return torch_params + + +def process_shard_files(config): + if config.torch_path and not os.path.exists(config.torch_path): + os.makedirs(config.torch_path, exist_ok=True) + + file_name = "torch" + ms_params = load_checkpoint(config.mindspore_path) + torch_params = ms_to_torch(ms_params) + save_file = config.torch_path + '/' + file_name + '.bin' + torch.save(torch_params, save_file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Telechat convert script") + parser.add_argument("--mindspore_path", + type=str, + default="", + help="The output mindspore checkpoint path.") + parser.add_argument("--torch_path", + type=str, + default="", + help="The input torch checkpoint path.") + args = parser.parse_args() + + # convert hf ckpt to ms + process_shard_files(config=args) + current_path = os.getcwd() + torch_ckpt_path = os.path.join(current_path, args.torch_path) + print("*** finish ms convert torch model, torch_ckpt save in {} ***".format(torch_ckpt_path)) diff --git a/MindSpore-telechat/convert_weight_torch_to_ms.py b/MindSpore-telechat/convert_weight_torch_to_ms.py new file mode 100644 index 0000000..4cc9318 --- /dev/null +++ b/MindSpore-telechat/convert_weight_torch_to_ms.py @@ -0,0 +1,103 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Convert MindSpore checkpoint to Torch""" +import os +import re +import argparse +import torch +from tqdm import tqdm +import mindspore +from mindspore import Tensor, Parameter + +def layer_name_mapping(key): + """Convert huggingface PP weights mapping in MindSpore. + + return: new_name + """ + prefix = '' + # Handle first and last layers + layer_rename_map = { + "transformer.word_embeddings.weight": "model.tok_embeddings.embedding_weight", + "input_layernorm.weight": "attention_norm.weight", + "self_attention.dense.weight": "attention.wo.weight", + "self_attention.dense.bias": "attention.wo.bias", + "self_attention.query.weight": "attention.wq.weight", + "self_attention.key_value.weight": "attention.wk_v.weight", + "mlp.gate_proj.weight": "feed_forward.w1.weight", + "mlp.down_proj.weight": "feed_forward.w2.weight", + "mlp.down_proj.bias": "feed_forward.w2.bias", + "mlp.up_proj.weight": "feed_forward.w3.weight", + "post_attention_layernorm.weight": "ffn_norm.weight", + "lm_head.weight": "lm_head.weight", + "transformer.ln_f.weight": "model.norm_out.weight" + } + if key in layer_rename_map: + return prefix + layer_rename_map[key] + + # Handle transformer blocks + match = re.match(r'^\w+\.\w*\.(\d+)\.(\w+\.\w+\.\w+|\w+\.\w+)$', key) + layer_number = int(match.group(1)) + text = match.group(2) + return f"{prefix}model.layers.{layer_number}." + layer_rename_map[text] + +def hf_to_ms(hf_weights, ms_dtype=mindspore.float16, for_save=False): + """Convert hf layers to ms.""" + ms_params = {} + for k, v in hf_weights.items(): + new_name = layer_name_mapping(k) + new_tensor = Tensor(v.float().detach().numpy(), ms_dtype) + ms_params[new_name] = Parameter(new_tensor, name=new_name) + if for_save: + return [{'name': k, 'data': v} for k, v in ms_params.items()] + return ms_params + +def process_shard_files(files, config, ms_dtype=mindspore.float16): + ''' torch ckpt files loop''' + if config.mindspore_path and not os.path.exists(args.mindspore_path): + os.makedirs(config.mindspore_path, exist_ok=True) + + ms_file_name = "mindspore" + combine_params = [] + for per_file in tqdm(files): + pt_states = torch.load(per_file, map_location='cpu') + ms_params = hf_to_ms(pt_states, ms_dtype, True) + combine_params.extend(ms_params) + del ms_params + save_file = config.mindspore_path + '/' + ms_file_name + '.ckpt' + mindspore.save_checkpoint(combine_params, save_file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Telechat convert script") + parser.add_argument("--torch_path", + type=str, + default="", + help="The input torch checkpoint path.") + parser.add_argument("--mindspore_path", + type=str, + default="", + help="The output mindspore checkpoint path.") + args = parser.parse_args() + + # convert hf ckpt to ms + files_list = [] + for file_name in os.listdir(args.torch_path): + if file_name.startswith("pytorch_model") and file_name.endswith(".bin"): + files_list.append(os.path.join(args.torch_path, file_name)) + process_shard_files(files=files_list, config=args) + current_path = os.getcwd() + mindspore_ckpt_path = os.path.join(current_path, args.mindspore_path) + print("*** finish torch convert ms model, ms_ckpt save in {} ***".format(mindspore_ckpt_path)) diff --git a/MindSpore-telechat/finetune_telechat_115b.yaml b/MindSpore-telechat/finetune_telechat_115b.yaml new file mode 100644 index 0000000..dce6976 --- /dev/null +++ b/MindSpore-telechat/finetune_telechat_115b.yaml @@ -0,0 +1,215 @@ +seed: 0 +output_dir: './output' +load_checkpoint: '' +src_strategy_path_or_dir: '' +auto_trans_ckpt: False +only_save_strategy: False +resume_training: False +ignore_data_skip: False +run_mode: 'finetune' + +# trainer config +trainer: + type: CausalLanguageModelingTrainer + model_name: 'telechat_115b' + +# runner config +runner_config: + epochs: 10 + batch_size: 1 + sink_mode: True + sink_size: 1 + +# optimizer +optimizer: + type: AdamW + betas: [0.9, 0.95] + eps: 1.e-8 + weight_decay: 0.1 + +# lr sechdule +lr_schedule: + type: CosineWithWarmUpLR + learning_rate: 1.5e-4 + lr_end: 1.5e-5 + warmup_ratio: 0.03 + total_steps: -1 # -1 means it will load the total steps of the dataset + +# dataset +train_dataset: &train_dataset + data_loader: + type: MindDataset + dataset_dir: "" + shuffle: True + input_columns: [ "input_ids", "labels" ] + num_parallel_workers: 8 + python_multiprocessing: False + drop_remainder: True + batch_size: 6 + repeat: 1 + numa_enable: False + prefetch_size: 1 +train_dataset_task: + type: CausalLanguageModelDataset + dataset_config: *train_dataset +# if True, do evaluate during the training process. if false, do nothing. +# note that the task trainer should support _evaluate_in_training function. +do_eval: False + +# eval dataset +eval_dataset: &eval_dataset + data_loader: + type: MindDataset + dataset_dir: "" + shuffle: False + input_columns: [ "input_ids", "labels" ] + num_parallel_workers: 8 + python_multiprocessing: False + drop_remainder: False + repeat: 1 + numa_enable: False + prefetch_size: 1 +eval_dataset_task: + type: CausalLanguageModelDataset + dataset_config: *eval_dataset + +use_parallel: True +# parallel context config +parallel: + parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel + gradients_mean: False + enable_alltoall: False + full_batch: True + search_mode: "sharding_propagation" + enable_parallel_optimizer: True + strategy_ckpt_save_file: "./ckpt_strategy.ckpt" + pipeline_config: + pipeline_interleave: True + pipeline_scheduler: '1f1b' + parallel_optimizer_config: + gradient_accumulation_shard: False + parallel_optimizer_threshold: 64 +parallel_config: + data_parallel: 8 + model_parallel: 8 + pipeline_stage: 8 + use_seq_parallel: True + micro_batch_num: 8 + vocab_emb_dp: False + gradient_aggregation_group: 4 +# when model parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process. +micro_batch_interleave_num: 1 + +# recompute config +recompute_config: + recompute: False + select_recompute: True + select_comm_recompute: [ 8, 8, 10, 11, 10, 7, 5, 4 ] + parallel_optimizer_comm_recompute: False + mp_comm_recompute: True + recompute_slice_activation: True + +# callbacks +callbacks: + - type: MFLossMonitor + - type: CheckpointMonitor + prefix: "telechat_115b" + save_checkpoint_steps: 1500 + keep_checkpoint_max: 300 + integrated_save: False + async_save: False + - type: ObsMonitor + +# mindspore context init config +context: + mode: 0 #0--Graph Mode; 1--Pynative Mode + device_target: "Ascend" + enable_graph_kernel: False + max_call_depth: 10000 + max_device_memory: "54.5GB" + save_graphs: False + save_graphs_path: "./graph" + device_id: 0 + jit_config: {"jit_level":"O1"} + +model: + model_config: + type: TelechatConfig + batch_size: 1 # add for increase predict + seq_length: 8192 + hidden_size: 8192 + num_layers: 96 + num_heads: 64 + n_kv_heads: 8 + vocab_size: 131072 + rms_norm_eps: 1.0e-5 + bos_token_id: 1 + eos_token_id: 2 + pad_token_id: 3 + pp_interleave_num: 3 + ignore_token_id: -100 + embed_dropout_prob: 0. + hidden_dropout_prob: 0. + attention_dropout_prob: 0. + intermediate_size: 40960 + res_dtype: "float32" + compute_dtype: "bfloat16" + layernorm_compute_type: "float32" + softmax_compute_type: "float32" + rotary_dtype: "float32" + param_init_type: "float32" + use_past: False + parallel_optimizer: True + pretrain_seqlen: 8192 # seqlen of the pretrain checkpoint + extend_method: "None" # support "None", "PI", "NTK" + use_flash_attention: True # FA can accelerate training or finetune + offset: [ [ -1, -1, 0, 0, 0, 0, 0, 0 ], [ -1, 0, 0, 0, 0, 0, 0, 0 ], [ 0, 0, 0, 1, 1, 1, 1, -1 ] ] + fine_grain_interleave: 2 + use_past_shard: False + repetition_penalty: 1 + max_decode_length: 512 + top_k: 3 + top_p: 1 + do_sample: False + arch: + type: TelechatForCausalLM + +processor: + return_tensors: ms + tokenizer: + unk_token: '' + bos_token: '<_start>' + eos_token: '<_end>' + pad_token: '<_pad>' + type: TelechatTokenizer + type: TelechatProcessor + +# metric +metric: + type: PerplexityMetric + +# wrapper cell config +runner_wrapper: + type: MFTrainOneStepCell + scale_sense: 1.0 + use_clip_grad: True + +eval_callbacks: + - type: ObsMonitor + +auto_tune: False +filepath_prefix: './autotune' +autotune_per_step: 10 + +profile: False +profile_start_step: 1 +profile_stop_step: 10 +init_start_profile: False +profile_communication: False +profile_memory: True +layer_scale: False +layer_decay: 0.65 +lr_scale_factor: 256 + +# aicc +remote_save_url: "Please input obs url on AICC platform." diff --git a/MindSpore-telechat/predict_telechat_115b.yaml b/MindSpore-telechat/predict_telechat_115b.yaml new file mode 100644 index 0000000..ca1811d --- /dev/null +++ b/MindSpore-telechat/predict_telechat_115b.yaml @@ -0,0 +1,210 @@ +seed: 0 +output_dir: './output' +load_checkpoint: '' +src_strategy_path_or_dir: '' +auto_trans_ckpt: False +only_save_strategy: False +resume_training: False +run_mode: 'predict' + +# trainer config +trainer: + type: CausalLanguageModelingTrainer + model_name: 'telechat_115b' + +# runner config +runner_config: + epochs: 10 + batch_size: 1 + sink_mode: True + sink_size: 1 + +# optimizer +optimizer: + type: AdamW + betas: [0.9, 0.95] + eps: 1.e-8 + weight_decay: 0.1 + +# lr sechdule +lr_schedule: + type: CosineWithWarmUpLR + learning_rate: 1.5e-4 + lr_end: 1.5e-5 + warmup_ratio: 0.03 + total_steps: -1 # -1 means it will load the total steps of the dataset + +# dataset +train_dataset: &train_dataset + data_loader: + type: MindDataset + dataset_dir: "" + shuffle: True + input_columns: [ "input_ids", "labels" ] + num_parallel_workers: 8 + python_multiprocessing: False + drop_remainder: True + batch_size: 6 + repeat: 1 + numa_enable: False + prefetch_size: 1 +train_dataset_task: + type: CausalLanguageModelDataset + dataset_config: *train_dataset +# if True, do evaluate during the training process. if false, do nothing. +# note that the task trainer should support _evaluate_in_training function. +do_eval: False + +# eval dataset +eval_dataset: &eval_dataset + data_loader: + type: MindDataset + dataset_dir: "" + shuffle: False + input_columns: [ "input_ids", "labels" ] + num_parallel_workers: 8 + python_multiprocessing: False + drop_remainder: False + repeat: 1 + numa_enable: False + prefetch_size: 1 +eval_dataset_task: + type: CausalLanguageModelDataset + dataset_config: *eval_dataset + +use_parallel: True +# parallel context config +parallel: + parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel + gradients_mean: False + enable_alltoall: False + full_batch: True + search_mode: "sharding_propagation" + strategy_ckpt_save_file: "./ckpt_strategy.ckpt" + parallel_optimizer_config: + gradient_accumulation_shard: False + parallel_optimizer_threshold: 64 +parallel_config: + data_parallel: 1 + model_parallel: 8 + pipeline_stage: 1 + use_seq_parallel: False + micro_batch_num: 8 + vocab_emb_dp: False + gradient_aggregation_group: 4 +# when model parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process. +micro_batch_interleave_num: 1 + +# recompute config +recompute_config: + recompute: False + select_recompute: False + select_comm_recompute: False + parallel_optimizer_comm_recompute: False + mp_comm_recompute: True + recompute_slice_activation: True + +# callbacks +callbacks: + - type: MFLossMonitor + - type: CheckpointMonitor + prefix: "telechat_115b" + save_checkpoint_steps: 1500 + keep_checkpoint_max: 300 + integrated_save: False + async_save: False + - type: ObsMonitor + +# mindspore context init config +context: + mode: 0 #0--Graph Mode; 1--Pynative Mode + device_target: "Ascend" + enable_graph_kernel: False + max_call_depth: 10000 + max_device_memory: "58GB" + save_graphs: False + save_graphs_path: "./graph" + device_id: 0 + jit_config: {"jit_level":"O1"} + +model: + model_config: + type: TelechatConfig + batch_size: 1 # add for increase predict + seq_length: 8192 + hidden_size: 8192 + num_layers: 96 + num_heads: 64 + n_kv_heads: 8 + vocab_size: 131072 + rms_norm_eps: 1.0e-5 + bos_token_id: 1 + eos_token_id: 2 + pad_token_id: 3 + ignore_token_id: -100 + embed_dropout_prob: 0. + hidden_dropout_prob: 0. + attention_dropout_prob: 0. + intermediate_size: 40960 + res_dtype: "float16" + compute_dtype: "float16" + layernorm_compute_type: "float32" + softmax_compute_type: "float32" + rotary_dtype: "float16" + param_init_type: "float16" + use_past: True + pretrain_seqlen: 8192 # seqlen of the pretrain checkpoint + extend_method: "None" # support "None", "PI", "NTK" + use_flash_attention: True # FA can accelerate training or finetune + block_size: 128 + num_blocks: 256 + is_dynamic: True + wo_has_bias: True + use_past_shard: False + repetition_penalty: 1 + max_decode_length: 512 + top_k: 3 + top_p: 1 + do_sample: False + arch: + type: TelechatForCausalLM + +processor: + return_tensors: ms + tokenizer: + unk_token: '' + bos_token: '<_start>' + eos_token: '<_end>' + pad_token: '<_pad>' + type: TelechatTokenizer + type: TelechatProcessor + +# metric +metric: + type: PerplexityMetric + +# wrapper cell config +runner_wrapper: + type: MFTrainOneStepCell + scale_sense: 1.0 + use_clip_grad: True + +eval_callbacks: + - type: ObsMonitor + +auto_tune: False +filepath_prefix: './autotune' +autotune_per_step: 10 + +profile: False +profile_start_step: 1 +profile_stop_step: 10 +init_start_profile: False +profile_communication: False +profile_memory: True +layer_scale: False +layer_decay: 0.65 +lr_scale_factor: 256 + +# aicc +remote_save_url: "Please input obs url on AICC platform." diff --git a/MindSpore-telechat/run_telechat.py b/MindSpore-telechat/run_telechat.py new file mode 100644 index 0000000..6e064c5 --- /dev/null +++ b/MindSpore-telechat/run_telechat.py @@ -0,0 +1,116 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat Train/Finetune scripts.""" +import os +import sys +import argparse + +# pylint: disable=W0611 +from mindformers import Trainer, MindFormerConfig +from mindformers.tools.utils import str2bool +from mindformers.tools.cloud_adapter import cloud_monitor +from mindformers.core.context import build_context +from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister + +from telechat_config import TelechatConfig +from telechat import TelechatForCausalLM +MindFormerRegister.register_cls(TelechatConfig, MindFormerModuleType.CONFIG) +MindFormerRegister.register_cls(TelechatForCausalLM, MindFormerModuleType.MODELS) + +sys.path.insert(0, os.getcwd().split('research')[0]) + +@cloud_monitor() +def main(): + """main function.""" + yaml_path = os.path.expanduser(args.config) + if not os.path.exists(yaml_path): + raise FileNotFoundError(yaml_path) + + config = MindFormerConfig(os.path.realpath(yaml_path)) + if args.seq_length is not None: + config.model.model_config.seq_length = args.seq_length + if args.mode is not None: + config.context.mode = args.mode + if args.mode: + config.recompute_config.recompute = False + if args.use_parallel is not None: + config.use_parallel = args.use_parallel + if args.device_id is not None: + config.context.device_id = args.device_id + if ckpt is None: + ckpt = config.load_checkpoint + if args.strategy is not None and os.path.exists(args.strategy): + config.src_strategy_path_or_dir = args.strategy + if args.auto_trans_ckpt is not None: + config.auto_trans_ckpt = args.auto_trans_ckpt + if args.vocab_file is not None: + config.processor.tokenizer.vocab_file = args.vocab_file + if args.remote_save_url is None: + config.remote_save_url = args.remote_save_url + + # init context + build_context(config) + + config.model.model_config.use_past = False + config.model.model_config.run_mode = args.run_mode + + # start task + if args.run_mode == 'train': + trainer = Trainer(args=config, + task=args.task, + train_dataset=args.train_dataset) + trainer.train(train_checkpoint=ckpt, auto_trans_ckpt=config.auto_trans_ckpt, resume_training=resume) + elif args.run_mode == 'finetune': + trainer = Trainer(args=config, + task=args.task, + train_dataset=args.train_dataset) + trainer.finetune(finetune_checkpoint=ckpt, auto_trans_ckpt=config.auto_trans_ckpt, resume_training=resume) + else: + raise ValueError("run_mode only support train and finetune.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--task', default='text_generation', type=str, + help='set task type.') + parser.add_argument('--config', default='telechat2/finetune_telechat_115b.yaml', type=str, + help='set task type.') + parser.add_argument('--run_mode', default='finetune', type=str, + help='set run mode for model.') + parser.add_argument('--seq_length', default=None, type=int, + help='seq_length') + parser.add_argument('--use_parallel', default=True, type=str2bool, + help='open parallel for model.') + parser.add_argument('--device_id', default=0, type=int, + help='device id set when run on single card. Default: 0') + parser.add_argument('--mode', default=0, type=int, + help='0--Graph Mode; 1--Pynative Mode') + parser.add_argument('--load_checkpoint', default=None, type=str, + help='checkpoint name or dir to load.') + parser.add_argument('--src_strategy', default=None, type=str, + help='strategy of load_checkpoint') + parser.add_argument('--auto_trans_ckpt', default=None, type=str2bool, + help='whether to transform checkpoint to the checkpoint matching current distribute strategy.') + parser.add_argument('--resume', default=None, type=str2bool, + help='whether resume training.') + parser.add_argument('--train_dataset', default=None, type=str, + help='set train dataset.') + parser.add_argument('--remote_save_url', default=None, type=str, + help='whether use optimizer parallel. Default: None') + parser.add_argument('--vocab_file', default=None, type=str, + help='tokenizer model') + args = parser.parse_args() + + main() diff --git a/MindSpore-telechat/run_telechat_predict.py b/MindSpore-telechat/run_telechat_predict.py new file mode 100644 index 0000000..c9812d2 --- /dev/null +++ b/MindSpore-telechat/run_telechat_predict.py @@ -0,0 +1,185 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat predict scripts.""" +import argparse +import json +import copy +from typing import Optional, Union, List, Dict +import mindspore as ms +from mindspore import Model, Tensor +from mindspore.common import initializer + +from mindformers import MindFormerConfig +from mindformers import build_context +from mindformers.tools import set_output_path +from mindformers.tools.utils import str2bool +from mindformers.tools.logger import logger +from mindformers.generation import GenerationConfig +from mindformers.trainer.utils import transform_and_load_checkpoint +from mindformers.core.parallel_config import build_parallel_config + +from telechat_tokenizer import TelechatTokenizer +from telechat_config import TelechatConfig +from telechat import TelechatForCausalLM +from research.telechat.telechat_predict_utils import History + + +def chat(model, tokenizer, question: str = '', history: Union[List[Dict], History] = None, + generation_config: Optional[GenerationConfig] = None): + """ + Args: + tokenizer: the tokenizer of telechat + question: question which the model reply in this turn + history: history which will format the input for telechat + stream: if return the full text at last or yield the text in token + generation_config: configuration for generation + **kwargs: args which will update the generation config or pass to model forward + """ + if not generation_config: + logger.error("generation_config is None") + raise ValueError("generation_config must not be None") + if not question: + logger.error("question is empty") + raise ValueError("question must not be empty") + if history is None: + history = [] + + generation_config = copy.deepcopy(generation_config) + user_id = generation_config.user_token_id + bot_id = generation_config.bot_token_id + + # transfer to History + if not isinstance(history, History): + history = History(tokenizer, history) + + inputs = build_inputs_for_chat(tokenizer, question, history, generation_config, user_id, bot_id) + history.append({"role": "user", "content": question}) + outputs = model.generate(inputs, + max_length=generation_config.max_decode_length, + do_sample=generation_config.do_sample, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + max_new_tokens=generation_config.max_new_tokens) + response = tokenizer.decode(outputs[0][len(inputs):-1]) + history.append({"role": "bot", "content": response}) + return response, history + + +def build_inputs_for_chat(tokenizer, question, history, generation_config, usr_id, bot_id): + """ + check history and build inputs here + """ + # first tokenize question + q_token = tokenizer(question) + qa_history = copy.deepcopy(history) + + # get the max length we should build our inputs in + model_max_length = generation_config.seq_length + build_max_length = max(0, model_max_length - generation_config.max_new_tokens) \ + if generation_config.max_new_tokens else max(0, generation_config.max_decode_length) + if build_max_length < 3: + raise ValueError("the model can not meet the requirements of input length,Please check config") + + # trunc left + input_tokens = [usr_id] + q_token["input_ids"][-build_max_length + 1:] + [bot_id] + length = len(input_tokens) + + while len(qa_history) >= 1: + message = qa_history.pop() + if message["role"] == "user": + tokens = [usr_id] + message["input_ids"] + elif message["role"] == "bot": + tokens = [bot_id] + message["input_ids"] + [generation_config.eos_token_id] + else: + tokens = [] + if len(tokens) + length >= build_max_length: + break + else: + input_tokens = tokens + input_tokens + return input_tokens + + +def main(): + """main function.""" + input_questions = [] + input_file = open(args.input_file, 'r', encoding='utf-8') + for line in input_file.readlines(): + dic = json.loads(line) + input_questions.append(dic["input"]) + input_file.close() + # set model config + config = MindFormerConfig(args.yaml_file) + config.context.device_id = 0 + if args.checkpoint_path: + config.load_checkpoint = args.checkpoint_path + config.use_parallel = True + # 初始化环境 + set_output_path(args.output_chat) + build_context(config) + build_parallel_config(config) + + # build tokenizer + tokenizer = TelechatTokenizer(args.vocab_file_path, fast_tokenizer=True, trust_remote_code=True) + + model_config = config.model.model_config + model_config.parallel_config = config.parallel_config + model_config.batch_size = 1 + model_config.run_mode = "predict" + model_config.use_past = args.use_past + model_config.use_flash_attention = True + model_config.user_token_id = tokenizer.convert_tokens_to_ids(args.user_token) + model_config.bot_token_id = tokenizer.convert_tokens_to_ids(args.bot_token) + model_config.max_new_tokens = None + + model_config = TelechatConfig(**model_config) + + # build model from config + model = TelechatForCausalLM(model_config) + ms_model = Model(model) + print(f"[INFO_config]: {model_config}") + print("----------------Transform and load checkpoint----------------") + seq_length = model_config.seq_length + input_ids = Tensor(shape=(model_config.batch_size, seq_length), dtype=ms.int32, init=initializer.One()) + infer_data = model.prepare_inputs_for_predict_layout(input_ids) + transform_and_load_checkpoint(config, ms_model, model, infer_data, do_predict=True) + history = [] + for question in input_questions: + print("question:", question) + answer, history = chat(model, tokenizer, question, history, generation_config=model_config) + print("answer:", answer) + print("\n截至目前的聊天记录是:", history) + print("\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input_file', default='', type=str, + help='input to infer.') + parser.add_argument('--vocab_file_path', default='', type=str, + help='which model to use.') + parser.add_argument('--checkpoint_path', default='', type=str, + help='set checkpoint path.') + parser.add_argument('--use_past', default=True, type=str2bool, + help='whether use past.') + parser.add_argument('--yaml_file', default="", type=str, + help='predict yaml path') + parser.add_argument('--output_chat', default="./output_chat", type=str, + help='chat output') + parser.add_argument('--user_token', default="<_user>", type=str, + help='user_token') + parser.add_argument('--bot_token', default="<_bot>", type=str, + help='bot_token') + args = parser.parse_args() + main() diff --git a/MindSpore-telechat/telechat.py b/MindSpore-telechat/telechat.py new file mode 100644 index 0000000..af29eba --- /dev/null +++ b/MindSpore-telechat/telechat.py @@ -0,0 +1,454 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat models' APIs.""" +import copy +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore import Tensor, nn +from mindspore.context import ParallelMode +from mindspore.ops import operations as P +from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation + +from mindformers.core.loss.loss import CrossEntropyLoss +from mindformers.models.modeling_utils import PreTrainedModel +from mindformers.models.utils import LayerSetting, lazy_inline, check_fine_grain_interleave_valid +from mindformers.models.llama.llama_layer import LlamaRMSNorm +from mindformers.modules.layers import Linear, FreqsMgr, Dropout +from mindformers.modules.transformer import LowerTriangularMaskWithDynamic +from mindformers.modules.transformer.op_parallel_config import _check_config +from mindformers.tools.logger import logger +from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister +from mindformers.tools.utils import get_ms_enable_asd_op, get_predict_run_mode, get_use_rope_self_define + +from telechat_transformer import TelechatDecodeLayer +from telechat_interleave import TelechatDecodeLayerInterleave +from telechat_layer import TelechatEmbedding +from telechat_config import TelechatConfig + +__all__ = ['TelechatModel', 'TelechatForCausalLM'] + + +class TelechatPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TelechatConfig + base_model_prefix = "telechat" + + +class TelechatModel(TelechatPreTrainedModel): + r""" + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TelechatDecoderLayer`] + Args: + config(TelechatConfig): the config of network + + Returns: + output: Tensor, the output of telechat decoderlayer + + Examples: + >>> from mindformers import TelechatModel + >>> network = TelechatModel.from_pretrained('telechat_115b') + >>> type(network) + + """ + + def __init__(self, + config: TelechatConfig = None): + super().__init__(config, auto_prefix=True) + _check_config(config.parallel_config) + self.dtype = config.compute_dtype + self.hidden_size = config.hidden_size + self.num_layers = config.num_layers + self.n_head = config.num_heads + self.head_dim = self.hidden_size // self.n_head + self.pad_token_id = config.pad_token_id + self.is_first_iteration = True + self.use_past = config.use_past + self.use_flash_attention = config.use_flash_attention + + self.embed_dropout_prob = config.embed_dropout_prob + self.embeddings_dropout = Dropout(1-self.embed_dropout_prob) + + self.concat = P.Concat(-1) + self.cast = P.Cast() + self.shape = P.Shape() + self.reshape = P.Reshape() + # default open internal kernel boost + self.enable_asd_op = get_ms_enable_asd_op() + logger.info("enable asd op:{}".format(self.enable_asd_op)) + self.use_rope_self_define = get_use_rope_self_define() + + self.freqs_mgr = FreqsMgr(head_dim=self.head_dim, + seq_length=config.seq_length, + max_position_embedding=config.max_position_embedding, + rotary_dtype=config.rotary_dtype, + theta=config.theta, + scaling_factor=config.scaling_factor, + extend_method=config.extend_method, + parallel_config=config.parallel_config) + self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length, + compute_type=config.compute_dtype, + is_dynamic=config.is_dynamic, + pad_token_id=config.pad_token_id, + use_flash_attention=config.use_flash_attention, + use_attn_mask_compression=config.use_attn_mask_compression) + self.tok_embeddings = TelechatEmbedding(vocab_table_size=config.vocab_size, + sigma=config.sigma, + mean=config.mean, + embedding_size=config.hidden_size, + param_init_type=config.embedding_init_type, + parallel_optimizer=config.parallel_optimizer) + self.fine_grain_interleave = check_fine_grain_interleave_valid(config.fine_grain_interleave, + config.parallel_config) + self.layers = nn.CellList() + self.layer_setting = LayerSetting(config.num_layers, + config.offset, + config.parallel_config, + config.pp_interleave_num) + for layer_id in range(config.num_layers): + if self.fine_grain_interleave: + layer = TelechatDecodeLayerInterleave(config.run_mode, + config.batch_size, + config.seq_length, + layer_id, + dim=config.hidden_size, + n_heads=config.num_heads, + num_layers=config.num_layers, + n_kv_heads=config.n_kv_heads, + hidden_dropout_prob=config.hidden_dropout_prob, + attention_dropout_prob=config.attention_dropout_prob, + intermediate_size=config.intermediate_size, + ffn_dim_multiplier=config.ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + qkv_has_bias=config.qkv_has_bias, + wo_has_bias=config.wo_has_bias, + compute_dtype=config.compute_dtype, + layernorm_compute_dtype=config.layernorm_compute_type, + softmax_compute_dtype=config.softmax_compute_type, + rotary_dtype=config.rotary_dtype, + param_init_type=config.param_init_type, + res_dtype=config.res_dtype, + use_flash_attention=config.use_flash_attention, + is_dynamic=config.is_dynamic, + use_rope_slice=config.use_rope_slice, + fine_grain_interleave=config.fine_grain_interleave, + parallel_config=config.parallel_config) + else: + layer = TelechatDecodeLayer(config.run_mode, + layer_id, + dim=config.hidden_size, + n_heads=config.num_heads, + n_kv_heads=config.n_kv_heads, + sigma=config.sigma, + mean=config.mean, + hidden_dropout_prob=config.hidden_dropout_prob, + attention_dropout_prob=config.attention_dropout_prob, + intermediate_size=config.intermediate_size, + multiple_of=config.multiple_of, + ffn_dim_multiplier=config.ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + qkv_has_bias=config.qkv_has_bias, + wo_has_bias=config.wo_has_bias, + compute_dtype=config.compute_dtype, + layernorm_compute_dtype=config.layernorm_compute_type, + softmax_compute_dtype=config.softmax_compute_type, + rotary_dtype=config.rotary_dtype, + param_init_type=config.param_init_type, + res_dtype=config.res_dtype, + use_past=config.use_past, + use_flash_attention=config.use_flash_attention, + use_attn_mask_compression=config.use_attn_mask_compression, + block_size=config.block_size, + num_blocks=config.num_blocks, + is_dynamic=config.is_dynamic, + use_rope_slice=config.use_rope_slice, + parallel_config=config.parallel_config) + self.layer_setting(layer, layer_id) + self.layers.append(layer) + self.norm_out = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps, + compute_type=config.layernorm_compute_type) + dp = config.parallel_config.data_parallel + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.tok_embeddings.pipeline_stage = 0 + if config.parallel_config.pipeline_stage > 1: + self.norm_out.pipeline_stage = config.parallel_config.pipeline_stage - 1 + self.tok_embeddings.set_comm_fusion(2) + self.norm_out.set_comm_fusion(2) + else: + self.tok_embeddings.set_comm_fusion(config.parallel_config.gradient_aggregation_group) + self.norm_out.set_comm_fusion(config.parallel_config.gradient_aggregation_group) + + self.tok_embeddings.shard(config.parallel_config) + self.casual_mask.shard(config.parallel_config) + self.concat.shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) + if self.fine_grain_interleave: + self.norm_out.shard((dp, 1)) + else: + self.norm_out.shard((dp, 1, 1)) + + # pylint: disable=W0613 + def construct(self, tokens: Tensor, batch_valid_length=None, batch_index=None, zactivate_len=None, + block_tables=None, slot_mapping=None, prefix_keys_values=None): + """ + Forward of telechat model. + + Args: + tokens: the tokenized inputs with datatype int32 + batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental + prediction. Tensor of shape :math:`(batch_size,)`. Default None. + block_tables (Tensor[int64]): Store mapping tables for each sequence. + slot_mapping (Tensor[int32]): Store token cache physical slot index. + Returns: + output: Tensor, the output of telechat decoderlayer + """ + # preprocess + bs, seq_len = self.shape(tokens) + mask = None + if self.use_past: + if self.is_first_iteration: + if self.use_rope_self_define: + freqs_cis = self.freqs_mgr(seq_len) + else: + freqs_cis = self.freqs_mgr.prefill(bs, seq_len) + + if self.use_flash_attention: + if self.enable_asd_op: # only support fp16 + mask = self.casual_mask(tokens) # mask: [bs, seq, seq] + mask = self.cast(mask, mstype.float16) + else: + mask = self.casual_mask(tokens) # mask: [bs, seq, seq] + + if prefix_keys_values is not None: + if mask is None: + mask = self.casual_mask(tokens) + prefix_length = prefix_keys_values[0].shape[2] + prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype) + mask = self.concat((prefix_mask, mask)) + else: + freqs_cis = self.freqs_mgr.increment(batch_valid_length) + else: + mask = self.casual_mask(tokens) + freqs_cis = self.freqs_mgr(seq_len) + if prefix_keys_values is not None: + prefix_length = prefix_keys_values[0].shape[2] + prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype) + mask = self.concat((prefix_mask, mask)) + + # tokens: [bs, seq/1] + h = self.tok_embeddings(tokens) + h = self.cast(h, self.dtype) + h = self.embeddings_dropout(h) + h = self.reshape(h, (bs, seq_len, self.hidden_size)) + # h: [bs, seq/1, hidden_dim] + for i in range(self.num_layers): + prefix_kv = prefix_keys_values[i] if prefix_keys_values is not None else None + h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length, block_tables=block_tables, + slot_mapping=slot_mapping, prefix_keys_values=prefix_kv) + output = self.norm_out(h) + return output + + +@MindFormerRegister.register(MindFormerModuleType.MODELS) +class TelechatForCausalLM(TelechatPreTrainedModel): + r""" + Provide telechat training loss or logits through network. + + Args: + config (TelechatConfig): The config of telechat model. + + Returns: + output: Tensor, the output of telechat decoderlayer + + Examples: + >>> from mindformers.models.telechat import TelechatConfig, TelechatForCausalLM + >>> config = TelechatConfig(batch_size=2) + >>> network = TelechatForCausalLM(config=config) + >>> type(network) + + >>> from mindformers import TelechatForCausalLM + >>> network = TelechatForCausalLM.from_pretrained('telechat_115b') + >>> type(network) + + """ + + @lazy_inline + def __init__(self, config: TelechatConfig = None): + super(TelechatForCausalLM, self).__init__(config, auto_prefix=True) + _check_config(config.parallel_config) + self.config = config + self.run_mode = config.run_mode + self.ignore_token_id = config.ignore_token_id + self.pad_token_id = config.pad_token_id + self.use_past = config.use_past + self.vocab_size = config.vocab_size + self.is_first_iteration = True + + self.shape = P.Shape() + self.reshape = P.Reshape() + if config.is_dynamic: + self.reshape.add_prim_attr("skip_redistribution", True) + self.cast = P.Cast() + self.slice = P.StridedSlice() + self.not_equal = P.NotEqual() + self.mul = P.Mul() + self.add = P.Add() + self.ones = P.Ones() + self.gather = P.Gather(1) + self.sub_batch_valid_len = P.Sub() + self.model = TelechatModel(config=config) + self.lm_head = Linear(in_channels=config.hidden_size, + out_channels=config.vocab_size, + has_bias=False, + compute_dtype=config.compute_dtype, + param_init_type=config.param_init_type, + weight_init="normal") # meta default: xavier_normal + + mp = config.parallel_config.model_parallel + vocab_size = config.vocab_size + loss_parallel_config = copy.deepcopy(config.parallel_config) + if vocab_size % mp != 0: + logger.warning("The vocab size of Loss is: %s, it is not divide by model_parallel: %s", + vocab_size, mp) + logger.warning("Now, the model_parallel num of Loss will be changed: mp = 1") + loss_parallel_config.model_parallel = 1 + loss_parallel_config.data_parallel *= loss_parallel_config.context_parallel + self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config) + + dp = config.parallel_config.data_parallel + mp = config.parallel_config.model_parallel + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.slice.shard(((dp, 1),)) + self.not_equal.shard(((dp, 1), ())) + self.mul.shard(((dp, 1), (dp, 1))) + self.add.shard(((dp, 1), ())) + self.gather.shard(((dp, 1, 1), (dp,))) + self.sub_batch_valid_len.shard(((1,), ())) + if config.parallel_config.vocab_emb_dp or (vocab_size % mp != 0): + self.lm_head.shard(strategy_matmul=((dp, 1), (1, 1))) + else: + self.lm_head.shard(strategy_matmul=((dp, 1), (mp, 1))) + if config.parallel_config.pipeline_stage > 1: + self.lm_head.pipeline_stage = config.parallel_config.pipeline_stage - 1 + + self.predict_run_mode = get_predict_run_mode() + + logger.info("Predict run mode:{}".format(self.predict_run_mode)) + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + if self.config.is_dynamic and "origin_inputs" in kwargs: + input_ids = kwargs["origin_inputs"] + return { + "input_ids": Tensor(input_ids, mstype.int32) + } + + # pylint: disable=W0613 + def prepare_inputs_for_predict_layout(self, input_ids, **kwargs): + """Get Telechat model input tuple for transform ckpt.""" + input_ids = Tensor(input_ids, mstype.int32) + labels = Tensor(kwargs["labels"]) if "labels" in kwargs else None + bs, seq = input_ids.shape[0], input_ids.shape[1] + slot_mapping = Tensor(np.ones(shape=tuple([bs*seq])), mstype.int32) + prefix_keys_values = Tensor(kwargs["prefix_keys_values"]) if "prefix_keys_values" in kwargs else None + return input_ids, labels, None, None, None, None, None, None, None, None, None, slot_mapping, prefix_keys_values + + def set_dynamic_inputs(self, **kwargs): + """Set dynamic inputs""" + dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32) + dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32) + dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) + dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) + have_prefix_keys_values = getattr(kwargs, "have_prefix_keys_values", False) + if have_prefix_keys_values: + dynamic_prefix_keys_values = Tensor(shape=[2, None, None, None, None], dtype=mstype.float16) + self.set_inputs(dynamic_input_ids, None, None, None, None, None, None, + dynamic_batch_valid_length, None, None, dynamic_block_tables, + dynamic_slot_mapping, dynamic_prefix_keys_values) + else: + self.set_inputs(dynamic_input_ids, None, None, None, None, None, None, + dynamic_batch_valid_length, None, None, dynamic_block_tables, + dynamic_slot_mapping, None) + logger.info("Set dynamic input for telechat.") + + def add_flags_custom(self, is_first_iteration): + """Add customized attributes for specific cells in the model.""" + self.add_flags(is_first_iteration=is_first_iteration) + self.model.add_flags(is_first_iteration=is_first_iteration) + for layer in self.model.layers: + layer.add_flags(is_first_iteration=is_first_iteration) + layer.attention.infer_attention.add_flags(is_first_iteration=is_first_iteration) + + # pylint: disable=W0613 + def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None, + input_embeds=None, init_reset=None, batch_valid_length=None, batch_index=None, zactivate_len=None, + block_tables=None, slot_mapping=None, prefix_keys_values=None): + r""" + TelechatForCausalLM forward. + + Args: + input_ids(Tensor): the tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. + labels(Tensor): the tokenized labels with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. + input_position(Tensor): current position, used by model.predict. + position_ids(Tensor): Reserved param, not used. + attention_mask(Tensor): Reserved param, not used. + input_embeds(Tensor): Reserved param, not used. + init_reset(bool, optional): A bool tensor with shape [1], used to clear the past key parameter and + past value parameter used in the incremental prediction. Default True. + batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental + prediction. Tensor of shape :math:`(batch_size,)`. Default None. + block_tables (Tensor[int64]): Store mapping tables for each sequence. + slot_mapping (Tensor[int32]): Store token cache physical slot index. + Returns: + Tensor: The loss or (logits, tokens, input_mask) of the network. + """ + bsz, seqlen = self.shape(input_ids) + if self.use_past: + if not isinstance(batch_valid_length, Tensor): + batch_valid_length = self.ones((bsz,), mstype.int32) + if self.training: + tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1)) + else: + tokens = input_ids + if batch_valid_length is not None: + batch_valid_length = self.reshape(batch_valid_length, (-1,)) + output = self.model(tokens, batch_valid_length, batch_index, zactivate_len, block_tables, \ + slot_mapping, prefix_keys_values) + pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None + if pre_gather: + output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1) + logits = self.lm_head(output) + + input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32) + if not self.training: + logits = self.cast(logits, mstype.float32) + if self.predict_run_mode: + return logits + return logits, tokens, input_mask + input_mask = self.slice(labels, (0, 1), (bsz, seqlen), (1, 1)) + labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1)) + if logits.ndim > 2: + logits = self.reshape(logits, (-1, logits.shape[-1])) + logits = self.cast(logits, mstype.float32) + labels = self.reshape(labels, (-1,)) + input_mask = self.reshape(input_mask, (-1,)) + loss = self.loss(logits, labels, input_mask) + return loss + + def kvcache(self, layer_idx): + key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache + value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache + return key_cache, value_cache diff --git a/MindSpore-telechat/telechat_config.py b/MindSpore-telechat/telechat_config.py new file mode 100644 index 0000000..76e617f --- /dev/null +++ b/MindSpore-telechat/telechat_config.py @@ -0,0 +1,213 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat Config API.""" + +from typing import Optional, Union + +from mindspore._checkparam import args_type_check + +from mindformers.modules.transformer.transformer import default_transformer_config, \ + TransformerOpParallelConfig +from mindformers.tools.register import MindFormerRegister, MindFormerModuleType +from mindformers.models.configuration_utils import PretrainedConfig +from mindformers.models.utils import convert_mstype + +__all__ = ['TelechatConfig'] + + +@MindFormerRegister.register(MindFormerModuleType.CONFIG) +class TelechatConfig(PretrainedConfig): + """ + Telechat config class which defines the model size. + + Args: + batch_size (Optional[int]): batch size for input data, use in predict. + seq_length (Optional[int]): The sequence length of input_ids, default is 1024. + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the BERT model. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + multiple_of (Optional[int]): Define SwiGLU hidden layer size multiples, default 256. + n_kv_heads (Optional[int]): Define multi group head attention heads number, default None. + ffn_dim_multiplier (Optional[int]): Define ffn layer dim multiples, default None. + rms_norm_eps (Optional[float]): The epsilon value of the denominator. Default 1e-5. + bos_token_id (Optional[int]): The id of the *beginning-of-sequence* token. + eos_token_id (Optional[int]): The id of the *end-of-sequence* token. + pad_token_id (Optional[int]): The id of the *padding* token. + ignore_token_id (Optional[int]): The id of the *ignoring* token. + compute_dtype (Optional[str]): + Linear layer compute dtype, default is "float16". + layernorm_compute_type (Optional[str]): + layernorm compute dtype, default is "float32". + softmax_compute_type (Optional[str]): + softmax compute dtype, default is "float32". + rotary_dtype (Optional[str]): + rope compute dtype, default is "float32". + param_init_type (Optional[str]): + parameter initial dtype, default is "float16". + qkv_has_bias (Optional[bool]): + Whether the Query, Key, and Value projection has bias. + use_past (`bool`, *optional*, defaults to `False`): + Whether the model should use the past last key/values attentions + (if applicable to the model) to speed up decoding. + parallel_config(TransformerOpParallelConfig): + The parallel configure. Default `default_transformer_config`, + an instance of `TransformerOpParallelConfig` with default args. + extend_method(str): The extend method of seq length of inferencem,default None. + use_flash_attention(bool): Whether enable flash attention ops, default False. + offset(int): Offset of transformer layer when set pipeline stage number. + checkpoint_name_or_path (Optional[str]): + checkpoint path or name used to load to the network. + repetition_penalty (`float`, *optional*, defaults to 1.0): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + max_decode_length (`int`, *optional*, defaults to 1024): + The maximum length the generated tokens can have. Corresponds to the length of the input prompt + + `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. + top_k (`int`, *optional*, defaults to 5): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional*, defaults to 1.0): + If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to `top_p` or higher are kept for generation. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. + block_size (`int`, *optional*, defaults to 16): + The maximum number of tokens in one block can have when using paged attention. + num_blocks (`int`, *optional*, defaults to 512): + The maximum number of blocks when using paged attention. + Returns: + Class, TelechatConfig. + """ + + model_type = "telechat" + + @args_type_check(parallel_config=(dict, TransformerOpParallelConfig)) + def __init__(self, + batch_size: int = 1, + seq_length: int = 2048, + hidden_size: int = 4096, + num_layers: int = 32, + num_heads: int = 32, + embed_dropout_prob: float = 1.0, + hidden_dropout_prob: float = 1.0, + attention_dropout_prob: float = 1.0, + n_kv_heads: Optional[int] = None, + max_position_embedding: Optional[int] = None, + intermediate_size: Optional[int] = None, + vocab_size: int = 32000, # defined later by tokenizer + multiple_of: int = 256, # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[int] = None, + rms_norm_eps: float = 1e-5, + bos_token_id: int = 1, + eos_token_id: int = 2, + pad_token_id: int = 0, + ignore_token_id: int = -100, + theta: float = 10000.0, + compute_dtype: str = "float16", + layernorm_compute_type: str = "float32", + softmax_compute_type: str = "float32", + rotary_dtype: str = "float32", + param_init_type: str = "float16", + embedding_init_type=None, + res_dtype: str = "float32", + qkv_has_bias: bool = False, + wo_has_bias: bool = True, + parallel_config: Union[dict, TransformerOpParallelConfig] = default_transformer_config, + use_past: bool = False, + extend_method: str = "None", + scaling_factor: float = 1.0, + is_dynamic: bool = False, + use_rope_slice: bool = False, + use_flash_attention: bool = False, + use_attn_mask_compression: bool = False, + parallel_optimizer: bool = False, + fine_grain_interleave: int = 1, + pp_interleave_num: int = 1, + offset: int = 0, + checkpoint_name_or_path: str = "", + repetition_penalty: float = 1.0, + max_decode_length: int = 1024, + block_size: int = 16, + num_blocks: int = 512, + top_k: int = 5, + top_p: float = 1.0, + do_sample: bool = True, + quant: str = "", + sigma: float = 0.0048, + mean: float = 0.0, + **kwargs): + super(TelechatConfig, self).__init__(**kwargs) + if isinstance(parallel_config, dict): + parallel_config = TransformerOpParallelConfig(**parallel_config) + self.run_mode = kwargs.get("run_mode", "finetune") + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_heads = num_heads + self.embed_dropout_prob = embed_dropout_prob + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_dropout_prob = attention_dropout_prob + self.max_position_embedding = max_position_embedding if max_position_embedding else seq_length + self.intermediate_size = intermediate_size + self.multiple_of = multiple_of + self.n_kv_heads = n_kv_heads + self.ffn_dim_multiplier = ffn_dim_multiplier + self.rms_norm_eps = rms_norm_eps + self.wo_has_bias = wo_has_bias + self.param_init_type = convert_mstype(param_init_type) + if embedding_init_type is not None: + self.embedding_init_type = convert_mstype(embedding_init_type) + else: + self.embedding_init_type = self.param_init_type + self.qkv_has_bias = qkv_has_bias + self.layernorm_compute_type = convert_mstype(layernorm_compute_type) + self.softmax_compute_type = convert_mstype(softmax_compute_type) + self.rotary_dtype = convert_mstype(rotary_dtype) + self.compute_dtype = convert_mstype(compute_dtype) + self.res_dtype = convert_mstype(res_dtype) + self.parallel_config = parallel_config + self.checkpoint_name_or_path = checkpoint_name_or_path + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.ignore_token_id = ignore_token_id + self.use_past = use_past + self.extend_method = extend_method + self.scaling_factor = scaling_factor + self.is_dynamic = is_dynamic + self.use_rope_slice = use_rope_slice + self.use_flash_attention = use_flash_attention + self.use_attn_mask_compression = use_attn_mask_compression + self.parallel_optimizer = parallel_optimizer + self.fine_grain_interleave = fine_grain_interleave + self.offset = offset + self.repetition_penalty = repetition_penalty + self.max_decode_length = max_decode_length + self.pp_interleave_num = pp_interleave_num + self.top_k = top_k + self.top_p = top_p + self.do_sample = do_sample + self.sigma = sigma + self.mean = mean + self.theta = theta + self.block_size = block_size + self.num_blocks = num_blocks + self.quant = quant diff --git a/MindSpore-telechat/telechat_interleave.py b/MindSpore-telechat/telechat_interleave.py new file mode 100644 index 0000000..99eb221 --- /dev/null +++ b/MindSpore-telechat/telechat_interleave.py @@ -0,0 +1,678 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat fine grain interleave transformer Telechat's APIs.""" + +from typing import Optional +import math + +import mindspore as ms +from mindspore import nn, __version__ +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.context import ParallelMode +from mindspore.ops import operations as P +from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation + +from mindformers.models.llama.llama_layer import LlamaRMSNorm +from mindformers.modules.layers import _check_input_dtype, Dropout, RotaryEmbedding +from mindformers.modules.transformer import TransformerOpParallelConfig +from mindformers.modules.flash_attention import FlashAttention +from telechat_layer import TelechatLinear, TelechatFeedForward + +__all__ = ['TelechatDecodeLayerInterleave'] + +class _MicroBatch(nn.Cell): + """ + transform mini-batch to micro-batch in pipeline parallel. + + Args: + params (micro_size): The number of micro-batch. + """ + def __init__(self, micro_size, input_size, axis_list): + super(_MicroBatch, self).__init__() + self.shape = P.Shape() + self.micro_size = micro_size + self.strided_slice_list = [] + for _ in range(input_size): + self.strided_slice_list.append(P.StridedSlice()) + self.axis_list = axis_list + + def construct(self, i, *inputs): + """construct for _MicroBatch.""" + micro_inputs = () + k = 0 + for each_input in inputs: + input_shape = self.shape(each_input) + micro_batch_begin = i * input_shape[self.axis_list[k]] // self.micro_size + micro_batch_end = (i + 1) * input_shape[self.axis_list[k]] // self.micro_size + strided_slice_begin = () + strided_slice_strides = () + strided_slice_end = () + for j in range(len(input_shape)): + strided_slice_strides += (1,) + if j == self.axis_list[k]: + strided_slice_begin += (micro_batch_begin,) + strided_slice_end += (micro_batch_end,) + else: + strided_slice_begin += (0,) + strided_slice_end += (input_shape[j],) + + micro_input = self.strided_slice_list[k](each_input, strided_slice_begin,\ + strided_slice_end, strided_slice_strides) + micro_inputs += (micro_input,) + k += 1 + return micro_inputs + + +class TelechatAttentionInterleave(nn.Cell): + r""" + This is an implementation of multihead attention in Telechat. + + Args: + - **batch_size** (int): The batch size of the input tensor when do increnmental prediction. Should be a + positive value. + When do training or prediction, the argument will not work and the user can just pass None to the + argument. + - **src_seq_length** (int): The sequence length of the query vector. + - **tgt_seq_length** (int): The sequence length of the key and value vector. + - **dim** (int): The hidden size of the input. + - **head_dim** (int): The dim of head. + - **n_heads** (int): The number of the heads. + - **compute_dtype** (dtype.Number): The computation type of dense. Default mstype.float16. + Should be mstype.float32 or mstype.float16. + - **softmax_compute_type** (dtype.Number): The type of softmax computation module. Default mstype.float32. + Should be mstype.float32 or mstype.float16. + - **param_init_type** (dtype.Number): The parameter initialization type of the module. Default mstype. + float32. Should be mstype.float32 or mstype.float16. + - **qkv_has_bias** (bool): Whether Q/K/V in attention has bias or not. + - **use_past** (bool): Use the past state to compute, used for incremental prediction. + For example, if we have two words and want to generate the ten more words. + We just need to compute the two words' state only once, and generate the next word one by one. + When use_past is True, there are two steps to run the prediction. + In the first step, set the is_first_iteration to be True by + `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the + is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment, + pass the single step's input tensor, and loop it. Default False. + - **parallel_config** (OpParallelConfig): The parallel configure. Default `default_dpmp_config`, + an instance of `OpParallelConfig` with default args. + + Inputs: + - **x** (Tensor) - The input tokens with shape (batch_size, src_seq_length, hidden_size) or + (batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. + Otherwise, must be (batch_size, 1, hidden_size) + - **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention. + - **attention_mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask + matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask + in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length) + - **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, head_dim, tgt_seq_length). + The past calculated key vector. Used for incremental prediction when the use_past is True. + Default None. + - **value_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, tgt_seq_length, + head_dim). + The past calculated value vector. Used for incremental prediction when the use_past is True. + Default None. + - **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index. + Used for incremental prediction when the use_past is True. Default None. + + Outputs: + Tuple, a tuple contains(`output`, `layer_present`) + + - **output** (Tensor) - Tensor, the float tensor of the output of the layer with + shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size), + if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size). + + - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with + ((batch_size, num_heads, head_dim, tgt_seq_length), + (batch_size, num_heads, tgt_seq_length, head_dim)). + """ + def __init__(self, + run_mode, + batch_size, + seq_length, + dim: int = 512, + n_heads: int = 8, + sigma: float = 0.0048, + mean: float = 0.0, + hidden_dropout_prob: float = 1.0, + attention_dropout_prob: float = 1.0, + n_kv_heads: Optional[int] = None, + compute_dtype=mstype.float16, + softmax_compute_dtype=mstype.float32, + rotary_dtype=mstype.float32, + param_init_type=mstype.float32, + qkv_has_bias=False, + wo_has_bias=True, + is_dynamic=False, + use_rope_slice=False, + use_flash_attention=False, + parallel_config=TransformerOpParallelConfig()): + super().__init__() + self.batch_size = batch_size + self.seq_length = seq_length + self.hidden_size = dim + self.n_head = n_heads + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_dropout_prob = attention_dropout_prob + self.head_dim = dim // n_heads + self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads + self.n_rep = self.n_head // self.n_kv_head + self.kv_dim = self.n_kv_head * self.head_dim + self.qkv_has_bias = qkv_has_bias + self.wo_has_bias = wo_has_bias + self.dtype = compute_dtype + self.softmax_dtype = softmax_compute_dtype + self.is_first_iteration = True + self.use_flash_attention = use_flash_attention + + if self.hidden_size % self.n_head != 0: + raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple " + "of 'n_head', but got the hidden_size is {} and the n_head is {}." + .format(self.hidden_size, self.n_head)) + if self.n_kv_head % parallel_config.model_parallel != 0: + raise ValueError("For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of " + "'parallel_config.model_parallel', but got the n_kv_head is {} " + "and the parallel_config.model_parallel is {}." + .format(self.n_kv_head, parallel_config.model_parallel)) + + self.inv_norm_factor = Tensor(1.0 / math.sqrt(self.head_dim), dtype=compute_dtype) + + self.shape = P.Shape() + self.reshape = P.Reshape() + self.transpose = P.Transpose() + self.merger_head_transpose = P.Transpose() + self.batch_matmul = P.BatchMatMul() + self.batch_matmul_q_k = P.BatchMatMul(transpose_b=True) + self.mul = P.Mul() + self.add = P.Add() + self.softmax = P.Softmax() + self.cast = P.Cast() + self.cast_attn = P.Cast() + self.tile_kv = P.Tile() + if run_mode == "predict": + self.split_kv = ms.ops.auto_generate.SplitWithSize() + self.split_kv.add_prim_attr("skip_redistribution", True) + else: + self.split_kv = P.Split(output_num=2, axis=-1) + self.apply_rotary_emb = RotaryEmbedding(self.head_dim, rotary_dtype, use_rope_slice=use_rope_slice) + self.attention_dropout = Dropout(1-self.attention_dropout_prob) + + self.wq = TelechatLinear(self.hidden_size, + self.hidden_size, + has_bias=qkv_has_bias, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + skip_redistribution=is_dynamic) + self.wk_v = TelechatLinear(self.hidden_size, + self.n_kv_head * self.head_dim * 2, + has_bias=qkv_has_bias, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + skip_redistribution=is_dynamic) + self.wo = TelechatLinear(in_channels=self.hidden_size, + out_channels=self.hidden_size, + has_bias=wo_has_bias, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + skip_redistribution=is_dynamic, + keep_prob=1-self.hidden_dropout_prob) + + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + self.split_kv.shard(((dp, mp, 1),)) + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.transpose.shard(((dp, 1, mp, 1),)) + self.merger_head_transpose.shard(((dp, mp, 1, 1),)) + self.batch_matmul_q_k.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) + self.batch_matmul.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) + self.mul.shard(((dp, mp, 1, 1), ())) + self.add.shard(((dp, 1, 1, 1), (dp, mp, 1, 1))) + self.softmax.shard(((dp, mp, 1, 1),)) + self.tile_kv.shard(((dp, mp, 1, 1),)) + + self.apply_rotary_emb.shard(parallel_config) + + if self.qkv_has_bias: + self.wq.shard(((dp, 1), (mp, 1)), ((dp, mp), (mp,))) + self.wk_v.shard(((dp, 1), (mp, 1)), ((dp, mp), (mp,))) + else: + self.wq.shard(((dp, 1), (mp, 1))) + self.wk_v.shard(((dp, 1), (mp, 1))) + if self.wo_has_bias: + self.wo.shard(((dp, mp), (1, mp)), ((dp, 1), (1,))) + else: + self.wo.shard(((dp, mp), (1, mp))) + if parallel_config.use_seq_parallel and self.is_first_iteration: + if self.wo_has_bias: + self.wo.shard(((dp, mp), (1, mp)), ((dp * mp, 1), (1,)), out_strategy_matmul=((dp * mp, 1),)) + else: + self.wo.shard(((dp, mp), (1, mp)), out_strategy_matmul=((dp * mp, 1),)) + if parallel_config.recompute.select_recompute and not self.use_flash_attention: + self.apply_rotary_emb.recompute() + self.tile_kv.recompute() + self.batch_matmul_q_k.recompute() + self.mul.recompute() + self.add.recompute() + self.cast_attn.recompute() + self.softmax.recompute() + self.batch_matmul.recompute() + + if self.use_flash_attention: + self.flash_attention = FlashAttention(head_num=self.n_head, + pre_tokens=65536, + next_tokens=0, + input_layout='BNSD', + keep_prob=1. - attention_dropout_prob, + scale_value=1. / math.sqrt(self.head_dim), + sparse_mode=0, + use_attention_mask=True) + self.flash_attention.shard(parallel_config) + + def compute_qkv(self, x): + """compute the qkv with interleave number""" + x = self.reshape(x, (-1, x.shape[-1])) + query = self.cast(self.wq(x), self.dtype) # dp, 1 -> dp, mp + key_value = self.cast(self.wk_v(x), self.dtype) # dp, 1 -> dp, mp + key_value = self.reshape(key_value, (-1, self.n_kv_head, self.head_dim * 2)) + if self.training: + key, value = self.split_kv(key_value) + else: + key, value = self.split_kv(key_value, (self.head_dim, self.head_dim), 2) + key = self.reshape(key, (-1, self.n_kv_head * self.head_dim)) + value = self.reshape(value, (-1, self.n_kv_head * self.head_dim)) + return query, key, value + + def cal_attn(self, query, key, value, mask, freqs_cis): + """cal_attn""" + query = self.reshape(query, (-1, self.seq_length, self.n_head, self.head_dim)) + key = self.reshape(key, (-1, self.seq_length, self.n_kv_head, self.head_dim)) + value = self.reshape(value, (-1, self.seq_length, self.n_kv_head, self.head_dim)) + + # [bs, seq/1, n_head/n_kv_head, head_dim] + query = self.transpose(query, (0, 2, 1, 3)) + key = self.transpose(key, (0, 2, 1, 3)) + value = self.transpose(value, (0, 2, 1, 3)) + + # [bs, n_head/n_kv_head, seq/1, head_dim] + query, key = self.apply_rotary_emb(query, key, freqs_cis) # dp, mp, 1, 1 + # kv share: [bs, n_kv_head, seq, head_dim] -> [bs, n_head, seq, head_dim] + bs, n_head, seq, head_dim = query.shape + n_kv_head = key.shape[1] + query = self.reshape(query, (bs, n_head, seq, head_dim)) + key = self.reshape(key, (bs, n_kv_head, seq, head_dim)) + value = self.reshape(value, (bs, n_kv_head, seq, head_dim)) + + # q, k, v: [bs, n_head, seq/1, head_dim], [bs, n_head, seq, head_dim], [bs, n_head, seq, head_dim] + if self.use_flash_attention: + attention = self.flash_attention(query, key, value, mask) + attention = self._merge_heads(attention) + else: + key = self._repeat_kv(key, self.n_rep) + value = self._repeat_kv(value, self.n_rep) + attention = self._attn(query, key, value, mask) + return attention + + def cal_output_proj(self, attention): + """cal_output_proj""" + output = self.wo(attention) # dp, mp -> dp, 1 / dp * mp, 1 + return output + + def _repeat_kv(self, x, rep): + """repeat_kv""" + if rep == 1: + return x + bs, n_kv_head, seqlen, head_dim = x.shape + x = self.reshape(x, (bs, n_kv_head, 1, seqlen * head_dim)) + x = self.tile_kv(x, (1, 1, rep, 1)) + x = self.reshape(x, (bs, n_kv_head * rep, seqlen, head_dim)) + return x + + def _merge_heads(self, x): + """ + convert a 4d input to a 2d or 3d output + + Inputs: + x: input tensor + + Output: + x_merge: the 2d output + """ + # [bs, n_head, seq/1, head_dim] + x = self.merger_head_transpose(x, (0, 2, 1, 3)) # dp,mp,1,1 -> dp,1,mp,1 + # [bs, seq/1, n_head, head_dim] + x_shape = x.shape + # [bs * seq/1, hidden_dim] + new_shape = (-1, x_shape[-2] * x_shape[-1]) + x_merge = self.reshape(x, new_shape) + return x_merge + + def _attn(self, query, key, value, mask): + """ + Get the weighted score along the seq_length + + Inputs: + query: the query matrix + key: the key matrix + value: the value matrix + mask: the attention mask adder matrix with shape (batch_size, + 1, seq_length, seq_length) + Outputs: + weighted_values: Tensor, the weighted sum scores + """ + # q, k: [bs, n_head, seq/1, head_dim], [bs, n_head, seq, head_dim] + score = self.batch_matmul_q_k(query, key) + # score: [bs, n_head, seq/1, seq] + score = self.mul(score, self.inv_norm_factor) + score = self.add(mask, score) + + attention_probs = self.softmax(self.cast_attn(score, self.softmax_dtype)) + attention_probs = self.attention_dropout(attention_probs) + # score, v: [bs, n_head, seq/1, seq], [bs, n_head, seq, head_dim] + weighted_values = self.batch_matmul(self.cast(attention_probs, self.dtype), value) + # [bs, n_head, seq/1, head_dim] + attention_merge = self._merge_heads(weighted_values) + # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] + return attention_merge + + +class TelechatDecodeLayerInterleave(nn.Cell): + r""" + Transformer Layer. This is an implementation of the single layer of the transformer + encoder layer, including multihead attention and feedward layer. + + Args: + batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive + value. When do training or prediction, the argument will not work and the user can just pass None to + the argument. + seq_length(int): The input sequence length. + layer_id(int): The layer id of current transformer block layer. + dim(int): The hidden size of the input. + num_heads(int): The number of the heads. + multiple_of(int): The SwiGLU hidden layer size multiple of large power of 2. + norm_eps (float): The epsilon value of the denominator. Default 1e-5. + compute_dtype(dtype.Number): The computation type of the layer. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + layernorm_compute_type(dtype.Number): The computation type of the norm. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + softmax_compute_type(dtype.Number): The computation type of the softmax in the attention. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + param_init_type(dtype.Number): The parameter initialization type of the module. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + qkv_has_bias(bool): Whether Q/K/V in attention has bias or not. + use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two + words and want to generate the ten more words. We just need to compute the two words' state only once, + and generate the next word one by one. When use_past is True, there are two steps to run the prediction. + In the first step, set the is_first_iteration to be True by + `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the + is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. + At this moment, pass the single step's input tensor, and loop it. Default False. + parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied, + MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`, + an instance of `OpParallelConfig` with default args. + + Inputs: + - **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or + [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise, + should be [batch_size, 1, hidden_size] + - **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention. + - **input_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True, + the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will + be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size] + - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and + past value parameter used in the incremental prediction. Only valid when use_past is True. Default True. + - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. + Used for incremental prediction when the use_past is True. Default None. + + Outputs: + Tuple, a tuple contains(`output`, `layer_present`). + + - **output** (Tensor) - The float tensor of the output of the layer with + shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is + False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size) + + - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with + ((batch_size, num_heads, head_dim, seq_length), + (batch_size, num_heads, seq_length, head_dim)). + + """ + def __init__(self, + run_mode, + batch_size, + seq_length, + layer_id, + dim: int = 512, + n_heads: int = 8, + num_layers: int = 32, + sigma: float = 0.0048, + mean: float = 0.0, + hidden_dropout_prob: float = 1.0, + attention_dropout_prob: float = 1.0, + n_kv_heads: Optional[int] = None, + intermediate_size: Optional[int] = None, + ffn_dim_multiplier: Optional[int] = None, + norm_eps: float = 1e-5, + compute_dtype=mstype.float16, + layernorm_compute_dtype=mstype.float32, + softmax_compute_dtype=mstype.float32, + rotary_dtype=mstype.float32, + param_init_type=mstype.float32, + res_dtype=mstype.float32, + qkv_has_bias=False, + wo_has_bias=True, + is_dynamic=False, + use_rope_slice=False, + use_flash_attention=False, + fine_grain_interleave=2, + parallel_config=TransformerOpParallelConfig()): + + super().__init__() + self.seq_length = seq_length + self.layer_id = layer_id + self.hidden_size = dim + self.n_head = n_heads + self.num_layers = num_layers + self.head_dim = self.hidden_size // self.n_head + self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads + + self.dtype = compute_dtype + self.res_dtype = res_dtype + self.is_first_iteration = True + self.interleave_num = fine_grain_interleave + self.key_past = None + self.value_past = None + + self.reshape = P.Reshape() + self.add = P.Add() + self.cast = P.Cast() + self.attention_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) + self.ffn_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) + self.attention = TelechatAttentionInterleave(run_mode=run_mode, + batch_size=batch_size, + seq_length=seq_length, + dim=dim, + n_heads=n_heads, + sigma=sigma, + mean=mean, + hidden_dropout_prob=hidden_dropout_prob, + attention_dropout_prob=attention_dropout_prob, + n_kv_heads=n_kv_heads, + compute_dtype=compute_dtype, + softmax_compute_dtype=softmax_compute_dtype, + rotary_dtype=rotary_dtype, + param_init_type=param_init_type, + qkv_has_bias=qkv_has_bias, + wo_has_bias=wo_has_bias, + is_dynamic=is_dynamic, + use_rope_slice=use_rope_slice, + use_flash_attention=use_flash_attention, + parallel_config=parallel_config) + self.feed_forward = TelechatFeedForward(dim=self.hidden_size, + intermediate_size=intermediate_size, + hidden_dim=4 * self.hidden_size, + sigma=sigma, + mean=mean, + hidden_dropout_prob=hidden_dropout_prob, + ffn_dim_multiplier=ffn_dim_multiplier, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + is_dynamic=is_dynamic) + + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.feed_forward.shard(parallel_config) + self.add.shard(((dp, 1), (dp, 1))) + self.attention_norm.shard((dp, 1)) + self.ffn_norm.shard((dp, 1)) + + if parallel_config.use_seq_parallel and self.is_first_iteration: + self.add.shard(((dp * mp, 1), (dp * mp, 1))) + self.attention_norm.shard((dp * mp, 1)) + self.ffn_norm.shard((dp * mp, 1)) + self.feed_forward.w2.shard(((dp, mp), (1, mp)), ((dp * mp, 1), (1,)), out_strategy_matmul=((dp * mp, 1),)) + + concat_stra1 = [] + concat_stra2 = [] + self.interleave1_inputs = nn.CellList() + self.interleave1_inputs_ = nn.CellList() + self.interleave2_inputs = nn.CellList() + self.interleaved_concat1 = P.Concat(axis=0) + self.interleaved_concat1.add_prim_attr("fine_grained_interleaved_index", self.layer_id) + self.interleaved_concat_1 = P.Concat(axis=0) + self.interleaved_concat2 = P.Concat(axis=0) + if self.layer_id != self.num_layers - 2: + self.interleaved_concat2.add_prim_attr("fine_grained_interleaved_index", 1000) + + for _ in range(self.interleave_num): + concat_stra1.append((dp, mp)) + interleave_data1 = _MicroBatch(self.interleave_num, 1, [0]) + interleave_data1.strided_slice_list[0].add_prim_attr("skip_redistribution", True) + interleave_data1_ = _MicroBatch(self.interleave_num, 1, [0]) + interleave_data1_.strided_slice_list[0].add_prim_attr("skip_redistribution", True) + interleave_data2 = _MicroBatch(self.interleave_num, 2, [0, 0]) + if parallel_config.use_seq_parallel: + if self.layer_id == self.num_layers - 2: + concat_stra2.append((dp, 1)) + else: + concat_stra2.append((dp * mp, 1)) + if self.layer_id == self.num_layers - 1: + interleave_data1.strided_slice_list[0].shard(((dp, 1),)) + else: + interleave_data1.strided_slice_list[0].shard(((dp * mp, 1),)) + interleave_data1_.strided_slice_list[0].shard(((1, 1),)) + interleave_data2.strided_slice_list[0].shard(((dp * mp, 1),)) + else: + concat_stra2.append((dp, 1)) + interleave_data1.strided_slice_list[0].shard(((dp, 1),)) + interleave_data1_.strided_slice_list[0].shard(((1, 1),)) + interleave_data2.strided_slice_list[0].shard(((dp, 1),)) + if self.layer_id == 0 and parallel_config.use_seq_parallel: + interleave_data2.strided_slice_list[0].shard(((dp, 1),)) + interleave_data2.strided_slice_list[0].add_prim_attr("skip_redistribution", True) + else: + interleave_data2.strided_slice_list[0].add_prim_attr("skip_redistribution", True) + + interleave_data2.strided_slice_list[0].add_prim_attr("fine_grained_interleaved_index", self.layer_id) + interleave_data2.strided_slice_list[1].shard(((dp, mp),)) + interleave_data2.strided_slice_list[1].add_prim_attr("fine_grained_interleaved_index", self.layer_id) + interleave_data2.strided_slice_list[1].add_prim_attr("skip_redistribution", True) + self.interleave1_inputs.append(interleave_data1) + self.interleave1_inputs_.append(interleave_data1_) + self.interleave2_inputs.append(interleave_data2) + concat_stra1 = tuple(concat_stra1) + concat_stra2 = tuple(concat_stra2) + self.interleaved_concat1.shard(concat_stra1) + self.interleaved_concat1.add_prim_attr("skip_redistribution", True) + self.interleaved_concat_1.shard(concat_stra1) + self.interleaved_concat_1.add_prim_attr("skip_redistribution", True) + self.interleaved_concat2.shard(concat_stra2) + self.interleaved_concat2.add_prim_attr("skip_redistribution", True) + + def linear_layer1(self, x): + """layer part 1""" + input_x = self.attention_norm(x) + query, key, value = self.attention.compute_qkv(input_x) + return query, key, value + + def linear_layer2(self, x, attention): + """layer part 2""" + attention_output = self.attention.cal_output_proj(attention) + ori_dtype = attention_output.dtype + # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm + x = self.add(self.cast(x, self.res_dtype), self.cast(attention_output, self.res_dtype)) + output_x = self.ffn_norm(x) + mlp_logit = self.feed_forward(output_x) + output = self.add(self.cast(x, self.res_dtype), self.cast(mlp_logit, self.res_dtype)) + output = self.cast(output, ori_dtype) + return output + + # pylint: disable=W0613 + def construct(self, x, freqs_cis, mask=None, batch_valid_length=None, block_tables=None, + slot_mapping=None, prefix_keys_values=None, q_seq_lens=None): + """ Forward of transformer block. """ + self._check_input(x, freqs_cis, mask) + x = self.reshape(x, (-1, x.shape[-1])) + # ============linear-layer1================ + if self.layer_id == 0: + query, key, value = self.linear_layer1(x) + else: + query_tuple = () + key_tuple = () + value_tuple = () + for i in range(self.interleave_num): + x_part, = self.interleave1_inputs[i](i, x) + query_part, key_part, value_part = self.linear_layer1(x_part) + query_tuple += (query_part,) + key_tuple += (key_part,) + value_tuple += (value_part,) + query = self.interleaved_concat1(query_tuple) + key = self.interleaved_concat_1(key_tuple) + value = self.interleaved_concat_1(value_tuple) + # ===========linear-layer1 end============= + attention = self.attention.cal_attn(query, key, value, mask, freqs_cis) + # ============linear-layer2================ + if self.layer_id == self.num_layers - 1: + output = self.linear_layer2(x, attention) + else: + output_tuple = () + for i in range(self.interleave_num): + x_part, attention_part = self.interleave2_inputs[i](i, x, attention) + output_part = self.linear_layer2(x_part, attention_part) + output_tuple += (output_part,) + output = self.interleaved_concat2(output_tuple) + # ============linear-layer2 end=========== + return output + + def _check_input(self, x, freqs_cis, mask): + r"""Check inputs""" + _check_input_dtype( + x.dtype, "x", [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + freqs_cos, freqs_sin, swap_mask = freqs_cis + _check_input_dtype(freqs_cos.dtype, "freqs_cos", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + _check_input_dtype(freqs_sin.dtype, "freqs_sin", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + if swap_mask is not None: + _check_input_dtype(swap_mask.dtype, "swap_mask", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + if mask is not None: + _check_input_dtype(mask.dtype, "input_mask", + [mstype.float32, mstype.float16, mstype.uint8, mstype.bfloat16], self.cls_name) + return True diff --git a/MindSpore-telechat/telechat_layer.py b/MindSpore-telechat/telechat_layer.py new file mode 100644 index 0000000..28cff05 --- /dev/null +++ b/MindSpore-telechat/telechat_layer.py @@ -0,0 +1,275 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat Model Layers' APIs.""" + +from mindspore.common.parameter import Parameter +from mindspore import nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.nn.cell import Cell + +try: + from mindspore._checkparam import Validator +except ImportError: + import mindspore._checkparam as Validator +from mindspore import log as logger +from mindspore.common.initializer import initializer, Normal +from mindspore.parallel._utils import _get_parallel_mode +from mindspore.context import ParallelMode + +from mindformers.models.llama.llama_layer import LlamaSiLU +from mindformers.modules.layers import Linear, Dropout, _check_input_dtype, _args_type_validator_check, _valid_value_checks +from mindformers.tools.logger import _LogActionOnce + + +class TelechatEmbedding(Cell): + """ + Embedding Layer. + + Args: + - **vocab_size** (int): Size of the dictionary of embeddings. + - **embedding_size** (int): The size of each embedding vector. + - **param_init_type** (mstype): The param init type, default mstype.float32. + - **parallel_config** (TransformerOpParallelConfig): The parallel config of network. Default + `default_embedding_parallel_config`, an instance of `EmbeddingOpParallelConfig` with default args. + - **param_init** (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table. + Refer to class `initializer` for the values of string when a string + is specified. Default: 'normal'. + Inputs: + - **input_ids** (Tensor) - The tokenized inputs with datatype int32 with shape (batch_size, seq_length) + + Outputs: + - **output** (Tensor) - The embedding vector for the input with shape (batch_size, + seq_length, embedding_size). + """ + + @_LogActionOnce(m_logger=logger, key='Embedding', + no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,)) + @_args_type_validator_check(vocab_table_size=Validator.check_positive_int, + embedding_size=Validator.check_positive_int) + def __init__(self, vocab_table_size, embedding_size, sigma=0.0048, mean=0.0, param_init_type=mstype.float32, + parallel_optimizer=True): + super().__init__() + self.vocab_table_size = vocab_table_size + self.embedding_size = embedding_size + self.embedding_weight = Parameter( + initializer(Normal(sigma=sigma, mean=mean), [self.vocab_table_size, self.embedding_size], + dtype=param_init_type), name='embedding_weight', parallel_optimizer=parallel_optimizer) + self.gather = P.Gather() + + def construct(self, input_ids): + """Forward of vocab embedding.""" + _check_input_dtype(F.dtype(input_ids), "input_ids", [mstype.int32, mstype.int64], self.cls_name) + output = self.gather(self.embedding_weight, input_ids, 0) + return output + + def shard(self, parallel_config): + """sharding for embedding""" + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + if parallel_config.vocab_emb_dp: + self.gather.shard(((1, 1), (dp, 1))) + logger.info(f"Using {dp} data parallel for the embedding lookup.") + else: + if self.vocab_table_size % mp != 0: + logger.warning("The vocab size of Loss is: %s, it is not divide by model_parallel: %s", + self.vocab_table_size, mp) + logger.warning("Now, the model_parallel num of Loss will be changed: mp = 1") + self.gather.shard(((1, 1), (dp, 1))) + else: + self.gather.shard(((mp, 1), (dp, 1))) + logger.info(f"Using {dp} data parallel and {mp} " + f"model parallel for the embedding lookup.") + + +class TelechatLinear(Linear): + # pylint: disable=W0212 + """ + Linear function for Telechat. + """ + + def __init__(self, + in_channels, + out_channels, + sigma=0.0048, + mean=0.0, + bias_init='zeros', + has_bias=True, + activation=None, + transpose_b=True, + expert_num=1, + outer_batch=1, + param_init_type=mstype.float32, + compute_dtype=mstype.float16, + skip_redistribution=False, + keep_prob=1.0): + super(TelechatLinear, self).__init__( + in_channels, + out_channels, + bias_init=bias_init, + has_bias=has_bias, + activation=activation, + transpose_b=transpose_b, + expert_num=expert_num, + outer_batch=outer_batch, + param_init_type=param_init_type, + skip_redistribution=skip_redistribution, + compute_dtype=compute_dtype) + weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels] + self.weight = Parameter(initializer(Normal(sigma=sigma, mean=mean), weight_shape, param_init_type), + name="weight") + self.dropout = Dropout(keep_prob=keep_prob) + + def construct(self, x): + """construct of linear.""" + out_shape = self.shape(x)[:-1] + (self.out_channels,) + x = self.reshape(x, (-1, self.in_channels)) + if self.expert_flag: + x = self.reshape(x, (self.outer_batch, self.expert_num, -1, self.in_channels)) + ori_dtype = F.dtype(x) + weight = self.cast(self.weight, self.dtype) + x = self.cast(x, self.dtype) + x = self.matmul(x, weight) + if self.has_bias: + x = self.bias_add(x, self.cast(self.bias, self.dtype)) + if self.activation_flag: + x = self.activation(x) + x = F.cast(x, ori_dtype) + output = self.reshape(x, out_shape) + output = self.dropout(output) + return output + + +class TelechatFeedForward(Cell): + r""" + Telechat FeedForward. + + .. math:: + (xW_1 * xW_3)W_2 + + Inputs: + - **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`. + Float tensor. + + Outputs: + Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or + [batch * seq_length, hidden_size]`. + + Raises: + ValueError: `hidden_dim` is not a multiple of the model parallel way. + ValueError: `dim` is not a multiple of the model parallel way. + """ + + @_LogActionOnce(m_logger=logger, key='FeedForward', + no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,)) + @_args_type_validator_check(dim=Validator.check_positive_int, + hidden_dim=Validator.check_positive_int, + multiple_of=Validator.check_positive_int, + compute_dtype=_valid_value_checks([mstype.float32, mstype.float16, mstype.bfloat16], + "FeedForward"), + param_init_type=_valid_value_checks([mstype.float32, mstype.float16, mstype.bfloat16], + "FeedForward")) + def __init__(self, dim, + intermediate_size=None, + hidden_dim=None, + sigma=0.0048, + mean=0.0, + multiple_of=256, + hidden_dropout_prob=1.0, + hidden_act=LlamaSiLU, + ffn_dim_multiplier=None, + compute_dtype=mstype.float16, + param_init_type=mstype.float32, + is_dynamic=False): + super().__init__() + + if hidden_act is None or not (isinstance(hidden_act, str) or issubclass(hidden_act, nn.Cell)): + raise TypeError(f"For FeedForward cell, the hidden_act should str type or nn.Cell type, " + f"but got {hidden_act}.") + + if intermediate_size is not None: + hidden_dim = intermediate_size + else: + if ffn_dim_multiplier is not None: + hidden_dim = int((ffn_dim_multiplier + 0.01) * hidden_dim) + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * \ + ((hidden_dim + multiple_of - 1) // multiple_of) + + self.hidden_dropout_prob = hidden_dropout_prob + self.dtype = compute_dtype + self.hidden_act = hidden_act + self.dim = dim + self.hidden_dim = hidden_dim + self.mul = P.Mul() + self.cast = P.Cast() + self.w1 = TelechatLinear(in_channels=dim, + out_channels=hidden_dim, + activation=hidden_act, + has_bias=False, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + skip_redistribution=is_dynamic) + + self.w2 = TelechatLinear(in_channels=hidden_dim, + out_channels=dim, + has_bias=True, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + skip_redistribution=is_dynamic, + keep_prob=1-self.hidden_dropout_prob) + + self.w3 = TelechatLinear(in_channels=dim, + out_channels=hidden_dim, + has_bias=False, + sigma=sigma, + mean=mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + skip_redistribution=is_dynamic) + + def construct(self, x): + """Forward process of the FeedForward""" + _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + x = self.cast(x, self.dtype) + # [bs, seq, hidden_dim] or [bs * seq, hidden_dim] + gate = self.w1(x) # dp,1 -> dp, mp + hidden = self.w3(x) # dp,1 -> dp, mp + hidden = self.mul(hidden, gate) # dp,mp -> dp, mp + output = self.w2(hidden) # dp,mp -> dp, 1 + return output + + def shard(self, parallel_config): + """sharding for feedforward""" + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + if self.hidden_dim % mp != 0: + raise ValueError("For 'FeedForward', the class variable 'hidden_dim' must be a multiple of the" + "num of model parallel, but got the hidden_dim is {} and the num of model " + "parallel is {}.".format(self.hidden_dim, mp)) + if self.dim % mp != 0: + raise ValueError("For 'FeedForward', the class variable 'dim' must be a multiple of the num of " + "model parallel, but got the dim is {} and the num of model parallel is {}." + .format(self.dim, mp)) + self.w1.shard(((dp, 1), (mp, 1))) + self.w1.activation.shard(((dp, mp),)) + self.w2.shard(((dp, mp), (1, mp)), ((dp, 1), (1,))) + self.w3.shard(((dp, 1), (mp, 1))) + self.mul.shard(((dp, mp), (dp, mp))) diff --git a/MindSpore-telechat/telechat_predict_utils.py b/MindSpore-telechat/telechat_predict_utils.py new file mode 100644 index 0000000..1f0ac9e --- /dev/null +++ b/MindSpore-telechat/telechat_predict_utils.py @@ -0,0 +1,73 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat predict utils.""" + +from collections import deque +import copy + +class History: + """Init from a list of dict, use deque to meet some special situation.""" + def __init__(self, tokenizer, history): + self.input_history = deque() + self.tokenizer = tokenizer + if history: + self._transfer_from_list(history) + + def _transfer_from_list(self, history): + for message in history: + content = message.get("content") + # the token result may not be equal to the result model gen + message.update(self.tokenizer(content)) + self.input_history.append(message) + + def append(self, message): + content = message.get("content") + if "input_ids" not in message or "attention_mask" not in message: + message.update(self.tokenizer(content)) + self.input_history.append(message) + + def append_left(self, message): + content = message.get("content") + if "input_ids" not in message or "attention_mask" not in message: + message.update(self.tokenizer(content)) + self.input_history.appendleft(message) + + def pop(self): + x = self.input_history.pop() + return x + + def pop_left(self): + x = self.pop_left() + return x + + def update(self, message): + self.input_history.pop() + self.append(message) + + def __len__(self): + return self.input_history.__len__() + + def __str__(self): + return self.input_history.__str__() + + def __copy__(self): + new_instance = type(self)(self.tokenizer, []) + new_instance.input_history = copy.copy(self.input_history) + return new_instance + + def __deepcopy__(self, memodict=None): + new_instance = type(self)(self.tokenizer, []) + new_instance.input_history = copy.deepcopy(self.input_history) + return new_instance diff --git a/MindSpore-telechat/telechat_preprocess.py b/MindSpore-telechat/telechat_preprocess.py new file mode 100644 index 0000000..27056ee --- /dev/null +++ b/MindSpore-telechat/telechat_preprocess.py @@ -0,0 +1,112 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""generate mindrecord script""" +import os +import argparse +import collections +from multiprocessing import Pool +import numpy as np +import jsonlines + +from mindspore.mindrecord import FileWriter +from telechat_tokenizer import TelechatTokenizer + +def write_instance_to_file(writer, instance): + """write the instance to file""" + input_ids = instance["input_ids"] + labels = instance["labels"] + + features = collections.OrderedDict() + features["input_ids"] = np.asarray(input_ids).astype(np.int32) + features["labels"] = np.asarray(labels).astype(np.int32) + writer.write_raw_data([features]) + return features + +def preprocess_concat_datas(datasets): + """Preprocess dataset""" + tokens = [] + tokenizer = TelechatTokenizer(args.tokenizer_file, trust_remote_code=True) + user_token_id = tokenizer.convert_tokens_to_ids(args.user_token) + bot_token_id = tokenizer.convert_tokens_to_ids(args.bot_token) + end_token_id = tokenizer.convert_tokens_to_ids(args.end_token) + pad_token_id = tokenizer.convert_tokens_to_ids(args.pad_token) + for data in datasets: + data = data["text"] + input_ids = [] + labels = [] + data = data.replace(args.pad_token, "").replace("", "") + dialogs = data.split(args.end_token)[:-1] + for dialog in dialogs: + dialog = dialog.split(args.bot_token) + question = dialog[0].replace(args.user_token, "") + answer = dialog[1] + input_token = tokenizer(question)["input_ids"] + output_token = tokenizer(answer)["input_ids"] + concat_tokens = [user_token_id] + input_token + [bot_token_id] + output_token + [end_token_id] + concat_labels = [1] + len(input_token) * [0] + [1] + len(output_token) * [1] + [1] + if len(input_ids) <= args.max_length and len(input_ids) + len(concat_tokens) > args.max_length: + break + input_ids = input_ids + concat_tokens + labels = labels + concat_labels + input_ids = input_ids + (args.max_length - len(input_ids)) * [pad_token_id] + labels = labels + (args.max_length - len(labels)) * [0] + tokens.append({"input_ids": input_ids, "labels": labels}) + return tokens + +def process(file_list): + """Multi-process processing""" + f_in = jsonlines.open(os.path.join(args.input_dataset_dir, file_list), "r") + dataset = [data for data in f_in] + f_in.close() + tokens = preprocess_concat_datas(dataset) + print(len(tokens)) + + writer = FileWriter(os.path.join(args.output_path, file_list) + ".mindrecord", 1) + + data_schema = { + "input_ids": {"type": "int32", "shape": [-1]}, + "labels": {"type": "int32", "shape": [-1]} + } + + writer.add_schema(data_schema, "lm-schema") + for token in tokens: + write_instance_to_file(writer, token) + writer.commit() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_dataset_dir", type=str, default="") + parser.add_argument("--output_path", type=str, default="") + parser.add_argument('--tokenizer_file', default='', type=str, help='which model to use.') + parser.add_argument("--max_length", type=int, default=4096) + parser.add_argument("--seed", type=int, default=1233) + parser.add_argument("--user_token", type=str, default="<_user>", help="user token") + parser.add_argument("--bot_token", type=str, default="<_bot>", help="bot token") + parser.add_argument("--end_token", type=str, default="<_end>", help="end token") + parser.add_argument("--pad_token", type=str, default="<_pad>", help="pad token") + parser.add_argument("--pool_num", type=int, default=32, help="num of pool") + args = parser.parse_args() + + args.max_length += 1 + file_lists = [i for i in os.listdir(args.input_dataset_dir)] + if not os.path.exists(args.output_path): + os.makedirs(args.output_path) + + pool = Pool(args.pool_num) + results = [] + for single_file in file_lists: + results.append(pool.apply_async(process, args=(single_file,))) + pool.close() + pool.join() diff --git a/MindSpore-telechat/telechat_tokenizer.py b/MindSpore-telechat/telechat_tokenizer.py new file mode 100644 index 0000000..89bf55c --- /dev/null +++ b/MindSpore-telechat/telechat_tokenizer.py @@ -0,0 +1,255 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat tokenizer APIs.""" + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional + +import sentencepiece as spm + +from mindformers.tools import logger +from mindformers.models.tokenization_utils import PreTrainedTokenizer, AddedToken +from mindformers.tools.register import MindFormerRegister, MindFormerModuleType + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + + +@MindFormerRegister.register(MindFormerModuleType.TOKENIZER) +class TelechatTokenizer(PreTrainedTokenizer): + r""" + Tokenize the input string and convert them into the ids. The tokenizer use the sentence piece internally. + + Args: + model_path(str): The spiece.model file path. + add_bos(bool): The flag defines whether add bos token, Default True. + eos_token(str): The token that represents the end-of-sentence. Default "". + unk_token(str): The token that represents the unknown. Default "". + pad_token(str): The token that represents the pad. Default "". + sp_model_kwargs(str): Other kwargs for sp_model`. + add_bos_token(bool): Whether or not to add the bos_token_id to the left of the input. Default "True" + add_eos_token(bool): Whether or not to add the eos_token_id to the right of the input. Default "True" + clean_up_tokenization_spaces (bool): Whether or not the model should cleanup the spaces that were added when + splitting the input text during the tokenization process. Default "False" + **kwargs: Other kwargs that will be passed into the base class of the `Tokenizer`. + + Outputs: + A dict contains the processed ids, attention_mask that specific by the member `MODEL_INPUT_NAME` + of the subclass. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + FILE_LIST = ['tokenizer_config.json'] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="<_start>", + eos_token="<_end>", + pad_token="<_pad>", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=False, + add_eos_token=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False, single_word=False, normalized=True) \ + if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False, single_word=True, normalized=True) \ + if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False, single_word=True, normalized=True) \ + if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False, single_word=True, normalized=True) \ + if isinstance(pad_token, str) else pad_token + + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + # pylint: disable=R1710 + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return out_vocab_file + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False): + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None): + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output diff --git a/MindSpore-telechat/telechat_transformer.py b/MindSpore-telechat/telechat_transformer.py new file mode 100644 index 0000000..6cbddc0 --- /dev/null +++ b/MindSpore-telechat/telechat_transformer.py @@ -0,0 +1,580 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Telechat transformer Layer's APIs.""" +import math +from typing import Tuple, Optional + +import mindspore as ms +from mindspore import nn +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.context import ParallelMode +from mindspore.ops import operations as P +from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation + +from mindformers.models.llama.llama_layer import LlamaRMSNorm +from mindformers.models.utils import predict_lazy_inline +from mindformers.modules.layers import _check_input_dtype, Dropout, RotaryEmbedding +from mindformers.modules.transformer import TransformerOpParallelConfig +from mindformers.modules.flash_attention import FlashAttention +from mindformers.modules.infer_attention import InferAttention +from mindformers.tools.logger import logger +from mindformers.tools.utils import get_predict_run_mode + +from telechat_layer import TelechatLinear, TelechatFeedForward + + +class TelechatAttention(nn.Cell): + r""" + This is an implementation of multihead attention in Telechat. + + Args: + - **dim** (int): The hidden size of the input. + - **head_dim** (int): The dim of head. + - **n_heads** (int): The number of the heads. + - **compute_dtype** (dtype.Number): The computation type of dense. Default mstype.float16. + Should be mstype.float32 or mstype.float16. + - **softmax_compute_type** (dtype.Number): The type of softmax computation module. Default mstype.float32. + Should be mstype.float32 or mstype.float16. + - **param_init_type** (dtype.Number): The parameter initialization type of the module. Default mstype. + float32. Should be mstype.float32 or mstype.float16. + - **qkv_has_bias** (bool): Whether Q/K/V in attention has bias or not. + - **use_past** (bool): Use the past state to compute, used for incremental prediction. + For example, if we have two words and want to generate the ten more words. + We just need to compute the two words' state only once, and generate the next word one by one. + When use_past is True, there are two steps to run the prediction. + In the first step, set the is_first_iteration to be True by + `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the + is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment, + pass the single step's input tensor, and loop it. Default False. + - **parallel_config** (OpParallelConfig): The parallel configure. Default `default_dpmp_config`, + an instance of `OpParallelConfig` with default args. + + Inputs: + - **x** (Tensor) - The input tokens with shape (batch_size, src_seq_length, hidden_size) or + (batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. + Otherwise, must be (batch_size, 1, hidden_size) + - **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention. + - **attention_mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask + matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask + in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length) + - **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index. + Used for incremental prediction when the use_past is True. Default None. + - **block_tables** (Tensor[int64]) - Store mapping tables for each sequence. + - **slot_mapping** (Tensor[int32]) - Store token cache physical slot index. + Outputs: + Tuple, a tuple contains(`output`, `layer_present`) + + - **output** (Tensor) - Tensor, the float tensor of the output of the layer with + shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size), + if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size). + + - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with + ((batch_size, num_heads, head_dim, tgt_seq_length), + (batch_size, num_heads, tgt_seq_length, head_dim)). + """ + + def __init__(self, + run_mode, + dim: int = 512, + n_heads: int = 8, + n_kv_heads: Optional[int] = None, + hidden_dropout_prob: float = 1.0, + attention_dropout_prob: float = 1.0, + sigma: float = 0.0048, + mean: float = 0.0, + compute_dtype=mstype.float16, + softmax_compute_dtype=mstype.float32, + rotary_dtype=mstype.float32, + param_init_type=mstype.float32, + qkv_has_bias=False, + wo_has_bias=True, + use_past=False, + is_dynamic=False, + use_rope_slice=False, + use_flash_attention=False, + use_attn_mask_compression=False, + block_size: Optional[int] = None, + num_blocks: Optional[int] = None, + parallel_config=TransformerOpParallelConfig()): + super().__init__() + self.hidden_size = dim + self.n_head = n_heads + self.head_dim = dim // n_heads + self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads + self.n_rep = self.n_head // self.n_kv_head + self.kv_dim = self.n_kv_head * self.head_dim + self.block_size = block_size + self.num_blocks = num_blocks + self.sigma = sigma + self.mean = mean + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_dropout_prob = attention_dropout_prob + self.dtype = compute_dtype + self.softmax_dtype = softmax_compute_dtype + self.is_first_iteration = True + self.use_past = use_past + self.use_flash_attention = use_flash_attention + self.use_attn_mask_compression = use_attn_mask_compression + + if self.hidden_size % self.n_head != 0: + raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple " + "of 'n_head', but got the hidden_size is {} and the n_head is {}." + .format(self.hidden_size, self.n_head)) + if self.n_kv_head % parallel_config.model_parallel != 0: + raise ValueError("For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of " + "'parallel_config.model_parallel', but got the n_kv_head is {} " + "and the parallel_config.model_parallel is {}." + .format(self.n_kv_head, parallel_config.model_parallel)) + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + self.shape = P.Shape() + self.cast = P.Cast() + self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True) + + self.wq = TelechatLinear(self.hidden_size, + self.hidden_size, + sigma=self.sigma, + mean=self.mean, + has_bias=qkv_has_bias, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + skip_redistribution=is_dynamic) + self.wk_v = TelechatLinear(self.hidden_size, + self.n_kv_head * self.head_dim * 2, + has_bias=qkv_has_bias, + sigma=self.sigma, + mean=self.mean, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + skip_redistribution=is_dynamic) + + if qkv_has_bias: + self.wq.shard(((dp, 1), (mp, 1)), ((dp, mp), (mp,))) + self.wk_v.shard(((dp, 1), (mp, 1)), ((dp, mp), (mp,))) + else: + self.wq.shard(((dp, 1), (mp, 1))) + self.wk_v.shard(((dp, 1), (mp, 1))) + if run_mode == "predict": + self.split_kv = ms.ops.auto_generate.SplitWithSize() + self.split_kv.add_prim_attr("skip_redistribution", True) + else: + self.split_kv = P.Split(output_num=2, axis=-1) + self.split_kv.shard(((dp, mp, 1),)) + + self.wo = TelechatLinear(in_channels=self.hidden_size, + out_channels=self.hidden_size, + sigma=self.sigma, + mean=self.mean, + has_bias=wo_has_bias, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + skip_redistribution=is_dynamic, + keep_prob=1-self.hidden_dropout_prob) + if wo_has_bias: + self.wo.shard(((dp, mp), (1, mp)), ((dp, 1), (1,)), out_strategy_matmul=((dp, 1),)) + else: + self.wo.shard(((dp, mp), (1, mp)), out_strategy_matmul=((dp, 1),)) + + if self.use_past: + self.infer_attention = InferAttention(self.n_head, + self.head_dim, + self.n_kv_head, + pa_n_head_split=self.n_head // mp, + pa_n_kv_head_split=self.n_kv_head // mp, + scale_value=1. / math.sqrt(self.head_dim), + pre_tokens=2147483647, + next_tokens=0, + block_size=self.block_size, + num_blocks=self.num_blocks, + use_flash_attention=self.use_flash_attention, + rotary_cos_format=2, + rotary_dtype=rotary_dtype, + compute_dtype=compute_dtype) + self.infer_attention.shard(parallel_config) + else: + self.inv_norm_factor = Tensor(1.0 / math.sqrt(self.head_dim), dtype=compute_dtype) + + self.transpose = P.Transpose() + self.merger_head_transpose = P.Transpose() + self.batch_matmul = P.BatchMatMul() + self.batch_matmul_q_k = P.BatchMatMul(transpose_b=True) + self.mul = P.Mul() + self.add = P.Add() + self.softmax = P.Softmax() + self.cast_attn = P.Cast() + self.tile_kv = P.Tile() + + self.apply_rotary_emb = RotaryEmbedding(self.head_dim, rotary_dtype, use_rope_slice=use_rope_slice) + self.attention_dropout = Dropout(1-self.attention_dropout_prob) + + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.transpose.shard(((dp, 1, mp, 1),)) + self.merger_head_transpose.shard(((dp, mp, 1, 1),)) + self.batch_matmul_q_k.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) + self.batch_matmul.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) + self.mul.shard(((dp, mp, 1, 1), ())) + self.add.shard(((dp, 1, 1, 1), (dp, mp, 1, 1))) + self.softmax.shard(((dp, mp, 1, 1),)) + self.tile_kv.shard(((dp, mp, 1, 1),)) + + self.apply_rotary_emb.shard(parallel_config) + if parallel_config.use_seq_parallel and self.is_first_iteration: + self.wo.shard(((dp, mp), (1, mp)), out_strategy_matmul=((dp * mp, 1),)) + if parallel_config.recompute.select_recompute and not self.use_flash_attention: + self.apply_rotary_emb.recompute() + self.tile_kv.recompute() + self.batch_matmul_q_k.recompute() + self.mul.recompute() + self.add.recompute() + self.cast_attn.recompute() + self.softmax.recompute() + self.batch_matmul.recompute() + + if self.use_flash_attention: + self.sparse_mode = 2 if self.use_attn_mask_compression else 0 + self.flash_attention = FlashAttention(head_num=self.n_head, + pre_tokens=65536, + next_tokens=0, + input_layout="BNSD", + keep_prob=1. - attention_dropout_prob, + scale_value=1. / math.sqrt(self.head_dim), + sparse_mode=self.sparse_mode, + use_attention_mask=True) + self.flash_attention.shard(parallel_config) + + + def construct(self, x: Tensor, freqs_cis: Tuple[Tensor, Tensor], mask=None, batch_valid_length=None, + block_tables=None, slot_mapping=None, prefix_keys_values=None): + """Forward process of the MultiHeadAttention""" + ori_dtype = x.dtype + # [bs, seq/1, hidden_dim] + bs, seq_len, _ = self.shape(x) + query = self.cast(self.wq(x), self.dtype) # dp, 1 -> dp, mp + key_value = self.cast(self.wk_v(x), self.dtype) # dp, 1 -> dp, mp + key_value = self.reshape(key_value, (-1, self.n_kv_head, self.head_dim * 2)) + if self.training: + key, value = self.split_kv(key_value) + else: + key, value = self.split_kv(key_value, (self.head_dim, self.head_dim), 2) + + # key and value for current token(s) + if self.use_past: + key = self.reshape(key, (bs, seq_len, self.n_kv_head * self.head_dim)) + value = self.reshape(value, (bs, seq_len, self.n_kv_head * self.head_dim)) + context_layer = self.infer_attention(query, key, value, batch_valid_length, block_tables, slot_mapping, + freqs_cis, mask, prefix_keys_values=prefix_keys_values) + else: + query = self.transpose(self.reshape(query, (bs, seq_len, self.n_head, self.head_dim)), (0, 2, 1, 3)) + key = self.transpose(self.reshape(key, (bs, seq_len, self.n_kv_head, self.head_dim)), (0, 2, 1, 3)) + query, key = self.apply_rotary_emb(query, key, freqs_cis) # dp, mp, 1, 1 + + value = self.transpose(self.reshape(value, (bs, seq_len, self.n_kv_head, self.head_dim)), (0, 2, 1, 3)) + key, value = self._cat_prefix(key, value, prefix_keys_values) + + if self.use_flash_attention: + context_layer = self.flash_attention(query, key, value, mask) + context_layer = self._merge_heads(context_layer) + else: + key = self._repeat_kv(key, self.n_rep) + value = self._repeat_kv(value, self.n_rep) + context_layer = self._attn(query, key, value, mask) + + # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] + output = self.wo(context_layer) # dp, mp -> dp, 1 / dp * mp, 1 + output = self.cast(output, ori_dtype) + return output + + def _cat_prefix(self, key, value, prefix_keys_values): + r''' + concat prefix_keys_values to key and value + prefix_keys_values: shape(2, bs, pre_len, num_heads * kv_channels) + ''' + if prefix_keys_values is not None: + bs, n_kv_head, _, head_dim = key.shape + past_key = prefix_keys_values[0] + past_value = prefix_keys_values[1] + past_key = self.transpose(self.reshape(past_key, (bs, -1, n_kv_head, head_dim)), (0, 2, 1, 3)) + past_value = self.transpose(self.reshape(past_value, (bs, -1, n_kv_head, head_dim)), (0, 2, 1, 3)) + past_key = self.cast(past_key, self.dtype) + past_value = self.cast(past_value, self.dtype) + cat = P.Concat(2) + key = cat((past_key, key)) + value = cat((past_value, value)) + return key, value + + def _repeat_kv(self, x, rep): + if rep == 1: + return x + bs, n_kv_head, seqlen, head_dim = self.shape(x) + x = self.reshape(x, (bs, n_kv_head, 1, seqlen * head_dim)) + x = self.tile_kv(x, (1, 1, rep, 1)) + x = self.reshape(x, (bs, n_kv_head * rep, seqlen, head_dim)) + return x + + def _merge_heads(self, x): + """ + convert a 4d input to a 3d output + + Inputs: + x: input tensor + + Output: + x_merge: the 2d output + """ + # [bs, n_head, seq/1, head_dim] + x = self.merger_head_transpose(x, (0, 2, 1, 3)) # dp,mp,1,1 -> dp,1,mp,1 + # [bs, seq/1, n_head, head_dim] + bs, seq_len, n_head, head_dim = self.shape(x) + # [bs, seq/1, hidden_dim] + new_shape = (bs, seq_len, n_head * head_dim) + x_merge = self.reshape(x, new_shape) + return x_merge + + def _attn(self, query, key, value, mask): + """ + Get the weighted score along the seq_length + + Inputs: + query: the query matrix + key: the key matrix + value: the value matrix + mask: the attention mask adder matrix with shape (batch_size, + 1, seq_length, seq_length) + Outputs: + weighted_values: Tensor, the weighted sum scores + """ + # q, k: [bs, n_head, seq/1, head_dim], [bs, n_head, seq, head_dim] + score = self.batch_matmul_q_k(query, key) + # score: [bs, n_head, seq/1, seq] + score = self.mul(score, self.inv_norm_factor) + score = self.add(mask, score) + + attention_probs = self.softmax(self.cast_attn(score, self.softmax_dtype)) + # score, v: [bs, n_head, seq/1, seq], [bs, n_head, seq, head_dim] + attention_probs = self.attention_dropout(attention_probs) + weighted_values = self.batch_matmul(self.cast(attention_probs, self.dtype), value) + # [bs, n_head, seq/1, head_dim] + attention_merge = self._merge_heads(weighted_values) + # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] + return attention_merge + + +# pylint: disable=C0326 +class TelechatDecodeLayer(nn.Cell): + r""" + Transformer Layer. This is an implementation of the single layer of the transformer + encoder layer, including multihead attention and feedward layer. + + Args: + layer_id(int): The layer id of current transformer block layer. + dim(int): The hidden size of the input. + num_heads(int): The number of the heads. + multiple_of(int): The SwiGLU hidden layer size multiple of large power of 2. + norm_eps (float): The epsilon value of the denominator. Default 1e-5. + compute_dtype(dtype.Number): The computation type of the layer. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + layernorm_compute_type(dtype.Number): The computation type of the norm. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + softmax_compute_type(dtype.Number): The computation type of the softmax in the attention. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + param_init_type(dtype.Number): The parameter initialization type of the module. + Should be mstype.float32 or mstype.float16. Default mstype.float32. + qkv_has_bias(bool): Whether Q/K/V in attention has bias or not. + use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two + words and want to generate the ten more words. We just need to compute the two words' state only once, + and generate the next word one by one. When use_past is True, there are two steps to run the prediction. + In the first step, set the is_first_iteration to be True by + `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the + is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. + At this moment, pass the single step's input tensor, and loop it. Default False. + parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied, + MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`, + an instance of `OpParallelConfig` with default args. + + Inputs: + - **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or + [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise, + should be [batch_size, 1, hidden_size] + - **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention. + - **input_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True, + the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will + be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size] + - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and + past value parameter used in the incremental prediction. Only valid when use_past is True. Default True. + - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. + Used for incremental prediction when the use_past is True. Default None. + - **block_tables** (Tensor[int64]) - Store mapping tables for each sequence. + - **slot_mapping** (Tensor[int32]) - Store token cache physical slot index. + Outputs: + Tuple, a tuple contains(`output`, `layer_present`). + + - **output** (Tensor) - The float tensor of the output of the layer with + shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is + False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size) + + - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with + ((batch_size, num_heads, head_dim, seq_length), + (batch_size, num_heads, seq_length, head_dim)). + + """ + + @predict_lazy_inline + def __init__(self, + run_mode, + layer_id, + dim: int = 512, + n_heads: int = 8, + sigma: float = 0.0048, + mean: float = 0.0, + hidden_dropout_prob: float = 1.0, + attention_dropout_prob: float = 1.0, + n_kv_heads: Optional[int] = None, + intermediate_size: Optional[int] = None, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[int] = None, + norm_eps: float = 1e-5, + compute_dtype=mstype.float16, + layernorm_compute_dtype=mstype.float32, + softmax_compute_dtype=mstype.float32, + rotary_dtype=mstype.float32, + param_init_type=mstype.float32, + res_dtype=mstype.float32, + qkv_has_bias=False, + wo_has_bias=True, + use_past=False, + is_dynamic=False, + use_rope_slice=False, + use_flash_attention=False, + use_attn_mask_compression=False, + block_size: Optional[int] = None, + num_blocks: Optional[int] = None, + parallel_config=TransformerOpParallelConfig()): + super().__init__() + self.layer_id = layer_id + self.hidden_size = dim + self.n_head = n_heads + self.head_dim = self.hidden_size // self.n_head + self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads + self.dtype = compute_dtype + self.res_dtype = res_dtype + self.is_first_iteration = True + self.use_past = use_past + + self.sigma = sigma + self.mean = mean + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_dropout_prob = attention_dropout_prob + + self.shape = P.Shape() + self.reshape = P.Reshape() + self.cast = P.Cast() + self.add = P.Add() + self.ffn_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) + self.attention_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) + self.attention = TelechatAttention(run_mode=run_mode, + dim=dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + sigma = self.sigma, + mean = self.mean, + hidden_dropout_prob=hidden_dropout_prob, + attention_dropout_prob=attention_dropout_prob, + compute_dtype=compute_dtype, + softmax_compute_dtype=softmax_compute_dtype, + rotary_dtype=rotary_dtype, + param_init_type=param_init_type, + qkv_has_bias=qkv_has_bias, + wo_has_bias=wo_has_bias, + use_past=use_past, + is_dynamic=is_dynamic, + use_rope_slice=use_rope_slice, + use_flash_attention=use_flash_attention, + use_attn_mask_compression=use_attn_mask_compression, + block_size=block_size, + num_blocks=num_blocks, + parallel_config=parallel_config) + + self.feed_forward = TelechatFeedForward(dim=self.hidden_size, + intermediate_size=intermediate_size, + hidden_dim=4 * self.hidden_size, + sigma=self.sigma, + mean=self.mean, + hidden_dropout_prob=hidden_dropout_prob, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + compute_dtype=compute_dtype, + param_init_type=param_init_type, + is_dynamic=is_dynamic) + + dp = parallel_config.data_parallel + mp = parallel_config.model_parallel + if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): + self.feed_forward.shard(parallel_config) + self.add.shard(((dp, 1, 1), (dp, 1, 1))) + self.attention_norm.shard((dp, 1, 1)) + self.ffn_norm.shard((dp, 1, 1)) + self.feed_forward.mul.shard(((dp, 1, mp), (dp, 1, mp))) + + if parallel_config.use_seq_parallel and self.is_first_iteration: + self.add.shard(((dp, mp, 1), (dp, mp, 1))) + self.attention_norm.shard((dp, mp, 1)) + self.ffn_norm.shard((dp, mp, 1)) + self.feed_forward.w2.shard(((dp, mp), (1, mp)), out_strategy_matmul=((dp * mp, 1),)) + + self.predict_run_mode = get_predict_run_mode() + logger.info("Predict run mode:{}".format(self.predict_run_mode)) + + if self.predict_run_mode: + self.no_inline = False + + def construct(self, x, freqs_cis, mask=None, batch_valid_length=None, block_tables=None, + slot_mapping=None, prefix_keys_values=None): + """ Forward of transformer block. """ + if not self.use_past: + self._check_input(x, freqs_cis, mask) + ori_dtype = x.dtype + # [bs, seq/1, hidden_dim] + input_x = self.attention_norm(x) + # [bs, seq/1, hidden_dim] + h = self.attention(input_x, freqs_cis, mask, batch_valid_length, block_tables, + slot_mapping, prefix_keys_values) + h = self.add(self.cast(x, self.res_dtype), self.cast(h, self.res_dtype)) + h = self.cast(h, ori_dtype) + ffn_norm = self.ffn_norm(h) + # [bs, seq/1, hidden_dim] + ffn_out = self.feed_forward(ffn_norm) + # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] + h = self.add(self.cast(h, self.res_dtype), self.cast(ffn_out, self.res_dtype)) + out = self.cast(h, ori_dtype) + return out + + def _check_input(self, x, freqs_cis, mask): + r"""Check inputs""" + _check_input_dtype( + x.dtype, "x", [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + freqs_cos, freqs_sin, swap_mask = freqs_cis + _check_input_dtype(freqs_cos.dtype, "freqs_cos", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + _check_input_dtype(freqs_sin.dtype, "freqs_sin", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + if swap_mask is not None: + _check_input_dtype(swap_mask.dtype, "swap_mask", + [mstype.float32, mstype.float16, mstype.bfloat16], self.cls_name) + if mask is not None: + _check_input_dtype(mask.dtype, "input_mask", + [mstype.float32, mstype.float16, mstype.bfloat16, mstype.uint8, mstype.bool_], + self.cls_name) + return True diff --git a/README.md b/README.md index c001c21..f69269a 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ - 星辰语义大模型**TeleChat2**是由中国电信人工智能研究院研发训练的大语言模型,该系列模型**完全基于国产算力**训练。 - 本次开源**TeleChat2-115B**模型采用10万亿 Tokens中英文高质量语料进行训练,同步开源对话模型**TeleChat2-115B**的多格式、多平台权重文件。 - **TeleChat2**在训练数据、训练方法等方面进行了改进,在通用问答和知识类、代码类、数学类榜单上相比**TeleChat1**均有大幅提升。 - - **TeleChat2**完全基于国产算力和国产深度学习框架进行训练,算力和算法框架更自主可控。优化MP、PP、SP实现方式提升模型性能,优化算子来提升训练速度。 + - **TeleChat2**完全基于国产算力和国产深度学习框架昇思MindSpore进行训练,算力和算法框架更自主可控。优化MP、PP、SP实现方式提升模型性能,优化算子来提升训练速度。 - 我们使用大量小模型实验来验证scaling law规律,在不同模型结构、不同数据配比和数据清洗方式中寻找最优设计。 - 采用RingAttention及其他序列切分方式,实现长文训练性能提升;通过ntk-aware+attention-scaling的方式保证训练长度切换时的平稳过渡,以此来保证模型在不同长度数据下的训练效果。 - 在微调数据方面,我们进行了指令复杂性提升与多样性扩充,通过数据合成和人工标注生成高质量数据,并使用拒绝采样生成多样的推理路径;通过研究一套基于base模型反向选择偏好对齐数据方案,基于适配数据最大限度提升模型效果。 -- Gitee