UDF - keeping old name of band in datacube after renaming

Hello all,

I created a new task but this is very much related to this

I’m still encountering a problem with the output band from the UDF. They are not being recognized in subsequent steps.

This particular section of the code is experiencing a failure due to the existence of an output band named “DEM” which is part of the UDF. It is worth noting that the “DEM” output from the UDF has been applied to the s2_cube which involves applying a shadow mask for each scene.

 def s1_s2_water_function(data):
            return LOOKUPTABLE[zone]["S1_S2"](vv=data[0], ndvi=data[1], ndwi=data[2])
        
 # THE PROBLEM STARTS HERE, see the full code below:
s2_cube_median = s2_cube.filter_temporal([start_date, end_date]).median_time()
s2_cube_median_NDVI = s2_cube_median.band("NDWI")  
        
s1_s2_water_save = s2_cube_median_NDVI.save_result(format='netCDF') #GTiff #netCDF
my_job  = s1_s2_water_save.send_job(title="s2_cube_median_NDVI")
results = my_job.start_and_wait().get_results()
results.download_files('s2_cube_median_NDVI')

Even after renaming the bands in the output, the code continues to fail and raises an error regarding the presence of “DEM,” which should have been removed from the cube.

s2_cube = s2_cube.rename_labels("bands", ["B02", "B03", "B04", "B08", "sunAzimuthAngles", "sunZenithAngles","NDWI", "NDVI"])

Error message:
OpenEoApiError: [500] Internal: Server error: ValueError("Invalid band name/index 'NDVI'. Valid names: ['DEM']") (ref: r-942a6f9a124541e9be5ea8ae626ffb8b)

The full code, an ouput is in the middle sent as a job as full code does not work.


#### Define all widget
zone_w = widgets.RadioButtons(
    options=['Deserts', 'Mountain','Tropical forest','Tropical savanna','Subtropical forest',
             'Subtropical savanna','Temperate broadleaf','Temperate grassland'],
    layout={'width': 'max-content'},
    description='Ecoregions',
    disabled=False)

start_date_w = widgets.DatePicker(
    description='Start Date',
    value = date(2021,5,1),
    disabled=False)

end_date_w = widgets.DatePicker(
    description='End Date',
    value = date(2021,6,1),
    disabled=False)

threshold = widgets.IntSlider(value =75, description='Threshold',)
threshold_cloud_cover = widgets.IntSlider(value = 99, description='Cloud Cover',)

#### Define Map
map = folium.Map(location= [19.462,-99.95], tiles= None, zoom_start=12.54).add_to(folium.Figure(height = 800))
tile_layer = folium.TileLayer( tiles = "https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}", 
                               attr = "Tiles © Esri — Source: Esri, i-cubed, USDA, USGS, AEX, GeoEye, Getmapping, Aerogrid, IGN, IGP, UPR-EGP, and the GIS User Community",
                               name = 'Satellite').add_to(map)
draw = plugins.Draw(export=True,  filename='aoi.geojson', position='topleft').add_to(map)

#### Show widgets
display(map)
display(start_date_w)
display(end_date_w)
display(zone_w)
display(threshold)
display(threshold_cloud_cover)


class LoadedButton(widgets.Button):
    """A button that can holds a value as a attribute."""

    def __init__(self, value=None, *args, **kwargs):
        super(LoadedButton, self).__init__(*args, **kwargs)
        # Create the value attribute.
        self.add_traits(value=traitlets.Any(value))

# Define the 'Click me' button
get_data_button = LoadedButton(description='Run',
                                 disabled=False,
                                 button_style='',
                                 tooltip='Click me',
                                 icon='check',
                                 value = '')        
        

