SVDTest.php 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. <?php
  2. namespace MathPHP\Tests\LinearAlgebra\Decomposition;
  3. use MathPHP\Functions\Support;
  4. use MathPHP\LinearAlgebra\MatrixFactory;
  5. use MathPHP\Exception;
  6. use MathPHP\LinearAlgebra\NumericMatrix;
  7. use MathPHP\LinearAlgebra\Vector;
  8. use MathPHP\Tests\LinearAlgebra\Fixture\MatrixDataProvider;
  9. class SVDTest extends \PHPUnit\Framework\TestCase
  10. {
  11. use MatrixDataProvider;
  12. /**
  13. * @test SVD returns the expected array of U, S, and Vt factorized matrices
  14. * @dataProvider dataProviderForSVD
  15. * @dataProvider dataProviderForLesserRankSVD
  16. * @param array $A
  17. * @param array $expected
  18. * @throws \Exception
  19. */
  20. public function testSVD(array $A, array $expected)
  21. {
  22. // Given
  23. $A = MatrixFactory::createNumeric($A);
  24. $expected_S = MatrixFactory::createNumeric($expected['S']);
  25. // When
  26. $svd = $A->svd();
  27. // And
  28. $U = $svd->U;
  29. $S = $svd->S;
  30. $V = $svd->V;
  31. // Then A = USVᵀ
  32. $this->assertEqualsWithDelta($A->getMatrix(), $U->multiply($S)->multiply($V->transpose())->getMatrix(), 0.00001);
  33. // And S is expected solution to SVD
  34. $this->assertEqualsWithDelta($expected_S->getMatrix(), $S->getMatrix(), 0.00001);
  35. }
  36. /**
  37. * Test data created with:
  38. * R: svd(A)
  39. * Python: scipy.linalg.svd(A)
  40. * @return array
  41. */
  42. public function dataProviderForSVD(): array
  43. {
  44. return [
  45. [
  46. [
  47. [1, 0, 0, 0, 2],
  48. [0, 0, 3, 0, 0],
  49. [0, 0, 0, 0, 0],
  50. [0, 2, 0, 0, 0],
  51. ],
  52. [ // Technically, the order of the diagonal elements can be in any order
  53. 'S' => [
  54. [3, 0, 0, 0, 0],
  55. [0, sqrt(5), 0, 0, 0],
  56. [0, 0, 2, 0, 0],
  57. [0, 0, 0, 0, 0],
  58. ],
  59. ],
  60. ],
  61. [
  62. [
  63. [8, -6, 2],
  64. [-6, 7, -4],
  65. [2, -4, -3],
  66. ],
  67. [
  68. 'S' => [
  69. [14.528807, 0, 0],
  70. [0, 4.404176, 0],
  71. [0, 0, 1.875369],
  72. ],
  73. ],
  74. ],
  75. [
  76. [
  77. [1, 2],
  78. [3, 4],
  79. [5, 6],
  80. ],
  81. [
  82. 'S' => [
  83. [9.52551809, 0],
  84. [0, 0.51430058],
  85. [0, 0],
  86. ],
  87. ],
  88. ],
  89. [
  90. [[3]],
  91. [
  92. 'S' => [[3]],
  93. ],
  94. ],
  95. [
  96. [[0]],
  97. [
  98. 'S' => [[0]],
  99. ],
  100. ],
  101. [
  102. [[1]],
  103. [
  104. 'S' => [[1]],
  105. ],
  106. ],
  107. [
  108. [
  109. [1, 2, 3],
  110. [4, 5, 6],
  111. [7, 8, 9],
  112. ],
  113. [
  114. 'S' => [
  115. [1.684810e+01, 0, 0],
  116. [0, 1.068370e+00, 0],
  117. [0, 0, 4.418425e-16],
  118. ],
  119. ],
  120. ],
  121. [
  122. [
  123. [2, 2, 2],
  124. [2, 2, 2],
  125. [2, 2, 2],
  126. ],
  127. [
  128. 'S' => [
  129. [6, 0, 0],
  130. [0, 0, 0],
  131. [0, 0, 0],
  132. ],
  133. ],
  134. ],
  135. [
  136. [
  137. [-2, -2, -2],
  138. [-2, -2, -2],
  139. [-2, -2, -2],
  140. ],
  141. [
  142. 'S' => [
  143. [6, 0, 0],
  144. [0, 0, 0],
  145. [0, 0, 0],
  146. ],
  147. ],
  148. ],
  149. [
  150. [
  151. [1, 2, 3],
  152. [0, 4, 5],
  153. [0, 0, 6],
  154. ],
  155. [
  156. 'S' => [
  157. [9.0125424, 0, 0],
  158. [0, 2.9974695, 0],
  159. [0, 0, 0.8884012],
  160. ],
  161. ],
  162. ],
  163. [
  164. [
  165. [1, 0, 0],
  166. [2, 3, 0],
  167. [4, 5, 6],
  168. ],
  169. [
  170. 'S' => [
  171. [9.2000960, 0, 0],
  172. [0, 2.3843001, 0],
  173. [0, 0, 0.8205768],
  174. ],
  175. ],
  176. ],
  177. // Singular
  178. [
  179. [
  180. [1, 0],
  181. [0, 0],
  182. ],
  183. [
  184. 'S' => [
  185. [1, 0],
  186. [0, 0],
  187. ],
  188. ],
  189. ],
  190. // Singular
  191. [
  192. [
  193. [1, 0, 1],
  194. [0, 1, -1],
  195. [0, 0, 0],
  196. ],
  197. [
  198. 'S' => [
  199. [1.732051, 0, 0],
  200. [0, 1.000000, 0],
  201. [0, 0, 0.0],
  202. ],
  203. ],
  204. ],
  205. // Idempotent
  206. [
  207. [
  208. [1, 0, 0],
  209. [0, 0, 0],
  210. [0, 0, 1],
  211. ],
  212. [
  213. 'S' => [
  214. [1, 0, 0],
  215. [0, 1, 0],
  216. [0, 0, 0],
  217. ],
  218. ],
  219. ],
  220. // Idempotent
  221. [
  222. [
  223. [2, -2, -4],
  224. [-1, 3, 4],
  225. [1, -2, -3],
  226. ],
  227. [
  228. 'S' => [
  229. [7.937254, 0, 0],
  230. [0, 1, 0],
  231. [0, 0, 2.198569e-17],
  232. ],
  233. ],
  234. ],
  235. // Floating point
  236. [
  237. [
  238. [2.5, 6.3, 9.1],
  239. [-1.4, 3.0, 4.45],
  240. [1.01, 8.5, -3.334],
  241. ],
  242. [
  243. 'S' => [
  244. [12.786005, 0, 0],
  245. [0, 8.663327, 0],
  246. [0, 0, 2.315812],
  247. ],
  248. ],
  249. ],
  250. ];
  251. }
  252. /**
  253. * @return array
  254. */
  255. public function dataProviderForLesserRankSVD(): array
  256. {
  257. return [
  258. [
  259. [
  260. [1, 1, 1, 1, 1],
  261. [1, 1, 1, 1, 1],
  262. [1, 1, 1, 1, 1],
  263. ],
  264. [
  265. 'S' => [
  266. [3.872983, 0, 0, 0, 0],
  267. [0, 1.812987e-16, 0, 0, 0],
  268. [0, 0, 1.509615e-32, 0, 0],
  269. ],
  270. ],
  271. ]
  272. ];
  273. }
  274. /**
  275. * @test SVD properties
  276. * @dataProvider dataProviderForSVD
  277. * @param array $A
  278. * @throws \Exception
  279. */
  280. public function testSVDProperties(array $A)
  281. {
  282. // Given
  283. $A = MatrixFactory::createNumeric($A);
  284. // When
  285. $svd = $A->svd();
  286. // And
  287. $U = $svd->U;
  288. $S = $svd->S;
  289. $V = $svd->V;
  290. $D = $svd->D;
  291. // Then U and V are orthogonal
  292. $this->assertTrue($svd->getU()->isOrthogonal());
  293. $this->assertTrue($svd->getV()->isOrthogonal());
  294. // And S is rectangular diagonal with non-negative real numbers on the diagonal
  295. $this->assertTrue($S->isRectangularDiagonal());
  296. foreach ($S->getDiagonalElements() as $diagonalElement) {
  297. $this->assertTrue($diagonalElement >= 0);
  298. }
  299. // And D contains the diagonal elements of S
  300. $this->assertEqualsWithDelta($D->getVector(), $S->getDiagonalElements(), 0.00001, '');
  301. // And the number of non-zero singular values is equal to the rank of M
  302. $nonZeroSingularValues = array_filter(
  303. $D->getVector(),
  304. function ($singularValue) {
  305. return Support::isNotZero($singularValue);
  306. }
  307. );
  308. $this->assertEquals($A->rank(), count($nonZeroSingularValues));
  309. // And UUᵀ = I
  310. $this->assertEqualsWithDelta(MatrixFactory::identity($U->getM())->getMatrix(), $U->multiply($U->transpose())->getMatrix(), 0.00001);
  311. // And VVᵀ = I
  312. $this->assertEqualsWithDelta(MatrixFactory::identity($V->getM())->getMatrix(), $V->multiply($V->transpose())->getMatrix(), 0.00001);
  313. }
  314. /**
  315. * @test SVD properties of less than full rank matrices
  316. * @dataProvider dataProviderForLesserRankSVD
  317. * @param array $A
  318. * @throws \Exception
  319. */
  320. public function testLesserRankSVDProperties(array $A)
  321. {
  322. // Given
  323. $A = MatrixFactory::createNumeric($A);
  324. // When
  325. $svd = $A->svd();
  326. // Then
  327. $this->assertTrue($svd->getU()->isOrthogonal());
  328. $this->assertTrue($svd->getS()->isRectangularDiagonal());
  329. $this->assertEqualsWithDelta($svd->D->getVector(), $svd->getS()->getDiagonalElements(), 0.00001, '');
  330. }
  331. /**
  332. * @test SVD get properties
  333. */
  334. public function testSVDGetProperties()
  335. {
  336. // Given
  337. $A = MatrixFactory::createNumeric([
  338. [4, 1, -1],
  339. [1, 2, 1],
  340. [-1, 1, 2],
  341. ]);
  342. $svd = $A->svd();
  343. // When
  344. $S = $svd->S;
  345. $V = $svd->V;
  346. $D = $svd->D;
  347. $U = $svd->U;
  348. // Then
  349. $this->assertInstanceOf(NumericMatrix::class, $S);
  350. $this->assertInstanceOf(NumericMatrix::class, $V);
  351. $this->assertInstanceOf(NumericMatrix::class, $U);
  352. $this->assertInstanceOf(Vector::class, $D);
  353. // And
  354. $this->assertEquals($svd->getS(), $S);
  355. $this->assertEquals($svd->getV(), $V);
  356. $this->assertEquals($svd->getD(), $D);
  357. $this->assertEquals($svd->getU(), $U);
  358. }
  359. /**
  360. * @test SVD invalid property
  361. */
  362. public function testSVDInvalidProperty()
  363. {
  364. // Given
  365. $A = MatrixFactory::createNumeric([
  366. [4, 1, -1],
  367. [1, 2, 1],
  368. [-1, 1, 2],
  369. ]);
  370. $svd = $A->svd();
  371. // Then
  372. $this->expectException(Exception\MathException::class);
  373. // When
  374. $doesNotExist = $svd->doesNotExist;
  375. }
  376. }