Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #273 +/- ##
==========================================
- Coverage 73.08% 70.70% -2.38%
==========================================
Files 5 6 +1
Lines 535 553 +18
==========================================
Hits 391 391
- Misses 144 162 +18
☔ View full report in Codecov by Sentry. |
|
This error is only happening with the MKL provider. With MKL, FFTW.jl doesn't even compile on my machine. Could be due to 008bc5b? test_frule: idct on Array{Float64, 3},Int64: Error During Test at /home/runner/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:123
Got exception outside of a @test
FFTW could not create plan
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] macro expansion
@ FFTW ~/work/FFTW.jl/FFTW.jl/src/fft.jl:722 [inlined] |
|
@devmotion, could you please review this? LMK if you want me to remove the |
devmotion
left a comment
There was a problem hiding this comment.
Can you explain the idea of the PR? The design goal was to define the differentiation rules in AbstractFFTs via adjoint plans. Downstream packages were supposed to implement this new adjoint interface.
|
Only the MKL tests are failing. Is there a regression with MKL? |
More concretely, #249 outlines the intended approach. FFTW was not supposed to define custom rules. |
|
This PR defines Chain Rules for |
|
With #249 , gradient computation would error for DCT/IDCT. This is because julia> using FFTW, Zygote
julia>
julia> using LinearAlgebra, FFTW, Zygote
julia> x = rand(4)
4-element Vector{Float64}:
0.8692266334693106
0.6938635624794242
0.552208368655668
0.9197557963740512
julia> f(x) = x |> dct |> idct |> norm
f (generic function with 1 method)
julia> f(x)
1.5452787421840921
julia> Zygote.gradient(f, x)
ERROR: Compiling Tuple{Type{FFTW.r2rFFTWPlan{Float64, Any, false, 1}}, Vector{Float64}, FFTW.FakeArray{Float64, 1}, UnitRange{Int64}, Int64, UInt32, Float64}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations |
devmotion
left a comment
There was a problem hiding this comment.
Sorry, somehow I missed yesterday that the PR does not add rules for plans but for the dct, idct and r2r functions 🤦
As long as they/their interface is not moved to AbstractFFTs, rules should be defined here 👍
ext/FFTWChainRulesCoreExt.jl
Outdated
|
|
||
| # R2R | ||
|
|
||
| function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, region...) |
There was a problem hiding this comment.
It seems the rrule for r2r is missing?
There was a problem hiding this comment.
The R2R transforms are not unitary. There is some scaling involved that depends on the kind of R2R transform. Because it looks like an involved task, I chose to skip that for now. I am happy to look into that in a separate PR
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
|
@devmotion I've addressed all your comments. LMK if you have more questions :D |
|
@devmotion ping :) |
address #272