Sign in
Sign up
Explore
Enterprise
Education
Search
Help
Terms of use
About Us
Explore
Enterprise
Education
Gitee Premium
Gitee AI
Sign in
Sign up
Fetch the repository succeeded.
Open Source
>
AI/ML
>
Natural Language Processing
&&
Donate
Please sign in before you donate.
Cancel
Sign in
Scan WeChat QR to Pay
Cancel
Complete
Prompt
Switch to Alipay.
OK
Cancel
Watch
Unwatch
Watching
Releases Only
Ignoring
64
Star
379
Fork
174
PaddlePaddle
/
PaddleNLP
Code
Issues
32
Pull Requests
2
Wiki
Insights
Pipelines
Service
Quality Analysis
Jenkins for Gitee
Tencent CloudBase
Tencent Cloud Serverless
悬镜安全
Aliyun SAE
Codeblitz
SBOM
Don’t show this again
Update failed. Please try again later!
Remove this flag
Content Risk Flag
This task is identified by
as the content contains sensitive information such as code security bugs, privacy leaks, etc., so it is only accessible to contributors of this repository.
examples/information_extraction/waybill_ie/run_bigru_crf.py 使用save_pretrained保存模型时报错
Backlog
#I3YUOY
大象无形
Opened this issue
2021-07-02 15:04
由于是本地使用,我修改了代码:使ErnieCrfForTokenClassification继承自ErniePretrainedModel ``` class ErnieCrfForTokenClassification(ErniePretrainedModel): def __init__(self, ernie, crf_lr=100): super(ErnieCrfForTokenClassification, self).__init__() self.num_classes = ernie.num_classes self.ernie = ernie # allow ernie to be config self.crf = LinearChainCrf( self.num_classes, crf_lr=crf_lr, with_start_stop_tag=False) self.crf_loss = LinearChainCrfLoss(self.crf) self.viterbi_decoder = ViterbiDecoder( self.crf.transitions, with_start_stop_tag=False) ``` 然后在保存模型时,使用model.save_pretrained('./ernie_crf_ckpt')保存模型,代码: ``` step = 0 best_f1 = 0 for epoch in range(3000): for input_ids, token_type_ids, lengths, labels in train_loader: loss = model( input_ids, token_type_ids, lengths=lengths, labels=labels) avg_loss = paddle.mean(loss) avg_loss.backward() optimizer.step() optimizer.clear_grad() step += 1 print("[TRAIN] Epoch:%d - Step:%d - Loss: %f" % (epoch, step, avg_loss)) f1_score = evaluate(model, metric, dev_loader) if f1_score >= best_f1: best_f1 = f1_score paddle.save(model.state_dict(), 'ernie_crf_ckpt/model_state.pdparams') # paddle.save(model, 'ernie_crf_ckpt/model_state.pdmodel') tokenizer.save_pretrained('./ernie_crf_ckpt') model.save_pretrained('./ernie_crf_ckpt') ``` 出现报错: ``` Traceback (most recent call last): File "/home/user/Desktop/on_going_projects/sgaqfh/PaddleNLP/examples/information_extraction/waybill_ie/run_ernie_crf.py", line 135, in <module> model.save_pretrained('./ernie_crf_ckpt') File "/home/user/Desktop/on_going_projects/sgaqfh/PaddleNLP/paddlenlp/transformers/model_utils.py", line 386, in save_pretrained self.save_model_config(save_dir) File "/home/user/Desktop/on_going_projects/sgaqfh/PaddleNLP/paddlenlp/transformers/model_utils.py", line 374, in save_model_config f.write(json.dumps(model_config, ensure_ascii=False)) File "/home/user/anaconda3/envs/paddle_2.1/lib/python3.8/json/__init__.py", line 234, in dumps return cls( File "/home/user/anaconda3/envs/paddle_2.1/lib/python3.8/json/encoder.py", line 199, in encode chunks = self.iterencode(o, _one_shot=True) File "/home/user/anaconda3/envs/paddle_2.1/lib/python3.8/json/encoder.py", line 257, in iterencode return _iterencode(o, 0) File "/home/user/anaconda3/envs/paddle_2.1/lib/python3.8/json/encoder.py", line 179, in default raise TypeError(f'Object of type {o.__class__.__name__} ' TypeError: Object of type ErnieModel is not JSON serializable ``` 经过查看,是在保存model_config.json时报错,无法解决,求解决办法
由于是本地使用,我修改了代码:使ErnieCrfForTokenClassification继承自ErniePretrainedModel ``` class ErnieCrfForTokenClassification(ErniePretrainedModel): def __init__(self, ernie, crf_lr=100): super(ErnieCrfForTokenClassification, self).__init__() self.num_classes = ernie.num_classes self.ernie = ernie # allow ernie to be config self.crf = LinearChainCrf( self.num_classes, crf_lr=crf_lr, with_start_stop_tag=False) self.crf_loss = LinearChainCrfLoss(self.crf) self.viterbi_decoder = ViterbiDecoder( self.crf.transitions, with_start_stop_tag=False) ``` 然后在保存模型时,使用model.save_pretrained('./ernie_crf_ckpt')保存模型,代码: ``` step = 0 best_f1 = 0 for epoch in range(3000): for input_ids, token_type_ids, lengths, labels in train_loader: loss = model( input_ids, token_type_ids, lengths=lengths, labels=labels) avg_loss = paddle.mean(loss) avg_loss.backward() optimizer.step() optimizer.clear_grad() step += 1 print("[TRAIN] Epoch:%d - Step:%d - Loss: %f" % (epoch, step, avg_loss)) f1_score = evaluate(model, metric, dev_loader) if f1_score >= best_f1: best_f1 = f1_score paddle.save(model.state_dict(), 'ernie_crf_ckpt/model_state.pdparams') # paddle.save(model, 'ernie_crf_ckpt/model_state.pdmodel') tokenizer.save_pretrained('./ernie_crf_ckpt') model.save_pretrained('./ernie_crf_ckpt') ``` 出现报错: ``` Traceback (most recent call last): File "/home/user/Desktop/on_going_projects/sgaqfh/PaddleNLP/examples/information_extraction/waybill_ie/run_ernie_crf.py", line 135, in <module> model.save_pretrained('./ernie_crf_ckpt') File "/home/user/Desktop/on_going_projects/sgaqfh/PaddleNLP/paddlenlp/transformers/model_utils.py", line 386, in save_pretrained self.save_model_config(save_dir) File "/home/user/Desktop/on_going_projects/sgaqfh/PaddleNLP/paddlenlp/transformers/model_utils.py", line 374, in save_model_config f.write(json.dumps(model_config, ensure_ascii=False)) File "/home/user/anaconda3/envs/paddle_2.1/lib/python3.8/json/__init__.py", line 234, in dumps return cls( File "/home/user/anaconda3/envs/paddle_2.1/lib/python3.8/json/encoder.py", line 199, in encode chunks = self.iterencode(o, _one_shot=True) File "/home/user/anaconda3/envs/paddle_2.1/lib/python3.8/json/encoder.py", line 257, in iterencode return _iterencode(o, 0) File "/home/user/anaconda3/envs/paddle_2.1/lib/python3.8/json/encoder.py", line 179, in default raise TypeError(f'Object of type {o.__class__.__name__} ' TypeError: Object of type ErnieModel is not JSON serializable ``` 经过查看,是在保存model_config.json时报错,无法解决,求解决办法
Comments (
7
)
Sign in
to comment
Status
Backlog
Backlog
Doing
Done
Closed
Assignees
Not set
Labels
Not set
Label settings
Milestones
No related milestones
No related milestones
Pull Requests
None yet
None yet
Successfully merging a pull request will close this issue.
Branches
No related branch
Branches (67)
Tags (53)
develop
dsv3_dev
incubate/paddlenlp-fleety
release/3.0-beta4-new
release/3.0-beta2
ZHUI-patch-4
ZHUI-patch-5
revert-9975-llama_shard_attention_mask_0303
ZHUI-patch-2
ZHUI-patch-3
release/3.0-beta4
incubate/deepseek
release/3.0-beta3
DesmonDay-patch-5-1
DesmonDay-patch-4
DesmonDay-patch-3
DesmonDay-patch-1
ckpt
release/2.8
DesmonDay-patch-2
revert-9117-cpd
release/3.0-beta1
paddlenlp-dev-fleetv
2.8_se
release/3.0-beta
ZHUI-patch-1
refactor-training-loop
fix_DistributedBatchSampler_
fix_DistributedBatchSampler
release/2.9
wawltor-patch-1
release/2.7
revert-8269-add_memory_stats_to_log
stable/paddle-ci
paddlenlp-2.7.2-fleetv
bugfix/sample_generate
ppo-4d/support_uc
stable_llm/1.3
pip90
stable_llm/1.1
stable_llm/1.2
trainer
llm-unittest
stable_0815
chatglm2-beamsearch
bugfix/hybrid_parallel_speed_drop
release/2.6
sijunhe-patch-3
fused-mt-feature
revert-6343-update_ldm_o1
fast-llm
llama-readme
fix_paddle_inference
enhancement-simpleServing
chenxiaozeng-patch-1
release/2.5
sijunhe-patch-2
sijunhe-patch-1
ppdiffusers
uie-memory-overflow
release/2.4
fix-return-more-output
hf_datasets
release/2.2
release/2.3
release/2.1
release/2.0
v3.0.0-beta4
v3.0.0-beta3
v3.0.0-beta2
v3.0.0-beta1
v3.0.0-beta0
v2.8.1
v2.8.0
v2.7.2
v2.7.1
v2.7.0
v2.6.1
v2.6.0
v2.6.0rc
v2.5.2
v2.5.1
v2.5.0
v2.4.9
v2.4.8
v2.4.7
v2.4.6
v2.4.5
v2.4.4
v2.4.3
v2.4.2
v2.4.1
v2.4.0
v2.3.7
v2.3.5
v2.3.4
v2.3.3
v2.3.2
v2.3.1
v2.3.0
v2.3.0rc1
v2.3.0rc0
v2.2.6
v2.2.5
v2.2.4
v2.2.2
v2.2.1
v2.2.0
v2.1.1
v2.1.0
v2.0.8
v2.0.7
v2.0.6
v2.0.4
v2.0.5
v2.0.3
v2.0.2
v2.0.0-rc
v2.0.0
2.0.0-rc
Planed to start   -   Planed to end
-
Top level
Not Top
Top Level: High
Top Level: Medium
Top Level: Low
Priority
Not specified
Serious
Main
Secondary
Unimportant
参与者(1)
Python
1
https://gitee.com/paddlepaddle/PaddleNLP.git
git@gitee.com:paddlepaddle/PaddleNLP.git
paddlepaddle
PaddleNLP
PaddleNLP
Going to Help Center
Search
Git 命令在线学习
如何在 Gitee 导入 GitHub 仓库
Git 仓库基础操作
企业版和社区版功能对比
SSH 公钥设置
如何处理代码冲突
仓库体积过大,如何减小?
如何找回被删除的仓库数据
Gitee 产品配额说明
GitHub仓库快速导入Gitee及同步更新
什么是 Release(发行版)
将 PHP 项目自动发布到 packagist.org
Comment
Repository Report
Back to the top
Login prompt
This operation requires login to the code cloud account. Please log in before operating.
Go to login
No account. Register