def WWT(b):

    while True:
           
        try:
            file = 'aoi.geojson'
            os.path.isfile(file)
            EsriImagery = "https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}"
            EsriAttribution = "Tiles © Esri — Source: Esri, i-cubed, USDA, USGS, AEX, GeoEye, Getmapping, Aerogrid, IGN, IGP, UPR-EGP, and the GIS User Community"
            map = folium.Map(location= [19.462,-99.95], tiles= None, zoom_start=12.54).add_to(folium.Figure(height = 800))
            tile_layer = folium.TileLayer( tiles = EsriImagery, attr = EsriAttribution, name = 'Satellite',).add_to(map)
            draw = plugins.Draw(export=True,  filename='aoi.geojson', position='topleft').add_to(map)
            gdf = gpd.read_file(file)
            gdf_folium = folium.GeoJson(data=gdf["geometry"], name ='geojson').add_to(map) 
            bbox = gdf.geometry.total_bounds
            map.fit_bounds(gdf_folium.get_bounds())
                

        except ValueError:
            print('Please insert AOI using drawing toolbox')
            break


        zone                     = zone_w.value
        get_data_button.disabled = True
        spatial_extent           = {'west':bbox[0],'east':bbox[2],'south':bbox[1],'north':bbox[3],'crs':4326} 
        start_date               = start_date_w.value
    
        end_date                 = end_date_w.value  ## End date, 1 month later (1st Feb. 2021)
        start_date_exclusion     = (start_date  + relativedelta(months = -1)) 
        bands                    = ['B02', 'B03', 'B04', 'B08', 'CLP', 'SCL' , 'sunAzimuthAngles', 'sunZenithAngles'] 


        LOOKUPTABLE = {

            "Deserts": {
                "S1": lambda vh, vv: 1 / (1 + exp(- (-7.03 + (-0.44 * vv)))),
                "S2": lambda ndvi, ndwi: 1 / (1 + exp(- (0.133 + (-5.92 * ndvi) + (14.82 * ndwi)))),
                "S1_S2": lambda vv, ndvi, ndwi: 1 / (1 + exp(- (-3.69 + (-0.25 * vv) + (0.47 * ndvi) + (15.3 * ndwi)))),
            },
            "Mountain": {
                "S1": lambda vh, vv: 1 / (1 + exp(- (-3.76 + (-0.262 * vv)))),
                "S2": lambda ndvi, ndwi: 1 / (1 + exp(- (0.262 + (0.75 * ndvi) + (12.65 * ndwi)))),
                "S1_S2": lambda vv, ndvi, ndwi: 1 / (1 + exp(- (-1.13 + (-0.11 * vv) + (3.03 * ndvi) + (13.21 * ndwi)))),
            },
            "Tropical forest":
                {
                    "S1": lambda vh, vv: (1 / (1 + exp(- (-5.8 + (-0.415 * vv))))),
                    "S2": lambda ndvi, ndwi: (1 / (1 + exp(- (0.344 + (2.886 * ndvi) + (11.91 * ndwi)))))*100,
                    "S1_S2": lambda vv, ndvi, ndwi: (1 / (1 + exp(- (-3.25 + (-0.23 * vv) + (4.17 * ndvi) + (9.5 * ndwi))))),
            },
             "Tropical savanna":
                {
                    "S1": lambda vh, vv: (1 / (1 + exp(- (-7.0 + (-0.444 * vv))))),
                    "S2": lambda ndvi, ndwi: (1 / (1 + exp(- (0.344 + (2.886 * ndvi) + (11.91 * ndwi)))))*100,
                    "S1_S2": lambda vv, ndvi, ndwi: (1 / (1 + exp(- (-1.06 + (-0.17 * vv) + (3.82* ndvi) + (14.4* ndwi))))),
            },
            "Subtropical savanna":
                {
                    "S1": lambda vh, vv: 1 / (1 + exp(- (-7.17 + (-0.48 * vv)))),
                    "S2": lambda ndvi, ndwi: 1 / (1 + exp(- (0.845 + (2.14 * ndvi) + (13.5 * ndwi)))),
                    "S1_S2": lambda vv, ndvi, ndwi: 1 / (1 + exp(- (-2.64 + (-0.23 * vv) + (8.6 * ndwi)))),
            },
            "Subtropical forest":
                {
                    "S1": lambda vh, vv: 1 / (1 + exp(- (-6.67 + (-6.67* vv)))),
                    "S2": lambda ndvi, ndwi: 1 / (1 + exp(- (0.712 + (-1.133 * ndvi) + (7.16 * ndwi)))),
                    "S1_S2": lambda vv, ndvi, ndwi: 1 / (1 + exp(- (-2.72 + (-0.22 * vv) + (-0.49  * ndvi) + 4.55 * ndwi))),
            },
            "Temperate broadleaf":
                {
                "S1": lambda  vh, vv: 1 / (1 + exp(- (-8.82 + (-0.58 * vv)))),
                "S2": lambda ndvi, ndwi: 1 / (1 + exp(- (-0.013 + (5.38 * ndvi)) + (13.79 * ndwi))),
                "S1_S2": lambda vv, ndvi, ndwi: 1 / (1 + exp(- (-2.7 + (-0.2 * vv)) + (3.6 * ndvi)) + (9.73 * ndwi))
            },
            "Temperate grassland":
                {
                "S1": lambda  vh, vv: 1 / (1 + exp(- (-7.01 + (-0.426 * vv)))),
                "S2": lambda ndvi, ndwi: 1 / (1 + exp(- (1.286 + (8.74 * ndvi)) + (23.217 * ndwi))),
                "S1_S2": lambda vv, ndvi, ndwi: 1 / (1 + exp(- (-3.43 + (-0.25 * vv)) + (11.74 * ndvi)) + (22.035 * ndwi))
                }
            }

        s2_properties = {"eo:cloud_cover": lambda v: v <= threshold_cloud_cover.value}

        s2_cube = connection.load_collection(
            'SENTINEL2_L2A_SENTINELHUB',
            spatial_extent=spatial_extent,
            temporal_extent=[start_date_exclusion, end_date],
            bands=['B02', 'B03', 'B04', 'B08', 'sunAzimuthAngles', 'sunZenithAngles'],
            properties=s2_properties
        )

        s2_cube_masking = connection.load_collection(
            'SENTINEL2_L2A_SENTINELHUB',
            spatial_extent=spatial_extent,
            temporal_extent=[start_date_exclusion, end_date],
            bands=['CLP', 'SCL'],
            properties=s2_properties
        )

        scl = s2_cube_masking.band("SCL")
        mask_scl = (scl == 3) | (scl == 8) | (scl == 9) | (scl == 10) | (scl == 11)

        clp = s2_cube_masking.band("CLP")
        mask_clp = mask_scl | (clp / 255) > 0.3
    
        # Start hillshade function
        dem_cube = connection.load_collection("COPERNICUS_30",
                                    spatial_extent = spatial_extent,
                                    temporal_extent=["2010-01-01", "2030-12-31"],)

        # WGS84 30m 
        dem_cube = dem_cube.max_time()
        # Resample s2 cube (Azimuth and Zenith) to 30m
        s2_cube_30 = s2_cube.resample_spatial(resolution = 30, method = 'average')
        # DEM WGS84 to DEM UTM from S2 cube
        dem_cube_s2 = dem_cube.resample_cube_spatial(s2_cube_30)
        # merge 30m DEM and 30m s2 cube due Azimuth and Zenith
        merged_cube = s2_cube_30.merge_cubes(dem_cube_s2)
        
        
      
        
        

        # Udf function importing hillshade packages 
        
        udf_code = """
        
from openeo.udf import XarrayDataCube
from openeo.udf.debug import inspect
import numpy as np
from hillshade.hillshade import hillshade


def rasterize(azimuth, resolution=None):
    azimuth = np.deg2rad(azimuth)
    xdir, ydir = np.sin(azimuth), np.cos(azimuth)

    if resolution is not None:
        xdir = xdir * resolution[0]
        ydir = ydir * resolution[1]
        signx = np.sign(xdir)
        signy = np.sign(ydir)
    slope = abs(ydir / xdir)
    
    if slope < 1. and slope > -1.:
        xdir = 1.
        ydir = slope
    else:
        xdir = 1. / slope
        ydir = 1.
        
    return xdir*signx, ydir*signx


def _run_shader(sun_zenith, sun_azimuth, elevation_model, resolution_x, resolution_y):

    azimuth = np.nanmean(sun_azimuth.astype(np.float32))
    zenith = np.nanmean(sun_zenith.astype(np.float32))
    if np.isnan(azimuth):
        shadow = np.zeros(elevation_model.shape) + 255
    else:
        resolution = (float(resolution_x), float(resolution_y))
        ray_xdir, ray_ydir = rasterize(azimuth, resolution)
    
        # Assume chunking is already done by Dask
        ystart = 0
        yend = elevation_model.shape[0]

        # Make sure inputs have the right data type
        zenith = float(zenith)
        ray = (float(ray_xdir), float(ray_ydir))
        shadow = hillshade(elevation_model.astype(np.float32),
                           resolution,
                           zenith,
                           ray,
                           ystart,
                           yend)
        shadow = shadow.reshape(elevation_model.shape)
        shadow[np.isnan(sun_azimuth)] = 255
    return shadow


def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube:
    in_xarray = cube.get_array()
    sun_zenith = in_xarray.sel({"bands": "sunZenithAngles"}).values.astype(np.float32)
    sun_azimuth = in_xarray.sel({"bands": "sunAzimuthAngles"}).values.astype(np.float32)
    elevation_model = in_xarray.sel({"bands": "DEM"}).values.astype(np.float32)
    res_y = in_xarray.coords["y"][int(len(in_xarray.coords["y"])/2)+1] - in_xarray.coords["y"][int(len(in_xarray.coords["y"])/2)]
    res_x = in_xarray.coords["x"][int(len(in_xarray.coords["x"])/2)+1] - in_xarray.coords["x"][int(len(in_xarray.coords["x"])/2)]
    
    sun_zenith = sun_zenith *3
    
    shadow = _run_shader(sun_zenith, sun_azimuth, elevation_model, res_x, res_x)
    cube.get_array().values[0] = shadow
    
    return cube
    
"""
      
        process = openeo.UDF(code = udf_code, runtime="Python")

        hillshade = merged_cube.apply_neighborhood(process=process,
                                            size=[{"dimension":"t","value": "P1D"},
                                                  {"dimension": "x", "unit": "px", "value": 256},
                                                  {"dimension": "y", "unit": "px", "value": 256}],
                                            overlap=[{"dimension": "x", "unit": "px", "value": "8"},
                                                     {"dimension": "y", "unit": "px", "value": "8"}])

        # Rename bands in a hilllshade cube - WORKS hillshade
        hillshade = hillshade.rename_labels("bands", ["hillshade_mask", "B03", "B04", "B08", "sunAzimuthAngles", "sunZenithAngles","DEM"])

                       
        # Select a hilshade band from hillshce cube and resample it to 10m using s2 cube 
        hillshade_mask = hillshade.band("hillshade_mask").resample_cube_spatial(s2_cube)
        # Get binary results of hillshade mask  
        hillshade_mask_binary = hillshade_mask == 1
          

        # Mask s2 cube with a hillshade mask
        s2_cube_hillshade = s2_cube.mask(hillshade_mask_binary) 
        # Mask s2 cube with a cloud mask
        s2_cube = s2_cube_hillshade.mask(mask_clp.resample_cube_spatial(s2_cube_hillshade))
        
        # Replace 0 to nan in s2 cubes -Works
        s2_cube = s2_cube.mask(s2_cube.apply(lambda x: x.eq(0)), replacement = None)

        #  End of hillshade functions
        
        # NDVI and NDWI Calculation - Works
        s2_cube = append_indices(s2_cube, ["NDWI","NDVI"]) 
        s2_cube = s2_cube.rename_labels("bands", ["B02", "B03", "B04", "B08", "sunAzimuthAngles", "sunZenithAngles","NDWI", "NDVI"]) 
        
       
        #  Works
        def water_function(data):
            return LOOKUPTABLE[zone]["S2"](ndwi=data[6], ndvi=data[7])
        
        s2_cube_water = s2_cube.reduce_dimension(reducer=water_function, dimension="bands")
        s2_cube_water = s2_cube_water.add_dimension("bands", "water_prob", type="bands")
                
            
        s2_cube_water_threshold = s2_cube_water.apply_dimension(dimension="bands", process=lambda x: if_(x > 0.75, x, 0))
        s2_cube_water_threshold = s2_cube_water_threshold.rename_labels("bands", ["w_T75"])
    
    
        # SwF - works
        s2_cube_water_sum = s2_cube_water_threshold.reduce_dimension(reducer="sum", dimension="t")
        s2_cube_water_sum = s2_cube_water_sum.rename_labels("bands", ["sum"])
           
        s2_count = s2_cube.band("B08")    
        s2_count = s2_count.reduce_dimension(reducer=lambda data: data.count(), dimension="t")
 
        s2_cube_swf = s2_cube_water_sum.resample_cube_spatial(s2_count) / s2_count
        s2_cube_swf = s2_cube_swf.rename_labels("bands", ["swf"])
        
       
        # works
        s2_median_water = s2_cube_water.filter_temporal([start_date, end_date]).median_time()
        s2_cube_median = s2_cube.filter_temporal([start_date, end_date]).median_time()
        s1_cube = connection.load_collection(
            'SENTINEL1_GRD',
            spatial_extent=spatial_extent,
            temporal_extent=[start_date, end_date],
            bands=['VH', 'VV'],
            properties={"polarization": lambda p: p == "DV"})

        s1_cube = s1_cube.sar_backscatter(coefficient="gamma0-terrain", mask=True, elevation_model="COPERNICUS_30")

        s1_cube = s1_cube.rename_labels("bands", ["VH", "VV", "mask", "incidence_angle"])
        s1_cube_mask = s1_cube.band("mask")

        def apply_mask(bands):    
             return if_(bands.array_element(2)!=2,bands)
        s1_cube = s1_cube.apply_dimension(apply_mask, dimension="bands")

        def log_(x):
            return 10 * log(x, 10)
        s1_median = s1_cube.median_time().apply(log_)

        def s1_water_function(data):
            return LOOKUPTABLE[zone]["S1"](vh=data[0], vv=data[1])

        s1_median_water = s1_median.reduce_dimension(reducer=s1_water_function, dimension="bands")
        exclusion_mask = (s1_median_water.resample_cube_spatial(s2_cube_swf) > 0.5) & (s2_cube_swf < 0.33)
        s1_median_water_mask = s1_median_water.mask(exclusion_mask.resample_cube_spatial(s1_median_water))
        
        
     
        def s1_s2_water_function(data):
            return LOOKUPTABLE[zone]["S1_S2"](vv=data[0], ndvi=data[1], ndwi=data[2])
        
         # tHE PROBLEM STARTS HERE
        s2_cube_median = s2_cube.filter_temporal([start_date, end_date]).median_time()
        s2_cube_median_NDVI = s2_cube_median.band("NDWI")  
        
        s1_s2_water_save = s2_cube_median_NDVI.save_result(format='netCDF') #GTiff #netCDF
        my_job  = s1_s2_water_save.send_job(title="s2_cube_median_NDVI")
        results = my_job.start_and_wait().get_results()
        results.download_files('s2_cube_median_NDVI')
        
        s1_s2_cube = s1_median.filter_bands(["VV"]).resample_cube_spatial(s2_cube_median).merge_cubes(s2_cube_median.filter_bands(["NDVI","NDWI"])) 
        s1_s2_water = s1_s2_cube.reduce_dimension(reducer=s1_s2_water_function, dimension="bands").add_dimension("bands", "var", type="bands")

        s1_s2_mask = (s1_s2_water >= 0)
        s2_mask = s2_median_water.mask(s1_s2_mask) >= 0
        s1_mask = s1_median_water.mask(s1_s2_mask).mask(s2_mask) >= 0
        s1_s2_masked = s1_s2_water.mask(s1_s2_mask.apply(lambda x: x.eq(0)), replacement = 0)
        s2_masked = s2_median_water.mask(s2_mask.apply(lambda x: x.eq(0)), replacement = 0)
        s1_masked = s1_median_water.mask(s1_mask.apply(lambda x: x.eq(0)), replacement = 0)

        merge_all = s1_s2_masked.merge_cubes(s2_masked, overlap_resolver='sum').merge_cubes(s1_masked, overlap_resolver='sum')
        worldcover_cube = connection.load_collection("ESA_WORLDCOVER_10M_2020_V1", 
                                                temporal_extent = ['2020-12-30', '2021-01-01'], 
                                                spatial_extent = spatial_extent, 
                                                bands = ["MAP"])

        builtup_mask = worldcover_cube.band("MAP") == 50
        water_probability = merge_all.mask(builtup_mask.max_time().resample_cube_spatial(merge_all))
        water_probability = water_probability.rename_labels("bands", ["water_prob_sum"])

        output = water_probability > (threshold.value/100)
        output= output.rename_labels("bands", ["surface_water"])

        zone  = '_'.join(zone.split(" ")) 
        output_name = f'WWT_{zone}_{threshold.value}_{start_date_w.value.strftime("%Y_%m_%d")}'
        job_options={"node_caching":True}

        print('Spatial Extent:', spatial_extent)
        print('Start_date, End_date:', start_date, end_date)
        print('Zone:', zone)
        print('Theshold:', threshold.value)
        print('Cloud Cover',threshold_cloud_cover.value)

        output = output * 1.0
        
        ##NEW EXPORT

        #cube = output.save_result(format="GTiff",options=dict(filename_prefix="andrea_1"))
        #cube = s2_cube.save_result(format="GTiff",options=dict(filename_prefix="andrea_2"))
        #cube.execute_batch(format="GTiff",filename_prefix="andrea_2")
        
        ###### ENND
        
        
        output_save = output.save_result(format='GTiff') #GTiff #netCDF
        my_job  = output_save.create_job(title= output_name, job_options=job_options)
        results = my_job.start_and_wait().get_results()
        results.download_files(output_name)

        full_path_file =  output_name + '/openEO.tif'

        print('You can check results Open Editor: https://editor.openeo.org/?server=https%3A%2F%2Fopeneo-dev.vito.be')
        print('File is saved:', full_path_file)

        dst_crs = 'EPSG:4326'

        with rasterio.open(full_path_file) as src:
            transform, width, height = calculate_default_transform(
                src.crs, dst_crs, src.width, src.height, *src.bounds)
            kwargs = src.meta.copy()
            kwargs.update({
                'crs': dst_crs,
                'transform': transform,
                'width': width,
                'height': height
            })

            full_path_file_wgs = output_name + '/openEO_wgs.tif'
            with rasterio.open(full_path_file_wgs, 'w', **kwargs) as dst:
                for i in range(1, src.count + 1):
                    reproject(
                        source=rasterio.band(src, i),
                        destination=rasterio.band(dst, i),
                        src_transform=src.transform,
                        src_crs=src.crs,
                        dst_transform=transform,
                        dst_crs=dst_crs,
                        resampling=Resampling.nearest)
        

        da_dem  = xr.open_rasterio(full_path_file_wgs).drop('band')[0].rename({'x':'longitude', 'y':'latitude'})

        mlat = da_dem.latitude.values.min()
        mlon = da_dem.longitude.values.min()
        xlat = da_dem.latitude.values.max()
        xlon = da_dem.longitude.values.max()
        
        def colorize(array, cmap='viridis'):
            normed_data = (array - array.min()) / (array.max() - array.min())  
            cm = plt.cm.get_cmap(cmap)    
            return cm(normed_data)  

        colored_data = colorize(da_dem, cmap='Blues')
        tile_layer = folium.TileLayer( tiles = EsriImagery, attr = EsriAttribution, name = 'Satellite',).add_to(map)
        map.add_child(folium.raster_layers.ImageOverlay(colored_data,
                                          [[mlat, mlon], [xlat, xlon]],
                                          opacity=0.5, name = 'Water Extent'))
        folium.LayerControl().add_to(map)
        display(map)
        return full_path_file_wgs

from IPython.display import display
button = widgets.Button(description="Run")
output = widgets.Output()

display(button, output)

def on_button_clicked(b):
    with output:
        WWT(1)

button.on_click(on_button_clicked)


@stefaan.lippens Do you maybe have any idea why this fails?

FYI, I noticed this one too late, but already answered it in: