Coverage for src/toolbox_pyspark/scale.py: 100%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-25 23:08 +0000

1# ============================================================================ # 

2# # 

3# Title : Scale # 

4# Purpose : Rounding a column (or columns) to a given rounding accuracy. # 

5# # 

6# ============================================================================ # 

7 

8 

9# ---------------------------------------------------------------------------- # 

10# # 

11# Overview #### 

12# # 

13# ---------------------------------------------------------------------------- # 

14 

15 

16# ---------------------------------------------------------------------------- # 

17# Description #### 

18# ---------------------------------------------------------------------------- # 

19 

20 

21""" 

22!!! note "Summary" 

23 The `scale` module is used for rounding a column (or columns) to a given rounding accuracy. 

24""" 

25 

26 

27# ---------------------------------------------------------------------------- # 

28# # 

29# Setup #### 

30# # 

31# ---------------------------------------------------------------------------- # 

32 

33 

34# ---------------------------------------------------------------------------- # 

35# Imports #### 

36# ---------------------------------------------------------------------------- # 

37 

38 

39# ## Python StdLib Imports ---- 

40from typing import Optional, Union 

41 

42# ## Python Third Party Imports ---- 

43from pyspark.sql import DataFrame as psDataFrame, functions as F 

44from toolbox_python.checkers import is_type 

45from toolbox_python.collection_types import str_collection, str_list 

46from typeguard import typechecked 

47 

48# ## Local First Party Imports ---- 

49from toolbox_pyspark.checks import assert_column_exists, assert_columns_exists 

50 

51 

52# ---------------------------------------------------------------------------- # 

53# Exports #### 

54# ---------------------------------------------------------------------------- # 

55 

56 

57__all__: str_list = ["round_column", "round_columns"] 

58 

59 

60# ---------------------------------------------------------------------------- # 

61# Constants #### 

62# ---------------------------------------------------------------------------- # 

63 

64 

65DEFAULT_DECIMAL_ACCURACY: int = 10 

66VALID_TYPES: str_list = ["float", "double", "decimal"] 

67 

68 

69# ---------------------------------------------------------------------------- # 

70# # 

71# Functions #### 

72# # 

73# ---------------------------------------------------------------------------- # 

74 

75 

76# ---------------------------------------------------------------------------- # 

77# Firstly #### 

78# ---------------------------------------------------------------------------- # 

79 

80 

81@typechecked 

82def round_column( 

83 dataframe: psDataFrame, 

84 column: str, 

85 scale: int = DEFAULT_DECIMAL_ACCURACY, 

86) -> psDataFrame: 

