예제 #1
0
 public static string[] load_attributes_from_hdf5_group(long group, string name)
 {
     if (Hdf5.AttributeExists(group, name))
     {
         var(success, attr) = Hdf5.ReadStringAttributes(group, name, "");
         if (success)
         {
             return(attr.ToArray());
         }
     }
     return(null);
 }
예제 #2
0
        public static List <(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List <ILayer> layers)
        {
            string original_keras_version = "2.5.0";
            string original_backend       = null;

            if (Hdf5.AttributeExists(f, "keras_version"))
            {
                var(success, attr) = Hdf5.ReadStringAttributes(f, "keras_version", "");
                if (success)
                {
                    original_keras_version = attr.First();
                }
                // keras version should be 2.5.0+
                var ver_major = int.Parse(original_keras_version.Split('.')[0]);
                var ver_minor = int.Parse(original_keras_version.Split('.')[1]);
                if (ver_major < 2 || (ver_major == 2 && ver_minor < 5))
                {
                    throw new ValueError("keras version should be 2.5.0 or later.");
                }
            }
            if (Hdf5.AttributeExists(f, "backend"))
            {
                var(success, attr) = Hdf5.ReadStringAttributes(f, "backend", "");
                if (success)
                {
                    original_backend = attr.First();
                }
            }

            var filtered_layers = new List <ILayer>();

            foreach (var layer in layers)
            {
                var weights = _legacy_weights(layer);
                if (weights.Count > 0)
                {
                    filtered_layers.append(layer);
                }
            }

            string[] layer_names          = load_attributes_from_hdf5_group(f, "layer_names");
            var      filtered_layer_names = new List <string>();

            foreach (var name in layer_names)
            {
                if (!filtered_layers.Select(x => x.Name).Contains(name))
                {
                    continue;
                }
                long g            = H5G.open(f, name);
                var  weight_names = load_attributes_from_hdf5_group(g, "weight_names");
                if (weight_names.Count() > 0)
                {
                    filtered_layer_names.Add(name);
                }
                H5G.close(g);
            }

            layer_names = filtered_layer_names.ToArray();
            if (layer_names.Length != filtered_layers.Count())
            {
                throw new ValueError("You are trying to load a weight file " +
                                     $"containing {layer_names}" +
                                     $" layers into a model with {filtered_layers.Count} layers.");
            }

            var weight_value_tuples = new List <(IVariableV1, NDArray)>();

            foreach (var(k, name) in enumerate(layer_names))
            {
                var  weight_values = new List <NDArray>();
                long g             = H5G.open(f, name);
                var  weight_names  = load_attributes_from_hdf5_group(g, "weight_names");
                foreach (var i_ in weight_names)
                {
                    (bool success, Array result) = Hdf5.ReadDataset <float>(g, i_);
                    if (success)
                    {
                        weight_values.Add(np.array(result));
                    }
                }
                H5G.close(g);
                var layer            = filtered_layers[k];
                var symbolic_weights = _legacy_weights(layer);
                preprocess_weights_for_loading(layer, weight_values, original_keras_version, original_backend);
                if (weight_values.Count() != symbolic_weights.Count())
                {
                    throw new ValueError($"Layer #{k} (named {layer.Name}" +
                                         "in the current model) was found to " +
                                         $"correspond to layer {name} in the save file." +
                                         $"However the new layer {layer.Name} expects " +
                                         $"{symbolic_weights.Count()} weights, but the saved weights have " +
                                         $"{weight_values.Count()} elements.");
                }
                weight_value_tuples.AddRange(zip(symbolic_weights, weight_values));
            }

            keras.backend.batch_set_value(weight_value_tuples);
            return(weight_value_tuples);
        }