扫码加入

  • 正文
  • 相关推荐
申请入驻 产业图谱

从零实现一个17M参数的GPT预训练模型

2025/10/13
1557
加入交流群
扫码加入
获取工程师必备礼包
参与热点资讯讨论

大家好,我是写代码的中年人!

今天我们使用开源的的中文数据进行模型的预训练,下面跟着我的步骤,从零实现你的预训练模型。

本文所有代码和数据资源位置:

https://github.com/ColinAIAPP/MoiraiLM

01、预训练模型的概念

预训练模型(Pretrained Model)就是一个已经在海量数据上训练过的模型,它学会了语言的基本规律、结构和语义,然后可以拿来做各种下游任务,比如写作、翻译、问答、分类、生成代码等。

那“预训练”到底在学什么?以语言模型(LLM)为例:预训练阶段的任务通常是预测下一个词(token)。

接下来我们就一步一步实现一个17M参数的预训练模型。

02、数据准备

构建语言模型的第一要义是高质量的数据源。对于中文任务,选择维基百科开源中文数据集是一个理想起点。这个数据集包含数百万条中文百科条目,涵盖历史、文化、科技等领域,总量约数GB的纯文本数据。它开源且免费,可通过维基百科的官方转储页面下载最新版本的XML格式文件。

要解压处理这个文件我们要使用wikiextractor工具进行数据解压。安装解压命令:pip install wikiextractor解压命令:

python -m wikiextractor.WikiExtractor -b 1G -o extracted_wiki_zh zhwiki-20250920-pages-articles-multistream.xml.bz2 --json

zhwiki-20250920-pages-articles-multistream.xml.bz2:为文件名

INFO: Preprocessing 'zhwiki-20250920-pages-articles-multistream.xml.bz2' to collect template definitions: this may take some time.INFO: Preprocessed 100000 pagesINFO: Preprocessed 200000 pagesINFO: Preprocessed 300000 pagesINFO: Preprocessed 400000 pagesINFO: Preprocessed 500000 pagesINFO: Preprocessed 600000 pagesINFO: Preprocessed 700000 pagesINFO: Preprocessed 800000 pagesINFO: Preprocessed 900000 pagesINFO: Preprocessed 1000000 pagesINFO: Preprocessed 1100000 pagesINFO: Preprocessed 1200000 pagesINFO: Preprocessed 1300000 pagesINFO: Preprocessed 1400000 pagesINFO: Preprocessed 1500000 pagesINFO: Preprocessed 1600000 pagesINFO: Preprocessed 1700000 pagesINFO: Preprocessed 1800000 pagesINFO: Preprocessed 1900000 pagesINFO: Preprocessed 2000000 pagesINFO: Preprocessed 2100000 pagesINFO: Preprocessed 2200000 pagesINFO: Preprocessed 2300000 pagesINFO: Preprocessed 2400000 pagesINFO: Preprocessed 2500000 pagesINFO: Preprocessed 2600000 pagesINFO: Preprocessed 2700000 pagesINFO: Preprocessed 2800000 pagesINFO: Preprocessed 2900000 pagesINFO: Preprocessed 3000000 pagesINFO: Preprocessed 3100000 pagesINFO: Preprocessed 3200000 pagesINFO: Preprocessed 3300000 pagesINFO: Preprocessed 3400000 pagesINFO: Preprocessed 3500000 pagesINFO: Preprocessed 3600000 pagesINFO: Preprocessed 3700000 pagesINFO: Preprocessed 3800000 pagesINFO: Preprocessed 3900000 pagesINFO: Preprocessed 4000000 pagesINFO: Preprocessed 4100000 pagesINFO: Preprocessed 4200000 pagesINFO: Preprocessed 4300000 pagesINFO: Preprocessed 4400000 pagesINFO: Preprocessed 4500000 pagesINFO: Preprocessed 4600000 pagesINFO: Preprocessed 4700000 pagesINFO: Loaded 1036734 templates in 704.2sINFO: Starting page extraction from zhwiki-20250920-pages-articles-multistream.xml.bz2.INFO: Using 127 extract processes.INFO: Extracted 100000 articles (1209.6 art/s)INFO: Extracted 200000 articles (1947.8 art/s)INFO: Extracted 300000 articles (2325.1 art/s)INFO: Extracted 400000 articles (3471.3 art/s)INFO: Extracted 500000 articles (2551.1 art/s)INFO: Extracted 600000 articles (2239.4 art/s)INFO: Extracted 700000 articles (2299.3 art/s)INFO: Extracted 800000 articles (1525.2 art/s)INFO: Extracted 900000 articles (3256.1 art/s)INFO: Extracted 1000000 articles (3485.9 art/s)INFO: Extracted 1100000 articles (3495.0 art/s)INFO: Extracted 1200000 articles (3330.4 art/s)INFO: Extracted 1300000 articles (3555.6 art/s)INFO: Extracted 1400000 articles (3456.3 art/s)INFO: Extracted 1500000 articles (2476.1 art/s)INFO: Extracted 1600000 articles (2268.6 art/s)INFO: Extracted 1700000 articles (2473.5 art/s)INFO: Extracted 1800000 articles (2305.9 art/s)INFO: Extracted 1900000 articles (2263.9 art/s)INFO: Extracted 2000000 articles (2136.4 art/s)INFO: Extracted 2100000 articles (2363.0 art/s)INFO: Extracted 2200000 articles (2601.9 art/s)INFO: Extracted 2300000 articles (3709.0 art/s)INFO: Extracted 2400000 articles (2723.9 art/s)INFO: Extracted 2500000 articles (2487.1 art/s)INFO: Extracted 2600000 articles (2621.3 art/s)INFO: Extracted 2700000 articles (2525.4 art/s)INFO: Extracted 2800000 articles (2666.4 art/s)INFO: Finished 127-process extraction of 2893023 articles in 1156.5s (2501.5 art/s)

03、清洗数据

我们解压后的数据如下图,下面我们要把数据清洗出来。

注:我们本步骤生成的文件为 data/cleaned_wiki_full.txt

import osimport jsonimport loggingimport argparseimport refrom tqdm import tqdm
# 配置日志记录logging.basicConfig(level=logging.INFO,                     format='%(asctime)s - %(levelname)s - %(message)s')
# python scripts/clean_wiki_text.py data/extracted_wiki_zh data/cleaned_wiki_full.txt --min_line_length 20 --min_article_length 300

def clean_text(text: str) -> str:    """    对文本进行深度清洗。    移除维基百科特有的格式标记、参考文献、HTML标签、日期和数字等。    """    # 移除维基链接 [[link|display]] 或 [[link]]    text = re.sub(r'[[([^]|]+|)?([^]]+)]]', r'2', text)
    # 移除参考文献标记 [1], [2], [ref], 等    text = re.sub(r'[d+]|[ref]|[/ref]|[citation needed]', '', text)
    # 移除HTML标签    text = re.sub(r'<[^>]+>', '', text)
    # 移除日期格式 (yyyy-mm-dd, yyyy/mm/dd, mm/dd/yyyy 等)    text = re.sub(r'd{1,4}[-/]d{1,2}[-/]d{1,4}', '', text)
    # 移除年份 (1000-2999)    text = re.sub(r'b[12]d{3}b', '', text)
    # 移除纯数字(包括小数)    text = re.sub(r'bd+.?d*b', '', text)
    # 移除重复的空白字符(但保留单个空格)    text = re.sub(r' +', ' ', text)
    # 移除行首尾空白    text = text.strip()
    return text

