Description
The code shows like the following. It could run but prompted some warning:
/opt/conda/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1544: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
Model params are a dictionary type tree like this:
{'params': {'Dense_0': {'bias': Array([ 1.75125599e-02, 4.10381891e-02, 1.96171561e-04, -2.42875870e-02,
1.51480837e-02, 3.71114984e-02, -4.87685064e-03, -1.63130835e-02,
5.47729768e-02, -5.70005644e-03, -5.12132980e-03, 3.10970427e-05,
2.31470224e-02, -1.55021911e-02, 1.72994770e-02, 2.26450190e-02,
-3.05333477e-03, 9.84513387e-03, -3.31428014e-02, 3.80380601e-02,
-3.20659392e-03, -3.09392507e-03, 1.86821781e-02, -1.25538018e-02,
4.41285521e-02, -4.72985283e-02, 3.28246184e-04, 1.31683890e-02,
1.32193940e-03, 1.48607325e-02, -3.43988538e-02, 8.36286321e-03,
-2.90089939e-02, -3.98164280e-02, 2.31531989e-02, 3.27519067e-02,
2.72216517e-02, -2.89463606e-02, 2.44598440e-03, 6.63389359e-03,
4.59096301e-03, -1.23022813e-02, 1.29767824e-02, 4.81516495e-03,
-1.20902760e-02, -2.27207374e-02, -1.27110668e-02, 1.20020472e-02,
3.91368084e-02, -4.30837227e-03, 3.32566164e-02, -2.71463916e-02,
2.25058272e-02, -3.91818397e-03, 1.49554424e-02, -6.85477350e-03,
1.01907691e-03, -6.12435490e-02, -1.18386028e-02, -6.03230670e-03,
7.54657155e-03, 8.14247876e-03, -6.61915401e-03, 8.85959063e-03], dtype=float32),
'kernel': Array([[-0.08260956, 0.42094436, -0.27531517, ..., -0.09135673,
0.21974503, 0.21818572],
[-0.04729075, -0.2666923 , 0.14365157, ..., 0.13939556,
-0.16218886, -0.04071451],
[ 0.10921595, 0.01364996, -0.11194808, ..., -0.01299416,
-0.02805288, -0.0272818 ],
...,
[-0.04990593, -0.01473087, 0.06877133, ..., -0.05618783,
-0.06337533, -0.17277789],
[-0.10326906, -0.03525492, 0.21592571, ..., -0.06726424,
0.04024971, 0.21430357],
[-0.06426816, 0.01593289, 0.01053577, ..., -0.08965493,
0.1562466 , 0.19774263]], dtype=float32)},
'Dense_1': {'bias': Array([ 0.01895721, 0.02381056, -0.00297396, 0.00253655, -0.00579324,
-0.00917996, -0.0524504 , -0.01307405, -0.00445831, -0.01765897,
-0.02990872, -0.01783756, -0.00417391, -0.02153626, -0.01237699,
0.00332377], dtype=float32),
'kernel': Array([[-0.0687123 , 0.11527583, 0.02760898, ..., -0.11483309,
0.09793864, 0.24956086],
[-0.17475414, 0.06557149, 0.02568068, ..., -0.18699066,
-0.235098 , 0.17345282],
[ 0.21747173, -0.00923413, -0.04049944, ..., 0.04021717,
-0.03704283, 0.13622351],
...,
[ 0.1976054 , -0.07143398, 0.11763132, ..., 0.15076494,
-0.08623252, 0.08628309],
[ 0.142208 , -0.07710048, 0.05116218, ..., 0.05643938,
0.01690205, -0.00337057],
[-0.08983981, -0.08721507, 0.05885444, ..., 0.2054291 ,
-0.0595689 , 0.09482205]], dtype=float32)},
'Dense_2': {'bias': Array([-0.27287585, -0.31808662, -0.22906446, -0.2392324 , -0.1169002 ,
-0.45564348, -0.27986547, -0.4403381 , -0.3194529 , -0.03579619,
-0.27706683, -0.20705369, -0.3464241 , -0.16313383, -0.3245753 ,
-0.12070157, -0.10058393, -0.335585 , -0.23487404, -0.13635263,
-0.3551262 , -0.19502614, -0.27066055, -0.22264665, -0.17983833,
-0.38362965, -0.2549991 , -0.35028023, -0.02632488, -0.24093926,
-0.26272595, -0.32823324, -0.1442327 , -0.18271838, -0.3466661 ,
-0.2975728 , -0.2519938 , -0.20744751, -0.48289314, -0.20181467,
-0.0694458 , -0.2868131 , -0.0621618 , -0.1489881 , -0.22316173,
-0.26048866, -0.3741152 , -0.22691546, -0.28160262, -0.39583412,
-0.44518995, -0.26774997, -0.18526609, -0.3136557 , -0.29002288,
-0.2983223 , -0.4889701 , -0.20518056, -0.06886528, -0.18853416,
-0.06637306, -0.45197925, -0.3145519 , -0.23673685], dtype=float32),
'kernel': Array([[ 0.01163595, 0.02972907, 0.02223774, ..., 0.0228426 ,
0.06626749, 0.04824122],
[-0.00046695, -0.08048075, -0.10955726, ..., -0.02934793,
-0.04758933, -0.06418303],
[-0.1286408 , -0.1255539 , -0.11133979, ..., -0.02911641,
-0.16321352, -0.1160882 ],
...,
[-0.01398295, -0.02459321, 0.21012494, ..., 0.07538891,
-0.08655173, 0.02649066],
[ 0.11309541, 0.09003462, -0.01682626, ..., -0.18835074,
0.09409627, 0.05982505],
[-0.12019534, -0.06023936, 0.14683168, ..., -0.10527591,
-0.0902904 , -0.08336279]], dtype=float32)},
'Dense_3': {'bias': Array([-0.19127552, -0.1539499 , -0.10825736, -0.1273831 , -0.14423408,
-0.10800537, -0.1158509 , -0.19560331, -0.0544809 , -0.12320589,
-0.14327425, -0.06410812, -0.06359567, -0.05706155, -0.16820268,
-0.06965973], dtype=float32),
'kernel': Array([[-0.17942922, -0.2944296 , -0.15672816, ..., -0.22876067,
0.04339508, -0.00558091],
[ 0.00910171, 0.00975822, -0.06065388, ..., -0.14715518,
-0.05254569, -0.09955268],
[-0.1705753 , -0.0021669 , 0.120933 , ..., -0.01333852,
-0.09636445, -0.13689627],
...,
[-0.07949128, -0.12297069, -0.30489385, ..., 0.08778288,
-0.10832835, -0.20170009],
[-0.00643041, -0.06448855, -0.02339345, ..., -0.1243507 ,
0.10904145, 0.01637404],
[-0.15585563, -0.07519744, -0.05542706, ..., -0.0285616 ,
-0.03688109, -0.07460079]], dtype=float32)}}}
# Save the params
if os.path.exists('/user/working/model_params'):
shutil.rmtree('/user/working/model_params')
checkpoint=orbax.checkpoint.PyTreeCheckpointer()
checkpoint.save('/user/working/model_params',params)
checkpoint=orbax.checkpoint.PyTreeCheckpointer()
checkpoint.restore('/user/working/model_params')
How should I fix this? Thx.
Activity