DivergenceTest.php 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. <?php
  2. namespace MathPHP\Tests\Statistics;
  3. use MathPHP\Statistics\Divergence;
  4. use MathPHP\Exception;
  5. class DivergenceTest extends \PHPUnit\Framework\TestCase
  6. {
  7. /**
  8. * @test kullbackLeibler
  9. * @dataProvider dataProviderForKullbackLeibler
  10. * @param array $p
  11. * @param array $q
  12. * @param float $expected
  13. */
  14. public function testKullbackLeibler(array $p, array $q, float $expected)
  15. {
  16. // When
  17. $BD = Divergence::kullbackLeibler($p, $q);
  18. // Then
  19. $this->assertEqualsWithDelta($expected, $BD, 0.0001);
  20. }
  21. /**
  22. * Test data created using Python's scipi.stats.Distance
  23. * @return array [p, q, distance]
  24. */
  25. public function dataProviderForKullbackLeibler(): array
  26. {
  27. return [
  28. [
  29. [0.5, 0.5],
  30. [0.75, 0.25],
  31. 0.14384103622589045,
  32. ],
  33. [
  34. [0.75, 0.25],
  35. [0.5, 0.5],
  36. 0.13081203594113694,
  37. ],
  38. [
  39. [0.2, 0.5, 0.3],
  40. [0.1, 0.4, 0.5],
  41. 0.096953524639296684,
  42. ],
  43. [
  44. [0.4, 0.6],
  45. [0.3, 0.7],
  46. 0.022582421084357374
  47. ],
  48. [
  49. [0.9, 0.1],
  50. [0.1, 0.9],
  51. 1.7577796618689758
  52. ],
  53. ];
  54. }
  55. /**
  56. * @test kullbackLeibler when arrays are different lengths
  57. */
  58. public function testKullbackLeiblerExceptionArraysDifferentLength()
  59. {
  60. // Given
  61. $p = [0.4, 0.5, 0.1];
  62. $q = [0.2, 0.8];
  63. // Then
  64. $this->expectException(Exception\BadDataException::class);
  65. // When
  66. Divergence::kullbackLeibler($p, $q);
  67. }
  68. /**
  69. * @test kullbackLeibler when probabilities do not add up to one
  70. */
  71. public function testKullbackLeiblerExceptionNotProbabilityDistributionThatAddsUpToOne()
  72. {
  73. // Given
  74. $p = [0.2, 0.2, 0.1];
  75. $q = [0.2, 0.4, 0.6];
  76. // Then
  77. $this->expectException(Exception\BadDataException::class);
  78. // When
  79. Divergence::kullbackLeibler($p, $q);
  80. }
  81. /**
  82. * @test jensenShannon
  83. * @dataProvider dataProviderForJensenShannonDivergence
  84. * @param array $p
  85. * @param array $q
  86. * @param float $expected
  87. */
  88. public function testJensenShannonDivergence(array $p, array $q, float $expected)
  89. {
  90. // When
  91. $BD = Divergence::jensenShannon($p, $q);
  92. // Then
  93. $this->assertEqualsWithDelta($expected, $BD, 0.0001);
  94. }
  95. /**
  96. * Test data created with Python's numpy/scipi where p and q are numpy.arrays:
  97. * def jsd(p, q):
  98. * M = (p + q) / 2
  99. * return (scipy.stats.Distance(p, M) + scipy.stats.Distance(q, M)) / 2
  100. * @return array [p, q, distance]
  101. */
  102. public function dataProviderForJensenShannonDivergence(): array
  103. {
  104. return [
  105. [
  106. [0.4, 0.6],
  107. [0.5, 0.5],
  108. 0.0050593899289876343,
  109. ],
  110. [
  111. [0.1, 0.2, 0.2, 0.2, 0.2, 0.1],
  112. [0.0, 0.1, 0.4, 0.4, 0.1, 0.0],
  113. 0.12028442909461383
  114. ],
  115. [
  116. [0.25, 0.5, 0.25],
  117. [0.5, 0.3, 0.2],
  118. 0.035262717451799902,
  119. ],
  120. [
  121. [0.5, 0.3, 0.2],
  122. [0.25, 0.5, 0.25],
  123. 0.035262717451799902,
  124. ],
  125. ];
  126. }
  127. /**
  128. * @test jensenShannon when the arrays are different lengths
  129. */
  130. public function testJensenShannonDivergenceExceptionArraysDifferentLength()
  131. {
  132. // Given
  133. $p = [0.4, 0.5, 0.1];
  134. $q = [0.2, 0.8];
  135. // Then
  136. $this->expectException(Exception\BadDataException::class);
  137. // When
  138. Divergence::jensenShannon($p, $q);
  139. }
  140. /**
  141. * @test jensenShannon when the probabilities do not add up to one
  142. */
  143. public function testJensenShannonDivergenceExceptionNotProbabilityDistributionThatAddsUpToOne()
  144. {
  145. // Given
  146. $p = [0.2, 0.2, 0.1];
  147. $q = [0.2, 0.4, 0.6];
  148. // Then
  149. $this->expectException(Exception\BadDataException::class);
  150. // When
  151. Divergence::jensenShannon($p, $q);
  152. }
  153. }