Skip to content

Commit 72ad5fb

Browse files
authored
Add default (mold, array_size) overloads for rvs_normal (#1056)
2 parents 9d26b90 + 04fe63a commit 72ad5fb

File tree

3 files changed

+89
-2
lines changed

3 files changed

+89
-2
lines changed

doc/specs/stdlib_stats_distribution_normal.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,17 @@ With two arguments, the function returns a normal distributed random variate \(N
2323

2424
With three arguments, the function returns a rank-1 array of normal distributed random variates.
2525

26+
With one or two arguments where the first is `array_size`, the function returns a rank-1 array of standard normal distributed random variates \(N(0,1)\). The `mold` argument determines the output type and kind; it is optional only for `real(dp)` (and defaults to `real(dp)` when omitted), but required for all other types.
27+
2628
@note
2729
The algorithm used for generating exponential random variates is fundamentally limited to double precision.[^1]
2830

2931
### Syntax
3032

3133
`result = ` [[stdlib_stats_distribution_normal(module):rvs_normal(interface)]] `([loc, scale] [[, array_size]])`
3234

35+
`result = ` [[stdlib_stats_distribution_normal(module):rvs_normal(interface)]] `(array_size [, mold])`
36+
3337
### Class
3438

3539
Elemental function (passing both `loc` and `scale`).
@@ -40,13 +44,15 @@ Elemental function (passing both `loc` and `scale`).
4044

4145
`scale`: optional argument has `intent(in)` and is a positive scalar of type `real` or `complex`.
4246

43-
`array_size`: optional argument has `intent(in)` and is a scalar of type `integer`.
47+
`array_size`: optional argument has `intent(in)` and is a scalar of type `integer`. When used with `loc` and `scale`, specifies the size of the output array. When used alone or with `mold`, must be provided as the first argument.
48+
49+
`mold`: optional argument (only for `real(dp)`; required for other types) has `intent(in)` and is a scalar of type `real` or `complex`. Used only to determine the type and kind of the output; its value is not referenced. When omitted (only allowed for `real(dp)`), defaults to `real(dp)`. When provided, generates standard normal variates \(N(0,1)\) of the specified type and kind.
4450

4551
`loc` and `scale` arguments must be of the same type.
4652

4753
### Return value
4854

49-
The result is a scalar or rank-1 array, with a size of `array_size`, and the same type as `scale` and `loc`. If `scale` is non-positive, the result is `NaN`.
55+
The result is a scalar or rank-1 array, with a size of `array_size`, and the same type as `scale` and `loc` (or same type and kind as `mold` when using the `array_size [, mold]` form; defaults to `real(dp)` when `mold` is omitted). If `scale` is non-positive, the result is `NaN`.
5056

5157
### Example
5258

src/stdlib_stats_distribution_normal.fypp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ module stdlib_stats_distribution_normal
3333

3434
#:for k1, t1 in RC_KINDS_TYPES
3535
module procedure rvs_norm_array_${t1[0]}$${k1}$ !3 dummy variables
36+
module procedure rvs_norm_array_default_${t1[0]}$${k1}$ !array_size, mold (mold optional for real(dp) only)
3637
#:endfor
3738
end interface rvs_normal
3839

@@ -238,6 +239,22 @@ contains
238239

239240
#:endfor
240241

242+
#:for k1, t1 in REAL_KINDS_TYPES
243+
impure function rvs_norm_array_default_${t1[0]}$${k1}$ (array_size, mold) result(res)
244+
!
245+
! Standard normal array random variate with default loc=0, scale=1
246+
! The mold argument is used only to determine the type and is not referenced
247+
!
248+
integer, intent(in) :: array_size
249+
${t1}$, intent(in) #{if t1 == 'real(dp)'}#, optional #{endif}#:: mold
250+
${t1}$ :: res(array_size)
251+
252+
res = rvs_norm_array_${t1[0]}$${k1}$ (0.0_${k1}$, 1.0_${k1}$, array_size)
253+
254+
end function rvs_norm_array_default_${t1[0]}$${k1}$
255+
256+
#:endfor
257+
241258
#:for k1, t1 in CMPLX_KINDS_TYPES
242259
impure function rvs_norm_array_${t1[0]}$${k1}$ (loc, scale, array_size) result(res)
243260
${t1}$, intent(in) :: loc, scale
@@ -256,6 +273,25 @@ contains
256273

257274
#:endfor
258275

276+
#:for k1, t1 in CMPLX_KINDS_TYPES
277+
impure function rvs_norm_array_default_${t1[0]}$${k1}$ (array_size, mold) result(res)
278+
!
279+
! Standard normal complex array random variate with default loc=0, scale=1
280+
! The mold argument is used only to determine the type and is not referenced
281+
!
282+
integer, intent(in) :: array_size
283+
${t1}$, intent(in) :: mold
284+
${t1}$ :: res(array_size)
285+
286+
! Call the full procedure with default loc=(0,0), scale=(1,1)
287+
res = rvs_norm_array_${t1[0]}$${k1}$ (cmplx(0.0_${k1}$, 0.0_${k1}$, kind=${k1}$), &
288+
cmplx(1.0_${k1}$, 1.0_${k1}$, kind=${k1}$), &
289+
array_size)
290+
291+
end function rvs_norm_array_default_${t1[0]}$${k1}$
292+
293+
#:endfor
294+
259295
#:for k1, t1 in REAL_KINDS_TYPES
260296
elemental function pdf_norm_${t1[0]}$${k1}$ (x, loc, scale) result(res)
261297
!

test/stats/test_distribution_normal.fypp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ program test_distribution_normal
2626
call test_nor_rvs_${t1[0]}$${k1}$
2727
#:endfor
2828

29+
#:for k1, t1 in RC_KINDS_TYPES
30+
call test_nor_rvs_default_${t1[0]}$${k1}$
31+
#:endfor
32+
2933

3034

3135
#:for k1, t1 in RC_KINDS_TYPES
@@ -138,6 +142,47 @@ contains
138142
#:endfor
139143

140144

145+
#:for k1, t1 in RC_KINDS_TYPES
146+
subroutine test_nor_rvs_default_${t1[0]}$${k1}$
147+
${t1}$ :: a1(10), a2(10), mold
148+
integer :: i
149+
integer :: seed, get
150+
151+
print *, "Test normal_distribution_rvs_default_${t1[0]}$${k1}$"
152+
seed = 25836914
153+
call random_seed(seed, get)
154+
155+
! explicit form with loc=0, scale=1
156+
#:if t1[0] == "r"
157+
a1 = nor_rvs(0.0_${k1}$, 1.0_${k1}$, 10)
158+
#:else
159+
a1 = nor_rvs((0.0_${k1}$, 0.0_${k1}$), (1.0_${k1}$, 1.0_${k1}$), 10)
160+
#:endif
161+
162+
! reset seed to reproduce same random sequence
163+
seed = 25836914
164+
call random_seed(seed, get)
165+
166+
! default mold form: mold used only to disambiguate kind
167+
! For real(dp), mold is optional; for other types (including complex), it's required
168+
#:if t1[0] == "r" and k1 == "dp"
169+
a2 = nor_rvs(10) ! mold optional for rdp only, defaults to real(dp)
170+
#:else
171+
#! mold required for all other types including complex and non-dp kinds
172+
#:if t1[0] == "r"
173+
mold = 0.0_${k1}$
174+
#:else
175+
mold = (0.0_${k1}$, 0.0_${k1}$)
176+
#:endif
177+
a2 = nor_rvs(10, mold)
178+
#:endif
179+
180+
call check(all(a1 == a2), msg="normal_distribution_rvs_default_${t1[0]}$${k1}$ failed", warn=warn)
181+
end subroutine test_nor_rvs_default_${t1[0]}$${k1}$
182+
183+
#:endfor
184+
185+
141186

142187

143188
#:for k1, t1 in RC_KINDS_TYPES

0 commit comments

Comments
 (0)