FortisCK Blog

记录研究生活

李宏毅ML2022-HW2:Multiclass_Classification

作业内容

Framewise phoneme prediction from speech.
Phoneme:A unit of speech sound in a language that can serve to distinguish one word from the
other.
具体内容:这里

改进方案

超参数的更改+余弦退火学习率

利用余弦退火学习率,有的学生可能问了,为什么老是余弦退火啊,用李宏毅老师的话,这都是古圣先贤的意思,用就对了,不过我的理解是使用余弦退火的时候可以很直观的看到哪些学习率是比较合适的,这对我们选择正确的学习率参数很有帮助。

改进模型

class BasicBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super(BasicBlock, self).__init__()

self.block = nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(output_dim),
nn.Dropout(0.2),
)

def forward(self, x):
x = self.block(x)
return x

超参数更改

# data prarameters
concat_nframes = 17 # the number of frames to concat with, n must be odd (total 2k+1 = n frames)
train_ratio = 0.8 # the ratio of data used for training, the rest will be used for validation

# training parameters
seed = 0 # random seed
batch_size = 2048 # batch size
num_epoch = 50 # the number of training epoch
learning_rate = 0.0002 # learning rate
model_path = './model.ckpt' # the path where the checkpoint will be saved

# model parameters
input_dim = 39 * concat_nframes # the input dim of the model, you should not change the value
hidden_layers = 2 # the number of hidden layers
hidden_dim = 1024 # the hidden dim

余弦退火学习率

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate*5, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
T_0=8, T_mult=2, eta_min=learning_rate/2)

结果

结果

相关资料

数据集:这里