-
Notifications
You must be signed in to change notification settings - Fork 12
add position update to progress in .h5 #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I'm also not a huge fan of default value of 0 |
hexane360
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, made a couple of notes. Why don't you rename the key to 'pos_update_mag' or 'pos_update_rms' to make clear it pertains to the position solver? Do we want 'tilt_update_mag'/'tilt_update_rms' as well?
phaser/engines/gradient/run.py
Outdated
| state, iter_solver_states[sol_i], filter_vars(iter_grads, solver.params), losses['total_loss'] | ||
| ) | ||
| state = apply_update(state, update) | ||
| update_mag = xp.linalg.norm(update['positions'], axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if no position solver is specified, and 'positions' is not in update?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be addressed now
phaser/engines/gradient/run.py
Outdated
| for (k) in other_keys: | ||
| progress[k].iters.append(i + start_i) | ||
| progress[k].values.append(xp.mean(update_mag)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The for loop seems redundant if you're special-casing 'update_mag' regardless
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I realize it'll definitely break, hopefully fixed in newest commit
phaser/engines/conventional/run.py
Outdated
| other_keys = ('positions_update_rms', ) if position_solver is not None else () | ||
| # populate missing keys in progress dictionary | ||
| for k in ('detector_loss', 'total_loss'): | ||
| for k in (*('detector_loss', 'total_loss'), *other_keys): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for k in (*('detector_loss', 'total_loss'), *other_keys): | |
| for k in ('detector_loss', 'total_loss', *other_keys): |
phaser/engines/conventional/run.py
Outdated
|
|
||
| # check positions are at least overlapping object | ||
| sim.state.object.sampling.check_scan(sim.state.scan, sim.state.probe.sampling.extent / 2.) | ||
| # for k in ('update_mag', ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # for k in ('update_mag', ): |
phaser/engines/conventional/run.py
Outdated
| pos_update_rms = xp.linalg.norm(pos_update, axis=-1, keepdims=True) | ||
| logger.info(f"Position update: mean {xp.mean(pos_update_rms)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| pos_update_rms = xp.linalg.norm(pos_update, axis=-1, keepdims=True) | |
| logger.info(f"Position update: mean {xp.mean(pos_update_rms)}") | |
| pos_update_rms = float(xp.mean(xp.linalg.norm(pos_update, axis=-1, keepdims=True))) | |
| logger.info(f"Position update: mean {pos_update_rms}") |
phaser/engines/conventional/run.py
Outdated
| progress['positions_update_rms'].iters.append(i + start_i) | ||
| progress['positions_update_rms'].values.append(xp.mean(pos_update_rms)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| progress['positions_update_rms'].iters.append(i + start_i) | |
| progress['positions_update_rms'].values.append(xp.mean(pos_update_rms)) | |
| progress['positions_update_rms'].iters.append(i + start_i) | |
| progress['positions_update_rms'].values.append(pos_update_rms) |
08da8b6 to
f6a989e
Compare
Co-authored-by: Colin Gilgenbach <colin@gilgenbach.net>
working for conventional and gradient, but maybe not the most elegant for the gradient engine. open to suggestions