def process_extracted_wiki(extracted_dir: str,                           output_file: str,                           min_line_length: int = 20,                           min_article_length: int = 200):    """    读取WikiExtractor输出的JSON文件,提取、清洗文本并保存到单个文件中。
    参数:        extracted_dir: WikiExtractor输出的目录路径        output_file: 最终合并的纯文本文件路径        min_line_length: 单行文本最小长度,用于过滤噪音(默认: 20)        min_article_length: 文章最小长度,用于过滤短文章(默认: 200)    """    if not os.path.isdir(extracted_dir):        logging.error(f"输入的目录不存在: {extracted_dir}")        return
    total_articles = 0    skipped_articles = 0
    # 第一次遍历:获取所有需要处理的文件列表    file_list = []    for root, dirs, files in os.walk(extracted_dir):        for file_name in files:            # 仅处理 WikiExtractor 生成的以 'wiki_' 开头的文件            if file_name.startswith('wiki_'):                file_list.append(os.path.join(root, file_name))
    total_files = len(file_list)    logging.info(f"找到 {total_files} 个文件等待处理。")
    if total_files == 0:        logging.warning(f"目录 {extracted_dir} 中未找到任何 'wiki_' 文件。请检查路径。")        return
    # 第二次遍历:处理文件并写入输出    with open(output_file, 'w', encoding='utf-8') as f_out:        # 使用 tqdm 包装文件列表,显示处理进度        for file_path in tqdm(file_list, desc=" 正在提取维基文本"):            try:                with open(file_path, 'r', encoding='utf-8') as f_in:                    for line_num, line in enumerate(f_in, 1):                        try:                            article = json.loads(line)                            text_content = article.get('text', '').strip()
                            # --- 文本清洗和过滤 ---
                            # 1. 过滤掉过短的文章,它们通常是噪音或重定向页                            if len(text_content) < min_article_length:                                skipped_articles += 1                                continue
                            # 2. 按行处理文本,过滤短行和额外的空白                            # 保留行结构,而不是将所有行连接成一个长句子                            cleaned_lines = []                            for text_line in text_content.split('n'):                                text_line = clean_text(text_line)                                # 只保留足够长的行                                if len(text_line) >= min_line_length:                                    cleaned_lines.append(text_line)
                            # 使用换行符连接各行,保留段落结构                            final_text = 'n'.join(cleaned_lines)
                            # 最终检查:确保清洗后的文本仍然足够长                            if final_text and len(final_text) >= min_article_length:                                # 文章之间用两个换行符分隔                                f_out.write(final_text + 'nn')                                total_articles += 1                            else:                                skipped_articles += 1
                        except json.JSONDecodeError:                            logging.warning(f"无法解析 JSON,文件: {file_path},行号: {line_num}")                        except Exception as e:                            logging.error(f"处理文件 {file_path} 第 {line_num} 行时出错: {e}")
            except Exception as e:                logging.error(f"打开文件 {file_path} 时出错: {e}")
    logging.info(f"  所有维基百科文本已成功提取并清洗。")    logging.info(f"   总文章数: {total_articles}")    logging.info(f"   跳过文章数: {skipped_articles}")    logging.info(f"   文件已保存到: {output_file}")

def main():    parser = argparse.ArgumentParser(        description="从 WikiExtractor 输出的 JSON 文件中提取并清洗纯文本。",        formatter_class=argparse.RawTextHelpFormatter    )
    # 位置参数 1: 输入目录    parser.add_argument(        "extracted_directory",        type=str,        help="WikiExtractor 输出的目录路径 (e.g., extracted_wiki_zh)"    )
    # 位置参数 2: 输出文件    parser.add_argument(        "output_filename",        type=str,        help="最终合并的纯文本文件路径 (e.g., cleaned_wiki.txt)"    )
    # 可选参数: 最小行长    parser.add_argument(        "--min_line_length",        type=int,        default=20,        help="文章中单行文本必须达到的最小长度,用于过滤噪音。默认值: 20"    )
    # 可选参数: 最小文章长度    parser.add_argument(        "--min_article_length",        type=int,        default=200,        help="文章最小长度,用于过滤短文章和重定向页。默认值: 200"    )
    args = parser.parse_args()
    process_extracted_wiki(        args.extracted_directory,         args.output_filename,         args.min_line_length,        args.min_article_length    )

if __name__ == "__main__":    main()
2025-10-01 11:10:58,772 - INFO - 找到 5 个文件等待处理。正在提取维基文本: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00,  6.78s/it]2025-10-01 11:11:32,681 - INFO - 所有维基百科文本已成功提取。总文章数: 628093。文件已保存到 data/cleaned_wiki_full.txt

04、训练分词器

我们使用SentencePiece训练分词器,本次我们训练的分词库大小为16k,你也可以训练32k的分词库。相关代码及过程如下:

注:我们本步骤生成的文件为workdir/spm_wiki_16k.modelworkdir/spm_wiki_16k.vocab

import sysimport sentencepiece as spmimport argparseimport osfrom tqdm import tqdm
# python scripts/train_tokenizer.py data/cleaned_wiki_full.txt workdir/spm_wiki 32000

def get_corpus_size(input_file: str) -> int:    """计算语料的总行数和文件大小"""    try:        file_size_bytes = os.path.getsize(input_file)        file_size_mb = file_size_bytes / (1024 * 1024)        print(f"语料文件大小: {file_size_mb:.2f} MB")
        # 计算行数和总字符数        line_count = 0        total_chars = 0        with open(input_file, 'r', encoding='utf-8') as f:            for line in tqdm(f, desc="统计语料信息"):                line_count += 1                total_chars += len(line)
        print(f"语料总行数 (文章数): {line_count}")        print(f"总字符数: {total_chars:,}")        print(f"平均每行字符数: {total_chars / line_count:.1f}")        return file_size_bytes    except Exception as e:        print(f"警告:无法计算文件大小或行数:{e}")        return 0

def train_spm_model(input_file: str,                     model_prefix: str,                     vocab_size: int,                    model_type: str = 'bpe',                    character_coverage: float = 0.9995):    """    训练一个SentencePiece分词器模型。
    参数:        input_file: 训练语料文件路径        model_prefix: 输出模型文件的前缀        vocab_size: 词汇表大小        model_type: 分词算法类型 ('bpe', 'unigram', 'char', 'word')        character_coverage: 字符覆盖率 (0-1,通常 0.995-0.9995)    """    if not os.path.exists(input_file):        print(f"错误:输入语料文件未找到:{input_file}")        sys.exit(1)
    # 确保输出目录存在    output_dir = os.path.dirname(model_prefix)    if output_dir and not os.path.exists(output_dir):        os.makedirs(output_dir, exist_ok=True)        print(f"已创建输出目录: {output_dir}")
    # 打印语料规模信息    print("n=== 语料分析 ===")    get_corpus_size(input_file)
    # 构建训练参数    # 对于 1.5GB 语料,建议启用 train_extremely_large_corpus=True 加速    train_params = {        'input': input_file,        'model_prefix': model_prefix,        'vocab_size': vocab_size,        'model_type': model_type,        'character_coverage': character_coverage,        'num_threads': 32,        # 增加到32(最大化CPU利用)        'bos_id': 0,        'eos_id': 1,        'unk_id': 2,        'pad_id': -1,        'normalization_rule_name': 'identity',        'input_sentence_size': 2000000, # 5000000,         # 增加到500万句子采样        'train_extremely_large_corpus': True,   # 必须启用        'shuffle_input_sentence': True,        'seed_sentencepiece_size': 2000000,     # 添加种子句子大小        'hard_vocab_limit': False,              # 允许超过目标词汇量以获得更好质量    }
    print("n=== SentencePiece 训练参数 ===")    for key, value in train_params.items():        print(f"  {key}: {value}")    print("=" * 35)
    print("n正在训练 SentencePiece 模型...")    print("   (请稍候,进度由 SentencePiece 输出)n")
    try:        # 执行训练        spm.SentencePieceTrainer.train(**train_params)
        print("n分词器模型训练完成!")        print(f"   模型文件: {model_prefix}.model")        print(f"   词汇表文件: {model_prefix}.vocab")
        # 验证模型是否成功创建        if os.path.exists(f"{model_prefix}.model") and os.path.exists(f"{model_prefix}.vocab"):            model_size_kb = os.path.getsize(f"{model_prefix}.model") / 1024            print(f"n模型文件大小: {model_size_kb:.2f} KB")
            # 加载模型进行快速测试            print("n进行快速测试...")            sp = spm.SentencePieceProcessor(model_file=f"{model_prefix}.model")
            test_text = "这是一个分词测试句子。"            tokens = sp.encode(test_text, out_type=str)            ids = sp.encode(test_text, out_type=int)
            print(f" 测试文本: {test_text}")            print(f" 分词结果: {tokens}")            print(f" Token IDs: {ids}")        else:            print("n警告:模型文件生成失败,请检查输入数据或参数")
    except Exception as e:        print(f"n训练过程出错: {e}")        sys.exit(1)

