大家好,我是写代码的中年人!
今天我们使用开源的的中文数据进行模型的预训练,下面跟着我的步骤,从零实现你的预训练模型。
本文所有代码和数据资源位置:
https://github.com/ColinAIAPP/MoiraiLM
01、预训练模型的概念
预训练模型(Pretrained Model)就是一个已经在海量数据上训练过的模型,它学会了语言的基本规律、结构和语义,然后可以拿来做各种下游任务,比如写作、翻译、问答、分类、生成代码等。
那“预训练”到底在学什么?以语言模型(LLM)为例:预训练阶段的任务通常是预测下一个词(token)。
接下来我们就一步一步实现一个17M参数的预训练模型。
02、数据准备
构建语言模型的第一要义是高质量的数据源。对于中文任务,选择维基百科开源中文数据集是一个理想起点。这个数据集包含数百万条中文百科条目,涵盖历史、文化、科技等领域,总量约数GB的纯文本数据。它开源且免费,可通过维基百科的官方转储页面下载最新版本的XML格式文件。
要解压处理这个文件我们要使用wikiextractor工具进行数据解压。安装解压命令:pip install wikiextractor解压命令:
python -m wikiextractor.WikiExtractor -b 1G -o extracted_wiki_zh zhwiki-20250920-pages-articles-multistream.xml.bz2 --json
zhwiki-20250920-pages-articles-multistream.xml.bz2:为文件名
INFO: Preprocessing 'zhwiki-20250920-pages-articles-multistream.xml.bz2' to collect template definitions: this may take some time.INFO: Preprocessed 100000 pagesINFO: Preprocessed 200000 pagesINFO: Preprocessed 300000 pagesINFO: Preprocessed 400000 pagesINFO: Preprocessed 500000 pagesINFO: Preprocessed 600000 pagesINFO: Preprocessed 700000 pagesINFO: Preprocessed 800000 pagesINFO: Preprocessed 900000 pagesINFO: Preprocessed 1000000 pagesINFO: Preprocessed 1100000 pagesINFO: Preprocessed 1200000 pagesINFO: Preprocessed 1300000 pagesINFO: Preprocessed 1400000 pagesINFO: Preprocessed 1500000 pagesINFO: Preprocessed 1600000 pagesINFO: Preprocessed 1700000 pagesINFO: Preprocessed 1800000 pagesINFO: Preprocessed 1900000 pagesINFO: Preprocessed 2000000 pagesINFO: Preprocessed 2100000 pagesINFO: Preprocessed 2200000 pagesINFO: Preprocessed 2300000 pagesINFO: Preprocessed 2400000 pagesINFO: Preprocessed 2500000 pagesINFO: Preprocessed 2600000 pagesINFO: Preprocessed 2700000 pagesINFO: Preprocessed 2800000 pagesINFO: Preprocessed 2900000 pagesINFO: Preprocessed 3000000 pagesINFO: Preprocessed 3100000 pagesINFO: Preprocessed 3200000 pagesINFO: Preprocessed 3300000 pagesINFO: Preprocessed 3400000 pagesINFO: Preprocessed 3500000 pagesINFO: Preprocessed 3600000 pagesINFO: Preprocessed 3700000 pagesINFO: Preprocessed 3800000 pagesINFO: Preprocessed 3900000 pagesINFO: Preprocessed 4000000 pagesINFO: Preprocessed 4100000 pagesINFO: Preprocessed 4200000 pagesINFO: Preprocessed 4300000 pagesINFO: Preprocessed 4400000 pagesINFO: Preprocessed 4500000 pagesINFO: Preprocessed 4600000 pagesINFO: Preprocessed 4700000 pagesINFO: Loaded 1036734 templates in 704.2sINFO: Starting page extraction from zhwiki-20250920-pages-articles-multistream.xml.bz2.INFO: Using 127 extract processes.INFO: Extracted 100000 articles (1209.6 art/s)INFO: Extracted 200000 articles (1947.8 art/s)INFO: Extracted 300000 articles (2325.1 art/s)INFO: Extracted 400000 articles (3471.3 art/s)INFO: Extracted 500000 articles (2551.1 art/s)INFO: Extracted 600000 articles (2239.4 art/s)INFO: Extracted 700000 articles (2299.3 art/s)INFO: Extracted 800000 articles (1525.2 art/s)INFO: Extracted 900000 articles (3256.1 art/s)INFO: Extracted 1000000 articles (3485.9 art/s)INFO: Extracted 1100000 articles (3495.0 art/s)INFO: Extracted 1200000 articles (3330.4 art/s)INFO: Extracted 1300000 articles (3555.6 art/s)INFO: Extracted 1400000 articles (3456.3 art/s)INFO: Extracted 1500000 articles (2476.1 art/s)INFO: Extracted 1600000 articles (2268.6 art/s)INFO: Extracted 1700000 articles (2473.5 art/s)INFO: Extracted 1800000 articles (2305.9 art/s)INFO: Extracted 1900000 articles (2263.9 art/s)INFO: Extracted 2000000 articles (2136.4 art/s)INFO: Extracted 2100000 articles (2363.0 art/s)INFO: Extracted 2200000 articles (2601.9 art/s)INFO: Extracted 2300000 articles (3709.0 art/s)INFO: Extracted 2400000 articles (2723.9 art/s)INFO: Extracted 2500000 articles (2487.1 art/s)INFO: Extracted 2600000 articles (2621.3 art/s)INFO: Extracted 2700000 articles (2525.4 art/s)INFO: Extracted 2800000 articles (2666.4 art/s)INFO: Finished 127-process extraction of 2893023 articles in 1156.5s (2501.5 art/s)
03、清洗数据
我们解压后的数据如下图,下面我们要把数据清洗出来。
注:我们本步骤生成的文件为 data/cleaned_wiki_full.txt
import osimport jsonimport loggingimport argparseimport refrom tqdm import tqdm# 配置日志记录logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s')# python scripts/clean_wiki_text.py data/extracted_wiki_zh data/cleaned_wiki_full.txt --min_line_length 20 --min_article_length 300def clean_text(text: str) -> str:"""对文本进行深度清洗。移除维基百科特有的格式标记、参考文献、HTML标签、日期和数字等。"""# 移除维基链接 [[link|display]] 或 [[link]]text = re.sub(r'[[([^]|]+|)?([^]]+)]]', r'2', text)# 移除参考文献标记 [1], [2], [ref], 等text = re.sub(r'[d+]|[ref]|[/ref]|[citation needed]', '', text)# 移除HTML标签text = re.sub(r'<[^>]+>', '', text)# 移除日期格式 (yyyy-mm-dd, yyyy/mm/dd, mm/dd/yyyy 等)text = re.sub(r'd{1,4}[-/]d{1,2}[-/]d{1,4}', '', text)# 移除年份 (1000-2999)text = re.sub(r'b[12]d{3}b', '', text)# 移除纯数字(包括小数)text = re.sub(r'bd+.?d*b', '', text)# 移除重复的空白字符(但保留单个空格)text = re.sub(r' +', ' ', text)# 移除行首尾空白text = text.strip()return textdef process_extracted_wiki(extracted_dir: str,output_file: str,min_line_length: int = 20,min_article_length: int = 200):"""读取WikiExtractor输出的JSON文件,提取、清洗文本并保存到单个文件中。参数:extracted_dir: WikiExtractor输出的目录路径output_file: 最终合并的纯文本文件路径min_line_length: 单行文本最小长度,用于过滤噪音(默认: 20)min_article_length: 文章最小长度,用于过滤短文章(默认: 200)"""if not os.path.isdir(extracted_dir):logging.error(f"输入的目录不存在: {extracted_dir}")returntotal_articles = 0skipped_articles = 0# 第一次遍历:获取所有需要处理的文件列表file_list = []for root, dirs, files in os.walk(extracted_dir):for file_name in files:# 仅处理 WikiExtractor 生成的以 'wiki_' 开头的文件if file_name.startswith('wiki_'):file_list.append(os.path.join(root, file_name))total_files = len(file_list)logging.info(f"找到 {total_files} 个文件等待处理。")if total_files == 0:logging.warning(f"目录 {extracted_dir} 中未找到任何 'wiki_' 文件。请检查路径。")return# 第二次遍历:处理文件并写入输出with open(output_file, 'w', encoding='utf-8') as f_out:# 使用 tqdm 包装文件列表,显示处理进度for file_path in tqdm(file_list, desc=" 正在提取维基文本"):try:with open(file_path, 'r', encoding='utf-8') as f_in:for line_num, line in enumerate(f_in, 1):try:article = json.loads(line)text_content = article.get('text', '').strip()# --- 文本清洗和过滤 ---# 1. 过滤掉过短的文章,它们通常是噪音或重定向页if len(text_content) < min_article_length:skipped_articles += 1continue# 2. 按行处理文本,过滤短行和额外的空白# 保留行结构,而不是将所有行连接成一个长句子cleaned_lines = []for text_line in text_content.split('n'):text_line = clean_text(text_line)# 只保留足够长的行if len(text_line) >= min_line_length:cleaned_lines.append(text_line)# 使用换行符连接各行,保留段落结构final_text = 'n'.join(cleaned_lines)# 最终检查:确保清洗后的文本仍然足够长if final_text and len(final_text) >= min_article_length:# 文章之间用两个换行符分隔f_out.write(final_text + 'nn')total_articles += 1else:skipped_articles += 1except json.JSONDecodeError:logging.warning(f"无法解析 JSON,文件: {file_path},行号: {line_num}")except Exception as e:logging.error(f"处理文件 {file_path} 第 {line_num} 行时出错: {e}")except Exception as e:logging.error(f"打开文件 {file_path} 时出错: {e}")logging.info(f" 所有维基百科文本已成功提取并清洗。")logging.info(f" 总文章数: {total_articles}")logging.info(f" 跳过文章数: {skipped_articles}")logging.info(f" 文件已保存到: {output_file}")def main():parser = argparse.ArgumentParser(description="从 WikiExtractor 输出的 JSON 文件中提取并清洗纯文本。",formatter_class=argparse.RawTextHelpFormatter)# 位置参数 1: 输入目录parser.add_argument("extracted_directory",type=str,help="WikiExtractor 输出的目录路径 (e.g., extracted_wiki_zh)")# 位置参数 2: 输出文件parser.add_argument("output_filename",type=str,help="最终合并的纯文本文件路径 (e.g., cleaned_wiki.txt)")# 可选参数: 最小行长parser.add_argument("--min_line_length",type=int,default=20,help="文章中单行文本必须达到的最小长度,用于过滤噪音。默认值: 20")# 可选参数: 最小文章长度parser.add_argument("--min_article_length",type=int,default=200,help="文章最小长度,用于过滤短文章和重定向页。默认值: 200")args = parser.parse_args()process_extracted_wiki(args.extracted_directory,args.output_filename,args.min_line_length,args.min_article_length)if __name__ == "__main__":main()
2025-10-01 11:10:58,772 - INFO - 找到 5 个文件等待处理。正在提取维基文本: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00, 6.78s/it]2025-10-01 11:11:32,681 - INFO - 所有维基百科文本已成功提取。总文章数: 628093。文件已保存到 data/cleaned_wiki_full.txt
04、训练分词器
我们使用SentencePiece训练分词器,本次我们训练的分词库大小为16k,你也可以训练32k的分词库。相关代码及过程如下:
注:我们本步骤生成的文件为workdir/spm_wiki_16k.modelworkdir/spm_wiki_16k.vocab
import sysimport sentencepiece as spmimport argparseimport osfrom tqdm import tqdm# python scripts/train_tokenizer.py data/cleaned_wiki_full.txt workdir/spm_wiki 32000def get_corpus_size(input_file: str) -> int:"""计算语料的总行数和文件大小"""try:file_size_bytes = os.path.getsize(input_file)file_size_mb = file_size_bytes / (1024 * 1024)print(f"语料文件大小: {file_size_mb:.2f} MB")# 计算行数和总字符数line_count = 0total_chars = 0with open(input_file, 'r', encoding='utf-8') as f:for line in tqdm(f, desc="统计语料信息"):line_count += 1total_chars += len(line)print(f"语料总行数 (文章数): {line_count}")print(f"总字符数: {total_chars:,}")print(f"平均每行字符数: {total_chars / line_count:.1f}")return file_size_bytesexcept Exception as e:print(f"警告:无法计算文件大小或行数:{e}")return 0def train_spm_model(input_file: str,model_prefix: str,vocab_size: int,model_type: str = 'bpe',character_coverage: float = 0.9995):"""训练一个SentencePiece分词器模型。参数:input_file: 训练语料文件路径model_prefix: 输出模型文件的前缀vocab_size: 词汇表大小model_type: 分词算法类型 ('bpe', 'unigram', 'char', 'word')character_coverage: 字符覆盖率 (0-1,通常 0.995-0.9995)"""if not os.path.exists(input_file):print(f"错误:输入语料文件未找到:{input_file}")sys.exit(1)# 确保输出目录存在output_dir = os.path.dirname(model_prefix)if output_dir and not os.path.exists(output_dir):os.makedirs(output_dir, exist_ok=True)print(f"已创建输出目录: {output_dir}")# 打印语料规模信息print("n=== 语料分析 ===")get_corpus_size(input_file)# 构建训练参数# 对于 1.5GB 语料,建议启用 train_extremely_large_corpus=True 加速train_params = {'input': input_file,'model_prefix': model_prefix,'vocab_size': vocab_size,'model_type': model_type,'character_coverage': character_coverage,'num_threads': 32, # 增加到32(最大化CPU利用)'bos_id': 0,'eos_id': 1,'unk_id': 2,'pad_id': -1,'normalization_rule_name': 'identity','input_sentence_size': 2000000, # 5000000, # 增加到500万句子采样'train_extremely_large_corpus': True, # 必须启用'shuffle_input_sentence': True,'seed_sentencepiece_size': 2000000, # 添加种子句子大小'hard_vocab_limit': False, # 允许超过目标词汇量以获得更好质量}print("n=== SentencePiece 训练参数 ===")for key, value in train_params.items():print(f" {key}: {value}")print("=" * 35)print("n正在训练 SentencePiece 模型...")print(" (请稍候,进度由 SentencePiece 输出)n")try:# 执行训练spm.SentencePieceTrainer.train(**train_params)print("n分词器模型训练完成!")print(f" 模型文件: {model_prefix}.model")print(f" 词汇表文件: {model_prefix}.vocab")# 验证模型是否成功创建if os.path.exists(f"{model_prefix}.model") and os.path.exists(f"{model_prefix}.vocab"):model_size_kb = os.path.getsize(f"{model_prefix}.model") / 1024print(f"n模型文件大小: {model_size_kb:.2f} KB")# 加载模型进行快速测试print("n进行快速测试...")sp = spm.SentencePieceProcessor(model_file=f"{model_prefix}.model")test_text = "这是一个分词测试句子。"tokens = sp.encode(test_text, out_type=str)ids = sp.encode(test_text, out_type=int)print(f" 测试文本: {test_text}")print(f" 分词结果: {tokens}")print(f" Token IDs: {ids}")else:print("n警告:模型文件生成失败,请检查输入数据或参数")except Exception as e:print(f"n训练过程出错: {e}")sys.exit(1)def main():parser = argparse.ArgumentParser(description="使用 SentencePiece 训练分词器模型。",formatter_class=argparse.RawTextHelpFormatter)parser.add_argument("input_file",type=str,help="训练语料的路径 (e.g., data/cleaned_wiki_full.txt)")parser.add_argument("model_prefix",type=str,help="训练模型文件的输出前缀 (e.g., workdir/spm_wiki)")parser.add_argument("vocab_size",type=int,help="词汇表大小 (e.g., 32000)")parser.add_argument("--model_type",type=str,default='bpe',choices=['bpe', 'unigram', 'char', 'word'],help="分词算法类型 (默认: bpe)")parser.add_argument("--character_coverage",type=float,default=0.9995,help="字符覆盖率,范围 [0-1]。对于小词表(8K),建议用0.99或更小")args = parser.parse_args()print("n" + "="*50)print("SentencePiece 分词器训练程序")print("="*50)print(f"输入语料: {args.input_file}")print(f"输出模型前缀: {args.model_prefix}")print(f"词汇表大小: {args.vocab_size}")print(f"分词算法: {args.model_type}")print(f"字符覆盖率: {args.character_coverage}")print("="*50 + "n")train_spm_model(args.input_file,args.model_prefix,args.vocab_size,args.model_type,args.character_coverage)if __name__ == "__main__":main()
开始训练SentencePiece分词器...输入语料: data/cleaned_wiki_full.txt输出模型前缀: workdir/spm_wiki_16k词汇表大小: 16000语料文件大小: 1697.54 MBCounting lines: 1256186it [00:05, 230354.42it/s]语料总行数 (文章数): 1256186--- SentencePiece 训练参数 -----input=data/cleaned_wiki_full.txt--model_prefix=workdir/spm_wiki_16k--vocab_size=16000--model_type=bpe--character_coverage=0.9995--num_threads=16--bos_id=0--eos_id=1--unk_id=2--pad_id=-1------------------------------⏳ 正在启动训练... 请注意观察 SentencePiece 自身的进度输出。sentencepiece_trainer.cc(178) LOG(INFO) Running command: --input=data/cleaned_wiki_full.txt --model_prefix=workdir/spm_colinai_16000 --vocab_size=16000 --model_type=bpe --character_coverage=0.9995 --num_threads=16 --bos_id=0 --eos_id=1 --unk_id=2 --pad_id=-1sentencepiece_trainer.cc(78) LOG(INFO) Starts training with :trainer_spec {input: data/cleaned_wiki_full.txtinput_format:model_prefix: workdir/spm_colinai_16000model_type: BPEvocab_size: 16000self_test_sample_size: 0character_coverage: 0.9995input_sentence_size: 0shuffle_input_sentence: 1seed_sentencepiece_size: 1000000shrinking_factor: 0.75max_sentence_length: 4192num_threads: 16num_sub_iterations: 2max_sentencepiece_length: 16split_by_unicode_script: 1split_by_number: 1split_by_whitespace: 1split_digits: 0pretokenization_delimiter:treat_whitespace_as_suffix: 0allow_whitespace_only_pieces: 0required_chars:byte_fallback: 0vocabulary_output_piece_score: 1train_extremely_large_corpus: 0seed_sentencepieces_file:hard_vocab_limit: 1use_all_vocab: 0unk_id: 2bos_id: 0eos_id: 1pad_id: -1unk_piece: <unk>bos_piece: <s>eos_piece: </s>pad_piece: <pad>unk_surface: ⁇enable_differential_privacy: 0differential_privacy_noise_level: 0differential_privacy_clipping_threshold: 0}normalizer_spec {name: nmt_nfkcadd_dummy_prefix: 1remove_extra_whitespaces: 1escape_whitespaces: 1normalization_rule_tsv:}denormalizer_spec {}trainer_interface.cc(355) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.trainer_interface.cc(186) LOG(INFO) Loading corpus: data/cleaned_wiki_full.txttrainer_interface.cc(382) LOG(WARNING) Found too long line (18615 > 4192).trainer_interface.cc(384) LOG(WARNING) Too long lines are skipped in the training.trainer_interface.cc(385) LOG(WARNING) The maximum length can be changed with --max_sentence_length=<size> flag.trainer_interface.cc(411) LOG(INFO) Loaded all 528882 sentencestrainer_interface.cc(418) LOG(INFO) Skipped 99211 too long sentences.trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <s>trainer_interface.cc(427) LOG(INFO) Adding meta_piece: </s>trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <unk>trainer_interface.cc(432) LOG(INFO) Normalizing sentences...trainer_interface.cc(541) LOG(INFO) all chars count=281809036trainer_interface.cc(552) LOG(INFO) Done: 99.95% characters are covered.trainer_interface.cc(562) LOG(INFO) Alphabet size=8686trainer_interface.cc(563) LOG(INFO) Final character coverage=0.9995trainer_interface.cc(594) LOG(INFO) Done! preprocessed 528882 sentences.trainer_interface.cc(600) LOG(INFO) Tokenizing input sentences with whitespace: 528882trainer_interface.cc(611) LOG(INFO) Done! 3885388.....
05、原始文本转为Token ID 序列
在训练大型语言模型的准备阶段,将海量文本语料转化为模型可处理的数字格式至关重要。本次将原始文本语料编码为整数 Token ID 序列。为了克服单次加载大文件的内存限制,脚本采用了分块读取机制,支持以自定义大小逐块处理语料。所有 Token ID 最终被汇总并转化为高效率的 torch.int32 PyTorch 张量,直接存储为 .pt 文件。这不仅优化了数据格式,方便后续 PyTorch DataLoader 快速读取,同时也提供了关键的统计信息和完整性验证,是构建 LLM 数据集的稳定且高性能的预处理方案。
import sysimport torchimport sentencepiece as spmimport argparsefrom tqdm import tqdmimport osimport numpy as np# python scripts/preprocess_data.py workdir/spm_wiki.model data/cleaned_wiki_full.txt workdir/wiki_tokens.ptdef preprocess(sp_model_path: str,corpus_path: str,output_path: str,chunk_size_mb: int = 50):"""分块读取语料,编码为 Token ID,并保存为 PyTorch 文件。参数:sp_model_path: SentencePiece 模型文件路径corpus_path: 输入语料文件路径output_path: 输出 .pt 文件路径chunk_size_mb: 每次处理的文本大小(MB),默认 50MB"""# 验证文件存在if not os.path.exists(sp_model_path):print(f"错误:分词器模型文件未找到: {sp_model_path}")sys.exit(1)if not os.path.exists(corpus_path):print(f"错误:语料文件未找到: {corpus_path}")sys.exit(1)# 加载分词器try:sp = spm.SentencePieceProcessor(model_file=sp_model_path)vocab_size = sp.get_piece_size()print(f" 分词器加载成功")print(f" 词汇表大小: {vocab_size}")print(f" 特殊 Token: BOS={sp.bos_id()}, EOS={sp.eos_id()}, UNK={sp.unk_id()}, PAD={sp.pad_id()}")except Exception as e:print(f"加载分词器失败: {e}")sys.exit(1)# 确保输出目录存在output_dir = os.path.dirname(output_path)if output_dir and not os.path.exists(output_dir):os.makedirs(output_dir, exist_ok=True)print(f"n 开始处理语料...")print(f" 输入文件: {corpus_path}")print(f" 输出文件: {output_path}")print(f" 块大小: {chunk_size_mb} MBn")# 计算总大小用于进度条total_bytes = os.path.getsize(corpus_path)chunk_size_bytes = chunk_size_mb * 1024 * 1024token_ids = []tokens_processed = 0chunks_processed = 0try:with open(corpus_path, 'r', encoding='utf-8') as f:with tqdm(total=total_bytes, unit='B', unit_scale=True, desc="⏳ 编码语料") as pbar:while True:chunk = f.read(chunk_size_bytes)if not chunk:break# 直接编码(cleaned_wiki_full.txt 已经过清洗)ids = sp.encode(chunk, out_type=int)token_ids.extend(ids)# 更新进度条bytes_read = len(chunk.encode('utf-8'))pbar.update(bytes_read)tokens_processed += len(ids)chunks_processed += 1# 定期显示进度信息if chunks_processed % 10 == 0:pbar.set_postfix({'chunks': chunks_processed,'tokens': f'{tokens_processed:,}'})print(f"n 编码完成")print(f" 处理块数: {chunks_processed}")print(f" 总 Token 数: {tokens_processed:,}")# 转换为 PyTorch 张量print(f"n转换为张量并保存...")final_tensor = torch.tensor(token_ids, dtype=torch.int32)print(f" 张量形状: {final_tensor.shape}")print(f" 张量大小: {final_tensor.numel():,}")print(f" 数据类型: {final_tensor.dtype}")print(f" 占用内存: {final_tensor.numel() * 4 / (1024**3):.2f} GB")# 验证 Token ID 范围min_id = final_tensor.min().item()max_id = final_tensor.max().item()print(f" Token ID 范围: [{min_id}, {max_id}]")if max_id >= vocab_size or min_id < 0:print(f" 警告: 检测到越界 Token ID!")print(f" 词汇表大小: {vocab_size}")print(f" 最大 ID: {max_id}")# 保存张量torch.save(final_tensor, output_path)file_size_mb = os.path.getsize(output_path) / (1024 ** 2)print(f"nToken ID 已保存到 {output_path}")print(f" 文件大小: {file_size_mb:.2f} MB")# 验证保存的文件print(f"n验证保存的文件...")loaded_tensor = torch.load(output_path)print(f" 加载成功,形状: {loaded_tensor.shape}")print(f" 是否相同: {torch.equal(final_tensor, loaded_tensor)}")print(f"n✨ 预处理完成!")except Exception as e:print(f"n处理过程中出错: {e}")import tracebacktraceback.print_exc()sys.exit(1)def main():parser = argparse.ArgumentParser(description="将清洗后的文本语料转换为 Token ID 二进制文件。",formatter_class=argparse.RawTextHelpFormatter)parser.add_argument("model_path",type=str,help="SentencePiece 模型文件路径 (e.g., workdir/spm_wiki.model)")parser.add_argument("corpus_path",type=str,help="输入语料文件路径 (e.g., data/cleaned_wiki_full.txt)")parser.add_argument("output_path",type=str,help="输出 Token ID 文件路径 (e.g., workdir/wiki_tokens.pt)")parser.add_argument("--chunk_size",type=int,default=50,help="每次处理的文本大小(MB),默认 50MB。更大的块更快,但占用更多内存。")args = parser.parse_args()print("n" + "="*60)print("数据预处理程序 - 文本到 Token ID")print("="*60)print(f"SentencePiece 模型: {args.model_path}")print(f"输入语料: {args.corpus_path}")print(f"输出文件: {args.output_path}")print(f"块大小: {args.chunk_size} MB")print("="*60 + "n")preprocess(args.model_path,args.corpus_path,args.output_path,args.chunk_size)if __name__ == "__main__":main()
06、进行模型预训练
"""GPT 高性能训练脚本"""from __future__ import annotationsimport sysimport osimport mathimport jsonfrom datetime import datetimefrom typing import Optionalimport torchfrom torch import nnfrom torch.utils.data import Dataset, DataLoaderimport sentencepiece as spmfrom tqdm import tqdm# ==================== 配置参数 ====================class Config:BLOCK_SIZE = 512 #256BATCH_SIZE = 32 #64GRAD_ACCUM_STEPS = 4 #1MODEL_DIM = 384 #256N_LAYERS = 5 #2NUM_HEADS = 6 #4HEAD_DIM = MODEL_DIM // NUM_HEADSFFN_DIM = MODEL_DIM * 4VOCAB_SIZE = NoneEPOCHS = 1MAX_STEPS = 10000 # 此处根据自己的硬件和时间定义步数WARMUP_STEPS = 500LR = 1e-4MIN_LR = 1e-5WEIGHT_DECAY = 0.01GRAD_CLIP = 1.0DROPOUT = 0.1CHECKPOINT_EVERY = 5000LOG_EVERY = 100DEVICE = "cuda" if torch.cuda.is_available() else "cpu"CHECKPOINT_DIR = "./checkpoints"LATEST_CHECKPOINT = "latest_checkpoint.pth"NUM_WORKERS = 8SEED = 42# 启用 bfloat16 (推荐用于现代 GPU)DTYPE = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16CFG = Config()if CFG.DEVICE == 'cuda':torch.backends.cuda.matmul.allow_tf32 = Truetorch.backends.cudnn.allow_tf32 = Truetorch.cuda.empty_cache()# 检查是否使用了 bfloat16if CFG.DTYPE == torch.bfloat16:print("使用 bfloat16 混合精度 (推荐)")else:print("使用 float16 混合精度")# ==================== 工具函数 ====================def print_gpu_memory():if torch.cuda.is_available():allocated = torch.cuda.memory_allocated() / (1024**3)reserved = torch.cuda.memory_reserved() / (1024**3)print(f"GPU显存: {allocated:.2f}GB / {reserved:.2f}GB")def set_seed(seed: int):torch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed_all(seed)set_seed(CFG.SEED)# ==================== 数据集 ====================class TextDataset(Dataset):def __init__(self, token_ids: torch.Tensor, block_size: int):self.ids = token_ids.long()self.block_size = block_sizedef __len__(self):return max(0, self.ids.size(0) - self.block_size)def __getitem__(self, idx):x = self.ids[idx: idx + self.block_size]y = self.ids[idx + 1: idx + 1 + self.block_size]return x, y# ==================== RoPE 位置编码 ====================class RotaryPositionalEmbedding(nn.Module):"""RoPE 实现"""def __init__(self, head_dim: int, max_seq_len: int = 2048):super().__init__()self.head_dim = head_dimassert head_dim % 2 == 0, "head_dim must be even"# 基频:theta_i = 10000^(-2i/d)inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))self.register_buffer("inv_freq", inv_freq)self.max_seq_len = max_seq_lenself._seq_len_cached = max_seq_lenself._cos_cached = Noneself._sin_cached = Noneself._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)def _update_cos_sin_cache(self, seq_len: int, device: torch.device):if seq_len == self._seq_len_cached and self._cos_cached is not None:return# m: (seq_len,), theta_i: (head_dim//2,)m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", m, self.inv_freq) # (seq_len, head_dim//2)# 构造完整的旋转矩阵(每个复数对重复)emb = torch.cat([freqs, freqs], dim=-1) # (seq_len, head_dim)cos = emb.cos()[None, None, :, :] # (1, 1, seq_len, head_dim)sin = emb.sin()[None, None, :, :] # (1, 1, seq_len, head_dim)self._cos_cached = cosself._sin_cached = sinself._seq_len_cached = seq_lendef forward(self, seq_len: int, device: Optional[torch.device] = None):if device is None:device = self.inv_freq.deviceself._update_cos_sin_cache(seq_len, device=device)return self._cos_cached.to(device), self._sin_cached.to(device)def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:"""应用RoPE旋转"""# x: (B, H, T, D), cos/sin: (1, 1, T, D)# 使用(x, y) -> (x*cos-y*sin, x*sin+y*cos)return (x * cos) + (_rotate_half(x) * sin)def _rotate_half(x: torch.Tensor) -> torch.Tensor:"""将向量旋转90度"""x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)# ==================== Flash Attention ====================class FlashAttention(nn.Module):def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsassert embed_dim % num_heads == 0self.head_dim = embed_dim // num_headsself.scale = self.head_dim ** -0.5self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)self.attn_dropout = nn.Dropout(attn_dropout)self.rope = RotaryPositionalEmbedding(self.head_dim)def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:B, T, C = x.shapeassert T <= self.rope.max_seq_len, f"Seq len {T} exceeds max {self.rope.max_seq_len}"qkv = self.qkv(x)qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)q, k, v = qkv.unbind(dim=2)q = q.permute(0, 2, 1, 3) # (B, H, T, D)k = k.permute(0, 2, 1, 3)v = v.permute(0, 2, 1, 3)# 应用RoPEcos, sin = self.rope(T, device=x.device)q = apply_rotary_emb(q, cos, sin)k = apply_rotary_emb(k, cos, sin)# 注意力计算# 注意:这里如果使用 torch.nn.functional.scaled_dot_product_attention 配合 torch.compile 会更快scores = torch.matmul(q, k.transpose(-2, -1)) * self.scaleif causal_mask is not None:scores = scores.masked_fill(causal_mask == 0, float('-inf'))attn = torch.softmax(scores, dim=-1)attn = self.attn_dropout(attn)out = torch.matmul(attn, v)out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)return self.out_proj(out)# ==================== 前馈网络 ====================class GLU(nn.Module):def __init__(self, in_dim: int, out_dim: int):super().__init__()self.linear = nn.Linear(in_dim, out_dim * 2)def forward(self, x):x, gates = self.linear(x).chunk(2, dim=-1)return x * torch.nn.functional.silu(gates)class FeedForward(nn.Module):def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):super().__init__()self.net = nn.Sequential(GLU(dim, hidden_dim),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout),)def forward(self, x):return self.net(x)# ==================== Transformer Block ====================class TransformerBlock(nn.Module):def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):super().__init__()self.ln1 = nn.LayerNorm(dim)self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)self.ln2 = nn.LayerNorm(dim)self.ff = FeedForward(dim, ffn_dim, dropout)def forward(self, x, causal_mask=None):x = x + self.attn(self.ln1(x), causal_mask)x = x + self.ff(self.ln2(x))return x# ==================== GPT 模型(已移除 pos_emb) ====================class GPTModel(nn.Module):def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,ffn_dim: int = CFG.FFN_DIM, dropout: float = CFG.DROPOUT,tie_weights: bool = True):super().__init__()self.token_emb = nn.Embedding(vocab_size, dim)# self.pos_emb = nn.Embedding(block_size, dim) # 移除:与 RoPE 冲突self.dropout = nn.Dropout(dropout)self.blocks = nn.ModuleList([TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)])self.ln_final = nn.LayerNorm(dim)self.lm_head = nn.Linear(dim, vocab_size, bias=False)if tie_weights:self.lm_head.weight = self.token_emb.weightself.block_size = block_sizeself.apply(self._init_weights)n_params = sum(p.numel() for p in self.parameters())print(f"模型参数: {n_params/1e6:.2f}M")def _init_weights(self, module):if isinstance(module, nn.Linear):nn.init.normal_(module.weight, mean=0.0, std=0.02)if module.bias is not None:nn.init.zeros_(module.bias)elif isinstance(module, nn.Embedding):nn.init.normal_(module.weight, mean=0.0, std=0.02)elif isinstance(module, nn.LayerNorm):nn.init.ones_(module.weight)nn.init.zeros_(module.bias)def forward(self, idx):B, T = idx.shapeassert T <= self.block_size, f"Seq len {T} exceeds block_size {self.block_size}"token_emb = self.token_emb(idx)x = self.dropout(token_emb) # token embeddingcausal_mask = torch.tril(torch.ones(T, T, device=idx.device, dtype=torch.bool))[None, None, :, :]for block in self.blocks:x = block(x, causal_mask)x = self.ln_final(x)logits = self.lm_head(x)return logits# ==================== 检查点管理 ====================def save_checkpoint(model, optimizer, scaler, lr_scheduler, step: int, loss: float, config_dict: dict):os.makedirs(CFG.CHECKPOINT_DIR, exist_ok=True)checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)state = {'step': step,'loss': loss,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'config': config_dict,'torch_rng_state': torch.get_rng_state(),'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,}if scaler is not None and hasattr(scaler, "state_dict"):state['scaler_state_dict'] = scaler.state_dict()if lr_scheduler is not None:state['lr_scheduler_state_dict'] = {'current_step': lr_scheduler.current_step,'warmup_steps': lr_scheduler.warmup_steps,'total_steps': lr_scheduler.total_steps,'base_lr': lr_scheduler.base_lr,'min_lr': lr_scheduler.min_lr,}torch.save(state, checkpoint_path)try:with open(os.path.join(CFG.CHECKPOINT_DIR, "config.json"), "w", encoding="utf-8") as f:json.dump(config_dict, f, indent=2)except Exception:passprint(f" 检查点已保存: {checkpoint_path} (step {step}, loss {loss:.4f})")def load_checkpoint(checkpoint_path: str, model, optimizer, scaler, lr_scheduler):if not os.path.exists(checkpoint_path):return Nonecheckpoint = torch.load(checkpoint_path, map_location=CFG.DEVICE)model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])if checkpoint.get('scaler_state_dict') is not None and scaler is not None:try:scaler.load_state_dict(checkpoint['scaler_state_dict'])except Exception as e:print(f"无法恢复scaler: {e}")if checkpoint.get('lr_scheduler_state_dict') is not None and lr_scheduler is not None:try:sched_state = checkpoint['lr_scheduler_state_dict']lr_scheduler.current_step = sched_state['current_step']lr_scheduler.warmup_steps = sched_state['warmup_steps']lr_scheduler.total_steps = sched_state['total_steps']lr_scheduler.base_lr = sched_state['base_lr']lr_scheduler.min_lr = sched_state['min_lr']except Exception as e:print(f"无法恢复lr_scheduler: {e}")torch.set_rng_state(checkpoint['torch_rng_state'])if torch.cuda.is_available() and checkpoint.get('cuda_rng_state') is not None:torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])print(f"检查点已加载: {checkpoint_path}")print(f" Step: {checkpoint['step']}, Loss: {checkpoint['loss']:.4f}")return checkpoint['step']# ==================== 学习率调度器 ====================class WarmupCosineScheduler:def __init__(self, optimizer, warmup_steps: int, total_steps: int, base_lr: float, min_lr: float):self.optimizer = optimizerself.warmup_steps = max(0, int(warmup_steps))self.total_steps = max(1, int(total_steps))self.base_lr = base_lrself.min_lr = min_lrself.current_step = 0def get_lr(self, step: int = None) -> float:"""计算给定step的学习率(不修改optimizer)"""if step is None:step = self.current_stepif step < self.warmup_steps and self.warmup_steps > 0:return self.base_lr * (step / float(self.warmup_steps))else:denom = max(1, (self.total_steps - self.warmup_steps))progress = (step - self.warmup_steps) / denomprogress = min(1.0, max(0.0, progress))return self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress))def step(self):"""执行一次步长更新"""lr = self.get_lr(self.current_step)for param_group in self.optimizer.param_groups:param_group['lr'] = lrself.current_step += 1return lr# ==================== 训练循环 ====================def train(model: nn.Module, train_loader: DataLoader, epochs: int = CFG.EPOCHS, resume: bool = False):# 检测fused优化器支持fused = Falsetry:fused = torch.cuda.is_available() and ("fused" in torch.optim.AdamW.__init__.__code__.co_varnames)except Exception:fused = Falseoptimizer = torch.optim.AdamW(model.parameters(),lr=CFG.LR,betas=(0.9, 0.95),weight_decay=CFG.WEIGHT_DECAY,fused=fused)# 使用配置中的 DTYPEscaler = torch.cuda.amp.GradScaler(enabled=(CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16))loss_fn = nn.CrossEntropyLoss()total_steps = CFG.MAX_STEPS if CFG.MAX_STEPS else len(train_loader) * epochslr_scheduler = WarmupCosineScheduler(optimizer, CFG.WARMUP_STEPS, total_steps, CFG.LR, CFG.MIN_LR)model.train()start_step = 0best_loss = float('inf')checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)if resume and os.path.exists(checkpoint_path):loaded_step = load_checkpoint(checkpoint_path, model, optimizer, scaler, lr_scheduler)if loaded_step is not None:start_step = loaded_stepglobal_step = start_stepgrad_accum_counter = 0accumulated_loss = 0.0print("n" + "="*60)print("开始训练...")print("="*60)print_gpu_memory()print()# 自动选择是否需要 scaler.scale()use_scaler = (CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16)for epoch in range(epochs):pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", initial=global_step % len(train_loader) if epoch == 0 else 0)num_batches = 0last_lr = Nonefor batch_idx, (xb, yb) in enumerate(pbar):# 跳过已训练的批次 (如果从中间恢复)if global_step > start_step and batch_idx < (start_step % len(train_loader)):continuexb = xb.to(CFG.DEVICE, non_blocking=True)yb = yb.to(CFG.DEVICE, non_blocking=True)with torch.cuda.amp.autocast(enabled=(CFG.DEVICE == "cuda"), dtype=CFG.DTYPE):logits = model(xb)loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))loss_item = loss.item()loss = loss / CFG.GRAD_ACCUM_STEPSif use_scaler:scaler.scale(loss).backward()else:loss.backward()grad_accum_counter += 1accumulated_loss += loss_itemnum_batches += 1# 这里的 global_step 计数是基于数据批次的,而不是优化器步数,用于日志和检查点# 真正的优化器步数会在下面更新# 梯度累积:达到阈值时执行优化步骤if grad_accum_counter >= CFG.GRAD_ACCUM_STEPS:# 优化器步进 (这是真正的 global_step 增长点)lr_scheduler.step() # 先更新 LRglobal_step += 1 # 只有进行了一次优化器步进,才算一个 global_stepif use_scaler:scaler.unscale_(optimizer)# 梯度裁剪 (在 unscale 后或非 AMP 模式下)torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)if use_scaler:scaler.step(optimizer)scaler.update()else:optimizer.step()optimizer.zero_grad()grad_accum_counter = 0last_lr = lr_scheduler.get_lr(global_step) # 获取当前步的LR# 日志输出if global_step % CFG.LOG_EVERY == 0 or (global_step == 1):# accumulated_loss 是累积的原始损失, num_batches 是累积的批次数avg_loss = accumulated_loss / num_batches if num_batches > 0 else 0.0pbar.set_postfix({'step': global_step,'loss': f'{avg_loss:.4f}','lr': f'{last_lr:.2e}' if last_lr is not None else 'N/A'})# 重置累积值以便计算下一个 LOG_EVERY 间隔的平均损失accumulated_loss = 0.0num_batches = 0# 保存检查点if global_step > start_step and global_step % CFG.CHECKPOINT_EVERY == 0:# 使用上一个日志点计算的 avg_losscurrent_avg_loss = accumulated_loss / num_batches if num_batches > 0 else loss_itemconfig_dict = {'vocab_size': CFG.VOCAB_SIZE,'block_size': CFG.BLOCK_SIZE,'model_dim': CFG.MODEL_DIM,'n_layers': CFG.N_LAYERS,'num_heads': CFG.NUM_HEADS,'created_at': datetime.now().isoformat()}save_checkpoint(model, optimizer, scaler, lr_scheduler, global_step, current_avg_loss, config_dict)torch.cuda.empty_cache()if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:break# 处理 epoch 结束时剩余的梯度 (如果 grad_accum_counter > 0)if grad_accum_counter > 0:if use_scaler:scaler.unscale_(optimizer)torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)if use_scaler:scaler.step(optimizer)scaler.update()else:optimizer.step()optimizer.zero_grad()lr_scheduler.step()global_step += 1grad_accum_counter = 0# 此时 pbar.total_loss 已累积if num_batches > 0:final_avg_loss = accumulated_loss / num_batcheselse:final_avg_loss = float('inf')if final_avg_loss < best_loss:best_loss = final_avg_lossbest_path = os.path.join(CFG.CHECKPOINT_DIR, "best_model.pth")torch.save(model.state_dict(), best_path)print(f"最佳模型已保存 (loss: {best_loss:.4f})")print(f"n[Epoch {epoch+1}] Avg Loss: {final_avg_loss:.4f}")if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:breakprint("n训练完成!")# ==================== 主函数 ====================def main():if len(sys.argv) < 4:print("用法: python train_20251012_v1.py workdir/spm_wiki_16k.model workdir/wiki_tokens_16k.pt models/gpt_wiki.pth [--resume]")sys.exit(1)sp_model_path, token_file_path, out_path = sys.argv[1:4]resume = "--resume" in sys.argvif not os.path.exists(token_file_path):print(f" Token文件不存在: {token_file_path}")sys.exit(1)# 检查 CFG.DTYPE 是否为 bfloat16 但环境不支持if CFG.DTYPE == torch.bfloat16 and not torch.cuda.is_bf16_supported():print("警告: bfloat16 不受当前 CUDA 设备支持,自动回退到 float16。")CFG.DTYPE = torch.float16sp = spm.SentencePieceProcessor(model_file=sp_model_path)CFG.VOCAB_SIZE = sp.get_piece_size()print("="*60)print("GPT 语言模型训练")print("="*60)print(f"分词器: {sp_model_path}")print(f"Token文件: {token_file_path}")print(f"输出模型: {out_path}")print(f"设备: {CFG.DEVICE}")print(f"n模型配置:")print(f" - VOCAB_SIZE: {CFG.VOCAB_SIZE}")print(f" - BLOCK_SIZE: {CFG.BLOCK_SIZE}")print(f" - MODEL_DIM: {CFG.MODEL_DIM}")print(f" - N_LAYERS: {CFG.N_LAYERS}")print(f" - NUM_HEADS: {CFG.NUM_HEADS}")print(f"n训练配置:")print(f" - BATCH_SIZE: {CFG.BATCH_SIZE}")print(f" - GRAD_ACCUM_STEPS: {CFG.GRAD_ACCUM_STEPS}")print(f" - 有效BATCH_SIZE: {CFG.BATCH_SIZE * CFG.GRAD_ACCUM_STEPS}")print(f" - LR: {CFG.LR}, WARMUP_STEPS: {CFG.WARMUP_STEPS}")print("="*60)print(f"n加载Token文件: {token_file_path}")ids = torch.load(token_file_path)print(f"已加载 {ids.numel():,} tokens ({ids.numel() * ids.element_size() / (1024**3):.2f} GB)")dataset = TextDataset(ids, CFG.BLOCK_SIZE)del idstorch.cuda.empty_cache()# 改进:启用 shuffle=True 进行预训练num_workers = CFG.NUM_WORKERStry:train_loader = DataLoader(dataset,batch_size=CFG.BATCH_SIZE,shuffle=True, # 启用 Shufflepin_memory=(CFG.DEVICE == "cuda"),num_workers=num_workers,persistent_workers=True if num_workers > 0 else False)except Exception as e:print(f"DataLoader错误: {e}, 改用num_workers=0")train_loader = DataLoader(dataset,batch_size=CFG.BATCH_SIZE,shuffle=True,pin_memory=(CFG.DEVICE == "cuda"),num_workers=0)model = GPTModel(CFG.VOCAB_SIZE,CFG.BLOCK_SIZE,dim=CFG.MODEL_DIM,num_layers=CFG.N_LAYERS,num_heads=CFG.NUM_HEADS,ffn_dim=CFG.FFN_DIM,dropout=CFG.DROPOUT).to(CFG.DEVICE)# 尝试编译(容错)try:model = torch.compile(model, mode='reduce-overhead')print("已启用 torch.compile() 加速")except Exception as e:print(f"跳过 torch.compile(): {e}")train(model, train_loader, epochs=CFG.EPOCHS, resume=resume)torch.save(model.state_dict(), out_path)print(f"n最终模型已保存到 {out_path}")print_gpu_memory()if __name__ == "__main__":main()
07、进行模型推理测试
import torchfrom torch import nnimport sentencepiece as spmfrom typing import Optional# ==================== 配置参数 (必须与训练时一致) ====================# 使用与训练脚本中完全相同的配置class Config:BLOCK_SIZE = 512# 模型尺寸参数 (必须与训练时一致)MODEL_DIM = 384N_LAYERS = 5NUM_HEADS = 6HEAD_DIM = MODEL_DIM // NUM_HEADSFFN_DIM = MODEL_DIM * 4VOCAB_SIZE = None# 推理设置DEVICE = "cuda" if torch.cuda.is_available() else "cpu"# 推理通常使用 float32 获得最佳兼容性和精度DTYPE = torch.float32CFG = Config()# ==================== RoPE 位置编码 (与训练脚本保持一致) ====================class RotaryPositionalEmbedding(nn.Module):def __init__(self, head_dim: int, max_seq_len: int = 2048):super().__init__()self.head_dim = head_dimassert head_dim % 2 == 0, "head_dim must be even"inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))self.register_buffer("inv_freq", inv_freq)self.max_seq_len = max_seq_lenself._seq_len_cached = max_seq_lenself._cos_cached = Noneself._sin_cached = Noneself._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)def _update_cos_sin_cache(self, seq_len: int, device: torch.device):if seq_len == self._seq_len_cached and self._cos_cached is not None:returnm = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", m, self.inv_freq)emb = torch.cat([freqs, freqs], dim=-1)cos = emb.cos()[None, None, :, :]sin = emb.sin()[None, None, :, :]self._cos_cached = cosself._sin_cached = sinself._seq_len_cached = seq_lendef forward(self, seq_len: int, device: Optional[torch.device] = None):if device is None:device = self.inv_freq.deviceself._update_cos_sin_cache(seq_len, device=device)return self._cos_cached.to(device), self._sin_cached.to(device)def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:return (x * cos) + (_rotate_half(x) * sin)def _rotate_half(x: torch.Tensor) -> torch.Tensor:x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)# ==================== Attention, FFN, Block, Model (与训练脚本保持一致) ====================class FlashAttention(nn.Module):def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsassert embed_dim % num_heads == 0self.head_dim = embed_dim // num_headsself.scale = self.head_dim ** -0.5self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)# 推理时通常不使用 Dropout,但模型结构需要保持一致self.attn_dropout = nn.Dropout(attn_dropout)self.rope = RotaryPositionalEmbedding(self.head_dim)def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:B, T, C = x.shapeqkv = self.qkv(x)qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)q, k, v = qkv.unbind(dim=2)q = q.permute(0, 2, 1, 3) # (B, H, T, D)k = k.permute(0, 2, 1, 3)v = v.permute(0, 2, 1, 3)cos, sin = self.rope(T, device=x.device)q = apply_rotary_emb(q, cos, sin)k = apply_rotary_emb(k, cos, sin)scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale# 注意:在推理时,通常使用 KV-Cache,这里简化为完整计算if T > 1: # 仅在序列长度大于 1 时应用 maskcausal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))[None, None, :, :]scores = scores.masked_fill(causal_mask == 0, float('-inf'))attn = torch.softmax(scores, dim=-1)# 推理时禁用 dropout# attn = self.attn_dropout(attn)out = torch.matmul(attn, v)out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)return self.out_proj(out)class FeedForward(nn.Module):def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):super().__init__()# 必须保持与训练脚本中完全相同的 nn.Sequential 结构self.net = nn.Sequential(GLU(dim, hidden_dim),nn.Dropout(dropout), # net.1: Dropout (必须保留,占位)nn.Linear(hidden_dim, dim), # net.2: Linear (与训练时一致)nn.Dropout(dropout), # net.3: Dropout (必须保留,占位))def forward(self, x):# 在推理时, model.eval() 会自动禁用所有 nn.Dropout 层,但结构不变return self.net(x)# 确保 GLU 的定义如下(与训练时一致):class GLU(nn.Module):def __init__(self, in_dim: int, out_dim: int):super().__init__()# GLU 内部只有一个 nn.Linearself.linear = nn.Linear(in_dim, out_dim * 2)def forward(self, x):x, gates = self.linear(x).chunk(2, dim=-1)return x * torch.nn.functional.silu(gates)class TransformerBlock(nn.Module):def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):super().__init__()self.ln1 = nn.LayerNorm(dim)self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)self.ln2 = nn.LayerNorm(dim)self.ff = FeedForward(dim, ffn_dim, dropout)def forward(self, x, causal_mask=None):x = x + self.attn(self.ln1(x), causal_mask)x = x + self.ff(self.ln2(x))return xclass GPTModel(nn.Module):def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,ffn_dim: int = CFG.FFN_DIM, dropout: float = 0.0, # 推理时 dropout=0tie_weights: bool = True):super().__init__()self.token_emb = nn.Embedding(vocab_size, dim)self.dropout = nn.Dropout(dropout)self.blocks = nn.ModuleList([TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)])self.ln_final = nn.LayerNorm(dim)self.lm_head = nn.Linear(dim, vocab_size, bias=False)if tie_weights:self.lm_head.weight = self.token_emb.weightself.block_size = block_sizedef forward(self, idx):B, T = idx.shapetoken_emb = self.token_emb(idx)x = token_emb # 推理时不使用 dropoutcausal_mask = None # Attention 模块内部处理 Causal Maskfor block in self.blocks:x = block(x, causal_mask)x = self.ln_final(x)logits = self.lm_head(x)return logits# ==================== 推理和生成函数 ====================@torch.no_grad()def generate_text(model: GPTModel, sp: spm.SentencePieceProcessor,prompt: str, max_new_tokens: int, temperature: float = 0.8,top_k: int = 50):model.eval()device = CFG.DEVICE# 1. 编码输入input_ids = sp.encode_as_ids(prompt)if not input_ids:return "无法编码输入。"# 将输入转换为模型期望的格式 (B, T)x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)# 2. 循环生成for _ in range(max_new_tokens):# 裁剪输入以适应模型的 BLOCK_SIZE# 在实际部署中,这里应该使用 KV Cache,但此处简化为完整前向传播idx_cond = x if x.size(1) <= CFG.BLOCK_SIZE else x[:, -CFG.BLOCK_SIZE:]# 获取 logitslogits = model(idx_cond)# 只取最后一个时间步的 logitslogits = logits[:, -1, :]# 应用温度缩放logits = logits / temperature# 3. Top-K 采样if top_k is not None:v, _ = torch.topk(logits, min(top_k, logits.size(-1)))logits[logits < v[:, [-1]]] = float('-inf')# 计算概率并采样probs = torch.softmax(logits, dim=-1)idx_next = torch.multinomial(probs, num_samples=1)# 4. 停止条件# 检查是否生成了 EOS token (假设 </s> 是 ID 3, 请根据您的分词器调整)# 默认使用 SentencePiece 的 <eos> IDif idx_next.item() == sp.eos_id():break# 将新生成的 token 添加到序列中x = torch.cat((x, idx_next), dim=1)# 检查是否达到最大序列长度 (防止溢出)if x.size(1) >= CFG.BLOCK_SIZE + max_new_tokens:break# 5. 解码输出output_ids = x[0].tolist()# 查找输入 prompt 的长度,只解码新生成的 tokenstart_index = len(input_ids)return sp.decode_ids(output_ids[start_index:])# ==================== 主执行函数 ====================def main_infer(sp_model_path: str, model_weights_path: str):print("="*50)print(f"GPT 模型推理模式")print(f"设备: {CFG.DEVICE}, DTYPE: {CFG.DTYPE}")print("="*50)# 1. 加载分词器try:sp = spm.SentencePieceProcessor(model_file=sp_model_path)CFG.VOCAB_SIZE = sp.get_piece_size()print(f"加载分词器成功,VOCAB_SIZE: {CFG.VOCAB_SIZE}")except Exception as e:print(f"无法加载分词器模型 {sp_model_path}: {e}")return# 2. 实例化模型model = GPTModel(vocab_size=CFG.VOCAB_SIZE,block_size=CFG.BLOCK_SIZE,dim=CFG.MODEL_DIM,num_layers=CFG.N_LAYERS,num_heads=CFG.NUM_HEADS,ffn_dim=CFG.FFN_DIM,dropout=0.0 # 推理时设置 dropout 为 0).to(CFG.DEVICE).to(CFG.DTYPE)# 3. 加载权重try:# 检查是否是 torch.compile 后的状态字典weights = torch.load(model_weights_path, map_location=CFG.DEVICE)# 如果权重是 DDP 或 torch.compile 包装后的,需要解包if any(k.startswith('_orig_mod.') for k in weights.keys()):weights = {k.replace('_orig_mod.', ''): v for k, v in weights.items()}model.load_state_dict(weights, strict=True)print(f"成功加载模型权重: {model_weights_path}")except Exception as e:print(f"无法加载或匹配模型权重: {e}")# 如果加载失败,打印预期键和实际键,方便调试# print("n--- 预期模型键 (部分) ---")# print(list(model.state_dict().keys())[:5])# print("n--- 载入权重键 (部分) ---")# print(list(weights.keys())[:5])return# 4. 进入交互循环print("n--- 进入交互模式 ---")print(f"输入 'exit' 或 'quit' 退出。")print(f"输入 'config' 查看当前生成参数。")print("----------------------")max_tokens = 100temperature = 0.8top_k = 50while True:try:prompt = input(">>> 输入提示词: ")if prompt.lower() in ['exit', 'quit']:breakif prompt.lower() == 'config':print(f" Max Tokens: {max_tokens}, Temp: {temperature}, Top K: {top_k}")new_max = input(" 设置 Max Tokens (回车跳过): ")new_temp = input(" 设置 Temperature (回车跳过): ")new_k = input(" 设置 Top K (回车跳过): ")if new_max: max_tokens = int(new_max)if new_temp: temperature = float(new_temp)if new_k: top_k = int(new_k)continueif not prompt.strip():continueprint("生成中...")# 执行生成output = generate_text(model, sp, prompt, max_tokens, temperature, top_k)print(f"--- 模型回复 ---n{output.strip()}")print("----------------")except KeyboardInterrupt:print("n退出生成...")breakexcept Exception as e:print(f"发生错误: {e}")if __name__ == "__main__":import sysif len(sys.argv) != 3:print("用法: python infer.py <spm模型路径> <模型权重文件路径>")# 示例用法 (请根据您的实际文件路径修改):# python infer.py tokenizer.model final_model.pthsys.exit(1)sp_model_path = sys.argv[1]model_weights_path = sys.argv[2]main_infer(sp_model_path, model_weights_path)
我们看到模型大概可以预测我们输入的下一个词,因我们训练的参数和步数很低,模型输出的乱七八糟!
本次总结
本次我们做了数据准备、数据清洗、分词器训练、模型训练、推理等,请根据步骤进行执行代码,你便可以得到一个17M参数的小模型。后面我们再加大参数进行训练,再进行监督微调。
1557