From 56bc5aa305b458d82fc308895c1c793cf8c3f639 Mon Sep 17 00:00:00 2001
From: David Ray
Date: Sat, 10 Dec 2016 06:26:03 -0600
Subject: [PATCH 01/52] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index e3f91cac..ed39c90b 100644
--- a/README.md
+++ b/README.md
@@ -75,7 +75,7 @@ See the blog: [Join the Cogmission](http://www.cogmission.ai)
| Core Algorithm | NuPIC Date |HTM.Java Date | Latest NuPIC SHA | Latest HTM.Java SHA | Status|
| --------------- |:-------------:|:------------:|:----------------:|:-------------------:|:-----:|
-| SpatialPooler | 2016-10-05 | 2016-10-07 |[commit](https://github.com/numenta/nupic/commit/7e77ecba4ffdd4991cfd87972de6211101e6661e)|[commit](https://github.com/numenta/htm.java/commit/2cdcee1fcc5f6c18c2c48b4b553c49879c1256bb#diff-22f96ea06fd0c2b3593c755cbccf0a8b)| Sync'd*
+| SpatialPooler | 2016-10-05 | 2016-10-07 |[commit](https://github.com/numenta/nupic/commit/7e77ecba4ffdd4991cfd87972de6211101e6661e)|[commit](https://github.com/numenta/htm.java/commit/2cdcee1fcc5f6c18c2c48b4b553c49879c1256bb#diff-22f96ea06fd0c2b3593c755cbccf0a8b)| [Pending NuPIC #3411*](https://github.com/numenta/nupic/pull/3411)
| TemporalMemory | 2016-09-23 | 2016-10-13 |[commit](https://github.com/numenta/nupic/commit/1036f25e7223471d72cebc536d6734f78d37b6c7)|[commit](https://github.com/numenta/htm.java/commit/7f4d8f2e2c910dd662909442546516e36adfc7cc)| Sync'd*
\* May be one of: "Sync'd" or "Behind". "Behind" expresses a temporary lapse in synchronization while devs are implementing new changes.
From 1f6bcdc25285812615f516416b677b736994f69e Mon Sep 17 00:00:00 2001
From: David Ray
Date: Fri, 23 Dec 2016 12:42:15 -0600
Subject: [PATCH 02/52] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index ed39c90b..67a74b79 100644
--- a/README.md
+++ b/README.md
@@ -75,7 +75,7 @@ See the blog: [Join the Cogmission](http://www.cogmission.ai)
| Core Algorithm | NuPIC Date |HTM.Java Date | Latest NuPIC SHA | Latest HTM.Java SHA | Status|
| --------------- |:-------------:|:------------:|:----------------:|:-------------------:|:-----:|
-| SpatialPooler | 2016-10-05 | 2016-10-07 |[commit](https://github.com/numenta/nupic/commit/7e77ecba4ffdd4991cfd87972de6211101e6661e)|[commit](https://github.com/numenta/htm.java/commit/2cdcee1fcc5f6c18c2c48b4b553c49879c1256bb#diff-22f96ea06fd0c2b3593c755cbccf0a8b)| [Pending NuPIC #3411*](https://github.com/numenta/nupic/pull/3411)
+| SpatialPooler | 2016-10-05 | 2016-10-07 |[commit](https://github.com/numenta/nupic/commit/5c3edead9526d3b5fb6a4f37ad9d38cdcf32f5ff)|[commit](https://github.com/numenta/htm.java/commit/2cdcee1fcc5f6c18c2c48b4b553c49879c1256bb#diff-22f96ea06fd0c2b3593c755cbccf0a8b)| [*Behind NuPIC Merge #3411](https://github.com/numenta/nupic/pull/3411)
| TemporalMemory | 2016-09-23 | 2016-10-13 |[commit](https://github.com/numenta/nupic/commit/1036f25e7223471d72cebc536d6734f78d37b6c7)|[commit](https://github.com/numenta/htm.java/commit/7f4d8f2e2c910dd662909442546516e36adfc7cc)| Sync'd*
\* May be one of: "Sync'd" or "Behind". "Behind" expresses a temporary lapse in synchronization while devs are implementing new changes.
From f299748a5aa3e6f8b75f6ad9d7fdc3aa8686235a Mon Sep 17 00:00:00 2001
From: David Ray
Date: Fri, 23 Dec 2016 12:43:46 -0600
Subject: [PATCH 03/52] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 67a74b79..61bdf54a 100644
--- a/README.md
+++ b/README.md
@@ -75,7 +75,7 @@ See the blog: [Join the Cogmission](http://www.cogmission.ai)
| Core Algorithm | NuPIC Date |HTM.Java Date | Latest NuPIC SHA | Latest HTM.Java SHA | Status|
| --------------- |:-------------:|:------------:|:----------------:|:-------------------:|:-----:|
-| SpatialPooler | 2016-10-05 | 2016-10-07 |[commit](https://github.com/numenta/nupic/commit/5c3edead9526d3b5fb6a4f37ad9d38cdcf32f5ff)|[commit](https://github.com/numenta/htm.java/commit/2cdcee1fcc5f6c18c2c48b4b553c49879c1256bb#diff-22f96ea06fd0c2b3593c755cbccf0a8b)| [*Behind NuPIC Merge #3411](https://github.com/numenta/nupic/pull/3411)
+| SpatialPooler | 2016-12-11 | 2016-10-07 |[commit](https://github.com/numenta/nupic/commit/5c3edead9526d3b5fb6a4f37ad9d38cdcf32f5ff)|[commit](https://github.com/numenta/htm.java/commit/2cdcee1fcc5f6c18c2c48b4b553c49879c1256bb#diff-22f96ea06fd0c2b3593c755cbccf0a8b)| [*Behind NuPIC Merge #3411](https://github.com/numenta/nupic/pull/3411)
| TemporalMemory | 2016-09-23 | 2016-10-13 |[commit](https://github.com/numenta/nupic/commit/1036f25e7223471d72cebc536d6734f78d37b6c7)|[commit](https://github.com/numenta/htm.java/commit/7f4d8f2e2c910dd662909442546516e36adfc7cc)| Sync'd*
\* May be one of: "Sync'd" or "Behind". "Behind" expresses a temporary lapse in synchronization while devs are implementing new changes.
From df92d0df30708e0af80861714c606112e1903502 Mon Sep 17 00:00:00 2001
From: David Ray
Date: Wed, 8 Feb 2017 09:51:08 -0600
Subject: [PATCH 04/52] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 61bdf54a..8b6acabe 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-# 
+# 
# htm.java
From 52bc9f99068587ad45a9c8a916cadc189eedaba3 Mon Sep 17 00:00:00 2001
From: David Ray
Date: Wed, 8 Feb 2017 09:55:14 -0600
Subject: [PATCH 05/52] Update README.md
---
README.md | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 8b6acabe..0e838b73 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,9 @@
-[](http://cogmission.ai) [](http://numenta.com/#hero) [](https://github.com/numenta/htm.java-examples) [](https://travis-ci.org/numenta/htm.java) [](https://coveralls.io/github/numenta/htm.java?branch=master) [](https://maven-badges.herokuapp.com/maven-central/org.numenta/htm.java) [![][license img]][license] [![docs-badge][]][docs] [](http://cogmission.ai)
+
+[](http://numenta.com/#hero) [](https://github.com/numenta/htm.java-examples) [](https://travis-ci.org/numenta/htm.java) [](https://coveralls.io/github/numenta/htm.java?branch=master) [](https://maven-badges.herokuapp.com/maven-central/org.numenta/htm.java) [![][license img]][license] [![docs-badge][]][docs] [](https://gitter.im/numenta/htm.java?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [](https://www.openhub.net/p/htm-java)
From 2280ce7d4da1a96cb3b9a49e99f151dc3d1fea27 Mon Sep 17 00:00:00 2001
From: David Ray
Date: Wed, 8 Feb 2017 10:05:52 -0600
Subject: [PATCH 06/52] Update README.md
---
README.md | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/README.md b/README.md
index 0e838b73..ff9688ae 100644
--- a/README.md
+++ b/README.md
@@ -3,9 +3,7 @@
-[](http://cogmission.ai)
-
-[](http://numenta.com/#hero) [](https://github.com/numenta/htm.java-examples) [](https://travis-ci.org/numenta/htm.java) [](https://coveralls.io/github/numenta/htm.java?branch=master) [](https://maven-badges.herokuapp.com/maven-central/org.numenta/htm.java) [![][license img]][license] [![docs-badge][]][docs] [](http://cogmission.ai)[](http://numenta.com/#hero) [](https://github.com/numenta/htm.java-examples) [](https://travis-ci.org/numenta/htm.java) [](https://coveralls.io/github/numenta/htm.java?branch=master) [](https://maven-badges.herokuapp.com/maven-central/org.numenta/htm.java) [![][license img]][license] [![docs-badge][]][docs] [](https://gitter.im/numenta/htm.java?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [](https://www.openhub.net/p/htm-java)
From 55662fc59e635976f4adfd290bf9d8c221cd8d26 Mon Sep 17 00:00:00 2001
From: David Ray
Date: Wed, 8 Feb 2017 10:07:47 -0600
Subject: [PATCH 07/52] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index ff9688ae..8b6acabe 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
-[](http://cogmission.ai)[](http://numenta.com/#hero) [](https://github.com/numenta/htm.java-examples) [](https://travis-ci.org/numenta/htm.java) [](https://coveralls.io/github/numenta/htm.java?branch=master) [](https://maven-badges.herokuapp.com/maven-central/org.numenta/htm.java) [![][license img]][license] [![docs-badge][]][docs] [](http://cogmission.ai) [](http://numenta.com/#hero) [](https://github.com/numenta/htm.java-examples) [](https://travis-ci.org/numenta/htm.java) [](https://coveralls.io/github/numenta/htm.java?branch=master) [](https://maven-badges.herokuapp.com/maven-central/org.numenta/htm.java) [![][license img]][license] [![docs-badge][]][docs] [](https://gitter.im/numenta/htm.java?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [](https://www.openhub.net/p/htm-java)
From 3a58776fd03b3b302982799d20a6afe113fc4af5 Mon Sep 17 00:00:00 2001
From: David Ray
Date: Wed, 8 Feb 2017 11:30:45 -0600
Subject: [PATCH 08/52] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 8b6acabe..3a73cf68 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-# 
+#
# htm.java
From 6c32139d479ed359c7076de553b7d9478c384506 Mon Sep 17 00:00:00 2001
From: Matthew Taylor
Date: Tue, 14 Feb 2017 06:46:32 -0800
Subject: [PATCH 09/52] Updated location of Java docs in readme
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 3a73cf68..96ec64fe 100644
--- a/README.md
+++ b/README.md
@@ -51,7 +51,7 @@ _**NOTE: Minimum JavaSE version is 8**_
For a more detailed discussion of htm.java see:
* [htm.java Wiki](https://github.com/numenta/htm.java/wiki)
-* [Java Docs](http://numenta.org/docs/htm.java/)
+* [Java Docs](http://numenta.github.io/htm.java/)
See the [Test Coverage Reports](https://coveralls.io/jobs/4164658) - For more information on where you can contribute! Extend the tests and get your name in bright lights!
From 75407f8fd57f9df1682492a5ec667b14e249fbf4 Mon Sep 17 00:00:00 2001
From: Hopding
Date: Sun, 19 Feb 2017 12:32:04 -0600
Subject: [PATCH 10/52] preliminary stage
---
.../java/org/numenta/nupic/Parameters.java | 5 +-
.../nupic/algorithms/CLAClassifier.java | 2 +-
.../numenta/nupic/algorithms/Classifier.java | 15 +++++
.../nupic/algorithms/SDRClassifier.java | 2 +-
.../java/org/numenta/nupic/network/Layer.java | 61 +++++++++++++------
.../org/numenta/nupic/network/Region.java | 3 +
6 files changed, 65 insertions(+), 23 deletions(-)
create mode 100644 src/main/java/org/numenta/nupic/algorithms/Classifier.java
diff --git a/src/main/java/org/numenta/nupic/Parameters.java b/src/main/java/org/numenta/nupic/Parameters.java
index ba623a86..001f8160 100644
--- a/src/main/java/org/numenta/nupic/Parameters.java
+++ b/src/main/java/org/numenta/nupic/Parameters.java
@@ -32,6 +32,7 @@
import java.util.Random;
import java.util.Set;
+import org.numenta.nupic.algorithms.Classifier;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.model.Cell;
@@ -417,8 +418,8 @@ public static enum KEY {
// Network Layer indicator for auto classifier generation
AUTO_CLASSIFY("hasClassifiers", Boolean.class),
-
-
+ INFERRED_FIELDS("inferredFields", Map.class), // Map Classification compute(int recordNum,
+ Map classification,
+ int[] patternNZ,
+ boolean learn,
+ boolean infer);
+}
diff --git a/src/main/java/org/numenta/nupic/algorithms/SDRClassifier.java b/src/main/java/org/numenta/nupic/algorithms/SDRClassifier.java
index 57a94887..d36cbabd 100644
--- a/src/main/java/org/numenta/nupic/algorithms/SDRClassifier.java
+++ b/src/main/java/org/numenta/nupic/algorithms/SDRClassifier.java
@@ -96,7 +96,7 @@
* @author David Ray
* @author Andrew Dillon
*/
-public class SDRClassifier implements Persistable {
+public class SDRClassifier implements Persistable, Classifier {
private static final long serialVersionUID = 1L;
int verbosity = 0;
diff --git a/src/main/java/org/numenta/nupic/network/Layer.java b/src/main/java/org/numenta/nupic/network/Layer.java
index e8fa3fb2..9a79745e 100644
--- a/src/main/java/org/numenta/nupic/network/Layer.java
+++ b/src/main/java/org/numenta/nupic/network/Layer.java
@@ -1694,25 +1694,36 @@ private void doEncoderBucketMapping(Inference inference, Map enc
// Store the encoding
int[] encoding = inference.getEncoding();
- for(EncoderTuple t : encoderTuples) {
- String name = t.getName();
- Encoder> e = t.getEncoder();
-
- int bucketIdx = -1;
- Object o = encoderInputMap.get(name);
- if(DateTime.class.isAssignableFrom(o.getClass())) {
- bucketIdx = ((DateEncoder)e).getBucketIndices((DateTime)o)[0];
- } else if(Number.class.isAssignableFrom(o.getClass())) {
- bucketIdx = e.getBucketIndices((double)o)[0];
- } else {
- bucketIdx = e.getBucketIndices((String)o)[0];
- }
-
- int offset = t.getOffset();
- int[] tempArray = new int[e.getWidth()];
- System.arraycopy(encoding, offset, tempArray, 0, tempArray.length);
-
- inference.getClassifierInput().put(name, new NamedTuple(new String[] { "name", "inputValue", "bucketIdx", "encoding" }, name, o, bucketIdx, tempArray));
+ // TODO: 21: Looks like this is where the classifierInput(s) are set.
+ // Should probably change this so that instead of adding a mapping for
+ // each encoder, it adds a mapping for each field specified by the user
+ // in the new Parameters KEY.
+// for(EncoderTuple t : encoderTuples) {
+// String name = t.getName();
+// Encoder> e = t.getEncoder();
+//
+// int bucketIdx = -1;
+// Object o = encoderInputMap.get(name);
+// if(DateTime.class.isAssignableFrom(o.getClass())) {
+// bucketIdx = ((DateEncoder)e).getBucketIndices((DateTime)o)[0];
+// } else if(Number.class.isAssignableFrom(o.getClass())) {
+// bucketIdx = e.getBucketIndices((double)o)[0];
+// } else {
+// bucketIdx = e.getBucketIndices((String)o)[0];
+// }
+//
+// int offset = t.getOffset();
+// int[] tempArray = new int[e.getWidth()];
+// System.arraycopy(encoding, offset, tempArray, 0, tempArray.length);
+//
+// inference.getClassifierInput().put(name, new NamedTuple(new String[] { "name", "inputValue", "bucketIdx", "encoding" }, name, o, bucketIdx, tempArray));
+// }
+
+ // Get fields user wants encoded from Parameters
+ Map inferredFields = (Map)params.get(KEY.INFERRED_FIELDS);
+ for(Map.Entry entry : inferredFields.entrySet()) {
+ String name = entry.getKey();
+ EncoderTuple encoderTuple = encoderTuples.get
}
}
@@ -1908,7 +1919,15 @@ private void clearSubscriberObserverLists() {
* @param encoder
* @return
*/
+ // TODO: 21: This creates two parallel arrays, one of encoder's names, and
+ // the other of each encoder's classifier. Returns a NamedTuple making use of
+ // these arrays easier.
NamedTuple makeClassifiers(MultiEncoder encoder) {
+ // TODO: 21: Should be able to inspect new Parameters KEY(s) in here
+ // TODO: 21: to adjust types of Classifiers that are created/used.
+ // Looks like the passed-in MultiEncoder is a single wrapper containing
+ // multiple encoders; one for each field. Right now, a classifier is
+ // created for each of those encoders. However, instead of do
String[] names = new String[encoder.getEncoders(encoder).size()];
CLAClassifier[] ca = new CLAClassifier[names.length];
int i = 0;
@@ -2321,6 +2340,10 @@ public Object get(Object o) {
@Override
public ManualInput call(ManualInput t1) {
+ // TODO: 21: Should only need to change this code to use the
+ // new Classifier interface. But will need to change what is
+ // returned by t1.getClassifierInput() to only pay attention
+ // to fields being classifier based on new Parameters KEY
Map ci = t1.getClassifierInput();
int recordNum = getRecordNum();
for(String key : ci.keySet()) {
diff --git a/src/main/java/org/numenta/nupic/network/Region.java b/src/main/java/org/numenta/nupic/network/Region.java
index f01d2246..797e19b1 100644
--- a/src/main/java/org/numenta/nupic/network/Region.java
+++ b/src/main/java/org/numenta/nupic/network/Region.java
@@ -457,6 +457,9 @@ Region connect(Region inputRegion) {
@Override public void onError(Throwable e) { e.printStackTrace(); }
@SuppressWarnings("unchecked")
@Override public void onNext(Inference i) {
+ // TODO: 21: This is where classifierInput is set. Need to change
+ // it to respect only fields user has specified for classification
+ // with the new Parameters KEY.
localInf.sdr(i.getSDR()).recordNum(i.getRecordNum()).classifierInput(i.getClassifierInput()).layerInput(i.getSDR());
if(i.getSDR().length > 0) {
((Layer)tail).compute(localInf);
From 67620cc49b298cd179661c4c6b36b88ce34ccba2 Mon Sep 17 00:00:00 2001
From: Andrew Dillon
Date: Sun, 19 Feb 2017 14:00:04 -0600
Subject: [PATCH 11/52] Updated doEncoderBucketMapping() method to use
INFERRED_FIELDS param
---
.../java/org/numenta/nupic/network/Layer.java | 51 +++++++++++++++----
1 file changed, 40 insertions(+), 11 deletions(-)
diff --git a/src/main/java/org/numenta/nupic/network/Layer.java b/src/main/java/org/numenta/nupic/network/Layer.java
index 9a79745e..1577f7b3 100644
--- a/src/main/java/org/numenta/nupic/network/Layer.java
+++ b/src/main/java/org/numenta/nupic/network/Layer.java
@@ -31,6 +31,7 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.stream.Collectors;
import org.joda.time.DateTime;
import org.numenta.nupic.FieldMetaType;
@@ -231,7 +232,7 @@ public class Layer implements Persistable {
private boolean hasGenericProcess;
/**
- * List of {@link Encoders} used when storing bucket information see
+ * List of {@link Encoder}s used when storing bucket information see
* {@link #doEncoderBucketMapping(Inference, Map)}
*/
private List encoderTuples;
@@ -1048,7 +1049,7 @@ public void start() {
/**
* Restarts this {@code Layer}
*
- * {@link #restart()} is to be called after a call to {@link #halt()}, to begin
+ * {@link #restart} is to be called after a call to {@link #halt()}, to begin
* processing again. The {@link Network} will continue from where it previously
* left off after the last call to halt().
*
@@ -1180,7 +1181,7 @@ public Set getPredictiveCells() {
}
/**
- * Returns the previous predictive {@link Cells}
+ * Returns the previous predictive {@link Cell}s
*
* @return the binary vector representing the current prediction.
*/
@@ -1472,7 +1473,7 @@ void notifyError(Exception e) {
*
*
* If any algorithms are repeated then {@link Inference}s will
- * NOT be shared between layers. {@link Regions}
+ * NOT be shared between layers. {@link Region}s
* NEVER share {@link Inference}s
*
*
@@ -1657,7 +1658,7 @@ private Observable resolveObservableSequence(T t) {
/**
* Executes the check point logic, handles the return of the serialized byte array
- * by delegating the call to {@link rx.Observer#onNext(byte[])} of all the currently queued
+ * by delegating the call to {@link rx.Observer#onNext}(byte[]) of all the currently queued
* Observers; then clears the list of Observers.
*/
private void doCheckPoint() {
@@ -1721,9 +1722,38 @@ private void doEncoderBucketMapping(Inference inference, Map enc
// Get fields user wants encoded from Parameters
Map inferredFields = (Map)params.get(KEY.INFERRED_FIELDS);
+
+ // Store a NamedTuple for each of those fields
for(Map.Entry entry : inferredFields.entrySet()) {
- String name = entry.getKey();
- EncoderTuple encoderTuple = encoderTuples.get
+ String fieldName = entry.getKey(); // Name of encoder input field
+ EncoderTuple encoderTuple = encoderTuples.stream() // Get the EncoderTuple for this input field
+ .filter(e -> e.getName().equals(fieldName))
+ .collect(Collectors.toList())
+ .get(0);
+ Encoder> e = encoderTuple.getEncoder();
+
+ int bucketIdx = -1;
+ Object o = encoderInputMap.get(name);
+ if(DateTime.class.isAssignableFrom(o.getClass())) {
+ bucketIdx = ((DateEncoder)e).getBucketIndices((DateTime)o)[0];
+ } else if(Number.class.isAssignableFrom(o.getClass())) {
+ bucketIdx = e.getBucketIndices((double)o)[0];
+ } else {
+ bucketIdx = e.getBucketIndices((String)o)[0];
+ }
+
+ int offset = encoderTuple.getOffset();
+ int[] tempArray = new int[e.getWidth()];
+ System.arraycopy(encoding, offset, tempArray, 0, tempArray.length);
+
+ inference.getClassifierInput().put(
+ name,
+ new NamedTuple(new String[] { "name", "inputValue", "bucketIdx", "encoding" },
+ name,
+ o,
+ bucketIdx,
+ tempArray
+ ));
}
}
@@ -1809,9 +1839,9 @@ private Observable fillInOrderedSequence(Observable o)
/**
* Called internally to create a subscription on behalf of the specified
- * {@link LayerObserver}
+ * Layer {@link Observer}
*
- * @param sub the LayerObserver (subscriber).
+ * @param sub the Layer Observer (subscriber).
* @return
*/
private Subscription createSubscription(final Observer sub) {
@@ -2033,8 +2063,7 @@ public void run() {
* that stores the state of this {@code Network} while keeping the Network up and running.
* The Network will be stored at the pre-configured location (in binary form only, not JSON).
*
- * @param network the {@link Network} to check point.
- * @return the {@link CheckPointOp} operator
+ * @return the {@link CheckPointOp} operator
*/
@SuppressWarnings("unchecked")
CheckPointOp getCheckPointOperator() {
From 6e9bc600a2ccbd5d5b94e7f768e6ff1b6a33e508 Mon Sep 17 00:00:00 2001
From: cogmission
Date: Tue, 28 Feb 2017 02:58:35 -0600
Subject: [PATCH 12/52] Remove tabs from file, neaten appearance
---
.../numenta/nupic/encoders/ScalarEncoder.java | 1198 ++++++++---------
1 file changed, 599 insertions(+), 599 deletions(-)
diff --git a/src/main/java/org/numenta/nupic/encoders/ScalarEncoder.java b/src/main/java/org/numenta/nupic/encoders/ScalarEncoder.java
index b62d9ed7..7e29e842 100644
--- a/src/main/java/org/numenta/nupic/encoders/ScalarEncoder.java
+++ b/src/main/java/org/numenta/nupic/encoders/ScalarEncoder.java
@@ -159,266 +159,266 @@
*/
public class ScalarEncoder extends Encoder {
- private static final long serialVersionUID = 1L;
-
- private static final Logger LOGGER = LoggerFactory.getLogger(ScalarEncoder.class);
-
- /**
- * Constructs a new {@code ScalarEncoder}
- */
- ScalarEncoder() {}
-
- /**
- * Returns a builder for building ScalarEncoders.
- * This builder may be reused to produce multiple builders
- *
- * @return a {@code ScalarEncoder.Builder}
- */
- public static Encoder.Builder builder() {
- return new ScalarEncoder.Builder();
- }
-
- /**
- * Returns true if the underlying encoder works on deltas
- */
- @Override
- public boolean isDelta() {
- return false;
- }
-
- /**
- * w -- number of bits to set in output
+ private static final long serialVersionUID = 1L;
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(ScalarEncoder.class);
+
+ /**
+ * Constructs a new {@code ScalarEncoder}
+ */
+ ScalarEncoder() {}
+
+ /**
+ * Returns a builder for building ScalarEncoders.
+ * This builder may be reused to produce multiple builders
+ *
+ * @return a {@code ScalarEncoder.Builder}
+ */
+ public static Encoder.Builder builder() {
+ return new ScalarEncoder.Builder();
+ }
+
+ /**
+ * Returns true if the underlying encoder works on deltas
+ */
+ @Override
+ public boolean isDelta() {
+ return false;
+ }
+
+ /**
+ * w -- number of bits to set in output
* minval -- minimum input value
* maxval -- maximum input value (input is strictly less if periodic == True)
- *
+ *
* Exactly one of n, radius, resolution must be set. "0" is a special
* value that means "not set".
- *
+ *
* n -- number of bits in the representation (must be > w)
* radius -- inputs separated by more than, or equal to this distance will have non-overlapping
* representations
* resolution -- inputs separated by more than, or equal to this distance will have different
* representations
- *
+ *
* name -- an optional string which will become part of the description
- *
+ *
* clipInput -- if true, non-periodic inputs smaller than minval or greater
* than maxval will be clipped to minval/maxval
- *
+ *
* forced -- if true, skip some safety checks (for compatibility reasons), default false
- */
- public void init() {
- if(getW() % 2 == 0) {
- throw new IllegalStateException(
- "W must be an odd number (to eliminate centering difficulty)");
- }
-
- setHalfWidth((getW() - 1) / 2);
-
- // For non-periodic inputs, padding is the number of bits "outside" the range,
- // on each side. I.e. the representation of minval is centered on some bit, and
- // there are "padding" bits to the left of that centered bit; similarly with
- // bits to the right of the center bit of maxval
- setPadding(isPeriodic() ? 0 : getHalfWidth());
-
- if(!Double.isNaN(getMinVal()) && !Double.isNaN(getMaxVal())) {
- if(getMinVal() >= getMaxVal()) {
- throw new IllegalStateException("maxVal must be > minVal");
- }
- setRangeInternal(getMaxVal() - getMinVal());
- }
-
- // There are three different ways of thinking about the representation. Handle
- // each case here.
- initEncoder(getW(), getMinVal(), getMaxVal(), getN(), getRadius(), getResolution());
-
- //nInternal represents the output area excluding the possible padding on each side
- setNInternal(getN() - 2 * getPadding());
-
- if(getName() == null) {
- if((getMinVal() % ((int)getMinVal())) > 0 ||
- (getMaxVal() % ((int)getMaxVal())) > 0) {
- setName("[" + getMinVal() + ":" + getMaxVal() + "]");
- }else{
- setName("[" + (int)getMinVal() + ":" + (int)getMaxVal() + "]");
- }
- }
-
- //Checks for likely mistakes in encoder settings
- if(!isForced()) {
- checkReasonableSettings();
- }
+ */
+ public void init() {
+ if(getW() % 2 == 0) {
+ throw new IllegalStateException(
+ "W must be an odd number (to eliminate centering difficulty)");
+ }
+
+ setHalfWidth((getW() - 1) / 2);
+
+ // For non-periodic inputs, padding is the number of bits "outside" the range,
+ // on each side. I.e. the representation of minval is centered on some bit, and
+ // there are "padding" bits to the left of that centered bit; similarly with
+ // bits to the right of the center bit of maxval
+ setPadding(isPeriodic() ? 0 : getHalfWidth());
+
+ if(!Double.isNaN(getMinVal()) && !Double.isNaN(getMaxVal())) {
+ if(getMinVal() >= getMaxVal()) {
+ throw new IllegalStateException("maxVal must be > minVal");
+ }
+ setRangeInternal(getMaxVal() - getMinVal());
+ }
+
+ // There are three different ways of thinking about the representation. Handle
+ // each case here.
+ initEncoder(getW(), getMinVal(), getMaxVal(), getN(), getRadius(), getResolution());
+
+ //nInternal represents the output area excluding the possible padding on each side
+ setNInternal(getN() - 2 * getPadding());
+
+ if(getName() == null) {
+ if((getMinVal() % ((int)getMinVal())) > 0 ||
+ (getMaxVal() % ((int)getMaxVal())) > 0) {
+ setName("[" + getMinVal() + ":" + getMaxVal() + "]");
+ }else{
+ setName("[" + (int)getMinVal() + ":" + (int)getMaxVal() + "]");
+ }
+ }
+
+ //Checks for likely mistakes in encoder settings
+ if(!isForced()) {
+ checkReasonableSettings();
+ }
description.add(new Tuple((name = getName()).equals("None") ? "[" + (int)getMinVal() + ":" + (int)getMaxVal() + "]" : name, 0));
- }
+ }
- /**
- * There are three different ways of thinking about the representation.
+ /**
+ * There are three different ways of thinking about the representation.
* Handle each case here.
*
- * @param c
- * @param minVal
- * @param maxVal
- * @param n
- * @param radius
- * @param resolution
- */
- public void initEncoder(int w, double minVal, double maxVal, int n, double radius, double resolution) {
- if(n != 0) {
- if(!Double.isNaN(minVal) && !Double.isNaN(maxVal)) {
- if(!isPeriodic()) {
- setResolution(getRangeInternal() / (getN() - getW()));
- }else{
- setResolution(getRangeInternal() / getN());
- }
-
- setRadius(getW() * getResolution());
-
- if(isPeriodic()) {
- setRange(getRangeInternal());
- }else{
- setRange(getRangeInternal() + getResolution());
- }
- }
- }else{
- if(radius != 0) {
- setResolution(getRadius() / w);
- }else if(resolution != 0) {
- setRadius(getResolution() * w);
- }else{
- throw new IllegalStateException(
- "One of n, radius, resolution must be specified for a ScalarEncoder");
- }
-
- if(isPeriodic()) {
- setRange(getRangeInternal());
- }else{
- setRange(getRangeInternal() + getResolution());
- }
-
- double nFloat = w * (getRange() / getRadius()) + 2 * getPadding();
- setN((int)Math.ceil(nFloat));
- }
- }
-
- /**
- * Return the bit offset of the first bit to be set in the encoder output.
+ * @param c
+ * @param minVal
+ * @param maxVal
+ * @param n
+ * @param radius
+ * @param resolution
+ */
+ public void initEncoder(int w, double minVal, double maxVal, int n, double radius, double resolution) {
+ if(n != 0) {
+ if(!Double.isNaN(minVal) && !Double.isNaN(maxVal)) {
+ if(!isPeriodic()) {
+ setResolution(getRangeInternal() / (getN() - getW()));
+ }else{
+ setResolution(getRangeInternal() / getN());
+ }
+
+ setRadius(getW() * getResolution());
+
+ if(isPeriodic()) {
+ setRange(getRangeInternal());
+ }else{
+ setRange(getRangeInternal() + getResolution());
+ }
+ }
+ }else{
+ if(radius != 0) {
+ setResolution(getRadius() / w);
+ }else if(resolution != 0) {
+ setRadius(getResolution() * w);
+ }else{
+ throw new IllegalStateException(
+ "One of n, radius, resolution must be specified for a ScalarEncoder");
+ }
+
+ if(isPeriodic()) {
+ setRange(getRangeInternal());
+ }else{
+ setRange(getRangeInternal() + getResolution());
+ }
+
+ double nFloat = w * (getRange() / getRadius()) + 2 * getPadding();
+ setN((int)Math.ceil(nFloat));
+ }
+ }
+
+ /**
+ * Return the bit offset of the first bit to be set in the encoder output.
* For periodic encoders, this can be a negative number when the encoded output
* wraps around.
*
- * @param c the memory
- * @param input the input data
- * @return an encoded array
- */
- public Integer getFirstOnBit(double input) {
- if(Double.isNaN(input)) {
- return null;
- }else{
- if(input < getMinVal()) {
- if(clipInput() && !isPeriodic()) {
- if(LOGGER.isTraceEnabled()) {
- LOGGER.info("Clipped input " + getName() + "=" + input + " to minval " + getMinVal());
- }
- input = getMinVal();
- }else{
- throw new IllegalStateException("input (" + input +") less than range (" +
- getMinVal() + " - " + getMaxVal() + ")");
- }
- }
- }
-
- if(isPeriodic()) {
- if(input >= getMaxVal()) {
- throw new IllegalStateException("input (" + input +") greater than periodic range (" +
- getMinVal() + " - " + getMaxVal() + ")");
- }
- }else{
- if(input > getMaxVal()) {
- if(clipInput()) {
- if(LOGGER.isTraceEnabled()) {
- LOGGER.info("Clipped input " + getName() + "=" + input + " to maxval " + getMaxVal());
- }
- input = getMaxVal();
- }else{
- throw new IllegalStateException("input (" + input +") greater than periodic range (" +
- getMinVal() + " - " + getMaxVal() + ")");
- }
- }
- }
-
- int centerbin;
- if(isPeriodic()) {
- centerbin = ((int)((input - getMinVal()) * getNInternal() / getRange())) + getPadding();
- }else{
- centerbin = ((int)(((input - getMinVal()) + getResolution()/2) / getResolution())) + getPadding();
- }
-
- return centerbin - getHalfWidth();
- }
-
- /**
- * Check if the settings are reasonable for the SpatialPooler to work
- * @param c
- */
- public void checkReasonableSettings() {
- if(getW() < 21) {
- throw new IllegalStateException(
- "Number of bits in the SDR (%d) must be greater than 2, and recommended >= 21 (use forced=True to override)");
- }
- }
-
- /**
- * {@inheritDoc}
- */
- @Override
- public Set getDecoderOutputFieldTypes() {
- return new LinkedHashSet<>(Arrays.asList(FieldMetaType.FLOAT, FieldMetaType.INTEGER));
- }
-
- /**
- * Should return the output width, in bits.
- */
- @Override
- public int getWidth() {
- return getN();
- }
-
- /**
- * {@inheritDoc}
- * NO-OP
- */
- @Override
- public int[] getBucketIndices(String input) { return null; }
-
- /**
- * Returns the bucket indices.
- *
- * @param input
- */
- @Override
- public int[] getBucketIndices(double input) {
- int minbin = getFirstOnBit(input);
-
- //For periodic encoders, the bucket index is the index of the center bit
- int bucketIdx;
- if(isPeriodic()) {
- bucketIdx = minbin + getHalfWidth();
- if(bucketIdx < 0) {
- bucketIdx += getN();
- }
- }else{//for non-periodic encoders, the bucket index is the index of the left bit
- bucketIdx = minbin;
- }
-
- return new int[] { bucketIdx };
- }
-
- /**
- * Encodes inputData and puts the encoded value into the output array,
+ * @param c the memory
+ * @param input the input data
+ * @return an encoded array
+ */
+ public Integer getFirstOnBit(double input) {
+ if(Double.isNaN(input)) {
+ return null;
+ }else{
+ if(input < getMinVal()) {
+ if(clipInput() && !isPeriodic()) {
+ if(LOGGER.isTraceEnabled()) {
+ LOGGER.info("Clipped input " + getName() + "=" + input + " to minval " + getMinVal());
+ }
+ input = getMinVal();
+ }else{
+ throw new IllegalStateException("input (" + input +") less than range (" +
+ getMinVal() + " - " + getMaxVal() + ")");
+ }
+ }
+ }
+
+ if(isPeriodic()) {
+ if(input >= getMaxVal()) {
+ throw new IllegalStateException("input (" + input +") greater than periodic range (" +
+ getMinVal() + " - " + getMaxVal() + ")");
+ }
+ }else{
+ if(input > getMaxVal()) {
+ if(clipInput()) {
+ if(LOGGER.isTraceEnabled()) {
+ LOGGER.info("Clipped input " + getName() + "=" + input + " to maxval " + getMaxVal());
+ }
+ input = getMaxVal();
+ }else{
+ throw new IllegalStateException("input (" + input +") greater than periodic range (" +
+ getMinVal() + " - " + getMaxVal() + ")");
+ }
+ }
+ }
+
+ int centerbin;
+ if(isPeriodic()) {
+ centerbin = ((int)((input - getMinVal()) * getNInternal() / getRange())) + getPadding();
+ }else{
+ centerbin = ((int)(((input - getMinVal()) + getResolution()/2) / getResolution())) + getPadding();
+ }
+
+ return centerbin - getHalfWidth();
+ }
+
+ /**
+ * Check if the settings are reasonable for the SpatialPooler to work
+ * @param c
+ */
+ public void checkReasonableSettings() {
+ if(getW() < 21) {
+ throw new IllegalStateException(
+ "Number of bits in the SDR (%d) must be greater than 2, and recommended >= 21 (use forced=True to override)");
+ }
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public Set getDecoderOutputFieldTypes() {
+ return new LinkedHashSet<>(Arrays.asList(FieldMetaType.FLOAT, FieldMetaType.INTEGER));
+ }
+
+ /**
+ * Should return the output width, in bits.
+ */
+ @Override
+ public int getWidth() {
+ return getN();
+ }
+
+ /**
+ * {@inheritDoc}
+ * NO-OP
+ */
+ @Override
+ public int[] getBucketIndices(String input) { return null; }
+
+ /**
+ * Returns the bucket indices.
+ *
+ * @param input
+ */
+ @Override
+ public int[] getBucketIndices(double input) {
+ int minbin = getFirstOnBit(input);
+
+ //For periodic encoders, the bucket index is the index of the center bit
+ int bucketIdx;
+ if(isPeriodic()) {
+ bucketIdx = minbin + getHalfWidth();
+ if(bucketIdx < 0) {
+ bucketIdx += getN();
+ }
+ }else{//for non-periodic encoders, the bucket index is the index of the left bit
+ bucketIdx = minbin;
+ }
+
+ return new int[] { bucketIdx };
+ }
+
+ /**
+ * Encodes inputData and puts the encoded value into the output array,
* which is a 1-D array of length returned by {@link Connections#getW()}.
- *
+ *
* Note: The output array is reused, so clear it before updating it.
- * @param inputData Data to encode. This should be validated by the encoder.
- * @param output 1-D array of same length returned by {@link Connections#getW()}
+ * @param inputData Data to encode. This should be validated by the encoder.
+ * @param output 1-D array of same length returned by {@link Connections#getW()}
*/
@Override
public void encodeIntoArray(Double input, int[] output) {
@@ -426,7 +426,7 @@ public void encodeIntoArray(Double input, int[] output) {
Arrays.fill(output, 0);
return;
}
-
+
Integer bucketVal = getFirstOnBit(input);
if(bucketVal != null) {
int bucketIdx = bucketVal;
@@ -446,391 +446,391 @@ public void encodeIntoArray(Double input, int[] output) {
minbin = 0;
}
}
-
+
ArrayUtils.setIndexesTo(output, ArrayUtils.range(minbin, maxbin + 1), 1);
}
-
+
// Added guard against immense string concatenation
if(LOGGER.isTraceEnabled()) {
LOGGER.trace("");
LOGGER.trace("input: " + input);
LOGGER.trace("range: " + getMinVal() + " - " + getMaxVal());
LOGGER.trace("n:" + getN() + "w:" + getW() + "resolution:" + getResolution() +
- "radius:" + getRadius() + "periodic:" + isPeriodic());
+ "radius:" + getRadius() + "periodic:" + isPeriodic());
LOGGER.trace("output: " + Arrays.toString(output));
LOGGER.trace("input desc: " + decode(output, ""));
}
}
- /**
- * Returns a {@link DecodeResult} which is a tuple of range names
- * and lists of {@link RangeLists} in the first entry, and a list
- * of descriptions for each range in the second entry.
- *
- * @param encoded the encoded bit vector
- * @param parentFieldName the field the vector corresponds with
- * @return
- */
- @Override
- public DecodeResult decode(int[] encoded, String parentFieldName) {
- // For now, we simply assume any top-down output greater than 0
- // is ON. Eventually, we will probably want to incorporate the strength
- // of each top-down output.
- if(encoded == null || encoded.length < 1) {
- return null;
- }
- int[] tmpOutput = Arrays.copyOf(encoded, encoded.length);
-
- // ------------------------------------------------------------------------
- // First, assume the input pool is not sampled 100%, and fill in the
- // "holes" in the encoded representation (which are likely to be present
- // if this is a coincidence that was learned by the SP).
-
- // Search for portions of the output that have "holes"
- int maxZerosInARow = getHalfWidth();
- for(int i = 0;i < maxZerosInARow;i++) {
- int[] searchStr = new int[i + 3];
- Arrays.fill(searchStr, 1);
- ArrayUtils.setRangeTo(searchStr, 1, -1, 0);
- int subLen = searchStr.length;
-
- // Does this search string appear in the output?
- if(isPeriodic()) {
- for(int j = 0;j < getN();j++) {
- int[] outputIndices = ArrayUtils.range(j, j + subLen);
- outputIndices = ArrayUtils.modulo(outputIndices, getN());
- if(Arrays.equals(searchStr, ArrayUtils.sub(tmpOutput, outputIndices))) {
- ArrayUtils.setIndexesTo(tmpOutput, outputIndices, 1);
- }
- }
- }else{
- for(int j = 0;j < getN() - subLen + 1;j++) {
- if(Arrays.equals(searchStr, ArrayUtils.sub(tmpOutput, ArrayUtils.range(j, j + subLen)))) {
- ArrayUtils.setRangeTo(tmpOutput, j, j + subLen, 1);
- }
- }
- }
- }
-
- LOGGER.trace("raw output:" + Arrays.toString(
- ArrayUtils.sub(encoded, ArrayUtils.range(0, getN()))));
- LOGGER.trace("filtered output:" + Arrays.toString(tmpOutput));
-
- // ------------------------------------------------------------------------
- // Find each run of 1's.
- int[] nz = ArrayUtils.where(tmpOutput, new Condition.Adapter() {
- @Override
- public boolean eval(int n) {
- return n > 0;
- }
- });
- List runs = new ArrayList(); //will be tuples of (startIdx, runLength)
- Arrays.sort(nz);
- int[] run = new int[] { nz[0], 1 };
- int i = 1;
- while(i < nz.length) {
- if(nz[i] == run[0] + run[1]) {
- run[1] += 1;
- }else{
- runs.add(new Tuple(run[0], run[1]));
- run = new int[] { nz[i], 1 };
- }
- i += 1;
- }
- runs.add(new Tuple(run[0], run[1]));
-
- // If we have a periodic encoder, merge the first and last run if they
- // both go all the way to the edges
- if(isPeriodic() && runs.size() > 1) {
- int l = runs.size() - 1;
- if(((Integer)runs.get(0).get(0)) == 0 && ((Integer)runs.get(l).get(0)) + ((Integer)runs.get(l).get(1)) == getN()) {
- runs.set(l, new Tuple((Integer)runs.get(l).get(0),
- ((Integer)runs.get(l).get(1)) + ((Integer)runs.get(0).get(1)) ));
- runs = runs.subList(1, runs.size());
- }
- }
-
- // ------------------------------------------------------------------------
- // Now, for each group of 1's, determine the "left" and "right" edges, where
- // the "left" edge is inset by halfwidth and the "right" edge is inset by
- // halfwidth.
- // For a group of width w or less, the "left" and "right" edge are both at
- // the center position of the group.
- int left = 0;
- int right = 0;
- List ranges = new ArrayList();
- for(Tuple tupleRun : runs) {
- int start = (Integer)tupleRun.get(0);
- int runLen = (Integer)tupleRun.get(1);
- if(runLen <= getW()) {
- left = right = start + runLen / 2;
- }else{
- left = start + getHalfWidth();
- right = start + runLen - 1 - getHalfWidth();
- }
-
- double inMin, inMax;
- // Convert to input space.
- if(!isPeriodic()) {
- inMin = (left - getPadding()) * getResolution() + getMinVal();
- inMax = (right - getPadding()) * getResolution() + getMinVal();
- }else{
- inMin = (left - getPadding()) * getRange() / getNInternal() + getMinVal();
- inMax = (right - getPadding()) * getRange() / getNInternal() + getMinVal();
- }
- // Handle wrap-around if periodic
- if(isPeriodic()) {
- if(inMin >= getMaxVal()) {
- inMin -= getRange();
- inMax -= getRange();
- }
- }
-
- // Clip low end
- if(inMin < getMinVal()) {
- inMin = getMinVal();
- }
- if(inMax < getMinVal()) {
- inMax = getMinVal();
- }
-
- // If we have a periodic encoder, and the max is past the edge, break into
- // 2 separate ranges
- if(isPeriodic() && inMax >= getMaxVal()) {
- ranges.add(new MinMax(inMin, getMaxVal()));
- ranges.add(new MinMax(getMinVal(), inMax - getRange()));
- }else{
- if(inMax > getMaxVal()) {
- inMax = getMaxVal();
- }
- if(inMin > getMaxVal()) {
- inMin = getMaxVal();
- }
- ranges.add(new MinMax(inMin, inMax));
- }
- }
-
- String desc = generateRangeDescription(ranges);
- String fieldName;
- // Return result
- if(parentFieldName != null && !parentFieldName.isEmpty()) {
- fieldName = String.format("%s.%s", parentFieldName, getName());
- }else{
- fieldName = getName();
- }
-
- RangeList inner = new RangeList(ranges, desc);
- Map fieldsDict = new HashMap();
- fieldsDict.put(fieldName, inner);
-
- return new DecodeResult(fieldsDict, Arrays.asList(fieldName));
- }
-
- /**
- * Generate description from a text description of the ranges
- *
- * @param ranges A list of {@link MinMax}es.
- */
- public String generateRangeDescription(List ranges) {
- StringBuilder desc = new StringBuilder();
- int numRanges = ranges.size();
- for(int i = 0;i < numRanges;i++) {
- if(ranges.get(i).min() != ranges.get(i).max()) {
- desc.append(String.format("%.2f-%.2f", ranges.get(i).min(), ranges.get(i).max()));
- }else{
- desc.append(String.format("%.2f", ranges.get(i).min()));
- }
- if(i < numRanges - 1) {
- desc.append(", ");
- }
- }
- return desc.toString();
- }
-
- /**
- * Return the internal topDownMapping matrix used for handling the
+ /**
+ * Returns a {@link DecodeResult} which is a tuple of range names
+ * and lists of {@link RangeLists} in the first entry, and a list
+ * of descriptions for each range in the second entry.
+ *
+ * @param encoded the encoded bit vector
+ * @param parentFieldName the field the vector corresponds with
+ * @return
+ */
+ @Override
+ public DecodeResult decode(int[] encoded, String parentFieldName) {
+ // For now, we simply assume any top-down output greater than 0
+ // is ON. Eventually, we will probably want to incorporate the strength
+ // of each top-down output.
+ if(encoded == null || encoded.length < 1) {
+ return null;
+ }
+ int[] tmpOutput = Arrays.copyOf(encoded, encoded.length);
+
+ // ------------------------------------------------------------------------
+ // First, assume the input pool is not sampled 100%, and fill in the
+ // "holes" in the encoded representation (which are likely to be present
+ // if this is a coincidence that was learned by the SP).
+
+ // Search for portions of the output that have "holes"
+ int maxZerosInARow = getHalfWidth();
+ for(int i = 0;i < maxZerosInARow;i++) {
+ int[] searchStr = new int[i + 3];
+ Arrays.fill(searchStr, 1);
+ ArrayUtils.setRangeTo(searchStr, 1, -1, 0);
+ int subLen = searchStr.length;
+
+ // Does this search string appear in the output?
+ if(isPeriodic()) {
+ for(int j = 0;j < getN();j++) {
+ int[] outputIndices = ArrayUtils.range(j, j + subLen);
+ outputIndices = ArrayUtils.modulo(outputIndices, getN());
+ if(Arrays.equals(searchStr, ArrayUtils.sub(tmpOutput, outputIndices))) {
+ ArrayUtils.setIndexesTo(tmpOutput, outputIndices, 1);
+ }
+ }
+ }else{
+ for(int j = 0;j < getN() - subLen + 1;j++) {
+ if(Arrays.equals(searchStr, ArrayUtils.sub(tmpOutput, ArrayUtils.range(j, j + subLen)))) {
+ ArrayUtils.setRangeTo(tmpOutput, j, j + subLen, 1);
+ }
+ }
+ }
+ }
+
+ LOGGER.trace("raw output:" + Arrays.toString(
+ ArrayUtils.sub(encoded, ArrayUtils.range(0, getN()))));
+ LOGGER.trace("filtered output:" + Arrays.toString(tmpOutput));
+
+ // ------------------------------------------------------------------------
+ // Find each run of 1's.
+ int[] nz = ArrayUtils.where(tmpOutput, new Condition.Adapter() {
+ @Override
+ public boolean eval(int n) {
+ return n > 0;
+ }
+ });
+ List runs = new ArrayList(); //will be tuples of (startIdx, runLength)
+ Arrays.sort(nz);
+ int[] run = new int[] { nz[0], 1 };
+ int i = 1;
+ while(i < nz.length) {
+ if(nz[i] == run[0] + run[1]) {
+ run[1] += 1;
+ }else{
+ runs.add(new Tuple(run[0], run[1]));
+ run = new int[] { nz[i], 1 };
+ }
+ i += 1;
+ }
+ runs.add(new Tuple(run[0], run[1]));
+
+ // If we have a periodic encoder, merge the first and last run if they
+ // both go all the way to the edges
+ if(isPeriodic() && runs.size() > 1) {
+ int l = runs.size() - 1;
+ if(((Integer)runs.get(0).get(0)) == 0 && ((Integer)runs.get(l).get(0)) + ((Integer)runs.get(l).get(1)) == getN()) {
+ runs.set(l, new Tuple((Integer)runs.get(l).get(0),
+ ((Integer)runs.get(l).get(1)) + ((Integer)runs.get(0).get(1)) ));
+ runs = runs.subList(1, runs.size());
+ }
+ }
+
+ // ------------------------------------------------------------------------
+ // Now, for each group of 1's, determine the "left" and "right" edges, where
+ // the "left" edge is inset by halfwidth and the "right" edge is inset by
+ // halfwidth.
+ // For a group of width w or less, the "left" and "right" edge are both at
+ // the center position of the group.
+ int left = 0;
+ int right = 0;
+ List ranges = new ArrayList();
+ for(Tuple tupleRun : runs) {
+ int start = (Integer)tupleRun.get(0);
+ int runLen = (Integer)tupleRun.get(1);
+ if(runLen <= getW()) {
+ left = right = start + runLen / 2;
+ }else{
+ left = start + getHalfWidth();
+ right = start + runLen - 1 - getHalfWidth();
+ }
+
+ double inMin, inMax;
+ // Convert to input space.
+ if(!isPeriodic()) {
+ inMin = (left - getPadding()) * getResolution() + getMinVal();
+ inMax = (right - getPadding()) * getResolution() + getMinVal();
+ }else{
+ inMin = (left - getPadding()) * getRange() / getNInternal() + getMinVal();
+ inMax = (right - getPadding()) * getRange() / getNInternal() + getMinVal();
+ }
+ // Handle wrap-around if periodic
+ if(isPeriodic()) {
+ if(inMin >= getMaxVal()) {
+ inMin -= getRange();
+ inMax -= getRange();
+ }
+ }
+
+ // Clip low end
+ if(inMin < getMinVal()) {
+ inMin = getMinVal();
+ }
+ if(inMax < getMinVal()) {
+ inMax = getMinVal();
+ }
+
+ // If we have a periodic encoder, and the max is past the edge, break into
+ // 2 separate ranges
+ if(isPeriodic() && inMax >= getMaxVal()) {
+ ranges.add(new MinMax(inMin, getMaxVal()));
+ ranges.add(new MinMax(getMinVal(), inMax - getRange()));
+ }else{
+ if(inMax > getMaxVal()) {
+ inMax = getMaxVal();
+ }
+ if(inMin > getMaxVal()) {
+ inMin = getMaxVal();
+ }
+ ranges.add(new MinMax(inMin, inMax));
+ }
+ }
+
+ String desc = generateRangeDescription(ranges);
+ String fieldName;
+ // Return result
+ if(parentFieldName != null && !parentFieldName.isEmpty()) {
+ fieldName = String.format("%s.%s", parentFieldName, getName());
+ }else{
+ fieldName = getName();
+ }
+
+ RangeList inner = new RangeList(ranges, desc);
+ Map fieldsDict = new HashMap();
+ fieldsDict.put(fieldName, inner);
+
+ return new DecodeResult(fieldsDict, Arrays.asList(fieldName));
+ }
+
+ /**
+ * Generate description from a text description of the ranges
+ *
+ * @param ranges A list of {@link MinMax}es.
+ */
+ public String generateRangeDescription(List ranges) {
+ StringBuilder desc = new StringBuilder();
+ int numRanges = ranges.size();
+ for(int i = 0;i < numRanges;i++) {
+ if(ranges.get(i).min() != ranges.get(i).max()) {
+ desc.append(String.format("%.2f-%.2f", ranges.get(i).min(), ranges.get(i).max()));
+ }else{
+ desc.append(String.format("%.2f", ranges.get(i).min()));
+ }
+ if(i < numRanges - 1) {
+ desc.append(", ");
+ }
+ }
+ return desc.toString();
+ }
+
+ /**
+ * Return the internal topDownMapping matrix used for handling the
* bucketInfo() and topDownCompute() methods. This is a matrix, one row per
* category (bucket) where each row contains the encoded output for that
* category.
*
- * @param c the connections memory
- * @return the internal topDownMapping
- */
- public SparseObjectMatrix getTopDownMapping() {
-
- if(topDownMapping == null) {
- //The input scalar value corresponding to each possible output encoding
- if(isPeriodic()) {
- setTopDownValues(
- ArrayUtils.arange(getMinVal() + getResolution() / 2.0,
- getMaxVal(), getResolution()));
- }else{
- //Number of values is (max-min)/resolutions
- setTopDownValues(
- ArrayUtils.arange(getMinVal(), getMaxVal() + getResolution() / 2.0,
- getResolution()));
- }
- }
-
- //Each row represents an encoded output pattern
- int numCategories = getTopDownValues().length;
- SparseObjectMatrix topDownMapping;
- setTopDownMapping(
- topDownMapping = new SparseObjectMatrix(
- new int[] { numCategories }));
-
- double[] topDownValues = getTopDownValues();
- int[] outputSpace = new int[getN()];
- double minVal = getMinVal();
- double maxVal = getMaxVal();
- for(int i = 0;i < numCategories;i++) {
- double value = topDownValues[i];
- value = Math.max(value, minVal);
- value = Math.min(value, maxVal);
- encodeIntoArray(value, outputSpace);
- topDownMapping.set(i, Arrays.copyOf(outputSpace, outputSpace.length));
- }
-
- return topDownMapping;
- }
-
- /**
- * {@inheritDoc}
- *
- * @param the input value, in this case a double
- * @return a list of one input double
- */
- @Override
- public TDoubleList getScalars(S d) {
- TDoubleList retVal = new TDoubleArrayList();
- retVal.add((Double)d);
- return retVal;
- }
-
- /**
- * Returns a list of items, one for each bucket defined by this encoder.
+ * @param c the connections memory
+ * @return the internal topDownMapping
+ */
+ public SparseObjectMatrix getTopDownMapping() {
+
+ if(topDownMapping == null) {
+ //The input scalar value corresponding to each possible output encoding
+ if(isPeriodic()) {
+ setTopDownValues(
+ ArrayUtils.arange(getMinVal() + getResolution() / 2.0,
+ getMaxVal(), getResolution()));
+ }else{
+ //Number of values is (max-min)/resolutions
+ setTopDownValues(
+ ArrayUtils.arange(getMinVal(), getMaxVal() + getResolution() / 2.0,
+ getResolution()));
+ }
+ }
+
+ //Each row represents an encoded output pattern
+ int numCategories = getTopDownValues().length;
+ SparseObjectMatrix topDownMapping;
+ setTopDownMapping(
+ topDownMapping = new SparseObjectMatrix(
+ new int[] { numCategories }));
+
+ double[] topDownValues = getTopDownValues();
+ int[] outputSpace = new int[getN()];
+ double minVal = getMinVal();
+ double maxVal = getMaxVal();
+ for(int i = 0;i < numCategories;i++) {
+ double value = topDownValues[i];
+ value = Math.max(value, minVal);
+ value = Math.min(value, maxVal);
+ encodeIntoArray(value, outputSpace);
+ topDownMapping.set(i, Arrays.copyOf(outputSpace, outputSpace.length));
+ }
+
+ return topDownMapping;
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * @param the input value, in this case a double
+ * @return a list of one input double
+ */
+ @Override
+ public TDoubleList getScalars(S d) {
+ TDoubleList retVal = new TDoubleArrayList();
+ retVal.add((Double)d);
+ return retVal;
+ }
+
+ /**
+ * Returns a list of items, one for each bucket defined by this encoder.
* Each item is the value assigned to that bucket, this is the same as the
* EncoderResult.value that would be returned by getBucketInfo() for that
* bucket and is in the same format as the input that would be passed to
* encode().
- *
+ *
* This call is faster than calling getBucketInfo() on each bucket individually
* if all you need are the bucket values.
- *
- * @param returnType class type parameter so that this method can return encoder
+ *
+ * @param returnType class type parameter so that this method can return encoder
* specific value types
*
* @return list of items, each item representing the bucket value for that
* bucket.
- */
- @SuppressWarnings("unchecked")
- @Override
- public List getBucketValues(Class t) {
- if(bucketValues == null) {
- SparseObjectMatrix topDownMapping = getTopDownMapping();
- int numBuckets = topDownMapping.getMaxIndex() + 1;
- bucketValues = new ArrayList();
- for(int i = 0;i < numBuckets;i++) {
- ((List)bucketValues).add((Double)getBucketInfo(new int[] { i }).get(0).get(1));
- }
- }
- return (List)bucketValues;
- }
-
- /**
- * {@inheritDoc}
- */
- @Override
- public List getBucketInfo(int[] buckets) {
- SparseObjectMatrix topDownMapping = getTopDownMapping();
-
- //The "category" is simply the bucket index
- int category = buckets[0];
- int[] encoding = topDownMapping.getObject(category);
-
- //Which input value does this correspond to?
- double inputVal;
- if(isPeriodic()) {
- inputVal = getMinVal() + getResolution() / 2 + category * getResolution();
- }else{
- inputVal = getMinVal() + category * getResolution();
- }
-
- return Arrays.asList(new Encoding(inputVal, inputVal, encoding));
- }
-
- /**
- * {@inheritDoc}
- */
- @Override
- public List topDownCompute(int[] encoded) {
- //Get/generate the topDown mapping table
- SparseObjectMatrix topDownMapping = getTopDownMapping();
-
- // See which "category" we match the closest.
- int category = ArrayUtils.argmax(rightVecProd(topDownMapping, encoded));
-
- return getBucketInfo(new int[]{category});
- }
-
- /**
- * Returns a list of {@link Tuple}s which in this case is a list of
- * key value parameter values for this {@code ScalarEncoder}
- *
- * @return a list of {@link Tuple}s
- */
- public List dict() {
- List l = new ArrayList();
- l.add(new Tuple("maxval", getMaxVal()));
- l.add(new Tuple("bucketValues", getBucketValues(Double.class)));
- l.add(new Tuple("nInternal", getNInternal()));
- l.add(new Tuple("name", getName()));
- l.add(new Tuple("minval", getMinVal()));
- l.add(new Tuple("topDownValues", Arrays.toString(getTopDownValues())));
- l.add(new Tuple("clipInput", clipInput()));
- l.add(new Tuple("n", getN()));
- l.add(new Tuple("padding", getPadding()));
- l.add(new Tuple("range", getRange()));
- l.add(new Tuple("periodic", isPeriodic()));
- l.add(new Tuple("radius", getRadius()));
- l.add(new Tuple("w", getW()));
- l.add(new Tuple("topDownMappingM", getTopDownMapping()));
- l.add(new Tuple("halfwidth", getHalfWidth()));
- l.add(new Tuple("resolution", getResolution()));
- l.add(new Tuple("rangeInternal", getRangeInternal()));
-
- return l;
- }
-
- /**
- * Returns a {@link EncoderBuilder} for constructing {@link ScalarEncoder}s
- *
- * The base class architecture is put together in such a way where boilerplate
- * initialization can be kept to a minimum for implementing subclasses, while avoiding
- * the mistake-proneness of extremely long argument lists.
- *
- * @see ScalarEncoder.Builder#setStuff(int)
- */
- public static class Builder extends Encoder.Builder {
- private Builder() {}
-
- @Override
- public ScalarEncoder build() {
- //Must be instantiated so that super class can initialize
- //boilerplate variables.
- encoder = new ScalarEncoder();
-
- //Call super class here
- super.build();
-
- ////////////////////////////////////////////////////////
- // Implementing classes would do setting of specific //
- // vars here together with any sanity checking //
- ////////////////////////////////////////////////////////
-
- ((ScalarEncoder)encoder).init();
-
- return (ScalarEncoder)encoder;
- }
- }
+ */
+ @SuppressWarnings("unchecked")
+ @Override
+ public List getBucketValues(Class t) {
+ if(bucketValues == null) {
+ SparseObjectMatrix topDownMapping = getTopDownMapping();
+ int numBuckets = topDownMapping.getMaxIndex() + 1;
+ bucketValues = new ArrayList();
+ for(int i = 0;i < numBuckets;i++) {
+ ((List)bucketValues).add((Double)getBucketInfo(new int[] { i }).get(0).get(1));
+ }
+ }
+ return (List)bucketValues;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public List getBucketInfo(int[] buckets) {
+ SparseObjectMatrix topDownMapping = getTopDownMapping();
+
+ //The "category" is simply the bucket index
+ int category = buckets[0];
+ int[] encoding = topDownMapping.getObject(category);
+
+ //Which input value does this correspond to?
+ double inputVal;
+ if(isPeriodic()) {
+ inputVal = getMinVal() + getResolution() / 2 + category * getResolution();
+ }else{
+ inputVal = getMinVal() + category * getResolution();
+ }
+
+ return Arrays.asList(new Encoding(inputVal, inputVal, encoding));
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public List topDownCompute(int[] encoded) {
+ //Get/generate the topDown mapping table
+ SparseObjectMatrix topDownMapping = getTopDownMapping();
+
+ // See which "category" we match the closest.
+ int category = ArrayUtils.argmax(rightVecProd(topDownMapping, encoded));
+
+ return getBucketInfo(new int[]{category});
+ }
+
+ /**
+ * Returns a list of {@link Tuple}s which in this case is a list of
+ * key value parameter values for this {@code ScalarEncoder}
+ *
+ * @return a list of {@link Tuple}s
+ */
+ public List dict() {
+ List l = new ArrayList();
+ l.add(new Tuple("maxval", getMaxVal()));
+ l.add(new Tuple("bucketValues", getBucketValues(Double.class)));
+ l.add(new Tuple("nInternal", getNInternal()));
+ l.add(new Tuple("name", getName()));
+ l.add(new Tuple("minval", getMinVal()));
+ l.add(new Tuple("topDownValues", Arrays.toString(getTopDownValues())));
+ l.add(new Tuple("clipInput", clipInput()));
+ l.add(new Tuple("n", getN()));
+ l.add(new Tuple("padding", getPadding()));
+ l.add(new Tuple("range", getRange()));
+ l.add(new Tuple("periodic", isPeriodic()));
+ l.add(new Tuple("radius", getRadius()));
+ l.add(new Tuple("w", getW()));
+ l.add(new Tuple("topDownMappingM", getTopDownMapping()));
+ l.add(new Tuple("halfwidth", getHalfWidth()));
+ l.add(new Tuple("resolution", getResolution()));
+ l.add(new Tuple("rangeInternal", getRangeInternal()));
+
+ return l;
+ }
+
+ /**
+ * Returns a {@link EncoderBuilder} for constructing {@link ScalarEncoder}s
+ *
+ * The base class architecture is put together in such a way where boilerplate
+ * initialization can be kept to a minimum for implementing subclasses, while avoiding
+ * the mistake-proneness of extremely long argument lists.
+ *
+ * @see ScalarEncoder.Builder#setStuff(int)
+ */
+ public static class Builder extends Encoder.Builder {
+ private Builder() {}
+
+ @Override
+ public ScalarEncoder build() {
+ //Must be instantiated so that super class can initialize
+ //boilerplate variables.
+ encoder = new ScalarEncoder();
+
+ //Call super class here
+ super.build();
+
+ ////////////////////////////////////////////////////////
+ // Implementing classes would do setting of specific //
+ // vars here together with any sanity checking //
+ ////////////////////////////////////////////////////////
+
+ ((ScalarEncoder)encoder).init();
+
+ return (ScalarEncoder)encoder;
+ }
+ }
}
From 932b4d6cafc1dfbf656adfa4cbb0ef6c7e07220e Mon Sep 17 00:00:00 2001
From: Andrew Dillon
Date: Sun, 5 Mar 2017 12:06:38 -0600
Subject: [PATCH 13/52] Added KEY.INFERRED_FIELDS and modified Layer to utilize
that new parameters
---
.../java/org/numenta/nupic/Parameters.java | 21 +--
.../java/org/numenta/nupic/network/Layer.java | 122 ++++++++++--------
.../numenta/nupic/network/ManualInput.java | 12 +-
3 files changed, 90 insertions(+), 65 deletions(-)
diff --git a/src/main/java/org/numenta/nupic/Parameters.java b/src/main/java/org/numenta/nupic/Parameters.java
index 001f8160..ad852e2a 100644
--- a/src/main/java/org/numenta/nupic/Parameters.java
+++ b/src/main/java/org/numenta/nupic/Parameters.java
@@ -23,19 +23,13 @@
package org.numenta.nupic;
import java.io.IOException;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.EnumMap;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import java.util.Set;
+import java.util.*;
import org.numenta.nupic.algorithms.Classifier;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.model.Cell;
+import org.numenta.nupic.model.Segment;
import org.numenta.nupic.model.Column;
import org.numenta.nupic.model.ComputeCycle;
import org.numenta.nupic.model.DistalDendrite;
@@ -69,6 +63,7 @@ public class Parameters implements Persistable {
private static final Map DEFAULTS_TEMPORAL;
private static final Map DEFAULTS_SPATIAL;
private static final Map DEFAULTS_ENCODER;
+ private static final Map DEFAULTS_CLASSIFIER;
static {
@@ -141,6 +136,12 @@ public class Parameters implements Persistable {
DEFAULTS_ENCODER = Collections.unmodifiableMap(defaultEncoderParams);
defaultParams.putAll(DEFAULTS_ENCODER);
+ /////////// Classifier Parameters ///////////
+ Map defaultClassifierParams = new ParametersMap();
+ defaultClassifierParams.put(KEY.INFERRED_FIELDS, new HashMap>());
+ DEFAULTS_CLASSIFIER = Collections.unmodifiableMap(defaultClassifierParams);
+ defaultParams.putAll(DEFAULTS_CLASSIFIER);
+
DEFAULTS_ALL = Collections.unmodifiableMap(defaultParams);
}
@@ -418,7 +419,9 @@ public static enum KEY {
// Network Layer indicator for auto classifier generation
AUTO_CLASSIFY("hasClassifiers", Boolean.class),
- INFERRED_FIELDS("inferredFields", Map.class), // Map
// How many bits to use if encoding the respective date fields.
// e.g. Tuple(bits to use:int, radius:double)
diff --git a/src/main/java/org/numenta/nupic/network/Layer.java b/src/main/java/org/numenta/nupic/network/Layer.java
index 1577f7b3..96e0ed28 100644
--- a/src/main/java/org/numenta/nupic/network/Layer.java
+++ b/src/main/java/org/numenta/nupic/network/Layer.java
@@ -37,11 +37,7 @@
import org.numenta.nupic.FieldMetaType;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
-import org.numenta.nupic.algorithms.Anomaly;
-import org.numenta.nupic.algorithms.CLAClassifier;
-import org.numenta.nupic.algorithms.Classification;
-import org.numenta.nupic.algorithms.SpatialPooler;
-import org.numenta.nupic.algorithms.TemporalMemory;
+import org.numenta.nupic.algorithms.*;
import org.numenta.nupic.encoders.DateEncoder;
import org.numenta.nupic.encoders.Encoder;
import org.numenta.nupic.encoders.EncoderTuple;
@@ -400,7 +396,7 @@ public Layer(Parameters params, MultiEncoder e, SpatialPooler sp, TemporalMemory
(encoder == null ? "" : "MultiEncoder,"),
(spatialPooler == null ? "" : "SpatialPooler,"),
(temporalMemory == null ? "" : "TemporalMemory,"),
- (autoCreateClassifiers == null ? "" : "Auto creating CLAClassifiers for each input field."),
+ (autoCreateClassifiers == null ? "" : "Auto creating Classifiers for each input field."),
(anomalyComputer == null ? "" : "Anomaly"));
}
}
@@ -1699,38 +1695,9 @@ private void doEncoderBucketMapping(Inference inference, Map enc
// Should probably change this so that instead of adding a mapping for
// each encoder, it adds a mapping for each field specified by the user
// in the new Parameters KEY.
-// for(EncoderTuple t : encoderTuples) {
-// String name = t.getName();
-// Encoder> e = t.getEncoder();
-//
-// int bucketIdx = -1;
-// Object o = encoderInputMap.get(name);
-// if(DateTime.class.isAssignableFrom(o.getClass())) {
-// bucketIdx = ((DateEncoder)e).getBucketIndices((DateTime)o)[0];
-// } else if(Number.class.isAssignableFrom(o.getClass())) {
-// bucketIdx = e.getBucketIndices((double)o)[0];
-// } else {
-// bucketIdx = e.getBucketIndices((String)o)[0];
-// }
-//
-// int offset = t.getOffset();
-// int[] tempArray = new int[e.getWidth()];
-// System.arraycopy(encoding, offset, tempArray, 0, tempArray.length);
-//
-// inference.getClassifierInput().put(name, new NamedTuple(new String[] { "name", "inputValue", "bucketIdx", "encoding" }, name, o, bucketIdx, tempArray));
-// }
-
- // Get fields user wants encoded from Parameters
- Map inferredFields = (Map)params.get(KEY.INFERRED_FIELDS);
-
- // Store a NamedTuple for each of those fields
- for(Map.Entry entry : inferredFields.entrySet()) {
- String fieldName = entry.getKey(); // Name of encoder input field
- EncoderTuple encoderTuple = encoderTuples.stream() // Get the EncoderTuple for this input field
- .filter(e -> e.getName().equals(fieldName))
- .collect(Collectors.toList())
- .get(0);
- Encoder> e = encoderTuple.getEncoder();
+ for(EncoderTuple t : encoderTuples) {
+ String name = t.getName();
+ Encoder> e = t.getEncoder();
int bucketIdx = -1;
Object o = encoderInputMap.get(name);
@@ -1742,19 +1709,53 @@ private void doEncoderBucketMapping(Inference inference, Map enc
bucketIdx = e.getBucketIndices((String)o)[0];
}
- int offset = encoderTuple.getOffset();
+ int offset = t.getOffset();
int[] tempArray = new int[e.getWidth()];
System.arraycopy(encoding, offset, tempArray, 0, tempArray.length);
- inference.getClassifierInput().put(
- name,
- new NamedTuple(new String[] { "name", "inputValue", "bucketIdx", "encoding" },
- name,
- o,
- bucketIdx,
- tempArray
- ));
+ inference.getClassifierInput().put(name, new NamedTuple(new String[] { "name", "inputValue", "bucketIdx", "encoding" }, name, o, bucketIdx, tempArray));
}
+
+// // Get fields user wants encoded from Parameters
+// Map inferredFields = (Map)params.get(KEY.INFERRED_FIELDS);
+// if(inferredFields == null) {
+// LOGGER.info("KEY.INFERRED_FIELDS is null, no fields will be classified.");
+// return;
+// }
+//
+// // Store a NamedTuple for each of those fields
+// for(Map.Entry entry : inferredFields.entrySet()) {
+// String fieldName = entry.getKey(); // Name of encoder input field
+// EncoderTuple encoderTuple = encoderTuples.stream() // Get the EncoderTuple for this input field
+// .filter(e -> e.getName().equals(fieldName))
+// .collect(Collectors.toList())
+// .get(0);
+// Encoder> e = encoderTuple.getEncoder();
+//
+// int bucketIdx = -1;
+// Object o = encoderInputMap.get(fieldName);
+// if(DateTime.class.isAssignableFrom(o.getClass())) {
+// bucketIdx = ((DateEncoder)e).getBucketIndices((DateTime)o)[0];
+// } else if(Number.class.isAssignableFrom(o.getClass())) {
+// bucketIdx = e.getBucketIndices((double)o)[0];
+// } else {
+// bucketIdx = e.getBucketIndices((String)o)[0];
+// }
+//
+// int offset = encoderTuple.getOffset();
+// int[] tempArray = new int[e.getWidth()];
+// System.arraycopy(encoding, offset, tempArray, 0, tempArray.length);
+//
+// inference.getClassifierInput().put(
+// fieldName,
+// new NamedTuple(
+// new String[] { "name", "inputValue", "bucketIdx", "encoding" },
+// fieldName,
+// o,
+// bucketIdx,
+// tempArray
+// ));
+// }
}
/**
@@ -1958,12 +1959,26 @@ NamedTuple makeClassifiers(MultiEncoder encoder) {
// Looks like the passed-in MultiEncoder is a single wrapper containing
// multiple encoders; one for each field. Right now, a classifier is
// created for each of those encoders. However, instead of do
+ Map inferredFields = (Map>) params.get(KEY.INFERRED_FIELDS);
String[] names = new String[encoder.getEncoders(encoder).size()];
- CLAClassifier[] ca = new CLAClassifier[names.length];
+ Classifier[] ca = new Classifier[names.length];
int i = 0;
for(EncoderTuple et : encoder.getEncoders(encoder)) {
names[i] = et.getName();
- ca[i] = new CLAClassifier();
+ Object fieldClassifier = inferredFields.get(et.getName());
+ if(fieldClassifier == CLAClassifier.class) {
+ LOGGER.info("Classifying \"" + et.getName() + "\" input field with CLAClassifier");
+ ca[i] = new CLAClassifier();
+ } else if(fieldClassifier == SDRClassifier.class) {
+ LOGGER.info("Classifying \"" + et.getName() + "\" input field with SDRClassifier");
+ ca[i] = new SDRClassifier();
+ } else {
+ if(fieldClassifier != null)
+ LOGGER.warn("Invalid Classifier class token, \"" + fieldClassifier +
+ "\", specified for, \"" + et.getName() + "\", input field. " +
+ "Valid class tokens are CLAClassifier.class and SDRClassifier.class");
+ LOGGER.info("Not classifying \"" + et.getName() + "\" input field");
+ }
i++;
}
return new NamedTuple(names, (Object[])ca);
@@ -2380,10 +2395,13 @@ public ManualInput call(ManualInput t1) {
bucketIdx = inputs.get("bucketIdx");
actValue = inputs.get("inputValue");
- CLAClassifier c = (CLAClassifier)t1.getClassifiers().get(key);
- Classification