diff --git a/contrib/swig/build.sbt b/contrib/swig/build.sbt index db6185b02..a7dd0347d 100644 --- a/contrib/swig/build.sbt +++ b/contrib/swig/build.sbt @@ -1,3 +1,4 @@ +javaOptions in Test ++= Seq("-Xms1G","-XX:+CMSClassUnloadingEnabled","-XX:+UseConcMarkSweepGC") lazy val root = (project in file(".")) .settings( name := "dynet_scala_helpers", diff --git a/contrib/swig/dynet_swig.i b/contrib/swig/dynet_swig.i index e7a349dfb..b1610d930 100644 --- a/contrib/swig/dynet_swig.i +++ b/contrib/swig/dynet_swig.i @@ -239,8 +239,9 @@ struct Parameter { Dim dim(); Tensor* values(); - + Tensor* gradients(); void set_updated(bool b); + void set_value(const std::vector& val); bool is_updated(); }; @@ -253,6 +254,7 @@ struct LookupParameter { std::vector* values(); void set_updated(bool b); bool is_updated(); + void set_value(const std::vector& val); }; struct ParameterInit { @@ -341,6 +343,9 @@ struct ParameterStorage : public ParameterStorageBase { void accumulate_grad(const Tensor& g); void clear(); + void set_value(const std::vector& val); + Tensor* value(); + Tensor* gradients(); Dim dim; Tensor values; Tensor g; @@ -360,6 +365,10 @@ struct LookupParameterStorage : public ParameterStorageBase { void accumulate_grads(unsigned n, const unsigned* ids_host, const unsigned* ids_dev, float* g); void clear(); + void set_value(const std::vector& val); + Tensor* get_all_values(); + Tensor* get_all_grads(); + // Initialize each individual lookup from the overall tensors void initialize_lookups(); }; @@ -411,6 +420,7 @@ std::vector as_vector(const Tensor& v); struct TensorTools { static float access_element(const Tensor& v, const Dim& index); + static void zero(Tensor& d); }; ///////////////////////////////////// @@ -435,6 +445,7 @@ struct Expression { VariableIndex i; Expression(ComputationGraph *pg, VariableIndex i) : pg(pg), i(i) { }; const Tensor& value(); + const Tensor& gradient(); const Dim& dim() const { return pg->get_dimension(i); } }; diff --git a/contrib/swig/src/main/java/edu/cmu/dynet/examples/LanguageModelExample.java b/contrib/swig/src/main/java/edu/cmu/dynet/examples/LanguageModelExample.java new file mode 100644 index 000000000..b03dd5a67 --- /dev/null +++ b/contrib/swig/src/main/java/edu/cmu/dynet/examples/LanguageModelExample.java @@ -0,0 +1,185 @@ +package edu.cmu.dynet.examples; + +import static edu.cmu.dynet.internal.dynet_swig.as_scalar; +import static edu.cmu.dynet.internal.dynet_swig.as_vector; +import static edu.cmu.dynet.internal.dynet_swig.exprPlus; +import static edu.cmu.dynet.internal.dynet_swig.exprTimes; +import static edu.cmu.dynet.internal.dynet_swig.initialize; +import static edu.cmu.dynet.internal.dynet_swig.lookup; +import static edu.cmu.dynet.internal.dynet_swig.parameter; +import static edu.cmu.dynet.internal.dynet_swig.pickneglogsoftmax; +import static edu.cmu.dynet.internal.dynet_swig.softmax; +import static edu.cmu.dynet.internal.dynet_swig.sum; +import static edu.cmu.dynet.internal.dynet_swig.sum_batches; + +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +import edu.cmu.dynet.internal.ComputationGraph; +import edu.cmu.dynet.internal.Dim; +import edu.cmu.dynet.internal.DynetParams; +import edu.cmu.dynet.internal.Expression; +import edu.cmu.dynet.internal.ExpressionVector; +import edu.cmu.dynet.internal.FloatVector; +import edu.cmu.dynet.internal.LongVector; +import edu.cmu.dynet.internal.LookupParameter; +import edu.cmu.dynet.internal.Parameter; +import edu.cmu.dynet.internal.ParameterCollection; +import edu.cmu.dynet.internal.SimpleRNNBuilder; +import edu.cmu.dynet.internal.SimpleSGDTrainer; +import edu.cmu.dynet.internal.Tensor; + +/** + * + * @author Allan (allanmcgrady@gmail.com) + * This follows the example in dynet RNN Tutorial: Character-level LSTM + * http://dynet.readthedocs.io/en/latest/tutorials_notebooks/RNNs.html + * + * Simply translating the code from Python to Java. + */ +public class LanguageModelExample { + + public static String characters = "abcdefghijklmnopqrstuvwxyz "; + public static Random rand = new Random(1234); + public int layers = 1; + public int inputDim = 50; + public int hiddenDim = 50; + + public LookupParameter ltp; + public Parameter r; + public Parameter bias; + public Expression lt; + public Expression re; + public Expression be; + + public ComputationGraph cg; + public SimpleRNNBuilder srnn; + public Map char2int; + public Map int2char; + public SimpleSGDTrainer sgd; + + public LanguageModelExample(int layers, int inputDim, int hiddenDim) { + //preprocessing to get the mapping between characters and index + this.char2int = new HashMap<>(); + this.int2char = new HashMap<>(); + for (int i = 0 ; i < characters.length(); i++) { + this.char2int.put(characters.substring(i, i+1), i); + this.int2char.put(i, characters.substring(i, i+1)); + } + this.char2int.put("", characters.length()); + this.int2char.put(characters.length(), ""); + this.layers = layers; + this.inputDim = inputDim; + this.hiddenDim = hiddenDim; + this.initializeModel(); + } + + private Dim makeDim(int[] dims) { + LongVector dimInts = new LongVector(); + for (int i = 0; i < dims.length; i++) { + dimInts.add(dims[i]); + } + return new Dim(dimInts); + } + + /** + * Initialize the model with parameters + * @return + */ + public ParameterCollection initializeModel() { + DynetParams dp = new DynetParams(); + dp.setRandom_seed(1234); + initialize(dp); + ParameterCollection model = new ParameterCollection(); + sgd = new SimpleSGDTrainer(model); + sgd.clip_gradients(); + sgd.setClip_threshold((float)5.0); + cg = ComputationGraph.getNew(); + srnn = new SimpleRNNBuilder(layers, inputDim, hiddenDim, model); + + ltp = model.add_lookup_parameters(characters.length() + 1, makeDim(new int[]{inputDim})); + r = model.add_parameters(makeDim(new int[]{characters.length() + 1, hiddenDim})); + bias = model.add_parameters(makeDim(new int[]{characters.length() + 1})); + lt = parameter(cg, ltp); + re = parameter(cg, r); + be = parameter(cg, bias); + return model; + } + + /** + * Build the RNN for a specific sequence + * @param sent + * @return + */ + public Expression buildForward(String sent) { + srnn.new_graph(cg); + srnn.start_new_sequence(); + ExpressionVector finalErr = new ExpressionVector(); + String last = ""; + String next = null; + for (int i = 0 ; i <= sent.length(); i++) { + Expression curr = lookup(cg, ltp, char2int.get(last)); + Expression curr_y = srnn.add_input(curr); + Expression curr_r = exprPlus(exprTimes(re, curr_y), be); + next = i == sent.length()? "" : sent.substring(i, i+1); + Expression curr_err = pickneglogsoftmax(curr_r, char2int.get(next)); + finalErr.add(curr_err); + last = next; + } + Expression lossExpr = sum_batches(sum((finalErr))); + return lossExpr; + } + + public int sample(FloatVector fv) { + float f = rand.nextFloat(); + int i = 0; + for (i = 0; i < fv.size(); i++) { + f -= fv.get(i); + if (f <=0 ) break; + } + return i; + } + + public String generateSentence() { + srnn.new_graph(cg); + srnn.start_new_sequence(); + Expression start = lookup(cg, ltp, char2int.get("")); + Expression s1 = srnn.add_input(start); + String out = ""; + while(true) { + Expression prob = softmax(exprPlus(exprTimes(re, s1), be) ); + int idx = sample(as_vector(cg.incremental_forward(prob))); + out += int2char.get(idx); + if (int2char.get(idx).equals("")) break; + s1 = srnn.add_input(lookup(cg, ltp, idx)); + } + return out; + } + + public static void main(String[] args) { + int layers = 1; + int inputDim = 50; + int hiddenDim = 50; + LanguageModelExample lm = new LanguageModelExample(layers, inputDim, hiddenDim); + String sent = "a quick brown fox jumped over the lazy dog"; + float loss = 0; + for (int it = 0; it < 100; it++) { + Expression lossExpr = lm.buildForward(sent); + Tensor lossTensor = lm.cg.forward(lossExpr); + loss = as_scalar(lossTensor); + lm.cg.backward(lossExpr); + lm.sgd.update(); + if (it % 5 == 0) { + System.out.print("loss is : " + loss); + String prediction = lm.generateSentence(); + System.out.println(" prediction: " + prediction); + } + lm.cg.forward(lossExpr); + lm.cg.backward(lossExpr); + lm.sgd.update(); + } + + } + +} diff --git a/dynet/model.cc b/dynet/model.cc index f81ba74e9..c0195c9e3 100644 --- a/dynet/model.cc +++ b/dynet/model.cc @@ -122,6 +122,7 @@ void ParameterStorage::set_value(const std::vector& val) { TensorTools::set_elements(values, val); } + bool valid_parameter(const std::string & s) { auto it = std::find_if(s.begin(), s.end(), [] (char ch) { return ch == '/' || ch == '_'; }); return it == s.end(); @@ -241,6 +242,10 @@ void LookupParameter::initialize(unsigned index, const std::vector& val) get_storage().initialize(index, val); } +void LookupParameter::set_value(const std::vector& val){ + get_storage().set_value(val); +} + string LookupParameter::get_fullname() const { DYNET_ASSERT(p != nullptr, "Attempt to get pointer for null parameter"); return p->name; @@ -803,6 +808,11 @@ void LookupParameterStorage::scale_gradient(float a) { } #endif + +void LookupParameterStorage::set_value(const std::vector& val) { + TensorTools::set_elements(all_values, val); +} + template float ParameterCollectionStorage::gradient_l2_norm_dev(MyDevice &dev) const { auto scratch_size = (all_params.size() + 1) * sizeof(float); diff --git a/dynet/model.h b/dynet/model.h index 2b49406c3..714c83b3d 100644 --- a/dynet/model.h +++ b/dynet/model.h @@ -137,6 +137,14 @@ struct ParameterStorage : public ParameterStorageBase { */ void clip(float left, float right); void set_value(const std::vector& val); + + /** + * \brief gradients of the parameter + * + * \return gradients as a `Tensor` object + */ + Tensor* gradients() { return &g; } + Tensor* value() { return &values; } Dim dim; /**< Dimensions of the parameter tensor*/ @@ -238,7 +246,10 @@ struct LookupParameterStorage : public ParameterStorageBase { */ void accumulate_grads(unsigned n, const unsigned* ids_host, const unsigned* ids_dev, float* g); void clear(); + void set_value(const std::vector& val); + Tensor* get_all_grads() { return &all_grads; } + Tensor* get_all_values() { return &all_values; } // Initialize each individual lookup from the overall tensors void initialize_lookups(); @@ -457,6 +468,8 @@ struct LookupParameter { * @return Update status */ bool is_updated(); + + void set_value(const std::vector& val); }; // struct LookupParameter // This is an internal class to store parameters in the collection