def main():    parser = argparse.ArgumentParser(        description="使用 SentencePiece 训练分词器模型。",        formatter_class=argparse.RawTextHelpFormatter    )
    parser.add_argument(        "input_file",        type=str,        help="训练语料的路径 (e.g., data/cleaned_wiki_full.txt)"    )
    parser.add_argument(        "model_prefix",        type=str,        help="训练模型文件的输出前缀 (e.g., workdir/spm_wiki)"    )
    parser.add_argument(        "vocab_size",        type=int,        help="词汇表大小 (e.g., 32000)"    )
    parser.add_argument(        "--model_type",        type=str,        default='bpe',        choices=['bpe', 'unigram', 'char', 'word'],        help="分词算法类型 (默认: bpe)"    )
    parser.add_argument(        "--character_coverage",        type=float,        default=0.9995,        help="字符覆盖率,范围 [0-1]。对于小词表(8K),建议用0.99或更小"    )
    args = parser.parse_args()
    print("n" + "="*50)    print("SentencePiece 分词器训练程序")    print("="*50)    print(f"输入语料: {args.input_file}")    print(f"输出模型前缀: {args.model_prefix}")    print(f"词汇表大小: {args.vocab_size}")    print(f"分词算法: {args.model_type}")    print(f"字符覆盖率: {args.character_coverage}")    print("="*50 + "n")
    train_spm_model(        args.input_file,         args.model_prefix,         args.vocab_size,        args.model_type,        args.character_coverage    )

if __name__ == "__main__":    main()
 开始训练SentencePiece分词器...   输入语料: data/cleaned_wiki_full.txt   输出模型前缀: workdir/spm_wiki_16k   词汇表大小: 16000   语料文件大小: 1697.54 MBCounting lines: 1256186it [00:05, 230354.42it/s]   语料总行数 (文章数): 1256186
--- SentencePiece 训练参数 -----input=data/cleaned_wiki_full.txt--model_prefix=workdir/spm_wiki_16k--vocab_size=16000--model_type=bpe--character_coverage=0.9995--num_threads=16--bos_id=0--eos_id=1--unk_id=2--pad_id=-1 ------------------------------
⏳ 正在启动训练... 请注意观察 SentencePiece 自身的进度输出。sentencepiece_trainer.cc(178) LOG(INFO) Running command: --input=data/cleaned_wiki_full.txt --model_prefix=workdir/spm_colinai_16000 --vocab_size=16000 --model_type=bpe --character_coverage=0.9995 --num_threads=16 --bos_id=0 --eos_id=1 --unk_id=2 --pad_id=-1 sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : trainer_spec {  input: data/cleaned_wiki_full.txt  input_format:   model_prefix: workdir/spm_colinai_16000  model_type: BPE  vocab_size: 16000  self_test_sample_size: 0  character_coverage: 0.9995  input_sentence_size: 0  shuffle_input_sentence: 1  seed_sentencepiece_size: 1000000  shrinking_factor: 0.75  max_sentence_length: 4192  num_threads: 16  num_sub_iterations: 2  max_sentencepiece_length: 16  split_by_unicode_script: 1  split_by_number: 1  split_by_whitespace: 1  split_digits: 0  pretokenization_delimiter:   treat_whitespace_as_suffix: 0  allow_whitespace_only_pieces: 0  required_chars:   byte_fallback: 0  vocabulary_output_piece_score: 1  train_extremely_large_corpus: 0  seed_sentencepieces_file:   hard_vocab_limit: 1  use_all_vocab: 0  unk_id: 2  bos_id: 0  eos_id: 1  pad_id: -1  unk_piece: <unk>  bos_piece: <s>  eos_piece: </s>  pad_piece: <pad>  unk_surface:  ⁇   enable_differential_privacy: 0  differential_privacy_noise_level: 0  differential_privacy_clipping_threshold: 0}normalizer_spec {  name: nmt_nfkc  add_dummy_prefix: 1  remove_extra_whitespaces: 1  escape_whitespaces: 1  normalization_rule_tsv: }denormalizer_spec {}trainer_interface.cc(355) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.trainer_interface.cc(186) LOG(INFO) Loading corpus: data/cleaned_wiki_full.txttrainer_interface.cc(382) LOG(WARNING) Found too long line (18615 > 4192).trainer_interface.cc(384) LOG(WARNING) Too long lines are skipped in the training.trainer_interface.cc(385) LOG(WARNING) The maximum length can be changed with --max_sentence_length=<size> flag.trainer_interface.cc(411) LOG(INFO) Loaded all 528882 sentencestrainer_interface.cc(418) LOG(INFO) Skipped 99211 too long sentences.trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <s>trainer_interface.cc(427) LOG(INFO) Adding meta_piece: </s>trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <unk>trainer_interface.cc(432) LOG(INFO) Normalizing sentences...trainer_interface.cc(541) LOG(INFO) all chars count=281809036trainer_interface.cc(552) LOG(INFO) Done: 99.95% characters are covered.trainer_interface.cc(562) LOG(INFO) Alphabet size=8686trainer_interface.cc(563) LOG(INFO) Final character coverage=0.9995trainer_interface.cc(594) LOG(INFO) Done! preprocessed 528882 sentences.trainer_interface.cc(600) LOG(INFO) Tokenizing input sentences with whitespace: 528882trainer_interface.cc(611) LOG(INFO) Done! 3885388.....

05、原始文本转为Token ID 序列

在训练大型语言模型的准备阶段,将海量文本语料转化为模型可处理的数字格式至关重要。本次将原始文本语料编码为整数 Token ID 序列。为了克服单次加载大文件的内存限制,脚本采用了分块读取机制,支持以自定义大小逐块处理语料。所有 Token ID 最终被汇总并转化为高效率的 torch.int32 PyTorch 张量,直接存储为 .pt 文件。这不仅优化了数据格式,方便后续 PyTorch DataLoader 快速读取,同时也提供了关键的统计信息和完整性验证,是构建 LLM 数据集的稳定且高性能的预处理方案。

import sysimport torchimport sentencepiece as spmimport argparsefrom tqdm import tqdmimport osimport numpy as np
# python scripts/preprocess_data.py workdir/spm_wiki.model data/cleaned_wiki_full.txt workdir/wiki_tokens.pt