87 """ 

88 !!! note "Summary" 

89 For a given `dataframe`, on a given `column` if the column data type is decimal (that is, one of: `#!py ["float", "double", "decimal"]`), then round that column to a `scale` accuracy at a given number of decimal places. 

90 

91 ???+ abstract "Details" 

92 Realistically, under the hood, this function is super simple. It merely runs: 

93 ```{.py .python linenums="1" title="Python"} 

94 dataframe = dataframe.withColumn(colName=column, col=F.round(col=column, scale=scale)) 

95 ``` 

96 This function merely adds some additional validation, and is enabled to run in a pyspark `.transform()` method. 

97 For more info, see: [`pyspark.sql.DataFrame.transform`](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.DataFrame.transform.html) 

98 

99 Params: 

100 dataframe (psDataFrame): 

101 The `dataframe` to be transformed. 

102 column (str): 

103 The desired column to be rounded. 

104 scale (int, optional): 

105 The required level of rounding for the column.<br> 

106 If not provided explicitly, it will default to the global value `#!py DEFAULT_DECIMAL_ACCURACY`; which is `#!py 10`.<br> 

107 Defaults to `#!py DEFAULT_DECIMAL_ACCURACY`. 

108 

109 Raises: 

110 TypeError: 

111 If any of the inputs parsed to the parameters of this function are not the correct type. Uses the [`@typeguard.typechecked`](https://typeguard.readthedocs.io/en/stable/api.html#typeguard.typechecked) decorator. 

112 TypeError: 

113 If the given `column` is not one of the correct data types for rounding. It must be one of: `#!py ["float", "double", "decimal"]`. 

114 

115 Returns: 

116 (psDataFrame): 

117 The transformed `dataframe` containing the column which has now been rounded. 

118 

119 ???+ example "Examples" 

120 

121 ```{.py .python linenums="1" title="Set up"} 

122 >>> # Imports 

123 >>> import pandas as pd 

124 >>> from pyspark.sql import SparkSession, functions as F, types as T 

125 >>> from toolbox_pyspark.io import read_from_path 

126 >>> 

127 >>> # Instantiate Spark 

128 >>> spark = SparkSession.builder.getOrCreate() 

129 >>> 

130 >>> # Create data 

131 >>> df = ( 

132 ... spark 

133 ... .createDataFrame( 

134 ... pd.DataFrame( 

135 ... { 

136 ... "a": range(20), 

137 ... "b": [f"1.{'0'*val}1" for val in range(20)], 

138 ... "c": [f"1.{'0'*val}6" for val in range(20)], 

139 ... } 

140 ... ) 

141 ... ) 

142 ... .withColumns( 

143 ... { 

144 ... "b": F.col("b").cast(T.DecimalType(21, 20)), 

145 ... "c": F.col("c").cast(T.DecimalType(21, 20)), 

146 ... } 

147 ... ) 

148 ... ) 

149 >>> 

150 >>> # Check 

151 >>> df.show(truncate=False) 

152 ``` 

153 <div class="result" markdown> 

154 ```{.txt .text title="Terminal"} 

155 +---+----------------------+----------------------+ 

156 |a |b |c | 

157 +---+----------------------+----------------------+ 

158 |0 |1.10000000000000000000|1.60000000000000000000| 

159 |1 |1.01000000000000000000|1.06000000000000000000| 

160 |2 |1.00100000000000000000|1.00600000000000000000| 

161 |3 |1.00010000000000000000|1.00060000000000000000| 

162 |4 |1.00001000000000000000|1.00006000000000000000| 

163 |5 |1.00000100000000000000|1.00000600000000000000| 

164 |6 |1.00000010000000000000|1.00000060000000000000| 

165 |7 |1.00000001000000000000|1.00000006000000000000| 

166 |8 |1.00000000100000000000|1.00000000600000000000| 

167 |9 |1.00000000010000000000|1.00000000060000000000| 

168 |10 |1.00000000001000000000|1.00000000006000000000| 

169 |11 |1.00000000000100000000|1.00000000000600000000| 

170 |12 |1.00000000000010000000|1.00000000000060000000| 

171 |13 |1.00000000000001000000|1.00000000000006000000| 

172 |14 |1.00000000000000100000|1.00000000000000600000| 

173 |15 |1.00000000000000010000|1.00000000000000060000| 

174 |16 |1.00000000000000001000|1.00000000000000006000| 

175 |17 |1.00000000000000000100|1.00000000000000000600| 

176 |18 |1.00000000000000000010|1.00000000000000000060| 

177 |19 |1.00000000000000000001|1.00000000000000000006| 

178 +---+----------------------+----------------------+ 

179 ``` 

180 </div> 

181 

182 ```{.py .python linenums="1" title="Example 1: Round with defaults"} 

183 >>> round_column(df, "b").show(truncate=False) 

184 ``` 

185 <div class="result" markdown> 

186 ```{.txt .text title="Terminal"} 

187 +---+------------+----------------------+ 

188 |a |b |c | 

189 +---+------------+----------------------+ 

190 |0 |1.1000000000|1.60000000000000000000| 

191 |1 |1.0100000000|1.06000000000000000000| 

192 |2 |1.0010000000|1.00600000000000000000| 

193 |3 |1.0001000000|1.00060000000000000000| 

194 |4 |1.0000100000|1.00006000000000000000| 

195 |5 |1.0000010000|1.00000600000000000000| 

196 |6 |1.0000001000|1.00000060000000000000| 

197 |7 |1.0000000100|1.00000006000000000000| 

198 |8 |1.0000000010|1.00000000600000000000| 

199 |9 |1.0000000001|1.00000000060000000000| 

200 |10 |1.0000000000|1.00000000006000000000| 

201 |11 |1.0000000000|1.00000000000600000000| 

202 |12 |1.0000000000|1.00000000000060000000| 

203 |13 |1.0000000000|1.00000000000006000000| 

204 |14 |1.0000000000|1.00000000000000600000| 

205 |15 |1.0000000000|1.00000000000000060000| 

206 |16 |1.0000000000|1.00000000000000006000| 

207 |17 |1.0000000000|1.00000000000000000600| 

208 |18 |1.0000000000|1.00000000000000000060| 

209 |19 |1.0000000000|1.00000000000000000006| 

210 +---+------------+----------------------+ 

211 ``` 

212 !!! success "Conclusion: Successfully rounded column `b`." 

213 </div> 

214 

215 ```{.py .python linenums="1" title="Example 2: Round to custom number"} 

216 >>> round_column(df, "c", 5).show(truncate=False) 

217 ``` 

218 <div class="result" markdown> 

219 ```{.txt .text title="Terminal"} 

220 +---+----------------------+-------+ 

221 |a |b |c | 

222 +---+----------------------+-------+ 

223 |0 |1.10000000000000000000|1.60000| 

224 |1 |1.01000000000000000000|1.06000| 

225 |2 |1.00100000000000000000|1.00600| 

226 |3 |1.00010000000000000000|1.00060| 

227 |4 |1.00001000000000000000|1.00006| 

228 |5 |1.00000100000000000000|1.00001| 

229 |6 |1.00000010000000000000|1.00000| 

230 |7 |1.00000001000000000000|1.00000| 

231 |8 |1.00000000100000000000|1.00000| 

232 |9 |1.00000000010000000000|1.00000| 

233 |10 |1.00000000001000000000|1.00000| 

234 |11 |1.00000000000100000000|1.00000| 

235 |12 |1.00000000000010000000|1.00000| 

236 |13 |1.00000000000001000000|1.00000| 

237 |14 |1.00000000000000100000|1.00000| 

238 |15 |1.00000000000000010000|1.00000| 

239 |16 |1.00000000000000001000|1.00000| 

240 |17 |1.00000000000000000100|1.00000| 

241 |18 |1.00000000000000000010|1.00000| 

242 |19 |1.00000000000000000001|1.00000| 

243 +---+----------------------+-------+ 

244 ``` 

245 !!! success "Conclusion: Successfully rounded column `b` to 5 decimal points." 

246 </div> 

247 

248 ```{.py .python linenums="1" title="Example 3: Raise error"} 

249 >>> round_column(df, "a").show(truncate=False) 

250 ``` 

251 <div class="result" markdown> 

252 ```{.txt .text title="Terminal"} 

253 TypeError: Column is not the correct type. Please check. 

254 For column 'a', the type is 'bigint'. 

255 In order to round it, it needs to be one of: '["float", "double", "decimal"]'. 

256 ``` 

257 !!! failure "Conclusion: Cannot round a column `a`." 

258 </div> 

259 """ 

