MtCarsPLS2ScaleTrueTest.php 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. <?php
  2. namespace MathPHP\Tests\Statistics\Multivariate\PLS;
  3. use MathPHP\Exception;
  4. use MathPHP\LinearAlgebra\Matrix;
  5. use MathPHP\LinearAlgebra\MatrixFactory;
  6. use MathPHP\SampleData;
  7. use MathPHP\Statistics\Multivariate\PLS;
  8. class MtCarsPLS2ScaleTrueTest extends \PHPUnit\Framework\TestCase
  9. {
  10. /** @var PLS */
  11. private static $pls;
  12. /** @var Matrix */
  13. private static $X;
  14. /** @var Matrix */
  15. private static $Y;
  16. /**
  17. * R code for expected values:
  18. * library(chemometrics)
  19. * X = mtcars[,c(2:3, 5:7, 10:11)]
  20. * Y = mtcars[,c(1,4)]
  21. * pls.model = pls2_nipals(X, Y, 2, scale=TRUE)
  22. *
  23. * @throws Exception\MathException
  24. */
  25. public static function setUpBeforeClass(): void
  26. {
  27. $mtCars = new SampleData\MtCars();
  28. // Remove any categorical variables
  29. $continuous = MatrixFactory::create($mtCars->getData())
  30. ->columnExclude(8)
  31. ->columnExclude(7);
  32. // exclude mpg and hp.
  33. self::$X = $continuous->columnExclude(3)->columnExclude(0);
  34. // mpg and hp, columns 0 and 3.
  35. self::$Y = $continuous
  36. ->columnExclude(2)
  37. ->columnExclude(1)
  38. ->submatrix(0, 0, $continuous->getM() - 1, 1);
  39. self::$pls = new PLS(self::$X, self::$Y, 2, true);
  40. }
  41. /**
  42. * @test Construction
  43. * @throws Exception\MathException
  44. */
  45. public function testConstruction()
  46. {
  47. // When
  48. $pls = new PLS(self::$X, self::$Y, 2, true);
  49. // Then
  50. $this->assertInstanceOf(PLS::class, $pls);
  51. }
  52. /**
  53. * @test Construction error - row mismatch
  54. */
  55. public function testConstructionFailureXAndYRowMismatch()
  56. {
  57. // Given
  58. $Y = self::$Y->rowExclude(0);
  59. // Then
  60. $this->expectException(\MathPHP\Exception\BadDataException::class);
  61. // When
  62. $pls = new PLS(self::$X, $Y, 2, true);
  63. }
  64. /**
  65. * @test The class returns the correct values for B
  66. *
  67. * R code for expected values:
  68. * pls.model$B
  69. */
  70. public function testB()
  71. {
  72. // Given
  73. $expected = [
  74. [-0.2143731, 0.22289146],
  75. [-0.2105791, 0.20363413],
  76. [ 0.1588566, -0.05241863],
  77. [-0.2034550, 0.14419053],
  78. [ 0.1246612, -0.26901292],
  79. [ 0.1007163, 0.07176686],
  80. [-0.1473158, 0.29083061],
  81. ];
  82. // When
  83. $B = self::$pls->getCoefficients()->getMatrix();
  84. // Then
  85. $this->assertEqualsWithDelta($expected, $B, .00001, '');
  86. }
  87. public function testC()
  88. {
  89. // Given.
  90. $expected = [
  91. [ 0.454770, 0.03737499],
  92. [-0.430135, 0.25598916],
  93. ];
  94. // When
  95. $C = self::$pls->getYLoadings()->getMatrix();
  96. // Then
  97. $this->assertEqualsWithDelta($expected, $C, .00001, '');
  98. }
  99. public function testP()
  100. {
  101. // Given.
  102. $expected = [
  103. [-0.4830909, 0.008542336],
  104. [-0.4731801, -0.101398021],
  105. [ 0.3706942, 0.362571565],
  106. [-0.4418500, -0.185442628],
  107. [ 0.2779994, -0.501504936],
  108. [ 0.2450876, 0.576659047],
  109. [-0.2948947, 0.495757659],
  110. ];
  111. // When
  112. $P = self::$pls->getProjection()->getMatrix();
  113. // Then
  114. $this->assertEqualsWithDelta($expected, $P, .00001, '');
  115. }
  116. public function testT()
  117. {
  118. // Given.
  119. $expected = [
  120. [0.3204847, 1.30386443],
  121. [0.3067005, 1.0981707],
  122. [2.2145883, -0.35041917],
  123. [0.169134, -1.9075156],
  124. [-1.4608329, -0.78730905],
  125. [0.1298623, -2.38042406],
  126. [-2.1455482, 0.21474065],
  127. [1.6070201, -0.67486532],
  128. [2.2992852, -1.3660906],
  129. [0.2666728, 0.64453581],
  130. [0.3730258, 0.47038762],
  131. [-1.632483, -0.74293298],
  132. [-1.4463028, -0.74344181],
  133. [-1.3975669, -0.86800244],
  134. [-3.099056, -0.85271805],
  135. [-3.1174811, -0.7932455],
  136. [-2.9400458, -0.521798],
  137. [2.668381, -0.43889341],
  138. [3.055, 0.79125128],
  139. [3.0191218, -0.41494088],
  140. [2.026571, -1.65106247],
  141. [-1.6002063, -1.00743915],
  142. [-1.1991955, -0.8746524],
  143. [-1.9911201, 0.61920095],
  144. [-1.8270286, -0.90357751],
  145. [2.6837019, -0.22856777],
  146. [2.2680306, 1.69918659],
  147. [2.2787298, 1.32962432],
  148. [-1.0916947, 2.85867373],
  149. [-0.2461754, 2.80033238],
  150. [-2.3811881, 3.61645679],
  151. [1.8896156, 0.06147091],
  152. ];
  153. // When
  154. $T = self::$pls->getXScores()->getMatrix();
  155. // Then
  156. $this->assertEqualsWithDelta($expected, $T, .00001, '');
  157. }
  158. public function testU()
  159. {
  160. // Given.
  161. $expected = [
  162. [0.29878003, -0.1196459],
  163. [0.29878003, -0.12014881],
  164. [0.54125196, -0.10284985],
  165. [0.32896247, -0.12268734],
  166. [-0.28255201, 0.04378714],
  167. [0.11132526, -0.16325269],
  168. [-1.05370979, 0.2528747],
  169. [0.85646285, -0.23083751],
  170. [0.52870479, -0.09229244],
  171. [0.08140245, -0.08423422],
  172. [-0.02423609, -0.08903582],
  173. [-0.48746897, 0.04192974],
  174. [-0.41955848, 0.05430363],
  175. [-0.57801629, 0.04305899],
  176. [-1.0970452, 0.04455559],
  177. [-1.15978105, 0.08121978],
  178. [-0.92942358, 0.17036369],
  179. [1.43501732, -0.12756913],
  180. [1.37193531, -0.1781371],
  181. [1.55447506, -0.10920417],
  182. [0.41806469, -0.10283672],
  183. [-0.36717192, -0.07448308],
  184. [-0.38980875, -0.06171276],
  185. [-1.12916589, 0.25230763],
  186. [-0.24482396, 0.03352729],
  187. [1.0501912, -0.15863685],
  188. [0.79525865, -0.08852311],
  189. [0.98924664, 0.02129312],
  190. [-1.05972375, 0.37156526],
  191. [-0.20709591, 0.09430472],
  192. [-1.56551315, 0.58464612],
  193. [0.33523606, -0.06364992],
  194. ];
  195. // When
  196. $U = self::$pls->getYScores()->getMatrix();
  197. // Then
  198. $this->assertEqualsWithDelta($expected, $U, .00001, '');
  199. }
  200. public function testW()
  201. {
  202. // Given.
  203. $expected = [
  204. [-0.4770668, 0.01413703],
  205. [-0.4643040, -0.03817455],
  206. [ 0.3217142, 0.37286624],
  207. [ -0.4337710, -0.21556426],
  208. [ 0.3167445, -0.48216394],
  209. [ 0.1743495, 0.59339427],
  210. [-0.3666701, 0.47775186],
  211. ];
  212. // When
  213. $W = self::$pls->getXLoadings()->getMatrix();
  214. // Then
  215. $this->assertEqualsWithDelta($expected, $W, .00001, '');
  216. }
  217. /**
  218. * R code for expected values:
  219. * ones = matrix(1L, nrow = dim(X)[1], ncol = 1)
  220. * scale(X) %*% pls.model$B %*% diag(apply(Y, 2, sd)) + ones %*% colMeans(Y)
  221. *
  222. * @test predict Y values from X
  223. * @dataProvider dataProviderForRegression
  224. * @param array $X
  225. * @param array $Y
  226. */
  227. public function testRegression($X, $expected)
  228. {
  229. // Given.
  230. $input = MatrixFactory::create($X);
  231. // When
  232. $actual = self::$pls->predict($input)->getMatrix();
  233. // Then
  234. $this->assertEqualsWithDelta($expected, $actual, .00001, '');
  235. }
  236. public function dataProviderForRegression()
  237. {
  238. return [
  239. [
  240. [[6, 160, 3.9, 2.62, 16.46, 4, 4]],
  241. [[21.26274, 160.12058]],
  242. ],
  243. [
  244. [[6, 160, 3.9, 2.875, 17.02, 4, 4]],
  245. [[21.17862, 156.91689]],
  246. ]
  247. ];
  248. }
  249. /**
  250. * @test predict error if the input X columns do not match
  251. */
  252. public function testPredictDataColumnMisMatch()
  253. {
  254. // Given
  255. $X = MatrixFactory::create([[6, 160, 3.9, 2.62, 16.46]]);
  256. // Then
  257. $this->expectException(\MathPHP\Exception\BadDataException::class);
  258. // When
  259. $prediction = self::$pls->predict($X);
  260. }
  261. }