def preprocess(sp_model_path: str,                corpus_path: str,                output_path: str,               chunk_size_mb: int = 50):    """    分块读取语料,编码为 Token ID,并保存为 PyTorch 文件。
    参数:        sp_model_path: SentencePiece 模型文件路径        corpus_path: 输入语料文件路径        output_path: 输出 .pt 文件路径        chunk_size_mb: 每次处理的文本大小(MB),默认 50MB    """    # 验证文件存在    if not os.path.exists(sp_model_path):        print(f"错误:分词器模型文件未找到: {sp_model_path}")        sys.exit(1)
    if not os.path.exists(corpus_path):        print(f"错误:语料文件未找到: {corpus_path}")        sys.exit(1)
    # 加载分词器    try:        sp = spm.SentencePieceProcessor(model_file=sp_model_path)        vocab_size = sp.get_piece_size()        print(f"   分词器加载成功")        print(f"   词汇表大小: {vocab_size}")        print(f"   特殊 Token: BOS={sp.bos_id()}, EOS={sp.eos_id()}, UNK={sp.unk_id()}, PAD={sp.pad_id()}")    except Exception as e:        print(f"加载分词器失败: {e}")        sys.exit(1)
    # 确保输出目录存在    output_dir = os.path.dirname(output_path)    if output_dir and not os.path.exists(output_dir):        os.makedirs(output_dir, exist_ok=True)
    print(f"n 开始处理语料...")    print(f"   输入文件: {corpus_path}")    print(f"   输出文件: {output_path}")    print(f"   块大小: {chunk_size_mb} MBn")
    # 计算总大小用于进度条    total_bytes = os.path.getsize(corpus_path)    chunk_size_bytes = chunk_size_mb * 1024 * 1024
    token_ids = []    tokens_processed = 0    chunks_processed = 0
    try:        with open(corpus_path, 'r', encoding='utf-8') as f:            with tqdm(total=total_bytes, unit='B', unit_scale=True, desc="⏳ 编码语料") as pbar:
                while True:                    chunk = f.read(chunk_size_bytes)                    if not chunk:                        break
                    # 直接编码(cleaned_wiki_full.txt 已经过清洗)                    ids = sp.encode(chunk, out_type=int)                    token_ids.extend(ids)
                    # 更新进度条                    bytes_read = len(chunk.encode('utf-8'))                    pbar.update(bytes_read)
                    tokens_processed += len(ids)                    chunks_processed += 1
                    # 定期显示进度信息                    if chunks_processed % 10 == 0:                        pbar.set_postfix({                            'chunks': chunks_processed,                            'tokens': f'{tokens_processed:,}'                        })
        print(f"n 编码完成")        print(f"   处理块数: {chunks_processed}")        print(f"   总 Token 数: {tokens_processed:,}")
        # 转换为 PyTorch 张量        print(f"n转换为张量并保存...")        final_tensor = torch.tensor(token_ids, dtype=torch.int32)
        print(f"   张量形状: {final_tensor.shape}")        print(f"   张量大小: {final_tensor.numel():,}")        print(f"   数据类型: {final_tensor.dtype}")        print(f"   占用内存: {final_tensor.numel() * 4 / (1024**3):.2f} GB")
        # 验证 Token ID 范围        min_id = final_tensor.min().item()        max_id = final_tensor.max().item()        print(f"   Token ID 范围: [{min_id}, {max_id}]")
        if max_id >= vocab_size or min_id < 0:            print(f"   警告: 检测到越界 Token ID!")            print(f"   词汇表大小: {vocab_size}")            print(f"   最大 ID: {max_id}")
        # 保存张量        torch.save(final_tensor, output_path)        file_size_mb = os.path.getsize(output_path) / (1024 ** 2)        print(f"nToken ID 已保存到 {output_path}")        print(f"   文件大小: {file_size_mb:.2f} MB")
        # 验证保存的文件        print(f"n验证保存的文件...")        loaded_tensor = torch.load(output_path)        print(f"   加载成功,形状: {loaded_tensor.shape}")        print(f"   是否相同: {torch.equal(final_tensor, loaded_tensor)}")
        print(f"n✨ 预处理完成!")
    except Exception as e:        print(f"n处理过程中出错: {e}")        import traceback        traceback.print_exc()        sys.exit(1)

def main():    parser = argparse.ArgumentParser(        description="将清洗后的文本语料转换为 Token ID 二进制文件。",        formatter_class=argparse.RawTextHelpFormatter    )
    parser.add_argument(        "model_path",        type=str,        help="SentencePiece 模型文件路径 (e.g., workdir/spm_wiki.model)"    )
    parser.add_argument(        "corpus_path",        type=str,        help="输入语料文件路径 (e.g., data/cleaned_wiki_full.txt)"    )
    parser.add_argument(        "output_path",        type=str,        help="输出 Token ID 文件路径 (e.g., workdir/wiki_tokens.pt)"    )
    parser.add_argument(        "--chunk_size",        type=int,        default=50,        help="每次处理的文本大小(MB),默认 50MB。更大的块更快,但占用更多内存。"    )
    args = parser.parse_args()
    print("n" + "="*60)    print("数据预处理程序 - 文本到 Token ID")    print("="*60)    print(f"SentencePiece 模型: {args.model_path}")    print(f"输入语料: {args.corpus_path}")    print(f"输出文件: {args.output_path}")    print(f"块大小: {args.chunk_size} MB")    print("="*60 + "n")
    preprocess(        args.model_path,        args.corpus_path,        args.output_path,        args.chunk_size    )

if __name__ == "__main__":    main()

06、进行模型预训练

"""GPT 高性能训练脚本"""
from __future__ import annotationsimport sysimport osimport mathimport jsonfrom datetime import datetimefrom typing import Optional
import torchfrom torch import nnfrom torch.utils.data import Dataset, DataLoaderimport sentencepiece as spmfrom tqdm import tqdm
# ==================== 配置参数 ====================class Config:    BLOCK_SIZE = 512 #256    BATCH_SIZE = 32 #64    GRAD_ACCUM_STEPS = 4 #1    MODEL_DIM = 384 #256    N_LAYERS = 5 #2    NUM_HEADS = 6 #4    HEAD_DIM = MODEL_DIM // NUM_HEADS    FFN_DIM = MODEL_DIM * 4    VOCAB_SIZE = None
    EPOCHS = 1    MAX_STEPS = 10000 # 此处根据自己的硬件和时间定义步数    WARMUP_STEPS = 500    LR = 1e-4    MIN_LR = 1e-5    WEIGHT_DECAY = 0.01    GRAD_CLIP = 1.0    DROPOUT = 0.1
    CHECKPOINT_EVERY = 5000    LOG_EVERY = 100
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"    CHECKPOINT_DIR = "./checkpoints"    LATEST_CHECKPOINT = "latest_checkpoint.pth"
    NUM_WORKERS = 8    SEED = 42
    # 启用 bfloat16 (推荐用于现代 GPU)    DTYPE = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16
CFG = Config()
if CFG.DEVICE == 'cuda':    torch.backends.cuda.matmul.allow_tf32 = True    torch.backends.cudnn.allow_tf32 = True    torch.cuda.empty_cache()    # 检查是否使用了 bfloat16    if CFG.DTYPE == torch.bfloat16:        print("使用 bfloat16 混合精度 (推荐)")    else:        print("使用 float16 混合精度")
# ==================== 工具函数 ====================def print_gpu_memory():    if torch.cuda.is_available():        allocated = torch.cuda.memory_allocated() / (1024**3)        reserved = torch.cuda.memory_reserved() / (1024**3)        print(f"GPU显存: {allocated:.2f}GB / {reserved:.2f}GB")
def set_seed(seed: int):    torch.manual_seed(seed)    if torch.cuda.is_available():        torch.cuda.manual_seed_all(seed)
set_seed(CFG.SEED)
# ==================== 数据集 ====================class TextDataset(Dataset):    def __init__(self, token_ids: torch.Tensor, block_size: int):        self.ids = token_ids.long()        self.block_size = block_size
    def __len__(self):        return max(0, self.ids.size(0) - self.block_size)
    def __getitem__(self, idx):        x = self.ids[idx: idx + self.block_size]        y = self.ids[idx + 1: idx + 1 + self.block_size]        return x, y