260 assert_column_exists(dataframe, column) 

261 col_type: str = [typ.split("(")[0] for col, typ in dataframe.dtypes if col == column][0] 

262 if col_type not in VALID_TYPES: 

263 raise TypeError( 

264 f"Column is not the correct type. Please check.\n" 

265 f"For column '{column}', the type is '{col_type}'.\n" 

266 f"In order to round it, it needs to be one of: '{VALID_TYPES}'." 

267 ) 

268 return dataframe.withColumn(colName=column, col=F.round(col=column, scale=scale)) 

269 

270 

271@typechecked 

272def round_columns( 

273 dataframe: psDataFrame, 

274 columns: Optional[Union[str, str_collection]] = "all_float", 

275 scale: int = DEFAULT_DECIMAL_ACCURACY, 

276) -> psDataFrame: 

277 """ 

278 !!! note "Summary" 

279 For a given `dataframe`, on a set of `columns` if the column data type is decimal (that is, one of: `#!py ["float", "double", "decimal"]`), then round that column to a `scale` accuracy at a given number of decimal places. 

280 

281 ???+ abstract "Details" 

282 Realistically, under the hood, this function is super simple. It merely runs: 

283 ```{.py .python linenums="1" title="Python"} 

284 dataframe = dataframe.withColumns({col: F.round(col, scale) for col in columns}) 

285 ``` 

286 This function merely adds some additional validation, and is enabled to run in a pyspark `.transform()` method. 

287 For more info, see: [`pyspark.sql.DataFrame.transform`](https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.DataFrame.transform.html) 

288 

289 Params: 

290 dataframe (psDataFrame): 

291 The `dataframe` to be transformed. 

292 columns (Optional[Union[str, str_collection]], optional): 

293 The desired column to be rounded.<br> 

294 If no value is parsed, or is the value `#!py None`, or one of `#!py ["all", "all_float"]`, then it will default to all numeric decimal columns on the `dataframe`.<br> 

295 If the value is a `#!py str`, then it will be coerced to a single-element list, like: `#!py [columns]`.<br> 

296 Defaults to `#!py "all_float"`. 

297 scale (int, optional): 

298 The required level of rounding for the column.<br> 

299 If not provided explicitly, it will default to the global value `#!py DEFAULT_DECIMAL_ACCURACY`; which is `#!py 10`.<br> 

300 Defaults to `#!py DEFAULT_DECIMAL_ACCURACY`. 

301 

302 Raises: 

303 TypeError: 

304 If any of the inputs parsed to the parameters of this function are not the correct type. Uses the [`@typeguard.typechecked`](https://typeguard.readthedocs.io/en/stable/api.html#typeguard.typechecked) decorator. 

305 TypeError: 

306 If any of the given `columns` are not one of the correct data types for rounding. They must be one of: `#!py ["float", "double", "decimal"]`. 

307 

308 Returns: 

309 (psDataFrame): 

310 The transformed `dataframe` containing the column which has now been rounded. 

311 

312 ???+ example "Examples" 

313 

314 ```{.py .python linenums="1" title="Set up"} 

315 >>> # Imports 

316 >>> import pandas as pd 

317 >>> from pyspark.sql import SparkSession, functions as F, types as T 

318 >>> from toolbox_pyspark.io import read_from_path 

319 >>> 

320 >>> # Instantiate Spark 

321 >>> spark = SparkSession.builder.getOrCreate() 

322 >>> 

323 >>> # Create data 

324 >>> df = ( 

325 ... spark 

326 ... .createDataFrame( 

327 ... pd.DataFrame( 

328 ... { 

329 ... "a": range(20), 

330 ... "b": [f"1.{'0'*val}1" for val in range(20)], 

331 ... "c": [f"1.{'0'*val}6" for val in range(20)], 

332 ... } 

333 ... ) 

334 ... ) 

335 ... .withColumns( 

336 ... { 

337 ... "b": F.col("b").cast(T.DecimalType(21, 20)), 

338 ... "c": F.col("c").cast(T.DecimalType(21, 20)), 

339 ... } 

340 ... ) 

341 ... ) 

342 >>> 

343 >>> # Check 

344 >>> df.show(truncate=False) 

345 ``` 

346 <div class="result" markdown> 

347 ```{.txt .text title="Terminal"} 

348 +---+----------------------+----------------------+ 

349 |a |b |c | 

350 +---+----------------------+----------------------+ 

351 |0 |1.10000000000000000000|1.60000000000000000000| 

352 |1 |1.01000000000000000000|1.06000000000000000000| 

353 |2 |1.00100000000000000000|1.00600000000000000000| 

354 |3 |1.00010000000000000000|1.00060000000000000000| 

355 |4 |1.00001000000000000000|1.00006000000000000000| 

356 |5 |1.00000100000000000000|1.00000600000000000000| 

357 |6 |1.00000010000000000000|1.00000060000000000000| 

358 |7 |1.00000001000000000000|1.00000006000000000000| 

359 |8 |1.00000000100000000000|1.00000000600000000000| 

360 |9 |1.00000000010000000000|1.00000000060000000000| 

361 |10 |1.00000000001000000000|1.00000000006000000000| 

362 |11 |1.00000000000100000000|1.00000000000600000000| 

363 |12 |1.00000000000010000000|1.00000000000060000000| 

364 |13 |1.00000000000001000000|1.00000000000006000000| 

365 |14 |1.00000000000000100000|1.00000000000000600000| 

366 |15 |1.00000000000000010000|1.00000000000000060000| 

367 |16 |1.00000000000000001000|1.00000000000000006000| 

368 |17 |1.00000000000000000100|1.00000000000000000600| 

369 |18 |1.00000000000000000010|1.00000000000000000060| 

370 |19 |1.00000000000000000001|1.00000000000000000006| 

371 +---+----------------------+----------------------+ 

372 ``` 

373 </div> 

374 

375 ```{.py .python linenums="1" title="Example 1: Round with defaults"} 

376 >>> round_columns(df).show(truncate=False) 

377 ``` 

378 <div class="result" markdown> 

379 ```{.txt .text title="Terminal"} 

380 +---+------------+------------+ 

381 | a| b| c| 

382 +---+------------+------------+ 

383 | 0|1.1000000000|1.6000000000| 

384 | 1|1.0100000000|1.0600000000| 

385 | 2|1.0010000000|1.0060000000| 

386 | 3|1.0001000000|1.0006000000| 

387 | 4|1.0000100000|1.0000600000| 

388 | 5|1.0000010000|1.0000060000| 

389 | 6|1.0000001000|1.0000006000| 

390 | 7|1.0000000100|1.0000000600| 

391 | 8|1.0000000010|1.0000000060| 

392 | 9|1.0000000001|1.0000000006| 

393 | 10|1.0000000000|1.0000000001| 

394 | 11|1.0000000000|1.0000000000| 

395 | 12|1.0000000000|1.0000000000| 

396 | 13|1.0000000000|1.0000000000| 

397 | 14|1.0000000000|1.0000000000| 

398 | 15|1.0000000000|1.0000000000| 

399 | 16|1.0000000000|1.0000000000| 

400 | 17|1.0000000000|1.0000000000| 

401 | 18|1.0000000000|1.0000000000| 

402 | 19|1.0000000000|1.0000000000| 

403 +---+------------+------------+ 

404 ``` 

405 </div> 

406 

407 ```{.py .python linenums="1" title="Example 2: Round to custom number"} 

408 >>> round_columns(df, "c", 5).show(truncate=False) 

409 ``` 

410 <div class="result" markdown> 

411 ```{.txt .text title="Terminal"} 

412 +---+----------------------+-------+ 

413 |a |b |c | 

414 +---+----------------------+-------+ 

415 |0 |1.10000000000000000000|1.60000| 

416 |1 |1.01000000000000000000|1.06000| 

417 |2 |1.00100000000000000000|1.00600| 

418 |3 |1.00010000000000000000|1.00060| 

419 |4 |1.00001000000000000000|1.00006| 

420 |5 |1.00000100000000000000|1.00001| 

421 |6 |1.00000010000000000000|1.00000| 

422 |7 |1.00000001000000000000|1.00000| 

423 |8 |1.00000000100000000000|1.00000| 

424 |9 |1.00000000010000000000|1.00000| 

425 |10 |1.00000000001000000000|1.00000| 

426 |11 |1.00000000000100000000|1.00000| 

427 |12 |1.00000000000010000000|1.00000| 

428 |13 |1.00000000000001000000|1.00000| 

429 |14 |1.00000000000000100000|1.00000| 

430 |15 |1.00000000000000010000|1.00000| 

431 |16 |1.00000000000000001000|1.00000| 

432 |17 |1.00000000000000000100|1.00000| 

433 |18 |1.00000000000000000010|1.00000| 

434 |19 |1.00000000000000000001|1.00000| 

435 +---+----------------------+-------+ 

436 ``` 

437 </div> 

438 

439 ```{.py .python linenums="1" title="Example 3: Raise error"} 

440 >>> round_columns(df, ["a", "b"]).show(truncate=False) 

441 ``` 

442 <div class="result" markdown> 

443 ```{.txt .text title="Terminal"} 

444 TypeError: Columns are not the correct types. Please check. 

445 These columns are invalid: '[("a", "bigint")]'. 

446 In order to round them, they need to be one of: '["float", "double", "decimal"]'. 

447 ``` 

448 </div> 

449 """ 

450 if columns is None or columns in ["all", "all_float"]: 

451 columns = [col for col, typ in dataframe.dtypes if typ.split("(")[0] in VALID_TYPES] 

452 elif is_type(columns, str): 

453 columns = [columns] 

454 assert_columns_exists(dataframe, columns) 

455 invalid_cols: list[tuple[str, str]] = [ 

456 (col, typ.split("(")[0]) 

457 for col, typ in dataframe.dtypes 

458 if col in columns and typ.split("(")[0] not in VALID_TYPES 

459 ] 

460 if len(invalid_cols) > 0: 

461 raise TypeError( 

462 f"Columns are not the correct types. Please check.\n" 

463 f"These columns are invalid: '{invalid_cols}'.\n" 

464 f"In order to round them, they need to be one of: '{VALID_TYPES}'." 

465 ) 

466 return dataframe.withColumns({col: F.round(col, scale) for col in columns})