-
Notifications
You must be signed in to change notification settings - Fork 161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax ml operator fix #4041
base: master
Are you sure you want to change the base?
Jax ml operator fix #4041
Conversation
|
|
@pbrubeck @dham Firedrake pip installs PR depends on the approval of UFL PR 348 and this PR. |
The failure in the docs building is because of a recent Sphinx update. |
@@ -77,6 +77,6 @@ def J(f): | |||
c = Control(f) | |||
Jhat = ReducedFunctional(J(f), c) | |||
|
|||
f_opt = minimize(Jhat, tol=1e-6, method="BFGS") | |||
f_opt = minimize(Jhat, tol=1e-4, method="BFGS") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tolerance modification reduces the runtime test by almost 50%. Most important is that the following assert
is still satisfied.
firedrake/ml/jax/ml_operator.py
Outdated
@@ -39,36 +39,41 @@ def __init__( | |||
*operands: Union[ufl.core.expr.Expr, ufl.form.BaseForm], | |||
function_space: WithGeometryBase, | |||
derivatives: Optional[tuple] = None, | |||
argument_slots: Optional[tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]]], | |||
argument_slots: Optional[tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]]] = (), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If argument slots are not provided, then they will be generated.
Having an empty tuple seems a strange default considering. Could/should it be None
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually if it is Optional
then that's what is has to be (https://docs.python.org/3/library/typing.html#typing.Optional)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional[tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]]] = None
leads to the error:
> argument_slots = tuple(map(as_ufl, argument_slots))
E TypeError: 'NoneType' object is not iterable
../ufl/ufl/core/base_form_operator.py:49: TypeError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way to define argument_slots
argument is comming from https://github.com/firedrakeproject/ufl/blob/master/ufl/core/external_operator.py#L28
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can improve this by writing
argument_slots: tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]] | None = None,
Then
if argument_slots is None:
argument_slots = ()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be better to just have
argument_slots: tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]] = (),
(i.e. no Optional
) and to remove
If argument slots are not provided, then they will be generated.
from the docstring.
That is at least consistent with the UFL object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Description
Depends on firedrakeproject/ufl#60
Fixed an issue in
firedrake.ml.jax.ml_operator
whereargument_slots
was incorrectly specified. I wrote it as optional:Clarified that if
argument_slots
is not provided,ML_Operator
will automatically write it.Test the results involving the Neural operators are in right function space.
Test jax and pytorch operators in Firedrake CI.