# ==================== RoPE 位置编码  ====================class RotaryPositionalEmbedding(nn.Module):    """RoPE 实现"""    def __init__(self, head_dim: int, max_seq_len: int = 2048):        super().__init__()        self.head_dim = head_dim        assert head_dim % 2 == 0, "head_dim must be even"
        # 基频:theta_i = 10000^(-2i/d)        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))        self.register_buffer("inv_freq", inv_freq)
        self.max_seq_len = max_seq_len        self._seq_len_cached = max_seq_len        self._cos_cached = None        self._sin_cached = None        self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)
    def _update_cos_sin_cache(self, seq_len: int, device: torch.device):        if seq_len == self._seq_len_cached and self._cos_cached is not None:            return
        # m: (seq_len,), theta_i: (head_dim//2,)        m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)        freqs = torch.einsum("i,j->ij", m, self.inv_freq)  # (seq_len, head_dim//2)
        # 构造完整的旋转矩阵(每个复数对重复)        emb = torch.cat([freqs, freqs], dim=-1)  # (seq_len, head_dim)
        cos = emb.cos()[None, None, :, :]  # (1, 1, seq_len, head_dim)        sin = emb.sin()[None, None, :, :]  # (1, 1, seq_len, head_dim)
        self._cos_cached = cos        self._sin_cached = sin        self._seq_len_cached = seq_len
    def forward(self, seq_len: int, device: Optional[torch.device] = None):        if device is None:            device = self.inv_freq.device        self._update_cos_sin_cache(seq_len, device=device)        return self._cos_cached.to(device), self._sin_cached.to(device)
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:    """应用RoPE旋转"""    # x: (B, H, T, D), cos/sin: (1, 1, T, D)    # 使用(x, y) -> (x*cos-y*sin, x*sin+y*cos)    return (x * cos) + (_rotate_half(x) * sin)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:    """将向量旋转90度"""    x1 = x[..., : x.shape[-1] // 2]    x2 = x[..., x.shape[-1] // 2 :]    return torch.cat((-x2, x1), dim=-1)
# ==================== Flash Attention ====================class FlashAttention(nn.Module):    def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):        super().__init__()        self.embed_dim = embed_dim        self.num_heads = num_heads        assert embed_dim % num_heads == 0        self.head_dim = embed_dim // num_heads        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)        self.attn_dropout = nn.Dropout(attn_dropout)        self.rope = RotaryPositionalEmbedding(self.head_dim)
    def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:        B, T, C = x.shape        assert T <= self.rope.max_seq_len, f"Seq len {T} exceeds max {self.rope.max_seq_len}"
        qkv = self.qkv(x)        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)        q, k, v = qkv.unbind(dim=2)        q = q.permute(0, 2, 1, 3)  # (B, H, T, D)        k = k.permute(0, 2, 1, 3)        v = v.permute(0, 2, 1, 3)
        # 应用RoPE        cos, sin = self.rope(T, device=x.device)        q = apply_rotary_emb(q, cos, sin)        k = apply_rotary_emb(k, cos, sin)
        # 注意力计算        # 注意:这里如果使用 torch.nn.functional.scaled_dot_product_attention 配合 torch.compile 会更快        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale        if causal_mask is not None:            scores = scores.masked_fill(causal_mask == 0, float('-inf'))        attn = torch.softmax(scores, dim=-1)        attn = self.attn_dropout(attn)        out = torch.matmul(attn, v)        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)        return self.out_proj(out)
# ==================== 前馈网络 ====================class GLU(nn.Module):    def __init__(self, in_dim: int, out_dim: int):        super().__init__()        self.linear = nn.Linear(in_dim, out_dim * 2)
    def forward(self, x):        x, gates = self.linear(x).chunk(2, dim=-1)        return x * torch.nn.functional.silu(gates)
class FeedForward(nn.Module):    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):        super().__init__()        self.net = nn.Sequential(            GLU(dim, hidden_dim),            nn.Dropout(dropout),            nn.Linear(hidden_dim, dim),            nn.Dropout(dropout),        )
    def forward(self, x):        return self.net(x)
# ==================== Transformer Block ====================class TransformerBlock(nn.Module):    def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):        super().__init__()        self.ln1 = nn.LayerNorm(dim)        self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)        self.ln2 = nn.LayerNorm(dim)        self.ff = FeedForward(dim, ffn_dim, dropout)
    def forward(self, x, causal_mask=None):        x = x + self.attn(self.ln1(x), causal_mask)        x = x + self.ff(self.ln2(x))        return x
# ==================== GPT 模型(已移除 pos_emb) ====================class GPTModel(nn.Module):    def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,                 num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,                 ffn_dim: int = CFG.FFN_DIM, dropout: float = CFG.DROPOUT,                 tie_weights: bool = True):        super().__init__()        self.token_emb = nn.Embedding(vocab_size, dim)        # self.pos_emb = nn.Embedding(block_size, dim) # 移除:与 RoPE 冲突        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([            TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)        ])
        self.ln_final = nn.LayerNorm(dim)        self.lm_head = nn.Linear(dim, vocab_size, bias=False)
        if tie_weights:            self.lm_head.weight = self.token_emb.weight
        self.block_size = block_size        self.apply(self._init_weights)
        n_params = sum(p.numel() for p in self.parameters())        print(f"模型参数: {n_params/1e6:.2f}M")
    def _init_weights(self, module):        if isinstance(module, nn.Linear):            nn.init.normal_(module.weight, mean=0.0, std=0.02)            if module.bias is not None:                nn.init.zeros_(module.bias)        elif isinstance(module, nn.Embedding):            nn.init.normal_(module.weight, mean=0.0, std=0.02)        elif isinstance(module, nn.LayerNorm):            nn.init.ones_(module.weight)            nn.init.zeros_(module.bias)
    def forward(self, idx):        B, T = idx.shape        assert T <= self.block_size, f"Seq len {T} exceeds block_size {self.block_size}"
        token_emb = self.token_emb(idx)
               x = self.dropout(token_emb) # token embedding
        causal_mask = torch.tril(torch.ones(T, T, device=idx.device, dtype=torch.bool))[None, None, :, :]        for block in self.blocks:            x = block(x, causal_mask)        x = self.ln_final(x)        logits = self.lm_head(x)        return logits
# ==================== 检查点管理 ====================def save_checkpoint(model, optimizer, scaler, lr_scheduler, step: int, loss: float, config_dict: dict):    os.makedirs(CFG.CHECKPOINT_DIR, exist_ok=True)    checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)    state = {        'step': step,        'loss': loss,        'model_state_dict': model.state_dict(),        'optimizer_state_dict': optimizer.state_dict(),        'config': config_dict,        'torch_rng_state': torch.get_rng_state(),        'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,    }
    if scaler is not None and hasattr(scaler, "state_dict"):        state['scaler_state_dict'] = scaler.state_dict()
    if lr_scheduler is not None:        state['lr_scheduler_state_dict'] = {            'current_step': lr_scheduler.current_step,            'warmup_steps': lr_scheduler.warmup_steps,            'total_steps': lr_scheduler.total_steps,            'base_lr': lr_scheduler.base_lr,            'min_lr': lr_scheduler.min_lr,        }
    torch.save(state, checkpoint_path)
    try:        with open(os.path.join(CFG.CHECKPOINT_DIR, "config.json"), "w", encoding="utf-8") as f:            json.dump(config_dict, f, indent=2)    except Exception:        pass
    print(f" 检查点已保存: {checkpoint_path} (step {step}, loss {loss:.4f})")
def load_checkpoint(checkpoint_path: str, model, optimizer, scaler, lr_scheduler):    if not os.path.exists(checkpoint_path):        return None
    checkpoint = torch.load(checkpoint_path, map_location=CFG.DEVICE)    model.load_state_dict(checkpoint['model_state_dict'])    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if checkpoint.get('scaler_state_dict') is not None and scaler is not None:        try:            scaler.load_state_dict(checkpoint['scaler_state_dict'])        except Exception as e:            print(f"无法恢复scaler: {e}")
    if checkpoint.get('lr_scheduler_state_dict') is not None and lr_scheduler is not None:        try:            sched_state = checkpoint['lr_scheduler_state_dict']            lr_scheduler.current_step = sched_state['current_step']            lr_scheduler.warmup_steps = sched_state['warmup_steps']            lr_scheduler.total_steps = sched_state['total_steps']            lr_scheduler.base_lr = sched_state['base_lr']            lr_scheduler.min_lr = sched_state['min_lr']        except Exception as e:            print(f"无法恢复lr_scheduler: {e}")
    torch.set_rng_state(checkpoint['torch_rng_state'])    if torch.cuda.is_available() and checkpoint.get('cuda_rng_state') is not None:        torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
    print(f"检查点已加载: {checkpoint_path}")    print(f"    Step: {checkpoint['step']}, Loss: {checkpoint['loss']:.4f}")    return checkpoint['step']
# ==================== 学习率调度器 ====================class WarmupCosineScheduler:    def __init__(self, optimizer, warmup_steps: int, total_steps: int, base_lr: float, min_lr: float):        self.optimizer = optimizer        self.warmup_steps = max(0, int(warmup_steps))        self.total_steps = max(1, int(total_steps))        self.base_lr = base_lr        self.min_lr = min_lr        self.current_step = 0
    def get_lr(self, step: int = None) -> float:        """计算给定step的学习率(不修改optimizer)"""        if step is None:            step = self.current_step
        if step < self.warmup_steps and self.warmup_steps > 0:            return self.base_lr * (step / float(self.warmup_steps))        else:            denom = max(1, (self.total_steps - self.warmup_steps))            progress = (step - self.warmup_steps) / denom            progress = min(1.0, max(0.0, progress))            return self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress))
    def step(self):        """执行一次步长更新"""        lr = self.get_lr(self.current_step)        for param_group in self.optimizer.param_groups:            param_group['lr'] = lr        self.current_step += 1        return lr
# ==================== 训练循环 ====================def train(model: nn.Module, train_loader: DataLoader, epochs: int = CFG.EPOCHS, resume: bool = False):    # 检测fused优化器支持    fused = False    try:        fused = torch.cuda.is_available() and ("fused" in torch.optim.AdamW.__init__.__code__.co_varnames)    except Exception:        fused = False
    optimizer = torch.optim.AdamW(        model.parameters(),        lr=CFG.LR,        betas=(0.9, 0.95),        weight_decay=CFG.WEIGHT_DECAY,        fused=fused    )
    # 使用配置中的 DTYPE    scaler = torch.cuda.amp.GradScaler(enabled=(CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16))    loss_fn = nn.CrossEntropyLoss()
    total_steps = CFG.MAX_STEPS if CFG.MAX_STEPS else len(train_loader) * epochs    lr_scheduler = WarmupCosineScheduler(optimizer, CFG.WARMUP_STEPS, total_steps, CFG.LR, CFG.MIN_LR)
    model.train()    start_step = 0    best_loss = float('inf')
    checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)    if resume and os.path.exists(checkpoint_path):        loaded_step = load_checkpoint(checkpoint_path, model, optimizer, scaler, lr_scheduler)        if loaded_step is not None:            start_step = loaded_step
    global_step = start_step    grad_accum_counter = 0    accumulated_loss = 0.0
    print("n" + "="*60)    print("开始训练...")    print("="*60)    print_gpu_memory()    print()
    # 自动选择是否需要 scaler.scale()    use_scaler = (CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16)
    for epoch in range(epochs):        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", initial=global_step % len(train_loader) if epoch == 0 else 0)        num_batches = 0        last_lr = None
        for batch_idx, (xb, yb) in enumerate(pbar):            # 跳过已训练的批次 (如果从中间恢复)            if global_step > start_step and batch_idx < (start_step % len(train_loader)):                 continue
            xb = xb.to(CFG.DEVICE, non_blocking=True)            yb = yb.to(CFG.DEVICE, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=(CFG.DEVICE == "cuda"), dtype=CFG.DTYPE):                logits = model(xb)                loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))                loss_item = loss.item()                loss = loss / CFG.GRAD_ACCUM_STEPS
            if use_scaler:                scaler.scale(loss).backward()            else:                loss.backward()
            grad_accum_counter += 1            accumulated_loss += loss_item            num_batches += 1            # 这里的 global_step 计数是基于数据批次的,而不是优化器步数,用于日志和检查点            # 真正的优化器步数会在下面更新
            # 梯度累积:达到阈值时执行优化步骤            if grad_accum_counter >= CFG.GRAD_ACCUM_STEPS:
                # 优化器步进 (这是真正的 global_step 增长点)                lr_scheduler.step() # 先更新 LR                global_step += 1 # 只有进行了一次优化器步进,才算一个 global_step
                if use_scaler:                    scaler.unscale_(optimizer)
                # 梯度裁剪 (在 unscale 后或非 AMP 模式下)                torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)
                if use_scaler:                    scaler.step(optimizer)                    scaler.update()                else:                    optimizer.step()
                optimizer.zero_grad()                grad_accum_counter = 0                last_lr = lr_scheduler.get_lr(global_step) # 获取当前步的LR
            # 日志输出            if global_step % CFG.LOG_EVERY == 0 or (global_step == 1):                # accumulated_loss 是累积的原始损失, num_batches 是累积的批次数                avg_loss = accumulated_loss / num_batches if num_batches > 0 else 0.0                pbar.set_postfix({                    'step': global_step,                    'loss': f'{avg_loss:.4f}',                    'lr': f'{last_lr:.2e}' if last_lr is not None else 'N/A'                })                # 重置累积值以便计算下一个 LOG_EVERY 间隔的平均损失                accumulated_loss = 0.0                num_batches = 0

            # 保存检查点            if global_step > start_step and global_step % CFG.CHECKPOINT_EVERY == 0:                # 使用上一个日志点计算的 avg_loss                current_avg_loss = accumulated_loss / num_batches if num_batches > 0 else loss_item
                config_dict = {                    'vocab_size': CFG.VOCAB_SIZE,                    'block_size': CFG.BLOCK_SIZE,                    'model_dim': CFG.MODEL_DIM,                    'n_layers': CFG.N_LAYERS,                    'num_heads': CFG.NUM_HEADS,                    'created_at': datetime.now().isoformat()                }                save_checkpoint(model, optimizer, scaler, lr_scheduler, global_step, current_avg_loss, config_dict)                torch.cuda.empty_cache()
            if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:                break
        # 处理 epoch 结束时剩余的梯度 (如果 grad_accum_counter > 0)        if grad_accum_counter > 0:            if use_scaler:                scaler.unscale_(optimizer)            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)
            if use_scaler:                scaler.step(optimizer)                scaler.update()            else:                optimizer.step()
            optimizer.zero_grad()            lr_scheduler.step()            global_step += 1            grad_accum_counter = 0

        # 此时 pbar.total_loss 已累积        if num_batches > 0:             final_avg_loss = accumulated_loss / num_batches        else:             final_avg_loss = float('inf')

        if final_avg_loss < best_loss:            best_loss = final_avg_loss            best_path = os.path.join(CFG.CHECKPOINT_DIR, "best_model.pth")            torch.save(model.state_dict(), best_path)            print(f"最佳模型已保存 (loss: {best_loss:.4f})")
        print(f"n[Epoch {epoch+1}] Avg Loss: {final_avg_loss:.4f}")
        if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:            break
    print("n训练完成!")
