Skip to content

Fix fine-tuning training loss accumulation #725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 15, 2024

Conversation

celestinoalan
Copy link
Contributor

What does this PR do?

Problem:

In /src/llama_recipes/utils/train_utils.py the training loss is correctly divided by the # of gradient accumulation steps to scale down the gradient:

loss = loss / gradient_accumulation_steps

The training loss is then accumulated

total_loss += loss.detach().float()

and used in the following to calculate the average loss across all samples in the epoch:

train_epoch_loss = total_loss / len(train_dataloader)

As the accumulated loss is scaled down by gradient_accumulation_steps and len(train_dataloader) includes all steps (even the gradient accumulation ones), train_epoch_loss is gradient_accumulation_steps times lower than it should be.

Solution:

Accumulate the loss
total_loss += loss.detach().float()

before scaling it down

loss = loss / gradient_accumulation_steps

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ X ] Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Thanks for contributing 🎉!

Copy link
Contributor

@mreso mreso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for contributing!

@mreso mreso merged commit d6ae203 into meta-llama:main Oct 15, 2024
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy