... | ... |
@@ -692,5 +692,7 @@ def finite_loss(ival, loss, x_scale): |
692 | 692 |
|
693 | 693 |
# We round the loss to 12 digits such that losses |
694 | 694 |
# are equal up to numerical precision will be considered |
695 |
- # equal. |
|
696 |
- return round(loss, ndigits=12), ival |
|
695 |
+ # equal. This is 3.5x faster than unsing the `round` function. |
|
696 |
+ round_fac = 1e12 |
|
697 |
+ loss = int(loss * round_fac + 0.5) / round_fac |
|
698 |
+ return loss, ival |