# ==================== 主函数 ====================def main():    if len(sys.argv) < 4:        print("用法: python train_20251012_v1.py workdir/spm_wiki_16k.model workdir/wiki_tokens_16k.pt models/gpt_wiki.pth [--resume]")        sys.exit(1)
    sp_model_path, token_file_path, out_path = sys.argv[1:4]    resume = "--resume" in sys.argv
    if not os.path.exists(token_file_path):        print(f" Token文件不存在: {token_file_path}")        sys.exit(1)
    # 检查 CFG.DTYPE 是否为 bfloat16 但环境不支持    if CFG.DTYPE == torch.bfloat16 and not torch.cuda.is_bf16_supported():        print("警告: bfloat16 不受当前 CUDA 设备支持,自动回退到 float16。")        CFG.DTYPE = torch.float16
    sp = spm.SentencePieceProcessor(model_file=sp_model_path)    CFG.VOCAB_SIZE = sp.get_piece_size()
    print("="*60)    print("GPT 语言模型训练")    print("="*60)    print(f"分词器: {sp_model_path}")    print(f"Token文件: {token_file_path}")    print(f"输出模型: {out_path}")    print(f"设备: {CFG.DEVICE}")    print(f"n模型配置:")    print(f"    - VOCAB_SIZE: {CFG.VOCAB_SIZE}")    print(f"    - BLOCK_SIZE: {CFG.BLOCK_SIZE}")    print(f"    - MODEL_DIM: {CFG.MODEL_DIM}")    print(f"    - N_LAYERS: {CFG.N_LAYERS}")    print(f"    - NUM_HEADS: {CFG.NUM_HEADS}")    print(f"n训练配置:")    print(f"    - BATCH_SIZE: {CFG.BATCH_SIZE}")    print(f"    - GRAD_ACCUM_STEPS: {CFG.GRAD_ACCUM_STEPS}")    print(f"    - 有效BATCH_SIZE: {CFG.BATCH_SIZE * CFG.GRAD_ACCUM_STEPS}")    print(f"    - LR: {CFG.LR}, WARMUP_STEPS: {CFG.WARMUP_STEPS}")    print("="*60)
    print(f"n加载Token文件: {token_file_path}")    ids = torch.load(token_file_path)    print(f"已加载 {ids.numel():,} tokens ({ids.numel() * ids.element_size() / (1024**3):.2f} GB)")
    dataset = TextDataset(ids, CFG.BLOCK_SIZE)    del ids    torch.cuda.empty_cache()
    # 改进:启用 shuffle=True 进行预训练    num_workers = CFG.NUM_WORKERS    try:        train_loader = DataLoader(            dataset,            batch_size=CFG.BATCH_SIZE,            shuffle=True, # 启用 Shuffle            pin_memory=(CFG.DEVICE == "cuda"),            num_workers=num_workers,            persistent_workers=True if num_workers > 0 else False        )    except Exception as e:        print(f"DataLoader错误: {e}, 改用num_workers=0")        train_loader = DataLoader(            dataset,            batch_size=CFG.BATCH_SIZE,            shuffle=True,            pin_memory=(CFG.DEVICE == "cuda"),            num_workers=0        )
    model = GPTModel(        CFG.VOCAB_SIZE,        CFG.BLOCK_SIZE,        dim=CFG.MODEL_DIM,        num_layers=CFG.N_LAYERS,        num_heads=CFG.NUM_HEADS,        ffn_dim=CFG.FFN_DIM,        dropout=CFG.DROPOUT    ).to(CFG.DEVICE)
    # 尝试编译(容错)    try:        model = torch.compile(model, mode='reduce-overhead')        print("已启用 torch.compile() 加速")    except Exception as e:        print(f"跳过 torch.compile(): {e}")
    train(model, train_loader, epochs=CFG.EPOCHS, resume=resume)
    torch.save(model.state_dict(), out_path)    print(f"n最终模型已保存到 {out_path}")    print_gpu_memory()
if __name__ == "__main__":    main()

07、进行模型推理测试

import torchfrom torch import nnimport sentencepiece as spmfrom typing import Optional
# ==================== 配置参数 (必须与训练时一致) ====================# 使用与训练脚本中完全相同的配置class Config:    BLOCK_SIZE = 512    # 模型尺寸参数 (必须与训练时一致)    MODEL_DIM = 384    N_LAYERS = 5    NUM_HEADS = 6    HEAD_DIM = MODEL_DIM // NUM_HEADS    FFN_DIM = MODEL_DIM * 4     VOCAB_SIZE = None
    # 推理设置    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"    # 推理通常使用 float32 获得最佳兼容性和精度    DTYPE = torch.float32 
CFG = Config()
# ==================== RoPE 位置编码 (与训练脚本保持一致) ====================class RotaryPositionalEmbedding(nn.Module):    def __init__(self, head_dim: int, max_seq_len: int = 2048):        super().__init__()        self.head_dim = head_dim        assert head_dim % 2 == 0, "head_dim must be even"        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))        self.register_buffer("inv_freq", inv_freq)        self.max_seq_len = max_seq_len        self._seq_len_cached = max_seq_len        self._cos_cached = None        self._sin_cached = None        self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)
    def _update_cos_sin_cache(self, seq_len: int, device: torch.device):        if seq_len == self._seq_len_cached and self._cos_cached is not None:            return        m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)        freqs = torch.einsum("i,j->ij", m, self.inv_freq)        emb = torch.cat([freqs, freqs], dim=-1)        cos = emb.cos()[None, None, :, :]        sin = emb.sin()[None, None, :, :]        self._cos_cached = cos        self._sin_cached = sin        self._seq_len_cached = seq_len
    def forward(self, seq_len: int, device: Optional[torch.device] = None):        if device is None:            device = self.inv_freq.device        self._update_cos_sin_cache(seq_len, device=device)        return self._cos_cached.to(device), self._sin_cached.to(device)
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:    return (x * cos) + (_rotate_half(x) * sin)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:    x1 = x[..., : x.shape[-1] // 2]    x2 = x[..., x.shape[-1] // 2 :]    return torch.cat((-x2, x1), dim=-1)
# ==================== Attention, FFN, Block, Model (与训练脚本保持一致) ====================class FlashAttention(nn.Module):    def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):        super().__init__()        self.embed_dim = embed_dim        self.num_heads = num_heads        assert embed_dim % num_heads == 0        self.head_dim = embed_dim // num_heads        self.scale = self.head_dim ** -0.5        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)        # 推理时通常不使用 Dropout,但模型结构需要保持一致        self.attn_dropout = nn.Dropout(attn_dropout)         self.rope = RotaryPositionalEmbedding(self.head_dim)
    def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:        B, T, C = x.shape        qkv = self.qkv(x)        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)        q, k, v = qkv.unbind(dim=2)        q = q.permute(0, 2, 1, 3)  # (B, H, T, D)        k = k.permute(0, 2, 1, 3)        v = v.permute(0, 2, 1, 3)
        cos, sin = self.rope(T, device=x.device)        q = apply_rotary_emb(q, cos, sin)        k = apply_rotary_emb(k, cos, sin)
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale        # 注意:在推理时,通常使用 KV-Cache,这里简化为完整计算        if T > 1: # 仅在序列长度大于 1 时应用 mask            causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))[None, None, :, :]            scores = scores.masked_fill(causal_mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)        # 推理时禁用 dropout        # attn = self.attn_dropout(attn)         out = torch.matmul(attn, v)        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)        return self.out_proj(out)
class FeedForward(nn.Module):    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):        super().__init__()        # 必须保持与训练脚本中完全相同的 nn.Sequential 结构        self.net = nn.Sequential(            GLU(dim, hidden_dim),            nn.Dropout(dropout),   # net.1: Dropout (必须保留,占位)            nn.Linear(hidden_dim, dim), # net.2: Linear (与训练时一致)            nn.Dropout(dropout),   # net.3: Dropout (必须保留,占位)        )
    def forward(self, x):        # 在推理时, model.eval() 会自动禁用所有 nn.Dropout 层,但结构不变        return self.net(x)
# 确保 GLU 的定义如下(与训练时一致):class GLU(nn.Module):    def __init__(self, in_dim: int, out_dim: int):        super().__init__()        # GLU 内部只有一个 nn.Linear        self.linear = nn.Linear(in_dim, out_dim * 2)
    def forward(self, x):        x, gates = self.linear(x).chunk(2, dim=-1)        return x * torch.nn.functional.silu(gates)
class TransformerBlock(nn.Module):    def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):        super().__init__()        self.ln1 = nn.LayerNorm(dim)        self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)        self.ln2 = nn.LayerNorm(dim)        self.ff = FeedForward(dim, ffn_dim, dropout)
    def forward(self, x, causal_mask=None):        x = x + self.attn(self.ln1(x), causal_mask)        x = x + self.ff(self.ln2(x))        return x
class GPTModel(nn.Module):    def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,                 num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,                 ffn_dim: int = CFG.FFN_DIM, dropout: float = 0.0, # 推理时 dropout=0                 tie_weights: bool = True):        super().__init__()        self.token_emb = nn.Embedding(vocab_size, dim)        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([            TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)        ])
        self.ln_final = nn.LayerNorm(dim)        self.lm_head = nn.Linear(dim, vocab_size, bias=False)
        if tie_weights:            self.lm_head.weight = self.token_emb.weight
        self.block_size = block_size
    def forward(self, idx):        B, T = idx.shape        token_emb = self.token_emb(idx)        x = token_emb # 推理时不使用 dropout
        causal_mask = None # Attention 模块内部处理 Causal Mask        for block in self.blocks:            x = block(x, causal_mask)        x = self.ln_final(x)        logits = self.lm_head(x)        return logits

# ==================== 推理和生成函数 ====================
@torch.no_grad()def generate_text(model: GPTModel, sp: spm.SentencePieceProcessor,                   prompt: str, max_new_tokens: int, temperature: float = 0.8,                   top_k: int = 50):
    model.eval()    device = CFG.DEVICE
    # 1. 编码输入    input_ids = sp.encode_as_ids(prompt)    if not input_ids:        return "无法编码输入。"
    # 将输入转换为模型期望的格式 (B, T)    x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
    # 2. 循环生成    for _ in range(max_new_tokens):        # 裁剪输入以适应模型的 BLOCK_SIZE        # 在实际部署中,这里应该使用 KV Cache,但此处简化为完整前向传播        idx_cond = x if x.size(1) <= CFG.BLOCK_SIZE else x[:, -CFG.BLOCK_SIZE:]
        # 获取 logits        logits = model(idx_cond)
        # 只取最后一个时间步的 logits        logits = logits[:, -1, :] 
        # 应用温度缩放        logits = logits / temperature
        # 3. Top-K 采样        if top_k is not None:            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))            logits[logits < v[:, [-1]]] = float('-inf')
        # 计算概率并采样        probs = torch.softmax(logits, dim=-1)        idx_next = torch.multinomial(probs, num_samples=1)
        # 4. 停止条件        # 检查是否生成了 EOS token (假设 </s> 是 ID 3, 请根据您的分词器调整)        # 默认使用 SentencePiece 的 <eos> ID        if idx_next.item() == sp.eos_id():            break
        # 将新生成的 token 添加到序列中        x = torch.cat((x, idx_next), dim=1)
        # 检查是否达到最大序列长度 (防止溢出)        if x.size(1) >= CFG.BLOCK_SIZE + max_new_tokens:            break
    # 5. 解码输出    output_ids = x[0].tolist()    # 查找输入 prompt 的长度,只解码新生成的 token    start_index = len(input_ids)
    return sp.decode_ids(output_ids[start_index:])

# ==================== 主执行函数 ====================
def main_infer(sp_model_path: str, model_weights_path: str):    print("="*50)    print(f"GPT 模型推理模式")    print(f"设备: {CFG.DEVICE}, DTYPE: {CFG.DTYPE}")    print("="*50)
    # 1. 加载分词器    try:        sp = spm.SentencePieceProcessor(model_file=sp_model_path)        CFG.VOCAB_SIZE = sp.get_piece_size()        print(f"加载分词器成功,VOCAB_SIZE: {CFG.VOCAB_SIZE}")    except Exception as e:        print(f"无法加载分词器模型 {sp_model_path}: {e}")        return
    # 2. 实例化模型    model = GPTModel(        vocab_size=CFG.VOCAB_SIZE,        block_size=CFG.BLOCK_SIZE,        dim=CFG.MODEL_DIM,        num_layers=CFG.N_LAYERS,        num_heads=CFG.NUM_HEADS,        ffn_dim=CFG.FFN_DIM,        dropout=0.0 # 推理时设置 dropout 为 0    ).to(CFG.DEVICE).to(CFG.DTYPE)
    # 3. 加载权重    try:        # 检查是否是 torch.compile 后的状态字典        weights = torch.load(model_weights_path, map_location=CFG.DEVICE)
        # 如果权重是 DDP 或 torch.compile 包装后的,需要解包        if any(k.startswith('_orig_mod.') for k in weights.keys()):            weights = {k.replace('_orig_mod.', ''): v for k, v in weights.items()}
        model.load_state_dict(weights, strict=True)        print(f"成功加载模型权重: {model_weights_path}")    except Exception as e:        print(f"无法加载或匹配模型权重: {e}")        # 如果加载失败,打印预期键和实际键,方便调试        # print("n--- 预期模型键 (部分) ---")        # print(list(model.state_dict().keys())[:5])        # print("n--- 载入权重键 (部分) ---")        # print(list(weights.keys())[:5])        return
    # 4. 进入交互循环    print("n--- 进入交互模式 ---")    print(f"输入 'exit' 或 'quit' 退出。")    print(f"输入 'config' 查看当前生成参数。")    print("----------------------")
    max_tokens = 100    temperature = 0.8    top_k = 50
    while True:        try:            prompt = input(">>> 输入提示词: ")
            if prompt.lower() in ['exit', 'quit']:                break
            if prompt.lower() == 'config':                print(f"  Max Tokens: {max_tokens}, Temp: {temperature}, Top K: {top_k}")                new_max = input("  设置 Max Tokens (回车跳过): ")                new_temp = input("  设置 Temperature (回车跳过): ")                new_k = input("  设置 Top K (回车跳过): ")
                if new_max: max_tokens = int(new_max)                if new_temp: temperature = float(new_temp)                if new_k: top_k = int(new_k)                continue
            if not prompt.strip():                continue
            print("生成中...")
            # 执行生成            output = generate_text(model, sp, prompt, max_tokens, temperature, top_k)
            print(f"--- 模型回复 ---n{output.strip()}")            print("----------------")
        except KeyboardInterrupt:            print("n退出生成...")            break        except Exception as e:            print(f"发生错误: {e}")

if __name__ == "__main__":    import sys
    if len(sys.argv) != 3:        print("用法: python infer.py <spm模型路径> <模型权重文件路径>")        # 示例用法 (请根据您的实际文件路径修改):        # python infer.py tokenizer.model final_model.pth        sys.exit(1)
    sp_model_path = sys.argv[1]    model_weights_path = sys.argv[2]
    main_infer(sp_model_path, model_weights_path)

我们看到模型大概可以预测我们输入的下一个词,因我们训练的参数和步数很低,模型输出的乱七八糟!

本次总结

本次我们做了数据准备、数据清洗、分词器训练、模型训练、推理等,请根据步骤进行执行代码,你便可以得到一个17M参数的小模型。后面我们再加大参数进行训练,再进行监督微调。

相关推荐

登录即可解锁
  • 海量技术文章
  • 设计资源下载
  • 产业链客户资源
  • 写文章/发需求
立即登录