changeset 749:10dbab1e871d dev

modifications in samples and distributions
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 20 Oct 2015 00:03:25 -0400
parents d45ab817ee11
children 6049e9b6902c
files python/utils.py
diffstat 1 files changed, 22 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/python/utils.py	Fri Oct 02 11:29:43 2015 -0400
+++ b/python/utils.py	Tue Oct 20 00:03:25 2015 -0400
@@ -74,7 +74,7 @@
         result += ((e-o)*(e-o))/e
     return result
 
-class EmpiricalDistribution(object):
+class DistributionSample(object):
     def nSamples(self):
         return sum(self.counts)
 
@@ -86,9 +86,8 @@
         counts /= float(len(sample))
     return xaxis, counts
 
-class EmpiricalDiscreteDistribution(EmpiricalDistribution):
-    '''Class to represent a sample of a distribution for a discrete random variable
-    '''
+class DiscreteDistributionSample(DistributionSample):
+    '''Class to represent a sample of a distribution for a discrete random variable'''
     def __init__(self, categories, counts):
         self.categories = categories
         self.counts = counts
@@ -113,7 +112,7 @@
         refCounts = [r*self.nSamples() for r in refProba]
         return refCounts, refProba
 
-class EmpiricalContinuousDistribution(EmpiricalDistribution):
+class ContinuousDistributionSample(DistributionSample):
     '''Class to represent a sample of a distribution for a continuous random variable
     with the number of observations for each interval
     intervals (categories variable) are defined by their left limits, the last one being the right limit
@@ -123,6 +122,24 @@
         self.categories = categories
         self.counts = counts
 
+    @staticmethod
+    def generate(sample, categories):
+        if min(sample) < min(categories):
+            print('Sample has lower min than proposed categories ({}, {})'.format(min(sample), min(categories)))
+        if max(sample) > max(categories):
+            print('Sample has higher max than proposed categories ({}, {})'.format(max(sample), max(categories)))
+        dist = ContinuousDistributionSample(sorted(categories), [0]*(len(categories)-1))
+        for s in sample:
+            i = 0
+            while  i<len(dist.categories) and dist.categories[i] <= s:
+                i += 1
+            if i <= len(dist.counts):
+                dist.counts[i-1] += 1
+                #print('{} in {} {}'.format(s, dist.categories[i-1], dist.categories[i]))
+            else:
+                print('Element {} is not in the categories'.format(s))
+        return dist
+
     def mean(self):
         result = 0.
         for i in range(len(self.counts)-1):